diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2018-07-28 10:51:19 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2018-07-28 10:51:19 +0000 |
commit | eb11fae6d08f479c0799db45860a98af528fa6e7 (patch) | |
tree | 44d492a50c8c1a7eb8e2d17ea3360ec4d066f042 /lib/Transforms | |
parent | b8a2042aa938069e862750553db0e4d82d25822c (diff) |
Notes
Diffstat (limited to 'lib/Transforms')
223 files changed, 24900 insertions, 13392 deletions
diff --git a/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp new file mode 100644 index 000000000000..b622d018478a --- /dev/null +++ b/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -0,0 +1,257 @@ +//===- AggressiveInstCombine.cpp ------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the aggressive expression pattern combiner classes. +// Currently, it handles expression patterns for: +// * Truncate instruction +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" +#include "AggressiveInstCombineInternal.h" +#include "llvm-c/Initialization.h" +#include "llvm-c/Transforms/Scalar.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Utils/Local.h" +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "aggressive-instcombine" + +namespace { +/// Contains expression pattern combiner logic. +/// This class provides both the logic to combine expression patterns and +/// combine them. It differs from InstCombiner class in that each pattern +/// combiner runs only once as opposed to InstCombine's multi-iteration, +/// which allows pattern combiner to have higher complexity than the O(1) +/// required by the instruction combiner. +class AggressiveInstCombinerLegacyPass : public FunctionPass { +public: + static char ID; // Pass identification, replacement for typeid + + AggressiveInstCombinerLegacyPass() : FunctionPass(ID) { + initializeAggressiveInstCombinerLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override; + + /// Run all expression pattern optimizations on the given /p F function. + /// + /// \param F function to optimize. + /// \returns true if the IR is changed. + bool runOnFunction(Function &F) override; +}; +} // namespace + +/// This is used by foldAnyOrAllBitsSet() to capture a source value (Root) and +/// the bit indexes (Mask) needed by a masked compare. If we're matching a chain +/// of 'and' ops, then we also need to capture the fact that we saw an +/// "and X, 1", so that's an extra return value for that case. +struct MaskOps { + Value *Root; + APInt Mask; + bool MatchAndChain; + bool FoundAnd1; + + MaskOps(unsigned BitWidth, bool MatchAnds) : + Root(nullptr), Mask(APInt::getNullValue(BitWidth)), + MatchAndChain(MatchAnds), FoundAnd1(false) {} +}; + +/// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a +/// chain of 'and' or 'or' instructions looking for shift ops of a common source +/// value. Examples: +/// or (or (or X, (X >> 3)), (X >> 5)), (X >> 8) +/// returns { X, 0x129 } +/// and (and (X >> 1), 1), (X >> 4) +/// returns { X, 0x12 } +static bool matchAndOrChain(Value *V, MaskOps &MOps) { + Value *Op0, *Op1; + if (MOps.MatchAndChain) { + // Recurse through a chain of 'and' operands. This requires an extra check + // vs. the 'or' matcher: we must find an "and X, 1" instruction somewhere + // in the chain to know that all of the high bits are cleared. + if (match(V, m_And(m_Value(Op0), m_One()))) { + MOps.FoundAnd1 = true; + return matchAndOrChain(Op0, MOps); + } + if (match(V, m_And(m_Value(Op0), m_Value(Op1)))) + return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps); + } else { + // Recurse through a chain of 'or' operands. + if (match(V, m_Or(m_Value(Op0), m_Value(Op1)))) + return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps); + } + + // We need a shift-right or a bare value representing a compare of bit 0 of + // the original source operand. + Value *Candidate; + uint64_t BitIndex = 0; + if (!match(V, m_LShr(m_Value(Candidate), m_ConstantInt(BitIndex)))) + Candidate = V; + + // Initialize result source operand. + if (!MOps.Root) + MOps.Root = Candidate; + + // The shift constant is out-of-range? This code hasn't been simplified. + if (BitIndex >= MOps.Mask.getBitWidth()) + return false; + + // Fill in the mask bit derived from the shift constant. + MOps.Mask.setBit(BitIndex); + return MOps.Root == Candidate; +} + +/// Match patterns that correspond to "any-bits-set" and "all-bits-set". +/// These will include a chain of 'or' or 'and'-shifted bits from a +/// common source value: +/// and (or (lshr X, C), ...), 1 --> (X & CMask) != 0 +/// and (and (lshr X, C), ...), 1 --> (X & CMask) == CMask +/// Note: "any-bits-clear" and "all-bits-clear" are variations of these patterns +/// that differ only with a final 'not' of the result. We expect that final +/// 'not' to be folded with the compare that we create here (invert predicate). +static bool foldAnyOrAllBitsSet(Instruction &I) { + // The 'any-bits-set' ('or' chain) pattern is simpler to match because the + // final "and X, 1" instruction must be the final op in the sequence. + bool MatchAllBitsSet; + if (match(&I, m_c_And(m_OneUse(m_And(m_Value(), m_Value())), m_Value()))) + MatchAllBitsSet = true; + else if (match(&I, m_And(m_OneUse(m_Or(m_Value(), m_Value())), m_One()))) + MatchAllBitsSet = false; + else + return false; + + MaskOps MOps(I.getType()->getScalarSizeInBits(), MatchAllBitsSet); + if (MatchAllBitsSet) { + if (!matchAndOrChain(cast<BinaryOperator>(&I), MOps) || !MOps.FoundAnd1) + return false; + } else { + if (!matchAndOrChain(cast<BinaryOperator>(&I)->getOperand(0), MOps)) + return false; + } + + // The pattern was found. Create a masked compare that replaces all of the + // shift and logic ops. + IRBuilder<> Builder(&I); + Constant *Mask = ConstantInt::get(I.getType(), MOps.Mask); + Value *And = Builder.CreateAnd(MOps.Root, Mask); + Value *Cmp = MatchAllBitsSet ? Builder.CreateICmpEQ(And, Mask) : + Builder.CreateIsNotNull(And); + Value *Zext = Builder.CreateZExt(Cmp, I.getType()); + I.replaceAllUsesWith(Zext); + return true; +} + +/// This is the entry point for folds that could be implemented in regular +/// InstCombine, but they are separated because they are not expected to +/// occur frequently and/or have more than a constant-length pattern match. +static bool foldUnusualPatterns(Function &F, DominatorTree &DT) { + bool MadeChange = false; + for (BasicBlock &BB : F) { + // Ignore unreachable basic blocks. + if (!DT.isReachableFromEntry(&BB)) + continue; + // Do not delete instructions under here and invalidate the iterator. + // Walk the block backwards for efficiency. We're matching a chain of + // use->defs, so we're more likely to succeed by starting from the bottom. + // Also, we want to avoid matching partial patterns. + // TODO: It would be more efficient if we removed dead instructions + // iteratively in this loop rather than waiting until the end. + for (Instruction &I : make_range(BB.rbegin(), BB.rend())) + MadeChange |= foldAnyOrAllBitsSet(I); + } + + // We're done with transforms, so remove dead instructions. + if (MadeChange) + for (BasicBlock &BB : F) + SimplifyInstructionsInBlock(&BB); + + return MadeChange; +} + +/// This is the entry point for all transforms. Pass manager differences are +/// handled in the callers of this function. +static bool runImpl(Function &F, TargetLibraryInfo &TLI, DominatorTree &DT) { + bool MadeChange = false; + const DataLayout &DL = F.getParent()->getDataLayout(); + TruncInstCombine TIC(TLI, DL, DT); + MadeChange |= TIC.run(F); + MadeChange |= foldUnusualPatterns(F, DT); + return MadeChange; +} + +void AggressiveInstCombinerLegacyPass::getAnalysisUsage( + AnalysisUsage &AU) const { + AU.setPreservesCFG(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addPreserved<AAResultsWrapperPass>(); + AU.addPreserved<BasicAAWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); +} + +bool AggressiveInstCombinerLegacyPass::runOnFunction(Function &F) { + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + return runImpl(F, TLI, DT); +} + +PreservedAnalyses AggressiveInstCombinePass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + if (!runImpl(F, TLI, DT)) { + // No changes, all analyses are preserved. + return PreservedAnalyses::all(); + } + // Mark all the analyses that instcombine updates as preserved. + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + PA.preserve<AAManager>(); + PA.preserve<GlobalsAA>(); + return PA; +} + +char AggressiveInstCombinerLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(AggressiveInstCombinerLegacyPass, + "aggressive-instcombine", + "Combine pattern based expressions", false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(AggressiveInstCombinerLegacyPass, "aggressive-instcombine", + "Combine pattern based expressions", false, false) + +// Initialization Routines +void llvm::initializeAggressiveInstCombine(PassRegistry &Registry) { + initializeAggressiveInstCombinerLegacyPassPass(Registry); +} + +void LLVMInitializeAggressiveInstCombiner(LLVMPassRegistryRef R) { + initializeAggressiveInstCombinerLegacyPassPass(*unwrap(R)); +} + +FunctionPass *llvm::createAggressiveInstCombinerPass() { + return new AggressiveInstCombinerLegacyPass(); +} + +void LLVMAddAggressiveInstCombinerPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createAggressiveInstCombinerPass()); +} diff --git a/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h b/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h new file mode 100644 index 000000000000..199374cdabf3 --- /dev/null +++ b/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h @@ -0,0 +1,121 @@ +//===- AggressiveInstCombineInternal.h --------------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the instruction pattern combiner classes. +// Currently, it handles pattern expressions for: +// * Truncate instruction +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/Pass.h" +using namespace llvm; + +//===----------------------------------------------------------------------===// +// TruncInstCombine - looks for expression dags dominated by trunc instructions +// and for each eligible dag, it will create a reduced bit-width expression and +// replace the old expression with this new one and remove the old one. +// Eligible expression dag is such that: +// 1. Contains only supported instructions. +// 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value. +// 3. Can be evaluated into type with reduced legal bit-width (or Trunc type). +// 4. All instructions in the dag must not have users outside the dag. +// Only exception is for {ZExt, SExt}Inst with operand type equal to the +// new reduced type chosen in (3). +// +// The motivation for this optimization is that evaluating and expression using +// smaller bit-width is preferable, especially for vectorization where we can +// fit more values in one vectorized instruction. In addition, this optimization +// may decrease the number of cast instructions, but will not increase it. +//===----------------------------------------------------------------------===// + +namespace llvm { + class DataLayout; + class DominatorTree; + class TargetLibraryInfo; + +class TruncInstCombine { + TargetLibraryInfo &TLI; + const DataLayout &DL; + const DominatorTree &DT; + + /// List of all TruncInst instructions to be processed. + SmallVector<TruncInst *, 4> Worklist; + + /// Current processed TruncInst instruction. + TruncInst *CurrentTruncInst; + + /// Information per each instruction in the expression dag. + struct Info { + /// Number of LSBs that are needed to generate a valid expression. + unsigned ValidBitWidth = 0; + /// Minimum number of LSBs needed to generate the ValidBitWidth. + unsigned MinBitWidth = 0; + /// The reduced value generated to replace the old instruction. + Value *NewValue = nullptr; + }; + /// An ordered map representing expression dag post-dominated by current + /// processed TruncInst. It maps each instruction in the dag to its Info + /// structure. The map is ordered such that each instruction appears before + /// all other instructions in the dag that uses it. + MapVector<Instruction *, Info> InstInfoMap; + +public: + TruncInstCombine(TargetLibraryInfo &TLI, const DataLayout &DL, + const DominatorTree &DT) + : TLI(TLI), DL(DL), DT(DT), CurrentTruncInst(nullptr) {} + + /// Perform TruncInst pattern optimization on given function. + bool run(Function &F); + +private: + /// Build expression dag dominated by the /p CurrentTruncInst and append it to + /// the InstInfoMap container. + /// + /// \return true only if succeed to generate an eligible sub expression dag. + bool buildTruncExpressionDag(); + + /// Calculate the minimal allowed bit-width of the chain ending with the + /// currently visited truncate's operand. + /// + /// \return minimum number of bits to which the chain ending with the + /// truncate's operand can be shrunk to. + unsigned getMinBitWidth(); + + /// Build an expression dag dominated by the current processed TruncInst and + /// Check if it is eligible to be reduced to a smaller type. + /// + /// \return the scalar version of the new type to be used for the reduced + /// expression dag, or nullptr if the expression dag is not eligible + /// to be reduced. + Type *getBestTruncatedType(); + + /// Given a \p V value and a \p SclTy scalar type return the generated reduced + /// value of \p V based on the type \p SclTy. + /// + /// \param V value to be reduced. + /// \param SclTy scalar version of new type to reduce to. + /// \return the new reduced value. + Value *getReducedOperand(Value *V, Type *SclTy); + + /// Create a new expression dag using the reduced /p SclTy type and replace + /// the old expression dag with it. Also erase all instructions in the old + /// dag, except those that are still needed outside the dag. + /// + /// \param SclTy scalar version of new type to reduce expression dag into. + void ReduceExpressionDag(Type *SclTy); +}; +} // end namespace llvm. diff --git a/lib/Transforms/AggressiveInstCombine/CMakeLists.txt b/lib/Transforms/AggressiveInstCombine/CMakeLists.txt new file mode 100644 index 000000000000..386314801e38 --- /dev/null +++ b/lib/Transforms/AggressiveInstCombine/CMakeLists.txt @@ -0,0 +1,11 @@ +add_llvm_library(LLVMAggressiveInstCombine + AggressiveInstCombine.cpp + TruncInstCombine.cpp + + ADDITIONAL_HEADER_DIRS + ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms + ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms/AggressiveInstCombine + + DEPENDS + intrinsics_gen + ) diff --git a/lib/Transforms/AggressiveInstCombine/LLVMBuild.txt b/lib/Transforms/AggressiveInstCombine/LLVMBuild.txt new file mode 100644 index 000000000000..c05844f33de9 --- /dev/null +++ b/lib/Transforms/AggressiveInstCombine/LLVMBuild.txt @@ -0,0 +1,22 @@ +;===- ./lib/Transforms/AggressiveInstCombine/LLVMBuild.txt -----*- Conf -*--===; +; +; The LLVM Compiler Infrastructure +; +; This file is distributed under the University of Illinois Open Source +; License. See LICENSE.TXT for details. +; +;===------------------------------------------------------------------------===; +; +; This is an LLVMBuild description file for the components in this subdirectory. +; +; For more information on the LLVMBuild system, please see: +; +; http://llvm.org/docs/LLVMBuild.html +; +;===------------------------------------------------------------------------===; + +[component_0] +type = Library +name = AggressiveInstCombine +parent = Transforms +required_libraries = Analysis Core Support TransformUtils diff --git a/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp new file mode 100644 index 000000000000..8289b2d68f8a --- /dev/null +++ b/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp @@ -0,0 +1,418 @@ +//===- TruncInstCombine.cpp -----------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// TruncInstCombine - looks for expression dags post-dominated by TruncInst and +// for each eligible dag, it will create a reduced bit-width expression, replace +// the old expression with this new one and remove the old expression. +// Eligible expression dag is such that: +// 1. Contains only supported instructions. +// 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value. +// 3. Can be evaluated into type with reduced legal bit-width. +// 4. All instructions in the dag must not have users outside the dag. +// The only exception is for {ZExt, SExt}Inst with operand type equal to +// the new reduced type evaluated in (3). +// +// The motivation for this optimization is that evaluating and expression using +// smaller bit-width is preferable, especially for vectorization where we can +// fit more values in one vectorized instruction. In addition, this optimization +// may decrease the number of cast instructions, but will not increase it. +// +//===----------------------------------------------------------------------===// + +#include "AggressiveInstCombineInternal.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +using namespace llvm; + +#define DEBUG_TYPE "aggressive-instcombine" + +/// Given an instruction and a container, it fills all the relevant operands of +/// that instruction, with respect to the Trunc expression dag optimizaton. +static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) { + unsigned Opc = I->getOpcode(); + switch (Opc) { + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + // These CastInst are considered leaves of the evaluated expression, thus, + // their operands are not relevent. + break; + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + Ops.push_back(I->getOperand(0)); + Ops.push_back(I->getOperand(1)); + break; + default: + llvm_unreachable("Unreachable!"); + } +} + +bool TruncInstCombine::buildTruncExpressionDag() { + SmallVector<Value *, 8> Worklist; + SmallVector<Instruction *, 8> Stack; + // Clear old expression dag. + InstInfoMap.clear(); + + Worklist.push_back(CurrentTruncInst->getOperand(0)); + + while (!Worklist.empty()) { + Value *Curr = Worklist.back(); + + if (isa<Constant>(Curr)) { + Worklist.pop_back(); + continue; + } + + auto *I = dyn_cast<Instruction>(Curr); + if (!I) + return false; + + if (!Stack.empty() && Stack.back() == I) { + // Already handled all instruction operands, can remove it from both the + // Worklist and the Stack, and add it to the instruction info map. + Worklist.pop_back(); + Stack.pop_back(); + // Insert I to the Info map. + InstInfoMap.insert(std::make_pair(I, Info())); + continue; + } + + if (InstInfoMap.count(I)) { + Worklist.pop_back(); + continue; + } + + // Add the instruction to the stack before start handling its operands. + Stack.push_back(I); + + unsigned Opc = I->getOpcode(); + switch (Opc) { + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + // trunc(trunc(x)) -> trunc(x) + // trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest + // trunc(ext(x)) -> trunc(x) if the source type is larger than the new + // dest + break; + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: { + SmallVector<Value *, 2> Operands; + getRelevantOperands(I, Operands); + for (Value *Operand : Operands) + Worklist.push_back(Operand); + break; + } + default: + // TODO: Can handle more cases here: + // 1. select, shufflevector, extractelement, insertelement + // 2. udiv, urem + // 3. shl, lshr, ashr + // 4. phi node(and loop handling) + // ... + return false; + } + } + return true; +} + +unsigned TruncInstCombine::getMinBitWidth() { + SmallVector<Value *, 8> Worklist; + SmallVector<Instruction *, 8> Stack; + + Value *Src = CurrentTruncInst->getOperand(0); + Type *DstTy = CurrentTruncInst->getType(); + unsigned TruncBitWidth = DstTy->getScalarSizeInBits(); + unsigned OrigBitWidth = + CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits(); + + if (isa<Constant>(Src)) + return TruncBitWidth; + + Worklist.push_back(Src); + InstInfoMap[cast<Instruction>(Src)].ValidBitWidth = TruncBitWidth; + + while (!Worklist.empty()) { + Value *Curr = Worklist.back(); + + if (isa<Constant>(Curr)) { + Worklist.pop_back(); + continue; + } + + // Otherwise, it must be an instruction. + auto *I = cast<Instruction>(Curr); + + auto &Info = InstInfoMap[I]; + + SmallVector<Value *, 2> Operands; + getRelevantOperands(I, Operands); + + if (!Stack.empty() && Stack.back() == I) { + // Already handled all instruction operands, can remove it from both, the + // Worklist and the Stack, and update MinBitWidth. + Worklist.pop_back(); + Stack.pop_back(); + for (auto *Operand : Operands) + if (auto *IOp = dyn_cast<Instruction>(Operand)) + Info.MinBitWidth = + std::max(Info.MinBitWidth, InstInfoMap[IOp].MinBitWidth); + continue; + } + + // Add the instruction to the stack before start handling its operands. + Stack.push_back(I); + unsigned ValidBitWidth = Info.ValidBitWidth; + + // Update minimum bit-width before handling its operands. This is required + // when the instruction is part of a loop. + Info.MinBitWidth = std::max(Info.MinBitWidth, Info.ValidBitWidth); + + for (auto *Operand : Operands) + if (auto *IOp = dyn_cast<Instruction>(Operand)) { + // If we already calculated the minimum bit-width for this valid + // bit-width, or for a smaller valid bit-width, then just keep the + // answer we already calculated. + unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth; + if (IOpBitwidth >= ValidBitWidth) + continue; + InstInfoMap[IOp].ValidBitWidth = std::max(ValidBitWidth, IOpBitwidth); + Worklist.push_back(IOp); + } + } + unsigned MinBitWidth = InstInfoMap.lookup(cast<Instruction>(Src)).MinBitWidth; + assert(MinBitWidth >= TruncBitWidth); + + if (MinBitWidth > TruncBitWidth) { + // In this case reducing expression with vector type might generate a new + // vector type, which is not preferable as it might result in generating + // sub-optimal code. + if (DstTy->isVectorTy()) + return OrigBitWidth; + // Use the smallest integer type in the range [MinBitWidth, OrigBitWidth). + Type *Ty = DL.getSmallestLegalIntType(DstTy->getContext(), MinBitWidth); + // Update minimum bit-width with the new destination type bit-width if + // succeeded to find such, otherwise, with original bit-width. + MinBitWidth = Ty ? Ty->getScalarSizeInBits() : OrigBitWidth; + } else { // MinBitWidth == TruncBitWidth + // In this case the expression can be evaluated with the trunc instruction + // destination type, and trunc instruction can be omitted. However, we + // should not perform the evaluation if the original type is a legal scalar + // type and the target type is illegal. + bool FromLegal = MinBitWidth == 1 || DL.isLegalInteger(OrigBitWidth); + bool ToLegal = MinBitWidth == 1 || DL.isLegalInteger(MinBitWidth); + if (!DstTy->isVectorTy() && FromLegal && !ToLegal) + return OrigBitWidth; + } + return MinBitWidth; +} + +Type *TruncInstCombine::getBestTruncatedType() { + if (!buildTruncExpressionDag()) + return nullptr; + + // We don't want to duplicate instructions, which isn't profitable. Thus, we + // can't shrink something that has multiple users, unless all users are + // post-dominated by the trunc instruction, i.e., were visited during the + // expression evaluation. + unsigned DesiredBitWidth = 0; + for (auto Itr : InstInfoMap) { + Instruction *I = Itr.first; + if (I->hasOneUse()) + continue; + bool IsExtInst = (isa<ZExtInst>(I) || isa<SExtInst>(I)); + for (auto *U : I->users()) + if (auto *UI = dyn_cast<Instruction>(U)) + if (UI != CurrentTruncInst && !InstInfoMap.count(UI)) { + if (!IsExtInst) + return nullptr; + // If this is an extension from the dest type, we can eliminate it, + // even if it has multiple users. Thus, update the DesiredBitWidth and + // validate all extension instructions agrees on same DesiredBitWidth. + unsigned ExtInstBitWidth = + I->getOperand(0)->getType()->getScalarSizeInBits(); + if (DesiredBitWidth && DesiredBitWidth != ExtInstBitWidth) + return nullptr; + DesiredBitWidth = ExtInstBitWidth; + } + } + + unsigned OrigBitWidth = + CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits(); + + // Calculate minimum allowed bit-width allowed for shrinking the currently + // visited truncate's operand. + unsigned MinBitWidth = getMinBitWidth(); + + // Check that we can shrink to smaller bit-width than original one and that + // it is similar to the DesiredBitWidth is such exists. + if (MinBitWidth >= OrigBitWidth || + (DesiredBitWidth && DesiredBitWidth != MinBitWidth)) + return nullptr; + + return IntegerType::get(CurrentTruncInst->getContext(), MinBitWidth); +} + +/// Given a reduced scalar type \p Ty and a \p V value, return a reduced type +/// for \p V, according to its type, if it vector type, return the vector +/// version of \p Ty, otherwise return \p Ty. +static Type *getReducedType(Value *V, Type *Ty) { + assert(Ty && !Ty->isVectorTy() && "Expect Scalar Type"); + if (auto *VTy = dyn_cast<VectorType>(V->getType())) + return VectorType::get(Ty, VTy->getNumElements()); + return Ty; +} + +Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) { + Type *Ty = getReducedType(V, SclTy); + if (auto *C = dyn_cast<Constant>(V)) { + C = ConstantExpr::getIntegerCast(C, Ty, false); + // If we got a constantexpr back, try to simplify it with DL info. + if (Constant *FoldedC = ConstantFoldConstant(C, DL, &TLI)) + C = FoldedC; + return C; + } + + auto *I = cast<Instruction>(V); + Info Entry = InstInfoMap.lookup(I); + assert(Entry.NewValue); + return Entry.NewValue; +} + +void TruncInstCombine::ReduceExpressionDag(Type *SclTy) { + for (auto &Itr : InstInfoMap) { // Forward + Instruction *I = Itr.first; + TruncInstCombine::Info &NodeInfo = Itr.second; + + assert(!NodeInfo.NewValue && "Instruction has been evaluated"); + + IRBuilder<> Builder(I); + Value *Res = nullptr; + unsigned Opc = I->getOpcode(); + switch (Opc) { + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: { + Type *Ty = getReducedType(I, SclTy); + // If the source type of the cast is the type we're trying for then we can + // just return the source. There's no need to insert it because it is not + // new. + if (I->getOperand(0)->getType() == Ty) { + assert(!isa<TruncInst>(I) && "Cannot reach here with TruncInst"); + NodeInfo.NewValue = I->getOperand(0); + continue; + } + // Otherwise, must be the same type of cast, so just reinsert a new one. + // This also handles the case of zext(trunc(x)) -> zext(x). + Res = Builder.CreateIntCast(I->getOperand(0), Ty, + Opc == Instruction::SExt); + + // Update Worklist entries with new value if needed. + // There are three possible changes to the Worklist: + // 1. Update Old-TruncInst -> New-TruncInst. + // 2. Remove Old-TruncInst (if New node is not TruncInst). + // 3. Add New-TruncInst (if Old node was not TruncInst). + auto Entry = find(Worklist, I); + if (Entry != Worklist.end()) { + if (auto *NewCI = dyn_cast<TruncInst>(Res)) + *Entry = NewCI; + else + Worklist.erase(Entry); + } else if (auto *NewCI = dyn_cast<TruncInst>(Res)) + Worklist.push_back(NewCI); + break; + } + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: { + Value *LHS = getReducedOperand(I->getOperand(0), SclTy); + Value *RHS = getReducedOperand(I->getOperand(1), SclTy); + Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS); + break; + } + default: + llvm_unreachable("Unhandled instruction"); + } + + NodeInfo.NewValue = Res; + if (auto *ResI = dyn_cast<Instruction>(Res)) + ResI->takeName(I); + } + + Value *Res = getReducedOperand(CurrentTruncInst->getOperand(0), SclTy); + Type *DstTy = CurrentTruncInst->getType(); + if (Res->getType() != DstTy) { + IRBuilder<> Builder(CurrentTruncInst); + Res = Builder.CreateIntCast(Res, DstTy, false); + if (auto *ResI = dyn_cast<Instruction>(Res)) + ResI->takeName(CurrentTruncInst); + } + CurrentTruncInst->replaceAllUsesWith(Res); + + // Erase old expression dag, which was replaced by the reduced expression dag. + // We iterate backward, which means we visit the instruction before we visit + // any of its operands, this way, when we get to the operand, we already + // removed the instructions (from the expression dag) that uses it. + CurrentTruncInst->eraseFromParent(); + for (auto I = InstInfoMap.rbegin(), E = InstInfoMap.rend(); I != E; ++I) { + // We still need to check that the instruction has no users before we erase + // it, because {SExt, ZExt}Inst Instruction might have other users that was + // not reduced, in such case, we need to keep that instruction. + if (I->first->use_empty()) + I->first->eraseFromParent(); + } +} + +bool TruncInstCombine::run(Function &F) { + bool MadeIRChange = false; + + // Collect all TruncInst in the function into the Worklist for evaluating. + for (auto &BB : F) { + // Ignore unreachable basic block. + if (!DT.isReachableFromEntry(&BB)) + continue; + for (auto &I : BB) + if (auto *CI = dyn_cast<TruncInst>(&I)) + Worklist.push_back(CI); + } + + // Process all TruncInst in the Worklist, for each instruction: + // 1. Check if it dominates an eligible expression dag to be reduced. + // 2. Create a reduced expression dag and replace the old one with it. + while (!Worklist.empty()) { + CurrentTruncInst = Worklist.pop_back_val(); + + if (Type *NewDstSclTy = getBestTruncatedType()) { + LLVM_DEBUG( + dbgs() << "ICE: TruncInstCombine reducing type of expression dag " + "dominated by: " + << CurrentTruncInst << '\n'); + ReduceExpressionDag(NewDstSclTy); + MadeIRChange = true; + } + } + + return MadeIRChange; +} diff --git a/lib/Transforms/CMakeLists.txt b/lib/Transforms/CMakeLists.txt index 67bdeb27212d..74db9e53304d 100644 --- a/lib/Transforms/CMakeLists.txt +++ b/lib/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(Utils) add_subdirectory(Instrumentation) +add_subdirectory(AggressiveInstCombine) add_subdirectory(InstCombine) add_subdirectory(Scalar) add_subdirectory(IPO) diff --git a/lib/Transforms/Coroutines/CMakeLists.txt b/lib/Transforms/Coroutines/CMakeLists.txt index 1c635bd9db08..80a052a2d45d 100644 --- a/lib/Transforms/Coroutines/CMakeLists.txt +++ b/lib/Transforms/Coroutines/CMakeLists.txt @@ -4,7 +4,7 @@ add_llvm_library(LLVMCoroutines CoroEarly.cpp CoroElide.cpp CoroFrame.cpp - CoroSplit.cpp + CoroSplit.cpp DEPENDS intrinsics_gen diff --git a/lib/Transforms/Coroutines/CoroEarly.cpp b/lib/Transforms/Coroutines/CoroEarly.cpp index ba05896af150..ac47a06281a5 100644 --- a/lib/Transforms/Coroutines/CoroEarly.cpp +++ b/lib/Transforms/Coroutines/CoroEarly.cpp @@ -27,10 +27,12 @@ namespace { class Lowerer : public coro::LowererBase { IRBuilder<> Builder; PointerType *const AnyResumeFnPtrTy; + Constant *NoopCoro = nullptr; void lowerResumeOrDestroy(CallSite CS, CoroSubFnInst::ResumeKind); void lowerCoroPromise(CoroPromiseInst *Intrin); void lowerCoroDone(IntrinsicInst *II); + void lowerCoroNoop(IntrinsicInst *II); public: Lowerer(Module &M) @@ -103,6 +105,41 @@ void Lowerer::lowerCoroDone(IntrinsicInst *II) { II->eraseFromParent(); } +void Lowerer::lowerCoroNoop(IntrinsicInst *II) { + if (!NoopCoro) { + LLVMContext &C = Builder.getContext(); + Module &M = *II->getModule(); + + // Create a noop.frame struct type. + StructType *FrameTy = StructType::create(C, "NoopCoro.Frame"); + auto *FramePtrTy = FrameTy->getPointerTo(); + auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy, + /*IsVarArgs=*/false); + auto *FnPtrTy = FnTy->getPointerTo(); + FrameTy->setBody({FnPtrTy, FnPtrTy}); + + // Create a Noop function that does nothing. + Function *NoopFn = + Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage, + "NoopCoro.ResumeDestroy", &M); + NoopFn->setCallingConv(CallingConv::Fast); + auto *Entry = BasicBlock::Create(C, "entry", NoopFn); + ReturnInst::Create(C, Entry); + + // Create a constant struct for the frame. + Constant* Values[] = {NoopFn, NoopFn}; + Constant* NoopCoroConst = ConstantStruct::get(FrameTy, Values); + NoopCoro = new GlobalVariable(M, NoopCoroConst->getType(), /*isConstant=*/true, + GlobalVariable::PrivateLinkage, NoopCoroConst, + "NoopCoro.Frame.Const"); + } + + Builder.SetInsertPoint(II); + auto *NoopCoroVoidPtr = Builder.CreateBitCast(NoopCoro, Int8Ptr); + II->replaceAllUsesWith(NoopCoroVoidPtr); + II->eraseFromParent(); +} + // Prior to CoroSplit, calls to coro.begin needs to be marked as NoDuplicate, // as CoroSplit assumes there is exactly one coro.begin. After CoroSplit, // NoDuplicate attribute will be removed from coro.begin otherwise, it will @@ -138,6 +175,9 @@ bool Lowerer::lowerEarlyIntrinsics(Function &F) { if (cast<CoroEndInst>(&I)->isFallthrough()) CS.setCannotDuplicate(); break; + case Intrinsic::coro_noop: + lowerCoroNoop(cast<IntrinsicInst>(&I)); + break; case Intrinsic::coro_id: // Mark a function that comes out of the frontend that has a coro.id // with a coroutine attribute. @@ -192,10 +232,10 @@ struct CoroEarly : public FunctionPass { // This pass has work to do only if we find intrinsics we are going to lower // in the module. bool doInitialization(Module &M) override { - if (coro::declaresIntrinsics(M, {"llvm.coro.id", "llvm.coro.destroy", - "llvm.coro.done", "llvm.coro.end", - "llvm.coro.free", "llvm.coro.promise", - "llvm.coro.resume", "llvm.coro.suspend"})) + if (coro::declaresIntrinsics( + M, {"llvm.coro.id", "llvm.coro.destroy", "llvm.coro.done", + "llvm.coro.end", "llvm.coro.noop", "llvm.coro.free", + "llvm.coro.promise", "llvm.coro.resume", "llvm.coro.suspend"})) L = llvm::make_unique<Lowerer>(M); return false; } diff --git a/lib/Transforms/Coroutines/CoroElide.cpp b/lib/Transforms/Coroutines/CoroElide.cpp index 42fd6d746145..dfe05c4b2a5e 100644 --- a/lib/Transforms/Coroutines/CoroElide.cpp +++ b/lib/Transforms/Coroutines/CoroElide.cpp @@ -14,6 +14,7 @@ #include "CoroInternal.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/InstIterator.h" #include "llvm/Pass.h" #include "llvm/Support/ErrorHandling.h" @@ -35,8 +36,8 @@ struct Lowerer : coro::LowererBase { Lowerer(Module &M) : LowererBase(M) {} void elideHeapAllocations(Function *F, Type *FrameTy, AAResults &AA); - bool shouldElide() const; - bool processCoroId(CoroIdInst *, AAResults &AA); + bool shouldElide(Function *F, DominatorTree &DT) const; + bool processCoroId(CoroIdInst *, AAResults &AA, DominatorTree &DT); }; } // end anonymous namespace @@ -77,7 +78,6 @@ static bool operandReferences(CallInst *CI, AllocaInst *Frame, AAResults &AA) { // call implies that the function does not references anything on the stack. static void removeTailCallAttribute(AllocaInst *Frame, AAResults &AA) { Function &F = *Frame->getFunction(); - MemoryLocation Mem(Frame); for (Instruction &I : instructions(F)) if (auto *Call = dyn_cast<CallInst>(&I)) if (Call->isTailCall() && operandReferences(Call, Frame, AA)) { @@ -142,33 +142,54 @@ void Lowerer::elideHeapAllocations(Function *F, Type *FrameTy, AAResults &AA) { removeTailCallAttribute(Frame, AA); } -bool Lowerer::shouldElide() const { +bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const { // If no CoroAllocs, we cannot suppress allocation, so elision is not // possible. if (CoroAllocs.empty()) return false; // Check that for every coro.begin there is a coro.destroy directly - // referencing the SSA value of that coro.begin. If the value escaped, then - // coro.destroy would have been referencing a memory location storing that - // value and not the virtual register. - - SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins; + // referencing the SSA value of that coro.begin along a non-exceptional path. + // If the value escaped, then coro.destroy would have been referencing a + // memory location storing that value and not the virtual register. + + // First gather all of the non-exceptional terminators for the function. + SmallPtrSet<Instruction *, 8> Terminators; + for (BasicBlock &B : *F) { + auto *TI = B.getTerminator(); + if (TI->getNumSuccessors() == 0 && !TI->isExceptional() && + !isa<UnreachableInst>(TI)) + Terminators.insert(TI); + } + // Filter out the coro.destroy that lie along exceptional paths. + SmallPtrSet<CoroSubFnInst *, 4> DAs; for (CoroSubFnInst *DA : DestroyAddr) { + for (Instruction *TI : Terminators) { + if (DT.dominates(DA, TI)) { + DAs.insert(DA); + break; + } + } + } + + // Find all the coro.begin referenced by coro.destroy along happy paths. + SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins; + for (CoroSubFnInst *DA : DAs) { if (auto *CB = dyn_cast<CoroBeginInst>(DA->getFrame())) ReferencedCoroBegins.insert(CB); else return false; } - // If size of the set is the same as total number of CoroBegins, means we - // found a coro.free or coro.destroy mentioning a coro.begin and we can + // If size of the set is the same as total number of coro.begin, that means we + // found a coro.free or coro.destroy referencing each coro.begin, so we can // perform heap elision. return ReferencedCoroBegins.size() == CoroBegins.size(); } -bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA) { +bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA, + DominatorTree &DT) { CoroBegins.clear(); CoroAllocs.clear(); CoroFrees.clear(); @@ -214,7 +235,7 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA) { replaceWithConstant(ResumeAddrConstant, ResumeAddr); - bool ShouldElide = shouldElide(); + bool ShouldElide = shouldElide(CoroId->getFunction(), DT); auto *DestroyAddrConstant = ConstantExpr::getExtractValue( Resumers, @@ -294,14 +315,16 @@ struct CoroElide : FunctionPass { return Changed; AAResults &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); for (auto *CII : L->CoroIds) - Changed |= L->processCoroId(CII, AA); + Changed |= L->processCoroId(CII, AA, DT); return Changed; } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); } StringRef getPassName() const override { return "Coroutine Elision"; } }; diff --git a/lib/Transforms/Coroutines/CoroFrame.cpp b/lib/Transforms/Coroutines/CoroFrame.cpp index 6334256bf03a..cf63b678b618 100644 --- a/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/lib/Transforms/Coroutines/CoroFrame.cpp @@ -19,6 +19,8 @@ #include "CoroInternal.h" #include "llvm/ADT/BitVector.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Config/llvm-config.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" @@ -27,7 +29,6 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/circular_raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -48,7 +49,7 @@ public: BlockToIndexMapping(Function &F) { for (BasicBlock &BB : F) V.push_back(&BB); - std::sort(V.begin(), V.end()); + llvm::sort(V.begin(), V.end()); } size_t blockToIndex(BasicBlock *BB) const { @@ -105,8 +106,8 @@ struct SuspendCrossingInfo { assert(Block[UseIndex].Consumes[DefIndex] && "use must consume def"); bool const Result = Block[UseIndex].Kills[DefIndex]; - DEBUG(dbgs() << UseBB->getName() << " => " << DefBB->getName() - << " answer is " << Result << "\n"); + LLVM_DEBUG(dbgs() << UseBB->getName() << " => " << DefBB->getName() + << " answer is " << Result << "\n"); return Result; } @@ -194,8 +195,8 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) bool Changed; do { - DEBUG(dbgs() << "iteration " << ++Iteration); - DEBUG(dbgs() << "==============\n"); + LLVM_DEBUG(dbgs() << "iteration " << ++Iteration); + LLVM_DEBUG(dbgs() << "==============\n"); Changed = false; for (size_t I = 0; I < N; ++I) { @@ -239,20 +240,20 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) Changed |= (S.Kills != SavedKills) || (S.Consumes != SavedConsumes); if (S.Kills != SavedKills) { - DEBUG(dbgs() << "\nblock " << I << " follower " << SI->getName() - << "\n"); - DEBUG(dump("S.Kills", S.Kills)); - DEBUG(dump("SavedKills", SavedKills)); + LLVM_DEBUG(dbgs() << "\nblock " << I << " follower " << SI->getName() + << "\n"); + LLVM_DEBUG(dump("S.Kills", S.Kills)); + LLVM_DEBUG(dump("SavedKills", SavedKills)); } if (S.Consumes != SavedConsumes) { - DEBUG(dbgs() << "\nblock " << I << " follower " << SI << "\n"); - DEBUG(dump("S.Consume", S.Consumes)); - DEBUG(dump("SavedCons", SavedConsumes)); + LLVM_DEBUG(dbgs() << "\nblock " << I << " follower " << SI << "\n"); + LLVM_DEBUG(dump("S.Consume", S.Consumes)); + LLVM_DEBUG(dump("SavedCons", SavedConsumes)); } } } } while (Changed); - DEBUG(dump()); + LLVM_DEBUG(dump()); } #undef DEBUG_TYPE // "coro-suspend-crossing" @@ -263,8 +264,9 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) namespace { class Spill { - Value *Def; - Instruction *User; + Value *Def = nullptr; + Instruction *User = nullptr; + unsigned FieldNo = 0; public: Spill(Value *Def, llvm::User *U) : Def(Def), User(cast<Instruction>(U)) {} @@ -272,6 +274,20 @@ public: Value *def() const { return Def; } Instruction *user() const { return User; } BasicBlock *userBlock() const { return User->getParent(); } + + // Note that field index is stored in the first SpillEntry for a particular + // definition. Subsequent mentions of a defintion do not have fieldNo + // assigned. This works out fine as the users of Spills capture the info about + // the definition the first time they encounter it. Consider refactoring + // SpillInfo into two arrays to normalize the spill representation. + unsigned fieldIndex() const { + assert(FieldNo && "Accessing unassigned field"); + return FieldNo; + } + void setFieldIndex(unsigned FieldNumber) { + assert(!FieldNo && "Reassigning field number"); + FieldNo = FieldNumber; + } }; } // namespace @@ -294,6 +310,57 @@ static void dump(StringRef Title, SpillInfo const &Spills) { } #endif +namespace { +// We cannot rely solely on natural alignment of a type when building a +// coroutine frame and if the alignment specified on the Alloca instruction +// differs from the natural alignment of the alloca type we will need to insert +// padding. +struct PaddingCalculator { + const DataLayout &DL; + LLVMContext &Context; + unsigned StructSize = 0; + + PaddingCalculator(LLVMContext &Context, DataLayout const &DL) + : DL(DL), Context(Context) {} + + // Replicate the logic from IR/DataLayout.cpp to match field offset + // computation for LLVM structs. + void addType(Type *Ty) { + unsigned TyAlign = DL.getABITypeAlignment(Ty); + if ((StructSize & (TyAlign - 1)) != 0) + StructSize = alignTo(StructSize, TyAlign); + + StructSize += DL.getTypeAllocSize(Ty); // Consume space for this data item. + } + + void addTypes(SmallVectorImpl<Type *> const &Types) { + for (auto *Ty : Types) + addType(Ty); + } + + unsigned computePadding(Type *Ty, unsigned ForcedAlignment) { + unsigned TyAlign = DL.getABITypeAlignment(Ty); + auto Natural = alignTo(StructSize, TyAlign); + auto Forced = alignTo(StructSize, ForcedAlignment); + + // Return how many bytes of padding we need to insert. + if (Natural != Forced) + return std::max(Natural, Forced) - StructSize; + + // Rely on natural alignment. + return 0; + } + + // If padding required, return the padding field type to insert. + ArrayType *getPaddingType(Type *Ty, unsigned ForcedAlignment) { + if (auto Padding = computePadding(Ty, ForcedAlignment)) + return ArrayType::get(Type::getInt8Ty(Context), Padding); + + return nullptr; + } +}; +} // namespace + // Build a struct that will keep state for an active coroutine. // struct f.frame { // ResumeFnTy ResumeFnAddr; @@ -305,6 +372,8 @@ static void dump(StringRef Title, SpillInfo const &Spills) { static StructType *buildFrameType(Function &F, coro::Shape &Shape, SpillInfo &Spills) { LLVMContext &C = F.getContext(); + const DataLayout &DL = F.getParent()->getDataLayout(); + PaddingCalculator Padder(C, DL); SmallString<32> Name(F.getName()); Name.append(".Frame"); StructType *FrameTy = StructType::create(C, Name); @@ -322,8 +391,10 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape, Type::getIntNTy(C, IndexBits)}; Value *CurrentDef = nullptr; + Padder.addTypes(Types); + // Create an entry for every spilled value. - for (auto const &S : Spills) { + for (auto &S : Spills) { if (CurrentDef == S.def()) continue; @@ -333,12 +404,22 @@ static StructType *buildFrameType(Function &F, coro::Shape &Shape, continue; Type *Ty = nullptr; - if (auto *AI = dyn_cast<AllocaInst>(CurrentDef)) + if (auto *AI = dyn_cast<AllocaInst>(CurrentDef)) { Ty = AI->getAllocatedType(); - else + if (unsigned AllocaAlignment = AI->getAlignment()) { + // If alignment is specified in alloca, see if we need to insert extra + // padding. + if (auto PaddingTy = Padder.getPaddingType(Ty, AllocaAlignment)) { + Types.push_back(PaddingTy); + Padder.addType(PaddingTy); + } + } + } else { Ty = CurrentDef->getType(); - + } + S.setFieldIndex(Types.size()); Types.push_back(Ty); + Padder.addType(Ty); } FrameTy->setBody(Types); @@ -399,7 +480,7 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) { Value *CurrentValue = nullptr; BasicBlock *CurrentBlock = nullptr; Value *CurrentReload = nullptr; - unsigned Index = coro::Shape::LastKnownField; + unsigned Index = 0; // Proper field number will be read from field definition. // We need to keep track of any allocas that need "spilling" // since they will live in the coroutine frame now, all access to them @@ -414,6 +495,7 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) { // Create a load instruction to reload the spilled value from the coroutine // frame. auto CreateReload = [&](Instruction *InsertBefore) { + assert(Index && "accessing unassigned field number"); Builder.SetInsertPoint(InsertBefore); auto *G = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, Index, CurrentValue->getName() + @@ -431,7 +513,7 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) { CurrentBlock = nullptr; CurrentReload = nullptr; - ++Index; + Index = E.fieldIndex(); if (auto *AI = dyn_cast<AllocaInst>(CurrentValue)) { // Spilled AllocaInst will be replaced with GEP from the coroutine frame @@ -739,6 +821,8 @@ static void moveSpillUsesAfterCoroBegin(Function &F, SpillInfo const &Spills, for (User *U : CurrentValue->users()) { Instruction *I = cast<Instruction>(U); if (!DT.dominates(CoroBegin, I)) { + LLVM_DEBUG(dbgs() << "will move: " << *I << "\n"); + // TODO: Make this more robust. Currently if we run into a situation // where simple instruction move won't work we panic and // report_fatal_error. @@ -748,7 +832,6 @@ static void moveSpillUsesAfterCoroBegin(Function &F, SpillInfo const &Spills, " dominated by CoroBegin"); } - DEBUG(dbgs() << "will move: " << *I << "\n"); NeedsMoving.push_back(I); } } @@ -823,7 +906,7 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) { break; // Rewrite materializable instructions to be materialized at the use point. - DEBUG(dump("Materializations", Spills)); + LLVM_DEBUG(dump("Materializations", Spills)); rewriteMaterializableInstructions(Builder, Spills); Spills.clear(); } @@ -853,7 +936,7 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) { Spills.emplace_back(&I, U); } } - DEBUG(dump("Spills", Spills)); + LLVM_DEBUG(dump("Spills", Spills)); moveSpillUsesAfterCoroBegin(F, Spills, Shape.CoroBegin); Shape.FrameTy = buildFrameType(F, Shape, Spills); Shape.FramePtr = insertSpills(Spills, Shape); diff --git a/lib/Transforms/Coroutines/CoroInternal.h b/lib/Transforms/Coroutines/CoroInternal.h index 1eac88dbac3a..8e690d649cf5 100644 --- a/lib/Transforms/Coroutines/CoroInternal.h +++ b/lib/Transforms/Coroutines/CoroInternal.h @@ -76,7 +76,6 @@ struct LLVM_LIBRARY_VISIBILITY Shape { DestroyField, PromiseField, IndexField, - LastKnownField = IndexField }; StructType *FrameTy; diff --git a/lib/Transforms/Coroutines/CoroSplit.cpp b/lib/Transforms/Coroutines/CoroSplit.cpp index 8712ca4823c6..49acc5e93a39 100644 --- a/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/lib/Transforms/Coroutines/CoroSplit.cpp @@ -28,6 +28,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -59,7 +60,6 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <cassert> #include <cstddef> @@ -250,7 +250,7 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape, auto *FnTy = cast<FunctionType>(FnPtrTy->getElementType()); Function *NewF = - Function::Create(FnTy, GlobalValue::LinkageTypes::InternalLinkage, + Function::Create(FnTy, GlobalValue::LinkageTypes::ExternalLinkage, F.getName() + Suffix, M); NewF->addParamAttr(0, Attribute::NonNull); NewF->addParamAttr(0, Attribute::NoAlias); @@ -265,6 +265,7 @@ static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape, SmallVector<ReturnInst *, 4> Returns; CloneFunctionInto(NewF, &F, VMap, /*ModuleLevelChanges=*/true, Returns); + NewF->setLinkage(GlobalValue::LinkageTypes::InternalLinkage); // Remove old returns. for (ReturnInst *Return : Returns) @@ -440,16 +441,14 @@ static void scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock, DenseMap<Value *, Value *> &ResolvedValues) { auto *PrevBB = Prev->getParent(); - auto *I = &*NewBlock->begin(); - while (auto PN = dyn_cast<PHINode>(I)) { - auto V = PN->getIncomingValueForBlock(PrevBB); + for (PHINode &PN : NewBlock->phis()) { + auto V = PN.getIncomingValueForBlock(PrevBB); // See if we already resolved it. auto VI = ResolvedValues.find(V); if (VI != ResolvedValues.end()) V = VI->second; // Remember the value. - ResolvedValues[PN] = V; - I = I->getNextNode(); + ResolvedValues[&PN] = V; } } @@ -655,13 +654,28 @@ getNotRelocatableInstructions(CoroBeginInst *CoroBegin, // set. do { Instruction *Current = Work.pop_back_val(); + LLVM_DEBUG(dbgs() << "CoroSplit: Will not relocate: " << *Current << "\n"); DoNotRelocate.insert(Current); for (Value *U : Current->operands()) { auto *I = dyn_cast<Instruction>(U); if (!I) continue; - if (isa<AllocaInst>(U)) + + if (auto *A = dyn_cast<AllocaInst>(I)) { + // Stores to alloca instructions that occur before the coroutine frame + // is allocated should not be moved; the stored values may be used by + // the coroutine frame allocator. The operands to those stores must also + // remain in place. + for (const auto &User : A->users()) + if (auto *SI = dyn_cast<llvm::StoreInst>(User)) + if (RelocBlocks.count(SI->getParent()) != 0 && + DoNotRelocate.count(SI) == 0) { + Work.push_back(SI); + DoNotRelocate.insert(SI); + } continue; + } + if (DoNotRelocate.count(I) == 0) { Work.push_back(I); DoNotRelocate.insert(I); @@ -836,8 +850,8 @@ struct CoroSplit : public CallGraphSCCPass { for (Function *F : Coroutines) { Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR); StringRef Value = Attr.getValueAsString(); - DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName() - << "' state: " << Value << "\n"); + LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName() + << "' state: " << Value << "\n"); if (Value == UNPREPARED_FOR_SPLIT) { prepareForSplit(*F, CG); continue; diff --git a/lib/Transforms/Coroutines/Coroutines.cpp b/lib/Transforms/Coroutines/Coroutines.cpp index 10411c1bd65d..731faeb5dce4 100644 --- a/lib/Transforms/Coroutines/Coroutines.cpp +++ b/lib/Transforms/Coroutines/Coroutines.cpp @@ -11,12 +11,14 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Coroutines.h" #include "CoroInstr.h" #include "CoroInternal.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" @@ -31,10 +33,8 @@ #include "llvm/IR/Type.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Transforms/Coroutines.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h" -#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstddef> #include <utility> @@ -125,9 +125,10 @@ static bool isCoroutineIntrinsicName(StringRef Name) { static const char *const CoroIntrinsics[] = { "llvm.coro.alloc", "llvm.coro.begin", "llvm.coro.destroy", "llvm.coro.done", "llvm.coro.end", "llvm.coro.frame", - "llvm.coro.free", "llvm.coro.id", "llvm.coro.param", - "llvm.coro.promise", "llvm.coro.resume", "llvm.coro.save", - "llvm.coro.size", "llvm.coro.subfn.addr", "llvm.coro.suspend", + "llvm.coro.free", "llvm.coro.id", "llvm.coro.noop", + "llvm.coro.param", "llvm.coro.promise", "llvm.coro.resume", + "llvm.coro.save", "llvm.coro.size", "llvm.coro.subfn.addr", + "llvm.coro.suspend", }; return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics, Name) != -1; } diff --git a/lib/Transforms/IPO/AlwaysInliner.cpp b/lib/Transforms/IPO/AlwaysInliner.cpp index a4bbc99b1f90..3b735ddd192e 100644 --- a/lib/Transforms/IPO/AlwaysInliner.cpp +++ b/lib/Transforms/IPO/AlwaysInliner.cpp @@ -50,7 +50,8 @@ PreservedAnalyses AlwaysInlinerPass::run(Module &M, ModuleAnalysisManager &) { for (CallSite CS : Calls) // FIXME: We really shouldn't be able to fail to inline at this point! // We should do something to log or check the inline failures here. - Changed |= InlineFunction(CS, IFI); + Changed |= + InlineFunction(CS, IFI, /*CalleeAAR=*/nullptr, InsertLifetime); // Remember to try and delete this function afterward. This both avoids // re-walking the rest of the module and avoids dealing with any iterator @@ -129,7 +130,7 @@ Pass *llvm::createAlwaysInlinerLegacyPass(bool InsertLifetime) { return new AlwaysInlinerLegacyPass(InsertLifetime); } -/// \brief Get the inline cost for the always-inliner. +/// Get the inline cost for the always-inliner. /// /// The always inliner *only* handles functions which are marked with the /// attribute to force inlining. As such, it is dramatically simpler and avoids diff --git a/lib/Transforms/IPO/ArgumentPromotion.cpp b/lib/Transforms/IPO/ArgumentPromotion.cpp index b25cbcad3b9d..f2c2b55b1c5b 100644 --- a/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -220,8 +220,8 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, NF->setSubprogram(F->getSubprogram()); F->setSubprogram(nullptr); - DEBUG(dbgs() << "ARG PROMOTION: Promoting to:" << *NF << "\n" - << "From: " << *F); + LLVM_DEBUG(dbgs() << "ARG PROMOTION: Promoting to:" << *NF << "\n" + << "From: " << *F); // Recompute the parameter attributes list based on the new arguments for // the function. @@ -426,8 +426,8 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, I2->setName(I->getName() + ".val"); LI->replaceAllUsesWith(&*I2); LI->eraseFromParent(); - DEBUG(dbgs() << "*** Promoted load of argument '" << I->getName() - << "' in function '" << F->getName() << "'\n"); + LLVM_DEBUG(dbgs() << "*** Promoted load of argument '" << I->getName() + << "' in function '" << F->getName() << "'\n"); } else { GetElementPtrInst *GEP = cast<GetElementPtrInst>(I->user_back()); IndicesVector Operands; @@ -453,8 +453,8 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, NewName += ".val"; TheArg->setName(NewName); - DEBUG(dbgs() << "*** Promoted agg argument '" << TheArg->getName() - << "' of function '" << NF->getName() << "'\n"); + LLVM_DEBUG(dbgs() << "*** Promoted agg argument '" << TheArg->getName() + << "' of function '" << NF->getName() << "'\n"); // All of the uses must be load instructions. Replace them all with // the argument specified by ArgNo. @@ -688,11 +688,11 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, // to do. if (ToPromote.find(Operands) == ToPromote.end()) { if (MaxElements > 0 && ToPromote.size() == MaxElements) { - DEBUG(dbgs() << "argpromotion not promoting argument '" - << Arg->getName() - << "' because it would require adding more " - << "than " << MaxElements - << " arguments to the function.\n"); + LLVM_DEBUG(dbgs() << "argpromotion not promoting argument '" + << Arg->getName() + << "' because it would require adding more " + << "than " << MaxElements + << " arguments to the function.\n"); // We limit aggregate promotion to only promoting up to a fixed number // of elements of the aggregate. return false; @@ -738,7 +738,7 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, return true; } -/// \brief Checks if a type could have padding bytes. +/// Checks if a type could have padding bytes. static bool isDenselyPacked(Type *type, const DataLayout &DL) { // There is no size information, so be conservative. if (!type->isSized()) @@ -772,7 +772,7 @@ static bool isDenselyPacked(Type *type, const DataLayout &DL) { return true; } -/// \brief Checks if the padding bytes of an argument could be accessed. +/// Checks if the padding bytes of an argument could be accessed. static bool canPaddingBeAccessed(Argument *arg) { assert(arg->hasByValAttr()); @@ -817,6 +817,12 @@ promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter, unsigned MaxElements, Optional<function_ref<void(CallSite OldCS, CallSite NewCS)>> ReplaceCallSite) { + // Don't perform argument promotion for naked functions; otherwise we can end + // up removing parameters that are seemingly 'not used' as they are referred + // to in the assembly. + if(F->hasFnAttribute(Attribute::Naked)) + return nullptr; + // Make sure that it is local to this module. if (!F->hasLocalLinkage()) return nullptr; @@ -847,10 +853,20 @@ promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter, if (CS.getInstruction() == nullptr || !CS.isCallee(&U)) return nullptr; + // Can't change signature of musttail callee + if (CS.isMustTailCall()) + return nullptr; + if (CS.getInstruction()->getParent()->getParent() == F) isSelfRecursive = true; } + // Can't change signature of musttail caller + // FIXME: Support promoting whole chain of musttail functions + for (BasicBlock &BB : *F) + if (BB.getTerminatingMustTailCall()) + return nullptr; + const DataLayout &DL = F->getParent()->getDataLayout(); AAResults &AAR = AARGetter(*F); @@ -885,11 +901,11 @@ promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter, if (isSafeToPromote) { if (StructType *STy = dyn_cast<StructType>(AgTy)) { if (MaxElements > 0 && STy->getNumElements() > MaxElements) { - DEBUG(dbgs() << "argpromotion disable promoting argument '" - << PtrArg->getName() - << "' because it would require adding more" - << " than " << MaxElements - << " arguments to the function.\n"); + LLVM_DEBUG(dbgs() << "argpromotion disable promoting argument '" + << PtrArg->getName() + << "' because it would require adding more" + << " than " << MaxElements + << " arguments to the function.\n"); continue; } @@ -963,7 +979,7 @@ PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C, return FAM.getResult<AAManager>(F); }; - Function *NewF = promoteArguments(&OldF, AARGetter, 3u, None); + Function *NewF = promoteArguments(&OldF, AARGetter, MaxElements, None); if (!NewF) continue; LocalChange = true; diff --git a/lib/Transforms/IPO/BarrierNoopPass.cpp b/lib/Transforms/IPO/BarrierNoopPass.cpp index 6af104362594..05fc3dd6950c 100644 --- a/lib/Transforms/IPO/BarrierNoopPass.cpp +++ b/lib/Transforms/IPO/BarrierNoopPass.cpp @@ -23,7 +23,7 @@ using namespace llvm; namespace { -/// \brief A nonce module pass used to place a barrier in a pass manager. +/// A nonce module pass used to place a barrier in a pass manager. /// /// There is no mechanism for ending a CGSCC pass manager once one is started. /// This prevents extension points from having clear deterministic ordering diff --git a/lib/Transforms/IPO/BlockExtractor.cpp b/lib/Transforms/IPO/BlockExtractor.cpp new file mode 100644 index 000000000000..ff5ee817da49 --- /dev/null +++ b/lib/Transforms/IPO/BlockExtractor.cpp @@ -0,0 +1,176 @@ +//===- BlockExtractor.cpp - Extracts blocks into their own functions ------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass extracts the specified basic blocks from the module into their +// own functions. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/CodeExtractor.h" +using namespace llvm; + +#define DEBUG_TYPE "block-extractor" + +STATISTIC(NumExtracted, "Number of basic blocks extracted"); + +static cl::opt<std::string> BlockExtractorFile( + "extract-blocks-file", cl::value_desc("filename"), + cl::desc("A file containing list of basic blocks to extract"), cl::Hidden); + +cl::opt<bool> BlockExtractorEraseFuncs("extract-blocks-erase-funcs", + cl::desc("Erase the existing functions"), + cl::Hidden); + +namespace { +class BlockExtractor : public ModulePass { + SmallVector<BasicBlock *, 16> Blocks; + bool EraseFunctions; + SmallVector<std::pair<std::string, std::string>, 32> BlocksByName; + +public: + static char ID; + BlockExtractor(const SmallVectorImpl<BasicBlock *> &BlocksToExtract, + bool EraseFunctions) + : ModulePass(ID), Blocks(BlocksToExtract.begin(), BlocksToExtract.end()), + EraseFunctions(EraseFunctions) { + if (!BlockExtractorFile.empty()) + loadFile(); + } + BlockExtractor() : BlockExtractor(SmallVector<BasicBlock *, 0>(), false) {} + bool runOnModule(Module &M) override; + +private: + void loadFile(); + void splitLandingPadPreds(Function &F); +}; +} // end anonymous namespace + +char BlockExtractor::ID = 0; +INITIALIZE_PASS(BlockExtractor, "extract-blocks", + "Extract basic blocks from module", false, false) + +ModulePass *llvm::createBlockExtractorPass() { return new BlockExtractor(); } +ModulePass *llvm::createBlockExtractorPass( + const SmallVectorImpl<BasicBlock *> &BlocksToExtract, bool EraseFunctions) { + return new BlockExtractor(BlocksToExtract, EraseFunctions); +} + +/// Gets all of the blocks specified in the input file. +void BlockExtractor::loadFile() { + auto ErrOrBuf = MemoryBuffer::getFile(BlockExtractorFile); + if (ErrOrBuf.getError()) + report_fatal_error("BlockExtractor couldn't load the file."); + // Read the file. + auto &Buf = *ErrOrBuf; + SmallVector<StringRef, 16> Lines; + Buf->getBuffer().split(Lines, '\n', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + for (const auto &Line : Lines) { + auto FBPair = Line.split(' '); + BlocksByName.push_back({FBPair.first, FBPair.second}); + } +} + +/// Extracts the landing pads to make sure all of them have only one +/// predecessor. +void BlockExtractor::splitLandingPadPreds(Function &F) { + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { + if (!isa<InvokeInst>(&I)) + continue; + InvokeInst *II = cast<InvokeInst>(&I); + BasicBlock *Parent = II->getParent(); + BasicBlock *LPad = II->getUnwindDest(); + + // Look through the landing pad's predecessors. If one of them ends in an + // 'invoke', then we want to split the landing pad. + bool Split = false; + for (auto PredBB : predecessors(LPad)) { + if (PredBB->isLandingPad() && PredBB != Parent && + isa<InvokeInst>(Parent->getTerminator())) { + Split = true; + break; + } + } + + if (!Split) + continue; + + SmallVector<BasicBlock *, 2> NewBBs; + SplitLandingPadPredecessors(LPad, Parent, ".1", ".2", NewBBs); + } + } +} + +bool BlockExtractor::runOnModule(Module &M) { + + bool Changed = false; + + // Get all the functions. + SmallVector<Function *, 4> Functions; + for (Function &F : M) { + splitLandingPadPreds(F); + Functions.push_back(&F); + } + + // Get all the blocks specified in the input file. + for (const auto &BInfo : BlocksByName) { + Function *F = M.getFunction(BInfo.first); + if (!F) + report_fatal_error("Invalid function name specified in the input file"); + auto Res = llvm::find_if(*F, [&](const BasicBlock &BB) { + return BB.getName().equals(BInfo.second); + }); + if (Res == F->end()) + report_fatal_error("Invalid block name specified in the input file"); + Blocks.push_back(&*Res); + } + + // Extract basic blocks. + for (BasicBlock *BB : Blocks) { + // Check if the module contains BB. + if (BB->getParent()->getParent() != &M) + report_fatal_error("Invalid basic block"); + LLVM_DEBUG(dbgs() << "BlockExtractor: Extracting " + << BB->getParent()->getName() << ":" << BB->getName() + << "\n"); + SmallVector<BasicBlock *, 2> BlocksToExtractVec; + BlocksToExtractVec.push_back(BB); + if (const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator())) + BlocksToExtractVec.push_back(II->getUnwindDest()); + CodeExtractor(BlocksToExtractVec).extractCodeRegion(); + ++NumExtracted; + Changed = true; + } + + // Erase the functions. + if (EraseFunctions || BlockExtractorEraseFuncs) { + for (Function *F : Functions) { + LLVM_DEBUG(dbgs() << "BlockExtractor: Trying to delete " << F->getName() + << "\n"); + F->deleteBody(); + } + // Set linkage as ExternalLinkage to avoid erasing unreachable functions. + for (Function &F : M) + F.setLinkage(GlobalValue::ExternalLinkage); + Changed = true; + } + + return Changed; +} diff --git a/lib/Transforms/IPO/CMakeLists.txt b/lib/Transforms/IPO/CMakeLists.txt index 397561746f86..4772baf5976c 100644 --- a/lib/Transforms/IPO/CMakeLists.txt +++ b/lib/Transforms/IPO/CMakeLists.txt @@ -2,6 +2,7 @@ add_llvm_library(LLVMipo AlwaysInliner.cpp ArgumentPromotion.cpp BarrierNoopPass.cpp + BlockExtractor.cpp CalledValuePropagation.cpp ConstantMerge.cpp CrossDSOCFI.cpp @@ -27,8 +28,10 @@ add_llvm_library(LLVMipo PassManagerBuilder.cpp PruneEH.cpp SampleProfile.cpp + SCCP.cpp StripDeadPrototypes.cpp StripSymbols.cpp + SyntheticCountsPropagation.cpp ThinLTOBitcodeWriter.cpp WholeProgramDevirt.cpp diff --git a/lib/Transforms/IPO/CalledValuePropagation.cpp b/lib/Transforms/IPO/CalledValuePropagation.cpp index c5f6336aa2be..d642445b35de 100644 --- a/lib/Transforms/IPO/CalledValuePropagation.cpp +++ b/lib/Transforms/IPO/CalledValuePropagation.cpp @@ -69,12 +69,15 @@ public: CVPLatticeVal() : LatticeState(Undefined) {} CVPLatticeVal(CVPLatticeStateTy LatticeState) : LatticeState(LatticeState) {} - CVPLatticeVal(std::set<Function *, Compare> &&Functions) - : LatticeState(FunctionSet), Functions(Functions) {} + CVPLatticeVal(std::vector<Function *> &&Functions) + : LatticeState(FunctionSet), Functions(std::move(Functions)) { + assert(std::is_sorted(this->Functions.begin(), this->Functions.end(), + Compare())); + } /// Get a reference to the functions held by this lattice value. The number /// of functions will be zero for states other than FunctionSet. - const std::set<Function *, Compare> &getFunctions() const { + const std::vector<Function *> &getFunctions() const { return Functions; } @@ -99,7 +102,8 @@ private: /// MaxFunctionsPerValue. Since most LLVM values are expected to be in /// uninteresting states (i.e., overdefined), CVPLatticeVal objects should be /// small and efficiently copyable. - std::set<Function *, Compare> Functions; + // FIXME: This could be a TinyPtrVector and/or merge with LatticeState. + std::vector<Function *> Functions; }; /// The custom lattice function used by the generic sparse propagation solver. @@ -150,11 +154,10 @@ public: return getOverdefinedVal(); if (X == getUndefVal() && Y == getUndefVal()) return getUndefVal(); - std::set<Function *, CVPLatticeVal::Compare> Union; + std::vector<Function *> Union; std::set_union(X.getFunctions().begin(), X.getFunctions().end(), Y.getFunctions().begin(), Y.getFunctions().end(), - std::inserter(Union, Union.begin()), - CVPLatticeVal::Compare{}); + std::back_inserter(Union), CVPLatticeVal::Compare{}); if (Union.size() > MaxFunctionsPerValue) return getOverdefinedVal(); return CVPLatticeVal(std::move(Union)); @@ -265,6 +268,10 @@ private: // If we can't track the function's return values, there's nothing to do. if (!F || !canTrackReturnsInterprocedurally(F)) { + // Void return, No need to create and update CVPLattice state as no one + // can use it. + if (I->getType()->isVoidTy()) + return; ChangedValues[RegI] = getOverdefinedVal(); return; } @@ -280,6 +287,12 @@ private: ChangedValues[RegFormal] = MergeValues(SS.getValueState(RegFormal), SS.getValueState(RegActual)); } + + // Void return, No need to create and update CVPLattice state as no one can + // use it. + if (I->getType()->isVoidTy()) + return; + ChangedValues[RegI] = MergeValues(SS.getValueState(RegI), SS.getValueState(RetF)); } @@ -377,8 +390,7 @@ static bool runCVP(Module &M) { CVPLatticeVal LV = Solver.getExistingValueState(RegI); if (!LV.isFunctionSet() || LV.getFunctions().empty()) continue; - MDNode *Callees = MDB.createCallees(SmallVector<Function *, 4>( - LV.getFunctions().begin(), LV.getFunctions().end())); + MDNode *Callees = MDB.createCallees(LV.getFunctions()); C->setMetadata(LLVMContext::MD_callees, Callees); Changed = true; } diff --git a/lib/Transforms/IPO/CrossDSOCFI.cpp b/lib/Transforms/IPO/CrossDSOCFI.cpp index 886029ea58d5..666f6cc37bfd 100644 --- a/lib/Transforms/IPO/CrossDSOCFI.cpp +++ b/lib/Transforms/IPO/CrossDSOCFI.cpp @@ -162,9 +162,6 @@ void CrossDSOCFI::buildCFICheck(Module &M) { } bool CrossDSOCFI::runOnModule(Module &M) { - if (skipModule(M)) - return false; - VeryLikelyWeights = MDBuilder(M.getContext()).createBranchWeights((1U << 20) - 1, 1); if (M.getModuleFlag("Cross-DSO CFI") == nullptr) diff --git a/lib/Transforms/IPO/DeadArgumentElimination.cpp b/lib/Transforms/IPO/DeadArgumentElimination.cpp index 5446541550e5..31e771da3bd3 100644 --- a/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -240,8 +240,11 @@ bool DeadArgumentEliminationPass::DeleteDeadVarargs(Function &Fn) { I2->takeName(&*I); } - // Patch the pointer to LLVM function in debug info descriptor. - NF->setSubprogram(Fn.getSubprogram()); + // Clone metadatas from the old function, including debug info descriptor. + SmallVector<std::pair<unsigned, MDNode *>, 1> MDs; + Fn.getAllMetadata(MDs); + for (auto MD : MDs) + NF->addMetadata(MD.first, *MD.second); // Fix up any BlockAddresses that refer to the function. Fn.replaceAllUsesWith(ConstantExpr::getBitCast(NF, Fn.getType())); @@ -507,25 +510,43 @@ void DeadArgumentEliminationPass::SurveyFunction(const Function &F) { // MaybeLive. Initialized to a list of RetCount empty lists. RetUses MaybeLiveRetUses(RetCount); - for (Function::const_iterator BB = F.begin(), E = F.end(); BB != E; ++BB) - if (const ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator())) + bool HasMustTailCalls = false; + + for (Function::const_iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { + if (const ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator())) { if (RI->getNumOperands() != 0 && RI->getOperand(0)->getType() != F.getFunctionType()->getReturnType()) { // We don't support old style multiple return values. MarkLive(F); return; } + } + + // If we have any returns of `musttail` results - the signature can't + // change + if (BB->getTerminatingMustTailCall() != nullptr) + HasMustTailCalls = true; + } + + if (HasMustTailCalls) { + LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - " << F.getName() + << " has musttail calls\n"); + } if (!F.hasLocalLinkage() && (!ShouldHackArguments || F.isIntrinsic())) { MarkLive(F); return; } - DEBUG(dbgs() << "DeadArgumentEliminationPass - Inspecting callers for fn: " - << F.getName() << "\n"); + LLVM_DEBUG( + dbgs() << "DeadArgumentEliminationPass - Inspecting callers for fn: " + << F.getName() << "\n"); // Keep track of the number of live retvals, so we can skip checks once all // of them turn out to be live. unsigned NumLiveRetVals = 0; + + bool HasMustTailCallers = false; + // Loop all uses of the function. for (const Use &U : F.uses()) { // If the function is PASSED IN as an argument, its address has been @@ -536,6 +557,11 @@ void DeadArgumentEliminationPass::SurveyFunction(const Function &F) { return; } + // The number of arguments for `musttail` call must match the number of + // arguments of the caller + if (CS.isMustTailCall()) + HasMustTailCallers = true; + // If this use is anything other than a call site, the function is alive. const Instruction *TheCall = CS.getInstruction(); if (!TheCall) { // Not a direct call site? @@ -580,12 +606,17 @@ void DeadArgumentEliminationPass::SurveyFunction(const Function &F) { } } + if (HasMustTailCallers) { + LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - " << F.getName() + << " has musttail callers\n"); + } + // Now we've inspected all callers, record the liveness of our return values. for (unsigned i = 0; i != RetCount; ++i) MarkValue(CreateRet(&F, i), RetValLiveness[i], MaybeLiveRetUses[i]); - DEBUG(dbgs() << "DeadArgumentEliminationPass - Inspecting args for fn: " - << F.getName() << "\n"); + LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - Inspecting args for fn: " + << F.getName() << "\n"); // Now, check all of our arguments. unsigned i = 0; @@ -593,12 +624,19 @@ void DeadArgumentEliminationPass::SurveyFunction(const Function &F) { for (Function::const_arg_iterator AI = F.arg_begin(), E = F.arg_end(); AI != E; ++AI, ++i) { Liveness Result; - if (F.getFunctionType()->isVarArg()) { + if (F.getFunctionType()->isVarArg() || HasMustTailCallers || + HasMustTailCalls) { // Variadic functions will already have a va_arg function expanded inside // them, making them potentially very sensitive to ABI changes resulting // from removing arguments entirely, so don't. For example AArch64 handles // register and stack HFAs very differently, and this is reflected in the // IR which has already been generated. + // + // `musttail` calls to this function restrict argument removal attempts. + // The signature of the caller must match the signature of the function. + // + // `musttail` calls in this function prevents us from changing its + // signature Result = Live; } else { // See what the effect of this use is (recording any uses that cause @@ -637,8 +675,8 @@ void DeadArgumentEliminationPass::MarkValue(const RetOrArg &RA, Liveness L, /// mark any values that are used as this function's parameters or by its return /// values (according to Uses) live as well. void DeadArgumentEliminationPass::MarkLive(const Function &F) { - DEBUG(dbgs() << "DeadArgumentEliminationPass - Intrinsically live fn: " - << F.getName() << "\n"); + LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - Intrinsically live fn: " + << F.getName() << "\n"); // Mark the function as live. LiveFunctions.insert(&F); // Mark all arguments as live. @@ -659,8 +697,8 @@ void DeadArgumentEliminationPass::MarkLive(const RetOrArg &RA) { if (!LiveValues.insert(RA).second) return; // We were already marked Live. - DEBUG(dbgs() << "DeadArgumentEliminationPass - Marking " - << RA.getDescription() << " live\n"); + LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - Marking " + << RA.getDescription() << " live\n"); PropagateLiveness(RA); } @@ -718,9 +756,9 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { HasLiveReturnedArg |= PAL.hasParamAttribute(i, Attribute::Returned); } else { ++NumArgumentsEliminated; - DEBUG(dbgs() << "DeadArgumentEliminationPass - Removing argument " << i - << " (" << I->getName() << ") from " << F->getName() - << "\n"); + LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - Removing argument " + << i << " (" << I->getName() << ") from " + << F->getName() << "\n"); } } @@ -763,8 +801,9 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { NewRetIdxs[i] = RetTypes.size() - 1; } else { ++NumRetValsEliminated; - DEBUG(dbgs() << "DeadArgumentEliminationPass - Removing return value " - << i << " from " << F->getName() << "\n"); + LLVM_DEBUG( + dbgs() << "DeadArgumentEliminationPass - Removing return value " + << i << " from " << F->getName() << "\n"); } } if (RetTypes.size() > 1) { @@ -803,10 +842,14 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { AttributeSet RetAttrs = AttributeSet::get(F->getContext(), RAttrs); + // Strip allocsize attributes. They might refer to the deleted arguments. + AttributeSet FnAttrs = PAL.getFnAttributes().removeAttribute( + F->getContext(), Attribute::AllocSize); + // Reconstruct the AttributesList based on the vector we constructed. assert(ArgAttrVec.size() == Params.size()); - AttributeList NewPAL = AttributeList::get( - F->getContext(), PAL.getFnAttributes(), RetAttrs, ArgAttrVec); + AttributeList NewPAL = + AttributeList::get(F->getContext(), FnAttrs, RetAttrs, ArgAttrVec); // Create the new function type based on the recomputed parameters. FunctionType *NFTy = FunctionType::get(NRetTy, Params, FTy->isVarArg()); @@ -875,8 +918,14 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { // Reconstruct the AttributesList based on the vector we constructed. assert(ArgAttrVec.size() == Args.size()); + + // Again, be sure to remove any allocsize attributes, since their indices + // may now be incorrect. + AttributeSet FnAttrs = CallPAL.getFnAttributes().removeAttribute( + F->getContext(), Attribute::AllocSize); + AttributeList NewCallPAL = AttributeList::get( - F->getContext(), CallPAL.getFnAttributes(), RetAttrs, ArgAttrVec); + F->getContext(), FnAttrs, RetAttrs, ArgAttrVec); SmallVector<OperandBundleDef, 1> OpBundles; CS.getOperandBundlesAsDefs(OpBundles); @@ -1017,8 +1066,11 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { BB.getInstList().erase(RI); } - // Patch the pointer to LLVM function in debug info descriptor. - NF->setSubprogram(F->getSubprogram()); + // Clone metadatas from the old function, including debug info descriptor. + SmallVector<std::pair<unsigned, MDNode *>, 1> MDs; + F->getAllMetadata(MDs); + for (auto MD : MDs) + NF->addMetadata(MD.first, *MD.second); // Now that the old function is dead, delete it. F->eraseFromParent(); @@ -1034,7 +1086,7 @@ PreservedAnalyses DeadArgumentEliminationPass::run(Module &M, // removed. We can do this if they never call va_start. This loop cannot be // fused with the next loop, because deleting a function invalidates // information computed while surveying other functions. - DEBUG(dbgs() << "DeadArgumentEliminationPass - Deleting dead varargs\n"); + LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - Deleting dead varargs\n"); for (Module::iterator I = M.begin(), E = M.end(); I != E; ) { Function &F = *I++; if (F.getFunctionType()->isVarArg()) @@ -1045,7 +1097,7 @@ PreservedAnalyses DeadArgumentEliminationPass::run(Module &M, // We assume all arguments are dead unless proven otherwise (allowing us to // determine that dead arguments passed into recursive functions are dead). // - DEBUG(dbgs() << "DeadArgumentEliminationPass - Determining liveness\n"); + LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - Determining liveness\n"); for (auto &F : M) SurveyFunction(F); diff --git a/lib/Transforms/IPO/ExtractGV.cpp b/lib/Transforms/IPO/ExtractGV.cpp index 042cacb70ad0..d45a88323910 100644 --- a/lib/Transforms/IPO/ExtractGV.cpp +++ b/lib/Transforms/IPO/ExtractGV.cpp @@ -51,7 +51,7 @@ static void makeVisible(GlobalValue &GV, bool Delete) { } namespace { - /// @brief A pass to extract specific global values and their dependencies. + /// A pass to extract specific global values and their dependencies. class GVExtractorPass : public ModulePass { SetVector<GlobalValue *> Named; bool deleteStuff; diff --git a/lib/Transforms/IPO/ForceFunctionAttrs.cpp b/lib/Transforms/IPO/ForceFunctionAttrs.cpp index 325a5d77aadb..37273f975417 100644 --- a/lib/Transforms/IPO/ForceFunctionAttrs.cpp +++ b/lib/Transforms/IPO/ForceFunctionAttrs.cpp @@ -42,8 +42,10 @@ static Attribute::AttrKind parseAttrKind(StringRef Kind) { .Case("nonlazybind", Attribute::NonLazyBind) .Case("noredzone", Attribute::NoRedZone) .Case("noreturn", Attribute::NoReturn) + .Case("nocf_check", Attribute::NoCfCheck) .Case("norecurse", Attribute::NoRecurse) .Case("nounwind", Attribute::NoUnwind) + .Case("optforfuzzing", Attribute::OptForFuzzing) .Case("optnone", Attribute::OptimizeNone) .Case("optsize", Attribute::OptimizeForSize) .Case("readnone", Attribute::ReadNone) @@ -51,6 +53,7 @@ static Attribute::AttrKind parseAttrKind(StringRef Kind) { .Case("argmemonly", Attribute::ArgMemOnly) .Case("returns_twice", Attribute::ReturnsTwice) .Case("safestack", Attribute::SafeStack) + .Case("shadowcallstack", Attribute::ShadowCallStack) .Case("sanitize_address", Attribute::SanitizeAddress) .Case("sanitize_hwaddress", Attribute::SanitizeHWAddress) .Case("sanitize_memory", Attribute::SanitizeMemory) @@ -72,8 +75,8 @@ static void addForcedAttributes(Function &F) { auto Kind = parseAttrKind(KV.second); if (Kind == Attribute::None) { - DEBUG(dbgs() << "ForcedAttribute: " << KV.second - << " unknown or not handled!\n"); + LLVM_DEBUG(dbgs() << "ForcedAttribute: " << KV.second + << " unknown or not handled!\n"); continue; } if (F.hasFnAttribute(Kind)) diff --git a/lib/Transforms/IPO/FunctionAttrs.cpp b/lib/Transforms/IPO/FunctionAttrs.cpp index 5352e32479bb..2797da6c0abd 100644 --- a/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/lib/Transforms/IPO/FunctionAttrs.cpp @@ -18,7 +18,6 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -74,6 +73,7 @@ STATISTIC(NumReadOnlyArg, "Number of arguments marked readonly"); STATISTIC(NumNoAlias, "Number of function returns marked noalias"); STATISTIC(NumNonNullReturn, "Number of function returns marked nonnull"); STATISTIC(NumNoRecurse, "Number of functions marked as norecurse"); +STATISTIC(NumNoUnwind, "Number of functions marked as nounwind"); // FIXME: This is disabled by default to avoid exposing security vulnerabilities // in C/C++ code compiled by clang: @@ -83,6 +83,10 @@ static cl::opt<bool> EnableNonnullArgPropagation( cl::desc("Try to propagate nonnull argument attributes from callsites to " "caller functions.")); +static cl::opt<bool> DisableNoUnwindInference( + "disable-nounwind-inference", cl::Hidden, + cl::desc("Stop inferring nounwind attribute during function-attrs pass")); + namespace { using SCCNodeSet = SmallSetVector<Function *, 8>; @@ -401,7 +405,7 @@ static Attribute::AttrKind determinePointerReadAttrs(Argument *A, const SmallPtrSet<Argument *, 8> &SCCNodes) { SmallVector<Use *, 32> Worklist; - SmallSet<Use *, 32> Visited; + SmallPtrSet<Use *, 32> Visited; // inalloca arguments are always clobbered by the call. if (A->hasInAllocaAttr()) @@ -1008,7 +1012,8 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes) { if (!Speculative) { // Mark the function eagerly since we may discover a function // which prevents us from speculating about the entire SCC - DEBUG(dbgs() << "Eagerly marking " << F->getName() << " as nonnull\n"); + LLVM_DEBUG(dbgs() << "Eagerly marking " << F->getName() + << " as nonnull\n"); F->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); ++NumNonNullReturn; MadeChange = true; @@ -1027,7 +1032,7 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes) { !F->getReturnType()->isPointerTy()) continue; - DEBUG(dbgs() << "SCC marking " << F->getName() << " as nonnull\n"); + LLVM_DEBUG(dbgs() << "SCC marking " << F->getName() << " as nonnull\n"); F->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); ++NumNonNullReturn; MadeChange = true; @@ -1037,49 +1042,214 @@ static bool addNonNullAttrs(const SCCNodeSet &SCCNodes) { return MadeChange; } -/// Remove the convergent attribute from all functions in the SCC if every -/// callsite within the SCC is not convergent (except for calls to functions -/// within the SCC). Returns true if changes were made. -static bool removeConvergentAttrs(const SCCNodeSet &SCCNodes) { - // For every function in SCC, ensure that either - // * it is not convergent, or - // * we can remove its convergent attribute. - bool HasConvergentFn = false; +namespace { + +/// Collects a set of attribute inference requests and performs them all in one +/// go on a single SCC Node. Inference involves scanning function bodies +/// looking for instructions that violate attribute assumptions. +/// As soon as all the bodies are fine we are free to set the attribute. +/// Customization of inference for individual attributes is performed by +/// providing a handful of predicates for each attribute. +class AttributeInferer { +public: + /// Describes a request for inference of a single attribute. + struct InferenceDescriptor { + + /// Returns true if this function does not have to be handled. + /// General intent for this predicate is to provide an optimization + /// for functions that do not need this attribute inference at all + /// (say, for functions that already have the attribute). + std::function<bool(const Function &)> SkipFunction; + + /// Returns true if this instruction violates attribute assumptions. + std::function<bool(Instruction &)> InstrBreaksAttribute; + + /// Sets the inferred attribute for this function. + std::function<void(Function &)> SetAttribute; + + /// Attribute we derive. + Attribute::AttrKind AKind; + + /// If true, only "exact" definitions can be used to infer this attribute. + /// See GlobalValue::isDefinitionExact. + bool RequiresExactDefinition; + + InferenceDescriptor(Attribute::AttrKind AK, + std::function<bool(const Function &)> SkipFunc, + std::function<bool(Instruction &)> InstrScan, + std::function<void(Function &)> SetAttr, + bool ReqExactDef) + : SkipFunction(SkipFunc), InstrBreaksAttribute(InstrScan), + SetAttribute(SetAttr), AKind(AK), + RequiresExactDefinition(ReqExactDef) {} + }; + +private: + SmallVector<InferenceDescriptor, 4> InferenceDescriptors; + +public: + void registerAttrInference(InferenceDescriptor AttrInference) { + InferenceDescriptors.push_back(AttrInference); + } + + bool run(const SCCNodeSet &SCCNodes); +}; + +/// Perform all the requested attribute inference actions according to the +/// attribute predicates stored before. +bool AttributeInferer::run(const SCCNodeSet &SCCNodes) { + SmallVector<InferenceDescriptor, 4> InferInSCC = InferenceDescriptors; + // Go through all the functions in SCC and check corresponding attribute + // assumptions for each of them. Attributes that are invalid for this SCC + // will be removed from InferInSCC. for (Function *F : SCCNodes) { - if (!F->isConvergent()) continue; - HasConvergentFn = true; - // Can't remove convergent from function declarations. - if (F->isDeclaration()) return false; + // No attributes whose assumptions are still valid - done. + if (InferInSCC.empty()) + return false; - // Can't remove convergent if any of our functions has a convergent call to a - // function not in the SCC. - for (Instruction &I : instructions(*F)) { - CallSite CS(&I); - // Bail if CS is a convergent call to a function not in the SCC. - if (CS && CS.isConvergent() && - SCCNodes.count(CS.getCalledFunction()) == 0) + // Check if our attributes ever need scanning/can be scanned. + llvm::erase_if(InferInSCC, [F](const InferenceDescriptor &ID) { + if (ID.SkipFunction(*F)) return false; + + // Remove from further inference (invalidate) when visiting a function + // that has no instructions to scan/has an unsuitable definition. + return F->isDeclaration() || + (ID.RequiresExactDefinition && !F->hasExactDefinition()); + }); + + // For each attribute still in InferInSCC that doesn't explicitly skip F, + // set up the F instructions scan to verify assumptions of the attribute. + SmallVector<InferenceDescriptor, 4> InferInThisFunc; + llvm::copy_if( + InferInSCC, std::back_inserter(InferInThisFunc), + [F](const InferenceDescriptor &ID) { return !ID.SkipFunction(*F); }); + + if (InferInThisFunc.empty()) + continue; + + // Start instruction scan. + for (Instruction &I : instructions(*F)) { + llvm::erase_if(InferInThisFunc, [&](const InferenceDescriptor &ID) { + if (!ID.InstrBreaksAttribute(I)) + return false; + // Remove attribute from further inference on any other functions + // because attribute assumptions have just been violated. + llvm::erase_if(InferInSCC, [&ID](const InferenceDescriptor &D) { + return D.AKind == ID.AKind; + }); + // Remove attribute from the rest of current instruction scan. + return true; + }); + + if (InferInThisFunc.empty()) + break; } } - // If the SCC doesn't have any convergent functions, we have nothing to do. - if (!HasConvergentFn) return false; + if (InferInSCC.empty()) + return false; - // If we got here, all of the calls the SCC makes to functions not in the SCC - // are non-convergent. Therefore all of the SCC's functions can also be made - // non-convergent. We'll remove the attr from the callsites in - // InstCombineCalls. - for (Function *F : SCCNodes) { - if (!F->isConvergent()) continue; + bool Changed = false; + for (Function *F : SCCNodes) + // At this point InferInSCC contains only functions that were either: + // - explicitly skipped from scan/inference, or + // - verified to have no instructions that break attribute assumptions. + // Hence we just go and force the attribute for all non-skipped functions. + for (auto &ID : InferInSCC) { + if (ID.SkipFunction(*F)) + continue; + Changed = true; + ID.SetAttribute(*F); + } + return Changed; +} - DEBUG(dbgs() << "Removing convergent attr from fn " << F->getName() - << "\n"); - F->setNotConvergent(); +} // end anonymous namespace + +/// Helper for non-Convergent inference predicate InstrBreaksAttribute. +static bool InstrBreaksNonConvergent(Instruction &I, + const SCCNodeSet &SCCNodes) { + const CallSite CS(&I); + // Breaks non-convergent assumption if CS is a convergent call to a function + // not in the SCC. + return CS && CS.isConvergent() && SCCNodes.count(CS.getCalledFunction()) == 0; +} + +/// Helper for NoUnwind inference predicate InstrBreaksAttribute. +static bool InstrBreaksNonThrowing(Instruction &I, const SCCNodeSet &SCCNodes) { + if (!I.mayThrow()) + return false; + if (const auto *CI = dyn_cast<CallInst>(&I)) { + if (Function *Callee = CI->getCalledFunction()) { + // I is a may-throw call to a function inside our SCC. This doesn't + // invalidate our current working assumption that the SCC is no-throw; we + // just have to scan that other function. + if (SCCNodes.count(Callee) > 0) + return false; + } } return true; } +/// Infer attributes from all functions in the SCC by scanning every +/// instruction for compliance to the attribute assumptions. Currently it +/// does: +/// - removal of Convergent attribute +/// - addition of NoUnwind attribute +/// +/// Returns true if any changes to function attributes were made. +static bool inferAttrsFromFunctionBodies(const SCCNodeSet &SCCNodes) { + + AttributeInferer AI; + + // Request to remove the convergent attribute from all functions in the SCC + // if every callsite within the SCC is not convergent (except for calls + // to functions within the SCC). + // Note: Removal of the attr from the callsites will happen in + // InstCombineCalls separately. + AI.registerAttrInference(AttributeInferer::InferenceDescriptor{ + Attribute::Convergent, + // Skip non-convergent functions. + [](const Function &F) { return !F.isConvergent(); }, + // Instructions that break non-convergent assumption. + [SCCNodes](Instruction &I) { + return InstrBreaksNonConvergent(I, SCCNodes); + }, + [](Function &F) { + LLVM_DEBUG(dbgs() << "Removing convergent attr from fn " << F.getName() + << "\n"); + F.setNotConvergent(); + }, + /* RequiresExactDefinition= */ false}); + + if (!DisableNoUnwindInference) + // Request to infer nounwind attribute for all the functions in the SCC if + // every callsite within the SCC is not throwing (except for calls to + // functions within the SCC). Note that nounwind attribute suffers from + // derefinement - results may change depending on how functions are + // optimized. Thus it can be inferred only from exact definitions. + AI.registerAttrInference(AttributeInferer::InferenceDescriptor{ + Attribute::NoUnwind, + // Skip non-throwing functions. + [](const Function &F) { return F.doesNotThrow(); }, + // Instructions that break non-throwing assumption. + [SCCNodes](Instruction &I) { + return InstrBreaksNonThrowing(I, SCCNodes); + }, + [](Function &F) { + LLVM_DEBUG(dbgs() + << "Adding nounwind attr to fn " << F.getName() << "\n"); + F.setDoesNotThrow(); + ++NumNoUnwind; + }, + /* RequiresExactDefinition= */ true}); + + // Perform all the requested attribute inference actions. + return AI.run(SCCNodes); +} + static bool setDoesNotRecurse(Function &F) { if (F.doesNotRecurse()) return false; @@ -1136,7 +1306,8 @@ PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C, bool HasUnknownCall = false; for (LazyCallGraph::Node &N : C) { Function &F = N.getFunction(); - if (F.hasFnAttribute(Attribute::OptimizeNone)) { + if (F.hasFnAttribute(Attribute::OptimizeNone) || + F.hasFnAttribute(Attribute::Naked)) { // Treat any function we're trying not to optimize as if it were an // indirect call and omit it from the node set used below. HasUnknownCall = true; @@ -1167,7 +1338,7 @@ PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C, if (!HasUnknownCall) { Changed |= addNoAliasAttrs(SCCNodes); Changed |= addNonNullAttrs(SCCNodes); - Changed |= removeConvergentAttrs(SCCNodes); + Changed |= inferAttrsFromFunctionBodies(SCCNodes); Changed |= addNoRecurseAttrs(SCCNodes); } @@ -1221,7 +1392,8 @@ static bool runImpl(CallGraphSCC &SCC, AARGetterT AARGetter) { bool ExternalNode = false; for (CallGraphNode *I : SCC) { Function *F = I->getFunction(); - if (!F || F->hasFnAttribute(Attribute::OptimizeNone)) { + if (!F || F->hasFnAttribute(Attribute::OptimizeNone) || + F->hasFnAttribute(Attribute::Naked)) { // External node or function we're trying not to optimize - we both avoid // transform them and avoid leveraging information they provide. ExternalNode = true; @@ -1244,7 +1416,7 @@ static bool runImpl(CallGraphSCC &SCC, AARGetterT AARGetter) { if (!ExternalNode) { Changed |= addNoAliasAttrs(SCCNodes); Changed |= addNonNullAttrs(SCCNodes); - Changed |= removeConvergentAttrs(SCCNodes); + Changed |= inferAttrsFromFunctionBodies(SCCNodes); Changed |= addNoRecurseAttrs(SCCNodes); } diff --git a/lib/Transforms/IPO/FunctionImport.cpp b/lib/Transforms/IPO/FunctionImport.cpp index b6d6201cd23b..15808a073894 100644 --- a/lib/Transforms/IPO/FunctionImport.cpp +++ b/lib/Transforms/IPO/FunctionImport.cpp @@ -18,8 +18,8 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringSet.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/IR/AutoUpgrade.h" #include "llvm/IR/Constants.h" @@ -61,6 +61,7 @@ using namespace llvm; #define DEBUG_TYPE "function-import" STATISTIC(NumImportedFunctions, "Number of functions imported"); +STATISTIC(NumImportedGlobalVars, "Number of global variables imported"); STATISTIC(NumImportedModules, "Number of modules imported from"); STATISTIC(NumDeadSymbols, "Number of dead stripped symbols in index"); STATISTIC(NumLiveSymbols, "Number of live symbols in index"); @@ -70,6 +71,10 @@ static cl::opt<unsigned> ImportInstrLimit( "import-instr-limit", cl::init(100), cl::Hidden, cl::value_desc("N"), cl::desc("Only import functions with less than N instructions")); +static cl::opt<int> ImportCutoff( + "import-cutoff", cl::init(-1), cl::Hidden, cl::value_desc("N"), + cl::desc("Only import first N functions if N>=0 (default -1)")); + static cl::opt<float> ImportInstrFactor("import-instr-evolution-factor", cl::init(0.7), cl::Hidden, cl::value_desc("x"), @@ -131,7 +136,7 @@ static cl::opt<bool> static std::unique_ptr<Module> loadFile(const std::string &FileName, LLVMContext &Context) { SMDiagnostic Err; - DEBUG(dbgs() << "Loading '" << FileName << "'\n"); + LLVM_DEBUG(dbgs() << "Loading '" << FileName << "'\n"); // Metadata isn't loaded until functions are imported, to minimize // the memory overhead. std::unique_ptr<Module> Result = @@ -163,6 +168,9 @@ selectCallee(const ModuleSummaryIndex &Index, CalleeSummaryList, [&](const std::unique_ptr<GlobalValueSummary> &SummaryPtr) { auto *GVSummary = SummaryPtr.get(); + if (!Index.isGlobalValueLive(GVSummary)) + return false; + // For SamplePGO, in computeImportForFunction the OriginalId // may have been used to locate the callee summary list (See // comment there). @@ -231,10 +239,37 @@ updateValueInfoForIndirectCalls(const ModuleSummaryIndex &Index, ValueInfo VI) { // it, rather than needing to perform this mapping on each walk. auto GUID = Index.getGUIDFromOriginalID(VI.getGUID()); if (GUID == 0) - return nullptr; + return ValueInfo(); return Index.getValueInfo(GUID); } +static void computeImportForReferencedGlobals( + const FunctionSummary &Summary, const GVSummaryMapTy &DefinedGVSummaries, + FunctionImporter::ImportMapTy &ImportList, + StringMap<FunctionImporter::ExportSetTy> *ExportLists) { + for (auto &VI : Summary.refs()) { + if (DefinedGVSummaries.count(VI.getGUID())) { + LLVM_DEBUG( + dbgs() << "Ref ignored! Target already in destination module.\n"); + continue; + } + + LLVM_DEBUG(dbgs() << " ref -> " << VI << "\n"); + + for (auto &RefSummary : VI.getSummaryList()) + if (RefSummary->getSummaryKind() == GlobalValueSummary::GlobalVarKind && + // Don't try to import regular LTO summaries added to dummy module. + !RefSummary->modulePath().empty() && + !GlobalValue::isInterposableLinkage(RefSummary->linkage()) && + RefSummary->refs().empty()) { + ImportList[RefSummary->modulePath()].insert(VI.getGUID()); + if (ExportLists) + (*ExportLists)[RefSummary->modulePath()].insert(VI.getGUID()); + break; + } + } +} + /// Compute the list of functions to import for a given caller. Mark these /// imported functions and the symbols they reference in their source module as /// exported from their source module. @@ -243,18 +278,28 @@ static void computeImportForFunction( const unsigned Threshold, const GVSummaryMapTy &DefinedGVSummaries, SmallVectorImpl<EdgeInfo> &Worklist, FunctionImporter::ImportMapTy &ImportList, - StringMap<FunctionImporter::ExportSetTy> *ExportLists = nullptr) { + StringMap<FunctionImporter::ExportSetTy> *ExportLists, + FunctionImporter::ImportThresholdsTy &ImportThresholds) { + computeImportForReferencedGlobals(Summary, DefinedGVSummaries, ImportList, + ExportLists); + static int ImportCount = 0; for (auto &Edge : Summary.calls()) { ValueInfo VI = Edge.first; - DEBUG(dbgs() << " edge -> " << VI.getGUID() << " Threshold:" << Threshold - << "\n"); + LLVM_DEBUG(dbgs() << " edge -> " << VI << " Threshold:" << Threshold + << "\n"); + + if (ImportCutoff >= 0 && ImportCount >= ImportCutoff) { + LLVM_DEBUG(dbgs() << "ignored! import-cutoff value of " << ImportCutoff + << " reached.\n"); + continue; + } VI = updateValueInfoForIndirectCalls(Index, VI); if (!VI) continue; if (DefinedGVSummaries.count(VI.getGUID())) { - DEBUG(dbgs() << "ignored! Target already in destination module.\n"); + LLVM_DEBUG(dbgs() << "ignored! Target already in destination module.\n"); continue; } @@ -269,20 +314,87 @@ static void computeImportForFunction( }; const auto NewThreshold = - Threshold * GetBonusMultiplier(Edge.second.Hotness); - - auto *CalleeSummary = selectCallee(Index, VI.getSummaryList(), NewThreshold, - Summary.modulePath()); - if (!CalleeSummary) { - DEBUG(dbgs() << "ignored! No qualifying callee with summary found.\n"); - continue; - } + Threshold * GetBonusMultiplier(Edge.second.getHotness()); + + auto IT = ImportThresholds.insert( + std::make_pair(VI.getGUID(), std::make_pair(NewThreshold, nullptr))); + bool PreviouslyVisited = !IT.second; + auto &ProcessedThreshold = IT.first->second.first; + auto &CalleeSummary = IT.first->second.second; + + const FunctionSummary *ResolvedCalleeSummary = nullptr; + if (CalleeSummary) { + assert(PreviouslyVisited); + // Since the traversal of the call graph is DFS, we can revisit a function + // a second time with a higher threshold. In this case, it is added back + // to the worklist with the new threshold (so that its own callee chains + // can be considered with the higher threshold). + if (NewThreshold <= ProcessedThreshold) { + LLVM_DEBUG( + dbgs() << "ignored! Target was already imported with Threshold " + << ProcessedThreshold << "\n"); + continue; + } + // Update with new larger threshold. + ProcessedThreshold = NewThreshold; + ResolvedCalleeSummary = cast<FunctionSummary>(CalleeSummary); + } else { + // If we already rejected importing a callee at the same or higher + // threshold, don't waste time calling selectCallee. + if (PreviouslyVisited && NewThreshold <= ProcessedThreshold) { + LLVM_DEBUG( + dbgs() << "ignored! Target was already rejected with Threshold " + << ProcessedThreshold << "\n"); + continue; + } - // "Resolve" the summary - const auto *ResolvedCalleeSummary = cast<FunctionSummary>(CalleeSummary->getBaseObject()); + CalleeSummary = selectCallee(Index, VI.getSummaryList(), NewThreshold, + Summary.modulePath()); + if (!CalleeSummary) { + // Update with new larger threshold if this was a retry (otherwise + // we would have already inserted with NewThreshold above). + if (PreviouslyVisited) + ProcessedThreshold = NewThreshold; + LLVM_DEBUG( + dbgs() << "ignored! No qualifying callee with summary found.\n"); + continue; + } - assert(ResolvedCalleeSummary->instCount() <= NewThreshold && - "selectCallee() didn't honor the threshold"); + // "Resolve" the summary + CalleeSummary = CalleeSummary->getBaseObject(); + ResolvedCalleeSummary = cast<FunctionSummary>(CalleeSummary); + + assert(ResolvedCalleeSummary->instCount() <= NewThreshold && + "selectCallee() didn't honor the threshold"); + + auto ExportModulePath = ResolvedCalleeSummary->modulePath(); + auto ILI = ImportList[ExportModulePath].insert(VI.getGUID()); + // We previously decided to import this GUID definition if it was already + // inserted in the set of imports from the exporting module. + bool PreviouslyImported = !ILI.second; + + // Make exports in the source module. + if (ExportLists) { + auto &ExportList = (*ExportLists)[ExportModulePath]; + ExportList.insert(VI.getGUID()); + if (!PreviouslyImported) { + // This is the first time this function was exported from its source + // module, so mark all functions and globals it references as exported + // to the outside if they are defined in the same source module. + // For efficiency, we unconditionally add all the referenced GUIDs + // to the ExportList for this module, and will prune out any not + // defined in the module later in a single pass. + for (auto &Edge : ResolvedCalleeSummary->calls()) { + auto CalleeGUID = Edge.first.getGUID(); + ExportList.insert(CalleeGUID); + } + for (auto &Ref : ResolvedCalleeSummary->refs()) { + auto GUID = Ref.getGUID(); + ExportList.insert(GUID); + } + } + } + } auto GetAdjustedThreshold = [](unsigned Threshold, bool IsHotCallsite) { // Adjust the threshold for next level of imported functions. @@ -293,44 +405,11 @@ static void computeImportForFunction( return Threshold * ImportInstrFactor; }; - bool IsHotCallsite = Edge.second.Hotness == CalleeInfo::HotnessType::Hot; + bool IsHotCallsite = + Edge.second.getHotness() == CalleeInfo::HotnessType::Hot; const auto AdjThreshold = GetAdjustedThreshold(Threshold, IsHotCallsite); - auto ExportModulePath = ResolvedCalleeSummary->modulePath(); - auto &ProcessedThreshold = ImportList[ExportModulePath][VI.getGUID()]; - /// Since the traversal of the call graph is DFS, we can revisit a function - /// a second time with a higher threshold. In this case, it is added back to - /// the worklist with the new threshold. - if (ProcessedThreshold && ProcessedThreshold >= AdjThreshold) { - DEBUG(dbgs() << "ignored! Target was already seen with Threshold " - << ProcessedThreshold << "\n"); - continue; - } - bool PreviouslyImported = ProcessedThreshold != 0; - // Mark this function as imported in this module, with the current Threshold - ProcessedThreshold = AdjThreshold; - - // Make exports in the source module. - if (ExportLists) { - auto &ExportList = (*ExportLists)[ExportModulePath]; - ExportList.insert(VI.getGUID()); - if (!PreviouslyImported) { - // This is the first time this function was exported from its source - // module, so mark all functions and globals it references as exported - // to the outside if they are defined in the same source module. - // For efficiency, we unconditionally add all the referenced GUIDs - // to the ExportList for this module, and will prune out any not - // defined in the module later in a single pass. - for (auto &Edge : ResolvedCalleeSummary->calls()) { - auto CalleeGUID = Edge.first.getGUID(); - ExportList.insert(CalleeGUID); - } - for (auto &Ref : ResolvedCalleeSummary->refs()) { - auto GUID = Ref.getGUID(); - ExportList.insert(GUID); - } - } - } + ImportCount++; // Insert the newly imported function to the worklist. Worklist.emplace_back(ResolvedCalleeSummary, AdjThreshold, VI.getGUID()); @@ -347,12 +426,18 @@ static void ComputeImportForModule( // Worklist contains the list of function imported in this module, for which // we will analyse the callees and may import further down the callgraph. SmallVector<EdgeInfo, 128> Worklist; + FunctionImporter::ImportThresholdsTy ImportThresholds; // Populate the worklist with the import for the functions in the current // module for (auto &GVSummary : DefinedGVSummaries) { +#ifndef NDEBUG + // FIXME: Change the GVSummaryMapTy to hold ValueInfo instead of GUID + // so this map look up (and possibly others) can be avoided. + auto VI = Index.getValueInfo(GVSummary.first); +#endif if (!Index.isGlobalValueLive(GVSummary.second)) { - DEBUG(dbgs() << "Ignores Dead GUID: " << GVSummary.first << "\n"); + LLVM_DEBUG(dbgs() << "Ignores Dead GUID: " << VI << "\n"); continue; } auto *FuncSummary = @@ -360,10 +445,10 @@ static void ComputeImportForModule( if (!FuncSummary) // Skip import for global variables continue; - DEBUG(dbgs() << "Initialize import for " << GVSummary.first << "\n"); + LLVM_DEBUG(dbgs() << "Initialize import for " << VI << "\n"); computeImportForFunction(*FuncSummary, Index, ImportInstrLimit, DefinedGVSummaries, Worklist, ImportList, - ExportLists); + ExportLists, ImportThresholds); } // Process the newly imported functions and add callees to the worklist. @@ -371,20 +456,37 @@ static void ComputeImportForModule( auto FuncInfo = Worklist.pop_back_val(); auto *Summary = std::get<0>(FuncInfo); auto Threshold = std::get<1>(FuncInfo); - auto GUID = std::get<2>(FuncInfo); - - // Check if we later added this summary with a higher threshold. - // If so, skip this entry. - auto ExportModulePath = Summary->modulePath(); - auto &LatestProcessedThreshold = ImportList[ExportModulePath][GUID]; - if (LatestProcessedThreshold > Threshold) - continue; computeImportForFunction(*Summary, Index, Threshold, DefinedGVSummaries, - Worklist, ImportList, ExportLists); + Worklist, ImportList, ExportLists, + ImportThresholds); + } +} + +#ifndef NDEBUG +static bool isGlobalVarSummary(const ModuleSummaryIndex &Index, + GlobalValue::GUID G) { + if (const auto &VI = Index.getValueInfo(G)) { + auto SL = VI.getSummaryList(); + if (!SL.empty()) + return SL[0]->getSummaryKind() == GlobalValueSummary::GlobalVarKind; } + return false; } +static GlobalValue::GUID getGUID(GlobalValue::GUID G) { return G; } + +template <class T> +static unsigned numGlobalVarSummaries(const ModuleSummaryIndex &Index, + T &Cont) { + unsigned NumGVS = 0; + for (auto &V : Cont) + if (isGlobalVarSummary(Index, getGUID(V))) + ++NumGVS; + return NumGVS; +} +#endif + /// Compute all the import and export for every module using the Index. void llvm::ComputeCrossModuleImport( const ModuleSummaryIndex &Index, @@ -394,8 +496,8 @@ void llvm::ComputeCrossModuleImport( // For each module that has function defined, compute the import/export lists. for (auto &DefinedGVSummaries : ModuleToDefinedGVSummaries) { auto &ImportList = ImportLists[DefinedGVSummaries.first()]; - DEBUG(dbgs() << "Computing import for Module '" - << DefinedGVSummaries.first() << "'\n"); + LLVM_DEBUG(dbgs() << "Computing import for Module '" + << DefinedGVSummaries.first() << "'\n"); ComputeImportForModule(DefinedGVSummaries.second, Index, ImportList, &ExportLists); } @@ -417,32 +519,41 @@ void llvm::ComputeCrossModuleImport( } #ifndef NDEBUG - DEBUG(dbgs() << "Import/Export lists for " << ImportLists.size() - << " modules:\n"); + LLVM_DEBUG(dbgs() << "Import/Export lists for " << ImportLists.size() + << " modules:\n"); for (auto &ModuleImports : ImportLists) { auto ModName = ModuleImports.first(); auto &Exports = ExportLists[ModName]; - DEBUG(dbgs() << "* Module " << ModName << " exports " << Exports.size() - << " functions. Imports from " << ModuleImports.second.size() - << " modules.\n"); + unsigned NumGVS = numGlobalVarSummaries(Index, Exports); + LLVM_DEBUG(dbgs() << "* Module " << ModName << " exports " + << Exports.size() - NumGVS << " functions and " << NumGVS + << " vars. Imports from " << ModuleImports.second.size() + << " modules.\n"); for (auto &Src : ModuleImports.second) { auto SrcModName = Src.first(); - DEBUG(dbgs() << " - " << Src.second.size() << " functions imported from " - << SrcModName << "\n"); + unsigned NumGVSPerMod = numGlobalVarSummaries(Index, Src.second); + LLVM_DEBUG(dbgs() << " - " << Src.second.size() - NumGVSPerMod + << " functions imported from " << SrcModName << "\n"); + LLVM_DEBUG(dbgs() << " - " << NumGVSPerMod + << " global vars imported from " << SrcModName << "\n"); } } #endif } #ifndef NDEBUG -static void dumpImportListForModule(StringRef ModulePath, +static void dumpImportListForModule(const ModuleSummaryIndex &Index, + StringRef ModulePath, FunctionImporter::ImportMapTy &ImportList) { - DEBUG(dbgs() << "* Module " << ModulePath << " imports from " - << ImportList.size() << " modules.\n"); + LLVM_DEBUG(dbgs() << "* Module " << ModulePath << " imports from " + << ImportList.size() << " modules.\n"); for (auto &Src : ImportList) { auto SrcModName = Src.first(); - DEBUG(dbgs() << " - " << Src.second.size() << " functions imported from " - << SrcModName << "\n"); + unsigned NumGVSPerMod = numGlobalVarSummaries(Index, Src.second); + LLVM_DEBUG(dbgs() << " - " << Src.second.size() - NumGVSPerMod + << " functions imported from " << SrcModName << "\n"); + LLVM_DEBUG(dbgs() << " - " << NumGVSPerMod << " vars imported from " + << SrcModName << "\n"); } } #endif @@ -457,11 +568,11 @@ void llvm::ComputeCrossModuleImportForModule( Index.collectDefinedFunctionsForModule(ModulePath, FunctionSummaryMap); // Compute the import list for this module. - DEBUG(dbgs() << "Computing import for Module '" << ModulePath << "'\n"); + LLVM_DEBUG(dbgs() << "Computing import for Module '" << ModulePath << "'\n"); ComputeImportForModule(FunctionSummaryMap, Index, ImportList); #ifndef NDEBUG - dumpImportListForModule(ModulePath, ImportList); + dumpImportListForModule(Index, ModulePath, ImportList); #endif } @@ -483,18 +594,18 @@ void llvm::ComputeCrossModuleImportForModuleFromIndex( // e.g. record required linkage changes. if (Summary->modulePath() == ModulePath) continue; - // Doesn't matter what value we plug in to the map, just needs an entry - // to provoke importing by thinBackend. - ImportList[Summary->modulePath()][GUID] = 1; + // Add an entry to provoke importing by thinBackend. + ImportList[Summary->modulePath()].insert(GUID); } #ifndef NDEBUG - dumpImportListForModule(ModulePath, ImportList); + dumpImportListForModule(Index, ModulePath, ImportList); #endif } void llvm::computeDeadSymbols( ModuleSummaryIndex &Index, - const DenseSet<GlobalValue::GUID> &GUIDPreservedSymbols) { + const DenseSet<GlobalValue::GUID> &GUIDPreservedSymbols, + function_ref<PrevailingType(GlobalValue::GUID)> isPrevailing) { assert(!Index.withGlobalValueDeadStripping()); if (!ComputeDead) return; @@ -513,17 +624,18 @@ void llvm::computeDeadSymbols( } // Add values flagged in the index as live roots to the worklist. - for (const auto &Entry : Index) + for (const auto &Entry : Index) { + auto VI = Index.getValueInfo(Entry); for (auto &S : Entry.second.SummaryList) if (S->isLive()) { - DEBUG(dbgs() << "Live root: " << Entry.first << "\n"); - Worklist.push_back(ValueInfo(&Entry)); + LLVM_DEBUG(dbgs() << "Live root: " << VI << "\n"); + Worklist.push_back(VI); ++LiveSymbols; break; } + } // Make value live and add it to the worklist if it was not live before. - // FIXME: we should only make the prevailing copy live here auto visit = [&](ValueInfo VI) { // FIXME: If we knew which edges were created for indirect call profiles, // we could skip them here. Any that are live should be reached via @@ -539,6 +651,28 @@ void llvm::computeDeadSymbols( for (auto &S : VI.getSummaryList()) if (S->isLive()) return; + + // We only keep live symbols that are known to be non-prevailing if any are + // available_externally. Those symbols are discarded later in the + // EliminateAvailableExternally pass and setting them to not-live breaks + // downstreams users of liveness information (PR36483). + if (isPrevailing(VI.getGUID()) == PrevailingType::No) { + bool AvailableExternally = false; + bool Interposable = false; + for (auto &S : VI.getSummaryList()) { + if (S->linkage() == GlobalValue::AvailableExternallyLinkage) + AvailableExternally = true; + else if (GlobalValue::isInterposableLinkage(S->linkage())) + Interposable = true; + } + + if (!AvailableExternally) + return; + + if (Interposable) + report_fatal_error("Interposable and available_externally symbol"); + } + for (auto &S : VI.getSummaryList()) S->setLive(true); ++LiveSymbols; @@ -549,6 +683,8 @@ void llvm::computeDeadSymbols( auto VI = Worklist.pop_back_val(); for (auto &Summary : VI.getSummaryList()) { GlobalValueSummary *Base = Summary->getBaseObject(); + // Set base value live in case it is an alias. + Base->setLive(true); for (auto Ref : Base->refs()) visit(Ref); if (auto *FS = dyn_cast<FunctionSummary>(Base)) @@ -559,8 +695,8 @@ void llvm::computeDeadSymbols( Index.setWithGlobalValueDeadStripping(); unsigned DeadSymbols = Index.size() - LiveSymbols; - DEBUG(dbgs() << LiveSymbols << " symbols Live, and " << DeadSymbols - << " symbols Dead \n"); + LLVM_DEBUG(dbgs() << LiveSymbols << " symbols Live, and " << DeadSymbols + << " symbols Dead \n"); NumDeadSymbols += DeadSymbols; NumLiveSymbols += LiveSymbols; } @@ -581,47 +717,66 @@ void llvm::gatherImportedSummariesForModule( const auto &DefinedGVSummaries = ModuleToDefinedGVSummaries.lookup(ILI.first()); for (auto &GI : ILI.second) { - const auto &DS = DefinedGVSummaries.find(GI.first); + const auto &DS = DefinedGVSummaries.find(GI); assert(DS != DefinedGVSummaries.end() && "Expected a defined summary for imported global value"); - SummariesForIndex[GI.first] = DS->second; + SummariesForIndex[GI] = DS->second; } } } /// Emit the files \p ModulePath will import from into \p OutputFilename. -std::error_code -llvm::EmitImportsFiles(StringRef ModulePath, StringRef OutputFilename, - const FunctionImporter::ImportMapTy &ModuleImports) { +std::error_code llvm::EmitImportsFiles( + StringRef ModulePath, StringRef OutputFilename, + const std::map<std::string, GVSummaryMapTy> &ModuleToSummariesForIndex) { std::error_code EC; raw_fd_ostream ImportsOS(OutputFilename, EC, sys::fs::OpenFlags::F_None); if (EC) return EC; - for (auto &ILI : ModuleImports) - ImportsOS << ILI.first() << "\n"; + for (auto &ILI : ModuleToSummariesForIndex) + // The ModuleToSummariesForIndex map includes an entry for the current + // Module (needed for writing out the index files). We don't want to + // include it in the imports file, however, so filter it out. + if (ILI.first != ModulePath) + ImportsOS << ILI.first << "\n"; return std::error_code(); } +bool llvm::convertToDeclaration(GlobalValue &GV) { + LLVM_DEBUG(dbgs() << "Converting to a declaration: `" << GV.getName() + << "\n"); + if (Function *F = dyn_cast<Function>(&GV)) { + F->deleteBody(); + F->clearMetadata(); + F->setComdat(nullptr); + } else if (GlobalVariable *V = dyn_cast<GlobalVariable>(&GV)) { + V->setInitializer(nullptr); + V->setLinkage(GlobalValue::ExternalLinkage); + V->clearMetadata(); + V->setComdat(nullptr); + } else { + GlobalValue *NewGV; + if (GV.getValueType()->isFunctionTy()) + NewGV = + Function::Create(cast<FunctionType>(GV.getValueType()), + GlobalValue::ExternalLinkage, "", GV.getParent()); + else + NewGV = + new GlobalVariable(*GV.getParent(), GV.getValueType(), + /*isConstant*/ false, GlobalValue::ExternalLinkage, + /*init*/ nullptr, "", + /*insertbefore*/ nullptr, GV.getThreadLocalMode(), + GV.getType()->getAddressSpace()); + NewGV->takeName(&GV); + GV.replaceAllUsesWith(NewGV); + return false; + } + return true; +} + /// Fixup WeakForLinker linkages in \p TheModule based on summary analysis. void llvm::thinLTOResolveWeakForLinkerModule( Module &TheModule, const GVSummaryMapTy &DefinedGlobals) { - auto ConvertToDeclaration = [](GlobalValue &GV) { - DEBUG(dbgs() << "Converting to a declaration: `" << GV.getName() << "\n"); - if (Function *F = dyn_cast<Function>(&GV)) { - F->deleteBody(); - F->clearMetadata(); - } else if (GlobalVariable *V = dyn_cast<GlobalVariable>(&GV)) { - V->setInitializer(nullptr); - V->setLinkage(GlobalValue::ExternalLinkage); - V->clearMetadata(); - } else - // For now we don't resolve or drop aliases. Once we do we'll - // need to add support here for creating either a function or - // variable declaration, and return the new GlobalValue* for - // the caller to use. - llvm_unreachable("Expected function or variable"); - }; - auto updateLinkage = [&](GlobalValue &GV) { // See if the global summary analysis computed a new resolved linkage. const auto &GS = DefinedGlobals.find(GV.getGUID()); @@ -651,11 +806,23 @@ void llvm::thinLTOResolveWeakForLinkerModule( // interposable property and possibly get inlined. Simply drop // the definition in that case. if (GlobalValue::isAvailableExternallyLinkage(NewLinkage) && - GlobalValue::isInterposableLinkage(GV.getLinkage())) - ConvertToDeclaration(GV); - else { - DEBUG(dbgs() << "ODR fixing up linkage for `" << GV.getName() << "` from " - << GV.getLinkage() << " to " << NewLinkage << "\n"); + GlobalValue::isInterposableLinkage(GV.getLinkage())) { + if (!convertToDeclaration(GV)) + // FIXME: Change this to collect replaced GVs and later erase + // them from the parent module once thinLTOResolveWeakForLinkerGUID is + // changed to enable this for aliases. + llvm_unreachable("Expected GV to be converted"); + } else { + // If the original symbols has global unnamed addr and linkonce_odr linkage, + // it should be an auto hide symbol. Add hidden visibility to the symbol to + // preserve the property. + if (GV.hasLinkOnceODRLinkage() && GV.hasGlobalUnnamedAddr() && + NewLinkage == GlobalValue::WeakODRLinkage) + GV.setVisibility(GlobalValue::HiddenVisibility); + + LLVM_DEBUG(dbgs() << "ODR fixing up linkage for `" << GV.getName() + << "` from " << GV.getLinkage() << " to " << NewLinkage + << "\n"); GV.setLinkage(NewLinkage); } // Remove declarations from comdats, including available_externally @@ -732,9 +899,9 @@ static Function *replaceAliasWithAliasee(Module *SrcModule, GlobalAlias *GA) { // index. Expected<bool> FunctionImporter::importFunctions( Module &DestModule, const FunctionImporter::ImportMapTy &ImportList) { - DEBUG(dbgs() << "Starting import for Module " - << DestModule.getModuleIdentifier() << "\n"); - unsigned ImportedCount = 0; + LLVM_DEBUG(dbgs() << "Starting import for Module " + << DestModule.getModuleIdentifier() << "\n"); + unsigned ImportedCount = 0, ImportedGVCount = 0; IRMover Mover(DestModule); // Do the actual import of functions now, one Module at a time @@ -766,9 +933,9 @@ Expected<bool> FunctionImporter::importFunctions( continue; auto GUID = F.getGUID(); auto Import = ImportGUIDs.count(GUID); - DEBUG(dbgs() << (Import ? "Is" : "Not") << " importing function " << GUID - << " " << F.getName() << " from " - << SrcModule->getSourceFileName() << "\n"); + LLVM_DEBUG(dbgs() << (Import ? "Is" : "Not") << " importing function " + << GUID << " " << F.getName() << " from " + << SrcModule->getSourceFileName() << "\n"); if (Import) { if (Error Err = F.materialize()) return std::move(Err); @@ -788,13 +955,13 @@ Expected<bool> FunctionImporter::importFunctions( continue; auto GUID = GV.getGUID(); auto Import = ImportGUIDs.count(GUID); - DEBUG(dbgs() << (Import ? "Is" : "Not") << " importing global " << GUID - << " " << GV.getName() << " from " - << SrcModule->getSourceFileName() << "\n"); + LLVM_DEBUG(dbgs() << (Import ? "Is" : "Not") << " importing global " + << GUID << " " << GV.getName() << " from " + << SrcModule->getSourceFileName() << "\n"); if (Import) { if (Error Err = GV.materialize()) return std::move(Err); - GlobalsToImport.insert(&GV); + ImportedGVCount += GlobalsToImport.insert(&GV); } } for (GlobalAlias &GA : SrcModule->aliases()) { @@ -802,9 +969,9 @@ Expected<bool> FunctionImporter::importFunctions( continue; auto GUID = GA.getGUID(); auto Import = ImportGUIDs.count(GUID); - DEBUG(dbgs() << (Import ? "Is" : "Not") << " importing alias " << GUID - << " " << GA.getName() << " from " - << SrcModule->getSourceFileName() << "\n"); + LLVM_DEBUG(dbgs() << (Import ? "Is" : "Not") << " importing alias " + << GUID << " " << GA.getName() << " from " + << SrcModule->getSourceFileName() << "\n"); if (Import) { if (Error Err = GA.materialize()) return std::move(Err); @@ -813,9 +980,9 @@ Expected<bool> FunctionImporter::importFunctions( if (Error Err = Base->materialize()) return std::move(Err); auto *Fn = replaceAliasWithAliasee(SrcModule.get(), &GA); - DEBUG(dbgs() << "Is importing aliasee fn " << Base->getGUID() - << " " << Base->getName() << " from " - << SrcModule->getSourceFileName() << "\n"); + LLVM_DEBUG(dbgs() << "Is importing aliasee fn " << Base->getGUID() + << " " << Base->getName() << " from " + << SrcModule->getSourceFileName() << "\n"); if (EnableImportMetadata) { // Add 'thinlto_src_module' metadata for statistics and debugging. Fn->setMetadata( @@ -851,10 +1018,15 @@ Expected<bool> FunctionImporter::importFunctions( NumImportedModules++; } - NumImportedFunctions += ImportedCount; + NumImportedFunctions += (ImportedCount - ImportedGVCount); + NumImportedGlobalVars += ImportedGVCount; - DEBUG(dbgs() << "Imported " << ImportedCount << " functions for Module " - << DestModule.getModuleIdentifier() << "\n"); + LLVM_DEBUG(dbgs() << "Imported " << ImportedCount - ImportedGVCount + << " functions for Module " + << DestModule.getModuleIdentifier() << "\n"); + LLVM_DEBUG(dbgs() << "Imported " << ImportedGVCount + << " global variables for Module " + << DestModule.getModuleIdentifier() << "\n"); return ImportedCount; } diff --git a/lib/Transforms/IPO/GlobalOpt.cpp b/lib/Transforms/IPO/GlobalOpt.cpp index 4bb2984e3b47..1af7e6894777 100644 --- a/lib/Transforms/IPO/GlobalOpt.cpp +++ b/lib/Transforms/IPO/GlobalOpt.cpp @@ -17,14 +17,16 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/BinaryFormat/Dwarf.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -55,6 +57,7 @@ #include "llvm/Pass.h" #include "llvm/Support/AtomicOrdering.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" @@ -63,7 +66,6 @@ #include "llvm/Transforms/Utils/CtorUtils.h" #include "llvm/Transforms/Utils/Evaluator.h" #include "llvm/Transforms/Utils/GlobalStatus.h" -#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstdint> #include <utility> @@ -88,6 +90,21 @@ STATISTIC(NumNestRemoved , "Number of nest attributes removed"); STATISTIC(NumAliasesResolved, "Number of global aliases resolved"); STATISTIC(NumAliasesRemoved, "Number of global aliases eliminated"); STATISTIC(NumCXXDtorsRemoved, "Number of global C++ destructors removed"); +STATISTIC(NumInternalFunc, "Number of internal functions"); +STATISTIC(NumColdCC, "Number of functions marked coldcc"); + +static cl::opt<bool> + EnableColdCCStressTest("enable-coldcc-stress-test", + cl::desc("Enable stress test of coldcc by adding " + "calling conv to all internal functions."), + cl::init(false), cl::Hidden); + +static cl::opt<int> ColdCCRelFreq( + "coldcc-rel-freq", cl::Hidden, cl::init(2), cl::ZeroOrMore, + cl::desc( + "Maximum block frequency, expressed as a percentage of caller's " + "entry frequency, for a call site to be considered cold for enabling" + "coldcc")); /// Is this global variable possibly used by a leak checker as a root? If so, /// we might not really want to eliminate the stores to it. @@ -483,7 +500,6 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { StartAlignment = DL.getABITypeAlignment(GV->getType()); if (StructType *STy = dyn_cast<StructType>(Ty)) { - uint64_t FragmentOffset = 0; unsigned NumElements = STy->getNumElements(); NewGlobals.reserve(NumElements); const StructLayout &Layout = *DL.getStructLayout(STy); @@ -509,10 +525,9 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { NGV->setAlignment(NewAlign); // Copy over the debug info for the variable. - FragmentOffset = alignTo(FragmentOffset, NewAlign); - uint64_t Size = DL.getTypeSizeInBits(NGV->getValueType()); - transferSRADebugInfo(GV, NGV, FragmentOffset, Size, NumElements); - FragmentOffset += Size; + uint64_t Size = DL.getTypeAllocSizeInBits(NGV->getValueType()); + uint64_t FragmentOffsetInBits = Layout.getElementOffsetInBits(i); + transferSRADebugInfo(GV, NGV, FragmentOffsetInBits, Size, NumElements); } } else if (SequentialType *STy = dyn_cast<SequentialType>(Ty)) { unsigned NumElements = STy->getNumElements(); @@ -522,7 +537,7 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { auto ElTy = STy->getElementType(); uint64_t EltSize = DL.getTypeAllocSize(ElTy); unsigned EltAlign = DL.getABITypeAlignment(ElTy); - uint64_t FragmentSizeInBits = DL.getTypeSizeInBits(ElTy); + uint64_t FragmentSizeInBits = DL.getTypeAllocSizeInBits(ElTy); for (unsigned i = 0, e = NumElements; i != e; ++i) { Constant *In = Init->getAggregateElement(i); assert(In && "Couldn't get element of initializer?"); @@ -551,7 +566,7 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { if (NewGlobals.empty()) return nullptr; - DEBUG(dbgs() << "PERFORMING GLOBAL SRA ON: " << *GV << "\n"); + LLVM_DEBUG(dbgs() << "PERFORMING GLOBAL SRA ON: " << *GV << "\n"); Constant *NullInt =Constant::getNullValue(Type::getInt32Ty(GV->getContext())); @@ -621,7 +636,13 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { /// reprocessing them. static bool AllUsesOfValueWillTrapIfNull(const Value *V, SmallPtrSetImpl<const PHINode*> &PHIs) { - for (const User *U : V->users()) + for (const User *U : V->users()) { + if (const Instruction *I = dyn_cast<Instruction>(U)) { + // If null pointer is considered valid, then all uses are non-trapping. + // Non address-space 0 globals have already been pruned by the caller. + if (NullPointerIsDefined(I->getFunction())) + return false; + } if (isa<LoadInst>(U)) { // Will trap. } else if (const StoreInst *SI = dyn_cast<StoreInst>(U)) { @@ -655,7 +676,7 @@ static bool AllUsesOfValueWillTrapIfNull(const Value *V, //cerr << "NONTRAPPING USE: " << *U; return false; } - + } return true; } @@ -682,6 +703,10 @@ static bool OptimizeAwayTrappingUsesOfValue(Value *V, Constant *NewV) { bool Changed = false; for (auto UI = V->user_begin(), E = V->user_end(); UI != E; ) { Instruction *I = cast<Instruction>(*UI++); + // Uses are non-trapping if null pointer is considered valid. + // Non address-space 0 globals are already pruned by the caller. + if (NullPointerIsDefined(I->getFunction())) + return false; if (LoadInst *LI = dyn_cast<LoadInst>(I)) { LI->setOperand(0, NewV); Changed = true; @@ -783,7 +808,8 @@ static bool OptimizeAwayTrappingUsesOfLoads(GlobalVariable *GV, Constant *LV, } if (Changed) { - DEBUG(dbgs() << "OPTIMIZED LOADS FROM STORED ONCE POINTER: " << *GV << "\n"); + LLVM_DEBUG(dbgs() << "OPTIMIZED LOADS FROM STORED ONCE POINTER: " << *GV + << "\n"); ++NumGlobUses; } @@ -797,7 +823,7 @@ static bool OptimizeAwayTrappingUsesOfLoads(GlobalVariable *GV, Constant *LV, CleanupConstantGlobalUsers(GV, nullptr, DL, TLI); } if (GV->use_empty()) { - DEBUG(dbgs() << " *** GLOBAL NOW DEAD!\n"); + LLVM_DEBUG(dbgs() << " *** GLOBAL NOW DEAD!\n"); Changed = true; GV->eraseFromParent(); ++NumDeleted; @@ -833,7 +859,8 @@ static GlobalVariable * OptimizeGlobalAddressOfMalloc(GlobalVariable *GV, CallInst *CI, Type *AllocTy, ConstantInt *NElements, const DataLayout &DL, TargetLibraryInfo *TLI) { - DEBUG(errs() << "PROMOTING GLOBAL: " << *GV << " CALL = " << *CI << '\n'); + LLVM_DEBUG(errs() << "PROMOTING GLOBAL: " << *GV << " CALL = " << *CI + << '\n'); Type *GlobalType; if (NElements->getZExtValue() == 1) @@ -1269,7 +1296,8 @@ static void RewriteUsesOfLoadForHeapSRoA(LoadInst *Load, static GlobalVariable *PerformHeapAllocSRoA(GlobalVariable *GV, CallInst *CI, Value *NElems, const DataLayout &DL, const TargetLibraryInfo *TLI) { - DEBUG(dbgs() << "SROA HEAP ALLOC: " << *GV << " MALLOC = " << *CI << '\n'); + LLVM_DEBUG(dbgs() << "SROA HEAP ALLOC: " << *GV << " MALLOC = " << *CI + << '\n'); Type *MAT = getMallocAllocatedType(CI, TLI); StructType *STy = cast<StructType>(MAT); @@ -1566,7 +1594,10 @@ static bool optimizeOnceStoredGlobal(GlobalVariable *GV, Value *StoredOnceVal, // users of the loaded value (often calls and loads) that would trap if the // value was null. if (GV->getInitializer()->getType()->isPointerTy() && - GV->getInitializer()->isNullValue()) { + GV->getInitializer()->isNullValue() && + !NullPointerIsDefined( + nullptr /* F */, + GV->getInitializer()->getType()->getPointerAddressSpace())) { if (Constant *SOVC = dyn_cast<Constant>(StoredOnceVal)) { if (GV->getInitializer()->getType() != SOVC->getType()) SOVC = ConstantExpr::getBitCast(SOVC, GV->getInitializer()->getType()); @@ -1608,7 +1639,7 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { if (!isa<LoadInst>(U) && !isa<StoreInst>(U)) return false; - DEBUG(dbgs() << " *** SHRINKING TO BOOL: " << *GV << "\n"); + LLVM_DEBUG(dbgs() << " *** SHRINKING TO BOOL: " << *GV << "\n"); // Create the new global, initializing it to false. GlobalVariable *NewGV = new GlobalVariable(Type::getInt1Ty(GV->getContext()), @@ -1652,15 +1683,11 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { // val * (ValOther - ValInit) + ValInit: // DW_OP_deref DW_OP_constu <ValMinus> // DW_OP_mul DW_OP_constu <ValInit> DW_OP_plus DW_OP_stack_value - E = DIExpression::get(NewGV->getContext(), - {dwarf::DW_OP_deref, - dwarf::DW_OP_constu, - ValMinus, - dwarf::DW_OP_mul, - dwarf::DW_OP_constu, - ValInit, - dwarf::DW_OP_plus, - dwarf::DW_OP_stack_value}); + SmallVector<uint64_t, 12> Ops = { + dwarf::DW_OP_deref, dwarf::DW_OP_constu, ValMinus, + dwarf::DW_OP_mul, dwarf::DW_OP_constu, ValInit, + dwarf::DW_OP_plus}; + E = DIExpression::prependOpcodes(E, Ops, DIExpression::WithStackValue); DIGlobalVariableExpression *DGVE = DIGlobalVariableExpression::get(NewGV->getContext(), DGV, E); NewGV->addDebugInfo(DGVE); @@ -1732,8 +1759,8 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { return true; } -static bool deleteIfDead(GlobalValue &GV, - SmallSet<const Comdat *, 8> &NotDiscardableComdats) { +static bool deleteIfDead( + GlobalValue &GV, SmallPtrSetImpl<const Comdat *> &NotDiscardableComdats) { GV.removeDeadConstantUsers(); if (!GV.isDiscardableIfUnused() && !GV.isDeclaration()) @@ -1751,7 +1778,7 @@ static bool deleteIfDead(GlobalValue &GV, if (!Dead) return false; - DEBUG(dbgs() << "GLOBAL DEAD: " << GV << "\n"); + LLVM_DEBUG(dbgs() << "GLOBAL DEAD: " << GV << "\n"); GV.eraseFromParent(); ++NumDeleted; return true; @@ -1917,7 +1944,7 @@ static bool processInternalGlobal( LookupDomTree)) { const DataLayout &DL = GV->getParent()->getDataLayout(); - DEBUG(dbgs() << "LOCALIZING GLOBAL: " << *GV << "\n"); + LLVM_DEBUG(dbgs() << "LOCALIZING GLOBAL: " << *GV << "\n"); Instruction &FirstI = const_cast<Instruction&>(*GS.AccessingFunction ->getEntryBlock().begin()); Type *ElemTy = GV->getValueType(); @@ -1938,7 +1965,7 @@ static bool processInternalGlobal( // If the global is never loaded (but may be stored to), it is dead. // Delete it now. if (!GS.IsLoaded) { - DEBUG(dbgs() << "GLOBAL NEVER LOADED: " << *GV << "\n"); + LLVM_DEBUG(dbgs() << "GLOBAL NEVER LOADED: " << *GV << "\n"); bool Changed; if (isLeakCheckerRoot(GV)) { @@ -1960,7 +1987,7 @@ static bool processInternalGlobal( } if (GS.StoredType <= GlobalStatus::InitializerStored) { - DEBUG(dbgs() << "MARKING CONSTANT: " << *GV << "\n"); + LLVM_DEBUG(dbgs() << "MARKING CONSTANT: " << *GV << "\n"); GV->setConstant(true); // Clean up any obviously simplifiable users now. @@ -1968,8 +1995,8 @@ static bool processInternalGlobal( // If the global is dead now, just nuke it. if (GV->use_empty()) { - DEBUG(dbgs() << " *** Marking constant allowed us to simplify " - << "all users and delete global!\n"); + LLVM_DEBUG(dbgs() << " *** Marking constant allowed us to simplify " + << "all users and delete global!\n"); GV->eraseFromParent(); ++NumDeleted; return true; @@ -1997,8 +2024,8 @@ static bool processInternalGlobal( CleanupConstantGlobalUsers(GV, GV->getInitializer(), DL, TLI); if (GV->use_empty()) { - DEBUG(dbgs() << " *** Substituting initializer allowed us to " - << "simplify all users and delete global!\n"); + LLVM_DEBUG(dbgs() << " *** Substituting initializer allowed us to " + << "simplify all users and delete global!\n"); GV->eraseFromParent(); ++NumDeleted; } @@ -2097,20 +2124,142 @@ static void RemoveNestAttribute(Function *F) { /// idea here is that we don't want to mess with the convention if the user /// explicitly requested something with performance implications like coldcc, /// GHC, or anyregcc. -static bool isProfitableToMakeFastCC(Function *F) { +static bool hasChangeableCC(Function *F) { CallingConv::ID CC = F->getCallingConv(); + // FIXME: Is it worth transforming x86_stdcallcc and x86_fastcallcc? - return CC == CallingConv::C || CC == CallingConv::X86_ThisCall; + if (CC != CallingConv::C && CC != CallingConv::X86_ThisCall) + return false; + + // FIXME: Change CC for the whole chain of musttail calls when possible. + // + // Can't change CC of the function that either has musttail calls, or is a + // musttail callee itself + for (User *U : F->users()) { + if (isa<BlockAddress>(U)) + continue; + CallInst* CI = dyn_cast<CallInst>(U); + if (!CI) + continue; + + if (CI->isMustTailCall()) + return false; + } + + for (BasicBlock &BB : *F) + if (BB.getTerminatingMustTailCall()) + return false; + + return true; +} + +/// Return true if the block containing the call site has a BlockFrequency of +/// less than ColdCCRelFreq% of the entry block. +static bool isColdCallSite(CallSite CS, BlockFrequencyInfo &CallerBFI) { + const BranchProbability ColdProb(ColdCCRelFreq, 100); + auto CallSiteBB = CS.getInstruction()->getParent(); + auto CallSiteFreq = CallerBFI.getBlockFreq(CallSiteBB); + auto CallerEntryFreq = + CallerBFI.getBlockFreq(&(CS.getCaller()->getEntryBlock())); + return CallSiteFreq < CallerEntryFreq * ColdProb; +} + +// This function checks if the input function F is cold at all call sites. It +// also looks each call site's containing function, returning false if the +// caller function contains other non cold calls. The input vector AllCallsCold +// contains a list of functions that only have call sites in cold blocks. +static bool +isValidCandidateForColdCC(Function &F, + function_ref<BlockFrequencyInfo &(Function &)> GetBFI, + const std::vector<Function *> &AllCallsCold) { + + if (F.user_empty()) + return false; + + for (User *U : F.users()) { + if (isa<BlockAddress>(U)) + continue; + + CallSite CS(cast<Instruction>(U)); + Function *CallerFunc = CS.getInstruction()->getParent()->getParent(); + BlockFrequencyInfo &CallerBFI = GetBFI(*CallerFunc); + if (!isColdCallSite(CS, CallerBFI)) + return false; + auto It = std::find(AllCallsCold.begin(), AllCallsCold.end(), CallerFunc); + if (It == AllCallsCold.end()) + return false; + } + return true; +} + +static void changeCallSitesToColdCC(Function *F) { + for (User *U : F->users()) { + if (isa<BlockAddress>(U)) + continue; + CallSite CS(cast<Instruction>(U)); + CS.setCallingConv(CallingConv::Cold); + } +} + +// This function iterates over all the call instructions in the input Function +// and checks that all call sites are in cold blocks and are allowed to use the +// coldcc calling convention. +static bool +hasOnlyColdCalls(Function &F, + function_ref<BlockFrequencyInfo &(Function &)> GetBFI) { + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { + if (CallInst *CI = dyn_cast<CallInst>(&I)) { + CallSite CS(cast<Instruction>(CI)); + // Skip over isline asm instructions since they aren't function calls. + if (CI->isInlineAsm()) + continue; + Function *CalledFn = CI->getCalledFunction(); + if (!CalledFn) + return false; + if (!CalledFn->hasLocalLinkage()) + return false; + // Skip over instrinsics since they won't remain as function calls. + if (CalledFn->getIntrinsicID() != Intrinsic::not_intrinsic) + continue; + // Check if it's valid to use coldcc calling convention. + if (!hasChangeableCC(CalledFn) || CalledFn->isVarArg() || + CalledFn->hasAddressTaken()) + return false; + BlockFrequencyInfo &CallerBFI = GetBFI(F); + if (!isColdCallSite(CS, CallerBFI)) + return false; + } + } + } + return true; } static bool OptimizeFunctions(Module &M, TargetLibraryInfo *TLI, + function_ref<TargetTransformInfo &(Function &)> GetTTI, + function_ref<BlockFrequencyInfo &(Function &)> GetBFI, function_ref<DominatorTree &(Function &)> LookupDomTree, - SmallSet<const Comdat *, 8> &NotDiscardableComdats) { + SmallPtrSetImpl<const Comdat *> &NotDiscardableComdats) { + bool Changed = false; + + std::vector<Function *> AllCallsCold; + for (Module::iterator FI = M.begin(), E = M.end(); FI != E;) { + Function *F = &*FI++; + if (hasOnlyColdCalls(*F, GetBFI)) + AllCallsCold.push_back(F); + } + // Optimize functions. for (Module::iterator FI = M.begin(), E = M.end(); FI != E; ) { Function *F = &*FI++; + + // Don't perform global opt pass on naked functions; we don't want fast + // calling conventions for naked functions. + if (F->hasFnAttribute(Attribute::Naked)) + continue; + // Functions without names cannot be referenced outside this module. if (!F->hasName() && !F->isDeclaration() && !F->hasLocalLinkage()) F->setLinkage(GlobalValue::InternalLinkage); @@ -2142,7 +2291,25 @@ OptimizeFunctions(Module &M, TargetLibraryInfo *TLI, if (!F->hasLocalLinkage()) continue; - if (isProfitableToMakeFastCC(F) && !F->isVarArg() && + + if (hasChangeableCC(F) && !F->isVarArg() && !F->hasAddressTaken()) { + NumInternalFunc++; + TargetTransformInfo &TTI = GetTTI(*F); + // Change the calling convention to coldcc if either stress testing is + // enabled or the target would like to use coldcc on functions which are + // cold at all call sites and the callers contain no other non coldcc + // calls. + if (EnableColdCCStressTest || + (isValidCandidateForColdCC(*F, GetBFI, AllCallsCold) && + TTI.useColdCCForColdCall(*F))) { + F->setCallingConv(CallingConv::Cold); + changeCallSitesToColdCC(F); + Changed = true; + NumColdCC++; + } + } + + if (hasChangeableCC(F) && !F->isVarArg() && !F->hasAddressTaken()) { // If this function has a calling convention worth changing, is not a // varargs function, and is only called directly, promote it to use the @@ -2168,7 +2335,7 @@ OptimizeFunctions(Module &M, TargetLibraryInfo *TLI, static bool OptimizeGlobalVars(Module &M, TargetLibraryInfo *TLI, function_ref<DominatorTree &(Function &)> LookupDomTree, - SmallSet<const Comdat *, 8> &NotDiscardableComdats) { + SmallPtrSetImpl<const Comdat *> &NotDiscardableComdats) { bool Changed = false; for (Module::global_iterator GVI = M.global_begin(), E = M.global_end(); @@ -2254,6 +2421,131 @@ static void CommitValueTo(Constant *Val, Constant *Addr) { GV->setInitializer(EvaluateStoreInto(GV->getInitializer(), Val, CE, 2)); } +/// Given a map of address -> value, where addresses are expected to be some form +/// of either a global or a constant GEP, set the initializer for the address to +/// be the value. This performs mostly the same function as CommitValueTo() +/// and EvaluateStoreInto() but is optimized to be more efficient for the common +/// case where the set of addresses are GEPs sharing the same underlying global, +/// processing the GEPs in batches rather than individually. +/// +/// To give an example, consider the following C++ code adapted from the clang +/// regression tests: +/// struct S { +/// int n = 10; +/// int m = 2 * n; +/// S(int a) : n(a) {} +/// }; +/// +/// template<typename T> +/// struct U { +/// T *r = &q; +/// T q = 42; +/// U *p = this; +/// }; +/// +/// U<S> e; +/// +/// The global static constructor for 'e' will need to initialize 'r' and 'p' of +/// the outer struct, while also initializing the inner 'q' structs 'n' and 'm' +/// members. This batch algorithm will simply use general CommitValueTo() method +/// to handle the complex nested S struct initialization of 'q', before +/// processing the outermost members in a single batch. Using CommitValueTo() to +/// handle member in the outer struct is inefficient when the struct/array is +/// very large as we end up creating and destroy constant arrays for each +/// initialization. +/// For the above case, we expect the following IR to be generated: +/// +/// %struct.U = type { %struct.S*, %struct.S, %struct.U* } +/// %struct.S = type { i32, i32 } +/// @e = global %struct.U { %struct.S* gep inbounds (%struct.U, %struct.U* @e, +/// i64 0, i32 1), +/// %struct.S { i32 42, i32 84 }, %struct.U* @e } +/// The %struct.S { i32 42, i32 84 } inner initializer is treated as a complex +/// constant expression, while the other two elements of @e are "simple". +static void BatchCommitValueTo(const DenseMap<Constant*, Constant*> &Mem) { + SmallVector<std::pair<GlobalVariable*, Constant*>, 32> GVs; + SmallVector<std::pair<ConstantExpr*, Constant*>, 32> ComplexCEs; + SmallVector<std::pair<ConstantExpr*, Constant*>, 32> SimpleCEs; + SimpleCEs.reserve(Mem.size()); + + for (const auto &I : Mem) { + if (auto *GV = dyn_cast<GlobalVariable>(I.first)) { + GVs.push_back(std::make_pair(GV, I.second)); + } else { + ConstantExpr *GEP = cast<ConstantExpr>(I.first); + // We don't handle the deeply recursive case using the batch method. + if (GEP->getNumOperands() > 3) + ComplexCEs.push_back(std::make_pair(GEP, I.second)); + else + SimpleCEs.push_back(std::make_pair(GEP, I.second)); + } + } + + // The algorithm below doesn't handle cases like nested structs, so use the + // slower fully general method if we have to. + for (auto ComplexCE : ComplexCEs) + CommitValueTo(ComplexCE.second, ComplexCE.first); + + for (auto GVPair : GVs) { + assert(GVPair.first->hasInitializer()); + GVPair.first->setInitializer(GVPair.second); + } + + if (SimpleCEs.empty()) + return; + + // We cache a single global's initializer elements in the case where the + // subsequent address/val pair uses the same one. This avoids throwing away and + // rebuilding the constant struct/vector/array just because one element is + // modified at a time. + SmallVector<Constant *, 32> Elts; + Elts.reserve(SimpleCEs.size()); + GlobalVariable *CurrentGV = nullptr; + + auto commitAndSetupCache = [&](GlobalVariable *GV, bool Update) { + Constant *Init = GV->getInitializer(); + Type *Ty = Init->getType(); + if (Update) { + if (CurrentGV) { + assert(CurrentGV && "Expected a GV to commit to!"); + Type *CurrentInitTy = CurrentGV->getInitializer()->getType(); + // We have a valid cache that needs to be committed. + if (StructType *STy = dyn_cast<StructType>(CurrentInitTy)) + CurrentGV->setInitializer(ConstantStruct::get(STy, Elts)); + else if (ArrayType *ArrTy = dyn_cast<ArrayType>(CurrentInitTy)) + CurrentGV->setInitializer(ConstantArray::get(ArrTy, Elts)); + else + CurrentGV->setInitializer(ConstantVector::get(Elts)); + } + if (CurrentGV == GV) + return; + // Need to clear and set up cache for new initializer. + CurrentGV = GV; + Elts.clear(); + unsigned NumElts; + if (auto *STy = dyn_cast<StructType>(Ty)) + NumElts = STy->getNumElements(); + else + NumElts = cast<SequentialType>(Ty)->getNumElements(); + for (unsigned i = 0, e = NumElts; i != e; ++i) + Elts.push_back(Init->getAggregateElement(i)); + } + }; + + for (auto CEPair : SimpleCEs) { + ConstantExpr *GEP = CEPair.first; + Constant *Val = CEPair.second; + + GlobalVariable *GV = cast<GlobalVariable>(GEP->getOperand(0)); + commitAndSetupCache(GV, GV != CurrentGV); + ConstantInt *CI = cast<ConstantInt>(GEP->getOperand(2)); + Elts[CI->getZExtValue()] = Val; + } + // The last initializer in the list needs to be committed, others + // will be committed on a new initializer being processed. + commitAndSetupCache(CurrentGV, true); +} + /// Evaluate static constructors in the function, if we can. Return true if we /// can, false otherwise. static bool EvaluateStaticConstructor(Function *F, const DataLayout &DL, @@ -2268,11 +2560,10 @@ static bool EvaluateStaticConstructor(Function *F, const DataLayout &DL, ++NumCtorsEvaluated; // We succeeded at evaluation: commit the result. - DEBUG(dbgs() << "FULLY EVALUATED GLOBAL CTOR FUNCTION '" - << F->getName() << "' to " << Eval.getMutatedMemory().size() - << " stores.\n"); - for (const auto &I : Eval.getMutatedMemory()) - CommitValueTo(I.second, I.first); + LLVM_DEBUG(dbgs() << "FULLY EVALUATED GLOBAL CTOR FUNCTION '" + << F->getName() << "' to " + << Eval.getMutatedMemory().size() << " stores.\n"); + BatchCommitValueTo(Eval.getMutatedMemory()); for (GlobalVariable *GV : Eval.getInvariants()) GV->setConstant(true); } @@ -2287,7 +2578,7 @@ static int compareNames(Constant *const *A, Constant *const *B) { } static void setUsedInitializer(GlobalVariable &V, - const SmallPtrSet<GlobalValue *, 8> &Init) { + const SmallPtrSetImpl<GlobalValue *> &Init) { if (Init.empty()) { V.eraseFromParent(); return; @@ -2440,7 +2731,7 @@ static bool hasUsesToReplace(GlobalAlias &GA, const LLVMUsed &U, static bool OptimizeGlobalAliases(Module &M, - SmallSet<const Comdat *, 8> &NotDiscardableComdats) { + SmallPtrSetImpl<const Comdat *> &NotDiscardableComdats) { bool Changed = false; LLVMUsed Used(M); @@ -2460,7 +2751,7 @@ OptimizeGlobalAliases(Module &M, continue; } - // If the aliasee may change at link time, nothing can be done - bail out. + // If the alias can change at link time, nothing can be done - bail out. if (J->isInterposable()) continue; @@ -2486,6 +2777,7 @@ OptimizeGlobalAliases(Module &M, // Give the aliasee the name, linkage and other attributes of the alias. Target->takeName(&*J); Target->setLinkage(J->getLinkage()); + Target->setDSOLocal(J->isDSOLocal()); Target->setVisibility(J->getVisibility()); Target->setDLLStorageClass(J->getDLLStorageClass()); @@ -2619,8 +2911,10 @@ static bool OptimizeEmptyGlobalCXXDtors(Function *CXAAtExitFn) { static bool optimizeGlobalsInModule( Module &M, const DataLayout &DL, TargetLibraryInfo *TLI, + function_ref<TargetTransformInfo &(Function &)> GetTTI, + function_ref<BlockFrequencyInfo &(Function &)> GetBFI, function_ref<DominatorTree &(Function &)> LookupDomTree) { - SmallSet<const Comdat *, 8> NotDiscardableComdats; + SmallPtrSet<const Comdat *, 8> NotDiscardableComdats; bool Changed = false; bool LocalChange = true; while (LocalChange) { @@ -2641,8 +2935,8 @@ static bool optimizeGlobalsInModule( NotDiscardableComdats.insert(C); // Delete functions that are trivially dead, ccc -> fastcc - LocalChange |= - OptimizeFunctions(M, TLI, LookupDomTree, NotDiscardableComdats); + LocalChange |= OptimizeFunctions(M, TLI, GetTTI, GetBFI, LookupDomTree, + NotDiscardableComdats); // Optimize global_ctors list. LocalChange |= optimizeGlobalCtorsList(M, [&](Function *F) { @@ -2679,7 +2973,15 @@ PreservedAnalyses GlobalOptPass::run(Module &M, ModuleAnalysisManager &AM) { auto LookupDomTree = [&FAM](Function &F) -> DominatorTree &{ return FAM.getResult<DominatorTreeAnalysis>(F); }; - if (!optimizeGlobalsInModule(M, DL, &TLI, LookupDomTree)) + auto GetTTI = [&FAM](Function &F) -> TargetTransformInfo & { + return FAM.getResult<TargetIRAnalysis>(F); + }; + + auto GetBFI = [&FAM](Function &F) -> BlockFrequencyInfo & { + return FAM.getResult<BlockFrequencyAnalysis>(F); + }; + + if (!optimizeGlobalsInModule(M, DL, &TLI, GetTTI, GetBFI, LookupDomTree)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); } @@ -2702,12 +3004,22 @@ struct GlobalOptLegacyPass : public ModulePass { auto LookupDomTree = [this](Function &F) -> DominatorTree & { return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); }; - return optimizeGlobalsInModule(M, DL, TLI, LookupDomTree); + auto GetTTI = [this](Function &F) -> TargetTransformInfo & { + return this->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + }; + + auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & { + return this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI(); + }; + + return optimizeGlobalsInModule(M, DL, TLI, GetTTI, GetBFI, LookupDomTree); } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<BlockFrequencyInfoWrapperPass>(); } }; @@ -2718,6 +3030,8 @@ char GlobalOptLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(GlobalOptLegacyPass, "globalopt", "Global Variable Optimizer", false, false) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_END(GlobalOptLegacyPass, "globalopt", "Global Variable Optimizer", false, false) diff --git a/lib/Transforms/IPO/IPO.cpp b/lib/Transforms/IPO/IPO.cpp index d5d35ee89e0e..dce9ee076bc5 100644 --- a/lib/Transforms/IPO/IPO.cpp +++ b/lib/Transforms/IPO/IPO.cpp @@ -40,7 +40,7 @@ void llvm::initializeIPO(PassRegistry &Registry) { initializeInferFunctionAttrsLegacyPassPass(Registry); initializeInternalizeLegacyPassPass(Registry); initializeLoopExtractorPass(Registry); - initializeBlockExtractorPassPass(Registry); + initializeBlockExtractorPass(Registry); initializeSingleLoopExtractorPass(Registry); initializeLowerTypeTestsPass(Registry); initializeMergeFunctionsPass(Registry); @@ -48,6 +48,7 @@ void llvm::initializeIPO(PassRegistry &Registry) { initializePostOrderFunctionAttrsLegacyPassPass(Registry); initializeReversePostOrderFunctionAttrsLegacyPassPass(Registry); initializePruneEHPass(Registry); + initializeIPSCCPLegacyPassPass(Registry); initializeStripDeadPrototypesLegacyPassPass(Registry); initializeStripSymbolsPass(Registry); initializeStripDebugDeclarePass(Registry); diff --git a/lib/Transforms/IPO/InlineSimple.cpp b/lib/Transforms/IPO/InlineSimple.cpp index b259a0abd63c..82bba1e5c93b 100644 --- a/lib/Transforms/IPO/InlineSimple.cpp +++ b/lib/Transforms/IPO/InlineSimple.cpp @@ -31,7 +31,7 @@ using namespace llvm; namespace { -/// \brief Actual inliner pass implementation. +/// Actual inliner pass implementation. /// /// The common implementation of the inlining logic is shared between this /// inliner pass and the always inliner pass. The two passes use different cost diff --git a/lib/Transforms/IPO/Inliner.cpp b/lib/Transforms/IPO/Inliner.cpp index 4449c87ddefa..3da0c2e83eb8 100644 --- a/lib/Transforms/IPO/Inliner.cpp +++ b/lib/Transforms/IPO/Inliner.cpp @@ -35,6 +35,7 @@ #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" @@ -59,7 +60,6 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/ImportedFunctionsInliningStatistics.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include <algorithm> #include <cassert> @@ -208,8 +208,8 @@ static void mergeInlinedArrayAllocas( // Otherwise, we *can* reuse it, RAUW AI into AvailableAlloca and declare // success! - DEBUG(dbgs() << " ***MERGED ALLOCA: " << *AI - << "\n\t\tINTO: " << *AvailableAlloca << '\n'); + LLVM_DEBUG(dbgs() << " ***MERGED ALLOCA: " << *AI + << "\n\t\tINTO: " << *AvailableAlloca << '\n'); // Move affected dbg.declare calls immediately after the new alloca to // avoid the situation when a dbg.declare precedes its alloca. @@ -379,14 +379,14 @@ shouldInline(CallSite CS, function_ref<InlineCost(CallSite CS)> GetInlineCost, Function *Caller = CS.getCaller(); if (IC.isAlways()) { - DEBUG(dbgs() << " Inlining: cost=always" - << ", Call: " << *CS.getInstruction() << "\n"); + LLVM_DEBUG(dbgs() << " Inlining: cost=always" + << ", Call: " << *CS.getInstruction() << "\n"); return IC; } if (IC.isNever()) { - DEBUG(dbgs() << " NOT Inlining: cost=never" - << ", Call: " << *CS.getInstruction() << "\n"); + LLVM_DEBUG(dbgs() << " NOT Inlining: cost=never" + << ", Call: " << *CS.getInstruction() << "\n"); ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline", Call) << NV("Callee", Callee) << " not inlined into " @@ -397,9 +397,9 @@ shouldInline(CallSite CS, function_ref<InlineCost(CallSite CS)> GetInlineCost, } if (!IC) { - DEBUG(dbgs() << " NOT Inlining: cost=" << IC.getCost() - << ", thres=" << IC.getThreshold() - << ", Call: " << *CS.getInstruction() << "\n"); + LLVM_DEBUG(dbgs() << " NOT Inlining: cost=" << IC.getCost() + << ", thres=" << IC.getThreshold() + << ", Call: " << *CS.getInstruction() << "\n"); ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "TooCostly", Call) << NV("Callee", Callee) << " not inlined into " @@ -412,9 +412,9 @@ shouldInline(CallSite CS, function_ref<InlineCost(CallSite CS)> GetInlineCost, int TotalSecondaryCost = 0; if (shouldBeDeferred(Caller, CS, IC, TotalSecondaryCost, GetInlineCost)) { - DEBUG(dbgs() << " NOT Inlining: " << *CS.getInstruction() - << " Cost = " << IC.getCost() - << ", outer Cost = " << TotalSecondaryCost << '\n'); + LLVM_DEBUG(dbgs() << " NOT Inlining: " << *CS.getInstruction() + << " Cost = " << IC.getCost() + << ", outer Cost = " << TotalSecondaryCost << '\n'); ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "IncreaseCostInOtherContexts", Call) @@ -428,9 +428,9 @@ shouldInline(CallSite CS, function_ref<InlineCost(CallSite CS)> GetInlineCost, return None; } - DEBUG(dbgs() << " Inlining: cost=" << IC.getCost() - << ", thres=" << IC.getThreshold() - << ", Call: " << *CS.getInstruction() << '\n'); + LLVM_DEBUG(dbgs() << " Inlining: cost=" << IC.getCost() + << ", thres=" << IC.getThreshold() + << ", Call: " << *CS.getInstruction() << '\n'); return IC; } @@ -470,12 +470,12 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, function_ref<AAResults &(Function &)> AARGetter, ImportedFunctionsInliningStatistics &ImportedFunctionsStats) { SmallPtrSet<Function *, 8> SCCFunctions; - DEBUG(dbgs() << "Inliner visiting SCC:"); + LLVM_DEBUG(dbgs() << "Inliner visiting SCC:"); for (CallGraphNode *Node : SCC) { Function *F = Node->getFunction(); if (F) SCCFunctions.insert(F); - DEBUG(dbgs() << " " << (F ? F->getName() : "INDIRECTNODE")); + LLVM_DEBUG(dbgs() << " " << (F ? F->getName() : "INDIRECTNODE")); } // Scan through and identify all call sites ahead of time so that we only @@ -524,7 +524,7 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, } } - DEBUG(dbgs() << ": " << CallSites.size() << " call sites.\n"); + LLVM_DEBUG(dbgs() << ": " << CallSites.size() << " call sites.\n"); // If there are no calls in this function, exit early. if (CallSites.empty()) @@ -593,7 +593,7 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, // size. This happens because IPSCCP propagates the result out of the // call and then we're left with the dead call. if (IsTriviallyDead) { - DEBUG(dbgs() << " -> Deleting dead call: " << *Instr << "\n"); + LLVM_DEBUG(dbgs() << " -> Deleting dead call: " << *Instr << "\n"); // Update the call graph by deleting the edge from Callee to Caller. CG[Caller]->removeCallEdgeFor(CS); Instr->eraseFromParent(); @@ -657,8 +657,8 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, // callgraph references to the node, we cannot delete it yet, this // could invalidate the CGSCC iterator. CG[Callee]->getNumReferences() == 0) { - DEBUG(dbgs() << " -> Deleting dead function: " << Callee->getName() - << "\n"); + LLVM_DEBUG(dbgs() << " -> Deleting dead function: " + << Callee->getName() << "\n"); CallGraphNode *CalleeNode = CG[Callee]; // Remove any call graph edges from the callee to its callees. @@ -793,6 +793,14 @@ bool LegacyInlinerBase::removeDeadFunctions(CallGraph &CG, return true; } +InlinerPass::~InlinerPass() { + if (ImportedFunctionsStats) { + assert(InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No); + ImportedFunctionsStats->dump(InlinerFunctionImportStats == + InlinerFunctionImportStatsOpts::Verbose); + } +} + PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR) { @@ -804,6 +812,13 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, Module &M = *InitialC.begin()->getFunction().getParent(); ProfileSummaryInfo *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(M); + if (!ImportedFunctionsStats && + InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No) { + ImportedFunctionsStats = + llvm::make_unique<ImportedFunctionsInliningStatistics>(); + ImportedFunctionsStats->setModuleInfo(M); + } + // We use a single common worklist for calls across the entire SCC. We // process these in-order and append new calls introduced during inlining to // the end. @@ -830,8 +845,14 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // incrementally maknig a single function grow in a super linear fashion. SmallVector<std::pair<CallSite, int>, 16> Calls; + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerCGSCCProxy>(InitialC, CG) + .getManager(); + // Populate the initial list of calls in this SCC. for (auto &N : InitialC) { + auto &ORE = + FAM.getResult<OptimizationRemarkEmitterAnalysis>(N.getFunction()); // We want to generally process call sites top-down in order for // simplifications stemming from replacing the call with the returned value // after inlining to be visible to subsequent inlining decisions. @@ -839,9 +860,20 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // Instead we should do an actual RPO walk of the function body. for (Instruction &I : instructions(N.getFunction())) if (auto CS = CallSite(&I)) - if (Function *Callee = CS.getCalledFunction()) + if (Function *Callee = CS.getCalledFunction()) { if (!Callee->isDeclaration()) Calls.push_back({CS, -1}); + else if (!isa<IntrinsicInst>(I)) { + using namespace ore; + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NoDefinition", &I) + << NV("Callee", Callee) << " will not be inlined into " + << NV("Caller", CS.getCaller()) + << " because its definition is unavailable" + << setIsVerbose(); + }); + } + } } if (Calls.empty()) return PreservedAnalyses::all(); @@ -879,7 +911,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, if (F.hasFnAttribute(Attribute::OptimizeNone)) continue; - DEBUG(dbgs() << "Inlining calls in: " << F.getName() << "\n"); + LLVM_DEBUG(dbgs() << "Inlining calls in: " << F.getName() << "\n"); // Get a FunctionAnalysisManager via a proxy for this particular node. We // do this each time we visit a node as the SCC may have changed and as @@ -931,9 +963,9 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // and thus hidden from the full inline history. if (CG.lookupSCC(*CG.lookup(Callee)) == C && UR.InlinedInternalEdges.count({&N, C})) { - DEBUG(dbgs() << "Skipping inlining internal SCC edge from a node " - "previously split out of this SCC by inlining: " - << F.getName() << " -> " << Callee.getName() << "\n"); + LLVM_DEBUG(dbgs() << "Skipping inlining internal SCC edge from a node " + "previously split out of this SCC by inlining: " + << F.getName() << " -> " << Callee.getName() << "\n"); continue; } @@ -992,6 +1024,9 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, Calls.push_back({CS, NewHistoryID}); } + if (InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No) + ImportedFunctionsStats->recordInline(F, Callee); + // Merge the attributes based on the inlining. AttributeFuncs::mergeAttributesForInlining(F, Callee); @@ -1052,7 +1087,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // change. LazyCallGraph::SCC *OldC = C; C = &updateCGAndAnalysisManagerForFunctionPass(CG, *C, N, AM, UR); - DEBUG(dbgs() << "Updated inlining SCC: " << *C << "\n"); + LLVM_DEBUG(dbgs() << "Updated inlining SCC: " << *C << "\n"); RC = &C->getOuterRefSCC(); // If this causes an SCC to split apart into multiple smaller SCCs, there @@ -1070,8 +1105,8 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, if (C != OldC && llvm::any_of(InlinedCallees, [&](Function *Callee) { return CG.lookupSCC(*CG.lookup(*Callee)) == OldC; })) { - DEBUG(dbgs() << "Inlined an internal call edge and split an SCC, " - "retaining this to avoid infinite inlining.\n"); + LLVM_DEBUG(dbgs() << "Inlined an internal call edge and split an SCC, " + "retaining this to avoid infinite inlining.\n"); UR.InlinedInternalEdges.insert({&N, OldC}); } InlinedCallees.clear(); diff --git a/lib/Transforms/IPO/Internalize.cpp b/lib/Transforms/IPO/Internalize.cpp index 26db1465bb26..a6542d28dfd8 100644 --- a/lib/Transforms/IPO/Internalize.cpp +++ b/lib/Transforms/IPO/Internalize.cpp @@ -192,7 +192,7 @@ bool InternalizePass::internalizeModule(Module &M, CallGraph *CG) { ExternalNode->removeOneAbstractEdgeTo((*CG)[&I]); ++NumFunctions; - DEBUG(dbgs() << "Internalizing func " << I.getName() << "\n"); + LLVM_DEBUG(dbgs() << "Internalizing func " << I.getName() << "\n"); } // Never internalize the llvm.used symbol. It is used to implement @@ -221,7 +221,7 @@ bool InternalizePass::internalizeModule(Module &M, CallGraph *CG) { Changed = true; ++NumGlobals; - DEBUG(dbgs() << "Internalized gvar " << GV.getName() << "\n"); + LLVM_DEBUG(dbgs() << "Internalized gvar " << GV.getName() << "\n"); } // Mark all aliases that are not in the api as internal as well. @@ -231,7 +231,7 @@ bool InternalizePass::internalizeModule(Module &M, CallGraph *CG) { Changed = true; ++NumAliases; - DEBUG(dbgs() << "Internalized alias " << GA.getName() << "\n"); + LLVM_DEBUG(dbgs() << "Internalized alias " << GA.getName() << "\n"); } return Changed; diff --git a/lib/Transforms/IPO/LLVMBuild.txt b/lib/Transforms/IPO/LLVMBuild.txt index a8b0f32fd785..54ce23876e66 100644 --- a/lib/Transforms/IPO/LLVMBuild.txt +++ b/lib/Transforms/IPO/LLVMBuild.txt @@ -20,4 +20,4 @@ type = Library name = IPO parent = Transforms library_name = ipo -required_libraries = Analysis BitReader BitWriter Core InstCombine IRReader Linker Object ProfileData Scalar Support TransformUtils Vectorize Instrumentation +required_libraries = AggressiveInstCombine Analysis BitReader BitWriter Core InstCombine IRReader Linker Object ProfileData Scalar Support TransformUtils Vectorize Instrumentation diff --git a/lib/Transforms/IPO/LoopExtractor.cpp b/lib/Transforms/IPO/LoopExtractor.cpp index 36b6bdba2cd0..8c86f7cb806a 100644 --- a/lib/Transforms/IPO/LoopExtractor.cpp +++ b/lib/Transforms/IPO/LoopExtractor.cpp @@ -23,6 +23,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/CodeExtractor.h" #include <fstream> @@ -158,155 +159,3 @@ bool LoopExtractor::runOnLoop(Loop *L, LPPassManager &LPM) { Pass *llvm::createSingleLoopExtractorPass() { return new SingleLoopExtractor(); } - - -// BlockFile - A file which contains a list of blocks that should not be -// extracted. -static cl::opt<std::string> -BlockFile("extract-blocks-file", cl::value_desc("filename"), - cl::desc("A file containing list of basic blocks to not extract"), - cl::Hidden); - -namespace { - /// BlockExtractorPass - This pass is used by bugpoint to extract all blocks - /// from the module into their own functions except for those specified by the - /// BlocksToNotExtract list. - class BlockExtractorPass : public ModulePass { - void LoadFile(const char *Filename); - void SplitLandingPadPreds(Function *F); - - std::vector<BasicBlock*> BlocksToNotExtract; - std::vector<std::pair<std::string, std::string> > BlocksToNotExtractByName; - public: - static char ID; // Pass identification, replacement for typeid - BlockExtractorPass() : ModulePass(ID) { - if (!BlockFile.empty()) - LoadFile(BlockFile.c_str()); - } - - bool runOnModule(Module &M) override; - }; -} - -char BlockExtractorPass::ID = 0; -INITIALIZE_PASS(BlockExtractorPass, "extract-blocks", - "Extract Basic Blocks From Module (for bugpoint use)", - false, false) - -// createBlockExtractorPass - This pass extracts all blocks (except those -// specified in the argument list) from the functions in the module. -// -ModulePass *llvm::createBlockExtractorPass() { - return new BlockExtractorPass(); -} - -void BlockExtractorPass::LoadFile(const char *Filename) { - // Load the BlockFile... - std::ifstream In(Filename); - if (!In.good()) { - errs() << "WARNING: BlockExtractor couldn't load file '" << Filename - << "'!\n"; - return; - } - while (In) { - std::string FunctionName, BlockName; - In >> FunctionName; - In >> BlockName; - if (!BlockName.empty()) - BlocksToNotExtractByName.push_back( - std::make_pair(FunctionName, BlockName)); - } -} - -/// SplitLandingPadPreds - The landing pad needs to be extracted with the invoke -/// instruction. The critical edge breaker will refuse to break critical edges -/// to a landing pad. So do them here. After this method runs, all landing pads -/// should have only one predecessor. -void BlockExtractorPass::SplitLandingPadPreds(Function *F) { - for (Function::iterator I = F->begin(), E = F->end(); I != E; ++I) { - InvokeInst *II = dyn_cast<InvokeInst>(I); - if (!II) continue; - BasicBlock *Parent = II->getParent(); - BasicBlock *LPad = II->getUnwindDest(); - - // Look through the landing pad's predecessors. If one of them ends in an - // 'invoke', then we want to split the landing pad. - bool Split = false; - for (pred_iterator - PI = pred_begin(LPad), PE = pred_end(LPad); PI != PE; ++PI) { - BasicBlock *BB = *PI; - if (BB->isLandingPad() && BB != Parent && - isa<InvokeInst>(Parent->getTerminator())) { - Split = true; - break; - } - } - - if (!Split) continue; - - SmallVector<BasicBlock*, 2> NewBBs; - SplitLandingPadPredecessors(LPad, Parent, ".1", ".2", NewBBs); - } -} - -bool BlockExtractorPass::runOnModule(Module &M) { - if (skipModule(M)) - return false; - - std::set<BasicBlock*> TranslatedBlocksToNotExtract; - for (unsigned i = 0, e = BlocksToNotExtract.size(); i != e; ++i) { - BasicBlock *BB = BlocksToNotExtract[i]; - Function *F = BB->getParent(); - - // Map the corresponding function in this module. - Function *MF = M.getFunction(F->getName()); - assert(MF->getFunctionType() == F->getFunctionType() && "Wrong function?"); - - // Figure out which index the basic block is in its function. - Function::iterator BBI = MF->begin(); - std::advance(BBI, std::distance(F->begin(), Function::iterator(BB))); - TranslatedBlocksToNotExtract.insert(&*BBI); - } - - while (!BlocksToNotExtractByName.empty()) { - // There's no way to find BBs by name without looking at every BB inside - // every Function. Fortunately, this is always empty except when used by - // bugpoint in which case correctness is more important than performance. - - std::string &FuncName = BlocksToNotExtractByName.back().first; - std::string &BlockName = BlocksToNotExtractByName.back().second; - - for (Function &F : M) { - if (F.getName() != FuncName) continue; - - for (BasicBlock &BB : F) { - if (BB.getName() != BlockName) continue; - - TranslatedBlocksToNotExtract.insert(&BB); - } - } - - BlocksToNotExtractByName.pop_back(); - } - - // Now that we know which blocks to not extract, figure out which ones we WANT - // to extract. - std::vector<BasicBlock*> BlocksToExtract; - for (Function &F : M) { - SplitLandingPadPreds(&F); - for (BasicBlock &BB : F) - if (!TranslatedBlocksToNotExtract.count(&BB)) - BlocksToExtract.push_back(&BB); - } - - for (BasicBlock *BlockToExtract : BlocksToExtract) { - SmallVector<BasicBlock*, 2> BlocksToExtractVec; - BlocksToExtractVec.push_back(BlockToExtract); - if (const InvokeInst *II = - dyn_cast<InvokeInst>(BlockToExtract->getTerminator())) - BlocksToExtractVec.push_back(II->getUnwindDest()); - CodeExtractor(BlocksToExtractVec).extractCodeRegion(); - } - - return !BlocksToExtract.empty(); -} diff --git a/lib/Transforms/IPO/LowerTypeTests.cpp b/lib/Transforms/IPO/LowerTypeTests.cpp index 8db7e1e142d2..4f7571884707 100644 --- a/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/lib/Transforms/IPO/LowerTypeTests.cpp @@ -8,6 +8,8 @@ //===----------------------------------------------------------------------===// // // This pass lowers type metadata and calls to the llvm.type.test intrinsic. +// It also ensures that globals are properly laid out for the +// llvm.icall.branch.funnel intrinsic. // See http://llvm.org/docs/TypeMetadata.html for more information. // //===----------------------------------------------------------------------===// @@ -25,6 +27,7 @@ #include "llvm/ADT/TinyPtrVector.h" #include "llvm/ADT/Triple.h" #include "llvm/Analysis/TypeMetadataUtils.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -291,6 +294,33 @@ public: } }; +struct ICallBranchFunnel final + : TrailingObjects<ICallBranchFunnel, GlobalTypeMember *> { + static ICallBranchFunnel *create(BumpPtrAllocator &Alloc, CallInst *CI, + ArrayRef<GlobalTypeMember *> Targets, + unsigned UniqueId) { + auto *Call = static_cast<ICallBranchFunnel *>( + Alloc.Allocate(totalSizeToAlloc<GlobalTypeMember *>(Targets.size()), + alignof(ICallBranchFunnel))); + Call->CI = CI; + Call->UniqueId = UniqueId; + Call->NTargets = Targets.size(); + std::uninitialized_copy(Targets.begin(), Targets.end(), + Call->getTrailingObjects<GlobalTypeMember *>()); + return Call; + } + + CallInst *CI; + ArrayRef<GlobalTypeMember *> targets() const { + return makeArrayRef(getTrailingObjects<GlobalTypeMember *>(), NTargets); + } + + unsigned UniqueId; + +private: + size_t NTargets; +}; + class LowerTypeTestsModule { Module &M; @@ -372,6 +402,7 @@ class LowerTypeTestsModule { const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout); Value *lowerTypeTestCall(Metadata *TypeId, CallInst *CI, const TypeIdLowering &TIL); + void buildBitSetsFromGlobalVariables(ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Globals); unsigned getJumpTableEntrySize(); @@ -383,19 +414,32 @@ class LowerTypeTestsModule { void buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Functions); void buildBitSetsFromFunctionsNative(ArrayRef<Metadata *> TypeIds, - ArrayRef<GlobalTypeMember *> Functions); + ArrayRef<GlobalTypeMember *> Functions); void buildBitSetsFromFunctionsWASM(ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Functions); - void buildBitSetsFromDisjointSet(ArrayRef<Metadata *> TypeIds, - ArrayRef<GlobalTypeMember *> Globals); + void + buildBitSetsFromDisjointSet(ArrayRef<Metadata *> TypeIds, + ArrayRef<GlobalTypeMember *> Globals, + ArrayRef<ICallBranchFunnel *> ICallBranchFunnels); - void replaceWeakDeclarationWithJumpTablePtr(Function *F, Constant *JT); + void replaceWeakDeclarationWithJumpTablePtr(Function *F, Constant *JT, bool IsDefinition); void moveInitializerToModuleConstructor(GlobalVariable *GV); void findGlobalVariableUsersOf(Constant *C, SmallSetVector<GlobalVariable *, 8> &Out); void createJumpTable(Function *F, ArrayRef<GlobalTypeMember *> Functions); + /// replaceCfiUses - Go through the uses list for this definition + /// and make each use point to "V" instead of "this" when the use is outside + /// the block. 'This's use list is expected to have at least one element. + /// Unlike replaceAllUsesWith this function skips blockaddr and direct call + /// uses. + void replaceCfiUses(Function *Old, Value *New, bool IsDefinition); + + /// replaceDirectCalls - Go through the uses list for this definition and + /// replace each use, which is a direct function call. + void replaceDirectCalls(Value *Old, Value *New); + public: LowerTypeTestsModule(Module &M, ModuleSummaryIndex *ExportSummary, const ModuleSummaryIndex *ImportSummary); @@ -427,8 +471,6 @@ struct LowerTypeTests : public ModulePass { } bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; if (UseCommandLine) return LowerTypeTestsModule::runForTesting(M); return LowerTypeTestsModule(M, ExportSummary, ImportSummary).lower(); @@ -729,10 +771,12 @@ void LowerTypeTestsModule::buildBitSetsFromGlobalVariables( // Compute the amount of padding required. uint64_t Padding = NextPowerOf2(InitSize - 1) - InitSize; - // Cap at 128 was found experimentally to have a good data/instruction - // overhead tradeoff. - if (Padding > 128) - Padding = alignTo(InitSize, 128) - InitSize; + // Experiments of different caps with Chromium on both x64 and ARM64 + // have shown that the 32-byte cap generates the smallest binary on + // both platforms while different caps yield similar performance. + // (see https://lists.llvm.org/pipermail/llvm-dev/2018-July/124694.html) + if (Padding > 32) + Padding = alignTo(InitSize, 32) - InitSize; GlobalInits.push_back( ConstantAggregateZero::get(ArrayType::get(Int8Ty, Padding))); @@ -936,14 +980,23 @@ void LowerTypeTestsModule::importTypeTest(CallInst *CI) { void LowerTypeTestsModule::importFunction(Function *F, bool isDefinition) { assert(F->getType()->getAddressSpace() == 0); - // Declaration of a local function - nothing to do. - if (F->isDeclarationForLinker() && isDefinition) - return; - GlobalValue::VisibilityTypes Visibility = F->getVisibility(); std::string Name = F->getName(); - Function *FDecl; + if (F->isDeclarationForLinker() && isDefinition) { + // Non-dso_local functions may be overriden at run time, + // don't short curcuit them + if (F->isDSOLocal()) { + Function *RealF = Function::Create(F->getFunctionType(), + GlobalValue::ExternalLinkage, + Name + ".cfi", &M); + RealF->setVisibility(GlobalVariable::HiddenVisibility); + replaceDirectCalls(F, RealF); + } + return; + } + + Function *FDecl; if (F->isDeclarationForLinker() && !isDefinition) { // Declaration of an external function. FDecl = Function::Create(F->getFunctionType(), GlobalValue::ExternalLinkage, @@ -952,10 +1005,25 @@ void LowerTypeTestsModule::importFunction(Function *F, bool isDefinition) { } else if (isDefinition) { F->setName(Name + ".cfi"); F->setLinkage(GlobalValue::ExternalLinkage); - F->setVisibility(GlobalValue::HiddenVisibility); FDecl = Function::Create(F->getFunctionType(), GlobalValue::ExternalLinkage, Name, &M); FDecl->setVisibility(Visibility); + Visibility = GlobalValue::HiddenVisibility; + + // Delete aliases pointing to this function, they'll be re-created in the + // merged output + SmallVector<GlobalAlias*, 4> ToErase; + for (auto &U : F->uses()) { + if (auto *A = dyn_cast<GlobalAlias>(U.getUser())) { + Function *AliasDecl = Function::Create( + F->getFunctionType(), GlobalValue::ExternalLinkage, "", &M); + AliasDecl->takeName(A); + A->replaceAllUsesWith(AliasDecl); + ToErase.push_back(A); + } + } + for (auto *A : ToErase) + A->eraseFromParent(); } else { // Function definition without type metadata, where some other translation // unit contained a declaration with type metadata. This normally happens @@ -966,9 +1034,13 @@ void LowerTypeTestsModule::importFunction(Function *F, bool isDefinition) { } if (F->isWeakForLinker()) - replaceWeakDeclarationWithJumpTablePtr(F, FDecl); + replaceWeakDeclarationWithJumpTablePtr(F, FDecl, isDefinition); else - F->replaceAllUsesWith(FDecl); + replaceCfiUses(F, FDecl, isDefinition); + + // Set visibility late because it's used in replaceCfiUses() to determine + // whether uses need to to be replaced. + F->setVisibility(Visibility); } void LowerTypeTestsModule::lowerTypeTestCalls( @@ -980,7 +1052,7 @@ void LowerTypeTestsModule::lowerTypeTestCalls( for (Metadata *TypeId : TypeIds) { // Build the bitset. BitSetInfo BSI = buildBitSet(TypeId, GlobalLayout); - DEBUG({ + LLVM_DEBUG({ if (auto MDS = dyn_cast<MDString>(TypeId)) dbgs() << MDS->getString() << ": "; else @@ -1150,7 +1222,7 @@ void LowerTypeTestsModule::findGlobalVariableUsersOf( // Replace all uses of F with (F ? JT : 0). void LowerTypeTestsModule::replaceWeakDeclarationWithJumpTablePtr( - Function *F, Constant *JT) { + Function *F, Constant *JT, bool IsDefinition) { // The target expression can not appear in a constant initializer on most // (all?) targets. Switch to a runtime initializer. SmallSetVector<GlobalVariable *, 8> GlobalVarUsers; @@ -1163,7 +1235,7 @@ void LowerTypeTestsModule::replaceWeakDeclarationWithJumpTablePtr( Function *PlaceholderFn = Function::Create(cast<FunctionType>(F->getValueType()), GlobalValue::ExternalWeakLinkage, "", &M); - F->replaceAllUsesWith(PlaceholderFn); + replaceCfiUses(F, PlaceholderFn, IsDefinition); Constant *Target = ConstantExpr::getSelect( ConstantExpr::getICmp(CmpInst::ICMP_NE, F, @@ -1226,12 +1298,6 @@ void LowerTypeTestsModule::createJumpTable( createJumpTableEntry(AsmOS, ConstraintOS, JumpTableArch, AsmArgs, cast<Function>(Functions[I]->getGlobal())); - // Try to emit the jump table at the end of the text segment. - // Jump table must come after __cfi_check in the cross-dso mode. - // FIXME: this magic section name seems to do the trick. - F->setSection(ObjectFormat == Triple::MachO - ? "__TEXT,__text,regular,pure_instructions" - : ".text.cfi"); // Align the whole table by entry size. F->setAlignment(getJumpTableEntrySize()); // Skip prologue. @@ -1248,6 +1314,8 @@ void LowerTypeTestsModule::createJumpTable( // by Clang for -march=armv7. F->addFnAttr("target-cpu", "cortex-a8"); } + // Make sure we don't emit .eh_frame for this function. + F->addFnAttr(Attribute::NoUnwind); BasicBlock *BB = BasicBlock::Create(M.getContext(), "entry", F); IRBuilder<> IRB(BB); @@ -1389,9 +1457,9 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative( } if (!IsDefinition) { if (F->isWeakForLinker()) - replaceWeakDeclarationWithJumpTablePtr(F, CombinedGlobalElemPtr); + replaceWeakDeclarationWithJumpTablePtr(F, CombinedGlobalElemPtr, IsDefinition); else - F->replaceAllUsesWith(CombinedGlobalElemPtr); + replaceCfiUses(F, CombinedGlobalElemPtr, IsDefinition); } else { assert(F->getType()->getAddressSpace() == 0); @@ -1401,10 +1469,10 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative( FAlias->takeName(F); if (FAlias->hasName()) F->setName(FAlias->getName() + ".cfi"); - F->replaceUsesExceptBlockAddr(FAlias); + replaceCfiUses(F, FAlias, IsDefinition); + if (!F->hasLocalLinkage()) + F->setVisibility(GlobalVariable::HiddenVisibility); } - if (!F->isDeclarationForLinker()) - F->setLinkage(GlobalValue::InternalLinkage); } createJumpTable(JumpTableFn, Functions); @@ -1447,7 +1515,8 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsWASM( } void LowerTypeTestsModule::buildBitSetsFromDisjointSet( - ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Globals) { + ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Globals, + ArrayRef<ICallBranchFunnel *> ICallBranchFunnels) { DenseMap<Metadata *, uint64_t> TypeIdIndices; for (unsigned I = 0; I != TypeIds.size(); ++I) TypeIdIndices[TypeIds[I]] = I; @@ -1456,15 +1525,25 @@ void LowerTypeTestsModule::buildBitSetsFromDisjointSet( // the type identifier. std::vector<std::set<uint64_t>> TypeMembers(TypeIds.size()); unsigned GlobalIndex = 0; + DenseMap<GlobalTypeMember *, uint64_t> GlobalIndices; for (GlobalTypeMember *GTM : Globals) { for (MDNode *Type : GTM->types()) { // Type = { offset, type identifier } - unsigned TypeIdIndex = TypeIdIndices[Type->getOperand(1)]; - TypeMembers[TypeIdIndex].insert(GlobalIndex); + auto I = TypeIdIndices.find(Type->getOperand(1)); + if (I != TypeIdIndices.end()) + TypeMembers[I->second].insert(GlobalIndex); } + GlobalIndices[GTM] = GlobalIndex; GlobalIndex++; } + for (ICallBranchFunnel *JT : ICallBranchFunnels) { + TypeMembers.emplace_back(); + std::set<uint64_t> &TMSet = TypeMembers.back(); + for (GlobalTypeMember *T : JT->targets()) + TMSet.insert(GlobalIndices[T]); + } + // Order the sets of indices by size. The GlobalLayoutBuilder works best // when given small index sets first. std::stable_sort( @@ -1514,7 +1593,7 @@ LowerTypeTestsModule::LowerTypeTestsModule( } bool LowerTypeTestsModule::runForTesting(Module &M) { - ModuleSummaryIndex Summary; + ModuleSummaryIndex Summary(/*HaveGVs=*/false); // Handle the command-line summary arguments. This code is for testing // purposes only, so we handle errors directly. @@ -1549,11 +1628,71 @@ bool LowerTypeTestsModule::runForTesting(Module &M) { return Changed; } +static bool isDirectCall(Use& U) { + auto *Usr = dyn_cast<CallInst>(U.getUser()); + if (Usr) { + CallSite CS(Usr); + if (CS.isCallee(&U)) + return true; + } + return false; +} + +void LowerTypeTestsModule::replaceCfiUses(Function *Old, Value *New, bool IsDefinition) { + SmallSetVector<Constant *, 4> Constants; + auto UI = Old->use_begin(), E = Old->use_end(); + for (; UI != E;) { + Use &U = *UI; + ++UI; + + // Skip block addresses + if (isa<BlockAddress>(U.getUser())) + continue; + + // Skip direct calls to externally defined or non-dso_local functions + if (isDirectCall(U) && (Old->isDSOLocal() || !IsDefinition)) + continue; + + // Must handle Constants specially, we cannot call replaceUsesOfWith on a + // constant because they are uniqued. + if (auto *C = dyn_cast<Constant>(U.getUser())) { + if (!isa<GlobalValue>(C)) { + // Save unique users to avoid processing operand replacement + // more than once. + Constants.insert(C); + continue; + } + } + + U.set(New); + } + + // Process operand replacement of saved constants. + for (auto *C : Constants) + C->handleOperandChange(Old, New); +} + +void LowerTypeTestsModule::replaceDirectCalls(Value *Old, Value *New) { + auto UI = Old->use_begin(), E = Old->use_end(); + for (; UI != E;) { + Use &U = *UI; + ++UI; + + if (!isDirectCall(U)) + continue; + + U.set(New); + } +} + bool LowerTypeTestsModule::lower() { Function *TypeTestFunc = M.getFunction(Intrinsic::getName(Intrinsic::type_test)); - if ((!TypeTestFunc || TypeTestFunc->use_empty()) && !ExportSummary && - !ImportSummary) + Function *ICallBranchFunnelFunc = + M.getFunction(Intrinsic::getName(Intrinsic::icall_branch_funnel)); + if ((!TypeTestFunc || TypeTestFunc->use_empty()) && + (!ICallBranchFunnelFunc || ICallBranchFunnelFunc->use_empty()) && + !ExportSummary && !ImportSummary) return false; if (ImportSummary) { @@ -1565,6 +1704,10 @@ bool LowerTypeTestsModule::lower() { } } + if (ICallBranchFunnelFunc && !ICallBranchFunnelFunc->use_empty()) + report_fatal_error( + "unexpected call to llvm.icall.branch.funnel during import phase"); + SmallVector<Function *, 8> Defs; SmallVector<Function *, 8> Decls; for (auto &F : M) { @@ -1589,8 +1732,8 @@ bool LowerTypeTestsModule::lower() { // Equivalence class set containing type identifiers and the globals that // reference them. This is used to partition the set of type identifiers in // the module into disjoint sets. - using GlobalClassesTy = - EquivalenceClasses<PointerUnion<GlobalTypeMember *, Metadata *>>; + using GlobalClassesTy = EquivalenceClasses< + PointerUnion3<GlobalTypeMember *, Metadata *, ICallBranchFunnel *>>; GlobalClassesTy GlobalClasses; // Verify the type metadata and build a few data structures to let us @@ -1602,33 +1745,61 @@ bool LowerTypeTestsModule::lower() { // identifiers. BumpPtrAllocator Alloc; struct TIInfo { - unsigned Index; + unsigned UniqueId; std::vector<GlobalTypeMember *> RefGlobals; }; DenseMap<Metadata *, TIInfo> TypeIdInfo; - unsigned I = 0; + unsigned CurUniqueId = 0; SmallVector<MDNode *, 2> Types; + // Cross-DSO CFI emits jumptable entries for exported functions as well as + // address taken functions in case they are address taken in other modules. + const bool CrossDsoCfi = M.getModuleFlag("Cross-DSO CFI") != nullptr; + struct ExportedFunctionInfo { CfiFunctionLinkage Linkage; MDNode *FuncMD; // {name, linkage, type[, type...]} }; DenseMap<StringRef, ExportedFunctionInfo> ExportedFunctions; if (ExportSummary) { + // A set of all functions that are address taken by a live global object. + DenseSet<GlobalValue::GUID> AddressTaken; + for (auto &I : *ExportSummary) + for (auto &GVS : I.second.SummaryList) + if (GVS->isLive()) + for (auto &Ref : GVS->refs()) + AddressTaken.insert(Ref.getGUID()); + NamedMDNode *CfiFunctionsMD = M.getNamedMetadata("cfi.functions"); if (CfiFunctionsMD) { for (auto FuncMD : CfiFunctionsMD->operands()) { assert(FuncMD->getNumOperands() >= 2); StringRef FunctionName = cast<MDString>(FuncMD->getOperand(0))->getString(); - if (!ExportSummary->isGUIDLive(GlobalValue::getGUID( - GlobalValue::dropLLVMManglingEscape(FunctionName)))) - continue; CfiFunctionLinkage Linkage = static_cast<CfiFunctionLinkage>( cast<ConstantAsMetadata>(FuncMD->getOperand(1)) ->getValue() ->getUniqueInteger() .getZExtValue()); + const GlobalValue::GUID GUID = GlobalValue::getGUID( + GlobalValue::dropLLVMManglingEscape(FunctionName)); + // Do not emit jumptable entries for functions that are not-live and + // have no live references (and are not exported with cross-DSO CFI.) + if (!ExportSummary->isGUIDLive(GUID)) + continue; + if (!AddressTaken.count(GUID)) { + if (!CrossDsoCfi || Linkage != CFL_Definition) + continue; + + bool Exported = false; + if (auto VI = ExportSummary->getValueInfo(GUID)) + for (auto &GVS : VI.getSummaryList()) + if (GVS->isLive() && !GlobalValue::isLocalLinkage(GVS->linkage())) + Exported = true; + + if (!Exported) + continue; + } auto P = ExportedFunctions.insert({FunctionName, {Linkage, FuncMD}}); if (!P.second && P.first->second.Linkage != CFL_Definition) P.first->second = {Linkage, FuncMD}; @@ -1656,6 +1827,11 @@ bool LowerTypeTestsModule::lower() { F->clearMetadata(); } + // Update the linkage for extern_weak declarations when a definition + // exists. + if (Linkage == CFL_Definition && F->hasExternalWeakLinkage()) + F->setLinkage(GlobalValue::ExternalLinkage); + // If the function in the full LTO module is a declaration, replace its // type metadata with the type metadata we found in cfi.functions. That // metadata is presumed to be more accurate than the metadata attached @@ -1673,28 +1849,37 @@ bool LowerTypeTestsModule::lower() { } } + DenseMap<GlobalObject *, GlobalTypeMember *> GlobalTypeMembers; for (GlobalObject &GO : M.global_objects()) { if (isa<GlobalVariable>(GO) && GO.isDeclarationForLinker()) continue; Types.clear(); GO.getMetadata(LLVMContext::MD_type, Types); - if (Types.empty()) - continue; bool IsDefinition = !GO.isDeclarationForLinker(); bool IsExported = false; - if (isa<Function>(GO) && ExportedFunctions.count(GO.getName())) { - IsDefinition |= ExportedFunctions[GO.getName()].Linkage == CFL_Definition; - IsExported = true; + if (Function *F = dyn_cast<Function>(&GO)) { + if (ExportedFunctions.count(F->getName())) { + IsDefinition |= ExportedFunctions[F->getName()].Linkage == CFL_Definition; + IsExported = true; + // TODO: The logic here checks only that the function is address taken, + // not that the address takers are live. This can be updated to check + // their liveness and emit fewer jumptable entries once monolithic LTO + // builds also emit summaries. + } else if (!F->hasAddressTaken()) { + if (!CrossDsoCfi || !IsDefinition || F->hasLocalLinkage()) + continue; + } } auto *GTM = GlobalTypeMember::create(Alloc, &GO, IsDefinition, IsExported, Types); + GlobalTypeMembers[&GO] = GTM; for (MDNode *Type : Types) { verifyTypeMDNode(&GO, Type); auto &Info = TypeIdInfo[Type->getOperand(1)]; - Info.Index = ++I; + Info.UniqueId = ++CurUniqueId; Info.RefGlobals.push_back(GTM); } } @@ -1731,6 +1916,44 @@ bool LowerTypeTestsModule::lower() { } } + if (ICallBranchFunnelFunc) { + for (const Use &U : ICallBranchFunnelFunc->uses()) { + if (Arch != Triple::x86_64) + report_fatal_error( + "llvm.icall.branch.funnel not supported on this target"); + + auto CI = cast<CallInst>(U.getUser()); + + std::vector<GlobalTypeMember *> Targets; + if (CI->getNumArgOperands() % 2 != 1) + report_fatal_error("number of arguments should be odd"); + + GlobalClassesTy::member_iterator CurSet; + for (unsigned I = 1; I != CI->getNumArgOperands(); I += 2) { + int64_t Offset; + auto *Base = dyn_cast<GlobalObject>(GetPointerBaseWithConstantOffset( + CI->getOperand(I), Offset, M.getDataLayout())); + if (!Base) + report_fatal_error( + "Expected branch funnel operand to be global value"); + + GlobalTypeMember *GTM = GlobalTypeMembers[Base]; + Targets.push_back(GTM); + GlobalClassesTy::member_iterator NewSet = + GlobalClasses.findLeader(GlobalClasses.insert(GTM)); + if (I == 1) + CurSet = NewSet; + else + CurSet = GlobalClasses.unionSets(CurSet, NewSet); + } + + GlobalClasses.unionSets( + CurSet, GlobalClasses.findLeader( + GlobalClasses.insert(ICallBranchFunnel::create( + Alloc, CI, Targets, ++CurUniqueId)))); + } + } + if (ExportSummary) { DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID; for (auto &P : TypeIdInfo) { @@ -1764,54 +1987,124 @@ bool LowerTypeTestsModule::lower() { continue; ++NumTypeIdDisjointSets; - unsigned MaxIndex = 0; + unsigned MaxUniqueId = 0; for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(I); MI != GlobalClasses.member_end(); ++MI) { - if ((*MI).is<Metadata *>()) - MaxIndex = std::max(MaxIndex, TypeIdInfo[MI->get<Metadata *>()].Index); + if (auto *MD = MI->dyn_cast<Metadata *>()) + MaxUniqueId = std::max(MaxUniqueId, TypeIdInfo[MD].UniqueId); + else if (auto *BF = MI->dyn_cast<ICallBranchFunnel *>()) + MaxUniqueId = std::max(MaxUniqueId, BF->UniqueId); } - Sets.emplace_back(I, MaxIndex); + Sets.emplace_back(I, MaxUniqueId); } - std::sort(Sets.begin(), Sets.end(), - [](const std::pair<GlobalClassesTy::iterator, unsigned> &S1, - const std::pair<GlobalClassesTy::iterator, unsigned> &S2) { - return S1.second < S2.second; - }); + llvm::sort(Sets.begin(), Sets.end(), + [](const std::pair<GlobalClassesTy::iterator, unsigned> &S1, + const std::pair<GlobalClassesTy::iterator, unsigned> &S2) { + return S1.second < S2.second; + }); // For each disjoint set we found... for (const auto &S : Sets) { // Build the list of type identifiers in this disjoint set. std::vector<Metadata *> TypeIds; std::vector<GlobalTypeMember *> Globals; + std::vector<ICallBranchFunnel *> ICallBranchFunnels; for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(S.first); MI != GlobalClasses.member_end(); ++MI) { - if ((*MI).is<Metadata *>()) + if (MI->is<Metadata *>()) TypeIds.push_back(MI->get<Metadata *>()); - else + else if (MI->is<GlobalTypeMember *>()) Globals.push_back(MI->get<GlobalTypeMember *>()); + else + ICallBranchFunnels.push_back(MI->get<ICallBranchFunnel *>()); } - // Order type identifiers by global index for determinism. This ordering is - // stable as there is a one-to-one mapping between metadata and indices. - std::sort(TypeIds.begin(), TypeIds.end(), [&](Metadata *M1, Metadata *M2) { - return TypeIdInfo[M1].Index < TypeIdInfo[M2].Index; + // Order type identifiers by unique ID for determinism. This ordering is + // stable as there is a one-to-one mapping between metadata and unique IDs. + llvm::sort(TypeIds.begin(), TypeIds.end(), [&](Metadata *M1, Metadata *M2) { + return TypeIdInfo[M1].UniqueId < TypeIdInfo[M2].UniqueId; }); + // Same for the branch funnels. + llvm::sort(ICallBranchFunnels.begin(), ICallBranchFunnels.end(), + [&](ICallBranchFunnel *F1, ICallBranchFunnel *F2) { + return F1->UniqueId < F2->UniqueId; + }); + // Build bitsets for this disjoint set. - buildBitSetsFromDisjointSet(TypeIds, Globals); + buildBitSetsFromDisjointSet(TypeIds, Globals, ICallBranchFunnels); } allocateByteArrays(); + // Parse alias data to replace stand-in function declarations for aliases + // with an alias to the intended target. + if (ExportSummary) { + if (NamedMDNode *AliasesMD = M.getNamedMetadata("aliases")) { + for (auto AliasMD : AliasesMD->operands()) { + assert(AliasMD->getNumOperands() >= 4); + StringRef AliasName = + cast<MDString>(AliasMD->getOperand(0))->getString(); + StringRef Aliasee = cast<MDString>(AliasMD->getOperand(1))->getString(); + + if (!ExportedFunctions.count(Aliasee) || + ExportedFunctions[Aliasee].Linkage != CFL_Definition || + !M.getNamedAlias(Aliasee)) + continue; + + GlobalValue::VisibilityTypes Visibility = + static_cast<GlobalValue::VisibilityTypes>( + cast<ConstantAsMetadata>(AliasMD->getOperand(2)) + ->getValue() + ->getUniqueInteger() + .getZExtValue()); + bool Weak = + static_cast<bool>(cast<ConstantAsMetadata>(AliasMD->getOperand(3)) + ->getValue() + ->getUniqueInteger() + .getZExtValue()); + + auto *Alias = GlobalAlias::create("", M.getNamedAlias(Aliasee)); + Alias->setVisibility(Visibility); + if (Weak) + Alias->setLinkage(GlobalValue::WeakAnyLinkage); + + if (auto *F = M.getFunction(AliasName)) { + Alias->takeName(F); + F->replaceAllUsesWith(Alias); + F->eraseFromParent(); + } else { + Alias->setName(AliasName); + } + } + } + } + + // Emit .symver directives for exported functions, if they exist. + if (ExportSummary) { + if (NamedMDNode *SymversMD = M.getNamedMetadata("symvers")) { + for (auto Symver : SymversMD->operands()) { + assert(Symver->getNumOperands() >= 2); + StringRef SymbolName = + cast<MDString>(Symver->getOperand(0))->getString(); + StringRef Alias = cast<MDString>(Symver->getOperand(1))->getString(); + + if (!ExportedFunctions.count(SymbolName)) + continue; + + M.appendModuleInlineAsm( + (llvm::Twine(".symver ") + SymbolName + ", " + Alias).str()); + } + } + } + return true; } PreservedAnalyses LowerTypeTestsPass::run(Module &M, ModuleAnalysisManager &AM) { - bool Changed = LowerTypeTestsModule(M, /*ExportSummary=*/nullptr, - /*ImportSummary=*/nullptr) - .lower(); + bool Changed = LowerTypeTestsModule(M, ExportSummary, ImportSummary).lower(); if (!Changed) return PreservedAnalyses::all(); return PreservedAnalyses::none(); diff --git a/lib/Transforms/IPO/MergeFunctions.cpp b/lib/Transforms/IPO/MergeFunctions.cpp index 76b90391fbb1..139941127dee 100644 --- a/lib/Transforms/IPO/MergeFunctions.cpp +++ b/lib/Transforms/IPO/MergeFunctions.cpp @@ -90,7 +90,7 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/Argument.h" @@ -407,10 +407,10 @@ bool MergeFunctions::runOnModule(Module &M) { std::vector<WeakTrackingVH> Worklist; Deferred.swap(Worklist); - DEBUG(doSanityCheck(Worklist)); + LLVM_DEBUG(doSanityCheck(Worklist)); - DEBUG(dbgs() << "size of module: " << M.size() << '\n'); - DEBUG(dbgs() << "size of worklist: " << Worklist.size() << '\n'); + LLVM_DEBUG(dbgs() << "size of module: " << M.size() << '\n'); + LLVM_DEBUG(dbgs() << "size of worklist: " << Worklist.size() << '\n'); // Insert functions and merge them. for (WeakTrackingVH &I : Worklist) { @@ -421,7 +421,7 @@ bool MergeFunctions::runOnModule(Module &M) { Changed |= insert(F); } } - DEBUG(dbgs() << "size of FnTree: " << FnTree.size() << '\n'); + LLVM_DEBUG(dbgs() << "size of FnTree: " << FnTree.size() << '\n'); } while (!Deferred.empty()); FnTree.clear(); @@ -498,19 +498,20 @@ static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) { // parameter debug info, from the entry block. void MergeFunctions::eraseInstsUnrelatedToPDI( std::vector<Instruction *> &PDIUnrelatedWL) { - DEBUG(dbgs() << " Erasing instructions (in reverse order of appearance in " - "entry block) unrelated to parameter debug info from entry " - "block: {\n"); + LLVM_DEBUG( + dbgs() << " Erasing instructions (in reverse order of appearance in " + "entry block) unrelated to parameter debug info from entry " + "block: {\n"); while (!PDIUnrelatedWL.empty()) { Instruction *I = PDIUnrelatedWL.back(); - DEBUG(dbgs() << " Deleting Instruction: "); - DEBUG(I->print(dbgs())); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " Deleting Instruction: "); + LLVM_DEBUG(I->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); I->eraseFromParent(); PDIUnrelatedWL.pop_back(); } - DEBUG(dbgs() << " } // Done erasing instructions unrelated to parameter " - "debug info from entry block. \n"); + LLVM_DEBUG(dbgs() << " } // Done erasing instructions unrelated to parameter " + "debug info from entry block. \n"); } // Reduce G to its entry block. @@ -543,99 +544,113 @@ void MergeFunctions::filterInstsUnrelatedToPDI( for (BasicBlock::iterator BI = GEntryBlock->begin(), BIE = GEntryBlock->end(); BI != BIE; ++BI) { if (auto *DVI = dyn_cast<DbgValueInst>(&*BI)) { - DEBUG(dbgs() << " Deciding: "); - DEBUG(BI->print(dbgs())); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " Deciding: "); + LLVM_DEBUG(BI->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); DILocalVariable *DILocVar = DVI->getVariable(); if (DILocVar->isParameter()) { - DEBUG(dbgs() << " Include (parameter): "); - DEBUG(BI->print(dbgs())); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " Include (parameter): "); + LLVM_DEBUG(BI->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); PDIRelated.insert(&*BI); } else { - DEBUG(dbgs() << " Delete (!parameter): "); - DEBUG(BI->print(dbgs())); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " Delete (!parameter): "); + LLVM_DEBUG(BI->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); } } else if (auto *DDI = dyn_cast<DbgDeclareInst>(&*BI)) { - DEBUG(dbgs() << " Deciding: "); - DEBUG(BI->print(dbgs())); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " Deciding: "); + LLVM_DEBUG(BI->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); DILocalVariable *DILocVar = DDI->getVariable(); if (DILocVar->isParameter()) { - DEBUG(dbgs() << " Parameter: "); - DEBUG(DILocVar->print(dbgs())); + LLVM_DEBUG(dbgs() << " Parameter: "); + LLVM_DEBUG(DILocVar->print(dbgs())); AllocaInst *AI = dyn_cast_or_null<AllocaInst>(DDI->getAddress()); if (AI) { - DEBUG(dbgs() << " Processing alloca users: "); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " Processing alloca users: "); + LLVM_DEBUG(dbgs() << "\n"); for (User *U : AI->users()) { if (StoreInst *SI = dyn_cast<StoreInst>(U)) { if (Value *Arg = SI->getValueOperand()) { if (dyn_cast<Argument>(Arg)) { - DEBUG(dbgs() << " Include: "); - DEBUG(AI->print(dbgs())); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " Include: "); + LLVM_DEBUG(AI->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); PDIRelated.insert(AI); - DEBUG(dbgs() << " Include (parameter): "); - DEBUG(SI->print(dbgs())); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " Include (parameter): "); + LLVM_DEBUG(SI->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); PDIRelated.insert(SI); - DEBUG(dbgs() << " Include: "); - DEBUG(BI->print(dbgs())); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " Include: "); + LLVM_DEBUG(BI->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); PDIRelated.insert(&*BI); } else { - DEBUG(dbgs() << " Delete (!parameter): "); - DEBUG(SI->print(dbgs())); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " Delete (!parameter): "); + LLVM_DEBUG(SI->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); } } } else { - DEBUG(dbgs() << " Defer: "); - DEBUG(U->print(dbgs())); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " Defer: "); + LLVM_DEBUG(U->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); } } } else { - DEBUG(dbgs() << " Delete (alloca NULL): "); - DEBUG(BI->print(dbgs())); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " Delete (alloca NULL): "); + LLVM_DEBUG(BI->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); } } else { - DEBUG(dbgs() << " Delete (!parameter): "); - DEBUG(BI->print(dbgs())); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " Delete (!parameter): "); + LLVM_DEBUG(BI->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); } } else if (dyn_cast<TerminatorInst>(BI) == GEntryBlock->getTerminator()) { - DEBUG(dbgs() << " Will Include Terminator: "); - DEBUG(BI->print(dbgs())); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " Will Include Terminator: "); + LLVM_DEBUG(BI->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); PDIRelated.insert(&*BI); } else { - DEBUG(dbgs() << " Defer: "); - DEBUG(BI->print(dbgs())); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " Defer: "); + LLVM_DEBUG(BI->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); } } - DEBUG(dbgs() - << " Report parameter debug info related/related instructions: {\n"); + LLVM_DEBUG( + dbgs() + << " Report parameter debug info related/related instructions: {\n"); for (BasicBlock::iterator BI = GEntryBlock->begin(), BE = GEntryBlock->end(); BI != BE; ++BI) { Instruction *I = &*BI; if (PDIRelated.find(I) == PDIRelated.end()) { - DEBUG(dbgs() << " !PDIRelated: "); - DEBUG(I->print(dbgs())); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " !PDIRelated: "); + LLVM_DEBUG(I->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); PDIUnrelatedWL.push_back(I); } else { - DEBUG(dbgs() << " PDIRelated: "); - DEBUG(I->print(dbgs())); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " PDIRelated: "); + LLVM_DEBUG(I->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); } } - DEBUG(dbgs() << " }\n"); + LLVM_DEBUG(dbgs() << " }\n"); +} + +// Don't merge tiny functions using a thunk, since it can just end up +// making the function larger. +static bool isThunkProfitable(Function * F) { + if (F->size() == 1) { + if (F->front().size() <= 2) { + LLVM_DEBUG(dbgs() << "isThunkProfitable: " << F->getName() + << " is too small to bother creating a thunk for\n"); + return false; + } + } + return true; } // Replace G with a simple tail call to bitcast(F). Also (unless @@ -647,51 +662,19 @@ void MergeFunctions::filterInstsUnrelatedToPDI( // For better debugability, under MergeFunctionsPDI, we do not modify G's // call sites to point to F even when within the same translation unit. void MergeFunctions::writeThunk(Function *F, Function *G) { - if (!G->isInterposable() && !MergeFunctionsPDI) { - if (G->hasGlobalUnnamedAddr()) { - // G might have been a key in our GlobalNumberState, and it's illegal - // to replace a key in ValueMap<GlobalValue *> with a non-global. - GlobalNumbers.erase(G); - // If G's address is not significant, replace it entirely. - Constant *BitcastF = ConstantExpr::getBitCast(F, G->getType()); - G->replaceAllUsesWith(BitcastF); - } else { - // Redirect direct callers of G to F. (See note on MergeFunctionsPDI - // above). - replaceDirectCallers(G, F); - } - } - - // If G was internal then we may have replaced all uses of G with F. If so, - // stop here and delete G. There's no need for a thunk. (See note on - // MergeFunctionsPDI above). - if (G->hasLocalLinkage() && G->use_empty() && !MergeFunctionsPDI) { - G->eraseFromParent(); - return; - } - - // Don't merge tiny functions using a thunk, since it can just end up - // making the function larger. - if (F->size() == 1) { - if (F->front().size() <= 2) { - DEBUG(dbgs() << "writeThunk: " << F->getName() - << " is too small to bother creating a thunk for\n"); - return; - } - } - BasicBlock *GEntryBlock = nullptr; std::vector<Instruction *> PDIUnrelatedWL; BasicBlock *BB = nullptr; Function *NewG = nullptr; if (MergeFunctionsPDI) { - DEBUG(dbgs() << "writeThunk: (MergeFunctionsPDI) Do not create a new " - "function as thunk; retain original: " - << G->getName() << "()\n"); + LLVM_DEBUG(dbgs() << "writeThunk: (MergeFunctionsPDI) Do not create a new " + "function as thunk; retain original: " + << G->getName() << "()\n"); GEntryBlock = &G->getEntryBlock(); - DEBUG(dbgs() << "writeThunk: (MergeFunctionsPDI) filter parameter related " - "debug info for " - << G->getName() << "() {\n"); + LLVM_DEBUG( + dbgs() << "writeThunk: (MergeFunctionsPDI) filter parameter related " + "debug info for " + << G->getName() << "() {\n"); filterInstsUnrelatedToPDI(GEntryBlock, PDIUnrelatedWL); GEntryBlock->getTerminator()->eraseFromParent(); BB = GEntryBlock; @@ -730,13 +713,15 @@ void MergeFunctions::writeThunk(Function *F, Function *G) { CI->setDebugLoc(CIDbgLoc); RI->setDebugLoc(RIDbgLoc); } else { - DEBUG(dbgs() << "writeThunk: (MergeFunctionsPDI) No DISubprogram for " - << G->getName() << "()\n"); + LLVM_DEBUG( + dbgs() << "writeThunk: (MergeFunctionsPDI) No DISubprogram for " + << G->getName() << "()\n"); } eraseTail(G); eraseInstsUnrelatedToPDI(PDIUnrelatedWL); - DEBUG(dbgs() << "} // End of parameter related debug info filtering for: " - << G->getName() << "()\n"); + LLVM_DEBUG( + dbgs() << "} // End of parameter related debug info filtering for: " + << G->getName() << "()\n"); } else { NewG->copyAttributesFrom(G); NewG->takeName(G); @@ -745,7 +730,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) { G->eraseFromParent(); } - DEBUG(dbgs() << "writeThunk: " << H->getName() << '\n'); + LLVM_DEBUG(dbgs() << "writeThunk: " << H->getName() << '\n'); ++NumThunksWritten; } @@ -754,6 +739,10 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { if (F->isInterposable()) { assert(G->isInterposable()); + if (!isThunkProfitable(F)) { + return; + } + // Make them both thunks to the same internal function. Function *H = Function::Create(F->getFunctionType(), F->getLinkage(), "", F->getParent()); @@ -770,11 +759,41 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { F->setAlignment(MaxAlignment); F->setLinkage(GlobalValue::PrivateLinkage); ++NumDoubleWeak; + ++NumFunctionsMerged; } else { + // For better debugability, under MergeFunctionsPDI, we do not modify G's + // call sites to point to F even when within the same translation unit. + if (!G->isInterposable() && !MergeFunctionsPDI) { + if (G->hasGlobalUnnamedAddr()) { + // G might have been a key in our GlobalNumberState, and it's illegal + // to replace a key in ValueMap<GlobalValue *> with a non-global. + GlobalNumbers.erase(G); + // If G's address is not significant, replace it entirely. + Constant *BitcastF = ConstantExpr::getBitCast(F, G->getType()); + G->replaceAllUsesWith(BitcastF); + } else { + // Redirect direct callers of G to F. (See note on MergeFunctionsPDI + // above). + replaceDirectCallers(G, F); + } + } + + // If G was internal then we may have replaced all uses of G with F. If so, + // stop here and delete G. There's no need for a thunk. (See note on + // MergeFunctionsPDI above). + if (G->hasLocalLinkage() && G->use_empty() && !MergeFunctionsPDI) { + G->eraseFromParent(); + ++NumFunctionsMerged; + return; + } + + if (!isThunkProfitable(F)) { + return; + } + writeThunk(F, G); + ++NumFunctionsMerged; } - - ++NumFunctionsMerged; } /// Replace function F by function G. @@ -806,7 +825,8 @@ bool MergeFunctions::insert(Function *NewFunction) { if (Result.second) { assert(FNodesInTree.count(NewFunction) == 0); FNodesInTree.insert({NewFunction, Result.first}); - DEBUG(dbgs() << "Inserting as unique: " << NewFunction->getName() << '\n'); + LLVM_DEBUG(dbgs() << "Inserting as unique: " << NewFunction->getName() + << '\n'); return false; } @@ -827,8 +847,8 @@ bool MergeFunctions::insert(Function *NewFunction) { assert(OldF.getFunc() != F && "Must have swapped the functions."); } - DEBUG(dbgs() << " " << OldF.getFunc()->getName() - << " == " << NewFunction->getName() << '\n'); + LLVM_DEBUG(dbgs() << " " << OldF.getFunc()->getName() + << " == " << NewFunction->getName() << '\n'); Function *DeleteF = NewFunction; mergeTwoFunctions(OldF.getFunc(), DeleteF); @@ -840,7 +860,7 @@ bool MergeFunctions::insert(Function *NewFunction) { void MergeFunctions::remove(Function *F) { auto I = FNodesInTree.find(F); if (I != FNodesInTree.end()) { - DEBUG(dbgs() << "Deferred " << F->getName()<< ".\n"); + LLVM_DEBUG(dbgs() << "Deferred " << F->getName() << ".\n"); FnTree.erase(I->second); // I->second has been invalidated, remove it from the FNodesInTree map to // preserve the invariant. @@ -854,7 +874,7 @@ void MergeFunctions::remove(Function *F) { void MergeFunctions::removeUsers(Value *V) { std::vector<Value *> Worklist; Worklist.push_back(V); - SmallSet<Value*, 8> Visited; + SmallPtrSet<Value*, 8> Visited; Visited.insert(V); while (!Worklist.empty()) { Value *V = Worklist.back(); diff --git a/lib/Transforms/IPO/PartialInlining.cpp b/lib/Transforms/IPO/PartialInlining.cpp index a9cfd8ded6fb..4907e4b30519 100644 --- a/lib/Transforms/IPO/PartialInlining.cpp +++ b/lib/Transforms/IPO/PartialInlining.cpp @@ -202,10 +202,8 @@ struct PartialInlinerImpl { std::function<AssumptionCache &(Function &)> *GetAC, std::function<TargetTransformInfo &(Function &)> *GTTI, Optional<function_ref<BlockFrequencyInfo &(Function &)>> GBFI, - ProfileSummaryInfo *ProfSI, - std::function<OptimizationRemarkEmitter &(Function &)> *GORE) - : GetAssumptionCache(GetAC), GetTTI(GTTI), GetBFI(GBFI), PSI(ProfSI), - GetORE(GORE) {} + ProfileSummaryInfo *ProfSI) + : GetAssumptionCache(GetAC), GetTTI(GTTI), GetBFI(GBFI), PSI(ProfSI) {} bool run(Module &M); // Main part of the transformation that calls helper functions to find @@ -217,7 +215,7 @@ struct PartialInlinerImpl { // outline function due to code size. std::pair<bool, Function *> unswitchFunction(Function *F); - // This class speculatively clones the the function to be partial inlined. + // This class speculatively clones the function to be partial inlined. // At the end of partial inlining, the remaining callsites to the cloned // function that are not partially inlined will be fixed up to reference // the original function, and the cloned function will be erased. @@ -271,7 +269,6 @@ private: std::function<TargetTransformInfo &(Function &)> *GetTTI; Optional<function_ref<BlockFrequencyInfo &(Function &)>> GetBFI; ProfileSummaryInfo *PSI; - std::function<OptimizationRemarkEmitter &(Function &)> *GetORE; // Return the frequency of the OutlininingBB relative to F's entry point. // The result is no larger than 1 and is represented using BP. @@ -282,7 +279,8 @@ private: // Return true if the callee of CS should be partially inlined with // profit. bool shouldPartialInline(CallSite CS, FunctionCloner &Cloner, - BlockFrequency WeightedOutliningRcost); + BlockFrequency WeightedOutliningRcost, + OptimizationRemarkEmitter &ORE); // Try to inline DuplicateFunction (cloned from F with call to // the OutlinedFunction into its callers. Return true @@ -337,7 +335,7 @@ private: std::unique_ptr<FunctionOutliningInfo> computeOutliningInfo(Function *F); std::unique_ptr<FunctionOutliningMultiRegionInfo> - computeOutliningColdRegionsInfo(Function *F); + computeOutliningColdRegionsInfo(Function *F, OptimizationRemarkEmitter &ORE); }; struct PartialInlinerLegacyPass : public ModulePass { @@ -362,7 +360,6 @@ struct PartialInlinerLegacyPass : public ModulePass { &getAnalysis<TargetTransformInfoWrapperPass>(); ProfileSummaryInfo *PSI = getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); - std::unique_ptr<OptimizationRemarkEmitter> UPORE; std::function<AssumptionCache &(Function &)> GetAssumptionCache = [&ACT](Function &F) -> AssumptionCache & { @@ -374,14 +371,7 @@ struct PartialInlinerLegacyPass : public ModulePass { return TTIWP->getTTI(F); }; - std::function<OptimizationRemarkEmitter &(Function &)> GetORE = - [&UPORE](Function &F) -> OptimizationRemarkEmitter & { - UPORE.reset(new OptimizationRemarkEmitter(&F)); - return *UPORE.get(); - }; - - return PartialInlinerImpl(&GetAssumptionCache, &GetTTI, NoneType::None, PSI, - &GetORE) + return PartialInlinerImpl(&GetAssumptionCache, &GetTTI, NoneType::None, PSI) .run(M); } }; @@ -389,7 +379,8 @@ struct PartialInlinerLegacyPass : public ModulePass { } // end anonymous namespace std::unique_ptr<FunctionOutliningMultiRegionInfo> -PartialInlinerImpl::computeOutliningColdRegionsInfo(Function *F) { +PartialInlinerImpl::computeOutliningColdRegionsInfo(Function *F, + OptimizationRemarkEmitter &ORE) { BasicBlock *EntryBlock = &F->front(); DominatorTree DT(*F); @@ -403,8 +394,6 @@ PartialInlinerImpl::computeOutliningColdRegionsInfo(Function *F) { } else BFI = &(*GetBFI)(*F); - auto &ORE = (*GetORE)(*F); - // Return if we don't have profiling information. if (!PSI->hasInstrumentationProfile()) return std::unique_ptr<FunctionOutliningMultiRegionInfo>(); @@ -414,8 +403,7 @@ PartialInlinerImpl::computeOutliningColdRegionsInfo(Function *F) { auto IsSingleEntry = [](SmallVectorImpl<BasicBlock *> &BlockList) { BasicBlock *Dom = BlockList.front(); - return BlockList.size() > 1 && - std::distance(pred_begin(Dom), pred_end(Dom)) == 1; + return BlockList.size() > 1 && pred_size(Dom) == 1; }; auto IsSingleExit = @@ -567,10 +555,6 @@ PartialInlinerImpl::computeOutliningInfo(Function *F) { return is_contained(successors(BB), Succ); }; - auto SuccSize = [](BasicBlock *BB) { - return std::distance(succ_begin(BB), succ_end(BB)); - }; - auto IsReturnBlock = [](BasicBlock *BB) { TerminatorInst *TI = BB->getTerminator(); return isa<ReturnInst>(TI); @@ -607,7 +591,7 @@ PartialInlinerImpl::computeOutliningInfo(Function *F) { if (OutliningInfo->GetNumInlinedBlocks() >= MaxNumInlineBlocks) break; - if (SuccSize(CurrEntry) != 2) + if (succ_size(CurrEntry) != 2) break; BasicBlock *Succ1 = *succ_begin(CurrEntry); @@ -681,7 +665,7 @@ PartialInlinerImpl::computeOutliningInfo(Function *F) { // peeling off dominating blocks from the outlining region: while (OutliningInfo->GetNumInlinedBlocks() < MaxNumInlineBlocks) { BasicBlock *Cand = OutliningInfo->NonReturnBlock; - if (SuccSize(Cand) != 2) + if (succ_size(Cand) != 2) break; if (HasNonEntryPred(Cand)) @@ -766,19 +750,19 @@ PartialInlinerImpl::getOutliningCallBBRelativeFreq(FunctionCloner &Cloner) { bool PartialInlinerImpl::shouldPartialInline( CallSite CS, FunctionCloner &Cloner, - BlockFrequency WeightedOutliningRcost) { + BlockFrequency WeightedOutliningRcost, + OptimizationRemarkEmitter &ORE) { using namespace ore; - if (SkipCostAnalysis) - return true; - Instruction *Call = CS.getInstruction(); Function *Callee = CS.getCalledFunction(); assert(Callee == Cloner.ClonedFunc); + if (SkipCostAnalysis) + return isInlineViable(*Callee); + Function *Caller = CS.getCaller(); auto &CalleeTTI = (*GetTTI)(*Callee); - auto &ORE = (*GetORE)(*Caller); InlineCost IC = getInlineCost(CS, getInlineParams(), CalleeTTI, *GetAssumptionCache, GetBFI, PSI, &ORE); @@ -1270,14 +1254,14 @@ std::pair<bool, Function *> PartialInlinerImpl::unswitchFunction(Function *F) { if (F->user_begin() == F->user_end()) return {false, nullptr}; - auto &ORE = (*GetORE)(*F); + OptimizationRemarkEmitter ORE(F); // Only try to outline cold regions if we have a profile summary, which // implies we have profiling information. if (PSI->hasProfileSummary() && F->hasProfileData() && !DisableMultiRegionPartialInline) { std::unique_ptr<FunctionOutliningMultiRegionInfo> OMRI = - computeOutliningColdRegionsInfo(F); + computeOutliningColdRegionsInfo(F, ORE); if (OMRI) { FunctionCloner Cloner(F, OMRI.get(), ORE); @@ -1357,11 +1341,11 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { // inlining the function with outlining (The inliner uses the size increase to // model the cost of inlining a callee). if (!SkipCostAnalysis && Cloner.OutlinedRegionCost < SizeCost) { - auto &ORE = (*GetORE)(*Cloner.OrigFunc); + OptimizationRemarkEmitter OrigFuncORE(Cloner.OrigFunc); DebugLoc DLoc; BasicBlock *Block; std::tie(DLoc, Block) = getOneDebugLoc(Cloner.ClonedFunc); - ORE.emit([&]() { + OrigFuncORE.emit([&]() { return OptimizationRemarkAnalysis(DEBUG_TYPE, "OutlineRegionTooSmall", DLoc, Block) << ore::NV("Function", Cloner.OrigFunc) @@ -1384,7 +1368,8 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { if (CalleeEntryCount) computeCallsiteToProfCountMap(Cloner.ClonedFunc, CallSiteToProfCountMap); - uint64_t CalleeEntryCountV = (CalleeEntryCount ? *CalleeEntryCount : 0); + uint64_t CalleeEntryCountV = + (CalleeEntryCount ? CalleeEntryCount.getCount() : 0); bool AnyInline = false; for (User *User : Users) { @@ -1393,11 +1378,10 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { if (IsLimitReached()) continue; - - if (!shouldPartialInline(CS, Cloner, WeightedRcost)) + OptimizationRemarkEmitter CallerORE(CS.getCaller()); + if (!shouldPartialInline(CS, Cloner, WeightedRcost, CallerORE)) continue; - auto &ORE = (*GetORE)(*CS.getCaller()); // Construct remark before doing the inlining, as after successful inlining // the callsite is removed. OptimizationRemark OR(DEBUG_TYPE, "PartiallyInlined", CS.getInstruction()); @@ -1412,7 +1396,7 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { : nullptr))) continue; - ORE.emit(OR); + CallerORE.emit(OR); // Now update the entry count: if (CalleeEntryCountV && CallSiteToProfCountMap.count(User)) { @@ -1433,9 +1417,10 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { if (AnyInline) { Cloner.IsFunctionInlined = true; if (CalleeEntryCount) - Cloner.OrigFunc->setEntryCount(CalleeEntryCountV); - auto &ORE = (*GetORE)(*Cloner.OrigFunc); - ORE.emit([&]() { + Cloner.OrigFunc->setEntryCount( + CalleeEntryCount.setCount(CalleeEntryCountV)); + OptimizationRemarkEmitter OrigFuncORE(Cloner.OrigFunc); + OrigFuncORE.emit([&]() { return OptimizationRemark(DEBUG_TYPE, "PartiallyInlined", Cloner.OrigFunc) << "Partially inlined into at least one caller"; }); @@ -1517,14 +1502,9 @@ PreservedAnalyses PartialInlinerPass::run(Module &M, return FAM.getResult<TargetIRAnalysis>(F); }; - std::function<OptimizationRemarkEmitter &(Function &)> GetORE = - [&FAM](Function &F) -> OptimizationRemarkEmitter & { - return FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); - }; - ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M); - if (PartialInlinerImpl(&GetAssumptionCache, &GetTTI, {GetBFI}, PSI, &GetORE) + if (PartialInlinerImpl(&GetAssumptionCache, &GetTTI, {GetBFI}, PSI) .run(M)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); diff --git a/lib/Transforms/IPO/PassManagerBuilder.cpp b/lib/Transforms/IPO/PassManagerBuilder.cpp index 3855e6245d8e..5ced6481996a 100644 --- a/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -29,14 +29,18 @@ #include "llvm/IR/Verifier.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ManagedStatic.h" +#include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/ForceFunctionAttrs.h" #include "llvm/Transforms/IPO/FunctionAttrs.h" #include "llvm/Transforms/IPO/InferFunctionAttrs.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/GVN.h" +#include "llvm/Transforms/Scalar/InstSimplifyPass.h" #include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Vectorize.h" using namespace llvm; @@ -92,6 +96,10 @@ static cl::opt<bool> EnableLoopInterchange( "enable-loopinterchange", cl::init(false), cl::Hidden, cl::desc("Enable the new, experimental LoopInterchange Pass")); +static cl::opt<bool> EnableUnrollAndJam("enable-unroll-and-jam", + cl::init(false), cl::Hidden, + cl::desc("Enable Unroll And Jam Pass")); + static cl::opt<bool> EnablePrepareForThinLTO("prepare-for-thinlto", cl::init(false), cl::Hidden, cl::desc("Enable preparation for ThinLTO.")); @@ -135,10 +143,10 @@ static cl::opt<bool> cl::Hidden, cl::desc("Disable shrink-wrap library calls")); -static cl::opt<bool> - EnableSimpleLoopUnswitch("enable-simple-loop-unswitch", cl::init(false), - cl::Hidden, - cl::desc("Enable the simple loop unswitch pass.")); +static cl::opt<bool> EnableSimpleLoopUnswitch( + "enable-simple-loop-unswitch", cl::init(false), cl::Hidden, + cl::desc("Enable the simple loop unswitch pass. Also enables independent " + "cleanup passes integrated into the loop pass manager pipeline.")); static cl::opt<bool> EnableGVNSink( "enable-gvn-sink", cl::init(false), cl::Hidden, @@ -318,6 +326,8 @@ void PassManagerBuilder::addFunctionSimplificationPasses( MPM.add(createCorrelatedValuePropagationPass()); // Propagate conditionals MPM.add(createCFGSimplificationPass()); // Merge & remove BBs // Combine silly seq's + if (OptLevel > 2) + MPM.add(createAggressiveInstCombinerPass()); addInstructionCombiningPass(MPM); if (SizeLevel == 0 && !DisableLibCallsShrinkWrap) MPM.add(createLibCallsShrinkWrapPass()); @@ -330,6 +340,15 @@ void PassManagerBuilder::addFunctionSimplificationPasses( MPM.add(createTailCallEliminationPass()); // Eliminate tail calls MPM.add(createCFGSimplificationPass()); // Merge & remove BBs MPM.add(createReassociatePass()); // Reassociate expressions + + // Begin the loop pass pipeline. + if (EnableSimpleLoopUnswitch) { + // The simple loop unswitch pass relies on separate cleanup passes. Schedule + // them first so when we re-process a loop they run before other loop + // passes. + MPM.add(createLoopInstSimplifyPass()); + MPM.add(createLoopSimplifyCFGPass()); + } // Rotate Loop - disable header duplication at -Oz MPM.add(createLoopRotatePass(SizeLevel == 2 ? 0 : -1)); MPM.add(createLICMPass()); // Hoist loop invariants @@ -337,20 +356,26 @@ void PassManagerBuilder::addFunctionSimplificationPasses( MPM.add(createSimpleLoopUnswitchLegacyPass()); else MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3, DivergentTarget)); + // FIXME: We break the loop pass pipeline here in order to do full + // simplify-cfg. Eventually loop-simplifycfg should be enhanced to replace the + // need for this. MPM.add(createCFGSimplificationPass()); addInstructionCombiningPass(MPM); + // We resume loop passes creating a second loop pipeline here. MPM.add(createIndVarSimplifyPass()); // Canonicalize indvars MPM.add(createLoopIdiomPass()); // Recognize idioms like memset. addExtensionsToPM(EP_LateLoopOptimizations, MPM); MPM.add(createLoopDeletionPass()); // Delete dead loops if (EnableLoopInterchange) { + // FIXME: These are function passes and break the loop pass pipeline. MPM.add(createLoopInterchangePass()); // Interchange loops MPM.add(createCFGSimplificationPass()); } if (!DisableUnrollLoops) MPM.add(createSimpleLoopUnrollPass(OptLevel)); // Unroll small loops addExtensionsToPM(EP_LoopOptimizerEnd, MPM); + // This ends the loop pass pipelines. if (OptLevel > 1) { MPM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds @@ -431,7 +456,7 @@ void PassManagerBuilder::populateModulePassManager( // This has to be done after we add the extensions to the pass manager // as there could be passes (e.g. Adddress sanitizer) which introduce // new unnamed globals. - if (PrepareForThinLTO) + if (PrepareForLTO || PrepareForThinLTO) MPM.add(createNameAnonGlobalPass()); return; } @@ -648,6 +673,13 @@ void PassManagerBuilder::populateModulePassManager( addInstructionCombiningPass(MPM); if (!DisableUnrollLoops) { + if (EnableUnrollAndJam) { + // Unroll and Jam. We do this before unroll but need to be in a separate + // loop pass manager in order for the outer loop to be processed by + // unroll and jam before the inner loop is unrolled. + MPM.add(createLoopUnrollAndJamPass(OptLevel)); + } + MPM.add(createLoopUnrollPass(OptLevel)); // Unroll small loops // LoopUnroll may generate some redundency to cleanup. @@ -683,7 +715,7 @@ void PassManagerBuilder::populateModulePassManager( // result too early. MPM.add(createLoopSinkPass()); // Get rid of LCSSA nodes. - MPM.add(createInstructionSimplifierPass()); + MPM.add(createInstSimplifyLegacyPass()); // This hoists/decomposes div/rem ops. It should run after other sink/hoist // passes to avoid re-sinking, but before SimplifyCFG because it can allow @@ -695,6 +727,10 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createCFGSimplificationPass()); addExtensionsToPM(EP_OptimizerLast, MPM); + + // Rename anon globals to be able to handle them in the summary + if (PrepareForLTO) + MPM.add(createNameAnonGlobalPass()); } void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { @@ -765,6 +801,8 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { // simplification opportunities, and both can propagate functions through // function pointers. When this happens, we often have to resolve varargs // calls, etc, so let instcombine do this. + if (OptLevel > 2) + PM.add(createAggressiveInstCombinerPass()); addInstructionCombiningPass(PM); addExtensionsToPM(EP_Peephole, PM); @@ -865,6 +903,8 @@ void PassManagerBuilder::addLateLTOOptimizationPasses( void PassManagerBuilder::populateThinLTOPassManager( legacy::PassManagerBase &PM) { PerformThinLTO = true; + if (LibraryInfo) + PM.add(new TargetLibraryInfoWrapperPass(*LibraryInfo)); if (VerifyInput) PM.add(createVerifierPass()); diff --git a/lib/Transforms/IPO/PruneEH.cpp b/lib/Transforms/IPO/PruneEH.cpp index 46b088189040..27d791857314 100644 --- a/lib/Transforms/IPO/PruneEH.cpp +++ b/lib/Transforms/IPO/PruneEH.cpp @@ -19,6 +19,7 @@ #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/Analysis/EHPersonalities.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" @@ -27,7 +28,6 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" -#include "llvm/Transforms/Utils/Local.h" #include <algorithm> using namespace llvm; diff --git a/lib/Transforms/IPO/SCCP.cpp b/lib/Transforms/IPO/SCCP.cpp new file mode 100644 index 000000000000..cc53c4b8c46f --- /dev/null +++ b/lib/Transforms/IPO/SCCP.cpp @@ -0,0 +1,58 @@ +#include "llvm/Transforms/IPO/SCCP.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/Scalar/SCCP.h" + +using namespace llvm; + +PreservedAnalyses IPSCCPPass::run(Module &M, ModuleAnalysisManager &AM) { + const DataLayout &DL = M.getDataLayout(); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); + if (!runIPSCCP(M, DL, &TLI)) + return PreservedAnalyses::all(); + return PreservedAnalyses::none(); +} + +namespace { + +//===--------------------------------------------------------------------===// +// +/// IPSCCP Class - This class implements interprocedural Sparse Conditional +/// Constant Propagation. +/// +class IPSCCPLegacyPass : public ModulePass { +public: + static char ID; + + IPSCCPLegacyPass() : ModulePass(ID) { + initializeIPSCCPLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + const DataLayout &DL = M.getDataLayout(); + const TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + return runIPSCCP(M, DL, TLI); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + } +}; + +} // end anonymous namespace + +char IPSCCPLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(IPSCCPLegacyPass, "ipsccp", + "Interprocedural Sparse Conditional Constant Propagation", + false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(IPSCCPLegacyPass, "ipsccp", + "Interprocedural Sparse Conditional Constant Propagation", + false, false) + +// createIPSCCPPass - This is the public interface to this file. +ModulePass *llvm::createIPSCCPPass() { return new IPSCCPLegacyPass(); } diff --git a/lib/Transforms/IPO/SampleProfile.cpp b/lib/Transforms/IPO/SampleProfile.cpp index a69c009e1a54..dcd24595f7ea 100644 --- a/lib/Transforms/IPO/SampleProfile.cpp +++ b/lib/Transforms/IPO/SampleProfile.cpp @@ -22,7 +22,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/SampleProfile.h" +#include "llvm/Transforms/IPO/SampleProfile.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" @@ -37,6 +37,8 @@ #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -85,7 +87,7 @@ using namespace llvm; using namespace sampleprof; - +using ProfileCount = Function::ProfileCount; #define DEBUG_TYPE "sample-profile" // Command line option to specify the file to read samples from. This is @@ -109,10 +111,10 @@ static cl::opt<unsigned> SampleProfileSampleCoverage( cl::desc("Emit a warning if less than N% of samples in the input profile " "are matched to the IR.")); -static cl::opt<double> SampleProfileHotThreshold( - "sample-profile-inline-hot-threshold", cl::init(0.1), cl::value_desc("N"), - cl::desc("Inlined functions that account for more than N% of all samples " - "collected in the parent function, will be inlined again.")); +static cl::opt<bool> NoWarnSampleUnused( + "no-warn-sample-unused", cl::init(false), cl::Hidden, + cl::desc("Use this option to turn off/on warnings about function with " + "samples but without debug information to use those samples. ")); namespace { @@ -130,10 +132,13 @@ public: bool markSamplesUsed(const FunctionSamples *FS, uint32_t LineOffset, uint32_t Discriminator, uint64_t Samples); unsigned computeCoverage(unsigned Used, unsigned Total) const; - unsigned countUsedRecords(const FunctionSamples *FS) const; - unsigned countBodyRecords(const FunctionSamples *FS) const; + unsigned countUsedRecords(const FunctionSamples *FS, + ProfileSummaryInfo *PSI) const; + unsigned countBodyRecords(const FunctionSamples *FS, + ProfileSummaryInfo *PSI) const; uint64_t getTotalUsedSamples() const { return TotalUsedSamples; } - uint64_t countBodySamples(const FunctionSamples *FS) const; + uint64_t countBodySamples(const FunctionSamples *FS, + ProfileSummaryInfo *PSI) const; void clear() { SampleCoverage.clear(); @@ -170,7 +175,7 @@ private: uint64_t TotalUsedSamples = 0; }; -/// \brief Sample profile pass. +/// Sample profile pass. /// /// This pass reads profile data from the file specified by /// -sample-profile-file and annotates every affected function with the @@ -186,7 +191,8 @@ public: IsThinLTOPreLink(IsThinLTOPreLink) {} bool doInitialization(Module &M); - bool runOnModule(Module &M, ModuleAnalysisManager *AM); + bool runOnModule(Module &M, ModuleAnalysisManager *AM, + ProfileSummaryInfo *_PSI); void dump() { Reader->dump(); } @@ -217,28 +223,27 @@ protected: void buildEdges(Function &F); bool propagateThroughEdges(Function &F, bool UpdateBlockCount); void computeDominanceAndLoopInfo(Function &F); - unsigned getOffset(const DILocation *DIL) const; void clearFunctionData(); - /// \brief Map basic blocks to their computed weights. + /// Map basic blocks to their computed weights. /// /// The weight of a basic block is defined to be the maximum /// of all the instruction weights in that block. BlockWeightMap BlockWeights; - /// \brief Map edges to their computed weights. + /// Map edges to their computed weights. /// /// Edge weights are computed by propagating basic block weights in /// SampleProfile::propagateWeights. EdgeWeightMap EdgeWeights; - /// \brief Set of visited blocks during propagation. + /// Set of visited blocks during propagation. SmallPtrSet<const BasicBlock *, 32> VisitedBlocks; - /// \brief Set of visited edges during propagation. + /// Set of visited edges during propagation. SmallSet<Edge, 32> VisitedEdges; - /// \brief Equivalence classes for block weights. + /// Equivalence classes for block weights. /// /// Two blocks BB1 and BB2 are in the same equivalence class if they /// dominate and post-dominate each other, and they are in the same loop @@ -252,47 +257,50 @@ protected: /// is one-to-one mapping. StringMap<Function *> SymbolMap; - /// \brief Dominance, post-dominance and loop information. + /// Dominance, post-dominance and loop information. std::unique_ptr<DominatorTree> DT; - std::unique_ptr<PostDomTreeBase<BasicBlock>> PDT; + std::unique_ptr<PostDominatorTree> PDT; std::unique_ptr<LoopInfo> LI; std::function<AssumptionCache &(Function &)> GetAC; std::function<TargetTransformInfo &(Function &)> GetTTI; - /// \brief Predecessors for each basic block in the CFG. + /// Predecessors for each basic block in the CFG. BlockEdgeMap Predecessors; - /// \brief Successors for each basic block in the CFG. + /// Successors for each basic block in the CFG. BlockEdgeMap Successors; SampleCoverageTracker CoverageTracker; - /// \brief Profile reader object. + /// Profile reader object. std::unique_ptr<SampleProfileReader> Reader; - /// \brief Samples collected for the body of this function. + /// Samples collected for the body of this function. FunctionSamples *Samples = nullptr; - /// \brief Name of the profile file to load. + /// Name of the profile file to load. std::string Filename; - /// \brief Flag indicating whether the profile input loaded successfully. + /// Flag indicating whether the profile input loaded successfully. bool ProfileIsValid = false; - /// \brief Flag indicating if the pass is invoked in ThinLTO compile phase. + /// Flag indicating if the pass is invoked in ThinLTO compile phase. /// /// In this phase, in annotation, we should not promote indirect calls. /// Instead, we will mark GUIDs that needs to be annotated to the function. bool IsThinLTOPreLink; - /// \brief Total number of samples collected in this profile. + /// Profile Summary Info computed from sample profile. + ProfileSummaryInfo *PSI = nullptr; + + /// Total number of samples collected in this profile. /// /// This is the sum of all the samples collected in all the functions executed /// at runtime. uint64_t TotalCollectedSamples = 0; - /// \brief Optimization Remark Emitter used to emit diagnostic remarks. + /// Optimization Remark Emitter used to emit diagnostic remarks. OptimizationRemarkEmitter *ORE = nullptr; }; @@ -326,6 +334,7 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addRequired<ProfileSummaryInfoWrapperPass>(); } private: @@ -336,7 +345,7 @@ private: } // end anonymous namespace -/// Return true if the given callsite is hot wrt to its caller. +/// Return true if the given callsite is hot wrt to hot cutoff threshold. /// /// Functions that were inlined in the original binary will be represented /// in the inline stack in the sample profile. If the profile shows that @@ -344,28 +353,17 @@ private: /// frequently), then we will recreate the inline decision and apply the /// profile from the inlined callsite. /// -/// To decide whether an inlined callsite is hot, we compute the fraction -/// of samples used by the callsite with respect to the total number of samples -/// collected in the caller. -/// -/// If that fraction is larger than the default given by -/// SampleProfileHotThreshold, the callsite will be inlined again. -static bool callsiteIsHot(const FunctionSamples *CallerFS, - const FunctionSamples *CallsiteFS) { +/// To decide whether an inlined callsite is hot, we compare the callsite +/// sample count with the hot cutoff computed by ProfileSummaryInfo, it is +/// regarded as hot if the count is above the cutoff value. +static bool callsiteIsHot(const FunctionSamples *CallsiteFS, + ProfileSummaryInfo *PSI) { if (!CallsiteFS) return false; // The callsite was not inlined in the original binary. - uint64_t ParentTotalSamples = CallerFS->getTotalSamples(); - if (ParentTotalSamples == 0) - return false; // Avoid division by zero. - + assert(PSI && "PSI is expected to be non null"); uint64_t CallsiteTotalSamples = CallsiteFS->getTotalSamples(); - if (CallsiteTotalSamples == 0) - return false; // Callsite is trivially cold. - - double PercentSamples = - (double)CallsiteTotalSamples / (double)ParentTotalSamples * 100.0; - return PercentSamples >= SampleProfileHotThreshold; + return PSI->isHotCount(CallsiteTotalSamples); } /// Mark as used the sample record for the given function samples at @@ -388,7 +386,8 @@ bool SampleCoverageTracker::markSamplesUsed(const FunctionSamples *FS, /// /// This count does not include records from cold inlined callsites. unsigned -SampleCoverageTracker::countUsedRecords(const FunctionSamples *FS) const { +SampleCoverageTracker::countUsedRecords(const FunctionSamples *FS, + ProfileSummaryInfo *PSI) const { auto I = SampleCoverage.find(FS); // The size of the coverage map for FS represents the number of records @@ -401,8 +400,8 @@ SampleCoverageTracker::countUsedRecords(const FunctionSamples *FS) const { for (const auto &I : FS->getCallsiteSamples()) for (const auto &J : I.second) { const FunctionSamples *CalleeSamples = &J.second; - if (callsiteIsHot(FS, CalleeSamples)) - Count += countUsedRecords(CalleeSamples); + if (callsiteIsHot(CalleeSamples, PSI)) + Count += countUsedRecords(CalleeSamples, PSI); } return Count; @@ -412,15 +411,16 @@ SampleCoverageTracker::countUsedRecords(const FunctionSamples *FS) const { /// /// This count does not include records from cold inlined callsites. unsigned -SampleCoverageTracker::countBodyRecords(const FunctionSamples *FS) const { +SampleCoverageTracker::countBodyRecords(const FunctionSamples *FS, + ProfileSummaryInfo *PSI) const { unsigned Count = FS->getBodySamples().size(); // Only count records in hot callsites. for (const auto &I : FS->getCallsiteSamples()) for (const auto &J : I.second) { const FunctionSamples *CalleeSamples = &J.second; - if (callsiteIsHot(FS, CalleeSamples)) - Count += countBodyRecords(CalleeSamples); + if (callsiteIsHot(CalleeSamples, PSI)) + Count += countBodyRecords(CalleeSamples, PSI); } return Count; @@ -430,7 +430,8 @@ SampleCoverageTracker::countBodyRecords(const FunctionSamples *FS) const { /// /// This count does not include samples from cold inlined callsites. uint64_t -SampleCoverageTracker::countBodySamples(const FunctionSamples *FS) const { +SampleCoverageTracker::countBodySamples(const FunctionSamples *FS, + ProfileSummaryInfo *PSI) const { uint64_t Total = 0; for (const auto &I : FS->getBodySamples()) Total += I.second.getSamples(); @@ -439,8 +440,8 @@ SampleCoverageTracker::countBodySamples(const FunctionSamples *FS) const { for (const auto &I : FS->getCallsiteSamples()) for (const auto &J : I.second) { const FunctionSamples *CalleeSamples = &J.second; - if (callsiteIsHot(FS, CalleeSamples)) - Total += countBodySamples(CalleeSamples); + if (callsiteIsHot(CalleeSamples, PSI)) + Total += countBodySamples(CalleeSamples, PSI); } return Total; @@ -473,15 +474,8 @@ void SampleProfileLoader::clearFunctionData() { CoverageTracker.clear(); } -/// Returns the line offset to the start line of the subprogram. -/// We assume that a single function will not exceed 65535 LOC. -unsigned SampleProfileLoader::getOffset(const DILocation *DIL) const { - return (DIL->getLine() - DIL->getScope()->getSubprogram()->getLine()) & - 0xffff; -} - #ifndef NDEBUG -/// \brief Print the weight of edge \p E on stream \p OS. +/// Print the weight of edge \p E on stream \p OS. /// /// \param OS Stream to emit the output to. /// \param E Edge to print. @@ -490,7 +484,7 @@ void SampleProfileLoader::printEdgeWeight(raw_ostream &OS, Edge E) { << "]: " << EdgeWeights[E] << "\n"; } -/// \brief Print the equivalence class of block \p BB on stream \p OS. +/// Print the equivalence class of block \p BB on stream \p OS. /// /// \param OS Stream to emit the output to. /// \param BB Block to print. @@ -501,7 +495,7 @@ void SampleProfileLoader::printBlockEquivalence(raw_ostream &OS, << "]: " << ((Equiv) ? EquivalenceClass[BB]->getName() : "NONE") << "\n"; } -/// \brief Print the weight of block \p BB on stream \p OS. +/// Print the weight of block \p BB on stream \p OS. /// /// \param OS Stream to emit the output to. /// \param BB Block to print. @@ -513,7 +507,7 @@ void SampleProfileLoader::printBlockWeight(raw_ostream &OS, } #endif -/// \brief Get the weight for an instruction. +/// Get the weight for an instruction. /// /// The "weight" of an instruction \p Inst is the number of samples /// collected on that instruction at runtime. To retrieve it, we @@ -549,7 +543,7 @@ ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) { return 0; const DILocation *DIL = DLoc; - uint32_t LineOffset = getOffset(DIL); + uint32_t LineOffset = FunctionSamples::getOffset(DIL); uint32_t Discriminator = DIL->getBaseDiscriminator(); ErrorOr<uint64_t> R = FS->findSamplesAt(LineOffset, Discriminator); if (R) { @@ -569,16 +563,16 @@ ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) { return Remark; }); } - DEBUG(dbgs() << " " << DLoc.getLine() << "." - << DIL->getBaseDiscriminator() << ":" << Inst - << " (line offset: " << LineOffset << "." - << DIL->getBaseDiscriminator() << " - weight: " << R.get() - << ")\n"); + LLVM_DEBUG(dbgs() << " " << DLoc.getLine() << "." + << DIL->getBaseDiscriminator() << ":" << Inst + << " (line offset: " << LineOffset << "." + << DIL->getBaseDiscriminator() << " - weight: " << R.get() + << ")\n"); } return R; } -/// \brief Compute the weight of a basic block. +/// Compute the weight of a basic block. /// /// The weight of basic block \p BB is the maximum weight of all the /// instructions in BB. @@ -599,7 +593,7 @@ ErrorOr<uint64_t> SampleProfileLoader::getBlockWeight(const BasicBlock *BB) { return HasWeight ? ErrorOr<uint64_t>(Max) : std::error_code(); } -/// \brief Compute and store the weights of every basic block. +/// Compute and store the weights of every basic block. /// /// This populates the BlockWeights map by computing /// the weights of every basic block in the CFG. @@ -607,7 +601,7 @@ ErrorOr<uint64_t> SampleProfileLoader::getBlockWeight(const BasicBlock *BB) { /// \param F The function to query. bool SampleProfileLoader::computeBlockWeights(Function &F) { bool Changed = false; - DEBUG(dbgs() << "Block weights\n"); + LLVM_DEBUG(dbgs() << "Block weights\n"); for (const auto &BB : F) { ErrorOr<uint64_t> Weight = getBlockWeight(&BB); if (Weight) { @@ -615,13 +609,13 @@ bool SampleProfileLoader::computeBlockWeights(Function &F) { VisitedBlocks.insert(&BB); Changed = true; } - DEBUG(printBlockWeight(dbgs(), &BB)); + LLVM_DEBUG(printBlockWeight(dbgs(), &BB)); } return Changed; } -/// \brief Get the FunctionSamples for a call instruction. +/// Get the FunctionSamples for a call instruction. /// /// The FunctionSamples of a call/invoke instruction \p Inst is the inlined /// instance in which that call instruction is calling to. It contains @@ -649,8 +643,11 @@ SampleProfileLoader::findCalleeFunctionSamples(const Instruction &Inst) const { if (FS == nullptr) return nullptr; - return FS->findFunctionSamplesAt( - LineLocation(getOffset(DIL), DIL->getBaseDiscriminator()), CalleeName); + std::string CalleeGUID; + CalleeName = getRepInFormat(CalleeName, Reader->getFormat(), CalleeGUID); + return FS->findFunctionSamplesAt(LineLocation(FunctionSamples::getOffset(DIL), + DIL->getBaseDiscriminator()), + CalleeName); } /// Returns a vector of FunctionSamples that are the indirect call targets @@ -670,7 +667,7 @@ SampleProfileLoader::findIndirectCallFunctionSamples( if (FS == nullptr) return R; - uint32_t LineOffset = getOffset(DIL); + uint32_t LineOffset = FunctionSamples::getOffset(DIL); uint32_t Discriminator = DIL->getBaseDiscriminator(); auto T = FS->findCallTargetMapAt(LineOffset, Discriminator); @@ -678,23 +675,23 @@ SampleProfileLoader::findIndirectCallFunctionSamples( if (T) for (const auto &T_C : T.get()) Sum += T_C.second; - if (const FunctionSamplesMap *M = FS->findFunctionSamplesMapAt( - LineLocation(getOffset(DIL), DIL->getBaseDiscriminator()))) { + if (const FunctionSamplesMap *M = FS->findFunctionSamplesMapAt(LineLocation( + FunctionSamples::getOffset(DIL), DIL->getBaseDiscriminator()))) { if (M->empty()) return R; for (const auto &NameFS : *M) { Sum += NameFS.second.getEntrySamples(); R.push_back(&NameFS.second); } - std::sort(R.begin(), R.end(), - [](const FunctionSamples *L, const FunctionSamples *R) { - return L->getEntrySamples() > R->getEntrySamples(); - }); + llvm::sort(R.begin(), R.end(), + [](const FunctionSamples *L, const FunctionSamples *R) { + return L->getEntrySamples() > R->getEntrySamples(); + }); } return R; } -/// \brief Get the FunctionSamples for an instruction. +/// Get the FunctionSamples for an instruction. /// /// The FunctionSamples of an instruction \p Inst is the inlined instance /// in which that instruction is coming from. We traverse the inline stack @@ -710,20 +707,7 @@ SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const { if (!DIL) return Samples; - const DILocation *PrevDIL = DIL; - for (DIL = DIL->getInlinedAt(); DIL; DIL = DIL->getInlinedAt()) { - S.push_back(std::make_pair( - LineLocation(getOffset(DIL), DIL->getBaseDiscriminator()), - PrevDIL->getScope()->getSubprogram()->getLinkageName())); - PrevDIL = DIL; - } - if (S.size() == 0) - return Samples; - const FunctionSamples *FS = Samples; - for (int i = S.size() - 1; i >= 0 && FS != nullptr; i--) { - FS = FS->findFunctionSamplesAt(S[i].first, S[i].second); - } - return FS; + return Samples->findFunctionSamples(DIL); } bool SampleProfileLoader::inlineCallInstruction(Instruction *I) { @@ -759,7 +743,7 @@ bool SampleProfileLoader::inlineCallInstruction(Instruction *I) { return false; } -/// \brief Iteratively inline hot callsites of a function. +/// Iteratively inline hot callsites of a function. /// /// Iteratively traverse all callsites of the function \p F, and find if /// the corresponding inlined instance exists and is hot in profile. If @@ -776,6 +760,7 @@ bool SampleProfileLoader::inlineHotFunctions( Function &F, DenseSet<GlobalValue::GUID> &InlinedGUIDs) { DenseSet<Instruction *> PromotedInsns; bool Changed = false; + bool isCompact = (Reader->getFormat() == SPF_Compact_Binary); while (true) { bool LocalChanged = false; SmallVector<Instruction *, 10> CIS; @@ -787,7 +772,7 @@ bool SampleProfileLoader::inlineHotFunctions( if ((isa<CallInst>(I) || isa<InvokeInst>(I)) && !isa<IntrinsicInst>(I) && (FS = findCalleeFunctionSamples(I))) { Candidates.push_back(&I); - if (callsiteIsHot(Samples, FS)) + if (callsiteIsHot(FS, PSI)) Hot = true; } } @@ -807,8 +792,8 @@ bool SampleProfileLoader::inlineHotFunctions( for (const auto *FS : findIndirectCallFunctionSamples(*I, Sum)) { if (IsThinLTOPreLink) { FS->findInlinedFunctions(InlinedGUIDs, F.getParent(), - Samples->getTotalSamples() * - SampleProfileHotThreshold / 100); + PSI->getOrCompHotCountThreshold(), + isCompact); continue; } auto CalleeFunctionName = FS->getName(); @@ -817,7 +802,9 @@ bool SampleProfileLoader::inlineHotFunctions( // clone the caller first, and inline the cloned caller if it is // recursive. As llvm does not inline recursive calls, we will // simply ignore it instead of handling it explicitly. - if (CalleeFunctionName == F.getName()) + std::string FGUID; + auto Fname = getRepInFormat(F.getName(), Reader->getFormat(), FGUID); + if (CalleeFunctionName == Fname) continue; const char *Reason = "Callee function not available"; @@ -836,9 +823,9 @@ bool SampleProfileLoader::inlineHotFunctions( inlineCallInstruction(DI)) LocalChanged = true; } else { - DEBUG(dbgs() - << "\nFailed to promote indirect call to " - << CalleeFunctionName << " because " << Reason << "\n"); + LLVM_DEBUG(dbgs() + << "\nFailed to promote indirect call to " + << CalleeFunctionName << " because " << Reason << "\n"); } } } else if (CalledFunction && CalledFunction->getSubprogram() && @@ -847,8 +834,8 @@ bool SampleProfileLoader::inlineHotFunctions( LocalChanged = true; } else if (IsThinLTOPreLink) { findCalleeFunctionSamples(*I)->findInlinedFunctions( - InlinedGUIDs, F.getParent(), - Samples->getTotalSamples() * SampleProfileHotThreshold / 100); + InlinedGUIDs, F.getParent(), PSI->getOrCompHotCountThreshold(), + isCompact); } } if (LocalChanged) { @@ -860,7 +847,7 @@ bool SampleProfileLoader::inlineHotFunctions( return Changed; } -/// \brief Find equivalence classes for the given block. +/// Find equivalence classes for the given block. /// /// This finds all the blocks that are guaranteed to execute the same /// number of times as \p BB1. To do this, it traverses all the @@ -917,7 +904,7 @@ void SampleProfileLoader::findEquivalencesFor( } } -/// \brief Find equivalence classes. +/// Find equivalence classes. /// /// Since samples may be missing from blocks, we can fill in the gaps by setting /// the weights of all the blocks in the same equivalence class to the same @@ -928,14 +915,14 @@ void SampleProfileLoader::findEquivalencesFor( /// \param F The function to query. void SampleProfileLoader::findEquivalenceClasses(Function &F) { SmallVector<BasicBlock *, 8> DominatedBBs; - DEBUG(dbgs() << "\nBlock equivalence classes\n"); + LLVM_DEBUG(dbgs() << "\nBlock equivalence classes\n"); // Find equivalence sets based on dominance and post-dominance information. for (auto &BB : F) { BasicBlock *BB1 = &BB; // Compute BB1's equivalence class once. if (EquivalenceClass.count(BB1)) { - DEBUG(printBlockEquivalence(dbgs(), BB1)); + LLVM_DEBUG(printBlockEquivalence(dbgs(), BB1)); continue; } @@ -956,7 +943,7 @@ void SampleProfileLoader::findEquivalenceClasses(Function &F) { DT->getDescendants(BB1, DominatedBBs); findEquivalencesFor(BB1, DominatedBBs, PDT.get()); - DEBUG(printBlockEquivalence(dbgs(), BB1)); + LLVM_DEBUG(printBlockEquivalence(dbgs(), BB1)); } // Assign weights to equivalence classes. @@ -965,17 +952,18 @@ void SampleProfileLoader::findEquivalenceClasses(Function &F) { // the same number of times. Since we know that the head block in // each equivalence class has the largest weight, assign that weight // to all the blocks in that equivalence class. - DEBUG(dbgs() << "\nAssign the same weight to all blocks in the same class\n"); + LLVM_DEBUG( + dbgs() << "\nAssign the same weight to all blocks in the same class\n"); for (auto &BI : F) { const BasicBlock *BB = &BI; const BasicBlock *EquivBB = EquivalenceClass[BB]; if (BB != EquivBB) BlockWeights[BB] = BlockWeights[EquivBB]; - DEBUG(printBlockWeight(dbgs(), BB)); + LLVM_DEBUG(printBlockWeight(dbgs(), BB)); } } -/// \brief Visit the given edge to decide if it has a valid weight. +/// Visit the given edge to decide if it has a valid weight. /// /// If \p E has not been visited before, we copy to \p UnknownEdge /// and increment the count of unknown edges. @@ -996,7 +984,7 @@ uint64_t SampleProfileLoader::visitEdge(Edge E, unsigned *NumUnknownEdges, return EdgeWeights[E]; } -/// \brief Propagate weights through incoming/outgoing edges. +/// Propagate weights through incoming/outgoing edges. /// /// If the weight of a basic block is known, and there is only one edge /// with an unknown weight, we can calculate the weight of that edge. @@ -1012,7 +1000,7 @@ uint64_t SampleProfileLoader::visitEdge(Edge E, unsigned *NumUnknownEdges, bool SampleProfileLoader::propagateThroughEdges(Function &F, bool UpdateBlockCount) { bool Changed = false; - DEBUG(dbgs() << "\nPropagation through edges\n"); + LLVM_DEBUG(dbgs() << "\nPropagation through edges\n"); for (const auto &BI : F) { const BasicBlock *BB = &BI; const BasicBlock *EC = EquivalenceClass[BB]; @@ -1084,9 +1072,9 @@ bool SampleProfileLoader::propagateThroughEdges(Function &F, if (TotalWeight > BBWeight) { BBWeight = TotalWeight; Changed = true; - DEBUG(dbgs() << "All edge weights for " << BB->getName() - << " known. Set weight for block: "; - printBlockWeight(dbgs(), BB);); + LLVM_DEBUG(dbgs() << "All edge weights for " << BB->getName() + << " known. Set weight for block: "; + printBlockWeight(dbgs(), BB);); } } else if (NumTotalEdges == 1 && EdgeWeights[SingleEdge] < BlockWeights[EC]) { @@ -1113,8 +1101,8 @@ bool SampleProfileLoader::propagateThroughEdges(Function &F, EdgeWeights[UnknownEdge] = BlockWeights[OtherEC]; VisitedEdges.insert(UnknownEdge); Changed = true; - DEBUG(dbgs() << "Set weight for edge: "; - printEdgeWeight(dbgs(), UnknownEdge)); + LLVM_DEBUG(dbgs() << "Set weight for edge: "; + printEdgeWeight(dbgs(), UnknownEdge)); } } else if (VisitedBlocks.count(EC) && BlockWeights[EC] == 0) { // If a block Weights 0, all its in/out edges should weight 0. @@ -1140,8 +1128,8 @@ bool SampleProfileLoader::propagateThroughEdges(Function &F, EdgeWeights[SelfReferentialEdge] = 0; VisitedEdges.insert(SelfReferentialEdge); Changed = true; - DEBUG(dbgs() << "Set self-referential edge weight to: "; - printEdgeWeight(dbgs(), SelfReferentialEdge)); + LLVM_DEBUG(dbgs() << "Set self-referential edge weight to: "; + printEdgeWeight(dbgs(), SelfReferentialEdge)); } if (UpdateBlockCount && !VisitedBlocks.count(EC) && TotalWeight > 0) { BlockWeights[EC] = TotalWeight; @@ -1154,7 +1142,7 @@ bool SampleProfileLoader::propagateThroughEdges(Function &F, return Changed; } -/// \brief Build in/out edge lists for each basic block in the CFG. +/// Build in/out edge lists for each basic block in the CFG. /// /// We are interested in unique edges. If a block B1 has multiple /// edges to another block B2, we only add a single B1->B2 edge. @@ -1190,17 +1178,17 @@ static SmallVector<InstrProfValueData, 2> SortCallTargets( SmallVector<InstrProfValueData, 2> R; for (auto I = M.begin(); I != M.end(); ++I) R.push_back({Function::getGUID(I->getKey()), I->getValue()}); - std::sort(R.begin(), R.end(), - [](const InstrProfValueData &L, const InstrProfValueData &R) { - if (L.Count == R.Count) - return L.Value > R.Value; - else - return L.Count > R.Count; - }); + llvm::sort(R.begin(), R.end(), + [](const InstrProfValueData &L, const InstrProfValueData &R) { + if (L.Count == R.Count) + return L.Value > R.Value; + else + return L.Count > R.Count; + }); return R; } -/// \brief Propagate weights into edges +/// Propagate weights into edges /// /// The following rules are applied to every block BB in the CFG: /// @@ -1265,7 +1253,7 @@ void SampleProfileLoader::propagateWeights(Function &F) { // Generate MD_prof metadata for every branch instruction using the // edge weights computed during propagation. - DEBUG(dbgs() << "\nPropagation complete. Setting branch weights\n"); + LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch weights\n"); LLVMContext &Ctx = F.getContext(); MDBuilder MDB(Ctx); for (auto &BI : F) { @@ -1281,7 +1269,7 @@ void SampleProfileLoader::propagateWeights(Function &F) { if (!DLoc) continue; const DILocation *DIL = DLoc; - uint32_t LineOffset = getOffset(DIL); + uint32_t LineOffset = FunctionSamples::getOffset(DIL); uint32_t Discriminator = DIL->getBaseDiscriminator(); const FunctionSamples *FS = findFunctionSamples(I); @@ -1311,10 +1299,10 @@ void SampleProfileLoader::propagateWeights(Function &F) { continue; DebugLoc BranchLoc = TI->getDebugLoc(); - DEBUG(dbgs() << "\nGetting weights for branch at line " - << ((BranchLoc) ? Twine(BranchLoc.getLine()) - : Twine("<UNKNOWN LOCATION>")) - << ".\n"); + LLVM_DEBUG(dbgs() << "\nGetting weights for branch at line " + << ((BranchLoc) ? Twine(BranchLoc.getLine()) + : Twine("<UNKNOWN LOCATION>")) + << ".\n"); SmallVector<uint32_t, 4> Weights; uint32_t MaxWeight = 0; Instruction *MaxDestInst; @@ -1322,12 +1310,12 @@ void SampleProfileLoader::propagateWeights(Function &F) { BasicBlock *Succ = TI->getSuccessor(I); Edge E = std::make_pair(BB, Succ); uint64_t Weight = EdgeWeights[E]; - DEBUG(dbgs() << "\t"; printEdgeWeight(dbgs(), E)); + LLVM_DEBUG(dbgs() << "\t"; printEdgeWeight(dbgs(), E)); // Use uint32_t saturated arithmetic to adjust the incoming weights, // if needed. Sample counts in profiles are 64-bit unsigned values, // but internally branch weights are expressed as 32-bit values. if (Weight > std::numeric_limits<uint32_t>::max()) { - DEBUG(dbgs() << " (saturated due to uint32_t overflow)"); + LLVM_DEBUG(dbgs() << " (saturated due to uint32_t overflow)"); Weight = std::numeric_limits<uint32_t>::max(); } // Weight is added by one to avoid propagation errors introduced by @@ -1348,7 +1336,7 @@ void SampleProfileLoader::propagateWeights(Function &F) { // annotation is done twice. If the first annotation already set the // weights, the second pass does not need to set it. if (MaxWeight > 0 && !TI->extractProfTotalWeight(TempWeight)) { - DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n"); + LLVM_DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n"); TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); ORE->emit([&]() { @@ -1357,12 +1345,12 @@ void SampleProfileLoader::propagateWeights(Function &F) { << ore::NV("CondBranchesLoc", BranchLoc); }); } else { - DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n"); + LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n"); } } } -/// \brief Get the line number for the function header. +/// Get the line number for the function header. /// /// This looks up function \p F in the current compilation unit and /// retrieves the line number where the function is defined. This is @@ -1377,6 +1365,9 @@ unsigned SampleProfileLoader::getFunctionLoc(Function &F) { if (DISubprogram *S = F.getSubprogram()) return S->getLine(); + if (NoWarnSampleUnused) + return 0; + // If the start of \p F is missing, emit a diagnostic to inform the user // about the missed opportunity. F.getContext().diagnose(DiagnosticInfoSampleProfile( @@ -1390,14 +1381,13 @@ void SampleProfileLoader::computeDominanceAndLoopInfo(Function &F) { DT.reset(new DominatorTree); DT->recalculate(F); - PDT.reset(new PostDomTreeBase<BasicBlock>()); - PDT->recalculate(F); + PDT.reset(new PostDominatorTree(F)); LI.reset(new LoopInfo); LI->analyze(*DT); } -/// \brief Generate branch weight metadata for all branches in \p F. +/// Generate branch weight metadata for all branches in \p F. /// /// Branch weights are computed out of instruction samples using a /// propagation heuristic. Propagation proceeds in 3 phases: @@ -1452,8 +1442,8 @@ bool SampleProfileLoader::emitAnnotations(Function &F) { if (getFunctionLoc(F) == 0) return false; - DEBUG(dbgs() << "Line number for the first instruction in " << F.getName() - << ": " << getFunctionLoc(F) << "\n"); + LLVM_DEBUG(dbgs() << "Line number for the first instruction in " + << F.getName() << ": " << getFunctionLoc(F) << "\n"); DenseSet<GlobalValue::GUID> InlinedGUIDs; Changed |= inlineHotFunctions(F, InlinedGUIDs); @@ -1467,7 +1457,9 @@ bool SampleProfileLoader::emitAnnotations(Function &F) { // Sets the GUIDs that are inlined in the profiled binary. This is used // for ThinLink to make correct liveness analysis, and also make the IR // match the profiled binary before annotation. - F.setEntryCount(Samples->getHeadSamples() + 1, &InlinedGUIDs); + F.setEntryCount( + ProfileCount(Samples->getHeadSamples() + 1, Function::PCT_Real), + &InlinedGUIDs); // Compute dominance and loop info needed for propagation. computeDominanceAndLoopInfo(F); @@ -1481,8 +1473,8 @@ bool SampleProfileLoader::emitAnnotations(Function &F) { // If coverage checking was requested, compute it now. if (SampleProfileRecordCoverage) { - unsigned Used = CoverageTracker.countUsedRecords(Samples); - unsigned Total = CoverageTracker.countBodyRecords(Samples); + unsigned Used = CoverageTracker.countUsedRecords(Samples, PSI); + unsigned Total = CoverageTracker.countBodyRecords(Samples, PSI); unsigned Coverage = CoverageTracker.computeCoverage(Used, Total); if (Coverage < SampleProfileRecordCoverage) { F.getContext().diagnose(DiagnosticInfoSampleProfile( @@ -1495,7 +1487,7 @@ bool SampleProfileLoader::emitAnnotations(Function &F) { if (SampleProfileSampleCoverage) { uint64_t Used = CoverageTracker.getTotalUsedSamples(); - uint64_t Total = CoverageTracker.countBodySamples(Samples); + uint64_t Total = CoverageTracker.countBodySamples(Samples, PSI); unsigned Coverage = CoverageTracker.computeCoverage(Used, Total); if (Coverage < SampleProfileSampleCoverage) { F.getContext().diagnose(DiagnosticInfoSampleProfile( @@ -1514,6 +1506,7 @@ INITIALIZE_PASS_BEGIN(SampleProfileLoaderLegacyPass, "sample-profile", "Sample Profile loader", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) INITIALIZE_PASS_END(SampleProfileLoaderLegacyPass, "sample-profile", "Sample Profile loader", false, false) @@ -1538,10 +1531,15 @@ ModulePass *llvm::createSampleProfileLoaderPass(StringRef Name) { return new SampleProfileLoaderLegacyPass(Name); } -bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM) { +bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, + ProfileSummaryInfo *_PSI) { if (!ProfileIsValid) return false; + PSI = _PSI; + if (M.getProfileSummary() == nullptr) + M.setProfileSummary(Reader->getSummary().getMD(M.getContext())); + // Compute the total number of samples collected in this profile. for (const auto &I : Reader->getProfiles()) TotalCollectedSamples += I.second.getTotalSamples(); @@ -1572,22 +1570,22 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM) { clearFunctionData(); retval |= runOnFunction(F, AM); } - if (M.getProfileSummary() == nullptr) - M.setProfileSummary(Reader->getSummary().getMD(M.getContext())); return retval; } bool SampleProfileLoaderLegacyPass::runOnModule(Module &M) { ACT = &getAnalysis<AssumptionCacheTracker>(); TTIWP = &getAnalysis<TargetTransformInfoWrapperPass>(); - return SampleLoader.runOnModule(M, nullptr); + ProfileSummaryInfo *PSI = + getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); + return SampleLoader.runOnModule(M, nullptr, PSI); } bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) { // Initialize the entry count to -1, which will be treated conservatively // by getEntryCount as the same as unknown (None). If we have samples this // will be overwritten in emitAnnotations. - F.setEntryCount(-1); + F.setEntryCount(ProfileCount(-1, Function::PCT_Real)); std::unique_ptr<OptimizationRemarkEmitter> OwnedORE; if (AM) { auto &FAM = @@ -1622,7 +1620,8 @@ PreservedAnalyses SampleProfileLoaderPass::run(Module &M, SampleLoader.doInitialization(M); - if (!SampleLoader.runOnModule(M, &AM)) + ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M); + if (!SampleLoader.runOnModule(M, &AM, PSI)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); diff --git a/lib/Transforms/IPO/StripSymbols.cpp b/lib/Transforms/IPO/StripSymbols.cpp index de1b51e206ff..c9afb060a91a 100644 --- a/lib/Transforms/IPO/StripSymbols.cpp +++ b/lib/Transforms/IPO/StripSymbols.cpp @@ -21,6 +21,7 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" @@ -30,7 +31,6 @@ #include "llvm/IR/ValueSymbolTable.h" #include "llvm/Pass.h" #include "llvm/Transforms/IPO.h" -#include "llvm/Transforms/Utils/Local.h" using namespace llvm; namespace { diff --git a/lib/Transforms/IPO/SyntheticCountsPropagation.cpp b/lib/Transforms/IPO/SyntheticCountsPropagation.cpp new file mode 100644 index 000000000000..3c5ad37bced1 --- /dev/null +++ b/lib/Transforms/IPO/SyntheticCountsPropagation.cpp @@ -0,0 +1,140 @@ +//=- SyntheticCountsPropagation.cpp - Propagate function counts --*- C++ -*-=// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a transformation that synthesizes entry counts for +// functions and attaches !prof metadata to functions with the synthesized +// counts. The presence of !prof metadata with counter name set to +// 'synthesized_function_entry_count' indicate that the value of the counter is +// an estimation of the likely execution count of the function. This transform +// is applied only in non PGO mode as functions get 'real' profile-based +// function entry counts in the PGO mode. +// +// The transformation works by first assigning some initial values to the entry +// counts of all functions and then doing a top-down traversal of the +// callgraph-scc to propagate the counts. For each function the set of callsites +// and their relative block frequency is gathered. The relative block frequency +// multiplied by the entry count of the caller and added to the callee's entry +// count. For non-trivial SCCs, the new counts are computed from the previous +// counts and updated in one shot. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/SyntheticCountsPropagation.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/Analysis/SyntheticCountsUtils.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; +using Scaled64 = ScaledNumber<uint64_t>; +using ProfileCount = Function::ProfileCount; + +#define DEBUG_TYPE "synthetic-counts-propagation" + +/// Initial synthetic count assigned to functions. +static cl::opt<int> + InitialSyntheticCount("initial-synthetic-count", cl::Hidden, cl::init(10), + cl::ZeroOrMore, + cl::desc("Initial value of synthetic entry count.")); + +/// Initial synthetic count assigned to inline functions. +static cl::opt<int> InlineSyntheticCount( + "inline-synthetic-count", cl::Hidden, cl::init(15), cl::ZeroOrMore, + cl::desc("Initial synthetic entry count for inline functions.")); + +/// Initial synthetic count assigned to cold functions. +static cl::opt<int> ColdSyntheticCount( + "cold-synthetic-count", cl::Hidden, cl::init(5), cl::ZeroOrMore, + cl::desc("Initial synthetic entry count for cold functions.")); + +// Assign initial synthetic entry counts to functions. +static void +initializeCounts(Module &M, function_ref<void(Function *, uint64_t)> SetCount) { + auto MayHaveIndirectCalls = [](Function &F) { + for (auto *U : F.users()) { + if (!isa<CallInst>(U) && !isa<InvokeInst>(U)) + return true; + } + return false; + }; + + for (Function &F : M) { + uint64_t InitialCount = InitialSyntheticCount; + if (F.isDeclaration()) + continue; + if (F.hasFnAttribute(Attribute::AlwaysInline) || + F.hasFnAttribute(Attribute::InlineHint)) { + // Use a higher value for inline functions to account for the fact that + // these are usually beneficial to inline. + InitialCount = InlineSyntheticCount; + } else if (F.hasLocalLinkage() && !MayHaveIndirectCalls(F)) { + // Local functions without inline hints get counts only through + // propagation. + InitialCount = 0; + } else if (F.hasFnAttribute(Attribute::Cold) || + F.hasFnAttribute(Attribute::NoInline)) { + // Use a lower value for noinline and cold functions. + InitialCount = ColdSyntheticCount; + } + SetCount(&F, InitialCount); + } +} + +PreservedAnalyses SyntheticCountsPropagation::run(Module &M, + ModuleAnalysisManager &MAM) { + FunctionAnalysisManager &FAM = + MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + DenseMap<Function *, uint64_t> Counts; + // Set initial entry counts. + initializeCounts(M, [&](Function *F, uint64_t Count) { Counts[F] = Count; }); + + // Compute the relative block frequency for a call edge. Use scaled numbers + // and not integers since the relative block frequency could be less than 1. + auto GetCallSiteRelFreq = [&](const CallGraphNode::CallRecord &Edge) { + Optional<Scaled64> Res = None; + if (!Edge.first) + return Res; + assert(isa<Instruction>(Edge.first)); + CallSite CS(cast<Instruction>(Edge.first)); + Function *Caller = CS.getCaller(); + auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*Caller); + BasicBlock *CSBB = CS.getInstruction()->getParent(); + Scaled64 EntryFreq(BFI.getEntryFreq(), 0); + Scaled64 BBFreq(BFI.getBlockFreq(CSBB).getFrequency(), 0); + BBFreq /= EntryFreq; + return Optional<Scaled64>(BBFreq); + }; + + CallGraph CG(M); + // Propgate the entry counts on the callgraph. + SyntheticCountsUtils<const CallGraph *>::propagate( + &CG, GetCallSiteRelFreq, + [&](const CallGraphNode *N) { return Counts[N->getFunction()]; }, + [&](const CallGraphNode *N, uint64_t New) { + auto F = N->getFunction(); + if (!F || F->isDeclaration()) + return; + Counts[F] += New; + }); + + // Set the counts as metadata. + for (auto Entry : Counts) + Entry.first->setEntryCount( + ProfileCount(Entry.second, Function::PCT_Synthetic)); + + return PreservedAnalyses::all(); +} diff --git a/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp index caffc03339c4..8fe7ae1282cc 100644 --- a/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp +++ b/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp @@ -18,11 +18,13 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" +#include "llvm/Object/ModuleSymbolTable.h" #include "llvm/Pass.h" #include "llvm/Support/ScopedPrinter.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/FunctionAttrs.h" +#include "llvm/Transforms/IPO/FunctionImport.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/ModuleUtils.h" using namespace llvm; @@ -128,8 +130,7 @@ void promoteTypeIds(Module &M, StringRef ModuleId) { } GO.addMetadata( LLVMContext::MD_type, - *MDNode::get(M.getContext(), - ArrayRef<Metadata *>{MD->getOperand(0), I->second})); + *MDNode::get(M.getContext(), {MD->getOperand(0), I->second})); } } } @@ -169,46 +170,17 @@ void simplifyExternals(Module &M) { } } -void filterModule( - Module *M, function_ref<bool(const GlobalValue *)> ShouldKeepDefinition) { - for (Module::alias_iterator I = M->alias_begin(), E = M->alias_end(); - I != E;) { - GlobalAlias *GA = &*I++; - if (ShouldKeepDefinition(GA)) - continue; - - GlobalObject *GO; - if (GA->getValueType()->isFunctionTy()) - GO = Function::Create(cast<FunctionType>(GA->getValueType()), - GlobalValue::ExternalLinkage, "", M); - else - GO = new GlobalVariable( - *M, GA->getValueType(), false, GlobalValue::ExternalLinkage, - nullptr, "", nullptr, - GA->getThreadLocalMode(), GA->getType()->getAddressSpace()); - GO->takeName(GA); - GA->replaceAllUsesWith(GO); - GA->eraseFromParent(); - } - - for (Function &F : *M) { - if (ShouldKeepDefinition(&F)) - continue; - - F.deleteBody(); - F.setComdat(nullptr); - F.clearMetadata(); - } - - for (GlobalVariable &GV : M->globals()) { - if (ShouldKeepDefinition(&GV)) - continue; - - GV.setInitializer(nullptr); - GV.setLinkage(GlobalValue::ExternalLinkage); - GV.setComdat(nullptr); - GV.clearMetadata(); - } +static void +filterModule(Module *M, + function_ref<bool(const GlobalValue *)> ShouldKeepDefinition) { + std::vector<GlobalValue *> V; + for (GlobalValue &GV : M->global_values()) + if (!ShouldKeepDefinition(&GV)) + V.push_back(&GV); + + for (GlobalValue *GV : V) + if (!convertToDeclaration(*GV)) + GV->eraseFromParent(); } void forEachVirtualFunction(Constant *C, function_ref<void(Function *)> Fn) { @@ -228,13 +200,19 @@ void splitAndWriteThinLTOBitcode( function_ref<AAResults &(Function &)> AARGetter, Module &M) { std::string ModuleId = getUniqueModuleId(&M); if (ModuleId.empty()) { - // We couldn't generate a module ID for this module, just write it out as a - // regular LTO module. - WriteBitcodeToFile(&M, OS); + // We couldn't generate a module ID for this module, write it out as a + // regular LTO module with an index for summary-based dead stripping. + ProfileSummaryInfo PSI(M); + M.addModuleFlag(Module::Error, "ThinLTO", uint32_t(0)); + ModuleSummaryIndex Index = buildModuleSummaryIndex(M, nullptr, &PSI); + WriteBitcodeToFile(M, OS, /*ShouldPreserveUseListOrder=*/false, &Index); + if (ThinLinkOS) // We don't have a ThinLTO part, but still write the module to the // ThinLinkOS if requested so that the expected output file is produced. - WriteBitcodeToFile(&M, *ThinLinkOS); + WriteBitcodeToFile(M, *ThinLinkOS, /*ShouldPreserveUseListOrder=*/false, + &Index); + return; } @@ -243,10 +221,8 @@ void splitAndWriteThinLTOBitcode( // Returns whether a global has attached type metadata. Such globals may // participate in CFI or whole-program devirtualization, so they need to // appear in the merged module instead of the thin LTO module. - auto HasTypeMetadata = [&](const GlobalObject *GO) { - SmallVector<MDNode *, 1> MDs; - GO->getMetadata(LLVMContext::MD_type, MDs); - return !MDs.empty(); + auto HasTypeMetadata = [](const GlobalObject *GO) { + return GO->hasMetadata(LLVMContext::MD_type); }; // Collect the set of virtual functions that are eligible for virtual constant @@ -287,7 +263,7 @@ void splitAndWriteThinLTOBitcode( ValueToValueMapTy VMap; std::unique_ptr<Module> MergedM( - CloneModule(&M, VMap, [&](const GlobalValue *GV) -> bool { + CloneModule(M, VMap, [&](const GlobalValue *GV) -> bool { if (const auto *C = GV->getComdat()) if (MergedMComdats.count(C)) return true; @@ -298,6 +274,7 @@ void splitAndWriteThinLTOBitcode( return false; })); StripDebugInfo(*MergedM); + MergedM->setModuleInlineAsm(""); for (Function &F : *MergedM) if (!F.isDeclaration()) { @@ -328,13 +305,13 @@ void splitAndWriteThinLTOBitcode( promoteInternals(*MergedM, M, ModuleId, CfiFunctions); promoteInternals(M, *MergedM, ModuleId, CfiFunctions); + auto &Ctx = MergedM->getContext(); SmallVector<MDNode *, 8> CfiFunctionMDs; for (auto V : CfiFunctions) { Function &F = *cast<Function>(V); SmallVector<MDNode *, 2> Types; F.getMetadata(LLVMContext::MD_type, Types); - auto &Ctx = MergedM->getContext(); SmallVector<Metadata *, 4> Elts; Elts.push_back(MDString::get(Ctx, F.getName())); CfiFunctionLinkage Linkage; @@ -357,6 +334,47 @@ void splitAndWriteThinLTOBitcode( NMD->addOperand(MD); } + SmallVector<MDNode *, 8> FunctionAliases; + for (auto &A : M.aliases()) { + if (!isa<Function>(A.getAliasee())) + continue; + + auto *F = cast<Function>(A.getAliasee()); + + Metadata *Elts[] = { + MDString::get(Ctx, A.getName()), + MDString::get(Ctx, F->getName()), + ConstantAsMetadata::get( + ConstantInt::get(Type::getInt8Ty(Ctx), A.getVisibility())), + ConstantAsMetadata::get( + ConstantInt::get(Type::getInt8Ty(Ctx), A.isWeakForLinker())), + }; + + FunctionAliases.push_back(MDTuple::get(Ctx, Elts)); + } + + if (!FunctionAliases.empty()) { + NamedMDNode *NMD = MergedM->getOrInsertNamedMetadata("aliases"); + for (auto MD : FunctionAliases) + NMD->addOperand(MD); + } + + SmallVector<MDNode *, 8> Symvers; + ModuleSymbolTable::CollectAsmSymvers(M, [&](StringRef Name, StringRef Alias) { + Function *F = M.getFunction(Name); + if (!F || F->use_empty()) + return; + + Symvers.push_back(MDTuple::get( + Ctx, {MDString::get(Ctx, Name), MDString::get(Ctx, Alias)})); + }); + + if (!Symvers.empty()) { + NamedMDNode *NMD = MergedM->getOrInsertNamedMetadata("symvers"); + for (auto MD : Symvers) + NMD->addOperand(MD); + } + simplifyExternals(*MergedM); // FIXME: Try to re-use BSI and PFI from the original module here. @@ -376,10 +394,9 @@ void splitAndWriteThinLTOBitcode( // be used in the backends, and use that in the minimized bitcode // produced for the full link. ModuleHash ModHash = {{0}}; - W.writeModule(&M, /*ShouldPreserveUseListOrder=*/false, &Index, + W.writeModule(M, /*ShouldPreserveUseListOrder=*/false, &Index, /*GenerateHash=*/true, &ModHash); - W.writeModule(MergedM.get(), /*ShouldPreserveUseListOrder=*/false, - &MergedMIndex); + W.writeModule(*MergedM, /*ShouldPreserveUseListOrder=*/false, &MergedMIndex); W.writeSymtab(); W.writeStrtab(); OS << Buffer; @@ -391,8 +408,8 @@ void splitAndWriteThinLTOBitcode( Buffer.clear(); BitcodeWriter W2(Buffer); StripDebugInfo(M); - W2.writeThinLinkBitcode(&M, Index, ModHash); - W2.writeModule(MergedM.get(), /*ShouldPreserveUseListOrder=*/false, + W2.writeThinLinkBitcode(M, Index, ModHash); + W2.writeModule(*MergedM, /*ShouldPreserveUseListOrder=*/false, &MergedMIndex); W2.writeSymtab(); W2.writeStrtab(); @@ -402,10 +419,8 @@ void splitAndWriteThinLTOBitcode( // Returns whether this module needs to be split because it uses type metadata. bool requiresSplit(Module &M) { - SmallVector<MDNode *, 1> MDs; for (auto &GO : M.global_objects()) { - GO.getMetadata(LLVMContext::MD_type, MDs); - if (!MDs.empty()) + if (GO.hasMetadata(LLVMContext::MD_type)) return true; } @@ -425,13 +440,13 @@ void writeThinLTOBitcode(raw_ostream &OS, raw_ostream *ThinLinkOS, // be used in the backends, and use that in the minimized bitcode // produced for the full link. ModuleHash ModHash = {{0}}; - WriteBitcodeToFile(&M, OS, /*ShouldPreserveUseListOrder=*/false, Index, + WriteBitcodeToFile(M, OS, /*ShouldPreserveUseListOrder=*/false, Index, /*GenerateHash=*/true, &ModHash); // If a minimized bitcode module was requested for the thin link, only // the information that is needed by thin link will be written in the // given OS. if (ThinLinkOS && Index) - WriteThinLinkBitcodeToFile(&M, *ThinLinkOS, *Index, ModHash); + WriteThinLinkBitcodeToFile(M, *ThinLinkOS, *Index, ModHash); } class WriteThinLTOBitcode : public ModulePass { diff --git a/lib/Transforms/IPO/WholeProgramDevirt.cpp b/lib/Transforms/IPO/WholeProgramDevirt.cpp index 5fbb001216a3..d65da2504db4 100644 --- a/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -111,6 +111,12 @@ static cl::opt<std::string> ClWriteSummary( cl::desc("Write summary to given YAML file after running pass"), cl::Hidden); +static cl::opt<unsigned> + ClThreshold("wholeprogramdevirt-branch-funnel-threshold", cl::Hidden, + cl::init(10), cl::ZeroOrMore, + cl::desc("Maximum number of call targets per " + "call site to enable branch funnels")); + // Find the minimum offset that we may store a value of size Size bits at. If // IsAfter is set, look for an offset before the object, otherwise look for an // offset after the object. @@ -281,24 +287,11 @@ struct VirtualCallSite { DebugLoc DLoc = CS->getDebugLoc(); BasicBlock *Block = CS.getParent(); - // In the new pass manager, we can request the optimization - // remark emitter pass on a per-function-basis, which the - // OREGetter will do for us. - // In the old pass manager, this is harder, so we just build - // a optimization remark emitter on the fly, when we need it. - std::unique_ptr<OptimizationRemarkEmitter> OwnedORE; - OptimizationRemarkEmitter *ORE; - if (OREGetter) - ORE = &OREGetter(F); - else { - OwnedORE = make_unique<OptimizationRemarkEmitter>(F); - ORE = OwnedORE.get(); - } - using namespace ore; - ORE->emit(OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block) - << NV("Optimization", OptName) << ": devirtualized a call to " - << NV("FunctionName", TargetName)); + OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block) + << NV("Optimization", OptName) + << ": devirtualized a call to " + << NV("FunctionName", TargetName)); } void replaceAndErase( @@ -329,12 +322,17 @@ struct CallSiteInfo { /// cases we are directly operating on the call sites at the IR level. std::vector<VirtualCallSite> CallSites; + /// Whether all call sites represented by this CallSiteInfo, including those + /// in summaries, have been devirtualized. This starts off as true because a + /// default constructed CallSiteInfo represents no call sites. + bool AllCallSitesDevirted = true; + // These fields are used during the export phase of ThinLTO and reflect // information collected from function summaries. /// Whether any function summary contains an llvm.assume(llvm.type.test) for /// this slot. - bool SummaryHasTypeTestAssumeUsers; + bool SummaryHasTypeTestAssumeUsers = false; /// CFI-specific: a vector containing the list of function summaries that use /// the llvm.type.checked.load intrinsic and therefore will require @@ -350,8 +348,22 @@ struct CallSiteInfo { !SummaryTypeCheckedLoadUsers.empty(); } - /// As explained in the comment for SummaryTypeCheckedLoadUsers. - void markDevirt() { SummaryTypeCheckedLoadUsers.clear(); } + void markSummaryHasTypeTestAssumeUsers() { + SummaryHasTypeTestAssumeUsers = true; + AllCallSitesDevirted = false; + } + + void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) { + SummaryTypeCheckedLoadUsers.push_back(FS); + AllCallSitesDevirted = false; + } + + void markDevirt() { + AllCallSitesDevirted = true; + + // As explained in the comment for SummaryTypeCheckedLoadUsers. + SummaryTypeCheckedLoadUsers.clear(); + } }; // Call site information collected for a specific VTableSlot. @@ -386,7 +398,9 @@ CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) { void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses) { - findCallSiteInfo(CS).CallSites.push_back({VTable, CS, NumUnsafeUses}); + auto &CSI = findCallSiteInfo(CS); + CSI.AllCallSitesDevirted = false; + CSI.CallSites.push_back({VTable, CS, NumUnsafeUses}); } struct DevirtModule { @@ -451,6 +465,12 @@ struct DevirtModule { VTableSlotInfo &SlotInfo, WholeProgramDevirtResolution *Res); + void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Constant *JT, + bool &IsExported); + void tryICallBranchFunnel(MutableArrayRef<VirtualCallTarget> TargetsForSlot, + VTableSlotInfo &SlotInfo, + WholeProgramDevirtResolution *Res, VTableSlot Slot); + bool tryEvaluateFunctionsWithArgs( MutableArrayRef<VirtualCallTarget> TargetsForSlot, ArrayRef<uint64_t> Args); @@ -484,6 +504,8 @@ struct DevirtModule { StringRef Name, IntegerType *IntTy, uint32_t Storage); + Constant *getMemberAddr(const TypeMemberInfo *M); + void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne, Constant *UniqueMemberAddr); bool tryUniqueRetValOpt(unsigned BitWidth, @@ -539,7 +561,16 @@ struct WholeProgramDevirt : public ModulePass { if (skipModule(M)) return false; - auto OREGetter = function_ref<OptimizationRemarkEmitter &(Function *)>(); + // In the new pass manager, we can request the optimization + // remark emitter pass on a per-function-basis, which the + // OREGetter will do for us. + // In the old pass manager, this is harder, so we just build + // an optimization remark emitter on the fly, when we need it. + std::unique_ptr<OptimizationRemarkEmitter> ORE; + auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & { + ORE = make_unique<OptimizationRemarkEmitter>(F); + return *ORE; + }; if (UseCommandLine) return DevirtModule::runForTesting(M, LegacyAARGetter(*this), OREGetter); @@ -580,7 +611,8 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M, auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & { return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); }; - if (!DevirtModule(M, AARGetter, OREGetter, nullptr, nullptr).run()) + if (!DevirtModule(M, AARGetter, OREGetter, ExportSummary, ImportSummary) + .run()) return PreservedAnalyses::all(); return PreservedAnalyses::none(); } @@ -588,7 +620,7 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M, bool DevirtModule::runForTesting( Module &M, function_ref<AAResults &(Function &)> AARGetter, function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) { - ModuleSummaryIndex Summary; + ModuleSummaryIndex Summary(/*HaveGVs=*/false); // Handle the command-line summary arguments. This code is for testing // purposes only, so we handle errors directly. @@ -730,10 +762,9 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, if (VCallSite.NumUnsafeUses) --*VCallSite.NumUnsafeUses; } - if (CSInfo.isExported()) { + if (CSInfo.isExported()) IsExported = true; - CSInfo.markDevirt(); - } + CSInfo.markDevirt(); }; Apply(SlotInfo.CSInfo); for (auto &P : SlotInfo.ConstCSInfo) @@ -789,6 +820,133 @@ bool DevirtModule::trySingleImplDevirt( return true; } +void DevirtModule::tryICallBranchFunnel( + MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, + WholeProgramDevirtResolution *Res, VTableSlot Slot) { + Triple T(M.getTargetTriple()); + if (T.getArch() != Triple::x86_64) + return; + + if (TargetsForSlot.size() > ClThreshold) + return; + + bool HasNonDevirt = !SlotInfo.CSInfo.AllCallSitesDevirted; + if (!HasNonDevirt) + for (auto &P : SlotInfo.ConstCSInfo) + if (!P.second.AllCallSitesDevirted) { + HasNonDevirt = true; + break; + } + + if (!HasNonDevirt) + return; + + FunctionType *FT = + FunctionType::get(Type::getVoidTy(M.getContext()), {Int8PtrTy}, true); + Function *JT; + if (isa<MDString>(Slot.TypeID)) { + JT = Function::Create(FT, Function::ExternalLinkage, + getGlobalName(Slot, {}, "branch_funnel"), &M); + JT->setVisibility(GlobalValue::HiddenVisibility); + } else { + JT = Function::Create(FT, Function::InternalLinkage, "branch_funnel", &M); + } + JT->addAttribute(1, Attribute::Nest); + + std::vector<Value *> JTArgs; + JTArgs.push_back(JT->arg_begin()); + for (auto &T : TargetsForSlot) { + JTArgs.push_back(getMemberAddr(T.TM)); + JTArgs.push_back(T.Fn); + } + + BasicBlock *BB = BasicBlock::Create(M.getContext(), "", JT, nullptr); + Constant *Intr = + Intrinsic::getDeclaration(&M, llvm::Intrinsic::icall_branch_funnel, {}); + + auto *CI = CallInst::Create(Intr, JTArgs, "", BB); + CI->setTailCallKind(CallInst::TCK_MustTail); + ReturnInst::Create(M.getContext(), nullptr, BB); + + bool IsExported = false; + applyICallBranchFunnel(SlotInfo, JT, IsExported); + if (IsExported) + Res->TheKind = WholeProgramDevirtResolution::BranchFunnel; +} + +void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, + Constant *JT, bool &IsExported) { + auto Apply = [&](CallSiteInfo &CSInfo) { + if (CSInfo.isExported()) + IsExported = true; + if (CSInfo.AllCallSitesDevirted) + return; + for (auto &&VCallSite : CSInfo.CallSites) { + CallSite CS = VCallSite.CS; + + // Jump tables are only profitable if the retpoline mitigation is enabled. + Attribute FSAttr = CS.getCaller()->getFnAttribute("target-features"); + if (FSAttr.hasAttribute(Attribute::None) || + !FSAttr.getValueAsString().contains("+retpoline")) + continue; + + if (RemarksEnabled) + VCallSite.emitRemark("branch-funnel", JT->getName(), OREGetter); + + // Pass the address of the vtable in the nest register, which is r10 on + // x86_64. + std::vector<Type *> NewArgs; + NewArgs.push_back(Int8PtrTy); + for (Type *T : CS.getFunctionType()->params()) + NewArgs.push_back(T); + PointerType *NewFT = PointerType::getUnqual( + FunctionType::get(CS.getFunctionType()->getReturnType(), NewArgs, + CS.getFunctionType()->isVarArg())); + + IRBuilder<> IRB(CS.getInstruction()); + std::vector<Value *> Args; + Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy)); + for (unsigned I = 0; I != CS.getNumArgOperands(); ++I) + Args.push_back(CS.getArgOperand(I)); + + CallSite NewCS; + if (CS.isCall()) + NewCS = IRB.CreateCall(IRB.CreateBitCast(JT, NewFT), Args); + else + NewCS = IRB.CreateInvoke( + IRB.CreateBitCast(JT, NewFT), + cast<InvokeInst>(CS.getInstruction())->getNormalDest(), + cast<InvokeInst>(CS.getInstruction())->getUnwindDest(), Args); + NewCS.setCallingConv(CS.getCallingConv()); + + AttributeList Attrs = CS.getAttributes(); + std::vector<AttributeSet> NewArgAttrs; + NewArgAttrs.push_back(AttributeSet::get( + M.getContext(), ArrayRef<Attribute>{Attribute::get( + M.getContext(), Attribute::Nest)})); + for (unsigned I = 0; I + 2 < Attrs.getNumAttrSets(); ++I) + NewArgAttrs.push_back(Attrs.getParamAttributes(I)); + NewCS.setAttributes( + AttributeList::get(M.getContext(), Attrs.getFnAttributes(), + Attrs.getRetAttributes(), NewArgAttrs)); + + CS->replaceAllUsesWith(NewCS.getInstruction()); + CS->eraseFromParent(); + + // This use is no longer unsafe. + if (VCallSite.NumUnsafeUses) + --*VCallSite.NumUnsafeUses; + } + // Don't mark as devirtualized because there may be callers compiled without + // retpoline mitigation, which would mean that they are lowered to + // llvm.type.test and therefore require an llvm.type.test resolution for the + // type identifier. + }; + Apply(SlotInfo.CSInfo); + for (auto &P : SlotInfo.ConstCSInfo) + Apply(P.second); +} + bool DevirtModule::tryEvaluateFunctionsWithArgs( MutableArrayRef<VirtualCallTarget> TargetsForSlot, ArrayRef<uint64_t> Args) { @@ -909,7 +1067,7 @@ Constant *DevirtModule::importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, // We only need to set metadata if the global is newly created, in which // case it would not have hidden visibility. - if (GV->getMetadata(LLVMContext::MD_absolute_symbol)) + if (GV->hasMetadata(LLVMContext::MD_absolute_symbol)) return C; auto SetAbsRange = [&](uint64_t Min, uint64_t Max) { @@ -941,6 +1099,12 @@ void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, CSInfo.markDevirt(); } +Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) { + Constant *C = ConstantExpr::getBitCast(M->Bits->GV, Int8PtrTy); + return ConstantExpr::getGetElementPtr(Int8Ty, C, + ConstantInt::get(Int64Ty, M->Offset)); +} + bool DevirtModule::tryUniqueRetValOpt( unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res, @@ -960,12 +1124,7 @@ bool DevirtModule::tryUniqueRetValOpt( // checked for a uniform return value in tryUniformRetValOpt. assert(UniqueMember); - Constant *UniqueMemberAddr = - ConstantExpr::getBitCast(UniqueMember->Bits->GV, Int8PtrTy); - UniqueMemberAddr = ConstantExpr::getGetElementPtr( - Int8Ty, UniqueMemberAddr, - ConstantInt::get(Int64Ty, UniqueMember->Offset)); - + Constant *UniqueMemberAddr = getMemberAddr(UniqueMember); if (CSInfo.isExported()) { Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal; Res->Info = IsOne; @@ -1352,6 +1511,14 @@ void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) { break; } } + + if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) { + auto *JT = M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"), + Type::getVoidTy(M.getContext())); + bool IsExported = false; + applyICallBranchFunnel(SlotInfo, JT, IsExported); + assert(!IsExported); + } } void DevirtModule::removeRedundantTypeTests() { @@ -1421,14 +1588,13 @@ bool DevirtModule::run() { // FIXME: Only add live functions. for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) { for (Metadata *MD : MetadataByGUID[VF.GUID]) { - CallSlots[{MD, VF.Offset}].CSInfo.SummaryHasTypeTestAssumeUsers = - true; + CallSlots[{MD, VF.Offset}] + .CSInfo.markSummaryHasTypeTestAssumeUsers(); } } for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) { for (Metadata *MD : MetadataByGUID[VF.GUID]) { - CallSlots[{MD, VF.Offset}] - .CSInfo.SummaryTypeCheckedLoadUsers.push_back(FS); + CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS); } } for (const FunctionSummary::ConstVCall &VC : @@ -1436,7 +1602,7 @@ bool DevirtModule::run() { for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { CallSlots[{MD, VC.VFunc.Offset}] .ConstCSInfo[VC.Args] - .SummaryHasTypeTestAssumeUsers = true; + .markSummaryHasTypeTestAssumeUsers(); } } for (const FunctionSummary::ConstVCall &VC : @@ -1444,7 +1610,7 @@ bool DevirtModule::run() { for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { CallSlots[{MD, VC.VFunc.Offset}] .ConstCSInfo[VC.Args] - .SummaryTypeCheckedLoadUsers.push_back(FS); + .addSummaryTypeCheckedLoadUser(FS); } } } @@ -1468,9 +1634,12 @@ bool DevirtModule::run() { cast<MDString>(S.first.TypeID)->getString()) .WPDRes[S.first.ByteOffset]; - if (!trySingleImplDevirt(TargetsForSlot, S.second, Res) && - tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first)) - DidVirtualConstProp = true; + if (!trySingleImplDevirt(TargetsForSlot, S.second, Res)) { + DidVirtualConstProp |= + tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first); + + tryICallBranchFunnel(TargetsForSlot, S.second, Res, S.first); + } // Collect functions devirtualized at least for one call site for stats. if (RemarksEnabled) @@ -1499,23 +1668,10 @@ bool DevirtModule::run() { for (const auto &DT : DevirtTargets) { Function *F = DT.second; - // In the new pass manager, we can request the optimization - // remark emitter pass on a per-function-basis, which the - // OREGetter will do for us. - // In the old pass manager, this is harder, so we just build - // a optimization remark emitter on the fly, when we need it. - std::unique_ptr<OptimizationRemarkEmitter> OwnedORE; - OptimizationRemarkEmitter *ORE; - if (OREGetter) - ORE = &OREGetter(F); - else { - OwnedORE = make_unique<OptimizationRemarkEmitter>(F); - ORE = OwnedORE.get(); - } - using namespace ore; - ORE->emit(OptimizationRemark(DEBUG_TYPE, "Devirtualized", F) - << "devirtualized " << NV("FunctionName", F->getName())); + OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, "Devirtualized", F) + << "devirtualized " + << NV("FunctionName", F->getName())); } } diff --git a/lib/Transforms/InstCombine/CMakeLists.txt b/lib/Transforms/InstCombine/CMakeLists.txt index 5cbe804ce3ec..8a3a58e9ecc9 100644 --- a/lib/Transforms/InstCombine/CMakeLists.txt +++ b/lib/Transforms/InstCombine/CMakeLists.txt @@ -1,3 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS InstCombineTables.td) +tablegen(LLVM InstCombineTables.inc -gen-searchable-tables) +add_public_tablegen_target(InstCombineTableGen) + add_llvm_library(LLVMInstCombine InstructionCombining.cpp InstCombineAddSub.cpp diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 688897644848..aa31e0d850dd 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -511,7 +511,8 @@ Value *FAddCombine::performFactorization(Instruction *I) { } Value *FAddCombine::simplify(Instruction *I) { - assert(I->isFast() && "Expected 'fast' instruction"); + assert(I->hasAllowReassoc() && I->hasNoSignedZeros() && + "Expected 'reassoc'+'nsz' instruction"); // Currently we are not able to handle vector type. if (I->getType()->isVectorTy()) @@ -855,48 +856,6 @@ Value *FAddCombine::createAddendVal(const FAddend &Opnd, bool &NeedNeg) { return createFMul(OpndVal, Coeff.getValue(Instr->getType())); } -/// \brief Return true if we can prove that: -/// (sub LHS, RHS) === (sub nsw LHS, RHS) -/// This basically requires proving that the add in the original type would not -/// overflow to change the sign bit or have a carry out. -/// TODO: Handle this for Vectors. -bool InstCombiner::willNotOverflowSignedSub(const Value *LHS, - const Value *RHS, - const Instruction &CxtI) const { - // If LHS and RHS each have at least two sign bits, the subtraction - // cannot overflow. - if (ComputeNumSignBits(LHS, 0, &CxtI) > 1 && - ComputeNumSignBits(RHS, 0, &CxtI) > 1) - return true; - - KnownBits LHSKnown = computeKnownBits(LHS, 0, &CxtI); - - KnownBits RHSKnown = computeKnownBits(RHS, 0, &CxtI); - - // Subtraction of two 2's complement numbers having identical signs will - // never overflow. - if ((LHSKnown.isNegative() && RHSKnown.isNegative()) || - (LHSKnown.isNonNegative() && RHSKnown.isNonNegative())) - return true; - - // TODO: implement logic similar to checkRippleForAdd - return false; -} - -/// \brief Return true if we can prove that: -/// (sub LHS, RHS) === (sub nuw LHS, RHS) -bool InstCombiner::willNotOverflowUnsignedSub(const Value *LHS, - const Value *RHS, - const Instruction &CxtI) const { - // If the LHS is negative and the RHS is non-negative, no unsigned wrap. - KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, &CxtI); - KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, &CxtI); - if (LHSKnown.isNegative() && RHSKnown.isNonNegative()) - return true; - - return false; -} - // Checks if any operand is negative and we can convert add to sub. // This function checks for following negative patterns // ADD(XOR(OR(Z, NOT(C)), C)), 1) == NEG(AND(Z, C)) @@ -964,7 +923,7 @@ Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) { if (!match(Op1, m_Constant(Op1C))) return nullptr; - if (Instruction *NV = foldOpWithConstantIntoOperand(Add)) + if (Instruction *NV = foldBinOpIntoSelectOrPhi(Add)) return NV; Value *X; @@ -1031,17 +990,148 @@ Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) { return nullptr; } -Instruction *InstCombiner::visitAdd(BinaryOperator &I) { - bool Changed = SimplifyAssociativeOrCommutative(I); - if (Value *V = SimplifyVectorOp(I)) - return replaceInstUsesWith(I, V); +// Matches multiplication expression Op * C where C is a constant. Returns the +// constant value in C and the other operand in Op. Returns true if such a +// match is found. +static bool MatchMul(Value *E, Value *&Op, APInt &C) { + const APInt *AI; + if (match(E, m_Mul(m_Value(Op), m_APInt(AI)))) { + C = *AI; + return true; + } + if (match(E, m_Shl(m_Value(Op), m_APInt(AI)))) { + C = APInt(AI->getBitWidth(), 1); + C <<= *AI; + return true; + } + return false; +} + +// Matches remainder expression Op % C where C is a constant. Returns the +// constant value in C and the other operand in Op. Returns the signedness of +// the remainder operation in IsSigned. Returns true if such a match is +// found. +static bool MatchRem(Value *E, Value *&Op, APInt &C, bool &IsSigned) { + const APInt *AI; + IsSigned = false; + if (match(E, m_SRem(m_Value(Op), m_APInt(AI)))) { + IsSigned = true; + C = *AI; + return true; + } + if (match(E, m_URem(m_Value(Op), m_APInt(AI)))) { + C = *AI; + return true; + } + if (match(E, m_And(m_Value(Op), m_APInt(AI))) && (*AI + 1).isPowerOf2()) { + C = *AI + 1; + return true; + } + return false; +} +// Matches division expression Op / C with the given signedness as indicated +// by IsSigned, where C is a constant. Returns the constant value in C and the +// other operand in Op. Returns true if such a match is found. +static bool MatchDiv(Value *E, Value *&Op, APInt &C, bool IsSigned) { + const APInt *AI; + if (IsSigned && match(E, m_SDiv(m_Value(Op), m_APInt(AI)))) { + C = *AI; + return true; + } + if (!IsSigned) { + if (match(E, m_UDiv(m_Value(Op), m_APInt(AI)))) { + C = *AI; + return true; + } + if (match(E, m_LShr(m_Value(Op), m_APInt(AI)))) { + C = APInt(AI->getBitWidth(), 1); + C <<= *AI; + return true; + } + } + return false; +} + +// Returns whether C0 * C1 with the given signedness overflows. +static bool MulWillOverflow(APInt &C0, APInt &C1, bool IsSigned) { + bool overflow; + if (IsSigned) + (void)C0.smul_ov(C1, overflow); + else + (void)C0.umul_ov(C1, overflow); + return overflow; +} + +// Simplifies X % C0 + (( X / C0 ) % C1) * C0 to X % (C0 * C1), where (C0 * C1) +// does not overflow. +Value *InstCombiner::SimplifyAddWithRemainder(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); - if (Value *V = - SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), - SQ.getWithInstruction(&I))) + Value *X, *MulOpV; + APInt C0, MulOpC; + bool IsSigned; + // Match I = X % C0 + MulOpV * C0 + if (((MatchRem(LHS, X, C0, IsSigned) && MatchMul(RHS, MulOpV, MulOpC)) || + (MatchRem(RHS, X, C0, IsSigned) && MatchMul(LHS, MulOpV, MulOpC))) && + C0 == MulOpC) { + Value *RemOpV; + APInt C1; + bool Rem2IsSigned; + // Match MulOpC = RemOpV % C1 + if (MatchRem(MulOpV, RemOpV, C1, Rem2IsSigned) && + IsSigned == Rem2IsSigned) { + Value *DivOpV; + APInt DivOpC; + // Match RemOpV = X / C0 + if (MatchDiv(RemOpV, DivOpV, DivOpC, IsSigned) && X == DivOpV && + C0 == DivOpC && !MulWillOverflow(C0, C1, IsSigned)) { + Value *NewDivisor = + ConstantInt::get(X->getType()->getContext(), C0 * C1); + return IsSigned ? Builder.CreateSRem(X, NewDivisor, "srem") + : Builder.CreateURem(X, NewDivisor, "urem"); + } + } + } + + return nullptr; +} + +/// Fold +/// (1 << NBits) - 1 +/// Into: +/// ~(-(1 << NBits)) +/// Because a 'not' is better for bit-tracking analysis and other transforms +/// than an 'add'. The new shl is always nsw, and is nuw if old `and` was. +static Instruction *canonicalizeLowbitMask(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *NBits; + if (!match(&I, m_Add(m_OneUse(m_Shl(m_One(), m_Value(NBits))), m_AllOnes()))) + return nullptr; + + Constant *MinusOne = Constant::getAllOnesValue(NBits->getType()); + Value *NotMask = Builder.CreateShl(MinusOne, NBits, "notmask"); + // Be wary of constant folding. + if (auto *BOp = dyn_cast<BinaryOperator>(NotMask)) { + // Always NSW. But NUW propagates from `add`. + BOp->setHasNoSignedWrap(); + BOp->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + } + + return BinaryOperator::CreateNot(NotMask, I.getName()); +} + +Instruction *InstCombiner::visitAdd(BinaryOperator &I) { + if (Value *V = SimplifyAddInst(I.getOperand(0), I.getOperand(1), + I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldShuffledBinop(I)) + return X; + // (A*B)+(A*C) -> A*(B+C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); @@ -1051,6 +1141,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // FIXME: This should be moved into the above helper function to allow these // transforms for general constant or constant splat vectors. + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); Type *Ty = I.getType(); if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { Value *XorLHS = nullptr; ConstantInt *XorRHS = nullptr; @@ -1123,6 +1214,14 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (Value *V = checkForNegativeOperand(I, Builder)) return replaceInstUsesWith(I, V); + // (A + 1) + ~B --> A - B + // ~B + (A + 1) --> A - B + if (match(&I, m_c_BinOp(m_Add(m_Value(A), m_One()), m_Not(m_Value(B))))) + return BinaryOperator::CreateSub(A, B); + + // X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1) + if (Value *V = SimplifyAddWithRemainder(I)) return replaceInstUsesWith(I, V); + // A+B --> A|B iff A and B have no bits set in common. if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT)) return BinaryOperator::CreateOr(LHS, RHS); @@ -1253,26 +1352,15 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } // (add (xor A, B) (and A, B)) --> (or A, B) - if (match(LHS, m_Xor(m_Value(A), m_Value(B))) && - match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) - return BinaryOperator::CreateOr(A, B); - // (add (and A, B) (xor A, B)) --> (or A, B) - if (match(RHS, m_Xor(m_Value(A), m_Value(B))) && - match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) + if (match(&I, m_c_BinOp(m_Xor(m_Value(A), m_Value(B)), + m_c_And(m_Deferred(A), m_Deferred(B))))) return BinaryOperator::CreateOr(A, B); // (add (or A, B) (and A, B)) --> (add A, B) - if (match(LHS, m_Or(m_Value(A), m_Value(B))) && - match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) { - I.setOperand(0, A); - I.setOperand(1, B); - return &I; - } - // (add (and A, B) (or A, B)) --> (add A, B) - if (match(RHS, m_Or(m_Value(A), m_Value(B))) && - match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) { + if (match(&I, m_c_BinOp(m_Or(m_Value(A), m_Value(B)), + m_c_And(m_Deferred(A), m_Deferred(B))))) { I.setOperand(0, A); I.setOperand(1, B); return &I; @@ -1281,6 +1369,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // TODO(jingyue): Consider willNotOverflowSignedAdd and // willNotOverflowUnsignedAdd to reduce the number of invocations of // computeKnownBits. + bool Changed = false; if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHS, RHS, I)) { Changed = true; I.setHasNoSignedWrap(true); @@ -1290,39 +1379,35 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { I.setHasNoUnsignedWrap(true); } + if (Instruction *V = canonicalizeLowbitMask(I, Builder)) + return V; + return Changed ? &I : nullptr; } Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { - bool Changed = SimplifyAssociativeOrCommutative(I); - Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) - return replaceInstUsesWith(I, V); - - if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), + if (Value *V = SimplifyFAddInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (isa<Constant>(RHS)) - if (Instruction *FoldedFAdd = foldOpWithConstantIntoOperand(I)) - return FoldedFAdd; + if (SimplifyAssociativeOrCommutative(I)) + return &I; - // -A + B --> B - A - // -A + -B --> -(A + B) - if (Value *LHSV = dyn_castFNegVal(LHS)) { - Instruction *RI = BinaryOperator::CreateFSub(RHS, LHSV); - RI->copyFastMathFlags(&I); - return RI; - } + if (Instruction *X = foldShuffledBinop(I)) + return X; - // A + -B --> A - B - if (!isa<Constant>(RHS)) - if (Value *V = dyn_castFNegVal(RHS)) { - Instruction *RI = BinaryOperator::CreateFSub(LHS, V); - RI->copyFastMathFlags(&I); - return RI; - } + if (Instruction *FoldedFAdd = foldBinOpIntoSelectOrPhi(I)) + return FoldedFAdd; + + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + Value *X; + // (-X) + Y --> Y - X + if (match(LHS, m_FNeg(m_Value(X)))) + return BinaryOperator::CreateFSubFMF(RHS, X, &I); + // Y + (-X) --> Y - X + if (match(RHS, m_FNeg(m_Value(X)))) + return BinaryOperator::CreateFSubFMF(LHS, X, &I); // Check for (fadd double (sitofp x), y), see if we can merge this into an // integer add followed by a promotion. @@ -1386,12 +1471,12 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { if (Value *V = SimplifySelectsFeedingBinaryOp(I, LHS, RHS)) return replaceInstUsesWith(I, V); - if (I.isFast()) { + if (I.hasAllowReassoc() && I.hasNoSignedZeros()) { if (Value *V = FAddCombine(Builder).simplify(&I)) return replaceInstUsesWith(I, V); } - return Changed ? &I : nullptr; + return nullptr; } /// Optimize pointer differences into the same array into a size. Consider: @@ -1481,21 +1566,20 @@ Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, } Instruction *InstCombiner::visitSub(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifySubInst(I.getOperand(0), I.getOperand(1), + I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Value *V = - SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), - SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; // (A*B)-(A*C) -> A*(B-C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); // If this is a 'B = x-(-A)', change to B = x+A. + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = dyn_castNegVal(Op1)) { BinaryOperator *Res = BinaryOperator::CreateAdd(Op0, V); @@ -1519,12 +1603,28 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (match(Op0, m_AllOnes())) return BinaryOperator::CreateNot(Op1); + // (~X) - (~Y) --> Y - X + Value *X, *Y; + if (match(Op0, m_Not(m_Value(X))) && match(Op1, m_Not(m_Value(Y)))) + return BinaryOperator::CreateSub(Y, X); + if (Constant *C = dyn_cast<Constant>(Op0)) { + bool IsNegate = match(C, m_ZeroInt()); Value *X; - // C - zext(bool) -> bool ? C - 1 : C - if (match(Op1, m_ZExt(m_Value(X))) && - X->getType()->getScalarSizeInBits() == 1) + if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + // 0 - (zext bool) --> sext bool + // C - (zext bool) --> bool ? C - 1 : C + if (IsNegate) + return CastInst::CreateSExtOrBitCast(X, I.getType()); return SelectInst::Create(X, SubOne(C), C); + } + if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + // 0 - (sext bool) --> zext bool + // C - (sext bool) --> bool ? C + 1 : C + if (IsNegate) + return CastInst::CreateZExtOrBitCast(X, I.getType()); + return SelectInst::Create(X, AddOne(C), C); + } // C - ~X == X + (1+C) if (match(Op1, m_Not(m_Value(X)))) @@ -1544,16 +1644,6 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Constant *C2; if (match(Op1, m_Add(m_Value(X), m_Constant(C2)))) return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X); - - // Fold (sub 0, (zext bool to B)) --> (sext bool to B) - if (C->isNullValue() && match(Op1, m_ZExt(m_Value(X)))) - if (X->getType()->isIntOrIntVectorTy(1)) - return CastInst::CreateSExtOrBitCast(X, Op1->getType()); - - // Fold (sub 0, (sext bool to B)) --> (zext bool to B) - if (C->isNullValue() && match(Op1, m_SExt(m_Value(X)))) - if (X->getType()->isIntOrIntVectorTy(1)) - return CastInst::CreateZExtOrBitCast(X, Op1->getType()); } const APInt *Op0C; @@ -1575,6 +1665,22 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Value *ShAmtOp = cast<Instruction>(Op1)->getOperand(1); return BinaryOperator::CreateLShr(X, ShAmtOp); } + + if (Op1->hasOneUse()) { + Value *LHS, *RHS; + SelectPatternFlavor SPF = matchSelectPattern(Op1, LHS, RHS).Flavor; + if (SPF == SPF_ABS || SPF == SPF_NABS) { + // This is a negate of an ABS/NABS pattern. Just swap the operands + // of the select. + SelectInst *SI = cast<SelectInst>(Op1); + Value *TrueVal = SI->getTrueValue(); + Value *FalseVal = SI->getFalseValue(); + SI->setTrueValue(FalseVal); + SI->setFalseValue(TrueVal); + // Don't swap prof metadata, we didn't change the branch behavior. + return replaceInstUsesWith(I, SI); + } + } } // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known @@ -1678,6 +1784,27 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType())) return replaceInstUsesWith(I, Res); + // Canonicalize a shifty way to code absolute value to the common pattern. + // There are 2 potential commuted variants. + // We're relying on the fact that we only do this transform when the shift has + // exactly 2 uses and the xor has exactly 1 use (otherwise, we might increase + // instructions). + Value *A; + const APInt *ShAmt; + Type *Ty = I.getType(); + if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) && + Op1->hasNUses(2) && *ShAmt == Ty->getScalarSizeInBits() - 1 && + match(Op0, m_OneUse(m_c_Xor(m_Specific(A), m_Specific(Op1))))) { + // B = ashr i32 A, 31 ; smear the sign bit + // sub (xor A, B), B ; flip bits if negative and subtract -1 (add 1) + // --> (A < 0) ? -A : A + Value *Cmp = Builder.CreateICmpSLT(A, ConstantInt::getNullValue(Ty)); + // Copy the nuw/nsw flags from the sub to the negate. + Value *Neg = Builder.CreateNeg(A, "", I.hasNoUnsignedWrap(), + I.hasNoSignedWrap()); + return SelectInst::Create(Cmp, Neg, A); + } + bool Changed = false; if (!I.hasNoSignedWrap() && willNotOverflowSignedSub(Op0, Op1, I)) { Changed = true; @@ -1692,21 +1819,32 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { } Instruction *InstCombiner::visitFSub(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) - return replaceInstUsesWith(I, V); - - if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), + if (Value *V = SimplifyFSubInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; + + // Subtraction from -0.0 is the canonical form of fneg. // fsub nsz 0, X ==> fsub nsz -0.0, X - if (I.getFastMathFlags().noSignedZeros() && match(Op0, m_Zero())) { - // Subtraction from -0.0 is the canonical form of fneg. - Instruction *NewI = BinaryOperator::CreateFNeg(Op1); - NewI->copyFastMathFlags(&I); - return NewI; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (I.hasNoSignedZeros() && match(Op0, m_PosZeroFP())) + return BinaryOperator::CreateFNegFMF(Op1, &I); + + // If Op0 is not -0.0 or we can ignore -0.0: Z - (X - Y) --> Z + (Y - X) + // Canonicalize to fadd to make analysis easier. + // This can also help codegen because fadd is commutative. + // Note that if this fsub was really an fneg, the fadd with -0.0 will get + // killed later. We still limit that particular transform with 'hasOneUse' + // because an fneg is assumed better/cheaper than a generic fsub. + Value *X, *Y; + if (I.hasNoSignedZeros() || CannotBeNegativeZero(Op0, SQ.TLI)) { + if (match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { + Value *NewSub = Builder.CreateFSubFMF(Y, X, &I); + return BinaryOperator::CreateFAddFMF(Op0, NewSub, &I); + } } if (isa<Constant>(Op0)) @@ -1714,34 +1852,34 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (Instruction *NV = FoldOpIntoSelect(I, SI)) return NV; - // If this is a 'B = x-(-A)', change to B = x+A, potentially looking - // through FP extensions/truncations along the way. - if (Value *V = dyn_castFNegVal(Op1)) { - Instruction *NewI = BinaryOperator::CreateFAdd(Op0, V); - NewI->copyFastMathFlags(&I); - return NewI; - } - if (FPTruncInst *FPTI = dyn_cast<FPTruncInst>(Op1)) { - if (Value *V = dyn_castFNegVal(FPTI->getOperand(0))) { - Value *NewTrunc = Builder.CreateFPTrunc(V, I.getType()); - Instruction *NewI = BinaryOperator::CreateFAdd(Op0, NewTrunc); - NewI->copyFastMathFlags(&I); - return NewI; - } - } else if (FPExtInst *FPEI = dyn_cast<FPExtInst>(Op1)) { - if (Value *V = dyn_castFNegVal(FPEI->getOperand(0))) { - Value *NewExt = Builder.CreateFPExt(V, I.getType()); - Instruction *NewI = BinaryOperator::CreateFAdd(Op0, NewExt); - NewI->copyFastMathFlags(&I); - return NewI; - } + // X - C --> X + (-C) + // But don't transform constant expressions because there's an inverse fold + // for X + (-Y) --> X - Y. + Constant *C; + if (match(Op1, m_Constant(C)) && !isa<ConstantExpr>(Op1)) + return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); + + // X - (-Y) --> X + Y + if (match(Op1, m_FNeg(m_Value(Y)))) + return BinaryOperator::CreateFAddFMF(Op0, Y, &I); + + // Similar to above, but look through a cast of the negated value: + // X - (fptrunc(-Y)) --> X + fptrunc(Y) + if (match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y)))))) { + Value *TruncY = Builder.CreateFPTrunc(Y, I.getType()); + return BinaryOperator::CreateFAddFMF(Op0, TruncY, &I); + } + // X - (fpext(-Y)) --> X + fpext(Y) + if (match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) { + Value *ExtY = Builder.CreateFPExt(Y, I.getType()); + return BinaryOperator::CreateFAddFMF(Op0, ExtY, &I); } // Handle specials cases for FSub with selects feeding the operation if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) return replaceInstUsesWith(I, V); - if (I.isFast()) { + if (I.hasAllowReassoc() && I.hasNoSignedZeros()) { if (Value *V = FAddCombine(Builder).simplify(&I)) return replaceInstUsesWith(I, V); } diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 2364202e5b69..372bc41f780e 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -14,10 +14,10 @@ #include "InstCombineInternal.h" #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/Transforms/Utils/Local.h" using namespace llvm; using namespace PatternMatch; @@ -75,7 +75,7 @@ static Value *getFCmpValue(unsigned Code, Value *LHS, Value *RHS, return Builder.CreateFCmp(Pred, LHS, RHS); } -/// \brief Transform BITWISE_OP(BSWAP(A),BSWAP(B)) or +/// Transform BITWISE_OP(BSWAP(A),BSWAP(B)) or /// BITWISE_OP(BSWAP(A), Constant) to BSWAP(BITWISE_OP(A, B)) /// \param I Binary operator to transform. /// \return Pointer to node that must replace the original binary operator, or @@ -305,17 +305,21 @@ static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pre } /// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E). -/// Return the set of pattern classes (from MaskedICmpType) that both LHS and -/// RHS satisfy. -static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, - Value *&D, Value *&E, ICmpInst *LHS, - ICmpInst *RHS, - ICmpInst::Predicate &PredL, - ICmpInst::Predicate &PredR) { +/// Return the pattern classes (from MaskedICmpType) for the left hand side and +/// the right hand side as a pair. +/// LHS and RHS are the left hand side and the right hand side ICmps and PredL +/// and PredR are their predicates, respectively. +static +Optional<std::pair<unsigned, unsigned>> +getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, + Value *&D, Value *&E, ICmpInst *LHS, + ICmpInst *RHS, + ICmpInst::Predicate &PredL, + ICmpInst::Predicate &PredR) { // vectors are not (yet?) supported. Don't support pointers either. if (!LHS->getOperand(0)->getType()->isIntegerTy() || !RHS->getOperand(0)->getType()->isIntegerTy()) - return 0; + return None; // Here comes the tricky part: // LHS might be of the form L11 & L12 == X, X == L21 & L22, @@ -346,7 +350,7 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, // Bail if LHS was a icmp that can't be decomposed into an equality. if (!ICmpInst::isEquality(PredL)) - return 0; + return None; Value *R1 = RHS->getOperand(0); Value *R2 = RHS->getOperand(1); @@ -360,7 +364,7 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, A = R12; D = R11; } else { - return 0; + return None; } E = R2; R1 = nullptr; @@ -388,7 +392,7 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, // Bail if RHS was a icmp that can't be decomposed into an equality. if (!ICmpInst::isEquality(PredR)) - return 0; + return None; // Look for ANDs on the right side of the RHS icmp. if (!Ok) { @@ -408,11 +412,11 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, E = R1; Ok = true; } else { - return 0; + return None; } } if (!Ok) - return 0; + return None; if (L11 == A) { B = L12; @@ -430,7 +434,174 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, unsigned LeftType = getMaskedICmpType(A, B, C, PredL); unsigned RightType = getMaskedICmpType(A, D, E, PredR); - return LeftType & RightType; + return Optional<std::pair<unsigned, unsigned>>(std::make_pair(LeftType, RightType)); +} + +/// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) into a single +/// (icmp(A & X) ==/!= Y), where the left-hand side is of type Mask_NotAllZeros +/// and the right hand side is of type BMask_Mixed. For example, +/// (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8). +static Value * foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( + ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, + Value *A, Value *B, Value *C, Value *D, Value *E, + ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, + llvm::InstCombiner::BuilderTy &Builder) { + // We are given the canonical form: + // (icmp ne (A & B), 0) & (icmp eq (A & D), E). + // where D & E == E. + // + // If IsAnd is false, we get it in negated form: + // (icmp eq (A & B), 0) | (icmp ne (A & D), E) -> + // !((icmp ne (A & B), 0) & (icmp eq (A & D), E)). + // + // We currently handle the case of B, C, D, E are constant. + // + ConstantInt *BCst = dyn_cast<ConstantInt>(B); + if (!BCst) + return nullptr; + ConstantInt *CCst = dyn_cast<ConstantInt>(C); + if (!CCst) + return nullptr; + ConstantInt *DCst = dyn_cast<ConstantInt>(D); + if (!DCst) + return nullptr; + ConstantInt *ECst = dyn_cast<ConstantInt>(E); + if (!ECst) + return nullptr; + + ICmpInst::Predicate NewCC = IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; + + // Update E to the canonical form when D is a power of two and RHS is + // canonicalized as, + // (icmp ne (A & D), 0) -> (icmp eq (A & D), D) or + // (icmp ne (A & D), D) -> (icmp eq (A & D), 0). + if (PredR != NewCC) + ECst = cast<ConstantInt>(ConstantExpr::getXor(DCst, ECst)); + + // If B or D is zero, skip because if LHS or RHS can be trivially folded by + // other folding rules and this pattern won't apply any more. + if (BCst->getValue() == 0 || DCst->getValue() == 0) + return nullptr; + + // If B and D don't intersect, ie. (B & D) == 0, no folding because we can't + // deduce anything from it. + // For example, + // (icmp ne (A & 12), 0) & (icmp eq (A & 3), 1) -> no folding. + if ((BCst->getValue() & DCst->getValue()) == 0) + return nullptr; + + // If the following two conditions are met: + // + // 1. mask B covers only a single bit that's not covered by mask D, that is, + // (B & (B ^ D)) is a power of 2 (in other words, B minus the intersection of + // B and D has only one bit set) and, + // + // 2. RHS (and E) indicates that the rest of B's bits are zero (in other + // words, the intersection of B and D is zero), that is, ((B & D) & E) == 0 + // + // then that single bit in B must be one and thus the whole expression can be + // folded to + // (A & (B | D)) == (B & (B ^ D)) | E. + // + // For example, + // (icmp ne (A & 12), 0) & (icmp eq (A & 7), 1) -> (icmp eq (A & 15), 9) + // (icmp ne (A & 15), 0) & (icmp eq (A & 7), 0) -> (icmp eq (A & 15), 8) + if ((((BCst->getValue() & DCst->getValue()) & ECst->getValue()) == 0) && + (BCst->getValue() & (BCst->getValue() ^ DCst->getValue())).isPowerOf2()) { + APInt BorD = BCst->getValue() | DCst->getValue(); + APInt BandBxorDorE = (BCst->getValue() & (BCst->getValue() ^ DCst->getValue())) | + ECst->getValue(); + Value *NewMask = ConstantInt::get(BCst->getType(), BorD); + Value *NewMaskedValue = ConstantInt::get(BCst->getType(), BandBxorDorE); + Value *NewAnd = Builder.CreateAnd(A, NewMask); + return Builder.CreateICmp(NewCC, NewAnd, NewMaskedValue); + } + + auto IsSubSetOrEqual = [](ConstantInt *C1, ConstantInt *C2) { + return (C1->getValue() & C2->getValue()) == C1->getValue(); + }; + auto IsSuperSetOrEqual = [](ConstantInt *C1, ConstantInt *C2) { + return (C1->getValue() & C2->getValue()) == C2->getValue(); + }; + + // In the following, we consider only the cases where B is a superset of D, B + // is a subset of D, or B == D because otherwise there's at least one bit + // covered by B but not D, in which case we can't deduce much from it, so + // no folding (aside from the single must-be-one bit case right above.) + // For example, + // (icmp ne (A & 14), 0) & (icmp eq (A & 3), 1) -> no folding. + if (!IsSubSetOrEqual(BCst, DCst) && !IsSuperSetOrEqual(BCst, DCst)) + return nullptr; + + // At this point, either B is a superset of D, B is a subset of D or B == D. + + // If E is zero, if B is a subset of (or equal to) D, LHS and RHS contradict + // and the whole expression becomes false (or true if negated), otherwise, no + // folding. + // For example, + // (icmp ne (A & 3), 0) & (icmp eq (A & 7), 0) -> false. + // (icmp ne (A & 15), 0) & (icmp eq (A & 3), 0) -> no folding. + if (ECst->isZero()) { + if (IsSubSetOrEqual(BCst, DCst)) + return ConstantInt::get(LHS->getType(), !IsAnd); + return nullptr; + } + + // At this point, B, D, E aren't zero and (B & D) == B, (B & D) == D or B == + // D. If B is a superset of (or equal to) D, since E is not zero, LHS is + // subsumed by RHS (RHS implies LHS.) So the whole expression becomes + // RHS. For example, + // (icmp ne (A & 255), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8). + // (icmp ne (A & 15), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8). + if (IsSuperSetOrEqual(BCst, DCst)) + return RHS; + // Otherwise, B is a subset of D. If B and E have a common bit set, + // ie. (B & E) != 0, then LHS is subsumed by RHS. For example. + // (icmp ne (A & 12), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8). + assert(IsSubSetOrEqual(BCst, DCst) && "Precondition due to above code"); + if ((BCst->getValue() & ECst->getValue()) != 0) + return RHS; + // Otherwise, LHS and RHS contradict and the whole expression becomes false + // (or true if negated.) For example, + // (icmp ne (A & 7), 0) & (icmp eq (A & 15), 8) -> false. + // (icmp ne (A & 6), 0) & (icmp eq (A & 15), 8) -> false. + return ConstantInt::get(LHS->getType(), !IsAnd); +} + +/// Try to fold (icmp(A & B) ==/!= 0) &/| (icmp(A & D) ==/!= E) into a single +/// (icmp(A & X) ==/!= Y), where the left-hand side and the right hand side +/// aren't of the common mask pattern type. +static Value *foldLogOpOfMaskedICmpsAsymmetric( + ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, + Value *A, Value *B, Value *C, Value *D, Value *E, + ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, + unsigned LHSMask, unsigned RHSMask, + llvm::InstCombiner::BuilderTy &Builder) { + assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) && + "Expected equality predicates for masked type of icmps."); + // Handle Mask_NotAllZeros-BMask_Mixed cases. + // (icmp ne/eq (A & B), C) &/| (icmp eq/ne (A & D), E), or + // (icmp eq/ne (A & B), C) &/| (icmp ne/eq (A & D), E) + // which gets swapped to + // (icmp ne/eq (A & D), E) &/| (icmp eq/ne (A & B), C). + if (!IsAnd) { + LHSMask = conjugateICmpMask(LHSMask); + RHSMask = conjugateICmpMask(RHSMask); + } + if ((LHSMask & Mask_NotAllZeros) && (RHSMask & BMask_Mixed)) { + if (Value *V = foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( + LHS, RHS, IsAnd, A, B, C, D, E, + PredL, PredR, Builder)) { + return V; + } + } else if ((LHSMask & BMask_Mixed) && (RHSMask & Mask_NotAllZeros)) { + if (Value *V = foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( + RHS, LHS, IsAnd, A, D, E, B, C, + PredR, PredL, Builder)) { + return V; + } + } + return nullptr; } /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) @@ -439,13 +610,24 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, llvm::InstCombiner::BuilderTy &Builder) { Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); - unsigned Mask = + Optional<std::pair<unsigned, unsigned>> MaskPair = getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR); - if (Mask == 0) + if (!MaskPair) return nullptr; - assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) && "Expected equality predicates for masked type of icmps."); + unsigned LHSMask = MaskPair->first; + unsigned RHSMask = MaskPair->second; + unsigned Mask = LHSMask & RHSMask; + if (Mask == 0) { + // Even if the two sides don't share a common pattern, check if folding can + // still happen. + if (Value *V = foldLogOpOfMaskedICmpsAsymmetric( + LHS, RHS, IsAnd, A, B, C, D, E, PredL, PredR, LHSMask, RHSMask, + Builder)) + return V; + return nullptr; + } // In full generality: // (icmp (A & B) Op C) | (icmp (A & D) Op E) @@ -939,8 +1121,8 @@ Value *InstCombiner::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd) return nullptr; // FCmp canonicalization ensures that (fcmp ord/uno X, X) and - // (fcmp ord/uno X, C) will be transformed to (fcmp X, 0.0). - if (match(LHS1, m_Zero()) && LHS1 == RHS1) + // (fcmp ord/uno X, C) will be transformed to (fcmp X, +0.0). + if (match(LHS1, m_PosZeroFP()) && match(RHS1, m_PosZeroFP())) // Ignore the constants because they are obviously not NANs: // (fcmp ord x, 0.0) & (fcmp ord y, 0.0) -> (fcmp ord x, y) // (fcmp uno x, 0.0) | (fcmp uno y, 0.0) -> (fcmp uno x, y) @@ -1106,8 +1288,8 @@ static Instruction *foldAndToXor(BinaryOperator &I, // Operand complexity canonicalization guarantees that the 'or' is Op0. // (A | B) & ~(A & B) --> A ^ B // (A | B) & ~(B & A) --> A ^ B - if (match(Op0, m_Or(m_Value(A), m_Value(B))) && - match(Op1, m_Not(m_c_And(m_Specific(A), m_Specific(B))))) + if (match(&I, m_BinOp(m_Or(m_Value(A), m_Value(B)), + m_Not(m_c_And(m_Deferred(A), m_Deferred(B)))))) return BinaryOperator::CreateXor(A, B); // (A | ~B) & (~A | B) --> ~(A ^ B) @@ -1115,8 +1297,8 @@ static Instruction *foldAndToXor(BinaryOperator &I, // (~B | A) & (~A | B) --> ~(A ^ B) // (~B | A) & (B | ~A) --> ~(A ^ B) if (Op0->hasOneUse() || Op1->hasOneUse()) - if (match(Op0, m_c_Or(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_c_Or(m_Not(m_Specific(A)), m_Specific(B)))) + if (match(&I, m_BinOp(m_c_Or(m_Value(A), m_Not(m_Value(B))), + m_c_Or(m_Not(m_Deferred(A)), m_Deferred(B))))) return BinaryOperator::CreateNot(Builder.CreateXor(A, B)); return nullptr; @@ -1148,18 +1330,86 @@ static Instruction *foldOrToXor(BinaryOperator &I, return nullptr; } +/// Return true if a constant shift amount is always less than the specified +/// bit-width. If not, the shift could create poison in the narrower type. +static bool canNarrowShiftAmt(Constant *C, unsigned BitWidth) { + if (auto *ScalarC = dyn_cast<ConstantInt>(C)) + return ScalarC->getZExtValue() < BitWidth; + + if (C->getType()->isVectorTy()) { + // Check each element of a constant vector. + unsigned NumElts = C->getType()->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = C->getAggregateElement(i); + if (!Elt) + return false; + if (isa<UndefValue>(Elt)) + continue; + auto *CI = dyn_cast<ConstantInt>(Elt); + if (!CI || CI->getZExtValue() >= BitWidth) + return false; + } + return true; + } + + // The constant is a constant expression or unknown. + return false; +} + +/// Try to use narrower ops (sink zext ops) for an 'and' with binop operand and +/// a common zext operand: and (binop (zext X), C), (zext X). +Instruction *InstCombiner::narrowMaskedBinOp(BinaryOperator &And) { + // This transform could also apply to {or, and, xor}, but there are better + // folds for those cases, so we don't expect those patterns here. AShr is not + // handled because it should always be transformed to LShr in this sequence. + // The subtract transform is different because it has a constant on the left. + // Add/mul commute the constant to RHS; sub with constant RHS becomes add. + Value *Op0 = And.getOperand(0), *Op1 = And.getOperand(1); + Constant *C; + if (!match(Op0, m_OneUse(m_Add(m_Specific(Op1), m_Constant(C)))) && + !match(Op0, m_OneUse(m_Mul(m_Specific(Op1), m_Constant(C)))) && + !match(Op0, m_OneUse(m_LShr(m_Specific(Op1), m_Constant(C)))) && + !match(Op0, m_OneUse(m_Shl(m_Specific(Op1), m_Constant(C)))) && + !match(Op0, m_OneUse(m_Sub(m_Constant(C), m_Specific(Op1))))) + return nullptr; + + Value *X; + if (!match(Op1, m_ZExt(m_Value(X))) || Op1->hasNUsesOrMore(3)) + return nullptr; + + Type *Ty = And.getType(); + if (!isa<VectorType>(Ty) && !shouldChangeType(Ty, X->getType())) + return nullptr; + + // If we're narrowing a shift, the shift amount must be safe (less than the + // width) in the narrower type. If the shift amount is greater, instsimplify + // usually handles that case, but we can't guarantee/assert it. + Instruction::BinaryOps Opc = cast<BinaryOperator>(Op0)->getOpcode(); + if (Opc == Instruction::LShr || Opc == Instruction::Shl) + if (!canNarrowShiftAmt(C, X->getType()->getScalarSizeInBits())) + return nullptr; + + // and (sub C, (zext X)), (zext X) --> zext (and (sub C', X), X) + // and (binop (zext X), C), (zext X) --> zext (and (binop X, C'), X) + Value *NewC = ConstantExpr::getTrunc(C, X->getType()); + Value *NewBO = Opc == Instruction::Sub ? Builder.CreateBinOp(Opc, NewC, X) + : Builder.CreateBinOp(Opc, X, NewC); + return new ZExtInst(Builder.CreateAnd(NewBO, X), Ty); +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. Instruction *InstCombiner::visitAnd(BinaryOperator &I) { - bool Changed = SimplifyAssociativeOrCommutative(I); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyAndInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyAndInst(Op0, Op1, SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldShuffledBinop(I)) + return X; // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. @@ -1177,6 +1427,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (Value *V = SimplifyBSwap(I, Builder)) return replaceInstUsesWith(I, V); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); const APInt *C; if (match(Op1, m_APInt(C))) { Value *X, *Y; @@ -1289,9 +1540,11 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { } } - if (isa<Constant>(Op1)) - if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) - return FoldedLogic; + if (Instruction *Z = narrowMaskedBinOp(I)) + return Z; + + if (Instruction *FoldedLogic = foldBinOpIntoSelectOrPhi(I)) + return FoldedLogic; if (Instruction *DeMorgan = matchDeMorgansLaws(I, Builder)) return DeMorgan; @@ -1397,7 +1650,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { A->getType()->isIntOrIntVectorTy(1)) return SelectInst::Create(A, Op0, Constant::getNullValue(I.getType())); - return Changed ? &I : nullptr; + return nullptr; } /// Given an OR instruction, check to see if this is a bswap idiom. If so, @@ -1424,7 +1677,18 @@ Instruction *InstCombiner::MatchBSwap(BinaryOperator &I) { bool OrOfAnds = match(Op0, m_And(m_Value(), m_Value())) && match(Op1, m_And(m_Value(), m_Value())); - if (!OrOfOrs && !OrOfShifts && !OrOfAnds) + // (A << B) | (C & D) -> bswap if possible. + // The bigger pattern here is ((A & C1) << C2) | ((B >> C2) & C1), which is a + // part of the bswap idiom for specific values of C1, C2 (e.g. C1 = 16711935, + // C2 = 8 for i32). + // This pattern can occur when the operands of the 'or' are not canonicalized + // for some reason (not having only one use, for example). + bool OrOfAndAndSh = (match(Op0, m_LogicalShift(m_Value(), m_Value())) && + match(Op1, m_And(m_Value(), m_Value()))) || + (match(Op0, m_And(m_Value(), m_Value())) && + match(Op1, m_LogicalShift(m_Value(), m_Value()))); + + if (!OrOfOrs && !OrOfShifts && !OrOfAnds && !OrOfAndAndSh) return nullptr; SmallVector<Instruction*, 4> Insts; @@ -1448,7 +1712,6 @@ static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) { return false; // One element must be all ones, and the other must be all zeros. - // FIXME: Allow undef elements. if (!((match(EltC1, m_Zero()) && match(EltC2, m_AllOnes())) || (match(EltC2, m_Zero()) && match(EltC1, m_AllOnes())))) return false; @@ -1755,14 +2018,15 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. Instruction *InstCombiner::visitOr(BinaryOperator &I) { - bool Changed = SimplifyAssociativeOrCommutative(I); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyOrInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyOrInst(Op0, Op1, SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldShuffledBinop(I)) + return X; // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. @@ -1780,14 +2044,14 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Value *V = SimplifyBSwap(I, Builder)) return replaceInstUsesWith(I, V); - if (isa<Constant>(Op1)) - if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) - return FoldedLogic; + if (Instruction *FoldedLogic = foldBinOpIntoSelectOrPhi(I)) + return FoldedLogic; // Given an OR instruction, check to see if this is a bswap. if (Instruction *BSwap = MatchBSwap(I)) return BSwap; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); { Value *A; const APInt *C; @@ -2027,7 +2291,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { } } - return Changed ? &I : nullptr; + return nullptr; } /// A ^ B can be specified using other logic ops in a variety of patterns. We @@ -2045,10 +2309,8 @@ static Instruction *foldXorToXor(BinaryOperator &I, // (A & B) ^ (B | A) -> A ^ B // (A | B) ^ (A & B) -> A ^ B // (A | B) ^ (B & A) -> A ^ B - if ((match(Op0, m_And(m_Value(A), m_Value(B))) && - match(Op1, m_c_Or(m_Specific(A), m_Specific(B)))) || - (match(Op0, m_Or(m_Value(A), m_Value(B))) && - match(Op1, m_c_And(m_Specific(A), m_Specific(B))))) { + if (match(&I, m_c_Xor(m_And(m_Value(A), m_Value(B)), + m_c_Or(m_Deferred(A), m_Deferred(B))))) { I.setOperand(0, A); I.setOperand(1, B); return &I; @@ -2058,10 +2320,8 @@ static Instruction *foldXorToXor(BinaryOperator &I, // (~B | A) ^ (~A | B) -> A ^ B // (~A | B) ^ (A | ~B) -> A ^ B // (B | ~A) ^ (A | ~B) -> A ^ B - if ((match(Op0, m_Or(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_c_Or(m_Not(m_Specific(A)), m_Specific(B)))) || - (match(Op0, m_Or(m_Not(m_Value(A)), m_Value(B))) && - match(Op1, m_c_Or(m_Specific(A), m_Not(m_Specific(B)))))) { + if (match(&I, m_Xor(m_c_Or(m_Value(A), m_Not(m_Value(B))), + m_c_Or(m_Not(m_Deferred(A)), m_Deferred(B))))) { I.setOperand(0, A); I.setOperand(1, B); return &I; @@ -2071,10 +2331,8 @@ static Instruction *foldXorToXor(BinaryOperator &I, // (~B & A) ^ (~A & B) -> A ^ B // (~A & B) ^ (A & ~B) -> A ^ B // (B & ~A) ^ (A & ~B) -> A ^ B - if ((match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_c_And(m_Not(m_Specific(A)), m_Specific(B)))) || - (match(Op0, m_And(m_Not(m_Value(A)), m_Value(B))) && - match(Op1, m_c_And(m_Specific(A), m_Not(m_Specific(B)))))) { + if (match(&I, m_Xor(m_c_And(m_Value(A), m_Not(m_Value(B))), + m_c_And(m_Not(m_Deferred(A)), m_Deferred(B))))) { I.setOperand(0, A); I.setOperand(1, B); return &I; @@ -2113,6 +2371,34 @@ Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS) { } } + // TODO: This can be generalized to compares of non-signbits using + // decomposeBitTestICmp(). It could be enhanced more by using (something like) + // foldLogOpOfMaskedICmps(). + ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); + Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); + Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); + if ((LHS->hasOneUse() || RHS->hasOneUse()) && + LHS0->getType() == RHS0->getType()) { + // (X > -1) ^ (Y > -1) --> (X ^ Y) < 0 + // (X < 0) ^ (Y < 0) --> (X ^ Y) < 0 + if ((PredL == CmpInst::ICMP_SGT && match(LHS1, m_AllOnes()) && + PredR == CmpInst::ICMP_SGT && match(RHS1, m_AllOnes())) || + (PredL == CmpInst::ICMP_SLT && match(LHS1, m_Zero()) && + PredR == CmpInst::ICMP_SLT && match(RHS1, m_Zero()))) { + Value *Zero = ConstantInt::getNullValue(LHS0->getType()); + return Builder.CreateICmpSLT(Builder.CreateXor(LHS0, RHS0), Zero); + } + // (X > -1) ^ (Y < 0) --> (X ^ Y) > -1 + // (X < 0) ^ (Y > -1) --> (X ^ Y) > -1 + if ((PredL == CmpInst::ICMP_SGT && match(LHS1, m_AllOnes()) && + PredR == CmpInst::ICMP_SLT && match(RHS1, m_Zero())) || + (PredL == CmpInst::ICMP_SLT && match(LHS1, m_Zero()) && + PredR == CmpInst::ICMP_SGT && match(RHS1, m_AllOnes()))) { + Value *MinusOne = ConstantInt::getAllOnesValue(LHS0->getType()); + return Builder.CreateICmpSGT(Builder.CreateXor(LHS0, RHS0), MinusOne); + } + } + // Instead of trying to imitate the folds for and/or, decompose this 'xor' // into those logic ops. That is, try to turn this into an and-of-icmps // because we have many folds for that pattern. @@ -2140,18 +2426,63 @@ Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS) { return nullptr; } +/// If we have a masked merge, in the canonical form of: +/// (assuming that A only has one use.) +/// | A | |B| +/// ((x ^ y) & M) ^ y +/// | D | +/// * If M is inverted: +/// | D | +/// ((x ^ y) & ~M) ^ y +/// We can canonicalize by swapping the final xor operand +/// to eliminate the 'not' of the mask. +/// ((x ^ y) & M) ^ x +/// * If M is a constant, and D has one use, we transform to 'and' / 'or' ops +/// because that shortens the dependency chain and improves analysis: +/// (x & M) | (y & ~M) +static Instruction *visitMaskedMerge(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *B, *X, *D; + Value *M; + if (!match(&I, m_c_Xor(m_Value(B), + m_OneUse(m_c_And( + m_CombineAnd(m_c_Xor(m_Deferred(B), m_Value(X)), + m_Value(D)), + m_Value(M)))))) + return nullptr; + + Value *NotM; + if (match(M, m_Not(m_Value(NotM)))) { + // De-invert the mask and swap the value in B part. + Value *NewA = Builder.CreateAnd(D, NotM); + return BinaryOperator::CreateXor(NewA, X); + } + + Constant *C; + if (D->hasOneUse() && match(M, m_Constant(C))) { + // Unfold. + Value *LHS = Builder.CreateAnd(X, C); + Value *NotC = Builder.CreateNot(C); + Value *RHS = Builder.CreateAnd(B, NotC); + return BinaryOperator::CreateOr(LHS, RHS); + } + + return nullptr; +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. Instruction *InstCombiner::visitXor(BinaryOperator &I) { - bool Changed = SimplifyAssociativeOrCommutative(I); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyXorInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyXorInst(Op0, Op1, SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldShuffledBinop(I)) + return X; if (Instruction *NewXor = foldXorToXor(I, Builder)) return NewXor; @@ -2168,6 +2499,11 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (Value *V = SimplifyBSwap(I, Builder)) return replaceInstUsesWith(I, V); + // A^B --> A|B iff A and B have no bits set in common. + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (haveNoCommonBitsSet(Op0, Op1, DL, &AC, &I, &DT)) + return BinaryOperator::CreateOr(Op0, Op1); + // Apply DeMorgan's Law for 'nand' / 'nor' logic with an inverted operand. Value *X, *Y; @@ -2186,6 +2522,9 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { return BinaryOperator::CreateAnd(X, NotY); } + if (Instruction *Xor = visitMaskedMerge(I, Builder)) + return Xor; + // Is this a 'not' (~) fed by a binary operator? BinaryOperator *NotVal; if (match(&I, m_Not(m_BinOp(NotVal)))) { @@ -2206,6 +2545,10 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } + // ~(X - Y) --> ~X + Y + if (match(NotVal, m_OneUse(m_Sub(m_Value(X), m_Value(Y))))) + return BinaryOperator::CreateAdd(Builder.CreateNot(X), Y); + // ~(~X >>s Y) --> (X >>s Y) if (match(NotVal, m_AShr(m_Not(m_Value(X)), m_Value(Y)))) return BinaryOperator::CreateAShr(X, Y); @@ -2214,16 +2557,18 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { // the 'not' by inverting the constant and using the opposite shift type. // Canonicalization rules ensure that only a negative constant uses 'ashr', // but we must check that in case that transform has not fired yet. - const APInt *C; - if (match(NotVal, m_AShr(m_APInt(C), m_Value(Y))) && C->isNegative()) { + Constant *C; + if (match(NotVal, m_AShr(m_Constant(C), m_Value(Y))) && + match(C, m_Negative())) { // ~(C >>s Y) --> ~C >>u Y (when inverting the replicated sign bits) - Constant *NotC = ConstantInt::get(I.getType(), ~(*C)); + Constant *NotC = ConstantExpr::getNot(C); return BinaryOperator::CreateLShr(NotC, Y); } - if (match(NotVal, m_LShr(m_APInt(C), m_Value(Y))) && C->isNonNegative()) { + if (match(NotVal, m_LShr(m_Constant(C), m_Value(Y))) && + match(C, m_NonNegative())) { // ~(C >>u Y) --> ~C >>s Y (when inverting the replicated sign bits) - Constant *NotC = ConstantInt::get(I.getType(), ~(*C)); + Constant *NotC = ConstantExpr::getNot(C); return BinaryOperator::CreateAShr(NotC, Y); } } @@ -2305,9 +2650,8 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } - if (isa<Constant>(Op1)) - if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) - return FoldedLogic; + if (Instruction *FoldedLogic = foldBinOpIntoSelectOrPhi(I)) + return FoldedLogic; { Value *A, *B; @@ -2397,25 +2741,59 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (Instruction *CastedXor = foldCastedBitwiseLogic(I)) return CastedXor; - // Canonicalize the shifty way to code absolute value to the common pattern. + // Canonicalize a shifty way to code absolute value to the common pattern. // There are 4 potential commuted variants. Move the 'ashr' candidate to Op1. // We're relying on the fact that we only do this transform when the shift has // exactly 2 uses and the add has exactly 1 use (otherwise, we might increase // instructions). - if (Op0->getNumUses() == 2) + if (Op0->hasNUses(2)) std::swap(Op0, Op1); const APInt *ShAmt; Type *Ty = I.getType(); if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) && - Op1->getNumUses() == 2 && *ShAmt == Ty->getScalarSizeInBits() - 1 && + Op1->hasNUses(2) && *ShAmt == Ty->getScalarSizeInBits() - 1 && match(Op0, m_OneUse(m_c_Add(m_Specific(A), m_Specific(Op1))))) { // B = ashr i32 A, 31 ; smear the sign bit // xor (add A, B), B ; add -1 and flip bits if negative // --> (A < 0) ? -A : A Value *Cmp = Builder.CreateICmpSLT(A, ConstantInt::getNullValue(Ty)); - return SelectInst::Create(Cmp, Builder.CreateNeg(A), A); + // Copy the nuw/nsw flags from the add to the negate. + auto *Add = cast<BinaryOperator>(Op0); + Value *Neg = Builder.CreateNeg(A, "", Add->hasNoUnsignedWrap(), + Add->hasNoSignedWrap()); + return SelectInst::Create(Cmp, Neg, A); + } + + // Eliminate a bitwise 'not' op of 'not' min/max by inverting the min/max: + // + // %notx = xor i32 %x, -1 + // %cmp1 = icmp sgt i32 %notx, %y + // %smax = select i1 %cmp1, i32 %notx, i32 %y + // %res = xor i32 %smax, -1 + // => + // %noty = xor i32 %y, -1 + // %cmp2 = icmp slt %x, %noty + // %res = select i1 %cmp2, i32 %x, i32 %noty + // + // Same is applicable for smin/umax/umin. + { + Value *LHS, *RHS; + SelectPatternFlavor SPF = matchSelectPattern(Op0, LHS, RHS).Flavor; + if (Op0->hasOneUse() && SelectPatternResult::isMinOrMax(SPF) && + match(Op1, m_AllOnes())) { + + Value *X; + if (match(RHS, m_Not(m_Value(X)))) + std::swap(RHS, LHS); + + if (match(LHS, m_Not(m_Value(X)))) { + Value *NotY = Builder.CreateNot(RHS); + return SelectInst::Create( + Builder.CreateICmp(getInverseMinMaxPred(SPF), X, NotY), X, NotY); + } + } } - return Changed ? &I : nullptr; + return nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index 40e52ee755e5..cbfbd8a53993 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -24,6 +24,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -57,7 +58,6 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SimplifyLibCalls.h" #include <algorithm> #include <cassert> @@ -73,11 +73,11 @@ using namespace PatternMatch; STATISTIC(NumSimplified, "Number of library calls simplified"); -static cl::opt<unsigned> UnfoldElementAtomicMemcpyMaxElements( - "unfold-element-atomic-memcpy-max-elements", - cl::init(16), - cl::desc("Maximum number of elements in atomic memcpy the optimizer is " - "allowed to unfold")); +static cl::opt<unsigned> GuardWideningWindow( + "instcombine-guard-widening-window", + cl::init(3), + cl::desc("How wide an instruction window to bypass looking for " + "another guard")); /// Return the specified type promoted as it would be to pass though a va_arg /// area. @@ -106,97 +106,24 @@ static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) { return ConstantVector::get(BoolVec); } -Instruction * -InstCombiner::SimplifyElementUnorderedAtomicMemCpy(AtomicMemCpyInst *AMI) { - // Try to unfold this intrinsic into sequence of explicit atomic loads and - // stores. - // First check that number of elements is compile time constant. - auto *LengthCI = dyn_cast<ConstantInt>(AMI->getLength()); - if (!LengthCI) - return nullptr; - - // Check that there are not too many elements. - uint64_t LengthInBytes = LengthCI->getZExtValue(); - uint32_t ElementSizeInBytes = AMI->getElementSizeInBytes(); - uint64_t NumElements = LengthInBytes / ElementSizeInBytes; - if (NumElements >= UnfoldElementAtomicMemcpyMaxElements) - return nullptr; - - // Only expand if there are elements to copy. - if (NumElements > 0) { - // Don't unfold into illegal integers - uint64_t ElementSizeInBits = ElementSizeInBytes * 8; - if (!getDataLayout().isLegalInteger(ElementSizeInBits)) - return nullptr; - - // Cast source and destination to the correct type. Intrinsic input - // arguments are usually represented as i8*. Often operands will be - // explicitly casted to i8* and we can just strip those casts instead of - // inserting new ones. However it's easier to rely on other InstCombine - // rules which will cover trivial cases anyway. - Value *Src = AMI->getRawSource(); - Value *Dst = AMI->getRawDest(); - Type *ElementPointerType = - Type::getIntNPtrTy(AMI->getContext(), ElementSizeInBits, - Src->getType()->getPointerAddressSpace()); - - Value *SrcCasted = Builder.CreatePointerCast(Src, ElementPointerType, - "memcpy_unfold.src_casted"); - Value *DstCasted = Builder.CreatePointerCast(Dst, ElementPointerType, - "memcpy_unfold.dst_casted"); - - for (uint64_t i = 0; i < NumElements; ++i) { - // Get current element addresses - ConstantInt *ElementIdxCI = - ConstantInt::get(AMI->getContext(), APInt(64, i)); - Value *SrcElementAddr = - Builder.CreateGEP(SrcCasted, ElementIdxCI, "memcpy_unfold.src_addr"); - Value *DstElementAddr = - Builder.CreateGEP(DstCasted, ElementIdxCI, "memcpy_unfold.dst_addr"); - - // Load from the source. Transfer alignment information and mark load as - // unordered atomic. - LoadInst *Load = Builder.CreateLoad(SrcElementAddr, "memcpy_unfold.val"); - Load->setOrdering(AtomicOrdering::Unordered); - // We know alignment of the first element. It is also guaranteed by the - // verifier that element size is less or equal than first element - // alignment and both of this values are powers of two. This means that - // all subsequent accesses are at least element size aligned. - // TODO: We can infer better alignment but there is no evidence that this - // will matter. - Load->setAlignment(i == 0 ? AMI->getParamAlignment(1) - : ElementSizeInBytes); - Load->setDebugLoc(AMI->getDebugLoc()); - - // Store loaded value via unordered atomic store. - StoreInst *Store = Builder.CreateStore(Load, DstElementAddr); - Store->setOrdering(AtomicOrdering::Unordered); - Store->setAlignment(i == 0 ? AMI->getParamAlignment(0) - : ElementSizeInBytes); - Store->setDebugLoc(AMI->getDebugLoc()); - } +Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { + unsigned DstAlign = getKnownAlignment(MI->getRawDest(), DL, MI, &AC, &DT); + unsigned CopyDstAlign = MI->getDestAlignment(); + if (CopyDstAlign < DstAlign){ + MI->setDestAlignment(DstAlign); + return MI; } - // Set the number of elements of the copy to 0, it will be deleted on the - // next iteration. - AMI->setLength(Constant::getNullValue(LengthCI->getType())); - return AMI; -} - -Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { - unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL, MI, &AC, &DT); - unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL, MI, &AC, &DT); - unsigned MinAlign = std::min(DstAlign, SrcAlign); - unsigned CopyAlign = MI->getAlignment(); - - if (CopyAlign < MinAlign) { - MI->setAlignment(ConstantInt::get(MI->getAlignmentType(), MinAlign, false)); + unsigned SrcAlign = getKnownAlignment(MI->getRawSource(), DL, MI, &AC, &DT); + unsigned CopySrcAlign = MI->getSourceAlignment(); + if (CopySrcAlign < SrcAlign) { + MI->setSourceAlignment(SrcAlign); return MI; } // If MemCpyInst length is 1/2/4/8 bytes then replace memcpy with // load/store. - ConstantInt *MemOpLength = dyn_cast<ConstantInt>(MI->getArgOperand(2)); + ConstantInt *MemOpLength = dyn_cast<ConstantInt>(MI->getLength()); if (!MemOpLength) return nullptr; // Source and destination pointer types are always "i8*" for intrinsic. See @@ -222,7 +149,9 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { // If the memcpy has metadata describing the members, see if we can get the // TBAA tag describing our copy. MDNode *CopyMD = nullptr; - if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa_struct)) { + if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa)) { + CopyMD = M; + } else if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa_struct)) { if (M->getNumOperands() == 3 && M->getOperand(0) && mdconst::hasa<ConstantInt>(M->getOperand(0)) && mdconst::extract<ConstantInt>(M->getOperand(0))->isZero() && @@ -234,15 +163,11 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { CopyMD = cast<MDNode>(M->getOperand(2)); } - // If the memcpy/memmove provides better alignment info than we can - // infer, use it. - SrcAlign = std::max(SrcAlign, CopyAlign); - DstAlign = std::max(DstAlign, CopyAlign); - Value *Src = Builder.CreateBitCast(MI->getArgOperand(1), NewSrcPtrTy); Value *Dest = Builder.CreateBitCast(MI->getArgOperand(0), NewDstPtrTy); - LoadInst *L = Builder.CreateLoad(Src, MI->isVolatile()); - L->setAlignment(SrcAlign); + LoadInst *L = Builder.CreateLoad(Src); + // Alignment from the mem intrinsic will be better, so use it. + L->setAlignment(CopySrcAlign); if (CopyMD) L->setMetadata(LLVMContext::MD_tbaa, CopyMD); MDNode *LoopMemParallelMD = @@ -250,23 +175,34 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { if (LoopMemParallelMD) L->setMetadata(LLVMContext::MD_mem_parallel_loop_access, LoopMemParallelMD); - StoreInst *S = Builder.CreateStore(L, Dest, MI->isVolatile()); - S->setAlignment(DstAlign); + StoreInst *S = Builder.CreateStore(L, Dest); + // Alignment from the mem intrinsic will be better, so use it. + S->setAlignment(CopyDstAlign); if (CopyMD) S->setMetadata(LLVMContext::MD_tbaa, CopyMD); if (LoopMemParallelMD) S->setMetadata(LLVMContext::MD_mem_parallel_loop_access, LoopMemParallelMD); + if (auto *MT = dyn_cast<MemTransferInst>(MI)) { + // non-atomics can be volatile + L->setVolatile(MT->isVolatile()); + S->setVolatile(MT->isVolatile()); + } + if (isa<AtomicMemTransferInst>(MI)) { + // atomics have to be unordered + L->setOrdering(AtomicOrdering::Unordered); + S->setOrdering(AtomicOrdering::Unordered); + } + // Set the size of the copy to 0, it will be deleted on the next iteration. - MI->setArgOperand(2, Constant::getNullValue(MemOpLength->getType())); + MI->setLength(Constant::getNullValue(MemOpLength->getType())); return MI; } -Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) { +Instruction *InstCombiner::SimplifyAnyMemSet(AnyMemSetInst *MI) { unsigned Alignment = getKnownAlignment(MI->getDest(), DL, MI, &AC, &DT); - if (MI->getAlignment() < Alignment) { - MI->setAlignment(ConstantInt::get(MI->getAlignmentType(), - Alignment, false)); + if (MI->getDestAlignment() < Alignment) { + MI->setDestAlignment(Alignment); return MI; } @@ -276,7 +212,7 @@ Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) { if (!LenC || !FillC || !FillC->getType()->isIntegerTy(8)) return nullptr; uint64_t Len = LenC->getLimitedValue(); - Alignment = MI->getAlignment(); + Alignment = MI->getDestAlignment(); assert(Len && "0-sized memory setting should be removed already."); // memset(s,c,n) -> store s, c (for n=1,2,4,8) @@ -296,6 +232,8 @@ Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) { StoreInst *S = Builder.CreateStore(ConstantInt::get(ITy, Fill), Dest, MI->isVolatile()); S->setAlignment(Alignment); + if (isa<AtomicMemSetInst>(MI)) + S->setOrdering(AtomicOrdering::Unordered); // Set the size of the copy to 0, it will be deleted on the next iteration. MI->setLength(Constant::getNullValue(LenC->getType())); @@ -563,55 +501,6 @@ static Value *simplifyX86varShift(const IntrinsicInst &II, return Builder.CreateAShr(Vec, ShiftVec); } -static Value *simplifyX86muldq(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder) { - Value *Arg0 = II.getArgOperand(0); - Value *Arg1 = II.getArgOperand(1); - Type *ResTy = II.getType(); - assert(Arg0->getType()->getScalarSizeInBits() == 32 && - Arg1->getType()->getScalarSizeInBits() == 32 && - ResTy->getScalarSizeInBits() == 64 && "Unexpected muldq/muludq types"); - - // muldq/muludq(undef, undef) -> zero (matches generic mul behavior) - if (isa<UndefValue>(Arg0) || isa<UndefValue>(Arg1)) - return ConstantAggregateZero::get(ResTy); - - // Constant folding. - // PMULDQ = (mul(vXi64 sext(shuffle<0,2,..>(Arg0)), - // vXi64 sext(shuffle<0,2,..>(Arg1)))) - // PMULUDQ = (mul(vXi64 zext(shuffle<0,2,..>(Arg0)), - // vXi64 zext(shuffle<0,2,..>(Arg1)))) - if (!isa<Constant>(Arg0) || !isa<Constant>(Arg1)) - return nullptr; - - unsigned NumElts = ResTy->getVectorNumElements(); - assert(Arg0->getType()->getVectorNumElements() == (2 * NumElts) && - Arg1->getType()->getVectorNumElements() == (2 * NumElts) && - "Unexpected muldq/muludq types"); - - unsigned IntrinsicID = II.getIntrinsicID(); - bool IsSigned = (Intrinsic::x86_sse41_pmuldq == IntrinsicID || - Intrinsic::x86_avx2_pmul_dq == IntrinsicID || - Intrinsic::x86_avx512_pmul_dq_512 == IntrinsicID); - - SmallVector<unsigned, 16> ShuffleMask; - for (unsigned i = 0; i != NumElts; ++i) - ShuffleMask.push_back(i * 2); - - auto *LHS = Builder.CreateShuffleVector(Arg0, Arg0, ShuffleMask); - auto *RHS = Builder.CreateShuffleVector(Arg1, Arg1, ShuffleMask); - - if (IsSigned) { - LHS = Builder.CreateSExt(LHS, ResTy); - RHS = Builder.CreateSExt(RHS, ResTy); - } else { - LHS = Builder.CreateZExt(LHS, ResTy); - RHS = Builder.CreateZExt(RHS, ResTy); - } - - return Builder.CreateMul(LHS, RHS); -} - static Value *simplifyX86pack(IntrinsicInst &II, bool IsSigned) { Value *Arg0 = II.getArgOperand(0); Value *Arg1 = II.getArgOperand(1); @@ -687,6 +576,105 @@ static Value *simplifyX86pack(IntrinsicInst &II, bool IsSigned) { return ConstantVector::get(Vals); } +// Replace X86-specific intrinsics with generic floor-ceil where applicable. +static Value *simplifyX86round(IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + ConstantInt *Arg = nullptr; + Intrinsic::ID IntrinsicID = II.getIntrinsicID(); + + if (IntrinsicID == Intrinsic::x86_sse41_round_ss || + IntrinsicID == Intrinsic::x86_sse41_round_sd) + Arg = dyn_cast<ConstantInt>(II.getArgOperand(2)); + else if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) + Arg = dyn_cast<ConstantInt>(II.getArgOperand(4)); + else + Arg = dyn_cast<ConstantInt>(II.getArgOperand(1)); + if (!Arg) + return nullptr; + unsigned RoundControl = Arg->getZExtValue(); + + Arg = nullptr; + unsigned SAE = 0; + if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_512 || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_512) + Arg = dyn_cast<ConstantInt>(II.getArgOperand(4)); + else if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) + Arg = dyn_cast<ConstantInt>(II.getArgOperand(5)); + else + SAE = 4; + if (!SAE) { + if (!Arg) + return nullptr; + SAE = Arg->getZExtValue(); + } + + if (SAE != 4 || (RoundControl != 2 /*ceil*/ && RoundControl != 1 /*floor*/)) + return nullptr; + + Value *Src, *Dst, *Mask; + bool IsScalar = false; + if (IntrinsicID == Intrinsic::x86_sse41_round_ss || + IntrinsicID == Intrinsic::x86_sse41_round_sd || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) { + IsScalar = true; + if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) { + Mask = II.getArgOperand(3); + Value *Zero = Constant::getNullValue(Mask->getType()); + Mask = Builder.CreateAnd(Mask, 1); + Mask = Builder.CreateICmp(ICmpInst::ICMP_NE, Mask, Zero); + Dst = II.getArgOperand(2); + } else + Dst = II.getArgOperand(0); + Src = Builder.CreateExtractElement(II.getArgOperand(1), (uint64_t)0); + } else { + Src = II.getArgOperand(0); + if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_128 || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_256 || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_512 || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_128 || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_256 || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_512) { + Dst = II.getArgOperand(2); + Mask = II.getArgOperand(3); + } else { + Dst = Src; + Mask = ConstantInt::getAllOnesValue( + Builder.getIntNTy(Src->getType()->getVectorNumElements())); + } + } + + Intrinsic::ID ID = (RoundControl == 2) ? Intrinsic::ceil : Intrinsic::floor; + Value *Res = Builder.CreateIntrinsic(ID, {Src}, &II); + if (!IsScalar) { + if (auto *C = dyn_cast<Constant>(Mask)) + if (C->isAllOnesValue()) + return Res; + auto *MaskTy = VectorType::get( + Builder.getInt1Ty(), cast<IntegerType>(Mask->getType())->getBitWidth()); + Mask = Builder.CreateBitCast(Mask, MaskTy); + unsigned Width = Src->getType()->getVectorNumElements(); + if (MaskTy->getVectorNumElements() > Width) { + uint32_t Indices[4]; + for (unsigned i = 0; i != Width; ++i) + Indices[i] = i; + Mask = Builder.CreateShuffleVector(Mask, Mask, + makeArrayRef(Indices, Width)); + } + return Builder.CreateSelect(Mask, Res, Dst); + } + if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss || + IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) { + Dst = Builder.CreateExtractElement(Dst, (uint64_t)0); + Res = Builder.CreateSelect(Mask, Res, Dst); + Dst = II.getArgOperand(0); + } + return Builder.CreateInsertElement(Dst, Res, (uint64_t)0); +} + static Value *simplifyX86movmsk(const IntrinsicInst &II) { Value *Arg = II.getArgOperand(0); Type *ResTy = II.getType(); @@ -1145,36 +1133,6 @@ static Value *simplifyX86vpcom(const IntrinsicInst &II, return nullptr; } -// Emit a select instruction and appropriate bitcasts to help simplify -// masked intrinsics. -static Value *emitX86MaskSelect(Value *Mask, Value *Op0, Value *Op1, - InstCombiner::BuilderTy &Builder) { - unsigned VWidth = Op0->getType()->getVectorNumElements(); - - // If the mask is all ones we don't need the select. But we need to check - // only the bit thats will be used in case VWidth is less than 8. - if (auto *C = dyn_cast<ConstantInt>(Mask)) - if (C->getValue().zextOrTrunc(VWidth).isAllOnesValue()) - return Op0; - - auto *MaskTy = VectorType::get(Builder.getInt1Ty(), - cast<IntegerType>(Mask->getType())->getBitWidth()); - Mask = Builder.CreateBitCast(Mask, MaskTy); - - // If we have less than 8 elements, then the starting mask was an i8 and - // we need to extract down to the right number of elements. - if (VWidth < 8) { - uint32_t Indices[4]; - for (unsigned i = 0; i != VWidth; ++i) - Indices[i] = i; - Mask = Builder.CreateShuffleVector(Mask, Mask, - makeArrayRef(Indices, VWidth), - "extract"); - } - - return Builder.CreateSelect(Mask, Op0, Op1); -} - static Value *simplifyMinnumMaxnum(const IntrinsicInst &II) { Value *Arg0 = II.getArgOperand(0); Value *Arg1 = II.getArgOperand(1); @@ -1308,6 +1266,40 @@ static Instruction *simplifyMaskedGather(IntrinsicInst &II, InstCombiner &IC) { return nullptr; } +/// This function transforms launder.invariant.group and strip.invariant.group +/// like: +/// launder(launder(%x)) -> launder(%x) (the result is not the argument) +/// launder(strip(%x)) -> launder(%x) +/// strip(strip(%x)) -> strip(%x) (the result is not the argument) +/// strip(launder(%x)) -> strip(%x) +/// This is legal because it preserves the most recent information about +/// the presence or absence of invariant.group. +static Instruction *simplifyInvariantGroupIntrinsic(IntrinsicInst &II, + InstCombiner &IC) { + auto *Arg = II.getArgOperand(0); + auto *StrippedArg = Arg->stripPointerCasts(); + auto *StrippedInvariantGroupsArg = Arg->stripPointerCastsAndInvariantGroups(); + if (StrippedArg == StrippedInvariantGroupsArg) + return nullptr; // No launders/strips to remove. + + Value *Result = nullptr; + + if (II.getIntrinsicID() == Intrinsic::launder_invariant_group) + Result = IC.Builder.CreateLaunderInvariantGroup(StrippedInvariantGroupsArg); + else if (II.getIntrinsicID() == Intrinsic::strip_invariant_group) + Result = IC.Builder.CreateStripInvariantGroup(StrippedInvariantGroupsArg); + else + llvm_unreachable( + "simplifyInvariantGroupIntrinsic only handles launder and strip"); + if (Result->getType()->getPointerAddressSpace() != + II.getType()->getPointerAddressSpace()) + Result = IC.Builder.CreateAddrSpaceCast(Result, II.getType()); + if (Result->getType() != II.getType()) + Result = IC.Builder.CreateBitCast(Result, II.getType()); + + return cast<Instruction>(Result); +} + static Instruction *simplifyMaskedScatter(IntrinsicInst &II, InstCombiner &IC) { // If the mask is all zeros, a scatter does nothing. auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); @@ -1498,6 +1490,68 @@ static APFloat fmed3AMDGCN(const APFloat &Src0, const APFloat &Src1, return maxnum(Src0, Src1); } +/// Convert a table lookup to shufflevector if the mask is constant. +/// This could benefit tbl1 if the mask is { 7,6,5,4,3,2,1,0 }, in +/// which case we could lower the shufflevector with rev64 instructions +/// as it's actually a byte reverse. +static Value *simplifyNeonTbl1(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + // Bail out if the mask is not a constant. + auto *C = dyn_cast<Constant>(II.getArgOperand(1)); + if (!C) + return nullptr; + + auto *VecTy = cast<VectorType>(II.getType()); + unsigned NumElts = VecTy->getNumElements(); + + // Only perform this transformation for <8 x i8> vector types. + if (!VecTy->getElementType()->isIntegerTy(8) || NumElts != 8) + return nullptr; + + uint32_t Indexes[8]; + + for (unsigned I = 0; I < NumElts; ++I) { + Constant *COp = C->getAggregateElement(I); + + if (!COp || !isa<ConstantInt>(COp)) + return nullptr; + + Indexes[I] = cast<ConstantInt>(COp)->getLimitedValue(); + + // Make sure the mask indices are in range. + if (Indexes[I] >= NumElts) + return nullptr; + } + + auto *ShuffleMask = ConstantDataVector::get(II.getContext(), + makeArrayRef(Indexes)); + auto *V1 = II.getArgOperand(0); + auto *V2 = Constant::getNullValue(V1->getType()); + return Builder.CreateShuffleVector(V1, V2, ShuffleMask); +} + +/// Convert a vector load intrinsic into a simple llvm load instruction. +/// This is beneficial when the underlying object being addressed comes +/// from a constant, since we get constant-folding for free. +static Value *simplifyNeonVld1(const IntrinsicInst &II, + unsigned MemAlign, + InstCombiner::BuilderTy &Builder) { + auto *IntrAlign = dyn_cast<ConstantInt>(II.getArgOperand(1)); + + if (!IntrAlign) + return nullptr; + + unsigned Alignment = IntrAlign->getLimitedValue() < MemAlign ? + MemAlign : IntrAlign->getLimitedValue(); + + if (!isPowerOf2_32(Alignment)) + return nullptr; + + auto *BCastInst = Builder.CreateBitCast(II.getArgOperand(0), + PointerType::get(II.getType(), 0)); + return Builder.CreateAlignedLoad(BCastInst, Alignment); +} + // Returns true iff the 2 intrinsics have the same operands, limiting the // comparison to the first NumOperands. static bool haveSameOperands(const IntrinsicInst &I, const IntrinsicInst &E, @@ -1820,7 +1874,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Intrinsics cannot occur in an invoke, so handle them here instead of in // visitCallSite. - if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(II)) { + if (auto *MI = dyn_cast<AnyMemIntrinsic>(II)) { bool Changed = false; // memmove/cpy/set of zero bytes is a noop. @@ -1837,17 +1891,21 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } // No other transformations apply to volatile transfers. - if (MI->isVolatile()) - return nullptr; + if (auto *M = dyn_cast<MemIntrinsic>(MI)) + if (M->isVolatile()) + return nullptr; // If we have a memmove and the source operation is a constant global, // then the source and dest pointers can't alias, so we can change this // into a call to memcpy. - if (MemMoveInst *MMI = dyn_cast<MemMoveInst>(MI)) { + if (auto *MMI = dyn_cast<AnyMemMoveInst>(MI)) { if (GlobalVariable *GVSrc = dyn_cast<GlobalVariable>(MMI->getSource())) if (GVSrc->isConstant()) { Module *M = CI.getModule(); - Intrinsic::ID MemCpyID = Intrinsic::memcpy; + Intrinsic::ID MemCpyID = + isa<AtomicMemMoveInst>(MMI) + ? Intrinsic::memcpy_element_unordered_atomic + : Intrinsic::memcpy; Type *Tys[3] = { CI.getArgOperand(0)->getType(), CI.getArgOperand(1)->getType(), CI.getArgOperand(2)->getType() }; @@ -1856,7 +1914,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } } - if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) { + if (AnyMemTransferInst *MTI = dyn_cast<AnyMemTransferInst>(MI)) { // memmove(x,x,size) -> noop. if (MTI->getSource() == MTI->getDest()) return eraseInstFromFunction(CI); @@ -1864,26 +1922,17 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // If we can determine a pointer alignment that is bigger than currently // set, update the alignment. - if (isa<MemTransferInst>(MI)) { - if (Instruction *I = SimplifyMemTransfer(MI)) + if (auto *MTI = dyn_cast<AnyMemTransferInst>(MI)) { + if (Instruction *I = SimplifyAnyMemTransfer(MTI)) return I; - } else if (MemSetInst *MSI = dyn_cast<MemSetInst>(MI)) { - if (Instruction *I = SimplifyMemSet(MSI)) + } else if (auto *MSI = dyn_cast<AnyMemSetInst>(MI)) { + if (Instruction *I = SimplifyAnyMemSet(MSI)) return I; } if (Changed) return II; } - if (auto *AMI = dyn_cast<AtomicMemCpyInst>(II)) { - if (Constant *C = dyn_cast<Constant>(AMI->getLength())) - if (C->isNullValue()) - return eraseInstFromFunction(*AMI); - - if (Instruction *I = SimplifyElementUnorderedAtomicMemCpy(AMI)) - return I; - } - if (Instruction *I = SimplifyNVVMIntrinsic(II, *this)) return I; @@ -1925,7 +1974,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return simplifyMaskedGather(*II, *this); case Intrinsic::masked_scatter: return simplifyMaskedScatter(*II, *this); - + case Intrinsic::launder_invariant_group: + case Intrinsic::strip_invariant_group: + if (auto *SkippedBarrier = simplifyInvariantGroupIntrinsic(*II, *this)) + return replaceInstUsesWith(*II, SkippedBarrier); + break; case Intrinsic::powi: if (ConstantInt *Power = dyn_cast<ConstantInt>(II->getArgOperand(1))) { // 0 and 1 are handled in instsimplify @@ -1991,8 +2044,24 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { II->setArgOperand(1, Arg0); return II; } + + // FIXME: Simplifications should be in instsimplify. if (Value *V = simplifyMinnumMaxnum(*II)) return replaceInstUsesWith(*II, V); + + Value *X, *Y; + if (match(Arg0, m_FNeg(m_Value(X))) && match(Arg1, m_FNeg(m_Value(Y))) && + (Arg0->hasOneUse() || Arg1->hasOneUse())) { + // If both operands are negated, invert the call and negate the result: + // minnum(-X, -Y) --> -(maxnum(X, Y)) + // maxnum(-X, -Y) --> -(minnum(X, Y)) + Intrinsic::ID NewIID = II->getIntrinsicID() == Intrinsic::maxnum ? + Intrinsic::minnum : Intrinsic::maxnum; + Value *NewCall = Builder.CreateIntrinsic(NewIID, { X, Y }, II); + Instruction *FNeg = BinaryOperator::CreateFNeg(NewCall); + FNeg->copyIRFlags(II); + return FNeg; + } break; } case Intrinsic::fmuladd: { @@ -2013,37 +2082,34 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *Src0 = II->getArgOperand(0); Value *Src1 = II->getArgOperand(1); - // Canonicalize constants into the RHS. + // Canonicalize constant multiply operand to Src1. if (isa<Constant>(Src0) && !isa<Constant>(Src1)) { II->setArgOperand(0, Src1); II->setArgOperand(1, Src0); std::swap(Src0, Src1); } - Value *LHS = nullptr; - Value *RHS = nullptr; - // fma fneg(x), fneg(y), z -> fma x, y, z - if (match(Src0, m_FNeg(m_Value(LHS))) && - match(Src1, m_FNeg(m_Value(RHS)))) { - II->setArgOperand(0, LHS); - II->setArgOperand(1, RHS); + Value *X, *Y; + if (match(Src0, m_FNeg(m_Value(X))) && match(Src1, m_FNeg(m_Value(Y)))) { + II->setArgOperand(0, X); + II->setArgOperand(1, Y); return II; } // fma fabs(x), fabs(x), z -> fma x, x, z - if (match(Src0, m_Intrinsic<Intrinsic::fabs>(m_Value(LHS))) && - match(Src1, m_Intrinsic<Intrinsic::fabs>(m_Value(RHS))) && LHS == RHS) { - II->setArgOperand(0, LHS); - II->setArgOperand(1, RHS); + if (match(Src0, m_FAbs(m_Value(X))) && + match(Src1, m_FAbs(m_Specific(X)))) { + II->setArgOperand(0, X); + II->setArgOperand(1, X); return II; } // fma x, 1, z -> fadd x, z if (match(Src1, m_FPOne())) { - Instruction *RI = BinaryOperator::CreateFAdd(Src0, II->getArgOperand(2)); - RI->copyFastMathFlags(II); - return RI; + auto *FAdd = BinaryOperator::CreateFAdd(Src0, II->getArgOperand(2)); + FAdd->copyFastMathFlags(II); + return FAdd; } break; @@ -2067,17 +2133,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::rint: case Intrinsic::trunc: { Value *ExtSrc; - if (match(II->getArgOperand(0), m_FPExt(m_Value(ExtSrc))) && - II->getArgOperand(0)->hasOneUse()) { - // fabs (fpext x) -> fpext (fabs x) - Value *F = Intrinsic::getDeclaration(II->getModule(), II->getIntrinsicID(), - { ExtSrc->getType() }); - CallInst *NewFabs = Builder.CreateCall(F, ExtSrc); - NewFabs->copyFastMathFlags(II); - NewFabs->takeName(II); - return new FPExtInst(NewFabs, II->getType()); + if (match(II->getArgOperand(0), m_OneUse(m_FPExt(m_Value(ExtSrc))))) { + // Narrow the call: intrinsic (fpext x) -> fpext (intrinsic x) + Value *NarrowII = Builder.CreateIntrinsic(II->getIntrinsicID(), + { ExtSrc }, II); + return new FPExtInst(NarrowII, II->getType()); } - break; } case Intrinsic::cos: @@ -2085,7 +2146,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *SrcSrc; Value *Src = II->getArgOperand(0); if (match(Src, m_FNeg(m_Value(SrcSrc))) || - match(Src, m_Intrinsic<Intrinsic::fabs>(m_Value(SrcSrc)))) { + match(Src, m_FAbs(m_Value(SrcSrc)))) { // cos(-x) -> cos(x) // cos(fabs(x)) -> cos(x) II->setArgOperand(0, SrcSrc); @@ -2298,6 +2359,22 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::x86_sse41_round_ps: + case Intrinsic::x86_sse41_round_pd: + case Intrinsic::x86_avx_round_ps_256: + case Intrinsic::x86_avx_round_pd_256: + case Intrinsic::x86_avx512_mask_rndscale_ps_128: + case Intrinsic::x86_avx512_mask_rndscale_ps_256: + case Intrinsic::x86_avx512_mask_rndscale_ps_512: + case Intrinsic::x86_avx512_mask_rndscale_pd_128: + case Intrinsic::x86_avx512_mask_rndscale_pd_256: + case Intrinsic::x86_avx512_mask_rndscale_pd_512: + case Intrinsic::x86_avx512_mask_rndscale_ss: + case Intrinsic::x86_avx512_mask_rndscale_sd: + if (Value *V = simplifyX86round(*II, Builder)) + return replaceInstUsesWith(*II, V); + break; + case Intrinsic::x86_mmx_pmovmskb: case Intrinsic::x86_sse_movmsk_ps: case Intrinsic::x86_sse2_movmsk_pd: @@ -2355,16 +2432,16 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return II; break; } - case Intrinsic::x86_avx512_mask_cmp_pd_128: - case Intrinsic::x86_avx512_mask_cmp_pd_256: - case Intrinsic::x86_avx512_mask_cmp_pd_512: - case Intrinsic::x86_avx512_mask_cmp_ps_128: - case Intrinsic::x86_avx512_mask_cmp_ps_256: - case Intrinsic::x86_avx512_mask_cmp_ps_512: { + case Intrinsic::x86_avx512_cmp_pd_128: + case Intrinsic::x86_avx512_cmp_pd_256: + case Intrinsic::x86_avx512_cmp_pd_512: + case Intrinsic::x86_avx512_cmp_ps_128: + case Intrinsic::x86_avx512_cmp_ps_256: + case Intrinsic::x86_avx512_cmp_ps_512: { // Folding cmp(sub(a,b),0) -> cmp(a,b) and cmp(0,sub(a,b)) -> cmp(b,a) Value *Arg0 = II->getArgOperand(0); Value *Arg1 = II->getArgOperand(1); - bool Arg0IsZero = match(Arg0, m_Zero()); + bool Arg0IsZero = match(Arg0, m_PosZeroFP()); if (Arg0IsZero) std::swap(Arg0, Arg1); Value *A, *B; @@ -2376,7 +2453,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // The compare intrinsic uses the above assumptions and therefore // doesn't require additional flags. if ((match(Arg0, m_OneUse(m_FSub(m_Value(A), m_Value(B)))) && - match(Arg1, m_Zero()) && isa<Instruction>(Arg0) && + match(Arg1, m_PosZeroFP()) && isa<Instruction>(Arg0) && cast<Instruction>(Arg0)->getFastMathFlags().noInfs())) { if (Arg0IsZero) std::swap(A, B); @@ -2387,17 +2464,17 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } - case Intrinsic::x86_avx512_mask_add_ps_512: - case Intrinsic::x86_avx512_mask_div_ps_512: - case Intrinsic::x86_avx512_mask_mul_ps_512: - case Intrinsic::x86_avx512_mask_sub_ps_512: - case Intrinsic::x86_avx512_mask_add_pd_512: - case Intrinsic::x86_avx512_mask_div_pd_512: - case Intrinsic::x86_avx512_mask_mul_pd_512: - case Intrinsic::x86_avx512_mask_sub_pd_512: + case Intrinsic::x86_avx512_add_ps_512: + case Intrinsic::x86_avx512_div_ps_512: + case Intrinsic::x86_avx512_mul_ps_512: + case Intrinsic::x86_avx512_sub_ps_512: + case Intrinsic::x86_avx512_add_pd_512: + case Intrinsic::x86_avx512_div_pd_512: + case Intrinsic::x86_avx512_mul_pd_512: + case Intrinsic::x86_avx512_sub_pd_512: // If the rounding mode is CUR_DIRECTION(4) we can turn these into regular // IR operations. - if (auto *R = dyn_cast<ConstantInt>(II->getArgOperand(4))) { + if (auto *R = dyn_cast<ConstantInt>(II->getArgOperand(2))) { if (R->getValue() == 4) { Value *Arg0 = II->getArgOperand(0); Value *Arg1 = II->getArgOperand(1); @@ -2405,27 +2482,24 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *V; switch (II->getIntrinsicID()) { default: llvm_unreachable("Case stmts out of sync!"); - case Intrinsic::x86_avx512_mask_add_ps_512: - case Intrinsic::x86_avx512_mask_add_pd_512: + case Intrinsic::x86_avx512_add_ps_512: + case Intrinsic::x86_avx512_add_pd_512: V = Builder.CreateFAdd(Arg0, Arg1); break; - case Intrinsic::x86_avx512_mask_sub_ps_512: - case Intrinsic::x86_avx512_mask_sub_pd_512: + case Intrinsic::x86_avx512_sub_ps_512: + case Intrinsic::x86_avx512_sub_pd_512: V = Builder.CreateFSub(Arg0, Arg1); break; - case Intrinsic::x86_avx512_mask_mul_ps_512: - case Intrinsic::x86_avx512_mask_mul_pd_512: + case Intrinsic::x86_avx512_mul_ps_512: + case Intrinsic::x86_avx512_mul_pd_512: V = Builder.CreateFMul(Arg0, Arg1); break; - case Intrinsic::x86_avx512_mask_div_ps_512: - case Intrinsic::x86_avx512_mask_div_pd_512: + case Intrinsic::x86_avx512_div_ps_512: + case Intrinsic::x86_avx512_div_pd_512: V = Builder.CreateFDiv(Arg0, Arg1); break; } - // Create a select for the masking. - V = emitX86MaskSelect(II->getArgOperand(3), V, II->getArgOperand(2), - Builder); return replaceInstUsesWith(*II, V); } } @@ -2499,32 +2573,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx512_mask_min_ss_round: case Intrinsic::x86_avx512_mask_max_sd_round: case Intrinsic::x86_avx512_mask_min_sd_round: - case Intrinsic::x86_avx512_mask_vfmadd_ss: - case Intrinsic::x86_avx512_mask_vfmadd_sd: - case Intrinsic::x86_avx512_maskz_vfmadd_ss: - case Intrinsic::x86_avx512_maskz_vfmadd_sd: - case Intrinsic::x86_avx512_mask3_vfmadd_ss: - case Intrinsic::x86_avx512_mask3_vfmadd_sd: - case Intrinsic::x86_avx512_mask3_vfmsub_ss: - case Intrinsic::x86_avx512_mask3_vfmsub_sd: - case Intrinsic::x86_avx512_mask3_vfnmsub_ss: - case Intrinsic::x86_avx512_mask3_vfnmsub_sd: - case Intrinsic::x86_fma_vfmadd_ss: - case Intrinsic::x86_fma_vfmsub_ss: - case Intrinsic::x86_fma_vfnmadd_ss: - case Intrinsic::x86_fma_vfnmsub_ss: - case Intrinsic::x86_fma_vfmadd_sd: - case Intrinsic::x86_fma_vfmsub_sd: - case Intrinsic::x86_fma_vfnmadd_sd: - case Intrinsic::x86_fma_vfnmsub_sd: case Intrinsic::x86_sse_cmp_ss: case Intrinsic::x86_sse_min_ss: case Intrinsic::x86_sse_max_ss: case Intrinsic::x86_sse2_cmp_sd: case Intrinsic::x86_sse2_min_sd: case Intrinsic::x86_sse2_max_sd: - case Intrinsic::x86_sse41_round_ss: - case Intrinsic::x86_sse41_round_sd: case Intrinsic::x86_xop_vfrcz_ss: case Intrinsic::x86_xop_vfrcz_sd: { unsigned VWidth = II->getType()->getVectorNumElements(); @@ -2537,6 +2591,19 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } break; } + case Intrinsic::x86_sse41_round_ss: + case Intrinsic::x86_sse41_round_sd: { + unsigned VWidth = II->getType()->getVectorNumElements(); + APInt UndefElts(VWidth, 0); + APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); + if (Value *V = SimplifyDemandedVectorElts(II, AllOnesEltMask, UndefElts)) { + if (V != II) + return replaceInstUsesWith(*II, V); + return II; + } else if (Value *V = simplifyX86round(*II, Builder)) + return replaceInstUsesWith(*II, V); + break; + } // Constant fold ashr( <A x Bi>, Ci ). // Constant fold lshr( <A x Bi>, Ci ). @@ -2647,26 +2714,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, V); break; - case Intrinsic::x86_sse2_pmulu_dq: - case Intrinsic::x86_sse41_pmuldq: - case Intrinsic::x86_avx2_pmul_dq: - case Intrinsic::x86_avx2_pmulu_dq: - case Intrinsic::x86_avx512_pmul_dq_512: - case Intrinsic::x86_avx512_pmulu_dq_512: { - if (Value *V = simplifyX86muldq(*II, Builder)) - return replaceInstUsesWith(*II, V); - - unsigned VWidth = II->getType()->getVectorNumElements(); - APInt UndefElts(VWidth, 0); - APInt DemandedElts = APInt::getAllOnesValue(VWidth); - if (Value *V = SimplifyDemandedVectorElts(II, DemandedElts, UndefElts)) { - if (V != II) - return replaceInstUsesWith(*II, V); - return II; - } - break; - } - case Intrinsic::x86_sse2_packssdw_128: case Intrinsic::x86_sse2_packsswb_128: case Intrinsic::x86_avx2_packssdw: @@ -2687,7 +2734,9 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, V); break; - case Intrinsic::x86_pclmulqdq: { + case Intrinsic::x86_pclmulqdq: + case Intrinsic::x86_pclmulqdq_256: + case Intrinsic::x86_pclmulqdq_512: { if (auto *C = dyn_cast<ConstantInt>(II->getArgOperand(2))) { unsigned Imm = C->getZExtValue(); @@ -2695,27 +2744,28 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *Arg0 = II->getArgOperand(0); Value *Arg1 = II->getArgOperand(1); unsigned VWidth = Arg0->getType()->getVectorNumElements(); - APInt DemandedElts(VWidth, 0); APInt UndefElts1(VWidth, 0); - DemandedElts = (Imm & 0x01) ? 2 : 1; - if (Value *V = SimplifyDemandedVectorElts(Arg0, DemandedElts, + APInt DemandedElts1 = APInt::getSplat(VWidth, + APInt(2, (Imm & 0x01) ? 2 : 1)); + if (Value *V = SimplifyDemandedVectorElts(Arg0, DemandedElts1, UndefElts1)) { II->setArgOperand(0, V); MadeChange = true; } APInt UndefElts2(VWidth, 0); - DemandedElts = (Imm & 0x10) ? 2 : 1; - if (Value *V = SimplifyDemandedVectorElts(Arg1, DemandedElts, + APInt DemandedElts2 = APInt::getSplat(VWidth, + APInt(2, (Imm & 0x10) ? 2 : 1)); + if (Value *V = SimplifyDemandedVectorElts(Arg1, DemandedElts2, UndefElts2)) { II->setArgOperand(1, V); MadeChange = true; } - // If both input elements are undef, the result is undef. - if (UndefElts1[(Imm & 0x01) ? 1 : 0] || - UndefElts2[(Imm & 0x10) ? 1 : 0]) + // If either input elements are undef, the result is zero. + if (DemandedElts1.isSubsetOf(UndefElts1) || + DemandedElts2.isSubsetOf(UndefElts2)) return replaceInstUsesWith(*II, ConstantAggregateZero::get(II->getType())); @@ -2916,32 +2966,22 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_permd: case Intrinsic::x86_avx2_permps: + case Intrinsic::x86_avx512_permvar_df_256: + case Intrinsic::x86_avx512_permvar_df_512: + case Intrinsic::x86_avx512_permvar_di_256: + case Intrinsic::x86_avx512_permvar_di_512: + case Intrinsic::x86_avx512_permvar_hi_128: + case Intrinsic::x86_avx512_permvar_hi_256: + case Intrinsic::x86_avx512_permvar_hi_512: + case Intrinsic::x86_avx512_permvar_qi_128: + case Intrinsic::x86_avx512_permvar_qi_256: + case Intrinsic::x86_avx512_permvar_qi_512: + case Intrinsic::x86_avx512_permvar_sf_512: + case Intrinsic::x86_avx512_permvar_si_512: if (Value *V = simplifyX86vpermv(*II, Builder)) return replaceInstUsesWith(*II, V); break; - case Intrinsic::x86_avx512_mask_permvar_df_256: - case Intrinsic::x86_avx512_mask_permvar_df_512: - case Intrinsic::x86_avx512_mask_permvar_di_256: - case Intrinsic::x86_avx512_mask_permvar_di_512: - case Intrinsic::x86_avx512_mask_permvar_hi_128: - case Intrinsic::x86_avx512_mask_permvar_hi_256: - case Intrinsic::x86_avx512_mask_permvar_hi_512: - case Intrinsic::x86_avx512_mask_permvar_qi_128: - case Intrinsic::x86_avx512_mask_permvar_qi_256: - case Intrinsic::x86_avx512_mask_permvar_qi_512: - case Intrinsic::x86_avx512_mask_permvar_sf_256: - case Intrinsic::x86_avx512_mask_permvar_sf_512: - case Intrinsic::x86_avx512_mask_permvar_si_256: - case Intrinsic::x86_avx512_mask_permvar_si_512: - if (Value *V = simplifyX86vpermv(*II, Builder)) { - // We simplified the permuting, now create a select for the masking. - V = emitX86MaskSelect(II->getArgOperand(3), V, II->getArgOperand(2), - Builder); - return replaceInstUsesWith(*II, V); - } - break; - case Intrinsic::x86_avx_maskload_ps: case Intrinsic::x86_avx_maskload_pd: case Intrinsic::x86_avx_maskload_ps_256: @@ -3042,7 +3082,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } break; - case Intrinsic::arm_neon_vld1: + case Intrinsic::arm_neon_vld1: { + unsigned MemAlign = getKnownAlignment(II->getArgOperand(0), + DL, II, &AC, &DT); + if (Value *V = simplifyNeonVld1(*II, MemAlign, Builder)) + return replaceInstUsesWith(*II, V); + break; + } + case Intrinsic::arm_neon_vld2: case Intrinsic::arm_neon_vld3: case Intrinsic::arm_neon_vld4: @@ -3069,6 +3116,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::arm_neon_vtbl1: + case Intrinsic::aarch64_neon_tbl1: + if (Value *V = simplifyNeonTbl1(*II, Builder)) + return replaceInstUsesWith(*II, V); + break; + case Intrinsic::arm_neon_vmulls: case Intrinsic::arm_neon_vmullu: case Intrinsic::aarch64_neon_smull: @@ -3107,6 +3160,23 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::arm_neon_aesd: + case Intrinsic::arm_neon_aese: + case Intrinsic::aarch64_crypto_aesd: + case Intrinsic::aarch64_crypto_aese: { + Value *DataArg = II->getArgOperand(0); + Value *KeyArg = II->getArgOperand(1); + + // Try to use the builtin XOR in AESE and AESD to eliminate a prior XOR + Value *Data, *Key; + if (match(KeyArg, m_ZeroInt()) && + match(DataArg, m_Xor(m_Value(Data), m_Value(Key)))) { + II->setArgOperand(0, Data); + II->setArgOperand(1, Key); + return II; + } + break; + } case Intrinsic::amdgcn_rcp: { Value *Src = II->getArgOperand(0); @@ -3264,6 +3334,18 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::amdgcn_cvt_pknorm_i16: + case Intrinsic::amdgcn_cvt_pknorm_u16: + case Intrinsic::amdgcn_cvt_pk_i16: + case Intrinsic::amdgcn_cvt_pk_u16: { + Value *Src0 = II->getArgOperand(0); + Value *Src1 = II->getArgOperand(1); + + if (isa<UndefValue>(Src0) && isa<UndefValue>(Src1)) + return replaceInstUsesWith(*II, UndefValue::get(II->getType())); + + break; + } case Intrinsic::amdgcn_ubfe: case Intrinsic::amdgcn_sbfe: { // Decompose simple cases into standard shifts. @@ -3370,6 +3452,24 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *Src1 = II->getArgOperand(1); Value *Src2 = II->getArgOperand(2); + // Checking for NaN before canonicalization provides better fidelity when + // mapping other operations onto fmed3 since the order of operands is + // unchanged. + CallInst *NewCall = nullptr; + if (match(Src0, m_NaN()) || isa<UndefValue>(Src0)) { + NewCall = Builder.CreateMinNum(Src1, Src2); + } else if (match(Src1, m_NaN()) || isa<UndefValue>(Src1)) { + NewCall = Builder.CreateMinNum(Src0, Src2); + } else if (match(Src2, m_NaN()) || isa<UndefValue>(Src2)) { + NewCall = Builder.CreateMaxNum(Src0, Src1); + } + + if (NewCall) { + NewCall->copyFastMathFlags(II); + NewCall->takeName(II); + return replaceInstUsesWith(*II, NewCall); + } + bool Swap = false; // Canonicalize constants to RHS operands. // @@ -3396,13 +3496,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return II; } - if (match(Src2, m_NaN()) || isa<UndefValue>(Src2)) { - CallInst *NewCall = Builder.CreateMinNum(Src0, Src1); - NewCall->copyFastMathFlags(II); - NewCall->takeName(II); - return replaceInstUsesWith(*II, NewCall); - } - if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) { if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) { if (const ConstantFP *C2 = dyn_cast<ConstantFP>(Src2)) { @@ -3536,13 +3629,32 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // amdgcn.kill(i1 1) is a no-op return eraseInstFromFunction(CI); } + case Intrinsic::amdgcn_update_dpp: { + Value *Old = II->getArgOperand(0); + + auto BC = dyn_cast<ConstantInt>(II->getArgOperand(5)); + auto RM = dyn_cast<ConstantInt>(II->getArgOperand(3)); + auto BM = dyn_cast<ConstantInt>(II->getArgOperand(4)); + if (!BC || !RM || !BM || + BC->isZeroValue() || + RM->getZExtValue() != 0xF || + BM->getZExtValue() != 0xF || + isa<UndefValue>(Old)) + break; + + // If bound_ctrl = 1, row mask = bank mask = 0xf we can omit old value. + II->setOperand(0, UndefValue::get(Old->getType())); + return II; + } case Intrinsic::stackrestore: { // If the save is right next to the restore, remove the restore. This can // happen when variable allocas are DCE'd. if (IntrinsicInst *SS = dyn_cast<IntrinsicInst>(II->getArgOperand(0))) { if (SS->getIntrinsicID() == Intrinsic::stacksave) { - if (&*++SS->getIterator() == II) + // Skip over debug info. + if (SS->getNextNonDebugInstruction() == II) { return eraseInstFromFunction(CI); + } } } @@ -3597,9 +3709,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; case Intrinsic::assume: { Value *IIOperand = II->getArgOperand(0); - // Remove an assume if it is immediately followed by an identical assume. - if (match(II->getNextNode(), - m_Intrinsic<Intrinsic::assume>(m_Specific(IIOperand)))) + // Remove an assume if it is followed by an identical assume. + // TODO: Do we need this? Unless there are conflicting assumptions, the + // computeKnownBits(IIOperand) below here eliminates redundant assumes. + Instruction *Next = II->getNextNonDebugInstruction(); + if (match(Next, m_Intrinsic<Intrinsic::assume>(m_Specific(IIOperand)))) return eraseInstFromFunction(CI); // Canonicalize assume(a && b) -> assume(a); assume(b); @@ -3686,8 +3800,16 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } case Intrinsic::experimental_guard: { - // Is this guard followed by another guard? + // Is this guard followed by another guard? We scan forward over a small + // fixed window of instructions to handle common cases with conditions + // computed between guards. Instruction *NextInst = II->getNextNode(); + for (unsigned i = 0; i < GuardWideningWindow; i++) { + // Note: Using context-free form to avoid compile time blow up + if (!isSafeToSpeculativelyExecute(NextInst)) + break; + NextInst = NextInst->getNextNode(); + } Value *NextCond = nullptr; if (match(NextInst, m_Intrinsic<Intrinsic::experimental_guard>(m_Value(NextCond)))) { @@ -3698,6 +3820,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return eraseInstFromFunction(*NextInst); // Otherwise canonicalize guard(a); guard(b) -> guard(a & b). + Instruction* MoveI = II->getNextNode(); + while (MoveI != NextInst) { + auto *Temp = MoveI; + MoveI = MoveI->getNextNode(); + Temp->moveBefore(II); + } II->setArgOperand(0, Builder.CreateAnd(CurrCond, NextCond)); return eraseInstFromFunction(*NextInst); } @@ -3710,7 +3838,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Fence instruction simplification Instruction *InstCombiner::visitFenceInst(FenceInst &FI) { // Remove identical consecutive fences. - if (auto *NFI = dyn_cast<FenceInst>(FI.getNextNode())) + Instruction *Next = FI.getNextNonDebugInstruction(); + if (auto *NFI = dyn_cast<FenceInst>(Next)) if (FI.isIdenticalTo(NFI)) return eraseInstFromFunction(FI); return nullptr; @@ -3887,8 +4016,8 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { // Remove the convergent attr on calls when the callee is not convergent. if (CS.isConvergent() && !CalleeF->isConvergent() && !CalleeF->isIntrinsic()) { - DEBUG(dbgs() << "Removing convergent attr from instr " - << CS.getInstruction() << "\n"); + LLVM_DEBUG(dbgs() << "Removing convergent attr from instr " + << CS.getInstruction() << "\n"); CS.setNotConvergent(); return CS.getInstruction(); } @@ -3919,7 +4048,9 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { } } - if (isa<ConstantPointerNull>(Callee) || isa<UndefValue>(Callee)) { + if ((isa<ConstantPointerNull>(Callee) && + !NullPointerIsDefined(CS.getInstruction()->getFunction())) || + isa<UndefValue>(Callee)) { // If CS does not return void then replaceAllUsesWith undef. // This allows ValueHandlers and custom metadata to adjust itself. if (!CS.getInstruction()->getType()->isVoidTy()) @@ -3986,10 +4117,19 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { if (!Callee) return false; - // The prototype of a thunk is a lie. Don't directly call such a function. + // If this is a call to a thunk function, don't remove the cast. Thunks are + // used to transparently forward all incoming parameters and outgoing return + // values, so it's important to leave the cast in place. if (Callee->hasFnAttribute("thunk")) return false; + // If this is a musttail call, the callee's prototype must match the caller's + // prototype with the exception of pointee types. The code below doesn't + // implement that, so we can't do this transform. + // TODO: Do the transform if it only requires adding pointer casts. + if (CS.isMustTailCall()) + return false; + Instruction *Caller = CS.getInstruction(); const AttributeList &CallerPAL = CS.getAttributes(); diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index 178c8eaf2502..e8ea7396a96a 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -16,6 +16,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DIBuilder.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/KnownBits.h" using namespace llvm; @@ -256,7 +257,7 @@ Instruction::CastOps InstCombiner::isEliminableCastPair(const CastInst *CI1, return Instruction::CastOps(Res); } -/// @brief Implement the transforms common to all CastInst visitors. +/// Implement the transforms common to all CastInst visitors. Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { Value *Src = CI.getOperand(0); @@ -265,14 +266,27 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { if (Instruction::CastOps NewOpc = isEliminableCastPair(CSrc, &CI)) { // The first cast (CSrc) is eliminable so we need to fix up or replace // the second cast (CI). CSrc will then have a good chance of being dead. - return CastInst::Create(NewOpc, CSrc->getOperand(0), CI.getType()); + auto *Ty = CI.getType(); + auto *Res = CastInst::Create(NewOpc, CSrc->getOperand(0), Ty); + // Point debug users of the dying cast to the new one. + if (CSrc->hasOneUse()) + replaceAllDbgUsesWith(*CSrc, *Res, CI, DT); + return Res; } } - // If we are casting a select, then fold the cast into the select. - if (auto *SI = dyn_cast<SelectInst>(Src)) - if (Instruction *NV = FoldOpIntoSelect(CI, SI)) - return NV; + if (auto *Sel = dyn_cast<SelectInst>(Src)) { + // We are casting a select. Try to fold the cast into the select, but only + // if the select does not have a compare instruction with matching operand + // types. Creating a select with operands that are different sizes than its + // condition may inhibit other folds and lead to worse codegen. + auto *Cmp = dyn_cast<CmpInst>(Sel->getCondition()); + if (!Cmp || Cmp->getOperand(0)->getType() != Sel->getType()) + if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) { + replaceAllDbgUsesWith(*Sel, *NV, CI, DT); + return NV; + } + } // If we are casting a PHI, then fold the cast into the PHI. if (auto *PN = dyn_cast<PHINode>(Src)) { @@ -287,6 +301,33 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { return nullptr; } +/// Constants and extensions/truncates from the destination type are always +/// free to be evaluated in that type. This is a helper for canEvaluate*. +static bool canAlwaysEvaluateInType(Value *V, Type *Ty) { + if (isa<Constant>(V)) + return true; + Value *X; + if ((match(V, m_ZExtOrSExt(m_Value(X))) || match(V, m_Trunc(m_Value(X)))) && + X->getType() == Ty) + return true; + + return false; +} + +/// Filter out values that we can not evaluate in the destination type for free. +/// This is a helper for canEvaluate*. +static bool canNotEvaluateInType(Value *V, Type *Ty) { + assert(!isa<Constant>(V) && "Constant should already be handled."); + if (!isa<Instruction>(V)) + return true; + // We don't extend or shrink something that has multiple uses -- doing so + // would require duplicating the instruction which isn't profitable. + if (!V->hasOneUse()) + return true; + + return false; +} + /// Return true if we can evaluate the specified expression tree as type Ty /// instead of its larger type, and arrive with the same value. /// This is used by code that tries to eliminate truncates. @@ -300,27 +341,14 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { /// static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, Instruction *CxtI) { - // We can always evaluate constants in another type. - if (isa<Constant>(V)) + if (canAlwaysEvaluateInType(V, Ty)) return true; + if (canNotEvaluateInType(V, Ty)) + return false; - Instruction *I = dyn_cast<Instruction>(V); - if (!I) return false; - + auto *I = cast<Instruction>(V); Type *OrigTy = V->getType(); - - // If this is an extension from the dest type, we can eliminate it, even if it - // has multiple uses. - if ((isa<ZExtInst>(I) || isa<SExtInst>(I)) && - I->getOperand(0)->getType() == Ty) - return true; - - // We can't extend or shrink something that has multiple uses: doing so would - // require duplicating the instruction in general, which isn't profitable. - if (!I->hasOneUse()) return false; - - unsigned Opc = I->getOpcode(); - switch (Opc) { + switch (I->getOpcode()) { case Instruction::Add: case Instruction::Sub: case Instruction::Mul: @@ -336,13 +364,12 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, // UDiv and URem can be truncated if all the truncated bits are zero. uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); uint32_t BitWidth = Ty->getScalarSizeInBits(); - if (BitWidth < OrigBitWidth) { - APInt Mask = APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth); - if (IC.MaskedValueIsZero(I->getOperand(0), Mask, 0, CxtI) && - IC.MaskedValueIsZero(I->getOperand(1), Mask, 0, CxtI)) { - return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && - canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); - } + assert(BitWidth < OrigBitWidth && "Unexpected bitwidths!"); + APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth); + if (IC.MaskedValueIsZero(I->getOperand(0), Mask, 0, CxtI) && + IC.MaskedValueIsZero(I->getOperand(1), Mask, 0, CxtI)) { + return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); } break; } @@ -365,9 +392,9 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, if (match(I->getOperand(1), m_APInt(Amt))) { uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); uint32_t BitWidth = Ty->getScalarSizeInBits(); - if (IC.MaskedValueIsZero(I->getOperand(0), - APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth), 0, CxtI) && - Amt->getLimitedValue(BitWidth) < BitWidth) { + if (Amt->getLimitedValue(BitWidth) < BitWidth && + IC.MaskedValueIsZero(I->getOperand(0), + APInt::getBitsSetFrom(OrigBitWidth, BitWidth), 0, CxtI)) { return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); } } @@ -644,20 +671,6 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { if (Instruction *Result = commonCastTransforms(CI)) return Result; - // Test if the trunc is the user of a select which is part of a - // minimum or maximum operation. If so, don't do any more simplification. - // Even simplifying demanded bits can break the canonical form of a - // min/max. - Value *LHS, *RHS; - if (SelectInst *SI = dyn_cast<SelectInst>(CI.getOperand(0))) - if (matchSelectPattern(SI, LHS, RHS).Flavor != SPF_UNKNOWN) - return nullptr; - - // See if we can simplify any instructions used by the input whose sole - // purpose is to compute bits we don't care about. - if (SimplifyDemandedInstructionBits(CI)) - return &CI; - Value *Src = CI.getOperand(0); Type *DestTy = CI.getType(), *SrcTy = Src->getType(); @@ -670,13 +683,29 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // If this cast is a truncate, evaluting in a different type always // eliminates the cast, so it is always a win. - DEBUG(dbgs() << "ICE: EvaluateInDifferentType converting expression type" - " to avoid cast: " << CI << '\n'); + LLVM_DEBUG( + dbgs() << "ICE: EvaluateInDifferentType converting expression type" + " to avoid cast: " + << CI << '\n'); Value *Res = EvaluateInDifferentType(Src, DestTy, false); assert(Res->getType() == DestTy); return replaceInstUsesWith(CI, Res); } + // Test if the trunc is the user of a select which is part of a + // minimum or maximum operation. If so, don't do any more simplification. + // Even simplifying demanded bits can break the canonical form of a + // min/max. + Value *LHS, *RHS; + if (SelectInst *SI = dyn_cast<SelectInst>(CI.getOperand(0))) + if (matchSelectPattern(SI, LHS, RHS).Flavor != SPF_UNKNOWN) + return nullptr; + + // See if we can simplify any instructions used by the input whose sole + // purpose is to compute bits we don't care about. + if (SimplifyDemandedInstructionBits(CI)) + return &CI; + // Canonicalize trunc x to i1 -> (icmp ne (and x, 1), 0), likewise for vector. if (DestTy->getScalarSizeInBits() == 1) { Constant *One = ConstantInt::get(SrcTy, 1); @@ -916,23 +945,14 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, InstCombiner &IC, Instruction *CxtI) { BitsToClear = 0; - if (isa<Constant>(V)) - return true; - - Instruction *I = dyn_cast<Instruction>(V); - if (!I) return false; - - // If the input is a truncate from the destination type, we can trivially - // eliminate it. - if (isa<TruncInst>(I) && I->getOperand(0)->getType() == Ty) + if (canAlwaysEvaluateInType(V, Ty)) return true; + if (canNotEvaluateInType(V, Ty)) + return false; - // We can't extend or shrink something that has multiple uses: doing so would - // require duplicating the instruction in general, which isn't profitable. - if (!I->hasOneUse()) return false; - - unsigned Opc = I->getOpcode(), Tmp; - switch (Opc) { + auto *I = cast<Instruction>(V); + unsigned Tmp; + switch (I->getOpcode()) { case Instruction::ZExt: // zext(zext(x)) -> zext(x). case Instruction::SExt: // zext(sext(x)) -> sext(x). case Instruction::Trunc: // zext(trunc(x)) -> trunc(x) or zext(x) @@ -961,7 +981,7 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, 0, CxtI)) { // If this is an And instruction and all of the BitsToClear are // known to be zero we can reset BitsToClear. - if (Opc == Instruction::And) + if (I->getOpcode() == Instruction::And) BitsToClear = 0; return true; } @@ -1052,11 +1072,18 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { "Can't clear more bits than in SrcTy"); // Okay, we can transform this! Insert the new expression now. - DEBUG(dbgs() << "ICE: EvaluateInDifferentType converting expression type" - " to avoid zero extend: " << CI << '\n'); + LLVM_DEBUG( + dbgs() << "ICE: EvaluateInDifferentType converting expression type" + " to avoid zero extend: " + << CI << '\n'); Value *Res = EvaluateInDifferentType(Src, DestTy, false); assert(Res->getType() == DestTy); + // Preserve debug values referring to Src if the zext is its last use. + if (auto *SrcOp = dyn_cast<Instruction>(Src)) + if (SrcOp->hasOneUse()) + replaceAllDbgUsesWith(*SrcOp, *Res, CI, DT); + uint32_t SrcBitsKept = SrcTy->getScalarSizeInBits()-BitsToClear; uint32_t DestBitSize = DestTy->getScalarSizeInBits(); @@ -1168,22 +1195,19 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { if (!Op1->getType()->isIntOrIntVectorTy()) return nullptr; - if (Constant *Op1C = dyn_cast<Constant>(Op1)) { + if ((Pred == ICmpInst::ICMP_SLT && match(Op1, m_ZeroInt())) || + (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes()))) { // (x <s 0) ? -1 : 0 -> ashr x, 31 -> all ones if negative // (x >s -1) ? -1 : 0 -> not (ashr x, 31) -> all ones if positive - if ((Pred == ICmpInst::ICMP_SLT && Op1C->isNullValue()) || - (Pred == ICmpInst::ICMP_SGT && Op1C->isAllOnesValue())) { + Value *Sh = ConstantInt::get(Op0->getType(), + Op0->getType()->getScalarSizeInBits() - 1); + Value *In = Builder.CreateAShr(Op0, Sh, Op0->getName() + ".lobit"); + if (In->getType() != CI.getType()) + In = Builder.CreateIntCast(In, CI.getType(), true /*SExt*/); - Value *Sh = ConstantInt::get(Op0->getType(), - Op0->getType()->getScalarSizeInBits()-1); - Value *In = Builder.CreateAShr(Op0, Sh, Op0->getName() + ".lobit"); - if (In->getType() != CI.getType()) - In = Builder.CreateIntCast(In, CI.getType(), true /*SExt*/); - - if (Pred == ICmpInst::ICMP_SGT) - In = Builder.CreateNot(In, In->getName() + ".not"); - return replaceInstUsesWith(CI, In); - } + if (Pred == ICmpInst::ICMP_SGT) + In = Builder.CreateNot(In, In->getName() + ".not"); + return replaceInstUsesWith(CI, In); } if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { @@ -1254,21 +1278,12 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) { static bool canEvaluateSExtd(Value *V, Type *Ty) { assert(V->getType()->getScalarSizeInBits() < Ty->getScalarSizeInBits() && "Can't sign extend type to a smaller type"); - // If this is a constant, it can be trivially promoted. - if (isa<Constant>(V)) + if (canAlwaysEvaluateInType(V, Ty)) return true; + if (canNotEvaluateInType(V, Ty)) + return false; - Instruction *I = dyn_cast<Instruction>(V); - if (!I) return false; - - // If this is a truncate from the dest type, we can trivially eliminate it. - if (isa<TruncInst>(I) && I->getOperand(0)->getType() == Ty) - return true; - - // We can't extend or shrink something that has multiple uses: doing so would - // require duplicating the instruction in general, which isn't profitable. - if (!I->hasOneUse()) return false; - + auto *I = cast<Instruction>(V); switch (I->getOpcode()) { case Instruction::SExt: // sext(sext(x)) -> sext(x) case Instruction::ZExt: // sext(zext(x)) -> zext(x) @@ -1335,8 +1350,10 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) && canEvaluateSExtd(Src, DestTy)) { // Okay, we can transform this! Insert the new expression now. - DEBUG(dbgs() << "ICE: EvaluateInDifferentType converting expression type" - " to avoid sign extend: " << CI << '\n'); + LLVM_DEBUG( + dbgs() << "ICE: EvaluateInDifferentType converting expression type" + " to avoid sign extend: " + << CI << '\n'); Value *Res = EvaluateInDifferentType(Src, DestTy, true); assert(Res->getType() == DestTy); @@ -1401,45 +1418,83 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { /// Return a Constant* for the specified floating-point constant if it fits /// in the specified FP type without changing its value. -static Constant *fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { +static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { bool losesInfo; APFloat F = CFP->getValueAPF(); (void)F.convert(Sem, APFloat::rmNearestTiesToEven, &losesInfo); - if (!losesInfo) - return ConstantFP::get(CFP->getContext(), F); + return !losesInfo; +} + +static Type *shrinkFPConstant(ConstantFP *CFP) { + if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext())) + return nullptr; // No constant folding of this. + // See if the value can be truncated to half and then reextended. + if (fitsInFPType(CFP, APFloat::IEEEhalf())) + return Type::getHalfTy(CFP->getContext()); + // See if the value can be truncated to float and then reextended. + if (fitsInFPType(CFP, APFloat::IEEEsingle())) + return Type::getFloatTy(CFP->getContext()); + if (CFP->getType()->isDoubleTy()) + return nullptr; // Won't shrink. + if (fitsInFPType(CFP, APFloat::IEEEdouble())) + return Type::getDoubleTy(CFP->getContext()); + // Don't try to shrink to various long double types. return nullptr; } -/// Look through floating-point extensions until we get the source value. -static Value *lookThroughFPExtensions(Value *V) { - while (auto *FPExt = dyn_cast<FPExtInst>(V)) - V = FPExt->getOperand(0); +// Determine if this is a vector of ConstantFPs and if so, return the minimal +// type we can safely truncate all elements to. +// TODO: Make these support undef elements. +static Type *shrinkFPConstantVector(Value *V) { + auto *CV = dyn_cast<Constant>(V); + if (!CV || !CV->getType()->isVectorTy()) + return nullptr; + + Type *MinType = nullptr; + + unsigned NumElts = CV->getType()->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i)); + if (!CFP) + return nullptr; + + Type *T = shrinkFPConstant(CFP); + if (!T) + return nullptr; + + // If we haven't found a type yet or this type has a larger mantissa than + // our previous type, this is our new minimal type. + if (!MinType || T->getFPMantissaWidth() > MinType->getFPMantissaWidth()) + MinType = T; + } + + // Make a vector type from the minimal type. + return VectorType::get(MinType, NumElts); +} + +/// Find the minimum FP type we can safely truncate to. +static Type *getMinimumFPType(Value *V) { + if (auto *FPExt = dyn_cast<FPExtInst>(V)) + return FPExt->getOperand(0)->getType(); // If this value is a constant, return the constant in the smallest FP type // that can accurately represent it. This allows us to turn // (float)((double)X+2.0) into x+2.0f. - if (auto *CFP = dyn_cast<ConstantFP>(V)) { - if (CFP->getType() == Type::getPPC_FP128Ty(V->getContext())) - return V; // No constant folding of this. - // See if the value can be truncated to half and then reextended. - if (Value *V = fitsInFPType(CFP, APFloat::IEEEhalf())) - return V; - // See if the value can be truncated to float and then reextended. - if (Value *V = fitsInFPType(CFP, APFloat::IEEEsingle())) - return V; - if (CFP->getType()->isDoubleTy()) - return V; // Won't shrink. - if (Value *V = fitsInFPType(CFP, APFloat::IEEEdouble())) - return V; - // Don't try to shrink to various long double types. - } - - return V; + if (auto *CFP = dyn_cast<ConstantFP>(V)) + if (Type *T = shrinkFPConstant(CFP)) + return T; + + // Try to shrink a vector of FP constants. + if (Type *T = shrinkFPConstantVector(V)) + return T; + + return V->getType(); } -Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { - if (Instruction *I = commonCastTransforms(CI)) +Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { + if (Instruction *I = commonCastTransforms(FPT)) return I; + // If we have fptrunc(OpI (fpextend x), (fpextend y)), we would like to // simplify this expression to avoid one or more of the trunc/extend // operations if we can do so without changing the numerical results. @@ -1447,15 +1502,16 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { // The exact manner in which the widths of the operands interact to limit // what we can and cannot do safely varies from operation to operation, and // is explained below in the various case statements. - BinaryOperator *OpI = dyn_cast<BinaryOperator>(CI.getOperand(0)); + Type *Ty = FPT.getType(); + BinaryOperator *OpI = dyn_cast<BinaryOperator>(FPT.getOperand(0)); if (OpI && OpI->hasOneUse()) { - Value *LHSOrig = lookThroughFPExtensions(OpI->getOperand(0)); - Value *RHSOrig = lookThroughFPExtensions(OpI->getOperand(1)); + Type *LHSMinType = getMinimumFPType(OpI->getOperand(0)); + Type *RHSMinType = getMinimumFPType(OpI->getOperand(1)); unsigned OpWidth = OpI->getType()->getFPMantissaWidth(); - unsigned LHSWidth = LHSOrig->getType()->getFPMantissaWidth(); - unsigned RHSWidth = RHSOrig->getType()->getFPMantissaWidth(); + unsigned LHSWidth = LHSMinType->getFPMantissaWidth(); + unsigned RHSWidth = RHSMinType->getFPMantissaWidth(); unsigned SrcWidth = std::max(LHSWidth, RHSWidth); - unsigned DstWidth = CI.getType()->getFPMantissaWidth(); + unsigned DstWidth = Ty->getFPMantissaWidth(); switch (OpI->getOpcode()) { default: break; case Instruction::FAdd: @@ -1479,12 +1535,9 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { // could be tightened for those cases, but they are rare (the main // case of interest here is (float)((double)float + float)). if (OpWidth >= 2*DstWidth+1 && DstWidth >= SrcWidth) { - if (LHSOrig->getType() != CI.getType()) - LHSOrig = Builder.CreateFPExt(LHSOrig, CI.getType()); - if (RHSOrig->getType() != CI.getType()) - RHSOrig = Builder.CreateFPExt(RHSOrig, CI.getType()); - Instruction *RI = - BinaryOperator::Create(OpI->getOpcode(), LHSOrig, RHSOrig); + Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), Ty); + Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), Ty); + Instruction *RI = BinaryOperator::Create(OpI->getOpcode(), LHS, RHS); RI->copyFastMathFlags(OpI); return RI; } @@ -1496,14 +1549,9 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { // rounding can possibly occur; we can safely perform the operation // in the destination format if it can represent both sources. if (OpWidth >= LHSWidth + RHSWidth && DstWidth >= SrcWidth) { - if (LHSOrig->getType() != CI.getType()) - LHSOrig = Builder.CreateFPExt(LHSOrig, CI.getType()); - if (RHSOrig->getType() != CI.getType()) - RHSOrig = Builder.CreateFPExt(RHSOrig, CI.getType()); - Instruction *RI = - BinaryOperator::CreateFMul(LHSOrig, RHSOrig); - RI->copyFastMathFlags(OpI); - return RI; + Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), Ty); + Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), Ty); + return BinaryOperator::CreateFMulFMF(LHS, RHS, OpI); } break; case Instruction::FDiv: @@ -1514,72 +1562,48 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { // condition used here is a good conservative first pass. // TODO: Tighten bound via rigorous analysis of the unbalanced case. if (OpWidth >= 2*DstWidth && DstWidth >= SrcWidth) { - if (LHSOrig->getType() != CI.getType()) - LHSOrig = Builder.CreateFPExt(LHSOrig, CI.getType()); - if (RHSOrig->getType() != CI.getType()) - RHSOrig = Builder.CreateFPExt(RHSOrig, CI.getType()); - Instruction *RI = - BinaryOperator::CreateFDiv(LHSOrig, RHSOrig); - RI->copyFastMathFlags(OpI); - return RI; + Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), Ty); + Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), Ty); + return BinaryOperator::CreateFDivFMF(LHS, RHS, OpI); } break; - case Instruction::FRem: + case Instruction::FRem: { // Remainder is straightforward. Remainder is always exact, so the // type of OpI doesn't enter into things at all. We simply evaluate // in whichever source type is larger, then convert to the // destination type. if (SrcWidth == OpWidth) break; - if (LHSWidth < SrcWidth) - LHSOrig = Builder.CreateFPExt(LHSOrig, RHSOrig->getType()); - else if (RHSWidth <= SrcWidth) - RHSOrig = Builder.CreateFPExt(RHSOrig, LHSOrig->getType()); - if (LHSOrig != OpI->getOperand(0) || RHSOrig != OpI->getOperand(1)) { - Value *ExactResult = Builder.CreateFRem(LHSOrig, RHSOrig); - if (Instruction *RI = dyn_cast<Instruction>(ExactResult)) - RI->copyFastMathFlags(OpI); - return CastInst::CreateFPCast(ExactResult, CI.getType()); + Value *LHS, *RHS; + if (LHSWidth == SrcWidth) { + LHS = Builder.CreateFPTrunc(OpI->getOperand(0), LHSMinType); + RHS = Builder.CreateFPTrunc(OpI->getOperand(1), LHSMinType); + } else { + LHS = Builder.CreateFPTrunc(OpI->getOperand(0), RHSMinType); + RHS = Builder.CreateFPTrunc(OpI->getOperand(1), RHSMinType); } + + Value *ExactResult = Builder.CreateFRemFMF(LHS, RHS, OpI); + return CastInst::CreateFPCast(ExactResult, Ty); + } } // (fptrunc (fneg x)) -> (fneg (fptrunc x)) if (BinaryOperator::isFNeg(OpI)) { - Value *InnerTrunc = Builder.CreateFPTrunc(OpI->getOperand(1), - CI.getType()); - Instruction *RI = BinaryOperator::CreateFNeg(InnerTrunc); - RI->copyFastMathFlags(OpI); - return RI; + Value *InnerTrunc = Builder.CreateFPTrunc(OpI->getOperand(1), Ty); + return BinaryOperator::CreateFNegFMF(InnerTrunc, OpI); } } - // (fptrunc (select cond, R1, Cst)) --> - // (select cond, (fptrunc R1), (fptrunc Cst)) - // - // - but only if this isn't part of a min/max operation, else we'll - // ruin min/max canonical form which is to have the select and - // compare's operands be of the same type with no casts to look through. - Value *LHS, *RHS; - SelectInst *SI = dyn_cast<SelectInst>(CI.getOperand(0)); - if (SI && - (isa<ConstantFP>(SI->getOperand(1)) || - isa<ConstantFP>(SI->getOperand(2))) && - matchSelectPattern(SI, LHS, RHS).Flavor == SPF_UNKNOWN) { - Value *LHSTrunc = Builder.CreateFPTrunc(SI->getOperand(1), CI.getType()); - Value *RHSTrunc = Builder.CreateFPTrunc(SI->getOperand(2), CI.getType()); - return SelectInst::Create(SI->getOperand(0), LHSTrunc, RHSTrunc); - } - - IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI.getOperand(0)); - if (II) { + if (auto *II = dyn_cast<IntrinsicInst>(FPT.getOperand(0))) { switch (II->getIntrinsicID()) { default: break; - case Intrinsic::fabs: case Intrinsic::ceil: + case Intrinsic::fabs: case Intrinsic::floor: + case Intrinsic::nearbyint: case Intrinsic::rint: case Intrinsic::round: - case Intrinsic::nearbyint: case Intrinsic::trunc: { Value *Src = II->getArgOperand(0); if (!Src->hasOneUse()) @@ -1590,30 +1614,26 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { // truncating. if (II->getIntrinsicID() != Intrinsic::fabs) { FPExtInst *FPExtSrc = dyn_cast<FPExtInst>(Src); - if (!FPExtSrc || FPExtSrc->getOperand(0)->getType() != CI.getType()) + if (!FPExtSrc || FPExtSrc->getSrcTy() != Ty) break; } // Do unary FP operation on smaller type. // (fptrunc (fabs x)) -> (fabs (fptrunc x)) - Value *InnerTrunc = Builder.CreateFPTrunc(Src, CI.getType()); - Type *IntrinsicType[] = { CI.getType() }; - Function *Overload = Intrinsic::getDeclaration( - CI.getModule(), II->getIntrinsicID(), IntrinsicType); - + Value *InnerTrunc = Builder.CreateFPTrunc(Src, Ty); + Function *Overload = Intrinsic::getDeclaration(FPT.getModule(), + II->getIntrinsicID(), Ty); SmallVector<OperandBundleDef, 1> OpBundles; II->getOperandBundlesAsDefs(OpBundles); - - Value *Args[] = { InnerTrunc }; - CallInst *NewCI = CallInst::Create(Overload, Args, - OpBundles, II->getName()); + CallInst *NewCI = CallInst::Create(Overload, { InnerTrunc }, OpBundles, + II->getName()); NewCI->copyFastMathFlags(II); return NewCI; } } } - if (Instruction *I = shrinkInsertElt(CI, Builder)) + if (Instruction *I = shrinkInsertElt(FPT, Builder)) return I; return nullptr; @@ -1718,7 +1738,7 @@ Instruction *InstCombiner::visitIntToPtr(IntToPtrInst &CI) { return nullptr; } -/// @brief Implement the transforms for cast of pointer (bitcast/ptrtoint) +/// Implement the transforms for cast of pointer (bitcast/ptrtoint) Instruction *InstCombiner::commonPointerCastTransforms(CastInst &CI) { Value *Src = CI.getOperand(0); @@ -1751,7 +1771,7 @@ Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) { Type *Ty = CI.getType(); unsigned AS = CI.getPointerAddressSpace(); - if (Ty->getScalarSizeInBits() == DL.getPointerSizeInBits(AS)) + if (Ty->getScalarSizeInBits() == DL.getIndexSizeInBits(AS)) return commonPointerCastTransforms(CI); Type *PtrTy = DL.getIntPtrType(CI.getContext(), AS); @@ -2004,13 +2024,13 @@ static Instruction *foldBitCastBitwiseLogic(BitCastInst &BitCast, !match(BitCast.getOperand(0), m_OneUse(m_BinOp(BO))) || !BO->isBitwiseLogicOp()) return nullptr; - + // FIXME: This transform is restricted to vector types to avoid backend // problems caused by creating potentially illegal operations. If a fix-up is // added to handle that situation, we can remove this check. if (!DestTy->isVectorTy() || !BO->getType()->isVectorTy()) return nullptr; - + Value *X; if (match(BO->getOperand(0), m_OneUse(m_BitCast(m_Value(X)))) && X->getType() == DestTy && !isa<Constant>(X)) { diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 3bc7fae77cb1..6de92a4842ab 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -682,7 +682,7 @@ static Value *rewriteGEPAsOffset(Value *Start, Value *Base, // 4. Emit GEPs to get the original pointers. // 5. Remove the original instructions. Type *IndexType = IntegerType::get( - Base->getContext(), DL.getPointerTypeSizeInBits(Start->getType())); + Base->getContext(), DL.getIndexTypeSizeInBits(Start->getType())); DenseMap<Value *, Value *> NewInsts; NewInsts[Base] = ConstantInt::getNullValue(IndexType); @@ -723,7 +723,7 @@ static Value *rewriteGEPAsOffset(Value *Start, Value *Base, } auto *Op = NewInsts[GEP->getOperand(0)]; - if (isa<ConstantInt>(Op) && dyn_cast<ConstantInt>(Op)->isZero()) + if (isa<ConstantInt>(Op) && cast<ConstantInt>(Op)->isZero()) NewInsts[GEP] = Index; else NewInsts[GEP] = Builder.CreateNSWAdd( @@ -790,7 +790,7 @@ static Value *rewriteGEPAsOffset(Value *Start, Value *Base, static std::pair<Value *, Value *> getAsConstantIndexedAddress(Value *V, const DataLayout &DL) { Type *IndexType = IntegerType::get(V->getContext(), - DL.getPointerTypeSizeInBits(V->getType())); + DL.getIndexTypeSizeInBits(V->getType())); Constant *Index = ConstantInt::getNullValue(IndexType); while (true) { @@ -1893,11 +1893,8 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, APInt ShiftedC = C.ashr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } - if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) { - // This is the same code as the SGT case, but assert the pre-condition - // that is needed for this to work with equality predicates. - assert(C.ashr(*ShiftAmt).shl(*ShiftAmt) == C && - "Compare known true or false was not folded"); + if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && + C.ashr(*ShiftAmt).shl(*ShiftAmt) == C) { APInt ShiftedC = C.ashr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } @@ -1926,11 +1923,8 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, APInt ShiftedC = C.lshr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } - if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) { - // This is the same code as the UGT case, but assert the pre-condition - // that is needed for this to work with equality predicates. - assert(C.lshr(*ShiftAmt).shl(*ShiftAmt) == C && - "Compare known true or false was not folded"); + if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && + C.lshr(*ShiftAmt).shl(*ShiftAmt) == C) { APInt ShiftedC = C.lshr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } @@ -2463,6 +2457,45 @@ Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, return nullptr; } +Instruction *InstCombiner::foldICmpBitCastConstant(ICmpInst &Cmp, + BitCastInst *Bitcast, + const APInt &C) { + // Folding: icmp <pred> iN X, C + // where X = bitcast <M x iK> (shufflevector <M x iK> %vec, undef, SC)) to iN + // and C is a splat of a K-bit pattern + // and SC is a constant vector = <C', C', C', ..., C'> + // Into: + // %E = extractelement <M x iK> %vec, i32 C' + // icmp <pred> iK %E, trunc(C) + if (!Bitcast->getType()->isIntegerTy() || + !Bitcast->getSrcTy()->isIntOrIntVectorTy()) + return nullptr; + + Value *BCIOp = Bitcast->getOperand(0); + Value *Vec = nullptr; // 1st vector arg of the shufflevector + Constant *Mask = nullptr; // Mask arg of the shufflevector + if (match(BCIOp, + m_ShuffleVector(m_Value(Vec), m_Undef(), m_Constant(Mask)))) { + // Check whether every element of Mask is the same constant + if (auto *Elem = dyn_cast_or_null<ConstantInt>(Mask->getSplatValue())) { + auto *VecTy = cast<VectorType>(BCIOp->getType()); + auto *EltTy = cast<IntegerType>(VecTy->getElementType()); + auto Pred = Cmp.getPredicate(); + if (C.isSplat(EltTy->getBitWidth())) { + // Fold the icmp based on the value of C + // If C is M copies of an iK sized bit pattern, + // then: + // => %E = extractelement <N x iK> %vec, i32 Elem + // icmp <pred> iK %SplatVal, <pattern> + Value *Extract = Builder.CreateExtractElement(Vec, Elem); + Value *NewC = ConstantInt::get(EltTy, C.trunc(EltTy->getBitWidth())); + return new ICmpInst(Pred, Extract, NewC); + } + } + } + return nullptr; +} + /// Try to fold integer comparisons with a constant operand: icmp Pred X, C /// where X is some kind of instruction. Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { @@ -2537,6 +2570,11 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { return I; } + if (auto *BCI = dyn_cast<BitCastInst>(Cmp.getOperand(0))) { + if (Instruction *I = foldICmpBitCastConstant(Cmp, BCI, *C)) + return I; + } + if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, *C)) return I; @@ -2828,6 +2866,160 @@ Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) { return nullptr; } +/// Some comparisons can be simplified. +/// In this case, we are looking for comparisons that look like +/// a check for a lossy truncation. +/// Folds: +/// x & (-1 >> y) SrcPred x to x DstPred (-1 >> y) +/// The Mask can be a constant, too. +/// For some predicates, the operands are commutative. +/// For others, x can only be on a specific side. +static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate SrcPred; + Value *X, *M; + auto m_Mask = m_CombineOr(m_LShr(m_AllOnes(), m_Value()), m_LowBitMask()); + if (!match(&I, m_c_ICmp(SrcPred, + m_c_And(m_CombineAnd(m_Mask, m_Value(M)), m_Value(X)), + m_Deferred(X)))) + return nullptr; + + ICmpInst::Predicate DstPred; + switch (SrcPred) { + case ICmpInst::Predicate::ICMP_EQ: + // x & (-1 >> y) == x -> x u<= (-1 >> y) + DstPred = ICmpInst::Predicate::ICMP_ULE; + break; + case ICmpInst::Predicate::ICMP_NE: + // x & (-1 >> y) != x -> x u> (-1 >> y) + DstPred = ICmpInst::Predicate::ICMP_UGT; + break; + case ICmpInst::Predicate::ICMP_UGT: + // x u> x & (-1 >> y) -> x u> (-1 >> y) + assert(X == I.getOperand(0) && "instsimplify took care of commut. variant"); + DstPred = ICmpInst::Predicate::ICMP_UGT; + break; + case ICmpInst::Predicate::ICMP_UGE: + // x & (-1 >> y) u>= x -> x u<= (-1 >> y) + assert(X == I.getOperand(1) && "instsimplify took care of commut. variant"); + DstPred = ICmpInst::Predicate::ICMP_ULE; + break; + case ICmpInst::Predicate::ICMP_ULT: + // x & (-1 >> y) u< x -> x u> (-1 >> y) + assert(X == I.getOperand(1) && "instsimplify took care of commut. variant"); + DstPred = ICmpInst::Predicate::ICMP_UGT; + break; + case ICmpInst::Predicate::ICMP_ULE: + // x u<= x & (-1 >> y) -> x u<= (-1 >> y) + assert(X == I.getOperand(0) && "instsimplify took care of commut. variant"); + DstPred = ICmpInst::Predicate::ICMP_ULE; + break; + case ICmpInst::Predicate::ICMP_SGT: + // x s> x & (-1 >> y) -> x s> (-1 >> y) + if (X != I.getOperand(0)) // X must be on LHS of comparison! + return nullptr; // Ignore the other case. + DstPred = ICmpInst::Predicate::ICMP_SGT; + break; + case ICmpInst::Predicate::ICMP_SGE: + // x & (-1 >> y) s>= x -> x s<= (-1 >> y) + if (X != I.getOperand(1)) // X must be on RHS of comparison! + return nullptr; // Ignore the other case. + DstPred = ICmpInst::Predicate::ICMP_SLE; + break; + case ICmpInst::Predicate::ICMP_SLT: + // x & (-1 >> y) s< x -> x s> (-1 >> y) + if (X != I.getOperand(1)) // X must be on RHS of comparison! + return nullptr; // Ignore the other case. + DstPred = ICmpInst::Predicate::ICMP_SGT; + break; + case ICmpInst::Predicate::ICMP_SLE: + // x s<= x & (-1 >> y) -> x s<= (-1 >> y) + if (X != I.getOperand(0)) // X must be on LHS of comparison! + return nullptr; // Ignore the other case. + DstPred = ICmpInst::Predicate::ICMP_SLE; + break; + default: + llvm_unreachable("All possible folds are handled."); + } + + return Builder.CreateICmp(DstPred, X, M); +} + +/// Some comparisons can be simplified. +/// In this case, we are looking for comparisons that look like +/// a check for a lossy signed truncation. +/// Folds: (MaskedBits is a constant.) +/// ((%x << MaskedBits) a>> MaskedBits) SrcPred %x +/// Into: +/// (add %x, (1 << (KeptBits-1))) DstPred (1 << KeptBits) +/// Where KeptBits = bitwidth(%x) - MaskedBits +static Value * +foldICmpWithTruncSignExtendedVal(ICmpInst &I, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate SrcPred; + Value *X; + const APInt *C0, *C1; // FIXME: non-splats, potentially with undef. + // We are ok with 'shl' having multiple uses, but 'ashr' must be one-use. + if (!match(&I, m_c_ICmp(SrcPred, + m_OneUse(m_AShr(m_Shl(m_Value(X), m_APInt(C0)), + m_APInt(C1))), + m_Deferred(X)))) + return nullptr; + + // Potential handling of non-splats: for each element: + // * if both are undef, replace with constant 0. + // Because (1<<0) is OK and is 1, and ((1<<0)>>1) is also OK and is 0. + // * if both are not undef, and are different, bailout. + // * else, only one is undef, then pick the non-undef one. + + // The shift amount must be equal. + if (*C0 != *C1) + return nullptr; + const APInt &MaskedBits = *C0; + assert(MaskedBits != 0 && "shift by zero should be folded away already."); + + ICmpInst::Predicate DstPred; + switch (SrcPred) { + case ICmpInst::Predicate::ICMP_EQ: + // ((%x << MaskedBits) a>> MaskedBits) == %x + // => + // (add %x, (1 << (KeptBits-1))) u< (1 << KeptBits) + DstPred = ICmpInst::Predicate::ICMP_ULT; + break; + case ICmpInst::Predicate::ICMP_NE: + // ((%x << MaskedBits) a>> MaskedBits) != %x + // => + // (add %x, (1 << (KeptBits-1))) u>= (1 << KeptBits) + DstPred = ICmpInst::Predicate::ICMP_UGE; + break; + // FIXME: are more folds possible? + default: + return nullptr; + } + + auto *XType = X->getType(); + const unsigned XBitWidth = XType->getScalarSizeInBits(); + const APInt BitWidth = APInt(XBitWidth, XBitWidth); + assert(BitWidth.ugt(MaskedBits) && "shifts should leave some bits untouched"); + + // KeptBits = bitwidth(%x) - MaskedBits + const APInt KeptBits = BitWidth - MaskedBits; + assert(KeptBits.ugt(0) && KeptBits.ult(BitWidth) && "unreachable"); + // ICmpCst = (1 << KeptBits) + const APInt ICmpCst = APInt(XBitWidth, 1).shl(KeptBits); + assert(ICmpCst.isPowerOf2()); + // AddCst = (1 << (KeptBits-1)) + const APInt AddCst = ICmpCst.lshr(1); + assert(AddCst.ult(ICmpCst) && AddCst.isPowerOf2()); + + // T0 = add %x, AddCst + Value *T0 = Builder.CreateAdd(X, ConstantInt::get(XType, AddCst)); + // T1 = T0 DstPred ICmpCst + Value *T1 = Builder.CreateICmp(DstPred, T0, ConstantInt::get(XType, ICmpCst)); + + return T1; +} + /// Try to fold icmp (binop), X or icmp X, (binop). /// TODO: A large part of this logic is duplicated in InstSimplify's /// simplifyICmpWithBinOp(). We should be able to share that and avoid the code @@ -3011,17 +3203,22 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { // icmp (X-Y), X -> icmp 0, Y for equalities or if there is no overflow. if (A == Op1 && NoOp0WrapProblem) return new ICmpInst(Pred, Constant::getNullValue(Op1->getType()), B); - // icmp X, (X-Y) -> icmp Y, 0 for equalities or if there is no overflow. if (C == Op0 && NoOp1WrapProblem) return new ICmpInst(Pred, D, Constant::getNullValue(Op0->getType())); + // (A - B) >u A --> A <u B + if (A == Op1 && Pred == ICmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_ULT, A, B); + // C <u (C - D) --> C <u D + if (C == Op0 && Pred == ICmpInst::ICMP_ULT) + return new ICmpInst(ICmpInst::ICMP_ULT, C, D); + // icmp (Y-X), (Z-X) -> icmp Y, Z for equalities or if there is no overflow. if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem && // Try not to increase register pressure. BO0->hasOneUse() && BO1->hasOneUse()) return new ICmpInst(Pred, A, C); - // icmp (X-Y), (X-Z) -> icmp Z, Y for equalities or if there is no overflow. if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem && // Try not to increase register pressure. @@ -3032,8 +3229,8 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { if (NoOp0WrapProblem && ICmpInst::isSigned(Pred)) { Value *X; if (match(BO0, m_Neg(m_Value(X)))) - if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) - if (!RHSC->isMinValue(/*isSigned=*/true)) + if (Constant *RHSC = dyn_cast<Constant>(Op1)) + if (RHSC->isNotMinSignedValue()) return new ICmpInst(I.getSwappedPredicate(), X, ConstantExpr::getNeg(RHSC)); } @@ -3160,6 +3357,12 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { } } + if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder)) + return replaceInstUsesWith(I, V); + + if (Value *V = foldICmpWithTruncSignExtendedVal(I, Builder)) + return replaceInstUsesWith(I, V); + return nullptr; } @@ -3414,8 +3617,15 @@ Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) { // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the // integer type is the same size as the pointer type. + const auto& CompatibleSizes = [&](Type* SrcTy, Type* DestTy) -> bool { + if (isa<VectorType>(SrcTy)) { + SrcTy = cast<VectorType>(SrcTy)->getElementType(); + DestTy = cast<VectorType>(DestTy)->getElementType(); + } + return DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth(); + }; if (LHSCI->getOpcode() == Instruction::PtrToInt && - DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth()) { + CompatibleSizes(SrcTy, DestTy)) { Value *RHSOp = nullptr; if (auto *RHSC = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) { Value *RHSCIOp = RHSC->getOperand(0); @@ -3618,7 +3828,7 @@ bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, return false; } -/// \brief Recognize and process idiom involving test for multiplication +/// Recognize and process idiom involving test for multiplication /// overflow. /// /// The caller has matched a pattern of the form: @@ -3799,7 +4009,8 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, // mul.with.overflow and adjust properly mask/size. if (MulVal->hasNUsesOrMore(2)) { Value *Mul = Builder.CreateExtractValue(Call, 0, "umul.value"); - for (User *U : MulVal->users()) { + for (auto UI = MulVal->user_begin(), UE = MulVal->user_end(); UI != UE;) { + User *U = *UI++; if (U == &I || U == OtherVal) continue; if (TruncInst *TI = dyn_cast<TruncInst>(U)) { @@ -3890,48 +4101,33 @@ static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) { } } -/// \brief Check if the order of \p Op0 and \p Op1 as operand in an ICmpInst +/// Check if the order of \p Op0 and \p Op1 as operands in an ICmpInst /// should be swapped. /// The decision is based on how many times these two operands are reused /// as subtract operands and their positions in those instructions. -/// The rational is that several architectures use the same instruction for -/// both subtract and cmp, thus it is better if the order of those operands +/// The rationale is that several architectures use the same instruction for +/// both subtract and cmp. Thus, it is better if the order of those operands /// match. /// \return true if Op0 and Op1 should be swapped. -static bool swapMayExposeCSEOpportunities(const Value * Op0, - const Value * Op1) { - // Filter out pointer value as those cannot appears directly in subtract. +static bool swapMayExposeCSEOpportunities(const Value *Op0, const Value *Op1) { + // Filter out pointer values as those cannot appear directly in subtract. // FIXME: we may want to go through inttoptrs or bitcasts. if (Op0->getType()->isPointerTy()) return false; - // Count every uses of both Op0 and Op1 in a subtract. - // Each time Op0 is the first operand, count -1: swapping is bad, the - // subtract has already the same layout as the compare. - // Each time Op0 is the second operand, count +1: swapping is good, the - // subtract has a different layout as the compare. - // At the end, if the benefit is greater than 0, Op0 should come second to - // expose more CSE opportunities. - int GlobalSwapBenefits = 0; + // If a subtract already has the same operands as a compare, swapping would be + // bad. If a subtract has the same operands as a compare but in reverse order, + // then swapping is good. + int GoodToSwap = 0; for (const User *U : Op0->users()) { - const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(U); - if (!BinOp || BinOp->getOpcode() != Instruction::Sub) - continue; - // If Op0 is the first argument, this is not beneficial to swap the - // arguments. - int LocalSwapBenefits = -1; - unsigned Op1Idx = 1; - if (BinOp->getOperand(Op1Idx) == Op0) { - Op1Idx = 0; - LocalSwapBenefits = 1; - } - if (BinOp->getOperand(Op1Idx) != Op1) - continue; - GlobalSwapBenefits += LocalSwapBenefits; + if (match(U, m_Sub(m_Specific(Op1), m_Specific(Op0)))) + GoodToSwap++; + else if (match(U, m_Sub(m_Specific(Op0), m_Specific(Op1)))) + GoodToSwap--; } - return GlobalSwapBenefits > 0; + return GoodToSwap > 0; } -/// \brief Check that one use is in the same block as the definition and all +/// Check that one use is in the same block as the definition and all /// other uses are in blocks dominated by a given block. /// /// \param DI Definition @@ -3976,7 +4172,7 @@ static bool isChainSelectCmpBranch(const SelectInst *SI) { return true; } -/// \brief True when a select result is replaced by one of its operands +/// True when a select result is replaced by one of its operands /// in select-icmp sequence. This will eventually result in the elimination /// of the select. /// @@ -4052,7 +4248,7 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { // Get scalar or pointer size. unsigned BitWidth = Ty->isIntOrIntVectorTy() ? Ty->getScalarSizeInBits() - : DL.getTypeSizeInBits(Ty->getScalarType()); + : DL.getIndexTypeSizeInBits(Ty->getScalarType()); if (!BitWidth) return nullptr; @@ -4082,13 +4278,13 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { computeUnsignedMinMaxValuesFromKnownBits(Op1Known, Op1Min, Op1Max); } - // If Min and Max are known to be the same, then SimplifyDemandedBits - // figured out that the LHS is a constant. Constant fold this now, so that + // If Min and Max are known to be the same, then SimplifyDemandedBits figured + // out that the LHS or RHS is a constant. Constant fold this now, so that // code below can assume that Min != Max. if (!isa<Constant>(Op0) && Op0Min == Op0Max) - return new ICmpInst(Pred, ConstantInt::get(Op0->getType(), Op0Min), Op1); + return new ICmpInst(Pred, ConstantExpr::getIntegerValue(Ty, Op0Min), Op1); if (!isa<Constant>(Op1) && Op1Min == Op1Max) - return new ICmpInst(Pred, Op0, ConstantInt::get(Op1->getType(), Op1Min)); + return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min)); // Based on the range information we know about the LHS, see if we can // simplify this comparison. For example, (x&4) < 8 is always true. @@ -4520,6 +4716,34 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return New; } + // Zero-equality and sign-bit checks are preserved through sitofp + bitcast. + Value *X; + if (match(Op0, m_BitCast(m_SIToFP(m_Value(X))))) { + // icmp eq (bitcast (sitofp X)), 0 --> icmp eq X, 0 + // icmp ne (bitcast (sitofp X)), 0 --> icmp ne X, 0 + // icmp slt (bitcast (sitofp X)), 0 --> icmp slt X, 0 + // icmp sgt (bitcast (sitofp X)), 0 --> icmp sgt X, 0 + if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_SLT || + Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT) && + match(Op1, m_Zero())) + return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); + + // icmp slt (bitcast (sitofp X)), 1 --> icmp slt X, 1 + if (Pred == ICmpInst::ICMP_SLT && match(Op1, m_One())) + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), 1)); + + // icmp sgt (bitcast (sitofp X)), -1 --> icmp sgt X, -1 + if (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes())) + return new ICmpInst(Pred, X, ConstantInt::getAllOnesValue(X->getType())); + } + + // Zero-equality checks are preserved through unsigned floating-point casts: + // icmp eq (bitcast (uitofp X)), 0 --> icmp eq X, 0 + // icmp ne (bitcast (uitofp X)), 0 --> icmp ne X, 0 + if (match(Op0, m_BitCast(m_UIToFP(m_Value(X))))) + if (I.isEquality() && match(Op1, m_Zero())) + return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); + // Test to see if the operands of the icmp are casted versions of other // values. If the ptr->ptr cast can be stripped off both arguments, we do so // now. @@ -4642,6 +4866,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (match(Op1, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op0 == X) return foldICmpAddOpConst(X, Cst, I.getSwappedPredicate()); } + return Changed ? &I : nullptr; } @@ -4928,11 +5153,11 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { // If we're just checking for a NaN (ORD/UNO) and have a non-NaN operand, // then canonicalize the operand to 0.0. if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) { - if (!match(Op0, m_Zero()) && isKnownNeverNaN(Op0)) { + if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0)) { I.setOperand(0, ConstantFP::getNullValue(Op0->getType())); return &I; } - if (!match(Op1, m_Zero()) && isKnownNeverNaN(Op1)) { + if (!match(Op1, m_PosZeroFP()) && isKnownNeverNaN(Op1)) { I.setOperand(1, ConstantFP::getNullValue(Op0->getType())); return &I; } diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index f1f66d86cb73..58ef3d41415c 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -20,6 +20,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetFolder.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" @@ -40,7 +41,6 @@ #include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" -#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstdint> @@ -122,17 +122,17 @@ static inline Value *peekThroughBitcast(Value *V, bool OneUseOnly = false) { return V; } -/// \brief Add one to a Constant +/// Add one to a Constant static inline Constant *AddOne(Constant *C) { return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1)); } -/// \brief Subtract one from a Constant +/// Subtract one from a Constant static inline Constant *SubOne(Constant *C) { return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1)); } -/// \brief Return true if the specified value is free to invert (apply ~ to). +/// Return true if the specified value is free to invert (apply ~ to). /// This happens in cases where the ~ can be eliminated. If WillInvertAllUses /// is true, work under the assumption that the caller intends to remove all /// uses of V and only keep uses of ~V. @@ -178,7 +178,7 @@ static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) { return false; } -/// \brief Specific patterns of overflow check idioms that we match. +/// Specific patterns of overflow check idioms that we match. enum OverflowCheckFlavor { OCF_UNSIGNED_ADD, OCF_SIGNED_ADD, @@ -190,7 +190,7 @@ enum OverflowCheckFlavor { OCF_INVALID }; -/// \brief Returns the OverflowCheckFlavor corresponding to a overflow_with_op +/// Returns the OverflowCheckFlavor corresponding to a overflow_with_op /// intrinsic. static inline OverflowCheckFlavor IntrinsicIDToOverflowCheckFlavor(unsigned ID) { @@ -212,7 +212,62 @@ IntrinsicIDToOverflowCheckFlavor(unsigned ID) { } } -/// \brief The core instruction combiner logic. +/// Some binary operators require special handling to avoid poison and undefined +/// behavior. If a constant vector has undef elements, replace those undefs with +/// identity constants if possible because those are always safe to execute. +/// If no identity constant exists, replace undef with some other safe constant. +static inline Constant *getSafeVectorConstantForBinop( + BinaryOperator::BinaryOps Opcode, Constant *In, bool IsRHSConstant) { + assert(In->getType()->isVectorTy() && "Not expecting scalars here"); + + Type *EltTy = In->getType()->getVectorElementType(); + auto *SafeC = ConstantExpr::getBinOpIdentity(Opcode, EltTy, IsRHSConstant); + if (!SafeC) { + // TODO: Should this be available as a constant utility function? It is + // similar to getBinOpAbsorber(). + if (IsRHSConstant) { + switch (Opcode) { + case Instruction::SRem: // X % 1 = 0 + case Instruction::URem: // X %u 1 = 0 + SafeC = ConstantInt::get(EltTy, 1); + break; + case Instruction::FRem: // X % 1.0 (doesn't simplify, but it is safe) + SafeC = ConstantFP::get(EltTy, 1.0); + break; + default: + llvm_unreachable("Only rem opcodes have no identity constant for RHS"); + } + } else { + switch (Opcode) { + case Instruction::Shl: // 0 << X = 0 + case Instruction::LShr: // 0 >>u X = 0 + case Instruction::AShr: // 0 >> X = 0 + case Instruction::SDiv: // 0 / X = 0 + case Instruction::UDiv: // 0 /u X = 0 + case Instruction::SRem: // 0 % X = 0 + case Instruction::URem: // 0 %u X = 0 + case Instruction::Sub: // 0 - X (doesn't simplify, but it is safe) + case Instruction::FSub: // 0.0 - X (doesn't simplify, but it is safe) + case Instruction::FDiv: // 0.0 / X (doesn't simplify, but it is safe) + case Instruction::FRem: // 0.0 % X = 0 + SafeC = Constant::getNullValue(EltTy); + break; + default: + llvm_unreachable("Expected to find identity constant for opcode"); + } + } + } + assert(SafeC && "Must have safe constant for binop"); + unsigned NumElts = In->getType()->getVectorNumElements(); + SmallVector<Constant *, 16> Out(NumElts); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *C = In->getAggregateElement(i); + Out[i] = isa<UndefValue>(C) ? SafeC : C; + } + return ConstantVector::get(Out); +} + +/// The core instruction combiner logic. /// /// This class provides both the logic to recursively visit instructions and /// combine them. @@ -220,10 +275,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner : public InstVisitor<InstCombiner, Instruction *> { // FIXME: These members shouldn't be public. public: - /// \brief A worklist of the instructions that need to be simplified. + /// A worklist of the instructions that need to be simplified. InstCombineWorklist &Worklist; - /// \brief An IRBuilder that automatically inserts new instructions into the + /// An IRBuilder that automatically inserts new instructions into the /// worklist. using BuilderTy = IRBuilder<TargetFolder, IRBuilderCallbackInserter>; BuilderTy &Builder; @@ -261,7 +316,7 @@ public: ExpensiveCombines(ExpensiveCombines), AA(AA), AC(AC), TLI(TLI), DT(DT), DL(DL), SQ(DL, &TLI, &DT, &AC), ORE(ORE), LI(LI) {} - /// \brief Run the combiner over the entire worklist until it is empty. + /// Run the combiner over the entire worklist until it is empty. /// /// \returns true if the IR is changed. bool run(); @@ -289,8 +344,6 @@ public: Instruction *visitSub(BinaryOperator &I); Instruction *visitFSub(BinaryOperator &I); Instruction *visitMul(BinaryOperator &I); - Value *foldFMulConst(Instruction *FMulOrDiv, Constant *C, - Instruction *InsertBefore); Instruction *visitFMul(BinaryOperator &I); Instruction *visitURem(BinaryOperator &I); Instruction *visitSRem(BinaryOperator &I); @@ -378,7 +431,6 @@ private: bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const; bool shouldChangeType(Type *From, Type *To) const; Value *dyn_castNegVal(Value *V) const; - Value *dyn_castFNegVal(Value *V, bool NoSignedZero = false) const; Type *FindElementAtOffset(PointerType *PtrTy, int64_t Offset, SmallVectorImpl<Value *> &NewIndices); @@ -393,7 +445,7 @@ private: /// if it cannot already be eliminated by some other transformation. bool shouldOptimizeCast(CastInst *CI); - /// \brief Try to optimize a sequence of instructions checking if an operation + /// Try to optimize a sequence of instructions checking if an operation /// on LHS and RHS overflows. /// /// If this overflow check is done via one of the overflow check intrinsics, @@ -445,11 +497,22 @@ private: } bool willNotOverflowSignedSub(const Value *LHS, const Value *RHS, - const Instruction &CxtI) const; + const Instruction &CxtI) const { + return computeOverflowForSignedSub(LHS, RHS, &CxtI) == + OverflowResult::NeverOverflows; + } + bool willNotOverflowUnsignedSub(const Value *LHS, const Value *RHS, - const Instruction &CxtI) const; + const Instruction &CxtI) const { + return computeOverflowForUnsignedSub(LHS, RHS, &CxtI) == + OverflowResult::NeverOverflows; + } + bool willNotOverflowSignedMul(const Value *LHS, const Value *RHS, - const Instruction &CxtI) const; + const Instruction &CxtI) const { + return computeOverflowForSignedMul(LHS, RHS, &CxtI) == + OverflowResult::NeverOverflows; + } bool willNotOverflowUnsignedMul(const Value *LHS, const Value *RHS, const Instruction &CxtI) const { @@ -462,6 +525,7 @@ private: Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask); Instruction *foldCastedBitwiseLogic(BinaryOperator &I); Instruction *narrowBinOp(TruncInst &Trunc); + Instruction *narrowMaskedBinOp(BinaryOperator &And); Instruction *narrowRotate(TruncInst &Trunc); Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); @@ -490,7 +554,7 @@ private: Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, bool JoinedByAnd, Instruction &CxtI); public: - /// \brief Inserts an instruction \p New before instruction \p Old + /// Inserts an instruction \p New before instruction \p Old /// /// Also adds the new instruction to the worklist and returns \p New so that /// it is suitable for use as the return from the visitation patterns. @@ -503,13 +567,13 @@ public: return New; } - /// \brief Same as InsertNewInstBefore, but also sets the debug loc. + /// Same as InsertNewInstBefore, but also sets the debug loc. Instruction *InsertNewInstWith(Instruction *New, Instruction &Old) { New->setDebugLoc(Old.getDebugLoc()); return InsertNewInstBefore(New, Old); } - /// \brief A combiner-aware RAUW-like routine. + /// A combiner-aware RAUW-like routine. /// /// This method is to be used when an instruction is found to be dead, /// replaceable with another preexisting expression. Here we add all uses of @@ -527,8 +591,8 @@ public: if (&I == V) V = UndefValue::get(I.getType()); - DEBUG(dbgs() << "IC: Replacing " << I << "\n" - << " with " << *V << '\n'); + LLVM_DEBUG(dbgs() << "IC: Replacing " << I << "\n" + << " with " << *V << '\n'); I.replaceAllUsesWith(V); return &I; @@ -544,13 +608,13 @@ public: return InsertValueInst::Create(Struct, Result, 0); } - /// \brief Combiner aware instruction erasure. + /// Combiner aware instruction erasure. /// /// When dealing with an instruction that has side effects or produces a void /// value, we can't rely on DCE to delete the instruction. Instead, visit /// methods should return the value returned by this function. Instruction *eraseInstFromFunction(Instruction &I) { - DEBUG(dbgs() << "IC: ERASE " << I << '\n'); + LLVM_DEBUG(dbgs() << "IC: ERASE " << I << '\n'); assert(I.use_empty() && "Cannot erase instruction that is used!"); salvageDebugInfo(I); @@ -599,6 +663,12 @@ public: return llvm::computeOverflowForUnsignedMul(LHS, RHS, DL, &AC, CxtI, &DT); } + OverflowResult computeOverflowForSignedMul(const Value *LHS, + const Value *RHS, + const Instruction *CxtI) const { + return llvm::computeOverflowForSignedMul(LHS, RHS, DL, &AC, CxtI, &DT); + } + OverflowResult computeOverflowForUnsignedAdd(const Value *LHS, const Value *RHS, const Instruction *CxtI) const { @@ -611,15 +681,26 @@ public: return llvm::computeOverflowForSignedAdd(LHS, RHS, DL, &AC, CxtI, &DT); } + OverflowResult computeOverflowForUnsignedSub(const Value *LHS, + const Value *RHS, + const Instruction *CxtI) const { + return llvm::computeOverflowForUnsignedSub(LHS, RHS, DL, &AC, CxtI, &DT); + } + + OverflowResult computeOverflowForSignedSub(const Value *LHS, const Value *RHS, + const Instruction *CxtI) const { + return llvm::computeOverflowForSignedSub(LHS, RHS, DL, &AC, CxtI, &DT); + } + /// Maximum size of array considered when transforming. uint64_t MaxArraySizeForCombine; private: - /// \brief Performs a few simplifications for operators which are associative + /// Performs a few simplifications for operators which are associative /// or commutative. bool SimplifyAssociativeOrCommutative(BinaryOperator &I); - /// \brief Tries to simplify binary operations which some other binary + /// Tries to simplify binary operations which some other binary /// operation distributes over. /// /// It does this by either by factorizing out common terms (eg "(A*B)+(A*C)" @@ -628,6 +709,13 @@ private: /// value, or null if it didn't simplify. Value *SimplifyUsingDistributiveLaws(BinaryOperator &I); + /// Tries to simplify add operations using the definition of remainder. + /// + /// The definition of remainder is X % C = X - (X / C ) * C. The add + /// expression X % C0 + (( X / C0 ) % C1) * C0 can be simplified to + /// X % (C0 * C1) + Value *SimplifyAddWithRemainder(BinaryOperator &I); + // Binary Op helper for select operations where the expression can be // efficiently reorganized. Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS, @@ -647,7 +735,7 @@ private: ConstantInt *&Less, ConstantInt *&Equal, ConstantInt *&Greater); - /// \brief Attempts to replace V with a simpler value based on the demanded + /// Attempts to replace V with a simpler value based on the demanded /// bits. Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownBits &Known, unsigned Depth, Instruction *CxtI); @@ -669,15 +757,19 @@ private: Instruction *Shr, const APInt &ShrOp1, Instruction *Shl, const APInt &ShlOp1, const APInt &DemandedMask, KnownBits &Known); - /// \brief Tries to simplify operands to an integer instruction based on its + /// Tries to simplify operands to an integer instruction based on its /// demanded bits. bool SimplifyDemandedInstructionBits(Instruction &Inst); + Value *simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, + APInt DemandedElts, + int DmaskIdx = -1); + Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, unsigned Depth = 0); - Value *SimplifyVectorOp(BinaryOperator &Inst); - + /// Canonicalize the position of binops relative to shufflevector. + Instruction *foldShuffledBinop(BinaryOperator &Inst); /// Given a binary operator, cast instruction, or select which has a PHI node /// as operand #0, see if we can fold the instruction into the PHI (which is @@ -691,11 +783,11 @@ private: Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI); /// This is a convenience wrapper function for the above two functions. - Instruction *foldOpWithConstantIntoOperand(BinaryOperator &I); + Instruction *foldBinOpIntoSelectOrPhi(BinaryOperator &I); Instruction *foldAddWithConstant(BinaryOperator &Add); - /// \brief Try to rotate an operation below a PHI node, using PHI nodes for + /// Try to rotate an operation below a PHI node, using PHI nodes for /// its operands. Instruction *FoldPHIArgOpIntoPHI(PHINode &PN); Instruction *FoldPHIArgBinOpIntoPHI(PHINode &PN); @@ -735,6 +827,8 @@ private: Instruction *foldICmpSelectConstant(ICmpInst &Cmp, SelectInst *Select, ConstantInt *C); + Instruction *foldICmpBitCastConstant(ICmpInst &Cmp, BitCastInst *Bitcast, + const APInt &C); Instruction *foldICmpTruncConstant(ICmpInst &Cmp, TruncInst *Trunc, const APInt &C); Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And, @@ -789,13 +883,12 @@ private: Instruction *MatchBSwap(BinaryOperator &I); bool SimplifyStoreAtEndOfBlock(StoreInst &SI); - Instruction *SimplifyElementUnorderedAtomicMemCpy(AtomicMemCpyInst *AMI); - Instruction *SimplifyMemTransfer(MemIntrinsic *MI); - Instruction *SimplifyMemSet(MemSetInst *MI); + Instruction *SimplifyAnyMemTransfer(AnyMemTransferInst *MI); + Instruction *SimplifyAnyMemSet(AnyMemSetInst *MI); Value *EvaluateInDifferentType(Value *V, Type *Ty, bool isSigned); - /// \brief Returns a value X such that Val = X * Scale, or null if none. + /// Returns a value X such that Val = X * Scale, or null if none. /// /// If the multiplication is known not to overflow then NoSignedWrap is set. Value *Descale(Value *Val, APInt Scale, bool &NoSignedWrap); diff --git a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index d4f06e18b957..742caf649007 100644 --- a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/Loads.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/IntrinsicInst.h" @@ -23,7 +24,6 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" using namespace llvm; using namespace PatternMatch; @@ -270,7 +270,7 @@ void PointerReplacer::findLoadAndReplace(Instruction &I) { auto *Inst = dyn_cast<Instruction>(&*U); if (!Inst) return; - DEBUG(dbgs() << "Found pointer user: " << *U << '\n'); + LLVM_DEBUG(dbgs() << "Found pointer user: " << *U << '\n'); if (isa<LoadInst>(Inst)) { for (auto P : Path) replace(P); @@ -405,8 +405,8 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { Copy->getSource(), AI.getAlignment(), DL, &AI, &AC, &DT); if (AI.getAlignment() <= SourceAlign && isDereferenceableForAllocaSize(Copy->getSource(), &AI, DL)) { - DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n'); - DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); + LLVM_DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n'); + LLVM_DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); for (unsigned i = 0, e = ToDelete.size(); i != e; ++i) eraseInstFromFunction(*ToDelete[i]); Constant *TheSrc = cast<Constant>(Copy->getSource()); @@ -437,10 +437,10 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { // Are we allowed to form a atomic load or store of this type? static bool isSupportedAtomicType(Type *Ty) { - return Ty->isIntegerTy() || Ty->isPointerTy() || Ty->isFloatingPointTy(); + return Ty->isIntOrPtrTy() || Ty->isFloatingPointTy(); } -/// \brief Helper to combine a load to a new type. +/// Helper to combine a load to a new type. /// /// This just does the work of combining a load to a new type. It handles /// metadata, etc., and returns the new instruction. The \c NewTy should be the @@ -453,15 +453,20 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT const Twine &Suffix = "") { assert((!LI.isAtomic() || isSupportedAtomicType(NewTy)) && "can't fold an atomic load to requested type"); - + Value *Ptr = LI.getPointerOperand(); unsigned AS = LI.getPointerAddressSpace(); SmallVector<std::pair<unsigned, MDNode *>, 8> MD; LI.getAllMetadata(MD); + Value *NewPtr = nullptr; + if (!(match(Ptr, m_BitCast(m_Value(NewPtr))) && + NewPtr->getType()->getPointerElementType() == NewTy && + NewPtr->getType()->getPointerAddressSpace() == AS)) + NewPtr = IC.Builder.CreateBitCast(Ptr, NewTy->getPointerTo(AS)); + LoadInst *NewLoad = IC.Builder.CreateAlignedLoad( - IC.Builder.CreateBitCast(Ptr, NewTy->getPointerTo(AS)), - LI.getAlignment(), LI.isVolatile(), LI.getName() + Suffix); + NewPtr, LI.getAlignment(), LI.isVolatile(), LI.getName() + Suffix); NewLoad->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); MDBuilder MDB(NewLoad->getContext()); for (const auto &MDPair : MD) { @@ -507,7 +512,7 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT return NewLoad; } -/// \brief Combine a store to a new type. +/// Combine a store to a new type. /// /// Returns the newly created store instruction. static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value *V) { @@ -584,7 +589,7 @@ static bool isMinMaxWithLoads(Value *V) { match(L2, m_Load(m_Specific(LHS)))); } -/// \brief Combine loads to match the type of their uses' value after looking +/// Combine loads to match the type of their uses' value after looking /// through intervening bitcasts. /// /// The core idea here is that if the result of a load is used in an operation, @@ -959,23 +964,26 @@ static Instruction *replaceGEPIdxWithZero(InstCombiner &IC, Value *Ptr, } static bool canSimplifyNullStoreOrGEP(StoreInst &SI) { - if (SI.getPointerAddressSpace() != 0) + if (NullPointerIsDefined(SI.getFunction(), SI.getPointerAddressSpace())) return false; auto *Ptr = SI.getPointerOperand(); if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Ptr)) Ptr = GEPI->getOperand(0); - return isa<ConstantPointerNull>(Ptr); + return (isa<ConstantPointerNull>(Ptr) && + !NullPointerIsDefined(SI.getFunction(), SI.getPointerAddressSpace())); } static bool canSimplifyNullLoadOrGEP(LoadInst &LI, Value *Op) { if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Op)) { const Value *GEPI0 = GEPI->getOperand(0); - if (isa<ConstantPointerNull>(GEPI0) && GEPI->getPointerAddressSpace() == 0) + if (isa<ConstantPointerNull>(GEPI0) && + !NullPointerIsDefined(LI.getFunction(), GEPI->getPointerAddressSpace())) return true; } if (isa<UndefValue>(Op) || - (isa<ConstantPointerNull>(Op) && LI.getPointerAddressSpace() == 0)) + (isa<ConstantPointerNull>(Op) && + !NullPointerIsDefined(LI.getFunction(), LI.getPointerAddressSpace()))) return true; return false; } @@ -1071,14 +1079,16 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // load (select (cond, null, P)) -> load P if (isa<ConstantPointerNull>(SI->getOperand(1)) && - LI.getPointerAddressSpace() == 0) { + !NullPointerIsDefined(SI->getFunction(), + LI.getPointerAddressSpace())) { LI.setOperand(0, SI->getOperand(2)); return &LI; } // load (select (cond, P, null)) -> load P if (isa<ConstantPointerNull>(SI->getOperand(2)) && - LI.getPointerAddressSpace() == 0) { + !NullPointerIsDefined(SI->getFunction(), + LI.getPointerAddressSpace())) { LI.setOperand(0, SI->getOperand(1)); return &LI; } @@ -1087,7 +1097,7 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { return nullptr; } -/// \brief Look for extractelement/insertvalue sequence that acts like a bitcast. +/// Look for extractelement/insertvalue sequence that acts like a bitcast. /// /// \returns underlying value that was "cast", or nullptr otherwise. /// @@ -1142,7 +1152,7 @@ static Value *likeBitCastFromVector(InstCombiner &IC, Value *V) { return U; } -/// \brief Combine stores to match the type of value being stored. +/// Combine stores to match the type of value being stored. /// /// The core idea here is that the memory does not have any intrinsic type and /// where we can we should match the type of a store to the type of value being diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 541dde6c47d2..63761d427235 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -33,6 +33,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include "llvm/Transforms/Utils/BuildLibCalls.h" #include <cassert> #include <cstddef> #include <cstdint> @@ -94,115 +95,52 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, return MadeChange ? V : nullptr; } -/// True if the multiply can not be expressed in an int this size. -static bool MultiplyOverflows(const APInt &C1, const APInt &C2, APInt &Product, - bool IsSigned) { - bool Overflow; - if (IsSigned) - Product = C1.smul_ov(C2, Overflow); - else - Product = C1.umul_ov(C2, Overflow); - - return Overflow; -} - -/// \brief True if C2 is a multiple of C1. Quotient contains C2/C1. -static bool IsMultiple(const APInt &C1, const APInt &C2, APInt &Quotient, - bool IsSigned) { - assert(C1.getBitWidth() == C2.getBitWidth() && - "Inconsistent width of constants!"); - - // Bail if we will divide by zero. - if (C2.isMinValue()) - return false; - - // Bail if we would divide INT_MIN by -1. - if (IsSigned && C1.isMinSignedValue() && C2.isAllOnesValue()) - return false; - - APInt Remainder(C1.getBitWidth(), /*Val=*/0ULL, IsSigned); - if (IsSigned) - APInt::sdivrem(C1, C2, Quotient, Remainder); - else - APInt::udivrem(C1, C2, Quotient, Remainder); - - return Remainder.isMinValue(); -} - -/// \brief A helper routine of InstCombiner::visitMul(). +/// A helper routine of InstCombiner::visitMul(). /// -/// If C is a vector of known powers of 2, then this function returns -/// a new vector obtained from C replacing each element with its logBase2. +/// If C is a scalar/vector of known powers of 2, then this function returns +/// a new scalar/vector obtained from logBase2 of C. /// Return a null pointer otherwise. -static Constant *getLogBase2Vector(ConstantDataVector *CV) { +static Constant *getLogBase2(Type *Ty, Constant *C) { const APInt *IVal; - SmallVector<Constant *, 4> Elts; + if (match(C, m_APInt(IVal)) && IVal->isPowerOf2()) + return ConstantInt::get(Ty, IVal->logBase2()); + + if (!Ty->isVectorTy()) + return nullptr; - for (unsigned I = 0, E = CV->getNumElements(); I != E; ++I) { - Constant *Elt = CV->getElementAsConstant(I); + SmallVector<Constant *, 4> Elts; + for (unsigned I = 0, E = Ty->getVectorNumElements(); I != E; ++I) { + Constant *Elt = C->getAggregateElement(I); + if (!Elt) + return nullptr; + if (isa<UndefValue>(Elt)) { + Elts.push_back(UndefValue::get(Ty->getScalarType())); + continue; + } if (!match(Elt, m_APInt(IVal)) || !IVal->isPowerOf2()) return nullptr; - Elts.push_back(ConstantInt::get(Elt->getType(), IVal->logBase2())); + Elts.push_back(ConstantInt::get(Ty->getScalarType(), IVal->logBase2())); } return ConstantVector::get(Elts); } -/// \brief Return true if we can prove that: -/// (mul LHS, RHS) === (mul nsw LHS, RHS) -bool InstCombiner::willNotOverflowSignedMul(const Value *LHS, - const Value *RHS, - const Instruction &CxtI) const { - // Multiplying n * m significant bits yields a result of n + m significant - // bits. If the total number of significant bits does not exceed the - // result bit width (minus 1), there is no overflow. - // This means if we have enough leading sign bits in the operands - // we can guarantee that the result does not overflow. - // Ref: "Hacker's Delight" by Henry Warren - unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); - - // Note that underestimating the number of sign bits gives a more - // conservative answer. - unsigned SignBits = - ComputeNumSignBits(LHS, 0, &CxtI) + ComputeNumSignBits(RHS, 0, &CxtI); - - // First handle the easy case: if we have enough sign bits there's - // definitely no overflow. - if (SignBits > BitWidth + 1) - return true; - - // There are two ambiguous cases where there can be no overflow: - // SignBits == BitWidth + 1 and - // SignBits == BitWidth - // The second case is difficult to check, therefore we only handle the - // first case. - if (SignBits == BitWidth + 1) { - // It overflows only when both arguments are negative and the true - // product is exactly the minimum negative number. - // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000 - // For simplicity we just check if at least one side is not negative. - KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, &CxtI); - KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, &CxtI); - if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative()) - return true; - } - return false; -} - Instruction *InstCombiner::visitMul(BinaryOperator &I) { - bool Changed = SimplifyAssociativeOrCommutative(I); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyMulInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyMulInst(Op0, Op1, SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldShuffledBinop(I)) + return X; if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); // X * -1 == 0 - X + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (match(Op1, m_AllOnes())) { BinaryOperator *BO = BinaryOperator::CreateNeg(Op0, I.getName()); if (I.hasNoSignedWrap()) @@ -231,16 +169,8 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } if (match(&I, m_Mul(m_Value(NewOp), m_Constant(C1)))) { - Constant *NewCst = nullptr; - if (match(C1, m_APInt(IVal)) && IVal->isPowerOf2()) - // Replace X*(2^C) with X << C, where C is either a scalar or a splat. - NewCst = ConstantInt::get(NewOp->getType(), IVal->logBase2()); - else if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(C1)) - // Replace X*(2^C) with X << C, where C is a vector of known - // constant powers of 2. - NewCst = getLogBase2Vector(CV); - - if (NewCst) { + // Replace X*(2^C) with X << C, where C is either a scalar or a vector. + if (Constant *NewCst = getLogBase2(NewOp->getType(), C1)) { unsigned Width = NewCst->getType()->getPrimitiveSizeInBits(); BinaryOperator *Shl = BinaryOperator::CreateShl(NewOp, NewCst); @@ -282,34 +212,37 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } + if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I)) + return FoldedMul; + // Simplify mul instructions with a constant RHS. if (isa<Constant>(Op1)) { - if (Instruction *FoldedMul = foldOpWithConstantIntoOperand(I)) - return FoldedMul; - // Canonicalize (X+C1)*CI -> X*CI+C1*CI. - { - Value *X; - Constant *C1; - if (match(Op0, m_OneUse(m_Add(m_Value(X), m_Constant(C1))))) { - Value *Mul = Builder.CreateMul(C1, Op1); - // Only go forward with the transform if C1*CI simplifies to a tidier - // constant. - if (!match(Mul, m_Mul(m_Value(), m_Value()))) - return BinaryOperator::CreateAdd(Builder.CreateMul(X, Op1), Mul); - } + Value *X; + Constant *C1; + if (match(Op0, m_OneUse(m_Add(m_Value(X), m_Constant(C1))))) { + Value *Mul = Builder.CreateMul(C1, Op1); + // Only go forward with the transform if C1*CI simplifies to a tidier + // constant. + if (!match(Mul, m_Mul(m_Value(), m_Value()))) + return BinaryOperator::CreateAdd(Builder.CreateMul(X, Op1), Mul); } } - if (Value *Op0v = dyn_castNegVal(Op0)) { // -X * -Y = X*Y - if (Value *Op1v = dyn_castNegVal(Op1)) { - BinaryOperator *BO = BinaryOperator::CreateMul(Op0v, Op1v); - if (I.hasNoSignedWrap() && - match(Op0, m_NSWSub(m_Value(), m_Value())) && - match(Op1, m_NSWSub(m_Value(), m_Value()))) - BO->setHasNoSignedWrap(); - return BO; - } + // -X * C --> X * -C + Value *X, *Y; + Constant *Op1C; + if (match(Op0, m_Neg(m_Value(X))) && match(Op1, m_Constant(Op1C))) + return BinaryOperator::CreateMul(X, ConstantExpr::getNeg(Op1C)); + + // -X * -Y --> X * Y + if (match(Op0, m_Neg(m_Value(X))) && match(Op1, m_Neg(m_Value(Y)))) { + auto *NewMul = BinaryOperator::CreateMul(X, Y); + if (I.hasNoSignedWrap() && + cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap() && + cast<OverflowingBinaryOperator>(Op1)->hasNoSignedWrap()) + NewMul->setHasNoSignedWrap(); + return NewMul; } // (X / Y) * Y = X - (X % Y) @@ -371,28 +304,24 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } - // If one of the operands of the multiply is a cast from a boolean value, then - // we know the bool is either zero or one, so this is a 'masking' multiply. - // X * Y (where Y is 0 or 1) -> X & (0-Y) - if (!I.getType()->isVectorTy()) { - // -2 is "-1 << 1" so it is all bits set except the low one. - APInt Negative2(I.getType()->getPrimitiveSizeInBits(), (uint64_t)-2, true); - - Value *BoolCast = nullptr, *OtherOp = nullptr; - if (MaskedValueIsZero(Op0, Negative2, 0, &I)) { - BoolCast = Op0; - OtherOp = Op1; - } else if (MaskedValueIsZero(Op1, Negative2, 0, &I)) { - BoolCast = Op1; - OtherOp = Op0; - } - - if (BoolCast) { - Value *V = Builder.CreateSub(Constant::getNullValue(I.getType()), - BoolCast); - return BinaryOperator::CreateAnd(V, OtherOp); - } - } + // (bool X) * Y --> X ? Y : 0 + // Y * (bool X) --> X ? Y : 0 + if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(X, Op1, ConstantInt::get(I.getType(), 0)); + if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(X, Op0, ConstantInt::get(I.getType(), 0)); + + // (lshr X, 31) * Y --> (ashr X, 31) & Y + // Y * (lshr X, 31) --> (ashr X, 31) & Y + // TODO: We are not checking one-use because the elimination of the multiply + // is better for analysis? + // TODO: Should we canonicalize to '(X < 0) ? Y : 0' instead? That would be + // more similar to what we're doing above. + const APInt *C; + if (match(Op0, m_LShr(m_Value(X), m_APInt(C))) && *C == C->getBitWidth() - 1) + return BinaryOperator::CreateAnd(Builder.CreateAShr(X, *C), Op1); + if (match(Op1, m_LShr(m_Value(X), m_APInt(C))) && *C == C->getBitWidth() - 1) + return BinaryOperator::CreateAnd(Builder.CreateAShr(X, *C), Op0); // Check for (mul (sext x), y), see if we can merge this into an // integer mul followed by a sext. @@ -466,6 +395,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } + bool Changed = false; if (!I.hasNoSignedWrap() && willNotOverflowSignedMul(Op0, Op1, I)) { Changed = true; I.setHasNoSignedWrap(true); @@ -479,286 +409,103 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { return Changed ? &I : nullptr; } -/// Detect pattern log2(Y * 0.5) with corresponding fast math flags. -static void detectLog2OfHalf(Value *&Op, Value *&Y, IntrinsicInst *&Log2) { - if (!Op->hasOneUse()) - return; - - IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op); - if (!II) - return; - if (II->getIntrinsicID() != Intrinsic::log2 || !II->isFast()) - return; - Log2 = II; - - Value *OpLog2Of = II->getArgOperand(0); - if (!OpLog2Of->hasOneUse()) - return; - - Instruction *I = dyn_cast<Instruction>(OpLog2Of); - if (!I) - return; - - if (I->getOpcode() != Instruction::FMul || !I->isFast()) - return; - - if (match(I->getOperand(0), m_SpecificFP(0.5))) - Y = I->getOperand(1); - else if (match(I->getOperand(1), m_SpecificFP(0.5))) - Y = I->getOperand(0); -} - -static bool isFiniteNonZeroFp(Constant *C) { - if (C->getType()->isVectorTy()) { - for (unsigned I = 0, E = C->getType()->getVectorNumElements(); I != E; - ++I) { - ConstantFP *CFP = dyn_cast_or_null<ConstantFP>(C->getAggregateElement(I)); - if (!CFP || !CFP->getValueAPF().isFiniteNonZero()) - return false; - } - return true; - } - - return isa<ConstantFP>(C) && - cast<ConstantFP>(C)->getValueAPF().isFiniteNonZero(); -} - -static bool isNormalFp(Constant *C) { - if (C->getType()->isVectorTy()) { - for (unsigned I = 0, E = C->getType()->getVectorNumElements(); I != E; - ++I) { - ConstantFP *CFP = dyn_cast_or_null<ConstantFP>(C->getAggregateElement(I)); - if (!CFP || !CFP->getValueAPF().isNormal()) - return false; - } - return true; - } - - return isa<ConstantFP>(C) && cast<ConstantFP>(C)->getValueAPF().isNormal(); -} - -/// Helper function of InstCombiner::visitFMul(BinaryOperator(). It returns -/// true iff the given value is FMul or FDiv with one and only one operand -/// being a normal constant (i.e. not Zero/NaN/Infinity). -static bool isFMulOrFDivWithConstant(Value *V) { - Instruction *I = dyn_cast<Instruction>(V); - if (!I || (I->getOpcode() != Instruction::FMul && - I->getOpcode() != Instruction::FDiv)) - return false; - - Constant *C0 = dyn_cast<Constant>(I->getOperand(0)); - Constant *C1 = dyn_cast<Constant>(I->getOperand(1)); - - if (C0 && C1) - return false; - - return (C0 && isFiniteNonZeroFp(C0)) || (C1 && isFiniteNonZeroFp(C1)); -} +Instruction *InstCombiner::visitFMul(BinaryOperator &I) { + if (Value *V = SimplifyFMulInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); -/// foldFMulConst() is a helper routine of InstCombiner::visitFMul(). -/// The input \p FMulOrDiv is a FMul/FDiv with one and only one operand -/// being a constant (i.e. isFMulOrFDivWithConstant(FMulOrDiv) == true). -/// This function is to simplify "FMulOrDiv * C" and returns the -/// resulting expression. Note that this function could return NULL in -/// case the constants cannot be folded into a normal floating-point. -Value *InstCombiner::foldFMulConst(Instruction *FMulOrDiv, Constant *C, - Instruction *InsertBefore) { - assert(isFMulOrFDivWithConstant(FMulOrDiv) && "V is invalid"); - - Value *Opnd0 = FMulOrDiv->getOperand(0); - Value *Opnd1 = FMulOrDiv->getOperand(1); - - Constant *C0 = dyn_cast<Constant>(Opnd0); - Constant *C1 = dyn_cast<Constant>(Opnd1); - - BinaryOperator *R = nullptr; - - // (X * C0) * C => X * (C0*C) - if (FMulOrDiv->getOpcode() == Instruction::FMul) { - Constant *F = ConstantExpr::getFMul(C1 ? C1 : C0, C); - if (isNormalFp(F)) - R = BinaryOperator::CreateFMul(C1 ? Opnd0 : Opnd1, F); - } else { - if (C0) { - // (C0 / X) * C => (C0 * C) / X - if (FMulOrDiv->hasOneUse()) { - // It would otherwise introduce another div. - Constant *F = ConstantExpr::getFMul(C0, C); - if (isNormalFp(F)) - R = BinaryOperator::CreateFDiv(F, Opnd1); - } - } else { - // (X / C1) * C => X * (C/C1) if C/C1 is not a denormal - Constant *F = ConstantExpr::getFDiv(C, C1); - if (isNormalFp(F)) { - R = BinaryOperator::CreateFMul(Opnd0, F); - } else { - // (X / C1) * C => X / (C1/C) - Constant *F = ConstantExpr::getFDiv(C1, C); - if (isNormalFp(F)) - R = BinaryOperator::CreateFDiv(Opnd0, F); - } - } - } + if (SimplifyAssociativeOrCommutative(I)) + return &I; - if (R) { - R->setFast(true); - InsertNewInstWith(R, *InsertBefore); - } + if (Instruction *X = foldShuffledBinop(I)) + return X; - return R; -} + if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I)) + return FoldedMul; -Instruction *InstCombiner::visitFMul(BinaryOperator &I) { - bool Changed = SimplifyAssociativeOrCommutative(I); + // X * -1.0 --> -X Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (match(Op1, m_SpecificFP(-1.0))) + return BinaryOperator::CreateFNegFMF(Op0, &I); - if (Value *V = SimplifyVectorOp(I)) - return replaceInstUsesWith(I, V); + // -X * -Y --> X * Y + Value *X, *Y; + if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) + return BinaryOperator::CreateFMulFMF(X, Y, &I); - if (isa<Constant>(Op0)) - std::swap(Op0, Op1); + // -X * C --> X * -C + Constant *C; + if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_Constant(C))) + return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); - if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), - SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + // Sink negation: -X * Y --> -(X * Y) + if (match(Op0, m_OneUse(m_FNeg(m_Value(X))))) + return BinaryOperator::CreateFNegFMF(Builder.CreateFMulFMF(X, Op1, &I), &I); - bool AllowReassociate = I.isFast(); + // Sink negation: Y * -X --> -(X * Y) + if (match(Op1, m_OneUse(m_FNeg(m_Value(X))))) + return BinaryOperator::CreateFNegFMF(Builder.CreateFMulFMF(X, Op0, &I), &I); - // Simplify mul instructions with a constant RHS. - if (isa<Constant>(Op1)) { - if (Instruction *FoldedMul = foldOpWithConstantIntoOperand(I)) - return FoldedMul; - - // (fmul X, -1.0) --> (fsub -0.0, X) - if (match(Op1, m_SpecificFP(-1.0))) { - Constant *NegZero = ConstantFP::getNegativeZero(Op1->getType()); - Instruction *RI = BinaryOperator::CreateFSub(NegZero, Op0); - RI->copyFastMathFlags(&I); - return RI; - } - - Constant *C = cast<Constant>(Op1); - if (AllowReassociate && isFiniteNonZeroFp(C)) { - // Let MDC denote an expression in one of these forms: - // X * C, C/X, X/C, where C is a constant. - // - // Try to simplify "MDC * Constant" - if (isFMulOrFDivWithConstant(Op0)) - if (Value *V = foldFMulConst(cast<Instruction>(Op0), C, &I)) - return replaceInstUsesWith(I, V); - - // (MDC +/- C1) * C => (MDC * C) +/- (C1 * C) - Instruction *FAddSub = dyn_cast<Instruction>(Op0); - if (FAddSub && - (FAddSub->getOpcode() == Instruction::FAdd || - FAddSub->getOpcode() == Instruction::FSub)) { - Value *Opnd0 = FAddSub->getOperand(0); - Value *Opnd1 = FAddSub->getOperand(1); - Constant *C0 = dyn_cast<Constant>(Opnd0); - Constant *C1 = dyn_cast<Constant>(Opnd1); - bool Swap = false; - if (C0) { - std::swap(C0, C1); - std::swap(Opnd0, Opnd1); - Swap = true; - } + // fabs(X) * fabs(X) -> X * X + if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X)))) + return BinaryOperator::CreateFMulFMF(X, X, &I); - if (C1 && isFiniteNonZeroFp(C1) && isFMulOrFDivWithConstant(Opnd0)) { - Value *M1 = ConstantExpr::getFMul(C1, C); - Value *M0 = isNormalFp(cast<Constant>(M1)) ? - foldFMulConst(cast<Instruction>(Opnd0), C, &I) : - nullptr; - if (M0 && M1) { - if (Swap && FAddSub->getOpcode() == Instruction::FSub) - std::swap(M0, M1); - - Instruction *RI = (FAddSub->getOpcode() == Instruction::FAdd) - ? BinaryOperator::CreateFAdd(M0, M1) - : BinaryOperator::CreateFSub(M0, M1); - RI->copyFastMathFlags(&I); - return RI; - } - } - } - } - } + // (select A, B, C) * (select A, D, E) --> select A, (B*D), (C*E) + if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) + return replaceInstUsesWith(I, V); - if (Op0 == Op1) { - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) { - // sqrt(X) * sqrt(X) -> X - if (AllowReassociate && II->getIntrinsicID() == Intrinsic::sqrt) - return replaceInstUsesWith(I, II->getOperand(0)); - - // fabs(X) * fabs(X) -> X * X - if (II->getIntrinsicID() == Intrinsic::fabs) { - Instruction *FMulVal = BinaryOperator::CreateFMul(II->getOperand(0), - II->getOperand(0), - I.getName()); - FMulVal->copyFastMathFlags(&I); - return FMulVal; + if (I.hasAllowReassoc()) { + // Reassociate constant RHS with another constant to form constant + // expression. + if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP()) { + Constant *C1; + if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) { + // (C1 / X) * C --> (C * C1) / X + Constant *CC1 = ConstantExpr::getFMul(C, C1); + if (CC1->isNormalFP()) + return BinaryOperator::CreateFDivFMF(CC1, X, &I); } - } - } - - // Under unsafe algebra do: - // X * log2(0.5*Y) = X*log2(Y) - X - if (AllowReassociate) { - Value *OpX = nullptr; - Value *OpY = nullptr; - IntrinsicInst *Log2; - detectLog2OfHalf(Op0, OpY, Log2); - if (OpY) { - OpX = Op1; - } else { - detectLog2OfHalf(Op1, OpY, Log2); - if (OpY) { - OpX = Op0; + if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) { + // (X / C1) * C --> X * (C / C1) + Constant *CDivC1 = ConstantExpr::getFDiv(C, C1); + if (CDivC1->isNormalFP()) + return BinaryOperator::CreateFMulFMF(X, CDivC1, &I); + + // If the constant was a denormal, try reassociating differently. + // (X / C1) * C --> X / (C1 / C) + Constant *C1DivC = ConstantExpr::getFDiv(C1, C); + if (Op0->hasOneUse() && C1DivC->isNormalFP()) + return BinaryOperator::CreateFDivFMF(X, C1DivC, &I); } - } - // if pattern detected emit alternate sequence - if (OpX && OpY) { - BuilderTy::FastMathFlagGuard Guard(Builder); - Builder.setFastMathFlags(Log2->getFastMathFlags()); - Log2->setArgOperand(0, OpY); - Value *FMulVal = Builder.CreateFMul(OpX, Log2); - Value *FSub = Builder.CreateFSub(FMulVal, OpX); - FSub->takeName(&I); - return replaceInstUsesWith(I, FSub); - } - } - // Handle symmetric situation in a 2-iteration loop - Value *Opnd0 = Op0; - Value *Opnd1 = Op1; - for (int i = 0; i < 2; i++) { - bool IgnoreZeroSign = I.hasNoSignedZeros(); - if (BinaryOperator::isFNeg(Opnd0, IgnoreZeroSign)) { - BuilderTy::FastMathFlagGuard Guard(Builder); - Builder.setFastMathFlags(I.getFastMathFlags()); - - Value *N0 = dyn_castFNegVal(Opnd0, IgnoreZeroSign); - Value *N1 = dyn_castFNegVal(Opnd1, IgnoreZeroSign); - - // -X * -Y => X*Y - if (N1) { - Value *FMul = Builder.CreateFMul(N0, N1); - FMul->takeName(&I); - return replaceInstUsesWith(I, FMul); + // We do not need to match 'fadd C, X' and 'fsub X, C' because they are + // canonicalized to 'fadd X, C'. Distributing the multiply may allow + // further folds and (X * C) + C2 is 'fma'. + if (match(Op0, m_OneUse(m_FAdd(m_Value(X), m_Constant(C1))))) { + // (X + C1) * C --> (X * C) + (C * C1) + Constant *CC1 = ConstantExpr::getFMul(C, C1); + Value *XC = Builder.CreateFMulFMF(X, C, &I); + return BinaryOperator::CreateFAddFMF(XC, CC1, &I); } - - if (Opnd0->hasOneUse()) { - // -X * Y => -(X*Y) (Promote negation as high as possible) - Value *T = Builder.CreateFMul(N0, Opnd1); - Value *Neg = Builder.CreateFNeg(T); - Neg->takeName(&I); - return replaceInstUsesWith(I, Neg); + if (match(Op0, m_OneUse(m_FSub(m_Constant(C1), m_Value(X))))) { + // (C1 - X) * C --> (C * C1) - (X * C) + Constant *CC1 = ConstantExpr::getFMul(C, C1); + Value *XC = Builder.CreateFMulFMF(X, C, &I); + return BinaryOperator::CreateFSubFMF(CC1, XC, &I); } } - // Handle specials cases for FMul with selects feeding the operation - if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) - return replaceInstUsesWith(I, V); + // sqrt(X) * sqrt(Y) -> sqrt(X * Y) + // nnan disallows the possibility of returning a number if both operands are + // negative (in that case, we should return NaN). + if (I.hasNoNaNs() && + match(Op0, m_OneUse(m_Intrinsic<Intrinsic::sqrt>(m_Value(X)))) && + match(Op1, m_OneUse(m_Intrinsic<Intrinsic::sqrt>(m_Value(Y))))) { + Value *XY = Builder.CreateFMulFMF(X, Y, &I); + Value *Sqrt = Builder.CreateIntrinsic(Intrinsic::sqrt, { XY }, &I); + return replaceInstUsesWith(I, Sqrt); + } // (X*Y) * X => (X*X) * Y where Y != X // The purpose is two-fold: @@ -767,34 +514,40 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { // latency of the instruction Y is amortized by the expression of X*X, // and therefore Y is in a "less critical" position compared to what it // was before the transformation. - if (AllowReassociate) { - Value *Opnd0_0, *Opnd0_1; - if (Opnd0->hasOneUse() && - match(Opnd0, m_FMul(m_Value(Opnd0_0), m_Value(Opnd0_1)))) { - Value *Y = nullptr; - if (Opnd0_0 == Opnd1 && Opnd0_1 != Opnd1) - Y = Opnd0_1; - else if (Opnd0_1 == Opnd1 && Opnd0_0 != Opnd1) - Y = Opnd0_0; - - if (Y) { - BuilderTy::FastMathFlagGuard Guard(Builder); - Builder.setFastMathFlags(I.getFastMathFlags()); - Value *T = Builder.CreateFMul(Opnd1, Opnd1); - Value *R = Builder.CreateFMul(T, Y); - R->takeName(&I); - return replaceInstUsesWith(I, R); - } - } + if (match(Op0, m_OneUse(m_c_FMul(m_Specific(Op1), m_Value(Y)))) && + Op1 != Y) { + Value *XX = Builder.CreateFMulFMF(Op1, Op1, &I); + return BinaryOperator::CreateFMulFMF(XX, Y, &I); + } + if (match(Op1, m_OneUse(m_c_FMul(m_Specific(Op0), m_Value(Y)))) && + Op0 != Y) { + Value *XX = Builder.CreateFMulFMF(Op0, Op0, &I); + return BinaryOperator::CreateFMulFMF(XX, Y, &I); } + } - if (!isa<Constant>(Op1)) - std::swap(Opnd0, Opnd1); - else - break; + // log2(X * 0.5) * Y = log2(X) * Y - Y + if (I.isFast()) { + IntrinsicInst *Log2 = nullptr; + if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::log2>( + m_OneUse(m_FMul(m_Value(X), m_SpecificFP(0.5))))))) { + Log2 = cast<IntrinsicInst>(Op0); + Y = Op1; + } + if (match(Op1, m_OneUse(m_Intrinsic<Intrinsic::log2>( + m_OneUse(m_FMul(m_Value(X), m_SpecificFP(0.5))))))) { + Log2 = cast<IntrinsicInst>(Op1); + Y = Op0; + } + if (Log2) { + Log2->setArgOperand(0, X); + Log2->copyFastMathFlags(&I); + Value *LogXTimesY = Builder.CreateFMulFMF(Log2, Y, &I); + return BinaryOperator::CreateFSubFMF(LogXTimesY, Y, &I); + } } - return Changed ? &I : nullptr; + return nullptr; } /// Fold a divide or remainder with a select instruction divisor when one of the @@ -835,9 +588,9 @@ bool InstCombiner::simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I) { Type *CondTy = SelectCond->getType(); while (BBI != BBFront) { --BBI; - // If we found a call to a function, we can't assume it will return, so + // If we found an instruction that we can't assume will return, so // information from below it cannot be propagated above it. - if (isa<CallInst>(BBI) && !isa<IntrinsicInst>(BBI)) + if (!isGuaranteedToTransferExecutionToSuccessor(&*BBI)) break; // Replace uses of the select or its condition with the known values. @@ -867,12 +620,44 @@ bool InstCombiner::simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I) { return true; } +/// True if the multiply can not be expressed in an int this size. +static bool multiplyOverflows(const APInt &C1, const APInt &C2, APInt &Product, + bool IsSigned) { + bool Overflow; + Product = IsSigned ? C1.smul_ov(C2, Overflow) : C1.umul_ov(C2, Overflow); + return Overflow; +} + +/// True if C1 is a multiple of C2. Quotient contains C1/C2. +static bool isMultiple(const APInt &C1, const APInt &C2, APInt &Quotient, + bool IsSigned) { + assert(C1.getBitWidth() == C2.getBitWidth() && "Constant widths not equal"); + + // Bail if we will divide by zero. + if (C2.isNullValue()) + return false; + + // Bail if we would divide INT_MIN by -1. + if (IsSigned && C1.isMinSignedValue() && C2.isAllOnesValue()) + return false; + + APInt Remainder(C1.getBitWidth(), /*Val=*/0ULL, IsSigned); + if (IsSigned) + APInt::sdivrem(C1, C2, Quotient, Remainder); + else + APInt::udivrem(C1, C2, Quotient, Remainder); + + return Remainder.isMinValue(); +} + /// This function implements the transforms common to both integer division /// instructions (udiv and sdiv). It is called by the visitors to those integer /// division instructions. -/// @brief Common integer divide transforms +/// Common integer divide transforms Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + bool IsSigned = I.getOpcode() == Instruction::SDiv; + Type *Ty = I.getType(); // The RHS is known non-zero. if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I)) { @@ -885,94 +670,87 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { if (simplifyDivRemOfSelectWithZeroOp(I)) return &I; - if (Instruction *LHS = dyn_cast<Instruction>(Op0)) { - const APInt *C2; - if (match(Op1, m_APInt(C2))) { - Value *X; - const APInt *C1; - bool IsSigned = I.getOpcode() == Instruction::SDiv; - - // (X / C1) / C2 -> X / (C1*C2) - if ((IsSigned && match(LHS, m_SDiv(m_Value(X), m_APInt(C1)))) || - (!IsSigned && match(LHS, m_UDiv(m_Value(X), m_APInt(C1))))) { - APInt Product(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); - if (!MultiplyOverflows(*C1, *C2, Product, IsSigned)) - return BinaryOperator::Create(I.getOpcode(), X, - ConstantInt::get(I.getType(), Product)); - } - - if ((IsSigned && match(LHS, m_NSWMul(m_Value(X), m_APInt(C1)))) || - (!IsSigned && match(LHS, m_NUWMul(m_Value(X), m_APInt(C1))))) { - APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); + const APInt *C2; + if (match(Op1, m_APInt(C2))) { + Value *X; + const APInt *C1; + + // (X / C1) / C2 -> X / (C1*C2) + if ((IsSigned && match(Op0, m_SDiv(m_Value(X), m_APInt(C1)))) || + (!IsSigned && match(Op0, m_UDiv(m_Value(X), m_APInt(C1))))) { + APInt Product(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); + if (!multiplyOverflows(*C1, *C2, Product, IsSigned)) + return BinaryOperator::Create(I.getOpcode(), X, + ConstantInt::get(Ty, Product)); + } - // (X * C1) / C2 -> X / (C2 / C1) if C2 is a multiple of C1. - if (IsMultiple(*C2, *C1, Quotient, IsSigned)) { - BinaryOperator *BO = BinaryOperator::Create( - I.getOpcode(), X, ConstantInt::get(X->getType(), Quotient)); - BO->setIsExact(I.isExact()); - return BO; - } + if ((IsSigned && match(Op0, m_NSWMul(m_Value(X), m_APInt(C1)))) || + (!IsSigned && match(Op0, m_NUWMul(m_Value(X), m_APInt(C1))))) { + APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); - // (X * C1) / C2 -> X * (C1 / C2) if C1 is a multiple of C2. - if (IsMultiple(*C1, *C2, Quotient, IsSigned)) { - BinaryOperator *BO = BinaryOperator::Create( - Instruction::Mul, X, ConstantInt::get(X->getType(), Quotient)); - BO->setHasNoUnsignedWrap( - !IsSigned && - cast<OverflowingBinaryOperator>(LHS)->hasNoUnsignedWrap()); - BO->setHasNoSignedWrap( - cast<OverflowingBinaryOperator>(LHS)->hasNoSignedWrap()); - return BO; - } + // (X * C1) / C2 -> X / (C2 / C1) if C2 is a multiple of C1. + if (isMultiple(*C2, *C1, Quotient, IsSigned)) { + auto *NewDiv = BinaryOperator::Create(I.getOpcode(), X, + ConstantInt::get(Ty, Quotient)); + NewDiv->setIsExact(I.isExact()); + return NewDiv; } - if ((IsSigned && match(LHS, m_NSWShl(m_Value(X), m_APInt(C1))) && - *C1 != C1->getBitWidth() - 1) || - (!IsSigned && match(LHS, m_NUWShl(m_Value(X), m_APInt(C1))))) { - APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); - APInt C1Shifted = APInt::getOneBitSet( - C1->getBitWidth(), static_cast<unsigned>(C1->getLimitedValue())); - - // (X << C1) / C2 -> X / (C2 >> C1) if C2 is a multiple of C1. - if (IsMultiple(*C2, C1Shifted, Quotient, IsSigned)) { - BinaryOperator *BO = BinaryOperator::Create( - I.getOpcode(), X, ConstantInt::get(X->getType(), Quotient)); - BO->setIsExact(I.isExact()); - return BO; - } + // (X * C1) / C2 -> X * (C1 / C2) if C1 is a multiple of C2. + if (isMultiple(*C1, *C2, Quotient, IsSigned)) { + auto *Mul = BinaryOperator::Create(Instruction::Mul, X, + ConstantInt::get(Ty, Quotient)); + auto *OBO = cast<OverflowingBinaryOperator>(Op0); + Mul->setHasNoUnsignedWrap(!IsSigned && OBO->hasNoUnsignedWrap()); + Mul->setHasNoSignedWrap(OBO->hasNoSignedWrap()); + return Mul; + } + } - // (X << C1) / C2 -> X * (C2 >> C1) if C1 is a multiple of C2. - if (IsMultiple(C1Shifted, *C2, Quotient, IsSigned)) { - BinaryOperator *BO = BinaryOperator::Create( - Instruction::Mul, X, ConstantInt::get(X->getType(), Quotient)); - BO->setHasNoUnsignedWrap( - !IsSigned && - cast<OverflowingBinaryOperator>(LHS)->hasNoUnsignedWrap()); - BO->setHasNoSignedWrap( - cast<OverflowingBinaryOperator>(LHS)->hasNoSignedWrap()); - return BO; - } + if ((IsSigned && match(Op0, m_NSWShl(m_Value(X), m_APInt(C1))) && + *C1 != C1->getBitWidth() - 1) || + (!IsSigned && match(Op0, m_NUWShl(m_Value(X), m_APInt(C1))))) { + APInt Quotient(C1->getBitWidth(), /*Val=*/0ULL, IsSigned); + APInt C1Shifted = APInt::getOneBitSet( + C1->getBitWidth(), static_cast<unsigned>(C1->getLimitedValue())); + + // (X << C1) / C2 -> X / (C2 >> C1) if C2 is a multiple of 1 << C1. + if (isMultiple(*C2, C1Shifted, Quotient, IsSigned)) { + auto *BO = BinaryOperator::Create(I.getOpcode(), X, + ConstantInt::get(Ty, Quotient)); + BO->setIsExact(I.isExact()); + return BO; } - if (!C2->isNullValue()) // avoid X udiv 0 - if (Instruction *FoldedDiv = foldOpWithConstantIntoOperand(I)) - return FoldedDiv; + // (X << C1) / C2 -> X * ((1 << C1) / C2) if 1 << C1 is a multiple of C2. + if (isMultiple(C1Shifted, *C2, Quotient, IsSigned)) { + auto *Mul = BinaryOperator::Create(Instruction::Mul, X, + ConstantInt::get(Ty, Quotient)); + auto *OBO = cast<OverflowingBinaryOperator>(Op0); + Mul->setHasNoUnsignedWrap(!IsSigned && OBO->hasNoUnsignedWrap()); + Mul->setHasNoSignedWrap(OBO->hasNoSignedWrap()); + return Mul; + } } + + if (!C2->isNullValue()) // avoid X udiv 0 + if (Instruction *FoldedDiv = foldBinOpIntoSelectOrPhi(I)) + return FoldedDiv; } if (match(Op0, m_One())) { - assert(!I.getType()->isIntOrIntVectorTy(1) && "i1 divide not removed?"); - if (I.getOpcode() == Instruction::SDiv) { + assert(!Ty->isIntOrIntVectorTy(1) && "i1 divide not removed?"); + if (IsSigned) { // If Op1 is 0 then it's undefined behaviour, if Op1 is 1 then the // result is one, if Op1 is -1 then the result is minus one, otherwise // it's zero. Value *Inc = Builder.CreateAdd(Op1, Op0); - Value *Cmp = Builder.CreateICmpULT(Inc, ConstantInt::get(I.getType(), 3)); - return SelectInst::Create(Cmp, Op1, ConstantInt::get(I.getType(), 0)); + Value *Cmp = Builder.CreateICmpULT(Inc, ConstantInt::get(Ty, 3)); + return SelectInst::Create(Cmp, Op1, ConstantInt::get(Ty, 0)); } else { // If Op1 is 0 then it's undefined behaviour. If Op1 is 1 then the // result is one, otherwise it's zero. - return new ZExtInst(Builder.CreateICmpEQ(Op1, Op0), I.getType()); + return new ZExtInst(Builder.CreateICmpEQ(Op1, Op0), Ty); } } @@ -981,12 +759,28 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { return &I; // (X - (X rem Y)) / Y -> X / Y; usually originates as ((X / Y) * Y) / Y - Value *X = nullptr, *Z = nullptr; - if (match(Op0, m_Sub(m_Value(X), m_Value(Z)))) { // (X - Z) / Y; Y = Op1 - bool isSigned = I.getOpcode() == Instruction::SDiv; - if ((isSigned && match(Z, m_SRem(m_Specific(X), m_Specific(Op1)))) || - (!isSigned && match(Z, m_URem(m_Specific(X), m_Specific(Op1))))) + Value *X, *Z; + if (match(Op0, m_Sub(m_Value(X), m_Value(Z)))) // (X - Z) / Y; Y = Op1 + if ((IsSigned && match(Z, m_SRem(m_Specific(X), m_Specific(Op1)))) || + (!IsSigned && match(Z, m_URem(m_Specific(X), m_Specific(Op1))))) return BinaryOperator::Create(I.getOpcode(), X, Op1); + + // (X << Y) / X -> 1 << Y + Value *Y; + if (IsSigned && match(Op0, m_NSWShl(m_Specific(Op1), m_Value(Y)))) + return BinaryOperator::CreateNSWShl(ConstantInt::get(Ty, 1), Y); + if (!IsSigned && match(Op0, m_NUWShl(m_Specific(Op1), m_Value(Y)))) + return BinaryOperator::CreateNUWShl(ConstantInt::get(Ty, 1), Y); + + // X / (X * Y) -> 1 / Y if the multiplication does not overflow. + if (match(Op1, m_c_Mul(m_Specific(Op0), m_Value(Y)))) { + bool HasNSW = cast<OverflowingBinaryOperator>(Op1)->hasNoSignedWrap(); + bool HasNUW = cast<OverflowingBinaryOperator>(Op1)->hasNoUnsignedWrap(); + if ((IsSigned && HasNSW) || (!IsSigned && HasNUW)) { + I.setOperand(0, ConstantInt::get(Ty, 1)); + I.setOperand(1, Y); + return &I; + } } return nullptr; @@ -1000,7 +794,7 @@ using FoldUDivOperandCb = Instruction *(*)(Value *Op0, Value *Op1, const BinaryOperator &I, InstCombiner &IC); -/// \brief Used to maintain state for visitUDivOperand(). +/// Used to maintain state for visitUDivOperand(). struct UDivFoldAction { /// Informs visitUDiv() how to fold this operand. This can be zero if this /// action joins two actions together. @@ -1028,23 +822,15 @@ struct UDivFoldAction { // X udiv 2^C -> X >> C static Instruction *foldUDivPow2Cst(Value *Op0, Value *Op1, const BinaryOperator &I, InstCombiner &IC) { - const APInt &C = cast<Constant>(Op1)->getUniqueInteger(); - BinaryOperator *LShr = BinaryOperator::CreateLShr( - Op0, ConstantInt::get(Op0->getType(), C.logBase2())); + Constant *C1 = getLogBase2(Op0->getType(), cast<Constant>(Op1)); + if (!C1) + llvm_unreachable("Failed to constant fold udiv -> logbase2"); + BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, C1); if (I.isExact()) LShr->setIsExact(); return LShr; } -// X udiv C, where C >= signbit -static Instruction *foldUDivNegCst(Value *Op0, Value *Op1, - const BinaryOperator &I, InstCombiner &IC) { - Value *ICI = IC.Builder.CreateICmpULT(Op0, cast<ConstantInt>(Op1)); - - return SelectInst::Create(ICI, Constant::getNullValue(I.getType()), - ConstantInt::get(I.getType(), 1)); -} - // X udiv (C1 << N), where C1 is "1<<C2" --> X >> (N+C2) // X udiv (zext (C1 << N)), where C1 is "1<<C2" --> X >> (N+C2) static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, @@ -1053,12 +839,14 @@ static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, if (!match(Op1, m_ZExt(m_Value(ShiftLeft)))) ShiftLeft = Op1; - const APInt *CI; + Constant *CI; Value *N; - if (!match(ShiftLeft, m_Shl(m_APInt(CI), m_Value(N)))) + if (!match(ShiftLeft, m_Shl(m_Constant(CI), m_Value(N)))) llvm_unreachable("match should never fail here!"); - if (*CI != 1) - N = IC.Builder.CreateAdd(N, ConstantInt::get(N->getType(), CI->logBase2())); + Constant *Log2Base = getLogBase2(N->getType(), CI); + if (!Log2Base) + llvm_unreachable("getLogBase2 should never fail here!"); + N = IC.Builder.CreateAdd(N, Log2Base); if (Op1 != ShiftLeft) N = IC.Builder.CreateZExt(N, Op1->getType()); BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, N); @@ -1067,7 +855,7 @@ static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, return LShr; } -// \brief Recursively visits the possible right hand operands of a udiv +// Recursively visits the possible right hand operands of a udiv // instruction, seeing through select instructions, to determine if we can // replace the udiv with something simpler. If we find that an operand is not // able to simplify the udiv, we abort the entire transformation. @@ -1081,13 +869,6 @@ static size_t visitUDivOperand(Value *Op0, Value *Op1, const BinaryOperator &I, return Actions.size(); } - if (ConstantInt *C = dyn_cast<ConstantInt>(Op1)) - // X udiv C, where C >= signbit - if (C->getValue().isNegative()) { - Actions.push_back(UDivFoldAction(foldUDivNegCst, C)); - return Actions.size(); - } - // X udiv (C1 << N), where C1 is "1<<C2" --> X >> (N+C2) if (match(Op1, m_Shl(m_Power2(), m_Value())) || match(Op1, m_ZExt(m_Shl(m_Power2(), m_Value())))) { @@ -1148,40 +929,65 @@ static Instruction *narrowUDivURem(BinaryOperator &I, } Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyUDivInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyUDivInst(Op0, Op1, SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; // Handle the integer div common cases if (Instruction *Common = commonIDivTransforms(I)) return Common; - // (x lshr C1) udiv C2 --> x udiv (C2 << C1) - { - Value *X; - const APInt *C1, *C2; - if (match(Op0, m_LShr(m_Value(X), m_APInt(C1))) && - match(Op1, m_APInt(C2))) { - bool Overflow; - APInt C2ShlC1 = C2->ushl_ov(*C1, Overflow); - if (!Overflow) { - bool IsExact = I.isExact() && match(Op0, m_Exact(m_Value())); - BinaryOperator *BO = BinaryOperator::CreateUDiv( - X, ConstantInt::get(X->getType(), C2ShlC1)); - if (IsExact) - BO->setIsExact(); - return BO; - } + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *X; + const APInt *C1, *C2; + if (match(Op0, m_LShr(m_Value(X), m_APInt(C1))) && match(Op1, m_APInt(C2))) { + // (X lshr C1) udiv C2 --> X udiv (C2 << C1) + bool Overflow; + APInt C2ShlC1 = C2->ushl_ov(*C1, Overflow); + if (!Overflow) { + bool IsExact = I.isExact() && match(Op0, m_Exact(m_Value())); + BinaryOperator *BO = BinaryOperator::CreateUDiv( + X, ConstantInt::get(X->getType(), C2ShlC1)); + if (IsExact) + BO->setIsExact(); + return BO; } } + // Op0 / C where C is large (negative) --> zext (Op0 >= C) + // TODO: Could use isKnownNegative() to handle non-constant values. + Type *Ty = I.getType(); + if (match(Op1, m_Negative())) { + Value *Cmp = Builder.CreateICmpUGE(Op0, Op1); + return CastInst::CreateZExtOrBitCast(Cmp, Ty); + } + // Op0 / (sext i1 X) --> zext (Op0 == -1) (if X is 0, the div is undefined) + if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + Value *Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::getAllOnesValue(Ty)); + return CastInst::CreateZExtOrBitCast(Cmp, Ty); + } + if (Instruction *NarrowDiv = narrowUDivURem(I, Builder)) return NarrowDiv; + // If the udiv operands are non-overflowing multiplies with a common operand, + // then eliminate the common factor: + // (A * B) / (A * X) --> B / X (and commuted variants) + // TODO: The code would be reduced if we had m_c_NUWMul pattern matching. + // TODO: If -reassociation handled this generally, we could remove this. + Value *A, *B; + if (match(Op0, m_NUWMul(m_Value(A), m_Value(B)))) { + if (match(Op1, m_NUWMul(m_Specific(A), m_Value(X))) || + match(Op1, m_NUWMul(m_Value(X), m_Specific(A)))) + return BinaryOperator::CreateUDiv(B, X); + if (match(Op1, m_NUWMul(m_Specific(B), m_Value(X))) || + match(Op1, m_NUWMul(m_Value(X), m_Specific(B)))) + return BinaryOperator::CreateUDiv(A, X); + } + // (LHS udiv (select (select (...)))) -> (LHS >> (select (select (...)))) SmallVector<UDivFoldAction, 6> UDivActions; if (visitUDivOperand(Op0, Op1, I, UDivActions)) @@ -1217,24 +1023,27 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { } Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifySDivInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Value *V = SimplifySDivInst(Op0, Op1, SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; // Handle the integer div common cases if (Instruction *Common = commonIDivTransforms(I)) return Common; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *X; + // sdiv Op0, -1 --> -Op0 + // sdiv Op0, (sext i1 X) --> -Op0 (because if X is 0, the op is undefined) + if (match(Op1, m_AllOnes()) || + (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))) + return BinaryOperator::CreateNeg(Op0); + const APInt *Op1C; if (match(Op1, m_APInt(Op1C))) { - // sdiv X, -1 == -X - if (Op1C->isAllOnesValue()) - return BinaryOperator::CreateNeg(Op0); - // sdiv exact X, C --> ashr exact X, log2(C) if (I.isExact() && Op1C->isNonNegative() && Op1C->isPowerOf2()) { Value *ShAmt = ConstantInt::get(Op1->getType(), Op1C->exactLogBase2()); @@ -1298,166 +1107,148 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { return nullptr; } -/// CvtFDivConstToReciprocal tries to convert X/C into X*1/C if C not a special -/// FP value and: -/// 1) 1/C is exact, or -/// 2) reciprocal is allowed. -/// If the conversion was successful, the simplified expression "X * 1/C" is -/// returned; otherwise, nullptr is returned. -static Instruction *CvtFDivConstToReciprocal(Value *Dividend, Constant *Divisor, - bool AllowReciprocal) { - if (!isa<ConstantFP>(Divisor)) // TODO: handle vectors. +/// Remove negation and try to convert division into multiplication. +static Instruction *foldFDivConstantDivisor(BinaryOperator &I) { + Constant *C; + if (!match(I.getOperand(1), m_Constant(C))) return nullptr; - const APFloat &FpVal = cast<ConstantFP>(Divisor)->getValueAPF(); - APFloat Reciprocal(FpVal.getSemantics()); - bool Cvt = FpVal.getExactInverse(&Reciprocal); + // -X / C --> X / -C + Value *X; + if (match(I.getOperand(0), m_FNeg(m_Value(X)))) + return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); - if (!Cvt && AllowReciprocal && FpVal.isFiniteNonZero()) { - Reciprocal = APFloat(FpVal.getSemantics(), 1.0f); - (void)Reciprocal.divide(FpVal, APFloat::rmNearestTiesToEven); - Cvt = !Reciprocal.isDenormal(); - } + // If the constant divisor has an exact inverse, this is always safe. If not, + // then we can still create a reciprocal if fast-math-flags allow it and the + // constant is a regular number (not zero, infinite, or denormal). + if (!(C->hasExactInverseFP() || (I.hasAllowReciprocal() && C->isNormalFP()))) + return nullptr; - if (!Cvt) + // Disallow denormal constants because we don't know what would happen + // on all targets. + // TODO: Use Intrinsic::canonicalize or let function attributes tell us that + // denorms are flushed? + auto *RecipC = ConstantExpr::getFDiv(ConstantFP::get(I.getType(), 1.0), C); + if (!RecipC->isNormalFP()) return nullptr; - ConstantFP *R; - R = ConstantFP::get(Dividend->getType()->getContext(), Reciprocal); - return BinaryOperator::CreateFMul(Dividend, R); + // X / C --> X * (1 / C) + return BinaryOperator::CreateFMulFMF(I.getOperand(0), RecipC, &I); } -Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); +/// Remove negation and try to reassociate constant math. +static Instruction *foldFDivConstantDividend(BinaryOperator &I) { + Constant *C; + if (!match(I.getOperand(0), m_Constant(C))) + return nullptr; - if (Value *V = SimplifyVectorOp(I)) - return replaceInstUsesWith(I, V); + // C / -X --> -C / X + Value *X; + if (match(I.getOperand(1), m_FNeg(m_Value(X)))) + return BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); + + if (!I.hasAllowReassoc() || !I.hasAllowReciprocal()) + return nullptr; + + // Try to reassociate C / X expressions where X includes another constant. + Constant *C2, *NewC = nullptr; + if (match(I.getOperand(1), m_FMul(m_Value(X), m_Constant(C2)))) { + // C / (X * C2) --> (C / C2) / X + NewC = ConstantExpr::getFDiv(C, C2); + } else if (match(I.getOperand(1), m_FDiv(m_Value(X), m_Constant(C2)))) { + // C / (X / C2) --> (C * C2) / X + NewC = ConstantExpr::getFMul(C, C2); + } + // Disallow denormal constants because we don't know what would happen + // on all targets. + // TODO: Use Intrinsic::canonicalize or let function attributes tell us that + // denorms are flushed? + if (!NewC || !NewC->isNormalFP()) + return nullptr; + + return BinaryOperator::CreateFDivFMF(NewC, X, &I); +} - if (Value *V = SimplifyFDivInst(Op0, Op1, I.getFastMathFlags(), +Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { + if (Value *V = SimplifyFDivInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; + + if (Instruction *R = foldFDivConstantDivisor(I)) + return R; + + if (Instruction *R = foldFDivConstantDividend(I)) + return R; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (isa<Constant>(Op0)) if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; - bool AllowReassociate = I.isFast(); - bool AllowReciprocal = I.hasAllowReciprocal(); - - if (Constant *Op1C = dyn_cast<Constant>(Op1)) { + if (isa<Constant>(Op1)) if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; - if (AllowReassociate) { - Constant *C1 = nullptr; - Constant *C2 = Op1C; - Value *X; - Instruction *Res = nullptr; - - if (match(Op0, m_FMul(m_Value(X), m_Constant(C1)))) { - // (X*C1)/C2 => X * (C1/C2) - // - Constant *C = ConstantExpr::getFDiv(C1, C2); - if (isNormalFp(C)) - Res = BinaryOperator::CreateFMul(X, C); - } else if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) { - // (X/C1)/C2 => X /(C2*C1) [=> X * 1/(C2*C1) if reciprocal is allowed] - Constant *C = ConstantExpr::getFMul(C1, C2); - if (isNormalFp(C)) { - Res = CvtFDivConstToReciprocal(X, C, AllowReciprocal); - if (!Res) - Res = BinaryOperator::CreateFDiv(X, C); - } - } - - if (Res) { - Res->setFastMathFlags(I.getFastMathFlags()); - return Res; - } + if (I.hasAllowReassoc() && I.hasAllowReciprocal()) { + Value *X, *Y; + if (match(Op0, m_OneUse(m_FDiv(m_Value(X), m_Value(Y)))) && + (!isa<Constant>(Y) || !isa<Constant>(Op1))) { + // (X / Y) / Z => X / (Y * Z) + Value *YZ = Builder.CreateFMulFMF(Y, Op1, &I); + return BinaryOperator::CreateFDivFMF(X, YZ, &I); } - - // X / C => X * 1/C - if (Instruction *T = CvtFDivConstToReciprocal(Op0, Op1C, AllowReciprocal)) { - T->copyFastMathFlags(&I); - return T; + if (match(Op1, m_OneUse(m_FDiv(m_Value(X), m_Value(Y)))) && + (!isa<Constant>(Y) || !isa<Constant>(Op0))) { + // Z / (X / Y) => (Y * Z) / X + Value *YZ = Builder.CreateFMulFMF(Y, Op0, &I); + return BinaryOperator::CreateFDivFMF(YZ, X, &I); } - - return nullptr; } - if (AllowReassociate && isa<Constant>(Op0)) { - Constant *C1 = cast<Constant>(Op0), *C2; - Constant *Fold = nullptr; + if (I.hasAllowReassoc() && Op0->hasOneUse() && Op1->hasOneUse()) { + // sin(X) / cos(X) -> tan(X) + // cos(X) / sin(X) -> 1/tan(X) (cotangent) Value *X; - bool CreateDiv = true; - - // C1 / (X*C2) => (C1/C2) / X - if (match(Op1, m_FMul(m_Value(X), m_Constant(C2)))) - Fold = ConstantExpr::getFDiv(C1, C2); - else if (match(Op1, m_FDiv(m_Value(X), m_Constant(C2)))) { - // C1 / (X/C2) => (C1*C2) / X - Fold = ConstantExpr::getFMul(C1, C2); - } else if (match(Op1, m_FDiv(m_Constant(C2), m_Value(X)))) { - // C1 / (C2/X) => (C1/C2) * X - Fold = ConstantExpr::getFDiv(C1, C2); - CreateDiv = false; - } - - if (Fold && isNormalFp(Fold)) { - Instruction *R = CreateDiv ? BinaryOperator::CreateFDiv(Fold, X) - : BinaryOperator::CreateFMul(X, Fold); - R->setFastMathFlags(I.getFastMathFlags()); - return R; + bool IsTan = match(Op0, m_Intrinsic<Intrinsic::sin>(m_Value(X))) && + match(Op1, m_Intrinsic<Intrinsic::cos>(m_Specific(X))); + bool IsCot = + !IsTan && match(Op0, m_Intrinsic<Intrinsic::cos>(m_Value(X))) && + match(Op1, m_Intrinsic<Intrinsic::sin>(m_Specific(X))); + + if ((IsTan || IsCot) && hasUnaryFloatFn(&TLI, I.getType(), LibFunc_tan, + LibFunc_tanf, LibFunc_tanl)) { + IRBuilder<> B(&I); + IRBuilder<>::FastMathFlagGuard FMFGuard(B); + B.setFastMathFlags(I.getFastMathFlags()); + AttributeList Attrs = CallSite(Op0).getCalledFunction()->getAttributes(); + Value *Res = emitUnaryFloatFnCall(X, TLI.getName(LibFunc_tan), B, Attrs); + if (IsCot) + Res = B.CreateFDiv(ConstantFP::get(I.getType(), 1.0), Res); + return replaceInstUsesWith(I, Res); } - return nullptr; } - if (AllowReassociate) { - Value *X, *Y; - Value *NewInst = nullptr; - Instruction *SimpR = nullptr; - - if (Op0->hasOneUse() && match(Op0, m_FDiv(m_Value(X), m_Value(Y)))) { - // (X/Y) / Z => X / (Y*Z) - if (!isa<Constant>(Y) || !isa<Constant>(Op1)) { - NewInst = Builder.CreateFMul(Y, Op1); - if (Instruction *RI = dyn_cast<Instruction>(NewInst)) { - FastMathFlags Flags = I.getFastMathFlags(); - Flags &= cast<Instruction>(Op0)->getFastMathFlags(); - RI->setFastMathFlags(Flags); - } - SimpR = BinaryOperator::CreateFDiv(X, NewInst); - } - } else if (Op1->hasOneUse() && match(Op1, m_FDiv(m_Value(X), m_Value(Y)))) { - // Z / (X/Y) => Z*Y / X - if (!isa<Constant>(Y) || !isa<Constant>(Op0)) { - NewInst = Builder.CreateFMul(Op0, Y); - if (Instruction *RI = dyn_cast<Instruction>(NewInst)) { - FastMathFlags Flags = I.getFastMathFlags(); - Flags &= cast<Instruction>(Op1)->getFastMathFlags(); - RI->setFastMathFlags(Flags); - } - SimpR = BinaryOperator::CreateFDiv(NewInst, X); - } - } - - if (NewInst) { - if (Instruction *T = dyn_cast<Instruction>(NewInst)) - T->setDebugLoc(I.getDebugLoc()); - SimpR->setFastMathFlags(I.getFastMathFlags()); - return SimpR; - } + // -X / -Y -> X / Y + Value *X, *Y; + if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) { + I.setOperand(0, X); + I.setOperand(1, Y); + return &I; } - Value *LHS; - Value *RHS; - - // -x / -y -> x / y - if (match(Op0, m_FNeg(m_Value(LHS))) && match(Op1, m_FNeg(m_Value(RHS)))) { - I.setOperand(0, LHS); - I.setOperand(1, RHS); + // X / (X * Y) --> 1.0 / Y + // Reassociate to (X / X -> 1.0) is legal when NaNs are not allowed. + // We can ignore the possibility that X is infinity because INF/INF is NaN. + if (I.hasNoNaNs() && I.hasAllowReassoc() && + match(Op1, m_c_FMul(m_Specific(Op0), m_Value(Y)))) { + I.setOperand(0, ConstantFP::get(I.getType(), 1.0)); + I.setOperand(1, Y); return &I; } @@ -1467,7 +1258,7 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { /// This function implements the transforms common to both integer remainder /// instructions (urem and srem). It is called by the visitors to those integer /// remainder instructions. -/// @brief Common integer remainder transforms +/// Common integer remainder transforms Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -1509,13 +1300,12 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { } Instruction *InstCombiner::visitURem(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyURemInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyURemInst(Op0, Op1, SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; if (Instruction *common = commonIRemTransforms(I)) return common; @@ -1524,47 +1314,55 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { return NarrowRem; // X urem Y -> X and Y-1, where Y is a power of 2, + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = I.getType(); if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) { - Constant *N1 = Constant::getAllOnesValue(I.getType()); + Constant *N1 = Constant::getAllOnesValue(Ty); Value *Add = Builder.CreateAdd(Op1, N1); return BinaryOperator::CreateAnd(Op0, Add); } // 1 urem X -> zext(X != 1) - if (match(Op0, m_One())) { - Value *Cmp = Builder.CreateICmpNE(Op1, Op0); - Value *Ext = Builder.CreateZExt(Cmp, I.getType()); - return replaceInstUsesWith(I, Ext); - } + if (match(Op0, m_One())) + return CastInst::CreateZExtOrBitCast(Builder.CreateICmpNE(Op1, Op0), Ty); // X urem C -> X < C ? X : X - C, where C >= signbit. - const APInt *DivisorC; - if (match(Op1, m_APInt(DivisorC)) && DivisorC->isNegative()) { + if (match(Op1, m_Negative())) { Value *Cmp = Builder.CreateICmpULT(Op0, Op1); Value *Sub = Builder.CreateSub(Op0, Op1); return SelectInst::Create(Cmp, Op0, Sub); } + // If the divisor is a sext of a boolean, then the divisor must be max + // unsigned value (-1). Therefore, the remainder is Op0 unless Op0 is also + // max unsigned value. In that case, the remainder is 0: + // urem Op0, (sext i1 X) --> (Op0 == -1) ? 0 : Op0 + Value *X; + if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + Value *Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::getAllOnesValue(Ty)); + return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Op0); + } + return nullptr; } Instruction *InstCombiner::visitSRem(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifySRemInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Value *V = SimplifySRemInst(Op0, Op1, SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; // Handle the integer rem common cases if (Instruction *Common = commonIRemTransforms(I)) return Common; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); { const APInt *Y; // X % -Y -> X % Y - if (match(Op1, m_APInt(Y)) && Y->isNegative() && !Y->isMinSignedValue()) { + if (match(Op1, m_Negative(Y)) && !Y->isMinSignedValue()) { Worklist.AddValue(I.getOperand(1)); I.setOperand(1, ConstantInt::get(I.getType(), -*Y)); return &I; @@ -1622,14 +1420,13 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { } Instruction *InstCombiner::visitFRem(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = SimplifyVectorOp(I)) - return replaceInstUsesWith(I, V); - - if (Value *V = SimplifyFRemInst(Op0, Op1, I.getFastMathFlags(), + if (Value *V = SimplifyFRemInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; + return nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombinePHI.cpp b/lib/Transforms/InstCombine/InstCombinePHI.cpp index 7ee018dbc49b..e54a1dd05a24 100644 --- a/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -15,14 +15,18 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/Transforms/Utils/Local.h" using namespace llvm; using namespace llvm::PatternMatch; #define DEBUG_TYPE "instcombine" +static cl::opt<unsigned> +MaxNumPhis("instcombine-max-num-phis", cl::init(512), + cl::desc("Maximum number phis to handle in intptr/ptrint folding")); + /// The PHI arguments will be folded into a single operation with a PHI node /// as input. The debug location of the single operation will be the merged /// locations of the original PHI node arguments. @@ -176,8 +180,12 @@ Instruction *InstCombiner::FoldIntegerTypedPHI(PHINode &PN) { assert(AvailablePtrVals.size() == PN.getNumIncomingValues() && "Not enough available ptr typed incoming values"); PHINode *MatchingPtrPHI = nullptr; + unsigned NumPhis = 0; for (auto II = BB->begin(), EI = BasicBlock::iterator(BB->getFirstNonPHI()); - II != EI; II++) { + II != EI; II++, NumPhis++) { + // FIXME: consider handling this in AggressiveInstCombine + if (NumPhis > MaxNumPhis) + return nullptr; PHINode *PtrPHI = dyn_cast<PHINode>(II); if (!PtrPHI || PtrPHI == &PN || PtrPHI->getType() != IntToPtr->getType()) continue; @@ -1008,10 +1016,9 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { // extracted out of it. First, sort the users by their offset and size. array_pod_sort(PHIUsers.begin(), PHIUsers.end()); - DEBUG(dbgs() << "SLICING UP PHI: " << FirstPhi << '\n'; - for (unsigned i = 1, e = PHIsToSlice.size(); i != e; ++i) - dbgs() << "AND USER PHI #" << i << ": " << *PHIsToSlice[i] << '\n'; - ); + LLVM_DEBUG(dbgs() << "SLICING UP PHI: " << FirstPhi << '\n'; + for (unsigned i = 1, e = PHIsToSlice.size(); i != e; ++i) dbgs() + << "AND USER PHI #" << i << ": " << *PHIsToSlice[i] << '\n';); // PredValues - This is a temporary used when rewriting PHI nodes. It is // hoisted out here to avoid construction/destruction thrashing. @@ -1092,8 +1099,8 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { } PredValues.clear(); - DEBUG(dbgs() << " Made element PHI for offset " << Offset << ": " - << *EltPHI << '\n'); + LLVM_DEBUG(dbgs() << " Made element PHI for offset " << Offset << ": " + << *EltPHI << '\n'); ExtractedVals[LoweredPHIRecord(PN, Offset, Ty)] = EltPHI; } diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 6f26f7f5cd19..4867808478a3 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -47,93 +47,51 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" -static SelectPatternFlavor -getInverseMinMaxSelectPattern(SelectPatternFlavor SPF) { - switch (SPF) { - default: - llvm_unreachable("unhandled!"); - - case SPF_SMIN: - return SPF_SMAX; - case SPF_UMIN: - return SPF_UMAX; - case SPF_SMAX: - return SPF_SMIN; - case SPF_UMAX: - return SPF_UMIN; - } -} - -static CmpInst::Predicate getCmpPredicateForMinMax(SelectPatternFlavor SPF, - bool Ordered=false) { - switch (SPF) { - default: - llvm_unreachable("unhandled!"); - - case SPF_SMIN: - return ICmpInst::ICMP_SLT; - case SPF_UMIN: - return ICmpInst::ICMP_ULT; - case SPF_SMAX: - return ICmpInst::ICMP_SGT; - case SPF_UMAX: - return ICmpInst::ICMP_UGT; - case SPF_FMINNUM: - return Ordered ? FCmpInst::FCMP_OLT : FCmpInst::FCMP_ULT; - case SPF_FMAXNUM: - return Ordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT; - } -} - -static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy &Builder, - SelectPatternFlavor SPF, Value *A, - Value *B) { - CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF); - assert(CmpInst::isIntPredicate(Pred)); +static Value *createMinMax(InstCombiner::BuilderTy &Builder, + SelectPatternFlavor SPF, Value *A, Value *B) { + CmpInst::Predicate Pred = getMinMaxPred(SPF); + assert(CmpInst::isIntPredicate(Pred) && "Expected integer predicate"); return Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B); } -/// If one of the constants is zero (we know they can't both be) and we have an -/// icmp instruction with zero, and we have an 'and' with the non-constant value -/// and a power of two we can turn the select into a shift on the result of the -/// 'and'. /// This folds: -/// select (icmp eq (and X, C1)), C2, C3 -/// iff C1 is a power 2 and the difference between C2 and C3 is a power of 2. +/// select (icmp eq (and X, C1)), TC, FC +/// iff C1 is a power 2 and the difference between TC and FC is a power-of-2. /// To something like: -/// (shr (and (X, C1)), (log2(C1) - log2(C2-C3))) + C3 +/// (shr (and (X, C1)), (log2(C1) - log2(TC-FC))) + FC /// Or: -/// (shl (and (X, C1)), (log2(C2-C3) - log2(C1))) + C3 -/// With some variations depending if C3 is larger than C2, or the shift +/// (shl (and (X, C1)), (log2(TC-FC) - log2(C1))) + FC +/// With some variations depending if FC is larger than TC, or the shift /// isn't needed, or the bit widths don't match. -static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, - APInt TrueVal, APInt FalseVal, +static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, InstCombiner::BuilderTy &Builder) { - assert(SelType->isIntOrIntVectorTy() && "Not an integer select?"); + const APInt *SelTC, *SelFC; + if (!match(Sel.getTrueValue(), m_APInt(SelTC)) || + !match(Sel.getFalseValue(), m_APInt(SelFC))) + return nullptr; // If this is a vector select, we need a vector compare. - if (SelType->isVectorTy() != IC->getType()->isVectorTy()) + Type *SelType = Sel.getType(); + if (SelType->isVectorTy() != Cmp->getType()->isVectorTy()) return nullptr; Value *V; APInt AndMask; bool CreateAnd = false; - ICmpInst::Predicate Pred = IC->getPredicate(); + ICmpInst::Predicate Pred = Cmp->getPredicate(); if (ICmpInst::isEquality(Pred)) { - if (!match(IC->getOperand(1), m_Zero())) + if (!match(Cmp->getOperand(1), m_Zero())) return nullptr; - V = IC->getOperand(0); - + V = Cmp->getOperand(0); const APInt *AndRHS; if (!match(V, m_And(m_Value(), m_Power2(AndRHS)))) return nullptr; AndMask = *AndRHS; - } else if (decomposeBitTestICmp(IC->getOperand(0), IC->getOperand(1), + } else if (decomposeBitTestICmp(Cmp->getOperand(0), Cmp->getOperand(1), Pred, V, AndMask)) { assert(ICmpInst::isEquality(Pred) && "Not equality test?"); - if (!AndMask.isPowerOf2()) return nullptr; @@ -142,39 +100,58 @@ static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, return nullptr; } - // If both select arms are non-zero see if we have a select of the form - // 'x ? 2^n + C : C'. Then we can offset both arms by C, use the logic - // for 'x ? 2^n : 0' and fix the thing up at the end. - APInt Offset(TrueVal.getBitWidth(), 0); - if (!TrueVal.isNullValue() && !FalseVal.isNullValue()) { - if ((TrueVal - FalseVal).isPowerOf2()) - Offset = FalseVal; - else if ((FalseVal - TrueVal).isPowerOf2()) - Offset = TrueVal; - else + // In general, when both constants are non-zero, we would need an offset to + // replace the select. This would require more instructions than we started + // with. But there's one special-case that we handle here because it can + // simplify/reduce the instructions. + APInt TC = *SelTC; + APInt FC = *SelFC; + if (!TC.isNullValue() && !FC.isNullValue()) { + // If the select constants differ by exactly one bit and that's the same + // bit that is masked and checked by the select condition, the select can + // be replaced by bitwise logic to set/clear one bit of the constant result. + if (TC.getBitWidth() != AndMask.getBitWidth() || (TC ^ FC) != AndMask) return nullptr; - - // Adjust TrueVal and FalseVal to the offset. - TrueVal -= Offset; - FalseVal -= Offset; + if (CreateAnd) { + // If we have to create an 'and', then we must kill the cmp to not + // increase the instruction count. + if (!Cmp->hasOneUse()) + return nullptr; + V = Builder.CreateAnd(V, ConstantInt::get(SelType, AndMask)); + } + bool ExtraBitInTC = TC.ugt(FC); + if (Pred == ICmpInst::ICMP_EQ) { + // If the masked bit in V is clear, clear or set the bit in the result: + // (V & AndMaskC) == 0 ? TC : FC --> (V & AndMaskC) ^ TC + // (V & AndMaskC) == 0 ? TC : FC --> (V & AndMaskC) | TC + Constant *C = ConstantInt::get(SelType, TC); + return ExtraBitInTC ? Builder.CreateXor(V, C) : Builder.CreateOr(V, C); + } + if (Pred == ICmpInst::ICMP_NE) { + // If the masked bit in V is set, set or clear the bit in the result: + // (V & AndMaskC) != 0 ? TC : FC --> (V & AndMaskC) | FC + // (V & AndMaskC) != 0 ? TC : FC --> (V & AndMaskC) ^ FC + Constant *C = ConstantInt::get(SelType, FC); + return ExtraBitInTC ? Builder.CreateOr(V, C) : Builder.CreateXor(V, C); + } + llvm_unreachable("Only expecting equality predicates"); } - // Make sure one of the select arms is a power of 2. - if (!TrueVal.isPowerOf2() && !FalseVal.isPowerOf2()) + // Make sure one of the select arms is a power-of-2. + if (!TC.isPowerOf2() && !FC.isPowerOf2()) return nullptr; // Determine which shift is needed to transform result of the 'and' into the // desired result. - const APInt &ValC = !TrueVal.isNullValue() ? TrueVal : FalseVal; + const APInt &ValC = !TC.isNullValue() ? TC : FC; unsigned ValZeros = ValC.logBase2(); unsigned AndZeros = AndMask.logBase2(); - if (CreateAnd) { - // Insert the AND instruction on the input to the truncate. + // Insert the 'and' instruction on the input to the truncate. + if (CreateAnd) V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask)); - } - // If types don't match we can still convert the select by introducing a zext + // If types don't match, we can still convert the select by introducing a zext // or a trunc of the 'and'. if (ValZeros > AndZeros) { V = Builder.CreateZExtOrTrunc(V, SelType); @@ -182,19 +159,17 @@ static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, } else if (ValZeros < AndZeros) { V = Builder.CreateLShr(V, AndZeros - ValZeros); V = Builder.CreateZExtOrTrunc(V, SelType); - } else + } else { V = Builder.CreateZExtOrTrunc(V, SelType); + } // Okay, now we know that everything is set up, we just don't know whether we // have a icmp_ne or icmp_eq and whether the true or false val is the zero. - bool ShouldNotVal = !TrueVal.isNullValue(); + bool ShouldNotVal = !TC.isNullValue(); ShouldNotVal ^= Pred == ICmpInst::ICMP_NE; if (ShouldNotVal) V = Builder.CreateXor(V, ValC); - // Apply an offset if needed. - if (!Offset.isNullValue()) - V = Builder.CreateAdd(V, ConstantInt::get(V->getType(), Offset)); return V; } @@ -300,12 +275,13 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, TI->getType()); } - // Only handle binary operators with one-use here. As with the cast case - // above, it may be possible to relax the one-use constraint, but that needs - // be examined carefully since it may not reduce the total number of - // instructions. - BinaryOperator *BO = dyn_cast<BinaryOperator>(TI); - if (!BO || !TI->hasOneUse() || !FI->hasOneUse()) + // Only handle binary operators (including two-operand getelementptr) with + // one-use here. As with the cast case above, it may be possible to relax the + // one-use constraint, but that needs be examined carefully since it may not + // reduce the total number of instructions. + if (TI->getNumOperands() != 2 || FI->getNumOperands() != 2 || + (!isa<BinaryOperator>(TI) && !isa<GetElementPtrInst>(TI)) || + !TI->hasOneUse() || !FI->hasOneUse()) return nullptr; // Figure out if the operations have any operands in common. @@ -342,7 +318,18 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, SI.getName() + ".v", &SI); Value *Op0 = MatchIsOpZero ? MatchOp : NewSI; Value *Op1 = MatchIsOpZero ? NewSI : MatchOp; - return BinaryOperator::Create(BO->getOpcode(), Op0, Op1); + if (auto *BO = dyn_cast<BinaryOperator>(TI)) { + return BinaryOperator::Create(BO->getOpcode(), Op0, Op1); + } + if (auto *TGEP = dyn_cast<GetElementPtrInst>(TI)) { + auto *FGEP = cast<GetElementPtrInst>(FI); + Type *ElementType = TGEP->getResultElementType(); + return TGEP->isInBounds() && FGEP->isInBounds() + ? GetElementPtrInst::CreateInBounds(ElementType, Op0, {Op1}) + : GetElementPtrInst::Create(ElementType, Op0, {Op1}); + } + llvm_unreachable("Expected BinaryOperator or GEP"); + return nullptr; } static bool isSelect01(const APInt &C1I, const APInt &C2I) { @@ -424,6 +411,47 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, } /// We want to turn: +/// (select (icmp eq (and X, Y), 0), (and (lshr X, Z), 1), 1) +/// into: +/// zext (icmp ne i32 (and X, (or Y, (shl 1, Z))), 0) +/// Note: +/// Z may be 0 if lshr is missing. +/// Worst-case scenario is that we will replace 5 instructions with 5 different +/// instructions, but we got rid of select. +static Instruction *foldSelectICmpAndAnd(Type *SelType, const ICmpInst *Cmp, + Value *TVal, Value *FVal, + InstCombiner::BuilderTy &Builder) { + if (!(Cmp->hasOneUse() && Cmp->getOperand(0)->hasOneUse() && + Cmp->getPredicate() == ICmpInst::ICMP_EQ && + match(Cmp->getOperand(1), m_Zero()) && match(FVal, m_One()))) + return nullptr; + + // The TrueVal has general form of: and %B, 1 + Value *B; + if (!match(TVal, m_OneUse(m_And(m_Value(B), m_One())))) + return nullptr; + + // Where %B may be optionally shifted: lshr %X, %Z. + Value *X, *Z; + const bool HasShift = match(B, m_OneUse(m_LShr(m_Value(X), m_Value(Z)))); + if (!HasShift) + X = B; + + Value *Y; + if (!match(Cmp->getOperand(0), m_c_And(m_Specific(X), m_Value(Y)))) + return nullptr; + + // ((X & Y) == 0) ? ((X >> Z) & 1) : 1 --> (X & (Y | (1 << Z))) != 0 + // ((X & Y) == 0) ? (X & 1) : 1 --> (X & (Y | 1)) != 0 + Constant *One = ConstantInt::get(SelType, 1); + Value *MaskB = HasShift ? Builder.CreateShl(One, Z) : One; + Value *FullMask = Builder.CreateOr(Y, MaskB); + Value *MaskedX = Builder.CreateAnd(X, FullMask); + Value *ICmpNeZero = Builder.CreateIsNotNull(MaskedX); + return new ZExtInst(ICmpNeZero, SelType); +} + +/// We want to turn: /// (select (icmp eq (and X, C1), 0), Y, (or Y, C2)) /// into: /// (or (shl (and X, C1), C3), Y) @@ -526,6 +554,59 @@ static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal, return Builder.CreateOr(V, Y); } +/// Transform patterns such as: (a > b) ? a - b : 0 +/// into: ((a > b) ? a : b) - b) +/// This produces a canonical max pattern that is more easily recognized by the +/// backend and converted into saturated subtraction instructions if those +/// exist. +/// There are 8 commuted/swapped variants of this pattern. +/// TODO: Also support a - UMIN(a,b) patterns. +static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI, + const Value *TrueVal, + const Value *FalseVal, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate Pred = ICI->getPredicate(); + if (!ICmpInst::isUnsigned(Pred)) + return nullptr; + + // (b > a) ? 0 : a - b -> (b <= a) ? a - b : 0 + if (match(TrueVal, m_Zero())) { + Pred = ICmpInst::getInversePredicate(Pred); + std::swap(TrueVal, FalseVal); + } + if (!match(FalseVal, m_Zero())) + return nullptr; + + Value *A = ICI->getOperand(0); + Value *B = ICI->getOperand(1); + if (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_ULT) { + // (b < a) ? a - b : 0 -> (a > b) ? a - b : 0 + std::swap(A, B); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + assert((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_UGT) && + "Unexpected isUnsigned predicate!"); + + // Account for swapped form of subtraction: ((a > b) ? b - a : 0). + bool IsNegative = false; + if (match(TrueVal, m_Sub(m_Specific(B), m_Specific(A)))) + IsNegative = true; + else if (!match(TrueVal, m_Sub(m_Specific(A), m_Specific(B)))) + return nullptr; + + // If sub is used anywhere else, we wouldn't be able to eliminate it + // afterwards. + if (!TrueVal->hasOneUse()) + return nullptr; + + // All checks passed, convert to canonical unsigned saturated subtraction + // form: sub(max()). + // (a > b) ? a - b : 0 -> ((a > b) ? a : b) - b) + Value *Max = Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B); + return IsNegative ? Builder.CreateSub(B, Max) : Builder.CreateSub(Max, B); +} + /// Attempt to fold a cttz/ctlz followed by a icmp plus select into a single /// call to cttz/ctlz with flag 'is_zero_undef' cleared. /// @@ -687,23 +768,18 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, // Canonicalize the compare predicate based on whether we have min or max. Value *LHS, *RHS; - ICmpInst::Predicate NewPred; SelectPatternResult SPR = matchSelectPattern(&Sel, LHS, RHS); - switch (SPR.Flavor) { - case SPF_SMIN: NewPred = ICmpInst::ICMP_SLT; break; - case SPF_UMIN: NewPred = ICmpInst::ICMP_ULT; break; - case SPF_SMAX: NewPred = ICmpInst::ICMP_SGT; break; - case SPF_UMAX: NewPred = ICmpInst::ICMP_UGT; break; - default: return nullptr; - } + if (!SelectPatternResult::isMinOrMax(SPR.Flavor)) + return nullptr; // Is this already canonical? + ICmpInst::Predicate CanonicalPred = getMinMaxPred(SPR.Flavor); if (Cmp.getOperand(0) == LHS && Cmp.getOperand(1) == RHS && - Cmp.getPredicate() == NewPred) + Cmp.getPredicate() == CanonicalPred) return nullptr; // Create the canonical compare and plug it into the select. - Sel.setCondition(Builder.CreateICmp(NewPred, LHS, RHS)); + Sel.setCondition(Builder.CreateICmp(CanonicalPred, LHS, RHS)); // If the select operands did not change, we're done. if (Sel.getTrueValue() == LHS && Sel.getFalseValue() == RHS) @@ -718,6 +794,89 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, return &Sel; } +/// There are many select variants for each of ABS/NABS. +/// In matchSelectPattern(), there are different compare constants, compare +/// predicates/operands and select operands. +/// In isKnownNegation(), there are different formats of negated operands. +/// Canonicalize all these variants to 1 pattern. +/// This makes CSE more likely. +static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp, + InstCombiner::BuilderTy &Builder) { + if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1))) + return nullptr; + + // Choose a sign-bit check for the compare (likely simpler for codegen). + // ABS: (X <s 0) ? -X : X + // NABS: (X <s 0) ? X : -X + Value *LHS, *RHS; + SelectPatternFlavor SPF = matchSelectPattern(&Sel, LHS, RHS).Flavor; + if (SPF != SelectPatternFlavor::SPF_ABS && + SPF != SelectPatternFlavor::SPF_NABS) + return nullptr; + + Value *TVal = Sel.getTrueValue(); + Value *FVal = Sel.getFalseValue(); + assert(isKnownNegation(TVal, FVal) && + "Unexpected result from matchSelectPattern"); + + // The compare may use the negated abs()/nabs() operand, or it may use + // negation in non-canonical form such as: sub A, B. + bool CmpUsesNegatedOp = match(Cmp.getOperand(0), m_Neg(m_Specific(TVal))) || + match(Cmp.getOperand(0), m_Neg(m_Specific(FVal))); + + bool CmpCanonicalized = !CmpUsesNegatedOp && + match(Cmp.getOperand(1), m_ZeroInt()) && + Cmp.getPredicate() == ICmpInst::ICMP_SLT; + bool RHSCanonicalized = match(RHS, m_Neg(m_Specific(LHS))); + + // Is this already canonical? + if (CmpCanonicalized && RHSCanonicalized) + return nullptr; + + // If RHS is used by other instructions except compare and select, don't + // canonicalize it to not increase the instruction count. + if (!(RHS->hasOneUse() || (RHS->hasNUses(2) && CmpUsesNegatedOp))) + return nullptr; + + // Create the canonical compare: icmp slt LHS 0. + if (!CmpCanonicalized) { + Cmp.setPredicate(ICmpInst::ICMP_SLT); + Cmp.setOperand(1, ConstantInt::getNullValue(Cmp.getOperand(0)->getType())); + if (CmpUsesNegatedOp) + Cmp.setOperand(0, LHS); + } + + // Create the canonical RHS: RHS = sub (0, LHS). + if (!RHSCanonicalized) { + assert(RHS->hasOneUse() && "RHS use number is not right"); + RHS = Builder.CreateNeg(LHS); + if (TVal == LHS) { + Sel.setFalseValue(RHS); + FVal = RHS; + } else { + Sel.setTrueValue(RHS); + TVal = RHS; + } + } + + // If the select operands do not change, we're done. + if (SPF == SelectPatternFlavor::SPF_NABS) { + if (TVal == LHS) + return &Sel; + assert(FVal == LHS && "Unexpected results from matchSelectPattern"); + } else { + if (FVal == LHS) + return &Sel; + assert(TVal == LHS && "Unexpected results from matchSelectPattern"); + } + + // We are swapping the select operands, so swap the metadata too. + Sel.setTrueValue(FVal); + Sel.setFalseValue(TVal); + Sel.swapProfMetadata(); + return &Sel; +} + /// Visit a SelectInst that has an ICmpInst as its first operand. Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI) { @@ -727,59 +886,18 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, Builder)) return NewSel; + if (Instruction *NewAbs = canonicalizeAbsNabs(SI, *ICI, Builder)) + return NewAbs; + bool Changed = adjustMinMax(SI, *ICI); + if (Value *V = foldSelectICmpAnd(SI, ICI, Builder)) + return replaceInstUsesWith(SI, V); + + // NOTE: if we wanted to, this is where to detect integer MIN/MAX ICmpInst::Predicate Pred = ICI->getPredicate(); Value *CmpLHS = ICI->getOperand(0); Value *CmpRHS = ICI->getOperand(1); - - // Transform (X >s -1) ? C1 : C2 --> ((X >>s 31) & (C2 - C1)) + C1 - // and (X <s 0) ? C2 : C1 --> ((X >>s 31) & (C2 - C1)) + C1 - // FIXME: Type and constness constraints could be lifted, but we have to - // watch code size carefully. We should consider xor instead of - // sub/add when we decide to do that. - // TODO: Merge this with foldSelectICmpAnd somehow. - if (CmpLHS->getType()->isIntOrIntVectorTy() && - CmpLHS->getType() == TrueVal->getType()) { - const APInt *C1, *C2; - if (match(TrueVal, m_APInt(C1)) && match(FalseVal, m_APInt(C2))) { - ICmpInst::Predicate Pred = ICI->getPredicate(); - Value *X; - APInt Mask; - if (decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, X, Mask, false)) { - if (Mask.isSignMask()) { - assert(X == CmpLHS && "Expected to use the compare input directly"); - assert(ICmpInst::isEquality(Pred) && "Expected equality predicate"); - - if (Pred == ICmpInst::ICMP_NE) - std::swap(C1, C2); - - // This shift results in either -1 or 0. - Value *AShr = Builder.CreateAShr(X, Mask.getBitWidth() - 1); - - // Check if we can express the operation with a single or. - if (C2->isAllOnesValue()) - return replaceInstUsesWith(SI, Builder.CreateOr(AShr, *C1)); - - Value *And = Builder.CreateAnd(AShr, *C2 - *C1); - return replaceInstUsesWith(SI, Builder.CreateAdd(And, - ConstantInt::get(And->getType(), *C1))); - } - } - } - } - - { - const APInt *TrueValC, *FalseValC; - if (match(TrueVal, m_APInt(TrueValC)) && - match(FalseVal, m_APInt(FalseValC))) - if (Value *V = foldSelectICmpAnd(SI.getType(), ICI, *TrueValC, - *FalseValC, Builder)) - return replaceInstUsesWith(SI, V); - } - - // NOTE: if we wanted to, this is where to detect integer MIN/MAX - if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) { if (CmpLHS == TrueVal && Pred == ICmpInst::ICMP_EQ) { // Transform (X == C) ? X : Y -> (X == C) ? C : Y @@ -842,16 +960,22 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, } } + if (Instruction *V = + foldSelectICmpAndAnd(SI.getType(), ICI, TrueVal, FalseVal, Builder)) + return V; + if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); + if (Value *V = canonicalizeSaturatedSubtract(ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + return Changed ? &SI : nullptr; } - /// SI is a select whose condition is a PHI node (but the two may be in /// different blocks). See if the true/false values (V) are live in all of the /// predecessor blocks of the PHI. For example, cases like this can't be mapped: @@ -900,7 +1024,7 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, if (C == A || C == B) { // MAX(MAX(A, B), B) -> MAX(A, B) // MIN(MIN(a, b), a) -> MIN(a, b) - if (SPF1 == SPF2) + if (SPF1 == SPF2 && SelectPatternResult::isMinOrMax(SPF1)) return replaceInstUsesWith(Outer, Inner); // MAX(MIN(a, b), a) -> a @@ -992,10 +1116,10 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, if (!NotC) NotC = Builder.CreateNot(C); - Value *NewInner = generateMinMaxSelectPattern( - Builder, getInverseMinMaxSelectPattern(SPF1), NotA, NotB); - Value *NewOuter = Builder.CreateNot(generateMinMaxSelectPattern( - Builder, getInverseMinMaxSelectPattern(SPF2), NewInner, NotC)); + Value *NewInner = createMinMax(Builder, getInverseMinMaxFlavor(SPF1), NotA, + NotB); + Value *NewOuter = Builder.CreateNot( + createMinMax(Builder, getInverseMinMaxFlavor(SPF2), NewInner, NotC)); return replaceInstUsesWith(Outer, NewOuter); } @@ -1075,6 +1199,11 @@ static Instruction *foldAddSubSelect(SelectInst &SI, } Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { + Constant *C; + if (!match(Sel.getTrueValue(), m_Constant(C)) && + !match(Sel.getFalseValue(), m_Constant(C))) + return nullptr; + Instruction *ExtInst; if (!match(Sel.getTrueValue(), m_Instruction(ExtInst)) && !match(Sel.getFalseValue(), m_Instruction(ExtInst))) @@ -1084,20 +1213,18 @@ Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { if (ExtOpcode != Instruction::ZExt && ExtOpcode != Instruction::SExt) return nullptr; - // TODO: Handle larger types? That requires adjusting FoldOpIntoSelect too. + // If we are extending from a boolean type or if we can create a select that + // has the same size operands as its condition, try to narrow the select. Value *X = ExtInst->getOperand(0); Type *SmallType = X->getType(); - if (!SmallType->isIntOrIntVectorTy(1)) - return nullptr; - - Constant *C; - if (!match(Sel.getTrueValue(), m_Constant(C)) && - !match(Sel.getFalseValue(), m_Constant(C))) + Value *Cond = Sel.getCondition(); + auto *Cmp = dyn_cast<CmpInst>(Cond); + if (!SmallType->isIntOrIntVectorTy(1) && + (!Cmp || Cmp->getOperand(0)->getType() != SmallType)) return nullptr; // If the constant is the same after truncation to the smaller type and // extension to the original type, we can narrow the select. - Value *Cond = Sel.getCondition(); Type *SelType = Sel.getType(); Constant *TruncC = ConstantExpr::getTrunc(C, SmallType); Constant *ExtC = ConstantExpr::getCast(ExtOpcode, TruncC, SelType); @@ -1289,6 +1416,63 @@ static Instruction *foldSelectCmpXchg(SelectInst &SI) { return nullptr; } +/// Reduce a sequence of min/max with a common operand. +static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS, + Value *RHS, + InstCombiner::BuilderTy &Builder) { + assert(SelectPatternResult::isMinOrMax(SPF) && "Expected a min/max"); + // TODO: Allow FP min/max with nnan/nsz. + if (!LHS->getType()->isIntOrIntVectorTy()) + return nullptr; + + // Match 3 of the same min/max ops. Example: umin(umin(), umin()). + Value *A, *B, *C, *D; + SelectPatternResult L = matchSelectPattern(LHS, A, B); + SelectPatternResult R = matchSelectPattern(RHS, C, D); + if (SPF != L.Flavor || L.Flavor != R.Flavor) + return nullptr; + + // Look for a common operand. The use checks are different than usual because + // a min/max pattern typically has 2 uses of each op: 1 by the cmp and 1 by + // the select. + Value *MinMaxOp = nullptr; + Value *ThirdOp = nullptr; + if (!LHS->hasNUsesOrMore(3) && RHS->hasNUsesOrMore(3)) { + // If the LHS is only used in this chain and the RHS is used outside of it, + // reuse the RHS min/max because that will eliminate the LHS. + if (D == A || C == A) { + // min(min(a, b), min(c, a)) --> min(min(c, a), b) + // min(min(a, b), min(a, d)) --> min(min(a, d), b) + MinMaxOp = RHS; + ThirdOp = B; + } else if (D == B || C == B) { + // min(min(a, b), min(c, b)) --> min(min(c, b), a) + // min(min(a, b), min(b, d)) --> min(min(b, d), a) + MinMaxOp = RHS; + ThirdOp = A; + } + } else if (!RHS->hasNUsesOrMore(3)) { + // Reuse the LHS. This will eliminate the RHS. + if (D == A || D == B) { + // min(min(a, b), min(c, a)) --> min(min(a, b), c) + // min(min(a, b), min(c, b)) --> min(min(a, b), c) + MinMaxOp = LHS; + ThirdOp = C; + } else if (C == A || C == B) { + // min(min(a, b), min(b, d)) --> min(min(a, b), d) + // min(min(a, b), min(c, b)) --> min(min(a, b), d) + MinMaxOp = LHS; + ThirdOp = D; + } + } + if (!MinMaxOp || !ThirdOp) + return nullptr; + + CmpInst::Predicate P = getMinMaxPred(SPF); + Value *CmpABC = Builder.CreateICmp(P, MinMaxOp, ThirdOp); + return SelectInst::Create(CmpABC, MinMaxOp, ThirdOp); +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -1489,7 +1673,37 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // NOTE: if we wanted to, this is where to detect MIN/MAX } - // NOTE: if we wanted to, this is where to detect ABS + + // Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need + // fast-math-flags (nsz) or fsub with +0.0 (not fneg) for this to work. We + // also require nnan because we do not want to unintentionally change the + // sign of a NaN value. + Value *X = FCI->getOperand(0); + FCmpInst::Predicate Pred = FCI->getPredicate(); + if (match(FCI->getOperand(1), m_AnyZeroFP()) && FCI->hasNoNaNs()) { + // (X <= +/-0.0) ? (0.0 - X) : X --> fabs(X) + // (X > +/-0.0) ? X : (0.0 - X) --> fabs(X) + if ((X == FalseVal && Pred == FCmpInst::FCMP_OLE && + match(TrueVal, m_FSub(m_PosZeroFP(), m_Specific(X)))) || + (X == TrueVal && Pred == FCmpInst::FCMP_OGT && + match(FalseVal, m_FSub(m_PosZeroFP(), m_Specific(X))))) { + Value *Fabs = Builder.CreateIntrinsic(Intrinsic::fabs, { X }, FCI); + return replaceInstUsesWith(SI, Fabs); + } + // With nsz: + // (X < +/-0.0) ? -X : X --> fabs(X) + // (X <= +/-0.0) ? -X : X --> fabs(X) + // (X > +/-0.0) ? X : -X --> fabs(X) + // (X >= +/-0.0) ? X : -X --> fabs(X) + if (FCI->hasNoSignedZeros() && + ((X == FalseVal && match(TrueVal, m_FNeg(m_Specific(X))) && + (Pred == FCmpInst::FCMP_OLT || Pred == FCmpInst::FCMP_OLE)) || + (X == TrueVal && match(FalseVal, m_FNeg(m_Specific(X))) && + (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_OGE)))) { + Value *Fabs = Builder.CreateIntrinsic(Intrinsic::fabs, { X }, FCI); + return replaceInstUsesWith(SI, Fabs); + } + } } // See if we are selecting two values based on a comparison of the two values. @@ -1532,7 +1746,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { (LHS->getType()->isFPOrFPVectorTy() && ((CmpLHS != LHS && CmpLHS != RHS) || (CmpRHS != LHS && CmpRHS != RHS)))) { - CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF, SPR.Ordered); + CmpInst::Predicate Pred = getMinMaxPred(SPF, SPR.Ordered); Value *Cmp; if (CmpInst::isIntPredicate(Pred)) { @@ -1551,6 +1765,20 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *NewCast = Builder.CreateCast(CastOp, NewSI, SelType); return replaceInstUsesWith(SI, NewCast); } + + // MAX(~a, ~b) -> ~MIN(a, b) + // MIN(~a, ~b) -> ~MAX(a, b) + Value *A, *B; + if (match(LHS, m_Not(m_Value(A))) && match(RHS, m_Not(m_Value(B))) && + (LHS->getNumUses() <= 2 || RHS->getNumUses() <= 2)) { + CmpInst::Predicate InvertedPred = getInverseMinMaxPred(SPF); + Value *InvertedCmp = Builder.CreateICmp(InvertedPred, A, B); + Value *NewSel = Builder.CreateSelect(InvertedCmp, A, B); + return BinaryOperator::CreateNot(NewSel); + } + + if (Instruction *I = factorizeMinMaxTree(SPF, LHS, RHS, Builder)) + return I; } if (SPF) { @@ -1570,28 +1798,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return R; } - // MAX(~a, ~b) -> ~MIN(a, b) - if ((SPF == SPF_SMAX || SPF == SPF_UMAX) && - IsFreeToInvert(LHS, LHS->hasNUses(2)) && - IsFreeToInvert(RHS, RHS->hasNUses(2))) { - // For this transform to be profitable, we need to eliminate at least two - // 'not' instructions if we're going to add one 'not' instruction. - int NumberOfNots = - (LHS->hasNUses(2) && match(LHS, m_Not(m_Value()))) + - (RHS->hasNUses(2) && match(RHS, m_Not(m_Value()))) + - (SI.hasOneUse() && match(*SI.user_begin(), m_Not(m_Value()))); - - if (NumberOfNots >= 2) { - Value *NewLHS = Builder.CreateNot(LHS); - Value *NewRHS = Builder.CreateNot(RHS); - Value *NewCmp = SPF == SPF_SMAX ? Builder.CreateICmpSLT(NewLHS, NewRHS) - : Builder.CreateICmpULT(NewLHS, NewRHS); - Value *NewSI = - Builder.CreateNot(Builder.CreateSelect(NewCmp, NewLHS, NewRHS)); - return replaceInstUsesWith(SI, NewSI); - } - } - // TODO. // ABS(-X) -> ABS(X) } @@ -1643,11 +1849,25 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } + auto canMergeSelectThroughBinop = [](BinaryOperator *BO) { + // The select might be preventing a division by 0. + switch (BO->getOpcode()) { + default: + return true; + case Instruction::SRem: + case Instruction::URem: + case Instruction::SDiv: + case Instruction::UDiv: + return false; + } + }; + // Try to simplify a binop sandwiched between 2 selects with the same // condition. // select(C, binop(select(C, X, Y), W), Z) -> select(C, binop(X, W), Z) BinaryOperator *TrueBO; - if (match(TrueVal, m_OneUse(m_BinOp(TrueBO)))) { + if (match(TrueVal, m_OneUse(m_BinOp(TrueBO))) && + canMergeSelectThroughBinop(TrueBO)) { if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(0))) { if (TrueBOSI->getCondition() == CondVal) { TrueBO->setOperand(0, TrueBOSI->getTrueValue()); @@ -1666,7 +1886,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // select(C, Z, binop(select(C, X, Y), W)) -> select(C, Z, binop(Y, W)) BinaryOperator *FalseBO; - if (match(FalseVal, m_OneUse(m_BinOp(FalseBO)))) { + if (match(FalseVal, m_OneUse(m_BinOp(FalseBO))) && + canMergeSelectThroughBinop(FalseBO)) { if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(0))) { if (FalseBOSI->getCondition() == CondVal) { FalseBO->setOperand(0, FalseBOSI->getFalseValue()); diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 44bbb84686ab..34f8037e519f 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -87,8 +87,7 @@ static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl, // Equal shift amounts in opposite directions become bitwise 'and': // lshr (shl X, C), C --> and X, C' // shl (lshr X, C), C --> and X, C' - unsigned InnerShAmt = InnerShiftConst->getZExtValue(); - if (InnerShAmt == OuterShAmt) + if (*InnerShiftConst == OuterShAmt) return true; // If the 2nd shift is bigger than the 1st, we can fold: @@ -98,7 +97,8 @@ static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl, // Also, check that the inner shift is valid (less than the type width) or // we'll crash trying to produce the bit mask for the 'and'. unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits(); - if (InnerShAmt > OuterShAmt && InnerShAmt < TypeWidth) { + if (InnerShiftConst->ugt(OuterShAmt) && InnerShiftConst->ult(TypeWidth)) { + unsigned InnerShAmt = InnerShiftConst->getZExtValue(); unsigned MaskShift = IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt; APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift; @@ -135,7 +135,7 @@ static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, ConstantInt *CI = nullptr; if ((IsLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) || (!IsLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) { - if (CI->getZExtValue() == NumBits) { + if (CI->getValue() == NumBits) { // TODO: Check that the input bits are already zero with MaskedValueIsZero #if 0 // If this is a truncate of a logical shr, we can truncate it to a smaller @@ -356,8 +356,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // cast of lshr(shl(x,c1),c2) as well as other more complex cases. if (I.getOpcode() != Instruction::AShr && canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) { - DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression" - " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n"); + LLVM_DEBUG( + dbgs() << "ICE: GetShiftedValue propagating shift through expression" + " to eliminate shift:\n IN: " + << *Op0 << "\n SH: " << I << "\n"); return replaceInstUsesWith( I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL)); @@ -370,7 +372,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, assert(!Op1C->uge(TypeBits) && "Shift over the type width should have been removed already"); - if (Instruction *FoldedShift = foldOpWithConstantIntoOperand(I)) + if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I)) return FoldedShift; // Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2)) @@ -586,23 +588,23 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, } Instruction *InstCombiner::visitShl(BinaryOperator &I) { - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1), + I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (Value *V = - SimplifyShlInst(Op0, Op1, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), - SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; if (Instruction *V = commonShiftTransforms(I)) return V; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = I.getType(); const APInt *ShAmtAPInt; if (match(Op1, m_APInt(ShAmtAPInt))) { unsigned ShAmt = ShAmtAPInt->getZExtValue(); - unsigned BitWidth = I.getType()->getScalarSizeInBits(); - Type *Ty = I.getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); // shl (zext X), ShAmt --> zext (shl X, ShAmt) // This is only valid if X would have zeros shifted out. @@ -620,11 +622,8 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); } - // Be careful about hiding shl instructions behind bit masks. They are used - // to represent multiplies by a constant, and it is important that simple - // arithmetic expressions are still recognizable by scalar evolution. - // The inexact versions are deferred to DAGCombine, so we don't hide shl - // behind a bit mask. + // FIXME: we do not yet transform non-exact shr's. The backend (DAGCombine) + // needs a few fixes for the rotate pattern recognition first. const APInt *ShOp1; if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(ShOp1))))) { unsigned ShrAmt = ShOp1->getZExtValue(); @@ -668,6 +667,15 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { } } + // Transform (x >> y) << y to x & (-1 << y) + // Valid for any type of right-shift. + Value *X; + if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) { + Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); + Value *Mask = Builder.CreateShl(AllOnes, Op1); + return BinaryOperator::CreateAnd(Mask, X); + } + Constant *C1; if (match(Op1, m_Constant(C1))) { Constant *C2; @@ -685,17 +693,17 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { } Instruction *InstCombiner::visitLShr(BinaryOperator &I) { - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (Value *V = - SimplifyLShrInst(Op0, Op1, I.isExact(), SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; if (Instruction *R = commonShiftTransforms(I)) return R; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); const APInt *ShAmtAPInt; if (match(Op1, m_APInt(ShAmtAPInt))) { @@ -800,25 +808,34 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { return &I; } } + + // Transform (x << y) >> y to x & (-1 >> y) + Value *X; + if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) { + Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); + Value *Mask = Builder.CreateLShr(AllOnes, Op1); + return BinaryOperator::CreateAnd(Mask, X); + } + return nullptr; } Instruction *InstCombiner::visitAShr(BinaryOperator &I) { - if (Value *V = SimplifyVectorOp(I)) + if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (Value *V = - SimplifyAShrInst(Op0, Op1, I.isExact(), SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); + if (Instruction *X = foldShuffledBinop(I)) + return X; if (Instruction *R = commonShiftTransforms(I)) return R; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); unsigned BitWidth = Ty->getScalarSizeInBits(); const APInt *ShAmtAPInt; - if (match(Op1, m_APInt(ShAmtAPInt))) { + if (match(Op1, m_APInt(ShAmtAPInt)) && ShAmtAPInt->ult(BitWidth)) { unsigned ShAmt = ShAmtAPInt->getZExtValue(); // If the shift amount equals the difference in width of the destination @@ -832,7 +849,8 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However, // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. const APInt *ShOp1; - if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1)))) { + if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1))) && + ShOp1->ult(BitWidth)) { unsigned ShlAmt = ShOp1->getZExtValue(); if (ShlAmt < ShAmt) { // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1) @@ -850,7 +868,8 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { } } - if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1)))) { + if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1))) && + ShOp1->ult(BitWidth)) { unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); // Oversized arithmetic shifts replicate the sign bit. AmtSum = std::min(AmtSum, BitWidth - 1); diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index a2e757cb4273..425f5ce384be 100644 --- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -23,6 +23,17 @@ using namespace llvm::PatternMatch; #define DEBUG_TYPE "instcombine" +namespace { + +struct AMDGPUImageDMaskIntrinsic { + unsigned Intr; +}; + +#define GET_AMDGPUImageDMaskIntrinsicTable_IMPL +#include "InstCombineTables.inc" + +} // end anonymous namespace + /// Check to see if the specified operand of the specified instruction is a /// constant integer. If so, check to see if there are any bits set in the /// constant that are not demanded. If so, shrink the constant and return true. @@ -333,7 +344,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownBits InputKnown(SrcBitWidth); if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1)) return I; - Known = Known.zextOrTrunc(BitWidth); + Known = InputKnown.zextOrTrunc(BitWidth); // Any top bits are known to be zero. if (BitWidth > SrcBitWidth) Known.Zero.setBitsFrom(SrcBitWidth); @@ -545,6 +556,27 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } break; } + case Instruction::UDiv: { + // UDiv doesn't demand low bits that are zero in the divisor. + const APInt *SA; + if (match(I->getOperand(1), m_APInt(SA))) { + // If the shift is exact, then it does demand the low bits. + if (cast<UDivOperator>(I)->isExact()) + break; + + // FIXME: Take the demanded mask of the result into account. + unsigned RHSTrailingZeros = SA->countTrailingZeros(); + APInt DemandedMaskIn = + APInt::getHighBitsSet(BitWidth, BitWidth - RHSTrailingZeros); + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1)) + return I; + + // Propagate zero bits from the input. + Known.Zero.setHighBits(std::min( + BitWidth, LHSKnown.Zero.countLeadingOnes() + RHSTrailingZeros)); + } + break; + } case Instruction::SRem: if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) { // X % -1 demands all the bits because we don't want to introduce @@ -888,6 +920,110 @@ InstCombiner::simplifyShrShlDemandedBits(Instruction *Shr, const APInt &ShrOp1, return nullptr; } +/// Implement SimplifyDemandedVectorElts for amdgcn buffer and image intrinsics. +Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, + APInt DemandedElts, + int DMaskIdx) { + unsigned VWidth = II->getType()->getVectorNumElements(); + if (VWidth == 1) + return nullptr; + + ConstantInt *NewDMask = nullptr; + + if (DMaskIdx < 0) { + // Pretend that a prefix of elements is demanded to simplify the code + // below. + DemandedElts = (1 << DemandedElts.getActiveBits()) - 1; + } else { + ConstantInt *DMask = dyn_cast<ConstantInt>(II->getArgOperand(DMaskIdx)); + if (!DMask) + return nullptr; // non-constant dmask is not supported by codegen + + unsigned DMaskVal = DMask->getZExtValue() & 0xf; + + // Mask off values that are undefined because the dmask doesn't cover them + DemandedElts &= (1 << countPopulation(DMaskVal)) - 1; + + unsigned NewDMaskVal = 0; + unsigned OrigLoadIdx = 0; + for (unsigned SrcIdx = 0; SrcIdx < 4; ++SrcIdx) { + const unsigned Bit = 1 << SrcIdx; + if (!!(DMaskVal & Bit)) { + if (!!DemandedElts[OrigLoadIdx]) + NewDMaskVal |= Bit; + OrigLoadIdx++; + } + } + + if (DMaskVal != NewDMaskVal) + NewDMask = ConstantInt::get(DMask->getType(), NewDMaskVal); + } + + // TODO: Handle 3 vectors when supported in code gen. + unsigned NewNumElts = PowerOf2Ceil(DemandedElts.countPopulation()); + if (!NewNumElts) + return UndefValue::get(II->getType()); + + if (NewNumElts >= VWidth && DemandedElts.isMask()) { + if (NewDMask) + II->setArgOperand(DMaskIdx, NewDMask); + return nullptr; + } + + // Determine the overload types of the original intrinsic. + auto IID = II->getIntrinsicID(); + SmallVector<Intrinsic::IITDescriptor, 16> Table; + getIntrinsicInfoTableEntries(IID, Table); + ArrayRef<Intrinsic::IITDescriptor> TableRef = Table; + + FunctionType *FTy = II->getCalledFunction()->getFunctionType(); + SmallVector<Type *, 6> OverloadTys; + Intrinsic::matchIntrinsicType(FTy->getReturnType(), TableRef, OverloadTys); + for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) + Intrinsic::matchIntrinsicType(FTy->getParamType(i), TableRef, OverloadTys); + + // Get the new return type overload of the intrinsic. + Module *M = II->getParent()->getParent()->getParent(); + Type *EltTy = II->getType()->getVectorElementType(); + Type *NewTy = (NewNumElts == 1) ? EltTy : VectorType::get(EltTy, NewNumElts); + + OverloadTys[0] = NewTy; + Function *NewIntrin = Intrinsic::getDeclaration(M, IID, OverloadTys); + + SmallVector<Value *, 16> Args; + for (unsigned I = 0, E = II->getNumArgOperands(); I != E; ++I) + Args.push_back(II->getArgOperand(I)); + + if (NewDMask) + Args[DMaskIdx] = NewDMask; + + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(II); + + CallInst *NewCall = Builder.CreateCall(NewIntrin, Args); + NewCall->takeName(II); + NewCall->copyMetadata(*II); + + if (NewNumElts == 1) { + return Builder.CreateInsertElement(UndefValue::get(II->getType()), NewCall, + DemandedElts.countTrailingZeros()); + } + + SmallVector<uint32_t, 8> EltMask; + unsigned NewLoadIdx = 0; + for (unsigned OrigLoadIdx = 0; OrigLoadIdx < VWidth; ++OrigLoadIdx) { + if (!!DemandedElts[OrigLoadIdx]) + EltMask.push_back(NewLoadIdx++); + else + EltMask.push_back(NewNumElts); + } + + Value *Shuffle = + Builder.CreateShuffleVector(NewCall, UndefValue::get(NewTy), EltMask); + + return Shuffle; +} + /// The specified value produces a vector with any number of elements. /// DemandedElts contains the set of elements that are actually used by the /// caller. This method analyzes which elements of the operand are undef and @@ -1187,7 +1323,6 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, break; } - // div/rem demand all inputs, because they don't want divide by zero. TmpV = SimplifyDemandedVectorElts(I->getOperand(0), InputDemandedElts, UndefElts2, Depth + 1); if (TmpV) { @@ -1247,8 +1382,6 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, IntrinsicInst *II = dyn_cast<IntrinsicInst>(I); if (!II) break; switch (II->getIntrinsicID()) { - default: break; - case Intrinsic::x86_xop_vfrcz_ss: case Intrinsic::x86_xop_vfrcz_sd: // The instructions for these intrinsics are speced to zero upper bits not @@ -1273,8 +1406,6 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, // Unary scalar-as-vector operations that work column-wise. case Intrinsic::x86_sse_rcp_ss: case Intrinsic::x86_sse_rsqrt_ss: - case Intrinsic::x86_sse_sqrt_ss: - case Intrinsic::x86_sse2_sqrt_sd: TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, UndefElts, Depth + 1); if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } @@ -1366,18 +1497,6 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, case Intrinsic::x86_avx512_mask_sub_sd_round: case Intrinsic::x86_avx512_mask_max_sd_round: case Intrinsic::x86_avx512_mask_min_sd_round: - case Intrinsic::x86_fma_vfmadd_ss: - case Intrinsic::x86_fma_vfmsub_ss: - case Intrinsic::x86_fma_vfnmadd_ss: - case Intrinsic::x86_fma_vfnmsub_ss: - case Intrinsic::x86_fma_vfmadd_sd: - case Intrinsic::x86_fma_vfmsub_sd: - case Intrinsic::x86_fma_vfnmadd_sd: - case Intrinsic::x86_fma_vfnmsub_sd: - case Intrinsic::x86_avx512_mask_vfmadd_ss: - case Intrinsic::x86_avx512_mask_vfmadd_sd: - case Intrinsic::x86_avx512_maskz_vfmadd_ss: - case Intrinsic::x86_avx512_maskz_vfmadd_sd: TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, UndefElts, Depth + 1); if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } @@ -1404,68 +1523,6 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, break; - case Intrinsic::x86_avx512_mask3_vfmadd_ss: - case Intrinsic::x86_avx512_mask3_vfmadd_sd: - case Intrinsic::x86_avx512_mask3_vfmsub_ss: - case Intrinsic::x86_avx512_mask3_vfmsub_sd: - case Intrinsic::x86_avx512_mask3_vfnmsub_ss: - case Intrinsic::x86_avx512_mask3_vfnmsub_sd: - // These intrinsics get the passthru bits from operand 2. - TmpV = SimplifyDemandedVectorElts(II->getArgOperand(2), DemandedElts, - UndefElts, Depth + 1); - if (TmpV) { II->setArgOperand(2, TmpV); MadeChange = true; } - - // If lowest element of a scalar op isn't used then use Arg2. - if (!DemandedElts[0]) { - Worklist.Add(II); - return II->getArgOperand(2); - } - - // Only lower element is used for operand 0 and 1. - DemandedElts = 1; - TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, - UndefElts2, Depth + 1); - if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } - TmpV = SimplifyDemandedVectorElts(II->getArgOperand(1), DemandedElts, - UndefElts3, Depth + 1); - if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } - - // Lower element is undefined if all three lower elements are undefined. - // Consider things like undef&0. The result is known zero, not undef. - if (!UndefElts2[0] || !UndefElts3[0]) - UndefElts.clearBit(0); - - break; - - case Intrinsic::x86_sse2_pmulu_dq: - case Intrinsic::x86_sse41_pmuldq: - case Intrinsic::x86_avx2_pmul_dq: - case Intrinsic::x86_avx2_pmulu_dq: - case Intrinsic::x86_avx512_pmul_dq_512: - case Intrinsic::x86_avx512_pmulu_dq_512: { - Value *Op0 = II->getArgOperand(0); - Value *Op1 = II->getArgOperand(1); - unsigned InnerVWidth = Op0->getType()->getVectorNumElements(); - assert((VWidth * 2) == InnerVWidth && "Unexpected input size"); - - APInt InnerDemandedElts(InnerVWidth, 0); - for (unsigned i = 0; i != VWidth; ++i) - if (DemandedElts[i]) - InnerDemandedElts.setBit(i * 2); - - UndefElts2 = APInt(InnerVWidth, 0); - TmpV = SimplifyDemandedVectorElts(Op0, InnerDemandedElts, UndefElts2, - Depth + 1); - if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } - - UndefElts3 = APInt(InnerVWidth, 0); - TmpV = SimplifyDemandedVectorElts(Op1, InnerDemandedElts, UndefElts3, - Depth + 1); - if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } - - break; - } - case Intrinsic::x86_sse2_packssdw_128: case Intrinsic::x86_sse2_packsswb_128: case Intrinsic::x86_sse2_packuswb_128: @@ -1554,124 +1611,12 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, break; case Intrinsic::amdgcn_buffer_load: case Intrinsic::amdgcn_buffer_load_format: - case Intrinsic::amdgcn_image_sample: - case Intrinsic::amdgcn_image_sample_cl: - case Intrinsic::amdgcn_image_sample_d: - case Intrinsic::amdgcn_image_sample_d_cl: - case Intrinsic::amdgcn_image_sample_l: - case Intrinsic::amdgcn_image_sample_b: - case Intrinsic::amdgcn_image_sample_b_cl: - case Intrinsic::amdgcn_image_sample_lz: - case Intrinsic::amdgcn_image_sample_cd: - case Intrinsic::amdgcn_image_sample_cd_cl: - - case Intrinsic::amdgcn_image_sample_c: - case Intrinsic::amdgcn_image_sample_c_cl: - case Intrinsic::amdgcn_image_sample_c_d: - case Intrinsic::amdgcn_image_sample_c_d_cl: - case Intrinsic::amdgcn_image_sample_c_l: - case Intrinsic::amdgcn_image_sample_c_b: - case Intrinsic::amdgcn_image_sample_c_b_cl: - case Intrinsic::amdgcn_image_sample_c_lz: - case Intrinsic::amdgcn_image_sample_c_cd: - case Intrinsic::amdgcn_image_sample_c_cd_cl: - - case Intrinsic::amdgcn_image_sample_o: - case Intrinsic::amdgcn_image_sample_cl_o: - case Intrinsic::amdgcn_image_sample_d_o: - case Intrinsic::amdgcn_image_sample_d_cl_o: - case Intrinsic::amdgcn_image_sample_l_o: - case Intrinsic::amdgcn_image_sample_b_o: - case Intrinsic::amdgcn_image_sample_b_cl_o: - case Intrinsic::amdgcn_image_sample_lz_o: - case Intrinsic::amdgcn_image_sample_cd_o: - case Intrinsic::amdgcn_image_sample_cd_cl_o: - - case Intrinsic::amdgcn_image_sample_c_o: - case Intrinsic::amdgcn_image_sample_c_cl_o: - case Intrinsic::amdgcn_image_sample_c_d_o: - case Intrinsic::amdgcn_image_sample_c_d_cl_o: - case Intrinsic::amdgcn_image_sample_c_l_o: - case Intrinsic::amdgcn_image_sample_c_b_o: - case Intrinsic::amdgcn_image_sample_c_b_cl_o: - case Intrinsic::amdgcn_image_sample_c_lz_o: - case Intrinsic::amdgcn_image_sample_c_cd_o: - case Intrinsic::amdgcn_image_sample_c_cd_cl_o: - - case Intrinsic::amdgcn_image_getlod: { - if (VWidth == 1 || !DemandedElts.isMask()) - return nullptr; - - // TODO: Handle 3 vectors when supported in code gen. - unsigned NewNumElts = PowerOf2Ceil(DemandedElts.countTrailingOnes()); - if (NewNumElts == VWidth) - return nullptr; - - Module *M = II->getParent()->getParent()->getParent(); - Type *EltTy = V->getType()->getVectorElementType(); - - Type *NewTy = (NewNumElts == 1) ? EltTy : - VectorType::get(EltTy, NewNumElts); - - auto IID = II->getIntrinsicID(); - - bool IsBuffer = IID == Intrinsic::amdgcn_buffer_load || - IID == Intrinsic::amdgcn_buffer_load_format; - - Function *NewIntrin = IsBuffer ? - Intrinsic::getDeclaration(M, IID, NewTy) : - // Samplers have 3 mangled types. - Intrinsic::getDeclaration(M, IID, - { NewTy, II->getArgOperand(0)->getType(), - II->getArgOperand(1)->getType()}); - - SmallVector<Value *, 5> Args; - for (unsigned I = 0, E = II->getNumArgOperands(); I != E; ++I) - Args.push_back(II->getArgOperand(I)); - - IRBuilderBase::InsertPointGuard Guard(Builder); - Builder.SetInsertPoint(II); - - CallInst *NewCall = Builder.CreateCall(NewIntrin, Args); - NewCall->takeName(II); - NewCall->copyMetadata(*II); - - if (!IsBuffer) { - ConstantInt *DMask = dyn_cast<ConstantInt>(NewCall->getArgOperand(3)); - if (DMask) { - unsigned DMaskVal = DMask->getZExtValue() & 0xf; - - unsigned PopCnt = 0; - unsigned NewDMask = 0; - for (unsigned I = 0; I < 4; ++I) { - const unsigned Bit = 1 << I; - if (!!(DMaskVal & Bit)) { - if (++PopCnt > NewNumElts) - break; + return simplifyAMDGCNMemoryIntrinsicDemanded(II, DemandedElts); + default: { + if (getAMDGPUImageDMaskIntrinsic(II->getIntrinsicID())) + return simplifyAMDGCNMemoryIntrinsicDemanded(II, DemandedElts, 0); - NewDMask |= Bit; - } - } - - NewCall->setArgOperand(3, ConstantInt::get(DMask->getType(), NewDMask)); - } - } - - - if (NewNumElts == 1) { - return Builder.CreateInsertElement(UndefValue::get(V->getType()), - NewCall, static_cast<uint64_t>(0)); - } - - SmallVector<uint32_t, 8> EltMask; - for (unsigned I = 0; I < VWidth; ++I) - EltMask.push_back(I); - - Value *Shuffle = Builder.CreateShuffleVector( - NewCall, UndefValue::get(NewTy), EltMask); - - MadeChange = true; - return Shuffle; + break; } } break; diff --git a/lib/Transforms/InstCombine/InstCombineTables.td b/lib/Transforms/InstCombine/InstCombineTables.td new file mode 100644 index 000000000000..98b2adc442fa --- /dev/null +++ b/lib/Transforms/InstCombine/InstCombineTables.td @@ -0,0 +1,11 @@ +include "llvm/TableGen/SearchableTable.td" +include "llvm/IR/Intrinsics.td" + +def AMDGPUImageDMaskIntrinsicTable : GenericTable { + let FilterClass = "AMDGPUImageDMaskIntrinsic"; + let Fields = ["Intr"]; + + let PrimaryKey = ["Intr"]; + let PrimaryKeyName = "getAMDGPUImageDMaskIntrinsic"; + let PrimaryKeyEarlyOut = 1; +} diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index aeac8910af6b..2560feb37d66 100644 --- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -1140,6 +1140,216 @@ static bool isShuffleExtractingFromLHS(ShuffleVectorInst &SVI, return true; } +/// These are the ingredients in an alternate form binary operator as described +/// below. +struct BinopElts { + BinaryOperator::BinaryOps Opcode; + Value *Op0; + Value *Op1; + BinopElts(BinaryOperator::BinaryOps Opc = (BinaryOperator::BinaryOps)0, + Value *V0 = nullptr, Value *V1 = nullptr) : + Opcode(Opc), Op0(V0), Op1(V1) {} + operator bool() const { return Opcode != 0; } +}; + +/// Binops may be transformed into binops with different opcodes and operands. +/// Reverse the usual canonicalization to enable folds with the non-canonical +/// form of the binop. If a transform is possible, return the elements of the +/// new binop. If not, return invalid elements. +static BinopElts getAlternateBinop(BinaryOperator *BO, const DataLayout &DL) { + Value *BO0 = BO->getOperand(0), *BO1 = BO->getOperand(1); + Type *Ty = BO->getType(); + switch (BO->getOpcode()) { + case Instruction::Shl: { + // shl X, C --> mul X, (1 << C) + Constant *C; + if (match(BO1, m_Constant(C))) { + Constant *ShlOne = ConstantExpr::getShl(ConstantInt::get(Ty, 1), C); + return { Instruction::Mul, BO0, ShlOne }; + } + break; + } + case Instruction::Or: { + // or X, C --> add X, C (when X and C have no common bits set) + const APInt *C; + if (match(BO1, m_APInt(C)) && MaskedValueIsZero(BO0, *C, DL)) + return { Instruction::Add, BO0, BO1 }; + break; + } + default: + break; + } + return {}; +} + +static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) { + assert(Shuf.isSelect() && "Must have select-equivalent shuffle"); + + // Are we shuffling together some value and that same value after it has been + // modified by a binop with a constant? + Value *Op0 = Shuf.getOperand(0), *Op1 = Shuf.getOperand(1); + Constant *C; + bool Op0IsBinop; + if (match(Op0, m_BinOp(m_Specific(Op1), m_Constant(C)))) + Op0IsBinop = true; + else if (match(Op1, m_BinOp(m_Specific(Op0), m_Constant(C)))) + Op0IsBinop = false; + else + return nullptr; + + // The identity constant for a binop leaves a variable operand unchanged. For + // a vector, this is a splat of something like 0, -1, or 1. + // If there's no identity constant for this binop, we're done. + auto *BO = cast<BinaryOperator>(Op0IsBinop ? Op0 : Op1); + BinaryOperator::BinaryOps BOpcode = BO->getOpcode(); + Constant *IdC = ConstantExpr::getBinOpIdentity(BOpcode, Shuf.getType(), true); + if (!IdC) + return nullptr; + + // Shuffle identity constants into the lanes that return the original value. + // Example: shuf (mul X, {-1,-2,-3,-4}), X, {0,5,6,3} --> mul X, {-1,1,1,-4} + // Example: shuf X, (add X, {-1,-2,-3,-4}), {0,1,6,7} --> add X, {0,0,-3,-4} + // The existing binop constant vector remains in the same operand position. + Constant *Mask = Shuf.getMask(); + Constant *NewC = Op0IsBinop ? ConstantExpr::getShuffleVector(C, IdC, Mask) : + ConstantExpr::getShuffleVector(IdC, C, Mask); + + bool MightCreatePoisonOrUB = + Mask->containsUndefElement() && + (Instruction::isIntDivRem(BOpcode) || Instruction::isShift(BOpcode)); + if (MightCreatePoisonOrUB) + NewC = getSafeVectorConstantForBinop(BOpcode, NewC, true); + + // shuf (bop X, C), X, M --> bop X, C' + // shuf X, (bop X, C), M --> bop X, C' + Value *X = Op0IsBinop ? Op1 : Op0; + Instruction *NewBO = BinaryOperator::Create(BOpcode, X, NewC); + NewBO->copyIRFlags(BO); + + // An undef shuffle mask element may propagate as an undef constant element in + // the new binop. That would produce poison where the original code might not. + // If we already made a safe constant, then there's no danger. + if (Mask->containsUndefElement() && !MightCreatePoisonOrUB) + NewBO->dropPoisonGeneratingFlags(); + return NewBO; +} + +/// Try to fold shuffles that are the equivalent of a vector select. +static Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf, + InstCombiner::BuilderTy &Builder, + const DataLayout &DL) { + if (!Shuf.isSelect()) + return nullptr; + + if (Instruction *I = foldSelectShuffleWith1Binop(Shuf)) + return I; + + BinaryOperator *B0, *B1; + if (!match(Shuf.getOperand(0), m_BinOp(B0)) || + !match(Shuf.getOperand(1), m_BinOp(B1))) + return nullptr; + + Value *X, *Y; + Constant *C0, *C1; + bool ConstantsAreOp1; + if (match(B0, m_BinOp(m_Value(X), m_Constant(C0))) && + match(B1, m_BinOp(m_Value(Y), m_Constant(C1)))) + ConstantsAreOp1 = true; + else if (match(B0, m_BinOp(m_Constant(C0), m_Value(X))) && + match(B1, m_BinOp(m_Constant(C1), m_Value(Y)))) + ConstantsAreOp1 = false; + else + return nullptr; + + // We need matching binops to fold the lanes together. + BinaryOperator::BinaryOps Opc0 = B0->getOpcode(); + BinaryOperator::BinaryOps Opc1 = B1->getOpcode(); + bool DropNSW = false; + if (ConstantsAreOp1 && Opc0 != Opc1) { + // TODO: We drop "nsw" if shift is converted into multiply because it may + // not be correct when the shift amount is BitWidth - 1. We could examine + // each vector element to determine if it is safe to keep that flag. + if (Opc0 == Instruction::Shl || Opc1 == Instruction::Shl) + DropNSW = true; + if (BinopElts AltB0 = getAlternateBinop(B0, DL)) { + assert(isa<Constant>(AltB0.Op1) && "Expecting constant with alt binop"); + Opc0 = AltB0.Opcode; + C0 = cast<Constant>(AltB0.Op1); + } else if (BinopElts AltB1 = getAlternateBinop(B1, DL)) { + assert(isa<Constant>(AltB1.Op1) && "Expecting constant with alt binop"); + Opc1 = AltB1.Opcode; + C1 = cast<Constant>(AltB1.Op1); + } + } + + if (Opc0 != Opc1) + return nullptr; + + // The opcodes must be the same. Use a new name to make that clear. + BinaryOperator::BinaryOps BOpc = Opc0; + + // Select the constant elements needed for the single binop. + Constant *Mask = Shuf.getMask(); + Constant *NewC = ConstantExpr::getShuffleVector(C0, C1, Mask); + + // We are moving a binop after a shuffle. When a shuffle has an undefined + // mask element, the result is undefined, but it is not poison or undefined + // behavior. That is not necessarily true for div/rem/shift. + bool MightCreatePoisonOrUB = + Mask->containsUndefElement() && + (Instruction::isIntDivRem(BOpc) || Instruction::isShift(BOpc)); + if (MightCreatePoisonOrUB) + NewC = getSafeVectorConstantForBinop(BOpc, NewC, ConstantsAreOp1); + + Value *V; + if (X == Y) { + // Remove a binop and the shuffle by rearranging the constant: + // shuffle (op V, C0), (op V, C1), M --> op V, C' + // shuffle (op C0, V), (op C1, V), M --> op C', V + V = X; + } else { + // If there are 2 different variable operands, we must create a new shuffle + // (select) first, so check uses to ensure that we don't end up with more + // instructions than we started with. + if (!B0->hasOneUse() && !B1->hasOneUse()) + return nullptr; + + // If we use the original shuffle mask and op1 is *variable*, we would be + // putting an undef into operand 1 of div/rem/shift. This is either UB or + // poison. We do not have to guard against UB when *constants* are op1 + // because safe constants guarantee that we do not overflow sdiv/srem (and + // there's no danger for other opcodes). + // TODO: To allow this case, create a new shuffle mask with no undefs. + if (MightCreatePoisonOrUB && !ConstantsAreOp1) + return nullptr; + + // Note: In general, we do not create new shuffles in InstCombine because we + // do not know if a target can lower an arbitrary shuffle optimally. In this + // case, the shuffle uses the existing mask, so there is no additional risk. + + // Select the variable vectors first, then perform the binop: + // shuffle (op X, C0), (op Y, C1), M --> op (shuffle X, Y, M), C' + // shuffle (op C0, X), (op C1, Y), M --> op C', (shuffle X, Y, M) + V = Builder.CreateShuffleVector(X, Y, Mask); + } + + Instruction *NewBO = ConstantsAreOp1 ? BinaryOperator::Create(BOpc, V, NewC) : + BinaryOperator::Create(BOpc, NewC, V); + + // Flags are intersected from the 2 source binops. But there are 2 exceptions: + // 1. If we changed an opcode, poison conditions might have changed. + // 2. If the shuffle had undef mask elements, the new binop might have undefs + // where the original code did not. But if we already made a safe constant, + // then there's no danger. + NewBO->copyIRFlags(B0); + NewBO->andIRFlags(B1); + if (DropNSW) + NewBO->setHasNoSignedWrap(false); + if (Mask->containsUndefElement() && !MightCreatePoisonOrUB) + NewBO->dropPoisonGeneratingFlags(); + return NewBO; +} + Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { Value *LHS = SVI.getOperand(0); Value *RHS = SVI.getOperand(1); @@ -1150,6 +1360,9 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { LHS, RHS, SVI.getMask(), SVI.getType(), SQ.getWithInstruction(&SVI))) return replaceInstUsesWith(SVI, V); + if (Instruction *I = foldSelectShuffle(SVI, Builder, DL)) + return I; + bool MadeChange = false; unsigned VWidth = SVI.getType()->getVectorNumElements(); diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp index b332e75c7feb..12fcc8752ea9 100644 --- a/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -34,6 +34,8 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm-c/Initialization.h" +#include "llvm-c/Transforms/InstCombine.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -55,6 +57,7 @@ #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -72,6 +75,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" @@ -93,8 +97,6 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombine.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" -#include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> #include <cstdint> @@ -144,12 +146,20 @@ Value *InstCombiner::EmitGEPOffset(User *GEP) { /// We don't want to convert from a legal to an illegal type or from a smaller /// to a larger illegal type. A width of '1' is always treated as a legal type /// because i1 is a fundamental type in IR, and there are many specialized -/// optimizations for i1 types. +/// optimizations for i1 types. Widths of 8, 16 or 32 are equally treated as +/// legal to convert to, in order to open up more combining opportunities. +/// NOTE: this treats i8, i16 and i32 specially, due to them being so common +/// from frontend languages. bool InstCombiner::shouldChangeType(unsigned FromWidth, unsigned ToWidth) const { bool FromLegal = FromWidth == 1 || DL.isLegalInteger(FromWidth); bool ToLegal = ToWidth == 1 || DL.isLegalInteger(ToWidth); + // Convert to widths of 8, 16 or 32 even if they are not legal types. Only + // shrink types, to prevent infinite loops. + if (ToWidth < FromWidth && (ToWidth == 8 || ToWidth == 16 || ToWidth == 32)) + return true; + // If this is a legal integer from type, and the result would be an illegal // type, don't do the transformation. if (FromLegal && !ToLegal) @@ -396,28 +406,23 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { // Transform: "(A op C1) op (B op C2)" ==> "(A op B) op (C1 op C2)" // if C1 and C2 are constants. + Value *A, *B; + Constant *C1, *C2; if (Op0 && Op1 && Op0->getOpcode() == Opcode && Op1->getOpcode() == Opcode && - isa<Constant>(Op0->getOperand(1)) && - isa<Constant>(Op1->getOperand(1)) && - Op0->hasOneUse() && Op1->hasOneUse()) { - Value *A = Op0->getOperand(0); - Constant *C1 = cast<Constant>(Op0->getOperand(1)); - Value *B = Op1->getOperand(0); - Constant *C2 = cast<Constant>(Op1->getOperand(1)); - - Constant *Folded = ConstantExpr::get(Opcode, C1, C2); - BinaryOperator *New = BinaryOperator::Create(Opcode, A, B); - if (isa<FPMathOperator>(New)) { + match(Op0, m_OneUse(m_BinOp(m_Value(A), m_Constant(C1)))) && + match(Op1, m_OneUse(m_BinOp(m_Value(B), m_Constant(C2))))) { + BinaryOperator *NewBO = BinaryOperator::Create(Opcode, A, B); + if (isa<FPMathOperator>(NewBO)) { FastMathFlags Flags = I.getFastMathFlags(); Flags &= Op0->getFastMathFlags(); Flags &= Op1->getFastMathFlags(); - New->setFastMathFlags(Flags); + NewBO->setFastMathFlags(Flags); } - InsertNewInstWith(New, I); - New->takeName(Op1); - I.setOperand(0, New); - I.setOperand(1, Folded); + InsertNewInstWith(NewBO, I); + NewBO->takeName(Op1); + I.setOperand(0, NewBO); + I.setOperand(1, ConstantExpr::get(Opcode, C1, C2)); // Conservatively clear the optional flags, since they may not be // preserved by the reassociation. ClearSubclassDataAfterReassociation(I); @@ -434,72 +439,38 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { /// Return whether "X LOp (Y ROp Z)" is always equal to /// "(X LOp Y) ROp (X LOp Z)". -static bool LeftDistributesOverRight(Instruction::BinaryOps LOp, +static bool leftDistributesOverRight(Instruction::BinaryOps LOp, Instruction::BinaryOps ROp) { - switch (LOp) { - default: - return false; + // X & (Y | Z) <--> (X & Y) | (X & Z) + // X & (Y ^ Z) <--> (X & Y) ^ (X & Z) + if (LOp == Instruction::And) + return ROp == Instruction::Or || ROp == Instruction::Xor; - case Instruction::And: - // And distributes over Or and Xor. - switch (ROp) { - default: - return false; - case Instruction::Or: - case Instruction::Xor: - return true; - } + // X | (Y & Z) <--> (X | Y) & (X | Z) + if (LOp == Instruction::Or) + return ROp == Instruction::And; - case Instruction::Mul: - // Multiplication distributes over addition and subtraction. - switch (ROp) { - default: - return false; - case Instruction::Add: - case Instruction::Sub: - return true; - } + // X * (Y + Z) <--> (X * Y) + (X * Z) + // X * (Y - Z) <--> (X * Y) - (X * Z) + if (LOp == Instruction::Mul) + return ROp == Instruction::Add || ROp == Instruction::Sub; - case Instruction::Or: - // Or distributes over And. - switch (ROp) { - default: - return false; - case Instruction::And: - return true; - } - } + return false; } /// Return whether "(X LOp Y) ROp Z" is always equal to /// "(X ROp Z) LOp (Y ROp Z)". -static bool RightDistributesOverLeft(Instruction::BinaryOps LOp, +static bool rightDistributesOverLeft(Instruction::BinaryOps LOp, Instruction::BinaryOps ROp) { if (Instruction::isCommutative(ROp)) - return LeftDistributesOverRight(ROp, LOp); + return leftDistributesOverRight(ROp, LOp); + + // (X {&|^} Y) >> Z <--> (X >> Z) {&|^} (Y >> Z) for all shifts. + return Instruction::isBitwiseLogicOp(LOp) && Instruction::isShift(ROp); - switch (LOp) { - default: - return false; - // (X >> Z) & (Y >> Z) -> (X&Y) >> Z for all shifts. - // (X >> Z) | (Y >> Z) -> (X|Y) >> Z for all shifts. - // (X >> Z) ^ (Y >> Z) -> (X^Y) >> Z for all shifts. - case Instruction::And: - case Instruction::Or: - case Instruction::Xor: - switch (ROp) { - default: - return false; - case Instruction::Shl: - case Instruction::LShr: - case Instruction::AShr: - return true; - } - } // TODO: It would be nice to handle division, aka "(X + Y)/Z = X/Z + Y/Z", // but this requires knowing that the addition does not overflow and other // such subtleties. - return false; } /// This function returns identity value for given opcode, which can be used to @@ -511,37 +482,27 @@ static Value *getIdentityValue(Instruction::BinaryOps Opcode, Value *V) { return ConstantExpr::getBinOpIdentity(Opcode, V->getType()); } -/// This function factors binary ops which can be combined using distributive -/// laws. This function tries to transform 'Op' based TopLevelOpcode to enable -/// factorization e.g for ADD(SHL(X , 2), MUL(X, 5)), When this function called -/// with TopLevelOpcode == Instruction::Add and Op = SHL(X, 2), transforms -/// SHL(X, 2) to MUL(X, 4) i.e. returns Instruction::Mul with LHS set to 'X' and -/// RHS to 4. +/// This function predicates factorization using distributive laws. By default, +/// it just returns the 'Op' inputs. But for special-cases like +/// 'add(shl(X, 5), ...)', this function will have TopOpcode == Instruction::Add +/// and Op = shl(X, 5). The 'shl' is treated as the more general 'mul X, 32' to +/// allow more factorization opportunities. static Instruction::BinaryOps -getBinOpsForFactorization(Instruction::BinaryOps TopLevelOpcode, - BinaryOperator *Op, Value *&LHS, Value *&RHS) { +getBinOpsForFactorization(Instruction::BinaryOps TopOpcode, BinaryOperator *Op, + Value *&LHS, Value *&RHS) { assert(Op && "Expected a binary operator"); - LHS = Op->getOperand(0); RHS = Op->getOperand(1); - - switch (TopLevelOpcode) { - default: - return Op->getOpcode(); - - case Instruction::Add: - case Instruction::Sub: - if (Op->getOpcode() == Instruction::Shl) { - if (Constant *CST = dyn_cast<Constant>(Op->getOperand(1))) { - // The multiplier is really 1 << CST. - RHS = ConstantExpr::getShl(ConstantInt::get(Op->getType(), 1), CST); - return Instruction::Mul; - } + if (TopOpcode == Instruction::Add || TopOpcode == Instruction::Sub) { + Constant *C; + if (match(Op, m_Shl(m_Value(), m_Constant(C)))) { + // X << C --> X * (1 << C) + RHS = ConstantExpr::getShl(ConstantInt::get(Op->getType(), 1), C); + return Instruction::Mul; } - return Op->getOpcode(); + // TODO: We can add other conversions e.g. shr => div etc. } - - // TODO: We can add other conversions e.g. shr => div etc. + return Op->getOpcode(); } /// This tries to simplify binary operations by factorizing out common terms @@ -560,7 +521,7 @@ Value *InstCombiner::tryFactorization(BinaryOperator &I, bool InnerCommutative = Instruction::isCommutative(InnerOpcode); // Does "X op' (Y op Z)" always equal "(X op' Y) op (X op' Z)"? - if (LeftDistributesOverRight(InnerOpcode, TopLevelOpcode)) + if (leftDistributesOverRight(InnerOpcode, TopLevelOpcode)) // Does the instruction have the form "(A op' B) op (A op' D)" or, in the // commutative case, "(A op' B) op (C op' A)"? if (A == C || (InnerCommutative && A == D)) { @@ -579,7 +540,7 @@ Value *InstCombiner::tryFactorization(BinaryOperator &I, } // Does "(X op Y) op' Z" always equal "(X op' Z) op (Y op' Z)"? - if (!SimplifiedInst && RightDistributesOverLeft(TopLevelOpcode, InnerOpcode)) + if (!SimplifiedInst && rightDistributesOverLeft(TopLevelOpcode, InnerOpcode)) // Does the instruction have the form "(A op' B) op (C op' B)" or, in the // commutative case, "(A op' B) op (B op' D)"? if (B == D || (InnerCommutative && B == C)) { @@ -664,21 +625,19 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { // term. if (Op0) if (Value *Ident = getIdentityValue(LHSOpcode, RHS)) - if (Value *V = - tryFactorization(I, LHSOpcode, A, B, RHS, Ident)) + if (Value *V = tryFactorization(I, LHSOpcode, A, B, RHS, Ident)) return V; // The instruction has the form "(B) op (C op' D)". Try to factorize common // term. if (Op1) if (Value *Ident = getIdentityValue(RHSOpcode, LHS)) - if (Value *V = - tryFactorization(I, RHSOpcode, LHS, Ident, C, D)) + if (Value *V = tryFactorization(I, RHSOpcode, LHS, Ident, C, D)) return V; } // Expansion. - if (Op0 && RightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) { + if (Op0 && rightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) { // The instruction has the form "(A op' B) op C". See if expanding it out // to "(A op C) op' (B op C)" results in simplifications. Value *A = Op0->getOperand(0), *B = Op0->getOperand(1), *C = RHS; @@ -715,7 +674,7 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { } } - if (Op1 && LeftDistributesOverRight(TopLevelOpcode, Op1->getOpcode())) { + if (Op1 && leftDistributesOverRight(TopLevelOpcode, Op1->getOpcode())) { // The instruction has the form "A op (B op' C)". See if expanding it out // to "(A op B) op' (A op C)" results in simplifications. Value *A = LHS, *B = Op1->getOperand(0), *C = Op1->getOperand(1); @@ -817,23 +776,6 @@ Value *InstCombiner::dyn_castNegVal(Value *V) const { return nullptr; } -/// Given a 'fsub' instruction, return the RHS of the instruction if the LHS is -/// a constant negative zero (which is the 'negate' form). -Value *InstCombiner::dyn_castFNegVal(Value *V, bool IgnoreZeroSign) const { - if (BinaryOperator::isFNeg(V, IgnoreZeroSign)) - return BinaryOperator::getFNegArgument(V); - - // Constants can be considered to be negated values if they can be folded. - if (ConstantFP *C = dyn_cast<ConstantFP>(V)) - return ConstantExpr::getFNeg(C); - - if (ConstantDataVector *C = dyn_cast<ConstantDataVector>(V)) - if (C->getType()->getElementType()->isFloatingPointTy()) - return ConstantExpr::getFNeg(C); - - return nullptr; -} - static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO, InstCombiner::BuilderTy &Builder) { if (auto *Cast = dyn_cast<CastInst>(&I)) @@ -1081,8 +1023,9 @@ Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) { return replaceInstUsesWith(I, NewPN); } -Instruction *InstCombiner::foldOpWithConstantIntoOperand(BinaryOperator &I) { - assert(isa<Constant>(I.getOperand(1)) && "Unexpected operand type"); +Instruction *InstCombiner::foldBinOpIntoSelectOrPhi(BinaryOperator &I) { + if (!isa<Constant>(I.getOperand(1))) + return nullptr; if (auto *Sel = dyn_cast<SelectInst>(I.getOperand(0))) { if (Instruction *NewSel = FoldOpIntoSelect(I, Sel)) @@ -1107,7 +1050,7 @@ Type *InstCombiner::FindElementAtOffset(PointerType *PtrTy, int64_t Offset, // Start with the index over the outer type. Note that the type size // might be zero (even if the offset isn't zero) if the indexed type // is something like [0 x {int, int}] - Type *IntPtrTy = DL.getIntPtrType(PtrTy); + Type *IndexTy = DL.getIndexType(PtrTy); int64_t FirstIdx = 0; if (int64_t TySize = DL.getTypeAllocSize(Ty)) { FirstIdx = Offset/TySize; @@ -1122,7 +1065,7 @@ Type *InstCombiner::FindElementAtOffset(PointerType *PtrTy, int64_t Offset, assert((uint64_t)Offset < (uint64_t)TySize && "Out of range offset"); } - NewIndices.push_back(ConstantInt::get(IntPtrTy, FirstIdx)); + NewIndices.push_back(ConstantInt::get(IndexTy, FirstIdx)); // Index into the types. If we fail, set OrigBase to null. while (Offset) { @@ -1144,7 +1087,7 @@ Type *InstCombiner::FindElementAtOffset(PointerType *PtrTy, int64_t Offset, } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) { uint64_t EltSize = DL.getTypeAllocSize(AT->getElementType()); assert(EltSize && "Cannot index into a zero-sized array"); - NewIndices.push_back(ConstantInt::get(IntPtrTy,Offset/EltSize)); + NewIndices.push_back(ConstantInt::get(IndexTy,Offset/EltSize)); Offset %= EltSize; Ty = AT->getElementType(); } else { @@ -1408,22 +1351,7 @@ Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { } while (true); } -/// \brief Creates node of binary operation with the same attributes as the -/// specified one but with other operands. -static Value *CreateBinOpAsGiven(BinaryOperator &Inst, Value *LHS, Value *RHS, - InstCombiner::BuilderTy &B) { - Value *BO = B.CreateBinOp(Inst.getOpcode(), LHS, RHS); - // If LHS and RHS are constant, BO won't be a binary operator. - if (BinaryOperator *NewBO = dyn_cast<BinaryOperator>(BO)) - NewBO->copyIRFlags(&Inst); - return BO; -} - -/// \brief Makes transformation of binary operation specific for vector types. -/// \param Inst Binary operator to transform. -/// \return Pointer to node that must replace the original binary operator, or -/// null pointer if no transformation was made. -Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) { +Instruction *InstCombiner::foldShuffledBinop(BinaryOperator &Inst) { if (!Inst.getType()->isVectorTy()) return nullptr; // It may not be safe to reorder shuffles and things like div, urem, etc. @@ -1437,58 +1365,71 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) { assert(cast<VectorType>(LHS->getType())->getNumElements() == VWidth); assert(cast<VectorType>(RHS->getType())->getNumElements() == VWidth); + auto createBinOpShuffle = [&](Value *X, Value *Y, Constant *M) { + Value *XY = Builder.CreateBinOp(Inst.getOpcode(), X, Y); + if (auto *BO = dyn_cast<BinaryOperator>(XY)) + BO->copyIRFlags(&Inst); + return new ShuffleVectorInst(XY, UndefValue::get(XY->getType()), M); + }; + // If both arguments of the binary operation are shuffles that use the same - // mask and shuffle within a single vector, move the shuffle after the binop: - // Op(shuffle(v1, m), shuffle(v2, m)) -> shuffle(Op(v1, v2), m) - auto *LShuf = dyn_cast<ShuffleVectorInst>(LHS); - auto *RShuf = dyn_cast<ShuffleVectorInst>(RHS); - if (LShuf && RShuf && LShuf->getMask() == RShuf->getMask() && - isa<UndefValue>(LShuf->getOperand(1)) && - isa<UndefValue>(RShuf->getOperand(1)) && - LShuf->getOperand(0)->getType() == RShuf->getOperand(0)->getType()) { - Value *NewBO = CreateBinOpAsGiven(Inst, LShuf->getOperand(0), - RShuf->getOperand(0), Builder); - return Builder.CreateShuffleVector( - NewBO, UndefValue::get(NewBO->getType()), LShuf->getMask()); + // mask and shuffle within a single vector, move the shuffle after the binop. + Value *V1, *V2; + Constant *Mask; + if (match(LHS, m_ShuffleVector(m_Value(V1), m_Undef(), m_Constant(Mask))) && + match(RHS, m_ShuffleVector(m_Value(V2), m_Undef(), m_Specific(Mask))) && + V1->getType() == V2->getType() && + (LHS->hasOneUse() || RHS->hasOneUse() || LHS == RHS)) { + // Op(shuffle(V1, Mask), shuffle(V2, Mask)) -> shuffle(Op(V1, V2), Mask) + return createBinOpShuffle(V1, V2, Mask); } - // If one argument is a shuffle within one vector, the other is a constant, - // try moving the shuffle after the binary operation. - ShuffleVectorInst *Shuffle = nullptr; - Constant *C1 = nullptr; - if (isa<ShuffleVectorInst>(LHS)) Shuffle = cast<ShuffleVectorInst>(LHS); - if (isa<ShuffleVectorInst>(RHS)) Shuffle = cast<ShuffleVectorInst>(RHS); - if (isa<Constant>(LHS)) C1 = cast<Constant>(LHS); - if (isa<Constant>(RHS)) C1 = cast<Constant>(RHS); - if (Shuffle && C1 && - (isa<ConstantVector>(C1) || isa<ConstantDataVector>(C1)) && - isa<UndefValue>(Shuffle->getOperand(1)) && - Shuffle->getType() == Shuffle->getOperand(0)->getType()) { - SmallVector<int, 16> ShMask = Shuffle->getShuffleMask(); - // Find constant C2 that has property: - // shuffle(C2, ShMask) = C1 - // If such constant does not exist (example: ShMask=<0,0> and C1=<1,2>) - // reorder is not possible. - SmallVector<Constant*, 16> C2M(VWidth, - UndefValue::get(C1->getType()->getScalarType())); + // If one argument is a shuffle within one vector and the other is a constant, + // try moving the shuffle after the binary operation. This canonicalization + // intends to move shuffles closer to other shuffles and binops closer to + // other binops, so they can be folded. It may also enable demanded elements + // transforms. + Constant *C; + if (match(&Inst, m_c_BinOp( + m_OneUse(m_ShuffleVector(m_Value(V1), m_Undef(), m_Constant(Mask))), + m_Constant(C))) && + V1->getType() == Inst.getType()) { + // Find constant NewC that has property: + // shuffle(NewC, ShMask) = C + // If such constant does not exist (example: ShMask=<0,0> and C=<1,2>) + // reorder is not possible. A 1-to-1 mapping is not required. Example: + // ShMask = <1,1,2,2> and C = <5,5,6,6> --> NewC = <undef,5,6,undef> + SmallVector<int, 16> ShMask; + ShuffleVectorInst::getShuffleMask(Mask, ShMask); + SmallVector<Constant *, 16> + NewVecC(VWidth, UndefValue::get(C->getType()->getScalarType())); bool MayChange = true; for (unsigned I = 0; I < VWidth; ++I) { if (ShMask[I] >= 0) { assert(ShMask[I] < (int)VWidth); - if (!isa<UndefValue>(C2M[ShMask[I]])) { + Constant *CElt = C->getAggregateElement(I); + Constant *NewCElt = NewVecC[ShMask[I]]; + if (!CElt || (!isa<UndefValue>(NewCElt) && NewCElt != CElt)) { MayChange = false; break; } - C2M[ShMask[I]] = C1->getAggregateElement(I); + NewVecC[ShMask[I]] = CElt; } } if (MayChange) { - Constant *C2 = ConstantVector::get(C2M); - Value *NewLHS = isa<Constant>(LHS) ? C2 : Shuffle->getOperand(0); - Value *NewRHS = isa<Constant>(LHS) ? Shuffle->getOperand(0) : C2; - Value *NewBO = CreateBinOpAsGiven(Inst, NewLHS, NewRHS, Builder); - return Builder.CreateShuffleVector(NewBO, - UndefValue::get(Inst.getType()), Shuffle->getMask()); + Constant *NewC = ConstantVector::get(NewVecC); + // It may not be safe to execute a binop on a vector with undef elements + // because the entire instruction can be folded to undef or create poison + // that did not exist in the original code. + bool ConstOp1 = isa<Constant>(Inst.getOperand(1)); + if (Inst.isIntDivRem() || (Inst.isShift() && ConstOp1)) + NewC = getSafeVectorConstantForBinop(Inst.getOpcode(), NewC, ConstOp1); + + // Op(shuffle(V1, Mask), C) -> shuffle(Op(V1, NewC), Mask) + // Op(C, shuffle(V1, Mask)) -> shuffle(Op(NewC, V1), Mask) + Value *NewLHS = isa<Constant>(LHS) ? NewC : V1; + Value *NewRHS = isa<Constant>(LHS) ? V1 : NewC; + return createBinOpShuffle(NewLHS, NewRHS, Mask); } } @@ -1497,9 +1438,9 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) { Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end()); - - if (Value *V = SimplifyGEPInst(GEP.getSourceElementType(), Ops, - SQ.getWithInstruction(&GEP))) + Type *GEPType = GEP.getType(); + Type *GEPEltType = GEP.getSourceElementType(); + if (Value *V = SimplifyGEPInst(GEPEltType, Ops, SQ.getWithInstruction(&GEP))) return replaceInstUsesWith(GEP, V); Value *PtrOp = GEP.getOperand(0); @@ -1507,8 +1448,11 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // Eliminate unneeded casts for indices, and replace indices which displace // by multiples of a zero size type with zero. bool MadeChange = false; - Type *IntPtrTy = - DL.getIntPtrType(GEP.getPointerOperandType()->getScalarType()); + + // Index width may not be the same width as pointer width. + // Data layout chooses the right type based on supported integer types. + Type *NewScalarIndexTy = + DL.getIndexType(GEP.getPointerOperandType()->getScalarType()); gep_type_iterator GTI = gep_type_begin(GEP); for (User::op_iterator I = GEP.op_begin() + 1, E = GEP.op_end(); I != E; @@ -1517,10 +1461,11 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (GTI.isStruct()) continue; - // Index type should have the same width as IntPtr Type *IndexTy = (*I)->getType(); - Type *NewIndexType = IndexTy->isVectorTy() ? - VectorType::get(IntPtrTy, IndexTy->getVectorNumElements()) : IntPtrTy; + Type *NewIndexType = + IndexTy->isVectorTy() + ? VectorType::get(NewScalarIndexTy, IndexTy->getVectorNumElements()) + : NewScalarIndexTy; // If the element type has zero size then any index over it is equivalent // to an index of zero, so replace it with zero if it is not zero already. @@ -1543,8 +1488,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { return &GEP; // Check to see if the inputs to the PHI node are getelementptr instructions. - if (PHINode *PN = dyn_cast<PHINode>(PtrOp)) { - GetElementPtrInst *Op1 = dyn_cast<GetElementPtrInst>(PN->getOperand(0)); + if (auto *PN = dyn_cast<PHINode>(PtrOp)) { + auto *Op1 = dyn_cast<GetElementPtrInst>(PN->getOperand(0)); if (!Op1) return nullptr; @@ -1560,7 +1505,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { int DI = -1; for (auto I = PN->op_begin()+1, E = PN->op_end(); I !=E; ++I) { - GetElementPtrInst *Op2 = dyn_cast<GetElementPtrInst>(*I); + auto *Op2 = dyn_cast<GetElementPtrInst>(*I); if (!Op2 || Op1->getNumOperands() != Op2->getNumOperands()) return nullptr; @@ -1602,7 +1547,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (J > 0) { if (J == 1) { CurTy = Op1->getSourceElementType(); - } else if (CompositeType *CT = dyn_cast<CompositeType>(CurTy)) { + } else if (auto *CT = dyn_cast<CompositeType>(CurTy)) { CurTy = CT->getTypeAtIndex(Op1->getOperand(J)); } else { CurTy = nullptr; @@ -1617,7 +1562,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (DI != -1 && !PN->hasOneUse()) return nullptr; - GetElementPtrInst *NewGEP = cast<GetElementPtrInst>(Op1->clone()); + auto *NewGEP = cast<GetElementPtrInst>(Op1->clone()); if (DI == -1) { // All the GEPs feeding the PHI are identical. Clone one down into our // BB so that it can be merged with the current GEP. @@ -1652,15 +1597,64 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // Combine Indices - If the source pointer to this getelementptr instruction // is a getelementptr instruction, combine the indices of the two // getelementptr instructions into a single instruction. - if (GEPOperator *Src = dyn_cast<GEPOperator>(PtrOp)) { + if (auto *Src = dyn_cast<GEPOperator>(PtrOp)) { if (!shouldMergeGEPs(*cast<GEPOperator>(&GEP), *Src)) return nullptr; + // Try to reassociate loop invariant GEP chains to enable LICM. + if (LI && Src->getNumOperands() == 2 && GEP.getNumOperands() == 2 && + Src->hasOneUse()) { + if (Loop *L = LI->getLoopFor(GEP.getParent())) { + Value *GO1 = GEP.getOperand(1); + Value *SO1 = Src->getOperand(1); + // Reassociate the two GEPs if SO1 is variant in the loop and GO1 is + // invariant: this breaks the dependence between GEPs and allows LICM + // to hoist the invariant part out of the loop. + if (L->isLoopInvariant(GO1) && !L->isLoopInvariant(SO1)) { + // We have to be careful here. + // We have something like: + // %src = getelementptr <ty>, <ty>* %base, <ty> %idx + // %gep = getelementptr <ty>, <ty>* %src, <ty> %idx2 + // If we just swap idx & idx2 then we could inadvertantly + // change %src from a vector to a scalar, or vice versa. + // Cases: + // 1) %base a scalar & idx a scalar & idx2 a vector + // => Swapping idx & idx2 turns %src into a vector type. + // 2) %base a scalar & idx a vector & idx2 a scalar + // => Swapping idx & idx2 turns %src in a scalar type + // 3) %base, %idx, and %idx2 are scalars + // => %src & %gep are scalars + // => swapping idx & idx2 is safe + // 4) %base a vector + // => %src is a vector + // => swapping idx & idx2 is safe. + auto *SO0 = Src->getOperand(0); + auto *SO0Ty = SO0->getType(); + if (!isa<VectorType>(GEPType) || // case 3 + isa<VectorType>(SO0Ty)) { // case 4 + Src->setOperand(1, GO1); + GEP.setOperand(1, SO1); + return &GEP; + } else { + // Case 1 or 2 + // -- have to recreate %src & %gep + // put NewSrc at same location as %src + Builder.SetInsertPoint(cast<Instruction>(PtrOp)); + auto *NewSrc = cast<GetElementPtrInst>( + Builder.CreateGEP(SO0, GO1, Src->getName())); + NewSrc->setIsInBounds(Src->isInBounds()); + auto *NewGEP = GetElementPtrInst::Create(nullptr, NewSrc, {SO1}); + NewGEP->setIsInBounds(GEP.isInBounds()); + return NewGEP; + } + } + } + } + // Note that if our source is a gep chain itself then we wait for that // chain to be resolved before we perform this transformation. This // avoids us creating a TON of code in some cases. - if (GEPOperator *SrcGEP = - dyn_cast<GEPOperator>(Src->getOperand(0))) + if (auto *SrcGEP = dyn_cast<GEPOperator>(Src->getOperand(0))) if (SrcGEP->getNumOperands() == 2 && shouldMergeGEPs(*Src, *SrcGEP)) return nullptr; // Wait until our source is folded to completion. @@ -1723,9 +1717,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (GEP.getNumIndices() == 1) { unsigned AS = GEP.getPointerAddressSpace(); if (GEP.getOperand(1)->getType()->getScalarSizeInBits() == - DL.getPointerSizeInBits(AS)) { - Type *Ty = GEP.getSourceElementType(); - uint64_t TyAllocSize = DL.getTypeAllocSize(Ty); + DL.getIndexSizeInBits(AS)) { + uint64_t TyAllocSize = DL.getTypeAllocSize(GEPEltType); bool Matched = false; uint64_t C; @@ -1752,22 +1745,20 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { Operator *Index = cast<Operator>(V); Value *PtrToInt = Builder.CreatePtrToInt(PtrOp, Index->getType()); Value *NewSub = Builder.CreateSub(PtrToInt, Index->getOperand(1)); - return CastInst::Create(Instruction::IntToPtr, NewSub, GEP.getType()); + return CastInst::Create(Instruction::IntToPtr, NewSub, GEPType); } // Canonicalize (gep i8* X, (ptrtoint Y)-(ptrtoint X)) // to (bitcast Y) Value *Y; if (match(V, m_Sub(m_PtrToInt(m_Value(Y)), - m_PtrToInt(m_Specific(GEP.getOperand(0)))))) { - return CastInst::CreatePointerBitCastOrAddrSpaceCast(Y, - GEP.getType()); - } + m_PtrToInt(m_Specific(GEP.getOperand(0)))))) + return CastInst::CreatePointerBitCastOrAddrSpaceCast(Y, GEPType); } } } // We do not handle pointer-vector geps here. - if (GEP.getType()->isVectorTy()) + if (GEPType->isVectorTy()) return nullptr; // Handle gep(bitcast x) and gep(gep x, 0, 0, 0). @@ -1776,7 +1767,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (StrippedPtr != PtrOp) { bool HasZeroPointerIndex = false; - if (ConstantInt *C = dyn_cast<ConstantInt>(GEP.getOperand(1))) + if (auto *C = dyn_cast<ConstantInt>(GEP.getOperand(1))) HasZeroPointerIndex = C->isZero(); // Transform: GEP (bitcast [10 x i8]* X to [0 x i8]*), i32 0, ... @@ -1787,8 +1778,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // // This occurs when the program declares an array extern like "int X[];" if (HasZeroPointerIndex) { - if (ArrayType *CATy = - dyn_cast<ArrayType>(GEP.getSourceElementType())) { + if (auto *CATy = dyn_cast<ArrayType>(GEPEltType)) { // GEP (bitcast i8* X to [0 x i8]*), i32 0, ... ? if (CATy->getElementType() == StrippedPtrTy->getElementType()) { // -> GEP i8* X, ... @@ -1804,11 +1794,10 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // -> // %0 = GEP i8 addrspace(1)* X, ... // addrspacecast i8 addrspace(1)* %0 to i8* - return new AddrSpaceCastInst(Builder.Insert(Res), GEP.getType()); + return new AddrSpaceCastInst(Builder.Insert(Res), GEPType); } - if (ArrayType *XATy = - dyn_cast<ArrayType>(StrippedPtrTy->getElementType())){ + if (auto *XATy = dyn_cast<ArrayType>(StrippedPtrTy->getElementType())) { // GEP (bitcast [10 x i8]* X to [0 x i8]*), i32 0, ... ? if (CATy->getElementType() == XATy->getElementType()) { // -> GEP [10 x i8]* X, i32 0, ... @@ -1836,7 +1825,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { nullptr, StrippedPtr, Idx, GEP.getName()) : Builder.CreateGEP(nullptr, StrippedPtr, Idx, GEP.getName()); - return new AddrSpaceCastInst(NewGEP, GEP.getType()); + return new AddrSpaceCastInst(NewGEP, GEPType); } } } @@ -1844,12 +1833,11 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // Transform things like: // %t = getelementptr i32* bitcast ([2 x i32]* %str to i32*), i32 %V // into: %t1 = getelementptr [2 x i32]* %str, i32 0, i32 %V; bitcast - Type *SrcElTy = StrippedPtrTy->getElementType(); - Type *ResElTy = GEP.getSourceElementType(); - if (SrcElTy->isArrayTy() && - DL.getTypeAllocSize(SrcElTy->getArrayElementType()) == - DL.getTypeAllocSize(ResElTy)) { - Type *IdxType = DL.getIntPtrType(GEP.getType()); + Type *SrcEltTy = StrippedPtrTy->getElementType(); + if (SrcEltTy->isArrayTy() && + DL.getTypeAllocSize(SrcEltTy->getArrayElementType()) == + DL.getTypeAllocSize(GEPEltType)) { + Type *IdxType = DL.getIndexType(GEPType); Value *Idx[2] = { Constant::getNullValue(IdxType), GEP.getOperand(1) }; Value *NewGEP = GEP.isInBounds() @@ -1858,28 +1846,28 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { : Builder.CreateGEP(nullptr, StrippedPtr, Idx, GEP.getName()); // V and GEP are both pointer types --> BitCast - return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, - GEP.getType()); + return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, GEPType); } // Transform things like: // %V = mul i64 %N, 4 // %t = getelementptr i8* bitcast (i32* %arr to i8*), i32 %V // into: %t1 = getelementptr i32* %arr, i32 %N; bitcast - if (ResElTy->isSized() && SrcElTy->isSized()) { + if (GEPEltType->isSized() && SrcEltTy->isSized()) { // Check that changing the type amounts to dividing the index by a scale // factor. - uint64_t ResSize = DL.getTypeAllocSize(ResElTy); - uint64_t SrcSize = DL.getTypeAllocSize(SrcElTy); + uint64_t ResSize = DL.getTypeAllocSize(GEPEltType); + uint64_t SrcSize = DL.getTypeAllocSize(SrcEltTy); if (ResSize && SrcSize % ResSize == 0) { Value *Idx = GEP.getOperand(1); unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); uint64_t Scale = SrcSize / ResSize; - // Earlier transforms ensure that the index has type IntPtrType, which - // considerably simplifies the logic by eliminating implicit casts. - assert(Idx->getType() == DL.getIntPtrType(GEP.getType()) && - "Index not cast to pointer width?"); + // Earlier transforms ensure that the index has the right type + // according to Data Layout, which considerably simplifies the + // logic by eliminating implicit casts. + assert(Idx->getType() == DL.getIndexType(GEPType) && + "Index type does not match the Data Layout preferences"); bool NSW; if (Value *NewIdx = Descale(Idx, APInt(BitWidth, Scale), NSW)) { @@ -1895,7 +1883,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // The NewGEP must be pointer typed, so must the old one -> BitCast return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, - GEP.getType()); + GEPType); } } } @@ -1904,39 +1892,40 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // getelementptr i8* bitcast ([100 x double]* X to i8*), i32 %tmp // (where tmp = 8*tmp2) into: // getelementptr [100 x double]* %arr, i32 0, i32 %tmp2; bitcast - if (ResElTy->isSized() && SrcElTy->isSized() && SrcElTy->isArrayTy()) { + if (GEPEltType->isSized() && SrcEltTy->isSized() && + SrcEltTy->isArrayTy()) { // Check that changing to the array element type amounts to dividing the // index by a scale factor. - uint64_t ResSize = DL.getTypeAllocSize(ResElTy); + uint64_t ResSize = DL.getTypeAllocSize(GEPEltType); uint64_t ArrayEltSize = - DL.getTypeAllocSize(SrcElTy->getArrayElementType()); + DL.getTypeAllocSize(SrcEltTy->getArrayElementType()); if (ResSize && ArrayEltSize % ResSize == 0) { Value *Idx = GEP.getOperand(1); unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); uint64_t Scale = ArrayEltSize / ResSize; - // Earlier transforms ensure that the index has type IntPtrType, which - // considerably simplifies the logic by eliminating implicit casts. - assert(Idx->getType() == DL.getIntPtrType(GEP.getType()) && - "Index not cast to pointer width?"); + // Earlier transforms ensure that the index has the right type + // according to the Data Layout, which considerably simplifies + // the logic by eliminating implicit casts. + assert(Idx->getType() == DL.getIndexType(GEPType) && + "Index type does not match the Data Layout preferences"); bool NSW; if (Value *NewIdx = Descale(Idx, APInt(BitWidth, Scale), NSW)) { // Successfully decomposed Idx as NewIdx * Scale, form a new GEP. // If the multiplication NewIdx * Scale may overflow then the new // GEP may not be "inbounds". - Value *Off[2] = { - Constant::getNullValue(DL.getIntPtrType(GEP.getType())), - NewIdx}; + Type *IndTy = DL.getIndexType(GEPType); + Value *Off[2] = {Constant::getNullValue(IndTy), NewIdx}; Value *NewGEP = GEP.isInBounds() && NSW ? Builder.CreateInBoundsGEP( - SrcElTy, StrippedPtr, Off, GEP.getName()) - : Builder.CreateGEP(SrcElTy, StrippedPtr, Off, + SrcEltTy, StrippedPtr, Off, GEP.getName()) + : Builder.CreateGEP(SrcEltTy, StrippedPtr, Off, GEP.getName()); // The NewGEP must be pointer typed, so must the old one -> BitCast return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, - GEP.getType()); + GEPType); } } } @@ -1946,34 +1935,53 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // addrspacecast between types is canonicalized as a bitcast, then an // addrspacecast. To take advantage of the below bitcast + struct GEP, look // through the addrspacecast. - if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(PtrOp)) { + Value *ASCStrippedPtrOp = PtrOp; + if (auto *ASC = dyn_cast<AddrSpaceCastInst>(PtrOp)) { // X = bitcast A addrspace(1)* to B addrspace(1)* // Y = addrspacecast A addrspace(1)* to B addrspace(2)* // Z = gep Y, <...constant indices...> // Into an addrspacecasted GEP of the struct. - if (BitCastInst *BC = dyn_cast<BitCastInst>(ASC->getOperand(0))) - PtrOp = BC; + if (auto *BC = dyn_cast<BitCastInst>(ASC->getOperand(0))) + ASCStrippedPtrOp = BC; } - /// See if we can simplify: - /// X = bitcast A* to B* - /// Y = gep X, <...constant indices...> - /// into a gep of the original struct. This is important for SROA and alias - /// analysis of unions. If "A" is also a bitcast, wait for A/X to be merged. - if (BitCastInst *BCI = dyn_cast<BitCastInst>(PtrOp)) { - Value *Operand = BCI->getOperand(0); - PointerType *OpType = cast<PointerType>(Operand->getType()); - unsigned OffsetBits = DL.getPointerTypeSizeInBits(GEP.getType()); - APInt Offset(OffsetBits, 0); - if (!isa<BitCastInst>(Operand) && - GEP.accumulateConstantOffset(DL, Offset)) { + if (auto *BCI = dyn_cast<BitCastInst>(ASCStrippedPtrOp)) { + Value *SrcOp = BCI->getOperand(0); + PointerType *SrcType = cast<PointerType>(BCI->getSrcTy()); + Type *SrcEltType = SrcType->getElementType(); + + // GEP directly using the source operand if this GEP is accessing an element + // of a bitcasted pointer to vector or array of the same dimensions: + // gep (bitcast <c x ty>* X to [c x ty]*), Y, Z --> gep X, Y, Z + // gep (bitcast [c x ty]* X to <c x ty>*), Y, Z --> gep X, Y, Z + auto areMatchingArrayAndVecTypes = [](Type *ArrTy, Type *VecTy) { + return ArrTy->getArrayElementType() == VecTy->getVectorElementType() && + ArrTy->getArrayNumElements() == VecTy->getVectorNumElements(); + }; + if (GEP.getNumOperands() == 3 && + ((GEPEltType->isArrayTy() && SrcEltType->isVectorTy() && + areMatchingArrayAndVecTypes(GEPEltType, SrcEltType)) || + (GEPEltType->isVectorTy() && SrcEltType->isArrayTy() && + areMatchingArrayAndVecTypes(SrcEltType, GEPEltType)))) { + GEP.setOperand(0, SrcOp); + GEP.setSourceElementType(SrcEltType); + return &GEP; + } + // See if we can simplify: + // X = bitcast A* to B* + // Y = gep X, <...constant indices...> + // into a gep of the original struct. This is important for SROA and alias + // analysis of unions. If "A" is also a bitcast, wait for A/X to be merged. + unsigned OffsetBits = DL.getIndexTypeSizeInBits(GEPType); + APInt Offset(OffsetBits, 0); + if (!isa<BitCastInst>(SrcOp) && GEP.accumulateConstantOffset(DL, Offset)) { // If this GEP instruction doesn't move the pointer, just replace the GEP // with a bitcast of the real input to the dest type. if (!Offset) { // If the bitcast is of an allocation, and the allocation will be // converted to match the type of the cast, don't touch this. - if (isa<AllocaInst>(Operand) || isAllocationFn(Operand, &TLI)) { + if (isa<AllocaInst>(SrcOp) || isAllocationFn(SrcOp, &TLI)) { // See if the bitcast simplifies, if so, don't nuke this GEP yet. if (Instruction *I = visitBitCast(*BCI)) { if (I != BCI) { @@ -1985,43 +1993,43 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { } } - if (Operand->getType()->getPointerAddressSpace() != GEP.getAddressSpace()) - return new AddrSpaceCastInst(Operand, GEP.getType()); - return new BitCastInst(Operand, GEP.getType()); + if (SrcType->getPointerAddressSpace() != GEP.getAddressSpace()) + return new AddrSpaceCastInst(SrcOp, GEPType); + return new BitCastInst(SrcOp, GEPType); } // Otherwise, if the offset is non-zero, we need to find out if there is a // field at Offset in 'A's type. If so, we can pull the cast through the // GEP. SmallVector<Value*, 8> NewIndices; - if (FindElementAtOffset(OpType, Offset.getSExtValue(), NewIndices)) { + if (FindElementAtOffset(SrcType, Offset.getSExtValue(), NewIndices)) { Value *NGEP = GEP.isInBounds() - ? Builder.CreateInBoundsGEP(nullptr, Operand, NewIndices) - : Builder.CreateGEP(nullptr, Operand, NewIndices); + ? Builder.CreateInBoundsGEP(nullptr, SrcOp, NewIndices) + : Builder.CreateGEP(nullptr, SrcOp, NewIndices); - if (NGEP->getType() == GEP.getType()) + if (NGEP->getType() == GEPType) return replaceInstUsesWith(GEP, NGEP); NGEP->takeName(&GEP); if (NGEP->getType()->getPointerAddressSpace() != GEP.getAddressSpace()) - return new AddrSpaceCastInst(NGEP, GEP.getType()); - return new BitCastInst(NGEP, GEP.getType()); + return new AddrSpaceCastInst(NGEP, GEPType); + return new BitCastInst(NGEP, GEPType); } } } if (!GEP.isInBounds()) { - unsigned PtrWidth = - DL.getPointerSizeInBits(PtrOp->getType()->getPointerAddressSpace()); - APInt BasePtrOffset(PtrWidth, 0); + unsigned IdxWidth = + DL.getIndexSizeInBits(PtrOp->getType()->getPointerAddressSpace()); + APInt BasePtrOffset(IdxWidth, 0); Value *UnderlyingPtrOp = PtrOp->stripAndAccumulateInBoundsConstantOffsets(DL, BasePtrOffset); if (auto *AI = dyn_cast<AllocaInst>(UnderlyingPtrOp)) { if (GEP.accumulateConstantOffset(DL, BasePtrOffset) && BasePtrOffset.isNonNegative()) { - APInt AllocSize(PtrWidth, DL.getTypeAllocSize(AI->getAllocatedType())); + APInt AllocSize(IdxWidth, DL.getTypeAllocSize(AI->getAllocatedType())); if (BasePtrOffset.ule(AllocSize)) { return GetElementPtrInst::CreateInBounds( PtrOp, makeArrayRef(Ops).slice(1), GEP.getName()); @@ -2198,7 +2206,7 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { return nullptr; } -/// \brief Move the call to free before a NULL test. +/// Move the call to free before a NULL test. /// /// Check if this free is accessed after its argument has been test /// against NULL (property 0). @@ -2562,6 +2570,7 @@ static bool isCatchAll(EHPersonality Personality, Constant *TypeInfo) { case EHPersonality::MSVC_Win64SEH: case EHPersonality::MSVC_CXX: case EHPersonality::CoreCLR: + case EHPersonality::Wasm_CXX: return TypeInfo->isNullValue(); } llvm_unreachable("invalid enum"); @@ -2889,6 +2898,7 @@ Instruction *InstCombiner::visitLandingPadInst(LandingPadInst &LI) { /// block. static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { assert(I->hasOneUse() && "Invariants didn't hold!"); + BasicBlock *SrcBlock = I->getParent(); // Cannot move control-flow-involving, volatile loads, vaarg, etc. if (isa<PHINode>(I) || I->isEHPad() || I->mayHaveSideEffects() || @@ -2918,10 +2928,20 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { if (Scan->mayWriteToMemory()) return false; } - BasicBlock::iterator InsertPos = DestBlock->getFirstInsertionPt(); I->moveBefore(&*InsertPos); ++NumSunkInst; + + // Also sink all related debug uses from the source basic block. Otherwise we + // get debug use before the def. + SmallVector<DbgInfoIntrinsic *, 1> DbgUsers; + findDbgUsers(DbgUsers, I); + for (auto *DII : DbgUsers) { + if (DII->getParent() == SrcBlock) { + DII->moveBefore(&*InsertPos); + LLVM_DEBUG(dbgs() << "SINK: " << *DII << '\n'); + } + } return true; } @@ -2932,7 +2952,7 @@ bool InstCombiner::run() { // Check to see if we can DCE the instruction. if (isInstructionTriviallyDead(I, &TLI)) { - DEBUG(dbgs() << "IC: DCE: " << *I << '\n'); + LLVM_DEBUG(dbgs() << "IC: DCE: " << *I << '\n'); eraseInstFromFunction(*I); ++NumDeadInst; MadeIRChange = true; @@ -2946,7 +2966,8 @@ bool InstCombiner::run() { if (!I->use_empty() && (I->getNumOperands() == 0 || isa<Constant>(I->getOperand(0)))) { if (Constant *C = ConstantFoldInstruction(I, DL, &TLI)) { - DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << *I << '\n'); + LLVM_DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << *I + << '\n'); // Add operands to the worklist. replaceInstUsesWith(*I, C); @@ -2965,8 +2986,8 @@ bool InstCombiner::run() { KnownBits Known = computeKnownBits(I, /*Depth*/0, I); if (Known.isConstant()) { Constant *C = ConstantInt::get(Ty, Known.getConstant()); - DEBUG(dbgs() << "IC: ConstFold (all bits known) to: " << *C << - " from: " << *I << '\n'); + LLVM_DEBUG(dbgs() << "IC: ConstFold (all bits known) to: " << *C + << " from: " << *I << '\n'); // Add operands to the worklist. replaceInstUsesWith(*I, C); @@ -3005,7 +3026,7 @@ bool InstCombiner::run() { if (UserIsSuccessor && UserParent->getUniquePredecessor()) { // Okay, the CFG is simple enough, try to sink this instruction. if (TryToSinkInstruction(I, UserParent)) { - DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); + LLVM_DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); MadeIRChange = true; // We'll add uses of the sunk instruction below, but since sinking // can expose opportunities for it's *operands* add them to the @@ -3025,15 +3046,15 @@ bool InstCombiner::run() { #ifndef NDEBUG std::string OrigI; #endif - DEBUG(raw_string_ostream SS(OrigI); I->print(SS); OrigI = SS.str();); - DEBUG(dbgs() << "IC: Visiting: " << OrigI << '\n'); + LLVM_DEBUG(raw_string_ostream SS(OrigI); I->print(SS); OrigI = SS.str();); + LLVM_DEBUG(dbgs() << "IC: Visiting: " << OrigI << '\n'); if (Instruction *Result = visit(*I)) { ++NumCombined; // Should we replace the old instruction with a new one? if (Result != I) { - DEBUG(dbgs() << "IC: Old = " << *I << '\n' - << " New = " << *Result << '\n'); + LLVM_DEBUG(dbgs() << "IC: Old = " << *I << '\n' + << " New = " << *Result << '\n'); if (I->getDebugLoc()) Result->setDebugLoc(I->getDebugLoc()); @@ -3060,8 +3081,8 @@ bool InstCombiner::run() { eraseInstFromFunction(*I); } else { - DEBUG(dbgs() << "IC: Mod = " << OrigI << '\n' - << " New = " << *I << '\n'); + LLVM_DEBUG(dbgs() << "IC: Mod = " << OrigI << '\n' + << " New = " << *I << '\n'); // If the instruction was modified, it's possible that it is now dead. // if so, remove it. @@ -3112,7 +3133,7 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, // DCE instruction if trivially dead. if (isInstructionTriviallyDead(Inst, TLI)) { ++NumDeadInst; - DEBUG(dbgs() << "IC: DCE: " << *Inst << '\n'); + LLVM_DEBUG(dbgs() << "IC: DCE: " << *Inst << '\n'); salvageDebugInfo(*Inst); Inst->eraseFromParent(); MadeIRChange = true; @@ -3123,8 +3144,8 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, if (!Inst->use_empty() && (Inst->getNumOperands() == 0 || isa<Constant>(Inst->getOperand(0)))) if (Constant *C = ConstantFoldInstruction(Inst, DL, TLI)) { - DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " - << *Inst << '\n'); + LLVM_DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << *Inst + << '\n'); Inst->replaceAllUsesWith(C); ++NumConstProp; if (isInstructionTriviallyDead(Inst, TLI)) @@ -3146,9 +3167,9 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, FoldRes = C; if (FoldRes != C) { - DEBUG(dbgs() << "IC: ConstFold operand of: " << *Inst - << "\n Old = " << *C - << "\n New = " << *FoldRes << '\n'); + LLVM_DEBUG(dbgs() << "IC: ConstFold operand of: " << *Inst + << "\n Old = " << *C + << "\n New = " << *FoldRes << '\n'); U = FoldRes; MadeIRChange = true; } @@ -3191,7 +3212,7 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, return MadeIRChange; } -/// \brief Populate the IC worklist from a function, and prune any dead basic +/// Populate the IC worklist from a function, and prune any dead basic /// blocks discovered in the process. /// /// This also does basic constant propagation and other forward fixing to make @@ -3251,8 +3272,8 @@ static bool combineInstructionsOverFunction( int Iteration = 0; while (true) { ++Iteration; - DEBUG(dbgs() << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on " - << F.getName() << "\n"); + LLVM_DEBUG(dbgs() << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on " + << F.getName() << "\n"); MadeIRChange |= prepareICWorklistFromFunction(F, DL, &TLI, Worklist); @@ -3348,3 +3369,7 @@ void LLVMInitializeInstCombine(LLVMPassRegistryRef R) { FunctionPass *llvm::createInstructionCombiningPass(bool ExpensiveCombines) { return new InstructionCombiningPass(ExpensiveCombines); } + +void LLVMAddInstructionCombiningPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createInstructionCombiningPass()); +} diff --git a/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/lib/Transforms/Instrumentation/AddressSanitizer.cpp index 8e39f24d819c..b3f659194558 100644 --- a/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -16,7 +16,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" @@ -25,6 +25,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/BinaryFormat/MachO.h" #include "llvm/IR/Argument.h" @@ -71,7 +72,6 @@ #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/ASanStackFrameLayout.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" #include <algorithm> @@ -107,10 +107,18 @@ static const uint64_t kMIPS64_ShadowOffset64 = 1ULL << 37; static const uint64_t kAArch64_ShadowOffset64 = 1ULL << 36; static const uint64_t kFreeBSD_ShadowOffset32 = 1ULL << 30; static const uint64_t kFreeBSD_ShadowOffset64 = 1ULL << 46; +static const uint64_t kNetBSD_ShadowOffset32 = 1ULL << 30; static const uint64_t kNetBSD_ShadowOffset64 = 1ULL << 46; static const uint64_t kPS4CPU_ShadowOffset64 = 1ULL << 40; static const uint64_t kWindowsShadowOffset32 = 3ULL << 28; +static const uint64_t kMyriadShadowScale = 5; +static const uint64_t kMyriadMemoryOffset32 = 0x80000000ULL; +static const uint64_t kMyriadMemorySize32 = 0x20000000ULL; +static const uint64_t kMyriadTagShift = 29; +static const uint64_t kMyriadDDRTag = 4; +static const uint64_t kMyriadCacheBitMask32 = 0x40000000ULL; + // The shadow memory space is dynamically allocated. static const uint64_t kWindowsShadowOffset64 = kDynamicShadowSentinel; @@ -145,7 +153,7 @@ static const char *const kAsanHandleNoReturnName = "__asan_handle_no_return"; static const int kMaxAsanStackMallocSizeClass = 10; static const char *const kAsanStackMallocNameTemplate = "__asan_stack_malloc_"; static const char *const kAsanStackFreeNameTemplate = "__asan_stack_free_"; -static const char *const kAsanGenPrefix = "__asan_gen_"; +static const char *const kAsanGenPrefix = "___asan_gen_"; static const char *const kODRGenPrefix = "__odr_asan_gen_"; static const char *const kSanCovGenPrefix = "__sancov_gen_"; static const char *const kAsanSetShadowPrefix = "__asan_set_shadow_"; @@ -485,18 +493,17 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, bool IsSystemZ = TargetTriple.getArch() == Triple::systemz; bool IsX86 = TargetTriple.getArch() == Triple::x86; bool IsX86_64 = TargetTriple.getArch() == Triple::x86_64; - bool IsMIPS32 = TargetTriple.getArch() == Triple::mips || - TargetTriple.getArch() == Triple::mipsel; - bool IsMIPS64 = TargetTriple.getArch() == Triple::mips64 || - TargetTriple.getArch() == Triple::mips64el; + bool IsMIPS32 = TargetTriple.isMIPS32(); + bool IsMIPS64 = TargetTriple.isMIPS64(); bool IsArmOrThumb = TargetTriple.isARM() || TargetTriple.isThumb(); bool IsAArch64 = TargetTriple.getArch() == Triple::aarch64; bool IsWindows = TargetTriple.isOSWindows(); bool IsFuchsia = TargetTriple.isOSFuchsia(); + bool IsMyriad = TargetTriple.getVendor() == llvm::Triple::Myriad; ShadowMapping Mapping; - Mapping.Scale = kDefaultShadowScale; + Mapping.Scale = IsMyriad ? kMyriadShadowScale : kDefaultShadowScale; if (ClMappingScale.getNumOccurrences() > 0) { Mapping.Scale = ClMappingScale; } @@ -508,11 +515,18 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, Mapping.Offset = kMIPS32_ShadowOffset32; else if (IsFreeBSD) Mapping.Offset = kFreeBSD_ShadowOffset32; + else if (IsNetBSD) + Mapping.Offset = kNetBSD_ShadowOffset32; else if (IsIOS) // If we're targeting iOS and x86, the binary is built for iOS simulator. Mapping.Offset = IsX86 ? kIOSSimShadowOffset32 : kIOSShadowOffset32; else if (IsWindows) Mapping.Offset = kWindowsShadowOffset32; + else if (IsMyriad) { + uint64_t ShadowOffset = (kMyriadMemoryOffset32 + kMyriadMemorySize32 - + (kMyriadMemorySize32 >> Mapping.Scale)); + Mapping.Offset = ShadowOffset - (kMyriadMemoryOffset32 >> Mapping.Scale); + } else Mapping.Offset = kDefaultShadowOffset32; } else { // LongSize == 64 @@ -589,9 +603,10 @@ struct AddressSanitizer : public FunctionPass { explicit AddressSanitizer(bool CompileKernel = false, bool Recover = false, bool UseAfterScope = false) - : FunctionPass(ID), CompileKernel(CompileKernel || ClEnableKasan), - Recover(Recover || ClRecover), - UseAfterScope(UseAfterScope || ClUseAfterScope) { + : FunctionPass(ID), UseAfterScope(UseAfterScope || ClUseAfterScope) { + this->Recover = ClRecover.getNumOccurrences() > 0 ? ClRecover : Recover; + this->CompileKernel = ClEnableKasan.getNumOccurrences() > 0 ? + ClEnableKasan : CompileKernel; initializeAddressSanitizerPass(*PassRegistry::getPassRegistry()); } @@ -717,8 +732,7 @@ public: explicit AddressSanitizerModule(bool CompileKernel = false, bool Recover = false, bool UseGlobalsGC = true) - : ModulePass(ID), CompileKernel(CompileKernel || ClEnableKasan), - Recover(Recover || ClRecover), + : ModulePass(ID), UseGlobalsGC(UseGlobalsGC && ClUseGlobalsGC), // Not a typo: ClWithComdat is almost completely pointless without // ClUseGlobalsGC (because then it only works on modules without @@ -727,7 +741,12 @@ public: // argument is designed as workaround. Therefore, disable both // ClWithComdat and ClUseGlobalsGC unless the frontend says it's ok to // do globals-gc. - UseCtorComdat(UseGlobalsGC && ClWithComdat) {} + UseCtorComdat(UseGlobalsGC && ClWithComdat) { + this->Recover = ClRecover.getNumOccurrences() > 0 ? + ClRecover : Recover; + this->CompileKernel = ClEnableKasan.getNumOccurrences() > 0 ? + ClEnableKasan : CompileKernel; + } bool runOnModule(Module &M) override; StringRef getPassName() const override { return "AddressSanitizerModule"; } @@ -869,7 +888,7 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { processStaticAllocas(); if (ClDebugStack) { - DEBUG(dbgs() << F); + LLVM_DEBUG(dbgs() << F); } return true; } @@ -888,13 +907,13 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { void createDynamicAllocasInitStorage(); // ----------------------- Visitors. - /// \brief Collect all Ret instructions. + /// Collect all Ret instructions. void visitReturnInst(ReturnInst &RI) { RetVec.push_back(&RI); } - /// \brief Collect all Resume instructions. + /// Collect all Resume instructions. void visitResumeInst(ResumeInst &RI) { RetVec.push_back(&RI); } - /// \brief Collect all CatchReturnInst instructions. + /// Collect all CatchReturnInst instructions. void visitCleanupReturnInst(CleanupReturnInst &CRI) { RetVec.push_back(&CRI); } void unpoisonDynamicAllocasBeforeInst(Instruction *InstBefore, @@ -942,7 +961,7 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { // requested memory, but also left, partial and right redzones. void handleDynamicAllocaCall(AllocaInst *AI); - /// \brief Collect Alloca instructions we want (and can) handle. + /// Collect Alloca instructions we want (and can) handle. void visitAllocaInst(AllocaInst &AI) { if (!ASan.isInterestingAlloca(AI)) { if (AI.isStaticAlloca()) { @@ -963,7 +982,7 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { AllocaVec.push_back(&AI); } - /// \brief Collect lifetime intrinsic calls to check for use-after-scope + /// Collect lifetime intrinsic calls to check for use-after-scope /// errors. void visitIntrinsicInst(IntrinsicInst &II) { Intrinsic::ID ID = II.getIntrinsicID(); @@ -1081,7 +1100,7 @@ static size_t TypeSizeToSizeIndex(uint32_t TypeSize) { return Res; } -// \brief Create a constant for Str so that we can pass it to the run-time lib. +// Create a constant for Str so that we can pass it to the run-time lib. static GlobalVariable *createPrivateGlobalForString(Module &M, StringRef Str, bool AllowMerging) { Constant *StrConst = ConstantDataArray::getString(M.getContext(), Str); @@ -1095,7 +1114,7 @@ static GlobalVariable *createPrivateGlobalForString(Module &M, StringRef Str, return GV; } -/// \brief Create a global describing a source location. +/// Create a global describing a source location. static GlobalVariable *createPrivateGlobalForSourceLoc(Module &M, LocationMetadata MD) { Constant *LocData[] = { @@ -1111,7 +1130,7 @@ static GlobalVariable *createPrivateGlobalForSourceLoc(Module &M, return GV; } -/// \brief Check if \p G has been created by a trusted compiler pass. +/// Check if \p G has been created by a trusted compiler pass. static bool GlobalWasGeneratedByCompiler(GlobalVariable *G) { // Do not instrument asan globals. if (G->getName().startswith(kAsanGenPrefix) || @@ -1487,6 +1506,8 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, uint32_t TypeSize, bool IsWrite, Value *SizeArgument, bool UseCalls, uint32_t Exp) { + bool IsMyriad = TargetTriple.getVendor() == llvm::Triple::Myriad; + IRBuilder<> IRB(InsertBefore); Value *AddrLong = IRB.CreatePointerCast(Addr, IntptrTy); size_t AccessSizeIndex = TypeSizeToSizeIndex(TypeSize); @@ -1501,6 +1522,23 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, return; } + if (IsMyriad) { + // Strip the cache bit and do range check. + // AddrLong &= ~kMyriadCacheBitMask32 + AddrLong = IRB.CreateAnd(AddrLong, ~kMyriadCacheBitMask32); + // Tag = AddrLong >> kMyriadTagShift + Value *Tag = IRB.CreateLShr(AddrLong, kMyriadTagShift); + // Tag == kMyriadDDRTag + Value *TagCheck = + IRB.CreateICmpEQ(Tag, ConstantInt::get(IntptrTy, kMyriadDDRTag)); + + TerminatorInst *TagCheckTerm = SplitBlockAndInsertIfThen( + TagCheck, InsertBefore, false, MDBuilder(*C).createBranchWeights(1, 100000)); + assert(cast<BranchInst>(TagCheckTerm)->isUnconditional()); + IRB.SetInsertPoint(TagCheckTerm); + InsertBefore = TagCheckTerm; + } + Type *ShadowTy = IntegerType::get(*C, std::max(8U, TypeSize >> Mapping.Scale)); Type *ShadowPtrTy = PointerType::get(ShadowTy, 0); @@ -1609,7 +1647,7 @@ void AddressSanitizerModule::createInitializerPoisonCalls( bool AddressSanitizerModule::ShouldInstrumentGlobal(GlobalVariable *G) { Type *Ty = G->getValueType(); - DEBUG(dbgs() << "GLOBAL: " << *G << "\n"); + LLVM_DEBUG(dbgs() << "GLOBAL: " << *G << "\n"); if (GlobalsMD.get(G).IsBlacklisted) return false; if (!Ty->isSized()) return false; @@ -1646,12 +1684,17 @@ bool AddressSanitizerModule::ShouldInstrumentGlobal(GlobalVariable *G) { return false; } - // Callbacks put into the CRT initializer/terminator sections - // should not be instrumented. + // On COFF, if the section name contains '$', it is highly likely that the + // user is using section sorting to create an array of globals similar to + // the way initialization callbacks are registered in .init_array and + // .CRT$XCU. The ATL also registers things in .ATL$__[azm]. Adding redzones + // to such globals is counterproductive, because the intent is that they + // will form an array, and out-of-bounds accesses are expected. // See https://github.com/google/sanitizers/issues/305 // and http://msdn.microsoft.com/en-US/en-en/library/bb918180(v=vs.120).aspx - if (Section.startswith(".CRT")) { - DEBUG(dbgs() << "Ignoring a global initializer callback: " << *G << "\n"); + if (TargetTriple.isOSBinFormatCOFF() && Section.contains('$')) { + LLVM_DEBUG(dbgs() << "Ignoring global in sorted section (contains '$'): " + << *G << "\n"); return false; } @@ -1668,7 +1711,7 @@ bool AddressSanitizerModule::ShouldInstrumentGlobal(GlobalVariable *G) { // them. if (ParsedSegment == "__OBJC" || (ParsedSegment == "__DATA" && ParsedSection.startswith("__objc_"))) { - DEBUG(dbgs() << "Ignoring ObjC runtime global: " << *G << "\n"); + LLVM_DEBUG(dbgs() << "Ignoring ObjC runtime global: " << *G << "\n"); return false; } // See https://github.com/google/sanitizers/issues/32 @@ -1680,13 +1723,13 @@ bool AddressSanitizerModule::ShouldInstrumentGlobal(GlobalVariable *G) { // Therefore there's no point in placing redzones into __DATA,__cfstring. // Moreover, it causes the linker to crash on OS X 10.7 if (ParsedSegment == "__DATA" && ParsedSection == "__cfstring") { - DEBUG(dbgs() << "Ignoring CFString: " << *G << "\n"); + LLVM_DEBUG(dbgs() << "Ignoring CFString: " << *G << "\n"); return false; } // The linker merges the contents of cstring_literals and removes the // trailing zeroes. if (ParsedSegment == "__TEXT" && (TAA & MachO::S_CSTRING_LITERALS)) { - DEBUG(dbgs() << "Ignoring a cstring literal: " << *G << "\n"); + LLVM_DEBUG(dbgs() << "Ignoring a cstring literal: " << *G << "\n"); return false; } } @@ -2153,11 +2196,21 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M, bool if (ClInitializers && MD.IsDynInit) HasDynamicallyInitializedGlobals = true; - DEBUG(dbgs() << "NEW GLOBAL: " << *NewGlobal << "\n"); + LLVM_DEBUG(dbgs() << "NEW GLOBAL: " << *NewGlobal << "\n"); Initializers[i] = Initializer; } + // Add instrumented globals to llvm.compiler.used list to avoid LTO from + // ConstantMerge'ing them. + SmallVector<GlobalValue *, 16> GlobalsToAddToUsedList; + for (size_t i = 0; i < n; i++) { + GlobalVariable *G = NewGlobals[i]; + if (G->getName().empty()) continue; + GlobalsToAddToUsedList.push_back(G); + } + appendToCompilerUsed(M, ArrayRef<GlobalValue *>(GlobalsToAddToUsedList)); + std::string ELFUniqueModuleId = (UseGlobalsGC && TargetTriple.isOSBinFormatELF()) ? getUniqueModuleId(&M) : ""; @@ -2177,7 +2230,7 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M, bool if (HasDynamicallyInitializedGlobals) createInitializerPoisonCalls(M, ModuleName); - DEBUG(dbgs() << M); + LLVM_DEBUG(dbgs() << M); return true; } @@ -2247,7 +2300,6 @@ void AddressSanitizer::initializeCallbacks(Module &M) { for (size_t AccessIsWrite = 0; AccessIsWrite <= 1; AccessIsWrite++) { const std::string TypeStr = AccessIsWrite ? "store" : "load"; const std::string ExpStr = Exp ? "exp_" : ""; - const std::string SuffixStr = CompileKernel ? "N" : "_n"; const std::string EndingStr = Recover ? "_noabort" : ""; SmallVector<Type *, 3> Args2 = {IntptrTy, IntptrTy}; @@ -2259,8 +2311,7 @@ void AddressSanitizer::initializeCallbacks(Module &M) { } AsanErrorCallbackSized[AccessIsWrite][Exp] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanReportErrorTemplate + ExpStr + TypeStr + SuffixStr + - EndingStr, + kAsanReportErrorTemplate + ExpStr + TypeStr + "_n" + EndingStr, FunctionType::get(IRB.getVoidTy(), Args2, false))); AsanMemoryAccessCallbackSized[AccessIsWrite][Exp] = @@ -2420,7 +2471,7 @@ bool AddressSanitizer::runOnFunction(Function &F) { // Leave if the function doesn't need instrumentation. if (!F.hasFnAttribute(Attribute::SanitizeAddress)) return FunctionModified; - DEBUG(dbgs() << "ASAN instrumenting:\n" << F << "\n"); + LLVM_DEBUG(dbgs() << "ASAN instrumenting:\n" << F << "\n"); initializeCallbacks(*F.getParent()); DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); @@ -2435,7 +2486,7 @@ bool AddressSanitizer::runOnFunction(Function &F) { // We want to instrument every address only once per basic block (unless there // are calls between uses). - SmallSet<Value *, 16> TempsToInstrument; + SmallPtrSet<Value *, 16> TempsToInstrument; SmallVector<Instruction *, 16> ToInstrument; SmallVector<Instruction *, 8> NoReturnCalls; SmallVector<BasicBlock *, 16> AllBlocks; @@ -2494,7 +2545,6 @@ bool AddressSanitizer::runOnFunction(Function &F) { } bool UseCalls = - CompileKernel || (ClInstrumentationWithCallsThreshold >= 0 && ToInstrument.size() > (unsigned)ClInstrumentationWithCallsThreshold); const DataLayout &DL = F.getParent()->getDataLayout(); @@ -2534,8 +2584,8 @@ bool AddressSanitizer::runOnFunction(Function &F) { if (NumInstrumented > 0 || ChangedStack || !NoReturnCalls.empty()) FunctionModified = true; - DEBUG(dbgs() << "ASAN done instrumenting: " << FunctionModified << " " - << F << "\n"); + LLVM_DEBUG(dbgs() << "ASAN done instrumenting: " << FunctionModified << " " + << F << "\n"); return FunctionModified; } @@ -2710,7 +2760,7 @@ void FunctionStackPoisoner::copyArgsPassedByValToAllocas() { Arg.replaceAllUsesWith(AI); uint64_t AllocSize = DL.getTypeAllocSize(Ty); - IRB.CreateMemCpy(AI, &Arg, AllocSize, Align); + IRB.CreateMemCpy(AI, Align, &Arg, Align, AllocSize); } } } @@ -2851,7 +2901,7 @@ void FunctionStackPoisoner::processStaticAllocas() { } auto DescriptionString = ComputeASanStackFrameDescription(SVD); - DEBUG(dbgs() << DescriptionString << " --- " << L.FrameSize << "\n"); + LLVM_DEBUG(dbgs() << DescriptionString << " --- " << L.FrameSize << "\n"); uint64_t LocalStackSize = L.FrameSize; bool DoStackMalloc = ClUseAfterReturn && !ASan.CompileKernel && LocalStackSize <= kMaxStackMallocSize; @@ -3086,7 +3136,8 @@ AllocaInst *FunctionStackPoisoner::findAllocaForValue(Value *V) { } else if (GetElementPtrInst *EP = dyn_cast<GetElementPtrInst>(V)) { Res = findAllocaForValue(EP->getPointerOperand()); } else { - DEBUG(dbgs() << "Alloca search canceled on unknown instruction: " << *V << "\n"); + LLVM_DEBUG(dbgs() << "Alloca search canceled on unknown instruction: " << *V + << "\n"); } if (Res) AllocaForValue[V] = Res; return Res; diff --git a/lib/Transforms/Instrumentation/BoundsChecking.cpp b/lib/Transforms/Instrumentation/BoundsChecking.cpp index be9a22a8681b..e13db08e263c 100644 --- a/lib/Transforms/Instrumentation/BoundsChecking.cpp +++ b/lib/Transforms/Instrumentation/BoundsChecking.cpp @@ -11,6 +11,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/BasicBlock.h" @@ -59,11 +60,11 @@ template <typename GetTrapBBT> static bool instrumentMemAccess(Value *Ptr, Value *InstVal, const DataLayout &DL, TargetLibraryInfo &TLI, ObjectSizeOffsetEvaluator &ObjSizeEval, - BuilderTy &IRB, - GetTrapBBT GetTrapBB) { + BuilderTy &IRB, GetTrapBBT GetTrapBB, + ScalarEvolution &SE) { uint64_t NeededSize = DL.getTypeStoreSize(InstVal->getType()); - DEBUG(dbgs() << "Instrument " << *Ptr << " for " << Twine(NeededSize) - << " bytes\n"); + LLVM_DEBUG(dbgs() << "Instrument " << *Ptr << " for " << Twine(NeededSize) + << " bytes\n"); SizeOffsetEvalType SizeOffset = ObjSizeEval.compute(Ptr); @@ -79,6 +80,10 @@ static bool instrumentMemAccess(Value *Ptr, Value *InstVal, Type *IntTy = DL.getIntPtrType(Ptr->getType()); Value *NeededSizeVal = ConstantInt::get(IntTy, NeededSize); + auto SizeRange = SE.getUnsignedRange(SE.getSCEV(Size)); + auto OffsetRange = SE.getUnsignedRange(SE.getSCEV(Offset)); + auto NeededSizeRange = SE.getUnsignedRange(SE.getSCEV(NeededSizeVal)); + // three checks are required to ensure safety: // . Offset >= 0 (since the offset is given from the base ptr) // . Size >= Offset (unsigned) @@ -87,10 +92,17 @@ static bool instrumentMemAccess(Value *Ptr, Value *InstVal, // optimization: if Size >= 0 (signed), skip 1st check // FIXME: add NSW/NUW here? -- we dont care if the subtraction overflows Value *ObjSize = IRB.CreateSub(Size, Offset); - Value *Cmp2 = IRB.CreateICmpULT(Size, Offset); - Value *Cmp3 = IRB.CreateICmpULT(ObjSize, NeededSizeVal); + Value *Cmp2 = SizeRange.getUnsignedMin().uge(OffsetRange.getUnsignedMax()) + ? ConstantInt::getFalse(Ptr->getContext()) + : IRB.CreateICmpULT(Size, Offset); + Value *Cmp3 = SizeRange.sub(OffsetRange) + .getUnsignedMin() + .uge(NeededSizeRange.getUnsignedMax()) + ? ConstantInt::getFalse(Ptr->getContext()) + : IRB.CreateICmpULT(ObjSize, NeededSizeVal); Value *Or = IRB.CreateOr(Cmp2, Cmp3); - if (!SizeCI || SizeCI->getValue().slt(0)) { + if ((!SizeCI || SizeCI->getValue().slt(0)) && + !SizeRange.getSignedMin().isNonNegative()) { Value *Cmp1 = IRB.CreateICmpSLT(Offset, ConstantInt::get(IntTy, 0)); Or = IRB.CreateOr(Cmp1, Or); } @@ -123,7 +135,8 @@ static bool instrumentMemAccess(Value *Ptr, Value *InstVal, return true; } -static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI) { +static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI, + ScalarEvolution &SE) { const DataLayout &DL = F.getParent()->getDataLayout(); ObjectSizeOffsetEvaluator ObjSizeEval(DL, &TLI, F.getContext(), /*RoundToAlign=*/true); @@ -168,19 +181,19 @@ static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI) { BuilderTy IRB(Inst->getParent(), BasicBlock::iterator(Inst), TargetFolder(DL)); if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { MadeChange |= instrumentMemAccess(LI->getPointerOperand(), LI, DL, TLI, - ObjSizeEval, IRB, GetTrapBB); + ObjSizeEval, IRB, GetTrapBB, SE); } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { MadeChange |= instrumentMemAccess(SI->getPointerOperand(), SI->getValueOperand(), - DL, TLI, ObjSizeEval, IRB, GetTrapBB); + DL, TLI, ObjSizeEval, IRB, GetTrapBB, SE); } else if (AtomicCmpXchgInst *AI = dyn_cast<AtomicCmpXchgInst>(Inst)) { MadeChange |= instrumentMemAccess(AI->getPointerOperand(), AI->getCompareOperand(), - DL, TLI, ObjSizeEval, IRB, GetTrapBB); + DL, TLI, ObjSizeEval, IRB, GetTrapBB, SE); } else if (AtomicRMWInst *AI = dyn_cast<AtomicRMWInst>(Inst)) { MadeChange |= instrumentMemAccess(AI->getPointerOperand(), AI->getValOperand(), DL, - TLI, ObjSizeEval, IRB, GetTrapBB); + TLI, ObjSizeEval, IRB, GetTrapBB, SE); } else { llvm_unreachable("unknown Instruction type"); } @@ -190,8 +203,9 @@ static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI) { PreservedAnalyses BoundsCheckingPass::run(Function &F, FunctionAnalysisManager &AM) { auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); - if (!addBoundsChecking(F, TLI)) + if (!addBoundsChecking(F, TLI, SE)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); @@ -207,11 +221,13 @@ struct BoundsCheckingLegacyPass : public FunctionPass { bool runOnFunction(Function &F) override { auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - return addBoundsChecking(F, TLI); + auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + return addBoundsChecking(F, TLI, SE); } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); } }; } // namespace diff --git a/lib/Transforms/Instrumentation/CFGMST.h b/lib/Transforms/Instrumentation/CFGMST.h index 075e5672cff8..cc9b149d0b6a 100644 --- a/lib/Transforms/Instrumentation/CFGMST.h +++ b/lib/Transforms/Instrumentation/CFGMST.h @@ -31,7 +31,7 @@ namespace llvm { -/// \brief An union-find based Minimum Spanning Tree for CFG +/// An union-find based Minimum Spanning Tree for CFG /// /// Implements a Union-find algorithm to compute Minimum Spanning Tree /// for a given CFG. @@ -97,7 +97,7 @@ public: // Edges with large weight will be put into MST first so they are less likely // to be instrumented. void buildEdges() { - DEBUG(dbgs() << "Build Edge on " << F.getName() << "\n"); + LLVM_DEBUG(dbgs() << "Build Edge on " << F.getName() << "\n"); const BasicBlock *Entry = &(F.getEntryBlock()); uint64_t EntryWeight = (BFI != nullptr ? BFI->getEntryFreq() : 2); @@ -107,8 +107,8 @@ public: // Add a fake edge to the entry. EntryIncoming = &addEdge(nullptr, Entry, EntryWeight); - DEBUG(dbgs() << " Edge: from fake node to " << Entry->getName() - << " w = " << EntryWeight << "\n"); + LLVM_DEBUG(dbgs() << " Edge: from fake node to " << Entry->getName() + << " w = " << EntryWeight << "\n"); // Special handling for single BB functions. if (succ_empty(Entry)) { @@ -138,8 +138,8 @@ public: Weight = BPI->getEdgeProbability(&*BB, TargetBB).scale(scaleFactor); auto *E = &addEdge(&*BB, TargetBB, Weight); E->IsCritical = Critical; - DEBUG(dbgs() << " Edge: from " << BB->getName() << " to " - << TargetBB->getName() << " w=" << Weight << "\n"); + LLVM_DEBUG(dbgs() << " Edge: from " << BB->getName() << " to " + << TargetBB->getName() << " w=" << Weight << "\n"); // Keep track of entry/exit edges: if (&*BB == Entry) { @@ -164,8 +164,8 @@ public: MaxExitOutWeight = BBWeight; ExitOutgoing = ExitO; } - DEBUG(dbgs() << " Edge: from " << BB->getName() << " to fake exit" - << " w = " << BBWeight << "\n"); + LLVM_DEBUG(dbgs() << " Edge: from " << BB->getName() << " to fake exit" + << " w = " << BBWeight << "\n"); } } diff --git a/lib/Transforms/Instrumentation/CGProfile.cpp b/lib/Transforms/Instrumentation/CGProfile.cpp new file mode 100644 index 000000000000..9606b3da2475 --- /dev/null +++ b/lib/Transforms/Instrumentation/CGProfile.cpp @@ -0,0 +1,100 @@ +//===-- CGProfile.cpp -----------------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Instrumentation/CGProfile.h" + +#include "llvm/ADT/MapVector.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/PassManager.h" +#include "llvm/ProfileData/InstrProf.h" +#include "llvm/Transforms/Instrumentation.h" + +#include <array> + +using namespace llvm; + +PreservedAnalyses CGProfilePass::run(Module &M, ModuleAnalysisManager &MAM) { + MapVector<std::pair<Function *, Function *>, uint64_t> Counts; + FunctionAnalysisManager &FAM = + MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + InstrProfSymtab Symtab; + auto UpdateCounts = [&](TargetTransformInfo &TTI, Function *F, + Function *CalledF, uint64_t NewCount) { + if (!CalledF || !TTI.isLoweredToCall(CalledF)) + return; + uint64_t &Count = Counts[std::make_pair(F, CalledF)]; + Count = SaturatingAdd(Count, NewCount); + }; + // Ignore error here. Indirect calls are ignored if this fails. + (void)(bool)Symtab.create(M); + for (auto &F : M) { + if (F.isDeclaration()) + continue; + auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F); + if (BFI.getEntryFreq() == 0) + continue; + TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); + for (auto &BB : F) { + Optional<uint64_t> BBCount = BFI.getBlockProfileCount(&BB); + if (!BBCount) + continue; + for (auto &I : BB) { + CallSite CS(&I); + if (!CS) + continue; + if (CS.isIndirectCall()) { + InstrProfValueData ValueData[8]; + uint32_t ActualNumValueData; + uint64_t TotalC; + if (!getValueProfDataFromInst(*CS.getInstruction(), + IPVK_IndirectCallTarget, 8, ValueData, + ActualNumValueData, TotalC)) + continue; + for (const auto &VD : + ArrayRef<InstrProfValueData>(ValueData, ActualNumValueData)) { + UpdateCounts(TTI, &F, Symtab.getFunction(VD.Value), VD.Count); + } + continue; + } + UpdateCounts(TTI, &F, CS.getCalledFunction(), *BBCount); + } + } + } + + addModuleFlags(M, Counts); + + return PreservedAnalyses::all(); +} + +void CGProfilePass::addModuleFlags( + Module &M, + MapVector<std::pair<Function *, Function *>, uint64_t> &Counts) const { + if (Counts.empty()) + return; + + LLVMContext &Context = M.getContext(); + MDBuilder MDB(Context); + std::vector<Metadata *> Nodes; + + for (auto E : Counts) { + SmallVector<Metadata *, 3> Vals; + Vals.push_back(ValueAsMetadata::get(E.first.first)); + Vals.push_back(ValueAsMetadata::get(E.first.second)); + Vals.push_back(MDB.createConstant( + ConstantInt::get(Type::getInt64Ty(Context), E.second))); + Nodes.push_back(MDNode::get(Context, Vals)); + } + + M.addModuleFlag(Module::Append, "CG Profile", MDNode::get(Context, Nodes)); +} diff --git a/lib/Transforms/Instrumentation/CMakeLists.txt b/lib/Transforms/Instrumentation/CMakeLists.txt index 66fdcb3ccc49..5d0084823190 100644 --- a/lib/Transforms/Instrumentation/CMakeLists.txt +++ b/lib/Transforms/Instrumentation/CMakeLists.txt @@ -1,6 +1,7 @@ add_llvm_library(LLVMInstrumentation AddressSanitizer.cpp BoundsChecking.cpp + CGProfile.cpp DataFlowSanitizer.cpp GCOVProfiling.cpp MemorySanitizer.cpp diff --git a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp index 09bcbb282653..bb0e4379d1a8 100644 --- a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp +++ b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp @@ -56,6 +56,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" @@ -90,7 +91,6 @@ #include "llvm/Support/SpecialCaseList.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> #include <cstddef> @@ -211,6 +211,72 @@ class DFSanABIList { } }; +/// TransformedFunction is used to express the result of transforming one +/// function type into another. This struct is immutable. It holds metadata +/// useful for updating calls of the old function to the new type. +struct TransformedFunction { + TransformedFunction(FunctionType* OriginalType, + FunctionType* TransformedType, + std::vector<unsigned> ArgumentIndexMapping) + : OriginalType(OriginalType), + TransformedType(TransformedType), + ArgumentIndexMapping(ArgumentIndexMapping) {} + + // Disallow copies. + TransformedFunction(const TransformedFunction&) = delete; + TransformedFunction& operator=(const TransformedFunction&) = delete; + + // Allow moves. + TransformedFunction(TransformedFunction&&) = default; + TransformedFunction& operator=(TransformedFunction&&) = default; + + /// Type of the function before the transformation. + FunctionType* const OriginalType; + + /// Type of the function after the transformation. + FunctionType* const TransformedType; + + /// Transforming a function may change the position of arguments. This + /// member records the mapping from each argument's old position to its new + /// position. Argument positions are zero-indexed. If the transformation + /// from F to F' made the first argument of F into the third argument of F', + /// then ArgumentIndexMapping[0] will equal 2. + const std::vector<unsigned> ArgumentIndexMapping; +}; + +/// Given function attributes from a call site for the original function, +/// return function attributes appropriate for a call to the transformed +/// function. +AttributeList TransformFunctionAttributes( + const TransformedFunction& TransformedFunction, + LLVMContext& Ctx, AttributeList CallSiteAttrs) { + + // Construct a vector of AttributeSet for each function argument. + std::vector<llvm::AttributeSet> ArgumentAttributes( + TransformedFunction.TransformedType->getNumParams()); + + // Copy attributes from the parameter of the original function to the + // transformed version. 'ArgumentIndexMapping' holds the mapping from + // old argument position to new. + for (unsigned i=0, ie = TransformedFunction.ArgumentIndexMapping.size(); + i < ie; ++i) { + unsigned TransformedIndex = TransformedFunction.ArgumentIndexMapping[i]; + ArgumentAttributes[TransformedIndex] = CallSiteAttrs.getParamAttributes(i); + } + + // Copy annotations on varargs arguments. + for (unsigned i = TransformedFunction.OriginalType->getNumParams(), + ie = CallSiteAttrs.getNumAttrSets(); i<ie; ++i) { + ArgumentAttributes.push_back(CallSiteAttrs.getParamAttributes(i)); + } + + return AttributeList::get( + Ctx, + CallSiteAttrs.getFnAttributes(), + CallSiteAttrs.getRetAttributes(), + llvm::makeArrayRef(ArgumentAttributes)); +} + class DataFlowSanitizer : public ModulePass { friend struct DFSanFunction; friend class DFSanVisitor; @@ -294,7 +360,7 @@ class DataFlowSanitizer : public ModulePass { bool isInstrumented(const GlobalAlias *GA); FunctionType *getArgsFunctionType(FunctionType *T); FunctionType *getTrampolineFunctionType(FunctionType *T); - FunctionType *getCustomFunctionType(FunctionType *T); + TransformedFunction getCustomFunctionType(FunctionType *T); InstrumentedABI getInstrumentedABI(); WrapperKind getWrapperKind(Function *F); void addGlobalNamePrefix(GlobalValue *GV); @@ -437,17 +503,25 @@ FunctionType *DataFlowSanitizer::getTrampolineFunctionType(FunctionType *T) { return FunctionType::get(T->getReturnType(), ArgTypes, false); } -FunctionType *DataFlowSanitizer::getCustomFunctionType(FunctionType *T) { +TransformedFunction DataFlowSanitizer::getCustomFunctionType(FunctionType *T) { SmallVector<Type *, 4> ArgTypes; - for (FunctionType::param_iterator i = T->param_begin(), e = T->param_end(); - i != e; ++i) { + + // Some parameters of the custom function being constructed are + // parameters of T. Record the mapping from parameters of T to + // parameters of the custom function, so that parameter attributes + // at call sites can be updated. + std::vector<unsigned> ArgumentIndexMapping; + for (unsigned i = 0, ie = T->getNumParams(); i != ie; ++i) { + Type* param_type = T->getParamType(i); FunctionType *FT; - if (isa<PointerType>(*i) && (FT = dyn_cast<FunctionType>(cast<PointerType>( - *i)->getElementType()))) { + if (isa<PointerType>(param_type) && (FT = dyn_cast<FunctionType>( + cast<PointerType>(param_type)->getElementType()))) { + ArgumentIndexMapping.push_back(ArgTypes.size()); ArgTypes.push_back(getTrampolineFunctionType(FT)->getPointerTo()); ArgTypes.push_back(Type::getInt8PtrTy(*Ctx)); } else { - ArgTypes.push_back(*i); + ArgumentIndexMapping.push_back(ArgTypes.size()); + ArgTypes.push_back(param_type); } } for (unsigned i = 0, e = T->getNumParams(); i != e; ++i) @@ -457,14 +531,15 @@ FunctionType *DataFlowSanitizer::getCustomFunctionType(FunctionType *T) { Type *RetType = T->getReturnType(); if (!RetType->isVoidTy()) ArgTypes.push_back(ShadowPtrTy); - return FunctionType::get(T->getReturnType(), ArgTypes, T->isVarArg()); + return TransformedFunction( + T, FunctionType::get(T->getReturnType(), ArgTypes, T->isVarArg()), + ArgumentIndexMapping); } bool DataFlowSanitizer::doInitialization(Module &M) { Triple TargetTriple(M.getTargetTriple()); bool IsX86_64 = TargetTriple.getArch() == Triple::x86_64; - bool IsMIPS64 = TargetTriple.getArch() == Triple::mips64 || - TargetTriple.getArch() == Triple::mips64el; + bool IsMIPS64 = TargetTriple.isMIPS64(); bool IsAArch64 = TargetTriple.getArch() == Triple::aarch64 || TargetTriple.getArch() == Triple::aarch64_be; @@ -783,9 +858,17 @@ bool DataFlowSanitizer::runOnModule(Module &M) { FunctionType *NewFT = getInstrumentedABI() == IA_Args ? getArgsFunctionType(FT) : FT; + + // If the function being wrapped has local linkage, then preserve the + // function's linkage in the wrapper function. + GlobalValue::LinkageTypes wrapperLinkage = + F.hasLocalLinkage() + ? F.getLinkage() + : GlobalValue::LinkOnceODRLinkage; + Function *NewF = buildWrapperFunction( &F, std::string("dfsw$") + std::string(F.getName()), - GlobalValue::LinkOnceODRLinkage, NewFT); + wrapperLinkage, NewFT); if (getInstrumentedABI() == IA_TLS) NewF->removeAttributes(AttributeList::FunctionIndex, ReadOnlyNoneAttrs); @@ -1382,20 +1465,19 @@ void DFSanVisitor::visitMemTransferInst(MemTransferInst &I) { Value *LenShadow = IRB.CreateMul( I.getLength(), ConstantInt::get(I.getLength()->getType(), DFSF.DFS.ShadowWidth / 8)); - Value *AlignShadow; - if (ClPreserveAlignment) { - AlignShadow = IRB.CreateMul(I.getAlignmentCst(), - ConstantInt::get(I.getAlignmentCst()->getType(), - DFSF.DFS.ShadowWidth / 8)); - } else { - AlignShadow = ConstantInt::get(I.getAlignmentCst()->getType(), - DFSF.DFS.ShadowWidth / 8); - } Type *Int8Ptr = Type::getInt8PtrTy(*DFSF.DFS.Ctx); DestShadow = IRB.CreateBitCast(DestShadow, Int8Ptr); SrcShadow = IRB.CreateBitCast(SrcShadow, Int8Ptr); - IRB.CreateCall(I.getCalledValue(), {DestShadow, SrcShadow, LenShadow, - AlignShadow, I.getVolatileCst()}); + auto *MTI = cast<MemTransferInst>( + IRB.CreateCall(I.getCalledValue(), + {DestShadow, SrcShadow, LenShadow, I.getVolatileCst()})); + if (ClPreserveAlignment) { + MTI->setDestAlignment(I.getDestAlignment() * (DFSF.DFS.ShadowWidth / 8)); + MTI->setSourceAlignment(I.getSourceAlignment() * (DFSF.DFS.ShadowWidth / 8)); + } else { + MTI->setDestAlignment(DFSF.DFS.ShadowWidth / 8); + MTI->setSourceAlignment(DFSF.DFS.ShadowWidth / 8); + } } void DFSanVisitor::visitReturnInst(ReturnInst &RI) { @@ -1460,11 +1542,11 @@ void DFSanVisitor::visitCallSite(CallSite CS) { // wrapper. if (CallInst *CI = dyn_cast<CallInst>(CS.getInstruction())) { FunctionType *FT = F->getFunctionType(); - FunctionType *CustomFT = DFSF.DFS.getCustomFunctionType(FT); + TransformedFunction CustomFn = DFSF.DFS.getCustomFunctionType(FT); std::string CustomFName = "__dfsw_"; CustomFName += F->getName(); - Constant *CustomF = - DFSF.DFS.Mod->getOrInsertFunction(CustomFName, CustomFT); + Constant *CustomF = DFSF.DFS.Mod->getOrInsertFunction( + CustomFName, CustomFn.TransformedType); if (Function *CustomFn = dyn_cast<Function>(CustomF)) { CustomFn->copyAttributesFrom(F); @@ -1532,7 +1614,8 @@ void DFSanVisitor::visitCallSite(CallSite CS) { CallInst *CustomCI = IRB.CreateCall(CustomF, Args); CustomCI->setCallingConv(CI->getCallingConv()); - CustomCI->setAttributes(CI->getAttributes()); + CustomCI->setAttributes(TransformFunctionAttributes(CustomFn, + CI->getContext(), CI->getAttributes())); // Update the parameter attributes of the custom call instruction to // zero extend the shadow parameters. This is required for targets diff --git a/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp b/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp index 6864d295525c..33f220a893df 100644 --- a/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp +++ b/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" @@ -33,7 +34,6 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" using namespace llvm; @@ -537,7 +537,7 @@ void EfficiencySanitizer::createDestructor(Module &M, Constant *ToolInfoArg) { bool EfficiencySanitizer::initOnModule(Module &M) { Triple TargetTriple(M.getTargetTriple()); - if (TargetTriple.getArch() == Triple::mips64 || TargetTriple.getArch() == Triple::mips64el) + if (TargetTriple.isMIPS64()) ShadowParams = ShadowParams40; else ShadowParams = ShadowParams47; diff --git a/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/lib/Transforms/Instrumentation/GCOVProfiling.cpp index 67ca8172b0d5..acd27c2e226f 100644 --- a/lib/Transforms/Instrumentation/GCOVProfiling.cpp +++ b/lib/Transforms/Instrumentation/GCOVProfiling.cpp @@ -17,11 +17,13 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/UniqueVector.h" #include "llvm/Analysis/EHPersonalities.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/IRBuilder.h" @@ -35,8 +37,8 @@ #include "llvm/Support/FileSystem.h" #include "llvm/Support/Path.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/GCOVProfiler.h" #include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/Instrumentation/GCOVProfiler.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include <algorithm> #include <memory> @@ -84,7 +86,7 @@ public: ReversedVersion[3] = Options.Version[0]; ReversedVersion[4] = '\0'; } - bool runOnModule(Module &M); + bool runOnModule(Module &M, const TargetLibraryInfo &TLI); private: // Create the .gcno files for the Module based on DebugInfo. @@ -130,6 +132,7 @@ private: SmallVector<uint32_t, 4> FileChecksums; Module *M; + const TargetLibraryInfo *TLI; LLVMContext *Ctx; SmallVector<std::unique_ptr<GCOVFunction>, 16> Funcs; }; @@ -145,7 +148,14 @@ public: } StringRef getPassName() const override { return "GCOV Profiler"; } - bool runOnModule(Module &M) override { return Profiler.runOnModule(M); } + bool runOnModule(Module &M) override { + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + return Profiler.runOnModule(M, TLI); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + } private: GCOVProfiler Profiler; @@ -153,8 +163,13 @@ private: } char GCOVProfilerLegacyPass::ID = 0; -INITIALIZE_PASS(GCOVProfilerLegacyPass, "insert-gcov-profiling", - "Insert instrumentation for GCOV profiling", false, false) +INITIALIZE_PASS_BEGIN( + GCOVProfilerLegacyPass, "insert-gcov-profiling", + "Insert instrumentation for GCOV profiling", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END( + GCOVProfilerLegacyPass, "insert-gcov-profiling", + "Insert instrumentation for GCOV profiling", false, false) ModulePass *llvm::createGCOVProfilerPass(const GCOVOptions &Options) { return new GCOVProfilerLegacyPass(Options); @@ -272,7 +287,7 @@ namespace { write(Len); write(Number); - std::sort( + llvm::sort( SortedLinesByFile.begin(), SortedLinesByFile.end(), [](StringMapEntry<GCOVLines> *LHS, StringMapEntry<GCOVLines> *RHS) { return LHS->getKey() < RHS->getKey(); @@ -315,7 +330,7 @@ namespace { ReturnBlock(1, os) { this->os = os; - DEBUG(dbgs() << "Function: " << getFunctionName(SP) << "\n"); + LLVM_DEBUG(dbgs() << "Function: " << getFunctionName(SP) << "\n"); uint32_t i = 0; for (auto &BB : *F) { @@ -383,7 +398,7 @@ namespace { for (int i = 0, e = Blocks.size() + 1; i != e; ++i) { write(0); // No flags on our blocks. } - DEBUG(dbgs() << Blocks.size() << " blocks.\n"); + LLVM_DEBUG(dbgs() << Blocks.size() << " blocks.\n"); // Emit edges between blocks. if (Blocks.empty()) return; @@ -396,8 +411,8 @@ namespace { write(Block.OutEdges.size() * 2 + 1); write(Block.Number); for (int i = 0, e = Block.OutEdges.size(); i != e; ++i) { - DEBUG(dbgs() << Block.Number << " -> " << Block.OutEdges[i]->Number - << "\n"); + LLVM_DEBUG(dbgs() << Block.Number << " -> " + << Block.OutEdges[i]->Number << "\n"); write(Block.OutEdges[i]->Number); write(0); // no flags } @@ -461,8 +476,9 @@ std::string GCOVProfiler::mangleName(const DICompileUnit *CU, return CurPath.str(); } -bool GCOVProfiler::runOnModule(Module &M) { +bool GCOVProfiler::runOnModule(Module &M, const TargetLibraryInfo &TLI) { this->M = &M; + this->TLI = &TLI; Ctx = &M.getContext(); if (Options.EmitNotes) emitProfileNotes(); @@ -475,7 +491,8 @@ PreservedAnalyses GCOVProfilerPass::run(Module &M, GCOVProfiler Profiler(GCOVOpts); - if (!Profiler.runOnModule(M)) + auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); + if (!Profiler.runOnModule(M, TLI)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); @@ -503,11 +520,11 @@ static bool functionHasLines(Function &F) { return false; } -static bool isUsingFuncletBasedEH(Function &F) { +static bool isUsingScopeBasedEH(Function &F) { if (!F.hasPersonalityFn()) return false; EHPersonality Personality = classifyEHPersonality(F.getPersonalityFn()); - return isFuncletEHPersonality(Personality); + return isScopedEHPersonality(Personality); } static bool shouldKeepInEntry(BasicBlock::iterator It) { @@ -550,8 +567,8 @@ void GCOVProfiler::emitProfileNotes() { DISubprogram *SP = F.getSubprogram(); if (!SP) continue; if (!functionHasLines(F)) continue; - // TODO: Functions using funclet-based EH are currently not supported. - if (isUsingFuncletBasedEH(F)) continue; + // TODO: Functions using scope-based EH are currently not supported. + if (isUsingScopeBasedEH(F)) continue; // gcov expects every function to start with an entry block that has a // single successor, so split the entry block to make sure of that. @@ -629,8 +646,8 @@ bool GCOVProfiler::emitProfileArcs() { DISubprogram *SP = F.getSubprogram(); if (!SP) continue; if (!functionHasLines(F)) continue; - // TODO: Functions using funclet-based EH are currently not supported. - if (isUsingFuncletBasedEH(F)) continue; + // TODO: Functions using scope-based EH are currently not supported. + if (isUsingScopeBasedEH(F)) continue; if (!Result) Result = true; unsigned Edges = 0; @@ -807,7 +824,12 @@ Constant *GCOVProfiler::getStartFileFunc() { Type::getInt32Ty(*Ctx), // uint32_t checksum }; FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), Args, false); - return M->getOrInsertFunction("llvm_gcda_start_file", FTy); + auto *Res = M->getOrInsertFunction("llvm_gcda_start_file", FTy); + if (Function *FunRes = dyn_cast<Function>(Res)) + if (auto AK = TLI->getExtAttrForI32Param(false)) + FunRes->addParamAttr(2, AK); + return Res; + } Constant *GCOVProfiler::getIncrementIndirectCounterFunc() { @@ -830,7 +852,15 @@ Constant *GCOVProfiler::getEmitFunctionFunc() { Type::getInt32Ty(*Ctx), // uint32_t cfg_checksum }; FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), Args, false); - return M->getOrInsertFunction("llvm_gcda_emit_function", FTy); + auto *Res = M->getOrInsertFunction("llvm_gcda_emit_function", FTy); + if (Function *FunRes = dyn_cast<Function>(Res)) + if (auto AK = TLI->getExtAttrForI32Param(false)) { + FunRes->addParamAttr(0, AK); + FunRes->addParamAttr(2, AK); + FunRes->addParamAttr(3, AK); + FunRes->addParamAttr(4, AK); + } + return Res; } Constant *GCOVProfiler::getEmitArcsFunc() { @@ -839,7 +869,11 @@ Constant *GCOVProfiler::getEmitArcsFunc() { Type::getInt64PtrTy(*Ctx), // uint64_t *counters }; FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), Args, false); - return M->getOrInsertFunction("llvm_gcda_emit_arcs", FTy); + auto *Res = M->getOrInsertFunction("llvm_gcda_emit_arcs", FTy); + if (Function *FunRes = dyn_cast<Function>(Res)) + if (auto AK = TLI->getExtAttrForI32Param(false)) + FunRes->addParamAttr(0, AK); + return Res; } Constant *GCOVProfiler::getSummaryInfoFunc() { @@ -886,46 +920,205 @@ Function *GCOVProfiler::insertCounterWriteout( Constant *SummaryInfo = getSummaryInfoFunc(); Constant *EndFile = getEndFileFunc(); - NamedMDNode *CU_Nodes = M->getNamedMetadata("llvm.dbg.cu"); - if (CU_Nodes) { - for (unsigned i = 0, e = CU_Nodes->getNumOperands(); i != e; ++i) { - auto *CU = cast<DICompileUnit>(CU_Nodes->getOperand(i)); + NamedMDNode *CUNodes = M->getNamedMetadata("llvm.dbg.cu"); + if (!CUNodes) { + Builder.CreateRetVoid(); + return WriteoutF; + } - // Skip module skeleton (and module) CUs. - if (CU->getDWOId()) - continue; + // Collect the relevant data into a large constant data structure that we can + // walk to write out everything. + StructType *StartFileCallArgsTy = StructType::create( + {Builder.getInt8PtrTy(), Builder.getInt8PtrTy(), Builder.getInt32Ty()}); + StructType *EmitFunctionCallArgsTy = StructType::create( + {Builder.getInt32Ty(), Builder.getInt8PtrTy(), Builder.getInt32Ty(), + Builder.getInt8Ty(), Builder.getInt32Ty()}); + StructType *EmitArcsCallArgsTy = StructType::create( + {Builder.getInt32Ty(), Builder.getInt64Ty()->getPointerTo()}); + StructType *FileInfoTy = + StructType::create({StartFileCallArgsTy, Builder.getInt32Ty(), + EmitFunctionCallArgsTy->getPointerTo(), + EmitArcsCallArgsTy->getPointerTo()}); + + Constant *Zero32 = Builder.getInt32(0); + // Build an explicit array of two zeros for use in ConstantExpr GEP building. + Constant *TwoZero32s[] = {Zero32, Zero32}; + + SmallVector<Constant *, 8> FileInfos; + for (int i : llvm::seq<int>(0, CUNodes->getNumOperands())) { + auto *CU = cast<DICompileUnit>(CUNodes->getOperand(i)); - std::string FilenameGcda = mangleName(CU, GCovFileType::GCDA); - uint32_t CfgChecksum = FileChecksums.empty() ? 0 : FileChecksums[i]; - Builder.CreateCall(StartFile, - {Builder.CreateGlobalStringPtr(FilenameGcda), - Builder.CreateGlobalStringPtr(ReversedVersion), - Builder.getInt32(CfgChecksum)}); - for (unsigned j = 0, e = CountersBySP.size(); j != e; ++j) { - auto *SP = cast_or_null<DISubprogram>(CountersBySP[j].second); - uint32_t FuncChecksum = Funcs.empty() ? 0 : Funcs[j]->getFuncChecksum(); - Builder.CreateCall( - EmitFunction, - {Builder.getInt32(j), - Options.FunctionNamesInData - ? Builder.CreateGlobalStringPtr(getFunctionName(SP)) - : Constant::getNullValue(Builder.getInt8PtrTy()), - Builder.getInt32(FuncChecksum), - Builder.getInt8(Options.UseCfgChecksum), - Builder.getInt32(CfgChecksum)}); - - GlobalVariable *GV = CountersBySP[j].first; - unsigned Arcs = - cast<ArrayType>(GV->getValueType())->getNumElements(); - Builder.CreateCall(EmitArcs, {Builder.getInt32(Arcs), - Builder.CreateConstGEP2_64(GV, 0, 0)}); - } - Builder.CreateCall(SummaryInfo, {}); - Builder.CreateCall(EndFile, {}); + // Skip module skeleton (and module) CUs. + if (CU->getDWOId()) + continue; + + std::string FilenameGcda = mangleName(CU, GCovFileType::GCDA); + uint32_t CfgChecksum = FileChecksums.empty() ? 0 : FileChecksums[i]; + auto *StartFileCallArgs = ConstantStruct::get( + StartFileCallArgsTy, {Builder.CreateGlobalStringPtr(FilenameGcda), + Builder.CreateGlobalStringPtr(ReversedVersion), + Builder.getInt32(CfgChecksum)}); + + SmallVector<Constant *, 8> EmitFunctionCallArgsArray; + SmallVector<Constant *, 8> EmitArcsCallArgsArray; + for (int j : llvm::seq<int>(0, CountersBySP.size())) { + auto *SP = cast_or_null<DISubprogram>(CountersBySP[j].second); + uint32_t FuncChecksum = Funcs.empty() ? 0 : Funcs[j]->getFuncChecksum(); + EmitFunctionCallArgsArray.push_back(ConstantStruct::get( + EmitFunctionCallArgsTy, + {Builder.getInt32(j), + Options.FunctionNamesInData + ? Builder.CreateGlobalStringPtr(getFunctionName(SP)) + : Constant::getNullValue(Builder.getInt8PtrTy()), + Builder.getInt32(FuncChecksum), + Builder.getInt8(Options.UseCfgChecksum), + Builder.getInt32(CfgChecksum)})); + + GlobalVariable *GV = CountersBySP[j].first; + unsigned Arcs = cast<ArrayType>(GV->getValueType())->getNumElements(); + EmitArcsCallArgsArray.push_back(ConstantStruct::get( + EmitArcsCallArgsTy, + {Builder.getInt32(Arcs), ConstantExpr::getInBoundsGetElementPtr( + GV->getValueType(), GV, TwoZero32s)})); } + // Create global arrays for the two emit calls. + int CountersSize = CountersBySP.size(); + assert(CountersSize == (int)EmitFunctionCallArgsArray.size() && + "Mismatched array size!"); + assert(CountersSize == (int)EmitArcsCallArgsArray.size() && + "Mismatched array size!"); + auto *EmitFunctionCallArgsArrayTy = + ArrayType::get(EmitFunctionCallArgsTy, CountersSize); + auto *EmitFunctionCallArgsArrayGV = new GlobalVariable( + *M, EmitFunctionCallArgsArrayTy, /*isConstant*/ true, + GlobalValue::InternalLinkage, + ConstantArray::get(EmitFunctionCallArgsArrayTy, + EmitFunctionCallArgsArray), + Twine("__llvm_internal_gcov_emit_function_args.") + Twine(i)); + auto *EmitArcsCallArgsArrayTy = + ArrayType::get(EmitArcsCallArgsTy, CountersSize); + EmitFunctionCallArgsArrayGV->setUnnamedAddr( + GlobalValue::UnnamedAddr::Global); + auto *EmitArcsCallArgsArrayGV = new GlobalVariable( + *M, EmitArcsCallArgsArrayTy, /*isConstant*/ true, + GlobalValue::InternalLinkage, + ConstantArray::get(EmitArcsCallArgsArrayTy, EmitArcsCallArgsArray), + Twine("__llvm_internal_gcov_emit_arcs_args.") + Twine(i)); + EmitArcsCallArgsArrayGV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); + + FileInfos.push_back(ConstantStruct::get( + FileInfoTy, + {StartFileCallArgs, Builder.getInt32(CountersSize), + ConstantExpr::getInBoundsGetElementPtr(EmitFunctionCallArgsArrayTy, + EmitFunctionCallArgsArrayGV, + TwoZero32s), + ConstantExpr::getInBoundsGetElementPtr( + EmitArcsCallArgsArrayTy, EmitArcsCallArgsArrayGV, TwoZero32s)})); } + // If we didn't find anything to actually emit, bail on out. + if (FileInfos.empty()) { + Builder.CreateRetVoid(); + return WriteoutF; + } + + // To simplify code, we cap the number of file infos we write out to fit + // easily in a 32-bit signed integer. This gives consistent behavior between + // 32-bit and 64-bit systems without requiring (potentially very slow) 64-bit + // operations on 32-bit systems. It also seems unreasonable to try to handle + // more than 2 billion files. + if ((int64_t)FileInfos.size() > (int64_t)INT_MAX) + FileInfos.resize(INT_MAX); + + // Create a global for the entire data structure so we can walk it more + // easily. + auto *FileInfoArrayTy = ArrayType::get(FileInfoTy, FileInfos.size()); + auto *FileInfoArrayGV = new GlobalVariable( + *M, FileInfoArrayTy, /*isConstant*/ true, GlobalValue::InternalLinkage, + ConstantArray::get(FileInfoArrayTy, FileInfos), + "__llvm_internal_gcov_emit_file_info"); + FileInfoArrayGV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); + + // Create the CFG for walking this data structure. + auto *FileLoopHeader = + BasicBlock::Create(*Ctx, "file.loop.header", WriteoutF); + auto *CounterLoopHeader = + BasicBlock::Create(*Ctx, "counter.loop.header", WriteoutF); + auto *FileLoopLatch = BasicBlock::Create(*Ctx, "file.loop.latch", WriteoutF); + auto *ExitBB = BasicBlock::Create(*Ctx, "exit", WriteoutF); + + // We always have at least one file, so just branch to the header. + Builder.CreateBr(FileLoopHeader); + + // The index into the files structure is our loop induction variable. + Builder.SetInsertPoint(FileLoopHeader); + PHINode *IV = + Builder.CreatePHI(Builder.getInt32Ty(), /*NumReservedValues*/ 2); + IV->addIncoming(Builder.getInt32(0), BB); + auto *FileInfoPtr = + Builder.CreateInBoundsGEP(FileInfoArrayGV, {Builder.getInt32(0), IV}); + auto *StartFileCallArgsPtr = Builder.CreateStructGEP(FileInfoPtr, 0); + auto *StartFileCall = Builder.CreateCall( + StartFile, + {Builder.CreateLoad(Builder.CreateStructGEP(StartFileCallArgsPtr, 0)), + Builder.CreateLoad(Builder.CreateStructGEP(StartFileCallArgsPtr, 1)), + Builder.CreateLoad(Builder.CreateStructGEP(StartFileCallArgsPtr, 2))}); + if (auto AK = TLI->getExtAttrForI32Param(false)) + StartFileCall->addParamAttr(2, AK); + auto *NumCounters = + Builder.CreateLoad(Builder.CreateStructGEP(FileInfoPtr, 1)); + auto *EmitFunctionCallArgsArray = + Builder.CreateLoad(Builder.CreateStructGEP(FileInfoPtr, 2)); + auto *EmitArcsCallArgsArray = + Builder.CreateLoad(Builder.CreateStructGEP(FileInfoPtr, 3)); + auto *EnterCounterLoopCond = + Builder.CreateICmpSLT(Builder.getInt32(0), NumCounters); + Builder.CreateCondBr(EnterCounterLoopCond, CounterLoopHeader, FileLoopLatch); + + Builder.SetInsertPoint(CounterLoopHeader); + auto *JV = Builder.CreatePHI(Builder.getInt32Ty(), /*NumReservedValues*/ 2); + JV->addIncoming(Builder.getInt32(0), FileLoopHeader); + auto *EmitFunctionCallArgsPtr = + Builder.CreateInBoundsGEP(EmitFunctionCallArgsArray, {JV}); + auto *EmitFunctionCall = Builder.CreateCall( + EmitFunction, + {Builder.CreateLoad(Builder.CreateStructGEP(EmitFunctionCallArgsPtr, 0)), + Builder.CreateLoad(Builder.CreateStructGEP(EmitFunctionCallArgsPtr, 1)), + Builder.CreateLoad(Builder.CreateStructGEP(EmitFunctionCallArgsPtr, 2)), + Builder.CreateLoad(Builder.CreateStructGEP(EmitFunctionCallArgsPtr, 3)), + Builder.CreateLoad( + Builder.CreateStructGEP(EmitFunctionCallArgsPtr, 4))}); + if (auto AK = TLI->getExtAttrForI32Param(false)) { + EmitFunctionCall->addParamAttr(0, AK); + EmitFunctionCall->addParamAttr(2, AK); + EmitFunctionCall->addParamAttr(3, AK); + EmitFunctionCall->addParamAttr(4, AK); + } + auto *EmitArcsCallArgsPtr = + Builder.CreateInBoundsGEP(EmitArcsCallArgsArray, {JV}); + auto *EmitArcsCall = Builder.CreateCall( + EmitArcs, + {Builder.CreateLoad(Builder.CreateStructGEP(EmitArcsCallArgsPtr, 0)), + Builder.CreateLoad(Builder.CreateStructGEP(EmitArcsCallArgsPtr, 1))}); + if (auto AK = TLI->getExtAttrForI32Param(false)) + EmitArcsCall->addParamAttr(0, AK); + auto *NextJV = Builder.CreateAdd(JV, Builder.getInt32(1)); + auto *CounterLoopCond = Builder.CreateICmpSLT(NextJV, NumCounters); + Builder.CreateCondBr(CounterLoopCond, CounterLoopHeader, FileLoopLatch); + JV->addIncoming(NextJV, CounterLoopHeader); + + Builder.SetInsertPoint(FileLoopLatch); + Builder.CreateCall(SummaryInfo, {}); + Builder.CreateCall(EndFile, {}); + auto *NextIV = Builder.CreateAdd(IV, Builder.getInt32(1)); + auto *FileLoopCond = + Builder.CreateICmpSLT(NextIV, Builder.getInt32(FileInfos.size())); + Builder.CreateCondBr(FileLoopCond, FileLoopHeader, ExitBB); + IV->addIncoming(NextIV, FileLoopLatch); + + Builder.SetInsertPoint(ExitBB); Builder.CreateRetVoid(); + return WriteoutF; } diff --git a/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp index 8e2833d22032..d62598bb5d4f 100644 --- a/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -22,10 +22,7 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/MDBuilder.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/IR/Function.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/InstVisitor.h" @@ -34,6 +31,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" @@ -41,8 +39,11 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/ModuleUtils.h" +#include "llvm/Transforms/Utils/PromoteMemToReg.h" using namespace llvm; @@ -51,10 +52,15 @@ using namespace llvm; static const char *const kHwasanModuleCtorName = "hwasan.module_ctor"; static const char *const kHwasanInitName = "__hwasan_init"; +static const char *const kHwasanShadowMemoryDynamicAddress = + "__hwasan_shadow_memory_dynamic_address"; + // Accesses sizes are powers of two: 1, 2, 4, 8, 16. static const size_t kNumberOfAccessSizes = 5; -static const size_t kShadowScale = 4; +static const size_t kDefaultShadowScale = 4; +static const uint64_t kDynamicShadowSentinel = + std::numeric_limits<uint64_t>::max(); static const unsigned kPointerTagShift = 56; static cl::opt<std::string> ClMemoryAccessCallbackPrefix( @@ -85,17 +91,57 @@ static cl::opt<bool> ClRecover( cl::desc("Enable recovery mode (continue-after-error)."), cl::Hidden, cl::init(false)); +static cl::opt<bool> ClInstrumentStack("hwasan-instrument-stack", + cl::desc("instrument stack (allocas)"), + cl::Hidden, cl::init(true)); + +static cl::opt<bool> ClUARRetagToZero( + "hwasan-uar-retag-to-zero", + cl::desc("Clear alloca tags before returning from the function to allow " + "non-instrumented and instrumented function calls mix. When set " + "to false, allocas are retagged before returning from the " + "function to detect use after return."), + cl::Hidden, cl::init(true)); + +static cl::opt<bool> ClGenerateTagsWithCalls( + "hwasan-generate-tags-with-calls", + cl::desc("generate new tags with runtime library calls"), cl::Hidden, + cl::init(false)); + +static cl::opt<int> ClMatchAllTag( + "hwasan-match-all-tag", + cl::desc("don't report bad accesses via pointers with this tag"), + cl::Hidden, cl::init(-1)); + +static cl::opt<bool> ClEnableKhwasan( + "hwasan-kernel", + cl::desc("Enable KernelHWAddressSanitizer instrumentation"), + cl::Hidden, cl::init(false)); + +// These flags allow to change the shadow mapping and control how shadow memory +// is accessed. The shadow mapping looks like: +// Shadow = (Mem >> scale) + offset + +static cl::opt<unsigned long long> ClMappingOffset( + "hwasan-mapping-offset", + cl::desc("HWASan shadow mapping offset [EXPERIMENTAL]"), cl::Hidden, + cl::init(0)); + namespace { -/// \brief An instrumentation pass implementing detection of addressability bugs +/// An instrumentation pass implementing detection of addressability bugs /// using tagged pointers. class HWAddressSanitizer : public FunctionPass { public: // Pass identification, replacement for typeid. static char ID; - HWAddressSanitizer(bool Recover = false) - : FunctionPass(ID), Recover(Recover || ClRecover) {} + explicit HWAddressSanitizer(bool CompileKernel = false, bool Recover = false) + : FunctionPass(ID) { + this->Recover = ClRecover.getNumOccurrences() > 0 ? ClRecover : Recover; + this->CompileKernel = ClEnableKhwasan.getNumOccurrences() > 0 ? + ClEnableKhwasan : CompileKernel; + } StringRef getPassName() const override { return "HWAddressSanitizer"; } @@ -103,6 +149,11 @@ public: bool doInitialization(Module &M) override; void initializeCallbacks(Module &M); + + void maybeInsertDynamicShadowAtFunctionEntry(Function &F); + + void untagPointerOperand(Instruction *I, Value *Addr); + Value *memToShadow(Value *Shadow, Type *Ty, IRBuilder<> &IRB); void instrumentMemAccessInline(Value *PtrLong, bool IsWrite, unsigned AccessSizeIndex, Instruction *InsertBefore); @@ -111,16 +162,54 @@ public: uint64_t *TypeSize, unsigned *Alignment, Value **MaybeMask); + bool isInterestingAlloca(const AllocaInst &AI); + bool tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag); + Value *tagPointer(IRBuilder<> &IRB, Type *Ty, Value *PtrLong, Value *Tag); + Value *untagPointer(IRBuilder<> &IRB, Value *PtrLong); + bool instrumentStack(SmallVectorImpl<AllocaInst *> &Allocas, + SmallVectorImpl<Instruction *> &RetVec); + Value *getNextTagWithCall(IRBuilder<> &IRB); + Value *getStackBaseTag(IRBuilder<> &IRB); + Value *getAllocaTag(IRBuilder<> &IRB, Value *StackTag, AllocaInst *AI, + unsigned AllocaNo); + Value *getUARTag(IRBuilder<> &IRB, Value *StackTag); + private: LLVMContext *C; + Triple TargetTriple; + + /// This struct defines the shadow mapping using the rule: + /// shadow = (mem >> Scale) + Offset. + /// If InGlobal is true, then + /// extern char __hwasan_shadow[]; + /// shadow = (mem >> Scale) + &__hwasan_shadow + struct ShadowMapping { + int Scale; + uint64_t Offset; + bool InGlobal; + + void init(Triple &TargetTriple); + unsigned getAllocaAlignment() const { return 1U << Scale; } + }; + ShadowMapping Mapping; + Type *IntptrTy; + Type *Int8Ty; + bool CompileKernel; bool Recover; Function *HwasanCtorFunction; Function *HwasanMemoryAccessCallback[2][kNumberOfAccessSizes]; Function *HwasanMemoryAccessCallbackSized[2]; + + Function *HwasanTagMemoryFunc; + Function *HwasanGenerateTagFunc; + + Constant *ShadowGlobal; + + Value *LocalDynamicShadow = nullptr; }; } // end anonymous namespace @@ -129,34 +218,44 @@ char HWAddressSanitizer::ID = 0; INITIALIZE_PASS_BEGIN( HWAddressSanitizer, "hwasan", - "HWAddressSanitizer: detect memory bugs using tagged addressing.", false, false) + "HWAddressSanitizer: detect memory bugs using tagged addressing.", false, + false) INITIALIZE_PASS_END( HWAddressSanitizer, "hwasan", - "HWAddressSanitizer: detect memory bugs using tagged addressing.", false, false) + "HWAddressSanitizer: detect memory bugs using tagged addressing.", false, + false) -FunctionPass *llvm::createHWAddressSanitizerPass(bool Recover) { - return new HWAddressSanitizer(Recover); +FunctionPass *llvm::createHWAddressSanitizerPass(bool CompileKernel, + bool Recover) { + assert(!CompileKernel || Recover); + return new HWAddressSanitizer(CompileKernel, Recover); } -/// \brief Module-level initialization. +/// Module-level initialization. /// /// inserts a call to __hwasan_init to the module's constructor list. bool HWAddressSanitizer::doInitialization(Module &M) { - DEBUG(dbgs() << "Init " << M.getName() << "\n"); + LLVM_DEBUG(dbgs() << "Init " << M.getName() << "\n"); auto &DL = M.getDataLayout(); - Triple TargetTriple(M.getTargetTriple()); + TargetTriple = Triple(M.getTargetTriple()); + + Mapping.init(TargetTriple); C = &(M.getContext()); IRBuilder<> IRB(*C); IntptrTy = IRB.getIntPtrTy(DL); - - std::tie(HwasanCtorFunction, std::ignore) = - createSanitizerCtorAndInitFunctions(M, kHwasanModuleCtorName, - kHwasanInitName, - /*InitArgTypes=*/{}, - /*InitArgs=*/{}); - appendToGlobalCtors(M, HwasanCtorFunction, 0); + Int8Ty = IRB.getInt8Ty(); + + HwasanCtorFunction = nullptr; + if (!CompileKernel) { + std::tie(HwasanCtorFunction, std::ignore) = + createSanitizerCtorAndInitFunctions(M, kHwasanModuleCtorName, + kHwasanInitName, + /*InitArgTypes=*/{}, + /*InitArgs=*/{}); + appendToGlobalCtors(M, HwasanCtorFunction, 0); + } return true; } @@ -168,7 +267,7 @@ void HWAddressSanitizer::initializeCallbacks(Module &M) { HwasanMemoryAccessCallbackSized[AccessIsWrite] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - ClMemoryAccessCallbackPrefix + TypeStr + EndingStr, + ClMemoryAccessCallbackPrefix + TypeStr + "N" + EndingStr, FunctionType::get(IRB.getVoidTy(), {IntptrTy, IntptrTy}, false))); for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes; @@ -180,16 +279,50 @@ void HWAddressSanitizer::initializeCallbacks(Module &M) { FunctionType::get(IRB.getVoidTy(), {IntptrTy}, false))); } } + + HwasanTagMemoryFunc = checkSanitizerInterfaceFunction(M.getOrInsertFunction( + "__hwasan_tag_memory", IRB.getVoidTy(), IntptrTy, Int8Ty, IntptrTy)); + HwasanGenerateTagFunc = checkSanitizerInterfaceFunction( + M.getOrInsertFunction("__hwasan_generate_tag", Int8Ty)); + + if (Mapping.InGlobal) + ShadowGlobal = M.getOrInsertGlobal("__hwasan_shadow", + ArrayType::get(IRB.getInt8Ty(), 0)); +} + +void HWAddressSanitizer::maybeInsertDynamicShadowAtFunctionEntry(Function &F) { + // Generate code only when dynamic addressing is needed. + if (Mapping.Offset != kDynamicShadowSentinel) + return; + + IRBuilder<> IRB(&F.front().front()); + if (Mapping.InGlobal) { + // An empty inline asm with input reg == output reg. + // An opaque pointer-to-int cast, basically. + InlineAsm *Asm = InlineAsm::get( + FunctionType::get(IntptrTy, {ShadowGlobal->getType()}, false), + StringRef(""), StringRef("=r,0"), + /*hasSideEffects=*/false); + LocalDynamicShadow = IRB.CreateCall(Asm, {ShadowGlobal}, ".hwasan.shadow"); + } else { + Value *GlobalDynamicAddress = F.getParent()->getOrInsertGlobal( + kHwasanShadowMemoryDynamicAddress, IntptrTy); + LocalDynamicShadow = IRB.CreateLoad(GlobalDynamicAddress); + } } Value *HWAddressSanitizer::isInterestingMemoryAccess(Instruction *I, - bool *IsWrite, - uint64_t *TypeSize, - unsigned *Alignment, - Value **MaybeMask) { + bool *IsWrite, + uint64_t *TypeSize, + unsigned *Alignment, + Value **MaybeMask) { // Skip memory accesses inserted by another instrumentation. if (I->getMetadata("nosanitize")) return nullptr; + // Do not instrument the load fetching the dynamic shadow address. + if (LocalDynamicShadow == I) + return nullptr; + Value *PtrOperand = nullptr; const DataLayout &DL = I->getModule()->getDataLayout(); if (LoadInst *LI = dyn_cast<LoadInst>(I)) { @@ -219,7 +352,7 @@ Value *HWAddressSanitizer::isInterestingMemoryAccess(Instruction *I, } if (PtrOperand) { - // Do not instrument acesses from different address spaces; we cannot deal + // Do not instrument accesses from different address spaces; we cannot deal // with them. Type *PtrTy = cast<PointerType>(PtrOperand->getType()->getScalarType()); if (PtrTy->getPointerAddressSpace() != 0) @@ -236,41 +369,103 @@ Value *HWAddressSanitizer::isInterestingMemoryAccess(Instruction *I, return PtrOperand; } +static unsigned getPointerOperandIndex(Instruction *I) { + if (LoadInst *LI = dyn_cast<LoadInst>(I)) + return LI->getPointerOperandIndex(); + if (StoreInst *SI = dyn_cast<StoreInst>(I)) + return SI->getPointerOperandIndex(); + if (AtomicRMWInst *RMW = dyn_cast<AtomicRMWInst>(I)) + return RMW->getPointerOperandIndex(); + if (AtomicCmpXchgInst *XCHG = dyn_cast<AtomicCmpXchgInst>(I)) + return XCHG->getPointerOperandIndex(); + report_fatal_error("Unexpected instruction"); + return -1; +} + static size_t TypeSizeToSizeIndex(uint32_t TypeSize) { size_t Res = countTrailingZeros(TypeSize / 8); assert(Res < kNumberOfAccessSizes); return Res; } +void HWAddressSanitizer::untagPointerOperand(Instruction *I, Value *Addr) { + if (TargetTriple.isAArch64()) + return; + + IRBuilder<> IRB(I); + Value *AddrLong = IRB.CreatePointerCast(Addr, IntptrTy); + Value *UntaggedPtr = + IRB.CreateIntToPtr(untagPointer(IRB, AddrLong), Addr->getType()); + I->setOperand(getPointerOperandIndex(I), UntaggedPtr); +} + +Value *HWAddressSanitizer::memToShadow(Value *Mem, Type *Ty, IRBuilder<> &IRB) { + // Mem >> Scale + Value *Shadow = IRB.CreateLShr(Mem, Mapping.Scale); + if (Mapping.Offset == 0) + return Shadow; + // (Mem >> Scale) + Offset + Value *ShadowBase; + if (LocalDynamicShadow) + ShadowBase = LocalDynamicShadow; + else + ShadowBase = ConstantInt::get(Ty, Mapping.Offset); + return IRB.CreateAdd(Shadow, ShadowBase); +} + void HWAddressSanitizer::instrumentMemAccessInline(Value *PtrLong, bool IsWrite, unsigned AccessSizeIndex, Instruction *InsertBefore) { IRBuilder<> IRB(InsertBefore); - Value *PtrTag = IRB.CreateTrunc(IRB.CreateLShr(PtrLong, kPointerTagShift), IRB.getInt8Ty()); - Value *AddrLong = - IRB.CreateAnd(PtrLong, ConstantInt::get(PtrLong->getType(), - ~(0xFFULL << kPointerTagShift))); - Value *ShadowLong = IRB.CreateLShr(AddrLong, kShadowScale); - Value *MemTag = IRB.CreateLoad(IRB.CreateIntToPtr(ShadowLong, IRB.getInt8PtrTy())); + Value *PtrTag = IRB.CreateTrunc(IRB.CreateLShr(PtrLong, kPointerTagShift), + IRB.getInt8Ty()); + Value *AddrLong = untagPointer(IRB, PtrLong); + Value *ShadowLong = memToShadow(AddrLong, PtrLong->getType(), IRB); + Value *MemTag = + IRB.CreateLoad(IRB.CreateIntToPtr(ShadowLong, IRB.getInt8PtrTy())); Value *TagMismatch = IRB.CreateICmpNE(PtrTag, MemTag); + int matchAllTag = ClMatchAllTag.getNumOccurrences() > 0 ? + ClMatchAllTag : (CompileKernel ? 0xFF : -1); + if (matchAllTag != -1) { + Value *TagNotIgnored = IRB.CreateICmpNE(PtrTag, + ConstantInt::get(PtrTag->getType(), matchAllTag)); + TagMismatch = IRB.CreateAnd(TagMismatch, TagNotIgnored); + } + TerminatorInst *CheckTerm = SplitBlockAndInsertIfThen(TagMismatch, InsertBefore, !Recover, MDBuilder(*C).createBranchWeights(1, 100000)); IRB.SetInsertPoint(CheckTerm); - // The signal handler will find the data address in x0. - InlineAsm *Asm = InlineAsm::get( - FunctionType::get(IRB.getVoidTy(), {PtrLong->getType()}, false), - "hlt #" + - itostr(0x100 + Recover * 0x20 + IsWrite * 0x10 + AccessSizeIndex), - "{x0}", - /*hasSideEffects=*/true); + const int64_t AccessInfo = Recover * 0x20 + IsWrite * 0x10 + AccessSizeIndex; + InlineAsm *Asm; + switch (TargetTriple.getArch()) { + case Triple::x86_64: + // The signal handler will find the data address in rdi. + Asm = InlineAsm::get( + FunctionType::get(IRB.getVoidTy(), {PtrLong->getType()}, false), + "int3\nnopl " + itostr(0x40 + AccessInfo) + "(%rax)", + "{rdi}", + /*hasSideEffects=*/true); + break; + case Triple::aarch64: + case Triple::aarch64_be: + // The signal handler will find the data address in x0. + Asm = InlineAsm::get( + FunctionType::get(IRB.getVoidTy(), {PtrLong->getType()}, false), + "brk #" + itostr(0x900 + AccessInfo), + "{x0}", + /*hasSideEffects=*/true); + break; + default: + report_fatal_error("unsupported architecture"); + } IRB.CreateCall(Asm, PtrLong); } bool HWAddressSanitizer::instrumentMemAccess(Instruction *I) { - DEBUG(dbgs() << "Instrumenting: " << *I << "\n"); + LLVM_DEBUG(dbgs() << "Instrumenting: " << *I << "\n"); bool IsWrite = false; unsigned Alignment = 0; uint64_t TypeSize = 0; @@ -288,7 +483,7 @@ bool HWAddressSanitizer::instrumentMemAccess(Instruction *I) { Value *AddrLong = IRB.CreatePointerCast(Addr, IntptrTy); if (isPowerOf2_64(TypeSize) && (TypeSize / 8 <= (1UL << (kNumberOfAccessSizes - 1))) && - (Alignment >= (1UL << kShadowScale) || Alignment == 0 || + (Alignment >= (1UL << Mapping.Scale) || Alignment == 0 || Alignment >= TypeSize / 8)) { size_t AccessSizeIndex = TypeSizeToSizeIndex(TypeSize); if (ClInstrumentWithCalls) { @@ -301,10 +496,197 @@ bool HWAddressSanitizer::instrumentMemAccess(Instruction *I) { IRB.CreateCall(HwasanMemoryAccessCallbackSized[IsWrite], {AddrLong, ConstantInt::get(IntptrTy, TypeSize / 8)}); } + untagPointerOperand(I, Addr); return true; } +static uint64_t getAllocaSizeInBytes(const AllocaInst &AI) { + uint64_t ArraySize = 1; + if (AI.isArrayAllocation()) { + const ConstantInt *CI = dyn_cast<ConstantInt>(AI.getArraySize()); + assert(CI && "non-constant array size"); + ArraySize = CI->getZExtValue(); + } + Type *Ty = AI.getAllocatedType(); + uint64_t SizeInBytes = AI.getModule()->getDataLayout().getTypeAllocSize(Ty); + return SizeInBytes * ArraySize; +} + +bool HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, + Value *Tag) { + size_t Size = (getAllocaSizeInBytes(*AI) + Mapping.getAllocaAlignment() - 1) & + ~(Mapping.getAllocaAlignment() - 1); + + Value *JustTag = IRB.CreateTrunc(Tag, IRB.getInt8Ty()); + if (ClInstrumentWithCalls) { + IRB.CreateCall(HwasanTagMemoryFunc, + {IRB.CreatePointerCast(AI, IntptrTy), JustTag, + ConstantInt::get(IntptrTy, Size)}); + } else { + size_t ShadowSize = Size >> Mapping.Scale; + Value *ShadowPtr = IRB.CreateIntToPtr( + memToShadow(IRB.CreatePointerCast(AI, IntptrTy), AI->getType(), IRB), + IRB.getInt8PtrTy()); + // If this memset is not inlined, it will be intercepted in the hwasan + // runtime library. That's OK, because the interceptor skips the checks if + // the address is in the shadow region. + // FIXME: the interceptor is not as fast as real memset. Consider lowering + // llvm.memset right here into either a sequence of stores, or a call to + // hwasan_tag_memory. + IRB.CreateMemSet(ShadowPtr, JustTag, ShadowSize, /*Align=*/1); + } + return true; +} + +static unsigned RetagMask(unsigned AllocaNo) { + // A list of 8-bit numbers that have at most one run of non-zero bits. + // x = x ^ (mask << 56) can be encoded as a single armv8 instruction for these + // masks. + // The list does not include the value 255, which is used for UAR. + static unsigned FastMasks[] = { + 0, 1, 2, 3, 4, 6, 7, 8, 12, 14, 15, 16, 24, + 28, 30, 31, 32, 48, 56, 60, 62, 63, 64, 96, 112, 120, + 124, 126, 127, 128, 192, 224, 240, 248, 252, 254}; + return FastMasks[AllocaNo % (sizeof(FastMasks) / sizeof(FastMasks[0]))]; +} + +Value *HWAddressSanitizer::getNextTagWithCall(IRBuilder<> &IRB) { + return IRB.CreateZExt(IRB.CreateCall(HwasanGenerateTagFunc), IntptrTy); +} + +Value *HWAddressSanitizer::getStackBaseTag(IRBuilder<> &IRB) { + if (ClGenerateTagsWithCalls) + return nullptr; + // FIXME: use addressofreturnaddress (but implement it in aarch64 backend + // first). + Module *M = IRB.GetInsertBlock()->getParent()->getParent(); + auto GetStackPointerFn = + Intrinsic::getDeclaration(M, Intrinsic::frameaddress); + Value *StackPointer = IRB.CreateCall( + GetStackPointerFn, {Constant::getNullValue(IRB.getInt32Ty())}); + + // Extract some entropy from the stack pointer for the tags. + // Take bits 20..28 (ASLR entropy) and xor with bits 0..8 (these differ + // between functions). + Value *StackPointerLong = IRB.CreatePointerCast(StackPointer, IntptrTy); + Value *StackTag = + IRB.CreateXor(StackPointerLong, IRB.CreateLShr(StackPointerLong, 20), + "hwasan.stack.base.tag"); + return StackTag; +} + +Value *HWAddressSanitizer::getAllocaTag(IRBuilder<> &IRB, Value *StackTag, + AllocaInst *AI, unsigned AllocaNo) { + if (ClGenerateTagsWithCalls) + return getNextTagWithCall(IRB); + return IRB.CreateXor(StackTag, + ConstantInt::get(IntptrTy, RetagMask(AllocaNo))); +} + +Value *HWAddressSanitizer::getUARTag(IRBuilder<> &IRB, Value *StackTag) { + if (ClUARRetagToZero) + return ConstantInt::get(IntptrTy, 0); + if (ClGenerateTagsWithCalls) + return getNextTagWithCall(IRB); + return IRB.CreateXor(StackTag, ConstantInt::get(IntptrTy, 0xFFU)); +} + +// Add a tag to an address. +Value *HWAddressSanitizer::tagPointer(IRBuilder<> &IRB, Type *Ty, + Value *PtrLong, Value *Tag) { + Value *TaggedPtrLong; + if (CompileKernel) { + // Kernel addresses have 0xFF in the most significant byte. + Value *ShiftedTag = IRB.CreateOr( + IRB.CreateShl(Tag, kPointerTagShift), + ConstantInt::get(IntptrTy, (1ULL << kPointerTagShift) - 1)); + TaggedPtrLong = IRB.CreateAnd(PtrLong, ShiftedTag); + } else { + // Userspace can simply do OR (tag << 56); + Value *ShiftedTag = IRB.CreateShl(Tag, kPointerTagShift); + TaggedPtrLong = IRB.CreateOr(PtrLong, ShiftedTag); + } + return IRB.CreateIntToPtr(TaggedPtrLong, Ty); +} + +// Remove tag from an address. +Value *HWAddressSanitizer::untagPointer(IRBuilder<> &IRB, Value *PtrLong) { + Value *UntaggedPtrLong; + if (CompileKernel) { + // Kernel addresses have 0xFF in the most significant byte. + UntaggedPtrLong = IRB.CreateOr(PtrLong, + ConstantInt::get(PtrLong->getType(), 0xFFULL << kPointerTagShift)); + } else { + // Userspace addresses have 0x00. + UntaggedPtrLong = IRB.CreateAnd(PtrLong, + ConstantInt::get(PtrLong->getType(), ~(0xFFULL << kPointerTagShift))); + } + return UntaggedPtrLong; +} + +bool HWAddressSanitizer::instrumentStack( + SmallVectorImpl<AllocaInst *> &Allocas, + SmallVectorImpl<Instruction *> &RetVec) { + Function *F = Allocas[0]->getParent()->getParent(); + Instruction *InsertPt = &*F->getEntryBlock().begin(); + IRBuilder<> IRB(InsertPt); + + Value *StackTag = getStackBaseTag(IRB); + + // Ideally, we want to calculate tagged stack base pointer, and rewrite all + // alloca addresses using that. Unfortunately, offsets are not known yet + // (unless we use ASan-style mega-alloca). Instead we keep the base tag in a + // temp, shift-OR it into each alloca address and xor with the retag mask. + // This generates one extra instruction per alloca use. + for (unsigned N = 0; N < Allocas.size(); ++N) { + auto *AI = Allocas[N]; + IRB.SetInsertPoint(AI->getNextNode()); + + // Replace uses of the alloca with tagged address. + Value *Tag = getAllocaTag(IRB, StackTag, AI, N); + Value *AILong = IRB.CreatePointerCast(AI, IntptrTy); + Value *Replacement = tagPointer(IRB, AI->getType(), AILong, Tag); + std::string Name = + AI->hasName() ? AI->getName().str() : "alloca." + itostr(N); + Replacement->setName(Name + ".hwasan"); + + for (auto UI = AI->use_begin(), UE = AI->use_end(); UI != UE;) { + Use &U = *UI++; + if (U.getUser() != AILong) + U.set(Replacement); + } + + tagAlloca(IRB, AI, Tag); + + for (auto RI : RetVec) { + IRB.SetInsertPoint(RI); + + // Re-tag alloca memory with the special UAR tag. + Value *Tag = getUARTag(IRB, StackTag); + tagAlloca(IRB, AI, Tag); + } + } + + return true; +} + +bool HWAddressSanitizer::isInterestingAlloca(const AllocaInst &AI) { + return (AI.getAllocatedType()->isSized() && + // FIXME: instrument dynamic allocas, too + AI.isStaticAlloca() && + // alloca() may be called with 0 size, ignore it. + getAllocaSizeInBytes(AI) > 0 && + // We are only interested in allocas not promotable to registers. + // Promotable allocas are common under -O0. + !isAllocaPromotable(&AI) && + // inalloca allocas are not treated as static, and we don't want + // dynamic alloca instrumentation for them as well. + !AI.isUsedWithInAlloca() && + // swifterror allocas are register promoted by ISel + !AI.isSwiftError()); +} + bool HWAddressSanitizer::runOnFunction(Function &F) { if (&F == HwasanCtorFunction) return false; @@ -312,14 +694,35 @@ bool HWAddressSanitizer::runOnFunction(Function &F) { if (!F.hasFnAttribute(Attribute::SanitizeHWAddress)) return false; - DEBUG(dbgs() << "Function: " << F.getName() << "\n"); + LLVM_DEBUG(dbgs() << "Function: " << F.getName() << "\n"); initializeCallbacks(*F.getParent()); + assert(!LocalDynamicShadow); + maybeInsertDynamicShadowAtFunctionEntry(F); + bool Changed = false; SmallVector<Instruction*, 16> ToInstrument; + SmallVector<AllocaInst*, 8> AllocasToInstrument; + SmallVector<Instruction*, 8> RetVec; for (auto &BB : F) { for (auto &Inst : BB) { + if (ClInstrumentStack) + if (AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) { + // Realign all allocas. We don't want small uninteresting allocas to + // hide in instrumented alloca's padding. + if (AI->getAlignment() < Mapping.getAllocaAlignment()) + AI->setAlignment(Mapping.getAllocaAlignment()); + // Instrument some of them. + if (isInterestingAlloca(*AI)) + AllocasToInstrument.push_back(AI); + continue; + } + + if (isa<ReturnInst>(Inst) || isa<ResumeInst>(Inst) || + isa<CleanupReturnInst>(Inst)) + RetVec.push_back(&Inst); + Value *MaybeMask = nullptr; bool IsWrite; unsigned Alignment; @@ -331,8 +734,30 @@ bool HWAddressSanitizer::runOnFunction(Function &F) { } } + if (!AllocasToInstrument.empty()) + Changed |= instrumentStack(AllocasToInstrument, RetVec); + for (auto Inst : ToInstrument) Changed |= instrumentMemAccess(Inst); + LocalDynamicShadow = nullptr; + return Changed; } + +void HWAddressSanitizer::ShadowMapping::init(Triple &TargetTriple) { + const bool IsAndroid = TargetTriple.isAndroid(); + const bool IsAndroidWithIfuncSupport = + IsAndroid && !TargetTriple.isAndroidVersionLT(21); + + Scale = kDefaultShadowScale; + + if (ClEnableKhwasan || ClInstrumentWithCalls || !IsAndroidWithIfuncSupport) + Offset = 0; + else + Offset = kDynamicShadowSentinel; + if (ClMappingOffset.getNumOccurrences() > 0) + Offset = ClMappingOffset; + + InGlobal = IsAndroidWithIfuncSupport; +} diff --git a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp index 49b8a67a6c14..27fb0e4393af 100644 --- a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp +++ b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp @@ -45,7 +45,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" -#include "llvm/Transforms/PGOInstrumentation.h" +#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/CallPromotionUtils.h" #include <cassert> @@ -223,12 +223,12 @@ ICallPromotionFunc::getPromotionCandidatesForCallSite( uint64_t TotalCount, uint32_t NumCandidates) { std::vector<PromotionCandidate> Ret; - DEBUG(dbgs() << " \nWork on callsite #" << NumOfPGOICallsites << *Inst - << " Num_targets: " << ValueDataRef.size() - << " Num_candidates: " << NumCandidates << "\n"); + LLVM_DEBUG(dbgs() << " \nWork on callsite #" << NumOfPGOICallsites << *Inst + << " Num_targets: " << ValueDataRef.size() + << " Num_candidates: " << NumCandidates << "\n"); NumOfPGOICallsites++; if (ICPCSSkip != 0 && NumOfPGOICallsites <= ICPCSSkip) { - DEBUG(dbgs() << " Skip: User options.\n"); + LLVM_DEBUG(dbgs() << " Skip: User options.\n"); return Ret; } @@ -236,11 +236,11 @@ ICallPromotionFunc::getPromotionCandidatesForCallSite( uint64_t Count = ValueDataRef[I].Count; assert(Count <= TotalCount); uint64_t Target = ValueDataRef[I].Value; - DEBUG(dbgs() << " Candidate " << I << " Count=" << Count - << " Target_func: " << Target << "\n"); + LLVM_DEBUG(dbgs() << " Candidate " << I << " Count=" << Count + << " Target_func: " << Target << "\n"); if (ICPInvokeOnly && dyn_cast<CallInst>(Inst)) { - DEBUG(dbgs() << " Not promote: User options.\n"); + LLVM_DEBUG(dbgs() << " Not promote: User options.\n"); ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "UserOptions", Inst) << " Not promote: User options"; @@ -248,7 +248,7 @@ ICallPromotionFunc::getPromotionCandidatesForCallSite( break; } if (ICPCallOnly && dyn_cast<InvokeInst>(Inst)) { - DEBUG(dbgs() << " Not promote: User option.\n"); + LLVM_DEBUG(dbgs() << " Not promote: User option.\n"); ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "UserOptions", Inst) << " Not promote: User options"; @@ -256,7 +256,7 @@ ICallPromotionFunc::getPromotionCandidatesForCallSite( break; } if (ICPCutOff != 0 && NumOfPGOICallPromotion >= ICPCutOff) { - DEBUG(dbgs() << " Not promote: Cutoff reached.\n"); + LLVM_DEBUG(dbgs() << " Not promote: Cutoff reached.\n"); ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "CutOffReached", Inst) << " Not promote: Cutoff reached"; @@ -266,7 +266,7 @@ ICallPromotionFunc::getPromotionCandidatesForCallSite( Function *TargetFunction = Symtab->getFunction(Target); if (TargetFunction == nullptr) { - DEBUG(dbgs() << " Not promote: Cannot find the target\n"); + LLVM_DEBUG(dbgs() << " Not promote: Cannot find the target\n"); ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "UnableToFindTarget", Inst) << "Cannot promote indirect call: target not found"; @@ -387,7 +387,7 @@ static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI, InstrProfSymtab Symtab; if (Error E = Symtab.create(M, InLTO)) { std::string SymtabFailure = toString(std::move(E)); - DEBUG(dbgs() << "Failed to create symtab: " << SymtabFailure << "\n"); + LLVM_DEBUG(dbgs() << "Failed to create symtab: " << SymtabFailure << "\n"); (void)SymtabFailure; return false; } @@ -412,12 +412,12 @@ static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI, ICallPromotionFunc ICallPromotion(F, &M, &Symtab, SamplePGO, *ORE); bool FuncChanged = ICallPromotion.processFunction(PSI); if (ICPDUMPAFTER && FuncChanged) { - DEBUG(dbgs() << "\n== IR Dump After =="; F.print(dbgs())); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << "\n== IR Dump After =="; F.print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); } Changed |= FuncChanged; if (ICPCutOff != 0 && NumOfPGOICallPromotion >= ICPCutOff) { - DEBUG(dbgs() << " Stop: Cutoff reached.\n"); + LLVM_DEBUG(dbgs() << " Stop: Cutoff reached.\n"); break; } } diff --git a/lib/Transforms/Instrumentation/InstrProfiling.cpp b/lib/Transforms/Instrumentation/InstrProfiling.cpp index 9b70f95480e4..22076f04d6ad 100644 --- a/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -13,7 +13,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/InstrProfiling.h" +#include "llvm/Transforms/Instrumentation/InstrProfiling.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -271,8 +271,8 @@ public: break; } - DEBUG(dbgs() << Promoted << " counters promoted for loop (depth=" - << L.getLoopDepth() << ")\n"); + LLVM_DEBUG(dbgs() << Promoted << " counters promoted for loop (depth=" + << L.getLoopDepth() << ")\n"); return Promoted != 0; } @@ -430,9 +430,24 @@ void InstrProfiling::promoteCounterLoadStores(Function *F) { } } -bool InstrProfiling::run(Module &M, const TargetLibraryInfo &TLI) { - bool MadeChange = false; +/// Check if the module contains uses of any profiling intrinsics. +static bool containsProfilingIntrinsics(Module &M) { + if (auto *F = M.getFunction( + Intrinsic::getName(llvm::Intrinsic::instrprof_increment))) + if (!F->use_empty()) + return true; + if (auto *F = M.getFunction( + Intrinsic::getName(llvm::Intrinsic::instrprof_increment_step))) + if (!F->use_empty()) + return true; + if (auto *F = M.getFunction( + Intrinsic::getName(llvm::Intrinsic::instrprof_value_profile))) + if (!F->use_empty()) + return true; + return false; +} +bool InstrProfiling::run(Module &M, const TargetLibraryInfo &TLI) { this->M = &M; this->TLI = &TLI; NamesVar = nullptr; @@ -443,6 +458,15 @@ bool InstrProfiling::run(Module &M, const TargetLibraryInfo &TLI) { MemOPSizeRangeLast); TT = Triple(M.getTargetTriple()); + // Emit the runtime hook even if no counters are present. + bool MadeChange = emitRuntimeHook(); + + // Improve compile time by avoiding linear scans when there is no work. + GlobalVariable *CoverageNamesVar = + M.getNamedGlobal(getCoverageUnusedNamesVarName()); + if (!containsProfilingIntrinsics(M) && !CoverageNamesVar) + return MadeChange; + // We did not know how many value sites there would be inside // the instrumented function. This is counting the number of instrumented // target value sites to enter it as field in the profile data variable. @@ -464,8 +488,7 @@ bool InstrProfiling::run(Module &M, const TargetLibraryInfo &TLI) { for (Function &F : M) MadeChange |= lowerIntrinsics(&F); - if (GlobalVariable *CoverageNamesVar = - M.getNamedGlobal(getCoverageUnusedNamesVarName())) { + if (CoverageNamesVar) { lowerCoverageData(CoverageNamesVar); MadeChange = true; } @@ -476,7 +499,6 @@ bool InstrProfiling::run(Module &M, const TargetLibraryInfo &TLI) { emitVNodes(); emitNameData(); emitRegistration(); - emitRuntimeHook(); emitUses(); emitInitialization(); return true; @@ -669,6 +691,7 @@ static bool needsRuntimeRegistrationOfSectionRange(const Module &M) { // Use linker script magic to get data/cnts/name start/end. if (Triple(M.getTargetTriple()).isOSLinux() || Triple(M.getTargetTriple()).isOSFreeBSD() || + Triple(M.getTargetTriple()).isOSFuchsia() || Triple(M.getTargetTriple()).isPS4CPU()) return false; @@ -892,15 +915,15 @@ void InstrProfiling::emitRegistration() { IRB.CreateRetVoid(); } -void InstrProfiling::emitRuntimeHook() { +bool InstrProfiling::emitRuntimeHook() { // We expect the linker to be invoked with -u<hook_var> flag for linux, // for which case there is no need to emit the user function. if (Triple(M->getTargetTriple()).isOSLinux()) - return; + return false; // If the module's provided its own runtime, we don't need to do anything. if (M->getGlobalVariable(getInstrProfRuntimeHookVarName())) - return; + return false; // Declare an external variable that will pull in the runtime initialization. auto *Int32Ty = Type::getInt32Ty(M->getContext()); @@ -925,6 +948,7 @@ void InstrProfiling::emitRuntimeHook() { // Mark the user variable as used so that it isn't stripped out. UsedVars.push_back(User); + return true; } void InstrProfiling::emitUses() { diff --git a/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/lib/Transforms/Instrumentation/MemorySanitizer.cpp index b3c39b5b1665..4bcef6972786 100644 --- a/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -101,6 +101,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -138,7 +139,6 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include <algorithm> #include <cassert> @@ -163,7 +163,7 @@ static const unsigned kRetvalTLSSize = 800; // Accesses sizes are powers of two: 1, 2, 4, 8. static const size_t kNumberOfAccessSizes = 4; -/// \brief Track origins of uninitialized values. +/// Track origins of uninitialized values. /// /// Adds a section to MemorySanitizer report that points to the allocation /// (stack or heap) the uninitialized bits came from originally. @@ -199,6 +199,18 @@ static cl::opt<bool> ClHandleICmpExact("msan-handle-icmp-exact", cl::desc("exact handling of relational integer ICmp"), cl::Hidden, cl::init(false)); +// When compiling the Linux kernel, we sometimes see false positives related to +// MSan being unable to understand that inline assembly calls may initialize +// local variables. +// This flag makes the compiler conservatively unpoison every memory location +// passed into an assembly call. Note that this may cause false positives. +// Because it's impossible to figure out the array sizes, we can only unpoison +// the first sizeof(type) bytes for each type* pointer. +static cl::opt<bool> ClHandleAsmConservative( + "msan-handle-asm-conservative", + cl::desc("conservative handling of inline assembly"), cl::Hidden, + cl::init(false)); + // This flag controls whether we check the shadow of the address // operand of load or store. Such bugs are very rare, since load from // a garbage address typically results in SEGV, but still happen @@ -234,6 +246,24 @@ static cl::opt<bool> ClWithComdat("msan-with-comdat", cl::desc("Place MSan constructors in comdat sections"), cl::Hidden, cl::init(false)); +// These options allow to specify custom memory map parameters +// See MemoryMapParams for details. +static cl::opt<unsigned long long> ClAndMask("msan-and-mask", + cl::desc("Define custom MSan AndMask"), + cl::Hidden, cl::init(0)); + +static cl::opt<unsigned long long> ClXorMask("msan-xor-mask", + cl::desc("Define custom MSan XorMask"), + cl::Hidden, cl::init(0)); + +static cl::opt<unsigned long long> ClShadowBase("msan-shadow-base", + cl::desc("Define custom MSan ShadowBase"), + cl::Hidden, cl::init(0)); + +static cl::opt<unsigned long long> ClOriginBase("msan-origin-base", + cl::desc("Define custom MSan OriginBase"), + cl::Hidden, cl::init(0)); + static const char *const kMsanModuleCtorName = "msan.module_ctor"; static const char *const kMsanInitName = "__msan_init"; @@ -360,7 +390,7 @@ static const PlatformMemoryMapParams NetBSD_X86_MemoryMapParams = { namespace { -/// \brief An instrumentation pass implementing detection of uninitialized +/// An instrumentation pass implementing detection of uninitialized /// reads. /// /// MemorySanitizer: instrument the code in module to find @@ -368,7 +398,7 @@ namespace { class MemorySanitizer : public FunctionPass { public: // Pass identification, replacement for typeid. - static char ID; + static char ID; MemorySanitizer(int TrackOrigins = 0, bool Recover = false) : FunctionPass(ID), @@ -392,8 +422,9 @@ private: friend struct VarArgPowerPC64Helper; void initializeCallbacks(Module &M); + void createUserspaceApi(Module &M); - /// \brief Track origins (allocation points) of uninitialized values. + /// Track origins (allocation points) of uninitialized values. int TrackOrigins; bool Recover; @@ -401,60 +432,67 @@ private: Type *IntptrTy; Type *OriginTy; - /// \brief Thread-local shadow storage for function parameters. + /// Thread-local shadow storage for function parameters. GlobalVariable *ParamTLS; - /// \brief Thread-local origin storage for function parameters. + /// Thread-local origin storage for function parameters. GlobalVariable *ParamOriginTLS; - /// \brief Thread-local shadow storage for function return value. + /// Thread-local shadow storage for function return value. GlobalVariable *RetvalTLS; - /// \brief Thread-local origin storage for function return value. + /// Thread-local origin storage for function return value. GlobalVariable *RetvalOriginTLS; - /// \brief Thread-local shadow storage for in-register va_arg function + /// Thread-local shadow storage for in-register va_arg function /// parameters (x86_64-specific). GlobalVariable *VAArgTLS; - /// \brief Thread-local shadow storage for va_arg overflow area + /// Thread-local shadow storage for va_arg overflow area /// (x86_64-specific). GlobalVariable *VAArgOverflowSizeTLS; - /// \brief Thread-local space used to pass origin value to the UMR reporting + /// Thread-local space used to pass origin value to the UMR reporting /// function. GlobalVariable *OriginTLS; - /// \brief The run-time callback to print a warning. - Value *WarningFn = nullptr; + /// Are the instrumentation callbacks set up? + bool CallbacksInitialized = false; + + /// The run-time callback to print a warning. + Value *WarningFn; // These arrays are indexed by log2(AccessSize). Value *MaybeWarningFn[kNumberOfAccessSizes]; Value *MaybeStoreOriginFn[kNumberOfAccessSizes]; - /// \brief Run-time helper that generates a new origin value for a stack + /// Run-time helper that generates a new origin value for a stack /// allocation. Value *MsanSetAllocaOrigin4Fn; - /// \brief Run-time helper that poisons stack on function entry. + /// Run-time helper that poisons stack on function entry. Value *MsanPoisonStackFn; - /// \brief Run-time helper that records a store (or any event) of an + /// Run-time helper that records a store (or any event) of an /// uninitialized value and returns an updated origin id encoding this info. Value *MsanChainOriginFn; - /// \brief MSan runtime replacements for memmove, memcpy and memset. + /// MSan runtime replacements for memmove, memcpy and memset. Value *MemmoveFn, *MemcpyFn, *MemsetFn; - /// \brief Memory map parameters used in application-to-shadow calculation. + /// Memory map parameters used in application-to-shadow calculation. const MemoryMapParams *MapParams; + /// Custom memory map parameters used when -msan-shadow-base or + // -msan-origin-base is provided. + MemoryMapParams CustomMapParams; + MDNode *ColdCallWeights; - /// \brief Branch weights for origin store. + /// Branch weights for origin store. MDNode *OriginStoreWeights; - /// \brief An empty volatile inline asm that prevents callback merge. + /// An empty volatile inline asm that prevents callback merge. InlineAsm *EmptyAsm; Function *MsanCtorFunction; @@ -476,7 +514,7 @@ FunctionPass *llvm::createMemorySanitizerPass(int TrackOrigins, bool Recover) { return new MemorySanitizer(TrackOrigins, Recover); } -/// \brief Create a non-const global initialized with the given string. +/// Create a non-const global initialized with the given string. /// /// Creates a writable global for Str so that we can pass it to the /// run-time lib. Runtime uses first 4 bytes of the string to store the @@ -488,12 +526,8 @@ static GlobalVariable *createPrivateNonConstGlobalForString(Module &M, GlobalValue::PrivateLinkage, StrConst, ""); } -/// \brief Insert extern declaration of runtime-provided functions and globals. -void MemorySanitizer::initializeCallbacks(Module &M) { - // Only do this once. - if (WarningFn) - return; - +/// Insert declarations for userspace-specific functions and globals. +void MemorySanitizer::createUserspaceApi(Module &M) { IRBuilder<> IRB(*C); // Create the callback. // FIXME: this function should have "Cold" calling conv, @@ -502,6 +536,38 @@ void MemorySanitizer::initializeCallbacks(Module &M) { : "__msan_warning_noreturn"; WarningFn = M.getOrInsertFunction(WarningFnName, IRB.getVoidTy()); + // Create the global TLS variables. + RetvalTLS = new GlobalVariable( + M, ArrayType::get(IRB.getInt64Ty(), kRetvalTLSSize / 8), false, + GlobalVariable::ExternalLinkage, nullptr, "__msan_retval_tls", nullptr, + GlobalVariable::InitialExecTLSModel); + + RetvalOriginTLS = new GlobalVariable( + M, OriginTy, false, GlobalVariable::ExternalLinkage, nullptr, + "__msan_retval_origin_tls", nullptr, GlobalVariable::InitialExecTLSModel); + + ParamTLS = new GlobalVariable( + M, ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8), false, + GlobalVariable::ExternalLinkage, nullptr, "__msan_param_tls", nullptr, + GlobalVariable::InitialExecTLSModel); + + ParamOriginTLS = new GlobalVariable( + M, ArrayType::get(OriginTy, kParamTLSSize / 4), false, + GlobalVariable::ExternalLinkage, nullptr, "__msan_param_origin_tls", + nullptr, GlobalVariable::InitialExecTLSModel); + + VAArgTLS = new GlobalVariable( + M, ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8), false, + GlobalVariable::ExternalLinkage, nullptr, "__msan_va_arg_tls", nullptr, + GlobalVariable::InitialExecTLSModel); + VAArgOverflowSizeTLS = new GlobalVariable( + M, IRB.getInt64Ty(), false, GlobalVariable::ExternalLinkage, nullptr, + "__msan_va_arg_overflow_size_tls", nullptr, + GlobalVariable::InitialExecTLSModel); + OriginTLS = new GlobalVariable( + M, IRB.getInt32Ty(), false, GlobalVariable::ExternalLinkage, nullptr, + "__msan_origin_tls", nullptr, GlobalVariable::InitialExecTLSModel); + for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes; AccessSizeIndex++) { unsigned AccessSize = 1 << AccessSizeIndex; @@ -522,6 +588,17 @@ void MemorySanitizer::initializeCallbacks(Module &M) { MsanPoisonStackFn = M.getOrInsertFunction("__msan_poison_stack", IRB.getVoidTy(), IRB.getInt8PtrTy(), IntptrTy); +} + +/// Insert extern declaration of runtime-provided functions and globals. +void MemorySanitizer::initializeCallbacks(Module &M) { + // Only do this once. + if (CallbacksInitialized) + return; + + IRBuilder<> IRB(*C); + // Initialize callbacks that are common for kernel and userspace + // instrumentation. MsanChainOriginFn = M.getOrInsertFunction( "__msan_chain_origin", IRB.getInt32Ty(), IRB.getInt32Ty()); MemmoveFn = M.getOrInsertFunction( @@ -533,98 +610,81 @@ void MemorySanitizer::initializeCallbacks(Module &M) { MemsetFn = M.getOrInsertFunction( "__msan_memset", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy); - - // Create globals. - RetvalTLS = new GlobalVariable( - M, ArrayType::get(IRB.getInt64Ty(), kRetvalTLSSize / 8), false, - GlobalVariable::ExternalLinkage, nullptr, "__msan_retval_tls", nullptr, - GlobalVariable::InitialExecTLSModel); - RetvalOriginTLS = new GlobalVariable( - M, OriginTy, false, GlobalVariable::ExternalLinkage, nullptr, - "__msan_retval_origin_tls", nullptr, GlobalVariable::InitialExecTLSModel); - - ParamTLS = new GlobalVariable( - M, ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8), false, - GlobalVariable::ExternalLinkage, nullptr, "__msan_param_tls", nullptr, - GlobalVariable::InitialExecTLSModel); - ParamOriginTLS = new GlobalVariable( - M, ArrayType::get(OriginTy, kParamTLSSize / 4), false, - GlobalVariable::ExternalLinkage, nullptr, "__msan_param_origin_tls", - nullptr, GlobalVariable::InitialExecTLSModel); - - VAArgTLS = new GlobalVariable( - M, ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8), false, - GlobalVariable::ExternalLinkage, nullptr, "__msan_va_arg_tls", nullptr, - GlobalVariable::InitialExecTLSModel); - VAArgOverflowSizeTLS = new GlobalVariable( - M, IRB.getInt64Ty(), false, GlobalVariable::ExternalLinkage, nullptr, - "__msan_va_arg_overflow_size_tls", nullptr, - GlobalVariable::InitialExecTLSModel); - OriginTLS = new GlobalVariable( - M, IRB.getInt32Ty(), false, GlobalVariable::ExternalLinkage, nullptr, - "__msan_origin_tls", nullptr, GlobalVariable::InitialExecTLSModel); - // We insert an empty inline asm after __msan_report* to avoid callback merge. EmptyAsm = InlineAsm::get(FunctionType::get(IRB.getVoidTy(), false), StringRef(""), StringRef(""), /*hasSideEffects=*/true); + + createUserspaceApi(M); + CallbacksInitialized = true; } -/// \brief Module-level initialization. +/// Module-level initialization. /// /// inserts a call to __msan_init to the module's constructor list. bool MemorySanitizer::doInitialization(Module &M) { auto &DL = M.getDataLayout(); - Triple TargetTriple(M.getTargetTriple()); - switch (TargetTriple.getOS()) { - case Triple::FreeBSD: - switch (TargetTriple.getArch()) { - case Triple::x86_64: - MapParams = FreeBSD_X86_MemoryMapParams.bits64; - break; - case Triple::x86: - MapParams = FreeBSD_X86_MemoryMapParams.bits32; - break; - default: - report_fatal_error("unsupported architecture"); - } - break; - case Triple::NetBSD: - switch (TargetTriple.getArch()) { - case Triple::x86_64: - MapParams = NetBSD_X86_MemoryMapParams.bits64; - break; - default: - report_fatal_error("unsupported architecture"); - } - break; - case Triple::Linux: - switch (TargetTriple.getArch()) { - case Triple::x86_64: - MapParams = Linux_X86_MemoryMapParams.bits64; - break; - case Triple::x86: - MapParams = Linux_X86_MemoryMapParams.bits32; - break; - case Triple::mips64: - case Triple::mips64el: - MapParams = Linux_MIPS_MemoryMapParams.bits64; - break; - case Triple::ppc64: - case Triple::ppc64le: - MapParams = Linux_PowerPC_MemoryMapParams.bits64; - break; - case Triple::aarch64: - case Triple::aarch64_be: - MapParams = Linux_ARM_MemoryMapParams.bits64; - break; - default: - report_fatal_error("unsupported architecture"); - } - break; - default: - report_fatal_error("unsupported operating system"); + bool ShadowPassed = ClShadowBase.getNumOccurrences() > 0; + bool OriginPassed = ClOriginBase.getNumOccurrences() > 0; + // Check the overrides first + if (ShadowPassed || OriginPassed) { + CustomMapParams.AndMask = ClAndMask; + CustomMapParams.XorMask = ClXorMask; + CustomMapParams.ShadowBase = ClShadowBase; + CustomMapParams.OriginBase = ClOriginBase; + MapParams = &CustomMapParams; + } else { + Triple TargetTriple(M.getTargetTriple()); + switch (TargetTriple.getOS()) { + case Triple::FreeBSD: + switch (TargetTriple.getArch()) { + case Triple::x86_64: + MapParams = FreeBSD_X86_MemoryMapParams.bits64; + break; + case Triple::x86: + MapParams = FreeBSD_X86_MemoryMapParams.bits32; + break; + default: + report_fatal_error("unsupported architecture"); + } + break; + case Triple::NetBSD: + switch (TargetTriple.getArch()) { + case Triple::x86_64: + MapParams = NetBSD_X86_MemoryMapParams.bits64; + break; + default: + report_fatal_error("unsupported architecture"); + } + break; + case Triple::Linux: + switch (TargetTriple.getArch()) { + case Triple::x86_64: + MapParams = Linux_X86_MemoryMapParams.bits64; + break; + case Triple::x86: + MapParams = Linux_X86_MemoryMapParams.bits32; + break; + case Triple::mips64: + case Triple::mips64el: + MapParams = Linux_MIPS_MemoryMapParams.bits64; + break; + case Triple::ppc64: + case Triple::ppc64le: + MapParams = Linux_PowerPC_MemoryMapParams.bits64; + break; + case Triple::aarch64: + case Triple::aarch64_be: + MapParams = Linux_ARM_MemoryMapParams.bits64; + break; + default: + report_fatal_error("unsupported architecture"); + } + break; + default: + report_fatal_error("unsupported operating system"); + } } C = &(M.getContext()); @@ -661,7 +721,7 @@ bool MemorySanitizer::doInitialization(Module &M) { namespace { -/// \brief A helper class that handles instrumentation of VarArg +/// A helper class that handles instrumentation of VarArg /// functions on a particular platform. /// /// Implementations are expected to insert the instrumentation @@ -672,16 +732,16 @@ namespace { struct VarArgHelper { virtual ~VarArgHelper() = default; - /// \brief Visit a CallSite. + /// Visit a CallSite. virtual void visitCallSite(CallSite &CS, IRBuilder<> &IRB) = 0; - /// \brief Visit a va_start call. + /// Visit a va_start call. virtual void visitVAStartInst(VAStartInst &I) = 0; - /// \brief Visit a va_copy call. + /// Visit a va_copy call. virtual void visitVACopyInst(VACopyInst &I) = 0; - /// \brief Finalize function instrumentation. + /// Finalize function instrumentation. /// /// This method is called after visiting all interesting (see above) /// instructions in a function. @@ -715,6 +775,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { ValueMap<Value*, Value*> ShadowMap, OriginMap; std::unique_ptr<VarArgHelper> VAHelper; const TargetLibraryInfo *TLI; + BasicBlock *ActualFnStart; // The following flags disable parts of MSan instrumentation based on // blacklist contents and command-line options. @@ -747,9 +808,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { CheckReturnValue = SanitizeFunction && (F.getName() == "main"); TLI = &MS.getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - DEBUG(if (!InsertChecks) - dbgs() << "MemorySanitizer is not inserting checks into '" - << F.getName() << "'\n"); + MS.initializeCallbacks(*F.getParent()); + ActualFnStart = &F.getEntryBlock(); + + LLVM_DEBUG(if (!InsertChecks) dbgs() + << "MemorySanitizer is not inserting checks into '" + << F.getName() << "'\n"); } Value *updateOrigin(Value *V, IRBuilder<> &IRB) { @@ -766,7 +830,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return IRB.CreateOr(Origin, IRB.CreateShl(Origin, kOriginSize * 8)); } - /// \brief Fill memory range with the given origin value. + /// Fill memory range with the given origin value. void paintOrigin(IRBuilder<> &IRB, Value *Origin, Value *OriginPtr, unsigned Size, unsigned Alignment) { const DataLayout &DL = F.getParent()->getDataLayout(); @@ -849,13 +913,11 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { unsigned Alignment = SI->getAlignment(); unsigned OriginAlignment = std::max(kMinOriginAlignment, Alignment); std::tie(ShadowPtr, OriginPtr) = - getShadowOriginPtr(Addr, IRB, ShadowTy, Alignment); + getShadowOriginPtr(Addr, IRB, ShadowTy, Alignment, /*isStore*/ true); StoreInst *NewSI = IRB.CreateAlignedStore(Shadow, ShadowPtr, Alignment); - DEBUG(dbgs() << " STORE: " << *NewSI << "\n"); - - if (ClCheckAccessAddress) - insertShadowCheck(Addr, NewSI); + LLVM_DEBUG(dbgs() << " STORE: " << *NewSI << "\n"); + (void)NewSI; if (SI->isAtomic()) SI->setOrdering(addReleaseOrdering(SI->getOrdering())); @@ -866,25 +928,31 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } } + /// Helper function to insert a warning at IRB's current insert point. + void insertWarningFn(IRBuilder<> &IRB, Value *Origin) { + if (!Origin) + Origin = (Value *)IRB.getInt32(0); + if (MS.TrackOrigins) { + IRB.CreateStore(Origin, MS.OriginTLS); + } + IRB.CreateCall(MS.WarningFn, {}); + IRB.CreateCall(MS.EmptyAsm, {}); + // FIXME: Insert UnreachableInst if !MS.Recover? + // This may invalidate some of the following checks and needs to be done + // at the very end. + } + void materializeOneCheck(Instruction *OrigIns, Value *Shadow, Value *Origin, bool AsCall) { IRBuilder<> IRB(OrigIns); - DEBUG(dbgs() << " SHAD0 : " << *Shadow << "\n"); + LLVM_DEBUG(dbgs() << " SHAD0 : " << *Shadow << "\n"); Value *ConvertedShadow = convertToShadowTyNoVec(Shadow, IRB); - DEBUG(dbgs() << " SHAD1 : " << *ConvertedShadow << "\n"); + LLVM_DEBUG(dbgs() << " SHAD1 : " << *ConvertedShadow << "\n"); Constant *ConstantShadow = dyn_cast_or_null<Constant>(ConvertedShadow); if (ConstantShadow) { if (ClCheckConstantShadow && !ConstantShadow->isZeroValue()) { - if (MS.TrackOrigins) { - IRB.CreateStore(Origin ? (Value *)Origin : (Value *)IRB.getInt32(0), - MS.OriginTLS); - } - IRB.CreateCall(MS.WarningFn, {}); - IRB.CreateCall(MS.EmptyAsm, {}); - // FIXME: Insert UnreachableInst if !MS.Recover? - // This may invalidate some of the following checks and needs to be done - // at the very end. + insertWarningFn(IRB, Origin); } return; } @@ -908,13 +976,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /* Unreachable */ !MS.Recover, MS.ColdCallWeights); IRB.SetInsertPoint(CheckTerm); - if (MS.TrackOrigins) { - IRB.CreateStore(Origin ? (Value *)Origin : (Value *)IRB.getInt32(0), - MS.OriginTLS); - } - IRB.CreateCall(MS.WarningFn, {}); - IRB.CreateCall(MS.EmptyAsm, {}); - DEBUG(dbgs() << " CHECK: " << *Cmp << "\n"); + insertWarningFn(IRB, Origin); + LLVM_DEBUG(dbgs() << " CHECK: " << *Cmp << "\n"); } } @@ -925,13 +988,11 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *Origin = ShadowData.Origin; materializeOneCheck(OrigIns, Shadow, Origin, InstrumentWithCalls); } - DEBUG(dbgs() << "DONE:\n" << F); + LLVM_DEBUG(dbgs() << "DONE:\n" << F); } - /// \brief Add MemorySanitizer instrumentation to a function. + /// Add MemorySanitizer instrumentation to a function. bool runOnFunction() { - MS.initializeCallbacks(*F.getParent()); - // In the presence of unreachable blocks, we may see Phi nodes with // incoming nodes from such blocks. Since InstVisitor skips unreachable // blocks, such nodes will not have any shadow value associated with them. @@ -941,7 +1002,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // Iterate all BBs in depth-first order and create shadow instructions // for all instructions (where applicable). // For PHI nodes we create dummy shadow PHIs which will be finalized later. - for (BasicBlock *BB : depth_first(&F.getEntryBlock())) + for (BasicBlock *BB : depth_first(ActualFnStart)) visit(*BB); // Finalize PHI nodes. @@ -961,22 +1022,22 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { InstrumentationList.size() + StoreList.size() > (unsigned)ClInstrumentationWithCallThreshold; - // Delayed instrumentation of StoreInst. - // This may add new checks to be inserted later. - materializeStores(InstrumentWithCalls); - // Insert shadow value checks. materializeChecks(InstrumentWithCalls); + // Delayed instrumentation of StoreInst. + // This may not add new address checks. + materializeStores(InstrumentWithCalls); + return true; } - /// \brief Compute the shadow type that corresponds to a given Value. + /// Compute the shadow type that corresponds to a given Value. Type *getShadowTy(Value *V) { return getShadowTy(V->getType()); } - /// \brief Compute the shadow type that corresponds to a given Type. + /// Compute the shadow type that corresponds to a given Type. Type *getShadowTy(Type *OrigTy) { if (!OrigTy->isSized()) { return nullptr; @@ -1000,21 +1061,21 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { for (unsigned i = 0, n = ST->getNumElements(); i < n; i++) Elements.push_back(getShadowTy(ST->getElementType(i))); StructType *Res = StructType::get(*MS.C, Elements, ST->isPacked()); - DEBUG(dbgs() << "getShadowTy: " << *ST << " ===> " << *Res << "\n"); + LLVM_DEBUG(dbgs() << "getShadowTy: " << *ST << " ===> " << *Res << "\n"); return Res; } uint32_t TypeSize = DL.getTypeSizeInBits(OrigTy); return IntegerType::get(*MS.C, TypeSize); } - /// \brief Flatten a vector type. + /// Flatten a vector type. Type *getShadowTyNoVec(Type *ty) { if (VectorType *vt = dyn_cast<VectorType>(ty)) return IntegerType::get(*MS.C, vt->getBitWidth()); return ty; } - /// \brief Convert a shadow value to it's flattened variant. + /// Convert a shadow value to it's flattened variant. Value *convertToShadowTyNoVec(Value *V, IRBuilder<> &IRB) { Type *Ty = V->getType(); Type *NoVecTy = getShadowTyNoVec(Ty); @@ -1022,7 +1083,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return IRB.CreateBitCast(V, NoVecTy); } - /// \brief Compute the integer shadow offset that corresponds to a given + /// Compute the integer shadow offset that corresponds to a given /// application address. /// /// Offset = (Addr & ~AndMask) ^ XorMask @@ -1041,18 +1102,18 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return OffsetLong; } - /// \brief Compute the shadow and origin addresses corresponding to a given + /// Compute the shadow and origin addresses corresponding to a given /// application address. /// /// Shadow = ShadowBase + Offset /// Origin = (OriginBase + Offset) & ~3ULL - std::pair<Value *, Value *> getShadowOriginPtrUserspace( - Value *Addr, IRBuilder<> &IRB, Type *ShadowTy, unsigned Alignment, - Instruction **FirstInsn) { + std::pair<Value *, Value *> getShadowOriginPtrUserspace(Value *Addr, + IRBuilder<> &IRB, + Type *ShadowTy, + unsigned Alignment) { Value *ShadowOffset = getShadowPtrOffset(Addr, IRB); Value *ShadowLong = ShadowOffset; uint64_t ShadowBase = MS.MapParams->ShadowBase; - *FirstInsn = dyn_cast<Instruction>(ShadowLong); if (ShadowBase != 0) { ShadowLong = IRB.CreateAdd(ShadowLong, @@ -1080,58 +1141,60 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { std::pair<Value *, Value *> getShadowOriginPtr(Value *Addr, IRBuilder<> &IRB, Type *ShadowTy, - unsigned Alignment) { - Instruction *FirstInsn = nullptr; + unsigned Alignment, + bool isStore) { std::pair<Value *, Value *> ret = - getShadowOriginPtrUserspace(Addr, IRB, ShadowTy, Alignment, &FirstInsn); + getShadowOriginPtrUserspace(Addr, IRB, ShadowTy, Alignment); return ret; } - /// \brief Compute the shadow address for a given function argument. + /// Compute the shadow address for a given function argument. /// /// Shadow = ParamTLS+ArgOffset. Value *getShadowPtrForArgument(Value *A, IRBuilder<> &IRB, int ArgOffset) { Value *Base = IRB.CreatePointerCast(MS.ParamTLS, MS.IntptrTy); - Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); + if (ArgOffset) + Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); return IRB.CreateIntToPtr(Base, PointerType::get(getShadowTy(A), 0), "_msarg"); } - /// \brief Compute the origin address for a given function argument. + /// Compute the origin address for a given function argument. Value *getOriginPtrForArgument(Value *A, IRBuilder<> &IRB, int ArgOffset) { if (!MS.TrackOrigins) return nullptr; Value *Base = IRB.CreatePointerCast(MS.ParamOriginTLS, MS.IntptrTy); - Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); + if (ArgOffset) + Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); return IRB.CreateIntToPtr(Base, PointerType::get(MS.OriginTy, 0), "_msarg_o"); } - /// \brief Compute the shadow address for a retval. + /// Compute the shadow address for a retval. Value *getShadowPtrForRetval(Value *A, IRBuilder<> &IRB) { return IRB.CreatePointerCast(MS.RetvalTLS, PointerType::get(getShadowTy(A), 0), "_msret"); } - /// \brief Compute the origin address for a retval. + /// Compute the origin address for a retval. Value *getOriginPtrForRetval(IRBuilder<> &IRB) { // We keep a single origin for the entire retval. Might be too optimistic. return MS.RetvalOriginTLS; } - /// \brief Set SV to be the shadow value for V. + /// Set SV to be the shadow value for V. void setShadow(Value *V, Value *SV) { assert(!ShadowMap.count(V) && "Values may only have one shadow"); ShadowMap[V] = PropagateShadow ? SV : getCleanShadow(V); } - /// \brief Set Origin to be the origin value for V. + /// Set Origin to be the origin value for V. void setOrigin(Value *V, Value *Origin) { if (!MS.TrackOrigins) return; assert(!OriginMap.count(V) && "Values may only have one origin"); - DEBUG(dbgs() << "ORIGIN: " << *V << " ==> " << *Origin << "\n"); + LLVM_DEBUG(dbgs() << "ORIGIN: " << *V << " ==> " << *Origin << "\n"); OriginMap[V] = Origin; } @@ -1142,7 +1205,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return Constant::getNullValue(ShadowTy); } - /// \brief Create a clean shadow value for a given value. + /// Create a clean shadow value for a given value. /// /// Clean shadow (all zeroes) means all bits of the value are defined /// (initialized). @@ -1150,7 +1213,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return getCleanShadow(V->getType()); } - /// \brief Create a dirty shadow of a given shadow type. + /// Create a dirty shadow of a given shadow type. Constant *getPoisonedShadow(Type *ShadowTy) { assert(ShadowTy); if (isa<IntegerType>(ShadowTy) || isa<VectorType>(ShadowTy)) @@ -1169,7 +1232,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { llvm_unreachable("Unexpected shadow type"); } - /// \brief Create a dirty shadow for a given value. + /// Create a dirty shadow for a given value. Constant *getPoisonedShadow(Value *V) { Type *ShadowTy = getShadowTy(V); if (!ShadowTy) @@ -1177,12 +1240,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return getPoisonedShadow(ShadowTy); } - /// \brief Create a clean (zero) origin. + /// Create a clean (zero) origin. Value *getCleanOrigin() { return Constant::getNullValue(MS.OriginTy); } - /// \brief Get the shadow value for a given Value. + /// Get the shadow value for a given Value. /// /// This function either returns the value set earlier with setShadow, /// or extracts if from ParamTLS (for function arguments). @@ -1194,7 +1257,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // For instructions the shadow is already stored in the map. Value *Shadow = ShadowMap[V]; if (!Shadow) { - DEBUG(dbgs() << "No shadow: " << *V << "\n" << *(I->getParent())); + LLVM_DEBUG(dbgs() << "No shadow: " << *V << "\n" << *(I->getParent())); (void)I; assert(Shadow && "No shadow for a value"); } @@ -1202,7 +1265,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } if (UndefValue *U = dyn_cast<UndefValue>(V)) { Value *AllOnes = PoisonUndef ? getPoisonedShadow(V) : getCleanShadow(V); - DEBUG(dbgs() << "Undef: " << *U << " ==> " << *AllOnes << "\n"); + LLVM_DEBUG(dbgs() << "Undef: " << *U << " ==> " << *AllOnes << "\n"); (void)U; return AllOnes; } @@ -1212,12 +1275,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (*ShadowPtr) return *ShadowPtr; Function *F = A->getParent(); - IRBuilder<> EntryIRB(F->getEntryBlock().getFirstNonPHI()); + IRBuilder<> EntryIRB(ActualFnStart->getFirstNonPHI()); unsigned ArgOffset = 0; const DataLayout &DL = F->getParent()->getDataLayout(); for (auto &FArg : F->args()) { if (!FArg.getType()->isSized()) { - DEBUG(dbgs() << "Arg is not sized\n"); + LLVM_DEBUG(dbgs() << "Arg is not sized\n"); continue; } unsigned Size = @@ -1237,7 +1300,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { ArgAlign = DL.getABITypeAlignment(EltType); } Value *CpShadowPtr = - getShadowOriginPtr(V, EntryIRB, EntryIRB.getInt8Ty(), ArgAlign) + getShadowOriginPtr(V, EntryIRB, EntryIRB.getInt8Ty(), ArgAlign, + /*isStore*/ true) .first; if (Overflow) { // ParamTLS overflow. @@ -1246,9 +1310,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Size, ArgAlign); } else { unsigned CopyAlign = std::min(ArgAlign, kShadowTLSAlignment); - Value *Cpy = - EntryIRB.CreateMemCpy(CpShadowPtr, Base, Size, CopyAlign); - DEBUG(dbgs() << " ByValCpy: " << *Cpy << "\n"); + Value *Cpy = EntryIRB.CreateMemCpy(CpShadowPtr, CopyAlign, Base, + CopyAlign, Size); + LLVM_DEBUG(dbgs() << " ByValCpy: " << *Cpy << "\n"); (void)Cpy; } *ShadowPtr = getCleanShadow(V); @@ -1261,8 +1325,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { EntryIRB.CreateAlignedLoad(Base, kShadowTLSAlignment); } } - DEBUG(dbgs() << " ARG: " << FArg << " ==> " << - **ShadowPtr << "\n"); + LLVM_DEBUG(dbgs() + << " ARG: " << FArg << " ==> " << **ShadowPtr << "\n"); if (MS.TrackOrigins && !Overflow) { Value *OriginPtr = getOriginPtrForArgument(&FArg, EntryIRB, ArgOffset); @@ -1280,12 +1344,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return getCleanShadow(V); } - /// \brief Get the shadow for i-th argument of the instruction I. + /// Get the shadow for i-th argument of the instruction I. Value *getShadow(Instruction *I, int i) { return getShadow(I->getOperand(i)); } - /// \brief Get the origin for a value. + /// Get the origin for a value. Value *getOrigin(Value *V) { if (!MS.TrackOrigins) return nullptr; if (!PropagateShadow) return getCleanOrigin(); @@ -1301,12 +1365,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return Origin; } - /// \brief Get the origin for i-th argument of the instruction I. + /// Get the origin for i-th argument of the instruction I. Value *getOrigin(Instruction *I, int i) { return getOrigin(I->getOperand(i)); } - /// \brief Remember the place where a shadow check should be inserted. + /// Remember the place where a shadow check should be inserted. /// /// This location will be later instrumented with a check that will print a /// UMR warning in runtime if the shadow value is not 0. @@ -1322,7 +1386,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { ShadowOriginAndInsertPoint(Shadow, Origin, OrigIns)); } - /// \brief Remember the place where a shadow check should be inserted. + /// Remember the place where a shadow check should be inserted. /// /// This location will be later instrumented with a check that will print a /// UMR warning in runtime if the value is not fully defined. @@ -1382,7 +1446,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { InstVisitor<MemorySanitizerVisitor>::visit(I); } - /// \brief Instrument LoadInst + /// Instrument LoadInst /// /// Loads the corresponding shadow and (optionally) origin. /// Optionally, checks that the load address is fully defined. @@ -1396,7 +1460,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { unsigned Alignment = I.getAlignment(); if (PropagateShadow) { std::tie(ShadowPtr, OriginPtr) = - getShadowOriginPtr(Addr, IRB, ShadowTy, Alignment); + getShadowOriginPtr(Addr, IRB, ShadowTy, Alignment, /*isStore*/ false); setShadow(&I, IRB.CreateAlignedLoad(ShadowPtr, Alignment, "_msld")); } else { setShadow(&I, getCleanShadow(&I)); @@ -1418,12 +1482,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } } - /// \brief Instrument StoreInst + /// Instrument StoreInst /// /// Stores the corresponding shadow and (optionally) origin. /// Optionally, checks that the store address is fully defined. void visitStoreInst(StoreInst &I) { StoreList.push_back(&I); + if (ClCheckAccessAddress) + insertShadowCheck(I.getPointerOperand(), &I); } void handleCASOrRMW(Instruction &I) { @@ -1431,8 +1497,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRBuilder<> IRB(&I); Value *Addr = I.getOperand(0); - Value *ShadowPtr = - getShadowOriginPtr(Addr, IRB, I.getType(), /*Alignment*/ 1).first; + Value *ShadowPtr = getShadowOriginPtr(Addr, IRB, I.getType(), + /*Alignment*/ 1, /*isStore*/ true) + .first; if (ClCheckAccessAddress) insertShadowCheck(Addr, &I); @@ -1536,7 +1603,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void visitFPExtInst(CastInst& I) { handleShadowOr(I); } void visitFPTruncInst(CastInst& I) { handleShadowOr(I); } - /// \brief Propagate shadow for bitwise AND. + /// Propagate shadow for bitwise AND. /// /// This code is exact, i.e. if, for example, a bit in the left argument /// is defined and 0, then neither the value not definedness of the @@ -1585,7 +1652,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOriginForNaryOp(I); } - /// \brief Default propagation of shadow and/or origin. + /// Default propagation of shadow and/or origin. /// /// This class implements the general case of shadow propagation, used in all /// cases where we don't know and/or don't care about what the operation @@ -1611,7 +1678,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Combiner(MemorySanitizerVisitor *MSV, IRBuilder<> &IRB) : IRB(IRB), MSV(MSV) {} - /// \brief Add a pair of shadow and origin values to the mix. + /// Add a pair of shadow and origin values to the mix. Combiner &Add(Value *OpShadow, Value *OpOrigin) { if (CombineShadow) { assert(OpShadow); @@ -1641,14 +1708,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return *this; } - /// \brief Add an application value to the mix. + /// Add an application value to the mix. Combiner &Add(Value *V) { Value *OpShadow = MSV->getShadow(V); Value *OpOrigin = MSV->MS.TrackOrigins ? MSV->getOrigin(V) : nullptr; return Add(OpShadow, OpOrigin); } - /// \brief Set the current combined values as the given instruction's shadow + /// Set the current combined values as the given instruction's shadow /// and origin. void Done(Instruction *I) { if (CombineShadow) { @@ -1666,7 +1733,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { using ShadowAndOriginCombiner = Combiner<true>; using OriginCombiner = Combiner<false>; - /// \brief Propagate origin for arbitrary operation. + /// Propagate origin for arbitrary operation. void setOriginForNaryOp(Instruction &I) { if (!MS.TrackOrigins) return; IRBuilder<> IRB(&I); @@ -1684,7 +1751,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Ty->getPrimitiveSizeInBits(); } - /// \brief Cast between two shadow types, extending or truncating as + /// Cast between two shadow types, extending or truncating as /// necessary. Value *CreateShadowCast(IRBuilder<> &IRB, Value *V, Type *dstTy, bool Signed = false) { @@ -1706,7 +1773,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // TODO: handle struct types. } - /// \brief Cast an application value to the type of its own shadow. + /// Cast an application value to the type of its own shadow. Value *CreateAppToShadowCast(IRBuilder<> &IRB, Value *V) { Type *ShadowTy = getShadowTy(V); if (V->getType() == ShadowTy) @@ -1717,7 +1784,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return IRB.CreateBitCast(V, ShadowTy); } - /// \brief Propagate shadow for arbitrary operation. + /// Propagate shadow for arbitrary operation. void handleShadowOr(Instruction &I) { IRBuilder<> IRB(&I); ShadowAndOriginCombiner SC(this, IRB); @@ -1726,7 +1793,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { SC.Done(&I); } - // \brief Handle multiplication by constant. + // Handle multiplication by constant. // // Handle a special case of multiplication by constant that may have one or // more zeros in the lower bits. This makes corresponding number of lower bits @@ -1788,7 +1855,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void visitSub(BinaryOperator &I) { handleShadowOr(I); } void visitXor(BinaryOperator &I) { handleShadowOr(I); } - void handleDiv(Instruction &I) { + void handleIntegerDiv(Instruction &I) { IRBuilder<> IRB(&I); // Strict on the second argument. insertShadowCheck(I.getOperand(1), &I); @@ -1796,14 +1863,17 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOrigin(&I, getOrigin(&I, 0)); } - void visitUDiv(BinaryOperator &I) { handleDiv(I); } - void visitSDiv(BinaryOperator &I) { handleDiv(I); } - void visitFDiv(BinaryOperator &I) { handleDiv(I); } - void visitURem(BinaryOperator &I) { handleDiv(I); } - void visitSRem(BinaryOperator &I) { handleDiv(I); } - void visitFRem(BinaryOperator &I) { handleDiv(I); } + void visitUDiv(BinaryOperator &I) { handleIntegerDiv(I); } + void visitSDiv(BinaryOperator &I) { handleIntegerDiv(I); } + void visitURem(BinaryOperator &I) { handleIntegerDiv(I); } + void visitSRem(BinaryOperator &I) { handleIntegerDiv(I); } + + // Floating point division is side-effect free. We can not require that the + // divisor is fully initialized and must propagate shadow. See PR37523. + void visitFDiv(BinaryOperator &I) { handleShadowOr(I); } + void visitFRem(BinaryOperator &I) { handleShadowOr(I); } - /// \brief Instrument == and != comparisons. + /// Instrument == and != comparisons. /// /// Sometimes the comparison result is known even if some of the bits of the /// arguments are not. @@ -1841,7 +1911,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOriginForNaryOp(I); } - /// \brief Build the lowest possible value of V, taking into account V's + /// Build the lowest possible value of V, taking into account V's /// uninitialized bits. Value *getLowestPossibleValue(IRBuilder<> &IRB, Value *A, Value *Sa, bool isSigned) { @@ -1858,7 +1928,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } } - /// \brief Build the highest possible value of V, taking into account V's + /// Build the highest possible value of V, taking into account V's /// uninitialized bits. Value *getHighestPossibleValue(IRBuilder<> &IRB, Value *A, Value *Sa, bool isSigned) { @@ -1875,7 +1945,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } } - /// \brief Instrument relational comparisons. + /// Instrument relational comparisons. /// /// This function does exact shadow propagation for all relational /// comparisons of integers, pointers and vectors of those. @@ -1908,7 +1978,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOriginForNaryOp(I); } - /// \brief Instrument signed relational comparisons. + /// Instrument signed relational comparisons. /// /// Handle sign bit tests: x<0, x>=0, x<=-1, x>-1 by propagating the highest /// bit of the shadow. Everything else is delegated to handleShadowOr(). @@ -1992,7 +2062,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void visitAShr(BinaryOperator &I) { handleShift(I); } void visitLShr(BinaryOperator &I) { handleShift(I); } - /// \brief Instrument llvm.memmove + /// Instrument llvm.memmove /// /// At this point we don't know if llvm.memmove will be inlined or not. /// If we don't instrument it and it gets inlined, @@ -2045,7 +2115,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { VAHelper->visitVACopyInst(I); } - /// \brief Handle vector store-like intrinsics. + /// Handle vector store-like intrinsics. /// /// Instrument intrinsics that look like a simple SIMD store: writes memory, /// has 1 pointer argument and 1 vector argument, returns void. @@ -2057,8 +2127,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // We don't know the pointer alignment (could be unaligned SSE store!). // Have to assume to worst case. - std::tie(ShadowPtr, OriginPtr) = - getShadowOriginPtr(Addr, IRB, Shadow->getType(), /*Alignment*/ 1); + std::tie(ShadowPtr, OriginPtr) = getShadowOriginPtr( + Addr, IRB, Shadow->getType(), /*Alignment*/ 1, /*isStore*/ true); IRB.CreateAlignedStore(Shadow, ShadowPtr, 1); if (ClCheckAccessAddress) @@ -2069,7 +2139,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return true; } - /// \brief Handle vector load-like intrinsics. + /// Handle vector load-like intrinsics. /// /// Instrument intrinsics that look like a simple SIMD load: reads memory, /// has 1 pointer argument, returns a vector. @@ -2084,7 +2154,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // Have to assume to worst case. unsigned Alignment = 1; std::tie(ShadowPtr, OriginPtr) = - getShadowOriginPtr(Addr, IRB, ShadowTy, Alignment); + getShadowOriginPtr(Addr, IRB, ShadowTy, Alignment, /*isStore*/ false); setShadow(&I, IRB.CreateAlignedLoad(ShadowPtr, Alignment, "_msld")); } else { setShadow(&I, getCleanShadow(&I)); @@ -2102,7 +2172,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return true; } - /// \brief Handle (SIMD arithmetic)-like intrinsics. + /// Handle (SIMD arithmetic)-like intrinsics. /// /// Instrument intrinsics with any number of arguments of the same type, /// equal to the return type. The type should be simple (no aggregates or @@ -2132,7 +2202,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return true; } - /// \brief Heuristically instrument unknown intrinsics. + /// Heuristically instrument unknown intrinsics. /// /// The main purpose of this code is to do something reasonable with all /// random intrinsics we might encounter, most importantly - SIMD intrinsics. @@ -2182,7 +2252,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOrigin(&I, getOrigin(Op)); } - // \brief Instrument vector convert instrinsic. + // Instrument vector convert instrinsic. // // This function instruments intrinsics like cvtsi2ss: // %Out = int_xxx_cvtyyy(%ConvertOp) @@ -2285,7 +2355,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return IRB.CreateSExt(S2, T); } - // \brief Instrument vector shift instrinsic. + // Instrument vector shift instrinsic. // // This function instruments intrinsics like int_x86_avx2_psll_w. // Intrinsic shifts %In by %ShiftSize bits. @@ -2310,14 +2380,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOriginForNaryOp(I); } - // \brief Get an X86_MMX-sized vector type. + // Get an X86_MMX-sized vector type. Type *getMMXVectorTy(unsigned EltSizeInBits) { const unsigned X86_MMXSizeInBits = 64; return VectorType::get(IntegerType::get(*MS.C, EltSizeInBits), X86_MMXSizeInBits / EltSizeInBits); } - // \brief Returns a signed counterpart for an (un)signed-saturate-and-pack + // Returns a signed counterpart for an (un)signed-saturate-and-pack // intrinsic. Intrinsic::ID getSignedPackIntrinsic(Intrinsic::ID id) { switch (id) { @@ -2348,7 +2418,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } } - // \brief Instrument vector pack instrinsic. + // Instrument vector pack instrinsic. // // This function instruments intrinsics like x86_mmx_packsswb, that // packs elements of 2 input vectors into half as many bits with saturation. @@ -2391,7 +2461,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOriginForNaryOp(I); } - // \brief Instrument sum-of-absolute-differencies intrinsic. + // Instrument sum-of-absolute-differencies intrinsic. void handleVectorSadIntrinsic(IntrinsicInst &I) { const unsigned SignificantBitsPerResultElement = 16; bool isX86_MMX = I.getOperand(0)->getType()->isX86_MMXTy(); @@ -2410,7 +2480,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOriginForNaryOp(I); } - // \brief Instrument multiply-add intrinsic. + // Instrument multiply-add intrinsic. void handleVectorPmaddIntrinsic(IntrinsicInst &I, unsigned EltSizeInBits = 0) { bool isX86_MMX = I.getOperand(0)->getType()->isX86_MMXTy(); @@ -2425,7 +2495,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOriginForNaryOp(I); } - // \brief Instrument compare-packed intrinsic. + // Instrument compare-packed intrinsic. // Basically, an or followed by sext(icmp ne 0) to end up with all-zeros or // all-ones shadow. void handleVectorComparePackedIntrinsic(IntrinsicInst &I) { @@ -2438,7 +2508,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOriginForNaryOp(I); } - // \brief Instrument compare-scalar intrinsic. + // Instrument compare-scalar intrinsic. // This handles both cmp* intrinsics which return the result in the first // element of a vector, and comi* which return the result as i32. void handleVectorCompareScalarIntrinsic(IntrinsicInst &I) { @@ -2453,7 +2523,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRBuilder<> IRB(&I); Value* Addr = I.getArgOperand(0); Type *Ty = IRB.getInt32Ty(); - Value *ShadowPtr = getShadowOriginPtr(Addr, IRB, Ty, /*Alignment*/ 1).first; + Value *ShadowPtr = + getShadowOriginPtr(Addr, IRB, Ty, /*Alignment*/ 1, /*isStore*/ true) + .first; IRB.CreateStore(getCleanShadow(Ty), IRB.CreatePointerCast(ShadowPtr, Ty->getPointerTo())); @@ -2471,7 +2543,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { unsigned Alignment = 1; Value *ShadowPtr, *OriginPtr; std::tie(ShadowPtr, OriginPtr) = - getShadowOriginPtr(Addr, IRB, Ty, Alignment); + getShadowOriginPtr(Addr, IRB, Ty, Alignment, /*isStore*/ false); if (ClCheckAccessAddress) insertShadowCheck(Addr, &I); @@ -2482,11 +2554,98 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { insertShadowCheck(Shadow, Origin, &I); } + void handleMaskedStore(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + Value *V = I.getArgOperand(0); + Value *Addr = I.getArgOperand(1); + unsigned Align = cast<ConstantInt>(I.getArgOperand(2))->getZExtValue(); + Value *Mask = I.getArgOperand(3); + Value *Shadow = getShadow(V); + + Value *ShadowPtr; + Value *OriginPtr; + std::tie(ShadowPtr, OriginPtr) = getShadowOriginPtr( + Addr, IRB, Shadow->getType(), Align, /*isStore*/ true); + + if (ClCheckAccessAddress) { + insertShadowCheck(Addr, &I); + // Uninitialized mask is kind of like uninitialized address, but not as + // scary. + insertShadowCheck(Mask, &I); + } + + IRB.CreateMaskedStore(Shadow, ShadowPtr, Align, Mask); + + if (MS.TrackOrigins) { + auto &DL = F.getParent()->getDataLayout(); + paintOrigin(IRB, getOrigin(V), OriginPtr, + DL.getTypeStoreSize(Shadow->getType()), + std::max(Align, kMinOriginAlignment)); + } + } + + bool handleMaskedLoad(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + Value *Addr = I.getArgOperand(0); + unsigned Align = cast<ConstantInt>(I.getArgOperand(1))->getZExtValue(); + Value *Mask = I.getArgOperand(2); + Value *PassThru = I.getArgOperand(3); + + Type *ShadowTy = getShadowTy(&I); + Value *ShadowPtr, *OriginPtr; + if (PropagateShadow) { + std::tie(ShadowPtr, OriginPtr) = + getShadowOriginPtr(Addr, IRB, ShadowTy, Align, /*isStore*/ false); + setShadow(&I, IRB.CreateMaskedLoad(ShadowPtr, Align, Mask, + getShadow(PassThru), "_msmaskedld")); + } else { + setShadow(&I, getCleanShadow(&I)); + } + + if (ClCheckAccessAddress) { + insertShadowCheck(Addr, &I); + insertShadowCheck(Mask, &I); + } + + if (MS.TrackOrigins) { + if (PropagateShadow) { + // Choose between PassThru's and the loaded value's origins. + Value *MaskedPassThruShadow = IRB.CreateAnd( + getShadow(PassThru), IRB.CreateSExt(IRB.CreateNeg(Mask), ShadowTy)); + + Value *Acc = IRB.CreateExtractElement( + MaskedPassThruShadow, ConstantInt::get(IRB.getInt32Ty(), 0)); + for (int i = 1, N = PassThru->getType()->getVectorNumElements(); i < N; + ++i) { + Value *More = IRB.CreateExtractElement( + MaskedPassThruShadow, ConstantInt::get(IRB.getInt32Ty(), i)); + Acc = IRB.CreateOr(Acc, More); + } + + Value *Origin = IRB.CreateSelect( + IRB.CreateICmpNE(Acc, Constant::getNullValue(Acc->getType())), + getOrigin(PassThru), IRB.CreateLoad(OriginPtr)); + + setOrigin(&I, Origin); + } else { + setOrigin(&I, getCleanOrigin()); + } + } + return true; + } + + void visitIntrinsicInst(IntrinsicInst &I) { switch (I.getIntrinsicID()) { case Intrinsic::bswap: handleBswap(I); break; + case Intrinsic::masked_store: + handleMaskedStore(I); + break; + case Intrinsic::masked_load: + handleMaskedLoad(I); + break; case Intrinsic::x86_sse_stmxcsr: handleStmxcsr(I); break; @@ -2501,20 +2660,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { case Intrinsic::x86_avx512_cvttss2usi: case Intrinsic::x86_avx512_cvttsd2usi64: case Intrinsic::x86_avx512_cvttsd2usi: - case Intrinsic::x86_avx512_cvtusi2sd: case Intrinsic::x86_avx512_cvtusi2ss: case Intrinsic::x86_avx512_cvtusi642sd: case Intrinsic::x86_avx512_cvtusi642ss: case Intrinsic::x86_sse2_cvtsd2si64: case Intrinsic::x86_sse2_cvtsd2si: case Intrinsic::x86_sse2_cvtsd2ss: - case Intrinsic::x86_sse2_cvtsi2sd: - case Intrinsic::x86_sse2_cvtsi642sd: - case Intrinsic::x86_sse2_cvtss2sd: case Intrinsic::x86_sse2_cvttsd2si64: case Intrinsic::x86_sse2_cvttsd2si: - case Intrinsic::x86_sse_cvtsi2ss: - case Intrinsic::x86_sse_cvtsi642ss: case Intrinsic::x86_sse_cvtss2si64: case Intrinsic::x86_sse_cvtss2si: case Intrinsic::x86_sse_cvttss2si64: @@ -2715,7 +2868,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // outputs as clean. Note that any side effects of the inline asm that are // not immediately visible in its constraints are not handled. if (Call->isInlineAsm()) { - visitInstruction(I); + if (ClHandleAsmConservative) + visitAsmInstruction(I); + else + visitInstruction(I); return; } @@ -2738,13 +2894,13 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRBuilder<> IRB(&I); unsigned ArgOffset = 0; - DEBUG(dbgs() << " CallSite: " << I << "\n"); + LLVM_DEBUG(dbgs() << " CallSite: " << I << "\n"); for (CallSite::arg_iterator ArgIt = CS.arg_begin(), End = CS.arg_end(); ArgIt != End; ++ArgIt) { Value *A = *ArgIt; unsigned i = ArgIt - CS.arg_begin(); if (!A->getType()->isSized()) { - DEBUG(dbgs() << "Arg " << i << " is not sized: " << I << "\n"); + LLVM_DEBUG(dbgs() << "Arg " << i << " is not sized: " << I << "\n"); continue; } unsigned Size = 0; @@ -2754,8 +2910,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // __msan_param_tls. Value *ArgShadow = getShadow(A); Value *ArgShadowBase = getShadowPtrForArgument(A, IRB, ArgOffset); - DEBUG(dbgs() << " Arg#" << i << ": " << *A << - " Shadow: " << *ArgShadow << "\n"); + LLVM_DEBUG(dbgs() << " Arg#" << i << ": " << *A + << " Shadow: " << *ArgShadow << "\n"); bool ArgIsInitialized = false; const DataLayout &DL = F.getParent()->getDataLayout(); if (CS.paramHasAttr(i, Attribute::ByVal)) { @@ -2765,10 +2921,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (ArgOffset + Size > kParamTLSSize) break; unsigned ParamAlignment = CS.getParamAlignment(i); unsigned Alignment = std::min(ParamAlignment, kShadowTLSAlignment); - Value *AShadowPtr = - getShadowOriginPtr(A, IRB, IRB.getInt8Ty(), Alignment).first; + Value *AShadowPtr = getShadowOriginPtr(A, IRB, IRB.getInt8Ty(), + Alignment, /*isStore*/ false) + .first; - Store = IRB.CreateMemCpy(ArgShadowBase, AShadowPtr, Size, Alignment); + Store = IRB.CreateMemCpy(ArgShadowBase, Alignment, AShadowPtr, + Alignment, Size); } else { Size = DL.getTypeAllocSize(A->getType()); if (ArgOffset + Size > kParamTLSSize) break; @@ -2782,10 +2940,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { getOriginPtrForArgument(A, IRB, ArgOffset)); (void)Store; assert(Size != 0 && Store != nullptr); - DEBUG(dbgs() << " Param:" << *Store << "\n"); + LLVM_DEBUG(dbgs() << " Param:" << *Store << "\n"); ArgOffset += alignTo(Size, 8); } - DEBUG(dbgs() << " done with call args\n"); + LLVM_DEBUG(dbgs() << " done with call args\n"); FunctionType *FT = cast<FunctionType>(CS.getCalledValue()->getType()->getContainedType(0)); @@ -2888,8 +3046,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRB.CreateCall(MS.MsanPoisonStackFn, {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len}); } else { - Value *ShadowBase = - getShadowOriginPtr(&I, IRB, IRB.getInt8Ty(), I.getAlignment()).first; + Value *ShadowBase = getShadowOriginPtr(&I, IRB, IRB.getInt8Ty(), + I.getAlignment(), /*isStore*/ true) + .first; Value *PoisonValue = IRB.getInt8(PoisonStack ? ClPoisonStackPattern : 0); IRB.CreateMemSet(ShadowBase, PoisonValue, Len, I.getAlignment()); @@ -2991,24 +3150,24 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void visitExtractValueInst(ExtractValueInst &I) { IRBuilder<> IRB(&I); Value *Agg = I.getAggregateOperand(); - DEBUG(dbgs() << "ExtractValue: " << I << "\n"); + LLVM_DEBUG(dbgs() << "ExtractValue: " << I << "\n"); Value *AggShadow = getShadow(Agg); - DEBUG(dbgs() << " AggShadow: " << *AggShadow << "\n"); + LLVM_DEBUG(dbgs() << " AggShadow: " << *AggShadow << "\n"); Value *ResShadow = IRB.CreateExtractValue(AggShadow, I.getIndices()); - DEBUG(dbgs() << " ResShadow: " << *ResShadow << "\n"); + LLVM_DEBUG(dbgs() << " ResShadow: " << *ResShadow << "\n"); setShadow(&I, ResShadow); setOriginForNaryOp(I); } void visitInsertValueInst(InsertValueInst &I) { IRBuilder<> IRB(&I); - DEBUG(dbgs() << "InsertValue: " << I << "\n"); + LLVM_DEBUG(dbgs() << "InsertValue: " << I << "\n"); Value *AggShadow = getShadow(I.getAggregateOperand()); Value *InsShadow = getShadow(I.getInsertedValueOperand()); - DEBUG(dbgs() << " AggShadow: " << *AggShadow << "\n"); - DEBUG(dbgs() << " InsShadow: " << *InsShadow << "\n"); + LLVM_DEBUG(dbgs() << " AggShadow: " << *AggShadow << "\n"); + LLVM_DEBUG(dbgs() << " InsShadow: " << *InsShadow << "\n"); Value *Res = IRB.CreateInsertValue(AggShadow, InsShadow, I.getIndices()); - DEBUG(dbgs() << " Res: " << *Res << "\n"); + LLVM_DEBUG(dbgs() << " Res: " << *Res << "\n"); setShadow(&I, Res); setOriginForNaryOp(I); } @@ -3023,25 +3182,58 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } void visitResumeInst(ResumeInst &I) { - DEBUG(dbgs() << "Resume: " << I << "\n"); + LLVM_DEBUG(dbgs() << "Resume: " << I << "\n"); // Nothing to do here. } void visitCleanupReturnInst(CleanupReturnInst &CRI) { - DEBUG(dbgs() << "CleanupReturn: " << CRI << "\n"); + LLVM_DEBUG(dbgs() << "CleanupReturn: " << CRI << "\n"); // Nothing to do here. } void visitCatchReturnInst(CatchReturnInst &CRI) { - DEBUG(dbgs() << "CatchReturn: " << CRI << "\n"); + LLVM_DEBUG(dbgs() << "CatchReturn: " << CRI << "\n"); // Nothing to do here. } + void visitAsmInstruction(Instruction &I) { + // Conservative inline assembly handling: check for poisoned shadow of + // asm() arguments, then unpoison the result and all the memory locations + // pointed to by those arguments. + CallInst *CI = dyn_cast<CallInst>(&I); + + for (size_t i = 0, n = CI->getNumOperands(); i < n; i++) { + Value *Operand = CI->getOperand(i); + if (Operand->getType()->isSized()) + insertShadowCheck(Operand, &I); + } + setShadow(&I, getCleanShadow(&I)); + setOrigin(&I, getCleanOrigin()); + IRBuilder<> IRB(&I); + IRB.SetInsertPoint(I.getNextNode()); + for (size_t i = 0, n = CI->getNumOperands(); i < n; i++) { + Value *Operand = CI->getOperand(i); + Type *OpType = Operand->getType(); + if (!OpType->isPointerTy()) + continue; + Type *ElType = OpType->getPointerElementType(); + if (!ElType->isSized()) + continue; + Value *ShadowPtr, *OriginPtr; + std::tie(ShadowPtr, OriginPtr) = getShadowOriginPtr( + Operand, IRB, ElType, /*Alignment*/ 1, /*isStore*/ true); + Value *CShadow = getCleanShadow(ElType); + IRB.CreateStore( + CShadow, + IRB.CreatePointerCast(ShadowPtr, CShadow->getType()->getPointerTo())); + } + } + void visitInstruction(Instruction &I) { // Everything else: stop propagating and check for poisoned shadow. if (ClDumpStrictInstructions) dumpInst(I); - DEBUG(dbgs() << "DEFAULT: " << I << "\n"); + LLVM_DEBUG(dbgs() << "DEFAULT: " << I << "\n"); for (size_t i = 0, n = I.getNumOperands(); i < n; i++) { Value *Operand = I.getOperand(i); if (Operand->getType()->isSized()) @@ -3052,7 +3244,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } }; -/// \brief AMD64-specific implementation of VarArgHelper. +/// AMD64-specific implementation of VarArgHelper. struct VarArgAMD64Helper : public VarArgHelper { // An unfortunate workaround for asymmetric lowering of va_arg stuff. // See a comment in visitCallSite for more details. @@ -3116,10 +3308,12 @@ struct VarArgAMD64Helper : public VarArgHelper { getShadowPtrForVAArgument(RealTy, IRB, OverflowOffset); OverflowOffset += alignTo(ArgSize, 8); Value *ShadowPtr, *OriginPtr; - std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr( - A, IRB, IRB.getInt8Ty(), kShadowTLSAlignment); + std::tie(ShadowPtr, OriginPtr) = + MSV.getShadowOriginPtr(A, IRB, IRB.getInt8Ty(), kShadowTLSAlignment, + /*isStore*/ false); - IRB.CreateMemCpy(ShadowBase, ShadowPtr, ArgSize, kShadowTLSAlignment); + IRB.CreateMemCpy(ShadowBase, kShadowTLSAlignment, ShadowPtr, + kShadowTLSAlignment, ArgSize); } else { ArgKind AK = classifyArgument(A); if (AK == AK_GeneralPurpose && GpOffset >= AMD64GpEndOffset) @@ -3157,7 +3351,7 @@ struct VarArgAMD64Helper : public VarArgHelper { IRB.CreateStore(OverflowSize, MS.VAArgOverflowSizeTLS); } - /// \brief Compute the shadow address for a given va_arg. + /// Compute the shadow address for a given va_arg. Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, int ArgOffset) { Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy); @@ -3172,7 +3366,8 @@ struct VarArgAMD64Helper : public VarArgHelper { Value *ShadowPtr, *OriginPtr; unsigned Alignment = 8; std::tie(ShadowPtr, OriginPtr) = - MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment); + MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment, + /*isStore*/ true); // Unpoison the whole __va_list_tag. // FIXME: magic ABI constants. @@ -3200,13 +3395,13 @@ struct VarArgAMD64Helper : public VarArgHelper { if (!VAStartInstrumentationList.empty()) { // If there is a va_start in this function, make a backup copy of // va_arg_tls somewhere in the function entry block. - IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI()); + IRBuilder<> IRB(MSV.ActualFnStart->getFirstNonPHI()); VAArgOverflowSize = IRB.CreateLoad(MS.VAArgOverflowSizeTLS); Value *CopySize = IRB.CreateAdd(ConstantInt::get(MS.IntptrTy, AMD64FpEndOffset), VAArgOverflowSize); VAArgTLSCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize); - IRB.CreateMemCpy(VAArgTLSCopy, MS.VAArgTLS, CopySize, 8); + IRB.CreateMemCpy(VAArgTLSCopy, 8, MS.VAArgTLS, 8, CopySize); } // Instrument va_start. @@ -3219,33 +3414,33 @@ struct VarArgAMD64Helper : public VarArgHelper { Value *RegSaveAreaPtrPtr = IRB.CreateIntToPtr( IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), ConstantInt::get(MS.IntptrTy, 16)), - Type::getInt64PtrTy(*MS.C)); + PointerType::get(Type::getInt64PtrTy(*MS.C), 0)); Value *RegSaveAreaPtr = IRB.CreateLoad(RegSaveAreaPtrPtr); Value *RegSaveAreaShadowPtr, *RegSaveAreaOriginPtr; unsigned Alignment = 16; std::tie(RegSaveAreaShadowPtr, RegSaveAreaOriginPtr) = MSV.getShadowOriginPtr(RegSaveAreaPtr, IRB, IRB.getInt8Ty(), - Alignment); - IRB.CreateMemCpy(RegSaveAreaShadowPtr, VAArgTLSCopy, AMD64FpEndOffset, - Alignment); + Alignment, /*isStore*/ true); + IRB.CreateMemCpy(RegSaveAreaShadowPtr, Alignment, VAArgTLSCopy, Alignment, + AMD64FpEndOffset); Value *OverflowArgAreaPtrPtr = IRB.CreateIntToPtr( IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), ConstantInt::get(MS.IntptrTy, 8)), - Type::getInt64PtrTy(*MS.C)); + PointerType::get(Type::getInt64PtrTy(*MS.C), 0)); Value *OverflowArgAreaPtr = IRB.CreateLoad(OverflowArgAreaPtrPtr); Value *OverflowArgAreaShadowPtr, *OverflowArgAreaOriginPtr; std::tie(OverflowArgAreaShadowPtr, OverflowArgAreaOriginPtr) = MSV.getShadowOriginPtr(OverflowArgAreaPtr, IRB, IRB.getInt8Ty(), - Alignment); + Alignment, /*isStore*/ true); Value *SrcPtr = IRB.CreateConstGEP1_32(IRB.getInt8Ty(), VAArgTLSCopy, AMD64FpEndOffset); - IRB.CreateMemCpy(OverflowArgAreaShadowPtr, SrcPtr, VAArgOverflowSize, - Alignment); + IRB.CreateMemCpy(OverflowArgAreaShadowPtr, Alignment, SrcPtr, Alignment, + VAArgOverflowSize); } } }; -/// \brief MIPS64-specific implementation of VarArgHelper. +/// MIPS64-specific implementation of VarArgHelper. struct VarArgMIPS64Helper : public VarArgHelper { Function &F; MemorySanitizer &MS; @@ -3286,7 +3481,7 @@ struct VarArgMIPS64Helper : public VarArgHelper { IRB.CreateStore(TotalVAArgSize, MS.VAArgOverflowSizeTLS); } - /// \brief Compute the shadow address for a given va_arg. + /// Compute the shadow address for a given va_arg. Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, int ArgOffset) { Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy); @@ -3301,8 +3496,8 @@ struct VarArgMIPS64Helper : public VarArgHelper { Value *VAListTag = I.getArgOperand(0); Value *ShadowPtr, *OriginPtr; unsigned Alignment = 8; - std::tie(ShadowPtr, OriginPtr) = - MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment); + std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr( + VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true); IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), /* size */ 8, Alignment, false); } @@ -3313,8 +3508,8 @@ struct VarArgMIPS64Helper : public VarArgHelper { Value *VAListTag = I.getArgOperand(0); Value *ShadowPtr, *OriginPtr; unsigned Alignment = 8; - std::tie(ShadowPtr, OriginPtr) = - MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment); + std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr( + VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true); IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), /* size */ 8, Alignment, false); } @@ -3322,7 +3517,7 @@ struct VarArgMIPS64Helper : public VarArgHelper { void finalizeInstrumentation() override { assert(!VAArgSize && !VAArgTLSCopy && "finalizeInstrumentation called twice"); - IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI()); + IRBuilder<> IRB(MSV.ActualFnStart->getFirstNonPHI()); VAArgSize = IRB.CreateLoad(MS.VAArgOverflowSizeTLS); Value *CopySize = IRB.CreateAdd(ConstantInt::get(MS.IntptrTy, 0), VAArgSize); @@ -3331,7 +3526,7 @@ struct VarArgMIPS64Helper : public VarArgHelper { // If there is a va_start in this function, make a backup copy of // va_arg_tls somewhere in the function entry block. VAArgTLSCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize); - IRB.CreateMemCpy(VAArgTLSCopy, MS.VAArgTLS, CopySize, 8); + IRB.CreateMemCpy(VAArgTLSCopy, 8, MS.VAArgTLS, 8, CopySize); } // Instrument va_start. @@ -3341,20 +3536,21 @@ struct VarArgMIPS64Helper : public VarArgHelper { IRBuilder<> IRB(OrigInst->getNextNode()); Value *VAListTag = OrigInst->getArgOperand(0); Value *RegSaveAreaPtrPtr = - IRB.CreateIntToPtr(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), - Type::getInt64PtrTy(*MS.C)); + IRB.CreateIntToPtr(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), + PointerType::get(Type::getInt64PtrTy(*MS.C), 0)); Value *RegSaveAreaPtr = IRB.CreateLoad(RegSaveAreaPtrPtr); Value *RegSaveAreaShadowPtr, *RegSaveAreaOriginPtr; unsigned Alignment = 8; std::tie(RegSaveAreaShadowPtr, RegSaveAreaOriginPtr) = MSV.getShadowOriginPtr(RegSaveAreaPtr, IRB, IRB.getInt8Ty(), - Alignment); - IRB.CreateMemCpy(RegSaveAreaShadowPtr, VAArgTLSCopy, CopySize, Alignment); + Alignment, /*isStore*/ true); + IRB.CreateMemCpy(RegSaveAreaShadowPtr, Alignment, VAArgTLSCopy, Alignment, + CopySize); } } }; -/// \brief AArch64-specific implementation of VarArgHelper. +/// AArch64-specific implementation of VarArgHelper. struct VarArgAArch64Helper : public VarArgHelper { static const unsigned kAArch64GrArgSize = 64; static const unsigned kAArch64VrArgSize = 128; @@ -3461,8 +3657,8 @@ struct VarArgAArch64Helper : public VarArgHelper { Value *VAListTag = I.getArgOperand(0); Value *ShadowPtr, *OriginPtr; unsigned Alignment = 8; - std::tie(ShadowPtr, OriginPtr) = - MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment); + std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr( + VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true); IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), /* size */ 32, Alignment, false); } @@ -3473,8 +3669,8 @@ struct VarArgAArch64Helper : public VarArgHelper { Value *VAListTag = I.getArgOperand(0); Value *ShadowPtr, *OriginPtr; unsigned Alignment = 8; - std::tie(ShadowPtr, OriginPtr) = - MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment); + std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr( + VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true); IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), /* size */ 32, Alignment, false); } @@ -3506,13 +3702,13 @@ struct VarArgAArch64Helper : public VarArgHelper { if (!VAStartInstrumentationList.empty()) { // If there is a va_start in this function, make a backup copy of // va_arg_tls somewhere in the function entry block. - IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI()); + IRBuilder<> IRB(MSV.ActualFnStart->getFirstNonPHI()); VAArgOverflowSize = IRB.CreateLoad(MS.VAArgOverflowSizeTLS); Value *CopySize = IRB.CreateAdd(ConstantInt::get(MS.IntptrTy, AArch64VAEndOffset), VAArgOverflowSize); VAArgTLSCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize); - IRB.CreateMemCpy(VAArgTLSCopy, MS.VAArgTLS, CopySize, 8); + IRB.CreateMemCpy(VAArgTLSCopy, 8, MS.VAArgTLS, 8, CopySize); } Value *GrArgSize = ConstantInt::get(MS.IntptrTy, kAArch64GrArgSize); @@ -3563,14 +3759,14 @@ struct VarArgAArch64Helper : public VarArgHelper { Value *GrRegSaveAreaShadowPtr = MSV.getShadowOriginPtr(GrRegSaveAreaPtr, IRB, IRB.getInt8Ty(), - /*Alignment*/ 8) + /*Alignment*/ 8, /*isStore*/ true) .first; Value *GrSrcPtr = IRB.CreateInBoundsGEP(IRB.getInt8Ty(), VAArgTLSCopy, GrRegSaveAreaShadowPtrOff); Value *GrCopySize = IRB.CreateSub(GrArgSize, GrRegSaveAreaShadowPtrOff); - IRB.CreateMemCpy(GrRegSaveAreaShadowPtr, GrSrcPtr, GrCopySize, 8); + IRB.CreateMemCpy(GrRegSaveAreaShadowPtr, 8, GrSrcPtr, 8, GrCopySize); // Again, but for FP/SIMD values. Value *VrRegSaveAreaShadowPtrOff = @@ -3578,7 +3774,7 @@ struct VarArgAArch64Helper : public VarArgHelper { Value *VrRegSaveAreaShadowPtr = MSV.getShadowOriginPtr(VrRegSaveAreaPtr, IRB, IRB.getInt8Ty(), - /*Alignment*/ 8) + /*Alignment*/ 8, /*isStore*/ true) .first; Value *VrSrcPtr = IRB.CreateInBoundsGEP( @@ -3588,25 +3784,25 @@ struct VarArgAArch64Helper : public VarArgHelper { VrRegSaveAreaShadowPtrOff); Value *VrCopySize = IRB.CreateSub(VrArgSize, VrRegSaveAreaShadowPtrOff); - IRB.CreateMemCpy(VrRegSaveAreaShadowPtr, VrSrcPtr, VrCopySize, 8); + IRB.CreateMemCpy(VrRegSaveAreaShadowPtr, 8, VrSrcPtr, 8, VrCopySize); // And finally for remaining arguments. Value *StackSaveAreaShadowPtr = MSV.getShadowOriginPtr(StackSaveAreaPtr, IRB, IRB.getInt8Ty(), - /*Alignment*/ 16) + /*Alignment*/ 16, /*isStore*/ true) .first; Value *StackSrcPtr = IRB.CreateInBoundsGEP(IRB.getInt8Ty(), VAArgTLSCopy, IRB.getInt32(AArch64VAEndOffset)); - IRB.CreateMemCpy(StackSaveAreaShadowPtr, StackSrcPtr, - VAArgOverflowSize, 16); + IRB.CreateMemCpy(StackSaveAreaShadowPtr, 16, StackSrcPtr, 16, + VAArgOverflowSize); } } }; -/// \brief PowerPC64-specific implementation of VarArgHelper. +/// PowerPC64-specific implementation of VarArgHelper. struct VarArgPowerPC64Helper : public VarArgHelper { Function &F; MemorySanitizer &MS; @@ -3657,9 +3853,10 @@ struct VarArgPowerPC64Helper : public VarArgHelper { VAArgOffset - VAArgBase); Value *AShadowPtr, *AOriginPtr; std::tie(AShadowPtr, AOriginPtr) = MSV.getShadowOriginPtr( - A, IRB, IRB.getInt8Ty(), kShadowTLSAlignment); + A, IRB, IRB.getInt8Ty(), kShadowTLSAlignment, /*isStore*/ false); - IRB.CreateMemCpy(Base, AShadowPtr, ArgSize, kShadowTLSAlignment); + IRB.CreateMemCpy(Base, kShadowTLSAlignment, AShadowPtr, + kShadowTLSAlignment, ArgSize); } VAArgOffset += alignTo(ArgSize, 8); } else { @@ -3704,7 +3901,7 @@ struct VarArgPowerPC64Helper : public VarArgHelper { IRB.CreateStore(TotalVAArgSize, MS.VAArgOverflowSizeTLS); } - /// \brief Compute the shadow address for a given va_arg. + /// Compute the shadow address for a given va_arg. Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, int ArgOffset) { Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy); @@ -3719,8 +3916,8 @@ struct VarArgPowerPC64Helper : public VarArgHelper { Value *VAListTag = I.getArgOperand(0); Value *ShadowPtr, *OriginPtr; unsigned Alignment = 8; - std::tie(ShadowPtr, OriginPtr) = - MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment); + std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr( + VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true); IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), /* size */ 8, Alignment, false); } @@ -3730,8 +3927,8 @@ struct VarArgPowerPC64Helper : public VarArgHelper { Value *VAListTag = I.getArgOperand(0); Value *ShadowPtr, *OriginPtr; unsigned Alignment = 8; - std::tie(ShadowPtr, OriginPtr) = - MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment); + std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr( + VAListTag, IRB, IRB.getInt8Ty(), Alignment, /*isStore*/ true); // Unpoison the whole __va_list_tag. // FIXME: magic ABI constants. IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), @@ -3741,7 +3938,7 @@ struct VarArgPowerPC64Helper : public VarArgHelper { void finalizeInstrumentation() override { assert(!VAArgSize && !VAArgTLSCopy && "finalizeInstrumentation called twice"); - IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI()); + IRBuilder<> IRB(MSV.ActualFnStart->getFirstNonPHI()); VAArgSize = IRB.CreateLoad(MS.VAArgOverflowSizeTLS); Value *CopySize = IRB.CreateAdd(ConstantInt::get(MS.IntptrTy, 0), VAArgSize); @@ -3750,7 +3947,7 @@ struct VarArgPowerPC64Helper : public VarArgHelper { // If there is a va_start in this function, make a backup copy of // va_arg_tls somewhere in the function entry block. VAArgTLSCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize); - IRB.CreateMemCpy(VAArgTLSCopy, MS.VAArgTLS, CopySize, 8); + IRB.CreateMemCpy(VAArgTLSCopy, 8, MS.VAArgTLS, 8, CopySize); } // Instrument va_start. @@ -3760,20 +3957,21 @@ struct VarArgPowerPC64Helper : public VarArgHelper { IRBuilder<> IRB(OrigInst->getNextNode()); Value *VAListTag = OrigInst->getArgOperand(0); Value *RegSaveAreaPtrPtr = - IRB.CreateIntToPtr(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), - Type::getInt64PtrTy(*MS.C)); + IRB.CreateIntToPtr(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), + PointerType::get(Type::getInt64PtrTy(*MS.C), 0)); Value *RegSaveAreaPtr = IRB.CreateLoad(RegSaveAreaPtrPtr); Value *RegSaveAreaShadowPtr, *RegSaveAreaOriginPtr; unsigned Alignment = 8; std::tie(RegSaveAreaShadowPtr, RegSaveAreaOriginPtr) = MSV.getShadowOriginPtr(RegSaveAreaPtr, IRB, IRB.getInt8Ty(), - Alignment); - IRB.CreateMemCpy(RegSaveAreaShadowPtr, VAArgTLSCopy, CopySize, Alignment); + Alignment, /*isStore*/ true); + IRB.CreateMemCpy(RegSaveAreaShadowPtr, Alignment, VAArgTLSCopy, Alignment, + CopySize); } } }; -/// \brief A no-op implementation of VarArgHelper. +/// A no-op implementation of VarArgHelper. struct VarArgNoOpHelper : public VarArgHelper { VarArgNoOpHelper(Function &F, MemorySanitizer &MS, MemorySanitizerVisitor &MSV) {} @@ -3796,8 +3994,7 @@ static VarArgHelper *CreateVarArgHelper(Function &Func, MemorySanitizer &Msan, Triple TargetTriple(Func.getParent()->getTargetTriple()); if (TargetTriple.getArch() == Triple::x86_64) return new VarArgAMD64Helper(Func, Msan, Visitor); - else if (TargetTriple.getArch() == Triple::mips64 || - TargetTriple.getArch() == Triple::mips64el) + else if (TargetTriple.isMIPS64()) return new VarArgMIPS64Helper(Func, Msan, Visitor); else if (TargetTriple.getArch() == Triple::aarch64) return new VarArgAArch64Helper(Func, Msan, Visitor); diff --git a/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index cb4b3a9c2545..307b7eaa2196 100644 --- a/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -48,7 +48,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/PGOInstrumentation.h" +#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" #include "CFGMST.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" @@ -119,6 +119,7 @@ #include <vector> using namespace llvm; +using ProfileCount = Function::ProfileCount; #define DEBUG_TYPE "pgo-instrumentation" @@ -223,8 +224,8 @@ static cl::opt<bool> EmitBranchProbability("pgo-emit-branch-prob", cl::init(false), cl::Hidden, cl::desc("When this option is on, the annotated " "branch probability will be emitted as " - " optimization remarks: -Rpass-analysis=" - "pgo-instr-use")); + "optimization remarks: -{Rpass|" + "pass-remarks}=pgo-instrumentation")); // Command line option to turn on CFG dot dump after profile annotation. // Defined in Analysis/BlockFrequencyInfo.cpp: -pgo-view-counts @@ -448,7 +449,7 @@ ModulePass *llvm::createPGOInstrumentationUseLegacyPass(StringRef Filename) { namespace { -/// \brief An MST based instrumentation for PGO +/// An MST based instrumentation for PGO /// /// Implements a Minimum Spanning Tree (MST) based instrumentation for PGO /// in the function level. @@ -545,7 +546,7 @@ public: computeCFGHash(); if (!ComdatMembers.empty()) renameComdatFunction(); - DEBUG(dumpInfo("after CFGMST")); + LLVM_DEBUG(dumpInfo("after CFGMST")); NumOfPGOBB += MST.BBInfos.size(); for (auto &E : MST.AllEdges) { @@ -595,12 +596,12 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() { FunctionHash = (uint64_t)SIVisitor.getNumOfSelectInsts() << 56 | (uint64_t)ValueSites[IPVK_IndirectCallTarget].size() << 48 | (uint64_t)MST.AllEdges.size() << 32 | JC.getCRC(); - DEBUG(dbgs() << "Function Hash Computation for " << F.getName() << ":\n" - << " CRC = " << JC.getCRC() - << ", Selects = " << SIVisitor.getNumOfSelectInsts() - << ", Edges = " << MST.AllEdges.size() - << ", ICSites = " << ValueSites[IPVK_IndirectCallTarget].size() - << ", Hash = " << FunctionHash << "\n";); + LLVM_DEBUG(dbgs() << "Function Hash Computation for " << F.getName() << ":\n" + << " CRC = " << JC.getCRC() + << ", Selects = " << SIVisitor.getNumOfSelectInsts() + << ", Edges = " << MST.AllEdges.size() << ", ICSites = " + << ValueSites[IPVK_IndirectCallTarget].size() + << ", Hash = " << FunctionHash << "\n";); } // Check if we can safely rename this Comdat function. @@ -701,8 +702,8 @@ BasicBlock *FuncPGOInstrumentation<Edge, BBInfo>::getInstrBB(Edge *E) { // For a critical edge, we have to split. Instrument the newly // created BB. NumOfPGOSplit++; - DEBUG(dbgs() << "Split critical edge: " << getBBInfo(SrcBB).Index << " --> " - << getBBInfo(DestBB).Index << "\n"); + LLVM_DEBUG(dbgs() << "Split critical edge: " << getBBInfo(SrcBB).Index + << " --> " << getBBInfo(DestBB).Index << "\n"); unsigned SuccNum = GetSuccessorNumber(SrcBB, DestBB); BasicBlock *InstrBB = SplitCriticalEdge(TI, SuccNum); assert(InstrBB && "Critical edge is not split"); @@ -752,8 +753,8 @@ static void instrumentOneFunc( for (auto &I : FuncInfo.ValueSites[IPVK_IndirectCallTarget]) { CallSite CS(I); Value *Callee = CS.getCalledValue(); - DEBUG(dbgs() << "Instrument one indirect call: CallSite Index = " - << NumIndirectCallSites << "\n"); + LLVM_DEBUG(dbgs() << "Instrument one indirect call: CallSite Index = " + << NumIndirectCallSites << "\n"); IRBuilder<> Builder(I); assert(Builder.GetInsertPoint() != I->getParent()->end() && "Cannot get the Instrumentation point"); @@ -861,7 +862,7 @@ public: // Set the branch weights based on the count values. void setBranchWeights(); - // Annotate the value profile call sites all all value kind. + // Annotate the value profile call sites for all value kind. void annotateValueSites(); // Annotate the value profile call sites for one value kind. @@ -1041,14 +1042,14 @@ bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader) { std::vector<uint64_t> &CountFromProfile = ProfileRecord.Counts; NumOfPGOFunc++; - DEBUG(dbgs() << CountFromProfile.size() << " counts\n"); + LLVM_DEBUG(dbgs() << CountFromProfile.size() << " counts\n"); uint64_t ValueSum = 0; for (unsigned I = 0, S = CountFromProfile.size(); I < S; I++) { - DEBUG(dbgs() << " " << I << ": " << CountFromProfile[I] << "\n"); + LLVM_DEBUG(dbgs() << " " << I << ": " << CountFromProfile[I] << "\n"); ValueSum += CountFromProfile[I]; } - DEBUG(dbgs() << "SUM = " << ValueSum << "\n"); + LLVM_DEBUG(dbgs() << "SUM = " << ValueSum << "\n"); getBBInfo(nullptr).UnknownCountOutEdge = 2; getBBInfo(nullptr).UnknownCountInEdge = 2; @@ -1128,7 +1129,7 @@ void PGOUseFunc::populateCounters() { } } - DEBUG(dbgs() << "Populate counts in " << NumPasses << " passes.\n"); + LLVM_DEBUG(dbgs() << "Populate counts in " << NumPasses << " passes.\n"); #ifndef NDEBUG // Assert every BB has a valid counter. for (auto &BB : F) { @@ -1139,7 +1140,7 @@ void PGOUseFunc::populateCounters() { } #endif uint64_t FuncEntryCount = getBBInfo(&*F.begin()).CountValue; - F.setEntryCount(FuncEntryCount); + F.setEntryCount(ProfileCount(FuncEntryCount, Function::PCT_Real)); uint64_t FuncMaxCount = FuncEntryCount; for (auto &BB : F) { auto BI = findBBInfo(&BB); @@ -1153,13 +1154,13 @@ void PGOUseFunc::populateCounters() { FuncInfo.SIVisitor.annotateSelects(F, this, &CountPosition); assert(CountPosition == ProfileCountSize); - DEBUG(FuncInfo.dumpInfo("after reading profile.")); + LLVM_DEBUG(FuncInfo.dumpInfo("after reading profile.")); } // Assign the scaled count values to the BB with multiple out edges. void PGOUseFunc::setBranchWeights() { // Generate MD_prof metadata for every branch instruction. - DEBUG(dbgs() << "\nSetting branch weights.\n"); + LLVM_DEBUG(dbgs() << "\nSetting branch weights.\n"); for (auto &BB : F) { TerminatorInst *TI = BB.getTerminator(); if (TI->getNumSuccessors() < 2) @@ -1200,7 +1201,7 @@ static bool isIndirectBrTarget(BasicBlock *BB) { } void PGOUseFunc::annotateIrrLoopHeaderWeights() { - DEBUG(dbgs() << "\nAnnotating irreducible loop header weights.\n"); + LLVM_DEBUG(dbgs() << "\nAnnotating irreducible loop header weights.\n"); // Find irr loop headers for (auto &BB : F) { // As a heuristic also annotate indrectbr targets as they have a high chance @@ -1333,9 +1334,9 @@ void PGOUseFunc::annotateValueSites(uint32_t Kind) { } for (auto &I : ValueSites) { - DEBUG(dbgs() << "Read one value site profile (kind = " << Kind - << "): Index = " << ValueSiteIndex << " out of " - << NumValueSites << "\n"); + LLVM_DEBUG(dbgs() << "Read one value site profile (kind = " << Kind + << "): Index = " << ValueSiteIndex << " out of " + << NumValueSites << "\n"); annotateValueSite(*M, *I, ProfileRecord, static_cast<InstrProfValueKind>(Kind), ValueSiteIndex, Kind == IPVK_MemOPSize ? MaxNumMemOPAnnotations @@ -1431,7 +1432,7 @@ static bool annotateAllFunctions( Module &M, StringRef ProfileFileName, function_ref<BranchProbabilityInfo *(Function &)> LookupBPI, function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) { - DEBUG(dbgs() << "Read in profile counters: "); + LLVM_DEBUG(dbgs() << "Read in profile counters: "); auto &Ctx = M.getContext(); // Read the counter array from file. auto ReaderOrErr = IndexedInstrProfReader::create(ProfileFileName); @@ -1517,12 +1518,13 @@ static bool annotateAllFunctions( // inconsistent MST between prof-gen and prof-use. for (auto &F : HotFunctions) { F->addFnAttr(Attribute::InlineHint); - DEBUG(dbgs() << "Set inline attribute to function: " << F->getName() - << "\n"); + LLVM_DEBUG(dbgs() << "Set inline attribute to function: " << F->getName() + << "\n"); } for (auto &F : ColdFunctions) { F->addFnAttr(Attribute::Cold); - DEBUG(dbgs() << "Set cold attribute to function: " << F->getName() << "\n"); + LLVM_DEBUG(dbgs() << "Set cold attribute to function: " << F->getName() + << "\n"); } return true; } @@ -1585,22 +1587,25 @@ void llvm::setProfMetadata(Module *M, Instruction *TI, for (const auto &ECI : EdgeCounts) Weights.push_back(scaleBranchCount(ECI, Scale)); - DEBUG(dbgs() << "Weight is: "; - for (const auto &W : Weights) { dbgs() << W << " "; } - dbgs() << "\n";); + LLVM_DEBUG(dbgs() << "Weight is: "; for (const auto &W + : Weights) { + dbgs() << W << " "; + } dbgs() << "\n";); TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); if (EmitBranchProbability) { std::string BrCondStr = getBranchCondString(TI); if (BrCondStr.empty()) return; - unsigned WSum = - std::accumulate(Weights.begin(), Weights.end(), 0, - [](unsigned w1, unsigned w2) { return w1 + w2; }); + uint64_t WSum = + std::accumulate(Weights.begin(), Weights.end(), (uint64_t)0, + [](uint64_t w1, uint64_t w2) { return w1 + w2; }); uint64_t TotalCount = - std::accumulate(EdgeCounts.begin(), EdgeCounts.end(), 0, + std::accumulate(EdgeCounts.begin(), EdgeCounts.end(), (uint64_t)0, [](uint64_t c1, uint64_t c2) { return c1 + c2; }); - BranchProbability BP(Weights[0], WSum); + Scale = calculateCountScale(WSum); + BranchProbability BP(scaleBranchCount(Weights[0], Scale), + scaleBranchCount(WSum, Scale)); std::string BranchProbStr; raw_string_ostream OS(BranchProbStr); OS << BP; diff --git a/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp b/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp index 95eb3680403a..2c71e75dadcc 100644 --- a/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp +++ b/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp @@ -25,6 +25,8 @@ #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/DomTreeUpdater.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" @@ -44,7 +46,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Instrumentation.h" -#include "llvm/Transforms/PGOInstrumentation.h" +#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include <cassert> #include <cstdint> @@ -112,6 +114,7 @@ private: AU.addRequired<BlockFrequencyInfoWrapperPass>(); AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); } }; } // end anonymous namespace @@ -133,8 +136,8 @@ namespace { class MemOPSizeOpt : public InstVisitor<MemOPSizeOpt> { public: MemOPSizeOpt(Function &Func, BlockFrequencyInfo &BFI, - OptimizationRemarkEmitter &ORE) - : Func(Func), BFI(BFI), ORE(ORE), Changed(false) { + OptimizationRemarkEmitter &ORE, DominatorTree *DT) + : Func(Func), BFI(BFI), ORE(ORE), DT(DT), Changed(false) { ValueDataArray = llvm::make_unique<InstrProfValueData[]>(MemOPMaxVersion + 2); // Get the MemOPSize range information from option MemOPSizeRange, @@ -151,8 +154,9 @@ public: if (perform(MI)) { Changed = true; ++NumOfPGOMemOPOpt; - DEBUG(dbgs() << "MemOP call: " << MI->getCalledFunction()->getName() - << "is Transformed.\n"); + LLVM_DEBUG(dbgs() << "MemOP call: " + << MI->getCalledFunction()->getName() + << "is Transformed.\n"); } } } @@ -169,6 +173,7 @@ private: Function &Func; BlockFrequencyInfo &BFI; OptimizationRemarkEmitter &ORE; + DominatorTree *DT; bool Changed; std::vector<MemIntrinsic *> WorkList; // Start of the previse range. @@ -245,9 +250,9 @@ bool MemOPSizeOpt::perform(MemIntrinsic *MI) { } ArrayRef<InstrProfValueData> VDs(ValueDataArray.get(), NumVals); - DEBUG(dbgs() << "Read one memory intrinsic profile with count " << ActualCount - << "\n"); - DEBUG( + LLVM_DEBUG(dbgs() << "Read one memory intrinsic profile with count " + << ActualCount << "\n"); + LLVM_DEBUG( for (auto &VD : VDs) { dbgs() << " (" << VD.Value << "," << VD.Count << ")\n"; }); @@ -260,8 +265,8 @@ bool MemOPSizeOpt::perform(MemIntrinsic *MI) { TotalCount = ActualCount; if (MemOPScaleCount) - DEBUG(dbgs() << "Scale counts: numerator = " << ActualCount - << " denominator = " << SavedTotalCount << "\n"); + LLVM_DEBUG(dbgs() << "Scale counts: numerator = " << ActualCount + << " denominator = " << SavedTotalCount << "\n"); // Keeping track of the count of the default case: uint64_t RemainCount = TotalCount; @@ -310,9 +315,9 @@ bool MemOPSizeOpt::perform(MemIntrinsic *MI) { uint64_t SumForOpt = TotalCount - RemainCount; - DEBUG(dbgs() << "Optimize one memory intrinsic call to " << Version - << " Versions (covering " << SumForOpt << " out of " - << TotalCount << ")\n"); + LLVM_DEBUG(dbgs() << "Optimize one memory intrinsic call to " << Version + << " Versions (covering " << SumForOpt << " out of " + << TotalCount << ")\n"); // mem_op(..., size) // ==> @@ -331,19 +336,20 @@ bool MemOPSizeOpt::perform(MemIntrinsic *MI) { // merge_bb: BasicBlock *BB = MI->getParent(); - DEBUG(dbgs() << "\n\n== Basic Block Before ==\n"); - DEBUG(dbgs() << *BB << "\n"); + LLVM_DEBUG(dbgs() << "\n\n== Basic Block Before ==\n"); + LLVM_DEBUG(dbgs() << *BB << "\n"); auto OrigBBFreq = BFI.getBlockFreq(BB); - BasicBlock *DefaultBB = SplitBlock(BB, MI); + BasicBlock *DefaultBB = SplitBlock(BB, MI, DT); BasicBlock::iterator It(*MI); ++It; assert(It != DefaultBB->end()); - BasicBlock *MergeBB = SplitBlock(DefaultBB, &(*It)); + BasicBlock *MergeBB = SplitBlock(DefaultBB, &(*It), DT); MergeBB->setName("MemOP.Merge"); BFI.setBlockFreq(MergeBB, OrigBBFreq.getFrequency()); DefaultBB->setName("MemOP.Default"); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); auto &Ctx = Func.getContext(); IRBuilder<> IRB(BB); BB->getTerminator()->eraseFromParent(); @@ -358,7 +364,11 @@ bool MemOPSizeOpt::perform(MemIntrinsic *MI) { annotateValueSite(*Func.getParent(), *MI, VDs.slice(Version), SavedRemainCount, IPVK_MemOPSize, NumVals); - DEBUG(dbgs() << "\n\n== Basic Block After==\n"); + LLVM_DEBUG(dbgs() << "\n\n== Basic Block After==\n"); + + std::vector<DominatorTree::UpdateType> Updates; + if (DT) + Updates.reserve(2 * SizeIds.size()); for (uint64_t SizeId : SizeIds) { BasicBlock *CaseBB = BasicBlock::Create( @@ -374,13 +384,20 @@ bool MemOPSizeOpt::perform(MemIntrinsic *MI) { IRBuilder<> IRBCase(CaseBB); IRBCase.CreateBr(MergeBB); SI->addCase(CaseSizeId, CaseBB); - DEBUG(dbgs() << *CaseBB << "\n"); + if (DT) { + Updates.push_back({DominatorTree::Insert, CaseBB, MergeBB}); + Updates.push_back({DominatorTree::Insert, BB, CaseBB}); + } + LLVM_DEBUG(dbgs() << *CaseBB << "\n"); } + DTU.applyUpdates(Updates); + Updates.clear(); + setProfMetadata(Func.getParent(), SI, CaseCounts, MaxCount); - DEBUG(dbgs() << *BB << "\n"); - DEBUG(dbgs() << *DefaultBB << "\n"); - DEBUG(dbgs() << *MergeBB << "\n"); + LLVM_DEBUG(dbgs() << *BB << "\n"); + LLVM_DEBUG(dbgs() << *DefaultBB << "\n"); + LLVM_DEBUG(dbgs() << *MergeBB << "\n"); ORE.emit([&]() { using namespace ore; @@ -396,13 +413,14 @@ bool MemOPSizeOpt::perform(MemIntrinsic *MI) { } // namespace static bool PGOMemOPSizeOptImpl(Function &F, BlockFrequencyInfo &BFI, - OptimizationRemarkEmitter &ORE) { + OptimizationRemarkEmitter &ORE, + DominatorTree *DT) { if (DisableMemOPOPT) return false; if (F.hasFnAttribute(Attribute::OptimizeForSize)) return false; - MemOPSizeOpt MemOPSizeOpt(F, BFI, ORE); + MemOPSizeOpt MemOPSizeOpt(F, BFI, ORE, DT); MemOPSizeOpt.perform(); return MemOPSizeOpt.isChanged(); } @@ -411,7 +429,9 @@ bool PGOMemOPSizeOptLegacyPass::runOnFunction(Function &F) { BlockFrequencyInfo &BFI = getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI(); auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - return PGOMemOPSizeOptImpl(F, BFI, ORE); + auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr; + return PGOMemOPSizeOptImpl(F, BFI, ORE, DT); } namespace llvm { @@ -421,11 +441,13 @@ PreservedAnalyses PGOMemOPSizeOpt::run(Function &F, FunctionAnalysisManager &FAM) { auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F); auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); - bool Changed = PGOMemOPSizeOptImpl(F, BFI, ORE); + auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(F); + bool Changed = PGOMemOPSizeOptImpl(F, BFI, ORE, DT); if (!Changed) return PreservedAnalyses::all(); auto PA = PreservedAnalyses(); PA.preserve<GlobalsAA>(); + PA.preserve<DominatorTreeAnalysis>(); return PA; } } // namespace llvm diff --git a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp index d950e2e730f2..a4dd48c8dd6a 100644 --- a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp +++ b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp @@ -35,7 +35,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -243,6 +242,7 @@ private: GlobalVariable *Function8bitCounterArray; // for inline-8bit-counters. GlobalVariable *FunctionPCsArray; // for pc-table. SmallVector<GlobalValue *, 20> GlobalsToAppendToUsed; + SmallVector<GlobalValue *, 20> GlobalsToAppendToCompilerUsed; SanitizerCoverageOptions Options; }; @@ -405,6 +405,7 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { // so we need to prevent them from being dead stripped. if (TargetTriple.isOSBinFormatMachO()) appendToUsed(M, GlobalsToAppendToUsed); + appendToCompilerUsed(M, GlobalsToAppendToCompilerUsed); return true; } @@ -480,6 +481,8 @@ bool SanitizerCoverageModule::runOnFunction(Function &F) { if (F.getName() == "__local_stdio_printf_options" || F.getName() == "__local_stdio_scanf_options") return false; + if (isa<UnreachableInst>(F.getEntryBlock().getTerminator())) + return false; // Don't instrument functions using SEH for now. Splitting basic blocks like // we do for coverage breaks WinEHPrepare. // FIXME: Remove this when SEH no longer uses landingpad pattern matching. @@ -592,11 +595,15 @@ void SanitizerCoverageModule::CreateFunctionLocalArrays( if (Options.Inline8bitCounters) { Function8bitCounterArray = CreateFunctionLocalArrayInSection( AllBlocks.size(), F, Int8Ty, SanCovCountersSectionName); - GlobalsToAppendToUsed.push_back(Function8bitCounterArray); + GlobalsToAppendToCompilerUsed.push_back(Function8bitCounterArray); + MDNode *MD = MDNode::get(F.getContext(), ValueAsMetadata::get(&F)); + Function8bitCounterArray->addMetadata(LLVMContext::MD_associated, *MD); } if (Options.PCTable) { FunctionPCsArray = CreatePCArray(F, AllBlocks); - GlobalsToAppendToUsed.push_back(FunctionPCsArray); + GlobalsToAppendToCompilerUsed.push_back(FunctionPCsArray); + MDNode *MD = MDNode::get(F.getContext(), ValueAsMetadata::get(&F)); + FunctionPCsArray->addMetadata(LLVMContext::MD_associated, *MD); } } @@ -659,11 +666,11 @@ void SanitizerCoverageModule::InjectTraceForSwitch( C = ConstantExpr::getCast(CastInst::ZExt, It.getCaseValue(), Int64Ty); Initializers.push_back(C); } - std::sort(Initializers.begin() + 2, Initializers.end(), - [](const Constant *A, const Constant *B) { - return cast<ConstantInt>(A)->getLimitedValue() < - cast<ConstantInt>(B)->getLimitedValue(); - }); + llvm::sort(Initializers.begin() + 2, Initializers.end(), + [](const Constant *A, const Constant *B) { + return cast<ConstantInt>(A)->getLimitedValue() < + cast<ConstantInt>(B)->getLimitedValue(); + }); ArrayType *ArrayOfInt64Ty = ArrayType::get(Int64Ty, Initializers.size()); GlobalVariable *GV = new GlobalVariable( *CurModule, ArrayOfInt64Ty, false, GlobalVariable::InternalLinkage, diff --git a/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/lib/Transforms/Instrumentation/ThreadSanitizer.cpp index ec6904486e10..fa1e5a157a0f 100644 --- a/lib/Transforms/Instrumentation/ThreadSanitizer.cpp +++ b/lib/Transforms/Instrumentation/ThreadSanitizer.cpp @@ -19,13 +19,14 @@ // The rest is handled by the run-time library. //===----------------------------------------------------------------------===// -#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" @@ -44,7 +45,6 @@ #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/EscapeEnumerator.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" using namespace llvm; @@ -339,7 +339,7 @@ bool ThreadSanitizer::addrPointsToConstantData(Value *Addr) { void ThreadSanitizer::chooseInstructionsToInstrument( SmallVectorImpl<Instruction *> &Local, SmallVectorImpl<Instruction *> &All, const DataLayout &DL) { - SmallSet<Value*, 8> WriteTargets; + SmallPtrSet<Value*, 8> WriteTargets; // Iterate from the end. for (Instruction *I : reverse(Local)) { if (StoreInst *Store = dyn_cast<StoreInst>(I)) { @@ -502,7 +502,7 @@ bool ThreadSanitizer::instrumentLoadOrStore(Instruction *I, if (Idx < 0) return false; if (IsWrite && isVtableAccess(I)) { - DEBUG(dbgs() << " VPTR : " << *I << "\n"); + LLVM_DEBUG(dbgs() << " VPTR : " << *I << "\n"); Value *StoredValue = cast<StoreInst>(I)->getValueOperand(); // StoredValue may be a vector type if we are storing several vptrs at once. // In this case, just take the first element of the vector since this is diff --git a/lib/Transforms/LLVMBuild.txt b/lib/Transforms/LLVMBuild.txt index 95482ad20225..f061c6d9285e 100644 --- a/lib/Transforms/LLVMBuild.txt +++ b/lib/Transforms/LLVMBuild.txt @@ -16,7 +16,7 @@ ;===------------------------------------------------------------------------===; [common] -subdirectories = Coroutines IPO InstCombine Instrumentation Scalar Utils Vectorize ObjCARC +subdirectories = AggressiveInstCombine Coroutines IPO InstCombine Instrumentation Scalar Utils Vectorize ObjCARC [component_0] type = Group diff --git a/lib/Transforms/ObjCARC/BlotMapVector.h b/lib/Transforms/ObjCARC/BlotMapVector.h index 5518b49c4095..9ade14c1177a 100644 --- a/lib/Transforms/ObjCARC/BlotMapVector.h +++ b/lib/Transforms/ObjCARC/BlotMapVector.h @@ -18,7 +18,7 @@ namespace llvm { -/// \brief An associative container with fast insertion-order (deterministic) +/// An associative container with fast insertion-order (deterministic) /// iteration over its elements. Plus the special blot operation. template <class KeyT, class ValueT> class BlotMapVector { /// Map keys to indices in Vector. diff --git a/lib/Transforms/ObjCARC/DependencyAnalysis.h b/lib/Transforms/ObjCARC/DependencyAnalysis.h index 8cc1232b18ca..0f13b02c806f 100644 --- a/lib/Transforms/ObjCARC/DependencyAnalysis.h +++ b/lib/Transforms/ObjCARC/DependencyAnalysis.h @@ -38,7 +38,7 @@ namespace objcarc { class ProvenanceAnalysis; /// \enum DependenceKind -/// \brief Defines different dependence kinds among various ARC constructs. +/// Defines different dependence kinds among various ARC constructs. /// /// There are several kinds of dependence-like concepts in use here. /// diff --git a/lib/Transforms/ObjCARC/ObjCARC.h b/lib/Transforms/ObjCARC/ObjCARC.h index cd9b3d96a14f..1dbe72c7569f 100644 --- a/lib/Transforms/ObjCARC/ObjCARC.h +++ b/lib/Transforms/ObjCARC/ObjCARC.h @@ -28,13 +28,13 @@ #include "llvm/Analysis/ObjCARCAnalysisUtils.h" #include "llvm/Analysis/ObjCARCInstKind.h" #include "llvm/Analysis/Passes.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" #include "llvm/Transforms/ObjCARC.h" -#include "llvm/Transforms/Utils/Local.h" namespace llvm { class raw_ostream; @@ -43,7 +43,7 @@ class raw_ostream; namespace llvm { namespace objcarc { -/// \brief Erase the given instruction. +/// Erase the given instruction. /// /// Many ObjC calls return their argument verbatim, /// so if it's such a call and the return value has users, replace them with the @@ -82,6 +82,26 @@ static inline const Instruction *getreturnRVOperand(const Instruction &Inst, return dyn_cast<InvokeInst>(Opnd); } +/// Return the list of PHI nodes that are equivalent to PN. +template<class PHINodeTy, class VectorTy> +void getEquivalentPHIs(PHINodeTy &PN, VectorTy &PHIList) { + auto *BB = PN.getParent(); + for (auto &P : BB->phis()) { + if (&P == &PN) // Do not add PN to the list. + continue; + unsigned I = 0, E = PN.getNumIncomingValues(); + for (; I < E; ++I) { + auto *BB = PN.getIncomingBlock(I); + auto *PNOpnd = PN.getIncomingValue(I)->stripPointerCasts(); + auto *POpnd = P.getIncomingValueForBlock(BB)->stripPointerCasts(); + if (PNOpnd != POpnd) + break; + } + if (I == E) + PHIList.push_back(&P); + } +} + } // end namespace objcarc } // end namespace llvm diff --git a/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp b/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp index b2c62a0e8eeb..8d3ef8fde534 100644 --- a/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp @@ -36,7 +36,7 @@ using namespace llvm::objcarc; #define DEBUG_TYPE "objc-arc-ap-elim" namespace { - /// \brief Autorelease pool elimination. + /// Autorelease pool elimination. class ObjCARCAPElim : public ModulePass { void getAnalysisUsage(AnalysisUsage &AU) const override; bool runOnModule(Module &M) override; @@ -103,10 +103,12 @@ bool ObjCARCAPElim::OptimizeBB(BasicBlock *BB) { // zap the pair. if (Push && cast<CallInst>(Inst)->getArgOperand(0) == Push) { Changed = true; - DEBUG(dbgs() << "ObjCARCAPElim::OptimizeBB: Zapping push pop " - "autorelease pair:\n" - " Pop: " << *Inst << "\n" - << " Push: " << *Push << "\n"); + LLVM_DEBUG(dbgs() << "ObjCARCAPElim::OptimizeBB: Zapping push pop " + "autorelease pair:\n" + " Pop: " + << *Inst << "\n" + << " Push: " << *Push + << "\n"); Inst->eraseFromParent(); Push->eraseFromParent(); } diff --git a/lib/Transforms/ObjCARC/ObjCARCContract.cpp b/lib/Transforms/ObjCARC/ObjCARCContract.cpp index c4e61218f3f3..1f1ea9f58739 100644 --- a/lib/Transforms/ObjCARC/ObjCARCContract.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCContract.cpp @@ -31,6 +31,7 @@ #include "ObjCARC.h" #include "ProvenanceAnalysis.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/EHPersonalities.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/Operator.h" @@ -50,7 +51,7 @@ STATISTIC(NumStoreStrongs, "Number objc_storeStrong calls formed"); //===----------------------------------------------------------------------===// namespace { - /// \brief Late ARC optimizations + /// Late ARC optimizations /// /// These change the IR in a way that makes it difficult to be analyzed by /// ObjCARCOpt, so it's run late. @@ -74,11 +75,12 @@ namespace { SmallPtrSet<CallInst *, 8> StoreStrongCalls; /// Returns true if we eliminated Inst. - bool tryToPeepholeInstruction(Function &F, Instruction *Inst, - inst_iterator &Iter, - SmallPtrSetImpl<Instruction *> &DepInsts, - SmallPtrSetImpl<const BasicBlock *> &Visited, - bool &TailOkForStoreStrong); + bool tryToPeepholeInstruction( + Function &F, Instruction *Inst, inst_iterator &Iter, + SmallPtrSetImpl<Instruction *> &DepInsts, + SmallPtrSetImpl<const BasicBlock *> &Visited, + bool &TailOkForStoreStrong, + const DenseMap<BasicBlock *, ColorVector> &BlockColors); bool optimizeRetainCall(Function &F, Instruction *Retain); @@ -88,8 +90,9 @@ namespace { SmallPtrSetImpl<Instruction *> &DependingInstructions, SmallPtrSetImpl<const BasicBlock *> &Visited); - void tryToContractReleaseIntoStoreStrong(Instruction *Release, - inst_iterator &Iter); + void tryToContractReleaseIntoStoreStrong( + Instruction *Release, inst_iterator &Iter, + const DenseMap<BasicBlock *, ColorVector> &BlockColors); void getAnalysisUsage(AnalysisUsage &AU) const override; bool doInitialization(Module &M) override; @@ -129,16 +132,18 @@ bool ObjCARCContract::optimizeRetainCall(Function &F, Instruction *Retain) { Changed = true; ++NumPeeps; - DEBUG(dbgs() << "Transforming objc_retain => " - "objc_retainAutoreleasedReturnValue since the operand is a " - "return value.\nOld: "<< *Retain << "\n"); + LLVM_DEBUG( + dbgs() << "Transforming objc_retain => " + "objc_retainAutoreleasedReturnValue since the operand is a " + "return value.\nOld: " + << *Retain << "\n"); // We do not have to worry about tail calls/does not throw since // retain/retainRV have the same properties. Constant *Decl = EP.get(ARCRuntimeEntryPointKind::RetainRV); cast<CallInst>(Retain)->setCalledFunction(Decl); - DEBUG(dbgs() << "New: " << *Retain << "\n"); + LLVM_DEBUG(dbgs() << "New: " << *Retain << "\n"); return true; } @@ -177,16 +182,19 @@ bool ObjCARCContract::contractAutorelease( Changed = true; ++NumPeeps; - DEBUG(dbgs() << " Fusing retain/autorelease!\n" - " Autorelease:" << *Autorelease << "\n" - " Retain: " << *Retain << "\n"); + LLVM_DEBUG(dbgs() << " Fusing retain/autorelease!\n" + " Autorelease:" + << *Autorelease + << "\n" + " Retain: " + << *Retain << "\n"); Constant *Decl = EP.get(Class == ARCInstKind::AutoreleaseRV ? ARCRuntimeEntryPointKind::RetainAutoreleaseRV : ARCRuntimeEntryPointKind::RetainAutorelease); Retain->setCalledFunction(Decl); - DEBUG(dbgs() << " New RetainAutorelease: " << *Retain << "\n"); + LLVM_DEBUG(dbgs() << " New RetainAutorelease: " << *Retain << "\n"); EraseInstruction(Autorelease); return true; @@ -303,6 +311,24 @@ findRetainForStoreStrongContraction(Value *New, StoreInst *Store, return Retain; } +/// Create a call instruction with the correct funclet token. Should be used +/// instead of calling CallInst::Create directly. +static CallInst * +createCallInst(Value *Func, ArrayRef<Value *> Args, const Twine &NameStr, + Instruction *InsertBefore, + const DenseMap<BasicBlock *, ColorVector> &BlockColors) { + SmallVector<OperandBundleDef, 1> OpBundles; + if (!BlockColors.empty()) { + const ColorVector &CV = BlockColors.find(InsertBefore->getParent())->second; + assert(CV.size() == 1 && "non-unique color for block!"); + Instruction *EHPad = CV.front()->getFirstNonPHI(); + if (EHPad->isEHPad()) + OpBundles.emplace_back("funclet", EHPad); + } + + return CallInst::Create(Func, Args, OpBundles, NameStr, InsertBefore); +} + /// Attempt to merge an objc_release with a store, load, and objc_retain to form /// an objc_storeStrong. An objc_storeStrong: /// @@ -330,8 +356,9 @@ findRetainForStoreStrongContraction(Value *New, StoreInst *Store, /// (4). /// 2. We need to make sure that any re-orderings of (1), (2), (3), (4) are /// safe. -void ObjCARCContract::tryToContractReleaseIntoStoreStrong(Instruction *Release, - inst_iterator &Iter) { +void ObjCARCContract::tryToContractReleaseIntoStoreStrong( + Instruction *Release, inst_iterator &Iter, + const DenseMap<BasicBlock *, ColorVector> &BlockColors) { // See if we are releasing something that we just loaded. auto *Load = dyn_cast<LoadInst>(GetArgRCIdentityRoot(Release)); if (!Load || !Load->isSimple()) @@ -365,7 +392,7 @@ void ObjCARCContract::tryToContractReleaseIntoStoreStrong(Instruction *Release, Changed = true; ++NumStoreStrongs; - DEBUG( + LLVM_DEBUG( llvm::dbgs() << " Contracting retain, release into objc_storeStrong.\n" << " Old:\n" << " Store: " << *Store << "\n" @@ -383,7 +410,7 @@ void ObjCARCContract::tryToContractReleaseIntoStoreStrong(Instruction *Release, if (Args[1]->getType() != I8X) Args[1] = new BitCastInst(Args[1], I8X, "", Store); Constant *Decl = EP.get(ARCRuntimeEntryPointKind::StoreStrong); - CallInst *StoreStrong = CallInst::Create(Decl, Args, "", Store); + CallInst *StoreStrong = createCallInst(Decl, Args, "", Store, BlockColors); StoreStrong->setDoesNotThrow(); StoreStrong->setDebugLoc(Store->getDebugLoc()); @@ -392,7 +419,8 @@ void ObjCARCContract::tryToContractReleaseIntoStoreStrong(Instruction *Release, // we can set the tail flag once we know it's safe. StoreStrongCalls.insert(StoreStrong); - DEBUG(llvm::dbgs() << " New Store Strong: " << *StoreStrong << "\n"); + LLVM_DEBUG(llvm::dbgs() << " New Store Strong: " << *StoreStrong + << "\n"); if (&*Iter == Retain) ++Iter; if (&*Iter == Store) ++Iter; @@ -407,7 +435,8 @@ bool ObjCARCContract::tryToPeepholeInstruction( Function &F, Instruction *Inst, inst_iterator &Iter, SmallPtrSetImpl<Instruction *> &DependingInsts, SmallPtrSetImpl<const BasicBlock *> &Visited, - bool &TailOkForStoreStrongs) { + bool &TailOkForStoreStrongs, + const DenseMap<BasicBlock *, ColorVector> &BlockColors) { // Only these library routines return their argument. In particular, // objc_retainBlock does not necessarily return its argument. ARCInstKind Class = GetBasicARCInstKind(Inst); @@ -449,15 +478,16 @@ bool ObjCARCContract::tryToPeepholeInstruction( } while (IsNoopInstruction(&*BBI)); if (&*BBI == GetArgRCIdentityRoot(Inst)) { - DEBUG(dbgs() << "Adding inline asm marker for the return value " - "optimization.\n"); + LLVM_DEBUG(dbgs() << "Adding inline asm marker for the return value " + "optimization.\n"); Changed = true; InlineAsm *IA = InlineAsm::get( FunctionType::get(Type::getVoidTy(Inst->getContext()), /*isVarArg=*/false), RVInstMarker->getString(), /*Constraints=*/"", /*hasSideEffects=*/true); - CallInst::Create(IA, "", Inst); + + createCallInst(IA, None, "", Inst, BlockColors); } decline_rv_optimization: return false; @@ -471,8 +501,8 @@ bool ObjCARCContract::tryToPeepholeInstruction( Changed = true; new StoreInst(Null, CI->getArgOperand(0), CI); - DEBUG(dbgs() << "OBJCARCContract: Old = " << *CI << "\n" - << " New = " << *Null << "\n"); + LLVM_DEBUG(dbgs() << "OBJCARCContract: Old = " << *CI << "\n" + << " New = " << *Null << "\n"); CI->replaceAllUsesWith(Null); CI->eraseFromParent(); @@ -482,7 +512,7 @@ bool ObjCARCContract::tryToPeepholeInstruction( case ARCInstKind::Release: // Try to form an objc store strong from our release. If we fail, there is // nothing further to do below, so continue. - tryToContractReleaseIntoStoreStrong(Inst, Iter); + tryToContractReleaseIntoStoreStrong(Inst, Iter, BlockColors); return true; case ARCInstKind::User: // Be conservative if the function has any alloca instructions. @@ -518,7 +548,12 @@ bool ObjCARCContract::runOnFunction(Function &F) { PA.setAA(&getAnalysis<AAResultsWrapperPass>().getAAResults()); - DEBUG(llvm::dbgs() << "**** ObjCARC Contract ****\n"); + DenseMap<BasicBlock *, ColorVector> BlockColors; + if (F.hasPersonalityFn() && + isScopedEHPersonality(classifyEHPersonality(F.getPersonalityFn()))) + BlockColors = colorEHFunclets(F); + + LLVM_DEBUG(llvm::dbgs() << "**** ObjCARC Contract ****\n"); // Track whether it's ok to mark objc_storeStrong calls with the "tail" // keyword. Be conservative if the function has variadic arguments. @@ -536,12 +571,12 @@ bool ObjCARCContract::runOnFunction(Function &F) { for (inst_iterator I = inst_begin(&F), E = inst_end(&F); I != E;) { Instruction *Inst = &*I++; - DEBUG(dbgs() << "Visiting: " << *Inst << "\n"); + LLVM_DEBUG(dbgs() << "Visiting: " << *Inst << "\n"); // First try to peephole Inst. If there is nothing further we can do in // terms of undoing objc-arc-expand, process the next inst. if (tryToPeepholeInstruction(F, Inst, I, DependingInstructions, Visited, - TailOkForStoreStrongs)) + TailOkForStoreStrongs, BlockColors)) continue; // Otherwise, try to undo objc-arc-expand. @@ -568,35 +603,48 @@ bool ObjCARCContract::runOnFunction(Function &F) { // trivially dominate itself, which would lead us to rewriting its // argument in terms of its return value, which would lead to // infinite loops in GetArgRCIdentityRoot. - if (DT->isReachableFromEntry(U) && DT->dominates(Inst, U)) { - Changed = true; - Instruction *Replacement = Inst; - Type *UseTy = U.get()->getType(); - if (PHINode *PHI = dyn_cast<PHINode>(U.getUser())) { - // For PHI nodes, insert the bitcast in the predecessor block. - unsigned ValNo = PHINode::getIncomingValueNumForOperand(OperandNo); - BasicBlock *BB = PHI->getIncomingBlock(ValNo); - if (Replacement->getType() != UseTy) - Replacement = new BitCastInst(Replacement, UseTy, "", - &BB->back()); - // While we're here, rewrite all edges for this PHI, rather - // than just one use at a time, to minimize the number of - // bitcasts we emit. - for (unsigned i = 0, e = PHI->getNumIncomingValues(); i != e; ++i) - if (PHI->getIncomingBlock(i) == BB) { - // Keep the UI iterator valid. - if (UI != UE && - &PHI->getOperandUse( - PHINode::getOperandNumForIncomingValue(i)) == &*UI) - ++UI; - PHI->setIncomingValue(i, Replacement); - } - } else { - if (Replacement->getType() != UseTy) - Replacement = new BitCastInst(Replacement, UseTy, "", - cast<Instruction>(U.getUser())); - U.set(Replacement); + if (!DT->isReachableFromEntry(U) || !DT->dominates(Inst, U)) + continue; + + Changed = true; + Instruction *Replacement = Inst; + Type *UseTy = U.get()->getType(); + if (PHINode *PHI = dyn_cast<PHINode>(U.getUser())) { + // For PHI nodes, insert the bitcast in the predecessor block. + unsigned ValNo = PHINode::getIncomingValueNumForOperand(OperandNo); + BasicBlock *IncomingBB = PHI->getIncomingBlock(ValNo); + if (Replacement->getType() != UseTy) { + // A catchswitch is both a pad and a terminator, meaning a basic + // block with a catchswitch has no insertion point. Keep going up + // the dominator tree until we find a non-catchswitch. + BasicBlock *InsertBB = IncomingBB; + while (isa<CatchSwitchInst>(InsertBB->getFirstNonPHI())) { + InsertBB = DT->getNode(InsertBB)->getIDom()->getBlock(); + } + + assert(DT->dominates(Inst, &InsertBB->back()) && + "Invalid insertion point for bitcast"); + Replacement = + new BitCastInst(Replacement, UseTy, "", &InsertBB->back()); } + + // While we're here, rewrite all edges for this PHI, rather + // than just one use at a time, to minimize the number of + // bitcasts we emit. + for (unsigned i = 0, e = PHI->getNumIncomingValues(); i != e; ++i) + if (PHI->getIncomingBlock(i) == IncomingBB) { + // Keep the UI iterator valid. + if (UI != UE && + &PHI->getOperandUse( + PHINode::getOperandNumForIncomingValue(i)) == &*UI) + ++UI; + PHI->setIncomingValue(i, Replacement); + } + } else { + if (Replacement->getType() != UseTy) + Replacement = new BitCastInst(Replacement, UseTy, "", + cast<Instruction>(U.getUser())); + U.set(Replacement); } } }; @@ -618,8 +666,17 @@ bool ObjCARCContract::runOnFunction(Function &F) { else if (isa<GlobalAlias>(Arg) && !cast<GlobalAlias>(Arg)->isInterposable()) Arg = cast<GlobalAlias>(Arg)->getAliasee(); - else + else { + // If Arg is a PHI node, get PHIs that are equivalent to it and replace + // their uses. + if (PHINode *PN = dyn_cast<PHINode>(Arg)) { + SmallVector<Value *, 1> PHIList; + getEquivalentPHIs(*PN, PHIList); + for (Value *PHI : PHIList) + ReplaceArgUses(PHI); + } break; + } } // Replace bitcast users of Arg that are dominated by Inst. diff --git a/lib/Transforms/ObjCARC/ObjCARCExpand.cpp b/lib/Transforms/ObjCARC/ObjCARCExpand.cpp index bb6a0a0e73db..6a345ef56e1b 100644 --- a/lib/Transforms/ObjCARC/ObjCARCExpand.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCExpand.cpp @@ -47,7 +47,7 @@ using namespace llvm; using namespace llvm::objcarc; namespace { - /// \brief Early ARC transformations. + /// Early ARC transformations. class ObjCARCExpand : public FunctionPass { void getAnalysisUsage(AnalysisUsage &AU) const override; bool doInitialization(Module &M) override; @@ -91,12 +91,13 @@ bool ObjCARCExpand::runOnFunction(Function &F) { bool Changed = false; - DEBUG(dbgs() << "ObjCARCExpand: Visiting Function: " << F.getName() << "\n"); + LLVM_DEBUG(dbgs() << "ObjCARCExpand: Visiting Function: " << F.getName() + << "\n"); for (inst_iterator I = inst_begin(&F), E = inst_end(&F); I != E; ++I) { Instruction *Inst = &*I; - DEBUG(dbgs() << "ObjCARCExpand: Visiting: " << *Inst << "\n"); + LLVM_DEBUG(dbgs() << "ObjCARCExpand: Visiting: " << *Inst << "\n"); switch (GetBasicARCInstKind(Inst)) { case ARCInstKind::Retain: @@ -111,8 +112,10 @@ bool ObjCARCExpand::runOnFunction(Function &F) { // emitted here. We'll redo them in the contract pass. Changed = true; Value *Value = cast<CallInst>(Inst)->getArgOperand(0); - DEBUG(dbgs() << "ObjCARCExpand: Old = " << *Inst << "\n" - " New = " << *Value << "\n"); + LLVM_DEBUG(dbgs() << "ObjCARCExpand: Old = " << *Inst + << "\n" + " New = " + << *Value << "\n"); Inst->replaceAllUsesWith(Value); break; } @@ -121,7 +124,7 @@ bool ObjCARCExpand::runOnFunction(Function &F) { } } - DEBUG(dbgs() << "ObjCARCExpand: Finished List.\n\n"); + LLVM_DEBUG(dbgs() << "ObjCARCExpand: Finished List.\n\n"); return Changed; } diff --git a/lib/Transforms/ObjCARC/ObjCARCOpts.cpp b/lib/Transforms/ObjCARC/ObjCARCOpts.cpp index 99ed6863c22e..21e2848030fc 100644 --- a/lib/Transforms/ObjCARC/ObjCARCOpts.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCOpts.cpp @@ -38,6 +38,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/ObjCARCAliasAnalysis.h" #include "llvm/Analysis/ObjCARCAnalysisUtils.h" #include "llvm/Analysis/ObjCARCInstKind.h" @@ -76,7 +77,7 @@ using namespace llvm::objcarc; /// \defgroup ARCUtilities Utility declarations/definitions specific to ARC. /// @{ -/// \brief This is similar to GetRCIdentityRoot but it stops as soon +/// This is similar to GetRCIdentityRoot but it stops as soon /// as it finds a value with multiple uses. static const Value *FindSingleUseIdentifiedObject(const Value *Arg) { // ConstantData (like ConstantPointerNull and UndefValue) is used across @@ -174,7 +175,7 @@ STATISTIC(NumReleasesAfterOpt, namespace { - /// \brief Per-BasicBlock state. + /// Per-BasicBlock state. class BBState { /// The number of unique control paths from the entry which can reach this /// block. @@ -422,7 +423,7 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, BBState &BBInfo) { // Dump the pointers we are tracking. OS << " TopDown State:\n"; if (!BBInfo.hasTopDownPtrs()) { - DEBUG(dbgs() << " NONE!\n"); + LLVM_DEBUG(dbgs() << " NONE!\n"); } else { for (auto I = BBInfo.top_down_ptr_begin(), E = BBInfo.top_down_ptr_end(); I != E; ++I) { @@ -442,7 +443,7 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, BBState &BBInfo) { OS << " BottomUp State:\n"; if (!BBInfo.hasBottomUpPtrs()) { - DEBUG(dbgs() << " NONE!\n"); + LLVM_DEBUG(dbgs() << " NONE!\n"); } else { for (auto I = BBInfo.bottom_up_ptr_begin(), E = BBInfo.bottom_up_ptr_end(); I != E; ++I) { @@ -465,7 +466,7 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, BBState &BBInfo) { namespace { - /// \brief The main ARC optimization pass. + /// The main ARC optimization pass. class ObjCARCOpt : public FunctionPass { bool Changed; ProvenanceAnalysis PA; @@ -612,8 +613,8 @@ ObjCARCOpt::OptimizeRetainRVCall(Function &F, Instruction *RetainRV) { Changed = true; ++NumPeeps; - DEBUG(dbgs() << "Erasing autoreleaseRV,retainRV pair: " << *I << "\n" - << "Erasing " << *RetainRV << "\n"); + LLVM_DEBUG(dbgs() << "Erasing autoreleaseRV,retainRV pair: " << *I << "\n" + << "Erasing " << *RetainRV << "\n"); EraseInstruction(&*I); EraseInstruction(RetainRV); @@ -625,14 +626,15 @@ ObjCARCOpt::OptimizeRetainRVCall(Function &F, Instruction *RetainRV) { Changed = true; ++NumPeeps; - DEBUG(dbgs() << "Transforming objc_retainAutoreleasedReturnValue => " - "objc_retain since the operand is not a return value.\n" - "Old = " << *RetainRV << "\n"); + LLVM_DEBUG(dbgs() << "Transforming objc_retainAutoreleasedReturnValue => " + "objc_retain since the operand is not a return value.\n" + "Old = " + << *RetainRV << "\n"); Constant *NewDecl = EP.get(ARCRuntimeEntryPointKind::Retain); cast<CallInst>(RetainRV)->setCalledFunction(NewDecl); - DEBUG(dbgs() << "New = " << *RetainRV << "\n"); + LLVM_DEBUG(dbgs() << "New = " << *RetainRV << "\n"); return false; } @@ -652,6 +654,11 @@ void ObjCARCOpt::OptimizeAutoreleaseRVCall(Function &F, SmallVector<const Value *, 2> Users; Users.push_back(Ptr); + + // Add PHIs that are equivalent to Ptr to Users. + if (const PHINode *PN = dyn_cast<PHINode>(Ptr)) + getEquivalentPHIs(*PN, Users); + do { Ptr = Users.pop_back_val(); for (const User *U : Ptr->users()) { @@ -665,10 +672,12 @@ void ObjCARCOpt::OptimizeAutoreleaseRVCall(Function &F, Changed = true; ++NumPeeps; - DEBUG(dbgs() << "Transforming objc_autoreleaseReturnValue => " - "objc_autorelease since its operand is not used as a return " - "value.\n" - "Old = " << *AutoreleaseRV << "\n"); + LLVM_DEBUG( + dbgs() << "Transforming objc_autoreleaseReturnValue => " + "objc_autorelease since its operand is not used as a return " + "value.\n" + "Old = " + << *AutoreleaseRV << "\n"); CallInst *AutoreleaseRVCI = cast<CallInst>(AutoreleaseRV); Constant *NewDecl = EP.get(ARCRuntimeEntryPointKind::Autorelease); @@ -676,23 +685,53 @@ void ObjCARCOpt::OptimizeAutoreleaseRVCall(Function &F, AutoreleaseRVCI->setTailCall(false); // Never tail call objc_autorelease. Class = ARCInstKind::Autorelease; - DEBUG(dbgs() << "New: " << *AutoreleaseRV << "\n"); + LLVM_DEBUG(dbgs() << "New: " << *AutoreleaseRV << "\n"); +} + +namespace { +Instruction * +CloneCallInstForBB(CallInst &CI, BasicBlock &BB, + const DenseMap<BasicBlock *, ColorVector> &BlockColors) { + SmallVector<OperandBundleDef, 1> OpBundles; + for (unsigned I = 0, E = CI.getNumOperandBundles(); I != E; ++I) { + auto Bundle = CI.getOperandBundleAt(I); + // Funclets will be reassociated in the future. + if (Bundle.getTagID() == LLVMContext::OB_funclet) + continue; + OpBundles.emplace_back(Bundle); + } + + if (!BlockColors.empty()) { + const ColorVector &CV = BlockColors.find(&BB)->second; + assert(CV.size() == 1 && "non-unique color for block!"); + Instruction *EHPad = CV.front()->getFirstNonPHI(); + if (EHPad->isEHPad()) + OpBundles.emplace_back("funclet", EHPad); + } + + return CallInst::Create(&CI, OpBundles); +} } /// Visit each call, one at a time, and make simplifications without doing any /// additional analysis. void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { - DEBUG(dbgs() << "\n== ObjCARCOpt::OptimizeIndividualCalls ==\n"); + LLVM_DEBUG(dbgs() << "\n== ObjCARCOpt::OptimizeIndividualCalls ==\n"); // Reset all the flags in preparation for recomputing them. UsedInThisFunction = 0; + DenseMap<BasicBlock *, ColorVector> BlockColors; + if (F.hasPersonalityFn() && + isScopedEHPersonality(classifyEHPersonality(F.getPersonalityFn()))) + BlockColors = colorEHFunclets(F); + // Visit all objc_* calls in F. for (inst_iterator I = inst_begin(&F), E = inst_end(&F); I != E; ) { Instruction *Inst = &*I++; ARCInstKind Class = GetBasicARCInstKind(Inst); - DEBUG(dbgs() << "Visiting: Class: " << Class << "; " << *Inst << "\n"); + LLVM_DEBUG(dbgs() << "Visiting: Class: " << Class << "; " << *Inst << "\n"); switch (Class) { default: break; @@ -708,7 +747,7 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { case ARCInstKind::NoopCast: Changed = true; ++NumNoops; - DEBUG(dbgs() << "Erasing no-op cast: " << *Inst << "\n"); + LLVM_DEBUG(dbgs() << "Erasing no-op cast: " << *Inst << "\n"); EraseInstruction(Inst); continue; @@ -726,8 +765,10 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { Constant::getNullValue(Ty), CI); Value *NewValue = UndefValue::get(CI->getType()); - DEBUG(dbgs() << "A null pointer-to-weak-pointer is undefined behavior." - "\nOld = " << *CI << "\nNew = " << *NewValue << "\n"); + LLVM_DEBUG( + dbgs() << "A null pointer-to-weak-pointer is undefined behavior." + "\nOld = " + << *CI << "\nNew = " << *NewValue << "\n"); CI->replaceAllUsesWith(NewValue); CI->eraseFromParent(); continue; @@ -746,8 +787,10 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { CI); Value *NewValue = UndefValue::get(CI->getType()); - DEBUG(dbgs() << "A null pointer-to-weak-pointer is undefined behavior." - "\nOld = " << *CI << "\nNew = " << *NewValue << "\n"); + LLVM_DEBUG( + dbgs() << "A null pointer-to-weak-pointer is undefined behavior." + "\nOld = " + << *CI << "\nNew = " << *NewValue << "\n"); CI->replaceAllUsesWith(NewValue); CI->eraseFromParent(); @@ -782,9 +825,10 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { NewCall->setMetadata(MDKindCache.get(ARCMDKindID::ImpreciseRelease), MDNode::get(C, None)); - DEBUG(dbgs() << "Replacing autorelease{,RV}(x) with objc_release(x) " - "since x is otherwise unused.\nOld: " << *Call << "\nNew: " - << *NewCall << "\n"); + LLVM_DEBUG( + dbgs() << "Replacing autorelease{,RV}(x) with objc_release(x) " + "since x is otherwise unused.\nOld: " + << *Call << "\nNew: " << *NewCall << "\n"); EraseInstruction(Call); Inst = NewCall; @@ -796,8 +840,10 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { // a tail keyword. if (IsAlwaysTail(Class)) { Changed = true; - DEBUG(dbgs() << "Adding tail keyword to function since it can never be " - "passed stack args: " << *Inst << "\n"); + LLVM_DEBUG( + dbgs() << "Adding tail keyword to function since it can never be " + "passed stack args: " + << *Inst << "\n"); cast<CallInst>(Inst)->setTailCall(); } @@ -805,16 +851,16 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { // semantics of ARC truly do not do so. if (IsNeverTail(Class)) { Changed = true; - DEBUG(dbgs() << "Removing tail keyword from function: " << *Inst << - "\n"); + LLVM_DEBUG(dbgs() << "Removing tail keyword from function: " << *Inst + << "\n"); cast<CallInst>(Inst)->setTailCall(false); } // Set nounwind as needed. if (IsNoThrow(Class)) { Changed = true; - DEBUG(dbgs() << "Found no throw class. Setting nounwind on: " << *Inst - << "\n"); + LLVM_DEBUG(dbgs() << "Found no throw class. Setting nounwind on: " + << *Inst << "\n"); cast<CallInst>(Inst)->setDoesNotThrow(); } @@ -829,8 +875,8 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { if (IsNullOrUndef(Arg)) { Changed = true; ++NumNoops; - DEBUG(dbgs() << "ARC calls with null are no-ops. Erasing: " << *Inst - << "\n"); + LLVM_DEBUG(dbgs() << "ARC calls with null are no-ops. Erasing: " << *Inst + << "\n"); EraseInstruction(Inst); continue; } @@ -922,22 +968,24 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { Value *Incoming = GetRCIdentityRoot(PN->getIncomingValue(i)); if (!IsNullOrUndef(Incoming)) { - CallInst *Clone = cast<CallInst>(CInst->clone()); Value *Op = PN->getIncomingValue(i); Instruction *InsertPos = &PN->getIncomingBlock(i)->back(); + CallInst *Clone = cast<CallInst>(CloneCallInstForBB( + *CInst, *InsertPos->getParent(), BlockColors)); if (Op->getType() != ParamTy) Op = new BitCastInst(Op, ParamTy, "", InsertPos); Clone->setArgOperand(0, Op); Clone->insertBefore(InsertPos); - DEBUG(dbgs() << "Cloning " - << *CInst << "\n" - "And inserting clone at " << *InsertPos << "\n"); + LLVM_DEBUG(dbgs() << "Cloning " << *CInst + << "\n" + "And inserting clone at " + << *InsertPos << "\n"); Worklist.push_back(std::make_pair(Clone, Incoming)); } } // Erase the original call. - DEBUG(dbgs() << "Erasing: " << *CInst << "\n"); + LLVM_DEBUG(dbgs() << "Erasing: " << *CInst << "\n"); EraseInstruction(CInst); continue; } @@ -1114,7 +1162,7 @@ bool ObjCARCOpt::VisitInstructionBottomUp( ARCInstKind Class = GetARCInstKind(Inst); const Value *Arg = nullptr; - DEBUG(dbgs() << " Class: " << Class << "\n"); + LLVM_DEBUG(dbgs() << " Class: " << Class << "\n"); switch (Class) { case ARCInstKind::Release: { @@ -1137,7 +1185,7 @@ bool ObjCARCOpt::VisitInstructionBottomUp( // Don't do retain+release tracking for ARCInstKind::RetainRV, because // it's better to let it remain as the first instruction after a call. if (Class != ARCInstKind::RetainRV) { - DEBUG(dbgs() << " Matching with: " << *Inst << "\n"); + LLVM_DEBUG(dbgs() << " Matching with: " << *Inst << "\n"); Retains[Inst] = S.GetRRInfo(); } S.ClearSequenceProgress(); @@ -1179,7 +1227,7 @@ bool ObjCARCOpt::VisitInstructionBottomUp( bool ObjCARCOpt::VisitBottomUp(BasicBlock *BB, DenseMap<const BasicBlock *, BBState> &BBStates, BlotMapVector<Value *, RRInfo> &Retains) { - DEBUG(dbgs() << "\n== ObjCARCOpt::VisitBottomUp ==\n"); + LLVM_DEBUG(dbgs() << "\n== ObjCARCOpt::VisitBottomUp ==\n"); bool NestingDetected = false; BBState &MyStates = BBStates[BB]; @@ -1202,8 +1250,9 @@ bool ObjCARCOpt::VisitBottomUp(BasicBlock *BB, } } - DEBUG(dbgs() << "Before:\n" << BBStates[BB] << "\n" - << "Performing Dataflow:\n"); + LLVM_DEBUG(dbgs() << "Before:\n" + << BBStates[BB] << "\n" + << "Performing Dataflow:\n"); // Visit all the instructions, bottom-up. for (BasicBlock::iterator I = BB->end(), E = BB->begin(); I != E; --I) { @@ -1213,7 +1262,7 @@ bool ObjCARCOpt::VisitBottomUp(BasicBlock *BB, if (isa<InvokeInst>(Inst)) continue; - DEBUG(dbgs() << " Visiting " << *Inst << "\n"); + LLVM_DEBUG(dbgs() << " Visiting " << *Inst << "\n"); NestingDetected |= VisitInstructionBottomUp(Inst, BB, Retains, MyStates); } @@ -1228,7 +1277,7 @@ bool ObjCARCOpt::VisitBottomUp(BasicBlock *BB, NestingDetected |= VisitInstructionBottomUp(II, BB, Retains, MyStates); } - DEBUG(dbgs() << "\nFinal State:\n" << BBStates[BB] << "\n"); + LLVM_DEBUG(dbgs() << "\nFinal State:\n" << BBStates[BB] << "\n"); return NestingDetected; } @@ -1241,7 +1290,7 @@ ObjCARCOpt::VisitInstructionTopDown(Instruction *Inst, ARCInstKind Class = GetARCInstKind(Inst); const Value *Arg = nullptr; - DEBUG(dbgs() << " Class: " << Class << "\n"); + LLVM_DEBUG(dbgs() << " Class: " << Class << "\n"); switch (Class) { case ARCInstKind::RetainBlock: @@ -1267,7 +1316,7 @@ ObjCARCOpt::VisitInstructionTopDown(Instruction *Inst, if (S.MatchWithRelease(MDKindCache, Inst)) { // If we succeed, copy S's RRInfo into the Release -> {Retain Set // Map}. Then we clear S. - DEBUG(dbgs() << " Matching with: " << *Inst << "\n"); + LLVM_DEBUG(dbgs() << " Matching with: " << *Inst << "\n"); Releases[Inst] = S.GetRRInfo(); S.ClearSequenceProgress(); } @@ -1307,7 +1356,7 @@ bool ObjCARCOpt::VisitTopDown(BasicBlock *BB, DenseMap<const BasicBlock *, BBState> &BBStates, DenseMap<Value *, RRInfo> &Releases) { - DEBUG(dbgs() << "\n== ObjCARCOpt::VisitTopDown ==\n"); + LLVM_DEBUG(dbgs() << "\n== ObjCARCOpt::VisitTopDown ==\n"); bool NestingDetected = false; BBState &MyStates = BBStates[BB]; @@ -1329,20 +1378,21 @@ ObjCARCOpt::VisitTopDown(BasicBlock *BB, } } - DEBUG(dbgs() << "Before:\n" << BBStates[BB] << "\n" - << "Performing Dataflow:\n"); + LLVM_DEBUG(dbgs() << "Before:\n" + << BBStates[BB] << "\n" + << "Performing Dataflow:\n"); // Visit all the instructions, top-down. for (Instruction &Inst : *BB) { - DEBUG(dbgs() << " Visiting " << Inst << "\n"); + LLVM_DEBUG(dbgs() << " Visiting " << Inst << "\n"); NestingDetected |= VisitInstructionTopDown(&Inst, Releases, MyStates); } - DEBUG(dbgs() << "\nState Before Checking for CFG Hazards:\n" - << BBStates[BB] << "\n\n"); + LLVM_DEBUG(dbgs() << "\nState Before Checking for CFG Hazards:\n" + << BBStates[BB] << "\n\n"); CheckForCFGHazards(BB, BBStates, MyStates); - DEBUG(dbgs() << "Final State:\n" << BBStates[BB] << "\n"); + LLVM_DEBUG(dbgs() << "Final State:\n" << BBStates[BB] << "\n"); return NestingDetected; } @@ -1465,7 +1515,7 @@ void ObjCARCOpt::MoveCalls(Value *Arg, RRInfo &RetainsToMove, Type *ArgTy = Arg->getType(); Type *ParamTy = PointerType::getUnqual(Type::getInt8Ty(ArgTy->getContext())); - DEBUG(dbgs() << "== ObjCARCOpt::MoveCalls ==\n"); + LLVM_DEBUG(dbgs() << "== ObjCARCOpt::MoveCalls ==\n"); // Insert the new retain and release calls. for (Instruction *InsertPt : ReleasesToMove.ReverseInsertPts) { @@ -1476,8 +1526,10 @@ void ObjCARCOpt::MoveCalls(Value *Arg, RRInfo &RetainsToMove, Call->setDoesNotThrow(); Call->setTailCall(); - DEBUG(dbgs() << "Inserting new Retain: " << *Call << "\n" - "At insertion point: " << *InsertPt << "\n"); + LLVM_DEBUG(dbgs() << "Inserting new Retain: " << *Call + << "\n" + "At insertion point: " + << *InsertPt << "\n"); } for (Instruction *InsertPt : RetainsToMove.ReverseInsertPts) { Value *MyArg = ArgTy == ParamTy ? Arg : @@ -1491,20 +1543,22 @@ void ObjCARCOpt::MoveCalls(Value *Arg, RRInfo &RetainsToMove, if (ReleasesToMove.IsTailCallRelease) Call->setTailCall(); - DEBUG(dbgs() << "Inserting new Release: " << *Call << "\n" - "At insertion point: " << *InsertPt << "\n"); + LLVM_DEBUG(dbgs() << "Inserting new Release: " << *Call + << "\n" + "At insertion point: " + << *InsertPt << "\n"); } // Delete the original retain and release calls. for (Instruction *OrigRetain : RetainsToMove.Calls) { Retains.blot(OrigRetain); DeadInsts.push_back(OrigRetain); - DEBUG(dbgs() << "Deleting retain: " << *OrigRetain << "\n"); + LLVM_DEBUG(dbgs() << "Deleting retain: " << *OrigRetain << "\n"); } for (Instruction *OrigRelease : ReleasesToMove.Calls) { Releases.erase(OrigRelease); DeadInsts.push_back(OrigRelease); - DEBUG(dbgs() << "Deleting release: " << *OrigRelease << "\n"); + LLVM_DEBUG(dbgs() << "Deleting release: " << *OrigRelease << "\n"); } } @@ -1538,6 +1592,7 @@ bool ObjCARCOpt::PairUpRetainsAndReleases( assert(It != Retains.end()); const RRInfo &NewRetainRRI = It->second; KnownSafeTD &= NewRetainRRI.KnownSafe; + CFGHazardAfflicted |= NewRetainRRI.CFGHazardAfflicted; for (Instruction *NewRetainRelease : NewRetainRRI.Calls) { auto Jt = Releases.find(NewRetainRelease); if (Jt == Releases.end()) @@ -1710,7 +1765,7 @@ bool ObjCARCOpt::PerformCodePlacement( DenseMap<const BasicBlock *, BBState> &BBStates, BlotMapVector<Value *, RRInfo> &Retains, DenseMap<Value *, RRInfo> &Releases, Module *M) { - DEBUG(dbgs() << "\n== ObjCARCOpt::PerformCodePlacement ==\n"); + LLVM_DEBUG(dbgs() << "\n== ObjCARCOpt::PerformCodePlacement ==\n"); bool AnyPairsCompletelyEliminated = false; SmallVector<Instruction *, 8> DeadInsts; @@ -1724,7 +1779,7 @@ bool ObjCARCOpt::PerformCodePlacement( Instruction *Retain = cast<Instruction>(V); - DEBUG(dbgs() << "Visiting: " << *Retain << "\n"); + LLVM_DEBUG(dbgs() << "Visiting: " << *Retain << "\n"); Value *Arg = GetArgRCIdentityRoot(Retain); @@ -1769,7 +1824,7 @@ bool ObjCARCOpt::PerformCodePlacement( /// Weak pointer optimizations. void ObjCARCOpt::OptimizeWeakCalls(Function &F) { - DEBUG(dbgs() << "\n== ObjCARCOpt::OptimizeWeakCalls ==\n"); + LLVM_DEBUG(dbgs() << "\n== ObjCARCOpt::OptimizeWeakCalls ==\n"); // First, do memdep-style RLE and S2L optimizations. We can't use memdep // itself because it uses AliasAnalysis and we need to do provenance @@ -1777,7 +1832,7 @@ void ObjCARCOpt::OptimizeWeakCalls(Function &F) { for (inst_iterator I = inst_begin(&F), E = inst_end(&F); I != E; ) { Instruction *Inst = &*I++; - DEBUG(dbgs() << "Visiting: " << *Inst << "\n"); + LLVM_DEBUG(dbgs() << "Visiting: " << *Inst << "\n"); ARCInstKind Class = GetBasicARCInstKind(Inst); if (Class != ARCInstKind::LoadWeak && @@ -2036,7 +2091,7 @@ void ObjCARCOpt::OptimizeReturns(Function &F) { if (!F.getReturnType()->isPointerTy()) return; - DEBUG(dbgs() << "\n== ObjCARCOpt::OptimizeReturns ==\n"); + LLVM_DEBUG(dbgs() << "\n== ObjCARCOpt::OptimizeReturns ==\n"); SmallPtrSet<Instruction *, 4> DependingInstructions; SmallPtrSet<const BasicBlock *, 4> Visited; @@ -2045,7 +2100,7 @@ void ObjCARCOpt::OptimizeReturns(Function &F) { if (!Ret) continue; - DEBUG(dbgs() << "Visiting: " << *Ret << "\n"); + LLVM_DEBUG(dbgs() << "Visiting: " << *Ret << "\n"); const Value *Arg = GetRCIdentityRoot(Ret->getOperand(0)); @@ -2083,8 +2138,8 @@ void ObjCARCOpt::OptimizeReturns(Function &F) { // If so, we can zap the retain and autorelease. Changed = true; ++NumRets; - DEBUG(dbgs() << "Erasing: " << *Retain << "\nErasing: " - << *Autorelease << "\n"); + LLVM_DEBUG(dbgs() << "Erasing: " << *Retain << "\nErasing: " << *Autorelease + << "\n"); EraseInstruction(Retain); EraseInstruction(Autorelease); } @@ -2144,8 +2199,9 @@ bool ObjCARCOpt::runOnFunction(Function &F) { Changed = false; - DEBUG(dbgs() << "<<< ObjCARCOpt: Visiting Function: " << F.getName() << " >>>" - "\n"); + LLVM_DEBUG(dbgs() << "<<< ObjCARCOpt: Visiting Function: " << F.getName() + << " >>>" + "\n"); PA.setAA(&getAnalysis<AAResultsWrapperPass>().getAAResults()); @@ -2193,7 +2249,7 @@ bool ObjCARCOpt::runOnFunction(Function &F) { } #endif - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << "\n"); return Changed; } diff --git a/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp b/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp index f89fc8eb62aa..3004fffb9745 100644 --- a/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp +++ b/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp @@ -115,14 +115,6 @@ static bool IsStoredObjCPointer(const Value *P) { bool ProvenanceAnalysis::relatedCheck(const Value *A, const Value *B, const DataLayout &DL) { - // Skip past provenance pass-throughs. - A = GetUnderlyingObjCPtr(A, DL); - B = GetUnderlyingObjCPtr(B, DL); - - // Quick check. - if (A == B) - return true; - // Ask regular AliasAnalysis, for a first approximation. switch (AA->alias(A, B)) { case NoAlias: @@ -171,6 +163,13 @@ bool ProvenanceAnalysis::relatedCheck(const Value *A, const Value *B, bool ProvenanceAnalysis::related(const Value *A, const Value *B, const DataLayout &DL) { + A = GetUnderlyingObjCPtrCached(A, DL, UnderlyingObjCPtrCache); + B = GetUnderlyingObjCPtrCached(B, DL, UnderlyingObjCPtrCache); + + // Quick check. + if (A == B) + return true; + // Begin by inserting a conservative value into the map. If the insertion // fails, we have the answer already. If it succeeds, leave it there until we // compute the real answer to guard against recursive queries. diff --git a/lib/Transforms/ObjCARC/ProvenanceAnalysis.h b/lib/Transforms/ObjCARC/ProvenanceAnalysis.h index 5e676167a6a1..1276f564a022 100644 --- a/lib/Transforms/ObjCARC/ProvenanceAnalysis.h +++ b/lib/Transforms/ObjCARC/ProvenanceAnalysis.h @@ -28,6 +28,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/IR/ValueHandle.h" #include <utility> namespace llvm { @@ -39,7 +40,7 @@ class Value; namespace objcarc { -/// \brief This is similar to BasicAliasAnalysis, and it uses many of the same +/// This is similar to BasicAliasAnalysis, and it uses many of the same /// techniques, except it uses special ObjC-specific reasoning about pointer /// relationships. /// @@ -56,6 +57,8 @@ class ProvenanceAnalysis { CachedResultsTy CachedResults; + DenseMap<const Value *, WeakTrackingVH> UnderlyingObjCPtrCache; + bool relatedCheck(const Value *A, const Value *B, const DataLayout &DL); bool relatedSelect(const SelectInst *A, const Value *B); bool relatedPHI(const PHINode *A, const Value *B); @@ -73,6 +76,7 @@ public: void clear() { CachedResults.clear(); + UnderlyingObjCPtrCache.clear(); } }; diff --git a/lib/Transforms/ObjCARC/PtrState.cpp b/lib/Transforms/ObjCARC/PtrState.cpp index e1774b88fd35..8a7b6a74fae2 100644 --- a/lib/Transforms/ObjCARC/PtrState.cpp +++ b/lib/Transforms/ObjCARC/PtrState.cpp @@ -126,22 +126,23 @@ bool RRInfo::Merge(const RRInfo &Other) { //===----------------------------------------------------------------------===// void PtrState::SetKnownPositiveRefCount() { - DEBUG(dbgs() << " Setting Known Positive.\n"); + LLVM_DEBUG(dbgs() << " Setting Known Positive.\n"); KnownPositiveRefCount = true; } void PtrState::ClearKnownPositiveRefCount() { - DEBUG(dbgs() << " Clearing Known Positive.\n"); + LLVM_DEBUG(dbgs() << " Clearing Known Positive.\n"); KnownPositiveRefCount = false; } void PtrState::SetSeq(Sequence NewSeq) { - DEBUG(dbgs() << " Old: " << GetSeq() << "; New: " << NewSeq << "\n"); + LLVM_DEBUG(dbgs() << " Old: " << GetSeq() << "; New: " << NewSeq + << "\n"); Seq = NewSeq; } void PtrState::ResetSequenceProgress(Sequence NewSeq) { - DEBUG(dbgs() << " Resetting sequence progress.\n"); + LLVM_DEBUG(dbgs() << " Resetting sequence progress.\n"); SetSeq(NewSeq); Partial = false; RRI.clear(); @@ -184,7 +185,8 @@ bool BottomUpPtrState::InitBottomUp(ARCMDKindCache &Cache, Instruction *I) { // simple and avoids adding overhead for the non-nested case. bool NestingDetected = false; if (GetSeq() == S_Release || GetSeq() == S_MovableRelease) { - DEBUG(dbgs() << " Found nested releases (i.e. a release pair)\n"); + LLVM_DEBUG( + dbgs() << " Found nested releases (i.e. a release pair)\n"); NestingDetected = true; } @@ -234,8 +236,8 @@ bool BottomUpPtrState::HandlePotentialAlterRefCount(Instruction *Inst, if (!CanAlterRefCount(Inst, Ptr, PA, Class)) return false; - DEBUG(dbgs() << " CanAlterRefCount: Seq: " << S << "; " << *Ptr - << "\n"); + LLVM_DEBUG(dbgs() << " CanAlterRefCount: Seq: " << S << "; " + << *Ptr << "\n"); switch (S) { case S_Use: SetSeq(S_CanRelease); @@ -266,6 +268,11 @@ void BottomUpPtrState::HandlePotentialUse(BasicBlock *BB, Instruction *Inst, if (isa<InvokeInst>(Inst)) { const auto IP = BB->getFirstInsertionPt(); InsertAfter = IP == BB->end() ? std::prev(BB->end()) : IP; + if (isa<CatchSwitchInst>(InsertAfter)) + // A catchswitch must be the only non-phi instruction in its basic + // block, so attempting to insert an instruction into such a block would + // produce invalid IR. + SetCFGHazardAfflicted(true); } else { InsertAfter = std::next(Inst->getIterator()); } @@ -277,26 +284,26 @@ void BottomUpPtrState::HandlePotentialUse(BasicBlock *BB, Instruction *Inst, case S_Release: case S_MovableRelease: if (CanUse(Inst, Ptr, PA, Class)) { - DEBUG(dbgs() << " CanUse: Seq: " << GetSeq() << "; " << *Ptr - << "\n"); + LLVM_DEBUG(dbgs() << " CanUse: Seq: " << GetSeq() << "; " + << *Ptr << "\n"); SetSeqAndInsertReverseInsertPt(S_Use); } else if (Seq == S_Release && IsUser(Class)) { - DEBUG(dbgs() << " PreciseReleaseUse: Seq: " << GetSeq() << "; " - << *Ptr << "\n"); + LLVM_DEBUG(dbgs() << " PreciseReleaseUse: Seq: " << GetSeq() + << "; " << *Ptr << "\n"); // Non-movable releases depend on any possible objc pointer use. SetSeqAndInsertReverseInsertPt(S_Stop); } else if (const auto *Call = getreturnRVOperand(*Inst, Class)) { if (CanUse(Call, Ptr, PA, GetBasicARCInstKind(Call))) { - DEBUG(dbgs() << " ReleaseUse: Seq: " << GetSeq() << "; " - << *Ptr << "\n"); + LLVM_DEBUG(dbgs() << " ReleaseUse: Seq: " << GetSeq() << "; " + << *Ptr << "\n"); SetSeqAndInsertReverseInsertPt(S_Stop); } } break; case S_Stop: if (CanUse(Inst, Ptr, PA, Class)) { - DEBUG(dbgs() << " PreciseStopUse: Seq: " << GetSeq() << "; " - << *Ptr << "\n"); + LLVM_DEBUG(dbgs() << " PreciseStopUse: Seq: " << GetSeq() + << "; " << *Ptr << "\n"); SetSeq(S_Use); } break; @@ -377,8 +384,8 @@ bool TopDownPtrState::HandlePotentialAlterRefCount(Instruction *Inst, Class != ARCInstKind::IntrinsicUser) return false; - DEBUG(dbgs() << " CanAlterRefCount: Seq: " << GetSeq() << "; " << *Ptr - << "\n"); + LLVM_DEBUG(dbgs() << " CanAlterRefCount: Seq: " << GetSeq() << "; " + << *Ptr << "\n"); ClearKnownPositiveRefCount(); switch (GetSeq()) { case S_Retain: @@ -410,8 +417,8 @@ void TopDownPtrState::HandlePotentialUse(Instruction *Inst, const Value *Ptr, case S_CanRelease: if (!CanUse(Inst, Ptr, PA, Class)) return; - DEBUG(dbgs() << " CanUse: Seq: " << GetSeq() << "; " << *Ptr - << "\n"); + LLVM_DEBUG(dbgs() << " CanUse: Seq: " << GetSeq() << "; " + << *Ptr << "\n"); SetSeq(S_Use); return; case S_Retain: diff --git a/lib/Transforms/ObjCARC/PtrState.h b/lib/Transforms/ObjCARC/PtrState.h index e1e95afcf76b..f5b9b853d8e3 100644 --- a/lib/Transforms/ObjCARC/PtrState.h +++ b/lib/Transforms/ObjCARC/PtrState.h @@ -36,7 +36,7 @@ class ProvenanceAnalysis; /// \enum Sequence /// -/// \brief A sequence of states that a pointer may go through in which an +/// A sequence of states that a pointer may go through in which an /// objc_retain and objc_release are actually needed. enum Sequence { S_None, @@ -51,7 +51,7 @@ enum Sequence { raw_ostream &operator<<(raw_ostream &OS, const Sequence S) LLVM_ATTRIBUTE_UNUSED; -/// \brief Unidirectional information about either a +/// Unidirectional information about either a /// retain-decrement-use-release sequence or release-use-decrement-retain /// reverse sequence. struct RRInfo { @@ -97,7 +97,7 @@ struct RRInfo { bool Merge(const RRInfo &Other); }; -/// \brief This class summarizes several per-pointer runtime properties which +/// This class summarizes several per-pointer runtime properties which /// are propagated through the flow graph. class PtrState { protected: diff --git a/lib/Transforms/Scalar/ADCE.cpp b/lib/Transforms/Scalar/ADCE.cpp index 1e683db50206..ce09a477b5f5 100644 --- a/lib/Transforms/Scalar/ADCE.cpp +++ b/lib/Transforms/Scalar/ADCE.cpp @@ -174,8 +174,8 @@ class AggressiveDeadCodeElimination { /// marked live. void markLiveBranchesFromControlDependences(); - /// Remove instructions not marked live, return if any any instruction - /// was removed. + /// Remove instructions not marked live, return if any instruction was + /// removed. bool removeDeadInstructions(); /// Identify connected sections of the control flow graph which have @@ -298,8 +298,8 @@ void AggressiveDeadCodeElimination::initialize() { auto &Info = BlockInfo[BB]; // Real function return if (isa<ReturnInst>(Info.Terminator)) { - DEBUG(dbgs() << "post-dom root child is a return: " << BB->getName() - << '\n';); + LLVM_DEBUG(dbgs() << "post-dom root child is a return: " << BB->getName() + << '\n';); continue; } @@ -356,7 +356,7 @@ void AggressiveDeadCodeElimination::markLiveInstructions() { // where we need to mark the inputs as live. while (!Worklist.empty()) { Instruction *LiveInst = Worklist.pop_back_val(); - DEBUG(dbgs() << "work live: "; LiveInst->dump();); + LLVM_DEBUG(dbgs() << "work live: "; LiveInst->dump();); for (Use &OI : LiveInst->operands()) if (Instruction *Inst = dyn_cast<Instruction>(OI)) @@ -378,7 +378,7 @@ void AggressiveDeadCodeElimination::markLive(Instruction *I) { if (Info.Live) return; - DEBUG(dbgs() << "mark live: "; I->dump()); + LLVM_DEBUG(dbgs() << "mark live: "; I->dump()); Info.Live = true; Worklist.push_back(I); @@ -402,7 +402,7 @@ void AggressiveDeadCodeElimination::markLive(Instruction *I) { void AggressiveDeadCodeElimination::markLive(BlockInfoType &BBInfo) { if (BBInfo.Live) return; - DEBUG(dbgs() << "mark block live: " << BBInfo.BB->getName() << '\n'); + LLVM_DEBUG(dbgs() << "mark block live: " << BBInfo.BB->getName() << '\n'); BBInfo.Live = true; if (!BBInfo.CFLive) { BBInfo.CFLive = true; @@ -463,7 +463,7 @@ void AggressiveDeadCodeElimination::markLiveBranchesFromControlDependences() { if (BlocksWithDeadTerminators.empty()) return; - DEBUG({ + LLVM_DEBUG({ dbgs() << "new live blocks:\n"; for (auto *BB : NewLiveBlocks) dbgs() << "\t" << BB->getName() << '\n'; @@ -487,7 +487,7 @@ void AggressiveDeadCodeElimination::markLiveBranchesFromControlDependences() { // Dead terminators which control live blocks are now marked live. for (auto *BB : IDFBlocks) { - DEBUG(dbgs() << "live control in: " << BB->getName() << '\n'); + LLVM_DEBUG(dbgs() << "live control in: " << BB->getName() << '\n'); markLive(BB->getTerminator()); } } @@ -501,7 +501,7 @@ bool AggressiveDeadCodeElimination::removeDeadInstructions() { // Updates control and dataflow around dead blocks updateDeadRegions(); - DEBUG({ + LLVM_DEBUG({ for (Instruction &I : instructions(F)) { // Check if the instruction is alive. if (isLive(&I)) @@ -555,7 +555,7 @@ bool AggressiveDeadCodeElimination::removeDeadInstructions() { // A dead region is the set of dead blocks with a common live post-dominator. void AggressiveDeadCodeElimination::updateDeadRegions() { - DEBUG({ + LLVM_DEBUG({ dbgs() << "final dead terminator blocks: " << '\n'; for (auto *BB : BlocksWithDeadTerminators) dbgs() << '\t' << BB->getName() @@ -607,8 +607,9 @@ void AggressiveDeadCodeElimination::updateDeadRegions() { // It might have happened that the same successor appeared multiple times // and the CFG edge wasn't really removed. if (Succ != PreferredSucc->BB) { - DEBUG(dbgs() << "ADCE: (Post)DomTree edge enqueued for deletion" - << BB->getName() << " -> " << Succ->getName() << "\n"); + LLVM_DEBUG(dbgs() << "ADCE: (Post)DomTree edge enqueued for deletion" + << BB->getName() << " -> " << Succ->getName() + << "\n"); DeletedEdges.push_back({DominatorTree::Delete, BB, Succ}); } } @@ -652,7 +653,7 @@ void AggressiveDeadCodeElimination::makeUnconditional(BasicBlock *BB, InstInfo[PredTerm].Live = true; return; } - DEBUG(dbgs() << "making unconditional " << BB->getName() << '\n'); + LLVM_DEBUG(dbgs() << "making unconditional " << BB->getName() << '\n'); NumBranchesRemoved += 1; IRBuilder<> Builder(PredTerm); auto *NewTerm = Builder.CreateBr(Target); diff --git a/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp index 99480f12da9e..fa7bcec677f7 100644 --- a/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp +++ b/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp @@ -98,8 +98,8 @@ static unsigned getNewAlignmentDiff(const SCEV *DiffSCEV, const SCEV *DiffAlign = SE->getMulExpr(DiffAlignDiv, AlignSCEV); const SCEV *DiffUnitsSCEV = SE->getMinusSCEV(DiffAlign, DiffSCEV); - DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is " << - *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n"); + LLVM_DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is " + << *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n"); if (const SCEVConstant *ConstDUSCEV = dyn_cast<SCEVConstant>(DiffUnitsSCEV)) { @@ -139,12 +139,12 @@ static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV, // address. This address is displaced by the provided offset. DiffSCEV = SE->getMinusSCEV(DiffSCEV, OffSCEV); - DEBUG(dbgs() << "AFI: alignment of " << *Ptr << " relative to " << - *AlignSCEV << " and offset " << *OffSCEV << - " using diff " << *DiffSCEV << "\n"); + LLVM_DEBUG(dbgs() << "AFI: alignment of " << *Ptr << " relative to " + << *AlignSCEV << " and offset " << *OffSCEV + << " using diff " << *DiffSCEV << "\n"); unsigned NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE); - DEBUG(dbgs() << "\tnew alignment: " << NewAlignment << "\n"); + LLVM_DEBUG(dbgs() << "\tnew alignment: " << NewAlignment << "\n"); if (NewAlignment) { return NewAlignment; @@ -160,8 +160,8 @@ static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV, const SCEV *DiffStartSCEV = DiffARSCEV->getStart(); const SCEV *DiffIncSCEV = DiffARSCEV->getStepRecurrence(*SE); - DEBUG(dbgs() << "\ttrying start/inc alignment using start " << - *DiffStartSCEV << " and inc " << *DiffIncSCEV << "\n"); + LLVM_DEBUG(dbgs() << "\ttrying start/inc alignment using start " + << *DiffStartSCEV << " and inc " << *DiffIncSCEV << "\n"); // Now compute the new alignment using the displacement to the value in the // first iteration, and also the alignment using the per-iteration delta. @@ -170,26 +170,26 @@ static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV, NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE); unsigned NewIncAlignment = getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE); - DEBUG(dbgs() << "\tnew start alignment: " << NewAlignment << "\n"); - DEBUG(dbgs() << "\tnew inc alignment: " << NewIncAlignment << "\n"); + LLVM_DEBUG(dbgs() << "\tnew start alignment: " << NewAlignment << "\n"); + LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << NewIncAlignment << "\n"); if (!NewAlignment || !NewIncAlignment) { return 0; } else if (NewAlignment > NewIncAlignment) { if (NewAlignment % NewIncAlignment == 0) { - DEBUG(dbgs() << "\tnew start/inc alignment: " << - NewIncAlignment << "\n"); + LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewIncAlignment + << "\n"); return NewIncAlignment; } } else if (NewIncAlignment > NewAlignment) { if (NewIncAlignment % NewAlignment == 0) { - DEBUG(dbgs() << "\tnew start/inc alignment: " << - NewAlignment << "\n"); + LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewAlignment + << "\n"); return NewAlignment; } } else if (NewIncAlignment == NewAlignment) { - DEBUG(dbgs() << "\tnew start/inc alignment: " << - NewAlignment << "\n"); + LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewAlignment + << "\n"); return NewAlignment; } } @@ -339,55 +339,24 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) { unsigned NewDestAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MI->getDest(), SE); - // For memory transfers, we need a common alignment for both the - // source and destination. If we have a new alignment for this - // instruction, but only for one operand, save it. If we reach the - // other operand through another assumption later, then we may - // change the alignment at that point. + LLVM_DEBUG(dbgs() << "\tmem inst: " << NewDestAlignment << "\n";); + if (NewDestAlignment > MI->getDestAlignment()) { + MI->setDestAlignment(NewDestAlignment); + ++NumMemIntAlignChanged; + } + + // For memory transfers, there is also a source alignment that + // can be set. if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) { unsigned NewSrcAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MTI->getSource(), SE); - DenseMap<MemTransferInst *, unsigned>::iterator DI = - NewDestAlignments.find(MTI); - unsigned AltDestAlignment = (DI == NewDestAlignments.end()) ? - 0 : DI->second; - - DenseMap<MemTransferInst *, unsigned>::iterator SI = - NewSrcAlignments.find(MTI); - unsigned AltSrcAlignment = (SI == NewSrcAlignments.end()) ? - 0 : SI->second; - - DEBUG(dbgs() << "\tmem trans: " << NewDestAlignment << " " << - AltDestAlignment << " " << NewSrcAlignment << - " " << AltSrcAlignment << "\n"); - - // Of these four alignments, pick the largest possible... - unsigned NewAlignment = 0; - if (NewDestAlignment <= std::max(NewSrcAlignment, AltSrcAlignment)) - NewAlignment = std::max(NewAlignment, NewDestAlignment); - if (AltDestAlignment <= std::max(NewSrcAlignment, AltSrcAlignment)) - NewAlignment = std::max(NewAlignment, AltDestAlignment); - if (NewSrcAlignment <= std::max(NewDestAlignment, AltDestAlignment)) - NewAlignment = std::max(NewAlignment, NewSrcAlignment); - if (AltSrcAlignment <= std::max(NewDestAlignment, AltDestAlignment)) - NewAlignment = std::max(NewAlignment, AltSrcAlignment); - - if (NewAlignment > MI->getAlignment()) { - MI->setAlignment(ConstantInt::get(Type::getInt32Ty( - MI->getParent()->getContext()), NewAlignment)); + LLVM_DEBUG(dbgs() << "\tmem trans: " << NewSrcAlignment << "\n";); + + if (NewSrcAlignment > MTI->getSourceAlignment()) { + MTI->setSourceAlignment(NewSrcAlignment); ++NumMemIntAlignChanged; } - - NewDestAlignments.insert(std::make_pair(MTI, NewDestAlignment)); - NewSrcAlignments.insert(std::make_pair(MTI, NewSrcAlignment)); - } else if (NewDestAlignment > MI->getAlignment()) { - assert((!isa<MemIntrinsic>(MI) || isa<MemSetInst>(MI)) && - "Unknown memory intrinsic"); - - MI->setAlignment(ConstantInt::get(Type::getInt32Ty( - MI->getParent()->getContext()), NewDestAlignment)); - ++NumMemIntAlignChanged; } } @@ -421,9 +390,6 @@ bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC, SE = SE_; DT = DT_; - NewDestAlignments.clear(); - NewSrcAlignments.clear(); - bool Changed = false; for (auto &AssumeVH : AC.assumptions()) if (AssumeVH) diff --git a/lib/Transforms/Scalar/BDCE.cpp b/lib/Transforms/Scalar/BDCE.cpp index 851efa000f65..3a8ef073cb48 100644 --- a/lib/Transforms/Scalar/BDCE.cpp +++ b/lib/Transforms/Scalar/BDCE.cpp @@ -20,6 +20,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/DemandedBits.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/Pass.h" @@ -99,7 +100,7 @@ static bool bitTrackingDCE(Function &F, DemandedBits &DB) { // For live instructions that have all dead bits, first make them dead by // replacing all uses with something else. Then, if they don't need to // remain live (because they have side effects, etc.) we can remove them. - DEBUG(dbgs() << "BDCE: Trivializing: " << I << " (all bits dead)\n"); + LLVM_DEBUG(dbgs() << "BDCE: Trivializing: " << I << " (all bits dead)\n"); clearAssumptionsOfUsers(&I, DB); @@ -114,6 +115,7 @@ static bool bitTrackingDCE(Function &F, DemandedBits &DB) { if (!DB.isInstructionDead(&I)) continue; + salvageDebugInfo(I); Worklist.push_back(&I); I.dropAllReferences(); Changed = true; diff --git a/lib/Transforms/Scalar/CMakeLists.txt b/lib/Transforms/Scalar/CMakeLists.txt index 0562d3882f8b..fce37d4bffb8 100644 --- a/lib/Transforms/Scalar/CMakeLists.txt +++ b/lib/Transforms/Scalar/CMakeLists.txt @@ -20,6 +20,7 @@ add_llvm_library(LLVMScalarOpts InductiveRangeCheckElimination.cpp IndVarSimplify.cpp InferAddressSpaces.cpp + InstSimplifyPass.cpp JumpThreading.cpp LICM.cpp LoopAccessAnalysisPrinter.cpp @@ -38,6 +39,7 @@ add_llvm_library(LLVMScalarOpts LoopSimplifyCFG.cpp LoopStrengthReduce.cpp LoopUnrollPass.cpp + LoopUnrollAndJamPass.cpp LoopUnswitch.cpp LoopVersioningLICM.cpp LowerAtomic.cpp diff --git a/lib/Transforms/Scalar/CallSiteSplitting.cpp b/lib/Transforms/Scalar/CallSiteSplitting.cpp index 207243231aad..5ebfbf8a879b 100644 --- a/lib/Transforms/Scalar/CallSiteSplitting.cpp +++ b/lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -59,12 +59,14 @@ #include "llvm/Transforms/Scalar/CallSiteSplitting.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/Cloning.h" using namespace llvm; using namespace PatternMatch; @@ -73,9 +75,16 @@ using namespace PatternMatch; STATISTIC(NumCallSiteSplit, "Number of call-site split"); -static void addNonNullAttribute(Instruction *CallI, Instruction *NewCallI, - Value *Op) { - CallSite CS(NewCallI); +/// Only allow instructions before a call, if their CodeSize cost is below +/// DuplicationThreshold. Those instructions need to be duplicated in all +/// split blocks. +static cl::opt<unsigned> + DuplicationThreshold("callsite-splitting-duplication-threshold", cl::Hidden, + cl::desc("Only allow instructions before a call, if " + "their cost is below DuplicationThreshold"), + cl::init(5)); + +static void addNonNullAttribute(CallSite CS, Value *Op) { unsigned ArgNo = 0; for (auto &I : CS.args()) { if (&*I == Op) @@ -84,13 +93,16 @@ static void addNonNullAttribute(Instruction *CallI, Instruction *NewCallI, } } -static void setConstantInArgument(Instruction *CallI, Instruction *NewCallI, - Value *Op, Constant *ConstValue) { - CallSite CS(NewCallI); +static void setConstantInArgument(CallSite CS, Value *Op, + Constant *ConstValue) { unsigned ArgNo = 0; for (auto &I : CS.args()) { - if (&*I == Op) + if (&*I == Op) { + // It is possible we have already added the non-null attribute to the + // parameter by using an earlier constraining condition. + CS.removeParamAttr(ArgNo, Attribute::NonNull); CS.setArgument(ArgNo, ConstValue); + } ++ArgNo; } } @@ -111,11 +123,13 @@ static bool isCondRelevantToAnyCallArgument(ICmpInst *Cmp, CallSite CS) { return false; } +typedef std::pair<ICmpInst *, unsigned> ConditionTy; +typedef SmallVector<ConditionTy, 2> ConditionsTy; + /// If From has a conditional jump to To, add the condition to Conditions, /// if it is relevant to any argument at CS. -static void -recordCondition(const CallSite &CS, BasicBlock *From, BasicBlock *To, - SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) { +static void recordCondition(CallSite CS, BasicBlock *From, BasicBlock *To, + ConditionsTy &Conditions) { auto *BI = dyn_cast<BranchInst>(From->getTerminator()); if (!BI || !BI->isConditional()) return; @@ -134,40 +148,33 @@ recordCondition(const CallSite &CS, BasicBlock *From, BasicBlock *To, } /// Record ICmp conditions relevant to any argument in CS following Pred's -/// single successors. If there are conflicting conditions along a path, like +/// single predecessors. If there are conflicting conditions along a path, like /// x == 1 and x == 0, the first condition will be used. -static void -recordConditions(const CallSite &CS, BasicBlock *Pred, - SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) { +static void recordConditions(CallSite CS, BasicBlock *Pred, + ConditionsTy &Conditions) { recordCondition(CS, Pred, CS.getInstruction()->getParent(), Conditions); BasicBlock *From = Pred; BasicBlock *To = Pred; - SmallPtrSet<BasicBlock *, 4> Visited = {From}; + SmallPtrSet<BasicBlock *, 4> Visited; while (!Visited.count(From->getSinglePredecessor()) && (From = From->getSinglePredecessor())) { recordCondition(CS, From, To, Conditions); + Visited.insert(From); To = From; } } -static Instruction * -addConditions(CallSite &CS, - SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) { - if (Conditions.empty()) - return nullptr; - - Instruction *NewCI = CS.getInstruction()->clone(); +static void addConditions(CallSite CS, const ConditionsTy &Conditions) { for (auto &Cond : Conditions) { Value *Arg = Cond.first->getOperand(0); Constant *ConstVal = cast<Constant>(Cond.first->getOperand(1)); if (Cond.second == ICmpInst::ICMP_EQ) - setConstantInArgument(CS.getInstruction(), NewCI, Arg, ConstVal); + setConstantInArgument(CS, Arg, ConstVal); else if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) { assert(Cond.second == ICmpInst::ICMP_NE); - addNonNullAttribute(CS.getInstruction(), NewCI, Arg); + addNonNullAttribute(CS, Arg); } } - return NewCI; } static SmallVector<BasicBlock *, 2> getTwoPredecessors(BasicBlock *BB) { @@ -176,38 +183,90 @@ static SmallVector<BasicBlock *, 2> getTwoPredecessors(BasicBlock *BB) { return Preds; } -static bool canSplitCallSite(CallSite CS) { +static bool canSplitCallSite(CallSite CS, TargetTransformInfo &TTI) { // FIXME: As of now we handle only CallInst. InvokeInst could be handled // without too much effort. Instruction *Instr = CS.getInstruction(); if (!isa<CallInst>(Instr)) return false; - // Allow splitting a call-site only when there is no instruction before the - // call-site in the basic block. Based on this constraint, we only clone the - // call instruction, and we do not move a call-site across any other - // instruction. BasicBlock *CallSiteBB = Instr->getParent(); - if (Instr != CallSiteBB->getFirstNonPHIOrDbg()) - return false; - // Need 2 predecessors and cannot split an edge from an IndirectBrInst. SmallVector<BasicBlock *, 2> Preds(predecessors(CallSiteBB)); if (Preds.size() != 2 || isa<IndirectBrInst>(Preds[0]->getTerminator()) || isa<IndirectBrInst>(Preds[1]->getTerminator())) return false; - return CallSiteBB->canSplitPredecessors(); + // BasicBlock::canSplitPredecessors is more agressive, so checking for + // BasicBlock::isEHPad as well. + if (!CallSiteBB->canSplitPredecessors() || CallSiteBB->isEHPad()) + return false; + + // Allow splitting a call-site only when the CodeSize cost of the + // instructions before the call is less then DuplicationThreshold. The + // instructions before the call will be duplicated in the split blocks and + // corresponding uses will be updated. + unsigned Cost = 0; + for (auto &InstBeforeCall : + llvm::make_range(CallSiteBB->begin(), Instr->getIterator())) { + Cost += TTI.getInstructionCost(&InstBeforeCall, + TargetTransformInfo::TCK_CodeSize); + if (Cost >= DuplicationThreshold) + return false; + } + + return true; +} + +static Instruction *cloneInstForMustTail(Instruction *I, Instruction *Before, + Value *V) { + Instruction *Copy = I->clone(); + Copy->setName(I->getName()); + Copy->insertBefore(Before); + if (V) + Copy->setOperand(0, V); + return Copy; } -/// Return true if the CS is split into its new predecessors which are directly -/// hooked to each of its original predecessors pointed by PredBB1 and PredBB2. -/// CallInst1 and CallInst2 will be the new call-sites placed in the new -/// predecessors split for PredBB1 and PredBB2, respectively. +/// Copy mandatory `musttail` return sequence that follows original `CI`, and +/// link it up to `NewCI` value instead: +/// +/// * (optional) `bitcast NewCI to ...` +/// * `ret bitcast or NewCI` +/// +/// Insert this sequence right before `SplitBB`'s terminator, which will be +/// cleaned up later in `splitCallSite` below. +static void copyMustTailReturn(BasicBlock *SplitBB, Instruction *CI, + Instruction *NewCI) { + bool IsVoid = SplitBB->getParent()->getReturnType()->isVoidTy(); + auto II = std::next(CI->getIterator()); + + BitCastInst* BCI = dyn_cast<BitCastInst>(&*II); + if (BCI) + ++II; + + ReturnInst* RI = dyn_cast<ReturnInst>(&*II); + assert(RI && "`musttail` call must be followed by `ret` instruction"); + + TerminatorInst *TI = SplitBB->getTerminator(); + Value *V = NewCI; + if (BCI) + V = cloneInstForMustTail(BCI, TI, V); + cloneInstForMustTail(RI, TI, IsVoid ? nullptr : V); + + // FIXME: remove TI here, `DuplicateInstructionsInSplitBetween` has a bug + // that prevents doing this now. +} + +/// For each (predecessor, conditions from predecessors) pair, it will split the +/// basic block containing the call site, hook it up to the predecessor and +/// replace the call instruction with new call instructions, which contain +/// constraints based on the conditions from their predecessors. /// For example, in the IR below with an OR condition, the call-site can -/// be split. Assuming PredBB1=Header and PredBB2=TBB, CallInst1 will be the -/// call-site placed between Header and Tail, and CallInst2 will be the -/// call-site between TBB and Tail. +/// be split. In this case, Preds for Tail is [(Header, a == null), +/// (TBB, a != null, b == null)]. Tail is replaced by 2 split blocks, containing +/// CallInst1, which has constraints based on the conditions from Head and +/// CallInst2, which has constraints based on the conditions coming from TBB. /// /// From : /// @@ -240,60 +299,112 @@ static bool canSplitCallSite(CallSite CS) { /// Note that in case any arguments at the call-site are constrained by its /// predecessors, new call-sites with more constrained arguments will be /// created in createCallSitesOnPredicatedArgument(). -static void splitCallSite(CallSite CS, BasicBlock *PredBB1, BasicBlock *PredBB2, - Instruction *CallInst1, Instruction *CallInst2) { +static void splitCallSite( + CallSite CS, + const SmallVectorImpl<std::pair<BasicBlock *, ConditionsTy>> &Preds, + DominatorTree *DT) { Instruction *Instr = CS.getInstruction(); BasicBlock *TailBB = Instr->getParent(); - assert(Instr == (TailBB->getFirstNonPHIOrDbg()) && "Unexpected call-site"); - - BasicBlock *SplitBlock1 = - SplitBlockPredecessors(TailBB, PredBB1, ".predBB1.split"); - BasicBlock *SplitBlock2 = - SplitBlockPredecessors(TailBB, PredBB2, ".predBB2.split"); - - assert((SplitBlock1 && SplitBlock2) && "Unexpected new basic block split."); - - if (!CallInst1) - CallInst1 = Instr->clone(); - if (!CallInst2) - CallInst2 = Instr->clone(); - - CallInst1->insertBefore(&*SplitBlock1->getFirstInsertionPt()); - CallInst2->insertBefore(&*SplitBlock2->getFirstInsertionPt()); - - CallSite CS1(CallInst1); - CallSite CS2(CallInst2); - - // Handle PHIs used as arguments in the call-site. - for (auto &PI : *TailBB) { - PHINode *PN = dyn_cast<PHINode>(&PI); - if (!PN) - break; - unsigned ArgNo = 0; - for (auto &CI : CS.args()) { - if (&*CI == PN) { - CS1.setArgument(ArgNo, PN->getIncomingValueForBlock(SplitBlock1)); - CS2.setArgument(ArgNo, PN->getIncomingValueForBlock(SplitBlock2)); + bool IsMustTailCall = CS.isMustTailCall(); + + PHINode *CallPN = nullptr; + + // `musttail` calls must be followed by optional `bitcast`, and `ret`. The + // split blocks will be terminated right after that so there're no users for + // this phi in a `TailBB`. + if (!IsMustTailCall && !Instr->use_empty()) + CallPN = PHINode::Create(Instr->getType(), Preds.size(), "phi.call"); + + LLVM_DEBUG(dbgs() << "split call-site : " << *Instr << " into \n"); + + assert(Preds.size() == 2 && "The ValueToValueMaps array has size 2."); + // ValueToValueMapTy is neither copy nor moveable, so we use a simple array + // here. + ValueToValueMapTy ValueToValueMaps[2]; + for (unsigned i = 0; i < Preds.size(); i++) { + BasicBlock *PredBB = Preds[i].first; + BasicBlock *SplitBlock = DuplicateInstructionsInSplitBetween( + TailBB, PredBB, &*std::next(Instr->getIterator()), ValueToValueMaps[i], + DT); + assert(SplitBlock && "Unexpected new basic block split."); + + Instruction *NewCI = + &*std::prev(SplitBlock->getTerminator()->getIterator()); + CallSite NewCS(NewCI); + addConditions(NewCS, Preds[i].second); + + // Handle PHIs used as arguments in the call-site. + for (PHINode &PN : TailBB->phis()) { + unsigned ArgNo = 0; + for (auto &CI : CS.args()) { + if (&*CI == &PN) { + NewCS.setArgument(ArgNo, PN.getIncomingValueForBlock(SplitBlock)); + } + ++ArgNo; } - ++ArgNo; } + LLVM_DEBUG(dbgs() << " " << *NewCI << " in " << SplitBlock->getName() + << "\n"); + if (CallPN) + CallPN->addIncoming(NewCI, SplitBlock); + + // Clone and place bitcast and return instructions before `TI` + if (IsMustTailCall) + copyMustTailReturn(SplitBlock, Instr, NewCI); + } + + NumCallSiteSplit++; + + // FIXME: remove TI in `copyMustTailReturn` + if (IsMustTailCall) { + // Remove superfluous `br` terminators from the end of the Split blocks + // NOTE: Removing terminator removes the SplitBlock from the TailBB's + // predecessors. Therefore we must get complete list of Splits before + // attempting removal. + SmallVector<BasicBlock *, 2> Splits(predecessors((TailBB))); + assert(Splits.size() == 2 && "Expected exactly 2 splits!"); + for (unsigned i = 0; i < Splits.size(); i++) + Splits[i]->getTerminator()->eraseFromParent(); + + // Erase the tail block once done with musttail patching + TailBB->eraseFromParent(); + return; } + auto *OriginalBegin = &*TailBB->begin(); // Replace users of the original call with a PHI mering call-sites split. - if (Instr->getNumUses()) { - PHINode *PN = PHINode::Create(Instr->getType(), 2, "phi.call", - TailBB->getFirstNonPHI()); - PN->addIncoming(CallInst1, SplitBlock1); - PN->addIncoming(CallInst2, SplitBlock2); - Instr->replaceAllUsesWith(PN); + if (CallPN) { + CallPN->insertBefore(OriginalBegin); + Instr->replaceAllUsesWith(CallPN); + } + + // Remove instructions moved to split blocks from TailBB, from the duplicated + // call instruction to the beginning of the basic block. If an instruction + // has any uses, add a new PHI node to combine the values coming from the + // split blocks. The new PHI nodes are placed before the first original + // instruction, so we do not end up deleting them. By using reverse-order, we + // do not introduce unnecessary PHI nodes for def-use chains from the call + // instruction to the beginning of the block. + auto I = Instr->getReverseIterator(); + while (I != TailBB->rend()) { + Instruction *CurrentI = &*I++; + if (!CurrentI->use_empty()) { + // If an existing PHI has users after the call, there is no need to create + // a new one. + if (isa<PHINode>(CurrentI)) + continue; + PHINode *NewPN = PHINode::Create(CurrentI->getType(), Preds.size()); + for (auto &Mapping : ValueToValueMaps) + NewPN->addIncoming(Mapping[CurrentI], + cast<Instruction>(Mapping[CurrentI])->getParent()); + NewPN->insertBefore(&*TailBB->begin()); + CurrentI->replaceAllUsesWith(NewPN); + } + CurrentI->eraseFromParent(); + // We are done once we handled the first original instruction in TailBB. + if (CurrentI == OriginalBegin) + break; } - DEBUG(dbgs() << "split call-site : " << *Instr << " into \n"); - DEBUG(dbgs() << " " << *CallInst1 << " in " << SplitBlock1->getName() - << "\n"); - DEBUG(dbgs() << " " << *CallInst2 << " in " << SplitBlock2->getName() - << "\n"); - Instr->eraseFromParent(); - NumCallSiteSplit++; } // Return true if the call-site has an argument which is a PHI with only @@ -324,45 +435,59 @@ static bool isPredicatedOnPHI(CallSite CS) { return false; } -static bool tryToSplitOnPHIPredicatedArgument(CallSite CS) { +static bool tryToSplitOnPHIPredicatedArgument(CallSite CS, DominatorTree *DT) { if (!isPredicatedOnPHI(CS)) return false; auto Preds = getTwoPredecessors(CS.getInstruction()->getParent()); - splitCallSite(CS, Preds[0], Preds[1], nullptr, nullptr); + SmallVector<std::pair<BasicBlock *, ConditionsTy>, 2> PredsCS = { + {Preds[0], {}}, {Preds[1], {}}}; + splitCallSite(CS, PredsCS, DT); return true; } -static bool tryToSplitOnPredicatedArgument(CallSite CS) { +static bool tryToSplitOnPredicatedArgument(CallSite CS, DominatorTree *DT) { auto Preds = getTwoPredecessors(CS.getInstruction()->getParent()); if (Preds[0] == Preds[1]) return false; - SmallVector<std::pair<ICmpInst *, unsigned>, 2> C1, C2; - recordConditions(CS, Preds[0], C1); - recordConditions(CS, Preds[1], C2); + SmallVector<std::pair<BasicBlock *, ConditionsTy>, 2> PredsCS; + for (auto *Pred : make_range(Preds.rbegin(), Preds.rend())) { + ConditionsTy Conditions; + recordConditions(CS, Pred, Conditions); + PredsCS.push_back({Pred, Conditions}); + } - Instruction *CallInst1 = addConditions(CS, C1); - Instruction *CallInst2 = addConditions(CS, C2); - if (!CallInst1 && !CallInst2) + if (std::all_of(PredsCS.begin(), PredsCS.end(), + [](const std::pair<BasicBlock *, ConditionsTy> &P) { + return P.second.empty(); + })) return false; - splitCallSite(CS, Preds[1], Preds[0], CallInst2, CallInst1); + splitCallSite(CS, PredsCS, DT); return true; } -static bool tryToSplitCallSite(CallSite CS) { - if (!CS.arg_size() || !canSplitCallSite(CS)) +static bool tryToSplitCallSite(CallSite CS, TargetTransformInfo &TTI, + DominatorTree *DT) { + if (!CS.arg_size() || !canSplitCallSite(CS, TTI)) return false; - return tryToSplitOnPredicatedArgument(CS) || - tryToSplitOnPHIPredicatedArgument(CS); + return tryToSplitOnPredicatedArgument(CS, DT) || + tryToSplitOnPHIPredicatedArgument(CS, DT); } -static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI) { +static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI, + TargetTransformInfo &TTI, DominatorTree *DT) { bool Changed = false; for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE;) { BasicBlock &BB = *BI++; - for (BasicBlock::iterator II = BB.begin(), IE = BB.end(); II != IE;) { + auto II = BB.getFirstNonPHIOrDbg()->getIterator(); + auto IE = BB.getTerminator()->getIterator(); + // Iterate until we reach the terminator instruction. tryToSplitCallSite + // can replace BB's terminator in case BB is a successor of itself. In that + // case, IE will be invalidated and we also have to check the current + // terminator. + while (II != IE && &*II != BB.getTerminator()) { Instruction *I = &*II++; CallSite CS(cast<Value>(I)); if (!CS || isa<IntrinsicInst>(I) || isInstructionTriviallyDead(I, &TLI)) @@ -371,7 +496,17 @@ static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI) { Function *Callee = CS.getCalledFunction(); if (!Callee || Callee->isDeclaration()) continue; - Changed |= tryToSplitCallSite(CS); + + // Successful musttail call-site splits result in erased CI and erased BB. + // Check if such path is possible before attempting the splitting. + bool IsMustTail = CS.isMustTailCall(); + + Changed |= tryToSplitCallSite(CS, TTI, DT); + + // There're no interesting instructions after this. The call site + // itself might have been erased on splitting. + if (IsMustTail) + break; } } return Changed; @@ -386,6 +521,8 @@ struct CallSiteSplittingLegacyPass : public FunctionPass { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); FunctionPass::getAnalysisUsage(AU); } @@ -394,7 +531,10 @@ struct CallSiteSplittingLegacyPass : public FunctionPass { return false; auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - return doCallSiteSplitting(F, TLI); + auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + return doCallSiteSplitting(F, TLI, TTI, + DTWP ? &DTWP->getDomTree() : nullptr); } }; } // namespace @@ -403,6 +543,7 @@ char CallSiteSplittingLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(CallSiteSplittingLegacyPass, "callsite-splitting", "Call-site splitting", false, false) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(CallSiteSplittingLegacyPass, "callsite-splitting", "Call-site splitting", false, false) FunctionPass *llvm::createCallSiteSplittingPass() { @@ -412,9 +553,12 @@ FunctionPass *llvm::createCallSiteSplittingPass() { PreservedAnalyses CallSiteSplittingPass::run(Function &F, FunctionAnalysisManager &AM) { auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); - if (!doCallSiteSplitting(F, TLI)) + if (!doCallSiteSplitting(F, TLI, TTI, DT)) return PreservedAnalyses::all(); PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); return PA; } diff --git a/lib/Transforms/Scalar/ConstantHoisting.cpp b/lib/Transforms/Scalar/ConstantHoisting.cpp index e4b08c5ed305..3a675b979017 100644 --- a/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -43,8 +43,10 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" @@ -59,8 +61,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" -#include "llvm/IR/DebugInfoMetadata.h" #include <algorithm> #include <cassert> #include <cstdint> @@ -84,7 +84,7 @@ static cl::opt<bool> ConstHoistWithBlockFrequency( namespace { -/// \brief The constant hoisting pass. +/// The constant hoisting pass. class ConstantHoistingLegacyPass : public FunctionPass { public: static char ID; // Pass identification, replacement for typeid @@ -127,13 +127,13 @@ FunctionPass *llvm::createConstantHoistingPass() { return new ConstantHoistingLegacyPass(); } -/// \brief Perform the constant hoisting optimization for the given function. +/// Perform the constant hoisting optimization for the given function. bool ConstantHoistingLegacyPass::runOnFunction(Function &Fn) { if (skipFunction(Fn)) return false; - DEBUG(dbgs() << "********** Begin Constant Hoisting **********\n"); - DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n'); + LLVM_DEBUG(dbgs() << "********** Begin Constant Hoisting **********\n"); + LLVM_DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n'); bool MadeChange = Impl.runImpl(Fn, getAnalysis<TargetTransformInfoWrapperPass>().getTTI(Fn), @@ -144,16 +144,16 @@ bool ConstantHoistingLegacyPass::runOnFunction(Function &Fn) { Fn.getEntryBlock()); if (MadeChange) { - DEBUG(dbgs() << "********** Function after Constant Hoisting: " - << Fn.getName() << '\n'); - DEBUG(dbgs() << Fn); + LLVM_DEBUG(dbgs() << "********** Function after Constant Hoisting: " + << Fn.getName() << '\n'); + LLVM_DEBUG(dbgs() << Fn); } - DEBUG(dbgs() << "********** End Constant Hoisting **********\n"); + LLVM_DEBUG(dbgs() << "********** End Constant Hoisting **********\n"); return MadeChange; } -/// \brief Find the constant materialization insertion point. +/// Find the constant materialization insertion point. Instruction *ConstantHoistingPass::findMatInsertPt(Instruction *Inst, unsigned Idx) const { // If the operand is a cast instruction, then we have to materialize the @@ -187,7 +187,7 @@ Instruction *ConstantHoistingPass::findMatInsertPt(Instruction *Inst, return IDom->getBlock()->getTerminator(); } -/// \brief Given \p BBs as input, find another set of BBs which collectively +/// Given \p BBs as input, find another set of BBs which collectively /// dominates \p BBs and have the minimal sum of frequencies. Return the BB /// set found in \p BBs. static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI, @@ -289,7 +289,7 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI, } } -/// \brief Find an insertion point that dominates all uses. +/// Find an insertion point that dominates all uses. SmallPtrSet<Instruction *, 8> ConstantHoistingPass::findConstantInsertionPoint( const ConstantInfo &ConstInfo) const { assert(!ConstInfo.RebasedConstants.empty() && "Invalid constant info entry."); @@ -335,7 +335,7 @@ SmallPtrSet<Instruction *, 8> ConstantHoistingPass::findConstantInsertionPoint( return InsertPts; } -/// \brief Record constant integer ConstInt for instruction Inst at operand +/// Record constant integer ConstInt for instruction Inst at operand /// index Idx. /// /// The operand at index Idx is not necessarily the constant integer itself. It @@ -364,18 +364,17 @@ void ConstantHoistingPass::collectConstantCandidates( Itr->second = ConstCandVec.size() - 1; } ConstCandVec[Itr->second].addUser(Inst, Idx, Cost); - DEBUG(if (isa<ConstantInt>(Inst->getOperand(Idx))) - dbgs() << "Collect constant " << *ConstInt << " from " << *Inst + LLVM_DEBUG(if (isa<ConstantInt>(Inst->getOperand(Idx))) dbgs() + << "Collect constant " << *ConstInt << " from " << *Inst << " with cost " << Cost << '\n'; - else - dbgs() << "Collect constant " << *ConstInt << " indirectly from " - << *Inst << " via " << *Inst->getOperand(Idx) << " with cost " - << Cost << '\n'; - ); + else dbgs() << "Collect constant " << *ConstInt + << " indirectly from " << *Inst << " via " + << *Inst->getOperand(Idx) << " with cost " << Cost + << '\n';); } } -/// \brief Check the operand for instruction Inst at index Idx. +/// Check the operand for instruction Inst at index Idx. void ConstantHoistingPass::collectConstantCandidates( ConstCandMapType &ConstCandMap, Instruction *Inst, unsigned Idx) { Value *Opnd = Inst->getOperand(Idx); @@ -416,7 +415,7 @@ void ConstantHoistingPass::collectConstantCandidates( } } -/// \brief Scan the instruction for expensive integer constants and record them +/// Scan the instruction for expensive integer constants and record them /// in the constant candidate vector. void ConstantHoistingPass::collectConstantCandidates( ConstCandMapType &ConstCandMap, Instruction *Inst) { @@ -436,7 +435,7 @@ void ConstantHoistingPass::collectConstantCandidates( } // end of for all operands } -/// \brief Collect all integer constants in the function that cannot be folded +/// Collect all integer constants in the function that cannot be folded /// into an instruction itself. void ConstantHoistingPass::collectConstantCandidates(Function &Fn) { ConstCandMapType ConstCandMap; @@ -501,20 +500,21 @@ ConstantHoistingPass::maximizeConstantsInRange(ConstCandVecType::iterator S, return NumUses; } - DEBUG(dbgs() << "== Maximize constants in range ==\n"); + LLVM_DEBUG(dbgs() << "== Maximize constants in range ==\n"); int MaxCost = -1; for (auto ConstCand = S; ConstCand != E; ++ConstCand) { auto Value = ConstCand->ConstInt->getValue(); Type *Ty = ConstCand->ConstInt->getType(); int Cost = 0; NumUses += ConstCand->Uses.size(); - DEBUG(dbgs() << "= Constant: " << ConstCand->ConstInt->getValue() << "\n"); + LLVM_DEBUG(dbgs() << "= Constant: " << ConstCand->ConstInt->getValue() + << "\n"); for (auto User : ConstCand->Uses) { unsigned Opcode = User.Inst->getOpcode(); unsigned OpndIdx = User.OpndIdx; Cost += TTI->getIntImmCost(Opcode, OpndIdx, Value, Ty); - DEBUG(dbgs() << "Cost: " << Cost << "\n"); + LLVM_DEBUG(dbgs() << "Cost: " << Cost << "\n"); for (auto C2 = S; C2 != E; ++C2) { Optional<APInt> Diff = calculateOffsetDiff( @@ -524,24 +524,24 @@ ConstantHoistingPass::maximizeConstantsInRange(ConstCandVecType::iterator S, const int ImmCosts = TTI->getIntImmCodeSizeCost(Opcode, OpndIdx, Diff.getValue(), Ty); Cost -= ImmCosts; - DEBUG(dbgs() << "Offset " << Diff.getValue() << " " - << "has penalty: " << ImmCosts << "\n" - << "Adjusted cost: " << Cost << "\n"); + LLVM_DEBUG(dbgs() << "Offset " << Diff.getValue() << " " + << "has penalty: " << ImmCosts << "\n" + << "Adjusted cost: " << Cost << "\n"); } } } - DEBUG(dbgs() << "Cumulative cost: " << Cost << "\n"); + LLVM_DEBUG(dbgs() << "Cumulative cost: " << Cost << "\n"); if (Cost > MaxCost) { MaxCost = Cost; MaxCostItr = ConstCand; - DEBUG(dbgs() << "New candidate: " << MaxCostItr->ConstInt->getValue() - << "\n"); + LLVM_DEBUG(dbgs() << "New candidate: " << MaxCostItr->ConstInt->getValue() + << "\n"); } } return NumUses; } -/// \brief Find the base constant within the given range and rebase all other +/// Find the base constant within the given range and rebase all other /// constants with respect to the base constant. void ConstantHoistingPass::findAndMakeBaseConstant( ConstCandVecType::iterator S, ConstCandVecType::iterator E) { @@ -567,12 +567,12 @@ void ConstantHoistingPass::findAndMakeBaseConstant( ConstantVec.push_back(std::move(ConstInfo)); } -/// \brief Finds and combines constant candidates that can be easily +/// Finds and combines constant candidates that can be easily /// rematerialized with an add from a common base constant. void ConstantHoistingPass::findBaseConstants() { // Sort the constants by value and type. This invalidates the mapping! - std::sort(ConstCandVec.begin(), ConstCandVec.end(), - [](const ConstantCandidate &LHS, const ConstantCandidate &RHS) { + llvm::sort(ConstCandVec.begin(), ConstCandVec.end(), + [](const ConstantCandidate &LHS, const ConstantCandidate &RHS) { if (LHS.ConstInt->getType() != RHS.ConstInt->getType()) return LHS.ConstInt->getType()->getBitWidth() < RHS.ConstInt->getType()->getBitWidth(); @@ -601,7 +601,7 @@ void ConstantHoistingPass::findBaseConstants() { findAndMakeBaseConstant(MinValItr, ConstCandVec.end()); } -/// \brief Updates the operand at Idx in instruction Inst with the result of +/// Updates the operand at Idx in instruction Inst with the result of /// instruction Mat. If the instruction is a PHI node then special /// handling for duplicate values form the same incoming basic block is /// required. @@ -629,7 +629,7 @@ static bool updateOperand(Instruction *Inst, unsigned Idx, Instruction *Mat) { return true; } -/// \brief Emit materialization code for all rebased constants and update their +/// Emit materialization code for all rebased constants and update their /// users. void ConstantHoistingPass::emitBaseConstants(Instruction *Base, Constant *Offset, @@ -641,19 +641,20 @@ void ConstantHoistingPass::emitBaseConstants(Instruction *Base, Mat = BinaryOperator::Create(Instruction::Add, Base, Offset, "const_mat", InsertionPt); - DEBUG(dbgs() << "Materialize constant (" << *Base->getOperand(0) - << " + " << *Offset << ") in BB " - << Mat->getParent()->getName() << '\n' << *Mat << '\n'); + LLVM_DEBUG(dbgs() << "Materialize constant (" << *Base->getOperand(0) + << " + " << *Offset << ") in BB " + << Mat->getParent()->getName() << '\n' + << *Mat << '\n'); Mat->setDebugLoc(ConstUser.Inst->getDebugLoc()); } Value *Opnd = ConstUser.Inst->getOperand(ConstUser.OpndIdx); // Visit constant integer. if (isa<ConstantInt>(Opnd)) { - DEBUG(dbgs() << "Update: " << *ConstUser.Inst << '\n'); + LLVM_DEBUG(dbgs() << "Update: " << *ConstUser.Inst << '\n'); if (!updateOperand(ConstUser.Inst, ConstUser.OpndIdx, Mat) && Offset) Mat->eraseFromParent(); - DEBUG(dbgs() << "To : " << *ConstUser.Inst << '\n'); + LLVM_DEBUG(dbgs() << "To : " << *ConstUser.Inst << '\n'); return; } @@ -669,13 +670,13 @@ void ConstantHoistingPass::emitBaseConstants(Instruction *Base, ClonedCastInst->insertAfter(CastInst); // Use the same debug location as the original cast instruction. ClonedCastInst->setDebugLoc(CastInst->getDebugLoc()); - DEBUG(dbgs() << "Clone instruction: " << *CastInst << '\n' - << "To : " << *ClonedCastInst << '\n'); + LLVM_DEBUG(dbgs() << "Clone instruction: " << *CastInst << '\n' + << "To : " << *ClonedCastInst << '\n'); } - DEBUG(dbgs() << "Update: " << *ConstUser.Inst << '\n'); + LLVM_DEBUG(dbgs() << "Update: " << *ConstUser.Inst << '\n'); updateOperand(ConstUser.Inst, ConstUser.OpndIdx, ClonedCastInst); - DEBUG(dbgs() << "To : " << *ConstUser.Inst << '\n'); + LLVM_DEBUG(dbgs() << "To : " << *ConstUser.Inst << '\n'); return; } @@ -689,20 +690,20 @@ void ConstantHoistingPass::emitBaseConstants(Instruction *Base, // Use the same debug location as the instruction we are about to update. ConstExprInst->setDebugLoc(ConstUser.Inst->getDebugLoc()); - DEBUG(dbgs() << "Create instruction: " << *ConstExprInst << '\n' - << "From : " << *ConstExpr << '\n'); - DEBUG(dbgs() << "Update: " << *ConstUser.Inst << '\n'); + LLVM_DEBUG(dbgs() << "Create instruction: " << *ConstExprInst << '\n' + << "From : " << *ConstExpr << '\n'); + LLVM_DEBUG(dbgs() << "Update: " << *ConstUser.Inst << '\n'); if (!updateOperand(ConstUser.Inst, ConstUser.OpndIdx, ConstExprInst)) { ConstExprInst->eraseFromParent(); if (Offset) Mat->eraseFromParent(); } - DEBUG(dbgs() << "To : " << *ConstUser.Inst << '\n'); + LLVM_DEBUG(dbgs() << "To : " << *ConstUser.Inst << '\n'); return; } } -/// \brief Hoist and hide the base constant behind a bitcast and emit +/// Hoist and hide the base constant behind a bitcast and emit /// materialization code for derived constants. bool ConstantHoistingPass::emitBaseConstants() { bool MadeChange = false; @@ -720,9 +721,9 @@ bool ConstantHoistingPass::emitBaseConstants() { Base->setDebugLoc(IP->getDebugLoc()); - DEBUG(dbgs() << "Hoist constant (" << *ConstInfo.BaseConstant - << ") to BB " << IP->getParent()->getName() << '\n' - << *Base << '\n'); + LLVM_DEBUG(dbgs() << "Hoist constant (" << *ConstInfo.BaseConstant + << ") to BB " << IP->getParent()->getName() << '\n' + << *Base << '\n'); // Emit materialization code for all rebased constants. unsigned Uses = 0; @@ -765,7 +766,7 @@ bool ConstantHoistingPass::emitBaseConstants() { return MadeChange; } -/// \brief Check all cast instructions we made a copy of and remove them if they +/// Check all cast instructions we made a copy of and remove them if they /// have no more users. void ConstantHoistingPass::deleteDeadCastInst() const { for (auto const &I : ClonedCastMap) @@ -773,7 +774,7 @@ void ConstantHoistingPass::deleteDeadCastInst() const { I.first->eraseFromParent(); } -/// \brief Optimize expensive integer constants in the given function. +/// Optimize expensive integer constants in the given function. bool ConstantHoistingPass::runImpl(Function &Fn, TargetTransformInfo &TTI, DominatorTree &DT, BlockFrequencyInfo *BFI, BasicBlock &Entry) { diff --git a/lib/Transforms/Scalar/ConstantProp.cpp b/lib/Transforms/Scalar/ConstantProp.cpp index 4fa27891a974..46915889ce7c 100644 --- a/lib/Transforms/Scalar/ConstantProp.cpp +++ b/lib/Transforms/Scalar/ConstantProp.cpp @@ -21,12 +21,12 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Constant.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" #include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" #include <set> using namespace llvm; diff --git a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index 8f468ebf8949..ea148b728a10 100644 --- a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -19,6 +19,7 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LazyValueInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -28,11 +29,11 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" @@ -43,7 +44,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <utility> @@ -52,12 +52,14 @@ using namespace llvm; #define DEBUG_TYPE "correlated-value-propagation" STATISTIC(NumPhis, "Number of phis propagated"); +STATISTIC(NumPhiCommon, "Number of phis deleted via common incoming value"); STATISTIC(NumSelects, "Number of selects propagated"); STATISTIC(NumMemAccess, "Number of memory access targets propagated"); STATISTIC(NumCmps, "Number of comparisons propagated"); STATISTIC(NumReturns, "Number of return values propagated"); STATISTIC(NumDeadCases, "Number of switch cases removed"); STATISTIC(NumSDivs, "Number of sdiv converted to udiv"); +STATISTIC(NumUDivs, "Number of udivs whose width was decreased"); STATISTIC(NumAShrs, "Number of ashr converted to lshr"); STATISTIC(NumSRems, "Number of srem converted to urem"); STATISTIC(NumOverflows, "Number of overflow checks removed"); @@ -77,8 +79,10 @@ namespace { bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<LazyValueInfoWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); } }; @@ -88,6 +92,7 @@ char CorrelatedValuePropagation::ID = 0; INITIALIZE_PASS_BEGIN(CorrelatedValuePropagation, "correlated-propagation", "Value Propagation", false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) INITIALIZE_PASS_END(CorrelatedValuePropagation, "correlated-propagation", "Value Propagation", false, false) @@ -101,14 +106,14 @@ static bool processSelect(SelectInst *S, LazyValueInfo *LVI) { if (S->getType()->isVectorTy()) return false; if (isa<Constant>(S->getOperand(0))) return false; - Constant *C = LVI->getConstant(S->getOperand(0), S->getParent(), S); + Constant *C = LVI->getConstant(S->getCondition(), S->getParent(), S); if (!C) return false; ConstantInt *CI = dyn_cast<ConstantInt>(C); if (!CI) return false; - Value *ReplaceWith = S->getOperand(1); - Value *Other = S->getOperand(2); + Value *ReplaceWith = S->getTrueValue(); + Value *Other = S->getFalseValue(); if (!CI->isOne()) std::swap(ReplaceWith, Other); if (ReplaceWith == S) ReplaceWith = UndefValue::get(S->getType()); @@ -120,7 +125,63 @@ static bool processSelect(SelectInst *S, LazyValueInfo *LVI) { return true; } -static bool processPHI(PHINode *P, LazyValueInfo *LVI, +/// Try to simplify a phi with constant incoming values that match the edge +/// values of a non-constant value on all other edges: +/// bb0: +/// %isnull = icmp eq i8* %x, null +/// br i1 %isnull, label %bb2, label %bb1 +/// bb1: +/// br label %bb2 +/// bb2: +/// %r = phi i8* [ %x, %bb1 ], [ null, %bb0 ] +/// --> +/// %r = %x +static bool simplifyCommonValuePhi(PHINode *P, LazyValueInfo *LVI, + DominatorTree *DT) { + // Collect incoming constants and initialize possible common value. + SmallVector<std::pair<Constant *, unsigned>, 4> IncomingConstants; + Value *CommonValue = nullptr; + for (unsigned i = 0, e = P->getNumIncomingValues(); i != e; ++i) { + Value *Incoming = P->getIncomingValue(i); + if (auto *IncomingConstant = dyn_cast<Constant>(Incoming)) { + IncomingConstants.push_back(std::make_pair(IncomingConstant, i)); + } else if (!CommonValue) { + // The potential common value is initialized to the first non-constant. + CommonValue = Incoming; + } else if (Incoming != CommonValue) { + // There can be only one non-constant common value. + return false; + } + } + + if (!CommonValue || IncomingConstants.empty()) + return false; + + // The common value must be valid in all incoming blocks. + BasicBlock *ToBB = P->getParent(); + if (auto *CommonInst = dyn_cast<Instruction>(CommonValue)) + if (!DT->dominates(CommonInst, ToBB)) + return false; + + // We have a phi with exactly 1 variable incoming value and 1 or more constant + // incoming values. See if all constant incoming values can be mapped back to + // the same incoming variable value. + for (auto &IncomingConstant : IncomingConstants) { + Constant *C = IncomingConstant.first; + BasicBlock *IncomingBB = P->getIncomingBlock(IncomingConstant.second); + if (C != LVI->getConstantOnEdge(CommonValue, IncomingBB, ToBB, P)) + return false; + } + + // All constant incoming values map to the same variable along the incoming + // edges of the phi. The phi is unnecessary. + P->replaceAllUsesWith(CommonValue); + P->eraseFromParent(); + ++NumPhiCommon; + return true; +} + +static bool processPHI(PHINode *P, LazyValueInfo *LVI, DominatorTree *DT, const SimplifyQuery &SQ) { bool Changed = false; @@ -168,7 +229,7 @@ static bool processPHI(PHINode *P, LazyValueInfo *LVI, V = SI->getTrueValue(); } - DEBUG(dbgs() << "CVP: Threading PHI over " << *SI << '\n'); + LLVM_DEBUG(dbgs() << "CVP: Threading PHI over " << *SI << '\n'); } P->setIncomingValue(i, V); @@ -181,6 +242,9 @@ static bool processPHI(PHINode *P, LazyValueInfo *LVI, Changed = true; } + if (!Changed) + Changed = simplifyCommonValuePhi(P, LVI, DT); + if (Changed) ++NumPhis; @@ -243,7 +307,7 @@ static bool processCmp(CmpInst *C, LazyValueInfo *LVI) { /// that cannot fire no matter what the incoming edge can safely be removed. If /// a case fires on every incoming edge then the entire switch can be removed /// and replaced with a branch to the case destination. -static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) { +static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI, DominatorTree *DT) { Value *Cond = SI->getCondition(); BasicBlock *BB = SI->getParent(); @@ -258,6 +322,10 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) { // Analyse each switch case in turn. bool Changed = false; + DenseMap<BasicBlock*, int> SuccessorsCount; + for (auto *Succ : successors(BB)) + SuccessorsCount[Succ]++; + for (auto CI = SI->case_begin(), CE = SI->case_end(); CI != CE;) { ConstantInt *Case = CI->getCaseValue(); @@ -292,7 +360,8 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) { if (State == LazyValueInfo::False) { // This case never fires - remove it. - CI->getCaseSuccessor()->removePredecessor(BB); + BasicBlock *Succ = CI->getCaseSuccessor(); + Succ->removePredecessor(BB); CI = SI->removeCase(CI); CE = SI->case_end(); @@ -302,6 +371,8 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) { ++NumDeadCases; Changed = true; + if (--SuccessorsCount[Succ] == 0) + DT->deleteEdge(BB, Succ); continue; } if (State == LazyValueInfo::True) { @@ -318,10 +389,14 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) { ++CI; } - if (Changed) + if (Changed) { // If the switch has been simplified to the point where it can be replaced // by a branch then do so now. - ConstantFoldTerminator(BB); + DeferredDominance DDT(*DT); + ConstantFoldTerminator(BB, /*DeleteDeadConditions = */ false, + /*TLI = */ nullptr, &DDT); + DDT.flush(); + } return Changed; } @@ -430,9 +505,50 @@ static bool hasPositiveOperands(BinaryOperator *SDI, LazyValueInfo *LVI) { return true; } +/// Try to shrink a udiv/urem's width down to the smallest power of two that's +/// sufficient to contain its operands. +static bool processUDivOrURem(BinaryOperator *Instr, LazyValueInfo *LVI) { + assert(Instr->getOpcode() == Instruction::UDiv || + Instr->getOpcode() == Instruction::URem); + if (Instr->getType()->isVectorTy()) + return false; + + // Find the smallest power of two bitwidth that's sufficient to hold Instr's + // operands. + auto OrigWidth = Instr->getType()->getIntegerBitWidth(); + ConstantRange OperandRange(OrigWidth, /*isFullset=*/false); + for (Value *Operand : Instr->operands()) { + OperandRange = OperandRange.unionWith( + LVI->getConstantRange(Operand, Instr->getParent())); + } + // Don't shrink below 8 bits wide. + unsigned NewWidth = std::max<unsigned>( + PowerOf2Ceil(OperandRange.getUnsignedMax().getActiveBits()), 8); + // NewWidth might be greater than OrigWidth if OrigWidth is not a power of + // two. + if (NewWidth >= OrigWidth) + return false; + + ++NumUDivs; + auto *TruncTy = Type::getIntNTy(Instr->getContext(), NewWidth); + auto *LHS = CastInst::Create(Instruction::Trunc, Instr->getOperand(0), TruncTy, + Instr->getName() + ".lhs.trunc", Instr); + auto *RHS = CastInst::Create(Instruction::Trunc, Instr->getOperand(1), TruncTy, + Instr->getName() + ".rhs.trunc", Instr); + auto *BO = + BinaryOperator::Create(Instr->getOpcode(), LHS, RHS, Instr->getName(), Instr); + auto *Zext = CastInst::Create(Instruction::ZExt, BO, Instr->getType(), + Instr->getName() + ".zext", Instr); + if (BO->getOpcode() == Instruction::UDiv) + BO->setIsExact(Instr->isExact()); + + Instr->replaceAllUsesWith(Zext); + Instr->eraseFromParent(); + return true; +} + static bool processSRem(BinaryOperator *SDI, LazyValueInfo *LVI) { - if (SDI->getType()->isVectorTy() || - !hasPositiveOperands(SDI, LVI)) + if (SDI->getType()->isVectorTy() || !hasPositiveOperands(SDI, LVI)) return false; ++NumSRems; @@ -440,6 +556,10 @@ static bool processSRem(BinaryOperator *SDI, LazyValueInfo *LVI) { SDI->getName(), SDI); SDI->replaceAllUsesWith(BO); SDI->eraseFromParent(); + + // Try to process our new urem. + processUDivOrURem(BO, LVI); + return true; } @@ -449,8 +569,7 @@ static bool processSRem(BinaryOperator *SDI, LazyValueInfo *LVI) { /// conditions, this can sometimes prove conditions instcombine can't by /// exploiting range information. static bool processSDiv(BinaryOperator *SDI, LazyValueInfo *LVI) { - if (SDI->getType()->isVectorTy() || - !hasPositiveOperands(SDI, LVI)) + if (SDI->getType()->isVectorTy() || !hasPositiveOperands(SDI, LVI)) return false; ++NumSDivs; @@ -460,6 +579,9 @@ static bool processSDiv(BinaryOperator *SDI, LazyValueInfo *LVI) { SDI->replaceAllUsesWith(BO); SDI->eraseFromParent(); + // Try to simplify our new udiv. + processUDivOrURem(BO, LVI); + return true; } @@ -559,7 +681,8 @@ static Constant *getConstantAt(Value *V, Instruction *At, LazyValueInfo *LVI) { ConstantInt::getFalse(C->getContext()); } -static bool runImpl(Function &F, LazyValueInfo *LVI, const SimplifyQuery &SQ) { +static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT, + const SimplifyQuery &SQ) { bool FnChanged = false; // Visiting in a pre-order depth-first traversal causes us to simplify early // blocks before querying later blocks (which require us to analyze early @@ -575,7 +698,7 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, const SimplifyQuery &SQ) { BBChanged |= processSelect(cast<SelectInst>(II), LVI); break; case Instruction::PHI: - BBChanged |= processPHI(cast<PHINode>(II), LVI, SQ); + BBChanged |= processPHI(cast<PHINode>(II), LVI, DT, SQ); break; case Instruction::ICmp: case Instruction::FCmp: @@ -595,6 +718,10 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, const SimplifyQuery &SQ) { case Instruction::SDiv: BBChanged |= processSDiv(cast<BinaryOperator>(II), LVI); break; + case Instruction::UDiv: + case Instruction::URem: + BBChanged |= processUDivOrURem(cast<BinaryOperator>(II), LVI); + break; case Instruction::AShr: BBChanged |= processAShr(cast<BinaryOperator>(II), LVI); break; @@ -607,7 +734,7 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, const SimplifyQuery &SQ) { Instruction *Term = BB->getTerminator(); switch (Term->getOpcode()) { case Instruction::Switch: - BBChanged |= processSwitch(cast<SwitchInst>(Term), LVI); + BBChanged |= processSwitch(cast<SwitchInst>(Term), LVI, DT); break; case Instruction::Ret: { auto *RI = cast<ReturnInst>(Term); @@ -636,18 +763,22 @@ bool CorrelatedValuePropagation::runOnFunction(Function &F) { return false; LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); - return runImpl(F, LVI, getBestSimplifyQuery(*this, F)); + DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + + return runImpl(F, LVI, DT, getBestSimplifyQuery(*this, F)); } PreservedAnalyses CorrelatedValuePropagationPass::run(Function &F, FunctionAnalysisManager &AM) { - LazyValueInfo *LVI = &AM.getResult<LazyValueAnalysis>(F); - bool Changed = runImpl(F, LVI, getBestSimplifyQuery(AM, F)); + DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F); + + bool Changed = runImpl(F, LVI, DT, getBestSimplifyQuery(AM, F)); if (!Changed) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserve<GlobalsAA>(); + PA.preserve<DominatorTreeAnalysis>(); return PA; } diff --git a/lib/Transforms/Scalar/DCE.cpp b/lib/Transforms/Scalar/DCE.cpp index fa4806e884c3..6078967a0f94 100644 --- a/lib/Transforms/Scalar/DCE.cpp +++ b/lib/Transforms/Scalar/DCE.cpp @@ -20,11 +20,11 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" #include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" using namespace llvm; #define DEBUG_TYPE "dce" @@ -50,6 +50,7 @@ namespace { for (BasicBlock::iterator DI = BB.begin(); DI != BB.end(); ) { Instruction *Inst = &*DI++; if (isInstructionTriviallyDead(Inst, TLI)) { + salvageDebugInfo(*Inst); Inst->eraseFromParent(); Changed = true; ++DIEEliminated; @@ -76,6 +77,8 @@ static bool DCEInstruction(Instruction *I, SmallSetVector<Instruction *, 16> &WorkList, const TargetLibraryInfo *TLI) { if (isInstructionTriviallyDead(I, TLI)) { + salvageDebugInfo(*I); + // Null out all of the instruction's operands to see if any operand becomes // dead as we go. for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { diff --git a/lib/Transforms/Scalar/DeadStoreElimination.cpp b/lib/Transforms/Scalar/DeadStoreElimination.cpp index e703014bb0e6..dd1a2a6adb82 100644 --- a/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -30,6 +30,7 @@ #include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" @@ -56,11 +57,10 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> -#include <cstdint> #include <cstddef> +#include <cstdint> #include <iterator> #include <map> #include <utility> @@ -115,6 +115,9 @@ deleteDeadInstruction(Instruction *I, BasicBlock::iterator *BBI, Instruction *DeadInst = NowDeadInsts.pop_back_val(); ++NumFastOther; + // Try to preserve debug information attached to the dead instruction. + salvageDebugInfo(*DeadInst); + // This instruction is dead, zap it, in stages. Start by removing it from // MemDep, which needs to know the operands and needs it to be in the // function. @@ -146,7 +149,8 @@ deleteDeadInstruction(Instruction *I, BasicBlock::iterator *BBI, /// Does this instruction write some memory? This only returns true for things /// that we can analyze with other helpers below. -static bool hasMemoryWrite(Instruction *I, const TargetLibraryInfo &TLI) { +static bool hasAnalyzableMemoryWrite(Instruction *I, + const TargetLibraryInfo &TLI) { if (isa<StoreInst>(I)) return true; if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { @@ -156,6 +160,9 @@ static bool hasMemoryWrite(Instruction *I, const TargetLibraryInfo &TLI) { case Intrinsic::memset: case Intrinsic::memmove: case Intrinsic::memcpy: + case Intrinsic::memcpy_element_unordered_atomic: + case Intrinsic::memmove_element_unordered_atomic: + case Intrinsic::memset_element_unordered_atomic: case Intrinsic::init_trampoline: case Intrinsic::lifetime_end: return true; @@ -180,43 +187,45 @@ static bool hasMemoryWrite(Instruction *I, const TargetLibraryInfo &TLI) { /// Return a Location stored to by the specified instruction. If isRemovable /// returns true, this function and getLocForRead completely describe the memory /// operations for this instruction. -static MemoryLocation getLocForWrite(Instruction *Inst, AliasAnalysis &AA) { +static MemoryLocation getLocForWrite(Instruction *Inst) { + if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) return MemoryLocation::get(SI); - if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(Inst)) { + if (auto *MI = dyn_cast<AnyMemIntrinsic>(Inst)) { // memcpy/memmove/memset. MemoryLocation Loc = MemoryLocation::getForDest(MI); return Loc; } - IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); - if (!II) - return MemoryLocation(); - - switch (II->getIntrinsicID()) { - default: - return MemoryLocation(); // Unhandled intrinsic. - case Intrinsic::init_trampoline: - // FIXME: We don't know the size of the trampoline, so we can't really - // handle it here. - return MemoryLocation(II->getArgOperand(0)); - case Intrinsic::lifetime_end: { - uint64_t Len = cast<ConstantInt>(II->getArgOperand(0))->getZExtValue(); - return MemoryLocation(II->getArgOperand(1), Len); - } + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { + switch (II->getIntrinsicID()) { + default: + return MemoryLocation(); // Unhandled intrinsic. + case Intrinsic::init_trampoline: + return MemoryLocation(II->getArgOperand(0)); + case Intrinsic::lifetime_end: { + uint64_t Len = cast<ConstantInt>(II->getArgOperand(0))->getZExtValue(); + return MemoryLocation(II->getArgOperand(1), Len); + } + } } + if (auto CS = CallSite(Inst)) + // All the supported TLI functions so far happen to have dest as their + // first argument. + return MemoryLocation(CS.getArgument(0)); + return MemoryLocation(); } -/// Return the location read by the specified "hasMemoryWrite" instruction if -/// any. +/// Return the location read by the specified "hasAnalyzableMemoryWrite" +/// instruction if any. static MemoryLocation getLocForRead(Instruction *Inst, const TargetLibraryInfo &TLI) { - assert(hasMemoryWrite(Inst, TLI) && "Unknown instruction case"); + assert(hasAnalyzableMemoryWrite(Inst, TLI) && "Unknown instruction case"); // The only instructions that both read and write are the mem transfer // instructions (memcpy/memmove). - if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(Inst)) + if (auto *MTI = dyn_cast<AnyMemTransferInst>(Inst)) return MemoryLocation::getForSource(MTI); return MemoryLocation(); } @@ -230,7 +239,7 @@ static bool isRemovable(Instruction *I) { if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { switch (II->getIntrinsicID()) { - default: llvm_unreachable("doesn't pass 'hasMemoryWrite' predicate"); + default: llvm_unreachable("doesn't pass 'hasAnalyzableMemoryWrite' predicate"); case Intrinsic::lifetime_end: // Never remove dead lifetime_end's, e.g. because it is followed by a // free. @@ -243,9 +252,14 @@ static bool isRemovable(Instruction *I) { case Intrinsic::memcpy: // Don't remove volatile memory intrinsics. return !cast<MemIntrinsic>(II)->isVolatile(); + case Intrinsic::memcpy_element_unordered_atomic: + case Intrinsic::memmove_element_unordered_atomic: + case Intrinsic::memset_element_unordered_atomic: + return true; } } + // note: only get here for calls with analyzable writes - i.e. libcalls if (auto CS = CallSite(I)) return CS.getInstruction()->use_empty(); @@ -264,6 +278,8 @@ static bool isShortenableAtTheEnd(Instruction *I) { default: return false; case Intrinsic::memset: case Intrinsic::memcpy: + case Intrinsic::memcpy_element_unordered_atomic: + case Intrinsic::memset_element_unordered_atomic: // Do shorten memory intrinsics. // FIXME: Add memmove if it's also safe to transform. return true; @@ -280,35 +296,27 @@ static bool isShortenableAtTheEnd(Instruction *I) { static bool isShortenableAtTheBeginning(Instruction *I) { // FIXME: Handle only memset for now. Supporting memcpy/memmove should be // easily done by offsetting the source address. - IntrinsicInst *II = dyn_cast<IntrinsicInst>(I); - return II && II->getIntrinsicID() == Intrinsic::memset; + return isa<AnyMemSetInst>(I); } /// Return the pointer that is being written to. static Value *getStoredPointerOperand(Instruction *I) { - if (StoreInst *SI = dyn_cast<StoreInst>(I)) - return SI->getPointerOperand(); - if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) - return MI->getDest(); - - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { - switch (II->getIntrinsicID()) { - default: llvm_unreachable("Unexpected intrinsic!"); - case Intrinsic::init_trampoline: - return II->getArgOperand(0); - } - } - - CallSite CS(I); - // All the supported functions so far happen to have dest as their first - // argument. - return CS.getArgument(0); + //TODO: factor this to reuse getLocForWrite + MemoryLocation Loc = getLocForWrite(I); + assert(Loc.Ptr && + "unable to find pointer written for analyzable instruction?"); + // TODO: most APIs don't expect const Value * + return const_cast<Value*>(Loc.Ptr); } static uint64_t getPointerSize(const Value *V, const DataLayout &DL, - const TargetLibraryInfo &TLI) { + const TargetLibraryInfo &TLI, + const Function *F) { uint64_t Size; - if (getObjectSize(V, Size, DL, &TLI)) + ObjectSizeOpts Opts; + Opts.NullIsUnknownSize = NullPointerIsDefined(F); + + if (getObjectSize(V, Size, DL, &TLI, Opts)) return Size; return MemoryLocation::UnknownSize; } @@ -338,7 +346,9 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, const TargetLibraryInfo &TLI, int64_t &EarlierOff, int64_t &LaterOff, Instruction *DepWrite, - InstOverlapIntervalsTy &IOL) { + InstOverlapIntervalsTy &IOL, + AliasAnalysis &AA, + const Function *F) { // If we don't know the sizes of either access, then we can't do a comparison. if (Later.Size == MemoryLocation::UnknownSize || Earlier.Size == MemoryLocation::UnknownSize) @@ -349,7 +359,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, // If the start pointers are the same, we just have to compare sizes to see if // the later store was larger than the earlier store. - if (P1 == P2) { + if (P1 == P2 || AA.isMustAlias(P1, P2)) { // Make sure that the Later size is >= the Earlier size. if (Later.Size >= Earlier.Size) return OW_Complete; @@ -367,7 +377,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, return OW_Unknown; // If the "Later" store is to a recognizable object, get its size. - uint64_t ObjectSize = getPointerSize(UO2, DL, TLI); + uint64_t ObjectSize = getPointerSize(UO2, DL, TLI, F); if (ObjectSize != MemoryLocation::UnknownSize) if (ObjectSize == Later.Size && ObjectSize >= Earlier.Size) return OW_Complete; @@ -415,9 +425,10 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, // Insert our part of the overlap into the map. auto &IM = IOL[DepWrite]; - DEBUG(dbgs() << "DSE: Partial overwrite: Earlier [" << EarlierOff << ", " << - int64_t(EarlierOff + Earlier.Size) << ") Later [" << - LaterOff << ", " << int64_t(LaterOff + Later.Size) << ")\n"); + LLVM_DEBUG(dbgs() << "DSE: Partial overwrite: Earlier [" << EarlierOff + << ", " << int64_t(EarlierOff + Earlier.Size) + << ") Later [" << LaterOff << ", " + << int64_t(LaterOff + Later.Size) << ")\n"); // Make sure that we only insert non-overlapping intervals and combine // adjacent intervals. The intervals are stored in the map with the ending @@ -454,11 +465,11 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, ILI = IM.begin(); if (ILI->second <= EarlierOff && ILI->first >= int64_t(EarlierOff + Earlier.Size)) { - DEBUG(dbgs() << "DSE: Full overwrite from partials: Earlier [" << - EarlierOff << ", " << - int64_t(EarlierOff + Earlier.Size) << - ") Composite Later [" << - ILI->second << ", " << ILI->first << ")\n"); + LLVM_DEBUG(dbgs() << "DSE: Full overwrite from partials: Earlier [" + << EarlierOff << ", " + << int64_t(EarlierOff + Earlier.Size) + << ") Composite Later [" << ILI->second << ", " + << ILI->first << ")\n"); ++NumCompletePartials; return OW_Complete; } @@ -469,10 +480,11 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, if (EnablePartialStoreMerging && LaterOff >= EarlierOff && int64_t(EarlierOff + Earlier.Size) > LaterOff && uint64_t(LaterOff - EarlierOff) + Later.Size <= Earlier.Size) { - DEBUG(dbgs() << "DSE: Partial overwrite an earlier load [" << EarlierOff - << ", " << int64_t(EarlierOff + Earlier.Size) - << ") by a later store [" << LaterOff << ", " - << int64_t(LaterOff + Later.Size) << ")\n"); + LLVM_DEBUG(dbgs() << "DSE: Partial overwrite an earlier load [" + << EarlierOff << ", " + << int64_t(EarlierOff + Earlier.Size) + << ") by a later store [" << LaterOff << ", " + << int64_t(LaterOff + Later.Size) << ")\n"); // TODO: Maybe come up with a better name? return OW_PartialEarlierWithFullLater; } @@ -514,8 +526,8 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, /// memory region into an identical pointer) then it doesn't actually make its /// input dead in the traditional sense. Consider this case: /// -/// memcpy(A <- B) -/// memcpy(A <- A) +/// memmove(A <- B) +/// memmove(A <- A) /// /// In this case, the second store to A does not make the first store to A dead. /// The usual situation isn't an explicit A<-A store like this (which can be @@ -531,24 +543,35 @@ static bool isPossibleSelfRead(Instruction *Inst, // Self reads can only happen for instructions that read memory. Get the // location read. MemoryLocation InstReadLoc = getLocForRead(Inst, TLI); - if (!InstReadLoc.Ptr) return false; // Not a reading instruction. + if (!InstReadLoc.Ptr) + return false; // Not a reading instruction. // If the read and written loc obviously don't alias, it isn't a read. - if (AA.isNoAlias(InstReadLoc, InstStoreLoc)) return false; - - // Okay, 'Inst' may copy over itself. However, we can still remove a the - // DepWrite instruction if we can prove that it reads from the same location - // as Inst. This handles useful cases like: - // memcpy(A <- B) - // memcpy(A <- B) - // Here we don't know if A/B may alias, but we do know that B/B are must - // aliases, so removing the first memcpy is safe (assuming it writes <= # - // bytes as the second one. - MemoryLocation DepReadLoc = getLocForRead(DepWrite, TLI); - - if (DepReadLoc.Ptr && AA.isMustAlias(InstReadLoc.Ptr, DepReadLoc.Ptr)) + if (AA.isNoAlias(InstReadLoc, InstStoreLoc)) return false; + if (isa<AnyMemCpyInst>(Inst)) { + // LLVM's memcpy overlap semantics are not fully fleshed out (see PR11763) + // but in practice memcpy(A <- B) either means that A and B are disjoint or + // are equal (i.e. there are not partial overlaps). Given that, if we have: + // + // memcpy/memmove(A <- B) // DepWrite + // memcpy(A <- B) // Inst + // + // with Inst reading/writing a >= size than DepWrite, we can reason as + // follows: + // + // - If A == B then both the copies are no-ops, so the DepWrite can be + // removed. + // - If A != B then A and B are disjoint locations in Inst. Since + // Inst.size >= DepWrite.size A and B are disjoint in DepWrite too. + // Therefore DepWrite can be removed. + MemoryLocation DepReadLoc = getLocForRead(DepWrite, TLI); + + if (DepReadLoc.Ptr && AA.isMustAlias(InstReadLoc.Ptr, DepReadLoc.Ptr)) + return false; + } + // If DepWrite doesn't read memory or if we can't prove it is a must alias, // then it can't be considered dead. return true; @@ -650,7 +673,8 @@ static bool handleFree(CallInst *F, AliasAnalysis *AA, MD->getPointerDependencyFrom(Loc, false, InstPt->getIterator(), BB); while (Dep.isDef() || Dep.isClobber()) { Instruction *Dependency = Dep.getInst(); - if (!hasMemoryWrite(Dependency, *TLI) || !isRemovable(Dependency)) + if (!hasAnalyzableMemoryWrite(Dependency, *TLI) || + !isRemovable(Dependency)) break; Value *DepPointer = @@ -660,8 +684,9 @@ static bool handleFree(CallInst *F, AliasAnalysis *AA, if (!AA->isMustAlias(F->getArgOperand(0), DepPointer)) break; - DEBUG(dbgs() << "DSE: Dead Store to soon to be freed memory:\n DEAD: " - << *Dependency << '\n'); + LLVM_DEBUG( + dbgs() << "DSE: Dead Store to soon to be freed memory:\n DEAD: " + << *Dependency << '\n'); // DCE instructions only used to calculate that store. BasicBlock::iterator BBI(Dependency); @@ -690,7 +715,8 @@ static bool handleFree(CallInst *F, AliasAnalysis *AA, static void removeAccessedObjects(const MemoryLocation &LoadedLoc, SmallSetVector<Value *, 16> &DeadStackObjects, const DataLayout &DL, AliasAnalysis *AA, - const TargetLibraryInfo *TLI) { + const TargetLibraryInfo *TLI, + const Function *F) { const Value *UnderlyingPointer = GetUnderlyingObject(LoadedLoc.Ptr, DL); // A constant can't be in the dead pointer set. @@ -707,7 +733,7 @@ static void removeAccessedObjects(const MemoryLocation &LoadedLoc, // Remove objects that could alias LoadedLoc. DeadStackObjects.remove_if([&](Value *I) { // See if the loaded location could alias the stack location. - MemoryLocation StackLoc(I, getPointerSize(I, DL, *TLI)); + MemoryLocation StackLoc(I, getPointerSize(I, DL, *TLI, F)); return !AA->isNoAlias(StackLoc, LoadedLoc); }); } @@ -754,7 +780,7 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, --BBI; // If we find a store, check to see if it points into a dead stack value. - if (hasMemoryWrite(&*BBI, *TLI) && isRemovable(&*BBI)) { + if (hasAnalyzableMemoryWrite(&*BBI, *TLI) && isRemovable(&*BBI)) { // See through pointer-to-pointer bitcasts SmallVector<Value *, 4> Pointers; GetUnderlyingObjects(getStoredPointerOperand(&*BBI), Pointers, DL); @@ -770,15 +796,16 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, if (AllDead) { Instruction *Dead = &*BBI; - DEBUG(dbgs() << "DSE: Dead Store at End of Block:\n DEAD: " - << *Dead << "\n Objects: "; - for (SmallVectorImpl<Value *>::iterator I = Pointers.begin(), - E = Pointers.end(); I != E; ++I) { - dbgs() << **I; - if (std::next(I) != E) - dbgs() << ", "; - } - dbgs() << '\n'); + LLVM_DEBUG(dbgs() << "DSE: Dead Store at End of Block:\n DEAD: " + << *Dead << "\n Objects: "; + for (SmallVectorImpl<Value *>::iterator I = Pointers.begin(), + E = Pointers.end(); + I != E; ++I) { + dbgs() << **I; + if (std::next(I) != E) + dbgs() << ", "; + } dbgs() + << '\n'); // DCE instructions only used to calculate that store. deleteDeadInstruction(Dead, &BBI, *MD, *TLI, IOL, InstrOrdering, &DeadStackObjects); @@ -790,8 +817,8 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, // Remove any dead non-memory-mutating instructions. if (isInstructionTriviallyDead(&*BBI, TLI)) { - DEBUG(dbgs() << "DSE: Removing trivially dead instruction:\n DEAD: " - << *&*BBI << '\n'); + LLVM_DEBUG(dbgs() << "DSE: Removing trivially dead instruction:\n DEAD: " + << *&*BBI << '\n'); deleteDeadInstruction(&*BBI, &BBI, *MD, *TLI, IOL, InstrOrdering, &DeadStackObjects); ++NumFastOther; MadeChange = true; @@ -820,7 +847,8 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, // the call is live. DeadStackObjects.remove_if([&](Value *I) { // See if the call site touches the value. - return isRefSet(AA->getModRefInfo(CS, I, getPointerSize(I, DL, *TLI))); + return isRefSet(AA->getModRefInfo(CS, I, getPointerSize(I, DL, *TLI, + BB.getParent()))); }); // If all of the allocas were clobbered by the call then we're not going @@ -848,8 +876,6 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, LoadedLoc = MemoryLocation::get(L); } else if (VAArgInst *V = dyn_cast<VAArgInst>(BBI)) { LoadedLoc = MemoryLocation::get(V); - } else if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(BBI)) { - LoadedLoc = MemoryLocation::getForSource(MTI); } else if (!BBI->mayReadFromMemory()) { // Instruction doesn't read memory. Note that stores that weren't removed // above will hit this case. @@ -861,7 +887,7 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, // Remove any allocas from the DeadPointer set that are loaded, as this // makes any stores above the access live. - removeAccessedObjects(LoadedLoc, DeadStackObjects, DL, AA, TLI); + removeAccessedObjects(LoadedLoc, DeadStackObjects, DL, AA, TLI, BB.getParent()); // If all of the allocas were clobbered by the access then we're not going // to find anything else to process. @@ -881,8 +907,8 @@ static bool tryToShorten(Instruction *EarlierWrite, int64_t &EarlierOffset, // Power of 2 vector writes are probably always a bad idea to optimize // as any store/memset/memcpy is likely using vector instructions so // shortening it to not vector size is likely to be slower - MemIntrinsic *EarlierIntrinsic = cast<MemIntrinsic>(EarlierWrite); - unsigned EarlierWriteAlign = EarlierIntrinsic->getAlignment(); + auto *EarlierIntrinsic = cast<AnyMemIntrinsic>(EarlierWrite); + unsigned EarlierWriteAlign = EarlierIntrinsic->getDestAlignment(); if (!IsOverwriteEnd) LaterOffset = int64_t(LaterOffset + LaterSize); @@ -890,15 +916,23 @@ static bool tryToShorten(Instruction *EarlierWrite, int64_t &EarlierOffset, !((EarlierWriteAlign != 0) && LaterOffset % EarlierWriteAlign == 0)) return false; - DEBUG(dbgs() << "DSE: Remove Dead Store:\n OW " - << (IsOverwriteEnd ? "END" : "BEGIN") << ": " << *EarlierWrite - << "\n KILLER (offset " << LaterOffset << ", " << EarlierSize - << ")\n"); - int64_t NewLength = IsOverwriteEnd ? LaterOffset - EarlierOffset : EarlierSize - (LaterOffset - EarlierOffset); + if (auto *AMI = dyn_cast<AtomicMemIntrinsic>(EarlierWrite)) { + // When shortening an atomic memory intrinsic, the newly shortened + // length must remain an integer multiple of the element size. + const uint32_t ElementSize = AMI->getElementSizeInBytes(); + if (0 != NewLength % ElementSize) + return false; + } + + LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n OW " + << (IsOverwriteEnd ? "END" : "BEGIN") << ": " + << *EarlierWrite << "\n KILLER (offset " << LaterOffset + << ", " << EarlierSize << ")\n"); + Value *EarlierWriteLength = EarlierIntrinsic->getLength(); Value *TrimmedLength = ConstantInt::get(EarlierWriteLength->getType(), NewLength); @@ -966,7 +1000,7 @@ static bool removePartiallyOverlappedStores(AliasAnalysis *AA, bool Changed = false; for (auto OI : IOL) { Instruction *EarlierWrite = OI.first; - MemoryLocation Loc = getLocForWrite(EarlierWrite, *AA); + MemoryLocation Loc = getLocForWrite(EarlierWrite); assert(isRemovable(EarlierWrite) && "Expect only removable instruction"); assert(Loc.Size != MemoryLocation::UnknownSize && "Unexpected mem loc"); @@ -1002,8 +1036,9 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI, if (SI->getPointerOperand() == DepLoad->getPointerOperand() && isRemovable(SI) && memoryIsNotModifiedBetween(DepLoad, SI, AA)) { - DEBUG(dbgs() << "DSE: Remove Store Of Load from same pointer:\n LOAD: " - << *DepLoad << "\n STORE: " << *SI << '\n'); + LLVM_DEBUG( + dbgs() << "DSE: Remove Store Of Load from same pointer:\n LOAD: " + << *DepLoad << "\n STORE: " << *SI << '\n'); deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, InstrOrdering); ++NumRedundantStores; @@ -1019,7 +1054,7 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI, if (UnderlyingPointer && isCallocLikeFn(UnderlyingPointer, TLI) && memoryIsNotModifiedBetween(UnderlyingPointer, SI, AA)) { - DEBUG( + LLVM_DEBUG( dbgs() << "DSE: Remove null store to the calloc'ed object:\n DEAD: " << *Inst << "\n OBJECT: " << *UnderlyingPointer << '\n'); @@ -1067,7 +1102,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, } // Check to see if Inst writes to memory. If not, continue. - if (!hasMemoryWrite(Inst, *TLI)) + if (!hasAnalyzableMemoryWrite(Inst, *TLI)) continue; // eliminateNoopStore will update in iterator, if necessary. @@ -1085,7 +1120,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, continue; // Figure out what location is being stored to. - MemoryLocation Loc = getLocForWrite(Inst, *AA); + MemoryLocation Loc = getLocForWrite(Inst); // If we didn't get a useful location, fail. if (!Loc.Ptr) @@ -1107,7 +1142,9 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, // // Find out what memory location the dependent instruction stores. Instruction *DepWrite = InstDep.getInst(); - MemoryLocation DepLoc = getLocForWrite(DepWrite, *AA); + if (!hasAnalyzableMemoryWrite(DepWrite, *TLI)) + break; + MemoryLocation DepLoc = getLocForWrite(DepWrite); // If we didn't get a useful location, or if it isn't a size, bail out. if (!DepLoc.Ptr) break; @@ -1145,12 +1182,12 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, if (isRemovable(DepWrite) && !isPossibleSelfRead(Inst, Loc, DepWrite, *TLI, *AA)) { int64_t InstWriteOffset, DepWriteOffset; - OverwriteResult OR = - isOverwrite(Loc, DepLoc, DL, *TLI, DepWriteOffset, InstWriteOffset, - DepWrite, IOL); + OverwriteResult OR = isOverwrite(Loc, DepLoc, DL, *TLI, DepWriteOffset, + InstWriteOffset, DepWrite, IOL, *AA, + BB.getParent()); if (OR == OW_Complete) { - DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " - << *DepWrite << "\n KILLER: " << *Inst << '\n'); + LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *DepWrite + << "\n KILLER: " << *Inst << '\n'); // Delete the store and now-dead instructions that feed it. deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL, &InstrOrdering); @@ -1176,7 +1213,8 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, auto *Earlier = dyn_cast<StoreInst>(DepWrite); auto *Later = dyn_cast<StoreInst>(Inst); if (Earlier && isa<ConstantInt>(Earlier->getValueOperand()) && - Later && isa<ConstantInt>(Later->getValueOperand())) { + Later && isa<ConstantInt>(Later->getValueOperand()) && + memoryIsNotModifiedBetween(Earlier, Later, AA)) { // If the store we find is: // a) partially overwritten by the store to 'Loc' // b) the later store is fully contained in the earlier one and @@ -1207,9 +1245,9 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, // store, shifted appropriately. APInt Merged = (EarlierValue & ~Mask) | (LaterValue << LShiftAmount); - DEBUG(dbgs() << "DSE: Merge Stores:\n Earlier: " << *DepWrite - << "\n Later: " << *Inst - << "\n Merged Value: " << Merged << '\n'); + LLVM_DEBUG(dbgs() << "DSE: Merge Stores:\n Earlier: " << *DepWrite + << "\n Later: " << *Inst + << "\n Merged Value: " << Merged << '\n'); auto *SI = new StoreInst( ConstantInt::get(Earlier->getValueOperand()->getType(), Merged), diff --git a/lib/Transforms/Scalar/DivRemPairs.cpp b/lib/Transforms/Scalar/DivRemPairs.cpp index e383af89a384..e1bc590c5c9a 100644 --- a/lib/Transforms/Scalar/DivRemPairs.cpp +++ b/lib/Transforms/Scalar/DivRemPairs.cpp @@ -13,6 +13,8 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/DivRemPairs.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -48,7 +50,10 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, // Insert all divide and remainder instructions into maps keyed by their // operands and opcode (signed or unsigned). - DenseMap<DivRemMapKey, Instruction *> DivMap, RemMap; + DenseMap<DivRemMapKey, Instruction *> DivMap; + // Use a MapVector for RemMap so that instructions are moved/inserted in a + // deterministic order. + MapVector<DivRemMapKey, Instruction *> RemMap; for (auto &BB : F) { for (auto &I : BB) { if (I.getOpcode() == Instruction::SDiv) @@ -67,14 +72,14 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, // rare than division. for (auto &RemPair : RemMap) { // Find the matching division instruction from the division map. - Instruction *DivInst = DivMap[RemPair.getFirst()]; + Instruction *DivInst = DivMap[RemPair.first]; if (!DivInst) continue; // We have a matching pair of div/rem instructions. If one dominates the // other, hoist and/or replace one. NumPairs++; - Instruction *RemInst = RemPair.getSecond(); + Instruction *RemInst = RemPair.second; bool IsSigned = DivInst->getOpcode() == Instruction::SDiv; bool HasDivRemOp = TTI.hasDivRemOp(DivInst->getType(), IsSigned); diff --git a/lib/Transforms/Scalar/EarlyCSE.cpp b/lib/Transforms/Scalar/EarlyCSE.cpp index 5798e1c4ee99..565745d12e99 100644 --- a/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/lib/Transforms/Scalar/EarlyCSE.cpp @@ -27,6 +27,7 @@ #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" @@ -49,10 +50,10 @@ #include "llvm/Support/AtomicOrdering.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugCounter.h" #include "llvm/Support/RecyclingAllocator.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <deque> #include <memory> @@ -70,13 +71,16 @@ STATISTIC(NumCSELoad, "Number of load instructions CSE'd"); STATISTIC(NumCSECall, "Number of call instructions CSE'd"); STATISTIC(NumDSE, "Number of trivial dead stores removed"); +DEBUG_COUNTER(CSECounter, "early-cse", + "Controls which instructions are removed"); + //===----------------------------------------------------------------------===// // SimpleValue //===----------------------------------------------------------------------===// namespace { -/// \brief Struct representing the available values in the scoped hash table. +/// Struct representing the available values in the scoped hash table. struct SimpleValue { Instruction *Inst; @@ -151,12 +155,15 @@ unsigned DenseMapInfo<SimpleValue>::getHashValue(SimpleValue Val) { SelectPatternFlavor SPF = matchSelectPattern(Inst, A, B).Flavor; // TODO: We should also detect FP min/max. if (SPF == SPF_SMIN || SPF == SPF_SMAX || - SPF == SPF_UMIN || SPF == SPF_UMAX || - SPF == SPF_ABS || SPF == SPF_NABS) { + SPF == SPF_UMIN || SPF == SPF_UMAX) { if (A > B) std::swap(A, B); return hash_combine(Inst->getOpcode(), SPF, A, B); } + if (SPF == SPF_ABS || SPF == SPF_NABS) { + // ABS/NABS always puts the input in A and its negation in B. + return hash_combine(Inst->getOpcode(), SPF, A, B); + } if (CastInst *CI = dyn_cast<CastInst>(Inst)) return hash_combine(CI->getOpcode(), CI->getType(), CI->getOperand(0)); @@ -226,8 +233,13 @@ bool DenseMapInfo<SimpleValue>::isEqual(SimpleValue LHS, SimpleValue RHS) { LSPF == SPF_ABS || LSPF == SPF_NABS) { Value *RHSA, *RHSB; SelectPatternFlavor RSPF = matchSelectPattern(RHSI, RHSA, RHSB).Flavor; - return (LSPF == RSPF && ((LHSA == RHSA && LHSB == RHSB) || - (LHSA == RHSB && LHSB == RHSA))); + if (LSPF == RSPF) { + // Abs results are placed in a defined order by matchSelectPattern. + if (LSPF == SPF_ABS || LSPF == SPF_NABS) + return LHSA == RHSA && LHSB == RHSB; + return ((LHSA == RHSA && LHSB == RHSB) || + (LHSA == RHSB && LHSB == RHSA)); + } } return false; @@ -239,7 +251,7 @@ bool DenseMapInfo<SimpleValue>::isEqual(SimpleValue LHS, SimpleValue RHS) { namespace { -/// \brief Struct representing the available call values in the scoped hash +/// Struct representing the available call values in the scoped hash /// table. struct CallValue { Instruction *Inst; @@ -305,7 +317,7 @@ bool DenseMapInfo<CallValue>::isEqual(CallValue LHS, CallValue RHS) { namespace { -/// \brief A simple and fast domtree-based CSE pass. +/// A simple and fast domtree-based CSE pass. /// /// This pass does a simple depth-first walk over the dominator tree, /// eliminating trivially redundant instructions and using instsimplify to @@ -329,7 +341,7 @@ public: ScopedHashTable<SimpleValue, Value *, DenseMapInfo<SimpleValue>, AllocatorTy>; - /// \brief A scoped hash table of the current values of all of our simple + /// A scoped hash table of the current values of all of our simple /// scalar expressions. /// /// As we walk down the domtree, we look to see if instructions are in this: @@ -337,8 +349,8 @@ public: /// that dominated values can succeed in their lookup. ScopedHTType AvailableValues; - /// A scoped hash table of the current values of previously encounted memory - /// locations. + /// A scoped hash table of the current values of previously encountered + /// memory locations. /// /// This allows us to get efficient access to dominating loads or stores when /// we have a fully redundant load. In addition to the most recent load, we @@ -356,13 +368,12 @@ public: unsigned Generation = 0; int MatchingId = -1; bool IsAtomic = false; - bool IsInvariant = false; LoadValue() = default; LoadValue(Instruction *Inst, unsigned Generation, unsigned MatchingId, - bool IsAtomic, bool IsInvariant) + bool IsAtomic) : DefInst(Inst), Generation(Generation), MatchingId(MatchingId), - IsAtomic(IsAtomic), IsInvariant(IsInvariant) {} + IsAtomic(IsAtomic) {} }; using LoadMapAllocator = @@ -373,8 +384,19 @@ public: LoadMapAllocator>; LoadHTType AvailableLoads; + + // A scoped hash table mapping memory locations (represented as typed + // addresses) to generation numbers at which that memory location became + // (henceforth indefinitely) invariant. + using InvariantMapAllocator = + RecyclingAllocator<BumpPtrAllocator, + ScopedHashTableVal<MemoryLocation, unsigned>>; + using InvariantHTType = + ScopedHashTable<MemoryLocation, unsigned, DenseMapInfo<MemoryLocation>, + InvariantMapAllocator>; + InvariantHTType AvailableInvariants; - /// \brief A scoped hash table of the current values of read-only call + /// A scoped hash table of the current values of read-only call /// values. /// /// It uses the same generation count as loads. @@ -382,10 +404,10 @@ public: ScopedHashTable<CallValue, std::pair<Instruction *, unsigned>>; CallHTType AvailableCalls; - /// \brief This is the current generation of the memory value. + /// This is the current generation of the memory value. unsigned CurrentGeneration = 0; - /// \brief Set up the EarlyCSE runner for a particular function. + /// Set up the EarlyCSE runner for a particular function. EarlyCSE(const DataLayout &DL, const TargetLibraryInfo &TLI, const TargetTransformInfo &TTI, DominatorTree &DT, AssumptionCache &AC, MemorySSA *MSSA) @@ -401,15 +423,16 @@ private: class NodeScope { public: NodeScope(ScopedHTType &AvailableValues, LoadHTType &AvailableLoads, - CallHTType &AvailableCalls) - : Scope(AvailableValues), LoadScope(AvailableLoads), - CallScope(AvailableCalls) {} + InvariantHTType &AvailableInvariants, CallHTType &AvailableCalls) + : Scope(AvailableValues), LoadScope(AvailableLoads), + InvariantScope(AvailableInvariants), CallScope(AvailableCalls) {} NodeScope(const NodeScope &) = delete; NodeScope &operator=(const NodeScope &) = delete; private: ScopedHTType::ScopeTy Scope; LoadHTType::ScopeTy LoadScope; + InvariantHTType::ScopeTy InvariantScope; CallHTType::ScopeTy CallScope; }; @@ -420,10 +443,13 @@ private: class StackNode { public: StackNode(ScopedHTType &AvailableValues, LoadHTType &AvailableLoads, - CallHTType &AvailableCalls, unsigned cg, DomTreeNode *n, - DomTreeNode::iterator child, DomTreeNode::iterator end) + InvariantHTType &AvailableInvariants, CallHTType &AvailableCalls, + unsigned cg, DomTreeNode *n, DomTreeNode::iterator child, + DomTreeNode::iterator end) : CurrentGeneration(cg), ChildGeneration(cg), Node(n), ChildIter(child), - EndIter(end), Scopes(AvailableValues, AvailableLoads, AvailableCalls) + EndIter(end), + Scopes(AvailableValues, AvailableLoads, AvailableInvariants, + AvailableCalls) {} StackNode(const StackNode &) = delete; StackNode &operator=(const StackNode &) = delete; @@ -455,7 +481,7 @@ private: bool Processed = false; }; - /// \brief Wrapper class to handle memory instructions, including loads, + /// Wrapper class to handle memory instructions, including loads, /// stores and intrinsic loads and stores defined by the target. class ParseMemoryInst { public: @@ -532,12 +558,7 @@ private: Value *getPointerOperand() const { if (IsTargetMemInst) return Info.PtrVal; - if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { - return LI->getPointerOperand(); - } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { - return SI->getPointerOperand(); - } - return nullptr; + return getLoadStorePointerOperand(Inst); } bool mayReadFromMemory() const { @@ -558,6 +579,9 @@ private: bool processNode(DomTreeNode *Node); + bool handleBranchCondition(Instruction *CondInst, const BranchInst *BI, + const BasicBlock *BB, const BasicBlock *Pred); + Value *getOrCreateResult(Value *Inst, Type *ExpectedType) const { if (auto *LI = dyn_cast<LoadInst>(Inst)) return LI; @@ -568,6 +592,10 @@ private: ExpectedType); } + /// Return true if the instruction is known to only operate on memory + /// provably invariant in the given "generation". + bool isOperatingOnInvariantMemAt(Instruction *I, unsigned GenAt); + bool isSameMemGeneration(unsigned EarlierGeneration, unsigned LaterGeneration, Instruction *EarlierInst, Instruction *LaterInst); @@ -661,6 +689,79 @@ bool EarlyCSE::isSameMemGeneration(unsigned EarlierGeneration, return MSSA->dominates(LaterDef, EarlierMA); } +bool EarlyCSE::isOperatingOnInvariantMemAt(Instruction *I, unsigned GenAt) { + // A location loaded from with an invariant_load is assumed to *never* change + // within the visible scope of the compilation. + if (auto *LI = dyn_cast<LoadInst>(I)) + if (LI->getMetadata(LLVMContext::MD_invariant_load)) + return true; + + auto MemLocOpt = MemoryLocation::getOrNone(I); + if (!MemLocOpt) + // "target" intrinsic forms of loads aren't currently known to + // MemoryLocation::get. TODO + return false; + MemoryLocation MemLoc = *MemLocOpt; + if (!AvailableInvariants.count(MemLoc)) + return false; + + // Is the generation at which this became invariant older than the + // current one? + return AvailableInvariants.lookup(MemLoc) <= GenAt; +} + +bool EarlyCSE::handleBranchCondition(Instruction *CondInst, + const BranchInst *BI, const BasicBlock *BB, + const BasicBlock *Pred) { + assert(BI->isConditional() && "Should be a conditional branch!"); + assert(BI->getCondition() == CondInst && "Wrong condition?"); + assert(BI->getSuccessor(0) == BB || BI->getSuccessor(1) == BB); + auto *TorF = (BI->getSuccessor(0) == BB) + ? ConstantInt::getTrue(BB->getContext()) + : ConstantInt::getFalse(BB->getContext()); + auto MatchBinOp = [](Instruction *I, unsigned Opcode) { + if (BinaryOperator *BOp = dyn_cast<BinaryOperator>(I)) + return BOp->getOpcode() == Opcode; + return false; + }; + // If the condition is AND operation, we can propagate its operands into the + // true branch. If it is OR operation, we can propagate them into the false + // branch. + unsigned PropagateOpcode = + (BI->getSuccessor(0) == BB) ? Instruction::And : Instruction::Or; + + bool MadeChanges = false; + SmallVector<Instruction *, 4> WorkList; + SmallPtrSet<Instruction *, 4> Visited; + WorkList.push_back(CondInst); + while (!WorkList.empty()) { + Instruction *Curr = WorkList.pop_back_val(); + + AvailableValues.insert(Curr, TorF); + LLVM_DEBUG(dbgs() << "EarlyCSE CVP: Add conditional value for '" + << Curr->getName() << "' as " << *TorF << " in " + << BB->getName() << "\n"); + if (!DebugCounter::shouldExecute(CSECounter)) { + LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); + } else { + // Replace all dominated uses with the known value. + if (unsigned Count = replaceDominatedUsesWith(Curr, TorF, DT, + BasicBlockEdge(Pred, BB))) { + NumCSECVP += Count; + MadeChanges = true; + } + } + + if (MatchBinOp(Curr, PropagateOpcode)) + for (auto &Op : cast<BinaryOperator>(Curr)->operands()) + if (Instruction *OPI = dyn_cast<Instruction>(Op)) + if (SimpleValue::canHandle(OPI) && Visited.insert(OPI).second) + WorkList.push_back(OPI); + } + + return MadeChanges; +} + bool EarlyCSE::processNode(DomTreeNode *Node) { bool Changed = false; BasicBlock *BB = Node->getBlock(); @@ -684,22 +785,8 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { auto *BI = dyn_cast<BranchInst>(Pred->getTerminator()); if (BI && BI->isConditional()) { auto *CondInst = dyn_cast<Instruction>(BI->getCondition()); - if (CondInst && SimpleValue::canHandle(CondInst)) { - assert(BI->getSuccessor(0) == BB || BI->getSuccessor(1) == BB); - auto *TorF = (BI->getSuccessor(0) == BB) - ? ConstantInt::getTrue(BB->getContext()) - : ConstantInt::getFalse(BB->getContext()); - AvailableValues.insert(CondInst, TorF); - DEBUG(dbgs() << "EarlyCSE CVP: Add conditional value for '" - << CondInst->getName() << "' as " << *TorF << " in " - << BB->getName() << "\n"); - // Replace all dominated uses with the known value. - if (unsigned Count = replaceDominatedUsesWith( - CondInst, TorF, DT, BasicBlockEdge(Pred, BB))) { - Changed = true; - NumCSECVP += Count; - } - } + if (CondInst && SimpleValue::canHandle(CondInst)) + Changed |= handleBranchCondition(CondInst, BI, BB, Pred); } } @@ -716,7 +803,12 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // Dead instructions should just be removed. if (isInstructionTriviallyDead(Inst, &TLI)) { - DEBUG(dbgs() << "EarlyCSE DCE: " << *Inst << '\n'); + LLVM_DEBUG(dbgs() << "EarlyCSE DCE: " << *Inst << '\n'); + if (!DebugCounter::shouldExecute(CSECounter)) { + LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); + continue; + } + salvageDebugInfo(*Inst); removeMSSA(Inst); Inst->eraseFromParent(); Changed = true; @@ -732,31 +824,44 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { auto *CondI = dyn_cast<Instruction>(cast<CallInst>(Inst)->getArgOperand(0)); if (CondI && SimpleValue::canHandle(CondI)) { - DEBUG(dbgs() << "EarlyCSE considering assumption: " << *Inst << '\n'); + LLVM_DEBUG(dbgs() << "EarlyCSE considering assumption: " << *Inst + << '\n'); AvailableValues.insert(CondI, ConstantInt::getTrue(BB->getContext())); } else - DEBUG(dbgs() << "EarlyCSE skipping assumption: " << *Inst << '\n'); + LLVM_DEBUG(dbgs() << "EarlyCSE skipping assumption: " << *Inst << '\n'); continue; } // Skip sideeffect intrinsics, for the same reason as assume intrinsics. if (match(Inst, m_Intrinsic<Intrinsic::sideeffect>())) { - DEBUG(dbgs() << "EarlyCSE skipping sideeffect: " << *Inst << '\n'); + LLVM_DEBUG(dbgs() << "EarlyCSE skipping sideeffect: " << *Inst << '\n'); continue; } - // Skip invariant.start intrinsics since they only read memory, and we can - // forward values across it. Also, we dont need to consume the last store - // since the semantics of invariant.start allow us to perform DSE of the - // last store, if there was a store following invariant.start. Consider: + // We can skip all invariant.start intrinsics since they only read memory, + // and we can forward values across it. For invariant starts without + // invariant ends, we can use the fact that the invariantness never ends to + // start a scope in the current generaton which is true for all future + // generations. Also, we dont need to consume the last store since the + // semantics of invariant.start allow us to perform DSE of the last + // store, if there was a store following invariant.start. Consider: // // store 30, i8* p // invariant.start(p) // store 40, i8* p // We can DSE the store to 30, since the store 40 to invariant location p // causes undefined behaviour. - if (match(Inst, m_Intrinsic<Intrinsic::invariant_start>())) + if (match(Inst, m_Intrinsic<Intrinsic::invariant_start>())) { + // If there are any uses, the scope might end. + if (!Inst->use_empty()) + continue; + auto *CI = cast<CallInst>(Inst); + MemoryLocation MemLoc = MemoryLocation::getForArgument(CI, 1, TLI); + // Don't start a scope if we already have a better one pushed + if (!AvailableInvariants.count(MemLoc)) + AvailableInvariants.insert(MemLoc, CurrentGeneration); continue; + } if (match(Inst, m_Intrinsic<Intrinsic::experimental_guard>())) { if (auto *CondI = @@ -767,7 +872,8 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // Is the condition known to be true? if (isa<ConstantInt>(KnownCond) && cast<ConstantInt>(KnownCond)->isOne()) { - DEBUG(dbgs() << "EarlyCSE removing guard: " << *Inst << '\n'); + LLVM_DEBUG(dbgs() + << "EarlyCSE removing guard: " << *Inst << '\n'); removeMSSA(Inst); Inst->eraseFromParent(); Changed = true; @@ -792,29 +898,39 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // If the instruction can be simplified (e.g. X+0 = X) then replace it with // its simpler value. if (Value *V = SimplifyInstruction(Inst, SQ)) { - DEBUG(dbgs() << "EarlyCSE Simplify: " << *Inst << " to: " << *V << '\n'); - bool Killed = false; - if (!Inst->use_empty()) { - Inst->replaceAllUsesWith(V); - Changed = true; - } - if (isInstructionTriviallyDead(Inst, &TLI)) { - removeMSSA(Inst); - Inst->eraseFromParent(); - Changed = true; - Killed = true; + LLVM_DEBUG(dbgs() << "EarlyCSE Simplify: " << *Inst << " to: " << *V + << '\n'); + if (!DebugCounter::shouldExecute(CSECounter)) { + LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); + } else { + bool Killed = false; + if (!Inst->use_empty()) { + Inst->replaceAllUsesWith(V); + Changed = true; + } + if (isInstructionTriviallyDead(Inst, &TLI)) { + removeMSSA(Inst); + Inst->eraseFromParent(); + Changed = true; + Killed = true; + } + if (Changed) + ++NumSimplify; + if (Killed) + continue; } - if (Changed) - ++NumSimplify; - if (Killed) - continue; } // If this is a simple instruction that we can value number, process it. if (SimpleValue::canHandle(Inst)) { // See if the instruction has an available value. If so, use it. if (Value *V = AvailableValues.lookup(Inst)) { - DEBUG(dbgs() << "EarlyCSE CSE: " << *Inst << " to: " << *V << '\n'); + LLVM_DEBUG(dbgs() << "EarlyCSE CSE: " << *Inst << " to: " << *V + << '\n'); + if (!DebugCounter::shouldExecute(CSECounter)) { + LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); + continue; + } if (auto *I = dyn_cast<Instruction>(V)) I->andIRFlags(Inst); Inst->replaceAllUsesWith(V); @@ -840,6 +956,17 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { ++CurrentGeneration; } + if (MemInst.isInvariantLoad()) { + // If we pass an invariant load, we know that memory location is + // indefinitely constant from the moment of first dereferenceability. + // We conservatively treat the invariant_load as that moment. If we + // pass a invariant load after already establishing a scope, don't + // restart it since we want to preserve the earliest point seen. + auto MemLoc = MemoryLocation::get(Inst); + if (!AvailableInvariants.count(MemLoc)) + AvailableInvariants.insert(MemLoc, CurrentGeneration); + } + // If we have an available version of this load, and if it is the right // generation or the load is known to be from an invariant location, // replace this instruction. @@ -854,13 +981,17 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { !MemInst.isVolatile() && MemInst.isUnordered() && // We can't replace an atomic load with one which isn't also atomic. InVal.IsAtomic >= MemInst.isAtomic() && - (InVal.IsInvariant || MemInst.isInvariantLoad() || + (isOperatingOnInvariantMemAt(Inst, InVal.Generation) || isSameMemGeneration(InVal.Generation, CurrentGeneration, InVal.DefInst, Inst))) { Value *Op = getOrCreateResult(InVal.DefInst, Inst->getType()); if (Op != nullptr) { - DEBUG(dbgs() << "EarlyCSE CSE LOAD: " << *Inst - << " to: " << *InVal.DefInst << '\n'); + LLVM_DEBUG(dbgs() << "EarlyCSE CSE LOAD: " << *Inst + << " to: " << *InVal.DefInst << '\n'); + if (!DebugCounter::shouldExecute(CSECounter)) { + LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); + continue; + } if (!Inst->use_empty()) Inst->replaceAllUsesWith(Op); removeMSSA(Inst); @@ -875,7 +1006,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { AvailableLoads.insert( MemInst.getPointerOperand(), LoadValue(Inst, CurrentGeneration, MemInst.getMatchingId(), - MemInst.isAtomic(), MemInst.isInvariantLoad())); + MemInst.isAtomic())); LastStore = nullptr; continue; } @@ -898,8 +1029,12 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { if (InVal.first != nullptr && isSameMemGeneration(InVal.second, CurrentGeneration, InVal.first, Inst)) { - DEBUG(dbgs() << "EarlyCSE CSE CALL: " << *Inst - << " to: " << *InVal.first << '\n'); + LLVM_DEBUG(dbgs() << "EarlyCSE CSE CALL: " << *Inst + << " to: " << *InVal.first << '\n'); + if (!DebugCounter::shouldExecute(CSECounter)) { + LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); + continue; + } if (!Inst->use_empty()) Inst->replaceAllUsesWith(InVal.first); removeMSSA(Inst); @@ -938,8 +1073,9 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { InVal.MatchingId == MemInst.getMatchingId() && // We don't yet handle removing stores with ordering of any kind. !MemInst.isVolatile() && MemInst.isUnordered() && - isSameMemGeneration(InVal.Generation, CurrentGeneration, - InVal.DefInst, Inst)) { + (isOperatingOnInvariantMemAt(Inst, InVal.Generation) || + isSameMemGeneration(InVal.Generation, CurrentGeneration, + InVal.DefInst, Inst))) { // It is okay to have a LastStore to a different pointer here if MemorySSA // tells us that the load and store are from the same memory generation. // In that case, LastStore should keep its present value since we're @@ -949,7 +1085,11 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { MemInst.getPointerOperand() || MSSA) && "can't have an intervening store if not using MemorySSA!"); - DEBUG(dbgs() << "EarlyCSE DSE (writeback): " << *Inst << '\n'); + LLVM_DEBUG(dbgs() << "EarlyCSE DSE (writeback): " << *Inst << '\n'); + if (!DebugCounter::shouldExecute(CSECounter)) { + LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); + continue; + } removeMSSA(Inst); Inst->eraseFromParent(); Changed = true; @@ -980,13 +1120,17 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { !LastStoreMemInst.isVolatile() && "Violated invariant"); if (LastStoreMemInst.isMatchingMemLoc(MemInst)) { - DEBUG(dbgs() << "EarlyCSE DEAD STORE: " << *LastStore - << " due to: " << *Inst << '\n'); - removeMSSA(LastStore); - LastStore->eraseFromParent(); - Changed = true; - ++NumDSE; - LastStore = nullptr; + LLVM_DEBUG(dbgs() << "EarlyCSE DEAD STORE: " << *LastStore + << " due to: " << *Inst << '\n'); + if (!DebugCounter::shouldExecute(CSECounter)) { + LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); + } else { + removeMSSA(LastStore); + LastStore->eraseFromParent(); + Changed = true; + ++NumDSE; + LastStore = nullptr; + } } // fallthrough - we can exploit information about this store } @@ -999,7 +1143,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { AvailableLoads.insert( MemInst.getPointerOperand(), LoadValue(Inst, CurrentGeneration, MemInst.getMatchingId(), - MemInst.isAtomic(), /*IsInvariant=*/false)); + MemInst.isAtomic())); // Remember that this was the last unordered store we saw for DSE. We // don't yet handle DSE on ordered or volatile stores since we don't @@ -1031,8 +1175,9 @@ bool EarlyCSE::run() { // Process the root node. nodesToProcess.push_back(new StackNode( - AvailableValues, AvailableLoads, AvailableCalls, CurrentGeneration, - DT.getRootNode(), DT.getRootNode()->begin(), DT.getRootNode()->end())); + AvailableValues, AvailableLoads, AvailableInvariants, AvailableCalls, + CurrentGeneration, DT.getRootNode(), + DT.getRootNode()->begin(), DT.getRootNode()->end())); // Save the current generation. unsigned LiveOutGeneration = CurrentGeneration; @@ -1056,9 +1201,9 @@ bool EarlyCSE::run() { // Push the next child onto the stack. DomTreeNode *child = NodeToProcess->nextChild(); nodesToProcess.push_back( - new StackNode(AvailableValues, AvailableLoads, AvailableCalls, - NodeToProcess->childGeneration(), child, child->begin(), - child->end())); + new StackNode(AvailableValues, AvailableLoads, AvailableInvariants, + AvailableCalls, NodeToProcess->childGeneration(), + child, child->begin(), child->end())); } else { // It has been processed, and there are no more children to process, // so delete it and pop it off the stack. @@ -1097,7 +1242,7 @@ PreservedAnalyses EarlyCSEPass::run(Function &F, namespace { -/// \brief A simple and fast domtree-based CSE pass. +/// A simple and fast domtree-based CSE pass. /// /// This pass does a simple depth-first walk over the dominator tree, /// eliminating trivially redundant instructions and using instsimplify to diff --git a/lib/Transforms/Scalar/FlattenCFGPass.cpp b/lib/Transforms/Scalar/FlattenCFGPass.cpp index 063df779a30b..117b19fb8a42 100644 --- a/lib/Transforms/Scalar/FlattenCFGPass.cpp +++ b/lib/Transforms/Scalar/FlattenCFGPass.cpp @@ -12,10 +12,10 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/CFG.h" #include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" using namespace llvm; #define DEBUG_TYPE "flattencfg" diff --git a/lib/Transforms/Scalar/Float2Int.cpp b/lib/Transforms/Scalar/Float2Int.cpp index b105ece8dc7c..f2828e80bc58 100644 --- a/lib/Transforms/Scalar/Float2Int.cpp +++ b/lib/Transforms/Scalar/Float2Int.cpp @@ -138,7 +138,7 @@ void Float2IntPass::findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots) { // Helper - mark I as having been traversed, having range R. void Float2IntPass::seen(Instruction *I, ConstantRange R) { - DEBUG(dbgs() << "F2I: " << *I << ":" << R << "\n"); + LLVM_DEBUG(dbgs() << "F2I: " << *I << ":" << R << "\n"); auto IT = SeenInsts.find(I); if (IT != SeenInsts.end()) IT->second = std::move(R); @@ -359,7 +359,7 @@ bool Float2IntPass::validateAndTransform() { for (User *U : I->users()) { Instruction *UI = dyn_cast<Instruction>(U); if (!UI || SeenInsts.find(UI) == SeenInsts.end()) { - DEBUG(dbgs() << "F2I: Failing because of " << *U << "\n"); + LLVM_DEBUG(dbgs() << "F2I: Failing because of " << *U << "\n"); Fail = true; break; } @@ -380,7 +380,7 @@ bool Float2IntPass::validateAndTransform() { // lower limits, plus one so it can be signed. unsigned MinBW = std::max(R.getLower().getMinSignedBits(), R.getUpper().getMinSignedBits()) + 1; - DEBUG(dbgs() << "F2I: MinBitwidth=" << MinBW << ", R: " << R << "\n"); + LLVM_DEBUG(dbgs() << "F2I: MinBitwidth=" << MinBW << ", R: " << R << "\n"); // If we've run off the realms of the exactly representable integers, // the floating point result will differ from an integer approximation. @@ -391,11 +391,12 @@ bool Float2IntPass::validateAndTransform() { unsigned MaxRepresentableBits = APFloat::semanticsPrecision(ConvertedToTy->getFltSemantics()) - 1; if (MinBW > MaxRepresentableBits) { - DEBUG(dbgs() << "F2I: Value not guaranteed to be representable!\n"); + LLVM_DEBUG(dbgs() << "F2I: Value not guaranteed to be representable!\n"); continue; } if (MinBW > 64) { - DEBUG(dbgs() << "F2I: Value requires more than 64 bits to represent!\n"); + LLVM_DEBUG( + dbgs() << "F2I: Value requires more than 64 bits to represent!\n"); continue; } @@ -490,7 +491,7 @@ void Float2IntPass::cleanup() { } bool Float2IntPass::runImpl(Function &F) { - DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n"); + LLVM_DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n"); // Clear out all state. ECs = EquivalenceClasses<Instruction*>(); SeenInsts.clear(); diff --git a/lib/Transforms/Scalar/GVN.cpp b/lib/Transforms/Scalar/GVN.cpp index e2c1eaf58e43..1e0a22cb14b3 100644 --- a/lib/Transforms/Scalar/GVN.cpp +++ b/lib/Transforms/Scalar/GVN.cpp @@ -38,7 +38,9 @@ #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/PHITransAddr.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Config/llvm-config.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" @@ -69,7 +71,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include "llvm/Transforms/Utils/VNCoercion.h" #include <algorithm> @@ -765,6 +766,15 @@ static Value *ConstructSSAForLoadSet(LoadInst *LI, if (SSAUpdate.HasValueForBlock(BB)) continue; + // If the value is the load that we will be eliminating, and the block it's + // available in is the block that the load is in, then don't add it as + // SSAUpdater will resolve the value to the relevant phi which may let it + // avoid phi construction entirely if there's actually only one value. + if (BB == LI->getParent() && + ((AV.AV.isSimpleValue() && AV.AV.getSimpleValue() == LI) || + (AV.AV.isCoercedLoadValue() && AV.AV.getCoercedLoadValue() == LI))) + continue; + SSAUpdate.AddAvailableValue(BB, AV.MaterializeAdjustedValue(LI, gvn)); } @@ -783,9 +793,10 @@ Value *AvailableValue::MaterializeAdjustedValue(LoadInst *LI, if (Res->getType() != LoadTy) { Res = getStoreValueForLoad(Res, Offset, LoadTy, InsertPt, DL); - DEBUG(dbgs() << "GVN COERCED NONLOCAL VAL:\nOffset: " << Offset << " " - << *getSimpleValue() << '\n' - << *Res << '\n' << "\n\n\n"); + LLVM_DEBUG(dbgs() << "GVN COERCED NONLOCAL VAL:\nOffset: " << Offset + << " " << *getSimpleValue() << '\n' + << *Res << '\n' + << "\n\n\n"); } } else if (isCoercedLoadValue()) { LoadInst *Load = getCoercedLoadValue(); @@ -799,20 +810,21 @@ Value *AvailableValue::MaterializeAdjustedValue(LoadInst *LI, // but then there all of the operations based on it would need to be // rehashed. Just leave the dead load around. gvn.getMemDep().removeInstruction(Load); - DEBUG(dbgs() << "GVN COERCED NONLOCAL LOAD:\nOffset: " << Offset << " " - << *getCoercedLoadValue() << '\n' - << *Res << '\n' - << "\n\n\n"); + LLVM_DEBUG(dbgs() << "GVN COERCED NONLOCAL LOAD:\nOffset: " << Offset + << " " << *getCoercedLoadValue() << '\n' + << *Res << '\n' + << "\n\n\n"); } } else if (isMemIntrinValue()) { Res = getMemInstValueForLoad(getMemIntrinValue(), Offset, LoadTy, InsertPt, DL); - DEBUG(dbgs() << "GVN COERCED NONLOCAL MEM INTRIN:\nOffset: " << Offset - << " " << *getMemIntrinValue() << '\n' - << *Res << '\n' << "\n\n\n"); + LLVM_DEBUG(dbgs() << "GVN COERCED NONLOCAL MEM INTRIN:\nOffset: " << Offset + << " " << *getMemIntrinValue() << '\n' + << *Res << '\n' + << "\n\n\n"); } else { assert(isUndefValue() && "Should be UndefVal"); - DEBUG(dbgs() << "GVN COERCED NONLOCAL Undef:\n";); + LLVM_DEBUG(dbgs() << "GVN COERCED NONLOCAL Undef:\n";); return UndefValue::get(LoadTy); } assert(Res && "failed to materialize?"); @@ -825,7 +837,7 @@ static bool isLifetimeStart(const Instruction *Inst) { return false; } -/// \brief Try to locate the three instruction involved in a missed +/// Try to locate the three instruction involved in a missed /// load-elimination case that is due to an intervening store. static void reportMayClobberedLoad(LoadInst *LI, MemDepResult DepInfo, DominatorTree *DT, @@ -914,13 +926,11 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, } } // Nothing known about this clobber, have to be conservative - DEBUG( - // fast print dep, using operator<< on instruction is too slow. - dbgs() << "GVN: load "; - LI->printAsOperand(dbgs()); - Instruction *I = DepInfo.getInst(); - dbgs() << " is clobbered by " << *I << '\n'; - ); + LLVM_DEBUG( + // fast print dep, using operator<< on instruction is too slow. + dbgs() << "GVN: load "; LI->printAsOperand(dbgs()); + Instruction *I = DepInfo.getInst(); + dbgs() << " is clobbered by " << *I << '\n';); if (ORE->allowExtraAnalysis(DEBUG_TYPE)) reportMayClobberedLoad(LI, DepInfo, DT, ORE); @@ -978,12 +988,10 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, } // Unknown def - must be conservative - DEBUG( - // fast print dep, using operator<< on instruction is too slow. - dbgs() << "GVN: load "; - LI->printAsOperand(dbgs()); - dbgs() << " has unknown def " << *DepInst << '\n'; - ); + LLVM_DEBUG( + // fast print dep, using operator<< on instruction is too slow. + dbgs() << "GVN: load "; LI->printAsOperand(dbgs()); + dbgs() << " has unknown def " << *DepInst << '\n';); return false; } @@ -1065,7 +1073,7 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, // It is illegal to move the array access to any point above the guard, // because if the index is out of bounds we should deoptimize rather than // access the array. - // Check that there is no guard in this block above our intruction. + // Check that there is no guard in this block above our instruction. if (!IsSafeToSpeculativelyExecute) { auto It = FirstImplicitControlFlowInsts.find(TmpBB); if (It != FirstImplicitControlFlowInsts.end()) { @@ -1113,9 +1121,9 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, // If any predecessor block is an EH pad that does not allow non-PHI // instructions before the terminator, we can't PRE the load. if (Pred->getTerminator()->isEHPad()) { - DEBUG(dbgs() - << "COULD NOT PRE LOAD BECAUSE OF AN EH PAD PREDECESSOR '" - << Pred->getName() << "': " << *LI << '\n'); + LLVM_DEBUG( + dbgs() << "COULD NOT PRE LOAD BECAUSE OF AN EH PAD PREDECESSOR '" + << Pred->getName() << "': " << *LI << '\n'); return false; } @@ -1125,15 +1133,16 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, if (Pred->getTerminator()->getNumSuccessors() != 1) { if (isa<IndirectBrInst>(Pred->getTerminator())) { - DEBUG(dbgs() << "COULD NOT PRE LOAD BECAUSE OF INDBR CRITICAL EDGE '" - << Pred->getName() << "': " << *LI << '\n'); + LLVM_DEBUG( + dbgs() << "COULD NOT PRE LOAD BECAUSE OF INDBR CRITICAL EDGE '" + << Pred->getName() << "': " << *LI << '\n'); return false; } if (LoadBB->isEHPad()) { - DEBUG(dbgs() - << "COULD NOT PRE LOAD BECAUSE OF AN EH PAD CRITICAL EDGE '" - << Pred->getName() << "': " << *LI << '\n'); + LLVM_DEBUG( + dbgs() << "COULD NOT PRE LOAD BECAUSE OF AN EH PAD CRITICAL EDGE '" + << Pred->getName() << "': " << *LI << '\n'); return false; } @@ -1161,8 +1170,8 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, BasicBlock *NewPred = splitCriticalEdges(OrigPred, LoadBB); assert(!PredLoads.count(OrigPred) && "Split edges shouldn't be in map!"); PredLoads[NewPred] = nullptr; - DEBUG(dbgs() << "Split critical edge " << OrigPred->getName() << "->" - << LoadBB->getName() << '\n'); + LLVM_DEBUG(dbgs() << "Split critical edge " << OrigPred->getName() << "->" + << LoadBB->getName() << '\n'); } // Check if the load can safely be moved to all the unavailable predecessors. @@ -1186,8 +1195,8 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, // If we couldn't find or insert a computation of this phi translated value, // we fail PRE. if (!LoadPtr) { - DEBUG(dbgs() << "COULDN'T INSERT PHI TRANSLATED VALUE OF: " - << *LI->getPointerOperand() << "\n"); + LLVM_DEBUG(dbgs() << "COULDN'T INSERT PHI TRANSLATED VALUE OF: " + << *LI->getPointerOperand() << "\n"); CanDoPRE = false; break; } @@ -1208,10 +1217,10 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, // Okay, we can eliminate this load by inserting a reload in the predecessor // and using PHI construction to get the value in the other predecessors, do // it. - DEBUG(dbgs() << "GVN REMOVING PRE LOAD: " << *LI << '\n'); - DEBUG(if (!NewInsts.empty()) - dbgs() << "INSERTED " << NewInsts.size() << " INSTS: " - << *NewInsts.back() << '\n'); + LLVM_DEBUG(dbgs() << "GVN REMOVING PRE LOAD: " << *LI << '\n'); + LLVM_DEBUG(if (!NewInsts.empty()) dbgs() + << "INSERTED " << NewInsts.size() << " INSTS: " << *NewInsts.back() + << '\n'); // Assign value numbers to the new instructions. for (Instruction *I : NewInsts) { @@ -1262,7 +1271,7 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, ValuesPerBlock.push_back(AvailableValueInBlock::get(UnavailablePred, NewLoad)); MD->invalidateCachedPointerInfo(LoadPtr); - DEBUG(dbgs() << "GVN INSERTED " << *NewLoad << '\n'); + LLVM_DEBUG(dbgs() << "GVN INSERTED " << *NewLoad << '\n'); } // Perform PHI construction. @@ -1320,11 +1329,8 @@ bool GVN::processNonLocalLoad(LoadInst *LI) { // clobber in the current block. Reject this early. if (NumDeps == 1 && !Deps[0].getResult().isDef() && !Deps[0].getResult().isClobber()) { - DEBUG( - dbgs() << "GVN: non-local load "; - LI->printAsOperand(dbgs()); - dbgs() << " has unknown dependencies\n"; - ); + LLVM_DEBUG(dbgs() << "GVN: non-local load "; LI->printAsOperand(dbgs()); + dbgs() << " has unknown dependencies\n";); return false; } @@ -1353,7 +1359,7 @@ bool GVN::processNonLocalLoad(LoadInst *LI) { // load, then it is fully redundant and we can use PHI insertion to compute // its value. Insert PHIs and remove the fully redundant value now. if (UnavailableBlocks.empty()) { - DEBUG(dbgs() << "GVN REMOVING NONLOCAL LOAD: " << *LI << '\n'); + LLVM_DEBUG(dbgs() << "GVN REMOVING NONLOCAL LOAD: " << *LI << '\n'); // Perform PHI construction. Value *V = ConstructSSAForLoadSet(LI, ValuesPerBlock, *this); @@ -1506,12 +1512,10 @@ bool GVN::processLoad(LoadInst *L) { // Only handle the local case below if (!Dep.isDef() && !Dep.isClobber()) { // This might be a NonFuncLocal or an Unknown - DEBUG( - // fast print dep, using operator<< on instruction is too slow. - dbgs() << "GVN: load "; - L->printAsOperand(dbgs()); - dbgs() << " has unknown dependence\n"; - ); + LLVM_DEBUG( + // fast print dep, using operator<< on instruction is too slow. + dbgs() << "GVN: load "; L->printAsOperand(dbgs()); + dbgs() << " has unknown dependence\n";); return false; } @@ -1695,8 +1699,8 @@ bool GVN::replaceOperandsWithConsts(Instruction *Instr) const { if (it != ReplaceWithConstMap.end()) { assert(!isa<Constant>(Operand) && "Replacing constants with constants is invalid"); - DEBUG(dbgs() << "GVN replacing: " << *Operand << " with " << *it->second - << " in instruction " << *Instr << '\n'); + LLVM_DEBUG(dbgs() << "GVN replacing: " << *Operand << " with " + << *it->second << " in instruction " << *Instr << '\n'); Instr->setOperand(OpNum, it->second); Changed = true; } @@ -2038,7 +2042,7 @@ bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, unsigned Iteration = 0; while (ShouldContinue) { - DEBUG(dbgs() << "GVN iteration: " << Iteration << "\n"); + LLVM_DEBUG(dbgs() << "GVN iteration: " << Iteration << "\n"); ShouldContinue = iterateOnFunction(F); Changed |= ShouldContinue; ++Iteration; @@ -2104,9 +2108,10 @@ bool GVN::processBlock(BasicBlock *BB) { const Instruction *MaybeFirstICF = FirstImplicitControlFlowInsts.lookup(BB); for (auto *I : InstrsToErase) { assert(I->getParent() == BB && "Removing instruction from wrong block?"); - DEBUG(dbgs() << "GVN removed: " << *I << '\n'); + LLVM_DEBUG(dbgs() << "GVN removed: " << *I << '\n'); + salvageDebugInfo(*I); if (MD) MD->removeInstruction(I); - DEBUG(verifyRemoved(I)); + LLVM_DEBUG(verifyRemoved(I)); if (MaybeFirstICF == I) { // We have erased the first ICF in block. The map needs to be updated. InvalidateImplicitCF = true; @@ -2288,7 +2293,7 @@ bool GVN::performScalarPRE(Instruction *CurInst) { PREInstr = CurInst->clone(); if (!performScalarPREInsertion(PREInstr, PREPred, CurrentBlock, ValNo)) { // If we failed insertion, make sure we remove the instruction. - DEBUG(verifyRemoved(PREInstr)); + LLVM_DEBUG(verifyRemoved(PREInstr)); PREInstr->deleteValue(); return false; } @@ -2326,10 +2331,10 @@ bool GVN::performScalarPRE(Instruction *CurInst) { VN.erase(CurInst); removeFromLeaderTable(ValNo, CurInst, CurrentBlock); - DEBUG(dbgs() << "GVN PRE removed: " << *CurInst << '\n'); + LLVM_DEBUG(dbgs() << "GVN PRE removed: " << *CurInst << '\n'); if (MD) MD->removeInstruction(CurInst); - DEBUG(verifyRemoved(CurInst)); + LLVM_DEBUG(verifyRemoved(CurInst)); bool InvalidateImplicitCF = FirstImplicitControlFlowInsts.lookup(CurInst->getParent()) == CurInst; // FIXME: Intended to be markInstructionForDeletion(CurInst), but it causes diff --git a/lib/Transforms/Scalar/GVNHoist.cpp b/lib/Transforms/Scalar/GVNHoist.cpp index c0cd1ea74a74..6d2b25cf6013 100644 --- a/lib/Transforms/Scalar/GVNHoist.cpp +++ b/lib/Transforms/Scalar/GVNHoist.cpp @@ -48,6 +48,7 @@ #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/PostDominators.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" @@ -72,7 +73,6 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/GVN.h" -#include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> #include <iterator> @@ -534,7 +534,7 @@ private: if (NewBB == DBB && !MSSA->isLiveOnEntryDef(D)) if (auto *UD = dyn_cast<MemoryUseOrDef>(D)) - if (firstInBB(NewPt, UD->getMemoryInst())) + if (!firstInBB(UD->getMemoryInst(), NewPt)) // Cannot move the load or store to NewPt above its definition in D. return false; @@ -570,7 +570,7 @@ private: // The ides is inspired from: // "Partial Redundancy Elimination in SSA Form" // ROBERT KENNEDY, SUN CHAN, SHIN-MING LIU, RAYMOND LO, PENG TU and FRED CHOW - // They use similar idea in the forward graph to to find fully redundant and + // They use similar idea in the forward graph to find fully redundant and // partially redundant expressions, here it is used in the inverse graph to // find fully anticipable instructions at merge point (post-dominator in // the inverse CFG). @@ -578,7 +578,7 @@ private: // Returns true when the values are flowing out to each edge. bool valueAnticipable(CHIArgs C, TerminatorInst *TI) const { - if (TI->getNumSuccessors() > (unsigned)std::distance(C.begin(), C.end())) + if (TI->getNumSuccessors() > (unsigned)size(C)) return false; // Not enough args in this CHI. for (auto CHI : C) { @@ -622,7 +622,7 @@ private: // Iterate in reverse order to keep lower ranked values on the top. for (std::pair<VNType, Instruction *> &VI : reverse(it1->second)) { // Get the value of instruction I - DEBUG(dbgs() << "\nPushing on stack: " << *VI.second); + LLVM_DEBUG(dbgs() << "\nPushing on stack: " << *VI.second); RenameStack[VI.first].push_back(VI.second); } } @@ -636,7 +636,7 @@ private: if (P == CHIBBs.end()) { continue; } - DEBUG(dbgs() << "\nLooking at CHIs in: " << Pred->getName();); + LLVM_DEBUG(dbgs() << "\nLooking at CHIs in: " << Pred->getName();); // A CHI is found (BB -> Pred is an edge in the CFG) // Pop the stack until Top(V) = Ve. auto &VCHI = P->second; @@ -648,12 +648,12 @@ private: // track in a CHI. In the PDom walk, there can be values in the // stack which are not control dependent e.g., nested loop. if (si != RenameStack.end() && si->second.size() && - DT->dominates(Pred, si->second.back()->getParent())) { + DT->properlyDominates(Pred, si->second.back()->getParent())) { C.Dest = BB; // Assign the edge C.I = si->second.pop_back_val(); // Assign the argument - DEBUG(dbgs() << "\nCHI Inserted in BB: " << C.Dest->getName() - << *C.I << ", VN: " << C.VN.first << ", " - << C.VN.second); + LLVM_DEBUG(dbgs() + << "\nCHI Inserted in BB: " << C.Dest->getName() << *C.I + << ", VN: " << C.VN.first << ", " << C.VN.second); } // Move to next CHI of a different value It = std::find_if(It, VCHI.end(), @@ -748,11 +748,11 @@ private: // TODO: Remove fully-redundant expressions. // Get instruction from the Map, assume that all the Instructions // with same VNs have same rank (this is an approximation). - std::sort(Ranks.begin(), Ranks.end(), - [this, &Map](const VNType &r1, const VNType &r2) { - return (rank(*Map.lookup(r1).begin()) < - rank(*Map.lookup(r2).begin())); - }); + llvm::sort(Ranks.begin(), Ranks.end(), + [this, &Map](const VNType &r1, const VNType &r2) { + return (rank(*Map.lookup(r1).begin()) < + rank(*Map.lookup(r2).begin())); + }); // - Sort VNs according to their rank, and start with lowest ranked VN // - Take a VN and for each instruction with same VN @@ -798,8 +798,8 @@ private: // Ignore spurious PDFs. if (DT->properlyDominates(IDFB, V[i]->getParent())) { OutValue[IDFB].push_back(C); - DEBUG(dbgs() << "\nInsertion a CHI for BB: " << IDFB->getName() - << ", for Insn: " << *V[i]); + LLVM_DEBUG(dbgs() << "\nInsertion a CHI for BB: " << IDFB->getName() + << ", for Insn: " << *V[i]); } } } @@ -1200,6 +1200,7 @@ INITIALIZE_PASS_BEGIN(GVNHoistLegacyPass, "gvn-hoist", INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_END(GVNHoistLegacyPass, "gvn-hoist", "Early GVN Hoisting of Expressions", false, false) diff --git a/lib/Transforms/Scalar/GVNSink.cpp b/lib/Transforms/Scalar/GVNSink.cpp index bf92e43c4715..28c5940db1e0 100644 --- a/lib/Transforms/Scalar/GVNSink.cpp +++ b/lib/Transforms/Scalar/GVNSink.cpp @@ -48,6 +48,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" @@ -71,7 +72,6 @@ #include "llvm/Transforms/Scalar/GVN.h" #include "llvm/Transforms/Scalar/GVNExpression.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> #include <cstddef> @@ -239,7 +239,7 @@ public: SmallVector<std::pair<BasicBlock *, Value *>, 4> Ops; for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) Ops.push_back({PN->getIncomingBlock(I), PN->getIncomingValue(I)}); - std::sort(Ops.begin(), Ops.end()); + llvm::sort(Ops.begin(), Ops.end()); for (auto &P : Ops) { Blocks.push_back(P.first); Values.push_back(P.second); @@ -361,7 +361,7 @@ public: for (auto &U : I->uses()) op_push_back(U.getUser()); - std::sort(op_begin(), op_end()); + llvm::sort(op_begin(), op_end()); } void setMemoryUseOrder(unsigned MUO) { MemoryUseOrder = MUO; } @@ -561,7 +561,8 @@ public: GVNSink() = default; bool run(Function &F) { - DEBUG(dbgs() << "GVNSink: running on function @" << F.getName() << "\n"); + LLVM_DEBUG(dbgs() << "GVNSink: running on function @" << F.getName() + << "\n"); unsigned NumSunk = 0; ReversePostOrderTraversal<Function*> RPOT(&F); @@ -592,12 +593,8 @@ private: /// Create a ModelledPHI for each PHI in BB, adding to PHIs. void analyzeInitialPHIs(BasicBlock *BB, ModelledPHISet &PHIs, SmallPtrSetImpl<Value *> &PHIContents) { - for (auto &I : *BB) { - auto *PN = dyn_cast<PHINode>(&I); - if (!PN) - return; - - auto MPHI = ModelledPHI(PN); + for (PHINode &PN : BB->phis()) { + auto MPHI = ModelledPHI(&PN); PHIs.insert(MPHI); for (auto *V : MPHI.getValues()) PHIContents.insert(V); @@ -633,15 +630,15 @@ Optional<SinkingInstructionCandidate> GVNSink::analyzeInstructionForSinking( LockstepReverseIterator &LRI, unsigned &InstNum, unsigned &MemoryInstNum, ModelledPHISet &NeededPHIs, SmallPtrSetImpl<Value *> &PHIContents) { auto Insts = *LRI; - DEBUG(dbgs() << " -- Analyzing instruction set: [\n"; for (auto *I - : Insts) { + LLVM_DEBUG(dbgs() << " -- Analyzing instruction set: [\n"; for (auto *I + : Insts) { I->dump(); } dbgs() << " ]\n";); DenseMap<uint32_t, unsigned> VNums; for (auto *I : Insts) { uint32_t N = VN.lookupOrAdd(I); - DEBUG(dbgs() << " VN=" << Twine::utohexstr(N) << " for" << *I << "\n"); + LLVM_DEBUG(dbgs() << " VN=" << Twine::utohexstr(N) << " for" << *I << "\n"); if (N == ~0U) return None; VNums[N]++; @@ -753,8 +750,8 @@ Optional<SinkingInstructionCandidate> GVNSink::analyzeInstructionForSinking( } unsigned GVNSink::sinkBB(BasicBlock *BBEnd) { - DEBUG(dbgs() << "GVNSink: running on basic block "; - BBEnd->printAsOperand(dbgs()); dbgs() << "\n"); + LLVM_DEBUG(dbgs() << "GVNSink: running on basic block "; + BBEnd->printAsOperand(dbgs()); dbgs() << "\n"); SmallVector<BasicBlock *, 4> Preds; for (auto *B : predecessors(BBEnd)) { auto *T = B->getTerminator(); @@ -765,7 +762,7 @@ unsigned GVNSink::sinkBB(BasicBlock *BBEnd) { } if (Preds.size() < 2) return 0; - std::sort(Preds.begin(), Preds.end()); + llvm::sort(Preds.begin(), Preds.end()); unsigned NumOrigPreds = Preds.size(); // We can only sink instructions through unconditional branches. @@ -798,23 +795,23 @@ unsigned GVNSink::sinkBB(BasicBlock *BBEnd) { Candidates.begin(), Candidates.end(), [](const SinkingInstructionCandidate &A, const SinkingInstructionCandidate &B) { return A > B; }); - DEBUG(dbgs() << " -- Sinking candidates:\n"; for (auto &C - : Candidates) dbgs() - << " " << C << "\n";); + LLVM_DEBUG(dbgs() << " -- Sinking candidates:\n"; for (auto &C + : Candidates) dbgs() + << " " << C << "\n";); // Pick the top candidate, as long it is positive! if (Candidates.empty() || Candidates.front().Cost <= 0) return 0; auto C = Candidates.front(); - DEBUG(dbgs() << " -- Sinking: " << C << "\n"); + LLVM_DEBUG(dbgs() << " -- Sinking: " << C << "\n"); BasicBlock *InsertBB = BBEnd; if (C.Blocks.size() < NumOrigPreds) { - DEBUG(dbgs() << " -- Splitting edge to "; BBEnd->printAsOperand(dbgs()); - dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " -- Splitting edge to "; + BBEnd->printAsOperand(dbgs()); dbgs() << "\n"); InsertBB = SplitBlockPredecessors(BBEnd, C.Blocks, ".gvnsink.split"); if (!InsertBB) { - DEBUG(dbgs() << " -- FAILED to split edge!\n"); + LLVM_DEBUG(dbgs() << " -- FAILED to split edge!\n"); // Edge couldn't be split. return 0; } diff --git a/lib/Transforms/Scalar/GuardWidening.cpp b/lib/Transforms/Scalar/GuardWidening.cpp index c4aeccb85ca7..ad1598d7b8bf 100644 --- a/lib/Transforms/Scalar/GuardWidening.cpp +++ b/lib/Transforms/Scalar/GuardWidening.cpp @@ -40,9 +40,11 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/GuardWidening.h" +#include <functional> #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/ConstantRange.h" @@ -53,6 +55,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; @@ -62,9 +65,14 @@ namespace { class GuardWideningImpl { DominatorTree &DT; - PostDominatorTree &PDT; + PostDominatorTree *PDT; LoopInfo &LI; + /// Together, these describe the region of interest. This might be all of + /// the blocks within a function, or only a given loop's blocks and preheader. + DomTreeNode *Root; + std::function<bool(BasicBlock*)> BlockFilter; + /// The set of guards whose conditions have been widened into dominating /// guards. SmallVector<IntrinsicInst *, 16> EliminatedGuards; @@ -205,39 +213,15 @@ class GuardWideningImpl { } public: - explicit GuardWideningImpl(DominatorTree &DT, PostDominatorTree &PDT, - LoopInfo &LI) - : DT(DT), PDT(PDT), LI(LI) {} + + explicit GuardWideningImpl(DominatorTree &DT, PostDominatorTree *PDT, + LoopInfo &LI, DomTreeNode *Root, + std::function<bool(BasicBlock*)> BlockFilter) + : DT(DT), PDT(PDT), LI(LI), Root(Root), BlockFilter(BlockFilter) {} /// The entry point for this pass. bool run(); }; - -struct GuardWideningLegacyPass : public FunctionPass { - static char ID; - GuardWideningPass Impl; - - GuardWideningLegacyPass() : FunctionPass(ID) { - initializeGuardWideningLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - return GuardWideningImpl( - getAnalysis<DominatorTreeWrapperPass>().getDomTree(), - getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(), - getAnalysis<LoopInfoWrapperPass>().getLoopInfo()).run(); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<PostDominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - } -}; - } bool GuardWideningImpl::run() { @@ -246,9 +230,12 @@ bool GuardWideningImpl::run() { DenseMap<BasicBlock *, SmallVector<IntrinsicInst *, 8>> GuardsInBlock; bool Changed = false; - for (auto DFI = df_begin(DT.getRootNode()), DFE = df_end(DT.getRootNode()); + for (auto DFI = df_begin(Root), DFE = df_end(Root); DFI != DFE; ++DFI) { auto *BB = (*DFI)->getBlock(); + if (!BlockFilter(BB)) + continue; + auto &CurrentList = GuardsInBlock[BB]; for (auto &I : *BB) @@ -259,6 +246,7 @@ bool GuardWideningImpl::run() { Changed |= eliminateGuardViaWidening(II, DFI, GuardsInBlock); } + assert(EliminatedGuards.empty() || Changed); for (auto *II : EliminatedGuards) if (!WidenedGuards.count(II)) II->eraseFromParent(); @@ -278,6 +266,8 @@ bool GuardWideningImpl::eliminateGuardViaWidening( // for the most profit. for (unsigned i = 0, e = DFSI.getPathLength(); i != e; ++i) { auto *CurBB = DFSI.getPath(i)->getBlock(); + if (!BlockFilter(CurBB)) + break; auto *CurLoop = LI.getLoopFor(CurBB); assert(GuardsInBlock.count(CurBB) && "Must have been populated by now!"); const auto &GuardsInCurBB = GuardsInBlock.find(CurBB)->second; @@ -312,9 +302,9 @@ bool GuardWideningImpl::eliminateGuardViaWidening( for (auto *Candidate : make_range(I, E)) { auto Score = computeWideningScore(GuardInst, GuardInstLoop, Candidate, CurLoop); - DEBUG(dbgs() << "Score between " << *GuardInst->getArgOperand(0) - << " and " << *Candidate->getArgOperand(0) << " is " - << scoreTypeToString(Score) << "\n"); + LLVM_DEBUG(dbgs() << "Score between " << *GuardInst->getArgOperand(0) + << " and " << *Candidate->getArgOperand(0) << " is " + << scoreTypeToString(Score) << "\n"); if (Score > BestScoreSoFar) { BestScoreSoFar = Score; BestSoFar = Candidate; @@ -323,15 +313,16 @@ bool GuardWideningImpl::eliminateGuardViaWidening( } if (BestScoreSoFar == WS_IllegalOrNegative) { - DEBUG(dbgs() << "Did not eliminate guard " << *GuardInst << "\n"); + LLVM_DEBUG(dbgs() << "Did not eliminate guard " << *GuardInst << "\n"); return false; } assert(BestSoFar != GuardInst && "Should have never visited same guard!"); assert(DT.dominates(BestSoFar, GuardInst) && "Should be!"); - DEBUG(dbgs() << "Widening " << *GuardInst << " into " << *BestSoFar - << " with score " << scoreTypeToString(BestScoreSoFar) << "\n"); + LLVM_DEBUG(dbgs() << "Widening " << *GuardInst << " into " << *BestSoFar + << " with score " << scoreTypeToString(BestScoreSoFar) + << "\n"); widenGuard(BestSoFar, GuardInst->getArgOperand(0)); GuardInst->setArgOperand(0, ConstantInt::getTrue(GuardInst->getContext())); EliminatedGuards.push_back(GuardInst); @@ -345,6 +336,8 @@ GuardWideningImpl::WideningScore GuardWideningImpl::computeWideningScore( bool HoistingOutOfLoop = false; if (DominatingGuardLoop != DominatedGuardLoop) { + // Be conservative and don't widen into a sibling loop. TODO: If the + // sibling is colder, we should consider allowing this. if (DominatingGuardLoop && !DominatingGuardLoop->contains(DominatedGuardLoop)) return WS_IllegalOrNegative; @@ -355,9 +348,14 @@ GuardWideningImpl::WideningScore GuardWideningImpl::computeWideningScore( if (!isAvailableAt(DominatedGuard->getArgOperand(0), DominatingGuard)) return WS_IllegalOrNegative; - bool HoistingOutOfIf = - !PDT.dominates(DominatedGuard->getParent(), DominatingGuard->getParent()); - + // If the guard was conditional executed, it may never be reached + // dynamically. There are two potential downsides to hoisting it out of the + // conditionally executed region: 1) we may spuriously deopt without need and + // 2) we have the extra cost of computing the guard condition in the common + // case. At the moment, we really only consider the second in our heuristic + // here. TODO: evaluate cost model for spurious deopt + // NOTE: As written, this also lets us hoist right over another guard which + // is essentially just another spelling for control flow. if (isWideningCondProfitable(DominatedGuard->getArgOperand(0), DominatingGuard->getArgOperand(0))) return HoistingOutOfLoop ? WS_VeryPositive : WS_Positive; @@ -365,7 +363,26 @@ GuardWideningImpl::WideningScore GuardWideningImpl::computeWideningScore( if (HoistingOutOfLoop) return WS_Positive; - return HoistingOutOfIf ? WS_IllegalOrNegative : WS_Neutral; + // Returns true if we might be hoisting above explicit control flow. Note + // that this completely ignores implicit control flow (guards, calls which + // throw, etc...). That choice appears arbitrary. + auto MaybeHoistingOutOfIf = [&]() { + auto *DominatingBlock = DominatingGuard->getParent(); + auto *DominatedBlock = DominatedGuard->getParent(); + + // Same Block? + if (DominatedBlock == DominatingBlock) + return false; + // Obvious successor (common loop header/preheader case) + if (DominatedBlock == DominatingBlock->getUniqueSuccessor()) + return false; + // TODO: diamond, triangle cases + if (!PDT) return true; + return !PDT->dominates(DominatedGuard->getParent(), + DominatingGuard->getParent()); + }; + + return MaybeHoistingOutOfIf() ? WS_IllegalOrNegative : WS_Neutral; } bool GuardWideningImpl::isAvailableAt(Value *V, Instruction *Loc, @@ -581,9 +598,9 @@ bool GuardWideningImpl::combineRangeChecks( // CurrentChecks.size() will typically be 3 here, but so far there has been // no need to hard-code that fact. - std::sort(CurrentChecks.begin(), CurrentChecks.end(), - [&](const GuardWideningImpl::RangeCheck &LHS, - const GuardWideningImpl::RangeCheck &RHS) { + llvm::sort(CurrentChecks.begin(), CurrentChecks.end(), + [&](const GuardWideningImpl::RangeCheck &LHS, + const GuardWideningImpl::RangeCheck &RHS) { return LHS.getOffsetValue().slt(RHS.getOffsetValue()); }); @@ -651,19 +668,6 @@ bool GuardWideningImpl::combineRangeChecks( return RangeChecksOut.size() != OldCount; } -PreservedAnalyses GuardWideningPass::run(Function &F, - FunctionAnalysisManager &AM) { - auto &DT = AM.getResult<DominatorTreeAnalysis>(F); - auto &LI = AM.getResult<LoopAnalysis>(F); - auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); - if (!GuardWideningImpl(DT, PDT, LI).run()) - return PreservedAnalyses::all(); - - PreservedAnalyses PA; - PA.preserveSet<CFGAnalyses>(); - return PA; -} - #ifndef NDEBUG StringRef GuardWideningImpl::scoreTypeToString(WideningScore WS) { switch (WS) { @@ -681,7 +685,82 @@ StringRef GuardWideningImpl::scoreTypeToString(WideningScore WS) { } #endif +PreservedAnalyses GuardWideningPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &LI = AM.getResult<LoopAnalysis>(F); + auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); + if (!GuardWideningImpl(DT, &PDT, LI, DT.getRootNode(), + [](BasicBlock*) { return true; } ).run()) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; +} + +namespace { +struct GuardWideningLegacyPass : public FunctionPass { + static char ID; + + GuardWideningLegacyPass() : FunctionPass(ID) { + initializeGuardWideningLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); + return GuardWideningImpl(DT, &PDT, LI, DT.getRootNode(), + [](BasicBlock*) { return true; } ).run(); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<PostDominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + } +}; + +/// Same as above, but restricted to a single loop at a time. Can be +/// scheduled with other loop passes w/o breaking out of LPM +struct LoopGuardWideningLegacyPass : public LoopPass { + static char ID; + + LoopGuardWideningLegacyPass() : LoopPass(ID) { + initializeLoopGuardWideningLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) + return false; + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>(); + auto *PDT = PDTWP ? &PDTWP->getPostDomTree() : nullptr; + BasicBlock *RootBB = L->getLoopPredecessor(); + if (!RootBB) + RootBB = L->getHeader(); + auto BlockFilter = [&](BasicBlock *BB) { + return BB == RootBB || L->contains(BB); + }; + return GuardWideningImpl(DT, PDT, LI, + DT.getNode(RootBB), BlockFilter).run(); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + getLoopAnalysisUsage(AU); + AU.addPreserved<PostDominatorTreeWrapperPass>(); + } +}; +} + char GuardWideningLegacyPass::ID = 0; +char LoopGuardWideningLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(GuardWideningLegacyPass, "guard-widening", "Widen guards", false, false) @@ -691,6 +770,20 @@ INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_END(GuardWideningLegacyPass, "guard-widening", "Widen guards", false, false) +INITIALIZE_PASS_BEGIN(LoopGuardWideningLegacyPass, "loop-guard-widening", + "Widen guards (within a single loop, as a loop pass)", + false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(LoopGuardWideningLegacyPass, "loop-guard-widening", + "Widen guards (within a single loop, as a loop pass)", + false, false) + FunctionPass *llvm::createGuardWideningPass() { return new GuardWideningLegacyPass(); } + +Pass *llvm::createLoopGuardWideningPass() { + return new LoopGuardWideningLegacyPass(); +} diff --git a/lib/Transforms/Scalar/IndVarSimplify.cpp b/lib/Transforms/Scalar/IndVarSimplify.cpp index 74d6014d3e3d..8656e88b79cb 100644 --- a/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -43,6 +43,7 @@ #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/ConstantRange.h" @@ -77,7 +78,6 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SimplifyIndVar.h" #include <cassert> @@ -210,8 +210,8 @@ bool IndVarSimplify::isValidRewrite(Value *FromVal, Value *ToVal) { if (FromBase == ToBase) return true; - DEBUG(dbgs() << "INDVARS: GEP rewrite bail out " - << *FromBase << " != " << *ToBase << "\n"); + LLVM_DEBUG(dbgs() << "INDVARS: GEP rewrite bail out " << *FromBase + << " != " << *ToBase << "\n"); return false; } @@ -485,9 +485,8 @@ void IndVarSimplify::rewriteNonIntegerIVs(Loop *L) { BasicBlock *Header = L->getHeader(); SmallVector<WeakTrackingVH, 8> PHIs; - for (BasicBlock::iterator I = Header->begin(); - PHINode *PN = dyn_cast<PHINode>(I); ++I) - PHIs.push_back(PN); + for (PHINode &PN : Header->phis()) + PHIs.push_back(&PN); for (unsigned i = 0, e = PHIs.size(); i != e; ++i) if (PHINode *PN = dyn_cast_or_null<PHINode>(&*PHIs[i])) @@ -654,8 +653,9 @@ void IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { Value *ExitVal = expandSCEVIfNeeded(Rewriter, ExitValue, L, Inst, PN->getType()); - DEBUG(dbgs() << "INDVARS: RLEV: AfterLoopVal = " << *ExitVal << '\n' - << " LoopVal = " << *Inst << "\n"); + LLVM_DEBUG(dbgs() << "INDVARS: RLEV: AfterLoopVal = " << *ExitVal + << '\n' + << " LoopVal = " << *Inst << "\n"); if (!isValidRewrite(Inst, ExitVal)) { DeadInsts.push_back(ExitVal); @@ -724,13 +724,12 @@ void IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) { assert(LoopHeader && "Invalid loop"); for (auto *ExitBB : ExitBlocks) { - BasicBlock::iterator BBI = ExitBB->begin(); // If there are no more PHI nodes in this exit block, then no more // values defined inside the loop are used on this path. - while (auto *PN = dyn_cast<PHINode>(BBI++)) { - for (unsigned IncomingValIdx = 0, E = PN->getNumIncomingValues(); - IncomingValIdx != E; ++IncomingValIdx) { - auto *IncomingBB = PN->getIncomingBlock(IncomingValIdx); + for (PHINode &PN : ExitBB->phis()) { + for (unsigned IncomingValIdx = 0, E = PN.getNumIncomingValues(); + IncomingValIdx != E; ++IncomingValIdx) { + auto *IncomingBB = PN.getIncomingBlock(IncomingValIdx); // We currently only support loop exits from loop header. If the // incoming block is not loop header, we need to recursively check @@ -755,8 +754,7 @@ void IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) { if (!L->isLoopInvariant(Cond)) continue; - auto *ExitVal = - dyn_cast<PHINode>(PN->getIncomingValue(IncomingValIdx)); + auto *ExitVal = dyn_cast<PHINode>(PN.getIncomingValue(IncomingValIdx)); // Only deal with PHIs. if (!ExitVal) @@ -771,8 +769,8 @@ void IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) { if (PreheaderIdx != -1) { assert(ExitVal->getParent() == LoopHeader && "ExitVal must be in loop header"); - PN->setIncomingValue(IncomingValIdx, - ExitVal->getIncomingValue(PreheaderIdx)); + PN.setIncomingValue(IncomingValIdx, + ExitVal->getIncomingValue(PreheaderIdx)); } } } @@ -1087,7 +1085,7 @@ Instruction *WidenIV::cloneBitwiseIVUser(NarrowIVDefUse DU) { Instruction *NarrowDef = DU.NarrowDef; Instruction *WideDef = DU.WideDef; - DEBUG(dbgs() << "Cloning bitwise IVUser: " << *NarrowUse << "\n"); + LLVM_DEBUG(dbgs() << "Cloning bitwise IVUser: " << *NarrowUse << "\n"); // Replace NarrowDef operands with WideDef. Otherwise, we don't know anything // about the narrow operand yet so must insert a [sz]ext. It is probably loop @@ -1118,7 +1116,7 @@ Instruction *WidenIV::cloneArithmeticIVUser(NarrowIVDefUse DU, Instruction *NarrowDef = DU.NarrowDef; Instruction *WideDef = DU.WideDef; - DEBUG(dbgs() << "Cloning arithmetic IVUser: " << *NarrowUse << "\n"); + LLVM_DEBUG(dbgs() << "Cloning arithmetic IVUser: " << *NarrowUse << "\n"); unsigned IVOpIdx = (NarrowUse->getOperand(0) == NarrowDef) ? 0 : 1; @@ -1318,8 +1316,8 @@ WidenIV::WidenedRecTy WidenIV::getWideRecurrence(NarrowIVDefUse DU) { /// This IV user cannot be widen. Replace this use of the original narrow IV /// with a truncation of the new wide IV to isolate and eliminate the narrow IV. static void truncateIVUse(NarrowIVDefUse DU, DominatorTree *DT, LoopInfo *LI) { - DEBUG(dbgs() << "INDVARS: Truncate IV " << *DU.WideDef - << " for user " << *DU.NarrowUse << "\n"); + LLVM_DEBUG(dbgs() << "INDVARS: Truncate IV " << *DU.WideDef << " for user " + << *DU.NarrowUse << "\n"); IRBuilder<> Builder( getInsertPointForUses(DU.NarrowUse, DU.NarrowDef, DT, LI)); Value *Trunc = Builder.CreateTrunc(DU.WideDef, DU.NarrowDef->getType()); @@ -1399,8 +1397,8 @@ Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { Value *Trunc = Builder.CreateTrunc(WidePhi, DU.NarrowDef->getType()); UsePhi->replaceAllUsesWith(Trunc); DeadInsts.emplace_back(UsePhi); - DEBUG(dbgs() << "INDVARS: Widen lcssa phi " << *UsePhi - << " to " << *WidePhi << "\n"); + LLVM_DEBUG(dbgs() << "INDVARS: Widen lcssa phi " << *UsePhi << " to " + << *WidePhi << "\n"); } return nullptr; } @@ -1431,15 +1429,16 @@ Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { // A wider extend was hidden behind a narrower one. This may induce // another round of IV widening in which the intermediate IV becomes // dead. It should be very rare. - DEBUG(dbgs() << "INDVARS: New IV " << *WidePhi - << " not wide enough to subsume " << *DU.NarrowUse << "\n"); + LLVM_DEBUG(dbgs() << "INDVARS: New IV " << *WidePhi + << " not wide enough to subsume " << *DU.NarrowUse + << "\n"); DU.NarrowUse->replaceUsesOfWith(DU.NarrowDef, DU.WideDef); NewDef = DU.NarrowUse; } } if (NewDef != DU.NarrowUse) { - DEBUG(dbgs() << "INDVARS: eliminating " << *DU.NarrowUse - << " replaced by " << *DU.WideDef << "\n"); + LLVM_DEBUG(dbgs() << "INDVARS: eliminating " << *DU.NarrowUse + << " replaced by " << *DU.WideDef << "\n"); ++NumElimExt; DU.NarrowUse->replaceAllUsesWith(NewDef); DeadInsts.emplace_back(DU.NarrowUse); @@ -1494,8 +1493,9 @@ Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { // absolutely guarantee it. Hence the following failsafe check. In rare cases // where it fails, we simply throw away the newly created wide use. if (WideAddRec.first != SE->getSCEV(WideUse)) { - DEBUG(dbgs() << "Wide use expression mismatch: " << *WideUse - << ": " << *SE->getSCEV(WideUse) << " != " << *WideAddRec.first << "\n"); + LLVM_DEBUG(dbgs() << "Wide use expression mismatch: " << *WideUse << ": " + << *SE->getSCEV(WideUse) << " != " << *WideAddRec.first + << "\n"); DeadInsts.emplace_back(WideUse); return nullptr; } @@ -1600,7 +1600,7 @@ PHINode *WidenIV::createWideIV(SCEVExpander &Rewriter) { WideInc->setDebugLoc(OrigInc->getDebugLoc()); } - DEBUG(dbgs() << "Wide IV: " << *WidePhi << "\n"); + LLVM_DEBUG(dbgs() << "Wide IV: " << *WidePhi << "\n"); ++NumWidened; // Traverse the def-use chain using a worklist starting at the original IV. @@ -2234,12 +2234,12 @@ linearFunctionTestReplace(Loop *L, else P = ICmpInst::ICMP_EQ; - DEBUG(dbgs() << "INDVARS: Rewriting loop exit condition to:\n" - << " LHS:" << *CmpIndVar << '\n' - << " op:\t" - << (P == ICmpInst::ICMP_NE ? "!=" : "==") << "\n" - << " RHS:\t" << *ExitCnt << "\n" - << " IVCount:\t" << *IVCount << "\n"); + LLVM_DEBUG(dbgs() << "INDVARS: Rewriting loop exit condition to:\n" + << " LHS:" << *CmpIndVar << '\n' + << " op:\t" << (P == ICmpInst::ICMP_NE ? "!=" : "==") + << "\n" + << " RHS:\t" << *ExitCnt << "\n" + << " IVCount:\t" << *IVCount << "\n"); IRBuilder<> Builder(BI); @@ -2275,7 +2275,7 @@ linearFunctionTestReplace(Loop *L, NewLimit = Start + Count; ExitCnt = ConstantInt::get(CmpIndVar->getType(), NewLimit); - DEBUG(dbgs() << " Widen RHS:\t" << *ExitCnt << "\n"); + LLVM_DEBUG(dbgs() << " Widen RHS:\t" << *ExitCnt << "\n"); } else { // We try to extend trip count first. If that doesn't work we truncate IV. // Zext(trunc(IV)) == IV implies equivalence of the following two: diff --git a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index 5c4d55bfbb2b..e2f29705f2dd 100644 --- a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -43,6 +43,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/InductiveRangeCheckElimination.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/None.h" @@ -52,6 +53,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -179,10 +181,7 @@ public: OS << " Step: "; Step->print(OS); OS << " End: "; - if (End) - End->print(OS); - else - OS << "(null)"; + End->print(OS); OS << "\n CheckUse: "; getCheckUse()->getUser()->print(OS); OS << " Operand: " << getCheckUse()->getOperandNo() << "\n"; @@ -196,7 +195,7 @@ public: Use *getCheckUse() const { return CheckUse; } /// Represents an signed integer range [Range.getBegin(), Range.getEnd()). If - /// R.getEnd() sle R.getBegin(), then R denotes the empty range. + /// R.getEnd() le R.getBegin(), then R denotes the empty range. class Range { const SCEV *Begin; @@ -238,17 +237,31 @@ public: /// checks, and hence don't end up in \p Checks. static void extractRangeChecksFromBranch(BranchInst *BI, Loop *L, ScalarEvolution &SE, - BranchProbabilityInfo &BPI, + BranchProbabilityInfo *BPI, SmallVectorImpl<InductiveRangeCheck> &Checks); }; -class InductiveRangeCheckElimination : public LoopPass { +class InductiveRangeCheckElimination { + ScalarEvolution &SE; + BranchProbabilityInfo *BPI; + DominatorTree &DT; + LoopInfo &LI; + +public: + InductiveRangeCheckElimination(ScalarEvolution &SE, + BranchProbabilityInfo *BPI, DominatorTree &DT, + LoopInfo &LI) + : SE(SE), BPI(BPI), DT(DT), LI(LI) {} + + bool run(Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop); +}; + +class IRCELegacyPass : public LoopPass { public: static char ID; - InductiveRangeCheckElimination() : LoopPass(ID) { - initializeInductiveRangeCheckEliminationPass( - *PassRegistry::getPassRegistry()); + IRCELegacyPass() : LoopPass(ID) { + initializeIRCELegacyPassPass(*PassRegistry::getPassRegistry()); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -261,14 +274,14 @@ public: } // end anonymous namespace -char InductiveRangeCheckElimination::ID = 0; +char IRCELegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(InductiveRangeCheckElimination, "irce", +INITIALIZE_PASS_BEGIN(IRCELegacyPass, "irce", "Inductive range check elimination", false, false) INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopPass) -INITIALIZE_PASS_END(InductiveRangeCheckElimination, "irce", - "Inductive range check elimination", false, false) +INITIALIZE_PASS_END(IRCELegacyPass, "irce", "Inductive range check elimination", + false, false) StringRef InductiveRangeCheck::rangeCheckKindToStr( InductiveRangeCheck::RangeCheckKind RCK) { @@ -299,13 +312,8 @@ InductiveRangeCheck::RangeCheckKind InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, ScalarEvolution &SE, Value *&Index, Value *&Length, bool &IsSigned) { - auto IsNonNegativeAndNotLoopVarying = [&SE, L](Value *V) { - const SCEV *S = SE.getSCEV(V); - if (isa<SCEVCouldNotCompute>(S)) - return false; - - return SE.getLoopDisposition(S, L) == ScalarEvolution::LoopInvariant && - SE.isKnownNonNegative(S); + auto IsLoopInvariant = [&SE, L](Value *V) { + return SE.isLoopInvariant(SE.getSCEV(V), L); }; ICmpInst::Predicate Pred = ICI->getPredicate(); @@ -337,7 +345,7 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, return RANGE_CHECK_LOWER; } - if (IsNonNegativeAndNotLoopVarying(LHS)) { + if (IsLoopInvariant(LHS)) { Index = RHS; Length = LHS; return RANGE_CHECK_UPPER; @@ -349,7 +357,7 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, LLVM_FALLTHROUGH; case ICmpInst::ICMP_UGT: IsSigned = false; - if (IsNonNegativeAndNotLoopVarying(LHS)) { + if (IsLoopInvariant(LHS)) { Index = RHS; Length = LHS; return RANGE_CHECK_BOTH; @@ -394,8 +402,23 @@ void InductiveRangeCheck::extractRangeChecksFromCond( if (!IsAffineIndex) return; + const SCEV *End = nullptr; + // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L". + // We can potentially do much better here. + if (Length) + End = SE.getSCEV(Length); + else { + assert(RCKind == InductiveRangeCheck::RANGE_CHECK_LOWER && "invariant!"); + // So far we can only reach this point for Signed range check. This may + // change in future. In this case we will need to pick Unsigned max for the + // unsigned range check. + unsigned BitWidth = cast<IntegerType>(IndexAddRec->getType())->getBitWidth(); + const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); + End = SIntMax; + } + InductiveRangeCheck IRC; - IRC.End = Length ? SE.getSCEV(Length) : nullptr; + IRC.End = End; IRC.Begin = IndexAddRec->getStart(); IRC.Step = IndexAddRec->getStepRecurrence(SE); IRC.CheckUse = &ConditionUse; @@ -405,15 +428,15 @@ void InductiveRangeCheck::extractRangeChecksFromCond( } void InductiveRangeCheck::extractRangeChecksFromBranch( - BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo &BPI, + BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo *BPI, SmallVectorImpl<InductiveRangeCheck> &Checks) { if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch()) return; BranchProbability LikelyTaken(15, 16); - if (!SkipProfitabilityChecks && - BPI.getEdgeProbability(BI->getParent(), (unsigned)0) < LikelyTaken) + if (!SkipProfitabilityChecks && BPI && + BPI->getEdgeProbability(BI->getParent(), (unsigned)0) < LikelyTaken) return; SmallPtrSet<Value *, 8> Visited; @@ -504,9 +527,8 @@ struct LoopStructure { } static Optional<LoopStructure> parseLoopStructure(ScalarEvolution &, - BranchProbabilityInfo &BPI, - Loop &, - const char *&); + BranchProbabilityInfo *BPI, + Loop &, const char *&); }; /// This class is used to constrain loops to run within a given iteration space. @@ -573,7 +595,7 @@ class LoopConstrainer { // Create the appropriate loop structure needed to describe a cloned copy of // `Original`. The clone is described by `VM`. Loop *createClonedLoopStructure(Loop *Original, Loop *Parent, - ValueToValueMapTy &VM); + ValueToValueMapTy &VM, bool IsSubloop); // Rewrite the iteration space of the loop denoted by (LS, Preheader). The // iteration space of the rewritten loop ends at ExitLoopAt. The start of the @@ -625,8 +647,8 @@ class LoopConstrainer { LLVMContext &Ctx; ScalarEvolution &SE; DominatorTree &DT; - LPPassManager &LPM; LoopInfo &LI; + function_ref<void(Loop *, bool)> LPMAddNewLoop; // Information about the original loop we started out with. Loop &OriginalLoop; @@ -646,12 +668,13 @@ class LoopConstrainer { LoopStructure MainLoopStructure; public: - LoopConstrainer(Loop &L, LoopInfo &LI, LPPassManager &LPM, + LoopConstrainer(Loop &L, LoopInfo &LI, + function_ref<void(Loop *, bool)> LPMAddNewLoop, const LoopStructure &LS, ScalarEvolution &SE, DominatorTree &DT, InductiveRangeCheck::Range R) : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), - SE(SE), DT(DT), LPM(LPM), LI(LI), OriginalLoop(L), Range(R), - MainLoopStructure(LS) {} + SE(SE), DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), + Range(R), MainLoopStructure(LS) {} // Entry point for the algorithm. Returns true on success. bool run(); @@ -666,56 +689,141 @@ void LoopConstrainer::replacePHIBlock(PHINode *PN, BasicBlock *Block, PN->setIncomingBlock(i, ReplaceBy); } -static bool CanBeMax(ScalarEvolution &SE, const SCEV *S, bool Signed) { - APInt Max = Signed ? - APInt::getSignedMaxValue(cast<IntegerType>(S->getType())->getBitWidth()) : - APInt::getMaxValue(cast<IntegerType>(S->getType())->getBitWidth()); - return SE.getSignedRange(S).contains(Max) && - SE.getUnsignedRange(S).contains(Max); +static bool CannotBeMaxInLoop(const SCEV *BoundSCEV, Loop *L, + ScalarEvolution &SE, bool Signed) { + unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); + APInt Max = Signed ? APInt::getSignedMaxValue(BitWidth) : + APInt::getMaxValue(BitWidth); + auto Predicate = Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; + return SE.isAvailableAtLoopEntry(BoundSCEV, L) && + SE.isLoopEntryGuardedByCond(L, Predicate, BoundSCEV, + SE.getConstant(Max)); } -static bool SumCanReachMax(ScalarEvolution &SE, const SCEV *S1, const SCEV *S2, - bool Signed) { - // S1 < INT_MAX - S2 ===> S1 + S2 < INT_MAX. - assert(SE.isKnownNonNegative(S2) && - "We expected the 2nd arg to be non-negative!"); - const SCEV *Max = SE.getConstant( - Signed ? APInt::getSignedMaxValue( - cast<IntegerType>(S1->getType())->getBitWidth()) - : APInt::getMaxValue( - cast<IntegerType>(S1->getType())->getBitWidth())); - const SCEV *CapForS1 = SE.getMinusSCEV(Max, S2); - return !SE.isKnownPredicate(Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, - S1, CapForS1); +/// Given a loop with an deccreasing induction variable, is it possible to +/// safely calculate the bounds of a new loop using the given Predicate. +static bool isSafeDecreasingBound(const SCEV *Start, + const SCEV *BoundSCEV, const SCEV *Step, + ICmpInst::Predicate Pred, + unsigned LatchBrExitIdx, + Loop *L, ScalarEvolution &SE) { + if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT && + Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT) + return false; + + if (!SE.isAvailableAtLoopEntry(BoundSCEV, L)) + return false; + + assert(SE.isKnownNegative(Step) && "expecting negative step"); + + LLVM_DEBUG(dbgs() << "irce: isSafeDecreasingBound with:\n"); + LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n"); + LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n"); + LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n"); + LLVM_DEBUG(dbgs() << "irce: Pred: " << ICmpInst::getPredicateName(Pred) + << "\n"); + LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n"); + + bool IsSigned = ICmpInst::isSigned(Pred); + // The predicate that we need to check that the induction variable lies + // within bounds. + ICmpInst::Predicate BoundPred = + IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT; + + if (LatchBrExitIdx == 1) + return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV); + + assert(LatchBrExitIdx == 0 && + "LatchBrExitIdx should be either 0 or 1"); + + const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType())); + unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); + APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth) : + APInt::getMinValue(BitWidth); + const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne); + + const SCEV *MinusOne = + SE.getMinusSCEV(BoundSCEV, SE.getOne(BoundSCEV->getType())); + + return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, MinusOne) && + SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit); + } -static bool CanBeMin(ScalarEvolution &SE, const SCEV *S, bool Signed) { - APInt Min = Signed ? - APInt::getSignedMinValue(cast<IntegerType>(S->getType())->getBitWidth()) : - APInt::getMinValue(cast<IntegerType>(S->getType())->getBitWidth()); - return SE.getSignedRange(S).contains(Min) && - SE.getUnsignedRange(S).contains(Min); +/// Given a loop with an increasing induction variable, is it possible to +/// safely calculate the bounds of a new loop using the given Predicate. +static bool isSafeIncreasingBound(const SCEV *Start, + const SCEV *BoundSCEV, const SCEV *Step, + ICmpInst::Predicate Pred, + unsigned LatchBrExitIdx, + Loop *L, ScalarEvolution &SE) { + if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT && + Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT) + return false; + + if (!SE.isAvailableAtLoopEntry(BoundSCEV, L)) + return false; + + LLVM_DEBUG(dbgs() << "irce: isSafeIncreasingBound with:\n"); + LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n"); + LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n"); + LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n"); + LLVM_DEBUG(dbgs() << "irce: Pred: " << ICmpInst::getPredicateName(Pred) + << "\n"); + LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n"); + + bool IsSigned = ICmpInst::isSigned(Pred); + // The predicate that we need to check that the induction variable lies + // within bounds. + ICmpInst::Predicate BoundPred = + IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT; + + if (LatchBrExitIdx == 1) + return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV); + + assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1"); + + const SCEV *StepMinusOne = + SE.getMinusSCEV(Step, SE.getOne(Step->getType())); + unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); + APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth) : + APInt::getMaxValue(BitWidth); + const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne); + + return (SE.isLoopEntryGuardedByCond(L, BoundPred, Start, + SE.getAddExpr(BoundSCEV, Step)) && + SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit)); } -static bool SumCanReachMin(ScalarEvolution &SE, const SCEV *S1, const SCEV *S2, - bool Signed) { - // S1 > INT_MIN - S2 ===> S1 + S2 > INT_MIN. - assert(SE.isKnownNonPositive(S2) && - "We expected the 2nd arg to be non-positive!"); - const SCEV *Max = SE.getConstant( - Signed ? APInt::getSignedMinValue( - cast<IntegerType>(S1->getType())->getBitWidth()) - : APInt::getMinValue( - cast<IntegerType>(S1->getType())->getBitWidth())); - const SCEV *CapForS1 = SE.getMinusSCEV(Max, S2); - return !SE.isKnownPredicate(Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT, - S1, CapForS1); +static bool CannotBeMinInLoop(const SCEV *BoundSCEV, Loop *L, + ScalarEvolution &SE, bool Signed) { + unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); + APInt Min = Signed ? APInt::getSignedMinValue(BitWidth) : + APInt::getMinValue(BitWidth); + auto Predicate = Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; + return SE.isAvailableAtLoopEntry(BoundSCEV, L) && + SE.isLoopEntryGuardedByCond(L, Predicate, BoundSCEV, + SE.getConstant(Min)); +} + +static bool isKnownNonNegativeInLoop(const SCEV *BoundSCEV, const Loop *L, + ScalarEvolution &SE) { + const SCEV *Zero = SE.getZero(BoundSCEV->getType()); + return SE.isAvailableAtLoopEntry(BoundSCEV, L) && + SE.isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SGE, BoundSCEV, Zero); +} + +static bool isKnownNegativeInLoop(const SCEV *BoundSCEV, const Loop *L, + ScalarEvolution &SE) { + const SCEV *Zero = SE.getZero(BoundSCEV->getType()); + return SE.isAvailableAtLoopEntry(BoundSCEV, L) && + SE.isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SLT, BoundSCEV, Zero); } Optional<LoopStructure> LoopStructure::parseLoopStructure(ScalarEvolution &SE, - BranchProbabilityInfo &BPI, - Loop &L, const char *&FailureReason) { + BranchProbabilityInfo *BPI, Loop &L, + const char *&FailureReason) { if (!L.isLoopSimplifyForm()) { FailureReason = "loop not in LoopSimplify form"; return None; @@ -750,7 +858,8 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0; BranchProbability ExitProbability = - BPI.getEdgeProbability(LatchBr->getParent(), LatchBrExitIdx); + BPI ? BPI->getEdgeProbability(LatchBr->getParent(), LatchBrExitIdx) + : BranchProbability::getZero(); if (!SkipProfitabilityChecks && ExitProbability > BranchProbability(1, MaxExitProbReciprocal)) { @@ -816,43 +925,29 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap; }; - // Here we check whether the suggested AddRec is an induction variable that - // can be handled (i.e. with known constant step), and if yes, calculate its - // step and identify whether it is increasing or decreasing. - auto IsInductionVar = [&](const SCEVAddRecExpr *AR, bool &IsIncreasing, - ConstantInt *&StepCI) { - if (!AR->isAffine()) - return false; - - // Currently we only work with induction variables that have been proved to - // not wrap. This restriction can potentially be lifted in the future. - - if (!HasNoSignedWrap(AR)) - return false; - - if (const SCEVConstant *StepExpr = - dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) { - StepCI = StepExpr->getValue(); - assert(!StepCI->isZero() && "Zero step?"); - IsIncreasing = !StepCI->isNegative(); - return true; - } - - return false; - }; - // `ICI` is interpreted as taking the backedge if the *next* value of the // induction variable satisfies some constraint. const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV); - bool IsIncreasing = false; - bool IsSignedPredicate = true; - ConstantInt *StepCI; - if (!IsInductionVar(IndVarBase, IsIncreasing, StepCI)) { + if (!IndVarBase->isAffine()) { + FailureReason = "LHS in icmp not induction variable"; + return None; + } + const SCEV* StepRec = IndVarBase->getStepRecurrence(SE); + if (!isa<SCEVConstant>(StepRec)) { FailureReason = "LHS in icmp not induction variable"; return None; } + ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue(); + if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) { + FailureReason = "LHS in icmp needs nsw for equality predicates"; + return None; + } + + assert(!StepCI->isZero() && "Zero step?"); + bool IsIncreasing = !StepCI->isNegative(); + bool IsSignedPredicate = ICmpInst::isSigned(Pred); const SCEV *StartNext = IndVarBase->getStart(); const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE)); const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); @@ -870,22 +965,29 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, // If both parts are known non-negative, it is profitable to use // unsigned comparison in increasing loop. This allows us to make the // comparison check against "RightSCEV + 1" more optimistic. - if (SE.isKnownNonNegative(IndVarStart) && - SE.isKnownNonNegative(RightSCEV)) + if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) && + isKnownNonNegativeInLoop(RightSCEV, &L, SE)) Pred = ICmpInst::ICMP_ULT; else Pred = ICmpInst::ICMP_SLT; - else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0 && - !CanBeMin(SE, RightSCEV, /* IsSignedPredicate */ true)) { + else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) { // while (true) { while (true) { // if (++i == len) ---> if (++i > len - 1) // break; break; // ... ... // } } - // TODO: Insert ICMP_UGT if both are non-negative? - Pred = ICmpInst::ICMP_SGT; - RightSCEV = SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); - DecreasedRightValueByOne = true; + if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && + CannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/false)) { + Pred = ICmpInst::ICMP_UGT; + RightSCEV = SE.getMinusSCEV(RightSCEV, + SE.getOne(RightSCEV->getType())); + DecreasedRightValueByOne = true; + } else if (CannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/true)) { + Pred = ICmpInst::ICMP_SGT; + RightSCEV = SE.getMinusSCEV(RightSCEV, + SE.getOne(RightSCEV->getType())); + DecreasedRightValueByOne = true; + } } } @@ -899,36 +1001,18 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, return None; } - IsSignedPredicate = - Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT; - + IsSignedPredicate = ICmpInst::isSigned(Pred); if (!IsSignedPredicate && !AllowUnsignedLatchCondition) { FailureReason = "unsigned latch conditions are explicitly prohibited"; return None; } - // The predicate that we need to check that the induction variable lies - // within bounds. - ICmpInst::Predicate BoundPred = - IsSignedPredicate ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT; - + if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred, + LatchBrExitIdx, &L, SE)) { + FailureReason = "Unsafe loop bounds"; + return None; + } if (LatchBrExitIdx == 0) { - const SCEV *StepMinusOne = SE.getMinusSCEV(Step, - SE.getOne(Step->getType())); - if (SumCanReachMax(SE, RightSCEV, StepMinusOne, IsSignedPredicate)) { - // TODO: this restriction is easily removable -- we just have to - // remember that the icmp was an slt and not an sle. - FailureReason = "limit may overflow when coercing le to lt"; - return None; - } - - if (!SE.isLoopEntryGuardedByCond( - &L, BoundPred, IndVarStart, - SE.getAddExpr(RightSCEV, Step))) { - FailureReason = "Induction variable start not bounded by upper limit"; - return None; - } - // We need to increase the right value unless we have already decreased // it virtually when we replaced EQ with SGT. if (!DecreasedRightValueByOne) { @@ -936,10 +1020,6 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, RightValue = B.CreateAdd(RightValue, One); } } else { - if (!SE.isLoopEntryGuardedByCond(&L, BoundPred, IndVarStart, RightSCEV)) { - FailureReason = "Induction variable start not bounded by upper limit"; - return None; - } assert(!DecreasedRightValueByOne && "Right value can be decreased only for LatchBrExitIdx == 0!"); } @@ -955,17 +1035,22 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, // that both operands are non-negative, because it will only pessimize // our check against "RightSCEV - 1". Pred = ICmpInst::ICMP_SGT; - else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0 && - !CanBeMax(SE, RightSCEV, /* IsSignedPredicate */ true)) { + else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) { // while (true) { while (true) { // if (--i == len) ---> if (--i < len + 1) // break; break; // ... ... // } } - // TODO: Insert ICMP_ULT if both are non-negative? - Pred = ICmpInst::ICMP_SLT; - RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); - IncreasedRightValueByOne = true; + if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && + CannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) { + Pred = ICmpInst::ICMP_ULT; + RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); + IncreasedRightValueByOne = true; + } else if (CannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) { + Pred = ICmpInst::ICMP_SLT; + RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); + IncreasedRightValueByOne = true; + } } } @@ -988,27 +1073,13 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, return None; } - // The predicate that we need to check that the induction variable lies - // within bounds. - ICmpInst::Predicate BoundPred = - IsSignedPredicate ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT; + if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred, + LatchBrExitIdx, &L, SE)) { + FailureReason = "Unsafe bounds"; + return None; + } if (LatchBrExitIdx == 0) { - const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType())); - if (SumCanReachMin(SE, RightSCEV, StepPlusOne, IsSignedPredicate)) { - // TODO: this restriction is easily removable -- we just have to - // remember that the icmp was an sgt and not an sge. - FailureReason = "limit may overflow when coercing ge to gt"; - return None; - } - - if (!SE.isLoopEntryGuardedByCond( - &L, BoundPred, IndVarStart, - SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())))) { - FailureReason = "Induction variable start not bounded by lower limit"; - return None; - } - // We need to decrease the right value unless we have already increased // it virtually when we replaced EQ with SLT. if (!IncreasedRightValueByOne) { @@ -1016,10 +1087,6 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, RightValue = B.CreateSub(RightValue, One); } } else { - if (!SE.isLoopEntryGuardedByCond(&L, BoundPred, IndVarStart, RightSCEV)) { - FailureReason = "Induction variable start not bounded by lower limit"; - return None; - } assert(!IncreasedRightValueByOne && "Right value can be increased only for LatchBrExitIdx == 0!"); } @@ -1174,13 +1241,9 @@ void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, if (OriginalLoop.contains(SBB)) continue; // not an exit block - for (Instruction &I : *SBB) { - auto *PN = dyn_cast<PHINode>(&I); - if (!PN) - break; - - Value *OldIncoming = PN->getIncomingValueForBlock(OriginalBB); - PN->addIncoming(GetClonedValue(OldIncoming), ClonedBB); + for (PHINode &PN : SBB->phis()) { + Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB); + PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB); } } } @@ -1327,16 +1390,12 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of // each of the PHI nodes in the loop header. This feeds into the initial // value of the same PHI nodes if/when we continue execution. - for (Instruction &I : *LS.Header) { - auto *PN = dyn_cast<PHINode>(&I); - if (!PN) - break; - - PHINode *NewPHI = PHINode::Create(PN->getType(), 2, PN->getName() + ".copy", + for (PHINode &PN : LS.Header->phis()) { + PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy", BranchToContinuation); - NewPHI->addIncoming(PN->getIncomingValueForBlock(Preheader), Preheader); - NewPHI->addIncoming(PN->getIncomingValueForBlock(LS.Latch), + NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader); + NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch), RRI.ExitSelector); RRI.PHIValuesAtPseudoExit.push_back(NewPHI); } @@ -1348,12 +1407,8 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( // The latch exit now has a branch from `RRI.ExitSelector' instead of // `LS.Latch'. The PHI nodes need to be updated to reflect that. - for (Instruction &I : *LS.LatchExit) { - if (PHINode *PN = dyn_cast<PHINode>(&I)) - replacePHIBlock(PN, LS.Latch, RRI.ExitSelector); - else - break; - } + for (PHINode &PN : LS.LatchExit->phis()) + replacePHIBlock(&PN, LS.Latch, RRI.ExitSelector); return RRI; } @@ -1362,15 +1417,10 @@ void LoopConstrainer::rewriteIncomingValuesForPHIs( LoopStructure &LS, BasicBlock *ContinuationBlock, const LoopConstrainer::RewrittenRangeInfo &RRI) const { unsigned PHIIndex = 0; - for (Instruction &I : *LS.Header) { - auto *PN = dyn_cast<PHINode>(&I); - if (!PN) - break; - - for (unsigned i = 0, e = PN->getNumIncomingValues(); i < e; ++i) - if (PN->getIncomingBlock(i) == ContinuationBlock) - PN->setIncomingValue(i, RRI.PHIValuesAtPseudoExit[PHIIndex++]); - } + for (PHINode &PN : LS.Header->phis()) + for (unsigned i = 0, e = PN.getNumIncomingValues(); i < e; ++i) + if (PN.getIncomingBlock(i) == ContinuationBlock) + PN.setIncomingValue(i, RRI.PHIValuesAtPseudoExit[PHIIndex++]); LS.IndVarStart = RRI.IndVarEnd; } @@ -1381,14 +1431,9 @@ BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS, BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header); BranchInst::Create(LS.Header, Preheader); - for (Instruction &I : *LS.Header) { - auto *PN = dyn_cast<PHINode>(&I); - if (!PN) - break; - - for (unsigned i = 0, e = PN->getNumIncomingValues(); i < e; ++i) - replacePHIBlock(PN, OldPreheader, Preheader); - } + for (PHINode &PN : LS.Header->phis()) + for (unsigned i = 0, e = PN.getNumIncomingValues(); i < e; ++i) + replacePHIBlock(&PN, OldPreheader, Preheader); return Preheader; } @@ -1403,13 +1448,14 @@ void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) { } Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent, - ValueToValueMapTy &VM) { + ValueToValueMapTy &VM, + bool IsSubloop) { Loop &New = *LI.AllocateLoop(); if (Parent) Parent->addChildLoop(&New); else LI.addTopLevelLoop(&New); - LPM.addLoop(New); + LPMAddNewLoop(&New, IsSubloop); // Add all of the blocks in Original to the new loop. for (auto *BB : Original->blocks()) @@ -1418,7 +1464,7 @@ Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent, // Add all of the subloops to the new loop. for (Loop *SubLoop : *Original) - createClonedLoopStructure(SubLoop, &New, VM); + createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true); return &New; } @@ -1436,7 +1482,7 @@ bool LoopConstrainer::run() { bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate; Optional<SubRanges> MaybeSR = calculateSubRanges(IsSignedPredicate); if (!MaybeSR.hasValue()) { - DEBUG(dbgs() << "irce: could not compute subranges\n"); + LLVM_DEBUG(dbgs() << "irce: could not compute subranges\n"); return false; } @@ -1468,19 +1514,22 @@ bool LoopConstrainer::run() { if (Increasing) ExitPreLoopAtSCEV = *SR.LowLimit; else { - if (CanBeMin(SE, *SR.HighLimit, IsSignedPredicate)) { - DEBUG(dbgs() << "irce: could not prove no-overflow when computing " - << "preloop exit limit. HighLimit = " << *(*SR.HighLimit) - << "\n"); + if (CannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE, + IsSignedPredicate)) + ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); + else { + LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing " + << "preloop exit limit. HighLimit = " + << *(*SR.HighLimit) << "\n"); return false; } - ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); } if (!isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt, SE)) { - DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" - << " preloop exit limit " << *ExitPreLoopAtSCEV - << " at block " << InsertPt->getParent()->getName() << "\n"); + LLVM_DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" + << " preloop exit limit " << *ExitPreLoopAtSCEV + << " at block " << InsertPt->getParent()->getName() + << "\n"); return false; } @@ -1494,19 +1543,22 @@ bool LoopConstrainer::run() { if (Increasing) ExitMainLoopAtSCEV = *SR.HighLimit; else { - if (CanBeMin(SE, *SR.LowLimit, IsSignedPredicate)) { - DEBUG(dbgs() << "irce: could not prove no-overflow when computing " - << "mainloop exit limit. LowLimit = " << *(*SR.LowLimit) - << "\n"); + if (CannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE, + IsSignedPredicate)) + ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); + else { + LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing " + << "mainloop exit limit. LowLimit = " + << *(*SR.LowLimit) << "\n"); return false; } - ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); } if (!isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt, SE)) { - DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" - << " main loop exit limit " << *ExitMainLoopAtSCEV - << " at block " << InsertPt->getParent()->getName() << "\n"); + LLVM_DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" + << " main loop exit limit " << *ExitMainLoopAtSCEV + << " at block " << InsertPt->getParent()->getName() + << "\n"); return false; } @@ -1568,13 +1620,15 @@ bool LoopConstrainer::run() { // LI when LoopSimplifyForm is generated. Loop *PreL = nullptr, *PostL = nullptr; if (!PreLoop.Blocks.empty()) { - PreL = createClonedLoopStructure( - &OriginalLoop, OriginalLoop.getParentLoop(), PreLoop.Map); + PreL = createClonedLoopStructure(&OriginalLoop, + OriginalLoop.getParentLoop(), PreLoop.Map, + /* IsSubLoop */ false); } if (!PostLoop.Blocks.empty()) { - PostL = createClonedLoopStructure( - &OriginalLoop, OriginalLoop.getParentLoop(), PostLoop.Map); + PostL = + createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(), + PostLoop.Map, /* IsSubLoop */ false); } // This function canonicalizes the loop into Loop-Simplify and LCSSA forms. @@ -1640,32 +1694,34 @@ InductiveRangeCheck::computeSafeIterationSpace( unsigned BitWidth = cast<IntegerType>(IndVar->getType())->getBitWidth(); const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); - // Substract Y from X so that it does not go through border of the IV + // Subtract Y from X so that it does not go through border of the IV // iteration space. Mathematically, it is equivalent to: // - // ClampedSubstract(X, Y) = min(max(X - Y, INT_MIN), INT_MAX). [1] + // ClampedSubtract(X, Y) = min(max(X - Y, INT_MIN), INT_MAX). [1] // - // In [1], 'X - Y' is a mathematical substraction (result is not bounded to + // In [1], 'X - Y' is a mathematical subtraction (result is not bounded to // any width of bit grid). But after we take min/max, the result is // guaranteed to be within [INT_MIN, INT_MAX]. // // In [1], INT_MAX and INT_MIN are respectively signed and unsigned max/min // values, depending on type of latch condition that defines IV iteration // space. - auto ClampedSubstract = [&](const SCEV *X, const SCEV *Y) { - assert(SE.isKnownNonNegative(X) && - "We can only substract from values in [0; SINT_MAX]!"); + auto ClampedSubtract = [&](const SCEV *X, const SCEV *Y) { + // FIXME: The current implementation assumes that X is in [0, SINT_MAX]. + // This is required to ensure that SINT_MAX - X does not overflow signed and + // that X - Y does not overflow unsigned if Y is negative. Can we lift this + // restriction and make it work for negative X either? if (IsLatchSigned) { // X is a number from signed range, Y is interpreted as signed. // Even if Y is SINT_MAX, (X - Y) does not reach SINT_MIN. So the only // thing we should care about is that we didn't cross SINT_MAX. - // So, if Y is positive, we substract Y safely. + // So, if Y is positive, we subtract Y safely. // Rule 1: Y > 0 ---> Y. - // If 0 <= -Y <= (SINT_MAX - X), we substract Y safely. + // If 0 <= -Y <= (SINT_MAX - X), we subtract Y safely. // Rule 2: Y >=s (X - SINT_MAX) ---> Y. - // If 0 <= (SINT_MAX - X) < -Y, we can only substract (X - SINT_MAX). + // If 0 <= (SINT_MAX - X) < -Y, we can only subtract (X - SINT_MAX). // Rule 3: Y <s (X - SINT_MAX) ---> (X - SINT_MAX). - // It gives us smax(Y, X - SINT_MAX) to substract in all cases. + // It gives us smax(Y, X - SINT_MAX) to subtract in all cases. const SCEV *XMinusSIntMax = SE.getMinusSCEV(X, SIntMax); return SE.getMinusSCEV(X, SE.getSMaxExpr(Y, XMinusSIntMax), SCEV::FlagNSW); @@ -1673,29 +1729,45 @@ InductiveRangeCheck::computeSafeIterationSpace( // X is a number from unsigned range, Y is interpreted as signed. // Even if Y is SINT_MIN, (X - Y) does not reach UINT_MAX. So the only // thing we should care about is that we didn't cross zero. - // So, if Y is negative, we substract Y safely. + // So, if Y is negative, we subtract Y safely. // Rule 1: Y <s 0 ---> Y. - // If 0 <= Y <= X, we substract Y safely. + // If 0 <= Y <= X, we subtract Y safely. // Rule 2: Y <=s X ---> Y. - // If 0 <= X < Y, we should stop at 0 and can only substract X. + // If 0 <= X < Y, we should stop at 0 and can only subtract X. // Rule 3: Y >s X ---> X. - // It gives us smin(X, Y) to substract in all cases. + // It gives us smin(X, Y) to subtract in all cases. return SE.getMinusSCEV(X, SE.getSMinExpr(X, Y), SCEV::FlagNUW); }; const SCEV *M = SE.getMinusSCEV(C, A); const SCEV *Zero = SE.getZero(M->getType()); - const SCEV *Begin = ClampedSubstract(Zero, M); - const SCEV *L = nullptr; - // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L". - // We can potentially do much better here. - if (const SCEV *EndLimit = getEnd()) - L = EndLimit; - else { - assert(Kind == InductiveRangeCheck::RANGE_CHECK_LOWER && "invariant!"); - L = SIntMax; - } - const SCEV *End = ClampedSubstract(L, M); + // This function returns SCEV equal to 1 if X is non-negative 0 otherwise. + auto SCEVCheckNonNegative = [&](const SCEV *X) { + const Loop *L = IndVar->getLoop(); + const SCEV *One = SE.getOne(X->getType()); + // Can we trivially prove that X is a non-negative or negative value? + if (isKnownNonNegativeInLoop(X, L, SE)) + return One; + else if (isKnownNegativeInLoop(X, L, SE)) + return Zero; + // If not, we will have to figure it out during the execution. + // Function smax(smin(X, 0), -1) + 1 equals to 1 if X >= 0 and 0 if X < 0. + const SCEV *NegOne = SE.getNegativeSCEV(One); + return SE.getAddExpr(SE.getSMaxExpr(SE.getSMinExpr(X, Zero), NegOne), One); + }; + // FIXME: Current implementation of ClampedSubtract implicitly assumes that + // X is non-negative (in sense of a signed value). We need to re-implement + // this function in a way that it will correctly handle negative X as well. + // We use it twice: for X = 0 everything is fine, but for X = getEnd() we can + // end up with a negative X and produce wrong results. So currently we ensure + // that if getEnd() is negative then both ends of the safe range are zero. + // Note that this may pessimize elimination of unsigned range checks against + // negative values. + const SCEV *REnd = getEnd(); + const SCEV *EndIsNonNegative = SCEVCheckNonNegative(REnd); + + const SCEV *Begin = SE.getMulExpr(ClampedSubtract(Zero, M), EndIsNonNegative); + const SCEV *End = SE.getMulExpr(ClampedSubtract(REnd, M), EndIsNonNegative); return InductiveRangeCheck::Range(Begin, End); } @@ -1757,26 +1829,56 @@ IntersectUnsignedRange(ScalarEvolution &SE, return Ret; } -bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { +PreservedAnalyses IRCEPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &U) { + Function *F = L.getHeader()->getParent(); + const auto &FAM = + AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); + auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F); + InductiveRangeCheckElimination IRCE(AR.SE, BPI, AR.DT, AR.LI); + auto LPMAddNewLoop = [&U](Loop *NL, bool IsSubloop) { + if (!IsSubloop) + U.addSiblingLoops(NL); + }; + bool Changed = IRCE.run(&L, LPMAddNewLoop); + if (!Changed) + return PreservedAnalyses::all(); + + return getLoopPassPreservedAnalyses(); +} + +bool IRCELegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { if (skipLoop(L)) return false; + ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + BranchProbabilityInfo &BPI = + getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + InductiveRangeCheckElimination IRCE(SE, &BPI, DT, LI); + auto LPMAddNewLoop = [&LPM](Loop *NL, bool /* IsSubLoop */) { + LPM.addLoop(*NL); + }; + return IRCE.run(L, LPMAddNewLoop); +} + +bool InductiveRangeCheckElimination::run( + Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop) { if (L->getBlocks().size() >= LoopSizeCutoff) { - DEBUG(dbgs() << "irce: giving up constraining loop, too large\n";); + LLVM_DEBUG(dbgs() << "irce: giving up constraining loop, too large\n"); return false; } BasicBlock *Preheader = L->getLoopPreheader(); if (!Preheader) { - DEBUG(dbgs() << "irce: loop has no preheader, leaving\n"); + LLVM_DEBUG(dbgs() << "irce: loop has no preheader, leaving\n"); return false; } LLVMContext &Context = Preheader->getContext(); SmallVector<InductiveRangeCheck, 16> RangeChecks; - ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - BranchProbabilityInfo &BPI = - getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); for (auto BBI : L->getBlocks()) if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator())) @@ -1794,7 +1896,7 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { IRC.print(OS); }; - DEBUG(PrintRecognizedRangeChecks(dbgs())); + LLVM_DEBUG(PrintRecognizedRangeChecks(dbgs())); if (PrintRangeChecks) PrintRecognizedRangeChecks(errs()); @@ -1803,8 +1905,8 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { Optional<LoopStructure> MaybeLoopStructure = LoopStructure::parseLoopStructure(SE, BPI, *L, FailureReason); if (!MaybeLoopStructure.hasValue()) { - DEBUG(dbgs() << "irce: could not parse loop structure: " << FailureReason - << "\n";); + LLVM_DEBUG(dbgs() << "irce: could not parse loop structure: " + << FailureReason << "\n";); return false; } LoopStructure LS = MaybeLoopStructure.getValue(); @@ -1842,9 +1944,8 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { if (!SafeIterRange.hasValue()) return false; - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - LoopConstrainer LC(*L, getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), LPM, - LS, SE, DT, SafeIterRange.getValue()); + LoopConstrainer LC(*L, LI, LPMAddNewLoop, LS, SE, DT, + SafeIterRange.getValue()); bool Changed = LC.run(); if (Changed) { @@ -1855,7 +1956,7 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { L->print(dbgs()); }; - DEBUG(PrintConstrainedLoopInfo()); + LLVM_DEBUG(PrintConstrainedLoopInfo()); if (PrintChangedLoops) PrintConstrainedLoopInfo(); @@ -1874,5 +1975,5 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { } Pass *llvm::createInductiveRangeCheckEliminationPass() { - return new InductiveRangeCheckElimination; + return new IRCELegacyPass(); } diff --git a/lib/Transforms/Scalar/InferAddressSpaces.cpp b/lib/Transforms/Scalar/InferAddressSpaces.cpp index 7d66c0f73821..fbbc09eb487f 100644 --- a/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ b/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -97,6 +97,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" @@ -121,7 +122,6 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <cassert> #include <iterator> @@ -140,7 +140,7 @@ namespace { using ValueToAddrSpaceMapTy = DenseMap<const Value *, unsigned>; -/// \brief InferAddressSpaces +/// InferAddressSpaces class InferAddressSpaces : public FunctionPass { /// Target specific address space which uses of should be replaced if /// possible. @@ -260,7 +260,10 @@ bool InferAddressSpaces::rewriteIntrinsicOperands(IntrinsicInst *II, switch (II->getIntrinsicID()) { case Intrinsic::amdgcn_atomic_inc: - case Intrinsic::amdgcn_atomic_dec:{ + case Intrinsic::amdgcn_atomic_dec: + case Intrinsic::amdgcn_ds_fadd: + case Intrinsic::amdgcn_ds_fmin: + case Intrinsic::amdgcn_ds_fmax: { const ConstantInt *IsVolatile = dyn_cast<ConstantInt>(II->getArgOperand(4)); if (!IsVolatile || !IsVolatile->isZero()) return false; @@ -289,6 +292,9 @@ void InferAddressSpaces::collectRewritableIntrinsicOperands( case Intrinsic::objectsize: case Intrinsic::amdgcn_atomic_inc: case Intrinsic::amdgcn_atomic_dec: + case Intrinsic::amdgcn_ds_fadd: + case Intrinsic::amdgcn_ds_fmin: + case Intrinsic::amdgcn_ds_fmax: appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0), PostorderStack, Visited); break; @@ -647,13 +653,13 @@ void InferAddressSpaces::inferAddressSpaces( // Tries to update the address space of the stack top according to the // address spaces of its operands. - DEBUG(dbgs() << "Updating the address space of\n " << *V << '\n'); + LLVM_DEBUG(dbgs() << "Updating the address space of\n " << *V << '\n'); Optional<unsigned> NewAS = updateAddressSpace(*V, *InferredAddrSpace); if (!NewAS.hasValue()) continue; // If any updates are made, grabs its users to the worklist because // their address spaces can also be possibly updated. - DEBUG(dbgs() << " to " << NewAS.getValue() << '\n'); + LLVM_DEBUG(dbgs() << " to " << NewAS.getValue() << '\n'); (*InferredAddrSpace)[V] = NewAS.getValue(); for (Value *User : V->users()) { @@ -779,7 +785,7 @@ static bool handleMemIntrinsicPtrUse(MemIntrinsic *MI, Value *OldV, if (auto *MSI = dyn_cast<MemSetInst>(MI)) { B.CreateMemSet(NewV, MSI->getValue(), - MSI->getLength(), MSI->getAlignment(), + MSI->getLength(), MSI->getDestAlignment(), false, // isVolatile TBAA, ScopeMD, NoAliasMD); } else if (auto *MTI = dyn_cast<MemTransferInst>(MI)) { @@ -795,14 +801,16 @@ static bool handleMemIntrinsicPtrUse(MemIntrinsic *MI, Value *OldV, if (isa<MemCpyInst>(MTI)) { MDNode *TBAAStruct = MTI->getMetadata(LLVMContext::MD_tbaa_struct); - B.CreateMemCpy(Dest, Src, MTI->getLength(), - MTI->getAlignment(), + B.CreateMemCpy(Dest, MTI->getDestAlignment(), + Src, MTI->getSourceAlignment(), + MTI->getLength(), false, // isVolatile TBAA, TBAAStruct, ScopeMD, NoAliasMD); } else { assert(isa<MemMoveInst>(MTI)); - B.CreateMemMove(Dest, Src, MTI->getLength(), - MTI->getAlignment(), + B.CreateMemMove(Dest, MTI->getDestAlignment(), + Src, MTI->getSourceAlignment(), + MTI->getLength(), false, // isVolatile TBAA, ScopeMD, NoAliasMD); } @@ -893,15 +901,15 @@ bool InferAddressSpaces::rewriteWithNewAddressSpaces( if (NewV == nullptr) continue; - DEBUG(dbgs() << "Replacing the uses of " << *V - << "\n with\n " << *NewV << '\n'); + LLVM_DEBUG(dbgs() << "Replacing the uses of " << *V << "\n with\n " + << *NewV << '\n'); if (Constant *C = dyn_cast<Constant>(V)) { Constant *Replace = ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV), C->getType()); if (C != Replace) { - DEBUG(dbgs() << "Inserting replacement const cast: " - << Replace << ": " << *Replace << '\n'); + LLVM_DEBUG(dbgs() << "Inserting replacement const cast: " << Replace + << ": " << *Replace << '\n'); C->replaceAllUsesWith(Replace); V = Replace; } diff --git a/lib/Transforms/Utils/SimplifyInstructions.cpp b/lib/Transforms/Scalar/InstSimplifyPass.cpp index f3d4f2ef38d7..05cd48d83267 100644 --- a/lib/Transforms/Utils/SimplifyInstructions.cpp +++ b/lib/Transforms/Scalar/InstSimplifyPass.cpp @@ -1,4 +1,4 @@ -//===------ SimplifyInstructions.cpp - Remove redundant instructions ------===// +//===- InstSimplifyPass.cpp -----------------------------------------------===// // // The LLVM Compiler Infrastructure // @@ -6,15 +6,8 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// -// -// This is a utility pass used for testing the InstructionSimplify analysis. -// The analysis is applied to every instruction, and if it simplifies then the -// instruction is replaced by the simplification. If you are looking for a pass -// that performs serious instruction folding, use the instcombine pass instead. -// -//===----------------------------------------------------------------------===// -#include "llvm/Transforms/Utils/SimplifyInstructions.h" +#include "llvm/Transforms/Scalar/InstSimplifyPass.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" @@ -27,7 +20,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/Type.h" #include "llvm/Pass.h" -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -84,58 +77,57 @@ static bool runImpl(Function &F, const SimplifyQuery &SQ, } namespace { - struct InstSimplifier : public FunctionPass { - static char ID; // Pass identification, replacement for typeid - InstSimplifier() : FunctionPass(ID) { - initializeInstSimplifierPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); - } - - /// runOnFunction - Remove instructions that simplify. - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; - - const DominatorTree *DT = - &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - const TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - AssumptionCache *AC = - &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - OptimizationRemarkEmitter *ORE = - &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - const DataLayout &DL = F.getParent()->getDataLayout(); - const SimplifyQuery SQ(DL, TLI, DT, AC); - return runImpl(F, SQ, ORE); - } - }; -} - -char InstSimplifier::ID = 0; -INITIALIZE_PASS_BEGIN(InstSimplifier, "instsimplify", +struct InstSimplifyLegacyPass : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + InstSimplifyLegacyPass() : FunctionPass(ID) { + initializeInstSimplifyLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); + } + + /// runOnFunction - Remove instructions that simplify. + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + const DominatorTree *DT = + &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + const TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + AssumptionCache *AC = + &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + OptimizationRemarkEmitter *ORE = + &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); + const DataLayout &DL = F.getParent()->getDataLayout(); + const SimplifyQuery SQ(DL, TLI, DT, AC); + return runImpl(F, SQ, ORE); + } +}; +} // namespace + +char InstSimplifyLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(InstSimplifyLegacyPass, "instsimplify", "Remove redundant instructions", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) -INITIALIZE_PASS_END(InstSimplifier, "instsimplify", +INITIALIZE_PASS_END(InstSimplifyLegacyPass, "instsimplify", "Remove redundant instructions", false, false) -char &llvm::InstructionSimplifierID = InstSimplifier::ID; // Public interface to the simplify instructions pass. -FunctionPass *llvm::createInstructionSimplifierPass() { - return new InstSimplifier(); +FunctionPass *llvm::createInstSimplifyLegacyPass() { + return new InstSimplifyLegacyPass(); } -PreservedAnalyses InstSimplifierPass::run(Function &F, - FunctionAnalysisManager &AM) { +PreservedAnalyses InstSimplifyPass::run(Function &F, + FunctionAnalysisManager &AM) { auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); diff --git a/lib/Transforms/Scalar/JumpThreading.cpp b/lib/Transforms/Scalar/JumpThreading.cpp index 1476f7850cf0..1d66472f93c8 100644 --- a/lib/Transforms/Scalar/JumpThreading.cpp +++ b/lib/Transforms/Scalar/JumpThreading.cpp @@ -30,6 +30,7 @@ #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -64,7 +65,6 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> @@ -131,10 +131,11 @@ namespace { bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { - if (PrintLVIAfterJumpThreading) - AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); AU.addRequired<AAResultsWrapperPass>(); AU.addRequired<LazyValueInfoWrapperPass>(); + AU.addPreserved<LazyValueInfoWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); } @@ -148,6 +149,7 @@ char JumpThreading::ID = 0; INITIALIZE_PASS_BEGIN(JumpThreading, "jump-threading", "Jump Threading", false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) @@ -164,7 +166,7 @@ JumpThreadingPass::JumpThreadingPass(int T) { } // Update branch probability information according to conditional -// branch probablity. This is usually made possible for cloned branches +// branch probability. This is usually made possible for cloned branches // in inline instances by the context specific profile in the caller. // For instance, // @@ -278,8 +280,12 @@ bool JumpThreading::runOnFunction(Function &F) { if (skipFunction(F)) return false; auto TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + // Get DT analysis before LVI. When LVI is initialized it conditionally adds + // DT if it's available. + auto DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + DeferredDominance DDT(*DT); std::unique_ptr<BlockFrequencyInfo> BFI; std::unique_ptr<BranchProbabilityInfo> BPI; bool HasProfileData = F.hasProfileData(); @@ -289,12 +295,11 @@ bool JumpThreading::runOnFunction(Function &F) { BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); } - bool Changed = Impl.runImpl(F, TLI, LVI, AA, HasProfileData, std::move(BFI), - std::move(BPI)); + bool Changed = Impl.runImpl(F, TLI, LVI, AA, &DDT, HasProfileData, + std::move(BFI), std::move(BPI)); if (PrintLVIAfterJumpThreading) { dbgs() << "LVI for function '" << F.getName() << "':\n"; - LVI->printLVI(F, getAnalysis<DominatorTreeWrapperPass>().getDomTree(), - dbgs()); + LVI->printLVI(F, *DT, dbgs()); } return Changed; } @@ -302,8 +307,12 @@ bool JumpThreading::runOnFunction(Function &F) { PreservedAnalyses JumpThreadingPass::run(Function &F, FunctionAnalysisManager &AM) { auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + // Get DT analysis before LVI. When LVI is initialized it conditionally adds + // DT if it's available. + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &LVI = AM.getResult<LazyValueAnalysis>(F); auto &AA = AM.getResult<AAManager>(F); + DeferredDominance DDT(DT); std::unique_ptr<BlockFrequencyInfo> BFI; std::unique_ptr<BranchProbabilityInfo> BPI; @@ -313,25 +322,28 @@ PreservedAnalyses JumpThreadingPass::run(Function &F, BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); } - bool Changed = runImpl(F, &TLI, &LVI, &AA, HasProfileData, std::move(BFI), - std::move(BPI)); + bool Changed = runImpl(F, &TLI, &LVI, &AA, &DDT, HasProfileData, + std::move(BFI), std::move(BPI)); if (!Changed) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserve<GlobalsAA>(); + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<LazyValueAnalysis>(); return PA; } bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, LazyValueInfo *LVI_, AliasAnalysis *AA_, - bool HasProfileData_, + DeferredDominance *DDT_, bool HasProfileData_, std::unique_ptr<BlockFrequencyInfo> BFI_, std::unique_ptr<BranchProbabilityInfo> BPI_) { - DEBUG(dbgs() << "Jump threading on function '" << F.getName() << "'\n"); + LLVM_DEBUG(dbgs() << "Jump threading on function '" << F.getName() << "'\n"); TLI = TLI_; LVI = LVI_; AA = AA_; + DDT = DDT_; BFI.reset(); BPI.reset(); // When profile data is available, we need to update edge weights after @@ -345,69 +357,66 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, BFI = std::move(BFI_); } - // Remove unreachable blocks from function as they may result in infinite - // loop. We do threading if we found something profitable. Jump threading a - // branch can create other opportunities. If these opportunities form a cycle - // i.e. if any jump threading is undoing previous threading in the path, then - // we will loop forever. We take care of this issue by not jump threading for - // back edges. This works for normal cases but not for unreachable blocks as - // they may have cycle with no back edge. - bool EverChanged = false; - EverChanged |= removeUnreachableBlocks(F, LVI); + // JumpThreading must not processes blocks unreachable from entry. It's a + // waste of compute time and can potentially lead to hangs. + SmallPtrSet<BasicBlock *, 16> Unreachable; + DominatorTree &DT = DDT->flush(); + for (auto &BB : F) + if (!DT.isReachableFromEntry(&BB)) + Unreachable.insert(&BB); FindLoopHeaders(F); + bool EverChanged = false; bool Changed; do { Changed = false; - for (Function::iterator I = F.begin(), E = F.end(); I != E;) { - BasicBlock *BB = &*I; - // Thread all of the branches we can over this block. - while (ProcessBlock(BB)) + for (auto &BB : F) { + if (Unreachable.count(&BB)) + continue; + while (ProcessBlock(&BB)) // Thread all of the branches we can over BB. Changed = true; + // Stop processing BB if it's the entry or is now deleted. The following + // routines attempt to eliminate BB and locating a suitable replacement + // for the entry is non-trivial. + if (&BB == &F.getEntryBlock() || DDT->pendingDeletedBB(&BB)) + continue; - ++I; - - // If the block is trivially dead, zap it. This eliminates the successor - // edges which simplifies the CFG. - if (pred_empty(BB) && - BB != &BB->getParent()->getEntryBlock()) { - DEBUG(dbgs() << " JT: Deleting dead block '" << BB->getName() - << "' with terminator: " << *BB->getTerminator() << '\n'); - LoopHeaders.erase(BB); - LVI->eraseBlock(BB); - DeleteDeadBlock(BB); + if (pred_empty(&BB)) { + // When ProcessBlock makes BB unreachable it doesn't bother to fix up + // the instructions in it. We must remove BB to prevent invalid IR. + LLVM_DEBUG(dbgs() << " JT: Deleting dead block '" << BB.getName() + << "' with terminator: " << *BB.getTerminator() + << '\n'); + LoopHeaders.erase(&BB); + LVI->eraseBlock(&BB); + DeleteDeadBlock(&BB, DDT); Changed = true; continue; } - BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); - - // Can't thread an unconditional jump, but if the block is "almost - // empty", we can replace uses of it with uses of the successor and make - // this dead. - // We should not eliminate the loop header or latch either, because - // eliminating a loop header or latch might later prevent LoopSimplify - // from transforming nested loops into simplified form. We will rely on - // later passes in backend to clean up empty blocks. + // ProcessBlock doesn't thread BBs with unconditional TIs. However, if BB + // is "almost empty", we attempt to merge BB with its sole successor. + auto *BI = dyn_cast<BranchInst>(BB.getTerminator()); if (BI && BI->isUnconditional() && - BB != &BB->getParent()->getEntryBlock() && - // If the terminator is the only non-phi instruction, try to nuke it. - BB->getFirstNonPHIOrDbg()->isTerminator() && !LoopHeaders.count(BB) && - !LoopHeaders.count(BI->getSuccessor(0))) { - // FIXME: It is always conservatively correct to drop the info - // for a block even if it doesn't get erased. This isn't totally - // awesome, but it allows us to use AssertingVH to prevent nasty - // dangling pointer issues within LazyValueInfo. - LVI->eraseBlock(BB); - if (TryToSimplifyUncondBranchFromEmptyBlock(BB)) - Changed = true; + // The terminator must be the only non-phi instruction in BB. + BB.getFirstNonPHIOrDbg()->isTerminator() && + // Don't alter Loop headers and latches to ensure another pass can + // detect and transform nested loops later. + !LoopHeaders.count(&BB) && !LoopHeaders.count(BI->getSuccessor(0)) && + TryToSimplifyUncondBranchFromEmptyBlock(&BB, DDT)) { + // BB is valid for cleanup here because we passed in DDT. F remains + // BB's parent until a DDT->flush() event. + LVI->eraseBlock(&BB); + Changed = true; } } EverChanged |= Changed; } while (Changed); LoopHeaders.clear(); + DDT->flush(); + LVI->enableDT(); return EverChanged; } @@ -600,6 +609,10 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( // "X < 4" and "X < 3" is known true but "X < 4" itself is not available. // Perhaps getConstantOnEdge should be smart enough to do this? + if (DDT->pending()) + LVI->disableDT(); + else + LVI->enableDT(); for (BasicBlock *P : predecessors(BB)) { // If the value is known by LazyValueInfo to be a constant in a // predecessor, use that information to try to thread this block. @@ -613,6 +626,10 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( /// If I is a PHI node, then we know the incoming values for any constants. if (PHINode *PN = dyn_cast<PHINode>(I)) { + if (DDT->pending()) + LVI->disableDT(); + else + LVI->enableDT(); for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { Value *InVal = PN->getIncomingValue(i); if (Constant *KC = getKnownConstant(InVal, Preference)) { @@ -630,11 +647,9 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( } // Handle Cast instructions. Only see through Cast when the source operand is - // PHI or Cmp and the source type is i1 to save the compilation time. + // PHI or Cmp to save the compilation time. if (CastInst *CI = dyn_cast<CastInst>(I)) { Value *Source = CI->getOperand(0); - if (!Source->getType()->isIntegerTy(1)) - return false; if (!isa<PHINode>(Source) && !isa<CmpInst>(Source)) return false; ComputeValueKnownInPredecessors(Source, BB, Result, Preference, CxtI); @@ -738,20 +753,36 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( CmpInst::Predicate Pred = Cmp->getPredicate(); PHINode *PN = dyn_cast<PHINode>(CmpLHS); + if (!PN) + PN = dyn_cast<PHINode>(CmpRHS); if (PN && PN->getParent() == BB) { const DataLayout &DL = PN->getModule()->getDataLayout(); // We can do this simplification if any comparisons fold to true or false. // See if any do. + if (DDT->pending()) + LVI->disableDT(); + else + LVI->enableDT(); for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { BasicBlock *PredBB = PN->getIncomingBlock(i); - Value *LHS = PN->getIncomingValue(i); - Value *RHS = CmpRHS->DoPHITranslation(BB, PredBB); - + Value *LHS, *RHS; + if (PN == CmpLHS) { + LHS = PN->getIncomingValue(i); + RHS = CmpRHS->DoPHITranslation(BB, PredBB); + } else { + LHS = CmpLHS->DoPHITranslation(BB, PredBB); + RHS = PN->getIncomingValue(i); + } Value *Res = SimplifyCmpInst(Pred, LHS, RHS, {DL}); if (!Res) { if (!isa<Constant>(RHS)) continue; + // getPredicateOnEdge call will make no sense if LHS is defined in BB. + auto LHSInst = dyn_cast<Instruction>(LHS); + if (LHSInst && LHSInst->getParent() == BB) + continue; + LazyValueInfo::Tristate ResT = LVI->getPredicateOnEdge(Pred, LHS, cast<Constant>(RHS), PredBB, BB, @@ -775,6 +806,10 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( if (!isa<Instruction>(CmpLHS) || cast<Instruction>(CmpLHS)->getParent() != BB) { + if (DDT->pending()) + LVI->disableDT(); + else + LVI->enableDT(); for (BasicBlock *P : predecessors(BB)) { // If the value is known by LazyValueInfo to be a constant in a // predecessor, use that information to try to thread this block. @@ -803,6 +838,10 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( match(CmpLHS, m_Add(m_Value(AddLHS), m_ConstantInt(AddConst)))) { if (!isa<Instruction>(AddLHS) || cast<Instruction>(AddLHS)->getParent() != BB) { + if (DDT->pending()) + LVI->disableDT(); + else + LVI->enableDT(); for (BasicBlock *P : predecessors(BB)) { // If the value is known by LazyValueInfo to be a ConstantRange in // a predecessor, use that information to try to thread this @@ -884,6 +923,10 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( } // If all else fails, see if LVI can figure out a constant value for us. + if (DDT->pending()) + LVI->disableDT(); + else + LVI->enableDT(); Constant *CI = LVI->getConstant(V, BB, CxtI); if (Constant *KC = getKnownConstant(CI, Preference)) { for (BasicBlock *Pred : predecessors(BB)) @@ -903,10 +946,10 @@ static unsigned GetBestDestForJumpOnUndef(BasicBlock *BB) { unsigned MinSucc = 0; BasicBlock *TestBB = BBTerm->getSuccessor(MinSucc); // Compute the successor with the minimum number of predecessors. - unsigned MinNumPreds = std::distance(pred_begin(TestBB), pred_end(TestBB)); + unsigned MinNumPreds = pred_size(TestBB); for (unsigned i = 1, e = BBTerm->getNumSuccessors(); i != e; ++i) { TestBB = BBTerm->getSuccessor(i); - unsigned NumPreds = std::distance(pred_begin(TestBB), pred_end(TestBB)); + unsigned NumPreds = pred_size(TestBB); if (NumPreds < MinNumPreds) { MinSucc = i; MinNumPreds = NumPreds; @@ -931,8 +974,8 @@ static bool hasAddressTakenAndUsed(BasicBlock *BB) { bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { // If the block is trivially dead, just return and let the caller nuke it. // This simplifies other transformations. - if (pred_empty(BB) && - BB != &BB->getParent()->getEntryBlock()) + if (DDT->pendingDeletedBB(BB) || + (pred_empty(BB) && BB != &BB->getParent()->getEntryBlock())) return false; // If this block has a single predecessor, and if that pred has a single @@ -948,7 +991,7 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { LoopHeaders.insert(BB); LVI->eraseBlock(SinglePred); - MergeBasicBlockIntoOnlyPred(BB); + MergeBasicBlockIntoOnlyPred(BB, nullptr, DDT); // Now that BB is merged into SinglePred (i.e. SinglePred Code followed by // BB code within one basic block `BB`), we need to invalidate the LVI @@ -977,9 +1020,7 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { // Invalidate LVI information for BB if the LVI is not provably true for // all of BB. - if (any_of(*BB, [](Instruction &I) { - return !isGuaranteedToTransferExecutionToSuccessor(&I); - })) + if (!isGuaranteedToTransferExecutionToSuccessor(BB)) LVI->eraseBlock(BB); return true; } @@ -1031,18 +1072,23 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { // successors to branch to. Let GetBestDestForJumpOnUndef decide. if (isa<UndefValue>(Condition)) { unsigned BestSucc = GetBestDestForJumpOnUndef(BB); + std::vector<DominatorTree::UpdateType> Updates; // Fold the branch/switch. TerminatorInst *BBTerm = BB->getTerminator(); + Updates.reserve(BBTerm->getNumSuccessors()); for (unsigned i = 0, e = BBTerm->getNumSuccessors(); i != e; ++i) { if (i == BestSucc) continue; - BBTerm->getSuccessor(i)->removePredecessor(BB, true); + BasicBlock *Succ = BBTerm->getSuccessor(i); + Succ->removePredecessor(BB, true); + Updates.push_back({DominatorTree::Delete, BB, Succ}); } - DEBUG(dbgs() << " In block '" << BB->getName() - << "' folding undef terminator: " << *BBTerm << '\n'); + LLVM_DEBUG(dbgs() << " In block '" << BB->getName() + << "' folding undef terminator: " << *BBTerm << '\n'); BranchInst::Create(BBTerm->getSuccessor(BestSucc), BBTerm); BBTerm->eraseFromParent(); + DDT->applyUpdates(Updates); return true; } @@ -1050,10 +1096,11 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { // terminator to an unconditional branch. This can occur due to threading in // other blocks. if (getKnownConstant(Condition, Preference)) { - DEBUG(dbgs() << " In block '" << BB->getName() - << "' folding terminator: " << *BB->getTerminator() << '\n'); + LLVM_DEBUG(dbgs() << " In block '" << BB->getName() + << "' folding terminator: " << *BB->getTerminator() + << '\n'); ++NumFolds; - ConstantFoldTerminator(BB, true); + ConstantFoldTerminator(BB, true, nullptr, DDT); return true; } @@ -1080,13 +1127,18 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { // threading is concerned. assert(CondBr->isConditional() && "Threading on unconditional terminator"); + if (DDT->pending()) + LVI->disableDT(); + else + LVI->enableDT(); LazyValueInfo::Tristate Ret = LVI->getPredicateAt(CondCmp->getPredicate(), CondCmp->getOperand(0), CondConst, CondBr); if (Ret != LazyValueInfo::Unknown) { unsigned ToRemove = Ret == LazyValueInfo::True ? 1 : 0; unsigned ToKeep = Ret == LazyValueInfo::True ? 0 : 1; - CondBr->getSuccessor(ToRemove)->removePredecessor(BB, true); + BasicBlock *ToRemoveSucc = CondBr->getSuccessor(ToRemove); + ToRemoveSucc->removePredecessor(BB, true); BranchInst::Create(CondBr->getSuccessor(ToKeep), CondBr); CondBr->eraseFromParent(); if (CondCmp->use_empty()) @@ -1104,6 +1156,7 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { ConstantInt::getFalse(CondCmp->getType()); ReplaceFoldableUses(CondCmp, CI); } + DDT->deleteEdge(BB, ToRemoveSucc); return true; } @@ -1125,8 +1178,8 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { // TODO: There are other places where load PRE would be profitable, such as // more complex comparisons. - if (LoadInst *LI = dyn_cast<LoadInst>(SimplifyValue)) - if (SimplifyPartiallyRedundantLoad(LI)) + if (LoadInst *LoadI = dyn_cast<LoadInst>(SimplifyValue)) + if (SimplifyPartiallyRedundantLoad(LoadI)) return true; // Before threading, try to propagate profile data backwards: @@ -1182,9 +1235,12 @@ bool JumpThreadingPass::ProcessImpliedCondition(BasicBlock *BB) { Optional<bool> Implication = isImpliedCondition(PBI->getCondition(), Cond, DL, CondIsTrue); if (Implication) { - BI->getSuccessor(*Implication ? 1 : 0)->removePredecessor(BB); - BranchInst::Create(BI->getSuccessor(*Implication ? 0 : 1), BI); + BasicBlock *KeepSucc = BI->getSuccessor(*Implication ? 0 : 1); + BasicBlock *RemoveSucc = BI->getSuccessor(*Implication ? 1 : 0); + RemoveSucc->removePredecessor(BB); + BranchInst::Create(KeepSucc, BI); BI->eraseFromParent(); + DDT->deleteEdge(BB, RemoveSucc); return true; } CurrentBB = CurrentPred; @@ -1202,17 +1258,17 @@ static bool isOpDefinedInBlock(Value *Op, BasicBlock *BB) { return false; } -/// SimplifyPartiallyRedundantLoad - If LI is an obviously partially redundant -/// load instruction, eliminate it by replacing it with a PHI node. This is an -/// important optimization that encourages jump threading, and needs to be run -/// interlaced with other jump threading tasks. -bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { +/// SimplifyPartiallyRedundantLoad - If LoadI is an obviously partially +/// redundant load instruction, eliminate it by replacing it with a PHI node. +/// This is an important optimization that encourages jump threading, and needs +/// to be run interlaced with other jump threading tasks. +bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LoadI) { // Don't hack volatile and ordered loads. - if (!LI->isUnordered()) return false; + if (!LoadI->isUnordered()) return false; // If the load is defined in a block with exactly one predecessor, it can't be // partially redundant. - BasicBlock *LoadBB = LI->getParent(); + BasicBlock *LoadBB = LoadI->getParent(); if (LoadBB->getSinglePredecessor()) return false; @@ -1222,7 +1278,7 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { if (LoadBB->isEHPad()) return false; - Value *LoadedPtr = LI->getOperand(0); + Value *LoadedPtr = LoadI->getOperand(0); // If the loaded operand is defined in the LoadBB and its not a phi, // it can't be available in predecessors. @@ -1231,26 +1287,27 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { // Scan a few instructions up from the load, to see if it is obviously live at // the entry to its block. - BasicBlock::iterator BBIt(LI); + BasicBlock::iterator BBIt(LoadI); bool IsLoadCSE; if (Value *AvailableVal = FindAvailableLoadedValue( - LI, LoadBB, BBIt, DefMaxInstsToScan, AA, &IsLoadCSE)) { + LoadI, LoadBB, BBIt, DefMaxInstsToScan, AA, &IsLoadCSE)) { // If the value of the load is locally available within the block, just use // it. This frequently occurs for reg2mem'd allocas. if (IsLoadCSE) { - LoadInst *NLI = cast<LoadInst>(AvailableVal); - combineMetadataForCSE(NLI, LI); + LoadInst *NLoadI = cast<LoadInst>(AvailableVal); + combineMetadataForCSE(NLoadI, LoadI); }; // If the returned value is the load itself, replace with an undef. This can // only happen in dead loops. - if (AvailableVal == LI) AvailableVal = UndefValue::get(LI->getType()); - if (AvailableVal->getType() != LI->getType()) - AvailableVal = - CastInst::CreateBitOrPointerCast(AvailableVal, LI->getType(), "", LI); - LI->replaceAllUsesWith(AvailableVal); - LI->eraseFromParent(); + if (AvailableVal == LoadI) + AvailableVal = UndefValue::get(LoadI->getType()); + if (AvailableVal->getType() != LoadI->getType()) + AvailableVal = CastInst::CreateBitOrPointerCast( + AvailableVal, LoadI->getType(), "", LoadI); + LoadI->replaceAllUsesWith(AvailableVal); + LoadI->eraseFromParent(); return true; } @@ -1263,7 +1320,7 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { // If all of the loads and stores that feed the value have the same AA tags, // then we can propagate them onto any newly inserted loads. AAMDNodes AATags; - LI->getAAMetadata(AATags); + LoadI->getAAMetadata(AATags); SmallPtrSet<BasicBlock*, 8> PredsScanned; @@ -1285,16 +1342,17 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { Value *PredAvailable = nullptr; // NOTE: We don't CSE load that is volatile or anything stronger than // unordered, that should have been checked when we entered the function. - assert(LI->isUnordered() && "Attempting to CSE volatile or atomic loads"); + assert(LoadI->isUnordered() && + "Attempting to CSE volatile or atomic loads"); // If this is a load on a phi pointer, phi-translate it and search // for available load/store to the pointer in predecessors. Value *Ptr = LoadedPtr->DoPHITranslation(LoadBB, PredBB); PredAvailable = FindAvailablePtrLoadStore( - Ptr, LI->getType(), LI->isAtomic(), PredBB, BBIt, DefMaxInstsToScan, - AA, &IsLoadCSE, &NumScanedInst); + Ptr, LoadI->getType(), LoadI->isAtomic(), PredBB, BBIt, + DefMaxInstsToScan, AA, &IsLoadCSE, &NumScanedInst); // If PredBB has a single predecessor, continue scanning through the - // single precessor. + // single predecessor. BasicBlock *SinglePredBB = PredBB; while (!PredAvailable && SinglePredBB && BBIt == SinglePredBB->begin() && NumScanedInst < DefMaxInstsToScan) { @@ -1302,7 +1360,7 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { if (SinglePredBB) { BBIt = SinglePredBB->end(); PredAvailable = FindAvailablePtrLoadStore( - Ptr, LI->getType(), LI->isAtomic(), SinglePredBB, BBIt, + Ptr, LoadI->getType(), LoadI->isAtomic(), SinglePredBB, BBIt, (DefMaxInstsToScan - NumScanedInst), AA, &IsLoadCSE, &NumScanedInst); } @@ -1334,15 +1392,15 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { // If the value is unavailable in one of predecessors, we will end up // inserting a new instruction into them. It is only valid if all the - // instructions before LI are guaranteed to pass execution to its successor, - // or if LI is safe to speculate. + // instructions before LoadI are guaranteed to pass execution to its + // successor, or if LoadI is safe to speculate. // TODO: If this logic becomes more complex, and we will perform PRE insertion // farther than to a predecessor, we need to reuse the code from GVN's PRE. // It requires domination tree analysis, so for this simple case it is an // overkill. if (PredsScanned.size() != AvailablePreds.size() && - !isSafeToSpeculativelyExecute(LI)) - for (auto I = LoadBB->begin(); &*I != LI; ++I) + !isSafeToSpeculativelyExecute(LoadI)) + for (auto I = LoadBB->begin(); &*I != LoadI; ++I) if (!isGuaranteedToTransferExecutionToSuccessor(&*I)) return false; @@ -1381,11 +1439,12 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { if (UnavailablePred) { assert(UnavailablePred->getTerminator()->getNumSuccessors() == 1 && "Can't handle critical edge here!"); - LoadInst *NewVal = new LoadInst( - LoadedPtr->DoPHITranslation(LoadBB, UnavailablePred), - LI->getName() + ".pr", false, LI->getAlignment(), LI->getOrdering(), - LI->getSyncScopeID(), UnavailablePred->getTerminator()); - NewVal->setDebugLoc(LI->getDebugLoc()); + LoadInst *NewVal = + new LoadInst(LoadedPtr->DoPHITranslation(LoadBB, UnavailablePred), + LoadI->getName() + ".pr", false, LoadI->getAlignment(), + LoadI->getOrdering(), LoadI->getSyncScopeID(), + UnavailablePred->getTerminator()); + NewVal->setDebugLoc(LoadI->getDebugLoc()); if (AATags) NewVal->setAAMetadata(AATags); @@ -1398,10 +1457,10 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { // Create a PHI node at the start of the block for the PRE'd load value. pred_iterator PB = pred_begin(LoadBB), PE = pred_end(LoadBB); - PHINode *PN = PHINode::Create(LI->getType(), std::distance(PB, PE), "", + PHINode *PN = PHINode::Create(LoadI->getType(), std::distance(PB, PE), "", &LoadBB->front()); - PN->takeName(LI); - PN->setDebugLoc(LI->getDebugLoc()); + PN->takeName(LoadI); + PN->setDebugLoc(LoadI->getDebugLoc()); // Insert new entries into the PHI for each predecessor. A single block may // have multiple entries here. @@ -1419,19 +1478,19 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { // AvailablePreds vector as we go so that all of the PHI entries for this // predecessor use the same bitcast. Value *&PredV = I->second; - if (PredV->getType() != LI->getType()) - PredV = CastInst::CreateBitOrPointerCast(PredV, LI->getType(), "", + if (PredV->getType() != LoadI->getType()) + PredV = CastInst::CreateBitOrPointerCast(PredV, LoadI->getType(), "", P->getTerminator()); PN->addIncoming(PredV, I->first); } - for (LoadInst *PredLI : CSELoads) { - combineMetadataForCSE(PredLI, LI); + for (LoadInst *PredLoadI : CSELoads) { + combineMetadataForCSE(PredLoadI, LoadI); } - LI->replaceAllUsesWith(PN); - LI->eraseFromParent(); + LoadI->replaceAllUsesWith(PN); + LoadI->eraseFromParent(); return true; } @@ -1454,6 +1513,9 @@ FindMostPopularDest(BasicBlock *BB, if (PredToDest.second) DestPopularity[PredToDest.second]++; + if (DestPopularity.empty()) + return nullptr; + // Find the most popular dest. DenseMap<BasicBlock*, unsigned>::iterator DPI = DestPopularity.begin(); BasicBlock *MostPopularDest = DPI->first; @@ -1513,12 +1575,12 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, assert(!PredValues.empty() && "ComputeValueKnownInPredecessors returned true with no values"); - DEBUG(dbgs() << "IN BB: " << *BB; - for (const auto &PredValue : PredValues) { - dbgs() << " BB '" << BB->getName() << "': FOUND condition = " - << *PredValue.first - << " for pred '" << PredValue.second->getName() << "'.\n"; - }); + LLVM_DEBUG(dbgs() << "IN BB: " << *BB; + for (const auto &PredValue : PredValues) { + dbgs() << " BB '" << BB->getName() + << "': FOUND condition = " << *PredValue.first + << " for pred '" << PredValue.second->getName() << "'.\n"; + }); // Decide what we want to thread through. Convert our list of known values to // a list of known destinations for each pred. This also discards duplicate @@ -1588,20 +1650,24 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, // not thread. By doing so, we do not need to duplicate the current block and // also miss potential opportunities in case we dont/cant duplicate. if (OnlyDest && OnlyDest != MultipleDestSentinel) { - if (PredWithKnownDest == - (size_t)std::distance(pred_begin(BB), pred_end(BB))) { + if (PredWithKnownDest == (size_t)pred_size(BB)) { bool SeenFirstBranchToOnlyDest = false; + std::vector <DominatorTree::UpdateType> Updates; + Updates.reserve(BB->getTerminator()->getNumSuccessors() - 1); for (BasicBlock *SuccBB : successors(BB)) { - if (SuccBB == OnlyDest && !SeenFirstBranchToOnlyDest) + if (SuccBB == OnlyDest && !SeenFirstBranchToOnlyDest) { SeenFirstBranchToOnlyDest = true; // Don't modify the first branch. - else + } else { SuccBB->removePredecessor(BB, true); // This is unreachable successor. + Updates.push_back({DominatorTree::Delete, BB, SuccBB}); + } } // Finally update the terminator. TerminatorInst *Term = BB->getTerminator(); BranchInst::Create(OnlyDest, Term); Term->eraseFromParent(); + DDT->applyUpdates(Updates); // If the condition is now dead due to the removal of the old terminator, // erase it. @@ -1629,8 +1695,20 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, // threadable destination (the common case) we can avoid this. BasicBlock *MostPopularDest = OnlyDest; - if (MostPopularDest == MultipleDestSentinel) + if (MostPopularDest == MultipleDestSentinel) { + // Remove any loop headers from the Dest list, ThreadEdge conservatively + // won't process them, but we might have other destination that are eligible + // and we still want to process. + erase_if(PredToDestList, + [&](const std::pair<BasicBlock *, BasicBlock *> &PredToDest) { + return LoopHeaders.count(PredToDest.second) != 0; + }); + + if (PredToDestList.empty()) + return false; + MostPopularDest = FindMostPopularDest(BB, PredToDestList); + } // Now that we know what the most popular destination is, factor all // predecessors that will jump to it into a single predecessor. @@ -1800,11 +1878,10 @@ static void AddPHINodeEntriesForMappedBlock(BasicBlock *PHIBB, BasicBlock *OldPred, BasicBlock *NewPred, DenseMap<Instruction*, Value*> &ValueMap) { - for (BasicBlock::iterator PNI = PHIBB->begin(); - PHINode *PN = dyn_cast<PHINode>(PNI); ++PNI) { + for (PHINode &PN : PHIBB->phis()) { // Ok, we have a PHI node. Figure out what the incoming value was for the // DestBlock. - Value *IV = PN->getIncomingValueForBlock(OldPred); + Value *IV = PN.getIncomingValueForBlock(OldPred); // Remap the value if necessary. if (Instruction *Inst = dyn_cast<Instruction>(IV)) { @@ -1813,7 +1890,7 @@ static void AddPHINodeEntriesForMappedBlock(BasicBlock *PHIBB, IV = I->second; } - PN->addIncoming(IV, NewPred); + PN.addIncoming(IV, NewPred); } } @@ -1825,15 +1902,15 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, BasicBlock *SuccBB) { // If threading to the same block as we come from, we would infinite loop. if (SuccBB == BB) { - DEBUG(dbgs() << " Not threading across BB '" << BB->getName() - << "' - would thread to self!\n"); + LLVM_DEBUG(dbgs() << " Not threading across BB '" << BB->getName() + << "' - would thread to self!\n"); return false; } // If threading this would thread across a loop header, don't thread the edge. // See the comments above FindLoopHeaders for justifications and caveats. if (LoopHeaders.count(BB) || LoopHeaders.count(SuccBB)) { - DEBUG({ + LLVM_DEBUG({ bool BBIsHeader = LoopHeaders.count(BB); bool SuccIsHeader = LoopHeaders.count(SuccBB); dbgs() << " Not threading across " @@ -1847,8 +1924,8 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, unsigned JumpThreadCost = getJumpThreadDuplicationCost(BB, BB->getTerminator(), BBDupThreshold); if (JumpThreadCost > BBDupThreshold) { - DEBUG(dbgs() << " Not threading BB '" << BB->getName() - << "' - Cost is too high: " << JumpThreadCost << "\n"); + LLVM_DEBUG(dbgs() << " Not threading BB '" << BB->getName() + << "' - Cost is too high: " << JumpThreadCost << "\n"); return false; } @@ -1857,17 +1934,21 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, if (PredBBs.size() == 1) PredBB = PredBBs[0]; else { - DEBUG(dbgs() << " Factoring out " << PredBBs.size() - << " common predecessors.\n"); + LLVM_DEBUG(dbgs() << " Factoring out " << PredBBs.size() + << " common predecessors.\n"); PredBB = SplitBlockPreds(BB, PredBBs, ".thr_comm"); } // And finally, do it! - DEBUG(dbgs() << " Threading edge from '" << PredBB->getName() << "' to '" - << SuccBB->getName() << "' with cost: " << JumpThreadCost - << ", across block:\n " - << *BB << "\n"); - + LLVM_DEBUG(dbgs() << " Threading edge from '" << PredBB->getName() + << "' to '" << SuccBB->getName() + << "' with cost: " << JumpThreadCost + << ", across block:\n " << *BB << "\n"); + + if (DDT->pending()) + LVI->disableDT(); + else + LVI->enableDT(); LVI->threadEdge(PredBB, BB, SuccBB); // We are going to have to map operands from the original BB block to the new @@ -1917,15 +1998,32 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, // PHI nodes for NewBB now. AddPHINodeEntriesForMappedBlock(SuccBB, BB, NewBB, ValueMapping); + // Update the terminator of PredBB to jump to NewBB instead of BB. This + // eliminates predecessors from BB, which requires us to simplify any PHI + // nodes in BB. + TerminatorInst *PredTerm = PredBB->getTerminator(); + for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) + if (PredTerm->getSuccessor(i) == BB) { + BB->removePredecessor(PredBB, true); + PredTerm->setSuccessor(i, NewBB); + } + + // Enqueue required DT updates. + DDT->applyUpdates({{DominatorTree::Insert, NewBB, SuccBB}, + {DominatorTree::Insert, PredBB, NewBB}, + {DominatorTree::Delete, PredBB, BB}}); + // If there were values defined in BB that are used outside the block, then we // now have to update all uses of the value to use either the original value, // the cloned value, or some PHI derived value. This can require arbitrary // PHI insertion, of which we are prepared to do, clean these up now. SSAUpdater SSAUpdate; SmallVector<Use*, 16> UsesToRename; + for (Instruction &I : *BB) { - // Scan all uses of this instruction to see if it is used outside of its - // block, and if so, record them in UsesToRename. + // Scan all uses of this instruction to see if their uses are no longer + // dominated by the previous def and if so, record them in UsesToRename. + // Also, skip phi operands from PredBB - we'll remove them anyway. for (Use &U : I.uses()) { Instruction *User = cast<Instruction>(U.getUser()); if (PHINode *UserPN = dyn_cast<PHINode>(User)) { @@ -1940,8 +2038,7 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, // If there are no uses outside the block, we're done with this instruction. if (UsesToRename.empty()) continue; - - DEBUG(dbgs() << "JT: Renaming non-local uses of: " << I << "\n"); + LLVM_DEBUG(dbgs() << "JT: Renaming non-local uses of: " << I << "\n"); // We found a use of I outside of BB. Rename all uses of I that are outside // its block to be uses of the appropriate PHI node etc. See ValuesInBlocks @@ -1952,19 +2049,9 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, while (!UsesToRename.empty()) SSAUpdate.RewriteUse(*UsesToRename.pop_back_val()); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << "\n"); } - // Ok, NewBB is good to go. Update the terminator of PredBB to jump to - // NewBB instead of BB. This eliminates predecessors from BB, which requires - // us to simplify any PHI nodes in BB. - TerminatorInst *PredTerm = PredBB->getTerminator(); - for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) - if (PredTerm->getSuccessor(i) == BB) { - BB->removePredecessor(PredBB, true); - PredTerm->setSuccessor(i, NewBB); - } - // At this point, the IR is fully up to date and consistent. Do a quick scan // over the new instructions and zap any that are constants or dead. This // frequently happens because of phi translation. @@ -1984,20 +2071,42 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, BasicBlock *JumpThreadingPass::SplitBlockPreds(BasicBlock *BB, ArrayRef<BasicBlock *> Preds, const char *Suffix) { + SmallVector<BasicBlock *, 2> NewBBs; + // Collect the frequencies of all predecessors of BB, which will be used to - // update the edge weight on BB->SuccBB. - BlockFrequency PredBBFreq(0); + // update the edge weight of the result of splitting predecessors. + DenseMap<BasicBlock *, BlockFrequency> FreqMap; if (HasProfileData) for (auto Pred : Preds) - PredBBFreq += BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, BB); + FreqMap.insert(std::make_pair( + Pred, BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, BB))); + + // In the case when BB is a LandingPad block we create 2 new predecessors + // instead of just one. + if (BB->isLandingPad()) { + std::string NewName = std::string(Suffix) + ".split-lp"; + SplitLandingPadPredecessors(BB, Preds, Suffix, NewName.c_str(), NewBBs); + } else { + NewBBs.push_back(SplitBlockPredecessors(BB, Preds, Suffix)); + } - BasicBlock *PredBB = SplitBlockPredecessors(BB, Preds, Suffix); + std::vector<DominatorTree::UpdateType> Updates; + Updates.reserve((2 * Preds.size()) + NewBBs.size()); + for (auto NewBB : NewBBs) { + BlockFrequency NewBBFreq(0); + Updates.push_back({DominatorTree::Insert, NewBB, BB}); + for (auto Pred : predecessors(NewBB)) { + Updates.push_back({DominatorTree::Delete, Pred, BB}); + Updates.push_back({DominatorTree::Insert, Pred, NewBB}); + if (HasProfileData) // Update frequencies between Pred -> NewBB. + NewBBFreq += FreqMap.lookup(Pred); + } + if (HasProfileData) // Apply the summed frequency to NewBB. + BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); + } - // Set the block frequency of the newly created PredBB, which is the sum of - // frequencies of Preds. - if (HasProfileData) - BFI->setBlockFreq(PredBB, PredBBFreq.getFrequency()); - return PredBB; + DDT->applyUpdates(Updates); + return NewBBs[0]; } bool JumpThreadingPass::doesBlockHaveProfileData(BasicBlock *BB) { @@ -2126,42 +2235,49 @@ bool JumpThreadingPass::DuplicateCondBranchOnPHIIntoPred( // cause us to transform this into an irreducible loop, don't do this. // See the comments above FindLoopHeaders for justifications and caveats. if (LoopHeaders.count(BB)) { - DEBUG(dbgs() << " Not duplicating loop header '" << BB->getName() - << "' into predecessor block '" << PredBBs[0]->getName() - << "' - it might create an irreducible loop!\n"); + LLVM_DEBUG(dbgs() << " Not duplicating loop header '" << BB->getName() + << "' into predecessor block '" << PredBBs[0]->getName() + << "' - it might create an irreducible loop!\n"); return false; } unsigned DuplicationCost = getJumpThreadDuplicationCost(BB, BB->getTerminator(), BBDupThreshold); if (DuplicationCost > BBDupThreshold) { - DEBUG(dbgs() << " Not duplicating BB '" << BB->getName() - << "' - Cost is too high: " << DuplicationCost << "\n"); + LLVM_DEBUG(dbgs() << " Not duplicating BB '" << BB->getName() + << "' - Cost is too high: " << DuplicationCost << "\n"); return false; } // And finally, do it! Start by factoring the predecessors if needed. + std::vector<DominatorTree::UpdateType> Updates; BasicBlock *PredBB; if (PredBBs.size() == 1) PredBB = PredBBs[0]; else { - DEBUG(dbgs() << " Factoring out " << PredBBs.size() - << " common predecessors.\n"); + LLVM_DEBUG(dbgs() << " Factoring out " << PredBBs.size() + << " common predecessors.\n"); PredBB = SplitBlockPreds(BB, PredBBs, ".thr_comm"); } + Updates.push_back({DominatorTree::Delete, PredBB, BB}); // Okay, we decided to do this! Clone all the instructions in BB onto the end // of PredBB. - DEBUG(dbgs() << " Duplicating block '" << BB->getName() << "' into end of '" - << PredBB->getName() << "' to eliminate branch on phi. Cost: " - << DuplicationCost << " block is:" << *BB << "\n"); + LLVM_DEBUG(dbgs() << " Duplicating block '" << BB->getName() + << "' into end of '" << PredBB->getName() + << "' to eliminate branch on phi. Cost: " + << DuplicationCost << " block is:" << *BB << "\n"); // Unless PredBB ends with an unconditional branch, split the edge so that we // can just clone the bits from BB into the end of the new PredBB. BranchInst *OldPredBranch = dyn_cast<BranchInst>(PredBB->getTerminator()); if (!OldPredBranch || !OldPredBranch->isUnconditional()) { - PredBB = SplitEdge(PredBB, BB); + BasicBlock *OldPredBB = PredBB; + PredBB = SplitEdge(OldPredBB, BB); + Updates.push_back({DominatorTree::Insert, OldPredBB, PredBB}); + Updates.push_back({DominatorTree::Insert, PredBB, BB}); + Updates.push_back({DominatorTree::Delete, OldPredBB, BB}); OldPredBranch = cast<BranchInst>(PredBB->getTerminator()); } @@ -2203,6 +2319,10 @@ bool JumpThreadingPass::DuplicateCondBranchOnPHIIntoPred( // Otherwise, insert the new instruction into the block. New->setName(BI->getName()); PredBB->getInstList().insert(OldPredBranch->getIterator(), New); + // Update Dominance from simplified New instruction operands. + for (unsigned i = 0, e = New->getNumOperands(); i != e; ++i) + if (BasicBlock *SuccBB = dyn_cast<BasicBlock>(New->getOperand(i))) + Updates.push_back({DominatorTree::Insert, PredBB, SuccBB}); } } @@ -2238,7 +2358,7 @@ bool JumpThreadingPass::DuplicateCondBranchOnPHIIntoPred( if (UsesToRename.empty()) continue; - DEBUG(dbgs() << "JT: Renaming non-local uses of: " << I << "\n"); + LLVM_DEBUG(dbgs() << "JT: Renaming non-local uses of: " << I << "\n"); // We found a use of I outside of BB. Rename all uses of I that are outside // its block to be uses of the appropriate PHI node etc. See ValuesInBlocks @@ -2249,7 +2369,7 @@ bool JumpThreadingPass::DuplicateCondBranchOnPHIIntoPred( while (!UsesToRename.empty()) SSAUpdate.RewriteUse(*UsesToRename.pop_back_val()); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << "\n"); } // PredBB no longer jumps to BB, remove entries in the PHI node for the edge @@ -2258,6 +2378,7 @@ bool JumpThreadingPass::DuplicateCondBranchOnPHIIntoPred( // Remove the unconditional branch at the end of the PredBB block. OldPredBranch->eraseFromParent(); + DDT->applyUpdates(Updates); ++NumDupes; return true; @@ -2300,6 +2421,10 @@ bool JumpThreadingPass::TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) { // Now check if one of the select values would allow us to constant fold the // terminator in BB. We don't do the transform if both sides fold, those // cases will be threaded in any case. + if (DDT->pending()) + LVI->disableDT(); + else + LVI->enableDT(); LazyValueInfo::Tristate LHSFolds = LVI->getPredicateOnEdge(CondCmp->getPredicate(), SI->getOperand(1), CondRHS, Pred, BB, CondCmp); @@ -2330,6 +2455,8 @@ bool JumpThreadingPass::TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) { // The select is now dead. SI->eraseFromParent(); + DDT->applyUpdates({{DominatorTree::Insert, NewBB, BB}, + {DominatorTree::Insert, Pred, NewBB}}); // Update any other PHI nodes in BB. for (BasicBlock::iterator BI = BB->begin(); PHINode *Phi = dyn_cast<PHINode>(BI); ++BI) @@ -2395,7 +2522,7 @@ bool JumpThreadingPass::TryToUnfoldSelectInCurrBB(BasicBlock *BB) { break; } } else if (SelectInst *SelectI = dyn_cast<SelectInst>(U.getUser())) { - // Look for a Select in BB that uses PN as condtion. + // Look for a Select in BB that uses PN as condition. if (isUnfoldCandidate(SelectI, U.get())) { SI = SelectI; break; @@ -2408,11 +2535,25 @@ bool JumpThreadingPass::TryToUnfoldSelectInCurrBB(BasicBlock *BB) { // Expand the select. TerminatorInst *Term = SplitBlockAndInsertIfThen(SI->getCondition(), SI, false); + BasicBlock *SplitBB = SI->getParent(); + BasicBlock *NewBB = Term->getParent(); PHINode *NewPN = PHINode::Create(SI->getType(), 2, "", SI); NewPN->addIncoming(SI->getTrueValue(), Term->getParent()); NewPN->addIncoming(SI->getFalseValue(), BB); SI->replaceAllUsesWith(NewPN); SI->eraseFromParent(); + // NewBB and SplitBB are newly created blocks which require insertion. + std::vector<DominatorTree::UpdateType> Updates; + Updates.reserve((2 * SplitBB->getTerminator()->getNumSuccessors()) + 3); + Updates.push_back({DominatorTree::Insert, BB, SplitBB}); + Updates.push_back({DominatorTree::Insert, BB, NewBB}); + Updates.push_back({DominatorTree::Insert, NewBB, SplitBB}); + // BB's successors were moved to SplitBB, update DDT accordingly. + for (auto *Succ : successors(SplitBB)) { + Updates.push_back({DominatorTree::Delete, BB, Succ}); + Updates.push_back({DominatorTree::Insert, SplitBB, Succ}); + } + DDT->applyUpdates(Updates); return true; } return false; @@ -2499,8 +2640,8 @@ bool JumpThreadingPass::ThreadGuard(BasicBlock *BB, IntrinsicInst *Guard, if (!TrueDestIsSafe && !FalseDestIsSafe) return false; - BasicBlock *UnguardedBlock = TrueDestIsSafe ? TrueDest : FalseDest; - BasicBlock *GuardedBlock = FalseDestIsSafe ? TrueDest : FalseDest; + BasicBlock *PredUnguardedBlock = TrueDestIsSafe ? TrueDest : FalseDest; + BasicBlock *PredGuardedBlock = FalseDestIsSafe ? TrueDest : FalseDest; ValueToValueMapTy UnguardedMapping, GuardedMapping; Instruction *AfterGuard = Guard->getNextNode(); @@ -2509,18 +2650,29 @@ bool JumpThreadingPass::ThreadGuard(BasicBlock *BB, IntrinsicInst *Guard, return false; // Duplicate all instructions before the guard and the guard itself to the // branch where implication is not proved. - GuardedBlock = DuplicateInstructionsInSplitBetween( - BB, GuardedBlock, AfterGuard, GuardedMapping); + BasicBlock *GuardedBlock = DuplicateInstructionsInSplitBetween( + BB, PredGuardedBlock, AfterGuard, GuardedMapping); assert(GuardedBlock && "Could not create the guarded block?"); // Duplicate all instructions before the guard in the unguarded branch. // Since we have successfully duplicated the guarded block and this block // has fewer instructions, we expect it to succeed. - UnguardedBlock = DuplicateInstructionsInSplitBetween(BB, UnguardedBlock, - Guard, UnguardedMapping); + BasicBlock *UnguardedBlock = DuplicateInstructionsInSplitBetween( + BB, PredUnguardedBlock, Guard, UnguardedMapping); assert(UnguardedBlock && "Could not create the unguarded block?"); - DEBUG(dbgs() << "Moved guard " << *Guard << " to block " - << GuardedBlock->getName() << "\n"); - + LLVM_DEBUG(dbgs() << "Moved guard " << *Guard << " to block " + << GuardedBlock->getName() << "\n"); + // DuplicateInstructionsInSplitBetween inserts a new block "BB.split" between + // PredBB and BB. We need to perform two inserts and one delete for each of + // the above calls to update Dominators. + DDT->applyUpdates( + {// Guarded block split. + {DominatorTree::Delete, PredGuardedBlock, BB}, + {DominatorTree::Insert, PredGuardedBlock, GuardedBlock}, + {DominatorTree::Insert, GuardedBlock, BB}, + // Unguarded block split. + {DominatorTree::Delete, PredUnguardedBlock, BB}, + {DominatorTree::Insert, PredUnguardedBlock, UnguardedBlock}, + {DominatorTree::Insert, UnguardedBlock, BB}}); // Some instructions before the guard may still have uses. For them, we need // to create Phi nodes merging their copies in both guarded and unguarded // branches. Those instructions that have no uses can be just removed. diff --git a/lib/Transforms/Scalar/LICM.cpp b/lib/Transforms/Scalar/LICM.cpp index 4ea935793b80..ff66632f0391 100644 --- a/lib/Transforms/Scalar/LICM.cpp +++ b/lib/Transforms/Scalar/LICM.cpp @@ -47,6 +47,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" @@ -64,7 +65,6 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include <algorithm> @@ -97,7 +97,7 @@ static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE); static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, - const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo, + const Loop *CurLoop, LoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE, bool FreeInLoop); static bool isSafeToExecuteUnconditionally(Instruction &Inst, const DominatorTree *DT, @@ -170,7 +170,8 @@ struct LegacyLICMPass : public LoopPass { /// loop preheaders be inserted into the CFG... /// void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); if (EnableMSSALoopDependency) AU.addRequired<MemorySSAWrapperPass>(); @@ -220,7 +221,10 @@ PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM, return PreservedAnalyses::all(); auto PA = getLoopPassPreservedAnalyses(); - PA.preserveSet<CFGAnalyses>(); + + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<LoopAnalysis>(); + return PA; } @@ -392,7 +396,8 @@ bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, // If the instruction is dead, we would try to sink it because it isn't // used in the loop, instead, just delete it. if (isInstructionTriviallyDead(&I, TLI)) { - DEBUG(dbgs() << "LICM deleting dead inst: " << I << '\n'); + LLVM_DEBUG(dbgs() << "LICM deleting dead inst: " << I << '\n'); + salvageDebugInfo(I); ++II; CurAST->deleteValue(&I); I.eraseFromParent(); @@ -445,101 +450,78 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, BasicBlock *BB = DTN->getBlock(); // Only need to process the contents of this block if it is not part of a // subloop (which would already have been processed). - if (!inSubLoop(BB, CurLoop, LI)) - for (BasicBlock::iterator II = BB->begin(), E = BB->end(); II != E;) { - Instruction &I = *II++; - // Try constant folding this instruction. If all the operands are - // constants, it is technically hoistable, but it would be better to - // just fold it. - if (Constant *C = ConstantFoldInstruction( - &I, I.getModule()->getDataLayout(), TLI)) { - DEBUG(dbgs() << "LICM folding inst: " << I << " --> " << *C << '\n'); - CurAST->copyValue(&I, C); - I.replaceAllUsesWith(C); - if (isInstructionTriviallyDead(&I, TLI)) { - CurAST->deleteValue(&I); - I.eraseFromParent(); - } - Changed = true; - continue; - } + if (inSubLoop(BB, CurLoop, LI)) + continue; - // Attempt to remove floating point division out of the loop by - // converting it to a reciprocal multiplication. - if (I.getOpcode() == Instruction::FDiv && - CurLoop->isLoopInvariant(I.getOperand(1)) && - I.hasAllowReciprocal()) { - auto Divisor = I.getOperand(1); - auto One = llvm::ConstantFP::get(Divisor->getType(), 1.0); - auto ReciprocalDivisor = BinaryOperator::CreateFDiv(One, Divisor); - ReciprocalDivisor->setFastMathFlags(I.getFastMathFlags()); - ReciprocalDivisor->insertBefore(&I); - - auto Product = - BinaryOperator::CreateFMul(I.getOperand(0), ReciprocalDivisor); - Product->setFastMathFlags(I.getFastMathFlags()); - Product->insertAfter(&I); - I.replaceAllUsesWith(Product); + // Keep track of whether the prefix of instructions visited so far are such + // that the next instruction visited is guaranteed to execute if the loop + // is entered. + bool IsMustExecute = CurLoop->getHeader() == BB; + + for (BasicBlock::iterator II = BB->begin(), E = BB->end(); II != E;) { + Instruction &I = *II++; + // Try constant folding this instruction. If all the operands are + // constants, it is technically hoistable, but it would be better to + // just fold it. + if (Constant *C = ConstantFoldInstruction( + &I, I.getModule()->getDataLayout(), TLI)) { + LLVM_DEBUG(dbgs() << "LICM folding inst: " << I << " --> " << *C + << '\n'); + CurAST->copyValue(&I, C); + I.replaceAllUsesWith(C); + if (isInstructionTriviallyDead(&I, TLI)) { + CurAST->deleteValue(&I); I.eraseFromParent(); - - hoist(*ReciprocalDivisor, DT, CurLoop, SafetyInfo, ORE); - Changed = true; - continue; } + Changed = true; + continue; + } + + // Try hoisting the instruction out to the preheader. We can only do + // this if all of the operands of the instruction are loop invariant and + // if it is safe to hoist the instruction. + // + if (CurLoop->hasLoopInvariantOperands(&I) && + canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, SafetyInfo, ORE) && + (IsMustExecute || + isSafeToExecuteUnconditionally( + I, DT, CurLoop, SafetyInfo, ORE, + CurLoop->getLoopPreheader()->getTerminator()))) { + Changed |= hoist(I, DT, CurLoop, SafetyInfo, ORE); + continue; + } - // Try hoisting the instruction out to the preheader. We can only do - // this if all of the operands of the instruction are loop invariant and - // if it is safe to hoist the instruction. - // - if (CurLoop->hasLoopInvariantOperands(&I) && - canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, SafetyInfo, ORE) && - isSafeToExecuteUnconditionally( - I, DT, CurLoop, SafetyInfo, ORE, - CurLoop->getLoopPreheader()->getTerminator())) - Changed |= hoist(I, DT, CurLoop, SafetyInfo, ORE); + // Attempt to remove floating point division out of the loop by + // converting it to a reciprocal multiplication. + if (I.getOpcode() == Instruction::FDiv && + CurLoop->isLoopInvariant(I.getOperand(1)) && + I.hasAllowReciprocal()) { + auto Divisor = I.getOperand(1); + auto One = llvm::ConstantFP::get(Divisor->getType(), 1.0); + auto ReciprocalDivisor = BinaryOperator::CreateFDiv(One, Divisor); + ReciprocalDivisor->setFastMathFlags(I.getFastMathFlags()); + ReciprocalDivisor->insertBefore(&I); + + auto Product = + BinaryOperator::CreateFMul(I.getOperand(0), ReciprocalDivisor); + Product->setFastMathFlags(I.getFastMathFlags()); + Product->insertAfter(&I); + I.replaceAllUsesWith(Product); + I.eraseFromParent(); + + hoist(*ReciprocalDivisor, DT, CurLoop, SafetyInfo, ORE); + Changed = true; + continue; } + + if (IsMustExecute) + IsMustExecute = isGuaranteedToTransferExecutionToSuccessor(&I); + } } return Changed; } -/// Computes loop safety information, checks loop body & header -/// for the possibility of may throw exception. -/// -void llvm::computeLoopSafetyInfo(LoopSafetyInfo *SafetyInfo, Loop *CurLoop) { - assert(CurLoop != nullptr && "CurLoop cant be null"); - BasicBlock *Header = CurLoop->getHeader(); - // Setting default safety values. - SafetyInfo->MayThrow = false; - SafetyInfo->HeaderMayThrow = false; - // Iterate over header and compute safety info. - for (BasicBlock::iterator I = Header->begin(), E = Header->end(); - (I != E) && !SafetyInfo->HeaderMayThrow; ++I) - SafetyInfo->HeaderMayThrow |= - !isGuaranteedToTransferExecutionToSuccessor(&*I); - - SafetyInfo->MayThrow = SafetyInfo->HeaderMayThrow; - // Iterate over loop instructions and compute safety info. - // Skip header as it has been computed and stored in HeaderMayThrow. - // The first block in loopinfo.Blocks is guaranteed to be the header. - assert(Header == *CurLoop->getBlocks().begin() && - "First block must be header"); - for (Loop::block_iterator BB = std::next(CurLoop->block_begin()), - BBE = CurLoop->block_end(); - (BB != BBE) && !SafetyInfo->MayThrow; ++BB) - for (BasicBlock::iterator I = (*BB)->begin(), E = (*BB)->end(); - (I != E) && !SafetyInfo->MayThrow; ++I) - SafetyInfo->MayThrow |= !isGuaranteedToTransferExecutionToSuccessor(&*I); - - // Compute funclet colors if we might sink/hoist in a function with a funclet - // personality routine. - Function *Fn = CurLoop->getHeader()->getParent(); - if (Fn->hasPersonalityFn()) - if (Constant *PersonalityFn = Fn->getPersonalityFn()) - if (isFuncletEHPersonality(classifyEHPersonality(PersonalityFn))) - SafetyInfo->BlockColors = colorEHFunclets(*Fn); -} - // Return true if LI is invariant within scope of the loop. LI is invariant if // CurLoop is dominated by an invariant.start representing the same memory // location and size as the memory location LI loads from, and also the @@ -708,7 +690,7 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, /// This is true when all incoming values are that instruction. /// This pattern occurs most often with LCSSA PHI nodes. /// -static bool isTriviallyReplacablePHI(const PHINode &PN, const Instruction &I) { +static bool isTriviallyReplaceablePHI(const PHINode &PN, const Instruction &I) { for (const Value *IncValue : PN.incoming_values()) if (IncValue != &I) return false; @@ -838,12 +820,12 @@ CloneInstructionInExitBlock(Instruction &I, BasicBlock &ExitBlock, PHINode &PN, return New; } -static Instruction *sinkThroughTriviallyReplacablePHI( +static Instruction *sinkThroughTriviallyReplaceablePHI( PHINode *TPN, Instruction *I, LoopInfo *LI, SmallDenseMap<BasicBlock *, Instruction *, 32> &SunkCopies, const LoopSafetyInfo *SafetyInfo, const Loop *CurLoop) { - assert(isTriviallyReplacablePHI(*TPN, *I) && - "Expect only trivially replacalbe PHI"); + assert(isTriviallyReplaceablePHI(*TPN, *I) && + "Expect only trivially replaceable PHI"); BasicBlock *ExitBlock = TPN->getParent(); Instruction *New; auto It = SunkCopies.find(ExitBlock); @@ -855,10 +837,16 @@ static Instruction *sinkThroughTriviallyReplacablePHI( return New; } -static bool canSplitPredecessors(PHINode *PN) { +static bool canSplitPredecessors(PHINode *PN, LoopSafetyInfo *SafetyInfo) { BasicBlock *BB = PN->getParent(); if (!BB->canSplitPredecessors()) return false; + // It's not impossible to split EHPad blocks, but if BlockColors already exist + // it require updating BlockColors for all offspring blocks accordingly. By + // skipping such corner case, we can make updating BlockColors after splitting + // predecessor fairly simple. + if (!SafetyInfo->BlockColors.empty() && BB->getFirstNonPHI()->isEHPad()) + return false; for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { BasicBlock *BBPred = *PI; if (isa<IndirectBrInst>(BBPred->getTerminator())) @@ -868,7 +856,8 @@ static bool canSplitPredecessors(PHINode *PN) { } static void splitPredecessorsOfLoopExit(PHINode *PN, DominatorTree *DT, - LoopInfo *LI, const Loop *CurLoop) { + LoopInfo *LI, const Loop *CurLoop, + LoopSafetyInfo *SafetyInfo) { #ifndef NDEBUG SmallVector<BasicBlock *, 32> ExitBlocks; CurLoop->getUniqueExitBlocks(ExitBlocks); @@ -879,7 +868,7 @@ static void splitPredecessorsOfLoopExit(PHINode *PN, DominatorTree *DT, assert(ExitBlockSet.count(ExitBB) && "Expect the PHI is in an exit block."); // Split predecessors of the loop exit to make instructions in the loop are - // exposed to exit blocks through trivially replacable PHIs while keeping the + // exposed to exit blocks through trivially replaceable PHIs while keeping the // loop in the canonical form where each predecessor of each exit block should // be contained within the loop. For example, this will convert the loop below // from @@ -891,7 +880,7 @@ static void splitPredecessorsOfLoopExit(PHINode *PN, DominatorTree *DT, // %v2 = // br %LE, %LB1 // LE: - // %p = phi [%v1, %LB1], [%v2, %LB2] <-- non-trivially replacable + // %p = phi [%v1, %LB1], [%v2, %LB2] <-- non-trivially replaceable // // to // @@ -902,21 +891,35 @@ static void splitPredecessorsOfLoopExit(PHINode *PN, DominatorTree *DT, // %v2 = // br %LE.split2, %LB1 // LE.split: - // %p1 = phi [%v1, %LB1] <-- trivially replacable + // %p1 = phi [%v1, %LB1] <-- trivially replaceable // br %LE // LE.split2: - // %p2 = phi [%v2, %LB2] <-- trivially replacable + // %p2 = phi [%v2, %LB2] <-- trivially replaceable // br %LE // LE: // %p = phi [%p1, %LE.split], [%p2, %LE.split2] // + auto &BlockColors = SafetyInfo->BlockColors; SmallSetVector<BasicBlock *, 8> PredBBs(pred_begin(ExitBB), pred_end(ExitBB)); while (!PredBBs.empty()) { BasicBlock *PredBB = *PredBBs.begin(); assert(CurLoop->contains(PredBB) && "Expect all predecessors are in the loop"); - if (PN->getBasicBlockIndex(PredBB) >= 0) - SplitBlockPredecessors(ExitBB, PredBB, ".split.loop.exit", DT, LI, true); + if (PN->getBasicBlockIndex(PredBB) >= 0) { + BasicBlock *NewPred = SplitBlockPredecessors( + ExitBB, PredBB, ".split.loop.exit", DT, LI, true); + // Since we do not allow splitting EH-block with BlockColors in + // canSplitPredecessors(), we can simply assign predecessor's color to + // the new block. + if (!BlockColors.empty()) { + // Grab a reference to the ColorVector to be inserted before getting the + // reference to the vector we are copying because inserting the new + // element in BlockColors might cause the map to be reallocated. + ColorVector &ColorsForNewBlock = BlockColors[NewPred]; + ColorVector &ColorsForOldBlock = BlockColors[PredBB]; + ColorsForNewBlock = ColorsForOldBlock; + } + } PredBBs.remove(PredBB); } } @@ -927,9 +930,9 @@ static void splitPredecessorsOfLoopExit(PHINode *PN, DominatorTree *DT, /// position, and may either delete it or move it to outside of the loop. /// static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, - const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo, + const Loop *CurLoop, LoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE, bool FreeInLoop) { - DEBUG(dbgs() << "LICM sinking instruction: " << I << "\n"); + LLVM_DEBUG(dbgs() << "LICM sinking instruction: " << I << "\n"); ORE->emit([&]() { return OptimizationRemark(DEBUG_TYPE, "InstSunk", &I) << "sinking " << ore::NV("Inst", &I); @@ -972,15 +975,15 @@ static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, } VisitedUsers.insert(PN); - if (isTriviallyReplacablePHI(*PN, I)) + if (isTriviallyReplaceablePHI(*PN, I)) continue; - if (!canSplitPredecessors(PN)) + if (!canSplitPredecessors(PN, SafetyInfo)) return Changed; // Split predecessors of the PHI so that we can make users trivially - // replacable. - splitPredecessorsOfLoopExit(PN, DT, LI, CurLoop); + // replaceable. + splitPredecessorsOfLoopExit(PN, DT, LI, CurLoop, SafetyInfo); // Should rebuild the iterators, as they may be invalidated by // splitPredecessorsOfLoopExit(). @@ -1014,9 +1017,9 @@ static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, PHINode *PN = cast<PHINode>(User); assert(ExitBlockSet.count(PN->getParent()) && "The LCSSA PHI is not in an exit block!"); - // The PHI must be trivially replacable. - Instruction *New = sinkThroughTriviallyReplacablePHI(PN, &I, LI, SunkCopies, - SafetyInfo, CurLoop); + // The PHI must be trivially replaceable. + Instruction *New = sinkThroughTriviallyReplaceablePHI(PN, &I, LI, SunkCopies, + SafetyInfo, CurLoop); PN->replaceAllUsesWith(New); PN->eraseFromParent(); Changed = true; @@ -1031,8 +1034,8 @@ static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE) { auto *Preheader = CurLoop->getLoopPreheader(); - DEBUG(dbgs() << "LICM hoisting to " << Preheader->getName() << ": " << I - << "\n"); + LLVM_DEBUG(dbgs() << "LICM hoisting to " << Preheader->getName() << ": " << I + << "\n"); ORE->emit([&]() { return OptimizationRemark(DEBUG_TYPE, "Hoisted", &I) << "hoisting " << ore::NV("Inst", &I); @@ -1221,7 +1224,7 @@ bool llvm::promoteLoopAccessesToScalars( Value *SomePtr = *PointerMustAliases.begin(); BasicBlock *Preheader = CurLoop->getLoopPreheader(); - // It isn't safe to promote a load/store from the loop if the load/store is + // It is not safe to promote a load/store from the loop if the load/store is // conditional. For example, turning: // // for () { if (c) *P += 1; } @@ -1350,7 +1353,7 @@ bool llvm::promoteLoopAccessesToScalars( // If a store dominates all exit blocks, it is safe to sink. // As explained above, if an exit block was executed, a dominating - // store must have been been executed at least once, so we are not + // store must have been executed at least once, so we are not // introducing stores on paths that did not have them. // Note that this only looks at explicit exit blocks. If we ever // start sinking stores into unwind edges (see above), this will break. @@ -1412,8 +1415,8 @@ bool llvm::promoteLoopAccessesToScalars( return false; // Otherwise, this is safe to promote, lets do it! - DEBUG(dbgs() << "LICM: Promoting value stored to in loop: " << *SomePtr - << '\n'); + LLVM_DEBUG(dbgs() << "LICM: Promoting value stored to in loop: " << *SomePtr + << '\n'); ORE->emit([&]() { return OptimizationRemark(DEBUG_TYPE, "PromoteLoopAccessesToScalar", LoopUses[0]) diff --git a/lib/Transforms/Scalar/LLVMBuild.txt b/lib/Transforms/Scalar/LLVMBuild.txt index 8a99df86b84a..ffe35f041b35 100644 --- a/lib/Transforms/Scalar/LLVMBuild.txt +++ b/lib/Transforms/Scalar/LLVMBuild.txt @@ -20,4 +20,4 @@ type = Library name = Scalar parent = Transforms library_name = ScalarOpts -required_libraries = Analysis Core InstCombine Support TransformUtils +required_libraries = AggressiveInstCombine Analysis Core InstCombine Support TransformUtils diff --git a/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/lib/Transforms/Scalar/LoopDataPrefetch.cpp index 7f7c6de76450..3b41b5d96c86 100644 --- a/lib/Transforms/Scalar/LoopDataPrefetch.cpp +++ b/lib/Transforms/Scalar/LoopDataPrefetch.cpp @@ -71,7 +71,7 @@ public: private: bool runOnLoop(Loop *L); - /// \brief Check if the the stride of the accesses is large enough to + /// Check if the stride of the accesses is large enough to /// warrant a prefetch. bool isStrideLargeEnough(const SCEVAddRecExpr *AR); @@ -244,9 +244,9 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { if (ItersAhead > getMaxPrefetchIterationsAhead()) return MadeChange; - DEBUG(dbgs() << "Prefetching " << ItersAhead - << " iterations ahead (loop size: " << LoopSize << ") in " - << L->getHeader()->getParent()->getName() << ": " << *L); + LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead + << " iterations ahead (loop size: " << LoopSize << ") in " + << L->getHeader()->getParent()->getName() << ": " << *L); SmallVector<std::pair<Instruction *, const SCEVAddRecExpr *>, 16> PrefLoads; for (const auto BB : L->blocks()) { @@ -275,7 +275,7 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { if (!LSCEVAddRec) continue; - // Check if the the stride of the accesses is large enough to warrant a + // Check if the stride of the accesses is large enough to warrant a // prefetch. if (!isStrideLargeEnough(LSCEVAddRec)) continue; @@ -320,8 +320,8 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { ConstantInt::get(I32, MemI->mayReadFromMemory() ? 0 : 1), ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)}); ++NumPrefetches; - DEBUG(dbgs() << " Access: " << *PtrValue << ", SCEV: " << *LSCEV - << "\n"); + LLVM_DEBUG(dbgs() << " Access: " << *PtrValue << ", SCEV: " << *LSCEV + << "\n"); ORE->emit([&]() { return OptimizationRemark(DEBUG_TYPE, "Prefetched", MemI) << "prefetched memory access"; diff --git a/lib/Transforms/Scalar/LoopDeletion.cpp b/lib/Transforms/Scalar/LoopDeletion.cpp index 82604a8842bf..d412025d7e94 100644 --- a/lib/Transforms/Scalar/LoopDeletion.cpp +++ b/lib/Transforms/Scalar/LoopDeletion.cpp @@ -49,11 +49,10 @@ static bool isLoopDead(Loop *L, ScalarEvolution &SE, // must pass through a PHI in the exit block, meaning that this check is // sufficient to guarantee that no loop-variant values are used outside // of the loop. - BasicBlock::iterator BI = ExitBlock->begin(); bool AllEntriesInvariant = true; bool AllOutgoingValuesSame = true; - while (PHINode *P = dyn_cast<PHINode>(BI)) { - Value *incoming = P->getIncomingValueForBlock(ExitingBlocks[0]); + for (PHINode &P : ExitBlock->phis()) { + Value *incoming = P.getIncomingValueForBlock(ExitingBlocks[0]); // Make sure all exiting blocks produce the same incoming value for the exit // block. If there are different incoming values for different exiting @@ -61,7 +60,7 @@ static bool isLoopDead(Loop *L, ScalarEvolution &SE, // be used. AllOutgoingValuesSame = all_of(makeArrayRef(ExitingBlocks).slice(1), [&](BasicBlock *BB) { - return incoming == P->getIncomingValueForBlock(BB); + return incoming == P.getIncomingValueForBlock(BB); }); if (!AllOutgoingValuesSame) @@ -72,8 +71,6 @@ static bool isLoopDead(Loop *L, ScalarEvolution &SE, AllEntriesInvariant = false; break; } - - ++BI; } if (Changed) @@ -145,14 +142,15 @@ static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT, // of trouble. BasicBlock *Preheader = L->getLoopPreheader(); if (!Preheader || !L->hasDedicatedExits()) { - DEBUG(dbgs() - << "Deletion requires Loop with preheader and dedicated exits.\n"); + LLVM_DEBUG( + dbgs() + << "Deletion requires Loop with preheader and dedicated exits.\n"); return LoopDeletionResult::Unmodified; } // We can't remove loops that contain subloops. If the subloops were dead, // they would already have been removed in earlier executions of this pass. if (L->begin() != L->end()) { - DEBUG(dbgs() << "Loop contains subloops.\n"); + LLVM_DEBUG(dbgs() << "Loop contains subloops.\n"); return LoopDeletionResult::Unmodified; } @@ -160,13 +158,11 @@ static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT, BasicBlock *ExitBlock = L->getUniqueExitBlock(); if (ExitBlock && isLoopNeverExecuted(L)) { - DEBUG(dbgs() << "Loop is proven to never execute, delete it!"); + LLVM_DEBUG(dbgs() << "Loop is proven to never execute, delete it!"); // Set incoming value to undef for phi nodes in the exit block. - BasicBlock::iterator BI = ExitBlock->begin(); - while (PHINode *P = dyn_cast<PHINode>(BI)) { - for (unsigned i = 0; i < P->getNumIncomingValues(); i++) - P->setIncomingValue(i, UndefValue::get(P->getType())); - BI++; + for (PHINode &P : ExitBlock->phis()) { + std::fill(P.incoming_values().begin(), P.incoming_values().end(), + UndefValue::get(P.getType())); } deleteDeadLoop(L, &DT, &SE, &LI); ++NumDeleted; @@ -183,13 +179,13 @@ static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT, // block will be branched to, or trying to preserve the branching logic in // a loop invariant manner. if (!ExitBlock) { - DEBUG(dbgs() << "Deletion requires single exit block\n"); + LLVM_DEBUG(dbgs() << "Deletion requires single exit block\n"); return LoopDeletionResult::Unmodified; } // Finally, we have to check that the loop really is dead. bool Changed = false; if (!isLoopDead(L, SE, ExitingBlocks, ExitBlock, Changed, Preheader)) { - DEBUG(dbgs() << "Loop is not invariant, cannot delete.\n"); + LLVM_DEBUG(dbgs() << "Loop is not invariant, cannot delete.\n"); return Changed ? LoopDeletionResult::Modified : LoopDeletionResult::Unmodified; } @@ -198,12 +194,12 @@ static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT, // They could be infinite, in which case we'd be changing program behavior. const SCEV *S = SE.getMaxBackedgeTakenCount(L); if (isa<SCEVCouldNotCompute>(S)) { - DEBUG(dbgs() << "Could not compute SCEV MaxBackedgeTakenCount.\n"); + LLVM_DEBUG(dbgs() << "Could not compute SCEV MaxBackedgeTakenCount.\n"); return Changed ? LoopDeletionResult::Modified : LoopDeletionResult::Unmodified; } - DEBUG(dbgs() << "Loop is invariant, delete it!"); + LLVM_DEBUG(dbgs() << "Loop is invariant, delete it!"); deleteDeadLoop(L, &DT, &SE, &LI); ++NumDeleted; @@ -214,8 +210,8 @@ PreservedAnalyses LoopDeletionPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &Updater) { - DEBUG(dbgs() << "Analyzing Loop for deletion: "); - DEBUG(L.dump()); + LLVM_DEBUG(dbgs() << "Analyzing Loop for deletion: "); + LLVM_DEBUG(L.dump()); std::string LoopName = L.getName(); auto Result = deleteLoopIfDead(&L, AR.DT, AR.SE, AR.LI); if (Result == LoopDeletionResult::Unmodified) @@ -260,8 +256,8 @@ bool LoopDeletionLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - DEBUG(dbgs() << "Analyzing Loop for deletion: "); - DEBUG(L->dump()); + LLVM_DEBUG(dbgs() << "Analyzing Loop for deletion: "); + LLVM_DEBUG(L->dump()); LoopDeletionResult Result = deleteLoopIfDead(L, DT, SE, LI); diff --git a/lib/Transforms/Scalar/LoopDistribute.cpp b/lib/Transforms/Scalar/LoopDistribute.cpp index 0d7e3db901cb..06083a4f5086 100644 --- a/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/lib/Transforms/Scalar/LoopDistribute.cpp @@ -111,7 +111,7 @@ STATISTIC(NumLoopsDistributed, "Number of loops distributed"); namespace { -/// \brief Maintains the set of instructions of the loop for a partition before +/// Maintains the set of instructions of the loop for a partition before /// cloning. After cloning, it hosts the new loop. class InstPartition { using InstructionSet = SmallPtrSet<Instruction *, 8>; @@ -122,20 +122,20 @@ public: Set.insert(I); } - /// \brief Returns whether this partition contains a dependence cycle. + /// Returns whether this partition contains a dependence cycle. bool hasDepCycle() const { return DepCycle; } - /// \brief Adds an instruction to this partition. + /// Adds an instruction to this partition. void add(Instruction *I) { Set.insert(I); } - /// \brief Collection accessors. + /// Collection accessors. InstructionSet::iterator begin() { return Set.begin(); } InstructionSet::iterator end() { return Set.end(); } InstructionSet::const_iterator begin() const { return Set.begin(); } InstructionSet::const_iterator end() const { return Set.end(); } bool empty() const { return Set.empty(); } - /// \brief Moves this partition into \p Other. This partition becomes empty + /// Moves this partition into \p Other. This partition becomes empty /// after this. void moveTo(InstPartition &Other) { Other.Set.insert(Set.begin(), Set.end()); @@ -143,7 +143,7 @@ public: Other.DepCycle |= DepCycle; } - /// \brief Populates the partition with a transitive closure of all the + /// Populates the partition with a transitive closure of all the /// instructions that the seeded instructions dependent on. void populateUsedSet() { // FIXME: We currently don't use control-dependence but simply include all @@ -166,7 +166,7 @@ public: } } - /// \brief Clones the original loop. + /// Clones the original loop. /// /// Updates LoopInfo and DominatorTree using the information that block \p /// LoopDomBB dominates the loop. @@ -179,27 +179,27 @@ public: return ClonedLoop; } - /// \brief The cloned loop. If this partition is mapped to the original loop, + /// The cloned loop. If this partition is mapped to the original loop, /// this is null. const Loop *getClonedLoop() const { return ClonedLoop; } - /// \brief Returns the loop where this partition ends up after distribution. + /// Returns the loop where this partition ends up after distribution. /// If this partition is mapped to the original loop then use the block from /// the loop. const Loop *getDistributedLoop() const { return ClonedLoop ? ClonedLoop : OrigLoop; } - /// \brief The VMap that is populated by cloning and then used in + /// The VMap that is populated by cloning and then used in /// remapinstruction to remap the cloned instructions. ValueToValueMapTy &getVMap() { return VMap; } - /// \brief Remaps the cloned instructions using VMap. + /// Remaps the cloned instructions using VMap. void remapInstructions() { remapInstructionsInBlocks(ClonedLoopBlocks, VMap); } - /// \brief Based on the set of instructions selected for this partition, + /// Based on the set of instructions selected for this partition, /// removes the unnecessary ones. void removeUnusedInsts() { SmallVector<Instruction *, 8> Unused; @@ -239,30 +239,30 @@ public: } private: - /// \brief Instructions from OrigLoop selected for this partition. + /// Instructions from OrigLoop selected for this partition. InstructionSet Set; - /// \brief Whether this partition contains a dependence cycle. + /// Whether this partition contains a dependence cycle. bool DepCycle; - /// \brief The original loop. + /// The original loop. Loop *OrigLoop; - /// \brief The cloned loop. If this partition is mapped to the original loop, + /// The cloned loop. If this partition is mapped to the original loop, /// this is null. Loop *ClonedLoop = nullptr; - /// \brief The blocks of ClonedLoop including the preheader. If this + /// The blocks of ClonedLoop including the preheader. If this /// partition is mapped to the original loop, this is empty. SmallVector<BasicBlock *, 8> ClonedLoopBlocks; - /// \brief These gets populated once the set of instructions have been + /// These gets populated once the set of instructions have been /// finalized. If this partition is mapped to the original loop, these are not /// set. ValueToValueMapTy VMap; }; -/// \brief Holds the set of Partitions. It populates them, merges them and then +/// Holds the set of Partitions. It populates them, merges them and then /// clones the loops. class InstPartitionContainer { using InstToPartitionIdT = DenseMap<Instruction *, int>; @@ -271,10 +271,10 @@ public: InstPartitionContainer(Loop *L, LoopInfo *LI, DominatorTree *DT) : L(L), LI(LI), DT(DT) {} - /// \brief Returns the number of partitions. + /// Returns the number of partitions. unsigned getSize() const { return PartitionContainer.size(); } - /// \brief Adds \p Inst into the current partition if that is marked to + /// Adds \p Inst into the current partition if that is marked to /// contain cycles. Otherwise start a new partition for it. void addToCyclicPartition(Instruction *Inst) { // If the current partition is non-cyclic. Start a new one. @@ -284,7 +284,7 @@ public: PartitionContainer.back().add(Inst); } - /// \brief Adds \p Inst into a partition that is not marked to contain + /// Adds \p Inst into a partition that is not marked to contain /// dependence cycles. /// // Initially we isolate memory instructions into as many partitions as @@ -293,7 +293,7 @@ public: PartitionContainer.emplace_back(Inst, L); } - /// \brief Merges adjacent non-cyclic partitions. + /// Merges adjacent non-cyclic partitions. /// /// The idea is that we currently only want to isolate the non-vectorizable /// partition. We could later allow more distribution among these partition @@ -303,7 +303,7 @@ public: [](const InstPartition *P) { return !P->hasDepCycle(); }); } - /// \brief If a partition contains only conditional stores, we won't vectorize + /// If a partition contains only conditional stores, we won't vectorize /// it. Try to merge it with a previous cyclic partition. void mergeNonIfConvertible() { mergeAdjacentPartitionsIf([&](const InstPartition *Partition) { @@ -323,14 +323,14 @@ public: }); } - /// \brief Merges the partitions according to various heuristics. + /// Merges the partitions according to various heuristics. void mergeBeforePopulating() { mergeAdjacentNonCyclic(); if (!DistributeNonIfConvertible) mergeNonIfConvertible(); } - /// \brief Merges partitions in order to ensure that no loads are duplicated. + /// Merges partitions in order to ensure that no loads are duplicated. /// /// We can't duplicate loads because that could potentially reorder them. /// LoopAccessAnalysis provides dependency information with the context that @@ -362,9 +362,11 @@ public: std::tie(LoadToPart, NewElt) = LoadToPartition.insert(std::make_pair(Inst, PartI)); if (!NewElt) { - DEBUG(dbgs() << "Merging partitions due to this load in multiple " - << "partitions: " << PartI << ", " - << LoadToPart->second << "\n" << *Inst << "\n"); + LLVM_DEBUG(dbgs() + << "Merging partitions due to this load in multiple " + << "partitions: " << PartI << ", " << LoadToPart->second + << "\n" + << *Inst << "\n"); auto PartJ = I; do { @@ -398,7 +400,7 @@ public: return true; } - /// \brief Sets up the mapping between instructions to partitions. If the + /// Sets up the mapping between instructions to partitions. If the /// instruction is duplicated across multiple partitions, set the entry to -1. void setupPartitionIdOnInstructions() { int PartitionID = 0; @@ -416,14 +418,14 @@ public: } } - /// \brief Populates the partition with everything that the seeding + /// Populates the partition with everything that the seeding /// instructions require. void populateUsedSet() { for (auto &P : PartitionContainer) P.populateUsedSet(); } - /// \brief This performs the main chunk of the work of cloning the loops for + /// This performs the main chunk of the work of cloning the loops for /// the partitions. void cloneLoops() { BasicBlock *OrigPH = L->getLoopPreheader(); @@ -470,13 +472,13 @@ public: Curr->getDistributedLoop()->getExitingBlock()); } - /// \brief Removes the dead instructions from the cloned loops. + /// Removes the dead instructions from the cloned loops. void removeUnusedInsts() { for (auto &Partition : PartitionContainer) Partition.removeUnusedInsts(); } - /// \brief For each memory pointer, it computes the partitionId the pointer is + /// For each memory pointer, it computes the partitionId the pointer is /// used in. /// /// This returns an array of int where the I-th entry corresponds to I-th @@ -543,10 +545,10 @@ public: private: using PartitionContainerT = std::list<InstPartition>; - /// \brief List of partitions. + /// List of partitions. PartitionContainerT PartitionContainer; - /// \brief Mapping from Instruction to partition Id. If the instruction + /// Mapping from Instruction to partition Id. If the instruction /// belongs to multiple partitions the entry contains -1. InstToPartitionIdT InstToPartitionId; @@ -554,7 +556,7 @@ private: LoopInfo *LI; DominatorTree *DT; - /// \brief The control structure to merge adjacent partitions if both satisfy + /// The control structure to merge adjacent partitions if both satisfy /// the \p Predicate. template <class UnaryPredicate> void mergeAdjacentPartitionsIf(UnaryPredicate Predicate) { @@ -575,7 +577,7 @@ private: } }; -/// \brief For each memory instruction, this class maintains difference of the +/// For each memory instruction, this class maintains difference of the /// number of unsafe dependences that start out from this instruction minus /// those that end here. /// @@ -602,7 +604,7 @@ public: const SmallVectorImpl<Dependence> &Dependences) { Accesses.append(Instructions.begin(), Instructions.end()); - DEBUG(dbgs() << "Backward dependences:\n"); + LLVM_DEBUG(dbgs() << "Backward dependences:\n"); for (auto &Dep : Dependences) if (Dep.isPossiblyBackward()) { // Note that the designations source and destination follow the program @@ -611,7 +613,7 @@ public: ++Accesses[Dep.Source].NumUnsafeDependencesStartOrEnd; --Accesses[Dep.Destination].NumUnsafeDependencesStartOrEnd; - DEBUG(Dep.print(dbgs(), 2, Instructions)); + LLVM_DEBUG(Dep.print(dbgs(), 2, Instructions)); } } @@ -619,7 +621,7 @@ private: AccessesType Accesses; }; -/// \brief The actual class performing the per-loop work. +/// The actual class performing the per-loop work. class LoopDistributeForLoop { public: LoopDistributeForLoop(Loop *L, Function *F, LoopInfo *LI, DominatorTree *DT, @@ -628,12 +630,13 @@ public: setForced(); } - /// \brief Try to distribute an inner-most loop. + /// Try to distribute an inner-most loop. bool processLoop(std::function<const LoopAccessInfo &(Loop &)> &GetLAA) { assert(L->empty() && "Only process inner loops."); - DEBUG(dbgs() << "\nLDist: In \"" << L->getHeader()->getParent()->getName() - << "\" checking " << *L << "\n"); + LLVM_DEBUG(dbgs() << "\nLDist: In \"" + << L->getHeader()->getParent()->getName() + << "\" checking " << *L << "\n"); if (!L->getExitBlock()) return fail("MultipleExitBlocks", "multiple exit blocks"); @@ -705,7 +708,7 @@ public: for (auto *Inst : DefsUsedOutside) Partitions.addToNewNonCyclicPartition(Inst); - DEBUG(dbgs() << "Seeded partitions:\n" << Partitions); + LLVM_DEBUG(dbgs() << "Seeded partitions:\n" << Partitions); if (Partitions.getSize() < 2) return fail("CantIsolateUnsafeDeps", "cannot isolate unsafe dependencies"); @@ -713,20 +716,20 @@ public: // Run the merge heuristics: Merge non-cyclic adjacent partitions since we // should be able to vectorize these together. Partitions.mergeBeforePopulating(); - DEBUG(dbgs() << "\nMerged partitions:\n" << Partitions); + LLVM_DEBUG(dbgs() << "\nMerged partitions:\n" << Partitions); if (Partitions.getSize() < 2) return fail("CantIsolateUnsafeDeps", "cannot isolate unsafe dependencies"); // Now, populate the partitions with non-memory operations. Partitions.populateUsedSet(); - DEBUG(dbgs() << "\nPopulated partitions:\n" << Partitions); + LLVM_DEBUG(dbgs() << "\nPopulated partitions:\n" << Partitions); // In order to preserve original lexical order for loads, keep them in the // partition that we set up in the MemoryInstructionDependences loop. if (Partitions.mergeToAvoidDuplicatedLoads()) { - DEBUG(dbgs() << "\nPartitions merged to ensure unique loads:\n" - << Partitions); + LLVM_DEBUG(dbgs() << "\nPartitions merged to ensure unique loads:\n" + << Partitions); if (Partitions.getSize() < 2) return fail("CantIsolateUnsafeDeps", "cannot isolate unsafe dependencies"); @@ -740,7 +743,7 @@ public: return fail("TooManySCEVRuntimeChecks", "too many SCEV run-time checks needed.\n"); - DEBUG(dbgs() << "\nDistributing loop: " << *L << "\n"); + LLVM_DEBUG(dbgs() << "\nDistributing loop: " << *L << "\n"); // We're done forming the partitions set up the reverse mapping from // instructions to partitions. Partitions.setupPartitionIdOnInstructions(); @@ -759,8 +762,8 @@ public: RtPtrChecking); if (!Pred.isAlwaysTrue() || !Checks.empty()) { - DEBUG(dbgs() << "\nPointers:\n"); - DEBUG(LAI->getRuntimePointerChecking()->printChecks(dbgs(), Checks)); + LLVM_DEBUG(dbgs() << "\nPointers:\n"); + LLVM_DEBUG(LAI->getRuntimePointerChecking()->printChecks(dbgs(), Checks)); LoopVersioning LVer(*LAI, L, LI, DT, SE, false); LVer.setAliasChecks(std::move(Checks)); LVer.setSCEVChecks(LAI->getPSE().getUnionPredicate()); @@ -775,12 +778,12 @@ public: // Now, we remove the instruction from each loop that don't belong to that // partition. Partitions.removeUnusedInsts(); - DEBUG(dbgs() << "\nAfter removing unused Instrs:\n"); - DEBUG(Partitions.printBlocks()); + LLVM_DEBUG(dbgs() << "\nAfter removing unused Instrs:\n"); + LLVM_DEBUG(Partitions.printBlocks()); if (LDistVerify) { LI->verify(*DT); - DT->verifyDomTree(); + assert(DT->verify(DominatorTree::VerificationLevel::Fast)); } ++NumLoopsDistributed; @@ -793,12 +796,12 @@ public: return true; } - /// \brief Provide diagnostics then \return with false. + /// Provide diagnostics then \return with false. bool fail(StringRef RemarkName, StringRef Message) { LLVMContext &Ctx = F->getContext(); bool Forced = isForced().getValueOr(false); - DEBUG(dbgs() << "Skipping; " << Message << "\n"); + LLVM_DEBUG(dbgs() << "Skipping; " << Message << "\n"); // With Rpass-missed report that distribution failed. ORE->emit([&]() { @@ -826,7 +829,7 @@ public: return false; } - /// \brief Return if distribution forced to be enabled/disabled for the loop. + /// Return if distribution forced to be enabled/disabled for the loop. /// /// If the optional has a value, it indicates whether distribution was forced /// to be enabled (true) or disabled (false). If the optional has no value @@ -834,7 +837,7 @@ public: const Optional<bool> &isForced() const { return IsForced; } private: - /// \brief Filter out checks between pointers from the same partition. + /// Filter out checks between pointers from the same partition. /// /// \p PtrToPartition contains the partition number for pointers. Partition /// number -1 means that the pointer is used in multiple partitions. In this @@ -873,7 +876,7 @@ private: return Checks; } - /// \brief Check whether the loop metadata is forcing distribution to be + /// Check whether the loop metadata is forcing distribution to be /// enabled/disabled. void setForced() { Optional<const MDOperand *> Value = @@ -896,7 +899,7 @@ private: ScalarEvolution *SE; OptimizationRemarkEmitter *ORE; - /// \brief Indicates whether distribution is forced to be enabled/disabled for + /// Indicates whether distribution is forced to be enabled/disabled for /// the loop. /// /// If the optional has a value, it indicates whether distribution was forced @@ -939,7 +942,7 @@ static bool runImpl(Function &F, LoopInfo *LI, DominatorTree *DT, namespace { -/// \brief The pass class. +/// The pass class. class LoopDistributeLegacy : public FunctionPass { public: static char ID; diff --git a/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 21551f0a0825..d8692198f7a3 100644 --- a/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -37,7 +37,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar/LoopIdiomRecognize.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -57,6 +56,7 @@ #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -87,8 +87,8 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopIdiomRecognize.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include <algorithm> #include <cassert> @@ -188,8 +188,9 @@ private: PHINode *CntPhi, Value *Var); bool recognizeAndInsertCTLZ(); void transformLoopToCountable(BasicBlock *PreCondBB, Instruction *CntInst, - PHINode *CntPhi, Value *Var, const DebugLoc DL, - bool ZeroCheck, bool IsCntPhiUsedOutsideLoop); + PHINode *CntPhi, Value *Var, Instruction *DefX, + const DebugLoc &DL, bool ZeroCheck, + bool IsCntPhiUsedOutsideLoop); /// @} }; @@ -310,9 +311,9 @@ bool LoopIdiomRecognize::runOnCountableLoop() { SmallVector<BasicBlock *, 8> ExitBlocks; CurLoop->getUniqueExitBlocks(ExitBlocks); - DEBUG(dbgs() << "loop-idiom Scanning: F[" - << CurLoop->getHeader()->getParent()->getName() << "] Loop %" - << CurLoop->getHeader()->getName() << "\n"); + LLVM_DEBUG(dbgs() << "loop-idiom Scanning: F[" + << CurLoop->getHeader()->getParent()->getName() + << "] Loop %" << CurLoop->getHeader()->getName() << "\n"); bool MadeChange = false; @@ -756,8 +757,8 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, MSIs.insert(MSI); bool NegStride = SizeInBytes == -Stride; return processLoopStridedStore(Pointer, (unsigned)SizeInBytes, - MSI->getAlignment(), SplatValue, MSI, MSIs, Ev, - BECount, NegStride, /*IsLoopMemset=*/true); + MSI->getDestAlignment(), SplatValue, MSI, MSIs, + Ev, BECount, NegStride, /*IsLoopMemset=*/true); } /// mayLoopAccessLocation - Return true if the specified loop might access the @@ -936,8 +937,9 @@ bool LoopIdiomRecognize::processLoopStridedStore( NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, NumBytes}); } - DEBUG(dbgs() << " Formed memset: " << *NewCall << "\n" - << " from store to: " << *Ev << " at: " << *TheStore << "\n"); + LLVM_DEBUG(dbgs() << " Formed memset: " << *NewCall << "\n" + << " from store to: " << *Ev << " at: " << *TheStore + << "\n"); NewCall->setDebugLoc(TheStore->getDebugLoc()); // Okay, the memset has been formed. Zap the original store and anything that @@ -1037,16 +1039,17 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntPtrTy, Preheader->getTerminator()); - unsigned Align = std::min(SI->getAlignment(), LI->getAlignment()); CallInst *NewCall = nullptr; // Check whether to generate an unordered atomic memcpy: - // If the load or store are atomic, then they must neccessarily be unordered + // If the load or store are atomic, then they must necessarily be unordered // by previous checks. if (!SI->isAtomic() && !LI->isAtomic()) - NewCall = Builder.CreateMemCpy(StoreBasePtr, LoadBasePtr, NumBytes, Align); + NewCall = Builder.CreateMemCpy(StoreBasePtr, SI->getAlignment(), + LoadBasePtr, LI->getAlignment(), NumBytes); else { // We cannot allow unaligned ops for unordered load/store, so reject // anything where the alignment isn't at least the element size. + unsigned Align = std::min(SI->getAlignment(), LI->getAlignment()); if (Align < StoreSize) return false; @@ -1066,9 +1069,10 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, } NewCall->setDebugLoc(SI->getDebugLoc()); - DEBUG(dbgs() << " Formed memcpy: " << *NewCall << "\n" - << " from load ptr=" << *LoadEv << " at: " << *LI << "\n" - << " from store ptr=" << *StoreEv << " at: " << *SI << "\n"); + LLVM_DEBUG(dbgs() << " Formed memcpy: " << *NewCall << "\n" + << " from load ptr=" << *LoadEv << " at: " << *LI << "\n" + << " from store ptr=" << *StoreEv << " at: " << *SI + << "\n"); // Okay, the memcpy has been formed. Zap the original store and anything that // feeds into it. @@ -1084,9 +1088,9 @@ bool LoopIdiomRecognize::avoidLIRForMultiBlockLoop(bool IsMemset, bool IsLoopMemset) { if (ApplyCodeSizeHeuristics && CurLoop->getNumBlocks() > 1) { if (!CurLoop->getParentLoop() && (!IsMemset || !IsLoopMemset)) { - DEBUG(dbgs() << " " << CurLoop->getHeader()->getParent()->getName() - << " : LIR " << (IsMemset ? "Memset" : "Memcpy") - << " avoided: multi-block top-level loop\n"); + LLVM_DEBUG(dbgs() << " " << CurLoop->getHeader()->getParent()->getName() + << " : LIR " << (IsMemset ? "Memset" : "Memcpy") + << " avoided: multi-block top-level loop\n"); return true; } } @@ -1195,14 +1199,13 @@ static bool detectPopcountIdiom(Loop *CurLoop, BasicBlock *PreCondBB, VarX1 = DefX2->getOperand(0); SubOneOp = dyn_cast<BinaryOperator>(DefX2->getOperand(1)); } - if (!SubOneOp) + if (!SubOneOp || SubOneOp->getOperand(0) != VarX1) return false; - Instruction *SubInst = cast<Instruction>(SubOneOp); - ConstantInt *Dec = dyn_cast<ConstantInt>(SubInst->getOperand(1)); + ConstantInt *Dec = dyn_cast<ConstantInt>(SubOneOp->getOperand(1)); if (!Dec || - !((SubInst->getOpcode() == Instruction::Sub && Dec->isOne()) || - (SubInst->getOpcode() == Instruction::Add && + !((SubOneOp->getOpcode() == Instruction::Sub && Dec->isOne()) || + (SubOneOp->getOpcode() == Instruction::Add && Dec->isMinusOne()))) { return false; } @@ -1314,7 +1317,8 @@ static bool detectCTLZIdiom(Loop *CurLoop, PHINode *&PhiX, return false; // step 2: detect instructions corresponding to "x.next = x >> 1" - if (!DefX || DefX->getOpcode() != Instruction::AShr) + if (!DefX || (DefX->getOpcode() != Instruction::AShr && + DefX->getOpcode() != Instruction::LShr)) return false; ConstantInt *Shft = dyn_cast<ConstantInt>(DefX->getOperand(1)); if (!Shft || !Shft->isOne()) @@ -1372,13 +1376,13 @@ bool LoopIdiomRecognize::recognizeAndInsertCTLZ() { bool IsCntPhiUsedOutsideLoop = false; for (User *U : CntPhi->users()) - if (!CurLoop->contains(dyn_cast<Instruction>(U))) { + if (!CurLoop->contains(cast<Instruction>(U))) { IsCntPhiUsedOutsideLoop = true; break; } bool IsCntInstUsedOutsideLoop = false; for (User *U : CntInst->users()) - if (!CurLoop->contains(dyn_cast<Instruction>(U))) { + if (!CurLoop->contains(cast<Instruction>(U))) { IsCntInstUsedOutsideLoop = true; break; } @@ -1395,16 +1399,27 @@ bool LoopIdiomRecognize::recognizeAndInsertCTLZ() { // parent function RunOnLoop. BasicBlock *PH = CurLoop->getLoopPreheader(); Value *InitX = PhiX->getIncomingValueForBlock(PH); - // If we check X != 0 before entering the loop we don't need a zero - // check in CTLZ intrinsic, but only if Cnt Phi is not used outside of the - // loop (if it is used we count CTLZ(X >> 1)). - if (!IsCntPhiUsedOutsideLoop) - if (BasicBlock *PreCondBB = PH->getSinglePredecessor()) - if (BranchInst *PreCondBr = - dyn_cast<BranchInst>(PreCondBB->getTerminator())) { - if (matchCondition(PreCondBr, PH) == InitX) - ZeroCheck = true; - } + + // Make sure the initial value can't be negative otherwise the ashr in the + // loop might never reach zero which would make the loop infinite. + if (DefX->getOpcode() == Instruction::AShr && !isKnownNonNegative(InitX, *DL)) + return false; + + // If we are using the count instruction outside the loop, make sure we + // have a zero check as a precondition. Without the check the loop would run + // one iteration for before any check of the input value. This means 0 and 1 + // would have identical behavior in the original loop and thus + if (!IsCntPhiUsedOutsideLoop) { + auto *PreCondBB = PH->getSinglePredecessor(); + if (!PreCondBB) + return false; + auto *PreCondBI = dyn_cast<BranchInst>(PreCondBB->getTerminator()); + if (!PreCondBI) + return false; + if (matchCondition(PreCondBI, PH) != InitX) + return false; + ZeroCheck = true; + } // Check if CTLZ intrinsic is profitable. Assume it is always profitable // if we delete the loop (the loop has only 6 instructions): @@ -1415,17 +1430,16 @@ bool LoopIdiomRecognize::recognizeAndInsertCTLZ() { // %inc = add nsw %i.0, 1 // br i1 %tobool - IRBuilder<> Builder(PH->getTerminator()); - SmallVector<const Value *, 2> Ops = - {InitX, ZeroCheck ? Builder.getTrue() : Builder.getFalse()}; - ArrayRef<const Value *> Args(Ops); + const Value *Args[] = + {InitX, ZeroCheck ? ConstantInt::getTrue(InitX->getContext()) + : ConstantInt::getFalse(InitX->getContext())}; if (CurLoop->getHeader()->size() != 6 && TTI->getIntrinsicCost(Intrinsic::ctlz, InitX->getType(), Args) > TargetTransformInfo::TCC_Basic) return false; - const DebugLoc DL = DefX->getDebugLoc(); - transformLoopToCountable(PH, CntInst, CntPhi, InitX, DL, ZeroCheck, + transformLoopToCountable(PH, CntInst, CntPhi, InitX, DefX, + DefX->getDebugLoc(), ZeroCheck, IsCntPhiUsedOutsideLoop); return true; } @@ -1461,7 +1475,7 @@ bool LoopIdiomRecognize::recognizePopcount() { if (!EntryBI || EntryBI->isConditional()) return false; - // It should have a precondition block where the generated popcount instrinsic + // It should have a precondition block where the generated popcount intrinsic // function can be inserted. auto *PreCondBB = PH->getSinglePredecessor(); if (!PreCondBB) @@ -1539,8 +1553,9 @@ static CallInst *createCTLZIntrinsic(IRBuilder<> &IRBuilder, Value *Val, /// If CntInst and DefX are not used in LOOP_BODY they will be removed. void LoopIdiomRecognize::transformLoopToCountable( BasicBlock *Preheader, Instruction *CntInst, PHINode *CntPhi, Value *InitX, - const DebugLoc DL, bool ZeroCheck, bool IsCntPhiUsedOutsideLoop) { - BranchInst *PreheaderBr = dyn_cast<BranchInst>(Preheader->getTerminator()); + Instruction *DefX, const DebugLoc &DL, bool ZeroCheck, + bool IsCntPhiUsedOutsideLoop) { + BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator()); // Step 1: Insert the CTLZ instruction at the end of the preheader block // Count = BitWidth - CTLZ(InitX); @@ -1550,10 +1565,16 @@ void LoopIdiomRecognize::transformLoopToCountable( Builder.SetCurrentDebugLocation(DL); Value *CTLZ, *Count, *CountPrev, *NewCount, *InitXNext; - if (IsCntPhiUsedOutsideLoop) - InitXNext = Builder.CreateAShr(InitX, - ConstantInt::get(InitX->getType(), 1)); - else + if (IsCntPhiUsedOutsideLoop) { + if (DefX->getOpcode() == Instruction::AShr) + InitXNext = + Builder.CreateAShr(InitX, ConstantInt::get(InitX->getType(), 1)); + else if (DefX->getOpcode() == Instruction::LShr) + InitXNext = + Builder.CreateLShr(InitX, ConstantInt::get(InitX->getType(), 1)); + else + llvm_unreachable("Unexpected opcode!"); + } else InitXNext = InitX; CTLZ = createCTLZIntrinsic(Builder, InitXNext, DL, ZeroCheck); Count = Builder.CreateSub( @@ -1588,7 +1609,7 @@ void LoopIdiomRecognize::transformLoopToCountable( // ... // Br: loop if (Dec != 0) BasicBlock *Body = *(CurLoop->block_begin()); - auto *LbBr = dyn_cast<BranchInst>(Body->getTerminator()); + auto *LbBr = cast<BranchInst>(Body->getTerminator()); ICmpInst *LbCond = cast<ICmpInst>(LbBr->getCondition()); Type *Ty = Count->getType(); @@ -1625,8 +1646,8 @@ void LoopIdiomRecognize::transformLoopToPopcount(BasicBlock *PreCondBB, Instruction *CntInst, PHINode *CntPhi, Value *Var) { BasicBlock *PreHead = CurLoop->getLoopPreheader(); - auto *PreCondBr = dyn_cast<BranchInst>(PreCondBB->getTerminator()); - const DebugLoc DL = CntInst->getDebugLoc(); + auto *PreCondBr = cast<BranchInst>(PreCondBB->getTerminator()); + const DebugLoc &DL = CntInst->getDebugLoc(); // Assuming before transformation, the loop is following: // if (x) // the precondition @@ -1675,7 +1696,7 @@ void LoopIdiomRecognize::transformLoopToPopcount(BasicBlock *PreCondBB, } // Step 3: Note that the population count is exactly the trip count of the - // loop in question, which enable us to to convert the loop from noncountable + // loop in question, which enable us to convert the loop from noncountable // loop into a countable one. The benefit is twofold: // // - If the loop only counts population, the entire loop becomes dead after @@ -1696,7 +1717,7 @@ void LoopIdiomRecognize::transformLoopToPopcount(BasicBlock *PreCondBB, // do { cnt++; x &= x-1; t--) } while (t > 0); BasicBlock *Body = *(CurLoop->block_begin()); { - auto *LbBr = dyn_cast<BranchInst>(Body->getTerminator()); + auto *LbBr = cast<BranchInst>(Body->getTerminator()); ICmpInst *LbCond = cast<ICmpInst>(LbBr->getCondition()); Type *Ty = TripCnt->getType(); diff --git a/lib/Transforms/Scalar/LoopInstSimplify.cpp b/lib/Transforms/Scalar/LoopInstSimplify.cpp index 40d468a084d4..71859efbf4bd 100644 --- a/lib/Transforms/Scalar/LoopInstSimplify.cpp +++ b/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -20,8 +20,10 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/DataLayout.h" @@ -34,7 +36,6 @@ #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include <algorithm> #include <utility> @@ -45,118 +46,116 @@ using namespace llvm; STATISTIC(NumSimplified, "Number of redundant instructions simplified"); -static bool SimplifyLoopInst(Loop *L, DominatorTree *DT, LoopInfo *LI, - AssumptionCache *AC, - const TargetLibraryInfo *TLI) { - SmallVector<BasicBlock *, 8> ExitBlocks; - L->getUniqueExitBlocks(ExitBlocks); - array_pod_sort(ExitBlocks.begin(), ExitBlocks.end()); - +static bool simplifyLoopInst(Loop &L, DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC, + const TargetLibraryInfo &TLI) { + const DataLayout &DL = L.getHeader()->getModule()->getDataLayout(); + SimplifyQuery SQ(DL, &TLI, &DT, &AC); + + // On the first pass over the loop body we try to simplify every instruction. + // On subsequent passes, we can restrict this to only simplifying instructions + // where the inputs have been updated. We end up needing two sets: one + // containing the instructions we are simplifying in *this* pass, and one for + // the instructions we will want to simplify in the *next* pass. We use + // pointers so we can swap between two stably allocated sets. SmallPtrSet<const Instruction *, 8> S1, S2, *ToSimplify = &S1, *Next = &S2; - // The bit we are stealing from the pointer represents whether this basic - // block is the header of a subloop, in which case we only process its phis. - using WorklistItem = PointerIntPair<BasicBlock *, 1>; - SmallVector<WorklistItem, 16> VisitStack; - SmallPtrSet<BasicBlock *, 32> Visited; - - bool Changed = false; - bool LocalChanged; - do { - LocalChanged = false; - - VisitStack.clear(); - Visited.clear(); + // Track the PHI nodes that have already been visited during each iteration so + // that we can identify when it is necessary to iterate. + SmallPtrSet<PHINode *, 4> VisitedPHIs; - VisitStack.push_back(WorklistItem(L->getHeader(), false)); + // While simplifying we may discover dead code or cause code to become dead. + // Keep track of all such instructions and we will delete them at the end. + SmallVector<Instruction *, 8> DeadInsts; - while (!VisitStack.empty()) { - WorklistItem Item = VisitStack.pop_back_val(); - BasicBlock *BB = Item.getPointer(); - bool IsSubloopHeader = Item.getInt(); - const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + // First we want to create an RPO traversal of the loop body. By processing in + // RPO we can ensure that definitions are processed prior to uses (for non PHI + // uses) in all cases. This ensures we maximize the simplifications in each + // iteration over the loop and minimizes the possible causes for continuing to + // iterate. + LoopBlocksRPO RPOT(&L); + RPOT.perform(&LI); - // Simplify instructions in the current basic block. - for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) { - Instruction *I = &*BI++; - - // The first time through the loop ToSimplify is empty and we try to - // simplify all instructions. On later iterations ToSimplify is not - // empty and we only bother simplifying instructions that are in it. - if (!ToSimplify->empty() && !ToSimplify->count(I)) + bool Changed = false; + for (;;) { + for (BasicBlock *BB : RPOT) { + for (Instruction &I : *BB) { + if (auto *PI = dyn_cast<PHINode>(&I)) + VisitedPHIs.insert(PI); + + if (I.use_empty()) { + if (isInstructionTriviallyDead(&I, &TLI)) + DeadInsts.push_back(&I); continue; - - // Don't bother simplifying unused instructions. - if (!I->use_empty()) { - Value *V = SimplifyInstruction(I, {DL, TLI, DT, AC}); - if (V && LI->replacementPreservesLCSSAForm(I, V)) { - // Mark all uses for resimplification next time round the loop. - for (User *U : I->users()) - Next->insert(cast<Instruction>(U)); - - I->replaceAllUsesWith(V); - LocalChanged = true; - ++NumSimplified; - } - } - if (RecursivelyDeleteTriviallyDeadInstructions(I, TLI)) { - // RecursivelyDeleteTriviallyDeadInstruction can remove more than one - // instruction, so simply incrementing the iterator does not work. - // When instructions get deleted re-iterate instead. - BI = BB->begin(); - BE = BB->end(); - LocalChanged = true; } - if (IsSubloopHeader && !isa<PHINode>(I)) - break; - } + // We special case the first iteration which we can detect due to the + // empty `ToSimplify` set. + bool IsFirstIteration = ToSimplify->empty(); - // Add all successors to the worklist, except for loop exit blocks and the - // bodies of subloops. We visit the headers of loops so that we can - // process - // their phis, but we contract the rest of the subloop body and only - // follow - // edges leading back to the original loop. - for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; - ++SI) { - BasicBlock *SuccBB = *SI; - if (!Visited.insert(SuccBB).second) + if (!IsFirstIteration && !ToSimplify->count(&I)) continue; - const Loop *SuccLoop = LI->getLoopFor(SuccBB); - if (SuccLoop && SuccLoop->getHeader() == SuccBB && - L->contains(SuccLoop)) { - VisitStack.push_back(WorklistItem(SuccBB, true)); - - SmallVector<BasicBlock *, 8> SubLoopExitBlocks; - SuccLoop->getExitBlocks(SubLoopExitBlocks); - - for (unsigned i = 0; i < SubLoopExitBlocks.size(); ++i) { - BasicBlock *ExitBB = SubLoopExitBlocks[i]; - if (LI->getLoopFor(ExitBB) == L && Visited.insert(ExitBB).second) - VisitStack.push_back(WorklistItem(ExitBB, false)); - } - + Value *V = SimplifyInstruction(&I, SQ.getWithInstruction(&I)); + if (!V || !LI.replacementPreservesLCSSAForm(&I, V)) continue; - } - bool IsExitBlock = - std::binary_search(ExitBlocks.begin(), ExitBlocks.end(), SuccBB); - if (IsExitBlock) - continue; + for (Value::use_iterator UI = I.use_begin(), UE = I.use_end(); + UI != UE;) { + Use &U = *UI++; + auto *UserI = cast<Instruction>(U.getUser()); + U.set(V); + + // If the instruction is used by a PHI node we have already processed + // we'll need to iterate on the loop body to converge, so add it to + // the next set. + if (auto *UserPI = dyn_cast<PHINode>(UserI)) + if (VisitedPHIs.count(UserPI)) { + Next->insert(UserPI); + continue; + } + + // If we are only simplifying targeted instructions and the user is an + // instruction in the loop body, add it to our set of targeted + // instructions. Because we process defs before uses (outside of PHIs) + // we won't have visited it yet. + // + // We also skip any uses outside of the loop being simplified. Those + // should always be PHI nodes due to LCSSA form, and we don't want to + // try to simplify those away. + assert((L.contains(UserI) || isa<PHINode>(UserI)) && + "Uses outside the loop should be PHI nodes due to LCSSA!"); + if (!IsFirstIteration && L.contains(UserI)) + ToSimplify->insert(UserI); + } - VisitStack.push_back(WorklistItem(SuccBB, false)); + assert(I.use_empty() && "Should always have replaced all uses!"); + if (isInstructionTriviallyDead(&I, &TLI)) + DeadInsts.push_back(&I); + ++NumSimplified; + Changed = true; } } - // Place the list of instructions to simplify on the next loop iteration - // into ToSimplify. - std::swap(ToSimplify, Next); - Next->clear(); + // Delete any dead instructions found thus far now that we've finished an + // iteration over all instructions in all the loop blocks. + if (!DeadInsts.empty()) { + Changed = true; + RecursivelyDeleteTriviallyDeadInstructions(DeadInsts, &TLI); + } + + // If we never found a PHI that needs to be simplified in the next + // iteration, we're done. + if (Next->empty()) + break; - Changed |= LocalChanged; - } while (LocalChanged); + // Otherwise, put the next set in place for the next iteration and reset it + // and the visited PHIs for that iteration. + std::swap(Next, ToSimplify); + Next->clear(); + VisitedPHIs.clear(); + DeadInsts.clear(); + } return Changed; } @@ -174,21 +173,20 @@ public: bool runOnLoop(Loop *L, LPPassManager &LPM) override { if (skipLoop(L)) return false; - DominatorTreeWrapperPass *DTWP = - getAnalysisIfAvailable<DominatorTreeWrapperPass>(); - DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr; - LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - AssumptionCache *AC = - &getAnalysis<AssumptionCacheTracker>().getAssumptionCache( + DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + AssumptionCache &AC = + getAnalysis<AssumptionCacheTracker>().getAssumptionCache( *L->getHeader()->getParent()); - const TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + const TargetLibraryInfo &TLI = + getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - return SimplifyLoopInst(L, DT, LI, AC, TLI); + return simplifyLoopInst(*L, DT, LI, AC, TLI); } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.setPreservesCFG(); getLoopAnalysisUsage(AU); @@ -200,7 +198,7 @@ public: PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { - if (!SimplifyLoopInst(&L, &AR.DT, &AR.LI, &AR.AC, &AR.TLI)) + if (!simplifyLoopInst(L, AR.DT, AR.LI, AR.AC, AR.TLI)) return PreservedAnalyses::all(); auto PA = getLoopPassPreservedAnalyses(); diff --git a/lib/Transforms/Scalar/LoopInterchange.cpp b/lib/Transforms/Scalar/LoopInterchange.cpp index 4f8dafef230a..2978165ed8a9 100644 --- a/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/lib/Transforms/Scalar/LoopInterchange.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/DependenceAnalysis.h" @@ -40,6 +41,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include <cassert> @@ -50,6 +52,8 @@ using namespace llvm; #define DEBUG_TYPE "loop-interchange" +STATISTIC(LoopsInterchanged, "Number of loops interchanged"); + static cl::opt<int> LoopInterchangeCostThreshold( "loop-interchange-threshold", cl::init(0), cl::Hidden, cl::desc("Interchange if you gain more than this number")); @@ -73,8 +77,8 @@ static const unsigned MaxLoopNestDepth = 10; static void printDepMatrix(CharMatrix &DepMatrix) { for (auto &Row : DepMatrix) { for (auto D : Row) - DEBUG(dbgs() << D << " "); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << D << " "); + LLVM_DEBUG(dbgs() << "\n"); } } #endif @@ -103,8 +107,8 @@ static bool populateDependencyMatrix(CharMatrix &DepMatrix, unsigned Level, } } - DEBUG(dbgs() << "Found " << MemInstr.size() - << " Loads and Stores to analyze\n"); + LLVM_DEBUG(dbgs() << "Found " << MemInstr.size() + << " Loads and Stores to analyze\n"); ValueVector::iterator I, IE, J, JE; @@ -121,11 +125,11 @@ static bool populateDependencyMatrix(CharMatrix &DepMatrix, unsigned Level, // Track Output, Flow, and Anti dependencies. if (auto D = DI->depends(Src, Dst, true)) { assert(D->isOrdered() && "Expected an output, flow or anti dep."); - DEBUG(StringRef DepType = - D->isFlow() ? "flow" : D->isAnti() ? "anti" : "output"; - dbgs() << "Found " << DepType - << " dependency between Src and Dst\n" - << " Src:" << *Src << "\n Dst:" << *Dst << '\n'); + LLVM_DEBUG(StringRef DepType = + D->isFlow() ? "flow" : D->isAnti() ? "anti" : "output"; + dbgs() << "Found " << DepType + << " dependency between Src and Dst\n" + << " Src:" << *Src << "\n Dst:" << *Dst << '\n'); unsigned Levels = D->getLevels(); char Direction; for (unsigned II = 1; II <= Levels; ++II) { @@ -165,17 +169,14 @@ static bool populateDependencyMatrix(CharMatrix &DepMatrix, unsigned Level, DepMatrix.push_back(Dep); if (DepMatrix.size() > MaxMemInstrCount) { - DEBUG(dbgs() << "Cannot handle more than " << MaxMemInstrCount - << " dependencies inside loop\n"); + LLVM_DEBUG(dbgs() << "Cannot handle more than " << MaxMemInstrCount + << " dependencies inside loop\n"); return false; } } } } - // We don't have a DepMatrix to check legality return false. - if (DepMatrix.empty()) - return false; return true; } @@ -271,9 +272,9 @@ static bool isLegalToInterChangeLoops(CharMatrix &DepMatrix, } static void populateWorklist(Loop &L, SmallVector<LoopVector, 8> &V) { - DEBUG(dbgs() << "Calling populateWorklist on Func: " - << L.getHeader()->getParent()->getName() << " Loop: %" - << L.getHeader()->getName() << '\n'); + LLVM_DEBUG(dbgs() << "Calling populateWorklist on Func: " + << L.getHeader()->getParent()->getName() << " Loop: %" + << L.getHeader()->getName() << '\n'); LoopVector LoopList; Loop *CurrentLoop = &L; const std::vector<Loop *> *Vec = &CurrentLoop->getSubLoops(); @@ -404,7 +405,9 @@ public: /// Interchange OuterLoop and InnerLoop. bool transform(); - void restructureLoops(Loop *InnerLoop, Loop *OuterLoop); + void restructureLoops(Loop *NewInner, Loop *NewOuter, + BasicBlock *OrigInnerPreHeader, + BasicBlock *OrigOuterPreHeader); void removeChildLoop(Loop *OuterLoop, Loop *InnerLoop); private: @@ -453,6 +456,9 @@ struct LoopInterchange : public FunctionPass { AU.addRequiredID(LoopSimplifyID); AU.addRequiredID(LCSSAID); AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); + + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); } bool runOnFunction(Function &F) override { @@ -462,8 +468,7 @@ struct LoopInterchange : public FunctionPass { SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); DI = &getAnalysis<DependenceAnalysisWrapperPass>().getDI(); - auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); - DT = DTWP ? &DTWP->getDomTree() : nullptr; + DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); @@ -473,7 +478,7 @@ struct LoopInterchange : public FunctionPass { for (Loop *L : *LI) populateWorklist(*L, Worklist); - DEBUG(dbgs() << "Worklist size = " << Worklist.size() << "\n"); + LLVM_DEBUG(dbgs() << "Worklist size = " << Worklist.size() << "\n"); bool Changed = true; while (!Worklist.empty()) { LoopVector LoopList = Worklist.pop_back_val(); @@ -486,15 +491,15 @@ struct LoopInterchange : public FunctionPass { for (Loop *L : LoopList) { const SCEV *ExitCountOuter = SE->getBackedgeTakenCount(L); if (ExitCountOuter == SE->getCouldNotCompute()) { - DEBUG(dbgs() << "Couldn't compute backedge count\n"); + LLVM_DEBUG(dbgs() << "Couldn't compute backedge count\n"); return false; } if (L->getNumBackEdges() != 1) { - DEBUG(dbgs() << "NumBackEdges is not equal to 1\n"); + LLVM_DEBUG(dbgs() << "NumBackEdges is not equal to 1\n"); return false; } if (!L->getExitingBlock()) { - DEBUG(dbgs() << "Loop doesn't have unique exit block\n"); + LLVM_DEBUG(dbgs() << "Loop doesn't have unique exit block\n"); return false; } } @@ -511,53 +516,38 @@ struct LoopInterchange : public FunctionPass { bool Changed = false; unsigned LoopNestDepth = LoopList.size(); if (LoopNestDepth < 2) { - DEBUG(dbgs() << "Loop doesn't contain minimum nesting level.\n"); + LLVM_DEBUG(dbgs() << "Loop doesn't contain minimum nesting level.\n"); return false; } if (LoopNestDepth > MaxLoopNestDepth) { - DEBUG(dbgs() << "Cannot handle loops of depth greater than " - << MaxLoopNestDepth << "\n"); + LLVM_DEBUG(dbgs() << "Cannot handle loops of depth greater than " + << MaxLoopNestDepth << "\n"); return false; } if (!isComputableLoopNest(LoopList)) { - DEBUG(dbgs() << "Not valid loop candidate for interchange\n"); + LLVM_DEBUG(dbgs() << "Not valid loop candidate for interchange\n"); return false; } - DEBUG(dbgs() << "Processing LoopList of size = " << LoopNestDepth << "\n"); + LLVM_DEBUG(dbgs() << "Processing LoopList of size = " << LoopNestDepth + << "\n"); CharMatrix DependencyMatrix; Loop *OuterMostLoop = *(LoopList.begin()); if (!populateDependencyMatrix(DependencyMatrix, LoopNestDepth, OuterMostLoop, DI)) { - DEBUG(dbgs() << "Populating dependency matrix failed\n"); + LLVM_DEBUG(dbgs() << "Populating dependency matrix failed\n"); return false; } #ifdef DUMP_DEP_MATRICIES - DEBUG(dbgs() << "Dependence before interchange\n"); + LLVM_DEBUG(dbgs() << "Dependence before interchange\n"); printDepMatrix(DependencyMatrix); #endif - BasicBlock *OuterMostLoopLatch = OuterMostLoop->getLoopLatch(); - BranchInst *OuterMostLoopLatchBI = - dyn_cast<BranchInst>(OuterMostLoopLatch->getTerminator()); - if (!OuterMostLoopLatchBI) - return false; - - // Since we currently do not handle LCSSA PHI's any failure in loop - // condition will now branch to LoopNestExit. - // TODO: This should be removed once we handle LCSSA PHI nodes. - // Get the Outermost loop exit. - BasicBlock *LoopNestExit; - if (OuterMostLoopLatchBI->getSuccessor(0) == OuterMostLoop->getHeader()) - LoopNestExit = OuterMostLoopLatchBI->getSuccessor(1); - else - LoopNestExit = OuterMostLoopLatchBI->getSuccessor(0); - - if (isa<PHINode>(LoopNestExit->begin())) { - DEBUG(dbgs() << "PHI Nodes in loop nest exit is not handled for now " - "since on failure all loops branch to loop nest exit.\n"); + BasicBlock *LoopNestExit = OuterMostLoop->getExitBlock(); + if (!LoopNestExit) { + LLVM_DEBUG(dbgs() << "OuterMostLoop needs an unique exit block"); return false; } @@ -573,9 +563,8 @@ struct LoopInterchange : public FunctionPass { // Update the DependencyMatrix interChangeDependencies(DependencyMatrix, i, i - 1); - DT->recalculate(F); #ifdef DUMP_DEP_MATRICIES - DEBUG(dbgs() << "Dependence after interchange\n"); + LLVM_DEBUG(dbgs() << "Dependence after interchange\n"); printDepMatrix(DependencyMatrix); #endif Changed |= Interchanged; @@ -586,21 +575,21 @@ struct LoopInterchange : public FunctionPass { bool processLoop(LoopVector LoopList, unsigned InnerLoopId, unsigned OuterLoopId, BasicBlock *LoopNestExit, std::vector<std::vector<char>> &DependencyMatrix) { - DEBUG(dbgs() << "Processing Inner Loop Id = " << InnerLoopId - << " and OuterLoopId = " << OuterLoopId << "\n"); + LLVM_DEBUG(dbgs() << "Processing Inner Loop Id = " << InnerLoopId + << " and OuterLoopId = " << OuterLoopId << "\n"); Loop *InnerLoop = LoopList[InnerLoopId]; Loop *OuterLoop = LoopList[OuterLoopId]; LoopInterchangeLegality LIL(OuterLoop, InnerLoop, SE, LI, DT, PreserveLCSSA, ORE); if (!LIL.canInterchangeLoops(InnerLoopId, OuterLoopId, DependencyMatrix)) { - DEBUG(dbgs() << "Not interchanging Loops. Cannot prove legality\n"); + LLVM_DEBUG(dbgs() << "Not interchanging loops. Cannot prove legality.\n"); return false; } - DEBUG(dbgs() << "Loops are legal to interchange\n"); + LLVM_DEBUG(dbgs() << "Loops are legal to interchange\n"); LoopInterchangeProfitability LIP(OuterLoop, InnerLoop, SE, ORE); if (!LIP.isProfitable(InnerLoopId, OuterLoopId, DependencyMatrix)) { - DEBUG(dbgs() << "Interchanging loops not profitable\n"); + LLVM_DEBUG(dbgs() << "Interchanging loops not profitable.\n"); return false; } @@ -614,7 +603,8 @@ struct LoopInterchange : public FunctionPass { LoopInterchangeTransform LIT(OuterLoop, InnerLoop, SE, LI, DT, LoopNestExit, LIL.hasInnerLoopReduction()); LIT.transform(); - DEBUG(dbgs() << "Loops interchanged\n"); + LLVM_DEBUG(dbgs() << "Loops interchanged.\n"); + LoopsInterchanged++; return true; } }; @@ -631,13 +621,13 @@ bool LoopInterchangeLegality::areAllUsesReductions(Instruction *Ins, Loop *L) { bool LoopInterchangeLegality::containsUnsafeInstructionsInHeader( BasicBlock *BB) { - for (auto I = BB->begin(), E = BB->end(); I != E; ++I) { + for (Instruction &I : *BB) { // Load corresponding to reduction PHI's are safe while concluding if // tightly nested. - if (LoadInst *L = dyn_cast<LoadInst>(I)) { + if (LoadInst *L = dyn_cast<LoadInst>(&I)) { if (!areAllUsesReductions(L, InnerLoop)) return true; - } else if (I->mayHaveSideEffects() || I->mayReadFromMemory()) + } else if (I.mayHaveSideEffects() || I.mayReadFromMemory()) return true; } return false; @@ -645,13 +635,13 @@ bool LoopInterchangeLegality::containsUnsafeInstructionsInHeader( bool LoopInterchangeLegality::containsUnsafeInstructionsInLatch( BasicBlock *BB) { - for (auto I = BB->begin(), E = BB->end(); I != E; ++I) { + for (Instruction &I : *BB) { // Stores corresponding to reductions are safe while concluding if tightly // nested. - if (StoreInst *L = dyn_cast<StoreInst>(I)) { + if (StoreInst *L = dyn_cast<StoreInst>(&I)) { if (!isa<PHINode>(L->getOperand(0))) return true; - } else if (I->mayHaveSideEffects() || I->mayReadFromMemory()) + } else if (I.mayHaveSideEffects() || I.mayReadFromMemory()) return true; } return false; @@ -662,7 +652,7 @@ bool LoopInterchangeLegality::tightlyNested(Loop *OuterLoop, Loop *InnerLoop) { BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch(); - DEBUG(dbgs() << "Checking if loops are tightly nested\n"); + LLVM_DEBUG(dbgs() << "Checking if loops are tightly nested\n"); // A perfectly nested loop will not have any branch in between the outer and // inner block i.e. outer header will branch to either inner preheader and @@ -676,14 +666,14 @@ bool LoopInterchangeLegality::tightlyNested(Loop *OuterLoop, Loop *InnerLoop) { if (Succ != InnerLoopPreHeader && Succ != OuterLoopLatch) return false; - DEBUG(dbgs() << "Checking instructions in Loop header and Loop latch\n"); + LLVM_DEBUG(dbgs() << "Checking instructions in Loop header and Loop latch\n"); // We do not have any basic block in between now make sure the outer header // and outer loop latch doesn't contain any unsafe instructions. if (containsUnsafeInstructionsInHeader(OuterLoopHeader) || containsUnsafeInstructionsInLatch(OuterLoopLatch)) return false; - DEBUG(dbgs() << "Loops are perfectly nested\n"); + LLVM_DEBUG(dbgs() << "Loops are perfectly nested\n"); // We have a perfect loop nest. return true; } @@ -717,16 +707,15 @@ bool LoopInterchangeLegality::findInductionAndReductions( SmallVector<PHINode *, 8> &Reductions) { if (!L->getLoopLatch() || !L->getLoopPredecessor()) return false; - for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) { + for (PHINode &PHI : L->getHeader()->phis()) { RecurrenceDescriptor RD; InductionDescriptor ID; - PHINode *PHI = cast<PHINode>(I); - if (InductionDescriptor::isInductionPHI(PHI, L, SE, ID)) - Inductions.push_back(PHI); - else if (RecurrenceDescriptor::isReductionPHI(PHI, L, RD)) - Reductions.push_back(PHI); + if (InductionDescriptor::isInductionPHI(&PHI, L, SE, ID)) + Inductions.push_back(&PHI); + else if (RecurrenceDescriptor::isReductionPHI(&PHI, L, RD)) + Reductions.push_back(&PHI); else { - DEBUG( + LLVM_DEBUG( dbgs() << "Failed to recognize PHI as an induction or reduction.\n"); return false; } @@ -735,12 +724,11 @@ bool LoopInterchangeLegality::findInductionAndReductions( } static bool containsSafePHI(BasicBlock *Block, bool isOuterLoopExitBlock) { - for (auto I = Block->begin(); isa<PHINode>(I); ++I) { - PHINode *PHI = cast<PHINode>(I); + for (PHINode &PHI : Block->phis()) { // Reduction lcssa phi will have only 1 incoming block that from loop latch. - if (PHI->getNumIncomingValues() > 1) + if (PHI.getNumIncomingValues() > 1) return false; - Instruction *Ins = dyn_cast<Instruction>(PHI->getIncomingValue(0)); + Instruction *Ins = dyn_cast<Instruction>(PHI.getIncomingValue(0)); if (!Ins) return false; // Incoming value for lcssa phi's in outer loop exit can only be inner loop @@ -751,35 +739,38 @@ static bool containsSafePHI(BasicBlock *Block, bool isOuterLoopExitBlock) { return true; } -static BasicBlock *getLoopLatchExitBlock(BasicBlock *LatchBlock, - BasicBlock *LoopHeader) { - if (BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator())) { - assert(BI->getNumSuccessors() == 2 && - "Branch leaving loop latch must have 2 successors"); - for (BasicBlock *Succ : BI->successors()) { - if (Succ == LoopHeader) - continue; - return Succ; - } - } - return nullptr; -} - // This function indicates the current limitations in the transform as a result // of which we do not proceed. bool LoopInterchangeLegality::currentLimitations() { BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); - BasicBlock *InnerLoopHeader = InnerLoop->getHeader(); BasicBlock *InnerLoopLatch = InnerLoop->getLoopLatch(); - BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch(); - BasicBlock *OuterLoopHeader = OuterLoop->getHeader(); + + // transform currently expects the loop latches to also be the exiting + // blocks. + if (InnerLoop->getExitingBlock() != InnerLoopLatch || + OuterLoop->getExitingBlock() != OuterLoop->getLoopLatch() || + !isa<BranchInst>(InnerLoopLatch->getTerminator()) || + !isa<BranchInst>(OuterLoop->getLoopLatch()->getTerminator())) { + LLVM_DEBUG( + dbgs() << "Loops where the latch is not the exiting block are not" + << " supported currently.\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "ExitingNotLatch", + OuterLoop->getStartLoc(), + OuterLoop->getHeader()) + << "Loops where the latch is not the exiting block cannot be" + " interchange currently."; + }); + return true; + } PHINode *InnerInductionVar; SmallVector<PHINode *, 8> Inductions; SmallVector<PHINode *, 8> Reductions; if (!findInductionAndReductions(InnerLoop, Inductions, Reductions)) { - DEBUG(dbgs() << "Only inner loops with induction or reduction PHI nodes " - << "are supported currently.\n"); + LLVM_DEBUG( + dbgs() << "Only inner loops with induction or reduction PHI nodes " + << "are supported currently.\n"); ORE->emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedPHIInner", InnerLoop->getStartLoc(), @@ -792,8 +783,9 @@ bool LoopInterchangeLegality::currentLimitations() { // TODO: Currently we handle only loops with 1 induction variable. if (Inductions.size() != 1) { - DEBUG(dbgs() << "We currently only support loops with 1 induction variable." - << "Failed to interchange due to current limitation\n"); + LLVM_DEBUG( + dbgs() << "We currently only support loops with 1 induction variable." + << "Failed to interchange due to current limitation\n"); ORE->emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "MultiInductionInner", InnerLoop->getStartLoc(), @@ -809,8 +801,9 @@ bool LoopInterchangeLegality::currentLimitations() { InnerInductionVar = Inductions.pop_back_val(); Reductions.clear(); if (!findInductionAndReductions(OuterLoop, Inductions, Reductions)) { - DEBUG(dbgs() << "Only outer loops with induction or reduction PHI nodes " - << "are supported currently.\n"); + LLVM_DEBUG( + dbgs() << "Only outer loops with induction or reduction PHI nodes " + << "are supported currently.\n"); ORE->emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedPHIOuter", OuterLoop->getStartLoc(), @@ -824,8 +817,8 @@ bool LoopInterchangeLegality::currentLimitations() { // Outer loop cannot have reduction because then loops will not be tightly // nested. if (!Reductions.empty()) { - DEBUG(dbgs() << "Outer loops with reductions are not supported " - << "currently.\n"); + LLVM_DEBUG(dbgs() << "Outer loops with reductions are not supported " + << "currently.\n"); ORE->emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "ReductionsOuter", OuterLoop->getStartLoc(), @@ -837,8 +830,8 @@ bool LoopInterchangeLegality::currentLimitations() { } // TODO: Currently we handle only loops with 1 induction variable. if (Inductions.size() != 1) { - DEBUG(dbgs() << "Loops with more than 1 induction variables are not " - << "supported currently.\n"); + LLVM_DEBUG(dbgs() << "Loops with more than 1 induction variables are not " + << "supported currently.\n"); ORE->emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "MultiIndutionOuter", OuterLoop->getStartLoc(), @@ -851,7 +844,7 @@ bool LoopInterchangeLegality::currentLimitations() { // TODO: Triangular loops are not handled for now. if (!isLoopStructureUnderstood(InnerInductionVar)) { - DEBUG(dbgs() << "Loop structure not understood by pass\n"); + LLVM_DEBUG(dbgs() << "Loop structure not understood by pass\n"); ORE->emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedStructureInner", InnerLoop->getStartLoc(), @@ -862,23 +855,10 @@ bool LoopInterchangeLegality::currentLimitations() { } // TODO: We only handle LCSSA PHI's corresponding to reduction for now. - BasicBlock *LoopExitBlock = - getLoopLatchExitBlock(OuterLoopLatch, OuterLoopHeader); - if (!LoopExitBlock || !containsSafePHI(LoopExitBlock, true)) { - DEBUG(dbgs() << "Can only handle LCSSA PHIs in outer loops currently.\n"); - ORE->emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, "NoLCSSAPHIOuter", - OuterLoop->getStartLoc(), - OuterLoop->getHeader()) - << "Only outer loops with LCSSA PHIs can be interchange " - "currently."; - }); - return true; - } - - LoopExitBlock = getLoopLatchExitBlock(InnerLoopLatch, InnerLoopHeader); - if (!LoopExitBlock || !containsSafePHI(LoopExitBlock, false)) { - DEBUG(dbgs() << "Can only handle LCSSA PHIs in inner loops currently.\n"); + BasicBlock *InnerExit = InnerLoop->getExitBlock(); + if (!containsSafePHI(InnerExit, false)) { + LLVM_DEBUG( + dbgs() << "Can only handle LCSSA PHIs in inner loops currently.\n"); ORE->emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "NoLCSSAPHIOuterInner", InnerLoop->getStartLoc(), @@ -908,8 +888,9 @@ bool LoopInterchangeLegality::currentLimitations() { dyn_cast<Instruction>(InnerInductionVar->getIncomingValue(0)); if (!InnerIndexVarInc) { - DEBUG(dbgs() << "Did not find an instruction to increment the induction " - << "variable.\n"); + LLVM_DEBUG( + dbgs() << "Did not find an instruction to increment the induction " + << "variable.\n"); ORE->emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "NoIncrementInInner", InnerLoop->getStartLoc(), @@ -924,7 +905,8 @@ bool LoopInterchangeLegality::currentLimitations() { // instruction. bool FoundInduction = false; - for (const Instruction &I : llvm::reverse(*InnerLoopLatch)) { + for (const Instruction &I : + llvm::reverse(InnerLoopLatch->instructionsWithoutDebug())) { if (isa<BranchInst>(I) || isa<CmpInst>(I) || isa<TruncInst>(I) || isa<ZExtInst>(I)) continue; @@ -932,8 +914,8 @@ bool LoopInterchangeLegality::currentLimitations() { // We found an instruction. If this is not induction variable then it is not // safe to split this loop latch. if (!I.isIdenticalTo(InnerIndexVarInc)) { - DEBUG(dbgs() << "Found unsupported instructions between induction " - << "variable increment and branch.\n"); + LLVM_DEBUG(dbgs() << "Found unsupported instructions between induction " + << "variable increment and branch.\n"); ORE->emit([&]() { return OptimizationRemarkMissed( DEBUG_TYPE, "UnsupportedInsBetweenInduction", @@ -950,7 +932,7 @@ bool LoopInterchangeLegality::currentLimitations() { // The loop latch ended and we didn't find the induction variable return as // current limitation. if (!FoundInduction) { - DEBUG(dbgs() << "Did not find the induction variable.\n"); + LLVM_DEBUG(dbgs() << "Did not find the induction variable.\n"); ORE->emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "NoIndutionVariable", InnerLoop->getStartLoc(), @@ -962,13 +944,50 @@ bool LoopInterchangeLegality::currentLimitations() { return false; } +// We currently support LCSSA PHI nodes in the outer loop exit, if their +// incoming values do not come from the outer loop latch or if the +// outer loop latch has a single predecessor. In that case, the value will +// be available if both the inner and outer loop conditions are true, which +// will still be true after interchanging. If we have multiple predecessor, +// that may not be the case, e.g. because the outer loop latch may be executed +// if the inner loop is not executed. +static bool areLoopExitPHIsSupported(Loop *OuterLoop, Loop *InnerLoop) { + BasicBlock *LoopNestExit = OuterLoop->getUniqueExitBlock(); + for (PHINode &PHI : LoopNestExit->phis()) { + // FIXME: We currently are not able to detect floating point reductions + // and have to use floating point PHIs as a proxy to prevent + // interchanging in the presence of floating point reductions. + if (PHI.getType()->isFloatingPointTy()) + return false; + for (unsigned i = 0; i < PHI.getNumIncomingValues(); i++) { + Instruction *IncomingI = dyn_cast<Instruction>(PHI.getIncomingValue(i)); + if (!IncomingI || IncomingI->getParent() != OuterLoop->getLoopLatch()) + continue; + + // The incoming value is defined in the outer loop latch. Currently we + // only support that in case the outer loop latch has a single predecessor. + // This guarantees that the outer loop latch is executed if and only if + // the inner loop is executed (because tightlyNested() guarantees that the + // outer loop header only branches to the inner loop or the outer loop + // latch). + // FIXME: We could weaken this logic and allow multiple predecessors, + // if the values are produced outside the loop latch. We would need + // additional logic to update the PHI nodes in the exit block as + // well. + if (OuterLoop->getLoopLatch()->getUniquePredecessor() == nullptr) + return false; + } + } + return true; +} + bool LoopInterchangeLegality::canInterchangeLoops(unsigned InnerLoopId, unsigned OuterLoopId, CharMatrix &DepMatrix) { if (!isLegalToInterChangeLoops(DepMatrix, InnerLoopId, OuterLoopId)) { - DEBUG(dbgs() << "Failed interchange InnerLoopId = " << InnerLoopId - << " and OuterLoopId = " << OuterLoopId - << " due to dependence\n"); + LLVM_DEBUG(dbgs() << "Failed interchange InnerLoopId = " << InnerLoopId + << " and OuterLoopId = " << OuterLoopId + << " due to dependence\n"); ORE->emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "Dependence", InnerLoop->getStartLoc(), @@ -977,16 +996,23 @@ bool LoopInterchangeLegality::canInterchangeLoops(unsigned InnerLoopId, }); return false; } - // Check if outer and inner loop contain legal instructions only. for (auto *BB : OuterLoop->blocks()) - for (Instruction &I : *BB) + for (Instruction &I : BB->instructionsWithoutDebug()) if (CallInst *CI = dyn_cast<CallInst>(&I)) { // readnone functions do not prevent interchanging. if (CI->doesNotReadMemory()) continue; - DEBUG(dbgs() << "Loops with call instructions cannot be interchanged " - << "safely."); + LLVM_DEBUG( + dbgs() << "Loops with call instructions cannot be interchanged " + << "safely."); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "CallInst", + CI->getDebugLoc(), + CI->getParent()) + << "Cannot interchange loops due to call instruction."; + }); + return false; } @@ -1015,13 +1041,13 @@ bool LoopInterchangeLegality::canInterchangeLoops(unsigned InnerLoopId, // TODO: The loops could not be interchanged due to current limitations in the // transform module. if (currentLimitations()) { - DEBUG(dbgs() << "Not legal because of current transform limitation\n"); + LLVM_DEBUG(dbgs() << "Not legal because of current transform limitation\n"); return false; } // Check if the loops are tightly nested. if (!tightlyNested(OuterLoop, InnerLoop)) { - DEBUG(dbgs() << "Loops not tightly nested\n"); + LLVM_DEBUG(dbgs() << "Loops not tightly nested\n"); ORE->emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "NotTightlyNested", InnerLoop->getStartLoc(), @@ -1032,6 +1058,17 @@ bool LoopInterchangeLegality::canInterchangeLoops(unsigned InnerLoopId, return false; } + if (!areLoopExitPHIsSupported(OuterLoop, InnerLoop)) { + LLVM_DEBUG(dbgs() << "Found unsupported PHI nodes in outer loop exit.\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedExitPHI", + OuterLoop->getStartLoc(), + OuterLoop->getHeader()) + << "Found unsupported PHI node in loop exit."; + }); + return false; + } + return true; } @@ -1100,7 +1137,8 @@ static bool isProfitableForVectorization(unsigned InnerLoopId, } // If outer loop has dependence and inner loop is loop independent then it is // profitable to interchange to enable parallelism. - return true; + // If there are no dependences, interchanging will not improve anything. + return !DepMatrix.empty(); } bool LoopInterchangeProfitability::isProfitable(unsigned InnerLoopId, @@ -1115,7 +1153,7 @@ bool LoopInterchangeProfitability::isProfitable(unsigned InnerLoopId, // of induction variables in the instruction and allows reordering if number // of bad orders is more than good. int Cost = getInstrOrderCost(); - DEBUG(dbgs() << "Cost = " << Cost << "\n"); + LLVM_DEBUG(dbgs() << "Cost = " << Cost << "\n"); if (Cost < -LoopInterchangeCostThreshold) return true; @@ -1138,33 +1176,88 @@ bool LoopInterchangeProfitability::isProfitable(unsigned InnerLoopId, void LoopInterchangeTransform::removeChildLoop(Loop *OuterLoop, Loop *InnerLoop) { - for (Loop::iterator I = OuterLoop->begin(), E = OuterLoop->end(); I != E; - ++I) { - if (*I == InnerLoop) { - OuterLoop->removeChildLoop(I); + for (Loop *L : *OuterLoop) + if (L == InnerLoop) { + OuterLoop->removeChildLoop(L); return; } - } llvm_unreachable("Couldn't find loop"); } -void LoopInterchangeTransform::restructureLoops(Loop *InnerLoop, - Loop *OuterLoop) { +/// Update LoopInfo, after interchanging. NewInner and NewOuter refer to the +/// new inner and outer loop after interchanging: NewInner is the original +/// outer loop and NewOuter is the original inner loop. +/// +/// Before interchanging, we have the following structure +/// Outer preheader +// Outer header +// Inner preheader +// Inner header +// Inner body +// Inner latch +// outer bbs +// Outer latch +// +// After interchanging: +// Inner preheader +// Inner header +// Outer preheader +// Outer header +// Inner body +// outer bbs +// Outer latch +// Inner latch +void LoopInterchangeTransform::restructureLoops( + Loop *NewInner, Loop *NewOuter, BasicBlock *OrigInnerPreHeader, + BasicBlock *OrigOuterPreHeader) { Loop *OuterLoopParent = OuterLoop->getParentLoop(); + // The original inner loop preheader moves from the new inner loop to + // the parent loop, if there is one. + NewInner->removeBlockFromLoop(OrigInnerPreHeader); + LI->changeLoopFor(OrigInnerPreHeader, OuterLoopParent); + + // Switch the loop levels. if (OuterLoopParent) { // Remove the loop from its parent loop. - removeChildLoop(OuterLoopParent, OuterLoop); - removeChildLoop(OuterLoop, InnerLoop); - OuterLoopParent->addChildLoop(InnerLoop); + removeChildLoop(OuterLoopParent, NewInner); + removeChildLoop(NewInner, NewOuter); + OuterLoopParent->addChildLoop(NewOuter); } else { - removeChildLoop(OuterLoop, InnerLoop); - LI->changeTopLevelLoop(OuterLoop, InnerLoop); + removeChildLoop(NewInner, NewOuter); + LI->changeTopLevelLoop(NewInner, NewOuter); + } + while (!NewOuter->empty()) + NewInner->addChildLoop(NewOuter->removeChildLoop(NewOuter->begin())); + NewOuter->addChildLoop(NewInner); + + // BBs from the original inner loop. + SmallVector<BasicBlock *, 8> OrigInnerBBs(NewOuter->blocks()); + + // Add BBs from the original outer loop to the original inner loop (excluding + // BBs already in inner loop) + for (BasicBlock *BB : NewInner->blocks()) + if (LI->getLoopFor(BB) == NewInner) + NewOuter->addBlockEntry(BB); + + // Now remove inner loop header and latch from the new inner loop and move + // other BBs (the loop body) to the new inner loop. + BasicBlock *OuterHeader = NewOuter->getHeader(); + BasicBlock *OuterLatch = NewOuter->getLoopLatch(); + for (BasicBlock *BB : OrigInnerBBs) { + // Nothing will change for BBs in child loops. + if (LI->getLoopFor(BB) != NewOuter) + continue; + // Remove the new outer loop header and latch from the new inner loop. + if (BB == OuterHeader || BB == OuterLatch) + NewInner->removeBlockFromLoop(BB); + else + LI->changeLoopFor(BB, NewInner); } - while (!InnerLoop->empty()) - OuterLoop->addChildLoop(InnerLoop->removeChildLoop(InnerLoop->begin())); - - InnerLoop->addChildLoop(OuterLoop); + // The preheader of the original outer loop becomes part of the new + // outer loop. + NewOuter->addBlockEntry(OrigOuterPreHeader); + LI->changeLoopFor(OrigOuterPreHeader, NewOuter); } bool LoopInterchangeTransform::transform() { @@ -1173,10 +1266,10 @@ bool LoopInterchangeTransform::transform() { if (InnerLoop->getSubLoops().empty()) { BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); - DEBUG(dbgs() << "Calling Split Inner Loop\n"); + LLVM_DEBUG(dbgs() << "Calling Split Inner Loop\n"); PHINode *InductionPHI = getInductionVariable(InnerLoop, SE); if (!InductionPHI) { - DEBUG(dbgs() << "Failed to find the point to split loop latch \n"); + LLVM_DEBUG(dbgs() << "Failed to find the point to split loop latch \n"); return false; } @@ -1185,8 +1278,7 @@ bool LoopInterchangeTransform::transform() { else InnerIndexVar = dyn_cast<Instruction>(InductionPHI->getIncomingValue(0)); - // Ensure that InductionPHI is the first Phi node as required by - // splitInnerLoopHeader + // Ensure that InductionPHI is the first Phi node. if (&InductionPHI->getParent()->front() != InductionPHI) InductionPHI->moveBefore(&InductionPHI->getParent()->front()); @@ -1194,20 +1286,20 @@ bool LoopInterchangeTransform::transform() { // incremented/decremented. // TODO: This splitting logic may not work always. Fix this. splitInnerLoopLatch(InnerIndexVar); - DEBUG(dbgs() << "splitInnerLoopLatch done\n"); + LLVM_DEBUG(dbgs() << "splitInnerLoopLatch done\n"); // Splits the inner loops phi nodes out into a separate basic block. - splitInnerLoopHeader(); - DEBUG(dbgs() << "splitInnerLoopHeader done\n"); + BasicBlock *InnerLoopHeader = InnerLoop->getHeader(); + SplitBlock(InnerLoopHeader, InnerLoopHeader->getFirstNonPHI(), DT, LI); + LLVM_DEBUG(dbgs() << "splitting InnerLoopHeader done\n"); } Transformed |= adjustLoopLinks(); if (!Transformed) { - DEBUG(dbgs() << "adjustLoopLinks failed\n"); + LLVM_DEBUG(dbgs() << "adjustLoopLinks failed\n"); return false; } - restructureLoops(InnerLoop, OuterLoop); return true; } @@ -1217,38 +1309,6 @@ void LoopInterchangeTransform::splitInnerLoopLatch(Instruction *Inc) { InnerLoopLatch = SplitBlock(InnerLoopLatchPred, Inc, DT, LI); } -void LoopInterchangeTransform::splitInnerLoopHeader() { - // Split the inner loop header out. Here make sure that the reduction PHI's - // stay in the innerloop body. - BasicBlock *InnerLoopHeader = InnerLoop->getHeader(); - BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); - if (InnerLoopHasReduction) { - // Note: The induction PHI must be the first PHI for this to work - BasicBlock *New = InnerLoopHeader->splitBasicBlock( - ++(InnerLoopHeader->begin()), InnerLoopHeader->getName() + ".split"); - if (LI) - if (Loop *L = LI->getLoopFor(InnerLoopHeader)) - L->addBasicBlockToLoop(New, *LI); - - // Adjust Reduction PHI's in the block. - SmallVector<PHINode *, 8> PHIVec; - for (auto I = New->begin(); isa<PHINode>(I); ++I) { - PHINode *PHI = dyn_cast<PHINode>(I); - Value *V = PHI->getIncomingValueForBlock(InnerLoopPreHeader); - PHI->replaceAllUsesWith(V); - PHIVec.push_back((PHI)); - } - for (PHINode *P : PHIVec) { - P->eraseFromParent(); - } - } else { - SplitBlock(InnerLoopHeader, InnerLoopHeader->getFirstNonPHI(), DT, LI); - } - - DEBUG(dbgs() << "Output of splitInnerLoopHeader InnerLoopHeaderSucc & " - "InnerLoopHeader\n"); -} - /// \brief Move all instructions except the terminator from FromBB right before /// InsertBefore static void moveBBContents(BasicBlock *FromBB, Instruction *InsertBefore) { @@ -1262,18 +1322,40 @@ static void moveBBContents(BasicBlock *FromBB, Instruction *InsertBefore) { void LoopInterchangeTransform::updateIncomingBlock(BasicBlock *CurrBlock, BasicBlock *OldPred, BasicBlock *NewPred) { - for (auto I = CurrBlock->begin(); isa<PHINode>(I); ++I) { - PHINode *PHI = cast<PHINode>(I); - unsigned Num = PHI->getNumIncomingValues(); + for (PHINode &PHI : CurrBlock->phis()) { + unsigned Num = PHI.getNumIncomingValues(); for (unsigned i = 0; i < Num; ++i) { - if (PHI->getIncomingBlock(i) == OldPred) - PHI->setIncomingBlock(i, NewPred); + if (PHI.getIncomingBlock(i) == OldPred) + PHI.setIncomingBlock(i, NewPred); + } + } +} + +/// Update BI to jump to NewBB instead of OldBB. Records updates to +/// the dominator tree in DTUpdates, if DT should be preserved. +static void updateSuccessor(BranchInst *BI, BasicBlock *OldBB, + BasicBlock *NewBB, + std::vector<DominatorTree::UpdateType> &DTUpdates) { + assert(llvm::count_if(BI->successors(), + [OldBB](BasicBlock *BB) { return BB == OldBB; }) < 2 && + "BI must jump to OldBB at most once."); + for (unsigned i = 0, e = BI->getNumSuccessors(); i < e; ++i) { + if (BI->getSuccessor(i) == OldBB) { + BI->setSuccessor(i, NewBB); + + DTUpdates.push_back( + {DominatorTree::UpdateKind::Insert, BI->getParent(), NewBB}); + DTUpdates.push_back( + {DominatorTree::UpdateKind::Delete, BI->getParent(), OldBB}); + break; } } } bool LoopInterchangeTransform::adjustLoopBranches() { - DEBUG(dbgs() << "adjustLoopBranches called\n"); + LLVM_DEBUG(dbgs() << "adjustLoopBranches called\n"); + std::vector<DominatorTree::UpdateType> DTUpdates; + // Adjust the loop preheader BasicBlock *InnerLoopHeader = InnerLoop->getHeader(); BasicBlock *OuterLoopHeader = OuterLoop->getHeader(); @@ -1313,27 +1395,18 @@ bool LoopInterchangeTransform::adjustLoopBranches() { return false; // Adjust Loop Preheader and headers - - unsigned NumSucc = OuterLoopPredecessorBI->getNumSuccessors(); - for (unsigned i = 0; i < NumSucc; ++i) { - if (OuterLoopPredecessorBI->getSuccessor(i) == OuterLoopPreHeader) - OuterLoopPredecessorBI->setSuccessor(i, InnerLoopPreHeader); - } - - NumSucc = OuterLoopHeaderBI->getNumSuccessors(); - for (unsigned i = 0; i < NumSucc; ++i) { - if (OuterLoopHeaderBI->getSuccessor(i) == OuterLoopLatch) - OuterLoopHeaderBI->setSuccessor(i, LoopExit); - else if (OuterLoopHeaderBI->getSuccessor(i) == InnerLoopPreHeader) - OuterLoopHeaderBI->setSuccessor(i, InnerLoopHeaderSuccessor); - } + updateSuccessor(OuterLoopPredecessorBI, OuterLoopPreHeader, + InnerLoopPreHeader, DTUpdates); + updateSuccessor(OuterLoopHeaderBI, OuterLoopLatch, LoopExit, DTUpdates); + updateSuccessor(OuterLoopHeaderBI, InnerLoopPreHeader, + InnerLoopHeaderSuccessor, DTUpdates); // Adjust reduction PHI's now that the incoming block has changed. updateIncomingBlock(InnerLoopHeaderSuccessor, InnerLoopHeader, OuterLoopHeader); - BranchInst::Create(OuterLoopPreHeader, InnerLoopHeaderBI); - InnerLoopHeaderBI->eraseFromParent(); + updateSuccessor(InnerLoopHeaderBI, InnerLoopHeaderSuccessor, + OuterLoopPreHeader, DTUpdates); // -------------Adjust loop latches----------- if (InnerLoopLatchBI->getSuccessor(0) == InnerLoopHeader) @@ -1341,19 +1414,15 @@ bool LoopInterchangeTransform::adjustLoopBranches() { else InnerLoopLatchSuccessor = InnerLoopLatchBI->getSuccessor(0); - NumSucc = InnerLoopLatchPredecessorBI->getNumSuccessors(); - for (unsigned i = 0; i < NumSucc; ++i) { - if (InnerLoopLatchPredecessorBI->getSuccessor(i) == InnerLoopLatch) - InnerLoopLatchPredecessorBI->setSuccessor(i, InnerLoopLatchSuccessor); - } + updateSuccessor(InnerLoopLatchPredecessorBI, InnerLoopLatch, + InnerLoopLatchSuccessor, DTUpdates); // Adjust PHI nodes in InnerLoopLatchSuccessor. Update all uses of PHI with // the value and remove this PHI node from inner loop. SmallVector<PHINode *, 8> LcssaVec; - for (auto I = InnerLoopLatchSuccessor->begin(); isa<PHINode>(I); ++I) { - PHINode *LcssaPhi = cast<PHINode>(I); - LcssaVec.push_back(LcssaPhi); - } + for (PHINode &P : InnerLoopLatchSuccessor->phis()) + LcssaVec.push_back(&P); + for (PHINode *P : LcssaVec) { Value *Incoming = P->getIncomingValueForBlock(InnerLoopLatch); P->replaceAllUsesWith(Incoming); @@ -1365,19 +1434,52 @@ bool LoopInterchangeTransform::adjustLoopBranches() { else OuterLoopLatchSuccessor = OuterLoopLatchBI->getSuccessor(0); - if (InnerLoopLatchBI->getSuccessor(1) == InnerLoopLatchSuccessor) - InnerLoopLatchBI->setSuccessor(1, OuterLoopLatchSuccessor); - else - InnerLoopLatchBI->setSuccessor(0, OuterLoopLatchSuccessor); + updateSuccessor(InnerLoopLatchBI, InnerLoopLatchSuccessor, + OuterLoopLatchSuccessor, DTUpdates); + updateSuccessor(OuterLoopLatchBI, OuterLoopLatchSuccessor, InnerLoopLatch, + DTUpdates); updateIncomingBlock(OuterLoopLatchSuccessor, OuterLoopLatch, InnerLoopLatch); - if (OuterLoopLatchBI->getSuccessor(0) == OuterLoopLatchSuccessor) { - OuterLoopLatchBI->setSuccessor(0, InnerLoopLatch); - } else { - OuterLoopLatchBI->setSuccessor(1, InnerLoopLatch); + DT->applyUpdates(DTUpdates); + restructureLoops(OuterLoop, InnerLoop, InnerLoopPreHeader, + OuterLoopPreHeader); + + // Now update the reduction PHIs in the inner and outer loop headers. + SmallVector<PHINode *, 4> InnerLoopPHIs, OuterLoopPHIs; + for (PHINode &PHI : drop_begin(InnerLoopHeader->phis(), 1)) + InnerLoopPHIs.push_back(cast<PHINode>(&PHI)); + for (PHINode &PHI : drop_begin(OuterLoopHeader->phis(), 1)) + OuterLoopPHIs.push_back(cast<PHINode>(&PHI)); + + for (PHINode *PHI : OuterLoopPHIs) + PHI->moveBefore(InnerLoopHeader->getFirstNonPHI()); + + // Move the PHI nodes from the inner loop header to the outer loop header. + // We have to deal with one kind of PHI nodes: + // 1) PHI nodes that are part of inner loop-only reductions. + // We only have to move the PHI node and update the incoming blocks. + for (PHINode *PHI : InnerLoopPHIs) { + PHI->moveBefore(OuterLoopHeader->getFirstNonPHI()); + for (BasicBlock *InBB : PHI->blocks()) { + if (InnerLoop->contains(InBB)) + continue; + + assert(!isa<PHINode>(PHI->getIncomingValueForBlock(InBB)) && + "Unexpected incoming PHI node, reductions in outer loop are not " + "supported yet"); + PHI->replaceAllUsesWith(PHI->getIncomingValueForBlock(InBB)); + PHI->eraseFromParent(); + break; + } } + // Update the incoming blocks for moved PHI nodes. + updateIncomingBlock(OuterLoopHeader, InnerLoopPreHeader, OuterLoopPreHeader); + updateIncomingBlock(OuterLoopHeader, InnerLoopLatch, OuterLoopLatch); + updateIncomingBlock(InnerLoopHeader, OuterLoopPreHeader, InnerLoopPreHeader); + updateIncomingBlock(InnerLoopHeader, OuterLoopLatch, InnerLoopLatch); + return true; } diff --git a/lib/Transforms/Scalar/LoopLoadElimination.cpp b/lib/Transforms/Scalar/LoopLoadElimination.cpp index dfa5ec1f354d..19bd9ebcc15b 100644 --- a/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -25,7 +25,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -52,6 +52,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/LoopVersioning.h" #include <algorithm> #include <cassert> @@ -79,7 +80,7 @@ STATISTIC(NumLoopLoadEliminted, "Number of loads eliminated by LLE"); namespace { -/// \brief Represent a store-to-forwarding candidate. +/// Represent a store-to-forwarding candidate. struct StoreToLoadForwardingCandidate { LoadInst *Load; StoreInst *Store; @@ -87,7 +88,7 @@ struct StoreToLoadForwardingCandidate { StoreToLoadForwardingCandidate(LoadInst *Load, StoreInst *Store) : Load(Load), Store(Store) {} - /// \brief Return true if the dependence from the store to the load has a + /// Return true if the dependence from the store to the load has a /// distance of one. E.g. A[i+1] = A[i] bool isDependenceDistanceOfOne(PredicatedScalarEvolution &PSE, Loop *L) const { @@ -136,7 +137,7 @@ struct StoreToLoadForwardingCandidate { } // end anonymous namespace -/// \brief Check if the store dominates all latches, so as long as there is no +/// Check if the store dominates all latches, so as long as there is no /// intervening store this value will be loaded in the next iteration. static bool doesStoreDominatesAllLatches(BasicBlock *StoreBlock, Loop *L, DominatorTree *DT) { @@ -147,21 +148,21 @@ static bool doesStoreDominatesAllLatches(BasicBlock *StoreBlock, Loop *L, }); } -/// \brief Return true if the load is not executed on all paths in the loop. +/// Return true if the load is not executed on all paths in the loop. static bool isLoadConditional(LoadInst *Load, Loop *L) { return Load->getParent() != L->getHeader(); } namespace { -/// \brief The per-loop class that does most of the work. +/// The per-loop class that does most of the work. class LoadEliminationForLoop { public: LoadEliminationForLoop(Loop *L, LoopInfo *LI, const LoopAccessInfo &LAI, DominatorTree *DT) : L(L), LI(LI), LAI(LAI), DT(DT), PSE(LAI.getPSE()) {} - /// \brief Look through the loop-carried and loop-independent dependences in + /// Look through the loop-carried and loop-independent dependences in /// this loop and find store->load dependences. /// /// Note that no candidate is returned if LAA has failed to analyze the loop @@ -178,7 +179,7 @@ public: // forward and backward dependences qualify. Disqualify loads that have // other unknown dependences. - SmallSet<Instruction *, 4> LoadsWithUnknownDepedence; + SmallPtrSet<Instruction *, 4> LoadsWithUnknownDepedence; for (const auto &Dep : *Deps) { Instruction *Source = Dep.getSource(LAI); @@ -222,14 +223,14 @@ public: return Candidates; } - /// \brief Return the index of the instruction according to program order. + /// Return the index of the instruction according to program order. unsigned getInstrIndex(Instruction *Inst) { auto I = InstOrder.find(Inst); assert(I != InstOrder.end() && "No index for instruction"); return I->second; } - /// \brief If a load has multiple candidates associated (i.e. different + /// If a load has multiple candidates associated (i.e. different /// stores), it means that it could be forwarding from multiple stores /// depending on control flow. Remove these candidates. /// @@ -284,22 +285,24 @@ public: Candidates.remove_if([&](const StoreToLoadForwardingCandidate &Cand) { if (LoadToSingleCand[Cand.Load] != &Cand) { - DEBUG(dbgs() << "Removing from candidates: \n" << Cand - << " The load may have multiple stores forwarding to " - << "it\n"); + LLVM_DEBUG( + dbgs() << "Removing from candidates: \n" + << Cand + << " The load may have multiple stores forwarding to " + << "it\n"); return true; } return false; }); } - /// \brief Given two pointers operations by their RuntimePointerChecking + /// Given two pointers operations by their RuntimePointerChecking /// indices, return true if they require an alias check. /// /// We need a check if one is a pointer for a candidate load and the other is /// a pointer for a possibly intervening store. bool needsChecking(unsigned PtrIdx1, unsigned PtrIdx2, - const SmallSet<Value *, 4> &PtrsWrittenOnFwdingPath, + const SmallPtrSet<Value *, 4> &PtrsWrittenOnFwdingPath, const std::set<Value *> &CandLoadPtrs) { Value *Ptr1 = LAI.getRuntimePointerChecking()->getPointerInfo(PtrIdx1).PointerValue; @@ -309,11 +312,11 @@ public: (PtrsWrittenOnFwdingPath.count(Ptr2) && CandLoadPtrs.count(Ptr1))); } - /// \brief Return pointers that are possibly written to on the path from a + /// Return pointers that are possibly written to on the path from a /// forwarding store to a load. /// /// These pointers need to be alias-checked against the forwarding candidates. - SmallSet<Value *, 4> findPointersWrittenOnForwardingPath( + SmallPtrSet<Value *, 4> findPointersWrittenOnForwardingPath( const SmallVectorImpl<StoreToLoadForwardingCandidate> &Candidates) { // From FirstStore to LastLoad neither of the elimination candidate loads // should overlap with any of the stores. @@ -351,7 +354,7 @@ public: // We're looking for stores after the first forwarding store until the end // of the loop, then from the beginning of the loop until the last // forwarded-to load. Collect the pointer for the stores. - SmallSet<Value *, 4> PtrsWrittenOnFwdingPath; + SmallPtrSet<Value *, 4> PtrsWrittenOnFwdingPath; auto InsertStorePtr = [&](Instruction *I) { if (auto *S = dyn_cast<StoreInst>(I)) @@ -366,16 +369,16 @@ public: return PtrsWrittenOnFwdingPath; } - /// \brief Determine the pointer alias checks to prove that there are no + /// Determine the pointer alias checks to prove that there are no /// intervening stores. SmallVector<RuntimePointerChecking::PointerCheck, 4> collectMemchecks( const SmallVectorImpl<StoreToLoadForwardingCandidate> &Candidates) { - SmallSet<Value *, 4> PtrsWrittenOnFwdingPath = + SmallPtrSet<Value *, 4> PtrsWrittenOnFwdingPath = findPointersWrittenOnForwardingPath(Candidates); // Collect the pointers of the candidate loads. - // FIXME: SmallSet does not work with std::inserter. + // FIXME: SmallPtrSet does not work with std::inserter. std::set<Value *> CandLoadPtrs; transform(Candidates, std::inserter(CandLoadPtrs, CandLoadPtrs.begin()), @@ -394,13 +397,14 @@ public: return false; }); - DEBUG(dbgs() << "\nPointer Checks (count: " << Checks.size() << "):\n"); - DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks)); + LLVM_DEBUG(dbgs() << "\nPointer Checks (count: " << Checks.size() + << "):\n"); + LLVM_DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks)); return Checks; } - /// \brief Perform the transformation for a candidate. + /// Perform the transformation for a candidate. void propagateStoredValueToLoadUsers(const StoreToLoadForwardingCandidate &Cand, SCEVExpander &SEE) { @@ -436,11 +440,11 @@ public: Cand.Load->replaceAllUsesWith(PHI); } - /// \brief Top-level driver for each loop: find store->load forwarding + /// Top-level driver for each loop: find store->load forwarding /// candidates, add run-time checks and perform transformation. bool processLoop() { - DEBUG(dbgs() << "\nIn \"" << L->getHeader()->getParent()->getName() - << "\" checking " << *L << "\n"); + LLVM_DEBUG(dbgs() << "\nIn \"" << L->getHeader()->getParent()->getName() + << "\" checking " << *L << "\n"); // Look for store-to-load forwarding cases across the // backedge. E.g.: @@ -479,7 +483,7 @@ public: SmallVector<StoreToLoadForwardingCandidate, 4> Candidates; unsigned NumForwarding = 0; for (const StoreToLoadForwardingCandidate Cand : StoreToLoadDependences) { - DEBUG(dbgs() << "Candidate " << Cand); + LLVM_DEBUG(dbgs() << "Candidate " << Cand); // Make sure that the stored values is available everywhere in the loop in // the next iteration. @@ -498,9 +502,10 @@ public: continue; ++NumForwarding; - DEBUG(dbgs() - << NumForwarding - << ". Valid store-to-load forwarding across the loop backedge\n"); + LLVM_DEBUG( + dbgs() + << NumForwarding + << ". Valid store-to-load forwarding across the loop backedge\n"); Candidates.push_back(Cand); } if (Candidates.empty()) @@ -513,25 +518,26 @@ public: // Too many checks are likely to outweigh the benefits of forwarding. if (Checks.size() > Candidates.size() * CheckPerElim) { - DEBUG(dbgs() << "Too many run-time checks needed.\n"); + LLVM_DEBUG(dbgs() << "Too many run-time checks needed.\n"); return false; } if (LAI.getPSE().getUnionPredicate().getComplexity() > LoadElimSCEVCheckThreshold) { - DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n"); + LLVM_DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n"); return false; } if (!Checks.empty() || !LAI.getPSE().getUnionPredicate().isAlwaysTrue()) { if (L->getHeader()->getParent()->optForSize()) { - DEBUG(dbgs() << "Versioning is needed but not allowed when optimizing " - "for size.\n"); + LLVM_DEBUG( + dbgs() << "Versioning is needed but not allowed when optimizing " + "for size.\n"); return false; } if (!L->isLoopSimplifyForm()) { - DEBUG(dbgs() << "Loop is not is loop-simplify form"); + LLVM_DEBUG(dbgs() << "Loop is not is loop-simplify form"); return false; } @@ -558,7 +564,7 @@ public: private: Loop *L; - /// \brief Maps the load/store instructions to their index according to + /// Maps the load/store instructions to their index according to /// program order. DenseMap<Instruction *, unsigned> InstOrder; @@ -599,7 +605,7 @@ eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, DominatorTree &DT, namespace { -/// \brief The pass. Most of the work is delegated to the per-loop +/// The pass. Most of the work is delegated to the per-loop /// LoadEliminationForLoop class. class LoopLoadElimination : public FunctionPass { public: diff --git a/lib/Transforms/Scalar/LoopPredication.cpp b/lib/Transforms/Scalar/LoopPredication.cpp index 2e4c7b19e476..561ceea1d880 100644 --- a/lib/Transforms/Scalar/LoopPredication.cpp +++ b/lib/Transforms/Scalar/LoopPredication.cpp @@ -155,7 +155,7 @@ // When S = -1 (i.e. reverse iterating loop), the transformation is supported // when: // * The loop has a single latch with the condition of the form: -// B(X) = X <pred> latchLimit, where <pred> is u> or s>. +// B(X) = X <pred> latchLimit, where <pred> is u>, u>=, s>, or s>=. // * The guard condition is of the form // G(X) = X - 1 u< guardLimit // @@ -171,9 +171,14 @@ // guardStart u< guardLimit && latchLimit u>= 1. // Similarly for sgt condition the widened condition is: // guardStart u< guardLimit && latchLimit s>= 1. +// For uge condition the widened condition is: +// guardStart u< guardLimit && latchLimit u> 1. +// For sge condition the widened condition is: +// guardStart u< guardLimit && latchLimit s> 1. //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LoopPredication.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -198,6 +203,20 @@ static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation", static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop", cl::Hidden, cl::init(true)); + +static cl::opt<bool> + SkipProfitabilityChecks("loop-predication-skip-profitability-checks", + cl::Hidden, cl::init(false)); + +// This is the scale factor for the latch probability. We use this during +// profitability analysis to find other exiting blocks that have a much higher +// probability of exiting the loop instead of loop exiting via latch. +// This value should be greater than 1 for a sane profitability check. +static cl::opt<float> LatchExitProbabilityScale( + "loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0), + cl::desc("scale factor for the latch probability. Value should be greater " + "than 1. Lower values are ignored")); + namespace { class LoopPredication { /// Represents an induction variable check: @@ -217,6 +236,7 @@ class LoopPredication { }; ScalarEvolution *SE; + BranchProbabilityInfo *BPI; Loop *L; const DataLayout *DL; @@ -250,6 +270,12 @@ class LoopPredication { IRBuilder<> &Builder); bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander); + // If the loop always exits through another block in the loop, we should not + // predicate based on the latch check. For example, the latch check can be a + // very coarse grained check and there can be more fine grained exit checks + // within the loop. We identify such unprofitable loops through BPI. + bool isLoopProfitableToPredicate(); + // When the IV type is wider than the range operand type, we can still do loop // predication, by generating SCEVs for the range and latch that are of the // same type. We achieve this by generating a SCEV truncate expression for the @@ -266,8 +292,10 @@ class LoopPredication { // Return the loopLatchCheck corresponding to the RangeCheckType if safe to do // so. Optional<LoopICmp> generateLoopLatchCheck(Type *RangeCheckType); + public: - LoopPredication(ScalarEvolution *SE) : SE(SE){}; + LoopPredication(ScalarEvolution *SE, BranchProbabilityInfo *BPI) + : SE(SE), BPI(BPI){}; bool runOnLoop(Loop *L); }; @@ -279,6 +307,7 @@ public: } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<BranchProbabilityInfoWrapperPass>(); getLoopAnalysisUsage(AU); } @@ -286,7 +315,9 @@ public: if (skipLoop(L)) return false; auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - LoopPredication LP(SE); + BranchProbabilityInfo &BPI = + getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); + LoopPredication LP(SE, &BPI); return LP.runOnLoop(L); } }; @@ -296,6 +327,7 @@ char LoopPredicationLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication", "Loop predication", false, false) +INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication", "Loop predication", false, false) @@ -307,7 +339,11 @@ Pass *llvm::createLoopPredicationPass() { PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U) { - LoopPredication LP(&AR.SE); + const auto &FAM = + AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); + Function *F = L.getHeader()->getParent(); + auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F); + LoopPredication LP(&AR.SE, BPI); if (!LP.runOnLoop(&L)) return PreservedAnalyses::all(); @@ -375,11 +411,11 @@ LoopPredication::generateLoopLatchCheck(Type *RangeCheckType) { if (!NewLatchCheck.IV) return None; NewLatchCheck.Limit = SE->getTruncateExpr(LatchCheck.Limit, RangeCheckType); - DEBUG(dbgs() << "IV of type: " << *LatchType - << "can be represented as range check type:" << *RangeCheckType - << "\n"); - DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n"); - DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n"); + LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType + << "can be represented as range check type:" + << *RangeCheckType << "\n"); + LLVM_DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n"); + LLVM_DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n"); return NewLatchCheck; } @@ -412,30 +448,15 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( SE->getMinusSCEV(LatchStart, SE->getOne(Ty))); if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) || !CanExpand(LatchLimit) || !CanExpand(RHS)) { - DEBUG(dbgs() << "Can't expand limit check!\n"); + LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); return None; } - ICmpInst::Predicate LimitCheckPred; - switch (LatchCheck.Pred) { - case ICmpInst::ICMP_ULT: - LimitCheckPred = ICmpInst::ICMP_ULE; - break; - case ICmpInst::ICMP_ULE: - LimitCheckPred = ICmpInst::ICMP_ULT; - break; - case ICmpInst::ICMP_SLT: - LimitCheckPred = ICmpInst::ICMP_SLE; - break; - case ICmpInst::ICMP_SLE: - LimitCheckPred = ICmpInst::ICMP_SLT; - break; - default: - llvm_unreachable("Unsupported loop latch!"); - } + auto LimitCheckPred = + ICmpInst::getFlippedStrictnessPredicate(LatchCheck.Pred); - DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n"); - DEBUG(dbgs() << "RHS: " << *RHS << "\n"); - DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n"); + LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n"); + LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n"); + LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n"); Instruction *InsertAt = Preheader->getTerminator(); auto *LimitCheck = @@ -454,16 +475,16 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( const SCEV *LatchLimit = LatchCheck.Limit; if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) || !CanExpand(LatchLimit)) { - DEBUG(dbgs() << "Can't expand limit check!\n"); + LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); return None; } // The decrement of the latch check IV should be the same as the // rangeCheckIV. auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE); if (RangeCheck.IV != PostDecLatchCheckIV) { - DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: " - << *PostDecLatchCheckIV - << " and RangeCheckIV: " << *RangeCheck.IV << "\n"); + LLVM_DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: " + << *PostDecLatchCheckIV + << " and RangeCheckIV: " << *RangeCheck.IV << "\n"); return None; } @@ -472,9 +493,8 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( // latchLimit <pred> 1. // See the header comment for reasoning of the checks. Instruction *InsertAt = Preheader->getTerminator(); - auto LimitCheckPred = ICmpInst::isSigned(LatchCheck.Pred) - ? ICmpInst::ICMP_SGE - : ICmpInst::ICMP_UGE; + auto LimitCheckPred = + ICmpInst::getFlippedStrictnessPredicate(LatchCheck.Pred); auto *FirstIterationCheck = expandCheck(Expander, Builder, ICmpInst::ICMP_ULT, GuardStart, GuardLimit, InsertAt); auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, @@ -488,8 +508,8 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, IRBuilder<> &Builder) { - DEBUG(dbgs() << "Analyzing ICmpInst condition:\n"); - DEBUG(ICI->dump()); + LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n"); + LLVM_DEBUG(ICI->dump()); // parseLoopStructure guarantees that the latch condition is: // ++i <pred> latchLimit, where <pred> is u<, u<=, s<, or s<=. @@ -497,34 +517,34 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, // i u< guardLimit auto RangeCheck = parseLoopICmp(ICI); if (!RangeCheck) { - DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); + LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); return None; } - DEBUG(dbgs() << "Guard check:\n"); - DEBUG(RangeCheck->dump()); + LLVM_DEBUG(dbgs() << "Guard check:\n"); + LLVM_DEBUG(RangeCheck->dump()); if (RangeCheck->Pred != ICmpInst::ICMP_ULT) { - DEBUG(dbgs() << "Unsupported range check predicate(" << RangeCheck->Pred - << ")!\n"); + LLVM_DEBUG(dbgs() << "Unsupported range check predicate(" + << RangeCheck->Pred << ")!\n"); return None; } auto *RangeCheckIV = RangeCheck->IV; if (!RangeCheckIV->isAffine()) { - DEBUG(dbgs() << "Range check IV is not affine!\n"); + LLVM_DEBUG(dbgs() << "Range check IV is not affine!\n"); return None; } auto *Step = RangeCheckIV->getStepRecurrence(*SE); // We cannot just compare with latch IV step because the latch and range IVs // may have different types. if (!isSupportedStep(Step)) { - DEBUG(dbgs() << "Range check and latch have IVs different steps!\n"); + LLVM_DEBUG(dbgs() << "Range check and latch have IVs different steps!\n"); return None; } auto *Ty = RangeCheckIV->getType(); auto CurrLatchCheckOpt = generateLoopLatchCheck(Ty); if (!CurrLatchCheckOpt) { - DEBUG(dbgs() << "Failed to generate a loop latch check " - "corresponding to range type: " - << *Ty << "\n"); + LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check " + "corresponding to range type: " + << *Ty << "\n"); return None; } @@ -535,7 +555,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() && "Range and latch steps should be of same type!"); if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) { - DEBUG(dbgs() << "Range and latch have different step values!\n"); + LLVM_DEBUG(dbgs() << "Range and latch have different step values!\n"); return None; } @@ -551,14 +571,14 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, SCEVExpander &Expander) { - DEBUG(dbgs() << "Processing guard:\n"); - DEBUG(Guard->dump()); + LLVM_DEBUG(dbgs() << "Processing guard:\n"); + LLVM_DEBUG(Guard->dump()); IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator())); // The guard condition is expected to be in form of: // cond1 && cond2 && cond3 ... - // Iterate over subconditions looking for for icmp conditions which can be + // Iterate over subconditions looking for icmp conditions which can be // widened across loop iterations. Widening these conditions remember the // resulting list of subconditions in Checks vector. SmallVector<Value *, 4> Worklist(1, Guard->getOperand(0)); @@ -605,7 +625,7 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, LastCheck = Builder.CreateAnd(LastCheck, Check); Guard->setOperand(0, LastCheck); - DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n"); + LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n"); return true; } @@ -614,7 +634,7 @@ Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() { BasicBlock *LoopLatch = L->getLoopLatch(); if (!LoopLatch) { - DEBUG(dbgs() << "The loop doesn't have a single latch!\n"); + LLVM_DEBUG(dbgs() << "The loop doesn't have a single latch!\n"); return None; } @@ -625,7 +645,7 @@ Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() { if (!match(LoopLatch->getTerminator(), m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), TrueDest, FalseDest))) { - DEBUG(dbgs() << "Failed to match the latch terminator!\n"); + LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n"); return None; } assert((TrueDest == L->getHeader() || FalseDest == L->getHeader()) && @@ -635,20 +655,20 @@ Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() { auto Result = parseLoopICmp(Pred, LHS, RHS); if (!Result) { - DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); + LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); return None; } // Check affine first, so if it's not we don't try to compute the step // recurrence. if (!Result->IV->isAffine()) { - DEBUG(dbgs() << "The induction variable is not affine!\n"); + LLVM_DEBUG(dbgs() << "The induction variable is not affine!\n"); return None; } auto *Step = Result->IV->getStepRecurrence(*SE); if (!isSupportedStep(Step)) { - DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n"); + LLVM_DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n"); return None; } @@ -658,13 +678,14 @@ Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() { Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE; } else { assert(Step->isAllOnesValue() && "Step should be -1!"); - return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT; + return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT && + Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_SGE; } }; if (IsUnsupportedPredicate(Step, Result->Pred)) { - DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred - << ")!\n"); + LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred + << ")!\n"); return None; } return Result; @@ -700,11 +721,65 @@ bool LoopPredication::isSafeToTruncateWideIVType(Type *RangeCheckType) { Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize; } +bool LoopPredication::isLoopProfitableToPredicate() { + if (SkipProfitabilityChecks || !BPI) + return true; + + SmallVector<std::pair<const BasicBlock *, const BasicBlock *>, 8> ExitEdges; + L->getExitEdges(ExitEdges); + // If there is only one exiting edge in the loop, it is always profitable to + // predicate the loop. + if (ExitEdges.size() == 1) + return true; + + // Calculate the exiting probabilities of all exiting edges from the loop, + // starting with the LatchExitProbability. + // Heuristic for profitability: If any of the exiting blocks' probability of + // exiting the loop is larger than exiting through the latch block, it's not + // profitable to predicate the loop. + auto *LatchBlock = L->getLoopLatch(); + assert(LatchBlock && "Should have a single latch at this point!"); + auto *LatchTerm = LatchBlock->getTerminator(); + assert(LatchTerm->getNumSuccessors() == 2 && + "expected to be an exiting block with 2 succs!"); + unsigned LatchBrExitIdx = + LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0; + BranchProbability LatchExitProbability = + BPI->getEdgeProbability(LatchBlock, LatchBrExitIdx); + + // Protect against degenerate inputs provided by the user. Providing a value + // less than one, can invert the definition of profitable loop predication. + float ScaleFactor = LatchExitProbabilityScale; + if (ScaleFactor < 1) { + LLVM_DEBUG( + dbgs() + << "Ignored user setting for loop-predication-latch-probability-scale: " + << LatchExitProbabilityScale << "\n"); + LLVM_DEBUG(dbgs() << "The value is set to 1.0\n"); + ScaleFactor = 1.0; + } + const auto LatchProbabilityThreshold = + LatchExitProbability * ScaleFactor; + + for (const auto &ExitEdge : ExitEdges) { + BranchProbability ExitingBlockProbability = + BPI->getEdgeProbability(ExitEdge.first, ExitEdge.second); + // Some exiting edge has higher probability than the latch exiting edge. + // No longer profitable to predicate. + if (ExitingBlockProbability > LatchProbabilityThreshold) + return false; + } + // Using BPI, we have concluded that the most probable way to exit from the + // loop is through the latch (or there's no profile information and all + // exits are equally likely). + return true; +} + bool LoopPredication::runOnLoop(Loop *Loop) { L = Loop; - DEBUG(dbgs() << "Analyzing "); - DEBUG(L->dump()); + LLVM_DEBUG(dbgs() << "Analyzing "); + LLVM_DEBUG(L->dump()); Module *M = L->getHeader()->getModule(); @@ -725,9 +800,13 @@ bool LoopPredication::runOnLoop(Loop *Loop) { return false; LatchCheck = *LatchCheckOpt; - DEBUG(dbgs() << "Latch check:\n"); - DEBUG(LatchCheck.dump()); + LLVM_DEBUG(dbgs() << "Latch check:\n"); + LLVM_DEBUG(LatchCheck.dump()); + if (!isLoopProfitableToPredicate()) { + LLVM_DEBUG(dbgs() << "Loop not profitable to predicate!\n"); + return false; + } // Collect all the guards into a vector and process later, so as not // to invalidate the instruction iterator. SmallVector<IntrinsicInst *, 4> Guards; diff --git a/lib/Transforms/Scalar/LoopRerollPass.cpp b/lib/Transforms/Scalar/LoopRerollPass.cpp index d1a54b877950..9a99e5925572 100644 --- a/lib/Transforms/Scalar/LoopRerollPass.cpp +++ b/lib/Transforms/Scalar/LoopRerollPass.cpp @@ -17,7 +17,7 @@ #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -28,6 +28,7 @@ #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" @@ -51,8 +52,8 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include <cassert> #include <cstddef> @@ -69,10 +70,6 @@ using namespace llvm; STATISTIC(NumRerolledLoops, "Number of rerolled loops"); static cl::opt<unsigned> -MaxInc("max-reroll-increment", cl::init(2048), cl::Hidden, - cl::desc("The maximum increment for loop rerolling")); - -static cl::opt<unsigned> NumToleratedFailedMatches("reroll-num-tolerated-failed-matches", cl::init(400), cl::Hidden, cl::desc("The maximum number of failures to tolerate" @@ -188,7 +185,7 @@ namespace { bool PreserveLCSSA; using SmallInstructionVector = SmallVector<Instruction *, 16>; - using SmallInstructionSet = SmallSet<Instruction *, 16>; + using SmallInstructionSet = SmallPtrSet<Instruction *, 16>; // Map between induction variable and its increment DenseMap<Instruction *, int64_t> IVToIncMap; @@ -397,8 +394,8 @@ namespace { /// Stage 3: Assuming validate() returned true, perform the /// replacement. - /// @param IterCount The maximum iteration count of L. - void replace(const SCEV *IterCount); + /// @param BackedgeTakenCount The backedge-taken count of L. + void replace(const SCEV *BackedgeTakenCount); protected: using UsesTy = MapVector<Instruction *, BitVector>; @@ -428,8 +425,7 @@ namespace { bool instrDependsOn(Instruction *I, UsesTy::iterator Start, UsesTy::iterator End); - void replaceIV(Instruction *Inst, Instruction *IV, const SCEV *IterCount); - void updateNonLoopCtrlIncr(); + void replaceIV(DAGRootSet &DRS, const SCEV *Start, const SCEV *IncrExpr); LoopReroll *Parent; @@ -482,8 +478,8 @@ namespace { void collectPossibleIVs(Loop *L, SmallInstructionVector &PossibleIVs); void collectPossibleReductions(Loop *L, ReductionTracker &Reductions); - bool reroll(Instruction *IV, Loop *L, BasicBlock *Header, const SCEV *IterCount, - ReductionTracker &Reductions); + bool reroll(Instruction *IV, Loop *L, BasicBlock *Header, + const SCEV *BackedgeTakenCount, ReductionTracker &Reductions); }; } // end anonymous namespace @@ -510,48 +506,6 @@ static bool hasUsesOutsideLoop(Instruction *I, Loop *L) { return false; } -static const SCEVConstant *getIncrmentFactorSCEV(ScalarEvolution *SE, - const SCEV *SCEVExpr, - Instruction &IV) { - const SCEVMulExpr *MulSCEV = dyn_cast<SCEVMulExpr>(SCEVExpr); - - // If StepRecurrence of a SCEVExpr is a constant (c1 * c2, c2 = sizeof(ptr)), - // Return c1. - if (!MulSCEV && IV.getType()->isPointerTy()) - if (const SCEVConstant *IncSCEV = dyn_cast<SCEVConstant>(SCEVExpr)) { - const PointerType *PTy = cast<PointerType>(IV.getType()); - Type *ElTy = PTy->getElementType(); - const SCEV *SizeOfExpr = - SE->getSizeOfExpr(SE->getEffectiveSCEVType(IV.getType()), ElTy); - if (IncSCEV->getValue()->getValue().isNegative()) { - const SCEV *NewSCEV = - SE->getUDivExpr(SE->getNegativeSCEV(SCEVExpr), SizeOfExpr); - return dyn_cast<SCEVConstant>(SE->getNegativeSCEV(NewSCEV)); - } else { - return dyn_cast<SCEVConstant>(SE->getUDivExpr(SCEVExpr, SizeOfExpr)); - } - } - - if (!MulSCEV) - return nullptr; - - // If StepRecurrence of a SCEVExpr is a c * sizeof(x), where c is constant, - // Return c. - const SCEVConstant *CIncSCEV = nullptr; - for (const SCEV *Operand : MulSCEV->operands()) { - if (const SCEVConstant *Constant = dyn_cast<SCEVConstant>(Operand)) { - CIncSCEV = Constant; - } else if (const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(Operand)) { - Type *AllocTy; - if (!Unknown->isSizeOf(AllocTy)) - break; - } else { - return nullptr; - } - } - return CIncSCEV; -} - // Check if an IV is only used to control the loop. There are two cases: // 1. It only has one use which is loop increment, and the increment is only // used by comparison and the PHI (could has sext with nsw in between), and the @@ -632,25 +586,17 @@ void LoopReroll::collectPossibleIVs(Loop *L, continue; if (!PHISCEV->isAffine()) continue; - const SCEVConstant *IncSCEV = nullptr; - if (I->getType()->isPointerTy()) - IncSCEV = - getIncrmentFactorSCEV(SE, PHISCEV->getStepRecurrence(*SE), *I); - else - IncSCEV = dyn_cast<SCEVConstant>(PHISCEV->getStepRecurrence(*SE)); + auto IncSCEV = dyn_cast<SCEVConstant>(PHISCEV->getStepRecurrence(*SE)); if (IncSCEV) { - const APInt &AInt = IncSCEV->getValue()->getValue().abs(); - if (IncSCEV->getValue()->isZero() || AInt.uge(MaxInc)) - continue; IVToIncMap[&*I] = IncSCEV->getValue()->getSExtValue(); - DEBUG(dbgs() << "LRR: Possible IV: " << *I << " = " << *PHISCEV - << "\n"); + LLVM_DEBUG(dbgs() << "LRR: Possible IV: " << *I << " = " << *PHISCEV + << "\n"); if (isLoopControlIV(L, &*I)) { assert(!LoopControlIV && "Found two loop control only IV"); LoopControlIV = &(*I); - DEBUG(dbgs() << "LRR: Possible loop control only IV: " << *I << " = " - << *PHISCEV << "\n"); + LLVM_DEBUG(dbgs() << "LRR: Possible loop control only IV: " << *I + << " = " << *PHISCEV << "\n"); } else PossibleIVs.push_back(&*I); } @@ -717,8 +663,8 @@ void LoopReroll::collectPossibleReductions(Loop *L, if (!SLR.valid()) continue; - DEBUG(dbgs() << "LRR: Possible reduction: " << *I << " (with " << - SLR.size() << " chained instructions)\n"); + LLVM_DEBUG(dbgs() << "LRR: Possible reduction: " << *I << " (with " + << SLR.size() << " chained instructions)\n"); Reductions.addSLR(SLR); } } @@ -856,7 +802,8 @@ collectPossibleRoots(Instruction *Base, std::map<int64_t,Instruction*> &Roots) { BaseUsers.push_back(II); continue; } else { - DEBUG(dbgs() << "LRR: Aborting due to non-instruction: " << *I << "\n"); + LLVM_DEBUG(dbgs() << "LRR: Aborting due to non-instruction: " << *I + << "\n"); return false; } } @@ -878,7 +825,7 @@ collectPossibleRoots(Instruction *Base, std::map<int64_t,Instruction*> &Roots) { // away. if (BaseUsers.size()) { if (Roots.find(0) != Roots.end()) { - DEBUG(dbgs() << "LRR: Multiple roots found for base - aborting!\n"); + LLVM_DEBUG(dbgs() << "LRR: Multiple roots found for base - aborting!\n"); return false; } Roots[0] = Base; @@ -894,9 +841,9 @@ collectPossibleRoots(Instruction *Base, std::map<int64_t,Instruction*> &Roots) { if (KV.first == 0) continue; if (!KV.second->hasNUses(NumBaseUses)) { - DEBUG(dbgs() << "LRR: Aborting - Root and Base #users not the same: " - << "#Base=" << NumBaseUses << ", #Root=" << - KV.second->getNumUses() << "\n"); + LLVM_DEBUG(dbgs() << "LRR: Aborting - Root and Base #users not the same: " + << "#Base=" << NumBaseUses + << ", #Root=" << KV.second->getNumUses() << "\n"); return false; } } @@ -1024,13 +971,14 @@ bool LoopReroll::DAGRootTracker::findRoots() { // Ensure all sets have the same size. if (RootSets.empty()) { - DEBUG(dbgs() << "LRR: Aborting because no root sets found!\n"); + LLVM_DEBUG(dbgs() << "LRR: Aborting because no root sets found!\n"); return false; } for (auto &V : RootSets) { if (V.Roots.empty() || V.Roots.size() != RootSets[0].Roots.size()) { - DEBUG(dbgs() - << "LRR: Aborting because not all root sets have the same size\n"); + LLVM_DEBUG( + dbgs() + << "LRR: Aborting because not all root sets have the same size\n"); return false; } } @@ -1038,13 +986,14 @@ bool LoopReroll::DAGRootTracker::findRoots() { Scale = RootSets[0].Roots.size() + 1; if (Scale > IL_MaxRerollIterations) { - DEBUG(dbgs() << "LRR: Aborting - too many iterations found. " - << "#Found=" << Scale << ", #Max=" << IL_MaxRerollIterations - << "\n"); + LLVM_DEBUG(dbgs() << "LRR: Aborting - too many iterations found. " + << "#Found=" << Scale + << ", #Max=" << IL_MaxRerollIterations << "\n"); return false; } - DEBUG(dbgs() << "LRR: Successfully found roots: Scale=" << Scale << "\n"); + LLVM_DEBUG(dbgs() << "LRR: Successfully found roots: Scale=" << Scale + << "\n"); return true; } @@ -1078,7 +1027,7 @@ bool LoopReroll::DAGRootTracker::collectUsedInstructions(SmallInstructionSet &Po // While we're here, check the use sets are the same size. if (V.size() != VBase.size()) { - DEBUG(dbgs() << "LRR: Aborting - use sets are different sizes\n"); + LLVM_DEBUG(dbgs() << "LRR: Aborting - use sets are different sizes\n"); return false; } @@ -1235,17 +1184,17 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) { // set. for (auto &KV : Uses) { if (KV.second.count() != 1 && !isIgnorableInst(KV.first)) { - DEBUG(dbgs() << "LRR: Aborting - instruction is not used in 1 iteration: " - << *KV.first << " (#uses=" << KV.second.count() << ")\n"); + LLVM_DEBUG( + dbgs() << "LRR: Aborting - instruction is not used in 1 iteration: " + << *KV.first << " (#uses=" << KV.second.count() << ")\n"); return false; } } - DEBUG( - for (auto &KV : Uses) { - dbgs() << "LRR: " << KV.second.find_first() << "\t" << *KV.first << "\n"; - } - ); + LLVM_DEBUG(for (auto &KV + : Uses) { + dbgs() << "LRR: " << KV.second.find_first() << "\t" << *KV.first << "\n"; + }); for (unsigned Iter = 1; Iter < Scale; ++Iter) { // In addition to regular aliasing information, we need to look for @@ -1304,8 +1253,8 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) { if (TryIt == Uses.end() || TryIt == RootIt || instrDependsOn(TryIt->first, RootIt, TryIt)) { - DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst << - " vs. " << *RootInst << "\n"); + LLVM_DEBUG(dbgs() << "LRR: iteration root match failed at " + << *BaseInst << " vs. " << *RootInst << "\n"); return false; } @@ -1341,8 +1290,8 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) { // root instruction, does not also belong to the base set or the set of // some other root instruction. if (RootIt->second.count() > 1) { - DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst << - " vs. " << *RootInst << " (prev. case overlap)\n"); + LLVM_DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst + << " vs. " << *RootInst << " (prev. case overlap)\n"); return false; } @@ -1352,8 +1301,9 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) { if (RootInst->mayReadFromMemory()) for (auto &K : AST) { if (K.aliasesUnknownInst(RootInst, *AA)) { - DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst << - " vs. " << *RootInst << " (depends on future store)\n"); + LLVM_DEBUG(dbgs() << "LRR: iteration root match failed at " + << *BaseInst << " vs. " << *RootInst + << " (depends on future store)\n"); return false; } } @@ -1366,9 +1316,9 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) { !isSafeToSpeculativelyExecute(BaseInst)) || (!isUnorderedLoadStore(RootInst) && !isSafeToSpeculativelyExecute(RootInst)))) { - DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst << - " vs. " << *RootInst << - " (side effects prevent reordering)\n"); + LLVM_DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst + << " vs. " << *RootInst + << " (side effects prevent reordering)\n"); return false; } @@ -1419,8 +1369,9 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) { BaseInst->getOperand(!j) == Op2) { Swapped = true; } else { - DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst - << " vs. " << *RootInst << " (operand " << j << ")\n"); + LLVM_DEBUG(dbgs() + << "LRR: iteration root match failed at " << *BaseInst + << " vs. " << *RootInst << " (operand " << j << ")\n"); return false; } } @@ -1433,8 +1384,8 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) { hasUsesOutsideLoop(BaseInst, L)) || (!PossibleRedLastSet.count(RootInst) && hasUsesOutsideLoop(RootInst, L))) { - DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst << - " vs. " << *RootInst << " (uses outside loop)\n"); + LLVM_DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst + << " vs. " << *RootInst << " (uses outside loop)\n"); return false; } @@ -1451,20 +1402,32 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) { "Mismatched set sizes!"); } - DEBUG(dbgs() << "LRR: Matched all iteration increments for " << - *IV << "\n"); + LLVM_DEBUG(dbgs() << "LRR: Matched all iteration increments for " << *IV + << "\n"); return true; } -void LoopReroll::DAGRootTracker::replace(const SCEV *IterCount) { +void LoopReroll::DAGRootTracker::replace(const SCEV *BackedgeTakenCount) { BasicBlock *Header = L->getHeader(); + + // Compute the start and increment for each BaseInst before we start erasing + // instructions. + SmallVector<const SCEV *, 8> StartExprs; + SmallVector<const SCEV *, 8> IncrExprs; + for (auto &DRS : RootSets) { + const SCEVAddRecExpr *IVSCEV = + cast<SCEVAddRecExpr>(SE->getSCEV(DRS.BaseInst)); + StartExprs.push_back(IVSCEV->getStart()); + IncrExprs.push_back(SE->getMinusSCEV(SE->getSCEV(DRS.Roots[0]), IVSCEV)); + } + // Remove instructions associated with non-base iterations. for (BasicBlock::reverse_iterator J = Header->rbegin(), JE = Header->rend(); J != JE;) { unsigned I = Uses[&*J].find_first(); if (I > 0 && I < IL_All) { - DEBUG(dbgs() << "LRR: removing: " << *J << "\n"); + LLVM_DEBUG(dbgs() << "LRR: removing: " << *J << "\n"); J++->eraseFromParent(); continue; } @@ -1472,74 +1435,47 @@ void LoopReroll::DAGRootTracker::replace(const SCEV *IterCount) { ++J; } - bool HasTwoIVs = LoopControlIV && LoopControlIV != IV; + // Rewrite each BaseInst using SCEV. + for (size_t i = 0, e = RootSets.size(); i != e; ++i) + // Insert the new induction variable. + replaceIV(RootSets[i], StartExprs[i], IncrExprs[i]); - if (HasTwoIVs) { - updateNonLoopCtrlIncr(); - replaceIV(LoopControlIV, LoopControlIV, IterCount); - } else - // We need to create a new induction variable for each different BaseInst. - for (auto &DRS : RootSets) - // Insert the new induction variable. - replaceIV(DRS.BaseInst, IV, IterCount); + { // Limit the lifetime of SCEVExpander. + BranchInst *BI = cast<BranchInst>(Header->getTerminator()); + const DataLayout &DL = Header->getModule()->getDataLayout(); + SCEVExpander Expander(*SE, DL, "reroll"); + auto Zero = SE->getZero(BackedgeTakenCount->getType()); + auto One = SE->getOne(BackedgeTakenCount->getType()); + auto NewIVSCEV = SE->getAddRecExpr(Zero, One, L, SCEV::FlagAnyWrap); + Value *NewIV = + Expander.expandCodeFor(NewIVSCEV, BackedgeTakenCount->getType(), + Header->getFirstNonPHIOrDbg()); + // FIXME: This arithmetic can overflow. + auto TripCount = SE->getAddExpr(BackedgeTakenCount, One); + auto ScaledTripCount = SE->getMulExpr( + TripCount, SE->getConstant(BackedgeTakenCount->getType(), Scale)); + auto ScaledBECount = SE->getMinusSCEV(ScaledTripCount, One); + Value *TakenCount = + Expander.expandCodeFor(ScaledBECount, BackedgeTakenCount->getType(), + Header->getFirstNonPHIOrDbg()); + Value *Cond = + new ICmpInst(BI, CmpInst::ICMP_EQ, NewIV, TakenCount, "exitcond"); + BI->setCondition(Cond); + + if (BI->getSuccessor(1) != Header) + BI->swapSuccessors(); + } SimplifyInstructionsInBlock(Header, TLI); DeleteDeadPHIs(Header, TLI); } -// For non-loop-control IVs, we only need to update the last increment -// with right amount, then we are done. -void LoopReroll::DAGRootTracker::updateNonLoopCtrlIncr() { - const SCEV *NewInc = nullptr; - for (auto *LoopInc : LoopIncs) { - GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LoopInc); - const SCEVConstant *COp = nullptr; - if (GEP && LoopInc->getOperand(0)->getType()->isPointerTy()) { - COp = dyn_cast<SCEVConstant>(SE->getSCEV(LoopInc->getOperand(1))); - } else { - COp = dyn_cast<SCEVConstant>(SE->getSCEV(LoopInc->getOperand(0))); - if (!COp) - COp = dyn_cast<SCEVConstant>(SE->getSCEV(LoopInc->getOperand(1))); - } - - assert(COp && "Didn't find constant operand of LoopInc!\n"); - - const APInt &AInt = COp->getValue()->getValue(); - const SCEV *ScaleSCEV = SE->getConstant(COp->getType(), Scale); - if (AInt.isNegative()) { - NewInc = SE->getNegativeSCEV(COp); - NewInc = SE->getUDivExpr(NewInc, ScaleSCEV); - NewInc = SE->getNegativeSCEV(NewInc); - } else - NewInc = SE->getUDivExpr(COp, ScaleSCEV); - - LoopInc->setOperand(1, dyn_cast<SCEVConstant>(NewInc)->getValue()); - } -} - -void LoopReroll::DAGRootTracker::replaceIV(Instruction *Inst, - Instruction *InstIV, - const SCEV *IterCount) { +void LoopReroll::DAGRootTracker::replaceIV(DAGRootSet &DRS, + const SCEV *Start, + const SCEV *IncrExpr) { BasicBlock *Header = L->getHeader(); - int64_t Inc = IVToIncMap[InstIV]; - bool NeedNewIV = InstIV == LoopControlIV; - bool Negative = !NeedNewIV && Inc < 0; - - const SCEVAddRecExpr *RealIVSCEV = cast<SCEVAddRecExpr>(SE->getSCEV(Inst)); - const SCEV *Start = RealIVSCEV->getStart(); - - if (NeedNewIV) - Start = SE->getConstant(Start->getType(), 0); - - const SCEV *SizeOfExpr = nullptr; - const SCEV *IncrExpr = - SE->getConstant(RealIVSCEV->getType(), Negative ? -1 : 1); - if (auto *PTy = dyn_cast<PointerType>(Inst->getType())) { - Type *ElTy = PTy->getElementType(); - SizeOfExpr = - SE->getSizeOfExpr(SE->getEffectiveSCEVType(Inst->getType()), ElTy); - IncrExpr = SE->getMulExpr(IncrExpr, SizeOfExpr); - } + Instruction *Inst = DRS.BaseInst; + const SCEV *NewIVSCEV = SE->getAddRecExpr(Start, IncrExpr, L, SCEV::FlagAnyWrap); @@ -1552,54 +1488,6 @@ void LoopReroll::DAGRootTracker::replaceIV(Instruction *Inst, for (auto &KV : Uses) if (KV.second.find_first() == 0) KV.first->replaceUsesOfWith(Inst, NewIV); - - if (BranchInst *BI = dyn_cast<BranchInst>(Header->getTerminator())) { - // FIXME: Why do we need this check? - if (Uses[BI].find_first() == IL_All) { - const SCEV *ICSCEV = RealIVSCEV->evaluateAtIteration(IterCount, *SE); - - if (NeedNewIV) - ICSCEV = SE->getMulExpr(IterCount, - SE->getConstant(IterCount->getType(), Scale)); - - // Iteration count SCEV minus or plus 1 - const SCEV *MinusPlus1SCEV = - SE->getConstant(ICSCEV->getType(), Negative ? -1 : 1); - if (Inst->getType()->isPointerTy()) { - assert(SizeOfExpr && "SizeOfExpr is not initialized"); - MinusPlus1SCEV = SE->getMulExpr(MinusPlus1SCEV, SizeOfExpr); - } - - const SCEV *ICMinusPlus1SCEV = SE->getMinusSCEV(ICSCEV, MinusPlus1SCEV); - // Iteration count minus 1 - Instruction *InsertPtr = nullptr; - if (isa<SCEVConstant>(ICMinusPlus1SCEV)) { - InsertPtr = BI; - } else { - BasicBlock *Preheader = L->getLoopPreheader(); - if (!Preheader) - Preheader = InsertPreheaderForLoop(L, DT, LI, PreserveLCSSA); - InsertPtr = Preheader->getTerminator(); - } - - if (!isa<PointerType>(NewIV->getType()) && NeedNewIV && - (SE->getTypeSizeInBits(NewIV->getType()) < - SE->getTypeSizeInBits(ICMinusPlus1SCEV->getType()))) { - IRBuilder<> Builder(BI); - Builder.SetCurrentDebugLocation(BI->getDebugLoc()); - NewIV = Builder.CreateSExt(NewIV, ICMinusPlus1SCEV->getType()); - } - Value *ICMinusPlus1 = Expander.expandCodeFor( - ICMinusPlus1SCEV, NewIV->getType(), InsertPtr); - - Value *Cond = - new ICmpInst(BI, CmpInst::ICMP_EQ, NewIV, ICMinusPlus1, "exitcond"); - BI->setCondition(Cond); - - if (BI->getSuccessor(1) != Header) - BI->swapSuccessors(); - } - } } } @@ -1617,17 +1505,17 @@ bool LoopReroll::ReductionTracker::validateSelected() { int Iter = PossibleRedIter[J]; if (Iter != PrevIter && Iter != PrevIter + 1 && !PossibleReds[i].getReducedValue()->isAssociative()) { - DEBUG(dbgs() << "LRR: Out-of-order non-associative reduction: " << - J << "\n"); + LLVM_DEBUG(dbgs() << "LRR: Out-of-order non-associative reduction: " + << J << "\n"); return false; } if (Iter != PrevIter) { if (Count != BaseCount) { - DEBUG(dbgs() << "LRR: Iteration " << PrevIter << - " reduction use count " << Count << - " is not equal to the base use count " << - BaseCount << "\n"); + LLVM_DEBUG(dbgs() + << "LRR: Iteration " << PrevIter << " reduction use count " + << Count << " is not equal to the base use count " + << BaseCount << "\n"); return false; } @@ -1716,15 +1604,15 @@ void LoopReroll::ReductionTracker::replaceSelected() { // f(%iv) or part of some f(%iv.i). If all of that is true (and all reductions // have been validated), then we reroll the loop. bool LoopReroll::reroll(Instruction *IV, Loop *L, BasicBlock *Header, - const SCEV *IterCount, + const SCEV *BackedgeTakenCount, ReductionTracker &Reductions) { DAGRootTracker DAGRoots(this, L, IV, SE, AA, TLI, DT, LI, PreserveLCSSA, IVToIncMap, LoopControlIV); if (!DAGRoots.findRoots()) return false; - DEBUG(dbgs() << "LRR: Found all root induction increments for: " << - *IV << "\n"); + LLVM_DEBUG(dbgs() << "LRR: Found all root induction increments for: " << *IV + << "\n"); if (!DAGRoots.validate(Reductions)) return false; @@ -1734,7 +1622,7 @@ bool LoopReroll::reroll(Instruction *IV, Loop *L, BasicBlock *Header, // making changes! Reductions.replaceSelected(); - DAGRoots.replace(IterCount); + DAGRoots.replace(BackedgeTakenCount); ++NumRerolledLoops; return true; @@ -1752,9 +1640,9 @@ bool LoopReroll::runOnLoop(Loop *L, LPPassManager &LPM) { PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); BasicBlock *Header = L->getHeader(); - DEBUG(dbgs() << "LRR: F[" << Header->getParent()->getName() << - "] Loop %" << Header->getName() << " (" << - L->getNumBlocks() << " block(s))\n"); + LLVM_DEBUG(dbgs() << "LRR: F[" << Header->getParent()->getName() << "] Loop %" + << Header->getName() << " (" << L->getNumBlocks() + << " block(s))\n"); // For now, we'll handle only single BB loops. if (L->getNumBlocks() > 1) @@ -1763,10 +1651,10 @@ bool LoopReroll::runOnLoop(Loop *L, LPPassManager &LPM) { if (!SE->hasLoopInvariantBackedgeTakenCount(L)) return false; - const SCEV *LIBETC = SE->getBackedgeTakenCount(L); - const SCEV *IterCount = SE->getAddExpr(LIBETC, SE->getOne(LIBETC->getType())); - DEBUG(dbgs() << "\n Before Reroll:\n" << *(L->getHeader()) << "\n"); - DEBUG(dbgs() << "LRR: iteration count = " << *IterCount << "\n"); + const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); + LLVM_DEBUG(dbgs() << "\n Before Reroll:\n" << *(L->getHeader()) << "\n"); + LLVM_DEBUG(dbgs() << "LRR: backedge-taken count = " << *BackedgeTakenCount + << "\n"); // First, we need to find the induction variable with respect to which we can // reroll (there may be several possible options). @@ -1776,7 +1664,7 @@ bool LoopReroll::runOnLoop(Loop *L, LPPassManager &LPM) { collectPossibleIVs(L, PossibleIVs); if (PossibleIVs.empty()) { - DEBUG(dbgs() << "LRR: No possible IVs found\n"); + LLVM_DEBUG(dbgs() << "LRR: No possible IVs found\n"); return false; } @@ -1787,11 +1675,11 @@ bool LoopReroll::runOnLoop(Loop *L, LPPassManager &LPM) { // For each possible IV, collect the associated possible set of 'root' nodes // (i+1, i+2, etc.). for (Instruction *PossibleIV : PossibleIVs) - if (reroll(PossibleIV, L, Header, IterCount, Reductions)) { + if (reroll(PossibleIV, L, Header, BackedgeTakenCount, Reductions)) { Changed = true; break; } - DEBUG(dbgs() << "\n After Reroll:\n" << *(L->getHeader()) << "\n"); + LLVM_DEBUG(dbgs() << "\n After Reroll:\n" << *(L->getHeader()) << "\n"); // Trip count of L has changed so SE must be re-evaluated. if (Changed) diff --git a/lib/Transforms/Scalar/LoopRotation.cpp b/lib/Transforms/Scalar/LoopRotation.cpp index a91f53ba663f..eeaad39dc1d1 100644 --- a/lib/Transforms/Scalar/LoopRotation.cpp +++ b/lib/Transforms/Scalar/LoopRotation.cpp @@ -13,33 +13,15 @@ #include "llvm/Transforms/Scalar/LoopRotation.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/BasicAliasAnalysis.h" -#include "llvm/Analysis/CodeMetrics.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/CFG.h" -#include "llvm/IR/DebugInfoMetadata.h" -#include "llvm/IR/Dominators.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Module.h" -#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopRotationUtils.h" #include "llvm/Transforms/Utils/LoopUtils.h" -#include "llvm/Transforms/Utils/SSAUpdater.h" -#include "llvm/Transforms/Utils/ValueMapper.h" using namespace llvm; #define DEBUG_TYPE "loop-rotate" @@ -48,595 +30,6 @@ static cl::opt<unsigned> DefaultRotationThreshold( "rotation-max-header-size", cl::init(16), cl::Hidden, cl::desc("The default maximum header size for automatic loop rotation")); -STATISTIC(NumRotated, "Number of loops rotated"); - -namespace { -/// A simple loop rotation transformation. -class LoopRotate { - const unsigned MaxHeaderSize; - LoopInfo *LI; - const TargetTransformInfo *TTI; - AssumptionCache *AC; - DominatorTree *DT; - ScalarEvolution *SE; - const SimplifyQuery &SQ; - -public: - LoopRotate(unsigned MaxHeaderSize, LoopInfo *LI, - const TargetTransformInfo *TTI, AssumptionCache *AC, - DominatorTree *DT, ScalarEvolution *SE, const SimplifyQuery &SQ) - : MaxHeaderSize(MaxHeaderSize), LI(LI), TTI(TTI), AC(AC), DT(DT), SE(SE), - SQ(SQ) {} - bool processLoop(Loop *L); - -private: - bool rotateLoop(Loop *L, bool SimplifiedLatch); - bool simplifyLoopLatch(Loop *L); -}; -} // end anonymous namespace - -/// RewriteUsesOfClonedInstructions - We just cloned the instructions from the -/// old header into the preheader. If there were uses of the values produced by -/// these instruction that were outside of the loop, we have to insert PHI nodes -/// to merge the two values. Do this now. -static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, - BasicBlock *OrigPreheader, - ValueToValueMapTy &ValueMap, - SmallVectorImpl<PHINode*> *InsertedPHIs) { - // Remove PHI node entries that are no longer live. - BasicBlock::iterator I, E = OrigHeader->end(); - for (I = OrigHeader->begin(); PHINode *PN = dyn_cast<PHINode>(I); ++I) - PN->removeIncomingValue(PN->getBasicBlockIndex(OrigPreheader)); - - // Now fix up users of the instructions in OrigHeader, inserting PHI nodes - // as necessary. - SSAUpdater SSA(InsertedPHIs); - for (I = OrigHeader->begin(); I != E; ++I) { - Value *OrigHeaderVal = &*I; - - // If there are no uses of the value (e.g. because it returns void), there - // is nothing to rewrite. - if (OrigHeaderVal->use_empty()) - continue; - - Value *OrigPreHeaderVal = ValueMap.lookup(OrigHeaderVal); - - // The value now exits in two versions: the initial value in the preheader - // and the loop "next" value in the original header. - SSA.Initialize(OrigHeaderVal->getType(), OrigHeaderVal->getName()); - SSA.AddAvailableValue(OrigHeader, OrigHeaderVal); - SSA.AddAvailableValue(OrigPreheader, OrigPreHeaderVal); - - // Visit each use of the OrigHeader instruction. - for (Value::use_iterator UI = OrigHeaderVal->use_begin(), - UE = OrigHeaderVal->use_end(); - UI != UE;) { - // Grab the use before incrementing the iterator. - Use &U = *UI; - - // Increment the iterator before removing the use from the list. - ++UI; - - // SSAUpdater can't handle a non-PHI use in the same block as an - // earlier def. We can easily handle those cases manually. - Instruction *UserInst = cast<Instruction>(U.getUser()); - if (!isa<PHINode>(UserInst)) { - BasicBlock *UserBB = UserInst->getParent(); - - // The original users in the OrigHeader are already using the - // original definitions. - if (UserBB == OrigHeader) - continue; - - // Users in the OrigPreHeader need to use the value to which the - // original definitions are mapped. - if (UserBB == OrigPreheader) { - U = OrigPreHeaderVal; - continue; - } - } - - // Anything else can be handled by SSAUpdater. - SSA.RewriteUse(U); - } - - // Replace MetadataAsValue(ValueAsMetadata(OrigHeaderVal)) uses in debug - // intrinsics. - SmallVector<DbgValueInst *, 1> DbgValues; - llvm::findDbgValues(DbgValues, OrigHeaderVal); - for (auto &DbgValue : DbgValues) { - // The original users in the OrigHeader are already using the original - // definitions. - BasicBlock *UserBB = DbgValue->getParent(); - if (UserBB == OrigHeader) - continue; - - // Users in the OrigPreHeader need to use the value to which the - // original definitions are mapped and anything else can be handled by - // the SSAUpdater. To avoid adding PHINodes, check if the value is - // available in UserBB, if not substitute undef. - Value *NewVal; - if (UserBB == OrigPreheader) - NewVal = OrigPreHeaderVal; - else if (SSA.HasValueForBlock(UserBB)) - NewVal = SSA.GetValueInMiddleOfBlock(UserBB); - else - NewVal = UndefValue::get(OrigHeaderVal->getType()); - DbgValue->setOperand(0, - MetadataAsValue::get(OrigHeaderVal->getContext(), - ValueAsMetadata::get(NewVal))); - } - } -} - -/// Propagate dbg.value intrinsics through the newly inserted Phis. -static void insertDebugValues(BasicBlock *OrigHeader, - SmallVectorImpl<PHINode*> &InsertedPHIs) { - ValueToValueMapTy DbgValueMap; - - // Map existing PHI nodes to their dbg.values. - for (auto &I : *OrigHeader) { - if (auto DbgII = dyn_cast<DbgInfoIntrinsic>(&I)) { - if (auto *Loc = dyn_cast_or_null<PHINode>(DbgII->getVariableLocation())) - DbgValueMap.insert({Loc, DbgII}); - } - } - - // Then iterate through the new PHIs and look to see if they use one of the - // previously mapped PHIs. If so, insert a new dbg.value intrinsic that will - // propagate the info through the new PHI. - LLVMContext &C = OrigHeader->getContext(); - for (auto PHI : InsertedPHIs) { - for (auto VI : PHI->operand_values()) { - auto V = DbgValueMap.find(VI); - if (V != DbgValueMap.end()) { - auto *DbgII = cast<DbgInfoIntrinsic>(V->second); - Instruction *NewDbgII = DbgII->clone(); - auto PhiMAV = MetadataAsValue::get(C, ValueAsMetadata::get(PHI)); - NewDbgII->setOperand(0, PhiMAV); - BasicBlock *Parent = PHI->getParent(); - NewDbgII->insertBefore(Parent->getFirstNonPHIOrDbgOrLifetime()); - } - } - } -} - -/// Rotate loop LP. Return true if the loop is rotated. -/// -/// \param SimplifiedLatch is true if the latch was just folded into the final -/// loop exit. In this case we may want to rotate even though the new latch is -/// now an exiting branch. This rotation would have happened had the latch not -/// been simplified. However, if SimplifiedLatch is false, then we avoid -/// rotating loops in which the latch exits to avoid excessive or endless -/// rotation. LoopRotate should be repeatable and converge to a canonical -/// form. This property is satisfied because simplifying the loop latch can only -/// happen once across multiple invocations of the LoopRotate pass. -bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { - // If the loop has only one block then there is not much to rotate. - if (L->getBlocks().size() == 1) - return false; - - BasicBlock *OrigHeader = L->getHeader(); - BasicBlock *OrigLatch = L->getLoopLatch(); - - BranchInst *BI = dyn_cast<BranchInst>(OrigHeader->getTerminator()); - if (!BI || BI->isUnconditional()) - return false; - - // If the loop header is not one of the loop exiting blocks then - // either this loop is already rotated or it is not - // suitable for loop rotation transformations. - if (!L->isLoopExiting(OrigHeader)) - return false; - - // If the loop latch already contains a branch that leaves the loop then the - // loop is already rotated. - if (!OrigLatch) - return false; - - // Rotate if either the loop latch does *not* exit the loop, or if the loop - // latch was just simplified. - if (L->isLoopExiting(OrigLatch) && !SimplifiedLatch) - return false; - - // Check size of original header and reject loop if it is very big or we can't - // duplicate blocks inside it. - { - SmallPtrSet<const Value *, 32> EphValues; - CodeMetrics::collectEphemeralValues(L, AC, EphValues); - - CodeMetrics Metrics; - Metrics.analyzeBasicBlock(OrigHeader, *TTI, EphValues); - if (Metrics.notDuplicatable) { - DEBUG(dbgs() << "LoopRotation: NOT rotating - contains non-duplicatable" - << " instructions: "; - L->dump()); - return false; - } - if (Metrics.convergent) { - DEBUG(dbgs() << "LoopRotation: NOT rotating - contains convergent " - "instructions: "; - L->dump()); - return false; - } - if (Metrics.NumInsts > MaxHeaderSize) - return false; - } - - // Now, this loop is suitable for rotation. - BasicBlock *OrigPreheader = L->getLoopPreheader(); - - // If the loop could not be converted to canonical form, it must have an - // indirectbr in it, just give up. - if (!OrigPreheader) - return false; - - // Anything ScalarEvolution may know about this loop or the PHI nodes - // in its header will soon be invalidated. - if (SE) - SE->forgetLoop(L); - - DEBUG(dbgs() << "LoopRotation: rotating "; L->dump()); - - // Find new Loop header. NewHeader is a Header's one and only successor - // that is inside loop. Header's other successor is outside the - // loop. Otherwise loop is not suitable for rotation. - BasicBlock *Exit = BI->getSuccessor(0); - BasicBlock *NewHeader = BI->getSuccessor(1); - if (L->contains(Exit)) - std::swap(Exit, NewHeader); - assert(NewHeader && "Unable to determine new loop header"); - assert(L->contains(NewHeader) && !L->contains(Exit) && - "Unable to determine loop header and exit blocks"); - - // This code assumes that the new header has exactly one predecessor. - // Remove any single-entry PHI nodes in it. - assert(NewHeader->getSinglePredecessor() && - "New header doesn't have one pred!"); - FoldSingleEntryPHINodes(NewHeader); - - // Begin by walking OrigHeader and populating ValueMap with an entry for - // each Instruction. - BasicBlock::iterator I = OrigHeader->begin(), E = OrigHeader->end(); - ValueToValueMapTy ValueMap; - - // For PHI nodes, the value available in OldPreHeader is just the - // incoming value from OldPreHeader. - for (; PHINode *PN = dyn_cast<PHINode>(I); ++I) - ValueMap[PN] = PN->getIncomingValueForBlock(OrigPreheader); - - // For the rest of the instructions, either hoist to the OrigPreheader if - // possible or create a clone in the OldPreHeader if not. - TerminatorInst *LoopEntryBranch = OrigPreheader->getTerminator(); - - // Record all debug intrinsics preceding LoopEntryBranch to avoid duplication. - using DbgIntrinsicHash = - std::pair<std::pair<Value *, DILocalVariable *>, DIExpression *>; - auto makeHash = [](DbgInfoIntrinsic *D) -> DbgIntrinsicHash { - return {{D->getVariableLocation(), D->getVariable()}, D->getExpression()}; - }; - SmallDenseSet<DbgIntrinsicHash, 8> DbgIntrinsics; - for (auto I = std::next(OrigPreheader->rbegin()), E = OrigPreheader->rend(); - I != E; ++I) { - if (auto *DII = dyn_cast<DbgInfoIntrinsic>(&*I)) - DbgIntrinsics.insert(makeHash(DII)); - else - break; - } - - while (I != E) { - Instruction *Inst = &*I++; - - // If the instruction's operands are invariant and it doesn't read or write - // memory, then it is safe to hoist. Doing this doesn't change the order of - // execution in the preheader, but does prevent the instruction from - // executing in each iteration of the loop. This means it is safe to hoist - // something that might trap, but isn't safe to hoist something that reads - // memory (without proving that the loop doesn't write). - if (L->hasLoopInvariantOperands(Inst) && !Inst->mayReadFromMemory() && - !Inst->mayWriteToMemory() && !isa<TerminatorInst>(Inst) && - !isa<DbgInfoIntrinsic>(Inst) && !isa<AllocaInst>(Inst)) { - Inst->moveBefore(LoopEntryBranch); - continue; - } - - // Otherwise, create a duplicate of the instruction. - Instruction *C = Inst->clone(); - - // Eagerly remap the operands of the instruction. - RemapInstruction(C, ValueMap, - RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); - - // Avoid inserting the same intrinsic twice. - if (auto *DII = dyn_cast<DbgInfoIntrinsic>(C)) - if (DbgIntrinsics.count(makeHash(DII))) { - C->deleteValue(); - continue; - } - - // With the operands remapped, see if the instruction constant folds or is - // otherwise simplifyable. This commonly occurs because the entry from PHI - // nodes allows icmps and other instructions to fold. - Value *V = SimplifyInstruction(C, SQ); - if (V && LI->replacementPreservesLCSSAForm(C, V)) { - // If so, then delete the temporary instruction and stick the folded value - // in the map. - ValueMap[Inst] = V; - if (!C->mayHaveSideEffects()) { - C->deleteValue(); - C = nullptr; - } - } else { - ValueMap[Inst] = C; - } - if (C) { - // Otherwise, stick the new instruction into the new block! - C->setName(Inst->getName()); - C->insertBefore(LoopEntryBranch); - - if (auto *II = dyn_cast<IntrinsicInst>(C)) - if (II->getIntrinsicID() == Intrinsic::assume) - AC->registerAssumption(II); - } - } - - // Along with all the other instructions, we just cloned OrigHeader's - // terminator into OrigPreHeader. Fix up the PHI nodes in each of OrigHeader's - // successors by duplicating their incoming values for OrigHeader. - TerminatorInst *TI = OrigHeader->getTerminator(); - for (BasicBlock *SuccBB : TI->successors()) - for (BasicBlock::iterator BI = SuccBB->begin(); - PHINode *PN = dyn_cast<PHINode>(BI); ++BI) - PN->addIncoming(PN->getIncomingValueForBlock(OrigHeader), OrigPreheader); - - // Now that OrigPreHeader has a clone of OrigHeader's terminator, remove - // OrigPreHeader's old terminator (the original branch into the loop), and - // remove the corresponding incoming values from the PHI nodes in OrigHeader. - LoopEntryBranch->eraseFromParent(); - - - SmallVector<PHINode*, 2> InsertedPHIs; - // If there were any uses of instructions in the duplicated block outside the - // loop, update them, inserting PHI nodes as required - RewriteUsesOfClonedInstructions(OrigHeader, OrigPreheader, ValueMap, - &InsertedPHIs); - - // Attach dbg.value intrinsics to the new phis if that phi uses a value that - // previously had debug metadata attached. This keeps the debug info - // up-to-date in the loop body. - if (!InsertedPHIs.empty()) - insertDebugValues(OrigHeader, InsertedPHIs); - - // NewHeader is now the header of the loop. - L->moveToHeader(NewHeader); - assert(L->getHeader() == NewHeader && "Latch block is our new header"); - - // Inform DT about changes to the CFG. - if (DT) { - // The OrigPreheader branches to the NewHeader and Exit now. Then, inform - // the DT about the removed edge to the OrigHeader (that got removed). - SmallVector<DominatorTree::UpdateType, 3> Updates; - Updates.push_back({DominatorTree::Insert, OrigPreheader, Exit}); - Updates.push_back({DominatorTree::Insert, OrigPreheader, NewHeader}); - Updates.push_back({DominatorTree::Delete, OrigPreheader, OrigHeader}); - DT->applyUpdates(Updates); - } - - // At this point, we've finished our major CFG changes. As part of cloning - // the loop into the preheader we've simplified instructions and the - // duplicated conditional branch may now be branching on a constant. If it is - // branching on a constant and if that constant means that we enter the loop, - // then we fold away the cond branch to an uncond branch. This simplifies the - // loop in cases important for nested loops, and it also means we don't have - // to split as many edges. - BranchInst *PHBI = cast<BranchInst>(OrigPreheader->getTerminator()); - assert(PHBI->isConditional() && "Should be clone of BI condbr!"); - if (!isa<ConstantInt>(PHBI->getCondition()) || - PHBI->getSuccessor(cast<ConstantInt>(PHBI->getCondition())->isZero()) != - NewHeader) { - // The conditional branch can't be folded, handle the general case. - // Split edges as necessary to preserve LoopSimplify form. - - // Right now OrigPreHeader has two successors, NewHeader and ExitBlock, and - // thus is not a preheader anymore. - // Split the edge to form a real preheader. - BasicBlock *NewPH = SplitCriticalEdge( - OrigPreheader, NewHeader, - CriticalEdgeSplittingOptions(DT, LI).setPreserveLCSSA()); - NewPH->setName(NewHeader->getName() + ".lr.ph"); - - // Preserve canonical loop form, which means that 'Exit' should have only - // one predecessor. Note that Exit could be an exit block for multiple - // nested loops, causing both of the edges to now be critical and need to - // be split. - SmallVector<BasicBlock *, 4> ExitPreds(pred_begin(Exit), pred_end(Exit)); - bool SplitLatchEdge = false; - for (BasicBlock *ExitPred : ExitPreds) { - // We only need to split loop exit edges. - Loop *PredLoop = LI->getLoopFor(ExitPred); - if (!PredLoop || PredLoop->contains(Exit)) - continue; - if (isa<IndirectBrInst>(ExitPred->getTerminator())) - continue; - SplitLatchEdge |= L->getLoopLatch() == ExitPred; - BasicBlock *ExitSplit = SplitCriticalEdge( - ExitPred, Exit, - CriticalEdgeSplittingOptions(DT, LI).setPreserveLCSSA()); - ExitSplit->moveBefore(Exit); - } - assert(SplitLatchEdge && - "Despite splitting all preds, failed to split latch exit?"); - } else { - // We can fold the conditional branch in the preheader, this makes things - // simpler. The first step is to remove the extra edge to the Exit block. - Exit->removePredecessor(OrigPreheader, true /*preserve LCSSA*/); - BranchInst *NewBI = BranchInst::Create(NewHeader, PHBI); - NewBI->setDebugLoc(PHBI->getDebugLoc()); - PHBI->eraseFromParent(); - - // With our CFG finalized, update DomTree if it is available. - if (DT) DT->deleteEdge(OrigPreheader, Exit); - } - - assert(L->getLoopPreheader() && "Invalid loop preheader after loop rotation"); - assert(L->getLoopLatch() && "Invalid loop latch after loop rotation"); - - // Now that the CFG and DomTree are in a consistent state again, try to merge - // the OrigHeader block into OrigLatch. This will succeed if they are - // connected by an unconditional branch. This is just a cleanup so the - // emitted code isn't too gross in this common case. - MergeBlockIntoPredecessor(OrigHeader, DT, LI); - - DEBUG(dbgs() << "LoopRotation: into "; L->dump()); - - ++NumRotated; - return true; -} - -/// Determine whether the instructions in this range may be safely and cheaply -/// speculated. This is not an important enough situation to develop complex -/// heuristics. We handle a single arithmetic instruction along with any type -/// conversions. -static bool shouldSpeculateInstrs(BasicBlock::iterator Begin, - BasicBlock::iterator End, Loop *L) { - bool seenIncrement = false; - bool MultiExitLoop = false; - - if (!L->getExitingBlock()) - MultiExitLoop = true; - - for (BasicBlock::iterator I = Begin; I != End; ++I) { - - if (!isSafeToSpeculativelyExecute(&*I)) - return false; - - if (isa<DbgInfoIntrinsic>(I)) - continue; - - switch (I->getOpcode()) { - default: - return false; - case Instruction::GetElementPtr: - // GEPs are cheap if all indices are constant. - if (!cast<GEPOperator>(I)->hasAllConstantIndices()) - return false; - // fall-thru to increment case - LLVM_FALLTHROUGH; - case Instruction::Add: - case Instruction::Sub: - case Instruction::And: - case Instruction::Or: - case Instruction::Xor: - case Instruction::Shl: - case Instruction::LShr: - case Instruction::AShr: { - Value *IVOpnd = - !isa<Constant>(I->getOperand(0)) - ? I->getOperand(0) - : !isa<Constant>(I->getOperand(1)) ? I->getOperand(1) : nullptr; - if (!IVOpnd) - return false; - - // If increment operand is used outside of the loop, this speculation - // could cause extra live range interference. - if (MultiExitLoop) { - for (User *UseI : IVOpnd->users()) { - auto *UserInst = cast<Instruction>(UseI); - if (!L->contains(UserInst)) - return false; - } - } - - if (seenIncrement) - return false; - seenIncrement = true; - break; - } - case Instruction::Trunc: - case Instruction::ZExt: - case Instruction::SExt: - // ignore type conversions - break; - } - } - return true; -} - -/// Fold the loop tail into the loop exit by speculating the loop tail -/// instructions. Typically, this is a single post-increment. In the case of a -/// simple 2-block loop, hoisting the increment can be much better than -/// duplicating the entire loop header. In the case of loops with early exits, -/// rotation will not work anyway, but simplifyLoopLatch will put the loop in -/// canonical form so downstream passes can handle it. -/// -/// I don't believe this invalidates SCEV. -bool LoopRotate::simplifyLoopLatch(Loop *L) { - BasicBlock *Latch = L->getLoopLatch(); - if (!Latch || Latch->hasAddressTaken()) - return false; - - BranchInst *Jmp = dyn_cast<BranchInst>(Latch->getTerminator()); - if (!Jmp || !Jmp->isUnconditional()) - return false; - - BasicBlock *LastExit = Latch->getSinglePredecessor(); - if (!LastExit || !L->isLoopExiting(LastExit)) - return false; - - BranchInst *BI = dyn_cast<BranchInst>(LastExit->getTerminator()); - if (!BI) - return false; - - if (!shouldSpeculateInstrs(Latch->begin(), Jmp->getIterator(), L)) - return false; - - DEBUG(dbgs() << "Folding loop latch " << Latch->getName() << " into " - << LastExit->getName() << "\n"); - - // Hoist the instructions from Latch into LastExit. - LastExit->getInstList().splice(BI->getIterator(), Latch->getInstList(), - Latch->begin(), Jmp->getIterator()); - - unsigned FallThruPath = BI->getSuccessor(0) == Latch ? 0 : 1; - BasicBlock *Header = Jmp->getSuccessor(0); - assert(Header == L->getHeader() && "expected a backward branch"); - - // Remove Latch from the CFG so that LastExit becomes the new Latch. - BI->setSuccessor(FallThruPath, Header); - Latch->replaceSuccessorsPhiUsesWith(LastExit); - Jmp->eraseFromParent(); - - // Nuke the Latch block. - assert(Latch->empty() && "unable to evacuate Latch"); - LI->removeBlock(Latch); - if (DT) - DT->eraseNode(Latch); - Latch->eraseFromParent(); - return true; -} - -/// Rotate \c L, and return true if any modification was made. -bool LoopRotate::processLoop(Loop *L) { - // Save the loop metadata. - MDNode *LoopMD = L->getLoopID(); - - // Simplify the loop latch before attempting to rotate the header - // upward. Rotation may not be needed if the loop tail can be folded into the - // loop exit. - bool SimplifiedLatch = simplifyLoopLatch(L); - - bool MadeChange = rotateLoop(L, SimplifiedLatch); - assert((!MadeChange || L->isLoopExiting(L->getLoopLatch())) && - "Loop latch should be exiting after loop-rotate."); - - // Restore the loop metadata. - // NB! We presume LoopRotation DOESN'T ADD its own metadata. - if ((MadeChange || SimplifiedLatch) && LoopMD) - L->setLoopID(LoopMD); - - return MadeChange || SimplifiedLatch; -} - LoopRotatePass::LoopRotatePass(bool EnableHeaderDuplication) : EnableHeaderDuplication(EnableHeaderDuplication) {} @@ -646,10 +39,10 @@ PreservedAnalyses LoopRotatePass::run(Loop &L, LoopAnalysisManager &AM, int Threshold = EnableHeaderDuplication ? DefaultRotationThreshold : 0; const DataLayout &DL = L.getHeader()->getModule()->getDataLayout(); const SimplifyQuery SQ = getBestSimplifyQuery(AR, DL); - LoopRotate LR(Threshold, &AR.LI, &AR.TTI, &AR.AC, &AR.DT, &AR.SE, - SQ); - bool Changed = LR.processLoop(&L); + bool Changed = LoopRotation(&L, &AR.LI, &AR.TTI, &AR.AC, &AR.DT, &AR.SE, SQ, + false, Threshold, false); + if (!Changed) return PreservedAnalyses::all(); @@ -691,8 +84,8 @@ public: auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); auto *SE = SEWP ? &SEWP->getSE() : nullptr; const SimplifyQuery SQ = getBestSimplifyQuery(*this, F); - LoopRotate LR(MaxHeaderSize, LI, TTI, AC, DT, SE, SQ); - return LR.processLoop(L); + return LoopRotation(L, LI, TTI, AC, DT, SE, SQ, false, MaxHeaderSize, + false); } }; } diff --git a/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/lib/Transforms/Scalar/LoopSimplifyCFG.cpp index 35c05e84fd68..2b83d3dc5f1b 100644 --- a/lib/Transforms/Scalar/LoopSimplifyCFG.cpp +++ b/lib/Transforms/Scalar/LoopSimplifyCFG.cpp @@ -30,13 +30,16 @@ #include "llvm/IR/Dominators.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; #define DEBUG_TYPE "loop-simplifycfg" -static bool simplifyLoopCFG(Loop &L, DominatorTree &DT, LoopInfo &LI) { +static bool simplifyLoopCFG(Loop &L, DominatorTree &DT, LoopInfo &LI, + ScalarEvolution &SE) { bool Changed = false; // Copy blocks into a temporary array to avoid iterator invalidation issues // as we remove them. @@ -53,11 +56,10 @@ static bool simplifyLoopCFG(Loop &L, DominatorTree &DT, LoopInfo &LI) { if (!Pred || !Pred->getSingleSuccessor() || LI.getLoopFor(Pred) != &L) continue; - // Pred is going to disappear, so we need to update the loop info. - if (L.getHeader() == Pred) - L.moveToHeader(Succ); - LI.removeBlock(Pred); - MergeBasicBlockIntoOnlyPred(Succ, &DT); + // Merge Succ into Pred and delete it. + MergeBlockIntoPredecessor(Succ, &DT, &LI); + + SE.forgetLoop(&L); Changed = true; } @@ -67,7 +69,7 @@ static bool simplifyLoopCFG(Loop &L, DominatorTree &DT, LoopInfo &LI) { PreservedAnalyses LoopSimplifyCFGPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { - if (!simplifyLoopCFG(L, AR.DT, AR.LI)) + if (!simplifyLoopCFG(L, AR.DT, AR.LI, AR.SE)) return PreservedAnalyses::all(); return getLoopPassPreservedAnalyses(); @@ -87,7 +89,8 @@ public: DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - return simplifyLoopCFG(*L, DT, LI); + ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + return simplifyLoopCFG(*L, DT, LI, SE); } void getAnalysisUsage(AnalysisUsage &AU) const override { diff --git a/lib/Transforms/Scalar/LoopSink.cpp b/lib/Transforms/Scalar/LoopSink.cpp index 430a7085d93f..760177c9c5e9 100644 --- a/lib/Transforms/Scalar/LoopSink.cpp +++ b/lib/Transforms/Scalar/LoopSink.cpp @@ -42,6 +42,7 @@ #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" @@ -49,7 +50,6 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; @@ -200,17 +200,19 @@ static bool sinkInstruction(Loop &L, Instruction &I, SmallVector<BasicBlock *, 2> SortedBBsToSinkInto; SortedBBsToSinkInto.insert(SortedBBsToSinkInto.begin(), BBsToSinkInto.begin(), BBsToSinkInto.end()); - std::sort(SortedBBsToSinkInto.begin(), SortedBBsToSinkInto.end(), - [&](BasicBlock *A, BasicBlock *B) { - return *LoopBlockNumber.find(A) < *LoopBlockNumber.find(B); - }); + llvm::sort(SortedBBsToSinkInto.begin(), SortedBBsToSinkInto.end(), + [&](BasicBlock *A, BasicBlock *B) { + return LoopBlockNumber.find(A)->second < + LoopBlockNumber.find(B)->second; + }); BasicBlock *MoveBB = *SortedBBsToSinkInto.begin(); // FIXME: Optimize the efficiency for cloned value replacement. The current // implementation is O(SortedBBsToSinkInto.size() * I.num_uses()). - for (BasicBlock *N : SortedBBsToSinkInto) { - if (N == MoveBB) - continue; + for (BasicBlock *N : makeArrayRef(SortedBBsToSinkInto).drop_front(1)) { + assert(LoopBlockNumber.find(N)->second > + LoopBlockNumber.find(MoveBB)->second && + "BBs not sorted!"); // Clone I and replace its uses. Instruction *IC = I.clone(); IC->setName(I.getName()); @@ -224,11 +226,11 @@ static bool sinkInstruction(Loop &L, Instruction &I, } // Replaces uses of I with IC in blocks dominated by N replaceDominatedUsesWith(&I, IC, DT, N); - DEBUG(dbgs() << "Sinking a clone of " << I << " To: " << N->getName() - << '\n'); + LLVM_DEBUG(dbgs() << "Sinking a clone of " << I << " To: " << N->getName() + << '\n'); NumLoopSunkCloned++; } - DEBUG(dbgs() << "Sinking " << I << " To: " << MoveBB->getName() << '\n'); + LLVM_DEBUG(dbgs() << "Sinking " << I << " To: " << MoveBB->getName() << '\n'); NumLoopSunk++; I.moveBefore(&*MoveBB->getFirstInsertionPt()); diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 953854c8b7b7..fa83b48210bc 100644 --- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -75,6 +75,8 @@ #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ScalarEvolutionNormalization.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Config/llvm-config.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" @@ -105,8 +107,8 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> #include <cstddef> @@ -121,7 +123,7 @@ using namespace llvm; #define DEBUG_TYPE "loop-reduce" -/// MaxIVUsers is an arbitrary threshold that provides an early opportunitiy for +/// MaxIVUsers is an arbitrary threshold that provides an early opportunity for /// bail out. This threshold is far beyond the number of users that LSR can /// conceivably solve, so it should not affect generated code, but catches the /// worst cases before LSR burns too much compile time and stack space. @@ -185,6 +187,8 @@ struct MemAccessTy { unsigned AS = UnknownAddressSpace) { return MemAccessTy(Type::getVoidTy(Ctx), AS); } + + Type *getType() { return MemTy; } }; /// This class holds data which is used to order reuse candidates. @@ -327,7 +331,7 @@ struct Formula { /// #2 enforces that 1 * reg is reg. /// #3 ensures invariant regs with respect to current loop can be combined /// together in LSR codegen. - /// This invariant can be temporarly broken while building a formula. + /// This invariant can be temporarily broken while building a formula. /// However, every formula inserted into the LSRInstance must be in canonical /// form. SmallVector<const SCEV *, 4> BaseRegs; @@ -442,7 +446,7 @@ void Formula::initialMatch(const SCEV *S, Loop *L, ScalarEvolution &SE) { canonicalize(*L); } -/// \brief Check whether or not this formula statisfies the canonical +/// Check whether or not this formula satisfies the canonical /// representation. /// \see Formula::BaseRegs. bool Formula::isCanonical(const Loop &L) const { @@ -470,7 +474,7 @@ bool Formula::isCanonical(const Loop &L) const { return I == BaseRegs.end(); } -/// \brief Helper method to morph a formula into its canonical representation. +/// Helper method to morph a formula into its canonical representation. /// \see Formula::BaseRegs. /// Every formula having more than one base register, must use the ScaledReg /// field. Otherwise, we would have to do special cases everywhere in LSR @@ -505,7 +509,7 @@ void Formula::canonicalize(const Loop &L) { } } -/// \brief Get rid of the scale in the formula. +/// Get rid of the scale in the formula. /// In other words, this method morphes reg1 + 1*reg2 into reg1 + reg2. /// \return true if it was possible to get rid of the scale, false otherwise. /// \note After this operation the formula may not be in the canonical form. @@ -818,7 +822,7 @@ static bool isAddressUse(const TargetTransformInfo &TTI, /// Return the type of the memory being accessed. static MemAccessTy getAccessType(const TargetTransformInfo &TTI, - Instruction *Inst) { + Instruction *Inst, Value *OperandVal) { MemAccessTy AccessTy(Inst->getType(), MemAccessTy::UnknownAddressSpace); if (const StoreInst *SI = dyn_cast<StoreInst>(Inst)) { AccessTy.MemTy = SI->getOperand(0)->getType(); @@ -832,7 +836,14 @@ static MemAccessTy getAccessType(const TargetTransformInfo &TTI, } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { switch (II->getIntrinsicID()) { case Intrinsic::prefetch: + case Intrinsic::memset: AccessTy.AddrSpace = II->getArgOperand(0)->getType()->getPointerAddressSpace(); + AccessTy.MemTy = OperandVal->getType(); + break; + case Intrinsic::memmove: + case Intrinsic::memcpy: + AccessTy.AddrSpace = OperandVal->getType()->getPointerAddressSpace(); + AccessTy.MemTy = OperandVal->getType(); break; default: { MemIntrinsicInfo IntrInfo; @@ -857,12 +868,11 @@ static MemAccessTy getAccessType(const TargetTransformInfo &TTI, /// Return true if this AddRec is already a phi in its loop. static bool isExistingPhi(const SCEVAddRecExpr *AR, ScalarEvolution &SE) { - for (BasicBlock::iterator I = AR->getLoop()->getHeader()->begin(); - PHINode *PN = dyn_cast<PHINode>(I); ++I) { - if (SE.isSCEVable(PN->getType()) && - (SE.getEffectiveSCEVType(PN->getType()) == + for (PHINode &PN : AR->getLoop()->getHeader()->phis()) { + if (SE.isSCEVable(PN.getType()) && + (SE.getEffectiveSCEVType(PN.getType()) == SE.getEffectiveSCEVType(AR->getType())) && - SE.getSCEV(PN) == AR) + SE.getSCEV(&PN) == AR) return true; } return false; @@ -938,7 +948,7 @@ static bool isHighCostExpansion(const SCEV *S, return true; } -/// If any of the instructions is the specified set are trivially dead, delete +/// If any of the instructions in the specified set are trivially dead, delete /// them and see if this makes any of their operands subsequently dead. static bool DeleteTriviallyDeadInstructions(SmallVectorImpl<WeakTrackingVH> &DeadInsts) { @@ -971,7 +981,7 @@ class LSRUse; } // end anonymous namespace -/// \brief Check if the addressing mode defined by \p F is completely +/// Check if the addressing mode defined by \p F is completely /// folded in \p LU at isel time. /// This includes address-mode folding and special icmp tricks. /// This function returns true if \p LU can accommodate what \p F @@ -1041,12 +1051,14 @@ private: void RateRegister(const SCEV *Reg, SmallPtrSetImpl<const SCEV *> &Regs, const Loop *L, - ScalarEvolution &SE, DominatorTree &DT); + ScalarEvolution &SE, DominatorTree &DT, + const TargetTransformInfo &TTI); void RatePrimaryRegister(const SCEV *Reg, SmallPtrSetImpl<const SCEV *> &Regs, const Loop *L, ScalarEvolution &SE, DominatorTree &DT, - SmallPtrSetImpl<const SCEV *> *LoserRegs); + SmallPtrSetImpl<const SCEV *> *LoserRegs, + const TargetTransformInfo &TTI); }; /// An operand value in an instruction which is to be replaced with some @@ -1195,7 +1207,8 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, void Cost::RateRegister(const SCEV *Reg, SmallPtrSetImpl<const SCEV *> &Regs, const Loop *L, - ScalarEvolution &SE, DominatorTree &DT) { + ScalarEvolution &SE, DominatorTree &DT, + const TargetTransformInfo &TTI) { if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Reg)) { // If this is an addrec for another loop, it should be an invariant // with respect to L since L is the innermost loop (at least @@ -1216,13 +1229,28 @@ void Cost::RateRegister(const SCEV *Reg, ++C.NumRegs; return; } - C.AddRecCost += 1; /// TODO: This should be a function of the stride. + + unsigned LoopCost = 1; + if (TTI.shouldFavorPostInc()) { + const SCEV *LoopStep = AR->getStepRecurrence(SE); + if (isa<SCEVConstant>(LoopStep)) { + // Check if a post-indexed load/store can be used. + if (TTI.isIndexedLoadLegal(TTI.MIM_PostInc, AR->getType()) || + TTI.isIndexedStoreLegal(TTI.MIM_PostInc, AR->getType())) { + const SCEV *LoopStart = AR->getStart(); + if (!isa<SCEVConstant>(LoopStart) && + SE.isLoopInvariant(LoopStart, L)) + LoopCost = 0; + } + } + } + C.AddRecCost += LoopCost; // Add the step value register, if it needs one. // TODO: The non-affine case isn't precisely modeled here. if (!AR->isAffine() || !isa<SCEVConstant>(AR->getOperand(1))) { if (!Regs.count(AR->getOperand(1))) { - RateRegister(AR->getOperand(1), Regs, L, SE, DT); + RateRegister(AR->getOperand(1), Regs, L, SE, DT, TTI); if (isLoser()) return; } @@ -1250,13 +1278,14 @@ void Cost::RatePrimaryRegister(const SCEV *Reg, SmallPtrSetImpl<const SCEV *> &Regs, const Loop *L, ScalarEvolution &SE, DominatorTree &DT, - SmallPtrSetImpl<const SCEV *> *LoserRegs) { + SmallPtrSetImpl<const SCEV *> *LoserRegs, + const TargetTransformInfo &TTI) { if (LoserRegs && LoserRegs->count(Reg)) { Lose(); return; } if (Regs.insert(Reg).second) { - RateRegister(Reg, Regs, L, SE, DT); + RateRegister(Reg, Regs, L, SE, DT, TTI); if (LoserRegs && isLoser()) LoserRegs->insert(Reg); } @@ -1280,7 +1309,7 @@ void Cost::RateFormula(const TargetTransformInfo &TTI, Lose(); return; } - RatePrimaryRegister(ScaledReg, Regs, L, SE, DT, LoserRegs); + RatePrimaryRegister(ScaledReg, Regs, L, SE, DT, LoserRegs, TTI); if (isLoser()) return; } @@ -1289,7 +1318,7 @@ void Cost::RateFormula(const TargetTransformInfo &TTI, Lose(); return; } - RatePrimaryRegister(BaseReg, Regs, L, SE, DT, LoserRegs); + RatePrimaryRegister(BaseReg, Regs, L, SE, DT, LoserRegs, TTI); if (isLoser()) return; } @@ -1344,14 +1373,15 @@ void Cost::RateFormula(const TargetTransformInfo &TTI, // If ICmpZero formula ends with not 0, it could not be replaced by // just add or sub. We'll need to compare final result of AddRec. - // That means we'll need an additional instruction. + // That means we'll need an additional instruction. But if the target can + // macro-fuse a compare with a branch, don't count this extra instruction. // For -10 + {0, +, 1}: // i = i + 1; // cmp i, 10 // // For {-10, +, 1}: // i = i + 1; - if (LU.Kind == LSRUse::ICmpZero && !F.hasZeroEnd()) + if (LU.Kind == LSRUse::ICmpZero && !F.hasZeroEnd() && !TTI.canMacroFuseCmp()) C.Insns++; // Each new AddRec adds 1 instruction to calculation. C.Insns += (C.AddRecCost - PrevAddRecCost); @@ -1457,7 +1487,7 @@ bool LSRUse::HasFormulaWithSameRegs(const Formula &F) const { SmallVector<const SCEV *, 4> Key = F.BaseRegs; if (F.ScaledReg) Key.push_back(F.ScaledReg); // Unstable sort by host order ok, because this is only used for uniquifying. - std::sort(Key.begin(), Key.end()); + llvm::sort(Key.begin(), Key.end()); return Uniquifier.count(Key); } @@ -1481,7 +1511,7 @@ bool LSRUse::InsertFormula(const Formula &F, const Loop &L) { SmallVector<const SCEV *, 4> Key = F.BaseRegs; if (F.ScaledReg) Key.push_back(F.ScaledReg); // Unstable sort by host order ok, because this is only used for uniquifying. - std::sort(Key.begin(), Key.end()); + llvm::sort(Key.begin(), Key.end()); if (!Uniquifier.insert(Key).second) return false; @@ -2385,24 +2415,27 @@ LSRInstance::OptimizeLoopTermCond() { C->getValue().isMinSignedValue()) goto decline_post_inc; // Check for possible scaled-address reuse. - MemAccessTy AccessTy = getAccessType(TTI, UI->getUser()); - int64_t Scale = C->getSExtValue(); - if (TTI.isLegalAddressingMode(AccessTy.MemTy, /*BaseGV=*/nullptr, - /*BaseOffset=*/0, - /*HasBaseReg=*/false, Scale, - AccessTy.AddrSpace)) - goto decline_post_inc; - Scale = -Scale; - if (TTI.isLegalAddressingMode(AccessTy.MemTy, /*BaseGV=*/nullptr, - /*BaseOffset=*/0, - /*HasBaseReg=*/false, Scale, - AccessTy.AddrSpace)) - goto decline_post_inc; + if (isAddressUse(TTI, UI->getUser(), UI->getOperandValToReplace())) { + MemAccessTy AccessTy = getAccessType( + TTI, UI->getUser(), UI->getOperandValToReplace()); + int64_t Scale = C->getSExtValue(); + if (TTI.isLegalAddressingMode(AccessTy.MemTy, /*BaseGV=*/nullptr, + /*BaseOffset=*/0, + /*HasBaseReg=*/false, Scale, + AccessTy.AddrSpace)) + goto decline_post_inc; + Scale = -Scale; + if (TTI.isLegalAddressingMode(AccessTy.MemTy, /*BaseGV=*/nullptr, + /*BaseOffset=*/0, + /*HasBaseReg=*/false, Scale, + AccessTy.AddrSpace)) + goto decline_post_inc; + } } } - DEBUG(dbgs() << " Change loop exiting icmp to use postinc iv: " - << *Cond << '\n'); + LLVM_DEBUG(dbgs() << " Change loop exiting icmp to use postinc iv: " + << *Cond << '\n'); // It's possible for the setcc instruction to be anywhere in the loop, and // possible for it to have multiple users. If it is not immediately before @@ -2643,7 +2676,7 @@ void LSRInstance::CollectInterestingTypesAndFactors() { if (Types.size() == 1) Types.clear(); - DEBUG(print_factors_and_types(dbgs())); + LLVM_DEBUG(print_factors_and_types(dbgs())); } /// Helper for CollectChains that finds an IV operand (computed by an AddRec in @@ -2667,7 +2700,7 @@ findIVOperand(User::op_iterator OI, User::op_iterator OE, return OI; } -/// IVChain logic must consistenctly peek base TruncInst operands, so wrap it in +/// IVChain logic must consistently peek base TruncInst operands, so wrap it in /// a convenient helper. static Value *getWideOperand(Value *Oper) { if (TruncInst *Trunc = dyn_cast<TruncInst>(Oper)) @@ -2774,10 +2807,9 @@ isProfitableChain(IVChain &Chain, SmallPtrSetImpl<Instruction*> &Users, return false; if (!Users.empty()) { - DEBUG(dbgs() << "Chain: " << *Chain.Incs[0].UserInst << " users:\n"; - for (Instruction *Inst : Users) { - dbgs() << " " << *Inst << "\n"; - }); + LLVM_DEBUG(dbgs() << "Chain: " << *Chain.Incs[0].UserInst << " users:\n"; + for (Instruction *Inst + : Users) { dbgs() << " " << *Inst << "\n"; }); return false; } assert(!Chain.Incs.empty() && "empty IV chains are not allowed"); @@ -2830,8 +2862,8 @@ isProfitableChain(IVChain &Chain, SmallPtrSetImpl<Instruction*> &Users, // the stride. cost -= NumReusedIncrements; - DEBUG(dbgs() << "Chain: " << *Chain.Incs[0].UserInst << " Cost: " << cost - << "\n"); + LLVM_DEBUG(dbgs() << "Chain: " << *Chain.Incs[0].UserInst << " Cost: " << cost + << "\n"); return cost < 0; } @@ -2884,7 +2916,7 @@ void LSRInstance::ChainInstruction(Instruction *UserInst, Instruction *IVOper, if (isa<PHINode>(UserInst)) return; if (NChains >= MaxChains && !StressIVChain) { - DEBUG(dbgs() << "IV Chain Limit\n"); + LLVM_DEBUG(dbgs() << "IV Chain Limit\n"); return; } LastIncExpr = OperExpr; @@ -2897,11 +2929,11 @@ void LSRInstance::ChainInstruction(Instruction *UserInst, Instruction *IVOper, IVChainVec.push_back(IVChain(IVInc(UserInst, IVOper, LastIncExpr), OperExprBase)); ChainUsersVec.resize(NChains); - DEBUG(dbgs() << "IV Chain#" << ChainIdx << " Head: (" << *UserInst - << ") IV=" << *LastIncExpr << "\n"); + LLVM_DEBUG(dbgs() << "IV Chain#" << ChainIdx << " Head: (" << *UserInst + << ") IV=" << *LastIncExpr << "\n"); } else { - DEBUG(dbgs() << "IV Chain#" << ChainIdx << " Inc: (" << *UserInst - << ") IV+" << *LastIncExpr << "\n"); + LLVM_DEBUG(dbgs() << "IV Chain#" << ChainIdx << " Inc: (" << *UserInst + << ") IV+" << *LastIncExpr << "\n"); // Add this IV user to the end of the chain. IVChainVec[ChainIdx].add(IVInc(UserInst, IVOper, LastIncExpr)); } @@ -2971,7 +3003,7 @@ void LSRInstance::ChainInstruction(Instruction *UserInst, Instruction *IVOper, /// loop latch. This will discover chains on side paths, but requires /// maintaining multiple copies of the Chains state. void LSRInstance::CollectChains() { - DEBUG(dbgs() << "Collecting IV Chains.\n"); + LLVM_DEBUG(dbgs() << "Collecting IV Chains.\n"); SmallVector<ChainUsers, 8> ChainUsersVec; SmallVector<BasicBlock *,8> LatchPath; @@ -3013,15 +3045,14 @@ void LSRInstance::CollectChains() { } // Continue walking down the instructions. } // Continue walking down the domtree. // Visit phi backedges to determine if the chain can generate the IV postinc. - for (BasicBlock::iterator I = L->getHeader()->begin(); - PHINode *PN = dyn_cast<PHINode>(I); ++I) { - if (!SE.isSCEVable(PN->getType())) + for (PHINode &PN : L->getHeader()->phis()) { + if (!SE.isSCEVable(PN.getType())) continue; Instruction *IncV = - dyn_cast<Instruction>(PN->getIncomingValueForBlock(L->getLoopLatch())); + dyn_cast<Instruction>(PN.getIncomingValueForBlock(L->getLoopLatch())); if (IncV) - ChainInstruction(PN, IncV, ChainUsersVec); + ChainInstruction(&PN, IncV, ChainUsersVec); } // Remove any unprofitable chains. unsigned ChainIdx = 0; @@ -3041,10 +3072,10 @@ void LSRInstance::CollectChains() { void LSRInstance::FinalizeChain(IVChain &Chain) { assert(!Chain.Incs.empty() && "empty IV chains are not allowed"); - DEBUG(dbgs() << "Final Chain: " << *Chain.Incs[0].UserInst << "\n"); + LLVM_DEBUG(dbgs() << "Final Chain: " << *Chain.Incs[0].UserInst << "\n"); for (const IVInc &Inc : Chain) { - DEBUG(dbgs() << " Inc: " << *Inc.UserInst << "\n"); + LLVM_DEBUG(dbgs() << " Inc: " << *Inc.UserInst << "\n"); auto UseI = find(Inc.UserInst->operands(), Inc.IVOperand); assert(UseI != Inc.UserInst->op_end() && "cannot find IV operand"); IVIncSet.insert(UseI); @@ -3061,7 +3092,7 @@ static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst, if (IncConst->getAPInt().getMinSignedBits() > 64) return false; - MemAccessTy AccessTy = getAccessType(TTI, UserInst); + MemAccessTy AccessTy = getAccessType(TTI, UserInst, Operand); int64_t IncOffset = IncConst->getValue()->getSExtValue(); if (!isAlwaysFoldable(TTI, LSRUse::Address, AccessTy, /*BaseGV=*/nullptr, IncOffset, /*HaseBaseReg=*/false)) @@ -3101,11 +3132,11 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, SCEVExpander &Rewriter, } if (IVOpIter == IVOpEnd) { // Gracefully give up on this chain. - DEBUG(dbgs() << "Concealed chain head: " << *Head.UserInst << "\n"); + LLVM_DEBUG(dbgs() << "Concealed chain head: " << *Head.UserInst << "\n"); return; } - DEBUG(dbgs() << "Generate chain at: " << *IVSrc << "\n"); + LLVM_DEBUG(dbgs() << "Generate chain at: " << *IVSrc << "\n"); Type *IVTy = IVSrc->getType(); Type *IntTy = SE.getEffectiveSCEVType(IVTy); const SCEV *LeftOverExpr = nullptr; @@ -3152,12 +3183,11 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, SCEVExpander &Rewriter, // If LSR created a new, wider phi, we may also replace its postinc. We only // do this if we also found a wide value for the head of the chain. if (isa<PHINode>(Chain.tailUserInst())) { - for (BasicBlock::iterator I = L->getHeader()->begin(); - PHINode *Phi = dyn_cast<PHINode>(I); ++I) { - if (!isCompatibleIVType(Phi, IVSrc)) + for (PHINode &Phi : L->getHeader()->phis()) { + if (!isCompatibleIVType(&Phi, IVSrc)) continue; Instruction *PostIncV = dyn_cast<Instruction>( - Phi->getIncomingValueForBlock(L->getLoopLatch())); + Phi.getIncomingValueForBlock(L->getLoopLatch())); if (!PostIncV || (SE.getSCEV(PostIncV) != SE.getSCEV(IVSrc))) continue; Value *IVOper = IVSrc; @@ -3168,7 +3198,7 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, SCEVExpander &Rewriter, Builder.SetCurrentDebugLocation(PostIncV->getDebugLoc()); IVOper = Builder.CreatePointerCast(IVSrc, PostIncTy, "lsr.chain"); } - Phi->replaceUsesOfWith(PostIncV, IVOper); + Phi.replaceUsesOfWith(PostIncV, IVOper); DeadInsts.emplace_back(PostIncV); } } @@ -3182,7 +3212,7 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { find(UserInst->operands(), U.getOperandValToReplace()); assert(UseI != UserInst->op_end() && "cannot find IV operand"); if (IVIncSet.count(UseI)) { - DEBUG(dbgs() << "Use is in profitable chain: " << **UseI << '\n'); + LLVM_DEBUG(dbgs() << "Use is in profitable chain: " << **UseI << '\n'); continue; } @@ -3190,7 +3220,7 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { MemAccessTy AccessTy; if (isAddressUse(TTI, UserInst, U.getOperandValToReplace())) { Kind = LSRUse::Address; - AccessTy = getAccessType(TTI, UserInst); + AccessTy = getAccessType(TTI, UserInst, U.getOperandValToReplace()); } const SCEV *S = IU.getExpr(U); @@ -3258,7 +3288,7 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { } } - DEBUG(print_fixups(dbgs())); + LLVM_DEBUG(print_fixups(dbgs())); } /// Insert a formula for the given expression into the given use, separating out @@ -3467,12 +3497,45 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C, return S; } -/// \brief Helper function for LSRInstance::GenerateReassociations. +/// Return true if the SCEV represents a value that may end up as a +/// post-increment operation. +static bool mayUsePostIncMode(const TargetTransformInfo &TTI, + LSRUse &LU, const SCEV *S, const Loop *L, + ScalarEvolution &SE) { + if (LU.Kind != LSRUse::Address || + !LU.AccessTy.getType()->isIntOrIntVectorTy()) + return false; + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S); + if (!AR) + return false; + const SCEV *LoopStep = AR->getStepRecurrence(SE); + if (!isa<SCEVConstant>(LoopStep)) + return false; + if (LU.AccessTy.getType()->getScalarSizeInBits() != + LoopStep->getType()->getScalarSizeInBits()) + return false; + // Check if a post-indexed load/store can be used. + if (TTI.isIndexedLoadLegal(TTI.MIM_PostInc, AR->getType()) || + TTI.isIndexedStoreLegal(TTI.MIM_PostInc, AR->getType())) { + const SCEV *LoopStart = AR->getStart(); + if (!isa<SCEVConstant>(LoopStart) && SE.isLoopInvariant(LoopStart, L)) + return true; + } + return false; +} + +/// Helper function for LSRInstance::GenerateReassociations. void LSRInstance::GenerateReassociationsImpl(LSRUse &LU, unsigned LUIdx, const Formula &Base, unsigned Depth, size_t Idx, bool IsScaledReg) { const SCEV *BaseReg = IsScaledReg ? Base.ScaledReg : Base.BaseRegs[Idx]; + // Don't generate reassociations for the base register of a value that + // may generate a post-increment operator. The reason is that the + // reassociations cause extra base+register formula to be created, + // and possibly chosen, but the post-increment is more efficient. + if (TTI.shouldFavorPostInc() && mayUsePostIncMode(TTI, LU, BaseReg, L, SE)) + return; SmallVector<const SCEV *, 8> AddOps; const SCEV *Remainder = CollectSubexprs(BaseReg, nullptr, AddOps, L, SE); if (Remainder) @@ -3545,7 +3608,12 @@ void LSRInstance::GenerateReassociationsImpl(LSRUse &LU, unsigned LUIdx, if (InsertFormula(LU, LUIdx, F)) // If that formula hadn't been seen before, recurse to find more like // it. - GenerateReassociations(LU, LUIdx, LU.Formulae.back(), Depth + 1); + // Add check on Log16(AddOps.size()) - same as Log2_32(AddOps.size()) >> 2) + // Because just Depth is not enough to bound compile time. + // This means that every time AddOps.size() is greater 16^x we will add + // x to Depth. + GenerateReassociations(LU, LUIdx, LU.Formulae.back(), + Depth + 1 + (Log2_32(AddOps.size()) >> 2)); } } @@ -3599,7 +3667,7 @@ void LSRInstance::GenerateCombinations(LSRUse &LU, unsigned LUIdx, } } -/// \brief Helper function for LSRInstance::GenerateSymbolicOffsets. +/// Helper function for LSRInstance::GenerateSymbolicOffsets. void LSRInstance::GenerateSymbolicOffsetsImpl(LSRUse &LU, unsigned LUIdx, const Formula &Base, size_t Idx, bool IsScaledReg) { @@ -3631,7 +3699,7 @@ void LSRInstance::GenerateSymbolicOffsets(LSRUse &LU, unsigned LUIdx, /* IsScaledReg */ true); } -/// \brief Helper function for LSRInstance::GenerateConstantOffsets. +/// Helper function for LSRInstance::GenerateConstantOffsets. void LSRInstance::GenerateConstantOffsetsImpl( LSRUse &LU, unsigned LUIdx, const Formula &Base, const SmallVectorImpl<int64_t> &Worklist, size_t Idx, bool IsScaledReg) { @@ -3941,10 +4009,11 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { if (Imms.size() == 1) continue; - DEBUG(dbgs() << "Generating cross-use offsets for " << *Reg << ':'; - for (const auto &Entry : Imms) - dbgs() << ' ' << Entry.first; - dbgs() << '\n'); + LLVM_DEBUG(dbgs() << "Generating cross-use offsets for " << *Reg << ':'; + for (const auto &Entry + : Imms) dbgs() + << ' ' << Entry.first; + dbgs() << '\n'); // Examine each offset. for (ImmMapTy::const_iterator J = Imms.begin(), JE = Imms.end(); @@ -3956,7 +4025,8 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { if (!isa<SCEVConstant>(OrigReg) && UsedByIndicesMap[Reg].count() == 1) { - DEBUG(dbgs() << "Skipping cross-use reuse for " << *OrigReg << '\n'); + LLVM_DEBUG(dbgs() << "Skipping cross-use reuse for " << *OrigReg + << '\n'); continue; } @@ -4041,6 +4111,9 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { NewF.BaseOffset = (uint64_t)NewF.BaseOffset + Imm; if (!isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, NewF)) { + if (TTI.shouldFavorPostInc() && + mayUsePostIncMode(TTI, LU, OrigReg, this->L, SE)) + continue; if (!TTI.isLegalAddImmediate((uint64_t)NewF.UnfoldedOffset + Imm)) continue; NewF = F; @@ -4102,9 +4175,9 @@ LSRInstance::GenerateAllReuseFormulae() { GenerateCrossUseConstantOffsets(); - DEBUG(dbgs() << "\n" - "After generating reuse formulae:\n"; - print_uses(dbgs())); + LLVM_DEBUG(dbgs() << "\n" + "After generating reuse formulae:\n"; + print_uses(dbgs())); } /// If there are multiple formulae with the same set of registers used @@ -4126,7 +4199,8 @@ void LSRInstance::FilterOutUndesirableDedicatedRegisters() { for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { LSRUse &LU = Uses[LUIdx]; - DEBUG(dbgs() << "Filtering for use "; LU.print(dbgs()); dbgs() << '\n'); + LLVM_DEBUG(dbgs() << "Filtering for use "; LU.print(dbgs()); + dbgs() << '\n'); bool Any = false; for (size_t FIdx = 0, NumForms = LU.Formulae.size(); @@ -4150,8 +4224,8 @@ void LSRInstance::FilterOutUndesirableDedicatedRegisters() { // as the basis of rediscovering the desired formula that uses an AddRec // corresponding to the existing phi. Once all formulae have been // generated, these initial losers may be pruned. - DEBUG(dbgs() << " Filtering loser "; F.print(dbgs()); - dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " Filtering loser "; F.print(dbgs()); + dbgs() << "\n"); } else { SmallVector<const SCEV *, 4> Key; @@ -4164,7 +4238,7 @@ void LSRInstance::FilterOutUndesirableDedicatedRegisters() { Key.push_back(F.ScaledReg); // Unstable sort by host order ok, because this is only used for // uniquifying. - std::sort(Key.begin(), Key.end()); + llvm::sort(Key.begin(), Key.end()); std::pair<BestFormulaeTy::const_iterator, bool> P = BestFormulae.insert(std::make_pair(Key, FIdx)); @@ -4178,10 +4252,10 @@ void LSRInstance::FilterOutUndesirableDedicatedRegisters() { CostBest.RateFormula(TTI, Best, Regs, VisitedRegs, L, SE, DT, LU); if (CostF.isLess(CostBest, TTI)) std::swap(F, Best); - DEBUG(dbgs() << " Filtering out formula "; F.print(dbgs()); - dbgs() << "\n" - " in favor of formula "; Best.print(dbgs()); - dbgs() << '\n'); + LLVM_DEBUG(dbgs() << " Filtering out formula "; F.print(dbgs()); + dbgs() << "\n" + " in favor of formula "; + Best.print(dbgs()); dbgs() << '\n'); } #ifndef NDEBUG ChangedFormulae = true; @@ -4200,11 +4274,11 @@ void LSRInstance::FilterOutUndesirableDedicatedRegisters() { BestFormulae.clear(); } - DEBUG(if (ChangedFormulae) { - dbgs() << "\n" - "After filtering out undesirable candidates:\n"; - print_uses(dbgs()); - }); + LLVM_DEBUG(if (ChangedFormulae) { + dbgs() << "\n" + "After filtering out undesirable candidates:\n"; + print_uses(dbgs()); + }); } // This is a rough guess that seems to work fairly well. @@ -4233,11 +4307,11 @@ size_t LSRInstance::EstimateSearchSpaceComplexity() const { /// register pressure); remove it to simplify the system. void LSRInstance::NarrowSearchSpaceByDetectingSupersets() { if (EstimateSearchSpaceComplexity() >= ComplexityLimit) { - DEBUG(dbgs() << "The search space is too complex.\n"); + LLVM_DEBUG(dbgs() << "The search space is too complex.\n"); - DEBUG(dbgs() << "Narrowing the search space by eliminating formulae " - "which use a superset of registers used by other " - "formulae.\n"); + LLVM_DEBUG(dbgs() << "Narrowing the search space by eliminating formulae " + "which use a superset of registers used by other " + "formulae.\n"); for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { LSRUse &LU = Uses[LUIdx]; @@ -4255,7 +4329,8 @@ void LSRInstance::NarrowSearchSpaceByDetectingSupersets() { NewF.BaseRegs.erase(NewF.BaseRegs.begin() + (I - F.BaseRegs.begin())); if (LU.HasFormulaWithSameRegs(NewF)) { - DEBUG(dbgs() << " Deleting "; F.print(dbgs()); dbgs() << '\n'); + LLVM_DEBUG(dbgs() << " Deleting "; F.print(dbgs()); + dbgs() << '\n'); LU.DeleteFormula(F); --i; --e; @@ -4270,8 +4345,8 @@ void LSRInstance::NarrowSearchSpaceByDetectingSupersets() { NewF.BaseRegs.erase(NewF.BaseRegs.begin() + (I - F.BaseRegs.begin())); if (LU.HasFormulaWithSameRegs(NewF)) { - DEBUG(dbgs() << " Deleting "; F.print(dbgs()); - dbgs() << '\n'); + LLVM_DEBUG(dbgs() << " Deleting "; F.print(dbgs()); + dbgs() << '\n'); LU.DeleteFormula(F); --i; --e; @@ -4286,8 +4361,7 @@ void LSRInstance::NarrowSearchSpaceByDetectingSupersets() { LU.RecomputeRegs(LUIdx, RegUses); } - DEBUG(dbgs() << "After pre-selection:\n"; - print_uses(dbgs())); + LLVM_DEBUG(dbgs() << "After pre-selection:\n"; print_uses(dbgs())); } } @@ -4297,9 +4371,10 @@ void LSRInstance::NarrowSearchSpaceByCollapsingUnrolledCode() { if (EstimateSearchSpaceComplexity() < ComplexityLimit) return; - DEBUG(dbgs() << "The search space is too complex.\n" - "Narrowing the search space by assuming that uses separated " - "by a constant offset will use the same registers.\n"); + LLVM_DEBUG( + dbgs() << "The search space is too complex.\n" + "Narrowing the search space by assuming that uses separated " + "by a constant offset will use the same registers.\n"); // This is especially useful for unrolled loops. @@ -4317,7 +4392,7 @@ void LSRInstance::NarrowSearchSpaceByCollapsingUnrolledCode() { LU.Kind, LU.AccessTy)) continue; - DEBUG(dbgs() << " Deleting use "; LU.print(dbgs()); dbgs() << '\n'); + LLVM_DEBUG(dbgs() << " Deleting use "; LU.print(dbgs()); dbgs() << '\n'); LUThatHas->AllFixupsOutsideLoop &= LU.AllFixupsOutsideLoop; @@ -4325,7 +4400,7 @@ void LSRInstance::NarrowSearchSpaceByCollapsingUnrolledCode() { for (LSRFixup &Fixup : LU.Fixups) { Fixup.Offset += F.BaseOffset; LUThatHas->pushFixup(Fixup); - DEBUG(dbgs() << "New fixup has offset " << Fixup.Offset << '\n'); + LLVM_DEBUG(dbgs() << "New fixup has offset " << Fixup.Offset << '\n'); } // Delete formulae from the new use which are no longer legal. @@ -4334,8 +4409,7 @@ void LSRInstance::NarrowSearchSpaceByCollapsingUnrolledCode() { Formula &F = LUThatHas->Formulae[i]; if (!isLegalUse(TTI, LUThatHas->MinOffset, LUThatHas->MaxOffset, LUThatHas->Kind, LUThatHas->AccessTy, F)) { - DEBUG(dbgs() << " Deleting "; F.print(dbgs()); - dbgs() << '\n'); + LLVM_DEBUG(dbgs() << " Deleting "; F.print(dbgs()); dbgs() << '\n'); LUThatHas->DeleteFormula(F); --i; --e; @@ -4354,7 +4428,7 @@ void LSRInstance::NarrowSearchSpaceByCollapsingUnrolledCode() { } } - DEBUG(dbgs() << "After pre-selection:\n"; print_uses(dbgs())); + LLVM_DEBUG(dbgs() << "After pre-selection:\n"; print_uses(dbgs())); } /// Call FilterOutUndesirableDedicatedRegisters again, if necessary, now that @@ -4362,15 +4436,14 @@ void LSRInstance::NarrowSearchSpaceByCollapsingUnrolledCode() { /// eliminate. void LSRInstance::NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters(){ if (EstimateSearchSpaceComplexity() >= ComplexityLimit) { - DEBUG(dbgs() << "The search space is too complex.\n"); + LLVM_DEBUG(dbgs() << "The search space is too complex.\n"); - DEBUG(dbgs() << "Narrowing the search space by re-filtering out " - "undesirable dedicated registers.\n"); + LLVM_DEBUG(dbgs() << "Narrowing the search space by re-filtering out " + "undesirable dedicated registers.\n"); FilterOutUndesirableDedicatedRegisters(); - DEBUG(dbgs() << "After pre-selection:\n"; - print_uses(dbgs())); + LLVM_DEBUG(dbgs() << "After pre-selection:\n"; print_uses(dbgs())); } } @@ -4381,15 +4454,16 @@ void LSRInstance::NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters(){ /// The benefit is that it is more likely to find out a better solution /// from a formulae set with more Scale and ScaledReg variations than /// a formulae set with the same Scale and ScaledReg. The picking winner -/// reg heurstic will often keep the formulae with the same Scale and +/// reg heuristic will often keep the formulae with the same Scale and /// ScaledReg and filter others, and we want to avoid that if possible. void LSRInstance::NarrowSearchSpaceByFilterFormulaWithSameScaledReg() { if (EstimateSearchSpaceComplexity() < ComplexityLimit) return; - DEBUG(dbgs() << "The search space is too complex.\n" - "Narrowing the search space by choosing the best Formula " - "from the Formulae with the same Scale and ScaledReg.\n"); + LLVM_DEBUG( + dbgs() << "The search space is too complex.\n" + "Narrowing the search space by choosing the best Formula " + "from the Formulae with the same Scale and ScaledReg.\n"); // Map the "Scale * ScaledReg" pair to the best formula of current LSRUse. using BestFormulaeTy = DenseMap<std::pair<const SCEV *, int64_t>, size_t>; @@ -4403,7 +4477,8 @@ void LSRInstance::NarrowSearchSpaceByFilterFormulaWithSameScaledReg() { for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { LSRUse &LU = Uses[LUIdx]; - DEBUG(dbgs() << "Filtering for use "; LU.print(dbgs()); dbgs() << '\n'); + LLVM_DEBUG(dbgs() << "Filtering for use "; LU.print(dbgs()); + dbgs() << '\n'); // Return true if Formula FA is better than Formula FB. auto IsBetterThan = [&](Formula &FA, Formula &FB) { @@ -4447,10 +4522,10 @@ void LSRInstance::NarrowSearchSpaceByFilterFormulaWithSameScaledReg() { Formula &Best = LU.Formulae[P.first->second]; if (IsBetterThan(F, Best)) std::swap(F, Best); - DEBUG(dbgs() << " Filtering out formula "; F.print(dbgs()); - dbgs() << "\n" - " in favor of formula "; - Best.print(dbgs()); dbgs() << '\n'); + LLVM_DEBUG(dbgs() << " Filtering out formula "; F.print(dbgs()); + dbgs() << "\n" + " in favor of formula "; + Best.print(dbgs()); dbgs() << '\n'); #ifndef NDEBUG ChangedFormulae = true; #endif @@ -4466,7 +4541,7 @@ void LSRInstance::NarrowSearchSpaceByFilterFormulaWithSameScaledReg() { BestFormulae.clear(); } - DEBUG(if (ChangedFormulae) { + LLVM_DEBUG(if (ChangedFormulae) { dbgs() << "\n" "After filtering out undesirable candidates:\n"; print_uses(dbgs()); @@ -4525,7 +4600,7 @@ void LSRInstance::NarrowSearchSpaceByDeletingCostlyFormulas() { // Used in each formula of a solution (in example above this is reg(c)). // We can skip them in calculations. SmallPtrSet<const SCEV *, 4> UniqRegs; - DEBUG(dbgs() << "The search space is too complex.\n"); + LLVM_DEBUG(dbgs() << "The search space is too complex.\n"); // Map each register to probability of not selecting DenseMap <const SCEV *, float> RegNumMap; @@ -4545,7 +4620,8 @@ void LSRInstance::NarrowSearchSpaceByDeletingCostlyFormulas() { RegNumMap.insert(std::make_pair(Reg, PNotSel)); } - DEBUG(dbgs() << "Narrowing the search space by deleting costly formulas\n"); + LLVM_DEBUG( + dbgs() << "Narrowing the search space by deleting costly formulas\n"); // Delete formulas where registers number expectation is high. for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { @@ -4587,26 +4663,25 @@ void LSRInstance::NarrowSearchSpaceByDeletingCostlyFormulas() { MinIdx = i; } } - DEBUG(dbgs() << " The formula "; LU.Formulae[MinIdx].print(dbgs()); - dbgs() << " with min reg num " << FMinRegNum << '\n'); + LLVM_DEBUG(dbgs() << " The formula "; LU.Formulae[MinIdx].print(dbgs()); + dbgs() << " with min reg num " << FMinRegNum << '\n'); if (MinIdx != 0) std::swap(LU.Formulae[MinIdx], LU.Formulae[0]); while (LU.Formulae.size() != 1) { - DEBUG(dbgs() << " Deleting "; LU.Formulae.back().print(dbgs()); - dbgs() << '\n'); + LLVM_DEBUG(dbgs() << " Deleting "; LU.Formulae.back().print(dbgs()); + dbgs() << '\n'); LU.Formulae.pop_back(); } LU.RecomputeRegs(LUIdx, RegUses); assert(LU.Formulae.size() == 1 && "Should be exactly 1 min regs formula"); Formula &F = LU.Formulae[0]; - DEBUG(dbgs() << " Leaving only "; F.print(dbgs()); dbgs() << '\n'); + LLVM_DEBUG(dbgs() << " Leaving only "; F.print(dbgs()); dbgs() << '\n'); // When we choose the formula, the regs become unique. UniqRegs.insert(F.BaseRegs.begin(), F.BaseRegs.end()); if (F.ScaledReg) UniqRegs.insert(F.ScaledReg); } - DEBUG(dbgs() << "After pre-selection:\n"; - print_uses(dbgs())); + LLVM_DEBUG(dbgs() << "After pre-selection:\n"; print_uses(dbgs())); } /// Pick a register which seems likely to be profitable, and then in any use @@ -4619,7 +4694,7 @@ void LSRInstance::NarrowSearchSpaceByPickingWinnerRegs() { while (EstimateSearchSpaceComplexity() >= ComplexityLimit) { // Ok, we have too many of formulae on our hands to conveniently handle. // Use a rough heuristic to thin out the list. - DEBUG(dbgs() << "The search space is too complex.\n"); + LLVM_DEBUG(dbgs() << "The search space is too complex.\n"); // Pick the register which is used by the most LSRUses, which is likely // to be a good reuse register candidate. @@ -4640,8 +4715,8 @@ void LSRInstance::NarrowSearchSpaceByPickingWinnerRegs() { } } - DEBUG(dbgs() << "Narrowing the search space by assuming " << *Best - << " will yield profitable reuse.\n"); + LLVM_DEBUG(dbgs() << "Narrowing the search space by assuming " << *Best + << " will yield profitable reuse.\n"); Taken.insert(Best); // In any use with formulae which references this register, delete formulae @@ -4654,7 +4729,7 @@ void LSRInstance::NarrowSearchSpaceByPickingWinnerRegs() { for (size_t i = 0, e = LU.Formulae.size(); i != e; ++i) { Formula &F = LU.Formulae[i]; if (!F.referencesReg(Best)) { - DEBUG(dbgs() << " Deleting "; F.print(dbgs()); dbgs() << '\n'); + LLVM_DEBUG(dbgs() << " Deleting "; F.print(dbgs()); dbgs() << '\n'); LU.DeleteFormula(F); --e; --i; @@ -4668,8 +4743,7 @@ void LSRInstance::NarrowSearchSpaceByPickingWinnerRegs() { LU.RecomputeRegs(LUIdx, RegUses); } - DEBUG(dbgs() << "After pre-selection:\n"; - print_uses(dbgs())); + LLVM_DEBUG(dbgs() << "After pre-selection:\n"; print_uses(dbgs())); } } @@ -4751,11 +4825,11 @@ void LSRInstance::SolveRecurse(SmallVectorImpl<const Formula *> &Solution, if (F.getNumRegs() == 1 && Workspace.size() == 1) VisitedRegs.insert(F.ScaledReg ? F.ScaledReg : F.BaseRegs[0]); } else { - DEBUG(dbgs() << "New best at "; NewCost.print(dbgs()); - dbgs() << ".\n Regs:"; - for (const SCEV *S : NewRegs) - dbgs() << ' ' << *S; - dbgs() << '\n'); + LLVM_DEBUG(dbgs() << "New best at "; NewCost.print(dbgs()); + dbgs() << ".\n Regs:"; for (const SCEV *S + : NewRegs) dbgs() + << ' ' << *S; + dbgs() << '\n'); SolutionCost = NewCost; Solution = Workspace; @@ -4780,22 +4854,22 @@ void LSRInstance::Solve(SmallVectorImpl<const Formula *> &Solution) const { SolveRecurse(Solution, SolutionCost, Workspace, CurCost, CurRegs, VisitedRegs); if (Solution.empty()) { - DEBUG(dbgs() << "\nNo Satisfactory Solution\n"); + LLVM_DEBUG(dbgs() << "\nNo Satisfactory Solution\n"); return; } // Ok, we've now made all our decisions. - DEBUG(dbgs() << "\n" - "The chosen solution requires "; SolutionCost.print(dbgs()); - dbgs() << ":\n"; - for (size_t i = 0, e = Uses.size(); i != e; ++i) { - dbgs() << " "; - Uses[i].print(dbgs()); - dbgs() << "\n" - " "; - Solution[i]->print(dbgs()); - dbgs() << '\n'; - }); + LLVM_DEBUG(dbgs() << "\n" + "The chosen solution requires "; + SolutionCost.print(dbgs()); dbgs() << ":\n"; + for (size_t i = 0, e = Uses.size(); i != e; ++i) { + dbgs() << " "; + Uses[i].print(dbgs()); + dbgs() << "\n" + " "; + Solution[i]->print(dbgs()); + dbgs() << '\n'; + }); assert(Solution.size() == Uses.size() && "Malformed solution!"); } @@ -4996,7 +5070,7 @@ Value *LSRInstance::Expand(const LSRUse &LU, const LSRFixup &LF, // Unless the addressing mode will not be folded. if (!Ops.empty() && LU.Kind == LSRUse::Address && isAMCompletelyFolded(TTI, LU, F)) { - Value *FullV = Rewriter.expandCodeFor(SE.getAddExpr(Ops), Ty); + Value *FullV = Rewriter.expandCodeFor(SE.getAddExpr(Ops), nullptr); Ops.clear(); Ops.push_back(SE.getUnknown(FullV)); } @@ -5269,7 +5343,8 @@ LSRInstance::LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE, for (const IVStrideUse &U : IU) { if (++NumUsers > MaxIVUsers) { (void)U; - DEBUG(dbgs() << "LSR skipping loop, too many IV Users in " << U << "\n"); + LLVM_DEBUG(dbgs() << "LSR skipping loop, too many IV Users in " << U + << "\n"); return; } // Bail out if we have a PHI on an EHPad that gets a value from a @@ -5302,9 +5377,9 @@ LSRInstance::LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE, } #endif // DEBUG - DEBUG(dbgs() << "\nLSR on loop "; - L->getHeader()->printAsOperand(dbgs(), /*PrintType=*/false); - dbgs() << ":\n"); + LLVM_DEBUG(dbgs() << "\nLSR on loop "; + L->getHeader()->printAsOperand(dbgs(), /*PrintType=*/false); + dbgs() << ":\n"); // First, perform some low-level loop optimizations. OptimizeShadowIV(); @@ -5315,7 +5390,7 @@ LSRInstance::LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE, // Skip nested loops until we can model them better with formulae. if (!L->empty()) { - DEBUG(dbgs() << "LSR skipping outer loop " << *L << "\n"); + LLVM_DEBUG(dbgs() << "LSR skipping outer loop " << *L << "\n"); return; } @@ -5325,9 +5400,11 @@ LSRInstance::LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE, CollectFixupsAndInitialFormulae(); CollectLoopInvariantFixupsAndFormulae(); - assert(!Uses.empty() && "IVUsers reported at least one use"); - DEBUG(dbgs() << "LSR found " << Uses.size() << " uses:\n"; - print_uses(dbgs())); + if (Uses.empty()) + return; + + LLVM_DEBUG(dbgs() << "LSR found " << Uses.size() << " uses:\n"; + print_uses(dbgs())); // Now use the reuse data to generate a bunch of interesting ways // to formulate the values needed for the uses. diff --git a/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp b/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp new file mode 100644 index 000000000000..86c99aed4417 --- /dev/null +++ b/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp @@ -0,0 +1,447 @@ +//===- LoopUnrollAndJam.cpp - Loop unroll and jam pass --------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass implements an unroll and jam pass. Most of the work is done by +// Utils/UnrollLoopAndJam.cpp. +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopUnrollAndJamPass.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/DependenceAnalysis.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/UnrollLoop.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <string> + +using namespace llvm; + +#define DEBUG_TYPE "loop-unroll-and-jam" + +static cl::opt<bool> + AllowUnrollAndJam("allow-unroll-and-jam", cl::Hidden, + cl::desc("Allows loops to be unroll-and-jammed.")); + +static cl::opt<unsigned> UnrollAndJamCount( + "unroll-and-jam-count", cl::Hidden, + cl::desc("Use this unroll count for all loops including those with " + "unroll_and_jam_count pragma values, for testing purposes")); + +static cl::opt<unsigned> UnrollAndJamThreshold( + "unroll-and-jam-threshold", cl::init(60), cl::Hidden, + cl::desc("Threshold to use for inner loop when doing unroll and jam.")); + +static cl::opt<unsigned> PragmaUnrollAndJamThreshold( + "pragma-unroll-and-jam-threshold", cl::init(1024), cl::Hidden, + cl::desc("Unrolled size limit for loops with an unroll_and_jam(full) or " + "unroll_count pragma.")); + +// Returns the loop hint metadata node with the given name (for example, +// "llvm.loop.unroll.count"). If no such metadata node exists, then nullptr is +// returned. +static MDNode *GetUnrollMetadataForLoop(const Loop *L, StringRef Name) { + if (MDNode *LoopID = L->getLoopID()) + return GetUnrollMetadata(LoopID, Name); + return nullptr; +} + +// Returns true if the loop has any metadata starting with Prefix. For example a +// Prefix of "llvm.loop.unroll." returns true if we have any unroll metadata. +static bool HasAnyUnrollPragma(const Loop *L, StringRef Prefix) { + if (MDNode *LoopID = L->getLoopID()) { + // First operand should refer to the loop id itself. + assert(LoopID->getNumOperands() > 0 && "requires at least one operand"); + assert(LoopID->getOperand(0) == LoopID && "invalid loop id"); + + for (unsigned i = 1, e = LoopID->getNumOperands(); i < e; ++i) { + MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); + if (!MD) + continue; + + MDString *S = dyn_cast<MDString>(MD->getOperand(0)); + if (!S) + continue; + + if (S->getString().startswith(Prefix)) + return true; + } + } + return false; +} + +// Returns true if the loop has an unroll_and_jam(enable) pragma. +static bool HasUnrollAndJamEnablePragma(const Loop *L) { + return GetUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.enable"); +} + +// Returns true if the loop has an unroll_and_jam(disable) pragma. +static bool HasUnrollAndJamDisablePragma(const Loop *L) { + return GetUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.disable"); +} + +// If loop has an unroll_and_jam_count pragma return the (necessarily +// positive) value from the pragma. Otherwise return 0. +static unsigned UnrollAndJamCountPragmaValue(const Loop *L) { + MDNode *MD = GetUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.count"); + if (MD) { + assert(MD->getNumOperands() == 2 && + "Unroll count hint metadata should have two operands."); + unsigned Count = + mdconst::extract<ConstantInt>(MD->getOperand(1))->getZExtValue(); + assert(Count >= 1 && "Unroll count must be positive."); + return Count; + } + return 0; +} + +// Returns loop size estimation for unrolled loop. +static uint64_t +getUnrollAndJammedLoopSize(unsigned LoopSize, + TargetTransformInfo::UnrollingPreferences &UP) { + assert(LoopSize >= UP.BEInsns && "LoopSize should not be less than BEInsns!"); + return static_cast<uint64_t>(LoopSize - UP.BEInsns) * UP.Count + UP.BEInsns; +} + +// Calculates unroll and jam count and writes it to UP.Count. Returns true if +// unroll count was set explicitly. +static bool computeUnrollAndJamCount( + Loop *L, Loop *SubLoop, const TargetTransformInfo &TTI, DominatorTree &DT, + LoopInfo *LI, ScalarEvolution &SE, + const SmallPtrSetImpl<const Value *> &EphValues, + OptimizationRemarkEmitter *ORE, unsigned OuterTripCount, + unsigned OuterTripMultiple, unsigned OuterLoopSize, unsigned InnerTripCount, + unsigned InnerLoopSize, TargetTransformInfo::UnrollingPreferences &UP) { + // Check for explicit Count from the "unroll-and-jam-count" option. + bool UserUnrollCount = UnrollAndJamCount.getNumOccurrences() > 0; + if (UserUnrollCount) { + UP.Count = UnrollAndJamCount; + UP.Force = true; + if (UP.AllowRemainder && + getUnrollAndJammedLoopSize(OuterLoopSize, UP) < UP.Threshold && + getUnrollAndJammedLoopSize(InnerLoopSize, UP) < + UP.UnrollAndJamInnerLoopThreshold) + return true; + } + + // Check for unroll_and_jam pragmas + unsigned PragmaCount = UnrollAndJamCountPragmaValue(L); + if (PragmaCount > 0) { + UP.Count = PragmaCount; + UP.Runtime = true; + UP.Force = true; + if ((UP.AllowRemainder || (OuterTripMultiple % PragmaCount == 0)) && + getUnrollAndJammedLoopSize(OuterLoopSize, UP) < UP.Threshold && + getUnrollAndJammedLoopSize(InnerLoopSize, UP) < + UP.UnrollAndJamInnerLoopThreshold) + return true; + } + + // Use computeUnrollCount from the loop unroller to get a sensible count + // for the unrolling the outer loop. This uses UP.Threshold / + // UP.PartialThreshold / UP.MaxCount to come up with sensible loop values. + // We have already checked that the loop has no unroll.* pragmas. + unsigned MaxTripCount = 0; + bool UseUpperBound = false; + bool ExplicitUnroll = computeUnrollCount( + L, TTI, DT, LI, SE, EphValues, ORE, OuterTripCount, MaxTripCount, + OuterTripMultiple, OuterLoopSize, UP, UseUpperBound); + if (ExplicitUnroll || UseUpperBound) { + // If the user explicitly set the loop as unrolled, dont UnJ it. Leave it + // for the unroller instead. + UP.Count = 0; + return false; + } + + bool PragmaEnableUnroll = HasUnrollAndJamEnablePragma(L); + ExplicitUnroll = PragmaCount > 0 || PragmaEnableUnroll || UserUnrollCount; + + // If the loop has an unrolling pragma, we want to be more aggressive with + // unrolling limits. + if (ExplicitUnroll && OuterTripCount != 0) + UP.UnrollAndJamInnerLoopThreshold = PragmaUnrollAndJamThreshold; + + if (!UP.AllowRemainder && getUnrollAndJammedLoopSize(InnerLoopSize, UP) >= + UP.UnrollAndJamInnerLoopThreshold) { + UP.Count = 0; + return false; + } + + // If the inner loop count is known and small, leave the entire loop nest to + // be the unroller + if (!ExplicitUnroll && InnerTripCount && + InnerLoopSize * InnerTripCount < UP.Threshold) { + UP.Count = 0; + return false; + } + + // We have a sensible limit for the outer loop, now adjust it for the inner + // loop and UP.UnrollAndJamInnerLoopThreshold. + while (UP.Count != 0 && UP.AllowRemainder && + getUnrollAndJammedLoopSize(InnerLoopSize, UP) >= + UP.UnrollAndJamInnerLoopThreshold) + UP.Count--; + + if (!ExplicitUnroll) { + // Check for situations where UnJ is likely to be unprofitable. Including + // subloops with more than 1 block. + if (SubLoop->getBlocks().size() != 1) { + UP.Count = 0; + return false; + } + + // Limit to loops where there is something to gain from unrolling and + // jamming the loop. In this case, look for loads that are invariant in the + // outer loop and can become shared. + unsigned NumInvariant = 0; + for (BasicBlock *BB : SubLoop->getBlocks()) { + for (Instruction &I : *BB) { + if (auto *Ld = dyn_cast<LoadInst>(&I)) { + Value *V = Ld->getPointerOperand(); + const SCEV *LSCEV = SE.getSCEVAtScope(V, L); + if (SE.isLoopInvariant(LSCEV, L)) + NumInvariant++; + } + } + } + if (NumInvariant == 0) { + UP.Count = 0; + return false; + } + } + + return ExplicitUnroll; +} + +static LoopUnrollResult +tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, + ScalarEvolution &SE, const TargetTransformInfo &TTI, + AssumptionCache &AC, DependenceInfo &DI, + OptimizationRemarkEmitter &ORE, int OptLevel) { + // Quick checks of the correct loop form + if (!L->isLoopSimplifyForm() || L->getSubLoops().size() != 1) + return LoopUnrollResult::Unmodified; + Loop *SubLoop = L->getSubLoops()[0]; + if (!SubLoop->isLoopSimplifyForm()) + return LoopUnrollResult::Unmodified; + + BasicBlock *Latch = L->getLoopLatch(); + BasicBlock *Exit = L->getExitingBlock(); + BasicBlock *SubLoopLatch = SubLoop->getLoopLatch(); + BasicBlock *SubLoopExit = SubLoop->getExitingBlock(); + + if (Latch != Exit || SubLoopLatch != SubLoopExit) + return LoopUnrollResult::Unmodified; + + TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( + L, SE, TTI, OptLevel, None, None, None, None, None, None); + if (AllowUnrollAndJam.getNumOccurrences() > 0) + UP.UnrollAndJam = AllowUnrollAndJam; + if (UnrollAndJamThreshold.getNumOccurrences() > 0) + UP.UnrollAndJamInnerLoopThreshold = UnrollAndJamThreshold; + // Exit early if unrolling is disabled. + if (!UP.UnrollAndJam || UP.UnrollAndJamInnerLoopThreshold == 0) + return LoopUnrollResult::Unmodified; + + LLVM_DEBUG(dbgs() << "Loop Unroll and Jam: F[" + << L->getHeader()->getParent()->getName() << "] Loop %" + << L->getHeader()->getName() << "\n"); + + // A loop with any unroll pragma (enabling/disabling/count/etc) is left for + // the unroller, so long as it does not explicitly have unroll_and_jam + // metadata. This means #pragma nounroll will disable unroll and jam as well + // as unrolling + if (HasUnrollAndJamDisablePragma(L) || + (HasAnyUnrollPragma(L, "llvm.loop.unroll.") && + !HasAnyUnrollPragma(L, "llvm.loop.unroll_and_jam."))) { + LLVM_DEBUG(dbgs() << " Disabled due to pragma.\n"); + return LoopUnrollResult::Unmodified; + } + + if (!isSafeToUnrollAndJam(L, SE, DT, DI)) { + LLVM_DEBUG(dbgs() << " Disabled due to not being safe.\n"); + return LoopUnrollResult::Unmodified; + } + + // Approximate the loop size and collect useful info + unsigned NumInlineCandidates; + bool NotDuplicatable; + bool Convergent; + SmallPtrSet<const Value *, 32> EphValues; + CodeMetrics::collectEphemeralValues(L, &AC, EphValues); + unsigned InnerLoopSize = + ApproximateLoopSize(SubLoop, NumInlineCandidates, NotDuplicatable, + Convergent, TTI, EphValues, UP.BEInsns); + unsigned OuterLoopSize = + ApproximateLoopSize(L, NumInlineCandidates, NotDuplicatable, Convergent, + TTI, EphValues, UP.BEInsns); + LLVM_DEBUG(dbgs() << " Outer Loop Size: " << OuterLoopSize << "\n"); + LLVM_DEBUG(dbgs() << " Inner Loop Size: " << InnerLoopSize << "\n"); + if (NotDuplicatable) { + LLVM_DEBUG(dbgs() << " Not unrolling loop which contains non-duplicatable " + "instructions.\n"); + return LoopUnrollResult::Unmodified; + } + if (NumInlineCandidates != 0) { + LLVM_DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n"); + return LoopUnrollResult::Unmodified; + } + if (Convergent) { + LLVM_DEBUG( + dbgs() << " Not unrolling loop with convergent instructions.\n"); + return LoopUnrollResult::Unmodified; + } + + // Find trip count and trip multiple + unsigned OuterTripCount = SE.getSmallConstantTripCount(L, Latch); + unsigned OuterTripMultiple = SE.getSmallConstantTripMultiple(L, Latch); + unsigned InnerTripCount = SE.getSmallConstantTripCount(SubLoop, SubLoopLatch); + + // Decide if, and by how much, to unroll + bool IsCountSetExplicitly = computeUnrollAndJamCount( + L, SubLoop, TTI, DT, LI, SE, EphValues, &ORE, OuterTripCount, + OuterTripMultiple, OuterLoopSize, InnerTripCount, InnerLoopSize, UP); + if (UP.Count <= 1) + return LoopUnrollResult::Unmodified; + // Unroll factor (Count) must be less or equal to TripCount. + if (OuterTripCount && UP.Count > OuterTripCount) + UP.Count = OuterTripCount; + + LoopUnrollResult UnrollResult = + UnrollAndJamLoop(L, UP.Count, OuterTripCount, OuterTripMultiple, + UP.UnrollRemainder, LI, &SE, &DT, &AC, &ORE); + + // If loop has an unroll count pragma or unrolled by explicitly set count + // mark loop as unrolled to prevent unrolling beyond that requested. + if (UnrollResult != LoopUnrollResult::FullyUnrolled && IsCountSetExplicitly) + L->setLoopAlreadyUnrolled(); + + return UnrollResult; +} + +namespace { + +class LoopUnrollAndJam : public LoopPass { +public: + static char ID; // Pass ID, replacement for typeid + unsigned OptLevel; + + LoopUnrollAndJam(int OptLevel = 2) : LoopPass(ID), OptLevel(OptLevel) { + initializeLoopUnrollAndJamPass(*PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) + return false; + + Function &F = *L->getHeader()->getParent(); + + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + const TargetTransformInfo &TTI = + getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto &DI = getAnalysis<DependenceAnalysisWrapperPass>().getDI(); + // For the old PM, we can't use OptimizationRemarkEmitter as an analysis + // pass. Function analyses need to be preserved across loop transformations + // but ORE cannot be preserved (see comment before the pass definition). + OptimizationRemarkEmitter ORE(&F); + + LoopUnrollResult Result = + tryToUnrollAndJamLoop(L, DT, LI, SE, TTI, AC, DI, ORE, OptLevel); + + if (Result == LoopUnrollResult::FullyUnrolled) + LPM.markLoopAsDeleted(*L); + + return Result != LoopUnrollResult::Unmodified; + } + + /// This transformation requires natural loop information & requires that + /// loop preheaders be inserted into the CFG... + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addRequired<DependenceAnalysisWrapperPass>(); + getLoopAnalysisUsage(AU); + } +}; + +} // end anonymous namespace + +char LoopUnrollAndJam::ID = 0; + +INITIALIZE_PASS_BEGIN(LoopUnrollAndJam, "loop-unroll-and-jam", + "Unroll and Jam loops", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DependenceAnalysisWrapperPass) +INITIALIZE_PASS_END(LoopUnrollAndJam, "loop-unroll-and-jam", + "Unroll and Jam loops", false, false) + +Pass *llvm::createLoopUnrollAndJamPass(int OptLevel) { + return new LoopUnrollAndJam(OptLevel); +} + +PreservedAnalyses LoopUnrollAndJamPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { + const auto &FAM = + AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); + Function *F = L.getHeader()->getParent(); + + auto *ORE = FAM.getCachedResult<OptimizationRemarkEmitterAnalysis>(*F); + // FIXME: This should probably be optional rather than required. + if (!ORE) + report_fatal_error( + "LoopUnrollAndJamPass: OptimizationRemarkEmitterAnalysis not cached at " + "a higher level"); + + DependenceInfo DI(F, &AR.AA, &AR.SE, &AR.LI); + + LoopUnrollResult Result = tryToUnrollAndJamLoop( + &L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, DI, *ORE, OptLevel); + + if (Result == LoopUnrollResult::Unmodified) + return PreservedAnalyses::all(); + + return getLoopPassPreservedAnalyses(); +} diff --git a/lib/Transforms/Scalar/LoopUnrollPass.cpp b/lib/Transforms/Scalar/LoopUnrollPass.cpp index 15e7da5e1a7a..634215c9770f 100644 --- a/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -53,6 +53,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/UnrollLoop.h" @@ -164,7 +165,7 @@ static const unsigned NoThreshold = std::numeric_limits<unsigned>::max(); /// Gather the various unrolling parameters based on the defaults, compiler /// flags, TTI overrides and user specified parameters. -static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( +TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences( Loop *L, ScalarEvolution &SE, const TargetTransformInfo &TTI, int OptLevel, Optional<unsigned> UserThreshold, Optional<unsigned> UserCount, Optional<bool> UserAllowPartial, Optional<bool> UserRuntime, @@ -191,6 +192,8 @@ static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( UP.Force = false; UP.UpperBound = false; UP.AllowPeeling = true; + UP.UnrollAndJam = false; + UP.UnrollAndJamInnerLoopThreshold = 60; // Override with any target specific settings TTI.getUnrollingPreferences(L, SE, UP); @@ -285,17 +288,17 @@ struct UnrolledInstStateKeyInfo { }; struct EstimatedUnrollCost { - /// \brief The estimated cost after unrolling. + /// The estimated cost after unrolling. unsigned UnrolledCost; - /// \brief The estimated dynamic cost of executing the instructions in the + /// The estimated dynamic cost of executing the instructions in the /// rolled form. unsigned RolledDynamicCost; }; } // end anonymous namespace -/// \brief Figure out if the loop is worth full unrolling. +/// Figure out if the loop is worth full unrolling. /// /// Complete loop unrolling can make some loads constant, and we need to know /// if that would expose any further optimization opportunities. This routine @@ -308,10 +311,10 @@ struct EstimatedUnrollCost { /// \returns Optional value, holding the RolledDynamicCost and UnrolledCost. If /// the analysis failed (no benefits expected from the unrolling, or the loop is /// too big to analyze), the returned value is None. -static Optional<EstimatedUnrollCost> -analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, - ScalarEvolution &SE, const TargetTransformInfo &TTI, - unsigned MaxUnrolledLoopSize) { +static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost( + const Loop *L, unsigned TripCount, DominatorTree &DT, ScalarEvolution &SE, + const SmallPtrSetImpl<const Value *> &EphValues, + const TargetTransformInfo &TTI, unsigned MaxUnrolledLoopSize) { // We want to be able to scale offsets by the trip count and add more offsets // to them without checking for overflows, and we already don't want to // analyze *massive* trip counts, so we force the max to be reasonably small. @@ -405,9 +408,9 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, // First accumulate the cost of this instruction. if (!Cost.IsFree) { UnrolledCost += TTI.getUserCost(I); - DEBUG(dbgs() << "Adding cost of instruction (iteration " << Iteration - << "): "); - DEBUG(I->dump()); + LLVM_DEBUG(dbgs() << "Adding cost of instruction (iteration " + << Iteration << "): "); + LLVM_DEBUG(I->dump()); } // We must count the cost of every operand which is not free, @@ -442,14 +445,14 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, assert(L->isLCSSAForm(DT) && "Must have loops in LCSSA form to track live-out values."); - DEBUG(dbgs() << "Starting LoopUnroll profitability analysis...\n"); + LLVM_DEBUG(dbgs() << "Starting LoopUnroll profitability analysis...\n"); // Simulate execution of each iteration of the loop counting instructions, // which would be simplified. // Since the same load will take different values on different iterations, // we literally have to go through all loop's iterations. for (unsigned Iteration = 0; Iteration < TripCount; ++Iteration) { - DEBUG(dbgs() << " Analyzing iteration " << Iteration << "\n"); + LLVM_DEBUG(dbgs() << " Analyzing iteration " << Iteration << "\n"); // Prepare for the iteration by collecting any simplified entry or backedge // inputs. @@ -490,7 +493,9 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, // it. We don't change the actual IR, just count optimization // opportunities. for (Instruction &I : *BB) { - if (isa<DbgInfoIntrinsic>(I)) + // These won't get into the final code - don't even try calculating the + // cost for them. + if (isa<DbgInfoIntrinsic>(I) || EphValues.count(&I)) continue; // Track this instruction's expected baseline cost when executing the @@ -512,8 +517,13 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, // Can't properly model a cost of a call. // FIXME: With a proper cost model we should be able to do it. - if(isa<CallInst>(&I)) - return None; + if (auto *CI = dyn_cast<CallInst>(&I)) { + const Function *Callee = CI->getCalledFunction(); + if (!Callee || TTI.isLoweredToCall(Callee)) { + LLVM_DEBUG(dbgs() << "Can't analyze cost of loop with call\n"); + return None; + } + } // If the instruction might have a side-effect recursively account for // the cost of it and all the instructions leading up to it. @@ -522,10 +532,10 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, // If unrolled body turns out to be too big, bail out. if (UnrolledCost > MaxUnrolledLoopSize) { - DEBUG(dbgs() << " Exceeded threshold.. exiting.\n" - << " UnrolledCost: " << UnrolledCost - << ", MaxUnrolledLoopSize: " << MaxUnrolledLoopSize - << "\n"); + LLVM_DEBUG(dbgs() << " Exceeded threshold.. exiting.\n" + << " UnrolledCost: " << UnrolledCost + << ", MaxUnrolledLoopSize: " << MaxUnrolledLoopSize + << "\n"); return None; } } @@ -578,8 +588,8 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, // If we found no optimization opportunities on the first iteration, we // won't find them on later ones too. if (UnrolledCost == RolledDynamicCost) { - DEBUG(dbgs() << " No opportunities found.. exiting.\n" - << " UnrolledCost: " << UnrolledCost << "\n"); + LLVM_DEBUG(dbgs() << " No opportunities found.. exiting.\n" + << " UnrolledCost: " << UnrolledCost << "\n"); return None; } } @@ -600,20 +610,17 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, } } - DEBUG(dbgs() << "Analysis finished:\n" - << "UnrolledCost: " << UnrolledCost << ", " - << "RolledDynamicCost: " << RolledDynamicCost << "\n"); + LLVM_DEBUG(dbgs() << "Analysis finished:\n" + << "UnrolledCost: " << UnrolledCost << ", " + << "RolledDynamicCost: " << RolledDynamicCost << "\n"); return {{UnrolledCost, RolledDynamicCost}}; } /// ApproximateLoopSize - Approximate the size of the loop. -static unsigned ApproximateLoopSize(const Loop *L, unsigned &NumCalls, - bool &NotDuplicatable, bool &Convergent, - const TargetTransformInfo &TTI, - AssumptionCache *AC, unsigned BEInsns) { - SmallPtrSet<const Value *, 32> EphValues; - CodeMetrics::collectEphemeralValues(L, AC, EphValues); - +unsigned llvm::ApproximateLoopSize( + const Loop *L, unsigned &NumCalls, bool &NotDuplicatable, bool &Convergent, + const TargetTransformInfo &TTI, + const SmallPtrSetImpl<const Value *> &EphValues, unsigned BEInsns) { CodeMetrics Metrics; for (BasicBlock *BB : L->blocks()) Metrics.analyzeBasicBlock(BB, TTI, EphValues); @@ -706,10 +713,11 @@ static uint64_t getUnrolledLoopSize( // Returns true if unroll count was set explicitly. // Calculates unroll count and writes it to UP.Count. -static bool computeUnrollCount( +bool llvm::computeUnrollCount( Loop *L, const TargetTransformInfo &TTI, DominatorTree &DT, LoopInfo *LI, - ScalarEvolution &SE, OptimizationRemarkEmitter *ORE, unsigned &TripCount, - unsigned MaxTripCount, unsigned &TripMultiple, unsigned LoopSize, + ScalarEvolution &SE, const SmallPtrSetImpl<const Value *> &EphValues, + OptimizationRemarkEmitter *ORE, unsigned &TripCount, unsigned MaxTripCount, + unsigned &TripMultiple, unsigned LoopSize, TargetTransformInfo::UnrollingPreferences &UP, bool &UseUpperBound) { // Check for explicit Count. // 1st priority is unroll count set by "unroll-count" option. @@ -729,7 +737,7 @@ static bool computeUnrollCount( UP.Runtime = true; UP.AllowExpensiveTripCount = true; UP.Force = true; - if (UP.AllowRemainder && + if ((UP.AllowRemainder || (TripMultiple % PragmaCount == 0)) && getUnrolledLoopSize(LoopSize, UP) < PragmaUnrollThreshold) return true; } @@ -746,8 +754,8 @@ static bool computeUnrollCount( if (ExplicitUnroll && TripCount != 0) { // If the loop has an unrolling pragma, we want to be more aggressive with - // unrolling limits. Set thresholds to at least the PragmaThreshold value - // which is larger than the default limits. + // unrolling limits. Set thresholds to at least the PragmaUnrollThreshold + // value which is larger than the default limits. UP.Threshold = std::max<unsigned>(UP.Threshold, PragmaUnrollThreshold); UP.PartialThreshold = std::max<unsigned>(UP.PartialThreshold, PragmaUnrollThreshold); @@ -763,7 +771,7 @@ static bool computeUnrollCount( // compute the former when the latter is zero. unsigned ExactTripCount = TripCount; assert((ExactTripCount == 0 || MaxTripCount == 0) && - "ExtractTripCound and MaxTripCount cannot both be non zero."); + "ExtractTripCount and MaxTripCount cannot both be non zero."); unsigned FullUnrollTripCount = ExactTripCount ? ExactTripCount : MaxTripCount; UP.Count = FullUnrollTripCount; if (FullUnrollTripCount && FullUnrollTripCount <= UP.FullUnrollMaxCount) { @@ -779,7 +787,7 @@ static bool computeUnrollCount( // helps to remove a significant number of instructions. // To check that, run additional analysis on the loop. if (Optional<EstimatedUnrollCost> Cost = analyzeLoopUnrollCost( - L, FullUnrollTripCount, DT, SE, TTI, + L, FullUnrollTripCount, DT, SE, EphValues, TTI, UP.Threshold * UP.MaxPercentThresholdBoost / 100)) { unsigned Boost = getFullUnrollBoostingFactor(*Cost, UP.MaxPercentThresholdBoost); @@ -794,7 +802,7 @@ static bool computeUnrollCount( } // 4th priority is loop peeling - computePeelCount(L, LoopSize, UP, TripCount); + computePeelCount(L, LoopSize, UP, TripCount, SE); if (UP.PeelCount) { UP.Runtime = false; UP.Count = 1; @@ -802,12 +810,12 @@ static bool computeUnrollCount( } // 5th priority is partial unrolling. - // Try partial unroll only when TripCount could be staticaly calculated. + // Try partial unroll only when TripCount could be statically calculated. if (TripCount) { UP.Partial |= ExplicitUnroll; if (!UP.Partial) { - DEBUG(dbgs() << " will not try to unroll partially because " - << "-unroll-allow-partial not given\n"); + LLVM_DEBUG(dbgs() << " will not try to unroll partially because " + << "-unroll-allow-partial not given\n"); UP.Count = 0; return false; } @@ -894,8 +902,9 @@ static bool computeUnrollCount( // Reduce count based on the type of unrolling and the threshold values. UP.Runtime |= PragmaEnableUnroll || PragmaCount > 0 || UserUnrollCount; if (!UP.Runtime) { - DEBUG(dbgs() << " will not try to unroll loop with runtime trip count " - << "-unroll-runtime not given\n"); + LLVM_DEBUG( + dbgs() << " will not try to unroll loop with runtime trip count " + << "-unroll-runtime not given\n"); UP.Count = 0; return false; } @@ -915,12 +924,13 @@ static bool computeUnrollCount( if (!UP.AllowRemainder && UP.Count != 0 && (TripMultiple % UP.Count) != 0) { while (UP.Count != 0 && TripMultiple % UP.Count != 0) UP.Count >>= 1; - DEBUG(dbgs() << "Remainder loop is restricted (that could architecture " - "specific or because the loop contains a convergent " - "instruction), so unroll count must divide the trip " - "multiple, " - << TripMultiple << ". Reducing unroll count from " - << OrigCount << " to " << UP.Count << ".\n"); + LLVM_DEBUG( + dbgs() << "Remainder loop is restricted (that could architecture " + "specific or because the loop contains a convergent " + "instruction), so unroll count must divide the trip " + "multiple, " + << TripMultiple << ". Reducing unroll count from " << OrigCount + << " to " << UP.Count << ".\n"); using namespace ore; @@ -942,7 +952,8 @@ static bool computeUnrollCount( if (UP.Count > UP.MaxCount) UP.Count = UP.MaxCount; - DEBUG(dbgs() << " partially unrolling with count: " << UP.Count << "\n"); + LLVM_DEBUG(dbgs() << " partially unrolling with count: " << UP.Count + << "\n"); if (UP.Count < 2) UP.Count = 0; return ExplicitUnroll; @@ -955,12 +966,13 @@ static LoopUnrollResult tryToUnrollLoop( Optional<unsigned> ProvidedCount, Optional<unsigned> ProvidedThreshold, Optional<bool> ProvidedAllowPartial, Optional<bool> ProvidedRuntime, Optional<bool> ProvidedUpperBound, Optional<bool> ProvidedAllowPeeling) { - DEBUG(dbgs() << "Loop Unroll: F[" << L->getHeader()->getParent()->getName() - << "] Loop %" << L->getHeader()->getName() << "\n"); + LLVM_DEBUG(dbgs() << "Loop Unroll: F[" + << L->getHeader()->getParent()->getName() << "] Loop %" + << L->getHeader()->getName() << "\n"); if (HasUnrollDisablePragma(L)) return LoopUnrollResult::Unmodified; if (!L->isLoopSimplifyForm()) { - DEBUG( + LLVM_DEBUG( dbgs() << " Not unrolling loop which is not in loop-simplify form.\n"); return LoopUnrollResult::Unmodified; } @@ -975,16 +987,21 @@ static LoopUnrollResult tryToUnrollLoop( // Exit early if unrolling is disabled. if (UP.Threshold == 0 && (!UP.Partial || UP.PartialThreshold == 0)) return LoopUnrollResult::Unmodified; - unsigned LoopSize = ApproximateLoopSize( - L, NumInlineCandidates, NotDuplicatable, Convergent, TTI, &AC, UP.BEInsns); - DEBUG(dbgs() << " Loop Size = " << LoopSize << "\n"); + + SmallPtrSet<const Value *, 32> EphValues; + CodeMetrics::collectEphemeralValues(L, &AC, EphValues); + + unsigned LoopSize = + ApproximateLoopSize(L, NumInlineCandidates, NotDuplicatable, Convergent, + TTI, EphValues, UP.BEInsns); + LLVM_DEBUG(dbgs() << " Loop Size = " << LoopSize << "\n"); if (NotDuplicatable) { - DEBUG(dbgs() << " Not unrolling loop which contains non-duplicatable" - << " instructions.\n"); + LLVM_DEBUG(dbgs() << " Not unrolling loop which contains non-duplicatable" + << " instructions.\n"); return LoopUnrollResult::Unmodified; } if (NumInlineCandidates != 0) { - DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n"); + LLVM_DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n"); return LoopUnrollResult::Unmodified; } @@ -1030,7 +1047,7 @@ static LoopUnrollResult tryToUnrollLoop( // loop tests remains the same compared to the non-unrolled version, whereas // the generic upper bound unrolling keeps all but the last loop test so the // number of loop tests goes up which may end up being worse on targets with - // constriained branch predictor resources so is controlled by an option.) + // constrained branch predictor resources so is controlled by an option.) // In addition we only unroll small upper bounds. if (!(UP.UpperBound || MaxOrZero) || MaxTripCount > UnrollMaxUpperBound) { MaxTripCount = 0; @@ -1040,9 +1057,9 @@ static LoopUnrollResult tryToUnrollLoop( // computeUnrollCount() decides whether it is beneficial to use upper bound to // fully unroll the loop. bool UseUpperBound = false; - bool IsCountSetExplicitly = - computeUnrollCount(L, TTI, DT, LI, SE, &ORE, TripCount, MaxTripCount, - TripMultiple, LoopSize, UP, UseUpperBound); + bool IsCountSetExplicitly = computeUnrollCount( + L, TTI, DT, LI, SE, EphValues, &ORE, TripCount, MaxTripCount, + TripMultiple, LoopSize, UP, UseUpperBound); if (!UP.Count) return LoopUnrollResult::Unmodified; // Unroll factor (Count) must be less or equal to TripCount. diff --git a/lib/Transforms/Scalar/LoopUnswitch.cpp b/lib/Transforms/Scalar/LoopUnswitch.cpp index bd468338a1d0..b12586758925 100644 --- a/lib/Transforms/Scalar/LoopUnswitch.cpp +++ b/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -28,7 +28,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" @@ -39,6 +39,7 @@ #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" @@ -66,7 +67,6 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> @@ -298,9 +298,9 @@ bool LUAnalysisCache::countLoop(const Loop *L, const TargetTransformInfo &TTI, MaxSize -= Props.SizeEstimation * Props.CanBeUnswitchedCount; if (Metrics.notDuplicatable) { - DEBUG(dbgs() << "NOT unswitching loop %" - << L->getHeader()->getName() << ", contents cannot be " - << "duplicated!\n"); + LLVM_DEBUG(dbgs() << "NOT unswitching loop %" << L->getHeader()->getName() + << ", contents cannot be " + << "duplicated!\n"); return false; } } @@ -635,6 +635,12 @@ bool LoopUnswitch::processCurrentLoop() { return true; } + // Do not do non-trivial unswitch while optimizing for size. + // FIXME: Use Function::optForSize(). + if (OptimizeForSize || + loopHeader->getParent()->hasFnAttribute(Attribute::OptimizeForSize)) + return false; + // Run through the instructions in the loop, keeping track of three things: // // - That we do not unswitch loops containing convergent operations, as we @@ -666,12 +672,6 @@ bool LoopUnswitch::processCurrentLoop() { } } - // Do not do non-trivial unswitch while optimizing for size. - // FIXME: Use Function::optForSize(). - if (OptimizeForSize || - loopHeader->getParent()->hasFnAttribute(Attribute::OptimizeForSize)) - return false; - for (IntrinsicInst *Guard : Guards) { Value *LoopCond = FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed).first; @@ -856,20 +856,20 @@ bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val, TerminatorInst *TI) { // Check to see if it would be profitable to unswitch current loop. if (!BranchesInfo.CostAllowsUnswitching()) { - DEBUG(dbgs() << "NOT unswitching loop %" - << currentLoop->getHeader()->getName() - << " at non-trivial condition '" << *Val - << "' == " << *LoopCond << "\n" - << ". Cost too high.\n"); + LLVM_DEBUG(dbgs() << "NOT unswitching loop %" + << currentLoop->getHeader()->getName() + << " at non-trivial condition '" << *Val + << "' == " << *LoopCond << "\n" + << ". Cost too high.\n"); return false; } if (hasBranchDivergence && getAnalysis<DivergenceAnalysis>().isDivergent(LoopCond)) { - DEBUG(dbgs() << "NOT unswitching loop %" - << currentLoop->getHeader()->getName() - << " at non-trivial condition '" << *Val - << "' == " << *LoopCond << "\n" - << ". Condition is divergent.\n"); + LLVM_DEBUG(dbgs() << "NOT unswitching loop %" + << currentLoop->getHeader()->getName() + << " at non-trivial condition '" << *Val + << "' == " << *LoopCond << "\n" + << ". Condition is divergent.\n"); return false; } @@ -910,6 +910,7 @@ void LoopUnswitch::EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, BranchInst *OldBranch, TerminatorInst *TI) { assert(OldBranch->isUnconditional() && "Preheader is not split correctly"); + assert(TrueDest != FalseDest && "Branch targets should be different"); // Insert a conditional branch on LIC to the two preheaders. The original // code is the true version and the new code is the false version. Value *BranchVal = LIC; @@ -942,9 +943,9 @@ void LoopUnswitch::EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, if (DT) { // First, add both successors. SmallVector<DominatorTree::UpdateType, 3> Updates; - if (TrueDest != OldBranchParent) + if (TrueDest != OldBranchSucc) Updates.push_back({DominatorTree::Insert, OldBranchParent, TrueDest}); - if (FalseDest != OldBranchParent) + if (FalseDest != OldBranchSucc) Updates.push_back({DominatorTree::Insert, OldBranchParent, FalseDest}); // If both of the new successors are different from the old one, inform the // DT that the edge was deleted. @@ -970,11 +971,15 @@ void LoopUnswitch::EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, BasicBlock *ExitBlock, TerminatorInst *TI) { - DEBUG(dbgs() << "loop-unswitch: Trivial-Unswitch loop %" - << loopHeader->getName() << " [" << L->getBlocks().size() - << " blocks] in Function " - << L->getHeader()->getParent()->getName() << " on cond: " << *Val - << " == " << *Cond << "\n"); + LLVM_DEBUG(dbgs() << "loop-unswitch: Trivial-Unswitch loop %" + << loopHeader->getName() << " [" << L->getBlocks().size() + << " blocks] in Function " + << L->getHeader()->getParent()->getName() + << " on cond: " << *Val << " == " << *Cond << "\n"); + // We are going to make essential changes to CFG. This may invalidate cached + // information for L or one of its parent loops in SCEV. + if (auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>()) + SEWP->getSE().forgetTopmostLoop(L); // First step, split the preheader, so that we know that there is a safe place // to insert the conditional branch. We will change loopPreheader to have a @@ -1038,7 +1043,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { // until it finds the trivial condition candidate (condition that is not a // constant). Since unswitching generates branches with constant conditions, // this scenario could be very common in practice. - SmallSet<BasicBlock*, 8> Visited; + SmallPtrSet<BasicBlock*, 8> Visited; while (true) { // If we exit loop or reach a previous visited block, then @@ -1196,13 +1201,15 @@ void LoopUnswitch::SplitExitEdges(Loop *L, void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, Loop *L, TerminatorInst *TI) { Function *F = loopHeader->getParent(); - DEBUG(dbgs() << "loop-unswitch: Unswitching loop %" - << loopHeader->getName() << " [" << L->getBlocks().size() - << " blocks] in Function " << F->getName() - << " when '" << *Val << "' == " << *LIC << "\n"); + LLVM_DEBUG(dbgs() << "loop-unswitch: Unswitching loop %" + << loopHeader->getName() << " [" << L->getBlocks().size() + << " blocks] in Function " << F->getName() << " when '" + << *Val << "' == " << *LIC << "\n"); + // We are going to make essential changes to CFG. This may invalidate cached + // information for L or one of its parent loops in SCEV. if (auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>()) - SEWP->getSE().forgetLoop(L); + SEWP->getSE().forgetTopmostLoop(L); LoopBlocks.clear(); NewBlocks.clear(); @@ -1274,12 +1281,11 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, // If the successor of the exit block had PHI nodes, add an entry for // NewExit. - for (BasicBlock::iterator I = ExitSucc->begin(); - PHINode *PN = dyn_cast<PHINode>(I); ++I) { - Value *V = PN->getIncomingValueForBlock(ExitBlocks[i]); + for (PHINode &PN : ExitSucc->phis()) { + Value *V = PN.getIncomingValueForBlock(ExitBlocks[i]); ValueToValueMapTy::iterator It = VMap.find(V); if (It != VMap.end()) V = It->second; - PN->addIncoming(V, NewExit); + PN.addIncoming(V, NewExit); } if (LandingPadInst *LPad = NewExit->getLandingPadInst()) { @@ -1356,7 +1362,7 @@ static void RemoveFromWorklist(Instruction *I, static void ReplaceUsesOfWith(Instruction *I, Value *V, std::vector<Instruction*> &Worklist, Loop *L, LPPassManager *LPM) { - DEBUG(dbgs() << "Replace with '" << *V << "': " << *I << "\n"); + LLVM_DEBUG(dbgs() << "Replace with '" << *V << "': " << *I << "\n"); // Add uses to the worklist, which may be dead now. for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) @@ -1496,10 +1502,9 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, BranchInst::Create(Abort, OldSISucc, ConstantInt::getTrue(Context), NewSISucc); // Release the PHI operands for this edge. - for (BasicBlock::iterator II = NewSISucc->begin(); - PHINode *PN = dyn_cast<PHINode>(II); ++II) - PN->setIncomingValue(PN->getBasicBlockIndex(Switch), - UndefValue::get(PN->getType())); + for (PHINode &PN : NewSISucc->phis()) + PN.setIncomingValue(PN.getBasicBlockIndex(Switch), + UndefValue::get(PN.getType())); // Tell the domtree about the new block. We don't fully update the // domtree here -- instead we force it to do a full recomputation // after the pass is complete -- but we do need to inform it of @@ -1526,7 +1531,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { // Simple DCE. if (isInstructionTriviallyDead(I)) { - DEBUG(dbgs() << "Remove dead instruction '" << *I << "\n"); + LLVM_DEBUG(dbgs() << "Remove dead instruction '" << *I << "\n"); // Add uses to the worklist, which may be dead now. for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) @@ -1559,8 +1564,8 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { if (!SinglePred) continue; // Nothing to do. assert(SinglePred == Pred && "CFG broken"); - DEBUG(dbgs() << "Merging blocks: " << Pred->getName() << " <- " - << Succ->getName() << "\n"); + LLVM_DEBUG(dbgs() << "Merging blocks: " << Pred->getName() << " <- " + << Succ->getName() << "\n"); // Resolve any single entry PHI nodes in Succ. while (PHINode *PN = dyn_cast<PHINode>(Succ->begin())) diff --git a/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/lib/Transforms/Scalar/LoopVersioningLICM.cpp index 53b25e688e82..06e86081e8a0 100644 --- a/lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ b/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -68,6 +68,7 @@ #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" @@ -85,6 +86,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/LoopVersioning.h" #include <cassert> @@ -111,7 +113,7 @@ static cl::opt<unsigned> LVLoopDepthThreshold( "LoopVersioningLICM's threshold for maximum allowed loop nest/depth"), cl::init(2), cl::Hidden); -/// \brief Create MDNode for input string. +/// Create MDNode for input string. static MDNode *createStringMetadata(Loop *TheLoop, StringRef Name, unsigned V) { LLVMContext &Context = TheLoop->getHeader()->getContext(); Metadata *MDs[] = { @@ -120,7 +122,7 @@ static MDNode *createStringMetadata(Loop *TheLoop, StringRef Name, unsigned V) { return MDNode::get(Context, MDs); } -/// \brief Set input string into loop metadata by keeping other values intact. +/// Set input string into loop metadata by keeping other values intact. void llvm::addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V) { SmallVector<Metadata *, 4> MDs(1); @@ -166,6 +168,7 @@ struct LoopVersioningLICM : public LoopPass { AU.addRequired<ScalarEvolutionWrapperPass>(); AU.addPreserved<AAResultsWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); } StringRef getPassName() const override { return "Loop Versioning for LICM"; } @@ -178,6 +181,7 @@ struct LoopVersioningLICM : public LoopPass { LoadAndStoreCounter = 0; InvariantCounter = 0; IsReadOnlyLoop = true; + ORE = nullptr; CurAST.reset(); } @@ -207,7 +211,7 @@ private: Loop *CurLoop = nullptr; // AliasSet information for the current loop. - std::unique_ptr<AliasSetTracker> CurAST; + std::unique_ptr<AliasSetTracker> CurAST; // Maximum loop nest threshold unsigned LoopDepthThreshold; @@ -224,6 +228,9 @@ private: // Read only loop marker. bool IsReadOnlyLoop = true; + // OptimizationRemarkEmitter + OptimizationRemarkEmitter *ORE; + bool isLegalForVersioning(); bool legalLoopStructure(); bool legalLoopInstructions(); @@ -235,58 +242,57 @@ private: } // end anonymous namespace -/// \brief Check loop structure and confirms it's good for LoopVersioningLICM. +/// Check loop structure and confirms it's good for LoopVersioningLICM. bool LoopVersioningLICM::legalLoopStructure() { // Loop must be in loop simplify form. if (!CurLoop->isLoopSimplifyForm()) { - DEBUG( - dbgs() << " loop is not in loop-simplify form.\n"); + LLVM_DEBUG(dbgs() << " loop is not in loop-simplify form.\n"); return false; } // Loop should be innermost loop, if not return false. if (!CurLoop->getSubLoops().empty()) { - DEBUG(dbgs() << " loop is not innermost\n"); + LLVM_DEBUG(dbgs() << " loop is not innermost\n"); return false; } // Loop should have a single backedge, if not return false. if (CurLoop->getNumBackEdges() != 1) { - DEBUG(dbgs() << " loop has multiple backedges\n"); + LLVM_DEBUG(dbgs() << " loop has multiple backedges\n"); return false; } // Loop must have a single exiting block, if not return false. if (!CurLoop->getExitingBlock()) { - DEBUG(dbgs() << " loop has multiple exiting block\n"); + LLVM_DEBUG(dbgs() << " loop has multiple exiting block\n"); return false; } // We only handle bottom-tested loop, i.e. loop in which the condition is // checked at the end of each iteration. With that we can assume that all // instructions in the loop are executed the same number of times. if (CurLoop->getExitingBlock() != CurLoop->getLoopLatch()) { - DEBUG(dbgs() << " loop is not bottom tested\n"); + LLVM_DEBUG(dbgs() << " loop is not bottom tested\n"); return false; } // Parallel loops must not have aliasing loop-invariant memory accesses. // Hence we don't need to version anything in this case. if (CurLoop->isAnnotatedParallel()) { - DEBUG(dbgs() << " Parallel loop is not worth versioning\n"); + LLVM_DEBUG(dbgs() << " Parallel loop is not worth versioning\n"); return false; } // Loop depth more then LoopDepthThreshold are not allowed if (CurLoop->getLoopDepth() > LoopDepthThreshold) { - DEBUG(dbgs() << " loop depth is more then threshold\n"); + LLVM_DEBUG(dbgs() << " loop depth is more then threshold\n"); return false; } // We need to be able to compute the loop trip count in order // to generate the bound checks. const SCEV *ExitCount = SE->getBackedgeTakenCount(CurLoop); if (ExitCount == SE->getCouldNotCompute()) { - DEBUG(dbgs() << " loop does not has trip count\n"); + LLVM_DEBUG(dbgs() << " loop does not has trip count\n"); return false; } return true; } -/// \brief Check memory accesses in loop and confirms it's good for +/// Check memory accesses in loop and confirms it's good for /// LoopVersioningLICM. bool LoopVersioningLICM::legalLoopMemoryAccesses() { bool HasMayAlias = false; @@ -328,24 +334,24 @@ bool LoopVersioningLICM::legalLoopMemoryAccesses() { } // Ensure types should be of same type. if (!TypeSafety) { - DEBUG(dbgs() << " Alias tracker type safety failed!\n"); + LLVM_DEBUG(dbgs() << " Alias tracker type safety failed!\n"); return false; } // Ensure loop body shouldn't be read only. if (!HasMod) { - DEBUG(dbgs() << " No memory modified in loop body\n"); + LLVM_DEBUG(dbgs() << " No memory modified in loop body\n"); return false; } // Make sure alias set has may alias case. // If there no alias memory ambiguity, return false. if (!HasMayAlias) { - DEBUG(dbgs() << " No ambiguity in memory access.\n"); + LLVM_DEBUG(dbgs() << " No ambiguity in memory access.\n"); return false; } return true; } -/// \brief Check loop instructions safe for Loop versioning. +/// Check loop instructions safe for Loop versioning. /// It returns true if it's safe else returns false. /// Consider following: /// 1) Check all load store in loop body are non atomic & non volatile. @@ -355,12 +361,12 @@ bool LoopVersioningLICM::instructionSafeForVersioning(Instruction *I) { assert(I != nullptr && "Null instruction found!"); // Check function call safety if (isa<CallInst>(I) && !AA->doesNotAccessMemory(CallSite(I))) { - DEBUG(dbgs() << " Unsafe call site found.\n"); + LLVM_DEBUG(dbgs() << " Unsafe call site found.\n"); return false; } // Avoid loops with possiblity of throw if (I->mayThrow()) { - DEBUG(dbgs() << " May throw instruction found in loop body\n"); + LLVM_DEBUG(dbgs() << " May throw instruction found in loop body\n"); return false; } // If current instruction is load instructions @@ -368,7 +374,7 @@ bool LoopVersioningLICM::instructionSafeForVersioning(Instruction *I) { if (I->mayReadFromMemory()) { LoadInst *Ld = dyn_cast<LoadInst>(I); if (!Ld || !Ld->isSimple()) { - DEBUG(dbgs() << " Found a non-simple load.\n"); + LLVM_DEBUG(dbgs() << " Found a non-simple load.\n"); return false; } LoadAndStoreCounter++; @@ -382,7 +388,7 @@ bool LoopVersioningLICM::instructionSafeForVersioning(Instruction *I) { else if (I->mayWriteToMemory()) { StoreInst *St = dyn_cast<StoreInst>(I); if (!St || !St->isSimple()) { - DEBUG(dbgs() << " Found a non-simple store.\n"); + LLVM_DEBUG(dbgs() << " Found a non-simple store.\n"); return false; } LoadAndStoreCounter++; @@ -396,59 +402,87 @@ bool LoopVersioningLICM::instructionSafeForVersioning(Instruction *I) { return true; } -/// \brief Check loop instructions and confirms it's good for +/// Check loop instructions and confirms it's good for /// LoopVersioningLICM. bool LoopVersioningLICM::legalLoopInstructions() { // Resetting counters. LoadAndStoreCounter = 0; InvariantCounter = 0; IsReadOnlyLoop = true; + using namespace ore; // Iterate over loop blocks and instructions of each block and check // instruction safety. for (auto *Block : CurLoop->getBlocks()) for (auto &Inst : *Block) { // If instruction is unsafe just return false. - if (!instructionSafeForVersioning(&Inst)) + if (!instructionSafeForVersioning(&Inst)) { + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "IllegalLoopInst", &Inst) + << " Unsafe Loop Instruction"; + }); return false; + } } // Get LoopAccessInfo from current loop. LAI = &LAA->getInfo(CurLoop); // Check LoopAccessInfo for need of runtime check. if (LAI->getRuntimePointerChecking()->getChecks().empty()) { - DEBUG(dbgs() << " LAA: Runtime check not found !!\n"); + LLVM_DEBUG(dbgs() << " LAA: Runtime check not found !!\n"); return false; } // Number of runtime-checks should be less then RuntimeMemoryCheckThreshold if (LAI->getNumRuntimePointerChecks() > VectorizerParams::RuntimeMemoryCheckThreshold) { - DEBUG(dbgs() << " LAA: Runtime checks are more than threshold !!\n"); + LLVM_DEBUG( + dbgs() << " LAA: Runtime checks are more than threshold !!\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "RuntimeCheck", + CurLoop->getStartLoc(), + CurLoop->getHeader()) + << "Number of runtime checks " + << NV("RuntimeChecks", LAI->getNumRuntimePointerChecks()) + << " exceeds threshold " + << NV("Threshold", VectorizerParams::RuntimeMemoryCheckThreshold); + }); return false; } // Loop should have at least one invariant load or store instruction. if (!InvariantCounter) { - DEBUG(dbgs() << " Invariant not found !!\n"); + LLVM_DEBUG(dbgs() << " Invariant not found !!\n"); return false; } // Read only loop not allowed. if (IsReadOnlyLoop) { - DEBUG(dbgs() << " Found a read-only loop!\n"); + LLVM_DEBUG(dbgs() << " Found a read-only loop!\n"); return false; } // Profitablity check: // Check invariant threshold, should be in limit. if (InvariantCounter * 100 < InvariantThreshold * LoadAndStoreCounter) { - DEBUG(dbgs() - << " Invariant load & store are less then defined threshold\n"); - DEBUG(dbgs() << " Invariant loads & stores: " - << ((InvariantCounter * 100) / LoadAndStoreCounter) << "%\n"); - DEBUG(dbgs() << " Invariant loads & store threshold: " - << InvariantThreshold << "%\n"); + LLVM_DEBUG( + dbgs() + << " Invariant load & store are less then defined threshold\n"); + LLVM_DEBUG(dbgs() << " Invariant loads & stores: " + << ((InvariantCounter * 100) / LoadAndStoreCounter) + << "%\n"); + LLVM_DEBUG(dbgs() << " Invariant loads & store threshold: " + << InvariantThreshold << "%\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "InvariantThreshold", + CurLoop->getStartLoc(), + CurLoop->getHeader()) + << "Invariant load & store " + << NV("LoadAndStoreCounter", + ((InvariantCounter * 100) / LoadAndStoreCounter)) + << " are less then defined threshold " + << NV("Threshold", InvariantThreshold); + }); return false; } return true; } -/// \brief It checks loop is already visited or not. +/// It checks loop is already visited or not. /// check loop meta data, if loop revisited return true /// else false. bool LoopVersioningLICM::isLoopAlreadyVisited() { @@ -459,42 +493,64 @@ bool LoopVersioningLICM::isLoopAlreadyVisited() { return false; } -/// \brief Checks legality for LoopVersioningLICM by considering following: +/// Checks legality for LoopVersioningLICM by considering following: /// a) loop structure legality b) loop instruction legality /// c) loop memory access legality. /// Return true if legal else returns false. bool LoopVersioningLICM::isLegalForVersioning() { - DEBUG(dbgs() << "Loop: " << *CurLoop); + using namespace ore; + LLVM_DEBUG(dbgs() << "Loop: " << *CurLoop); // Make sure not re-visiting same loop again. if (isLoopAlreadyVisited()) { - DEBUG( + LLVM_DEBUG( dbgs() << " Revisiting loop in LoopVersioningLICM not allowed.\n\n"); return false; } // Check loop structure leagality. if (!legalLoopStructure()) { - DEBUG( + LLVM_DEBUG( dbgs() << " Loop structure not suitable for LoopVersioningLICM\n\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "IllegalLoopStruct", + CurLoop->getStartLoc(), + CurLoop->getHeader()) + << " Unsafe Loop structure"; + }); return false; } // Check loop instruction leagality. if (!legalLoopInstructions()) { - DEBUG(dbgs() - << " Loop instructions not suitable for LoopVersioningLICM\n\n"); + LLVM_DEBUG( + dbgs() + << " Loop instructions not suitable for LoopVersioningLICM\n\n"); return false; } // Check loop memory access leagality. if (!legalLoopMemoryAccesses()) { - DEBUG(dbgs() - << " Loop memory access not suitable for LoopVersioningLICM\n\n"); + LLVM_DEBUG( + dbgs() + << " Loop memory access not suitable for LoopVersioningLICM\n\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "IllegalLoopMemoryAccess", + CurLoop->getStartLoc(), + CurLoop->getHeader()) + << " Unsafe Loop memory access"; + }); return false; } // Loop versioning is feasible, return true. - DEBUG(dbgs() << " Loop Versioning found to be beneficial\n\n"); + LLVM_DEBUG(dbgs() << " Loop Versioning found to be beneficial\n\n"); + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "IsLegalForVersioning", + CurLoop->getStartLoc(), CurLoop->getHeader()) + << " Versioned loop for LICM." + << " Number of runtime checks we had to insert " + << NV("RuntimeChecks", LAI->getNumRuntimePointerChecks()); + }); return true; } -/// \brief Update loop with aggressive aliasing assumptions. +/// Update loop with aggressive aliasing assumptions. /// It marks no-alias to any pairs of memory operations by assuming /// loop should not have any must-alias memory accesses pairs. /// During LoopVersioningLICM legality we ignore loops having must @@ -542,6 +598,7 @@ bool LoopVersioningLICM::runOnLoop(Loop *L, LPPassManager &LPM) { AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); LAA = &getAnalysis<LoopAccessLegacyAnalysis>(); + ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); LAI = nullptr; // Set Current Loop CurLoop = L; @@ -592,6 +649,7 @@ INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_END(LoopVersioningLICM, "loop-versioning-licm", "Loop Versioning For LICM", false, false) diff --git a/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp index 46f8a3564265..68bfa0030395 100644 --- a/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp +++ b/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp @@ -357,7 +357,7 @@ PreservedAnalyses LowerExpectIntrinsicPass::run(Function &F, } namespace { -/// \brief Legacy pass for lowering expect intrinsics out of the IR. +/// Legacy pass for lowering expect intrinsics out of the IR. /// /// When this pass is run over a function it uses expect intrinsics which feed /// branches and switches to provide branch weight metadata for those diff --git a/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/lib/Transforms/Scalar/MemCpyOptimizer.cpp index 9c870b42a747..3b74421a47a0 100644 --- a/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -25,6 +25,7 @@ #include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" @@ -55,7 +56,6 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> #include <cstdint> @@ -263,7 +263,7 @@ public: void addMemSet(int64_t OffsetFromFirst, MemSetInst *MSI) { int64_t Size = cast<ConstantInt>(MSI->getLength())->getZExtValue(); - addRange(OffsetFromFirst, Size, MSI->getDest(), MSI->getAlignment(), MSI); + addRange(OffsetFromFirst, Size, MSI->getDest(), MSI->getDestAlignment(), MSI); } void addRange(int64_t Start, int64_t Size, Value *Ptr, @@ -479,10 +479,10 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, AMemSet = Builder.CreateMemSet(StartPtr, ByteVal, Range.End-Range.Start, Alignment); - DEBUG(dbgs() << "Replace stores:\n"; - for (Instruction *SI : Range.TheStores) - dbgs() << *SI << '\n'; - dbgs() << "With: " << *AMemSet << '\n'); + LLVM_DEBUG(dbgs() << "Replace stores:\n"; for (Instruction *SI + : Range.TheStores) dbgs() + << *SI << '\n'; + dbgs() << "With: " << *AMemSet << '\n'); if (!Range.TheStores.empty()) AMemSet->setDebugLoc(Range.TheStores[0]->getDebugLoc()); @@ -498,16 +498,25 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, return AMemSet; } -static unsigned findCommonAlignment(const DataLayout &DL, const StoreInst *SI, - const LoadInst *LI) { +static unsigned findStoreAlignment(const DataLayout &DL, const StoreInst *SI) { unsigned StoreAlign = SI->getAlignment(); if (!StoreAlign) StoreAlign = DL.getABITypeAlignment(SI->getOperand(0)->getType()); + return StoreAlign; +} + +static unsigned findLoadAlignment(const DataLayout &DL, const LoadInst *LI) { unsigned LoadAlign = LI->getAlignment(); if (!LoadAlign) LoadAlign = DL.getABITypeAlignment(LI->getType()); + return LoadAlign; +} - return std::min(StoreAlign, LoadAlign); +static unsigned findCommonAlignment(const DataLayout &DL, const StoreInst *SI, + const LoadInst *LI) { + unsigned StoreAlign = findStoreAlignment(DL, SI); + unsigned LoadAlign = findLoadAlignment(DL, LI); + return MinAlign(StoreAlign, LoadAlign); } // This method try to lift a store instruction before position P. @@ -522,7 +531,7 @@ static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P, return false; // Keep track of the arguments of all instruction we plan to lift - // so we can make sure to lift them as well if apropriate. + // so we can make sure to lift them as well if appropriate. DenseSet<Instruction*> Args; if (auto *Ptr = dyn_cast<Instruction>(SI->getPointerOperand())) if (Ptr->getParent() == SI->getParent()) @@ -594,7 +603,7 @@ static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P, // We made it, we need to lift for (auto *I : llvm::reverse(ToLift)) { - DEBUG(dbgs() << "Lifting " << *I << " before " << *P << "\n"); + LLVM_DEBUG(dbgs() << "Lifting " << *I << " before " << *P << "\n"); I->moveBefore(P); } @@ -656,22 +665,23 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { if (!AA.isNoAlias(MemoryLocation::get(SI), LoadLoc)) UseMemMove = true; - unsigned Align = findCommonAlignment(DL, SI, LI); uint64_t Size = DL.getTypeStoreSize(T); IRBuilder<> Builder(P); Instruction *M; if (UseMemMove) - M = Builder.CreateMemMove(SI->getPointerOperand(), - LI->getPointerOperand(), Size, - Align, SI->isVolatile()); + M = Builder.CreateMemMove( + SI->getPointerOperand(), findStoreAlignment(DL, SI), + LI->getPointerOperand(), findLoadAlignment(DL, LI), Size, + SI->isVolatile()); else - M = Builder.CreateMemCpy(SI->getPointerOperand(), - LI->getPointerOperand(), Size, - Align, SI->isVolatile()); + M = Builder.CreateMemCpy( + SI->getPointerOperand(), findStoreAlignment(DL, SI), + LI->getPointerOperand(), findLoadAlignment(DL, LI), Size, + SI->isVolatile()); - DEBUG(dbgs() << "Promoting " << *LI << " to " << *SI - << " => " << *M << "\n"); + LLVM_DEBUG(dbgs() << "Promoting " << *LI << " to " << *SI << " => " + << *M << "\n"); MD->removeInstruction(SI); SI->eraseFromParent(); @@ -760,7 +770,7 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { auto *M = Builder.CreateMemSet(SI->getPointerOperand(), ByteVal, Size, Align, SI->isVolatile()); - DEBUG(dbgs() << "Promoting " << *SI << " to " << *M << "\n"); + LLVM_DEBUG(dbgs() << "Promoting " << *SI << " to " << *M << "\n"); MD->removeInstruction(SI); SI->eraseFromParent(); @@ -1047,20 +1057,17 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M, // If all checks passed, then we can transform M. - // Make sure to use the lesser of the alignment of the source and the dest - // since we're changing where we're reading from, but don't want to increase - // the alignment past what can be read from or written to. // TODO: Is this worth it if we're creating a less aligned memcpy? For // example we could be moving from movaps -> movq on x86. - unsigned Align = std::min(MDep->getAlignment(), M->getAlignment()); - IRBuilder<> Builder(M); if (UseMemMove) - Builder.CreateMemMove(M->getRawDest(), MDep->getRawSource(), M->getLength(), - Align, M->isVolatile()); + Builder.CreateMemMove(M->getRawDest(), M->getDestAlignment(), + MDep->getRawSource(), MDep->getSourceAlignment(), + M->getLength(), M->isVolatile()); else - Builder.CreateMemCpy(M->getRawDest(), MDep->getRawSource(), M->getLength(), - Align, M->isVolatile()); + Builder.CreateMemCpy(M->getRawDest(), M->getDestAlignment(), + MDep->getRawSource(), MDep->getSourceAlignment(), + M->getLength(), M->isVolatile()); // Remove the instruction we're replacing. MD->removeInstruction(M); @@ -1106,7 +1113,7 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, // If Dest is aligned, and SrcSize is constant, use the minimum alignment // of the sum. const unsigned DestAlign = - std::max(MemSet->getAlignment(), MemCpy->getAlignment()); + std::max(MemSet->getDestAlignment(), MemCpy->getDestAlignment()); if (DestAlign > 1) if (ConstantInt *SrcSizeC = dyn_cast<ConstantInt>(SrcSize)) Align = MinAlign(SrcSizeC->getZExtValue(), DestAlign); @@ -1166,7 +1173,7 @@ bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, IRBuilder<> Builder(MemCpy); Builder.CreateMemSet(MemCpy->getRawDest(), MemSet->getOperand(1), - CopySize, MemCpy->getAlignment()); + CopySize, MemCpy->getDestAlignment()); return true; } @@ -1192,7 +1199,7 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M) { if (Value *ByteVal = isBytewiseValue(GV->getInitializer())) { IRBuilder<> Builder(M); Builder.CreateMemSet(M->getRawDest(), ByteVal, M->getLength(), - M->getAlignment(), false); + M->getDestAlignment(), false); MD->removeInstruction(M); M->eraseFromParent(); ++NumCpyToSet; @@ -1221,8 +1228,11 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M) { // d) memcpy from a just-memset'd source can be turned into memset. if (DepInfo.isClobber()) { if (CallInst *C = dyn_cast<CallInst>(DepInfo.getInst())) { + // FIXME: Can we pass in either of dest/src alignment here instead + // of conservatively taking the minimum? + unsigned Align = MinAlign(M->getDestAlignment(), M->getSourceAlignment()); if (performCallSlotOptzn(M, M->getDest(), M->getSource(), - CopySize->getZExtValue(), M->getAlignment(), + CopySize->getZExtValue(), Align, C)) { MD->removeInstruction(M); M->eraseFromParent(); @@ -1284,8 +1294,8 @@ bool MemCpyOptPass::processMemMove(MemMoveInst *M) { MemoryLocation::getForSource(M))) return false; - DEBUG(dbgs() << "MemCpyOptPass: Optimizing memmove -> memcpy: " << *M - << "\n"); + LLVM_DEBUG(dbgs() << "MemCpyOptPass: Optimizing memmove -> memcpy: " << *M + << "\n"); // If not, then we know we can transform this. Type *ArgTys[3] = { M->getRawDest()->getType(), @@ -1337,7 +1347,7 @@ bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) { // source of the memcpy to the alignment we need. If we fail, we bail out. AssumptionCache &AC = LookupAssumptionCache(); DominatorTree &DT = LookupDomTree(); - if (MDep->getAlignment() < ByValAlign && + if (MDep->getSourceAlignment() < ByValAlign && getOrEnforceKnownAlignment(MDep->getSource(), ByValAlign, DL, CS.getInstruction(), &AC, &DT) < ByValAlign) return false; @@ -1367,9 +1377,9 @@ bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) { TmpCast = new BitCastInst(MDep->getSource(), ByValArg->getType(), "tmpcast", CS.getInstruction()); - DEBUG(dbgs() << "MemCpyOptPass: Forwarding memcpy to byval:\n" - << " " << *MDep << "\n" - << " " << *CS.getInstruction() << "\n"); + LLVM_DEBUG(dbgs() << "MemCpyOptPass: Forwarding memcpy to byval:\n" + << " " << *MDep << "\n" + << " " << *CS.getInstruction() << "\n"); // Otherwise we're good! Update the byval argument. CS.setArgument(ArgNo, TmpCast); @@ -1381,10 +1391,19 @@ bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) { bool MemCpyOptPass::iterateOnFunction(Function &F) { bool MadeChange = false; + DominatorTree &DT = LookupDomTree(); + // Walk all instruction in the function. for (BasicBlock &BB : F) { + // Skip unreachable blocks. For example processStore assumes that an + // instruction in a BB can't be dominated by a later instruction in the + // same BB (which is a scenario that can happen for an unreachable BB that + // has itself as a predecessor). + if (!DT.isReachableFromEntry(&BB)) + continue; + for (BasicBlock::iterator BI = BB.begin(), BE = BB.end(); BI != BE;) { - // Avoid invalidating the iterator. + // Avoid invalidating the iterator. Instruction *I = &*BI++; bool RepeatInstruction = false; diff --git a/lib/Transforms/Scalar/MergeICmps.cpp b/lib/Transforms/Scalar/MergeICmps.cpp index 9869a3fb96fa..ff0183a8ea2d 100644 --- a/lib/Transforms/Scalar/MergeICmps.cpp +++ b/lib/Transforms/Scalar/MergeICmps.cpp @@ -71,30 +71,30 @@ struct BCEAtom { }; // If this value is a load from a constant offset w.r.t. a base address, and -// there are no othe rusers of the load or address, returns the base address and +// there are no other users of the load or address, returns the base address and // the offset. BCEAtom visitICmpLoadOperand(Value *const Val) { BCEAtom Result; if (auto *const LoadI = dyn_cast<LoadInst>(Val)) { - DEBUG(dbgs() << "load\n"); + LLVM_DEBUG(dbgs() << "load\n"); if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) { - DEBUG(dbgs() << "used outside of block\n"); + LLVM_DEBUG(dbgs() << "used outside of block\n"); return {}; } if (LoadI->isVolatile()) { - DEBUG(dbgs() << "volatile\n"); + LLVM_DEBUG(dbgs() << "volatile\n"); return {}; } Value *const Addr = LoadI->getOperand(0); if (auto *const GEP = dyn_cast<GetElementPtrInst>(Addr)) { - DEBUG(dbgs() << "GEP\n"); + LLVM_DEBUG(dbgs() << "GEP\n"); if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) { - DEBUG(dbgs() << "used outside of block\n"); + LLVM_DEBUG(dbgs() << "used outside of block\n"); return {}; } const auto &DL = GEP->getModule()->getDataLayout(); if (!isDereferenceablePointer(GEP, DL)) { - DEBUG(dbgs() << "not dereferenceable\n"); + LLVM_DEBUG(dbgs() << "not dereferenceable\n"); // We need to make sure that we can do comparison in any order, so we // require memory to be unconditionnally dereferencable. return {}; @@ -110,6 +110,10 @@ BCEAtom visitICmpLoadOperand(Value *const Val) { } // A basic block with a comparison between two BCE atoms. +// The block might do extra work besides the atom comparison, in which case +// doesOtherWork() returns true. Under some conditions, the block can be +// split into the atom comparison part and the "other work" part +// (see canSplit()). // Note: the terminology is misleading: the comparison is symmetric, so there // is no real {l/r}hs. What we want though is to have the same base on the // left (resp. right), so that we can detect consecutive loads. To ensure this @@ -127,7 +131,7 @@ class BCECmpBlock { return Lhs_.Base() != nullptr && Rhs_.Base() != nullptr; } - // Assert the the block is consistent: If valid, it should also have + // Assert the block is consistent: If valid, it should also have // non-null members besides Lhs_ and Rhs_. void AssertConsistent() const { if (IsValid()) { @@ -144,37 +148,95 @@ class BCECmpBlock { // Returns true if the block does other works besides comparison. bool doesOtherWork() const; + // Returns true if the non-BCE-cmp instructions can be separated from BCE-cmp + // instructions in the block. + bool canSplit() const; + + // Return true if this all the relevant instructions in the BCE-cmp-block can + // be sunk below this instruction. By doing this, we know we can separate the + // BCE-cmp-block instructions from the non-BCE-cmp-block instructions in the + // block. + bool canSinkBCECmpInst(const Instruction *, DenseSet<Instruction *> &) const; + + // We can separate the BCE-cmp-block instructions and the non-BCE-cmp-block + // instructions. Split the old block and move all non-BCE-cmp-insts into the + // new parent block. + void split(BasicBlock *NewParent) const; + // The basic block where this comparison happens. BasicBlock *BB = nullptr; // The ICMP for this comparison. ICmpInst *CmpI = nullptr; // The terminating branch. BranchInst *BranchI = nullptr; + // The block requires splitting. + bool RequireSplit = false; - private: +private: BCEAtom Lhs_; BCEAtom Rhs_; int SizeBits_ = 0; }; +bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst, + DenseSet<Instruction *> &BlockInsts) const { + // If this instruction has side effects and its in middle of the BCE cmp block + // instructions, then bail for now. + // TODO: use alias analysis to tell whether there is real interference. + if (Inst->mayHaveSideEffects()) + return false; + // Make sure this instruction does not use any of the BCE cmp block + // instructions as operand. + for (auto BI : BlockInsts) { + if (is_contained(Inst->operands(), BI)) + return false; + } + return true; +} + +void BCECmpBlock::split(BasicBlock *NewParent) const { + DenseSet<Instruction *> BlockInsts( + {Lhs_.GEP, Rhs_.GEP, Lhs_.LoadI, Rhs_.LoadI, CmpI, BranchI}); + llvm::SmallVector<Instruction *, 4> OtherInsts; + for (Instruction &Inst : *BB) { + if (BlockInsts.count(&Inst)) + continue; + assert(canSinkBCECmpInst(&Inst, BlockInsts) && "Split unsplittable block"); + // This is a non-BCE-cmp-block instruction. And it can be separated + // from the BCE-cmp-block instruction. + OtherInsts.push_back(&Inst); + } + + // Do the actual spliting. + for (Instruction *Inst : reverse(OtherInsts)) { + Inst->moveBefore(&*NewParent->begin()); + } +} + +bool BCECmpBlock::canSplit() const { + DenseSet<Instruction *> BlockInsts( + {Lhs_.GEP, Rhs_.GEP, Lhs_.LoadI, Rhs_.LoadI, CmpI, BranchI}); + for (Instruction &Inst : *BB) { + if (!BlockInsts.count(&Inst)) { + if (!canSinkBCECmpInst(&Inst, BlockInsts)) + return false; + } + } + return true; +} + bool BCECmpBlock::doesOtherWork() const { AssertConsistent(); + // All the instructions we care about in the BCE cmp block. + DenseSet<Instruction *> BlockInsts( + {Lhs_.GEP, Rhs_.GEP, Lhs_.LoadI, Rhs_.LoadI, CmpI, BranchI}); // TODO(courbet): Can we allow some other things ? This is very conservative. - // We might be able to get away with anything does does not have any side + // We might be able to get away with anything does not have any side // effects outside of the basic block. // Note: The GEPs and/or loads are not necessarily in the same block. for (const Instruction &Inst : *BB) { - if (const auto *const GEP = dyn_cast<GetElementPtrInst>(&Inst)) { - if (!(Lhs_.GEP == GEP || Rhs_.GEP == GEP)) return true; - } else if (const auto *const L = dyn_cast<LoadInst>(&Inst)) { - if (!(Lhs_.LoadI == L || Rhs_.LoadI == L)) return true; - } else if (const auto *const C = dyn_cast<ICmpInst>(&Inst)) { - if (C != CmpI) return true; - } else if (const auto *const Br = dyn_cast<BranchInst>(&Inst)) { - if (Br != BranchI) return true; - } else { + if (!BlockInsts.count(&Inst)) return true; - } } return false; } @@ -183,10 +245,19 @@ bool BCECmpBlock::doesOtherWork() const { // BCE atoms, returns the comparison. BCECmpBlock visitICmp(const ICmpInst *const CmpI, const ICmpInst::Predicate ExpectedPredicate) { + // The comparison can only be used once: + // - For intermediate blocks, as a branch condition. + // - For the final block, as an incoming value for the Phi. + // If there are any other uses of the comparison, we cannot merge it with + // other comparisons as we would create an orphan use of the value. + if (!CmpI->hasOneUse()) { + LLVM_DEBUG(dbgs() << "cmp has several uses\n"); + return {}; + } if (CmpI->getPredicate() == ExpectedPredicate) { - DEBUG(dbgs() << "cmp " - << (ExpectedPredicate == ICmpInst::ICMP_EQ ? "eq" : "ne") - << "\n"); + LLVM_DEBUG(dbgs() << "cmp " + << (ExpectedPredicate == ICmpInst::ICMP_EQ ? "eq" : "ne") + << "\n"); auto Lhs = visitICmpLoadOperand(CmpI->getOperand(0)); if (!Lhs.Base()) return {}; auto Rhs = visitICmpLoadOperand(CmpI->getOperand(1)); @@ -204,7 +275,7 @@ BCECmpBlock visitCmpBlock(Value *const Val, BasicBlock *const Block, if (Block->empty()) return {}; auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator()); if (!BranchI) return {}; - DEBUG(dbgs() << "branch\n"); + LLVM_DEBUG(dbgs() << "branch\n"); if (BranchI->isUnconditional()) { // In this case, we expect an incoming value which is the result of the // comparison. This is the last link in the chain of comparisons (note @@ -212,7 +283,7 @@ BCECmpBlock visitCmpBlock(Value *const Val, BasicBlock *const Block, // can be reordered). auto *const CmpI = dyn_cast<ICmpInst>(Val); if (!CmpI) return {}; - DEBUG(dbgs() << "icmp\n"); + LLVM_DEBUG(dbgs() << "icmp\n"); auto Result = visitICmp(CmpI, ICmpInst::ICMP_EQ); Result.CmpI = CmpI; Result.BranchI = BranchI; @@ -221,12 +292,12 @@ BCECmpBlock visitCmpBlock(Value *const Val, BasicBlock *const Block, // In this case, we expect a constant incoming value (the comparison is // chained). const auto *const Const = dyn_cast<ConstantInt>(Val); - DEBUG(dbgs() << "const\n"); + LLVM_DEBUG(dbgs() << "const\n"); if (!Const->isZero()) return {}; - DEBUG(dbgs() << "false\n"); + LLVM_DEBUG(dbgs() << "false\n"); auto *const CmpI = dyn_cast<ICmpInst>(BranchI->getCondition()); if (!CmpI) return {}; - DEBUG(dbgs() << "icmp\n"); + LLVM_DEBUG(dbgs() << "icmp\n"); assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch"); BasicBlock *const FalseBlock = BranchI->getSuccessor(1); auto Result = visitICmp( @@ -238,6 +309,18 @@ BCECmpBlock visitCmpBlock(Value *const Val, BasicBlock *const Block, return {}; } +static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons, + BCECmpBlock &Comparison) { + LLVM_DEBUG(dbgs() << "Block '" << Comparison.BB->getName() + << "': Found cmp of " << Comparison.SizeBits() + << " bits between " << Comparison.Lhs().Base() << " + " + << Comparison.Lhs().Offset << " and " + << Comparison.Rhs().Base() << " + " + << Comparison.Rhs().Offset << "\n"); + LLVM_DEBUG(dbgs() << "\n"); + Comparisons.push_back(Comparison); +} + // A chain of comparisons. class BCECmpChain { public: @@ -263,9 +346,9 @@ class BCECmpChain { // Merges the given comparison blocks into one memcmp block and update // branches. Comparisons are assumed to be continguous. If NextBBInChain is // null, the merged block will link to the phi block. - static void mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, - BasicBlock *const NextBBInChain, PHINode &Phi, - const TargetLibraryInfo *const TLI); + void mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, + BasicBlock *const NextBBInChain, PHINode &Phi, + const TargetLibraryInfo *const TLI); PHINode &Phi_; std::vector<BCECmpBlock> Comparisons_; @@ -275,24 +358,47 @@ class BCECmpChain { BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi) : Phi_(Phi) { + assert(!Blocks.empty() && "a chain should have at least one block"); // Now look inside blocks to check for BCE comparisons. std::vector<BCECmpBlock> Comparisons; - for (BasicBlock *Block : Blocks) { + for (size_t BlockIdx = 0; BlockIdx < Blocks.size(); ++BlockIdx) { + BasicBlock *const Block = Blocks[BlockIdx]; + assert(Block && "invalid block"); BCECmpBlock Comparison = visitCmpBlock(Phi.getIncomingValueForBlock(Block), Block, Phi.getParent()); Comparison.BB = Block; if (!Comparison.IsValid()) { - DEBUG(dbgs() << "skip: not a valid BCECmpBlock\n"); + LLVM_DEBUG(dbgs() << "chain with invalid BCECmpBlock, no merge.\n"); return; } if (Comparison.doesOtherWork()) { - DEBUG(dbgs() << "block does extra work besides compare\n"); - if (Comparisons.empty()) { // First block. - // TODO(courbet): The first block can do other things, and we should - // split them apart in a separate block before the comparison chain. - // Right now we just discard it and make the chain shorter. - DEBUG(dbgs() - << "ignoring first block that does extra work besides compare\n"); + LLVM_DEBUG(dbgs() << "block '" << Comparison.BB->getName() + << "' does extra work besides compare\n"); + if (Comparisons.empty()) { + // This is the initial block in the chain, in case this block does other + // work, we can try to split the block and move the irrelevant + // instructions to the predecessor. + // + // If this is not the initial block in the chain, splitting it wont + // work. + // + // As once split, there will still be instructions before the BCE cmp + // instructions that do other work in program order, i.e. within the + // chain before sorting. Unless we can abort the chain at this point + // and start anew. + // + // NOTE: we only handle block with single predecessor for now. + if (Comparison.canSplit()) { + LLVM_DEBUG(dbgs() + << "Split initial block '" << Comparison.BB->getName() + << "' that does extra work besides compare\n"); + Comparison.RequireSplit = true; + enqueueBlock(Comparisons, Comparison); + } else { + LLVM_DEBUG(dbgs() + << "ignoring initial block '" << Comparison.BB->getName() + << "' that does extra work besides compare\n"); + } continue; } // TODO(courbet): Right now we abort the whole chain. We could be @@ -320,13 +426,13 @@ BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi) // We could still merge bb1 and bb2 though. return; } - DEBUG(dbgs() << "*Found cmp of " << Comparison.SizeBits() - << " bits between " << Comparison.Lhs().Base() << " + " - << Comparison.Lhs().Offset << " and " - << Comparison.Rhs().Base() << " + " << Comparison.Rhs().Offset - << "\n"); - DEBUG(dbgs() << "\n"); - Comparisons.push_back(Comparison); + enqueueBlock(Comparisons, Comparison); + } + + // It is possible we have no suitable comparison to merge. + if (Comparisons.empty()) { + LLVM_DEBUG(dbgs() << "chain with no BCE basic blocks, no merge\n"); + return; } EntryBlock_ = Comparisons[0].BB; Comparisons_ = std::move(Comparisons); @@ -336,10 +442,10 @@ BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi) #endif // MERGEICMPS_DOT_ON // Reorder blocks by LHS. We can do that without changing the // semantics because we are only accessing dereferencable memory. - std::sort(Comparisons_.begin(), Comparisons_.end(), - [](const BCECmpBlock &a, const BCECmpBlock &b) { - return a.Lhs() < b.Lhs(); - }); + llvm::sort(Comparisons_.begin(), Comparisons_.end(), + [](const BCECmpBlock &a, const BCECmpBlock &b) { + return a.Lhs() < b.Lhs(); + }); #ifdef MERGEICMPS_DOT_ON errs() << "AFTER REORDERING:\n\n"; dump(); @@ -389,10 +495,24 @@ bool BCECmpChain::simplify(const TargetLibraryInfo *const TLI) { Phi_.removeIncomingValue(Comparison.BB, false); } + // If entry block is part of the chain, we need to make the first block + // of the chain the new entry block of the function. + BasicBlock *Entry = &Comparisons_[0].BB->getParent()->getEntryBlock(); + for (size_t I = 1; I < Comparisons_.size(); ++I) { + if (Entry == Comparisons_[I].BB) { + BasicBlock *NEntryBB = BasicBlock::Create(Entry->getContext(), "", + Entry->getParent(), Entry); + BranchInst::Create(Entry, NEntryBB); + break; + } + } + // Point the predecessors of the chain to the first comparison block (which is - // the new entry point). - if (EntryBlock_ != Comparisons_[0].BB) + // the new entry point) and update the entry block of the chain. + if (EntryBlock_ != Comparisons_[0].BB) { EntryBlock_->replaceAllUsesWith(Comparisons_[0].BB); + EntryBlock_ = Comparisons_[0].BB; + } // Effectively merge blocks. int NumMerged = 1; @@ -424,7 +544,15 @@ void BCECmpChain::mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, LLVMContext &Context = BB->getContext(); if (Comparisons.size() >= 2) { - DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons\n"); + // If there is one block that requires splitting, we do it now, i.e. + // just before we know we will collapse the chain. The instructions + // can be executed before any of the instructions in the chain. + auto C = std::find_if(Comparisons.begin(), Comparisons.end(), + [](const BCECmpBlock &B) { return B.RequireSplit; }); + if (C != Comparisons.end()) + C->split(EntryBlock_); + + LLVM_DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons\n"); const auto TotalSize = std::accumulate(Comparisons.begin(), Comparisons.end(), 0, [](int Size, const BCECmpBlock &C) { @@ -445,7 +573,8 @@ void BCECmpChain::mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, IRBuilder<> Builder(BB); const auto &DL = Phi.getModule()->getDataLayout(); Value *const MemCmpCall = emitMemCmp( - FirstComparison.Lhs().GEP, FirstComparison.Rhs().GEP, ConstantInt::get(DL.getIntPtrType(Context), TotalSize), + FirstComparison.Lhs().GEP, FirstComparison.Rhs().GEP, + ConstantInt::get(DL.getIntPtrType(Context), TotalSize), Builder, DL, TLI); Value *const MemCmpIsZero = Builder.CreateICmpEQ( MemCmpCall, ConstantInt::get(Type::getInt32Ty(Context), 0)); @@ -468,17 +597,17 @@ void BCECmpChain::mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, } else { assert(Comparisons.size() == 1); // There are no blocks to merge, but we still need to update the branches. - DEBUG(dbgs() << "Only one comparison, updating branches\n"); + LLVM_DEBUG(dbgs() << "Only one comparison, updating branches\n"); if (NextBBInChain) { if (FirstComparison.BranchI->isConditional()) { - DEBUG(dbgs() << "conditional -> conditional\n"); + LLVM_DEBUG(dbgs() << "conditional -> conditional\n"); // Just update the "true" target, the "false" target should already be // the phi block. assert(FirstComparison.BranchI->getSuccessor(1) == Phi.getParent()); FirstComparison.BranchI->setSuccessor(0, NextBBInChain); Phi.addIncoming(ConstantInt::getFalse(Context), BB); } else { - DEBUG(dbgs() << "unconditional -> conditional\n"); + LLVM_DEBUG(dbgs() << "unconditional -> conditional\n"); // Replace the unconditional branch by a conditional one. FirstComparison.BranchI->eraseFromParent(); IRBuilder<> Builder(BB); @@ -488,14 +617,14 @@ void BCECmpChain::mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, } } else { if (FirstComparison.BranchI->isConditional()) { - DEBUG(dbgs() << "conditional -> unconditional\n"); + LLVM_DEBUG(dbgs() << "conditional -> unconditional\n"); // Replace the conditional branch by an unconditional one. FirstComparison.BranchI->eraseFromParent(); IRBuilder<> Builder(BB); Builder.CreateBr(Phi.getParent()); Phi.addIncoming(FirstComparison.CmpI, BB); } else { - DEBUG(dbgs() << "unconditional -> unconditional\n"); + LLVM_DEBUG(dbgs() << "unconditional -> unconditional\n"); Phi.addIncoming(FirstComparison.CmpI, BB); } } @@ -507,27 +636,28 @@ std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi, int NumBlocks) { // Walk up from the last block to find other blocks. std::vector<BasicBlock *> Blocks(NumBlocks); + assert(LastBlock && "invalid last block"); BasicBlock *CurBlock = LastBlock; for (int BlockIndex = NumBlocks - 1; BlockIndex > 0; --BlockIndex) { if (CurBlock->hasAddressTaken()) { // Somebody is jumping to the block through an address, all bets are // off. - DEBUG(dbgs() << "skip: block " << BlockIndex - << " has its address taken\n"); + LLVM_DEBUG(dbgs() << "skip: block " << BlockIndex + << " has its address taken\n"); return {}; } Blocks[BlockIndex] = CurBlock; auto *SinglePredecessor = CurBlock->getSinglePredecessor(); if (!SinglePredecessor) { // The block has two or more predecessors. - DEBUG(dbgs() << "skip: block " << BlockIndex - << " has two or more predecessors\n"); + LLVM_DEBUG(dbgs() << "skip: block " << BlockIndex + << " has two or more predecessors\n"); return {}; } if (Phi.getBasicBlockIndex(SinglePredecessor) < 0) { // The block does not link back to the phi. - DEBUG(dbgs() << "skip: block " << BlockIndex - << " does not link back to the phi\n"); + LLVM_DEBUG(dbgs() << "skip: block " << BlockIndex + << " does not link back to the phi\n"); return {}; } CurBlock = SinglePredecessor; @@ -537,9 +667,9 @@ std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi, } bool processPhi(PHINode &Phi, const TargetLibraryInfo *const TLI) { - DEBUG(dbgs() << "processPhi()\n"); + LLVM_DEBUG(dbgs() << "processPhi()\n"); if (Phi.getNumIncomingValues() <= 1) { - DEBUG(dbgs() << "skip: only one incoming value in phi\n"); + LLVM_DEBUG(dbgs() << "skip: only one incoming value in phi\n"); return false; } // We are looking for something that has the following structure: @@ -552,7 +682,7 @@ bool processPhi(PHINode &Phi, const TargetLibraryInfo *const TLI) { // - The last basic block (bb4 here) must branch unconditionally to bb_phi. // It's the only block that contributes a non-constant value to the Phi. // - All other blocks (b1, b2, b3) must have exactly two successors, one of - // them being the the phi block. + // them being the phi block. // - All intermediate blocks (bb2, bb3) must have only one predecessor. // - Blocks cannot do other work besides the comparison, see doesOtherWork() @@ -563,18 +693,31 @@ bool processPhi(PHINode &Phi, const TargetLibraryInfo *const TLI) { if (isa<ConstantInt>(Phi.getIncomingValue(I))) continue; if (LastBlock) { // There are several non-constant values. - DEBUG(dbgs() << "skip: several non-constant values\n"); + LLVM_DEBUG(dbgs() << "skip: several non-constant values\n"); + return false; + } + if (!isa<ICmpInst>(Phi.getIncomingValue(I)) || + cast<ICmpInst>(Phi.getIncomingValue(I))->getParent() != + Phi.getIncomingBlock(I)) { + // Non-constant incoming value is not from a cmp instruction or not + // produced by the last block. We could end up processing the value + // producing block more than once. + // + // This is an uncommon case, so we bail. + LLVM_DEBUG( + dbgs() + << "skip: non-constant value not from cmp or not from last block.\n"); return false; } LastBlock = Phi.getIncomingBlock(I); } if (!LastBlock) { // There is no non-constant block. - DEBUG(dbgs() << "skip: no non-constant block\n"); + LLVM_DEBUG(dbgs() << "skip: no non-constant block\n"); return false; } if (LastBlock->getSingleSuccessor() != Phi.getParent()) { - DEBUG(dbgs() << "skip: last block non-phi successor\n"); + LLVM_DEBUG(dbgs() << "skip: last block non-phi successor\n"); return false; } @@ -584,7 +727,7 @@ bool processPhi(PHINode &Phi, const TargetLibraryInfo *const TLI) { BCECmpChain CmpChain(Blocks, Phi); if (CmpChain.size() < 2) { - DEBUG(dbgs() << "skip: only one compare block\n"); + LLVM_DEBUG(dbgs() << "skip: only one compare block\n"); return false; } @@ -619,12 +762,16 @@ class MergeICmps : public FunctionPass { PreservedAnalyses MergeICmps::runImpl(Function &F, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI) { - DEBUG(dbgs() << "MergeICmpsPass: " << F.getName() << "\n"); + LLVM_DEBUG(dbgs() << "MergeICmpsPass: " << F.getName() << "\n"); // We only try merging comparisons if the target wants to expand memcmp later. // The rationale is to avoid turning small chains into memcmp calls. if (!TTI->enableMemCmpExpansion(true)) return PreservedAnalyses::all(); + // If we don't have memcmp avaiable we can't emit calls to it. + if (!TLI->has(LibFunc_memcmp)) + return PreservedAnalyses::all(); + bool MadeChange = false; for (auto BBIt = ++F.begin(); BBIt != F.end(); ++BBIt) { diff --git a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp index f2f615cb9b0f..3464b759280f 100644 --- a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp +++ b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// // //! \file -//! \brief This pass performs merges of loads and stores on both sides of a +//! This pass performs merges of loads and stores on both sides of a // diamond (hammock). It hoists the loads and sinks the stores. // // The algorithm iteratively hoists two loads to the same address out of a @@ -80,7 +80,6 @@ #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/Loads.h" -#include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Metadata.h" #include "llvm/Support/Debug.h" @@ -97,7 +96,6 @@ namespace { // MergedLoadStoreMotion Pass //===----------------------------------------------------------------------===// class MergedLoadStoreMotion { - MemoryDependenceResults *MD = nullptr; AliasAnalysis *AA = nullptr; // The mergeLoad/Store algorithms could have Size0 * Size1 complexity, @@ -107,14 +105,9 @@ class MergedLoadStoreMotion { const int MagicCompileTimeControl = 250; public: - bool run(Function &F, MemoryDependenceResults *MD, AliasAnalysis &AA); + bool run(Function &F, AliasAnalysis &AA); private: - /// - /// \brief Remove instruction from parent and update memory dependence - /// analysis. - /// - void removeInstruction(Instruction *Inst); BasicBlock *getDiamondTail(BasicBlock *BB); bool isDiamondHead(BasicBlock *BB); // Routines for sinking stores @@ -128,23 +121,7 @@ private: } // end anonymous namespace /// -/// \brief Remove instruction from parent and update memory dependence analysis. -/// -void MergedLoadStoreMotion::removeInstruction(Instruction *Inst) { - // Notify the memory dependence analysis. - if (MD) { - MD->removeInstruction(Inst); - if (auto *LI = dyn_cast<LoadInst>(Inst)) - MD->invalidateCachedPointerInfo(LI->getPointerOperand()); - if (Inst->getType()->isPtrOrPtrVectorTy()) { - MD->invalidateCachedPointerInfo(Inst); - } - } - Inst->eraseFromParent(); -} - -/// -/// \brief Return tail block of a diamond. +/// Return tail block of a diamond. /// BasicBlock *MergedLoadStoreMotion::getDiamondTail(BasicBlock *BB) { assert(isDiamondHead(BB) && "Basic block is not head of a diamond"); @@ -152,7 +129,7 @@ BasicBlock *MergedLoadStoreMotion::getDiamondTail(BasicBlock *BB) { } /// -/// \brief True when BB is the head of a diamond (hammock) +/// True when BB is the head of a diamond (hammock) /// bool MergedLoadStoreMotion::isDiamondHead(BasicBlock *BB) { if (!BB) @@ -179,7 +156,7 @@ bool MergedLoadStoreMotion::isDiamondHead(BasicBlock *BB) { /// -/// \brief True when instruction is a sink barrier for a store +/// True when instruction is a sink barrier for a store /// located in Loc /// /// Whenever an instruction could possibly read or modify the @@ -197,13 +174,13 @@ bool MergedLoadStoreMotion::isStoreSinkBarrierInRange(const Instruction &Start, } /// -/// \brief Check if \p BB contains a store to the same address as \p SI +/// Check if \p BB contains a store to the same address as \p SI /// /// \return The store in \p when it is safe to sink. Otherwise return Null. /// StoreInst *MergedLoadStoreMotion::canSinkFromBlock(BasicBlock *BB1, StoreInst *Store0) { - DEBUG(dbgs() << "can Sink? : "; Store0->dump(); dbgs() << "\n"); + LLVM_DEBUG(dbgs() << "can Sink? : "; Store0->dump(); dbgs() << "\n"); BasicBlock *BB0 = Store0->getParent(); for (Instruction &Inst : reverse(*BB1)) { auto *Store1 = dyn_cast<StoreInst>(&Inst); @@ -222,7 +199,7 @@ StoreInst *MergedLoadStoreMotion::canSinkFromBlock(BasicBlock *BB1, } /// -/// \brief Create a PHI node in BB for the operands of S0 and S1 +/// Create a PHI node in BB for the operands of S0 and S1 /// PHINode *MergedLoadStoreMotion::getPHIOperand(BasicBlock *BB, StoreInst *S0, StoreInst *S1) { @@ -236,13 +213,11 @@ PHINode *MergedLoadStoreMotion::getPHIOperand(BasicBlock *BB, StoreInst *S0, &BB->front()); NewPN->addIncoming(Opd1, S0->getParent()); NewPN->addIncoming(Opd2, S1->getParent()); - if (MD && NewPN->getType()->isPtrOrPtrVectorTy()) - MD->invalidateCachedPointerInfo(NewPN); return NewPN; } /// -/// \brief Merge two stores to same address and sink into \p BB +/// Merge two stores to same address and sink into \p BB /// /// Also sinks GEP instruction computing the store address /// @@ -254,9 +229,9 @@ bool MergedLoadStoreMotion::sinkStore(BasicBlock *BB, StoreInst *S0, if (A0 && A1 && A0->isIdenticalTo(A1) && A0->hasOneUse() && (A0->getParent() == S0->getParent()) && A1->hasOneUse() && (A1->getParent() == S1->getParent()) && isa<GetElementPtrInst>(A0)) { - DEBUG(dbgs() << "Sink Instruction into BB \n"; BB->dump(); - dbgs() << "Instruction Left\n"; S0->dump(); dbgs() << "\n"; - dbgs() << "Instruction Right\n"; S1->dump(); dbgs() << "\n"); + LLVM_DEBUG(dbgs() << "Sink Instruction into BB \n"; BB->dump(); + dbgs() << "Instruction Left\n"; S0->dump(); dbgs() << "\n"; + dbgs() << "Instruction Right\n"; S1->dump(); dbgs() << "\n"); // Hoist the instruction. BasicBlock::iterator InsertPt = BB->getFirstInsertionPt(); // Intersect optional metadata. @@ -275,19 +250,19 @@ bool MergedLoadStoreMotion::sinkStore(BasicBlock *BB, StoreInst *S0, // New PHI operand? Use it. if (PHINode *NewPN = getPHIOperand(BB, S0, S1)) SNew->setOperand(0, NewPN); - removeInstruction(S0); - removeInstruction(S1); + S0->eraseFromParent(); + S1->eraseFromParent(); A0->replaceAllUsesWith(ANew); - removeInstruction(A0); + A0->eraseFromParent(); A1->replaceAllUsesWith(ANew); - removeInstruction(A1); + A1->eraseFromParent(); return true; } return false; } /// -/// \brief True when two stores are equivalent and can sink into the footer +/// True when two stores are equivalent and can sink into the footer /// /// Starting from a diamond tail block, iterate over the instructions in one /// predecessor block and try to match a store in the second predecessor. @@ -310,7 +285,8 @@ bool MergedLoadStoreMotion::mergeStores(BasicBlock *T) { return false; // No. More than 2 predecessors. // #Instructions in Succ1 for Compile Time Control - int Size1 = Pred1->size(); + auto InstsNoDbg = Pred1->instructionsWithoutDebug(); + int Size1 = std::distance(InstsNoDbg.begin(), InstsNoDbg.end()); int NStores = 0; for (BasicBlock::reverse_iterator RBI = Pred0->rbegin(), RBE = Pred0->rend(); @@ -338,19 +314,17 @@ bool MergedLoadStoreMotion::mergeStores(BasicBlock *T) { break; RBI = Pred0->rbegin(); RBE = Pred0->rend(); - DEBUG(dbgs() << "Search again\n"; Instruction *I = &*RBI; I->dump()); + LLVM_DEBUG(dbgs() << "Search again\n"; Instruction *I = &*RBI; I->dump()); } } return MergedStores; } -bool MergedLoadStoreMotion::run(Function &F, MemoryDependenceResults *MD, - AliasAnalysis &AA) { - this->MD = MD; +bool MergedLoadStoreMotion::run(Function &F, AliasAnalysis &AA) { this->AA = &AA; bool Changed = false; - DEBUG(dbgs() << "Instruction Merger\n"); + LLVM_DEBUG(dbgs() << "Instruction Merger\n"); // Merge unconditional branches, allowing PRE to catch more // optimization opportunities. @@ -376,15 +350,13 @@ public: } /// - /// \brief Run the transformation for each function + /// Run the transformation for each function /// bool runOnFunction(Function &F) override { if (skipFunction(F)) return false; MergedLoadStoreMotion Impl; - auto *MDWP = getAnalysisIfAvailable<MemoryDependenceWrapperPass>(); - return Impl.run(F, MDWP ? &MDWP->getMemDep() : nullptr, - getAnalysis<AAResultsWrapperPass>().getAAResults()); + return Impl.run(F, getAnalysis<AAResultsWrapperPass>().getAAResults()); } private: @@ -392,7 +364,6 @@ private: AU.setPreservesCFG(); AU.addRequired<AAResultsWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addPreserved<MemoryDependenceWrapperPass>(); } }; @@ -400,7 +371,7 @@ char MergedLoadStoreMotionLegacyPass::ID = 0; } // anonymous namespace /// -/// \brief createMergedLoadStoreMotionPass - The public interface to this file. +/// createMergedLoadStoreMotionPass - The public interface to this file. /// FunctionPass *llvm::createMergedLoadStoreMotionPass() { return new MergedLoadStoreMotionLegacyPass(); @@ -408,7 +379,6 @@ FunctionPass *llvm::createMergedLoadStoreMotionPass() { INITIALIZE_PASS_BEGIN(MergedLoadStoreMotionLegacyPass, "mldst-motion", "MergedLoadStoreMotion", false, false) -INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_END(MergedLoadStoreMotionLegacyPass, "mldst-motion", "MergedLoadStoreMotion", false, false) @@ -416,14 +386,12 @@ INITIALIZE_PASS_END(MergedLoadStoreMotionLegacyPass, "mldst-motion", PreservedAnalyses MergedLoadStoreMotionPass::run(Function &F, FunctionAnalysisManager &AM) { MergedLoadStoreMotion Impl; - auto *MD = AM.getCachedResult<MemoryDependenceAnalysis>(F); auto &AA = AM.getResult<AAManager>(F); - if (!Impl.run(F, MD, AA)) + if (!Impl.run(F, AA)) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); PA.preserve<GlobalsAA>(); - PA.preserve<MemoryDependenceAnalysis>(); return PA; } diff --git a/lib/Transforms/Scalar/NaryReassociate.cpp b/lib/Transforms/Scalar/NaryReassociate.cpp index b026c8d692c3..7106ea216ad6 100644 --- a/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/lib/Transforms/Scalar/NaryReassociate.cpp @@ -83,6 +83,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" @@ -105,7 +106,6 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstdint> @@ -240,10 +240,17 @@ bool NaryReassociatePass::doOneIteration(Function &F) { Changed = true; SE->forgetValue(&*I); I->replaceAllUsesWith(NewI); - // If SeenExprs constains I's WeakTrackingVH, that entry will be - // replaced with - // nullptr. + WeakVH NewIExist = NewI; + // If SeenExprs/NewIExist contains I's WeakTrackingVH/WeakVH, that + // entry will be replaced with nullptr if deleted. RecursivelyDeleteTriviallyDeadInstructions(&*I, TLI); + if (!NewIExist) { + // Rare occation where the new instruction (NewI) have been removed, + // probably due to parts of the input code was dead from the + // beginning, reset the iterator and start over from the beginning + I = BB->begin(); + continue; + } I = NewI->getIterator(); } // Add the rewritten instruction to SeenExprs; the original instruction @@ -429,6 +436,9 @@ NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, Instruction *NaryReassociatePass::tryReassociateBinaryOp(BinaryOperator *I) { Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); + // There is no need to reassociate 0. + if (SE->getSCEV(I)->isZero()) + return nullptr; if (auto *NewI = tryReassociateBinaryOp(LHS, RHS, I)) return NewI; if (auto *NewI = tryReassociateBinaryOp(RHS, LHS, I)) diff --git a/lib/Transforms/Scalar/NewGVN.cpp b/lib/Transforms/Scalar/NewGVN.cpp index 9ebf2d769356..2eb887c986be 100644 --- a/lib/Transforms/Scalar/NewGVN.cpp +++ b/lib/Transforms/Scalar/NewGVN.cpp @@ -77,6 +77,7 @@ #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -105,7 +106,6 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/GVNExpression.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/PredicateInfo.h" #include "llvm/Transforms/Utils/VNCoercion.h" #include <algorithm> @@ -221,13 +221,13 @@ private: Components.resize(Components.size() + 1); auto &Component = Components.back(); Component.insert(I); - DEBUG(dbgs() << "Component root is " << *I << "\n"); + LLVM_DEBUG(dbgs() << "Component root is " << *I << "\n"); InComponent.insert(I); ValueToComponent[I] = ComponentID; // Pop a component off the stack and label it. while (!Stack.empty() && Root.lookup(Stack.back()) >= OurDFS) { auto *Member = Stack.back(); - DEBUG(dbgs() << "Component member is " << *Member << "\n"); + LLVM_DEBUG(dbgs() << "Component member is " << *Member << "\n"); Component.insert(Member); InComponent.insert(Member); ValueToComponent[Member] = ComponentID; @@ -366,9 +366,8 @@ public: // True if this class has no memory members. bool definesNoMemory() const { return StoreCount == 0 && memory_empty(); } - // Return true if two congruence classes are equivalent to each other. This - // means - // that every field but the ID number and the dead field are equivalent. + // Return true if two congruence classes are equivalent to each other. This + // means that every field but the ID number and the dead field are equivalent. bool isEquivalentTo(const CongruenceClass *Other) const { if (!Other) return false; @@ -383,10 +382,12 @@ public: if (!DefiningExpr || !Other->DefiningExpr || *DefiningExpr != *Other->DefiningExpr) return false; - // We need some ordered set - std::set<Value *> AMembers(Members.begin(), Members.end()); - std::set<Value *> BMembers(Members.begin(), Members.end()); - return AMembers == BMembers; + + if (Members.size() != Other->Members.size()) + return false; + + return all_of(Members, + [&](const Value *V) { return Other->Members.count(V); }); } private: @@ -860,7 +861,7 @@ private: // Debug counter info. When verifying, we have to reset the value numbering // debug counter to the same state it started in to get the same results. - std::pair<int, int> StartingVNCounter; + int64_t StartingVNCounter; }; } // end anonymous namespace @@ -958,7 +959,8 @@ static bool isCopyOfAPHI(const Value *V) { // order. The BlockInstRange numbers are generated in an RPO walk of the basic // blocks. void NewGVN::sortPHIOps(MutableArrayRef<ValPair> Ops) const { - std::sort(Ops.begin(), Ops.end(), [&](const ValPair &P1, const ValPair &P2) { + llvm::sort(Ops.begin(), Ops.end(), + [&](const ValPair &P1, const ValPair &P2) { return BlockInstRange.lookup(P1.second).first < BlockInstRange.lookup(P2.second).first; }); @@ -1067,8 +1069,8 @@ const Expression *NewGVN::checkSimplificationResults(Expression *E, return nullptr; if (auto *C = dyn_cast<Constant>(V)) { if (I) - DEBUG(dbgs() << "Simplified " << *I << " to " - << " constant " << *C << "\n"); + LLVM_DEBUG(dbgs() << "Simplified " << *I << " to " + << " constant " << *C << "\n"); NumGVNOpsSimplified++; assert(isa<BasicExpression>(E) && "We should always have had a basic expression here"); @@ -1076,8 +1078,8 @@ const Expression *NewGVN::checkSimplificationResults(Expression *E, return createConstantExpression(C); } else if (isa<Argument>(V) || isa<GlobalVariable>(V)) { if (I) - DEBUG(dbgs() << "Simplified " << *I << " to " - << " variable " << *V << "\n"); + LLVM_DEBUG(dbgs() << "Simplified " << *I << " to " + << " variable " << *V << "\n"); deleteExpression(E); return createVariableExpression(V); } @@ -1100,8 +1102,8 @@ const Expression *NewGVN::checkSimplificationResults(Expression *E, } if (I) - DEBUG(dbgs() << "Simplified " << *I << " to " - << " expression " << *CC->getDefiningExpr() << "\n"); + LLVM_DEBUG(dbgs() << "Simplified " << *I << " to " + << " expression " << *CC->getDefiningExpr() << "\n"); NumGVNOpsSimplified++; deleteExpression(E); return CC->getDefiningExpr(); @@ -1257,7 +1259,7 @@ bool NewGVN::someEquivalentDominates(const Instruction *Inst, // This must be an instruction because we are only called from phi nodes // in the case that the value it needs to check against is an instruction. - // The most likely candiates for dominance are the leader and the next leader. + // The most likely candidates for dominance are the leader and the next leader. // The leader or nextleader will dominate in all cases where there is an // equivalent that is higher up in the dom tree. // We can't *only* check them, however, because the @@ -1421,8 +1423,8 @@ NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr, if (Offset >= 0) { if (auto *C = dyn_cast<Constant>( lookupOperandLeader(DepSI->getValueOperand()))) { - DEBUG(dbgs() << "Coercing load from store " << *DepSI << " to constant " - << *C << "\n"); + LLVM_DEBUG(dbgs() << "Coercing load from store " << *DepSI + << " to constant " << *C << "\n"); return createConstantExpression( getConstantStoreValueForLoad(C, Offset, LoadType, DL)); } @@ -1437,8 +1439,8 @@ NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr, if (auto *C = dyn_cast<Constant>(lookupOperandLeader(DepLI))) if (auto *PossibleConstant = getConstantLoadValueForLoad(C, Offset, LoadType, DL)) { - DEBUG(dbgs() << "Coercing load from load " << *LI << " to constant " - << *PossibleConstant << "\n"); + LLVM_DEBUG(dbgs() << "Coercing load from load " << *LI + << " to constant " << *PossibleConstant << "\n"); return createConstantExpression(PossibleConstant); } } @@ -1447,8 +1449,8 @@ NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr, if (Offset >= 0) { if (auto *PossibleConstant = getConstantMemInstValueForLoad(DepMI, Offset, LoadType, DL)) { - DEBUG(dbgs() << "Coercing load from meminst " << *DepMI - << " to constant " << *PossibleConstant << "\n"); + LLVM_DEBUG(dbgs() << "Coercing load from meminst " << *DepMI + << " to constant " << *PossibleConstant << "\n"); return createConstantExpression(PossibleConstant); } } @@ -1529,7 +1531,7 @@ NewGVN::performSymbolicPredicateInfoEvaluation(Instruction *I) const { if (!PI) return nullptr; - DEBUG(dbgs() << "Found predicate info from instruction !\n"); + LLVM_DEBUG(dbgs() << "Found predicate info from instruction !\n"); auto *PWC = dyn_cast<PredicateWithCondition>(PI); if (!PWC) @@ -1569,7 +1571,7 @@ NewGVN::performSymbolicPredicateInfoEvaluation(Instruction *I) const { return nullptr; if (CopyOf != Cmp->getOperand(0) && CopyOf != Cmp->getOperand(1)) { - DEBUG(dbgs() << "Copy is not of any condition operands!\n"); + LLVM_DEBUG(dbgs() << "Copy is not of any condition operands!\n"); return nullptr; } Value *FirstOp = lookupOperandLeader(Cmp->getOperand(0)); @@ -1584,11 +1586,11 @@ NewGVN::performSymbolicPredicateInfoEvaluation(Instruction *I) const { SwappedOps ? Cmp->getSwappedPredicate() : Cmp->getPredicate(); if (isa<PredicateAssume>(PI)) { - // If the comparison is true when the operands are equal, then we know the - // operands are equal, because assumes must always be true. - if (CmpInst::isTrueWhenEqual(Predicate)) { + // If we assume the operands are equal, then they are equal. + if (Predicate == CmpInst::ICMP_EQ) { addPredicateUsers(PI, I); - addAdditionalUsers(Cmp->getOperand(0), I); + addAdditionalUsers(SwappedOps ? Cmp->getOperand(1) : Cmp->getOperand(0), + I); return createVariableOrConstant(FirstOp); } } @@ -1622,7 +1624,7 @@ NewGVN::performSymbolicPredicateInfoEvaluation(Instruction *I) const { const Expression *NewGVN::performSymbolicCallEvaluation(Instruction *I) const { auto *CI = cast<CallInst>(I); if (auto *II = dyn_cast<IntrinsicInst>(I)) { - // Instrinsics with the returned attribute are copies of arguments. + // Intrinsics with the returned attribute are copies of arguments. if (auto *ReturnedValue = II->getReturnedArgOperand()) { if (II->getIntrinsicID() == Intrinsic::ssa_copy) if (const auto *Result = performSymbolicPredicateInfoEvaluation(I)) @@ -1652,10 +1654,11 @@ bool NewGVN::setMemoryClass(const MemoryAccess *From, CongruenceClass *NewClass) { assert(NewClass && "Every MemoryAccess should be getting mapped to a non-null class"); - DEBUG(dbgs() << "Setting " << *From); - DEBUG(dbgs() << " equivalent to congruence class "); - DEBUG(dbgs() << NewClass->getID() << " with current MemoryAccess leader "); - DEBUG(dbgs() << *NewClass->getMemoryLeader() << "\n"); + LLVM_DEBUG(dbgs() << "Setting " << *From); + LLVM_DEBUG(dbgs() << " equivalent to congruence class "); + LLVM_DEBUG(dbgs() << NewClass->getID() + << " with current MemoryAccess leader "); + LLVM_DEBUG(dbgs() << *NewClass->getMemoryLeader() << "\n"); auto LookupResult = MemoryAccessToClass.find(From); bool Changed = false; @@ -1673,11 +1676,11 @@ bool NewGVN::setMemoryClass(const MemoryAccess *From, OldClass->setMemoryLeader(nullptr); } else { OldClass->setMemoryLeader(getNextMemoryLeader(OldClass)); - DEBUG(dbgs() << "Memory class leader change for class " - << OldClass->getID() << " to " - << *OldClass->getMemoryLeader() - << " due to removal of a memory member " << *From - << "\n"); + LLVM_DEBUG(dbgs() << "Memory class leader change for class " + << OldClass->getID() << " to " + << *OldClass->getMemoryLeader() + << " due to removal of a memory member " << *From + << "\n"); markMemoryLeaderChangeTouched(OldClass); } } @@ -1705,7 +1708,7 @@ bool NewGVN::isCycleFree(const Instruction *I) const { if (ICS == ICS_Unknown) { SCCFinder.Start(I); auto &SCC = SCCFinder.getComponentFor(I); - // It's cycle free if it's size 1 or or the SCC is *only* phi nodes. + // It's cycle free if it's size 1 or the SCC is *only* phi nodes. if (SCC.size() == 1) InstCycleState.insert({I, ICS_CycleFree}); else { @@ -1753,12 +1756,13 @@ NewGVN::performSymbolicPHIEvaluation(ArrayRef<ValPair> PHIOps, // If it has undef at this point, it means there are no-non-undef arguments, // and thus, the value of the phi node must be undef. if (HasUndef) { - DEBUG(dbgs() << "PHI Node " << *I - << " has no non-undef arguments, valuing it as undef\n"); + LLVM_DEBUG( + dbgs() << "PHI Node " << *I + << " has no non-undef arguments, valuing it as undef\n"); return createConstantExpression(UndefValue::get(I->getType())); } - DEBUG(dbgs() << "No arguments of PHI node " << *I << " are live\n"); + LLVM_DEBUG(dbgs() << "No arguments of PHI node " << *I << " are live\n"); deleteExpression(E); return createDeadExpression(); } @@ -1797,8 +1801,8 @@ NewGVN::performSymbolicPHIEvaluation(ArrayRef<ValPair> PHIOps, InstrToDFSNum(AllSameValue) > InstrToDFSNum(I)) return E; NumGVNPhisAllSame++; - DEBUG(dbgs() << "Simplified PHI node " << *I << " to " << *AllSameValue - << "\n"); + LLVM_DEBUG(dbgs() << "Simplified PHI node " << *I << " to " << *AllSameValue + << "\n"); deleteExpression(E); return createVariableOrConstant(AllSameValue); } @@ -2091,7 +2095,7 @@ void NewGVN::markUsersTouched(Value *V) { } void NewGVN::addMemoryUsers(const MemoryAccess *To, MemoryAccess *U) const { - DEBUG(dbgs() << "Adding memory user " << *U << " to " << *To << "\n"); + LLVM_DEBUG(dbgs() << "Adding memory user " << *U << " to " << *To << "\n"); MemoryToUsers[To].insert(U); } @@ -2207,13 +2211,13 @@ Value *NewGVN::getNextValueLeader(CongruenceClass *CC) const { // // - I must be moving to NewClass from OldClass // - The StoreCount of OldClass and NewClass is expected to have been updated -// for I already if it is is a store. +// for I already if it is a store. // - The OldClass memory leader has not been updated yet if I was the leader. void NewGVN::moveMemoryToNewCongruenceClass(Instruction *I, MemoryAccess *InstMA, CongruenceClass *OldClass, CongruenceClass *NewClass) { - // If the leader is I, and we had a represenative MemoryAccess, it should + // If the leader is I, and we had a representative MemoryAccess, it should // be the MemoryAccess of OldClass. assert((!InstMA || !OldClass->getMemoryLeader() || OldClass->getLeader() != I || @@ -2227,8 +2231,9 @@ void NewGVN::moveMemoryToNewCongruenceClass(Instruction *I, (isa<StoreInst>(I) && NewClass->getStoreCount() == 1)); NewClass->setMemoryLeader(InstMA); // Mark it touched if we didn't just create a singleton - DEBUG(dbgs() << "Memory class leader change for class " << NewClass->getID() - << " due to new memory instruction becoming leader\n"); + LLVM_DEBUG(dbgs() << "Memory class leader change for class " + << NewClass->getID() + << " due to new memory instruction becoming leader\n"); markMemoryLeaderChangeTouched(NewClass); } setMemoryClass(InstMA, NewClass); @@ -2236,10 +2241,10 @@ void NewGVN::moveMemoryToNewCongruenceClass(Instruction *I, if (OldClass->getMemoryLeader() == InstMA) { if (!OldClass->definesNoMemory()) { OldClass->setMemoryLeader(getNextMemoryLeader(OldClass)); - DEBUG(dbgs() << "Memory class leader change for class " - << OldClass->getID() << " to " - << *OldClass->getMemoryLeader() - << " due to removal of old leader " << *InstMA << "\n"); + LLVM_DEBUG(dbgs() << "Memory class leader change for class " + << OldClass->getID() << " to " + << *OldClass->getMemoryLeader() + << " due to removal of old leader " << *InstMA << "\n"); markMemoryLeaderChangeTouched(OldClass); } else OldClass->setMemoryLeader(nullptr); @@ -2276,9 +2281,10 @@ void NewGVN::moveValueToNewCongruenceClass(Instruction *I, const Expression *E, NewClass->setStoredValue(SE->getStoredValue()); markValueLeaderChangeTouched(NewClass); // Shift the new class leader to be the store - DEBUG(dbgs() << "Changing leader of congruence class " - << NewClass->getID() << " from " << *NewClass->getLeader() - << " to " << *SI << " because store joined class\n"); + LLVM_DEBUG(dbgs() << "Changing leader of congruence class " + << NewClass->getID() << " from " + << *NewClass->getLeader() << " to " << *SI + << " because store joined class\n"); // If we changed the leader, we have to mark it changed because we don't // know what it will do to symbolic evaluation. NewClass->setLeader(SI); @@ -2298,8 +2304,8 @@ void NewGVN::moveValueToNewCongruenceClass(Instruction *I, const Expression *E, // See if we destroyed the class or need to swap leaders. if (OldClass->empty() && OldClass != TOPClass) { if (OldClass->getDefiningExpr()) { - DEBUG(dbgs() << "Erasing expression " << *OldClass->getDefiningExpr() - << " from table\n"); + LLVM_DEBUG(dbgs() << "Erasing expression " << *OldClass->getDefiningExpr() + << " from table\n"); // We erase it as an exact expression to make sure we don't just erase an // equivalent one. auto Iter = ExpressionToClass.find_as( @@ -2316,8 +2322,8 @@ void NewGVN::moveValueToNewCongruenceClass(Instruction *I, const Expression *E, // When the leader changes, the value numbering of // everything may change due to symbolization changes, so we need to // reprocess. - DEBUG(dbgs() << "Value class leader change for class " << OldClass->getID() - << "\n"); + LLVM_DEBUG(dbgs() << "Value class leader change for class " + << OldClass->getID() << "\n"); ++NumGVNLeaderChanges; // Destroy the stored value if there are no more stores to represent it. // Note that this is basically clean up for the expression removal that @@ -2380,12 +2386,14 @@ void NewGVN::performCongruenceFinding(Instruction *I, const Expression *E) { "VariableExpression should have been handled already"); EClass = NewClass; - DEBUG(dbgs() << "Created new congruence class for " << *I - << " using expression " << *E << " at " << NewClass->getID() - << " and leader " << *(NewClass->getLeader())); + LLVM_DEBUG(dbgs() << "Created new congruence class for " << *I + << " using expression " << *E << " at " + << NewClass->getID() << " and leader " + << *(NewClass->getLeader())); if (NewClass->getStoredValue()) - DEBUG(dbgs() << " and stored value " << *(NewClass->getStoredValue())); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " and stored value " + << *(NewClass->getStoredValue())); + LLVM_DEBUG(dbgs() << "\n"); } else { EClass = lookupResult.first->second; if (isa<ConstantExpression>(E)) @@ -2403,8 +2411,8 @@ void NewGVN::performCongruenceFinding(Instruction *I, const Expression *E) { bool ClassChanged = IClass != EClass; bool LeaderChanged = LeaderChanges.erase(I); if (ClassChanged || LeaderChanged) { - DEBUG(dbgs() << "New class " << EClass->getID() << " for expression " << *E - << "\n"); + LLVM_DEBUG(dbgs() << "New class " << EClass->getID() << " for expression " + << *E << "\n"); if (ClassChanged) { moveValueToNewCongruenceClass(I, E, IClass, EClass); markPhiOfOpsChanged(E); @@ -2442,13 +2450,15 @@ void NewGVN::updateReachableEdge(BasicBlock *From, BasicBlock *To) { if (ReachableEdges.insert({From, To}).second) { // If this block wasn't reachable before, all instructions are touched. if (ReachableBlocks.insert(To).second) { - DEBUG(dbgs() << "Block " << getBlockName(To) << " marked reachable\n"); + LLVM_DEBUG(dbgs() << "Block " << getBlockName(To) + << " marked reachable\n"); const auto &InstRange = BlockInstRange.lookup(To); TouchedInstructions.set(InstRange.first, InstRange.second); } else { - DEBUG(dbgs() << "Block " << getBlockName(To) - << " was reachable, but new edge {" << getBlockName(From) - << "," << getBlockName(To) << "} to it found\n"); + LLVM_DEBUG(dbgs() << "Block " << getBlockName(To) + << " was reachable, but new edge {" + << getBlockName(From) << "," << getBlockName(To) + << "} to it found\n"); // We've made an edge reachable to an existing block, which may // impact predicates. Otherwise, only mark the phi nodes as touched, as @@ -2495,12 +2505,12 @@ void NewGVN::processOutgoingEdges(TerminatorInst *TI, BasicBlock *B) { BasicBlock *FalseSucc = BR->getSuccessor(1); if (CondEvaluated && (CI = dyn_cast<ConstantInt>(CondEvaluated))) { if (CI->isOne()) { - DEBUG(dbgs() << "Condition for Terminator " << *TI - << " evaluated to true\n"); + LLVM_DEBUG(dbgs() << "Condition for Terminator " << *TI + << " evaluated to true\n"); updateReachableEdge(B, TrueSucc); } else if (CI->isZero()) { - DEBUG(dbgs() << "Condition for Terminator " << *TI - << " evaluated to false\n"); + LLVM_DEBUG(dbgs() << "Condition for Terminator " << *TI + << " evaluated to false\n"); updateReachableEdge(B, FalseSucc); } } else { @@ -2685,8 +2695,8 @@ Value *NewGVN::findLeaderForInst(Instruction *TransInst, auto *FoundVal = findPHIOfOpsLeader(E, OrigInst, PredBB); if (!FoundVal) { ExpressionToPhiOfOps[E].insert(OrigInst); - DEBUG(dbgs() << "Cannot find phi of ops operand for " << *TransInst - << " in block " << getBlockName(PredBB) << "\n"); + LLVM_DEBUG(dbgs() << "Cannot find phi of ops operand for " << *TransInst + << " in block " << getBlockName(PredBB) << "\n"); return nullptr; } if (auto *SI = dyn_cast<StoreInst>(FoundVal)) @@ -2723,116 +2733,143 @@ NewGVN::makePossiblePHIOfOps(Instruction *I, MemAccess->getDefiningAccess()->getBlock() == I->getParent()) return nullptr; - SmallPtrSet<const Value *, 10> VisitedOps; // Convert op of phis to phi of ops - for (auto *Op : I->operand_values()) { + SmallPtrSet<const Value *, 10> VisitedOps; + SmallVector<Value *, 4> Ops(I->operand_values()); + BasicBlock *SamePHIBlock = nullptr; + PHINode *OpPHI = nullptr; + if (!DebugCounter::shouldExecute(PHIOfOpsCounter)) + return nullptr; + for (auto *Op : Ops) { if (!isa<PHINode>(Op)) { auto *ValuePHI = RealToTemp.lookup(Op); if (!ValuePHI) continue; - DEBUG(dbgs() << "Found possible dependent phi of ops\n"); + LLVM_DEBUG(dbgs() << "Found possible dependent phi of ops\n"); Op = ValuePHI; } - auto *OpPHI = cast<PHINode>(Op); + OpPHI = cast<PHINode>(Op); + if (!SamePHIBlock) { + SamePHIBlock = getBlockForValue(OpPHI); + } else if (SamePHIBlock != getBlockForValue(OpPHI)) { + LLVM_DEBUG( + dbgs() + << "PHIs for operands are not all in the same block, aborting\n"); + return nullptr; + } // No point in doing this for one-operand phis. - if (OpPHI->getNumOperands() == 1) + if (OpPHI->getNumOperands() == 1) { + OpPHI = nullptr; continue; - if (!DebugCounter::shouldExecute(PHIOfOpsCounter)) - return nullptr; - SmallVector<ValPair, 4> Ops; - SmallPtrSet<Value *, 4> Deps; - auto *PHIBlock = getBlockForValue(OpPHI); - RevisitOnReachabilityChange[PHIBlock].reset(InstrToDFSNum(I)); - for (unsigned PredNum = 0; PredNum < OpPHI->getNumOperands(); ++PredNum) { - auto *PredBB = OpPHI->getIncomingBlock(PredNum); - Value *FoundVal = nullptr; - // We could just skip unreachable edges entirely but it's tricky to do - // with rewriting existing phi nodes. - if (ReachableEdges.count({PredBB, PHIBlock})) { - // Clone the instruction, create an expression from it that is - // translated back into the predecessor, and see if we have a leader. - Instruction *ValueOp = I->clone(); - if (MemAccess) - TempToMemory.insert({ValueOp, MemAccess}); - bool SafeForPHIOfOps = true; - VisitedOps.clear(); - for (auto &Op : ValueOp->operands()) { - auto *OrigOp = &*Op; - // When these operand changes, it could change whether there is a - // leader for us or not, so we have to add additional users. - if (isa<PHINode>(Op)) { - Op = Op->DoPHITranslation(PHIBlock, PredBB); - if (Op != OrigOp && Op != I) - Deps.insert(Op); - } else if (auto *ValuePHI = RealToTemp.lookup(Op)) { - if (getBlockForValue(ValuePHI) == PHIBlock) - Op = ValuePHI->getIncomingValueForBlock(PredBB); - } - // If we phi-translated the op, it must be safe. - SafeForPHIOfOps = - SafeForPHIOfOps && - (Op != OrigOp || OpIsSafeForPHIOfOps(Op, PHIBlock, VisitedOps)); + } + } + + if (!OpPHI) + return nullptr; + + SmallVector<ValPair, 4> PHIOps; + SmallPtrSet<Value *, 4> Deps; + auto *PHIBlock = getBlockForValue(OpPHI); + RevisitOnReachabilityChange[PHIBlock].reset(InstrToDFSNum(I)); + for (unsigned PredNum = 0; PredNum < OpPHI->getNumOperands(); ++PredNum) { + auto *PredBB = OpPHI->getIncomingBlock(PredNum); + Value *FoundVal = nullptr; + SmallPtrSet<Value *, 4> CurrentDeps; + // We could just skip unreachable edges entirely but it's tricky to do + // with rewriting existing phi nodes. + if (ReachableEdges.count({PredBB, PHIBlock})) { + // Clone the instruction, create an expression from it that is + // translated back into the predecessor, and see if we have a leader. + Instruction *ValueOp = I->clone(); + if (MemAccess) + TempToMemory.insert({ValueOp, MemAccess}); + bool SafeForPHIOfOps = true; + VisitedOps.clear(); + for (auto &Op : ValueOp->operands()) { + auto *OrigOp = &*Op; + // When these operand changes, it could change whether there is a + // leader for us or not, so we have to add additional users. + if (isa<PHINode>(Op)) { + Op = Op->DoPHITranslation(PHIBlock, PredBB); + if (Op != OrigOp && Op != I) + CurrentDeps.insert(Op); + } else if (auto *ValuePHI = RealToTemp.lookup(Op)) { + if (getBlockForValue(ValuePHI) == PHIBlock) + Op = ValuePHI->getIncomingValueForBlock(PredBB); } - // FIXME: For those things that are not safe we could generate - // expressions all the way down, and see if this comes out to a - // constant. For anything where that is true, and unsafe, we should - // have made a phi-of-ops (or value numbered it equivalent to something) - // for the pieces already. - FoundVal = !SafeForPHIOfOps ? nullptr - : findLeaderForInst(ValueOp, Visited, - MemAccess, I, PredBB); - ValueOp->deleteValue(); - if (!FoundVal) - return nullptr; - } else { - DEBUG(dbgs() << "Skipping phi of ops operand for incoming block " - << getBlockName(PredBB) - << " because the block is unreachable\n"); - FoundVal = UndefValue::get(I->getType()); - RevisitOnReachabilityChange[PHIBlock].set(InstrToDFSNum(I)); + // If we phi-translated the op, it must be safe. + SafeForPHIOfOps = + SafeForPHIOfOps && + (Op != OrigOp || OpIsSafeForPHIOfOps(Op, PHIBlock, VisitedOps)); } - - Ops.push_back({FoundVal, PredBB}); - DEBUG(dbgs() << "Found phi of ops operand " << *FoundVal << " in " - << getBlockName(PredBB) << "\n"); - } - for (auto Dep : Deps) - addAdditionalUsers(Dep, I); - sortPHIOps(Ops); - auto *E = performSymbolicPHIEvaluation(Ops, I, PHIBlock); - if (isa<ConstantExpression>(E) || isa<VariableExpression>(E)) { - DEBUG(dbgs() - << "Not creating real PHI of ops because it simplified to existing " - "value or constant\n"); - return E; - } - auto *ValuePHI = RealToTemp.lookup(I); - bool NewPHI = false; - if (!ValuePHI) { - ValuePHI = - PHINode::Create(I->getType(), OpPHI->getNumOperands(), "phiofops"); - addPhiOfOps(ValuePHI, PHIBlock, I); - NewPHI = true; - NumGVNPHIOfOpsCreated++; - } - if (NewPHI) { - for (auto PHIOp : Ops) - ValuePHI->addIncoming(PHIOp.first, PHIOp.second); - } else { - unsigned int i = 0; - for (auto PHIOp : Ops) { - ValuePHI->setIncomingValue(i, PHIOp.first); - ValuePHI->setIncomingBlock(i, PHIOp.second); - ++i; + // FIXME: For those things that are not safe we could generate + // expressions all the way down, and see if this comes out to a + // constant. For anything where that is true, and unsafe, we should + // have made a phi-of-ops (or value numbered it equivalent to something) + // for the pieces already. + FoundVal = !SafeForPHIOfOps ? nullptr + : findLeaderForInst(ValueOp, Visited, + MemAccess, I, PredBB); + ValueOp->deleteValue(); + if (!FoundVal) { + // We failed to find a leader for the current ValueOp, but this might + // change in case of the translated operands change. + if (SafeForPHIOfOps) + for (auto Dep : CurrentDeps) + addAdditionalUsers(Dep, I); + + return nullptr; } + Deps.insert(CurrentDeps.begin(), CurrentDeps.end()); + } else { + LLVM_DEBUG(dbgs() << "Skipping phi of ops operand for incoming block " + << getBlockName(PredBB) + << " because the block is unreachable\n"); + FoundVal = UndefValue::get(I->getType()); + RevisitOnReachabilityChange[PHIBlock].set(InstrToDFSNum(I)); } - RevisitOnReachabilityChange[PHIBlock].set(InstrToDFSNum(I)); - DEBUG(dbgs() << "Created phi of ops " << *ValuePHI << " for " << *I - << "\n"); + PHIOps.push_back({FoundVal, PredBB}); + LLVM_DEBUG(dbgs() << "Found phi of ops operand " << *FoundVal << " in " + << getBlockName(PredBB) << "\n"); + } + for (auto Dep : Deps) + addAdditionalUsers(Dep, I); + sortPHIOps(PHIOps); + auto *E = performSymbolicPHIEvaluation(PHIOps, I, PHIBlock); + if (isa<ConstantExpression>(E) || isa<VariableExpression>(E)) { + LLVM_DEBUG( + dbgs() + << "Not creating real PHI of ops because it simplified to existing " + "value or constant\n"); return E; } - return nullptr; + auto *ValuePHI = RealToTemp.lookup(I); + bool NewPHI = false; + if (!ValuePHI) { + ValuePHI = + PHINode::Create(I->getType(), OpPHI->getNumOperands(), "phiofops"); + addPhiOfOps(ValuePHI, PHIBlock, I); + NewPHI = true; + NumGVNPHIOfOpsCreated++; + } + if (NewPHI) { + for (auto PHIOp : PHIOps) + ValuePHI->addIncoming(PHIOp.first, PHIOp.second); + } else { + TempToBlock[ValuePHI] = PHIBlock; + unsigned int i = 0; + for (auto PHIOp : PHIOps) { + ValuePHI->setIncomingValue(i, PHIOp.first); + ValuePHI->setIncomingBlock(i, PHIOp.second); + ++i; + } + } + RevisitOnReachabilityChange[PHIBlock].set(InstrToDFSNum(I)); + LLVM_DEBUG(dbgs() << "Created phi of ops " << *ValuePHI << " for " << *I + << "\n"); + + return E; } // The algorithm initially places the values of the routine in the TOP @@ -2902,8 +2939,9 @@ void NewGVN::initializeCongruenceClasses(Function &F) { void NewGVN::cleanupTables() { for (unsigned i = 0, e = CongruenceClasses.size(); i != e; ++i) { - DEBUG(dbgs() << "Congruence class " << CongruenceClasses[i]->getID() - << " has " << CongruenceClasses[i]->size() << " members\n"); + LLVM_DEBUG(dbgs() << "Congruence class " << CongruenceClasses[i]->getID() + << " has " << CongruenceClasses[i]->size() + << " members\n"); // Make sure we delete the congruence class (probably worth switching to // a unique_ptr at some point. delete CongruenceClasses[i]; @@ -2973,7 +3011,7 @@ std::pair<unsigned, unsigned> NewGVN::assignDFSNumbers(BasicBlock *B, // we change its DFS number so that it doesn't get value numbered. if (isInstructionTriviallyDead(&I, TLI)) { InstrDFS[&I] = 0; - DEBUG(dbgs() << "Skipping trivially dead instruction " << I << "\n"); + LLVM_DEBUG(dbgs() << "Skipping trivially dead instruction " << I << "\n"); markInstructionForDeletion(&I); continue; } @@ -3039,9 +3077,10 @@ void NewGVN::valueNumberMemoryPhi(MemoryPhi *MP) { [&AllSameValue](const MemoryAccess *V) { return V == AllSameValue; }); if (AllEqual) - DEBUG(dbgs() << "Memory Phi value numbered to " << *AllSameValue << "\n"); + LLVM_DEBUG(dbgs() << "Memory Phi value numbered to " << *AllSameValue + << "\n"); else - DEBUG(dbgs() << "Memory Phi value numbered to itself\n"); + LLVM_DEBUG(dbgs() << "Memory Phi value numbered to itself\n"); // If it's equal to something, it's in that class. Otherwise, it has to be in // a class where it is the leader (other things may be equivalent to it, but // it needs to start off in its own class, which means it must have been the @@ -3060,7 +3099,7 @@ void NewGVN::valueNumberMemoryPhi(MemoryPhi *MP) { // Value number a single instruction, symbolically evaluating, performing // congruence finding, and updating mappings. void NewGVN::valueNumberInstruction(Instruction *I) { - DEBUG(dbgs() << "Processing instruction " << *I << "\n"); + LLVM_DEBUG(dbgs() << "Processing instruction " << *I << "\n"); if (!I->isTerminator()) { const Expression *Symbolized = nullptr; SmallPtrSet<Value *, 2> Visited; @@ -3246,7 +3285,7 @@ void NewGVN::verifyMemoryCongruency() const { // and redoing the iteration to see if anything changed. void NewGVN::verifyIterationSettled(Function &F) { #ifndef NDEBUG - DEBUG(dbgs() << "Beginning iteration verification\n"); + LLVM_DEBUG(dbgs() << "Beginning iteration verification\n"); if (DebugCounter::isCounterSet(VNCounter)) DebugCounter::setCounterValue(VNCounter, StartingVNCounter); @@ -3364,9 +3403,9 @@ void NewGVN::iterateTouchedInstructions() { // If it's not reachable, erase any touched instructions and move on. if (!BlockReachable) { TouchedInstructions.reset(CurrInstRange.first, CurrInstRange.second); - DEBUG(dbgs() << "Skipping instructions in block " - << getBlockName(CurrBlock) - << " because it is unreachable\n"); + LLVM_DEBUG(dbgs() << "Skipping instructions in block " + << getBlockName(CurrBlock) + << " because it is unreachable\n"); continue; } updateProcessedCount(CurrBlock); @@ -3376,7 +3415,7 @@ void NewGVN::iterateTouchedInstructions() { TouchedInstructions.reset(InstrNum); if (auto *MP = dyn_cast<MemoryPhi>(V)) { - DEBUG(dbgs() << "Processing MemoryPhi " << *MP << "\n"); + LLVM_DEBUG(dbgs() << "Processing MemoryPhi " << *MP << "\n"); valueNumberMemoryPhi(MP); } else if (auto *I = dyn_cast<Instruction>(V)) { valueNumberInstruction(I); @@ -3422,10 +3461,10 @@ bool NewGVN::runGVN() { for (auto &B : RPOT) { auto *Node = DT->getNode(B); if (Node->getChildren().size() > 1) - std::sort(Node->begin(), Node->end(), - [&](const DomTreeNode *A, const DomTreeNode *B) { - return RPOOrdering[A] < RPOOrdering[B]; - }); + llvm::sort(Node->begin(), Node->end(), + [&](const DomTreeNode *A, const DomTreeNode *B) { + return RPOOrdering[A] < RPOOrdering[B]; + }); } // Now a standard depth first ordering of the domtree is equivalent to RPO. @@ -3446,8 +3485,8 @@ bool NewGVN::runGVN() { // Initialize the touched instructions to include the entry block. const auto &InstRange = BlockInstRange.lookup(&F.getEntryBlock()); TouchedInstructions.set(InstRange.first, InstRange.second); - DEBUG(dbgs() << "Block " << getBlockName(&F.getEntryBlock()) - << " marked reachable\n"); + LLVM_DEBUG(dbgs() << "Block " << getBlockName(&F.getEntryBlock()) + << " marked reachable\n"); ReachableBlocks.insert(&F.getEntryBlock()); iterateTouchedInstructions(); @@ -3472,8 +3511,8 @@ bool NewGVN::runGVN() { }; for (auto &BB : make_filter_range(F, UnreachableBlockPred)) { - DEBUG(dbgs() << "We believe block " << getBlockName(&BB) - << " is unreachable\n"); + LLVM_DEBUG(dbgs() << "We believe block " << getBlockName(&BB) + << " is unreachable\n"); deleteInstructionsInBlock(&BB); Changed = true; } @@ -3695,7 +3734,7 @@ static void patchAndReplaceAllUsesWith(Instruction *I, Value *Repl) { } void NewGVN::deleteInstructionsInBlock(BasicBlock *BB) { - DEBUG(dbgs() << " BasicBlock Dead:" << *BB); + LLVM_DEBUG(dbgs() << " BasicBlock Dead:" << *BB); ++NumGVNBlocksDeleted; // Delete the instructions backwards, as it has a reduced likelihood of having @@ -3722,12 +3761,12 @@ void NewGVN::deleteInstructionsInBlock(BasicBlock *BB) { } void NewGVN::markInstructionForDeletion(Instruction *I) { - DEBUG(dbgs() << "Marking " << *I << " for deletion\n"); + LLVM_DEBUG(dbgs() << "Marking " << *I << " for deletion\n"); InstructionsToErase.insert(I); } void NewGVN::replaceInstruction(Instruction *I, Value *V) { - DEBUG(dbgs() << "Replacing " << *I << " with " << *V << "\n"); + LLVM_DEBUG(dbgs() << "Replacing " << *I << " with " << *V << "\n"); patchAndReplaceAllUsesWith(I, V); // We save the actual erasing to avoid invalidating memory // dependencies until we are done with everything. @@ -3853,9 +3892,10 @@ bool NewGVN::eliminateInstructions(Function &F) { auto ReplaceUnreachablePHIArgs = [&](PHINode *PHI, BasicBlock *BB) { for (auto &Operand : PHI->incoming_values()) if (!ReachableEdges.count({PHI->getIncomingBlock(Operand), BB})) { - DEBUG(dbgs() << "Replacing incoming value of " << PHI << " for block " - << getBlockName(PHI->getIncomingBlock(Operand)) - << " with undef due to it being unreachable\n"); + LLVM_DEBUG(dbgs() << "Replacing incoming value of " << PHI + << " for block " + << getBlockName(PHI->getIncomingBlock(Operand)) + << " with undef due to it being unreachable\n"); Operand.set(UndefValue::get(PHI->getType())); } }; @@ -3887,7 +3927,8 @@ bool NewGVN::eliminateInstructions(Function &F) { // Map to store the use counts DenseMap<const Value *, unsigned int> UseCounts; for (auto *CC : reverse(CongruenceClasses)) { - DEBUG(dbgs() << "Eliminating in congruence class " << CC->getID() << "\n"); + LLVM_DEBUG(dbgs() << "Eliminating in congruence class " << CC->getID() + << "\n"); // Track the equivalent store info so we can decide whether to try // dead store elimination. SmallVector<ValueDFS, 8> PossibleDeadStores; @@ -3925,8 +3966,8 @@ bool NewGVN::eliminateInstructions(Function &F) { MembersLeft.insert(Member); continue; } - DEBUG(dbgs() << "Found replacement " << *(Leader) << " for " << *Member - << "\n"); + LLVM_DEBUG(dbgs() << "Found replacement " << *(Leader) << " for " + << *Member << "\n"); auto *I = cast<Instruction>(Member); assert(Leader != I && "About to accidentally remove our leader"); replaceInstruction(I, Leader); @@ -3947,7 +3988,7 @@ bool NewGVN::eliminateInstructions(Function &F) { convertClassToDFSOrdered(*CC, DFSOrderedSet, UseCounts, ProbablyDead); // Sort the whole thing. - std::sort(DFSOrderedSet.begin(), DFSOrderedSet.end()); + llvm::sort(DFSOrderedSet.begin(), DFSOrderedSet.end()); for (auto &VD : DFSOrderedSet) { int MemberDFSIn = VD.DFSIn; int MemberDFSOut = VD.DFSOut; @@ -3966,24 +4007,24 @@ bool NewGVN::eliminateInstructions(Function &F) { // remove from temp instruction list. AllTempInstructions.erase(PN); auto *DefBlock = getBlockForValue(Def); - DEBUG(dbgs() << "Inserting fully real phi of ops" << *Def - << " into block " - << getBlockName(getBlockForValue(Def)) << "\n"); + LLVM_DEBUG(dbgs() << "Inserting fully real phi of ops" << *Def + << " into block " + << getBlockName(getBlockForValue(Def)) << "\n"); PN->insertBefore(&DefBlock->front()); Def = PN; NumGVNPHIOfOpsEliminations++; } if (EliminationStack.empty()) { - DEBUG(dbgs() << "Elimination Stack is empty\n"); + LLVM_DEBUG(dbgs() << "Elimination Stack is empty\n"); } else { - DEBUG(dbgs() << "Elimination Stack Top DFS numbers are (" - << EliminationStack.dfs_back().first << "," - << EliminationStack.dfs_back().second << ")\n"); + LLVM_DEBUG(dbgs() << "Elimination Stack Top DFS numbers are (" + << EliminationStack.dfs_back().first << "," + << EliminationStack.dfs_back().second << ")\n"); } - DEBUG(dbgs() << "Current DFS numbers are (" << MemberDFSIn << "," - << MemberDFSOut << ")\n"); + LLVM_DEBUG(dbgs() << "Current DFS numbers are (" << MemberDFSIn << "," + << MemberDFSOut << ")\n"); // First, we see if we are out of scope or empty. If so, // and there equivalences, we try to replace the top of // stack with equivalences (if it's on the stack, it must @@ -4058,14 +4099,16 @@ bool NewGVN::eliminateInstructions(Function &F) { Value *DominatingLeader = EliminationStack.back(); auto *II = dyn_cast<IntrinsicInst>(DominatingLeader); - if (II && II->getIntrinsicID() == Intrinsic::ssa_copy) + bool isSSACopy = II && II->getIntrinsicID() == Intrinsic::ssa_copy; + if (isSSACopy) DominatingLeader = II->getOperand(0); // Don't replace our existing users with ourselves. if (U->get() == DominatingLeader) continue; - DEBUG(dbgs() << "Found replacement " << *DominatingLeader << " for " - << *U->get() << " in " << *(U->getUser()) << "\n"); + LLVM_DEBUG(dbgs() + << "Found replacement " << *DominatingLeader << " for " + << *U->get() << " in " << *(U->getUser()) << "\n"); // If we replaced something in an instruction, handle the patching of // metadata. Skip this if we are replacing predicateinfo with its @@ -4081,7 +4124,9 @@ bool NewGVN::eliminateInstructions(Function &F) { // It's about to be alive again. if (LeaderUseCount == 0 && isa<Instruction>(DominatingLeader)) ProbablyDead.erase(cast<Instruction>(DominatingLeader)); - if (LeaderUseCount == 0 && II) + // Copy instructions, however, are still dead because we use their + // operand as the leader. + if (LeaderUseCount == 0 && isSSACopy) ProbablyDead.insert(II); ++LeaderUseCount; AnythingReplaced = true; @@ -4106,7 +4151,7 @@ bool NewGVN::eliminateInstructions(Function &F) { // If we have possible dead stores to look at, try to eliminate them. if (CC->getStoreCount() > 0) { convertClassToLoadsAndStores(*CC, PossibleDeadStores); - std::sort(PossibleDeadStores.begin(), PossibleDeadStores.end()); + llvm::sort(PossibleDeadStores.begin(), PossibleDeadStores.end()); ValueDFSStack EliminationStack; for (auto &VD : PossibleDeadStores) { int MemberDFSIn = VD.DFSIn; @@ -4129,8 +4174,8 @@ bool NewGVN::eliminateInstructions(Function &F) { (void)Leader; assert(DT->dominates(Leader->getParent(), Member->getParent())); // Member is dominater by Leader, and thus dead - DEBUG(dbgs() << "Marking dead store " << *Member - << " that is dominated by " << *Leader << "\n"); + LLVM_DEBUG(dbgs() << "Marking dead store " << *Member + << " that is dominated by " << *Leader << "\n"); markInstructionForDeletion(Member); CC->erase(Member); ++NumGVNDeadStores; diff --git a/lib/Transforms/Scalar/PlaceSafepoints.cpp b/lib/Transforms/Scalar/PlaceSafepoints.cpp index 2d0cb6fbf211..8f30bccf48f1 100644 --- a/lib/Transforms/Scalar/PlaceSafepoints.cpp +++ b/lib/Transforms/Scalar/PlaceSafepoints.cpp @@ -55,6 +55,7 @@ #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IntrinsicInst.h" @@ -65,7 +66,6 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" -#include "llvm/Transforms/Utils/Local.h" #define DEBUG_TYPE "safepoint-placement" @@ -323,7 +323,7 @@ bool PlaceBackedgeSafepointsImpl::runOnLoop(Loop *L) { // avoiding the runtime cost of the actual safepoint. if (!AllBackedges) { if (mustBeFiniteCountedLoop(L, SE, Pred)) { - DEBUG(dbgs() << "skipping safepoint placement in finite loop\n"); + LLVM_DEBUG(dbgs() << "skipping safepoint placement in finite loop\n"); FiniteExecution++; continue; } @@ -332,7 +332,9 @@ bool PlaceBackedgeSafepointsImpl::runOnLoop(Loop *L) { // Note: This is only semantically legal since we won't do any further // IPO or inlining before the actual call insertion.. If we hadn't, we // might latter loose this call safepoint. - DEBUG(dbgs() << "skipping safepoint placement due to unconditional call\n"); + LLVM_DEBUG( + dbgs() + << "skipping safepoint placement due to unconditional call\n"); CallInLoop++; continue; } @@ -348,7 +350,7 @@ bool PlaceBackedgeSafepointsImpl::runOnLoop(Loop *L) { // variables) and branches to the true header TerminatorInst *Term = Pred->getTerminator(); - DEBUG(dbgs() << "[LSP] terminator instruction: " << *Term); + LLVM_DEBUG(dbgs() << "[LSP] terminator instruction: " << *Term); PollLocations.push_back(Term); } @@ -522,7 +524,7 @@ bool PlaceSafepoints::runOnFunction(Function &F) { }; // We need the order of list to be stable so that naming ends up stable // when we split edges. This makes test cases much easier to write. - std::sort(PollLocations.begin(), PollLocations.end(), OrderByBBName); + llvm::sort(PollLocations.begin(), PollLocations.end(), OrderByBBName); // We can sometimes end up with duplicate poll locations. This happens if // a single loop is visited more than once. The fact this happens seems diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp index 88dcaf0f8a36..c81ac70d99e6 100644 --- a/lib/Transforms/Scalar/Reassociate.cpp +++ b/lib/Transforms/Scalar/Reassociate.cpp @@ -31,6 +31,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" @@ -42,6 +43,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" @@ -55,7 +57,6 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> #include <utility> @@ -168,8 +169,8 @@ void ReassociatePass::BuildRankMap(Function &F, // Assign distinct ranks to function arguments. for (auto &Arg : F.args()) { ValueRankMap[&Arg] = ++Rank; - DEBUG(dbgs() << "Calculated Rank[" << Arg.getName() << "] = " << Rank - << "\n"); + LLVM_DEBUG(dbgs() << "Calculated Rank[" << Arg.getName() << "] = " << Rank + << "\n"); } // Traverse basic blocks in ReversePostOrder @@ -200,17 +201,17 @@ unsigned ReassociatePass::getRank(Value *V) { // for PHI nodes, we cannot have infinite recursion here, because there // cannot be loops in the value graph that do not go through PHI nodes. unsigned Rank = 0, MaxRank = RankMap[I->getParent()]; - for (unsigned i = 0, e = I->getNumOperands(); - i != e && Rank != MaxRank; ++i) + for (unsigned i = 0, e = I->getNumOperands(); i != e && Rank != MaxRank; ++i) Rank = std::max(Rank, getRank(I->getOperand(i))); // If this is a not or neg instruction, do not count it for rank. This // assures us that X and ~X will have the same rank. - if (!BinaryOperator::isNot(I) && !BinaryOperator::isNeg(I) && - !BinaryOperator::isFNeg(I)) + if (!BinaryOperator::isNot(I) && !BinaryOperator::isNeg(I) && + !BinaryOperator::isFNeg(I)) ++Rank; - DEBUG(dbgs() << "Calculated Rank[" << V->getName() << "] = " << Rank << "\n"); + LLVM_DEBUG(dbgs() << "Calculated Rank[" << V->getName() << "] = " << Rank + << "\n"); return ValueRankMap[I] = Rank; } @@ -445,7 +446,7 @@ using RepeatedValue = std::pair<Value*, APInt>; /// type and thus make the expression bigger. static bool LinearizeExprTree(BinaryOperator *I, SmallVectorImpl<RepeatedValue> &Ops) { - DEBUG(dbgs() << "LINEARIZE: " << *I << '\n'); + LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n'); unsigned Bitwidth = I->getType()->getScalarType()->getPrimitiveSizeInBits(); unsigned Opcode = I->getOpcode(); assert(I->isAssociative() && I->isCommutative() && @@ -494,14 +495,14 @@ static bool LinearizeExprTree(BinaryOperator *I, for (unsigned OpIdx = 0; OpIdx < 2; ++OpIdx) { // Visit operands. Value *Op = I->getOperand(OpIdx); APInt Weight = P.second; // Number of paths to this operand. - DEBUG(dbgs() << "OPERAND: " << *Op << " (" << Weight << ")\n"); + LLVM_DEBUG(dbgs() << "OPERAND: " << *Op << " (" << Weight << ")\n"); assert(!Op->use_empty() && "No uses, so how did we get to it?!"); // If this is a binary operation of the right kind with only one use then // add its operands to the expression. if (BinaryOperator *BO = isReassociableOp(Op, Opcode)) { assert(Visited.insert(Op).second && "Not first visit!"); - DEBUG(dbgs() << "DIRECT ADD: " << *Op << " (" << Weight << ")\n"); + LLVM_DEBUG(dbgs() << "DIRECT ADD: " << *Op << " (" << Weight << ")\n"); Worklist.push_back(std::make_pair(BO, Weight)); continue; } @@ -514,7 +515,8 @@ static bool LinearizeExprTree(BinaryOperator *I, if (!Op->hasOneUse()) { // This value has uses not accounted for by the expression, so it is // not safe to modify. Mark it as being a leaf. - DEBUG(dbgs() << "ADD USES LEAF: " << *Op << " (" << Weight << ")\n"); + LLVM_DEBUG(dbgs() + << "ADD USES LEAF: " << *Op << " (" << Weight << ")\n"); LeafOrder.push_back(Op); Leaves[Op] = Weight; continue; @@ -540,7 +542,7 @@ static bool LinearizeExprTree(BinaryOperator *I, // to the expression, then no longer consider it to be a leaf and add // its operands to the expression. if (BinaryOperator *BO = isReassociableOp(Op, Opcode)) { - DEBUG(dbgs() << "UNLEAF: " << *Op << " (" << It->second << ")\n"); + LLVM_DEBUG(dbgs() << "UNLEAF: " << *Op << " (" << It->second << ")\n"); Worklist.push_back(std::make_pair(BO, It->second)); Leaves.erase(It); continue; @@ -573,9 +575,10 @@ static bool LinearizeExprTree(BinaryOperator *I, if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op)) if ((Opcode == Instruction::Mul && BinaryOperator::isNeg(BO)) || (Opcode == Instruction::FMul && BinaryOperator::isFNeg(BO))) { - DEBUG(dbgs() << "MORPH LEAF: " << *Op << " (" << Weight << ") TO "); + LLVM_DEBUG(dbgs() + << "MORPH LEAF: " << *Op << " (" << Weight << ") TO "); BO = LowerNegateToMultiply(BO); - DEBUG(dbgs() << *BO << '\n'); + LLVM_DEBUG(dbgs() << *BO << '\n'); Worklist.push_back(std::make_pair(BO, Weight)); Changed = true; continue; @@ -583,7 +586,7 @@ static bool LinearizeExprTree(BinaryOperator *I, // Failed to morph into an expression of the right type. This really is // a leaf. - DEBUG(dbgs() << "ADD LEAF: " << *Op << " (" << Weight << ")\n"); + LLVM_DEBUG(dbgs() << "ADD LEAF: " << *Op << " (" << Weight << ")\n"); assert(!isReassociableOp(Op, Opcode) && "Value was morphed?"); LeafOrder.push_back(Op); Leaves[Op] = Weight; @@ -675,9 +678,9 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, if (NewLHS == OldRHS && NewRHS == OldLHS) { // The order of the operands was reversed. Swap them. - DEBUG(dbgs() << "RA: " << *Op << '\n'); + LLVM_DEBUG(dbgs() << "RA: " << *Op << '\n'); Op->swapOperands(); - DEBUG(dbgs() << "TO: " << *Op << '\n'); + LLVM_DEBUG(dbgs() << "TO: " << *Op << '\n'); MadeChange = true; ++NumChanged; break; @@ -685,7 +688,7 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, // The new operation differs non-trivially from the original. Overwrite // the old operands with the new ones. - DEBUG(dbgs() << "RA: " << *Op << '\n'); + LLVM_DEBUG(dbgs() << "RA: " << *Op << '\n'); if (NewLHS != OldLHS) { BinaryOperator *BO = isReassociableOp(OldLHS, Opcode); if (BO && !NotRewritable.count(BO)) @@ -698,7 +701,7 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, NodesToRewrite.push_back(BO); Op->setOperand(1, NewRHS); } - DEBUG(dbgs() << "TO: " << *Op << '\n'); + LLVM_DEBUG(dbgs() << "TO: " << *Op << '\n'); ExpressionChanged = Op; MadeChange = true; @@ -711,7 +714,7 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, // while the right-hand side will be the current element of Ops. Value *NewRHS = Ops[i].Op; if (NewRHS != Op->getOperand(1)) { - DEBUG(dbgs() << "RA: " << *Op << '\n'); + LLVM_DEBUG(dbgs() << "RA: " << *Op << '\n'); if (NewRHS == Op->getOperand(0)) { // The new right-hand side was already present as the left operand. If // we are lucky then swapping the operands will sort out both of them. @@ -724,7 +727,7 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, Op->setOperand(1, NewRHS); ExpressionChanged = Op; } - DEBUG(dbgs() << "TO: " << *Op << '\n'); + LLVM_DEBUG(dbgs() << "TO: " << *Op << '\n'); MadeChange = true; ++NumChanged; } @@ -756,9 +759,9 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, NewOp = NodesToRewrite.pop_back_val(); } - DEBUG(dbgs() << "RA: " << *Op << '\n'); + LLVM_DEBUG(dbgs() << "RA: " << *Op << '\n'); Op->setOperand(0, NewOp); - DEBUG(dbgs() << "TO: " << *Op << '\n'); + LLVM_DEBUG(dbgs() << "TO: " << *Op << '\n'); ExpressionChanged = Op; MadeChange = true; ++NumChanged; @@ -781,6 +784,18 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, if (ExpressionChanged == I) break; + + // Discard any debug info related to the expressions that has changed (we + // can leave debug infor related to the root, since the result of the + // expression tree should be the same even after reassociation). + SmallVector<DbgInfoIntrinsic *, 1> DbgUsers; + findDbgUsers(DbgUsers, ExpressionChanged); + for (auto *DII : DbgUsers) { + Value *Undef = UndefValue::get(ExpressionChanged->getType()); + DII->setOperand(0, MetadataAsValue::get(DII->getContext(), + ValueAsMetadata::get(Undef))); + } + ExpressionChanged->moveBefore(I); ExpressionChanged = cast<BinaryOperator>(*ExpressionChanged->user_begin()); } while (true); @@ -798,7 +813,7 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, /// pushing the negates through adds. These will be revisited to see if /// additional opportunities have been exposed. static Value *NegateValue(Value *V, Instruction *BI, - SetVector<AssertingVH<Instruction>> &ToRedo) { + ReassociatePass::OrderedSet &ToRedo) { if (auto *C = dyn_cast<Constant>(V)) return C->getType()->isFPOrFPVectorTy() ? ConstantExpr::getFNeg(C) : ConstantExpr::getNeg(C); @@ -912,8 +927,8 @@ static bool ShouldBreakUpSubtract(Instruction *Sub) { /// If we have (X-Y), and if either X is an add, or if this is only used by an /// add, transform this into (X+(0-Y)) to promote better reassociation. -static BinaryOperator * -BreakUpSubtract(Instruction *Sub, SetVector<AssertingVH<Instruction>> &ToRedo) { +static BinaryOperator *BreakUpSubtract(Instruction *Sub, + ReassociatePass::OrderedSet &ToRedo) { // Convert a subtract into an add and a neg instruction. This allows sub // instructions to be commuted with other add instructions. // @@ -929,7 +944,7 @@ BreakUpSubtract(Instruction *Sub, SetVector<AssertingVH<Instruction>> &ToRedo) { Sub->replaceAllUsesWith(New); New->setDebugLoc(Sub->getDebugLoc()); - DEBUG(dbgs() << "Negated: " << *New << '\n'); + LLVM_DEBUG(dbgs() << "Negated: " << *New << '\n'); return New; } @@ -1415,7 +1430,8 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I, ++NumFound; } while (i != Ops.size() && Ops[i].Op == TheOp); - DEBUG(dbgs() << "\nFACTORING [" << NumFound << "]: " << *TheOp << '\n'); + LLVM_DEBUG(dbgs() << "\nFACTORING [" << NumFound << "]: " << *TheOp + << '\n'); ++NumFactor; // Insert a new multiply. @@ -1553,7 +1569,8 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I, // If any factor occurred more than one time, we can pull it out. if (MaxOcc > 1) { - DEBUG(dbgs() << "\nFACTORING [" << MaxOcc << "]: " << *MaxOccVal << '\n'); + LLVM_DEBUG(dbgs() << "\nFACTORING [" << MaxOcc << "]: " << *MaxOccVal + << '\n'); ++NumFactor; // Create a new instruction that uses the MaxOccVal twice. If we don't do @@ -1622,7 +1639,7 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I, return nullptr; } -/// \brief Build up a vector of value/power pairs factoring a product. +/// Build up a vector of value/power pairs factoring a product. /// /// Given a series of multiplication operands, build a vector of factors and /// the powers each is raised to when forming the final product. Sort them in @@ -1687,7 +1704,7 @@ static bool collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops, return true; } -/// \brief Build a tree of multiplies, computing the product of Ops. +/// Build a tree of multiplies, computing the product of Ops. static Value *buildMultiplyTree(IRBuilder<> &Builder, SmallVectorImpl<Value*> &Ops) { if (Ops.size() == 1) @@ -1704,7 +1721,7 @@ static Value *buildMultiplyTree(IRBuilder<> &Builder, return LHS; } -/// \brief Build a minimal multiplication DAG for (a^x)*(b^y)*(c^z)*... +/// Build a minimal multiplication DAG for (a^x)*(b^y)*(c^z)*... /// /// Given a vector of values raised to various powers, where no two values are /// equal and the powers are sorted in decreasing order, compute the minimal @@ -1859,8 +1876,8 @@ Value *ReassociatePass::OptimizeExpression(BinaryOperator *I, // Remove dead instructions and if any operands are trivially dead add them to // Insts so they will be removed as well. -void ReassociatePass::RecursivelyEraseDeadInsts( - Instruction *I, SetVector<AssertingVH<Instruction>> &Insts) { +void ReassociatePass::RecursivelyEraseDeadInsts(Instruction *I, + OrderedSet &Insts) { assert(isInstructionTriviallyDead(I) && "Trivially dead instructions only!"); SmallVector<Value *, 4> Ops(I->op_begin(), I->op_end()); ValueRankMap.erase(I); @@ -1876,7 +1893,7 @@ void ReassociatePass::RecursivelyEraseDeadInsts( /// Zap the given instruction, adding interesting operands to the work list. void ReassociatePass::EraseInst(Instruction *I) { assert(isInstructionTriviallyDead(I) && "Trivially dead instructions only!"); - DEBUG(dbgs() << "Erasing dead inst: "; I->dump()); + LLVM_DEBUG(dbgs() << "Erasing dead inst: "; I->dump()); SmallVector<Value*, 8> Ops(I->op_begin(), I->op_end()); // Erase the dead instruction. @@ -1893,7 +1910,14 @@ void ReassociatePass::EraseInst(Instruction *I) { while (Op->hasOneUse() && Op->user_back()->getOpcode() == Opcode && Visited.insert(Op).second) Op = Op->user_back(); - RedoInsts.insert(Op); + + // The instruction we're going to push may be coming from a + // dead block, and Reassociate skips the processing of unreachable + // blocks because it's a waste of time and also because it can + // lead to infinite loop due to LLVM's non-standard definition + // of dominance. + if (ValueRankMap.find(Op) != ValueRankMap.end()) + RedoInsts.insert(Op); } MadeChange = true; @@ -2120,7 +2144,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { ValueEntry(getRank(E.first), E.first)); } - DEBUG(dbgs() << "RAIn:\t"; PrintOps(I, Ops); dbgs() << '\n'); + LLVM_DEBUG(dbgs() << "RAIn:\t"; PrintOps(I, Ops); dbgs() << '\n'); // Now that we have linearized the tree to a list and have gathered all of // the operands and their ranks, sort the operands by their rank. Use a @@ -2138,7 +2162,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { return; // This expression tree simplified to something that isn't a tree, // eliminate it. - DEBUG(dbgs() << "Reassoc to scalar: " << *V << '\n'); + LLVM_DEBUG(dbgs() << "Reassoc to scalar: " << *V << '\n'); I->replaceAllUsesWith(V); if (Instruction *VI = dyn_cast<Instruction>(V)) if (I->getDebugLoc()) @@ -2169,7 +2193,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { } } - DEBUG(dbgs() << "RAOut:\t"; PrintOps(I, Ops); dbgs() << '\n'); + LLVM_DEBUG(dbgs() << "RAOut:\t"; PrintOps(I, Ops); dbgs() << '\n'); if (Ops.size() == 1) { if (Ops[0].Op == I) @@ -2321,7 +2345,7 @@ PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { // Make a copy of all the instructions to be redone so we can remove dead // instructions. - SetVector<AssertingVH<Instruction>> ToRedo(RedoInsts); + OrderedSet ToRedo(RedoInsts); // Iterate over all instructions to be reevaluated and remove trivially dead // instructions. If any operand of the trivially dead instruction becomes // dead mark it for deletion as well. Continue this process until all @@ -2337,7 +2361,8 @@ PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { // Now that we have removed dead instructions, we can reoptimize the // remaining instructions. while (!RedoInsts.empty()) { - Instruction *I = RedoInsts.pop_back_val(); + Instruction *I = RedoInsts.front(); + RedoInsts.erase(RedoInsts.begin()); if (isInstructionTriviallyDead(I)) EraseInst(I); else diff --git a/lib/Transforms/Scalar/Reg2Mem.cpp b/lib/Transforms/Scalar/Reg2Mem.cpp index 96295683314c..018feb035a4f 100644 --- a/lib/Transforms/Scalar/Reg2Mem.cpp +++ b/lib/Transforms/Scalar/Reg2Mem.cpp @@ -17,6 +17,7 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/Statistic.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Function.h" @@ -25,7 +26,7 @@ #include "llvm/IR/Module.h" #include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils.h" #include <list> using namespace llvm; diff --git a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index c44edbed8ed9..391e43f79121 100644 --- a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -28,6 +28,7 @@ #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -64,7 +65,6 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" #include <algorithm> #include <cassert> @@ -476,6 +476,12 @@ findBaseDefiningValueOfVector(Value *I) { if (auto *BC = dyn_cast<BitCastInst>(I)) return findBaseDefiningValue(BC->getOperand(0)); + // We assume that functions in the source language only return base + // pointers. This should probably be generalized via attributes to support + // both source language and internal functions. + if (isa<CallInst>(I) || isa<InvokeInst>(I)) + return BaseDefiningValueResult(I, true); + // A PHI or Select is a base defining value. The outer findBasePointer // algorithm is responsible for constructing a base value for this BDV. assert((isa<SelectInst>(I) || isa<PHINode>(I)) && @@ -610,8 +616,8 @@ static Value *findBaseDefiningValueCached(Value *I, DefiningValueMapTy &Cache) { Value *&Cached = Cache[I]; if (!Cached) { Cached = findBaseDefiningValue(I).BDV; - DEBUG(dbgs() << "fBDV-cached: " << I->getName() << " -> " - << Cached->getName() << "\n"); + LLVM_DEBUG(dbgs() << "fBDV-cached: " << I->getName() << " -> " + << Cached->getName() << "\n"); } assert(Cache[I] != nullptr); return Cached; @@ -842,9 +848,9 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { } #ifndef NDEBUG - DEBUG(dbgs() << "States after initialization:\n"); + LLVM_DEBUG(dbgs() << "States after initialization:\n"); for (auto Pair : States) { - DEBUG(dbgs() << " " << Pair.second << " for " << *Pair.first << "\n"); + LLVM_DEBUG(dbgs() << " " << Pair.second << " for " << *Pair.first << "\n"); } #endif @@ -917,9 +923,9 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { } #ifndef NDEBUG - DEBUG(dbgs() << "States after meet iteration:\n"); + LLVM_DEBUG(dbgs() << "States after meet iteration:\n"); for (auto Pair : States) { - DEBUG(dbgs() << " " << Pair.second << " for " << *Pair.first << "\n"); + LLVM_DEBUG(dbgs() << " " << Pair.second << " for " << *Pair.first << "\n"); } #endif @@ -960,7 +966,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { auto MakeBaseInstPlaceholder = [](Instruction *I) -> Instruction* { if (isa<PHINode>(I)) { BasicBlock *BB = I->getParent(); - int NumPreds = std::distance(pred_begin(BB), pred_end(BB)); + int NumPreds = pred_size(BB); assert(NumPreds > 0 && "how did we reach here"); std::string Name = suffixed_name_or(I, ".base", "base_phi"); return PHINode::Create(I->getType(), NumPreds, Name, I); @@ -1118,10 +1124,11 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { assert(BDV && Base); assert(!isKnownBaseResult(BDV) && "why did it get added?"); - DEBUG(dbgs() << "Updating base value cache" - << " for: " << BDV->getName() << " from: " - << (Cache.count(BDV) ? Cache[BDV]->getName().str() : "none") - << " to: " << Base->getName() << "\n"); + LLVM_DEBUG( + dbgs() << "Updating base value cache" + << " for: " << BDV->getName() << " from: " + << (Cache.count(BDV) ? Cache[BDV]->getName().str() : "none") + << " to: " << Base->getName() << "\n"); if (Cache.count(BDV)) { assert(isKnownBaseResult(Base) && @@ -1369,7 +1376,7 @@ public: assert(OldI != NewI && "Disallowed at construction?!"); assert((!IsDeoptimize || !New) && - "Deoptimize instrinsics are not replaced!"); + "Deoptimize intrinsics are not replaced!"); Old = nullptr; New = nullptr; @@ -1379,7 +1386,7 @@ public: if (IsDeoptimize) { // Note: we've inserted instructions, so the call to llvm.deoptimize may - // not necessarilly be followed by the matching return. + // not necessarily be followed by the matching return. auto *RI = cast<ReturnInst>(OldI->getParent()->getTerminator()); new UnreachableInst(RI->getContext(), RI); RI->eraseFromParent(); @@ -1805,7 +1812,7 @@ static void relocationViaAlloca( SmallVector<Instruction *, 20> Uses; // PERF: trade a linear scan for repeated reallocation - Uses.reserve(std::distance(Def->user_begin(), Def->user_end())); + Uses.reserve(Def->getNumUses()); for (User *U : Def->users()) { if (!isa<ConstantExpr>(U)) { // If the def has a ConstantExpr use, then the def is either a @@ -1817,7 +1824,7 @@ static void relocationViaAlloca( } } - std::sort(Uses.begin(), Uses.end()); + llvm::sort(Uses.begin(), Uses.end()); auto Last = std::unique(Uses.begin(), Uses.end()); Uses.erase(Last, Uses.end()); @@ -1977,7 +1984,7 @@ chainToBasePointerCost(SmallVectorImpl<Instruction*> &Chain, Cost += 2; } else { - llvm_unreachable("unsupported instruciton type during rematerialization"); + llvm_unreachable("unsupported instruction type during rematerialization"); } } @@ -2024,7 +2031,7 @@ static void rematerializeLiveValues(CallSite CS, SmallVector<Value *, 32> LiveValuesToBeDeleted; for (Value *LiveValue: Info.LiveSet) { - // For each live pointer find it's defining chain + // For each live pointer find its defining chain SmallVector<Instruction *, 3> ChainToBase; assert(Info.PointerToBase.count(LiveValue)); Value *RootOfChain = @@ -2461,22 +2468,8 @@ static void stripNonValidDataFromBody(Function &F) { continue; } - if (const MDNode *MD = I.getMetadata(LLVMContext::MD_tbaa)) { - assert(MD->getNumOperands() < 5 && "unrecognized metadata shape!"); - bool IsImmutableTBAA = - MD->getNumOperands() == 4 && - mdconst::extract<ConstantInt>(MD->getOperand(3))->getValue() == 1; - - if (!IsImmutableTBAA) - continue; // no work to do, MD_tbaa is already marked mutable - - MDNode *Base = cast<MDNode>(MD->getOperand(0)); - MDNode *Access = cast<MDNode>(MD->getOperand(1)); - uint64_t Offset = - mdconst::extract<ConstantInt>(MD->getOperand(2))->getZExtValue(); - - MDNode *MutableTBAA = - Builder.createTBAAStructTagNode(Base, Access, Offset); + if (MDNode *Tag = I.getMetadata(LLVMContext::MD_tbaa)) { + MDNode *MutableTBAA = Builder.createMutableTBAAAccessTag(Tag); I.setMetadata(LLVMContext::MD_tbaa, MutableTBAA); } @@ -2537,30 +2530,31 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT, return false; }; + + // Delete any unreachable statepoints so that we don't have unrewritten + // statepoints surviving this pass. This makes testing easier and the + // resulting IR less confusing to human readers. + DeferredDominance DD(DT); + bool MadeChange = removeUnreachableBlocks(F, nullptr, &DD); + DD.flush(); + // Gather all the statepoints which need rewritten. Be careful to only // consider those in reachable code since we need to ask dominance queries // when rewriting. We'll delete the unreachable ones in a moment. SmallVector<CallSite, 64> ParsePointNeeded; - bool HasUnreachableStatepoint = false; for (Instruction &I : instructions(F)) { // TODO: only the ones with the flag set! if (NeedsRewrite(I)) { - if (DT.isReachableFromEntry(I.getParent())) - ParsePointNeeded.push_back(CallSite(&I)); - else - HasUnreachableStatepoint = true; + // NOTE removeUnreachableBlocks() is stronger than + // DominatorTree::isReachableFromEntry(). In other words + // removeUnreachableBlocks can remove some blocks for which + // isReachableFromEntry() returns true. + assert(DT.isReachableFromEntry(I.getParent()) && + "no unreachable blocks expected"); + ParsePointNeeded.push_back(CallSite(&I)); } } - bool MadeChange = false; - - // Delete any unreachable statepoints so that we don't have unrewritten - // statepoints surviving this pass. This makes testing easier and the - // resulting IR less confusing to human readers. Rather than be fancy, we - // just reuse a utility function which removes the unreachable blocks. - if (HasUnreachableStatepoint) - MadeChange |= removeUnreachableBlocks(F); - // Return early if no work to do. if (ParsePointNeeded.empty()) return MadeChange; diff --git a/lib/Transforms/Scalar/SCCP.cpp b/lib/Transforms/Scalar/SCCP.cpp index 66608ec631f6..5e3ddeda2d49 100644 --- a/lib/Transforms/Scalar/SCCP.cpp +++ b/lib/Transforms/Scalar/SCCP.cpp @@ -17,7 +17,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO/SCCP.h" #include "llvm/Transforms/Scalar/SCCP.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -30,6 +29,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueLattice.h" #include "llvm/Analysis/ValueLatticeUtils.h" #include "llvm/IR/BasicBlock.h" @@ -54,9 +54,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <utility> #include <vector> @@ -71,8 +69,6 @@ STATISTIC(NumDeadBlocks , "Number of basic blocks unreachable"); STATISTIC(IPNumInstRemoved, "Number of instructions removed by IPSCCP"); STATISTIC(IPNumArgsElimed ,"Number of arguments constant propagated by IPSCCP"); STATISTIC(IPNumGlobalConst, "Number of globals found to be constant by IPSCCP"); -STATISTIC(IPNumRangeInfoUsed, "Number of times constant range info was used by" - "IPSCCP"); namespace { @@ -223,6 +219,10 @@ class SCCPSolver : public InstVisitor<SCCPSolver> { /// represented here for efficient lookup. SmallPtrSet<Function *, 16> MRVFunctionsTracked; + /// MustTailFunctions - Each function here is a callee of non-removable + /// musttail call site. + SmallPtrSet<Function *, 16> MustTailCallees; + /// TrackingIncomingArguments - This is the set of functions for whose /// arguments we make optimistic assumptions about and try to prove as /// constants. @@ -257,7 +257,7 @@ public: bool MarkBlockExecutable(BasicBlock *BB) { if (!BBExecutable.insert(BB).second) return false; - DEBUG(dbgs() << "Marking Block Executable: " << BB->getName() << '\n'); + LLVM_DEBUG(dbgs() << "Marking Block Executable: " << BB->getName() << '\n'); BBWorkList.push_back(BB); // Add the block to the work list! return true; } @@ -289,6 +289,18 @@ public: TrackedRetVals.insert(std::make_pair(F, LatticeVal())); } + /// AddMustTailCallee - If the SCCP solver finds that this function is called + /// from non-removable musttail call site. + void AddMustTailCallee(Function *F) { + MustTailCallees.insert(F); + } + + /// Returns true if the given function is called from non-removable musttail + /// call site. + bool isMustTailCallee(Function *F) { + return MustTailCallees.count(F); + } + void AddArgumentTrackedFunction(Function *F) { TrackingIncomingArguments.insert(F); } @@ -313,6 +325,10 @@ public: return BBExecutable.count(BB); } + // isEdgeFeasible - Return true if the control flow edge from the 'From' basic + // block to the 'To' basic block is currently feasible. + bool isEdgeFeasible(BasicBlock *From, BasicBlock *To); + std::vector<LatticeVal> getStructLatticeValueFor(Value *V) const { std::vector<LatticeVal> StructValues; auto *STy = dyn_cast<StructType>(V->getType()); @@ -325,20 +341,13 @@ public: return StructValues; } - ValueLatticeElement getLatticeValueFor(Value *V) { + const LatticeVal &getLatticeValueFor(Value *V) const { assert(!V->getType()->isStructTy() && "Should use getStructLatticeValueFor"); - std::pair<DenseMap<Value*, ValueLatticeElement>::iterator, bool> - PI = ParamState.insert(std::make_pair(V, ValueLatticeElement())); - ValueLatticeElement &LV = PI.first->second; - if (PI.second) { - DenseMap<Value*, LatticeVal>::const_iterator I = ValueState.find(V); - assert(I != ValueState.end() && - "V not found in ValueState nor Paramstate map!"); - LV = I->second.toValueLattice(); - } - - return LV; + DenseMap<Value *, LatticeVal>::const_iterator I = ValueState.find(V); + assert(I != ValueState.end() && + "V not found in ValueState nor Paramstate map!"); + return I->second; } /// getTrackedRetVals - Get the inferred return value map. @@ -358,6 +367,12 @@ public: return MRVFunctionsTracked; } + /// getMustTailCallees - Get the set of functions which are called + /// from non-removable musttail call sites. + const SmallPtrSet<Function *, 16> getMustTailCallees() { + return MustTailCallees; + } + /// markOverdefined - Mark the specified value overdefined. This /// works with both scalars and structs. void markOverdefined(Value *V) { @@ -393,55 +408,57 @@ private: // markConstant - Make a value be marked as "constant". If the value // is not already a constant, add it to the instruction work list so that // the users of the instruction are updated later. - void markConstant(LatticeVal &IV, Value *V, Constant *C) { - if (!IV.markConstant(C)) return; - DEBUG(dbgs() << "markConstant: " << *C << ": " << *V << '\n'); + bool markConstant(LatticeVal &IV, Value *V, Constant *C) { + if (!IV.markConstant(C)) return false; + LLVM_DEBUG(dbgs() << "markConstant: " << *C << ": " << *V << '\n'); pushToWorkList(IV, V); + return true; } - void markConstant(Value *V, Constant *C) { + bool markConstant(Value *V, Constant *C) { assert(!V->getType()->isStructTy() && "structs should use mergeInValue"); - markConstant(ValueState[V], V, C); + return markConstant(ValueState[V], V, C); } void markForcedConstant(Value *V, Constant *C) { assert(!V->getType()->isStructTy() && "structs should use mergeInValue"); LatticeVal &IV = ValueState[V]; IV.markForcedConstant(C); - DEBUG(dbgs() << "markForcedConstant: " << *C << ": " << *V << '\n'); + LLVM_DEBUG(dbgs() << "markForcedConstant: " << *C << ": " << *V << '\n'); pushToWorkList(IV, V); } // markOverdefined - Make a value be marked as "overdefined". If the // value is not already overdefined, add it to the overdefined instruction // work list so that the users of the instruction are updated later. - void markOverdefined(LatticeVal &IV, Value *V) { - if (!IV.markOverdefined()) return; - - DEBUG(dbgs() << "markOverdefined: "; - if (auto *F = dyn_cast<Function>(V)) - dbgs() << "Function '" << F->getName() << "'\n"; - else - dbgs() << *V << '\n'); + bool markOverdefined(LatticeVal &IV, Value *V) { + if (!IV.markOverdefined()) return false; + + LLVM_DEBUG(dbgs() << "markOverdefined: "; + if (auto *F = dyn_cast<Function>(V)) dbgs() + << "Function '" << F->getName() << "'\n"; + else dbgs() << *V << '\n'); // Only instructions go on the work list pushToWorkList(IV, V); + return true; } - void mergeInValue(LatticeVal &IV, Value *V, LatticeVal MergeWithV) { + bool mergeInValue(LatticeVal &IV, Value *V, LatticeVal MergeWithV) { if (IV.isOverdefined() || MergeWithV.isUnknown()) - return; // Noop. + return false; // Noop. if (MergeWithV.isOverdefined()) return markOverdefined(IV, V); if (IV.isUnknown()) return markConstant(IV, V, MergeWithV.getConstant()); if (IV.getConstant() != MergeWithV.getConstant()) return markOverdefined(IV, V); + return false; } - void mergeInValue(Value *V, LatticeVal MergeWithV) { + bool mergeInValue(Value *V, LatticeVal MergeWithV) { assert(!V->getType()->isStructTy() && "non-structs should use markConstant"); - mergeInValue(ValueState[V], V, MergeWithV); + return mergeInValue(ValueState[V], V, MergeWithV); } /// getValueState - Return the LatticeVal object that corresponds to the @@ -512,32 +529,27 @@ private: /// markEdgeExecutable - Mark a basic block as executable, adding it to the BB /// work list if it is not already executable. - void markEdgeExecutable(BasicBlock *Source, BasicBlock *Dest) { + bool markEdgeExecutable(BasicBlock *Source, BasicBlock *Dest) { if (!KnownFeasibleEdges.insert(Edge(Source, Dest)).second) - return; // This edge is already known to be executable! + return false; // This edge is already known to be executable! if (!MarkBlockExecutable(Dest)) { // If the destination is already executable, we just made an *edge* // feasible that wasn't before. Revisit the PHI nodes in the block // because they have potentially new operands. - DEBUG(dbgs() << "Marking Edge Executable: " << Source->getName() - << " -> " << Dest->getName() << '\n'); + LLVM_DEBUG(dbgs() << "Marking Edge Executable: " << Source->getName() + << " -> " << Dest->getName() << '\n'); - PHINode *PN; - for (BasicBlock::iterator I = Dest->begin(); - (PN = dyn_cast<PHINode>(I)); ++I) - visitPHINode(*PN); + for (PHINode &PN : Dest->phis()) + visitPHINode(PN); } + return true; } // getFeasibleSuccessors - Return a vector of booleans to indicate which // successors are reachable from a given terminator instruction. void getFeasibleSuccessors(TerminatorInst &TI, SmallVectorImpl<bool> &Succs); - // isEdgeFeasible - Return true if the control flow edge from the 'From' basic - // block to the 'To' basic block is currently feasible. - bool isEdgeFeasible(BasicBlock *From, BasicBlock *To); - // OperandChangedState - This method is invoked on all of the users of an // instruction that was just changed state somehow. Based on this // information, we need to update the specified user of this instruction. @@ -594,7 +606,7 @@ private: void visitInstruction(Instruction &I) { // All the instructions we don't do any special handling for just // go to overdefined. - DEBUG(dbgs() << "SCCP: Don't know how to handle: " << I << '\n'); + LLVM_DEBUG(dbgs() << "SCCP: Don't know how to handle: " << I << '\n'); markOverdefined(&I); } }; @@ -681,68 +693,17 @@ void SCCPSolver::getFeasibleSuccessors(TerminatorInst &TI, return; } - DEBUG(dbgs() << "Unknown terminator instruction: " << TI << '\n'); + LLVM_DEBUG(dbgs() << "Unknown terminator instruction: " << TI << '\n'); llvm_unreachable("SCCP: Don't know how to handle this terminator!"); } // isEdgeFeasible - Return true if the control flow edge from the 'From' basic // block to the 'To' basic block is currently feasible. bool SCCPSolver::isEdgeFeasible(BasicBlock *From, BasicBlock *To) { - assert(BBExecutable.count(To) && "Dest should always be alive!"); - - // Make sure the source basic block is executable!! - if (!BBExecutable.count(From)) return false; - - // Check to make sure this edge itself is actually feasible now. - TerminatorInst *TI = From->getTerminator(); - if (auto *BI = dyn_cast<BranchInst>(TI)) { - if (BI->isUnconditional()) - return true; - - LatticeVal BCValue = getValueState(BI->getCondition()); - - // Overdefined condition variables mean the branch could go either way, - // undef conditions mean that neither edge is feasible yet. - ConstantInt *CI = BCValue.getConstantInt(); - if (!CI) - return !BCValue.isUnknown(); - - // Constant condition variables mean the branch can only go a single way. - return BI->getSuccessor(CI->isZero()) == To; - } - - // Unwinding instructions successors are always executable. - if (TI->isExceptional()) - return true; - - if (auto *SI = dyn_cast<SwitchInst>(TI)) { - if (SI->getNumCases() < 1) - return true; - - LatticeVal SCValue = getValueState(SI->getCondition()); - ConstantInt *CI = SCValue.getConstantInt(); - - if (!CI) - return !SCValue.isUnknown(); - - return SI->findCaseValue(CI)->getCaseSuccessor() == To; - } - - // In case of indirect branch and its address is a blockaddress, we mark - // the target as executable. - if (auto *IBR = dyn_cast<IndirectBrInst>(TI)) { - LatticeVal IBRValue = getValueState(IBR->getAddress()); - BlockAddress *Addr = IBRValue.getBlockAddress(); - - if (!Addr) - return !IBRValue.isUnknown(); - - // At this point, the indirectbr is branching on a blockaddress. - return Addr->getBasicBlock() == To; - } - - DEBUG(dbgs() << "Unknown terminator instruction: " << *TI << '\n'); - llvm_unreachable("SCCP: Don't know how to handle this terminator!"); + // Check if we've called markEdgeExecutable on the edge yet. (We could + // be more aggressive and try to consider edges which haven't been marked + // yet, but there isn't any need.) + return KnownFeasibleEdges.count(Edge(From, To)); } // visit Implementations - Something changed in this instruction, either an @@ -766,7 +727,7 @@ void SCCPSolver::visitPHINode(PHINode &PN) { // If this PN returns a struct, just mark the result overdefined. // TODO: We could do a lot better than this if code actually uses this. if (PN.getType()->isStructTy()) - return markOverdefined(&PN); + return (void)markOverdefined(&PN); if (getValueState(&PN).isOverdefined()) return; // Quick exit @@ -774,7 +735,7 @@ void SCCPSolver::visitPHINode(PHINode &PN) { // Super-extra-high-degree PHI nodes are unlikely to ever be marked constant, // and slow us down a lot. Just mark them overdefined. if (PN.getNumIncomingValues() > 64) - return markOverdefined(&PN); + return (void)markOverdefined(&PN); // Look at all of the executable operands of the PHI node. If any of them // are overdefined, the PHI becomes overdefined as well. If they are all @@ -790,7 +751,7 @@ void SCCPSolver::visitPHINode(PHINode &PN) { continue; if (IV.isOverdefined()) // PHI node becomes overdefined! - return markOverdefined(&PN); + return (void)markOverdefined(&PN); if (!OperandVal) { // Grab the first value. OperandVal = IV.getConstant(); @@ -804,7 +765,7 @@ void SCCPSolver::visitPHINode(PHINode &PN) { // Check to see if there are two different constants merging, if so, the PHI // node is overdefined. if (IV.getConstant() != OperandVal) - return markOverdefined(&PN); + return (void)markOverdefined(&PN); } // If we exited the loop, this means that the PHI node only has constant @@ -872,11 +833,11 @@ void SCCPSolver::visitExtractValueInst(ExtractValueInst &EVI) { // If this returns a struct, mark all elements over defined, we don't track // structs in structs. if (EVI.getType()->isStructTy()) - return markOverdefined(&EVI); + return (void)markOverdefined(&EVI); // If this is extracting from more than one level of struct, we don't know. if (EVI.getNumIndices() != 1) - return markOverdefined(&EVI); + return (void)markOverdefined(&EVI); Value *AggVal = EVI.getAggregateOperand(); if (AggVal->getType()->isStructTy()) { @@ -885,19 +846,19 @@ void SCCPSolver::visitExtractValueInst(ExtractValueInst &EVI) { mergeInValue(getValueState(&EVI), &EVI, EltVal); } else { // Otherwise, must be extracting from an array. - return markOverdefined(&EVI); + return (void)markOverdefined(&EVI); } } void SCCPSolver::visitInsertValueInst(InsertValueInst &IVI) { auto *STy = dyn_cast<StructType>(IVI.getType()); if (!STy) - return markOverdefined(&IVI); + return (void)markOverdefined(&IVI); // If this has more than one index, we can't handle it, drive all results to // undef. if (IVI.getNumIndices() != 1) - return markOverdefined(&IVI); + return (void)markOverdefined(&IVI); Value *Aggr = IVI.getAggregateOperand(); unsigned Idx = *IVI.idx_begin(); @@ -926,7 +887,7 @@ void SCCPSolver::visitSelectInst(SelectInst &I) { // If this select returns a struct, just mark the result overdefined. // TODO: We could do a lot better than this if code actually uses this. if (I.getType()->isStructTy()) - return markOverdefined(&I); + return (void)markOverdefined(&I); LatticeVal CondValue = getValueState(I.getCondition()); if (CondValue.isUnknown()) @@ -947,12 +908,12 @@ void SCCPSolver::visitSelectInst(SelectInst &I) { // select ?, C, C -> C. if (TVal.isConstant() && FVal.isConstant() && TVal.getConstant() == FVal.getConstant()) - return markConstant(&I, FVal.getConstant()); + return (void)markConstant(&I, FVal.getConstant()); if (TVal.isUnknown()) // select ?, undef, X -> X. - return mergeInValue(&I, FVal); + return (void)mergeInValue(&I, FVal); if (FVal.isUnknown()) // select ?, X, undef -> X. - return mergeInValue(&I, TVal); + return (void)mergeInValue(&I, TVal); markOverdefined(&I); } @@ -970,7 +931,7 @@ void SCCPSolver::visitBinaryOperator(Instruction &I) { // X op Y -> undef. if (isa<UndefValue>(C)) return; - return markConstant(IV, &I, C); + return (void)markConstant(IV, &I, C); } // If something is undef, wait for it to resolve. @@ -983,7 +944,7 @@ void SCCPSolver::visitBinaryOperator(Instruction &I) { // overdefined, and we can replace it with zero. if (I.getOpcode() == Instruction::UDiv || I.getOpcode() == Instruction::SDiv) if (V1State.isConstant() && V1State.getConstant()->isNullValue()) - return markConstant(IV, &I, V1State.getConstant()); + return (void)markConstant(IV, &I, V1State.getConstant()); // If this is: // -> AND/MUL with 0 @@ -1006,12 +967,12 @@ void SCCPSolver::visitBinaryOperator(Instruction &I) { // X and 0 = 0 // X * 0 = 0 if (NonOverdefVal->getConstant()->isNullValue()) - return markConstant(IV, &I, NonOverdefVal->getConstant()); + return (void)markConstant(IV, &I, NonOverdefVal->getConstant()); } else { // X or -1 = -1 if (ConstantInt *CI = NonOverdefVal->getConstantInt()) if (CI->isMinusOne()) - return markConstant(IV, &I, NonOverdefVal->getConstant()); + return (void)markConstant(IV, &I, NonOverdefVal->getConstant()); } } } @@ -1021,22 +982,36 @@ void SCCPSolver::visitBinaryOperator(Instruction &I) { // Handle ICmpInst instruction. void SCCPSolver::visitCmpInst(CmpInst &I) { - LatticeVal V1State = getValueState(I.getOperand(0)); - LatticeVal V2State = getValueState(I.getOperand(1)); - LatticeVal &IV = ValueState[&I]; if (IV.isOverdefined()) return; - if (V1State.isConstant() && V2State.isConstant()) { - Constant *C = ConstantExpr::getCompare( - I.getPredicate(), V1State.getConstant(), V2State.getConstant()); + Value *Op1 = I.getOperand(0); + Value *Op2 = I.getOperand(1); + + // For parameters, use ParamState which includes constant range info if + // available. + auto V1Param = ParamState.find(Op1); + ValueLatticeElement V1State = (V1Param != ParamState.end()) + ? V1Param->second + : getValueState(Op1).toValueLattice(); + + auto V2Param = ParamState.find(Op2); + ValueLatticeElement V2State = V2Param != ParamState.end() + ? V2Param->second + : getValueState(Op2).toValueLattice(); + + Constant *C = V1State.getCompare(I.getPredicate(), I.getType(), V2State); + if (C) { if (isa<UndefValue>(C)) return; - return markConstant(IV, &I, C); + LatticeVal CV; + CV.markConstant(C); + mergeInValue(&I, CV); + return; } // If operands are still unknown, wait for it to resolve. - if (!V1State.isOverdefined() && !V2State.isOverdefined()) + if (!V1State.isOverdefined() && !V2State.isOverdefined() && !IV.isConstant()) return; markOverdefined(&I); @@ -1056,7 +1031,7 @@ void SCCPSolver::visitGetElementPtrInst(GetElementPtrInst &I) { return; // Operands are not resolved yet. if (State.isOverdefined()) - return markOverdefined(&I); + return (void)markOverdefined(&I); assert(State.isConstant() && "Unknown state!"); Operands.push_back(State.getConstant()); @@ -1094,7 +1069,7 @@ void SCCPSolver::visitStoreInst(StoreInst &SI) { void SCCPSolver::visitLoadInst(LoadInst &I) { // If this load is of a struct, just mark the result overdefined. if (I.getType()->isStructTy()) - return markOverdefined(&I); + return (void)markOverdefined(&I); LatticeVal PtrVal = getValueState(I.getOperand(0)); if (PtrVal.isUnknown()) return; // The pointer is not resolved yet! @@ -1103,13 +1078,17 @@ void SCCPSolver::visitLoadInst(LoadInst &I) { if (IV.isOverdefined()) return; if (!PtrVal.isConstant() || I.isVolatile()) - return markOverdefined(IV, &I); + return (void)markOverdefined(IV, &I); Constant *Ptr = PtrVal.getConstant(); // load null is undefined. - if (isa<ConstantPointerNull>(Ptr) && I.getPointerAddressSpace() == 0) - return; + if (isa<ConstantPointerNull>(Ptr)) { + if (NullPointerIsDefined(I.getFunction(), I.getPointerAddressSpace())) + return (void)markOverdefined(IV, &I); + else + return; + } // Transform load (constant global) into the value loaded. if (auto *GV = dyn_cast<GlobalVariable>(Ptr)) { @@ -1128,7 +1107,7 @@ void SCCPSolver::visitLoadInst(LoadInst &I) { if (Constant *C = ConstantFoldLoadFromConstPtr(Ptr, I.getType(), DL)) { if (isa<UndefValue>(C)) return; - return markConstant(IV, &I, C); + return (void)markConstant(IV, &I, C); } // Otherwise we cannot say for certain what value this load will produce. @@ -1160,7 +1139,7 @@ CallOverdefined: if (State.isUnknown()) return; // Operands are not resolved yet. if (State.isOverdefined()) - return markOverdefined(I); + return (void)markOverdefined(I); assert(State.isConstant() && "Unknown state!"); Operands.push_back(State.getConstant()); } @@ -1174,12 +1153,12 @@ CallOverdefined: // call -> undef. if (isa<UndefValue>(C)) return; - return markConstant(I, C); + return (void)markConstant(I, C); } } // Otherwise, we don't know anything about this call, mark it overdefined. - return markOverdefined(I); + return (void)markOverdefined(I); } // If this is a local function that doesn't have its address taken, mark its @@ -1207,8 +1186,16 @@ CallOverdefined: } else { // Most other parts of the Solver still only use the simpler value // lattice, so we propagate changes for parameters to both lattices. - getParamState(&*AI).mergeIn(getValueState(*CAI).toValueLattice(), DL); - mergeInValue(&*AI, getValueState(*CAI)); + LatticeVal ConcreteArgument = getValueState(*CAI); + bool ParamChanged = + getParamState(&*AI).mergeIn(ConcreteArgument.toValueLattice(), DL); + bool ValueChanged = mergeInValue(&*AI, ConcreteArgument); + // Add argument to work list, if the state of a parameter changes but + // ValueState does not change (because it is already overdefined there), + // We have to take changes in ParamState into account, as it is used + // when evaluating Cmp instructions. + if (!ValueChanged && ParamChanged) + pushToWorkList(ValueState[&*AI], &*AI); } } } @@ -1242,7 +1229,7 @@ void SCCPSolver::Solve() { while (!OverdefinedInstWorkList.empty()) { Value *I = OverdefinedInstWorkList.pop_back_val(); - DEBUG(dbgs() << "\nPopped off OI-WL: " << *I << '\n'); + LLVM_DEBUG(dbgs() << "\nPopped off OI-WL: " << *I << '\n'); // "I" got into the work list because it either made the transition from // bottom to constant, or to overdefined. @@ -1260,7 +1247,7 @@ void SCCPSolver::Solve() { while (!InstWorkList.empty()) { Value *I = InstWorkList.pop_back_val(); - DEBUG(dbgs() << "\nPopped off I-WL: " << *I << '\n'); + LLVM_DEBUG(dbgs() << "\nPopped off I-WL: " << *I << '\n'); // "I" got into the work list because it made the transition from undef to // constant. @@ -1280,7 +1267,7 @@ void SCCPSolver::Solve() { BasicBlock *BB = BBWorkList.back(); BBWorkList.pop_back(); - DEBUG(dbgs() << "\nPopped off BBWL: " << *BB << '\n'); + LLVM_DEBUG(dbgs() << "\nPopped off BBWL: " << *BB << '\n'); // Notify all instructions in this basic block that they are newly // executable. @@ -1501,7 +1488,11 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { break; case Instruction::ICmp: // X == undef -> undef. Other comparisons get more complicated. - if (cast<ICmpInst>(&I)->isEquality()) + Op0LV = getValueState(I.getOperand(0)); + Op1LV = getValueState(I.getOperand(1)); + + if ((Op0LV.isUnknown() || Op1LV.isUnknown()) && + cast<ICmpInst>(&I)->isEquality()) break; markOverdefined(&I); return true; @@ -1546,11 +1537,14 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { } // Otherwise, it is a branch on a symbolic value which is currently - // considered to be undef. Handle this by forcing the input value to the - // branch to false. - markForcedConstant(BI->getCondition(), - ConstantInt::getFalse(TI->getContext())); - return true; + // considered to be undef. Make sure some edge is executable, so a + // branch on "undef" always flows somewhere. + // FIXME: Distinguish between dead code and an LLVM "undef" value. + BasicBlock *DefaultSuccessor = TI->getSuccessor(1); + if (markEdgeExecutable(&BB, DefaultSuccessor)) + return true; + + continue; } if (auto *IBR = dyn_cast<IndirectBrInst>(TI)) { @@ -1571,11 +1565,15 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { } // Otherwise, it is a branch on a symbolic value which is currently - // considered to be undef. Handle this by forcing the input value to the - // branch to the first successor. - markForcedConstant(IBR->getAddress(), - BlockAddress::get(IBR->getSuccessor(0))); - return true; + // considered to be undef. Make sure some edge is executable, so a + // branch on "undef" always flows somewhere. + // FIXME: IndirectBr on "undef" doesn't actually need to go anywhere: + // we can assume the branch has undefined behavior instead. + BasicBlock *DefaultSuccessor = IBR->getSuccessor(0); + if (markEdgeExecutable(&BB, DefaultSuccessor)) + return true; + + continue; } if (auto *SI = dyn_cast<SwitchInst>(TI)) { @@ -1590,56 +1588,19 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { return true; } - markForcedConstant(SI->getCondition(), SI->case_begin()->getCaseValue()); - return true; - } - } - - return false; -} - -static bool tryToReplaceWithConstantRange(SCCPSolver &Solver, Value *V) { - bool Changed = false; - - // Currently we only use range information for integer values. - if (!V->getType()->isIntegerTy()) - return false; - - const ValueLatticeElement &IV = Solver.getLatticeValueFor(V); - if (!IV.isConstantRange()) - return false; + // Otherwise, it is a branch on a symbolic value which is currently + // considered to be undef. Make sure some edge is executable, so a + // branch on "undef" always flows somewhere. + // FIXME: Distinguish between dead code and an LLVM "undef" value. + BasicBlock *DefaultSuccessor = SI->case_begin()->getCaseSuccessor(); + if (markEdgeExecutable(&BB, DefaultSuccessor)) + return true; - for (auto UI = V->uses().begin(), E = V->uses().end(); UI != E;) { - const Use &U = *UI++; - auto *Icmp = dyn_cast<ICmpInst>(U.getUser()); - if (!Icmp || !Solver.isBlockExecutable(Icmp->getParent())) continue; - - auto getIcmpLatticeValue = [&](Value *Op) { - if (auto *C = dyn_cast<Constant>(Op)) - return ValueLatticeElement::get(C); - return Solver.getLatticeValueFor(Op); - }; - - ValueLatticeElement A = getIcmpLatticeValue(Icmp->getOperand(0)); - ValueLatticeElement B = getIcmpLatticeValue(Icmp->getOperand(1)); - - Constant *C = nullptr; - if (A.satisfiesPredicate(Icmp->getPredicate(), B)) - C = ConstantInt::getTrue(Icmp->getType()); - else if (A.satisfiesPredicate(Icmp->getInversePredicate(), B)) - C = ConstantInt::getFalse(Icmp->getType()); - - if (C) { - Icmp->replaceAllUsesWith(C); - DEBUG(dbgs() << "Replacing " << *Icmp << " with " << *C - << ", because of range information " << A << " " << B - << "\n"); - Icmp->eraseFromParent(); - Changed = true; } } - return Changed; + + return false; } static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { @@ -1659,22 +1620,31 @@ static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { } Const = ConstantStruct::get(ST, ConstVals); } else { - const ValueLatticeElement &IV = Solver.getLatticeValueFor(V); + const LatticeVal &IV = Solver.getLatticeValueFor(V); if (IV.isOverdefined()) return false; - if (IV.isConstantRange()) { - if (IV.getConstantRange().isSingleElement()) - Const = - ConstantInt::get(V->getType(), IV.asConstantInteger().getValue()); - else - return false; - } else - Const = - IV.isConstant() ? IV.getConstant() : UndefValue::get(V->getType()); + Const = IV.isConstant() ? IV.getConstant() : UndefValue::get(V->getType()); } assert(Const && "Constant is nullptr here!"); - DEBUG(dbgs() << " Constant: " << *Const << " = " << *V << '\n'); + + // Replacing `musttail` instructions with constant breaks `musttail` invariant + // unless the call itself can be removed + CallInst *CI = dyn_cast<CallInst>(V); + if (CI && CI->isMustTailCall() && !CI->isSafeToRemove()) { + CallSite CS(CI); + Function *F = CS.getCalledFunction(); + + // Don't zap returns of the callee + if (F) + Solver.AddMustTailCallee(F); + + LLVM_DEBUG(dbgs() << " Can\'t treat the result of musttail call : " << *CI + << " as a constant\n"); + return false; + } + + LLVM_DEBUG(dbgs() << " Constant: " << *Const << " = " << *V << '\n'); // Replaces all of the uses of a variable with uses of the constant. V->replaceAllUsesWith(Const); @@ -1685,7 +1655,7 @@ static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { // and return true if the function was modified. static bool runSCCP(Function &F, const DataLayout &DL, const TargetLibraryInfo *TLI) { - DEBUG(dbgs() << "SCCP on function '" << F.getName() << "'\n"); + LLVM_DEBUG(dbgs() << "SCCP on function '" << F.getName() << "'\n"); SCCPSolver Solver(DL, TLI); // Mark the first block of the function as being executable. @@ -1699,7 +1669,7 @@ static bool runSCCP(Function &F, const DataLayout &DL, bool ResolvedUndefs = true; while (ResolvedUndefs) { Solver.Solve(); - DEBUG(dbgs() << "RESOLVING UNDEFs\n"); + LLVM_DEBUG(dbgs() << "RESOLVING UNDEFs\n"); ResolvedUndefs = Solver.ResolvedUndefsIn(F); } @@ -1711,7 +1681,7 @@ static bool runSCCP(Function &F, const DataLayout &DL, for (BasicBlock &BB : F) { if (!Solver.isBlockExecutable(&BB)) { - DEBUG(dbgs() << " BasicBlock Dead:" << BB); + LLVM_DEBUG(dbgs() << " BasicBlock Dead:" << BB); ++NumDeadBlocks; NumInstRemoved += removeAllNonTerminatorAndEHPadInstructions(&BB); @@ -1748,6 +1718,7 @@ PreservedAnalyses SCCPPass::run(Function &F, FunctionAnalysisManager &AM) { auto PA = PreservedAnalyses(); PA.preserve<GlobalsAA>(); + PA.preserveSet<CFGAnalyses>(); return PA; } @@ -1770,6 +1741,7 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); + AU.setPreservesCFG(); } // runOnFunction - Run the Sparse Conditional Constant Propagation @@ -1804,14 +1776,30 @@ static void findReturnsToZap(Function &F, if (!Solver.isArgumentTrackedFunction(&F)) return; - for (BasicBlock &BB : F) + // There is a non-removable musttail call site of this function. Zapping + // returns is not allowed. + if (Solver.isMustTailCallee(&F)) { + LLVM_DEBUG(dbgs() << "Can't zap returns of the function : " << F.getName() + << " due to present musttail call of it\n"); + return; + } + + for (BasicBlock &BB : F) { + if (CallInst *CI = BB.getTerminatingMustTailCall()) { + LLVM_DEBUG(dbgs() << "Can't zap return of the block due to present " + << "musttail call : " << *CI << "\n"); + (void)CI; + return; + } + if (auto *RI = dyn_cast<ReturnInst>(BB.getTerminator())) if (!isa<UndefValue>(RI->getOperand(0))) ReturnsToZap.push_back(RI); + } } -static bool runIPSCCP(Module &M, const DataLayout &DL, - const TargetLibraryInfo *TLI) { +bool llvm::runIPSCCP(Module &M, const DataLayout &DL, + const TargetLibraryInfo *TLI) { SCCPSolver Solver(DL, TLI); // Loop over all functions, marking arguments to those with their addresses @@ -1851,13 +1839,17 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, // Solve for constants. bool ResolvedUndefs = true; + Solver.Solve(); while (ResolvedUndefs) { - Solver.Solve(); - - DEBUG(dbgs() << "RESOLVING UNDEFS\n"); + LLVM_DEBUG(dbgs() << "RESOLVING UNDEFS\n"); ResolvedUndefs = false; for (Function &F : M) - ResolvedUndefs |= Solver.ResolvedUndefsIn(F); + if (Solver.ResolvedUndefsIn(F)) { + // We run Solve() after we resolved an undef in a function, because + // we might deduce a fact that eliminates an undef in another function. + Solver.Solve(); + ResolvedUndefs = true; + } } bool MadeChanges = false; @@ -1877,18 +1869,12 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, ++IPNumArgsElimed; continue; } - - if (!AI->use_empty() && tryToReplaceWithConstantRange(Solver, &*AI)) - ++IPNumRangeInfoUsed; } for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { if (!Solver.isBlockExecutable(&*BB)) { - DEBUG(dbgs() << " BasicBlock Dead:" << *BB); - + LLVM_DEBUG(dbgs() << " BasicBlock Dead:" << *BB); ++NumDeadBlocks; - NumInstRemoved += - changeToUnreachable(BB->getFirstNonPHI(), /*UseLLVMTrap=*/false); MadeChanges = true; @@ -1902,7 +1888,7 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, if (Inst->getType()->isVoidTy()) continue; if (tryToReplaceWithConstant(Solver, Inst)) { - if (!isa<CallInst>(Inst) && !isa<TerminatorInst>(Inst)) + if (Inst->isSafeToRemove()) Inst->eraseFromParent(); // Hey, we just changed something! MadeChanges = true; @@ -1911,6 +1897,17 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, } } + // Change dead blocks to unreachable. We do it after replacing constants in + // all executable blocks, because changeToUnreachable may remove PHI nodes + // in executable blocks we found values for. The function's entry block is + // not part of BlocksToErase, so we have to handle it separately. + for (BasicBlock *BB : BlocksToErase) + NumInstRemoved += + changeToUnreachable(BB->getFirstNonPHI(), /*UseLLVMTrap=*/false); + if (!Solver.isBlockExecutable(&F.front())) + NumInstRemoved += changeToUnreachable(F.front().getFirstNonPHI(), + /*UseLLVMTrap=*/false); + // Now that all instructions in the function are constant folded, erase dead // blocks, because we can now use ConstantFoldTerminator to get rid of // in-edges. @@ -1930,31 +1927,33 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, bool Folded = ConstantFoldTerminator(I->getParent()); if (!Folded) { - // The constant folder may not have been able to fold the terminator - // if this is a branch or switch on undef. Fold it manually as a - // branch to the first successor. -#ifndef NDEBUG - if (auto *BI = dyn_cast<BranchInst>(I)) { - assert(BI->isConditional() && isa<UndefValue>(BI->getCondition()) && - "Branch should be foldable!"); - } else if (auto *SI = dyn_cast<SwitchInst>(I)) { - assert(isa<UndefValue>(SI->getCondition()) && "Switch should fold"); + // If the branch can't be folded, we must have forced an edge + // for an indeterminate value. Force the terminator to fold + // to that edge. + Constant *C; + BasicBlock *Dest; + if (SwitchInst *SI = dyn_cast<SwitchInst>(I)) { + Dest = SI->case_begin()->getCaseSuccessor(); + C = SI->case_begin()->getCaseValue(); + } else if (BranchInst *BI = dyn_cast<BranchInst>(I)) { + Dest = BI->getSuccessor(1); + C = ConstantInt::getFalse(BI->getContext()); + } else if (IndirectBrInst *IBR = dyn_cast<IndirectBrInst>(I)) { + Dest = IBR->getSuccessor(0); + C = BlockAddress::get(IBR->getSuccessor(0)); } else { - llvm_unreachable("Didn't fold away reference to block!"); + llvm_unreachable("Unexpected terminator instruction"); } -#endif + assert(Solver.isEdgeFeasible(I->getParent(), Dest) && + "Didn't find feasible edge?"); + (void)Dest; - // Make this an uncond branch to the first successor. - TerminatorInst *TI = I->getParent()->getTerminator(); - BranchInst::Create(TI->getSuccessor(0), TI); - - // Remove entries in successor phi nodes to remove edges. - for (unsigned i = 1, e = TI->getNumSuccessors(); i != e; ++i) - TI->getSuccessor(i)->removePredecessor(TI->getParent()); - - // Remove the old terminator. - TI->eraseFromParent(); + I->setOperand(0, C); + Folded = ConstantFoldTerminator(I->getParent()); } + assert(Folded && + "Expect TermInst on constantint or blockaddress to be folded"); + (void) Folded; } // Finally, delete the basic block. @@ -2005,7 +2004,8 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, GlobalVariable *GV = I->first; assert(!I->second.isOverdefined() && "Overdefined values should have been taken out of the map!"); - DEBUG(dbgs() << "Found that GV '" << GV->getName() << "' is constant!\n"); + LLVM_DEBUG(dbgs() << "Found that GV '" << GV->getName() + << "' is constant!\n"); while (!GV->use_empty()) { StoreInst *SI = cast<StoreInst>(GV->user_back()); SI->eraseFromParent(); @@ -2016,55 +2016,3 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, return MadeChanges; } - -PreservedAnalyses IPSCCPPass::run(Module &M, ModuleAnalysisManager &AM) { - const DataLayout &DL = M.getDataLayout(); - auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); - if (!runIPSCCP(M, DL, &TLI)) - return PreservedAnalyses::all(); - return PreservedAnalyses::none(); -} - -namespace { - -//===--------------------------------------------------------------------===// -// -/// IPSCCP Class - This class implements interprocedural Sparse Conditional -/// Constant Propagation. -/// -class IPSCCPLegacyPass : public ModulePass { -public: - static char ID; - - IPSCCPLegacyPass() : ModulePass(ID) { - initializeIPSCCPLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; - const DataLayout &DL = M.getDataLayout(); - const TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - return runIPSCCP(M, DL, TLI); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetLibraryInfoWrapperPass>(); - } -}; - -} // end anonymous namespace - -char IPSCCPLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(IPSCCPLegacyPass, "ipsccp", - "Interprocedural Sparse Conditional Constant Propagation", - false, false) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(IPSCCPLegacyPass, "ipsccp", - "Interprocedural Sparse Conditional Constant Propagation", - false, false) - -// createIPSCCPPass - This is the public interface to this file. -ModulePass *llvm::createIPSCCPPass() { return new IPSCCPLegacyPass(); } diff --git a/lib/Transforms/Scalar/SROA.cpp b/lib/Transforms/Scalar/SROA.cpp index bfe3754f0769..6c3f012c6280 100644 --- a/lib/Transforms/Scalar/SROA.cpp +++ b/lib/Transforms/Scalar/SROA.cpp @@ -42,6 +42,8 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/PtrUseVisitor.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Config/llvm-config.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/ConstantFolder.h" @@ -79,7 +81,6 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" #include <algorithm> #include <cassert> @@ -124,14 +125,9 @@ static cl::opt<bool> SROARandomShuffleSlices("sroa-random-shuffle-slices", static cl::opt<bool> SROAStrictInbounds("sroa-strict-inbounds", cl::init(false), cl::Hidden); -/// Hidden option to allow more aggressive splitting. -static cl::opt<bool> -SROASplitNonWholeAllocaSlices("sroa-split-nonwhole-alloca-slices", - cl::init(false), cl::Hidden); - namespace { -/// \brief A custom IRBuilder inserter which prefixes all names, but only in +/// A custom IRBuilder inserter which prefixes all names, but only in /// Assert builds. class IRBuilderPrefixedInserter : public IRBuilderDefaultInserter { std::string Prefix; @@ -151,23 +147,23 @@ protected: } }; -/// \brief Provide a type for IRBuilder that drops names in release builds. +/// Provide a type for IRBuilder that drops names in release builds. using IRBuilderTy = IRBuilder<ConstantFolder, IRBuilderPrefixedInserter>; -/// \brief A used slice of an alloca. +/// A used slice of an alloca. /// /// This structure represents a slice of an alloca used by some instruction. It /// stores both the begin and end offsets of this use, a pointer to the use /// itself, and a flag indicating whether we can classify the use as splittable /// or not when forming partitions of the alloca. class Slice { - /// \brief The beginning offset of the range. + /// The beginning offset of the range. uint64_t BeginOffset = 0; - /// \brief The ending offset, not included in the range. + /// The ending offset, not included in the range. uint64_t EndOffset = 0; - /// \brief Storage for both the use of this slice and whether it can be + /// Storage for both the use of this slice and whether it can be /// split. PointerIntPair<Use *, 1, bool> UseAndIsSplittable; @@ -189,7 +185,7 @@ public: bool isDead() const { return getUse() == nullptr; } void kill() { UseAndIsSplittable.setPointer(nullptr); } - /// \brief Support for ordering ranges. + /// Support for ordering ranges. /// /// This provides an ordering over ranges such that start offsets are /// always increasing, and within equal start offsets, the end offsets are @@ -207,7 +203,7 @@ public: return false; } - /// \brief Support comparison with a single offset to allow binary searches. + /// Support comparison with a single offset to allow binary searches. friend LLVM_ATTRIBUTE_UNUSED bool operator<(const Slice &LHS, uint64_t RHSOffset) { return LHS.beginOffset() < RHSOffset; @@ -233,7 +229,7 @@ template <> struct isPodLike<Slice> { static const bool value = true; }; } // end namespace llvm -/// \brief Representation of the alloca slices. +/// Representation of the alloca slices. /// /// This class represents the slices of an alloca which are formed by its /// various uses. If a pointer escapes, we can't fully build a representation @@ -242,16 +238,16 @@ template <> struct isPodLike<Slice> { static const bool value = true; }; /// starting at a particular offset before splittable slices. class llvm::sroa::AllocaSlices { public: - /// \brief Construct the slices of a particular alloca. + /// Construct the slices of a particular alloca. AllocaSlices(const DataLayout &DL, AllocaInst &AI); - /// \brief Test whether a pointer to the allocation escapes our analysis. + /// Test whether a pointer to the allocation escapes our analysis. /// /// If this is true, the slices are never fully built and should be /// ignored. bool isEscaped() const { return PointerEscapingInstr; } - /// \brief Support for iterating over the slices. + /// Support for iterating over the slices. /// @{ using iterator = SmallVectorImpl<Slice>::iterator; using range = iterator_range<iterator>; @@ -266,10 +262,10 @@ public: const_iterator end() const { return Slices.end(); } /// @} - /// \brief Erase a range of slices. + /// Erase a range of slices. void erase(iterator Start, iterator Stop) { Slices.erase(Start, Stop); } - /// \brief Insert new slices for this alloca. + /// Insert new slices for this alloca. /// /// This moves the slices into the alloca's slices collection, and re-sorts /// everything so that the usual ordering properties of the alloca's slices @@ -278,7 +274,7 @@ public: int OldSize = Slices.size(); Slices.append(NewSlices.begin(), NewSlices.end()); auto SliceI = Slices.begin() + OldSize; - std::sort(SliceI, Slices.end()); + llvm::sort(SliceI, Slices.end()); std::inplace_merge(Slices.begin(), SliceI, Slices.end()); } @@ -287,10 +283,10 @@ public: class partition_iterator; iterator_range<partition_iterator> partitions(); - /// \brief Access the dead users for this alloca. + /// Access the dead users for this alloca. ArrayRef<Instruction *> getDeadUsers() const { return DeadUsers; } - /// \brief Access the dead operands referring to this alloca. + /// Access the dead operands referring to this alloca. /// /// These are operands which have cannot actually be used to refer to the /// alloca as they are outside its range and the user doesn't correct for @@ -316,11 +312,11 @@ private: friend class AllocaSlices::SliceBuilder; #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) - /// \brief Handle to alloca instruction to simplify method interfaces. + /// Handle to alloca instruction to simplify method interfaces. AllocaInst &AI; #endif - /// \brief The instruction responsible for this alloca not having a known set + /// The instruction responsible for this alloca not having a known set /// of slices. /// /// When an instruction (potentially) escapes the pointer to the alloca, we @@ -328,7 +324,7 @@ private: /// alloca. This will be null if the alloca slices are analyzed successfully. Instruction *PointerEscapingInstr; - /// \brief The slices of the alloca. + /// The slices of the alloca. /// /// We store a vector of the slices formed by uses of the alloca here. This /// vector is sorted by increasing begin offset, and then the unsplittable @@ -336,7 +332,7 @@ private: /// details. SmallVector<Slice, 8> Slices; - /// \brief Instructions which will become dead if we rewrite the alloca. + /// Instructions which will become dead if we rewrite the alloca. /// /// Note that these are not separated by slice. This is because we expect an /// alloca to be completely rewritten or not rewritten at all. If rewritten, @@ -344,7 +340,7 @@ private: /// they come from outside of the allocated space. SmallVector<Instruction *, 8> DeadUsers; - /// \brief Operands which will become dead if we rewrite the alloca. + /// Operands which will become dead if we rewrite the alloca. /// /// These are operands that in their particular use can be replaced with /// undef when we rewrite the alloca. These show up in out-of-bounds inputs @@ -355,7 +351,7 @@ private: SmallVector<Use *, 8> DeadOperands; }; -/// \brief A partition of the slices. +/// A partition of the slices. /// /// An ephemeral representation for a range of slices which can be viewed as /// a partition of the alloca. This range represents a span of the alloca's @@ -371,32 +367,32 @@ private: using iterator = AllocaSlices::iterator; - /// \brief The beginning and ending offsets of the alloca for this + /// The beginning and ending offsets of the alloca for this /// partition. uint64_t BeginOffset, EndOffset; - /// \brief The start and end iterators of this partition. + /// The start and end iterators of this partition. iterator SI, SJ; - /// \brief A collection of split slice tails overlapping the partition. + /// A collection of split slice tails overlapping the partition. SmallVector<Slice *, 4> SplitTails; - /// \brief Raw constructor builds an empty partition starting and ending at + /// Raw constructor builds an empty partition starting and ending at /// the given iterator. Partition(iterator SI) : SI(SI), SJ(SI) {} public: - /// \brief The start offset of this partition. + /// The start offset of this partition. /// /// All of the contained slices start at or after this offset. uint64_t beginOffset() const { return BeginOffset; } - /// \brief The end offset of this partition. + /// The end offset of this partition. /// /// All of the contained slices end at or before this offset. uint64_t endOffset() const { return EndOffset; } - /// \brief The size of the partition. + /// The size of the partition. /// /// Note that this can never be zero. uint64_t size() const { @@ -404,7 +400,7 @@ public: return EndOffset - BeginOffset; } - /// \brief Test whether this partition contains no slices, and merely spans + /// Test whether this partition contains no slices, and merely spans /// a region occupied by split slices. bool empty() const { return SI == SJ; } @@ -421,7 +417,7 @@ public: iterator end() const { return SJ; } /// @} - /// \brief Get the sequence of split slice tails. + /// Get the sequence of split slice tails. /// /// These tails are of slices which start before this partition but are /// split and overlap into the partition. We accumulate these while forming @@ -429,7 +425,7 @@ public: ArrayRef<Slice *> splitSliceTails() const { return SplitTails; } }; -/// \brief An iterator over partitions of the alloca's slices. +/// An iterator over partitions of the alloca's slices. /// /// This iterator implements the core algorithm for partitioning the alloca's /// slices. It is a forward iterator as we don't support backtracking for @@ -443,18 +439,18 @@ class AllocaSlices::partition_iterator Partition> { friend class AllocaSlices; - /// \brief Most of the state for walking the partitions is held in a class + /// Most of the state for walking the partitions is held in a class /// with a nice interface for examining them. Partition P; - /// \brief We need to keep the end of the slices to know when to stop. + /// We need to keep the end of the slices to know when to stop. AllocaSlices::iterator SE; - /// \brief We also need to keep track of the maximum split end offset seen. + /// We also need to keep track of the maximum split end offset seen. /// FIXME: Do we really? uint64_t MaxSplitSliceEndOffset = 0; - /// \brief Sets the partition to be empty at given iterator, and sets the + /// Sets the partition to be empty at given iterator, and sets the /// end iterator. partition_iterator(AllocaSlices::iterator SI, AllocaSlices::iterator SE) : P(SI), SE(SE) { @@ -464,7 +460,7 @@ class AllocaSlices::partition_iterator advance(); } - /// \brief Advance the iterator to the next partition. + /// Advance the iterator to the next partition. /// /// Requires that the iterator not be at the end of the slices. void advance() { @@ -619,7 +615,7 @@ public: Partition &operator*() { return P; } }; -/// \brief A forward range over the partitions of the alloca's slices. +/// A forward range over the partitions of the alloca's slices. /// /// This accesses an iterator range over the partitions of the alloca's /// slices. It computes these partitions on the fly based on the overlapping @@ -643,7 +639,7 @@ static Value *foldSelectInst(SelectInst &SI) { return nullptr; } -/// \brief A helper that folds a PHI node or a select. +/// A helper that folds a PHI node or a select. static Value *foldPHINodeOrSelectInst(Instruction &I) { if (PHINode *PN = dyn_cast<PHINode>(&I)) { // If PN merges together the same value, return that value. @@ -652,7 +648,7 @@ static Value *foldPHINodeOrSelectInst(Instruction &I) { return foldSelectInst(cast<SelectInst>(I)); } -/// \brief Builder for the alloca slices. +/// Builder for the alloca slices. /// /// This class builds a set of alloca slices by recursively visiting the uses /// of an alloca and making a slice for each load and store at each offset. @@ -668,7 +664,7 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> { SmallDenseMap<Instruction *, unsigned> MemTransferSliceMap; SmallDenseMap<Instruction *, uint64_t> PHIOrSelectSizes; - /// \brief Set to de-duplicate dead instructions found in the use walk. + /// Set to de-duplicate dead instructions found in the use walk. SmallPtrSet<Instruction *, 4> VisitedDeadInsts; public: @@ -687,11 +683,12 @@ private: // Completely skip uses which have a zero size or start either before or // past the end of the allocation. if (Size == 0 || Offset.uge(AllocSize)) { - DEBUG(dbgs() << "WARNING: Ignoring " << Size << " byte use @" << Offset - << " which has zero size or starts outside of the " - << AllocSize << " byte alloca:\n" - << " alloca: " << AS.AI << "\n" - << " use: " << I << "\n"); + LLVM_DEBUG(dbgs() << "WARNING: Ignoring " << Size << " byte use @" + << Offset + << " which has zero size or starts outside of the " + << AllocSize << " byte alloca:\n" + << " alloca: " << AS.AI << "\n" + << " use: " << I << "\n"); return markAsDead(I); } @@ -706,10 +703,11 @@ private: // them, and so have to record at least the information here. assert(AllocSize >= BeginOffset); // Established above. if (Size > AllocSize - BeginOffset) { - DEBUG(dbgs() << "WARNING: Clamping a " << Size << " byte use @" << Offset - << " to remain within the " << AllocSize << " byte alloca:\n" - << " alloca: " << AS.AI << "\n" - << " use: " << I << "\n"); + LLVM_DEBUG(dbgs() << "WARNING: Clamping a " << Size << " byte use @" + << Offset << " to remain within the " << AllocSize + << " byte alloca:\n" + << " alloca: " << AS.AI << "\n" + << " use: " << I << "\n"); EndOffset = AllocSize; } @@ -802,18 +800,18 @@ private: uint64_t Size = DL.getTypeStoreSize(ValOp->getType()); // If this memory access can be shown to *statically* extend outside the - // bounds of of the allocation, it's behavior is undefined, so simply + // bounds of the allocation, it's behavior is undefined, so simply // ignore it. Note that this is more strict than the generic clamping // behavior of insertUse. We also try to handle cases which might run the // risk of overflow. // FIXME: We should instead consider the pointer to have escaped if this // function is being instrumented for addressing bugs or race conditions. if (Size > AllocSize || Offset.ugt(AllocSize - Size)) { - DEBUG(dbgs() << "WARNING: Ignoring " << Size << " byte store @" << Offset - << " which extends past the end of the " << AllocSize - << " byte alloca:\n" - << " alloca: " << AS.AI << "\n" - << " use: " << SI << "\n"); + LLVM_DEBUG(dbgs() << "WARNING: Ignoring " << Size << " byte store @" + << Offset << " which extends past the end of the " + << AllocSize << " byte alloca:\n" + << " alloca: " << AS.AI << "\n" + << " use: " << SI << "\n"); return markAsDead(SI); } @@ -1027,7 +1025,7 @@ private: void visitSelectInst(SelectInst &SI) { visitPHINodeOrSelectInst(SI); } - /// \brief Disable SROA entirely if there are unhandled users of the alloca. + /// Disable SROA entirely if there are unhandled users of the alloca. void visitInstruction(Instruction &I) { PI.setAborted(&I); } }; @@ -1062,7 +1060,7 @@ AllocaSlices::AllocaSlices(const DataLayout &DL, AllocaInst &AI) // Sort the uses. This arranges for the offsets to be in ascending order, // and the sizes to be in descending order. - std::sort(Slices.begin(), Slices.end()); + llvm::sort(Slices.begin(), Slices.end()); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -1240,7 +1238,7 @@ static bool isSafePHIToSpeculate(PHINode &PN) { } static void speculatePHINodeLoads(PHINode &PN) { - DEBUG(dbgs() << " original: " << PN << "\n"); + LLVM_DEBUG(dbgs() << " original: " << PN << "\n"); Type *LoadTy = cast<PointerType>(PN.getType())->getElementType(); IRBuilderTy PHIBuilder(&PN); @@ -1263,10 +1261,21 @@ static void speculatePHINodeLoads(PHINode &PN) { } // Inject loads into all of the pred blocks. + DenseMap<BasicBlock*, Value*> InjectedLoads; for (unsigned Idx = 0, Num = PN.getNumIncomingValues(); Idx != Num; ++Idx) { BasicBlock *Pred = PN.getIncomingBlock(Idx); - TerminatorInst *TI = Pred->getTerminator(); Value *InVal = PN.getIncomingValue(Idx); + + // A PHI node is allowed to have multiple (duplicated) entries for the same + // basic block, as long as the value is the same. So if we already injected + // a load in the predecessor, then we should reuse the same load for all + // duplicated entries. + if (Value* V = InjectedLoads.lookup(Pred)) { + NewPN->addIncoming(V, Pred); + continue; + } + + TerminatorInst *TI = Pred->getTerminator(); IRBuilderTy PredBuilder(TI); LoadInst *Load = PredBuilder.CreateLoad( @@ -1276,9 +1285,10 @@ static void speculatePHINodeLoads(PHINode &PN) { if (AATags) Load->setAAMetadata(AATags); NewPN->addIncoming(Load, Pred); + InjectedLoads[Pred] = Load; } - DEBUG(dbgs() << " speculated to: " << *NewPN << "\n"); + LLVM_DEBUG(dbgs() << " speculated to: " << *NewPN << "\n"); PN.eraseFromParent(); } @@ -1318,7 +1328,7 @@ static bool isSafeSelectToSpeculate(SelectInst &SI) { } static void speculateSelectInstLoads(SelectInst &SI) { - DEBUG(dbgs() << " original: " << SI << "\n"); + LLVM_DEBUG(dbgs() << " original: " << SI << "\n"); IRBuilderTy IRB(&SI); Value *TV = SI.getTrueValue(); @@ -1349,14 +1359,14 @@ static void speculateSelectInstLoads(SelectInst &SI) { Value *V = IRB.CreateSelect(SI.getCondition(), TL, FL, LI->getName() + ".sroa.speculated"); - DEBUG(dbgs() << " speculated to: " << *V << "\n"); + LLVM_DEBUG(dbgs() << " speculated to: " << *V << "\n"); LI->replaceAllUsesWith(V); LI->eraseFromParent(); } SI.eraseFromParent(); } -/// \brief Build a GEP out of a base pointer and indices. +/// Build a GEP out of a base pointer and indices. /// /// This will return the BasePtr if that is valid, or build a new GEP /// instruction using the IRBuilder if GEP-ing is needed. @@ -1374,7 +1384,7 @@ static Value *buildGEP(IRBuilderTy &IRB, Value *BasePtr, NamePrefix + "sroa_idx"); } -/// \brief Get a natural GEP off of the BasePtr walking through Ty toward +/// Get a natural GEP off of the BasePtr walking through Ty toward /// TargetTy without changing the offset of the pointer. /// /// This routine assumes we've already established a properly offset GEP with @@ -1423,7 +1433,7 @@ static Value *getNaturalGEPWithType(IRBuilderTy &IRB, const DataLayout &DL, return buildGEP(IRB, BasePtr, Indices, NamePrefix); } -/// \brief Recursively compute indices for a natural GEP. +/// Recursively compute indices for a natural GEP. /// /// This is the recursive step for getNaturalGEPWithOffset that walks down the /// element types adding appropriate indices for the GEP. @@ -1491,7 +1501,7 @@ static Value *getNaturalGEPRecursively(IRBuilderTy &IRB, const DataLayout &DL, Indices, NamePrefix); } -/// \brief Get a natural GEP from a base pointer to a particular offset and +/// Get a natural GEP from a base pointer to a particular offset and /// resulting in a particular type. /// /// The goal is to produce a "natural" looking GEP that works with the existing @@ -1526,7 +1536,7 @@ static Value *getNaturalGEPWithOffset(IRBuilderTy &IRB, const DataLayout &DL, Indices, NamePrefix); } -/// \brief Compute an adjusted pointer from Ptr by Offset bytes where the +/// Compute an adjusted pointer from Ptr by Offset bytes where the /// resulting pointer has PointerTy. /// /// This tries very hard to compute a "natural" GEP which arrives at the offset @@ -1635,7 +1645,7 @@ static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr, return Ptr; } -/// \brief Compute the adjusted alignment for a load or store from an offset. +/// Compute the adjusted alignment for a load or store from an offset. static unsigned getAdjustedAlignment(Instruction *I, uint64_t Offset, const DataLayout &DL) { unsigned Alignment; @@ -1656,7 +1666,7 @@ static unsigned getAdjustedAlignment(Instruction *I, uint64_t Offset, return MinAlign(Alignment, Offset); } -/// \brief Test whether we can convert a value from the old to the new type. +/// Test whether we can convert a value from the old to the new type. /// /// This predicate should be used to guard calls to convertValue in order to /// ensure that we only try to convert viable values. The strategy is that we @@ -1707,7 +1717,7 @@ static bool canConvertValue(const DataLayout &DL, Type *OldTy, Type *NewTy) { return true; } -/// \brief Generic routine to convert an SSA value to a value of a different +/// Generic routine to convert an SSA value to a value of a different /// type. /// /// This will try various different casting techniques, such as bitcasts, @@ -1759,7 +1769,7 @@ static Value *convertValue(const DataLayout &DL, IRBuilderTy &IRB, Value *V, return IRB.CreateBitCast(V, NewTy); } -/// \brief Test whether the given slice use can be promoted to a vector. +/// Test whether the given slice use can be promoted to a vector. /// /// This function is called to test each entry in a partition which is slated /// for a single slice. @@ -1830,7 +1840,7 @@ static bool isVectorPromotionViableForSlice(Partition &P, const Slice &S, return true; } -/// \brief Test whether the given alloca partitioning and range of slices can be +/// Test whether the given alloca partitioning and range of slices can be /// promoted to a vector. /// /// This is a quick test to check whether we can rewrite a particular alloca @@ -1896,7 +1906,7 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) { "All non-integer types eliminated!"); return RHSTy->getNumElements() < LHSTy->getNumElements(); }; - std::sort(CandidateTys.begin(), CandidateTys.end(), RankVectorTypes); + llvm::sort(CandidateTys.begin(), CandidateTys.end(), RankVectorTypes); CandidateTys.erase( std::unique(CandidateTys.begin(), CandidateTys.end(), RankVectorTypes), CandidateTys.end()); @@ -1943,7 +1953,7 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) { return nullptr; } -/// \brief Test whether a slice of an alloca is valid for integer widening. +/// Test whether a slice of an alloca is valid for integer widening. /// /// This implements the necessary checking for the \c isIntegerWideningViable /// test below on a single slice of the alloca. @@ -1970,6 +1980,10 @@ static bool isIntegerWideningViableForSlice(const Slice &S, // We can't handle loads that extend past the allocated memory. if (DL.getTypeStoreSize(LI->getType()) > Size) return false; + // So far, AllocaSliceRewriter does not support widening split slice tails + // in rewriteIntegerLoad. + if (S.beginOffset() < AllocBeginOffset) + return false; // Note that we don't count vector loads or stores as whole-alloca // operations which enable integer widening because we would prefer to use // vector widening instead. @@ -1991,6 +2005,10 @@ static bool isIntegerWideningViableForSlice(const Slice &S, // We can't handle stores that extend past the allocated memory. if (DL.getTypeStoreSize(ValueTy) > Size) return false; + // So far, AllocaSliceRewriter does not support widening split slice tails + // in rewriteIntegerStore. + if (S.beginOffset() < AllocBeginOffset) + return false; // Note that we don't count vector loads or stores as whole-alloca // operations which enable integer widening because we would prefer to use // vector widening instead. @@ -2021,7 +2039,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S, return true; } -/// \brief Test whether the given alloca partition's integer operations can be +/// Test whether the given alloca partition's integer operations can be /// widened to promotable ones. /// /// This is a quick test to check whether we can rewrite the integer loads and @@ -2072,7 +2090,7 @@ static bool isIntegerWideningViable(Partition &P, Type *AllocaTy, static Value *extractInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *V, IntegerType *Ty, uint64_t Offset, const Twine &Name) { - DEBUG(dbgs() << " start: " << *V << "\n"); + LLVM_DEBUG(dbgs() << " start: " << *V << "\n"); IntegerType *IntTy = cast<IntegerType>(V->getType()); assert(DL.getTypeStoreSize(Ty) + Offset <= DL.getTypeStoreSize(IntTy) && "Element extends past full value"); @@ -2081,13 +2099,13 @@ static Value *extractInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *V, ShAmt = 8 * (DL.getTypeStoreSize(IntTy) - DL.getTypeStoreSize(Ty) - Offset); if (ShAmt) { V = IRB.CreateLShr(V, ShAmt, Name + ".shift"); - DEBUG(dbgs() << " shifted: " << *V << "\n"); + LLVM_DEBUG(dbgs() << " shifted: " << *V << "\n"); } assert(Ty->getBitWidth() <= IntTy->getBitWidth() && "Cannot extract to a larger integer!"); if (Ty != IntTy) { V = IRB.CreateTrunc(V, Ty, Name + ".trunc"); - DEBUG(dbgs() << " trunced: " << *V << "\n"); + LLVM_DEBUG(dbgs() << " trunced: " << *V << "\n"); } return V; } @@ -2098,10 +2116,10 @@ static Value *insertInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *Old, IntegerType *Ty = cast<IntegerType>(V->getType()); assert(Ty->getBitWidth() <= IntTy->getBitWidth() && "Cannot insert a larger integer!"); - DEBUG(dbgs() << " start: " << *V << "\n"); + LLVM_DEBUG(dbgs() << " start: " << *V << "\n"); if (Ty != IntTy) { V = IRB.CreateZExt(V, IntTy, Name + ".ext"); - DEBUG(dbgs() << " extended: " << *V << "\n"); + LLVM_DEBUG(dbgs() << " extended: " << *V << "\n"); } assert(DL.getTypeStoreSize(Ty) + Offset <= DL.getTypeStoreSize(IntTy) && "Element store outside of alloca store"); @@ -2110,15 +2128,15 @@ static Value *insertInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *Old, ShAmt = 8 * (DL.getTypeStoreSize(IntTy) - DL.getTypeStoreSize(Ty) - Offset); if (ShAmt) { V = IRB.CreateShl(V, ShAmt, Name + ".shift"); - DEBUG(dbgs() << " shifted: " << *V << "\n"); + LLVM_DEBUG(dbgs() << " shifted: " << *V << "\n"); } if (ShAmt || Ty->getBitWidth() < IntTy->getBitWidth()) { APInt Mask = ~Ty->getMask().zext(IntTy->getBitWidth()).shl(ShAmt); Old = IRB.CreateAnd(Old, Mask, Name + ".mask"); - DEBUG(dbgs() << " masked: " << *Old << "\n"); + LLVM_DEBUG(dbgs() << " masked: " << *Old << "\n"); V = IRB.CreateOr(Old, V, Name + ".insert"); - DEBUG(dbgs() << " inserted: " << *V << "\n"); + LLVM_DEBUG(dbgs() << " inserted: " << *V << "\n"); } return V; } @@ -2135,7 +2153,7 @@ static Value *extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, if (NumElements == 1) { V = IRB.CreateExtractElement(V, IRB.getInt32(BeginIndex), Name + ".extract"); - DEBUG(dbgs() << " extract: " << *V << "\n"); + LLVM_DEBUG(dbgs() << " extract: " << *V << "\n"); return V; } @@ -2145,7 +2163,7 @@ static Value *extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, Mask.push_back(IRB.getInt32(i)); V = IRB.CreateShuffleVector(V, UndefValue::get(V->getType()), ConstantVector::get(Mask), Name + ".extract"); - DEBUG(dbgs() << " shuffle: " << *V << "\n"); + LLVM_DEBUG(dbgs() << " shuffle: " << *V << "\n"); return V; } @@ -2159,7 +2177,7 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V, // Single element to insert. V = IRB.CreateInsertElement(Old, V, IRB.getInt32(BeginIndex), Name + ".insert"); - DEBUG(dbgs() << " insert: " << *V << "\n"); + LLVM_DEBUG(dbgs() << " insert: " << *V << "\n"); return V; } @@ -2184,7 +2202,7 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V, Mask.push_back(UndefValue::get(IRB.getInt32Ty())); V = IRB.CreateShuffleVector(V, UndefValue::get(V->getType()), ConstantVector::get(Mask), Name + ".expand"); - DEBUG(dbgs() << " shuffle: " << *V << "\n"); + LLVM_DEBUG(dbgs() << " shuffle: " << *V << "\n"); Mask.clear(); for (unsigned i = 0; i != VecTy->getNumElements(); ++i) @@ -2192,11 +2210,11 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V, V = IRB.CreateSelect(ConstantVector::get(Mask), V, Old, Name + "blend"); - DEBUG(dbgs() << " blend: " << *V << "\n"); + LLVM_DEBUG(dbgs() << " blend: " << *V << "\n"); return V; } -/// \brief Visitor to rewrite instructions using p particular slice of an alloca +/// Visitor to rewrite instructions using p particular slice of an alloca /// to use a new alloca. /// /// Also implements the rewriting to vector-based accesses when the partition @@ -2295,9 +2313,9 @@ public: IsSplittable = I->isSplittable(); IsSplit = BeginOffset < NewAllocaBeginOffset || EndOffset > NewAllocaEndOffset; - DEBUG(dbgs() << " rewriting " << (IsSplit ? "split " : "")); - DEBUG(AS.printSlice(dbgs(), I, "")); - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << " rewriting " << (IsSplit ? "split " : "")); + LLVM_DEBUG(AS.printSlice(dbgs(), I, "")); + LLVM_DEBUG(dbgs() << "\n"); // Compute the intersecting offset range. assert(BeginOffset < NewAllocaEndOffset); @@ -2327,7 +2345,7 @@ private: // Every instruction which can end up as a user must have a rewrite rule. bool visitInstruction(Instruction &I) { - DEBUG(dbgs() << " !!!! Cannot rewrite: " << I << "\n"); + LLVM_DEBUG(dbgs() << " !!!! Cannot rewrite: " << I << "\n"); llvm_unreachable("No rewrite rule for this instruction!"); } @@ -2369,7 +2387,7 @@ private: ); } - /// \brief Compute suitable alignment to access this slice of the *new* + /// Compute suitable alignment to access this slice of the *new* /// alloca. /// /// You can optionally pass a type to this routine and if that type's ABI @@ -2431,10 +2449,13 @@ private: } bool visitLoadInst(LoadInst &LI) { - DEBUG(dbgs() << " original: " << LI << "\n"); + LLVM_DEBUG(dbgs() << " original: " << LI << "\n"); Value *OldOp = LI.getOperand(0); assert(OldOp == OldPtr); + AAMDNodes AATags; + LI.getAAMetadata(AATags); + unsigned AS = LI.getPointerAddressSpace(); Type *TargetTy = IsSplit ? Type::getIntNTy(LI.getContext(), SliceSize * 8) @@ -2453,6 +2474,8 @@ private: TargetTy->isIntegerTy()))) { LoadInst *NewLI = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), LI.isVolatile(), LI.getName()); + if (AATags) + NewLI->setAAMetadata(AATags); if (LI.isVolatile()) NewLI->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); @@ -2488,6 +2511,8 @@ private: LoadInst *NewLI = IRB.CreateAlignedLoad(getNewAllocaSlicePtr(IRB, LTy), getSliceAlign(TargetTy), LI.isVolatile(), LI.getName()); + if (AATags) + NewLI->setAAMetadata(AATags); if (LI.isVolatile()) NewLI->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); @@ -2524,11 +2549,12 @@ private: Pass.DeadInsts.insert(&LI); deleteIfTriviallyDead(OldOp); - DEBUG(dbgs() << " to: " << *V << "\n"); + LLVM_DEBUG(dbgs() << " to: " << *V << "\n"); return !LI.isVolatile() && !IsPtrAdjusted; } - bool rewriteVectorizedStoreInst(Value *V, StoreInst &SI, Value *OldOp) { + bool rewriteVectorizedStoreInst(Value *V, StoreInst &SI, Value *OldOp, + AAMDNodes AATags) { if (V->getType() != VecTy) { unsigned BeginIndex = getIndex(NewBeginOffset); unsigned EndIndex = getIndex(NewEndOffset); @@ -2546,14 +2572,15 @@ private: V = insertVector(IRB, Old, V, BeginIndex, "vec"); } StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment()); + if (AATags) + Store->setAAMetadata(AATags); Pass.DeadInsts.insert(&SI); - (void)Store; - DEBUG(dbgs() << " to: " << *Store << "\n"); + LLVM_DEBUG(dbgs() << " to: " << *Store << "\n"); return true; } - bool rewriteIntegerStore(Value *V, StoreInst &SI) { + bool rewriteIntegerStore(Value *V, StoreInst &SI, AAMDNodes AATags) { assert(IntTy && "We cannot extract an integer from the alloca"); assert(!SI.isVolatile()); if (DL.getTypeSizeInBits(V->getType()) != IntTy->getBitWidth()) { @@ -2567,16 +2594,21 @@ private: V = convertValue(DL, IRB, V, NewAllocaTy); StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment()); Store->copyMetadata(SI, LLVMContext::MD_mem_parallel_loop_access); + if (AATags) + Store->setAAMetadata(AATags); Pass.DeadInsts.insert(&SI); - DEBUG(dbgs() << " to: " << *Store << "\n"); + LLVM_DEBUG(dbgs() << " to: " << *Store << "\n"); return true; } bool visitStoreInst(StoreInst &SI) { - DEBUG(dbgs() << " original: " << SI << "\n"); + LLVM_DEBUG(dbgs() << " original: " << SI << "\n"); Value *OldOp = SI.getOperand(1); assert(OldOp == OldPtr); + AAMDNodes AATags; + SI.getAAMetadata(AATags); + Value *V = SI.getValueOperand(); // Strip all inbounds GEPs and pointer casts to try to dig out any root @@ -2598,9 +2630,9 @@ private: } if (VecTy) - return rewriteVectorizedStoreInst(V, SI, OldOp); + return rewriteVectorizedStoreInst(V, SI, OldOp, AATags); if (IntTy && V->getType()->isIntegerTy()) - return rewriteIntegerStore(V, SI); + return rewriteIntegerStore(V, SI, AATags); const bool IsStorePastEnd = DL.getTypeStoreSize(V->getType()) > SliceSize; StoreInst *NewSI; @@ -2631,16 +2663,18 @@ private: SI.isVolatile()); } NewSI->copyMetadata(SI, LLVMContext::MD_mem_parallel_loop_access); + if (AATags) + NewSI->setAAMetadata(AATags); if (SI.isVolatile()) NewSI->setAtomic(SI.getOrdering(), SI.getSyncScopeID()); Pass.DeadInsts.insert(&SI); deleteIfTriviallyDead(OldOp); - DEBUG(dbgs() << " to: " << *NewSI << "\n"); + LLVM_DEBUG(dbgs() << " to: " << *NewSI << "\n"); return NewSI->getPointerOperand() == &NewAI && !SI.isVolatile(); } - /// \brief Compute an integer value from splatting an i8 across the given + /// Compute an integer value from splatting an i8 across the given /// number of bytes. /// /// Note that this routine assumes an i8 is a byte. If that isn't true, don't @@ -2667,25 +2701,27 @@ private: return V; } - /// \brief Compute a vector splat for a given element value. + /// Compute a vector splat for a given element value. Value *getVectorSplat(Value *V, unsigned NumElements) { V = IRB.CreateVectorSplat(NumElements, V, "vsplat"); - DEBUG(dbgs() << " splat: " << *V << "\n"); + LLVM_DEBUG(dbgs() << " splat: " << *V << "\n"); return V; } bool visitMemSetInst(MemSetInst &II) { - DEBUG(dbgs() << " original: " << II << "\n"); + LLVM_DEBUG(dbgs() << " original: " << II << "\n"); assert(II.getRawDest() == OldPtr); + AAMDNodes AATags; + II.getAAMetadata(AATags); + // If the memset has a variable size, it cannot be split, just adjust the // pointer to the new alloca. if (!isa<Constant>(II.getLength())) { assert(!IsSplit); assert(NewBeginOffset == BeginOffset); II.setDest(getNewAllocaSlicePtr(IRB, OldPtr->getType())); - Type *CstTy = II.getAlignmentCst()->getType(); - II.setAlignment(ConstantInt::get(CstTy, getSliceAlign())); + II.setDestAlignment(getSliceAlign()); deleteIfTriviallyDead(OldPtr); return false; @@ -2710,8 +2746,9 @@ private: CallInst *New = IRB.CreateMemSet( getNewAllocaSlicePtr(IRB, OldPtr->getType()), II.getValue(), Size, getSliceAlign(), II.isVolatile()); - (void)New; - DEBUG(dbgs() << " to: " << *New << "\n"); + if (AATags) + New->setAAMetadata(AATags); + LLVM_DEBUG(dbgs() << " to: " << *New << "\n"); return false; } @@ -2773,10 +2810,11 @@ private: V = convertValue(DL, IRB, V, AllocaTy); } - Value *New = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment(), - II.isVolatile()); - (void)New; - DEBUG(dbgs() << " to: " << *New << "\n"); + StoreInst *New = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment(), + II.isVolatile()); + if (AATags) + New->setAAMetadata(AATags); + LLVM_DEBUG(dbgs() << " to: " << *New << "\n"); return !II.isVolatile(); } @@ -2784,7 +2822,10 @@ private: // Rewriting of memory transfer instructions can be a bit tricky. We break // them into two categories: split intrinsics and unsplit intrinsics. - DEBUG(dbgs() << " original: " << II << "\n"); + LLVM_DEBUG(dbgs() << " original: " << II << "\n"); + + AAMDNodes AATags; + II.getAAMetadata(AATags); bool IsDest = &II.getRawDestUse() == OldUse; assert((IsDest && II.getRawDest() == OldPtr) || @@ -2801,18 +2842,16 @@ private: // update both source and dest of a single call. if (!IsSplittable) { Value *AdjustedPtr = getNewAllocaSlicePtr(IRB, OldPtr->getType()); - if (IsDest) + if (IsDest) { II.setDest(AdjustedPtr); - else + II.setDestAlignment(SliceAlign); + } + else { II.setSource(AdjustedPtr); - - if (II.getAlignment() > SliceAlign) { - Type *CstTy = II.getAlignmentCst()->getType(); - II.setAlignment( - ConstantInt::get(CstTy, MinAlign(II.getAlignment(), SliceAlign))); + II.setSourceAlignment(SliceAlign); } - DEBUG(dbgs() << " to: " << II << "\n"); + LLVM_DEBUG(dbgs() << " to: " << II << "\n"); deleteIfTriviallyDead(OldPtr); return false; } @@ -2862,8 +2901,10 @@ private: // Compute the relative offset for the other pointer within the transfer. unsigned IntPtrWidth = DL.getPointerSizeInBits(OtherAS); APInt OtherOffset(IntPtrWidth, NewBeginOffset - BeginOffset); - unsigned OtherAlign = MinAlign(II.getAlignment() ? II.getAlignment() : 1, - OtherOffset.zextOrTrunc(64).getZExtValue()); + unsigned OtherAlign = + IsDest ? II.getSourceAlignment() : II.getDestAlignment(); + OtherAlign = MinAlign(OtherAlign ? OtherAlign : 1, + OtherOffset.zextOrTrunc(64).getZExtValue()); if (EmitMemCpy) { // Compute the other pointer, folding as much as possible to produce @@ -2875,11 +2916,25 @@ private: Type *SizeTy = II.getLength()->getType(); Constant *Size = ConstantInt::get(SizeTy, NewEndOffset - NewBeginOffset); - CallInst *New = IRB.CreateMemCpy( - IsDest ? OurPtr : OtherPtr, IsDest ? OtherPtr : OurPtr, Size, - MinAlign(SliceAlign, OtherAlign), II.isVolatile()); - (void)New; - DEBUG(dbgs() << " to: " << *New << "\n"); + Value *DestPtr, *SrcPtr; + unsigned DestAlign, SrcAlign; + // Note: IsDest is true iff we're copying into the new alloca slice + if (IsDest) { + DestPtr = OurPtr; + DestAlign = SliceAlign; + SrcPtr = OtherPtr; + SrcAlign = OtherAlign; + } else { + DestPtr = OtherPtr; + DestAlign = OtherAlign; + SrcPtr = OurPtr; + SrcAlign = SliceAlign; + } + CallInst *New = IRB.CreateMemCpy(DestPtr, DestAlign, SrcPtr, SrcAlign, + Size, II.isVolatile()); + if (AATags) + New->setAAMetadata(AATags); + LLVM_DEBUG(dbgs() << " to: " << *New << "\n"); return false; } @@ -2927,8 +2982,11 @@ private: uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; Src = extractInteger(DL, IRB, Src, SubIntTy, Offset, "extract"); } else { - Src = - IRB.CreateAlignedLoad(SrcPtr, SrcAlign, II.isVolatile(), "copyload"); + LoadInst *Load = IRB.CreateAlignedLoad(SrcPtr, SrcAlign, II.isVolatile(), + "copyload"); + if (AATags) + Load->setAAMetadata(AATags); + Src = Load; } if (VecTy && !IsWholeAlloca && IsDest) { @@ -2946,15 +3004,16 @@ private: StoreInst *Store = cast<StoreInst>( IRB.CreateAlignedStore(Src, DstPtr, DstAlign, II.isVolatile())); - (void)Store; - DEBUG(dbgs() << " to: " << *Store << "\n"); + if (AATags) + Store->setAAMetadata(AATags); + LLVM_DEBUG(dbgs() << " to: " << *Store << "\n"); return !II.isVolatile(); } bool visitIntrinsicInst(IntrinsicInst &II) { assert(II.getIntrinsicID() == Intrinsic::lifetime_start || II.getIntrinsicID() == Intrinsic::lifetime_end); - DEBUG(dbgs() << " original: " << II << "\n"); + LLVM_DEBUG(dbgs() << " original: " << II << "\n"); assert(II.getArgOperand(1) == OldPtr); // Record this instruction for deletion. @@ -2982,13 +3041,13 @@ private: New = IRB.CreateLifetimeEnd(Ptr, Size); (void)New; - DEBUG(dbgs() << " to: " << *New << "\n"); + LLVM_DEBUG(dbgs() << " to: " << *New << "\n"); return true; } bool visitPHINode(PHINode &PN) { - DEBUG(dbgs() << " original: " << PN << "\n"); + LLVM_DEBUG(dbgs() << " original: " << PN << "\n"); assert(BeginOffset >= NewAllocaBeginOffset && "PHIs are unsplittable"); assert(EndOffset <= NewAllocaEndOffset && "PHIs are unsplittable"); @@ -3007,7 +3066,7 @@ private: // Replace the operands which were using the old pointer. std::replace(PN.op_begin(), PN.op_end(), cast<Value>(OldPtr), NewPtr); - DEBUG(dbgs() << " to: " << PN << "\n"); + LLVM_DEBUG(dbgs() << " to: " << PN << "\n"); deleteIfTriviallyDead(OldPtr); // PHIs can't be promoted on their own, but often can be speculated. We @@ -3018,7 +3077,7 @@ private: } bool visitSelectInst(SelectInst &SI) { - DEBUG(dbgs() << " original: " << SI << "\n"); + LLVM_DEBUG(dbgs() << " original: " << SI << "\n"); assert((SI.getTrueValue() == OldPtr || SI.getFalseValue() == OldPtr) && "Pointer isn't an operand!"); assert(BeginOffset >= NewAllocaBeginOffset && "Selects are unsplittable"); @@ -3031,7 +3090,7 @@ private: if (SI.getOperand(2) == OldPtr) SI.setOperand(2, NewPtr); - DEBUG(dbgs() << " to: " << SI << "\n"); + LLVM_DEBUG(dbgs() << " to: " << SI << "\n"); deleteIfTriviallyDead(OldPtr); // Selects can't be promoted on their own, but often can be speculated. We @@ -3044,7 +3103,7 @@ private: namespace { -/// \brief Visitor to rewrite aggregate loads and stores as scalar. +/// Visitor to rewrite aggregate loads and stores as scalar. /// /// This pass aggressively rewrites all aggregate loads and stores on /// a particular pointer (or any pointer derived from it which we can identify) @@ -3067,7 +3126,7 @@ public: /// Rewrite loads and stores through a pointer and all pointers derived from /// it. bool rewrite(Instruction &I) { - DEBUG(dbgs() << " Rewriting FCA loads and stores...\n"); + LLVM_DEBUG(dbgs() << " Rewriting FCA loads and stores...\n"); enqueueUsers(I); bool Changed = false; while (!Queue.empty()) { @@ -3089,7 +3148,7 @@ private: // Conservative default is to not rewrite anything. bool visitInstruction(Instruction &I) { return false; } - /// \brief Generic recursive split emission class. + /// Generic recursive split emission class. template <typename Derived> class OpSplitter { protected: /// The builder used to form new instructions. @@ -3113,7 +3172,7 @@ private: : IRB(InsertionPoint), GEPIndices(1, IRB.getInt32(0)), Ptr(Ptr) {} public: - /// \brief Generic recursive split emission routine. + /// Generic recursive split emission routine. /// /// This method recursively splits an aggregate op (load or store) into /// scalar or vector ops. It splits recursively until it hits a single value @@ -3165,8 +3224,10 @@ private: }; struct LoadOpSplitter : public OpSplitter<LoadOpSplitter> { - LoadOpSplitter(Instruction *InsertionPoint, Value *Ptr) - : OpSplitter<LoadOpSplitter>(InsertionPoint, Ptr) {} + AAMDNodes AATags; + + LoadOpSplitter(Instruction *InsertionPoint, Value *Ptr, AAMDNodes AATags) + : OpSplitter<LoadOpSplitter>(InsertionPoint, Ptr), AATags(AATags) {} /// Emit a leaf load of a single value. This is called at the leaves of the /// recursive emission to actually load values. @@ -3175,9 +3236,11 @@ private: // Load the single value and insert it using the indices. Value *GEP = IRB.CreateInBoundsGEP(nullptr, Ptr, GEPIndices, Name + ".gep"); - Value *Load = IRB.CreateLoad(GEP, Name + ".load"); + LoadInst *Load = IRB.CreateLoad(GEP, Name + ".load"); + if (AATags) + Load->setAAMetadata(AATags); Agg = IRB.CreateInsertValue(Agg, Load, Indices, Name + ".insert"); - DEBUG(dbgs() << " to: " << *Load << "\n"); + LLVM_DEBUG(dbgs() << " to: " << *Load << "\n"); } }; @@ -3187,8 +3250,10 @@ private: return false; // We have an aggregate being loaded, split it apart. - DEBUG(dbgs() << " original: " << LI << "\n"); - LoadOpSplitter Splitter(&LI, *U); + LLVM_DEBUG(dbgs() << " original: " << LI << "\n"); + AAMDNodes AATags; + LI.getAAMetadata(AATags); + LoadOpSplitter Splitter(&LI, *U, AATags); Value *V = UndefValue::get(LI.getType()); Splitter.emitSplitOps(LI.getType(), V, LI.getName() + ".fca"); LI.replaceAllUsesWith(V); @@ -3197,8 +3262,9 @@ private: } struct StoreOpSplitter : public OpSplitter<StoreOpSplitter> { - StoreOpSplitter(Instruction *InsertionPoint, Value *Ptr) - : OpSplitter<StoreOpSplitter>(InsertionPoint, Ptr) {} + StoreOpSplitter(Instruction *InsertionPoint, Value *Ptr, AAMDNodes AATags) + : OpSplitter<StoreOpSplitter>(InsertionPoint, Ptr), AATags(AATags) {} + AAMDNodes AATags; /// Emit a leaf store of a single value. This is called at the leaves of the /// recursive emission to actually produce stores. @@ -3212,9 +3278,10 @@ private: IRB.CreateExtractValue(Agg, Indices, Name + ".extract"); Value *InBoundsGEP = IRB.CreateInBoundsGEP(nullptr, Ptr, GEPIndices, Name + ".gep"); - Value *Store = IRB.CreateStore(ExtractValue, InBoundsGEP); - (void)Store; - DEBUG(dbgs() << " to: " << *Store << "\n"); + StoreInst *Store = IRB.CreateStore(ExtractValue, InBoundsGEP); + if (AATags) + Store->setAAMetadata(AATags); + LLVM_DEBUG(dbgs() << " to: " << *Store << "\n"); } }; @@ -3226,8 +3293,10 @@ private: return false; // We have an aggregate being stored, split it apart. - DEBUG(dbgs() << " original: " << SI << "\n"); - StoreOpSplitter Splitter(&SI, *U); + LLVM_DEBUG(dbgs() << " original: " << SI << "\n"); + AAMDNodes AATags; + SI.getAAMetadata(AATags); + StoreOpSplitter Splitter(&SI, *U, AATags); Splitter.emitSplitOps(V->getType(), V, V->getName() + ".fca"); SI.eraseFromParent(); return true; @@ -3256,7 +3325,7 @@ private: } // end anonymous namespace -/// \brief Strip aggregate type wrapping. +/// Strip aggregate type wrapping. /// /// This removes no-op aggregate types wrapping an underlying type. It will /// strip as many layers of types as it can without changing either the type @@ -3286,7 +3355,7 @@ static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) { return stripAggregateTypeWrapping(DL, InnerTy); } -/// \brief Try to find a partition of the aggregate type passed in for a given +/// Try to find a partition of the aggregate type passed in for a given /// offset and size. /// /// This recurses through the aggregate type and tries to compute a subtype @@ -3392,7 +3461,7 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset, return SubTy; } -/// \brief Pre-split loads and stores to simplify rewriting. +/// Pre-split loads and stores to simplify rewriting. /// /// We want to break up the splittable load+store pairs as much as /// possible. This is important to do as a preprocessing step, as once we @@ -3423,7 +3492,7 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset, /// /// \returns true if any changes are made. bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { - DEBUG(dbgs() << "Pre-splitting loads and stores\n"); + LLVM_DEBUG(dbgs() << "Pre-splitting loads and stores\n"); // Track the loads and stores which are candidates for pre-splitting here, in // the order they first appear during the partition scan. These give stable @@ -3455,7 +3524,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { // maybe it would make it more principled? SmallPtrSet<LoadInst *, 8> UnsplittableLoads; - DEBUG(dbgs() << " Searching for candidate loads and stores\n"); + LLVM_DEBUG(dbgs() << " Searching for candidate loads and stores\n"); for (auto &P : AS.partitions()) { for (Slice &S : P) { Instruction *I = cast<Instruction>(S.getUse()->getUser()); @@ -3510,7 +3579,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { } // Record the initial split. - DEBUG(dbgs() << " Candidate: " << *I << "\n"); + LLVM_DEBUG(dbgs() << " Candidate: " << *I << "\n"); auto &Offsets = SplitOffsetsMap[I]; assert(Offsets.Splits.empty() && "Should not have splits the first time we see an instruction!"); @@ -3570,10 +3639,11 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { if (LoadOffsets.Splits == StoreOffsets.Splits) return false; - DEBUG(dbgs() - << " Mismatched splits for load and store:\n" - << " " << *LI << "\n" - << " " << *SI << "\n"); + LLVM_DEBUG( + dbgs() + << " Mismatched splits for load and store:\n" + << " " << *LI << "\n" + << " " << *SI << "\n"); // We've found a store and load that we need to split // with mismatched relative splits. Just give up on them @@ -3646,7 +3716,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { Instruction *BasePtr = cast<Instruction>(LI->getPointerOperand()); IRB.SetInsertPoint(LI); - DEBUG(dbgs() << " Splitting load: " << *LI << "\n"); + LLVM_DEBUG(dbgs() << " Splitting load: " << *LI << "\n"); uint64_t PartOffset = 0, PartSize = Offsets.Splits.front(); int Idx = 0, Size = Offsets.Splits.size(); @@ -3656,7 +3726,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { auto *PartPtrTy = PartTy->getPointerTo(AS); LoadInst *PLoad = IRB.CreateAlignedLoad( getAdjustedPtr(IRB, DL, BasePtr, - APInt(DL.getPointerSizeInBits(AS), PartOffset), + APInt(DL.getIndexSizeInBits(AS), PartOffset), PartPtrTy, BasePtr->getName() + "."), getAdjustedAlignment(LI, PartOffset, DL), /*IsVolatile*/ false, LI->getName()); @@ -3671,9 +3741,9 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { Slice(BaseOffset + PartOffset, BaseOffset + PartOffset + PartSize, &PLoad->getOperandUse(PLoad->getPointerOperandIndex()), /*IsSplittable*/ false)); - DEBUG(dbgs() << " new slice [" << NewSlices.back().beginOffset() - << ", " << NewSlices.back().endOffset() << "): " << *PLoad - << "\n"); + LLVM_DEBUG(dbgs() << " new slice [" << NewSlices.back().beginOffset() + << ", " << NewSlices.back().endOffset() + << "): " << *PLoad << "\n"); // See if we've handled all the splits. if (Idx >= Size) @@ -3693,14 +3763,15 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { StoreInst *SI = cast<StoreInst>(LU); if (!Stores.empty() && SplitOffsetsMap.count(SI)) { DeferredStores = true; - DEBUG(dbgs() << " Deferred splitting of store: " << *SI << "\n"); + LLVM_DEBUG(dbgs() << " Deferred splitting of store: " << *SI + << "\n"); continue; } Value *StoreBasePtr = SI->getPointerOperand(); IRB.SetInsertPoint(SI); - DEBUG(dbgs() << " Splitting store of load: " << *SI << "\n"); + LLVM_DEBUG(dbgs() << " Splitting store of load: " << *SI << "\n"); for (int Idx = 0, Size = SplitLoads.size(); Idx < Size; ++Idx) { LoadInst *PLoad = SplitLoads[Idx]; @@ -3712,11 +3783,11 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { StoreInst *PStore = IRB.CreateAlignedStore( PLoad, getAdjustedPtr(IRB, DL, StoreBasePtr, - APInt(DL.getPointerSizeInBits(AS), PartOffset), + APInt(DL.getIndexSizeInBits(AS), PartOffset), PartPtrTy, StoreBasePtr->getName() + "."), getAdjustedAlignment(SI, PartOffset, DL), /*IsVolatile*/ false); PStore->copyMetadata(*LI, LLVMContext::MD_mem_parallel_loop_access); - DEBUG(dbgs() << " +" << PartOffset << ":" << *PStore << "\n"); + LLVM_DEBUG(dbgs() << " +" << PartOffset << ":" << *PStore << "\n"); } // We want to immediately iterate on any allocas impacted by splitting @@ -3765,7 +3836,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { Value *LoadBasePtr = LI->getPointerOperand(); Instruction *StoreBasePtr = cast<Instruction>(SI->getPointerOperand()); - DEBUG(dbgs() << " Splitting store: " << *SI << "\n"); + LLVM_DEBUG(dbgs() << " Splitting store: " << *SI << "\n"); // Check whether we have an already split load. auto SplitLoadsMapI = SplitLoadsMap.find(LI); @@ -3775,7 +3846,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { assert(SplitLoads->size() == Offsets.Splits.size() + 1 && "Too few split loads for the number of splits in the store!"); } else { - DEBUG(dbgs() << " of load: " << *LI << "\n"); + LLVM_DEBUG(dbgs() << " of load: " << *LI << "\n"); } uint64_t PartOffset = 0, PartSize = Offsets.Splits.front(); @@ -3794,7 +3865,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { auto AS = LI->getPointerAddressSpace(); PLoad = IRB.CreateAlignedLoad( getAdjustedPtr(IRB, DL, LoadBasePtr, - APInt(DL.getPointerSizeInBits(AS), PartOffset), + APInt(DL.getIndexSizeInBits(AS), PartOffset), LoadPartPtrTy, LoadBasePtr->getName() + "."), getAdjustedAlignment(LI, PartOffset, DL), /*IsVolatile*/ false, LI->getName()); @@ -3806,7 +3877,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { StoreInst *PStore = IRB.CreateAlignedStore( PLoad, getAdjustedPtr(IRB, DL, StoreBasePtr, - APInt(DL.getPointerSizeInBits(AS), PartOffset), + APInt(DL.getIndexSizeInBits(AS), PartOffset), StorePartPtrTy, StoreBasePtr->getName() + "."), getAdjustedAlignment(SI, PartOffset, DL), /*IsVolatile*/ false); @@ -3815,11 +3886,11 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { Slice(BaseOffset + PartOffset, BaseOffset + PartOffset + PartSize, &PStore->getOperandUse(PStore->getPointerOperandIndex()), /*IsSplittable*/ false)); - DEBUG(dbgs() << " new slice [" << NewSlices.back().beginOffset() - << ", " << NewSlices.back().endOffset() << "): " << *PStore - << "\n"); + LLVM_DEBUG(dbgs() << " new slice [" << NewSlices.back().beginOffset() + << ", " << NewSlices.back().endOffset() + << "): " << *PStore << "\n"); if (!SplitLoads) { - DEBUG(dbgs() << " of split load: " << *PLoad << "\n"); + LLVM_DEBUG(dbgs() << " of split load: " << *PLoad << "\n"); } // See if we've finished all the splits. @@ -3874,10 +3945,10 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { // sequence. AS.insert(NewSlices); - DEBUG(dbgs() << " Pre-split slices:\n"); + LLVM_DEBUG(dbgs() << " Pre-split slices:\n"); #ifndef NDEBUG for (auto I = AS.begin(), E = AS.end(); I != E; ++I) - DEBUG(AS.print(dbgs(), I, " ")); + LLVM_DEBUG(AS.print(dbgs(), I, " ")); #endif // Finally, don't try to promote any allocas that new require re-splitting. @@ -3891,7 +3962,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { return true; } -/// \brief Rewrite an alloca partition's users. +/// Rewrite an alloca partition's users. /// /// This routine drives both of the rewriting goals of the SROA pass. It tries /// to rewrite uses of an alloca partition to be conducive for SSA value @@ -3934,10 +4005,10 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, // exact same type as the original, and with the same access offsets. In that // case, re-use the existing alloca, but still run through the rewriter to // perform phi and select speculation. + // P.beginOffset() can be non-zero even with the same type in a case with + // out-of-bounds access (e.g. @PR35657 function in SROA/basictest.ll). AllocaInst *NewAI; - if (SliceTy == AI.getAllocatedType()) { - assert(P.beginOffset() == 0 && - "Non-zero begin offset but same alloca type"); + if (SliceTy == AI.getAllocatedType() && P.beginOffset() == 0) { NewAI = &AI; // FIXME: We should be able to bail at this point with "nothing changed". // FIXME: We might want to defer PHI speculation until after here. @@ -3958,12 +4029,14 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, NewAI = new AllocaInst( SliceTy, AI.getType()->getAddressSpace(), nullptr, Alignment, AI.getName() + ".sroa." + Twine(P.begin() - AS.begin()), &AI); + // Copy the old AI debug location over to the new one. + NewAI->setDebugLoc(AI.getDebugLoc()); ++NumNewAllocas; } - DEBUG(dbgs() << "Rewriting alloca partition " - << "[" << P.beginOffset() << "," << P.endOffset() - << ") to: " << *NewAI << "\n"); + LLVM_DEBUG(dbgs() << "Rewriting alloca partition " + << "[" << P.beginOffset() << "," << P.endOffset() + << ") to: " << *NewAI << "\n"); // Track the high watermark on the worklist as it is only relevant for // promoted allocas. We will reset it to this point if the alloca is not in @@ -4040,7 +4113,7 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, return NewAI; } -/// \brief Walks the slices of an alloca and form partitions based on them, +/// Walks the slices of an alloca and form partitions based on them, /// rewriting each of their uses. bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { if (AS.begin() == AS.end()) @@ -4063,7 +4136,7 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { uint64_t AllocaSize = DL.getTypeAllocSize(AI.getAllocatedType()); const uint64_t MaxBitVectorSize = 1024; - if (SROASplitNonWholeAllocaSlices && AllocaSize <= MaxBitVectorSize) { + if (AllocaSize <= MaxBitVectorSize) { // If a byte boundary is included in any load or store, a slice starting or // ending at the boundary is not splittable. SmallBitVector SplittableOffset(AllocaSize + 1, true); @@ -4106,7 +4179,7 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { } if (!IsSorted) - std::sort(AS.begin(), AS.end()); + llvm::sort(AS.begin(), AS.end()); /// Describes the allocas introduced by rewritePartition in order to migrate /// the debug info. @@ -4201,7 +4274,7 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { return Changed; } -/// \brief Clobber a use with undef, deleting the used value if it becomes dead. +/// Clobber a use with undef, deleting the used value if it becomes dead. void SROA::clobberUse(Use &U) { Value *OldV = U; // Replace the use with an undef value. @@ -4216,13 +4289,13 @@ void SROA::clobberUse(Use &U) { } } -/// \brief Analyze an alloca for SROA. +/// Analyze an alloca for SROA. /// /// This analyzes the alloca to ensure we can reason about it, builds /// the slices of the alloca, and then hands it off to be split and /// rewritten as needed. bool SROA::runOnAlloca(AllocaInst &AI) { - DEBUG(dbgs() << "SROA alloca: " << AI << "\n"); + LLVM_DEBUG(dbgs() << "SROA alloca: " << AI << "\n"); ++NumAllocasAnalyzed; // Special case dead allocas, as they're trivial. @@ -4246,7 +4319,7 @@ bool SROA::runOnAlloca(AllocaInst &AI) { // Build the slices using a recursive instruction-visiting builder. AllocaSlices AS(DL, AI); - DEBUG(AS.print(dbgs())); + LLVM_DEBUG(AS.print(dbgs())); if (AS.isEscaped()) return Changed; @@ -4274,18 +4347,18 @@ bool SROA::runOnAlloca(AllocaInst &AI) { Changed |= splitAlloca(AI, AS); - DEBUG(dbgs() << " Speculating PHIs\n"); + LLVM_DEBUG(dbgs() << " Speculating PHIs\n"); while (!SpeculatablePHIs.empty()) speculatePHINodeLoads(*SpeculatablePHIs.pop_back_val()); - DEBUG(dbgs() << " Speculating Selects\n"); + LLVM_DEBUG(dbgs() << " Speculating Selects\n"); while (!SpeculatableSelects.empty()) speculateSelectInstLoads(*SpeculatableSelects.pop_back_val()); return Changed; } -/// \brief Delete the dead instructions accumulated in this run. +/// Delete the dead instructions accumulated in this run. /// /// Recursively deletes the dead instructions we've accumulated. This is done /// at the very end to maximize locality of the recursive delete and to @@ -4299,7 +4372,7 @@ bool SROA::deleteDeadInstructions( bool Changed = false; while (!DeadInsts.empty()) { Instruction *I = DeadInsts.pop_back_val(); - DEBUG(dbgs() << "Deleting dead instruction: " << *I << "\n"); + LLVM_DEBUG(dbgs() << "Deleting dead instruction: " << *I << "\n"); // If the instruction is an alloca, find the possible dbg.declare connected // to it, and remove it too. We must do this before calling RAUW or we will @@ -4327,7 +4400,7 @@ bool SROA::deleteDeadInstructions( return Changed; } -/// \brief Promote the allocas, using the best available technique. +/// Promote the allocas, using the best available technique. /// /// This attempts to promote whatever allocas have been identified as viable in /// the PromotableAllocas list. If that list is empty, there is nothing to do. @@ -4338,7 +4411,7 @@ bool SROA::promoteAllocas(Function &F) { NumPromoted += PromotableAllocas.size(); - DEBUG(dbgs() << "Promoting allocas with mem2reg...\n"); + LLVM_DEBUG(dbgs() << "Promoting allocas with mem2reg...\n"); PromoteMemToReg(PromotableAllocas, *DT, AC); PromotableAllocas.clear(); return true; @@ -4346,7 +4419,7 @@ bool SROA::promoteAllocas(Function &F) { PreservedAnalyses SROA::runImpl(Function &F, DominatorTree &RunDT, AssumptionCache &RunAC) { - DEBUG(dbgs() << "SROA function: " << F.getName() << "\n"); + LLVM_DEBUG(dbgs() << "SROA function: " << F.getName() << "\n"); C = &F.getContext(); DT = &RunDT; AC = &RunAC; diff --git a/lib/Transforms/Scalar/Scalar.cpp b/lib/Transforms/Scalar/Scalar.cpp index 3b99ddff2e06..526487d3477e 100644 --- a/lib/Transforms/Scalar/Scalar.cpp +++ b/lib/Transforms/Scalar/Scalar.cpp @@ -45,6 +45,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeScalarizerPass(Registry); initializeDSELegacyPassPass(Registry); initializeGuardWideningLegacyPassPass(Registry); + initializeLoopGuardWideningLegacyPassPass(Registry); initializeGVNLegacyPassPass(Registry); initializeNewGVNLegacyPassPass(Registry); initializeEarlyCSELegacyPassPass(Registry); @@ -52,9 +53,10 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeGVNHoistLegacyPassPass(Registry); initializeGVNSinkLegacyPassPass(Registry); initializeFlattenCFGPassPass(Registry); - initializeInductiveRangeCheckEliminationPass(Registry); + initializeIRCELegacyPassPass(Registry); initializeIndVarSimplifyLegacyPassPass(Registry); initializeInferAddressSpacesPass(Registry); + initializeInstSimplifyLegacyPassPass(Registry); initializeJumpThreadingPass(Registry); initializeLegacyLICMPassPass(Registry); initializeLegacyLoopSinkPassPass(Registry); @@ -68,6 +70,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeLoopStrengthReducePass(Registry); initializeLoopRerollPass(Registry); initializeLoopUnrollPass(Registry); + initializeLoopUnrollAndJamPass(Registry); initializeLoopUnswitchPass(Registry); initializeLoopVersioningLICMPass(Registry); initializeLoopIdiomRecognizeLegacyPassPass(Registry); @@ -83,7 +86,6 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeRegToMemPass(Registry); initializeRewriteStatepointsForGCLegacyPassPass(Registry); initializeSCCPLegacyPassPass(Registry); - initializeIPSCCPLegacyPassPass(Registry); initializeSROALegacyPassPass(Registry); initializeCFGSimplifyPassPass(Registry); initializeStructurizeCFGPass(Registry); @@ -104,6 +106,10 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializePostInlineEntryExitInstrumenterPass(Registry); } +void LLVMAddLoopSimplifyCFGPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopSimplifyCFGPass()); +} + void LLVMInitializeScalarOpts(LLVMPassRegistryRef R) { initializeScalarOpts(*unwrap(R)); } @@ -148,10 +154,6 @@ void LLVMAddIndVarSimplifyPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createIndVarSimplifyPass()); } -void LLVMAddInstructionCombiningPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createInstructionCombiningPass()); -} - void LLVMAddJumpThreadingPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createJumpThreadingPass()); } @@ -180,14 +182,14 @@ void LLVMAddLoopRerollPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createLoopRerollPass()); } -void LLVMAddLoopSimplifyCFGPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLoopSimplifyCFGPass()); -} - void LLVMAddLoopUnrollPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createLoopUnrollPass()); } +void LLVMAddLoopUnrollAndJamPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopUnrollAndJamPass()); +} + void LLVMAddLoopUnswitchPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createLoopUnswitchPass()); } @@ -200,14 +202,6 @@ void LLVMAddPartiallyInlineLibCallsPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createPartiallyInlineLibCallsPass()); } -void LLVMAddLowerSwitchPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLowerSwitchPass()); -} - -void LLVMAddPromoteMemoryToRegisterPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createPromoteMemoryToRegisterPass()); -} - void LLVMAddReassociatePass(LLVMPassManagerRef PM) { unwrap(PM)->add(createReassociatePass()); } diff --git a/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index 8fa9ffb6d014..967f4a42a8fb 100644 --- a/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -165,8 +165,8 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" @@ -190,7 +190,6 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstdint> #include <string> @@ -213,7 +212,7 @@ static cl::opt<bool> namespace { -/// \brief A helper class for separating a constant offset from a GEP index. +/// A helper class for separating a constant offset from a GEP index. /// /// In real programs, a GEP index may be more complicated than a simple addition /// of something and a constant integer which can be trivially splitted. For @@ -340,16 +339,15 @@ private: const DominatorTree *DT; }; -/// \brief A pass that tries to split every GEP in the function into a variadic +/// A pass that tries to split every GEP in the function into a variadic /// base and a constant offset. It is a FunctionPass because searching for the /// constant offset may inspect other basic blocks. class SeparateConstOffsetFromGEP : public FunctionPass { public: static char ID; - SeparateConstOffsetFromGEP(const TargetMachine *TM = nullptr, - bool LowerGEP = false) - : FunctionPass(ID), TM(TM), LowerGEP(LowerGEP) { + SeparateConstOffsetFromGEP(bool LowerGEP = false) + : FunctionPass(ID), LowerGEP(LowerGEP) { initializeSeparateConstOffsetFromGEPPass(*PassRegistry::getPassRegistry()); } @@ -450,7 +448,6 @@ private: const DataLayout *DL = nullptr; DominatorTree *DT = nullptr; ScalarEvolution *SE; - const TargetMachine *TM; LoopInfo *LI; TargetLibraryInfo *TLI; @@ -480,10 +477,8 @@ INITIALIZE_PASS_END( "Split GEPs to a variadic base and a constant offset for better CSE", false, false) -FunctionPass * -llvm::createSeparateConstOffsetFromGEPPass(const TargetMachine *TM, - bool LowerGEP) { - return new SeparateConstOffsetFromGEP(TM, LowerGEP); +FunctionPass *llvm::createSeparateConstOffsetFromGEPPass(bool LowerGEP) { + return new SeparateConstOffsetFromGEP(LowerGEP); } bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended, @@ -502,6 +497,8 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended, Value *LHS = BO->getOperand(0), *RHS = BO->getOperand(1); // Do not trace into "or" unless it is equivalent to "add". If LHS and RHS // don't have common bits, (LHS | RHS) is equivalent to (LHS + RHS). + // FIXME: this does not appear to be covered by any tests + // (with x86/aarch64 backends at least) if (BO->getOpcode() == Instruction::Or && !haveNoCommonBitsSet(LHS, RHS, DL, nullptr, BO, DT)) return false; @@ -590,6 +587,10 @@ APInt ConstantOffsetExtractor::find(Value *V, bool SignExtended, // Trace into subexpressions for more hoisting opportunities. if (CanTraceInto(SignExtended, ZeroExtended, BO, NonNegative)) ConstantOffset = findInEitherOperand(BO, SignExtended, ZeroExtended); + } else if (isa<TruncInst>(V)) { + ConstantOffset = + find(U->getOperand(0), SignExtended, ZeroExtended, NonNegative) + .trunc(BitWidth); } else if (isa<SExtInst>(V)) { ConstantOffset = find(U->getOperand(0), /* SignExtended */ true, ZeroExtended, NonNegative).sext(BitWidth); @@ -654,8 +655,9 @@ ConstantOffsetExtractor::distributeExtsAndCloneChain(unsigned ChainIndex) { } if (CastInst *Cast = dyn_cast<CastInst>(U)) { - assert((isa<SExtInst>(Cast) || isa<ZExtInst>(Cast)) && - "We only traced into two types of CastInst: sext and zext"); + assert( + (isa<SExtInst>(Cast) || isa<ZExtInst>(Cast) || isa<TruncInst>(Cast)) && + "Only following instructions can be traced: sext, zext & trunc"); ExtInsts.push_back(Cast); UserChain[ChainIndex] = nullptr; return distributeExtsAndCloneChain(ChainIndex - 1); @@ -706,7 +708,7 @@ Value *ConstantOffsetExtractor::removeConstOffset(unsigned ChainIndex) { BinaryOperator::BinaryOps NewOp = BO->getOpcode(); if (BO->getOpcode() == Instruction::Or) { // Rebuild "or" as "add", because "or" may be invalid for the new - // epxression. + // expression. // // For instance, given // a | (b + 5) where a and b + 5 have no common bits, @@ -943,6 +945,10 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { if (!NeedsExtraction) return Changed; + + TargetTransformInfo &TTI = + getAnalysis<TargetTransformInfoWrapperPass>().getTTI(*GEP->getFunction()); + // If LowerGEP is disabled, before really splitting the GEP, check whether the // backend supports the addressing mode we are about to produce. If no, this // splitting probably won't be beneficial. @@ -951,9 +957,6 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { // of variable indices. Therefore, we don't check for addressing modes in that // case. if (!LowerGEP) { - TargetTransformInfo &TTI = - getAnalysis<TargetTransformInfoWrapperPass>().getTTI( - *GEP->getParent()->getParent()); unsigned AddrSpace = GEP->getPointerAddressSpace(); if (!TTI.isLegalAddressingMode(GEP->getResultElementType(), /*BaseGV=*/nullptr, AccumulativeByteOffset, @@ -1016,7 +1019,7 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { if (LowerGEP) { // As currently BasicAA does not analyze ptrtoint/inttoptr, do not lower to // arithmetic operations if the target uses alias analysis in codegen. - if (TM && TM->getSubtargetImpl(*GEP->getParent()->getParent())->useAA()) + if (TTI.useAA()) lowerToSingleIndexGEPs(GEP, AccumulativeByteOffset); else lowerToArithmetics(GEP, AccumulativeByteOffset); @@ -1065,12 +1068,13 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { DL->getTypeAllocSize(GEP->getResultElementType())); Type *IntPtrTy = DL->getIntPtrType(GEP->getType()); if (AccumulativeByteOffset % ElementTypeSizeOfGEP == 0) { - // Very likely. As long as %gep is natually aligned, the byte offset we + // Very likely. As long as %gep is naturally aligned, the byte offset we // extracted should be a multiple of sizeof(*%gep). int64_t Index = AccumulativeByteOffset / ElementTypeSizeOfGEP; NewGEP = GetElementPtrInst::Create(GEP->getResultElementType(), NewGEP, ConstantInt::get(IntPtrTy, Index, true), GEP->getName(), GEP); + NewGEP->copyMetadata(*GEP); // Inherit the inbounds attribute of the original GEP. cast<GetElementPtrInst>(NewGEP)->setIsInBounds(GEPWasInBounds); } else { @@ -1095,6 +1099,7 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { Type::getInt8Ty(GEP->getContext()), NewGEP, ConstantInt::get(IntPtrTy, AccumulativeByteOffset, true), "uglygep", GEP); + NewGEP->copyMetadata(*GEP); // Inherit the inbounds attribute of the original GEP. cast<GetElementPtrInst>(NewGEP)->setIsInBounds(GEPWasInBounds); if (GEP->getType() != I8PtrTy) @@ -1293,7 +1298,7 @@ void SeparateConstOffsetFromGEP::swapGEPOperand(GetElementPtrInst *First, // We changed p+o+c to p+c+o, p+c may not be inbound anymore. const DataLayout &DAL = First->getModule()->getDataLayout(); - APInt Offset(DAL.getPointerSizeInBits( + APInt Offset(DAL.getIndexSizeInBits( cast<PointerType>(First->getType())->getAddressSpace()), 0); Value *NewBase = diff --git a/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index 3d0fca0bc3a5..34510cb40732 100644 --- a/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -1,4 +1,4 @@ -//===- SimpleLoopUnswitch.cpp - Hoist loop-invariant control flow ---------===// +///===- SimpleLoopUnswitch.cpp - Hoist loop-invariant control flow ---------===// // // The LLVM Compiler Infrastructure // @@ -17,10 +17,14 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/Utils/Local.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" @@ -66,180 +70,65 @@ static cl::opt<int> UnswitchThreshold("unswitch-threshold", cl::init(50), cl::Hidden, cl::desc("The cost threshold for unswitching a loop.")); -static void replaceLoopUsesWithConstant(Loop &L, Value &LIC, - Constant &Replacement) { - assert(!isa<Constant>(LIC) && "Why are we unswitching on a constant?"); - - // Replace uses of LIC in the loop with the given constant. - for (auto UI = LIC.use_begin(), UE = LIC.use_end(); UI != UE;) { - // Grab the use and walk past it so we can clobber it in the use list. - Use *U = &*UI++; - Instruction *UserI = dyn_cast<Instruction>(U->getUser()); - if (!UserI || !L.contains(UserI)) - continue; - - // Replace this use within the loop body. - *U = &Replacement; - } -} - -/// Update the IDom for a basic block whose predecessor set has changed. -/// -/// This routine is designed to work when the domtree update is relatively -/// localized by leveraging a known common dominator, often a loop header. -/// -/// FIXME: Should consider hand-rolling a slightly more efficient non-DFS -/// approach here as we can do that easily by persisting the candidate IDom's -/// dominating set between each predecessor. +/// Collect all of the loop invariant input values transitively used by the +/// homogeneous instruction graph from a given root. /// -/// FIXME: Longer term, many uses of this can be replaced by an incremental -/// domtree update strategy that starts from a known dominating block and -/// rebuilds that subtree. -static bool updateIDomWithKnownCommonDominator(BasicBlock *BB, - BasicBlock *KnownDominatingBB, - DominatorTree &DT) { - assert(pred_begin(BB) != pred_end(BB) && - "This routine does not handle unreachable blocks!"); - - BasicBlock *OrigIDom = DT[BB]->getIDom()->getBlock(); - - BasicBlock *IDom = *pred_begin(BB); - assert(DT.dominates(KnownDominatingBB, IDom) && - "Bad known dominating block!"); - - // Walk all of the other predecessors finding the nearest common dominator - // until all predecessors are covered or we reach the loop header. The loop - // header necessarily dominates all loop exit blocks in loop simplified form - // so we can early-exit the moment we hit that block. - for (auto PI = std::next(pred_begin(BB)), PE = pred_end(BB); - PI != PE && IDom != KnownDominatingBB; ++PI) { - assert(DT.dominates(KnownDominatingBB, *PI) && - "Bad known dominating block!"); - IDom = DT.findNearestCommonDominator(IDom, *PI); - } +/// This essentially walks from a root recursively through loop variant operands +/// which have the exact same opcode and finds all inputs which are loop +/// invariant. For some operations these can be re-associated and unswitched out +/// of the loop entirely. +static TinyPtrVector<Value *> +collectHomogenousInstGraphLoopInvariants(Loop &L, Instruction &Root, + LoopInfo &LI) { + assert(!L.isLoopInvariant(&Root) && + "Only need to walk the graph if root itself is not invariant."); + TinyPtrVector<Value *> Invariants; + + // Build a worklist and recurse through operators collecting invariants. + SmallVector<Instruction *, 4> Worklist; + SmallPtrSet<Instruction *, 8> Visited; + Worklist.push_back(&Root); + Visited.insert(&Root); + do { + Instruction &I = *Worklist.pop_back_val(); + for (Value *OpV : I.operand_values()) { + // Skip constants as unswitching isn't interesting for them. + if (isa<Constant>(OpV)) + continue; - if (IDom == OrigIDom) - return false; + // Add it to our result if loop invariant. + if (L.isLoopInvariant(OpV)) { + Invariants.push_back(OpV); + continue; + } - DT.changeImmediateDominator(BB, IDom); - return true; -} + // If not an instruction with the same opcode, nothing we can do. + Instruction *OpI = dyn_cast<Instruction>(OpV); + if (!OpI || OpI->getOpcode() != Root.getOpcode()) + continue; -// Note that we don't currently use the IDFCalculator here for two reasons: -// 1) It computes dominator tree levels for the entire function on each run -// of 'compute'. While this isn't terrible, given that we expect to update -// relatively small subtrees of the domtree, it isn't necessarily the right -// tradeoff. -// 2) The interface doesn't fit this usage well. It doesn't operate in -// append-only, and builds several sets that we don't need. -// -// FIXME: Neither of these issues are a big deal and could be addressed with -// some amount of refactoring of IDFCalculator. That would allow us to share -// the core logic here (which is solving the same core problem). -static void appendDomFrontier(DomTreeNode *Node, - SmallSetVector<BasicBlock *, 4> &Worklist, - SmallVectorImpl<DomTreeNode *> &DomNodes, - SmallPtrSetImpl<BasicBlock *> &DomSet) { - assert(DomNodes.empty() && "Must start with no dominator nodes."); - assert(DomSet.empty() && "Must start with an empty dominator set."); - - // First flatten this subtree into sequence of nodes by doing a pre-order - // walk. - DomNodes.push_back(Node); - // We intentionally re-evaluate the size as each node can add new children. - // Because this is a tree walk, this cannot add any duplicates. - for (int i = 0; i < (int)DomNodes.size(); ++i) - DomNodes.insert(DomNodes.end(), DomNodes[i]->begin(), DomNodes[i]->end()); - - // Now create a set of the basic blocks so we can quickly test for - // dominated successors. We could in theory use the DFS numbers of the - // dominator tree for this, but we want this to remain predictably fast - // even while we mutate the dominator tree in ways that would invalidate - // the DFS numbering. - for (DomTreeNode *InnerN : DomNodes) - DomSet.insert(InnerN->getBlock()); - - // Now re-walk the nodes, appending every successor of every node that isn't - // in the set. Note that we don't append the node itself, even though if it - // is a successor it does not strictly dominate itself and thus it would be - // part of the dominance frontier. The reason we don't append it is that - // the node passed in came *from* the worklist and so it has already been - // processed. - for (DomTreeNode *InnerN : DomNodes) - for (BasicBlock *SuccBB : successors(InnerN->getBlock())) - if (!DomSet.count(SuccBB)) - Worklist.insert(SuccBB); - - DomNodes.clear(); - DomSet.clear(); -} + // Visit this operand. + if (Visited.insert(OpI).second) + Worklist.push_back(OpI); + } + } while (!Worklist.empty()); -/// Update the dominator tree after unswitching a particular former exit block. -/// -/// This handles the full update of the dominator tree after hoisting a block -/// that previously was an exit block (or split off of an exit block) up to be -/// reached from the new immediate dominator of the preheader. -/// -/// The common case is simple -- we just move the unswitched block to have an -/// immediate dominator of the old preheader. But in complex cases, there may -/// be other blocks reachable from the unswitched block that are immediately -/// dominated by some node between the unswitched one and the old preheader. -/// All of these also need to be hoisted in the dominator tree. We also want to -/// minimize queries to the dominator tree because each step of this -/// invalidates any DFS numbers that would make queries fast. -static void updateDTAfterUnswitch(BasicBlock *UnswitchedBB, BasicBlock *OldPH, - DominatorTree &DT) { - DomTreeNode *OldPHNode = DT[OldPH]; - DomTreeNode *UnswitchedNode = DT[UnswitchedBB]; - // If the dominator tree has already been updated for this unswitched node, - // we're done. This makes it easier to use this routine if there are multiple - // paths to the same unswitched destination. - if (UnswitchedNode->getIDom() == OldPHNode) - return; + return Invariants; +} - // First collect the domtree nodes that we are hoisting over. These are the - // set of nodes which may have children that need to be hoisted as well. - SmallPtrSet<DomTreeNode *, 4> DomChain; - for (auto *IDom = UnswitchedNode->getIDom(); IDom != OldPHNode; - IDom = IDom->getIDom()) - DomChain.insert(IDom); - - // The unswitched block ends up immediately dominated by the old preheader -- - // regardless of whether it is the loop exit block or split off of the loop - // exit block. - DT.changeImmediateDominator(UnswitchedNode, OldPHNode); - - // For everything that moves up the dominator tree, we need to examine the - // dominator frontier to see if it additionally should move up the dominator - // tree. This lambda appends the dominator frontier for a node on the - // worklist. - SmallSetVector<BasicBlock *, 4> Worklist; - - // Scratch data structures reused by domfrontier finding. - SmallVector<DomTreeNode *, 4> DomNodes; - SmallPtrSet<BasicBlock *, 4> DomSet; - - // Append the initial dom frontier nodes. - appendDomFrontier(UnswitchedNode, Worklist, DomNodes, DomSet); - - // Walk the worklist. We grow the list in the loop and so must recompute size. - for (int i = 0; i < (int)Worklist.size(); ++i) { - auto *BB = Worklist[i]; - - DomTreeNode *Node = DT[BB]; - assert(!DomChain.count(Node) && - "Cannot be dominated by a block you can reach!"); - - // If this block had an immediate dominator somewhere in the chain - // we hoisted over, then its position in the domtree needs to move as it is - // reachable from a node hoisted over this chain. - if (!DomChain.count(Node->getIDom())) - continue; +static void replaceLoopInvariantUses(Loop &L, Value *Invariant, + Constant &Replacement) { + assert(!isa<Constant>(Invariant) && "Why are we unswitching on a constant?"); - DT.changeImmediateDominator(Node, OldPHNode); + // Replace uses of LIC in the loop with the given constant. + for (auto UI = Invariant->use_begin(), UE = Invariant->use_end(); UI != UE;) { + // Grab the use and walk past it so we can clobber it in the use list. + Use *U = &*UI++; + Instruction *UserI = dyn_cast<Instruction>(U->getUser()); - // Now add this node's dominator frontier to the worklist as well. - appendDomFrontier(Node, Worklist, DomNodes, DomSet); + // Replace this use within the loop body. + if (UserI && L.contains(UserI)) + U->set(&Replacement); } } @@ -261,6 +150,26 @@ static bool areLoopExitPHIsLoopInvariant(Loop &L, BasicBlock &ExitingBB, llvm_unreachable("Basic blocks should never be empty!"); } +/// Insert code to test a set of loop invariant values, and conditionally branch +/// on them. +static void buildPartialUnswitchConditionalBranch(BasicBlock &BB, + ArrayRef<Value *> Invariants, + bool Direction, + BasicBlock &UnswitchedSucc, + BasicBlock &NormalSucc) { + IRBuilder<> IRB(&BB); + Value *Cond = Invariants.front(); + for (Value *Invariant : + make_range(std::next(Invariants.begin()), Invariants.end())) + if (Direction) + Cond = IRB.CreateOr(Cond, Invariant); + else + Cond = IRB.CreateAnd(Cond, Invariant); + + IRB.CreateCondBr(Cond, Direction ? &UnswitchedSucc : &NormalSucc, + Direction ? &NormalSucc : &UnswitchedSucc); +} + /// Rewrite the PHI nodes in an unswitched loop exit basic block. /// /// Requires that the loop exit and unswitched basic block are the same, and @@ -271,19 +180,14 @@ static bool areLoopExitPHIsLoopInvariant(Loop &L, BasicBlock &ExitingBB, static void rewritePHINodesForUnswitchedExitBlock(BasicBlock &UnswitchedBB, BasicBlock &OldExitingBB, BasicBlock &OldPH) { - for (Instruction &I : UnswitchedBB) { - auto *PN = dyn_cast<PHINode>(&I); - if (!PN) - // No more PHIs to check. - break; - + for (PHINode &PN : UnswitchedBB.phis()) { // When the loop exit is directly unswitched we just need to update the // incoming basic block. We loop to handle weird cases with repeated // incoming blocks, but expect to typically only have one operand here. - for (auto i : seq<int>(0, PN->getNumOperands())) { - assert(PN->getIncomingBlock(i) == &OldExitingBB && + for (auto i : seq<int>(0, PN.getNumOperands())) { + assert(PN.getIncomingBlock(i) == &OldExitingBB && "Found incoming block different from unique predecessor!"); - PN->setIncomingBlock(i, &OldPH); + PN.setIncomingBlock(i, &OldPH); } } } @@ -298,18 +202,14 @@ static void rewritePHINodesForUnswitchedExitBlock(BasicBlock &UnswitchedBB, static void rewritePHINodesForExitAndUnswitchedBlocks(BasicBlock &ExitBB, BasicBlock &UnswitchedBB, BasicBlock &OldExitingBB, - BasicBlock &OldPH) { + BasicBlock &OldPH, + bool FullUnswitch) { assert(&ExitBB != &UnswitchedBB && "Must have different loop exit and unswitched blocks!"); Instruction *InsertPt = &*UnswitchedBB.begin(); - for (Instruction &I : ExitBB) { - auto *PN = dyn_cast<PHINode>(&I); - if (!PN) - // No more PHIs to check. - break; - - auto *NewPN = PHINode::Create(PN->getType(), /*NumReservedValues*/ 2, - PN->getName() + ".split", InsertPt); + for (PHINode &PN : ExitBB.phis()) { + auto *NewPN = PHINode::Create(PN.getType(), /*NumReservedValues*/ 2, + PN.getName() + ".split", InsertPt); // Walk backwards over the old PHI node's inputs to minimize the cost of // removing each one. We have to do this weird loop manually so that we @@ -320,18 +220,92 @@ static void rewritePHINodesForExitAndUnswitchedBlocks(BasicBlock &ExitBB, // allowed us to create a single entry for a predecessor block without // having separate entries for each "edge" even though these edges are // required to produce identical results. - for (int i = PN->getNumIncomingValues() - 1; i >= 0; --i) { - if (PN->getIncomingBlock(i) != &OldExitingBB) + for (int i = PN.getNumIncomingValues() - 1; i >= 0; --i) { + if (PN.getIncomingBlock(i) != &OldExitingBB) continue; - Value *Incoming = PN->removeIncomingValue(i); + Value *Incoming = PN.getIncomingValue(i); + if (FullUnswitch) + // No more edge from the old exiting block to the exit block. + PN.removeIncomingValue(i); + NewPN->addIncoming(Incoming, &OldPH); } // Now replace the old PHI with the new one and wire the old one in as an // input to the new one. - PN->replaceAllUsesWith(NewPN); - NewPN->addIncoming(PN, &ExitBB); + PN.replaceAllUsesWith(NewPN); + NewPN->addIncoming(&PN, &ExitBB); + } +} + +/// Hoist the current loop up to the innermost loop containing a remaining exit. +/// +/// Because we've removed an exit from the loop, we may have changed the set of +/// loops reachable and need to move the current loop up the loop nest or even +/// to an entirely separate nest. +static void hoistLoopToNewParent(Loop &L, BasicBlock &Preheader, + DominatorTree &DT, LoopInfo &LI) { + // If the loop is already at the top level, we can't hoist it anywhere. + Loop *OldParentL = L.getParentLoop(); + if (!OldParentL) + return; + + SmallVector<BasicBlock *, 4> Exits; + L.getExitBlocks(Exits); + Loop *NewParentL = nullptr; + for (auto *ExitBB : Exits) + if (Loop *ExitL = LI.getLoopFor(ExitBB)) + if (!NewParentL || NewParentL->contains(ExitL)) + NewParentL = ExitL; + + if (NewParentL == OldParentL) + return; + + // The new parent loop (if different) should always contain the old one. + if (NewParentL) + assert(NewParentL->contains(OldParentL) && + "Can only hoist this loop up the nest!"); + + // The preheader will need to move with the body of this loop. However, + // because it isn't in this loop we also need to update the primary loop map. + assert(OldParentL == LI.getLoopFor(&Preheader) && + "Parent loop of this loop should contain this loop's preheader!"); + LI.changeLoopFor(&Preheader, NewParentL); + + // Remove this loop from its old parent. + OldParentL->removeChildLoop(&L); + + // Add the loop either to the new parent or as a top-level loop. + if (NewParentL) + NewParentL->addChildLoop(&L); + else + LI.addTopLevelLoop(&L); + + // Remove this loops blocks from the old parent and every other loop up the + // nest until reaching the new parent. Also update all of these + // no-longer-containing loops to reflect the nesting change. + for (Loop *OldContainingL = OldParentL; OldContainingL != NewParentL; + OldContainingL = OldContainingL->getParentLoop()) { + llvm::erase_if(OldContainingL->getBlocksVector(), + [&](const BasicBlock *BB) { + return BB == &Preheader || L.contains(BB); + }); + + OldContainingL->getBlocksSet().erase(&Preheader); + for (BasicBlock *BB : L.blocks()) + OldContainingL->getBlocksSet().erase(BB); + + // Because we just hoisted a loop out of this one, we have essentially + // created new exit paths from it. That means we need to form LCSSA PHI + // nodes for values used in the no-longer-nested loop. + formLCSSA(*OldContainingL, DT, &LI, nullptr); + + // We shouldn't need to form dedicated exits because the exit introduced + // here is the (just split by unswitching) preheader. As such, it is + // necessarily dedicated. + assert(OldContainingL->hasDedicatedExits() && + "Unexpected predecessor of hoisted loop preheader!"); } } @@ -349,48 +323,83 @@ static void rewritePHINodesForExitAndUnswitchedBlocks(BasicBlock &ExitBB, /// (splitting the exit block as necessary). It simplifies the branch within /// the loop to an unconditional branch but doesn't remove it entirely. Further /// cleanup can be done with some simplify-cfg like pass. +/// +/// If `SE` is not null, it will be updated based on the potential loop SCEVs +/// invalidated by this. static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, - LoopInfo &LI) { + LoopInfo &LI, ScalarEvolution *SE) { assert(BI.isConditional() && "Can only unswitch a conditional branch!"); - DEBUG(dbgs() << " Trying to unswitch branch: " << BI << "\n"); + LLVM_DEBUG(dbgs() << " Trying to unswitch branch: " << BI << "\n"); - Value *LoopCond = BI.getCondition(); + // The loop invariant values that we want to unswitch. + TinyPtrVector<Value *> Invariants; - // Need a trivial loop condition to unswitch. - if (!L.isLoopInvariant(LoopCond)) - return false; + // When true, we're fully unswitching the branch rather than just unswitching + // some input conditions to the branch. + bool FullUnswitch = false; - // FIXME: We should compute this once at the start and update it! - SmallVector<BasicBlock *, 16> ExitBlocks; - L.getExitBlocks(ExitBlocks); - SmallPtrSet<BasicBlock *, 16> ExitBlockSet(ExitBlocks.begin(), - ExitBlocks.end()); - - // Check to see if a successor of the branch is guaranteed to - // exit through a unique exit block without having any - // side-effects. If so, determine the value of Cond that causes - // it to do this. - ConstantInt *CondVal = ConstantInt::getTrue(BI.getContext()); - ConstantInt *Replacement = ConstantInt::getFalse(BI.getContext()); + if (L.isLoopInvariant(BI.getCondition())) { + Invariants.push_back(BI.getCondition()); + FullUnswitch = true; + } else { + if (auto *CondInst = dyn_cast<Instruction>(BI.getCondition())) + Invariants = collectHomogenousInstGraphLoopInvariants(L, *CondInst, LI); + if (Invariants.empty()) + // Couldn't find invariant inputs! + return false; + } + + // Check that one of the branch's successors exits, and which one. + bool ExitDirection = true; int LoopExitSuccIdx = 0; auto *LoopExitBB = BI.getSuccessor(0); - if (!ExitBlockSet.count(LoopExitBB)) { - std::swap(CondVal, Replacement); + if (L.contains(LoopExitBB)) { + ExitDirection = false; LoopExitSuccIdx = 1; LoopExitBB = BI.getSuccessor(1); - if (!ExitBlockSet.count(LoopExitBB)) + if (L.contains(LoopExitBB)) return false; } auto *ContinueBB = BI.getSuccessor(1 - LoopExitSuccIdx); - assert(L.contains(ContinueBB) && - "Cannot have both successors exit and still be in the loop!"); - auto *ParentBB = BI.getParent(); if (!areLoopExitPHIsLoopInvariant(L, *ParentBB, *LoopExitBB)) return false; - DEBUG(dbgs() << " unswitching trivial branch when: " << CondVal - << " == " << LoopCond << "\n"); + // When unswitching only part of the branch's condition, we need the exit + // block to be reached directly from the partially unswitched input. This can + // be done when the exit block is along the true edge and the branch condition + // is a graph of `or` operations, or the exit block is along the false edge + // and the condition is a graph of `and` operations. + if (!FullUnswitch) { + if (ExitDirection) { + if (cast<Instruction>(BI.getCondition())->getOpcode() != Instruction::Or) + return false; + } else { + if (cast<Instruction>(BI.getCondition())->getOpcode() != Instruction::And) + return false; + } + } + + LLVM_DEBUG({ + dbgs() << " unswitching trivial invariant conditions for: " << BI + << "\n"; + for (Value *Invariant : Invariants) { + dbgs() << " " << *Invariant << " == true"; + if (Invariant != Invariants.back()) + dbgs() << " ||"; + dbgs() << "\n"; + } + }); + + // If we have scalar evolutions, we need to invalidate them including this + // loop and the loop containing the exit block. + if (SE) { + if (Loop *ExitL = LI.getLoopFor(LoopExitBB)) + SE->forgetLoop(ExitL); + else + // Forget the entire nest as this exits the entire nest. + SE->forgetTopmostLoop(&L); + } // Split the preheader, so that we know that there is a safe place to insert // the conditional branch. We will change the preheader to have a conditional @@ -403,45 +412,73 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, // unswitching. We need to split this if there are other loop predecessors. // Because the loop is in simplified form, *any* other predecessor is enough. BasicBlock *UnswitchedBB; - if (BasicBlock *PredBB = LoopExitBB->getUniquePredecessor()) { - (void)PredBB; - assert(PredBB == BI.getParent() && + if (FullUnswitch && LoopExitBB->getUniquePredecessor()) { + assert(LoopExitBB->getUniquePredecessor() == BI.getParent() && "A branch's parent isn't a predecessor!"); UnswitchedBB = LoopExitBB; } else { UnswitchedBB = SplitBlock(LoopExitBB, &LoopExitBB->front(), &DT, &LI); } - // Now splice the branch to gate reaching the new preheader and re-point its - // successors. - OldPH->getInstList().splice(std::prev(OldPH->end()), - BI.getParent()->getInstList(), BI); + // Actually move the invariant uses into the unswitched position. If possible, + // we do this by moving the instructions, but when doing partial unswitching + // we do it by building a new merge of the values in the unswitched position. OldPH->getTerminator()->eraseFromParent(); - BI.setSuccessor(LoopExitSuccIdx, UnswitchedBB); - BI.setSuccessor(1 - LoopExitSuccIdx, NewPH); - - // Create a new unconditional branch that will continue the loop as a new - // terminator. - BranchInst::Create(ContinueBB, ParentBB); + if (FullUnswitch) { + // If fully unswitching, we can use the existing branch instruction. + // Splice it into the old PH to gate reaching the new preheader and re-point + // its successors. + OldPH->getInstList().splice(OldPH->end(), BI.getParent()->getInstList(), + BI); + BI.setSuccessor(LoopExitSuccIdx, UnswitchedBB); + BI.setSuccessor(1 - LoopExitSuccIdx, NewPH); + + // Create a new unconditional branch that will continue the loop as a new + // terminator. + BranchInst::Create(ContinueBB, ParentBB); + } else { + // Only unswitching a subset of inputs to the condition, so we will need to + // build a new branch that merges the invariant inputs. + if (ExitDirection) + assert(cast<Instruction>(BI.getCondition())->getOpcode() == + Instruction::Or && + "Must have an `or` of `i1`s for the condition!"); + else + assert(cast<Instruction>(BI.getCondition())->getOpcode() == + Instruction::And && + "Must have an `and` of `i1`s for the condition!"); + buildPartialUnswitchConditionalBranch(*OldPH, Invariants, ExitDirection, + *UnswitchedBB, *NewPH); + } // Rewrite the relevant PHI nodes. if (UnswitchedBB == LoopExitBB) rewritePHINodesForUnswitchedExitBlock(*UnswitchedBB, *ParentBB, *OldPH); else rewritePHINodesForExitAndUnswitchedBlocks(*LoopExitBB, *UnswitchedBB, - *ParentBB, *OldPH); + *ParentBB, *OldPH, FullUnswitch); // Now we need to update the dominator tree. - updateDTAfterUnswitch(UnswitchedBB, OldPH, DT); - // But if we split something off of the loop exit block then we also removed - // one of the predecessors for the loop exit block and may need to update its - // idom. - if (UnswitchedBB != LoopExitBB) - updateIDomWithKnownCommonDominator(LoopExitBB, L.getHeader(), DT); + DT.insertEdge(OldPH, UnswitchedBB); + if (FullUnswitch) + DT.deleteEdge(ParentBB, UnswitchedBB); + + // The constant we can replace all of our invariants with inside the loop + // body. If any of the invariants have a value other than this the loop won't + // be entered. + ConstantInt *Replacement = ExitDirection + ? ConstantInt::getFalse(BI.getContext()) + : ConstantInt::getTrue(BI.getContext()); // Since this is an i1 condition we can also trivially replace uses of it // within the loop with a constant. - replaceLoopUsesWithConstant(L, *LoopCond, *Replacement); + for (Value *Invariant : Invariants) + replaceLoopInvariantUses(L, Invariant, *Replacement); + + // If this was full unswitching, we may have changed the nesting relationship + // for this loop so hoist it to its correct parent if needed. + if (FullUnswitch) + hoistLoopToNewParent(L, *NewPH, DT, LI); ++NumTrivial; ++NumBranches; @@ -471,9 +508,12 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, /// switch will not be revisited. If after unswitching there is only a single /// in-loop successor, the switch is further simplified to an unconditional /// branch. Still more cleanup can be done with some simplify-cfg like pass. +/// +/// If `SE` is not null, it will be updated based on the potential loop SCEVs +/// invalidated by this. static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, - LoopInfo &LI) { - DEBUG(dbgs() << " Trying to unswitch switch: " << SI << "\n"); + LoopInfo &LI, ScalarEvolution *SE) { + LLVM_DEBUG(dbgs() << " Trying to unswitch switch: " << SI << "\n"); Value *LoopCond = SI.getCondition(); // If this isn't switching on an invariant condition, we can't unswitch it. @@ -482,41 +522,62 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, auto *ParentBB = SI.getParent(); - // FIXME: We should compute this once at the start and update it! - SmallVector<BasicBlock *, 16> ExitBlocks; - L.getExitBlocks(ExitBlocks); - SmallPtrSet<BasicBlock *, 16> ExitBlockSet(ExitBlocks.begin(), - ExitBlocks.end()); - SmallVector<int, 4> ExitCaseIndices; for (auto Case : SI.cases()) { auto *SuccBB = Case.getCaseSuccessor(); - if (ExitBlockSet.count(SuccBB) && + if (!L.contains(SuccBB) && areLoopExitPHIsLoopInvariant(L, *ParentBB, *SuccBB)) ExitCaseIndices.push_back(Case.getCaseIndex()); } BasicBlock *DefaultExitBB = nullptr; - if (ExitBlockSet.count(SI.getDefaultDest()) && + if (!L.contains(SI.getDefaultDest()) && areLoopExitPHIsLoopInvariant(L, *ParentBB, *SI.getDefaultDest()) && !isa<UnreachableInst>(SI.getDefaultDest()->getTerminator())) DefaultExitBB = SI.getDefaultDest(); else if (ExitCaseIndices.empty()) return false; - DEBUG(dbgs() << " unswitching trivial cases...\n"); + LLVM_DEBUG(dbgs() << " unswitching trivial cases...\n"); + // We may need to invalidate SCEVs for the outermost loop reached by any of + // the exits. + Loop *OuterL = &L; + + if (DefaultExitBB) { + // Clear out the default destination temporarily to allow accurate + // predecessor lists to be examined below. + SI.setDefaultDest(nullptr); + // Check the loop containing this exit. + Loop *ExitL = LI.getLoopFor(DefaultExitBB); + if (!ExitL || ExitL->contains(OuterL)) + OuterL = ExitL; + } + + // Store the exit cases into a separate data structure and remove them from + // the switch. SmallVector<std::pair<ConstantInt *, BasicBlock *>, 4> ExitCases; ExitCases.reserve(ExitCaseIndices.size()); // We walk the case indices backwards so that we remove the last case first // and don't disrupt the earlier indices. for (unsigned Index : reverse(ExitCaseIndices)) { auto CaseI = SI.case_begin() + Index; + // Compute the outer loop from this exit. + Loop *ExitL = LI.getLoopFor(CaseI->getCaseSuccessor()); + if (!ExitL || ExitL->contains(OuterL)) + OuterL = ExitL; // Save the value of this case. ExitCases.push_back({CaseI->getCaseValue(), CaseI->getCaseSuccessor()}); // Delete the unswitched cases. SI.removeCase(CaseI); } + if (SE) { + if (OuterL) + SE->forgetLoop(OuterL); + else + SE->forgetTopmostLoop(&L); + } + // Check if after this all of the remaining cases point at the same // successor. BasicBlock *CommonSuccBB = nullptr; @@ -527,23 +588,7 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, SI.case_begin()->getCaseSuccessor(); })) CommonSuccBB = SI.case_begin()->getCaseSuccessor(); - - if (DefaultExitBB) { - // We can't remove the default edge so replace it with an edge to either - // the single common remaining successor (if we have one) or an unreachable - // block. - if (CommonSuccBB) { - SI.setDefaultDest(CommonSuccBB); - } else { - BasicBlock *UnreachableBB = BasicBlock::Create( - ParentBB->getContext(), - Twine(ParentBB->getName()) + ".unreachable_default", - ParentBB->getParent()); - new UnreachableInst(ParentBB->getContext(), UnreachableBB); - SI.setDefaultDest(UnreachableBB); - DT.addNewBlock(UnreachableBB, ParentBB); - } - } else { + if (!DefaultExitBB) { // If we're not unswitching the default, we need it to match any cases to // have a common successor or if we have no cases it is the common // successor. @@ -580,9 +625,8 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, } else { auto *SplitBB = SplitBlock(DefaultExitBB, &DefaultExitBB->front(), &DT, &LI); - rewritePHINodesForExitAndUnswitchedBlocks(*DefaultExitBB, *SplitBB, - *ParentBB, *OldPH); - updateIDomWithKnownCommonDominator(DefaultExitBB, L.getHeader(), DT); + rewritePHINodesForExitAndUnswitchedBlocks( + *DefaultExitBB, *SplitBB, *ParentBB, *OldPH, /*FullUnswitch*/ true); DefaultExitBB = SplitExitBBMap[DefaultExitBB] = SplitBB; } } @@ -607,9 +651,8 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, if (!SplitExitBB) { // If this is the first time we see this, do the split and remember it. SplitExitBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI); - rewritePHINodesForExitAndUnswitchedBlocks(*ExitBB, *SplitExitBB, - *ParentBB, *OldPH); - updateIDomWithKnownCommonDominator(ExitBB, L.getHeader(), DT); + rewritePHINodesForExitAndUnswitchedBlocks( + *ExitBB, *SplitExitBB, *ParentBB, *OldPH, /*FullUnswitch*/ true); } // Update the case pair to point to the split block. CasePair.second = SplitExitBB; @@ -622,14 +665,12 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, BasicBlock *UnswitchedBB = CasePair.second; NewSI->addCase(CaseVal, UnswitchedBB); - updateDTAfterUnswitch(UnswitchedBB, OldPH, DT); } // If the default was unswitched, re-point it and add explicit cases for // entering the loop. if (DefaultExitBB) { NewSI->setDefaultDest(DefaultExitBB); - updateDTAfterUnswitch(DefaultExitBB, OldPH, DT); // We removed all the exit cases, so we just copy the cases to the // unswitched switch. @@ -643,11 +684,57 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, // pointing at unreachable and other complexity. if (CommonSuccBB) { BasicBlock *BB = SI.getParent(); + // We may have had multiple edges to this common successor block, so remove + // them as predecessors. We skip the first one, either the default or the + // actual first case. + bool SkippedFirst = DefaultExitBB == nullptr; + for (auto Case : SI.cases()) { + assert(Case.getCaseSuccessor() == CommonSuccBB && + "Non-common successor!"); + (void)Case; + if (!SkippedFirst) { + SkippedFirst = true; + continue; + } + CommonSuccBB->removePredecessor(BB, + /*DontDeleteUselessPHIs*/ true); + } + // Now nuke the switch and replace it with a direct branch. SI.eraseFromParent(); BranchInst::Create(CommonSuccBB, BB); + } else if (DefaultExitBB) { + assert(SI.getNumCases() > 0 && + "If we had no cases we'd have a common successor!"); + // Move the last case to the default successor. This is valid as if the + // default got unswitched it cannot be reached. This has the advantage of + // being simple and keeping the number of edges from this switch to + // successors the same, and avoiding any PHI update complexity. + auto LastCaseI = std::prev(SI.case_end()); + SI.setDefaultDest(LastCaseI->getCaseSuccessor()); + SI.removeCase(LastCaseI); + } + + // Walk the unswitched exit blocks and the unswitched split blocks and update + // the dominator tree based on the CFG edits. While we are walking unordered + // containers here, the API for applyUpdates takes an unordered list of + // updates and requires them to not contain duplicates. + SmallVector<DominatorTree::UpdateType, 4> DTUpdates; + for (auto *UnswitchedExitBB : UnswitchedExitBBs) { + DTUpdates.push_back({DT.Delete, ParentBB, UnswitchedExitBB}); + DTUpdates.push_back({DT.Insert, OldPH, UnswitchedExitBB}); } + for (auto SplitUnswitchedPair : SplitExitBBMap) { + auto *UnswitchedBB = SplitUnswitchedPair.second; + DTUpdates.push_back({DT.Delete, ParentBB, UnswitchedBB}); + DTUpdates.push_back({DT.Insert, OldPH, UnswitchedBB}); + } + DT.applyUpdates(DTUpdates); + assert(DT.verify(DominatorTree::VerificationLevel::Fast)); + + // We may have changed the nesting relationship for this loop so hoist it to + // its correct parent if needed. + hoistLoopToNewParent(L, *NewPH, DT, LI); - DT.verifyDomTree(); ++NumTrivial; ++NumSwitches; return true; @@ -662,8 +749,11 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, /// /// The return value indicates whether anything was unswitched (and therefore /// changed). +/// +/// If `SE` is not null, it will be updated based on the potential loop SCEVs +/// invalidated by this. static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, - LoopInfo &LI) { + LoopInfo &LI, ScalarEvolution *SE) { bool Changed = false; // If loop header has only one reachable successor we should keep looking for @@ -697,8 +787,8 @@ static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, if (isa<Constant>(SI->getCondition())) return Changed; - if (!unswitchTrivialSwitch(L, *SI, DT, LI)) - // Coludn't unswitch this one so we're done. + if (!unswitchTrivialSwitch(L, *SI, DT, LI, SE)) + // Couldn't unswitch this one so we're done. return Changed; // Mark that we managed to unswitch something. @@ -729,17 +819,19 @@ static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, // Found a trivial condition candidate: non-foldable conditional branch. If // we fail to unswitch this, we can't do anything else that is trivial. - if (!unswitchTrivialBranch(L, *BI, DT, LI)) + if (!unswitchTrivialBranch(L, *BI, DT, LI, SE)) return Changed; // Mark that we managed to unswitch something. Changed = true; - // We unswitched the branch. This should always leave us with an - // unconditional branch that we can follow now. + // If we only unswitched some of the conditions feeding the branch, we won't + // have collapsed it to a single successor. BI = cast<BranchInst>(CurrentBB->getTerminator()); - assert(!BI->isConditional() && - "Cannot form a conditional branch by unswitching1"); + if (BI->isConditional()) + return Changed; + + // Follow the newly unconditional branch into its successor. CurrentBB = BI->getSuccessor(0); // When continuing, if we exit the loop or reach a previous visited block, @@ -758,8 +850,12 @@ static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, /// /// This routine handles cloning all of the necessary loop blocks and exit /// blocks including rewriting their instructions and the relevant PHI nodes. -/// It skips loop and exit blocks that are not necessary based on the provided -/// set. It also correctly creates the unconditional branch in the cloned +/// Any loop blocks or exit blocks which are dominated by a different successor +/// than the one for this clone of the loop blocks can be trivially skipped. We +/// use the `DominatingSucc` map to determine whether a block satisfies that +/// property with a simple map lookup. +/// +/// It also correctly creates the unconditional branch in the cloned /// unswitched parent block to only point at the unswitched successor. /// /// This does not handle most of the necessary updates to `LoopInfo`. Only exit @@ -773,9 +869,10 @@ static BasicBlock *buildClonedLoopBlocks( Loop &L, BasicBlock *LoopPH, BasicBlock *SplitBB, ArrayRef<BasicBlock *> ExitBlocks, BasicBlock *ParentBB, BasicBlock *UnswitchedSuccBB, BasicBlock *ContinueSuccBB, - const SmallPtrSetImpl<BasicBlock *> &SkippedLoopAndExitBlocks, - ValueToValueMapTy &VMap, AssumptionCache &AC, DominatorTree &DT, - LoopInfo &LI) { + const SmallDenseMap<BasicBlock *, BasicBlock *, 16> &DominatingSucc, + ValueToValueMapTy &VMap, + SmallVectorImpl<DominatorTree::UpdateType> &DTUpdates, AssumptionCache &AC, + DominatorTree &DT, LoopInfo &LI) { SmallVector<BasicBlock *, 4> NewBlocks; NewBlocks.reserve(L.getNumBlocks() + ExitBlocks.size()); @@ -790,26 +887,29 @@ static BasicBlock *buildClonedLoopBlocks( NewBlocks.push_back(NewBB); VMap[OldBB] = NewBB; - // Add the block to the domtree. We'll move it to the correct position - // below. - DT.addNewBlock(NewBB, SplitBB); - return NewBB; }; + // We skip cloning blocks when they have a dominating succ that is not the + // succ we are cloning for. + auto SkipBlock = [&](BasicBlock *BB) { + auto It = DominatingSucc.find(BB); + return It != DominatingSucc.end() && It->second != UnswitchedSuccBB; + }; + // First, clone the preheader. auto *ClonedPH = CloneBlock(LoopPH); // Then clone all the loop blocks, skipping the ones that aren't necessary. for (auto *LoopBB : L.blocks()) - if (!SkippedLoopAndExitBlocks.count(LoopBB)) + if (!SkipBlock(LoopBB)) CloneBlock(LoopBB); // Split all the loop exit edges so that when we clone the exit blocks, if // any of the exit blocks are *also* a preheader for some other loop, we // don't create multiple predecessors entering the loop header. for (auto *ExitBB : ExitBlocks) { - if (SkippedLoopAndExitBlocks.count(ExitBB)) + if (SkipBlock(ExitBB)) continue; // When we are going to clone an exit, we don't need to clone all the @@ -832,17 +932,6 @@ static BasicBlock *buildClonedLoopBlocks( assert(ClonedExitBB->getTerminator()->getSuccessor(0) == MergeBB && "Cloned exit block has the wrong successor!"); - // Move the merge block's idom to be the split point as one exit is - // dominated by one header, and the other by another, so we know the split - // point dominates both. While the dominator tree isn't fully accurate, we - // want sub-trees within the original loop to be correctly reflect - // dominance within that original loop (at least) and that requires moving - // the merge block out of that subtree. - // FIXME: This is very brittle as we essentially have a partial contract on - // the dominator tree. We really need to instead update it and keep it - // valid or stop relying on it. - DT.changeImmediateDominator(MergeBB, SplitBB); - // Remap any cloned instructions and create a merge phi node for them. for (auto ZippedInsts : llvm::zip_first( llvm::make_range(ExitBB->begin(), std::prev(ExitBB->end())), @@ -882,28 +971,63 @@ static BasicBlock *buildClonedLoopBlocks( AC.registerAssumption(II); } - // Remove the cloned parent as a predecessor of the cloned continue successor - // if we did in fact clone it. - auto *ClonedParentBB = cast<BasicBlock>(VMap.lookup(ParentBB)); - if (auto *ClonedContinueSuccBB = - cast_or_null<BasicBlock>(VMap.lookup(ContinueSuccBB))) - ClonedContinueSuccBB->removePredecessor(ClonedParentBB, - /*DontDeleteUselessPHIs*/ true); - // Replace the cloned branch with an unconditional branch to the cloneed - // unswitched successor. - auto *ClonedSuccBB = cast<BasicBlock>(VMap.lookup(UnswitchedSuccBB)); - ClonedParentBB->getTerminator()->eraseFromParent(); - BranchInst::Create(ClonedSuccBB, ClonedParentBB); - // Update any PHI nodes in the cloned successors of the skipped blocks to not // have spurious incoming values. for (auto *LoopBB : L.blocks()) - if (SkippedLoopAndExitBlocks.count(LoopBB)) + if (SkipBlock(LoopBB)) for (auto *SuccBB : successors(LoopBB)) if (auto *ClonedSuccBB = cast_or_null<BasicBlock>(VMap.lookup(SuccBB))) for (PHINode &PN : ClonedSuccBB->phis()) PN.removeIncomingValue(LoopBB, /*DeletePHIIfEmpty*/ false); + // Remove the cloned parent as a predecessor of any successor we ended up + // cloning other than the unswitched one. + auto *ClonedParentBB = cast<BasicBlock>(VMap.lookup(ParentBB)); + for (auto *SuccBB : successors(ParentBB)) { + if (SuccBB == UnswitchedSuccBB) + continue; + + auto *ClonedSuccBB = cast_or_null<BasicBlock>(VMap.lookup(SuccBB)); + if (!ClonedSuccBB) + continue; + + ClonedSuccBB->removePredecessor(ClonedParentBB, + /*DontDeleteUselessPHIs*/ true); + } + + // Replace the cloned branch with an unconditional branch to the cloned + // unswitched successor. + auto *ClonedSuccBB = cast<BasicBlock>(VMap.lookup(UnswitchedSuccBB)); + ClonedParentBB->getTerminator()->eraseFromParent(); + BranchInst::Create(ClonedSuccBB, ClonedParentBB); + + // If there are duplicate entries in the PHI nodes because of multiple edges + // to the unswitched successor, we need to nuke all but one as we replaced it + // with a direct branch. + for (PHINode &PN : ClonedSuccBB->phis()) { + bool Found = false; + // Loop over the incoming operands backwards so we can easily delete as we + // go without invalidating the index. + for (int i = PN.getNumOperands() - 1; i >= 0; --i) { + if (PN.getIncomingBlock(i) != ClonedParentBB) + continue; + if (!Found) { + Found = true; + continue; + } + PN.removeIncomingValue(i, /*DeletePHIIfEmpty*/ false); + } + } + + // Record the domtree updates for the new blocks. + SmallPtrSet<BasicBlock *, 4> SuccSet; + for (auto *ClonedBB : NewBlocks) { + for (auto *SuccBB : successors(ClonedBB)) + if (SuccSet.insert(SuccBB).second) + DTUpdates.push_back({DominatorTree::Insert, ClonedBB, SuccBB}); + SuccSet.clear(); + } + return ClonedPH; } @@ -921,11 +1045,8 @@ static Loop *cloneLoopNest(Loop &OrigRootL, Loop *RootParentL, for (auto *BB : OrigL.blocks()) { auto *ClonedBB = cast<BasicBlock>(VMap.lookup(BB)); ClonedL.addBlockEntry(ClonedBB); - if (LI.getLoopFor(BB) == &OrigL) { - assert(!LI.getLoopFor(ClonedBB) && - "Should not have an existing loop for this block!"); + if (LI.getLoopFor(BB) == &OrigL) LI.changeLoopFor(ClonedBB, &ClonedL); - } } }; @@ -975,9 +1096,9 @@ static Loop *cloneLoopNest(Loop &OrigRootL, Loop *RootParentL, /// original loop, multiple cloned sibling loops may be created. All of them /// are returned so that the newly introduced loop nest roots can be /// identified. -static Loop *buildClonedLoops(Loop &OrigL, ArrayRef<BasicBlock *> ExitBlocks, - const ValueToValueMapTy &VMap, LoopInfo &LI, - SmallVectorImpl<Loop *> &NonChildClonedLoops) { +static void buildClonedLoops(Loop &OrigL, ArrayRef<BasicBlock *> ExitBlocks, + const ValueToValueMapTy &VMap, LoopInfo &LI, + SmallVectorImpl<Loop *> &NonChildClonedLoops) { Loop *ClonedL = nullptr; auto *OrigPH = OrigL.getLoopPreheader(); @@ -1070,6 +1191,7 @@ static Loop *buildClonedLoops(Loop &OrigL, ArrayRef<BasicBlock *> ExitBlocks, } else { LI.addTopLevelLoop(ClonedL); } + NonChildClonedLoops.push_back(ClonedL); ClonedL->reserveBlocks(BlocksInClonedLoop.size()); // We don't want to just add the cloned loop blocks based on how we @@ -1138,11 +1260,11 @@ static Loop *buildClonedLoops(Loop &OrigL, ArrayRef<BasicBlock *> ExitBlocks, // matter as we're just trying to build up the map from inside-out; we use // the map in a more stably ordered way below. auto OrderedClonedExitsInLoops = ClonedExitsInLoops; - std::sort(OrderedClonedExitsInLoops.begin(), OrderedClonedExitsInLoops.end(), - [&](BasicBlock *LHS, BasicBlock *RHS) { - return ExitLoopMap.lookup(LHS)->getLoopDepth() < - ExitLoopMap.lookup(RHS)->getLoopDepth(); - }); + llvm::sort(OrderedClonedExitsInLoops.begin(), OrderedClonedExitsInLoops.end(), + [&](BasicBlock *LHS, BasicBlock *RHS) { + return ExitLoopMap.lookup(LHS)->getLoopDepth() < + ExitLoopMap.lookup(RHS)->getLoopDepth(); + }); // Populate the existing ExitLoopMap with everything reachable from each // exit, starting from the inner most exit. @@ -1222,60 +1344,69 @@ static Loop *buildClonedLoops(Loop &OrigL, ArrayRef<BasicBlock *> ExitBlocks, NonChildClonedLoops.push_back(cloneLoopNest( *ChildL, ExitLoopMap.lookup(ClonedChildHeader), VMap, LI)); } +} + +static void +deleteDeadClonedBlocks(Loop &L, ArrayRef<BasicBlock *> ExitBlocks, + ArrayRef<std::unique_ptr<ValueToValueMapTy>> VMaps, + DominatorTree &DT) { + // Find all the dead clones, and remove them from their successors. + SmallVector<BasicBlock *, 16> DeadBlocks; + for (BasicBlock *BB : llvm::concat<BasicBlock *const>(L.blocks(), ExitBlocks)) + for (auto &VMap : VMaps) + if (BasicBlock *ClonedBB = cast_or_null<BasicBlock>(VMap->lookup(BB))) + if (!DT.isReachableFromEntry(ClonedBB)) { + for (BasicBlock *SuccBB : successors(ClonedBB)) + SuccBB->removePredecessor(ClonedBB); + DeadBlocks.push_back(ClonedBB); + } - // Return the main cloned loop if any. - return ClonedL; + // Drop any remaining references to break cycles. + for (BasicBlock *BB : DeadBlocks) + BB->dropAllReferences(); + // Erase them from the IR. + for (BasicBlock *BB : DeadBlocks) + BB->eraseFromParent(); } -static void deleteDeadBlocksFromLoop(Loop &L, BasicBlock *DeadSubtreeRoot, - SmallVectorImpl<BasicBlock *> &ExitBlocks, - DominatorTree &DT, LoopInfo &LI) { - // Walk the dominator tree to build up the set of blocks we will delete here. - // The order is designed to allow us to always delete bottom-up and avoid any - // dangling uses. - SmallSetVector<BasicBlock *, 16> DeadBlocks; - DeadBlocks.insert(DeadSubtreeRoot); - for (int i = 0; i < (int)DeadBlocks.size(); ++i) - for (DomTreeNode *ChildN : *DT[DeadBlocks[i]]) { - // FIXME: This assert should pass and that means we don't change nearly - // as much below! Consider rewriting all of this to avoid deleting - // blocks. They are always cloned before being deleted, and so instead - // could just be moved. - // FIXME: This in turn means that we might actually be more able to - // update the domtree. - assert((L.contains(ChildN->getBlock()) || - llvm::find(ExitBlocks, ChildN->getBlock()) != ExitBlocks.end()) && - "Should never reach beyond the loop and exits when deleting!"); - DeadBlocks.insert(ChildN->getBlock()); +static void +deleteDeadBlocksFromLoop(Loop &L, + SmallVectorImpl<BasicBlock *> &ExitBlocks, + DominatorTree &DT, LoopInfo &LI) { + // Find all the dead blocks, and remove them from their successors. + SmallVector<BasicBlock *, 16> DeadBlocks; + for (BasicBlock *BB : llvm::concat<BasicBlock *const>(L.blocks(), ExitBlocks)) + if (!DT.isReachableFromEntry(BB)) { + for (BasicBlock *SuccBB : successors(BB)) + SuccBB->removePredecessor(BB); + DeadBlocks.push_back(BB); } + SmallPtrSet<BasicBlock *, 16> DeadBlockSet(DeadBlocks.begin(), + DeadBlocks.end()); + // Filter out the dead blocks from the exit blocks list so that it can be // used in the caller. llvm::erase_if(ExitBlocks, - [&](BasicBlock *BB) { return DeadBlocks.count(BB); }); - - // Remove these blocks from their successors. - for (auto *BB : DeadBlocks) - for (BasicBlock *SuccBB : successors(BB)) - SuccBB->removePredecessor(BB, /*DontDeleteUselessPHIs*/ true); + [&](BasicBlock *BB) { return DeadBlockSet.count(BB); }); // Walk from this loop up through its parents removing all of the dead blocks. for (Loop *ParentL = &L; ParentL; ParentL = ParentL->getParentLoop()) { for (auto *BB : DeadBlocks) ParentL->getBlocksSet().erase(BB); llvm::erase_if(ParentL->getBlocksVector(), - [&](BasicBlock *BB) { return DeadBlocks.count(BB); }); + [&](BasicBlock *BB) { return DeadBlockSet.count(BB); }); } // Now delete the dead child loops. This raw delete will clear them // recursively. llvm::erase_if(L.getSubLoopsVector(), [&](Loop *ChildL) { - if (!DeadBlocks.count(ChildL->getHeader())) + if (!DeadBlockSet.count(ChildL->getHeader())) return false; assert(llvm::all_of(ChildL->blocks(), [&](BasicBlock *ChildBB) { - return DeadBlocks.count(ChildBB); + return DeadBlockSet.count(ChildBB); }) && "If the child loop header is dead all blocks in the child loop must " "be dead as well!"); @@ -1283,19 +1414,20 @@ static void deleteDeadBlocksFromLoop(Loop &L, BasicBlock *DeadSubtreeRoot, return true; }); - // Remove the mappings for the dead blocks. - for (auto *BB : DeadBlocks) + // Remove the loop mappings for the dead blocks and drop all the references + // from these blocks to others to handle cyclic references as we start + // deleting the blocks themselves. + for (auto *BB : DeadBlocks) { + // Check that the dominator tree has already been updated. + assert(!DT.getNode(BB) && "Should already have cleared domtree!"); LI.changeLoopFor(BB, nullptr); - - // Drop all the references from these blocks to others to handle cyclic - // references as we start deleting the blocks themselves. - for (auto *BB : DeadBlocks) BB->dropAllReferences(); + } - for (auto *BB : llvm::reverse(DeadBlocks)) { - DT.eraseNode(BB); + // Actually delete the blocks now that they've been fully unhooked from the + // IR. + for (auto *BB : DeadBlocks) BB->eraseFromParent(); - } } /// Recompute the set of blocks in a loop after unswitching. @@ -1343,14 +1475,15 @@ static SmallPtrSet<const BasicBlock *, 16> recomputeLoopBlockSet(Loop &L, if (LoopBlockSet.empty()) return LoopBlockSet; - // Add the loop header to the set. - LoopBlockSet.insert(Header); - // We found backedges, recurse through them to identify the loop blocks. while (!Worklist.empty()) { BasicBlock *BB = Worklist.pop_back_val(); assert(LoopBlockSet.count(BB) && "Didn't put block into the loop set!"); + // No need to walk past the header. + if (BB == Header) + continue; + // Because we know the inner loop structure remains valid we can use the // loop structure to jump immediately across the entire nested loop. // Further, because it is in loop simplified form, we can directly jump @@ -1371,9 +1504,10 @@ static SmallPtrSet<const BasicBlock *, 16> recomputeLoopBlockSet(Loop &L, continue; // Insert all of the blocks (other than those already present) into - // the loop set. The only block we expect to already be in the set is - // the one we used to find this loop as we immediately handle the - // others the first time we encounter the loop. + // the loop set. We expect at least the block that led us to find the + // inner loop to be in the block set, but we may also have other loop + // blocks if they were already enqueued as predecessors of some other + // outer loop block. for (auto *InnerBB : InnerL->blocks()) { if (InnerBB == BB) { assert(LoopBlockSet.count(InnerBB) && @@ -1381,9 +1515,7 @@ static SmallPtrSet<const BasicBlock *, 16> recomputeLoopBlockSet(Loop &L, continue; } - bool Inserted = LoopBlockSet.insert(InnerBB).second; - (void)Inserted; - assert(Inserted && "Should only insert an inner loop once!"); + LoopBlockSet.insert(InnerBB); } // Add the preheader to the worklist so we will continue past the @@ -1399,6 +1531,8 @@ static SmallPtrSet<const BasicBlock *, 16> recomputeLoopBlockSet(Loop &L, Worklist.push_back(Pred); } + assert(LoopBlockSet.count(Header) && "Cannot fail to add the header!"); + // We've found all the blocks participating in the loop, return our completed // set. return LoopBlockSet; @@ -1646,32 +1780,58 @@ void visitDomSubTree(DominatorTree &DT, BasicBlock *BB, CallableT Callable) { } while (!DomWorklist.empty()); } -/// Take an invariant branch that has been determined to be safe and worthwhile -/// to unswitch despite being non-trivial to do so and perform the unswitch. -/// -/// This directly updates the CFG to hoist the predicate out of the loop, and -/// clone the necessary parts of the loop to maintain behavior. -/// -/// It also updates both dominator tree and loopinfo based on the unswitching. -/// -/// Once unswitching has been performed it runs the provided callback to report -/// the new loops and no-longer valid loops to the caller. -static bool unswitchInvariantBranch( - Loop &L, BranchInst &BI, DominatorTree &DT, LoopInfo &LI, - AssumptionCache &AC, - function_ref<void(bool, ArrayRef<Loop *>)> NonTrivialUnswitchCB) { - assert(BI.isConditional() && "Can only unswitch a conditional branch!"); - assert(L.isLoopInvariant(BI.getCondition()) && - "Can only unswitch an invariant branch condition!"); +static bool unswitchNontrivialInvariants( + Loop &L, TerminatorInst &TI, ArrayRef<Value *> Invariants, + DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, + function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB, + ScalarEvolution *SE) { + auto *ParentBB = TI.getParent(); + BranchInst *BI = dyn_cast<BranchInst>(&TI); + SwitchInst *SI = BI ? nullptr : cast<SwitchInst>(&TI); + + // We can only unswitch switches, conditional branches with an invariant + // condition, or combining invariant conditions with an instruction. + assert((SI || BI->isConditional()) && + "Can only unswitch switches and conditional branch!"); + bool FullUnswitch = SI || BI->getCondition() == Invariants[0]; + if (FullUnswitch) + assert(Invariants.size() == 1 && + "Cannot have other invariants with full unswitching!"); + else + assert(isa<Instruction>(BI->getCondition()) && + "Partial unswitching requires an instruction as the condition!"); + + // Constant and BBs tracking the cloned and continuing successor. When we are + // unswitching the entire condition, this can just be trivially chosen to + // unswitch towards `true`. However, when we are unswitching a set of + // invariants combined with `and` or `or`, the combining operation determines + // the best direction to unswitch: we want to unswitch the direction that will + // collapse the branch. + bool Direction = true; + int ClonedSucc = 0; + if (!FullUnswitch) { + if (cast<Instruction>(BI->getCondition())->getOpcode() != Instruction::Or) { + assert(cast<Instruction>(BI->getCondition())->getOpcode() == + Instruction::And && + "Only `or` and `and` instructions can combine invariants being " + "unswitched."); + Direction = false; + ClonedSucc = 1; + } + } - // Constant and BBs tracking the cloned and continuing successor. - const int ClonedSucc = 0; - auto *ParentBB = BI.getParent(); - auto *UnswitchedSuccBB = BI.getSuccessor(ClonedSucc); - auto *ContinueSuccBB = BI.getSuccessor(1 - ClonedSucc); + BasicBlock *RetainedSuccBB = + BI ? BI->getSuccessor(1 - ClonedSucc) : SI->getDefaultDest(); + SmallSetVector<BasicBlock *, 4> UnswitchedSuccBBs; + if (BI) + UnswitchedSuccBBs.insert(BI->getSuccessor(ClonedSucc)); + else + for (auto Case : SI->cases()) + if (Case.getCaseSuccessor() != RetainedSuccBB) + UnswitchedSuccBBs.insert(Case.getCaseSuccessor()); - assert(UnswitchedSuccBB != ContinueSuccBB && - "Should not unswitch a branch that always goes to the same place!"); + assert(!UnswitchedSuccBBs.count(RetainedSuccBB) && + "Should not unswitch the same successor we are retaining!"); // The branch should be in this exact loop. Any inner loop's invariant branch // should be handled by unswitching that inner loop. The caller of this @@ -1690,9 +1850,6 @@ static bool unswitchInvariantBranch( if (isa<CleanupPadInst>(ExitBB->getFirstNonPHI())) return false; - SmallPtrSet<BasicBlock *, 4> ExitBlockSet(ExitBlocks.begin(), - ExitBlocks.end()); - // Compute the parent loop now before we start hacking on things. Loop *ParentL = L.getParentLoop(); @@ -1711,27 +1868,31 @@ static bool unswitchInvariantBranch( OuterExitL = NewOuterExitL; } - // If the edge we *aren't* cloning in the unswitch (the continuing edge) - // dominates its target, we can skip cloning the dominated region of the loop - // and its exits. We compute this as a set of nodes to be skipped. - SmallPtrSet<BasicBlock *, 4> SkippedLoopAndExitBlocks; - if (ContinueSuccBB->getUniquePredecessor() || - llvm::all_of(predecessors(ContinueSuccBB), [&](BasicBlock *PredBB) { - return PredBB == ParentBB || DT.dominates(ContinueSuccBB, PredBB); - })) { - visitDomSubTree(DT, ContinueSuccBB, [&](BasicBlock *BB) { - SkippedLoopAndExitBlocks.insert(BB); - return true; - }); + // At this point, we're definitely going to unswitch something so invalidate + // any cached information in ScalarEvolution for the outer most loop + // containing an exit block and all nested loops. + if (SE) { + if (OuterExitL) + SE->forgetLoop(OuterExitL); + else + SE->forgetTopmostLoop(&L); } - // Similarly, if the edge we *are* cloning in the unswitch (the unswitched - // edge) dominates its target, we will end up with dead nodes in the original - // loop and its exits that will need to be deleted. Here, we just retain that - // the property holds and will compute the deleted set later. - bool DeleteUnswitchedSucc = - UnswitchedSuccBB->getUniquePredecessor() || - llvm::all_of(predecessors(UnswitchedSuccBB), [&](BasicBlock *PredBB) { - return PredBB == ParentBB || DT.dominates(UnswitchedSuccBB, PredBB); + + // If the edge from this terminator to a successor dominates that successor, + // store a map from each block in its dominator subtree to it. This lets us + // tell when cloning for a particular successor if a block is dominated by + // some *other* successor with a single data structure. We use this to + // significantly reduce cloning. + SmallDenseMap<BasicBlock *, BasicBlock *, 16> DominatingSucc; + for (auto *SuccBB : llvm::concat<BasicBlock *const>( + makeArrayRef(RetainedSuccBB), UnswitchedSuccBBs)) + if (SuccBB->getUniquePredecessor() || + llvm::all_of(predecessors(SuccBB), [&](BasicBlock *PredBB) { + return PredBB == ParentBB || DT.dominates(SuccBB, PredBB); + })) + visitDomSubTree(DT, SuccBB, [&](BasicBlock *BB) { + DominatingSucc[BB] = SuccBB; + return true; }); // Split the preheader, so that we know that there is a safe place to insert @@ -1742,52 +1903,162 @@ static bool unswitchInvariantBranch( BasicBlock *SplitBB = L.getLoopPreheader(); BasicBlock *LoopPH = SplitEdge(SplitBB, L.getHeader(), &DT, &LI); - // Keep a mapping for the cloned values. - ValueToValueMapTy VMap; + // Keep track of the dominator tree updates needed. + SmallVector<DominatorTree::UpdateType, 4> DTUpdates; + + // Clone the loop for each unswitched successor. + SmallVector<std::unique_ptr<ValueToValueMapTy>, 4> VMaps; + VMaps.reserve(UnswitchedSuccBBs.size()); + SmallDenseMap<BasicBlock *, BasicBlock *, 4> ClonedPHs; + for (auto *SuccBB : UnswitchedSuccBBs) { + VMaps.emplace_back(new ValueToValueMapTy()); + ClonedPHs[SuccBB] = buildClonedLoopBlocks( + L, LoopPH, SplitBB, ExitBlocks, ParentBB, SuccBB, RetainedSuccBB, + DominatingSucc, *VMaps.back(), DTUpdates, AC, DT, LI); + } + + // The stitching of the branched code back together depends on whether we're + // doing full unswitching or not with the exception that we always want to + // nuke the initial terminator placed in the split block. + SplitBB->getTerminator()->eraseFromParent(); + if (FullUnswitch) { + // First we need to unhook the successor relationship as we'll be replacing + // the terminator with a direct branch. This is much simpler for branches + // than switches so we handle those first. + if (BI) { + // Remove the parent as a predecessor of the unswitched successor. + assert(UnswitchedSuccBBs.size() == 1 && + "Only one possible unswitched block for a branch!"); + BasicBlock *UnswitchedSuccBB = *UnswitchedSuccBBs.begin(); + UnswitchedSuccBB->removePredecessor(ParentBB, + /*DontDeleteUselessPHIs*/ true); + DTUpdates.push_back({DominatorTree::Delete, ParentBB, UnswitchedSuccBB}); + } else { + // Note that we actually want to remove the parent block as a predecessor + // of *every* case successor. The case successor is either unswitched, + // completely eliminating an edge from the parent to that successor, or it + // is a duplicate edge to the retained successor as the retained successor + // is always the default successor and as we'll replace this with a direct + // branch we no longer need the duplicate entries in the PHI nodes. + assert(SI->getDefaultDest() == RetainedSuccBB && + "Not retaining default successor!"); + for (auto &Case : SI->cases()) + Case.getCaseSuccessor()->removePredecessor( + ParentBB, + /*DontDeleteUselessPHIs*/ true); + + // We need to use the set to populate domtree updates as even when there + // are multiple cases pointing at the same successor we only want to + // remove and insert one edge in the domtree. + for (BasicBlock *SuccBB : UnswitchedSuccBBs) + DTUpdates.push_back({DominatorTree::Delete, ParentBB, SuccBB}); + } + + // Now that we've unhooked the successor relationship, splice the terminator + // from the original loop to the split. + SplitBB->getInstList().splice(SplitBB->end(), ParentBB->getInstList(), TI); + + // Now wire up the terminator to the preheaders. + if (BI) { + BasicBlock *ClonedPH = ClonedPHs.begin()->second; + BI->setSuccessor(ClonedSucc, ClonedPH); + BI->setSuccessor(1 - ClonedSucc, LoopPH); + DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH}); + } else { + assert(SI && "Must either be a branch or switch!"); + + // Walk the cases and directly update their successors. + SI->setDefaultDest(LoopPH); + for (auto &Case : SI->cases()) + if (Case.getCaseSuccessor() == RetainedSuccBB) + Case.setSuccessor(LoopPH); + else + Case.setSuccessor(ClonedPHs.find(Case.getCaseSuccessor())->second); + + // We need to use the set to populate domtree updates as even when there + // are multiple cases pointing at the same successor we only want to + // remove and insert one edge in the domtree. + for (BasicBlock *SuccBB : UnswitchedSuccBBs) + DTUpdates.push_back( + {DominatorTree::Insert, SplitBB, ClonedPHs.find(SuccBB)->second}); + } + + // Create a new unconditional branch to the continuing block (as opposed to + // the one cloned). + BranchInst::Create(RetainedSuccBB, ParentBB); + } else { + assert(BI && "Only branches have partial unswitching."); + assert(UnswitchedSuccBBs.size() == 1 && + "Only one possible unswitched block for a branch!"); + BasicBlock *ClonedPH = ClonedPHs.begin()->second; + // When doing a partial unswitch, we have to do a bit more work to build up + // the branch in the split block. + buildPartialUnswitchConditionalBranch(*SplitBB, Invariants, Direction, + *ClonedPH, *LoopPH); + DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH}); + } - // Build the cloned blocks from the loop. - auto *ClonedPH = buildClonedLoopBlocks( - L, LoopPH, SplitBB, ExitBlocks, ParentBB, UnswitchedSuccBB, - ContinueSuccBB, SkippedLoopAndExitBlocks, VMap, AC, DT, LI); + // Apply the updates accumulated above to get an up-to-date dominator tree. + DT.applyUpdates(DTUpdates); + + // Now that we have an accurate dominator tree, first delete the dead cloned + // blocks so that we can accurately build any cloned loops. It is important to + // not delete the blocks from the original loop yet because we still want to + // reference the original loop to understand the cloned loop's structure. + deleteDeadClonedBlocks(L, ExitBlocks, VMaps, DT); // Build the cloned loop structure itself. This may be substantially // different from the original structure due to the simplified CFG. This also // handles inserting all the cloned blocks into the correct loops. SmallVector<Loop *, 4> NonChildClonedLoops; - Loop *ClonedL = - buildClonedLoops(L, ExitBlocks, VMap, LI, NonChildClonedLoops); - - // Remove the parent as a predecessor of the unswitched successor. - UnswitchedSuccBB->removePredecessor(ParentBB, /*DontDeleteUselessPHIs*/ true); - - // Now splice the branch from the original loop and use it to select between - // the two loops. - SplitBB->getTerminator()->eraseFromParent(); - SplitBB->getInstList().splice(SplitBB->end(), ParentBB->getInstList(), BI); - BI.setSuccessor(ClonedSucc, ClonedPH); - BI.setSuccessor(1 - ClonedSucc, LoopPH); - - // Create a new unconditional branch to the continuing block (as opposed to - // the one cloned). - BranchInst::Create(ContinueSuccBB, ParentBB); - - // Delete anything that was made dead in the original loop due to - // unswitching. - if (DeleteUnswitchedSucc) - deleteDeadBlocksFromLoop(L, UnswitchedSuccBB, ExitBlocks, DT, LI); + for (std::unique_ptr<ValueToValueMapTy> &VMap : VMaps) + buildClonedLoops(L, ExitBlocks, *VMap, LI, NonChildClonedLoops); + // Now that our cloned loops have been built, we can update the original loop. + // First we delete the dead blocks from it and then we rebuild the loop + // structure taking these deletions into account. + deleteDeadBlocksFromLoop(L, ExitBlocks, DT, LI); SmallVector<Loop *, 4> HoistedLoops; bool IsStillLoop = rebuildLoopAfterUnswitch(L, ExitBlocks, LI, HoistedLoops); - // This will have completely invalidated the dominator tree. We can't easily - // bound how much is invalid because in some cases we will refine the - // predecessor set of exit blocks of the loop which can move large unrelated - // regions of code into a new subtree. - // - // FIXME: Eventually, we should use an incremental update utility that - // leverages the existing information in the dominator tree (and potentially - // the nature of the change) to more efficiently update things. - DT.recalculate(*SplitBB->getParent()); + // This transformation has a high risk of corrupting the dominator tree, and + // the below steps to rebuild loop structures will result in hard to debug + // errors in that case so verify that the dominator tree is sane first. + // FIXME: Remove this when the bugs stop showing up and rely on existing + // verification steps. + assert(DT.verify(DominatorTree::VerificationLevel::Fast)); + + if (BI) { + // If we unswitched a branch which collapses the condition to a known + // constant we want to replace all the uses of the invariants within both + // the original and cloned blocks. We do this here so that we can use the + // now updated dominator tree to identify which side the users are on. + assert(UnswitchedSuccBBs.size() == 1 && + "Only one possible unswitched block for a branch!"); + BasicBlock *ClonedPH = ClonedPHs.begin()->second; + ConstantInt *UnswitchedReplacement = + Direction ? ConstantInt::getTrue(BI->getContext()) + : ConstantInt::getFalse(BI->getContext()); + ConstantInt *ContinueReplacement = + Direction ? ConstantInt::getFalse(BI->getContext()) + : ConstantInt::getTrue(BI->getContext()); + for (Value *Invariant : Invariants) + for (auto UI = Invariant->use_begin(), UE = Invariant->use_end(); + UI != UE;) { + // Grab the use and walk past it so we can clobber it in the use list. + Use *U = &*UI++; + Instruction *UserI = dyn_cast<Instruction>(U->getUser()); + if (!UserI) + continue; + + // Replace it with the 'continue' side if in the main loop body, and the + // unswitched if in the cloned blocks. + if (DT.dominates(LoopPH, UserI->getParent())) + U->set(ContinueReplacement); + else if (DT.dominates(ClonedPH, UserI->getParent())) + U->set(UnswitchedReplacement); + } + } // We can change which blocks are exit blocks of all the cloned sibling // loops, the current loop, and any parent loops which shared exit blocks @@ -1801,57 +2072,50 @@ static bool unswitchInvariantBranch( // also need to cover any intervening loops. We add all of these loops to // a list and sort them by loop depth to achieve this without updating // unnecessary loops. - auto UpdateLCSSA = [&](Loop &UpdateL) { + auto UpdateLoop = [&](Loop &UpdateL) { #ifndef NDEBUG - for (Loop *ChildL : UpdateL) + UpdateL.verifyLoop(); + for (Loop *ChildL : UpdateL) { + ChildL->verifyLoop(); assert(ChildL->isRecursivelyLCSSAForm(DT, LI) && "Perturbed a child loop's LCSSA form!"); + } #endif + // First build LCSSA for this loop so that we can preserve it when + // forming dedicated exits. We don't want to perturb some other loop's + // LCSSA while doing that CFG edit. formLCSSA(UpdateL, DT, &LI, nullptr); + + // For loops reached by this loop's original exit blocks we may + // introduced new, non-dedicated exits. At least try to re-form dedicated + // exits for these loops. This may fail if they couldn't have dedicated + // exits to start with. + formDedicatedExitBlocks(&UpdateL, &DT, &LI, /*PreserveLCSSA*/ true); }; // For non-child cloned loops and hoisted loops, we just need to update LCSSA // and we can do it in any order as they don't nest relative to each other. - for (Loop *UpdatedL : llvm::concat<Loop *>(NonChildClonedLoops, HoistedLoops)) - UpdateLCSSA(*UpdatedL); + // + // Also check if any of the loops we have updated have become top-level loops + // as that will necessitate widening the outer loop scope. + for (Loop *UpdatedL : + llvm::concat<Loop *>(NonChildClonedLoops, HoistedLoops)) { + UpdateLoop(*UpdatedL); + if (!UpdatedL->getParentLoop()) + OuterExitL = nullptr; + } + if (IsStillLoop) { + UpdateLoop(L); + if (!L.getParentLoop()) + OuterExitL = nullptr; + } // If the original loop had exit blocks, walk up through the outer most loop // of those exit blocks to update LCSSA and form updated dedicated exits. - if (OuterExitL != &L) { - SmallVector<Loop *, 4> OuterLoops; - // We start with the cloned loop and the current loop if they are loops and - // move toward OuterExitL. Also, if either the cloned loop or the current - // loop have become top level loops we need to walk all the way out. - if (ClonedL) { - OuterLoops.push_back(ClonedL); - if (!ClonedL->getParentLoop()) - OuterExitL = nullptr; - } - if (IsStillLoop) { - OuterLoops.push_back(&L); - if (!L.getParentLoop()) - OuterExitL = nullptr; - } - // Grab all of the enclosing loops now. + if (OuterExitL != &L) for (Loop *OuterL = ParentL; OuterL != OuterExitL; OuterL = OuterL->getParentLoop()) - OuterLoops.push_back(OuterL); - - // Finally, update our list of outer loops. This is nicely ordered to work - // inside-out. - for (Loop *OuterL : OuterLoops) { - // First build LCSSA for this loop so that we can preserve it when - // forming dedicated exits. We don't want to perturb some other loop's - // LCSSA while doing that CFG edit. - UpdateLCSSA(*OuterL); - - // For loops reached by this loop's original exit blocks we may - // introduced new, non-dedicated exits. At least try to re-form dedicated - // exits for these loops. This may fail if they couldn't have dedicated - // exits to start with. - formDedicatedExitBlocks(OuterL, &DT, &LI, /*PreserveLCSSA*/ true); - } - } + UpdateLoop(*OuterL); #ifndef NDEBUG // Verify the entire loop structure to catch any incorrect updates before we @@ -1866,7 +2130,7 @@ static bool unswitchInvariantBranch( for (Loop *UpdatedL : llvm::concat<Loop *>(NonChildClonedLoops, HoistedLoops)) if (UpdatedL->getParentLoop() == ParentL) SibLoops.push_back(UpdatedL); - NonTrivialUnswitchCB(IsStillLoop, SibLoops); + UnswitchCB(IsStillLoop, SibLoops); ++NumBranches; return true; @@ -1905,50 +2169,69 @@ computeDomSubtreeCost(DomTreeNode &N, return Cost; } -/// Unswitch control flow predicated on loop invariant conditions. -/// -/// This first hoists all branches or switches which are trivial (IE, do not -/// require duplicating any part of the loop) out of the loop body. It then -/// looks at other loop invariant control flows and tries to unswitch those as -/// well by cloning the loop if the result is small enough. static bool -unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, - TargetTransformInfo &TTI, bool NonTrivial, - function_ref<void(bool, ArrayRef<Loop *>)> NonTrivialUnswitchCB) { - assert(L.isRecursivelyLCSSAForm(DT, LI) && - "Loops must be in LCSSA form before unswitching."); - bool Changed = false; +unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC, TargetTransformInfo &TTI, + function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB, + ScalarEvolution *SE) { + // Collect all invariant conditions within this loop (as opposed to an inner + // loop which would be handled when visiting that inner loop). + SmallVector<std::pair<TerminatorInst *, TinyPtrVector<Value *>>, 4> + UnswitchCandidates; + for (auto *BB : L.blocks()) { + if (LI.getLoopFor(BB) != &L) + continue; - // Must be in loop simplified form: we need a preheader and dedicated exits. - if (!L.isLoopSimplifyForm()) - return false; + if (auto *SI = dyn_cast<SwitchInst>(BB->getTerminator())) { + // We can only consider fully loop-invariant switch conditions as we need + // to completely eliminate the switch after unswitching. + if (!isa<Constant>(SI->getCondition()) && + L.isLoopInvariant(SI->getCondition())) + UnswitchCandidates.push_back({SI, {SI->getCondition()}}); + continue; + } - // Try trivial unswitch first before loop over other basic blocks in the loop. - Changed |= unswitchAllTrivialConditions(L, DT, LI); + auto *BI = dyn_cast<BranchInst>(BB->getTerminator()); + if (!BI || !BI->isConditional() || isa<Constant>(BI->getCondition()) || + BI->getSuccessor(0) == BI->getSuccessor(1)) + continue; - // If we're not doing non-trivial unswitching, we're done. We both accept - // a parameter but also check a local flag that can be used for testing - // a debugging. - if (!NonTrivial && !EnableNonTrivialUnswitch) - return Changed; - - // Collect all remaining invariant branch conditions within this loop (as - // opposed to an inner loop which would be handled when visiting that inner - // loop). - SmallVector<TerminatorInst *, 4> UnswitchCandidates; - for (auto *BB : L.blocks()) - if (LI.getLoopFor(BB) == &L) - if (auto *BI = dyn_cast<BranchInst>(BB->getTerminator())) - if (BI->isConditional() && L.isLoopInvariant(BI->getCondition()) && - BI->getSuccessor(0) != BI->getSuccessor(1)) - UnswitchCandidates.push_back(BI); + if (L.isLoopInvariant(BI->getCondition())) { + UnswitchCandidates.push_back({BI, {BI->getCondition()}}); + continue; + } + + Instruction &CondI = *cast<Instruction>(BI->getCondition()); + if (CondI.getOpcode() != Instruction::And && + CondI.getOpcode() != Instruction::Or) + continue; + + TinyPtrVector<Value *> Invariants = + collectHomogenousInstGraphLoopInvariants(L, CondI, LI); + if (Invariants.empty()) + continue; + + UnswitchCandidates.push_back({BI, std::move(Invariants)}); + } // If we didn't find any candidates, we're done. if (UnswitchCandidates.empty()) - return Changed; + return false; + + // Check if there are irreducible CFG cycles in this loop. If so, we cannot + // easily unswitch non-trivial edges out of the loop. Doing so might turn the + // irreducible control flow into reducible control flow and introduce new + // loops "out of thin air". If we ever discover important use cases for doing + // this, we can add support to loop unswitch, but it is a lot of complexity + // for what seems little or no real world benefit. + LoopBlocksRPO RPOT(&L); + RPOT.perform(&LI); + if (containsIrreducibleCFG<const BasicBlock *>(RPOT, LI)) + return false; - DEBUG(dbgs() << "Considering " << UnswitchCandidates.size() - << " non-trivial loop invariant conditions for unswitching.\n"); + LLVM_DEBUG( + dbgs() << "Considering " << UnswitchCandidates.size() + << " non-trivial loop invariant conditions for unswitching.\n"); // Given that unswitching these terminators will require duplicating parts of // the loop, so we need to be able to model that cost. Compute the ephemeral @@ -1972,10 +2255,10 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, continue; if (I.getType()->isTokenTy() && I.isUsedOutsideOfBlock(BB)) - return Changed; + return false; if (auto CS = CallSite(&I)) if (CS.isConvergent() || CS.cannotDuplicate()) - return Changed; + return false; Cost += TTI.getUserCost(&I); } @@ -1984,7 +2267,7 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, assert(LoopCost >= 0 && "Must not have negative loop costs!"); BBCostMap[BB] = Cost; } - DEBUG(dbgs() << " Total loop cost: " << LoopCost << "\n"); + LLVM_DEBUG(dbgs() << " Total loop cost: " << LoopCost << "\n"); // Now we find the best candidate by searching for the one with the following // properties in order: @@ -2003,8 +2286,8 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, SmallDenseMap<DomTreeNode *, int, 4> DTCostMap; // Given a terminator which might be unswitched, computes the non-duplicated // cost for that terminator. - auto ComputeUnswitchedCost = [&](TerminatorInst *TI) { - BasicBlock &BB = *TI->getParent(); + auto ComputeUnswitchedCost = [&](TerminatorInst &TI, bool FullUnswitch) { + BasicBlock &BB = *TI.getParent(); SmallPtrSet<BasicBlock *, 4> Visited; int Cost = LoopCost; @@ -2013,6 +2296,26 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, if (!Visited.insert(SuccBB).second) continue; + // If this is a partial unswitch candidate, then it must be a conditional + // branch with a condition of either `or` or `and`. In that case, one of + // the successors is necessarily duplicated, so don't even try to remove + // its cost. + if (!FullUnswitch) { + auto &BI = cast<BranchInst>(TI); + if (cast<Instruction>(BI.getCondition())->getOpcode() == + Instruction::And) { + if (SuccBB == BI.getSuccessor(1)) + continue; + } else { + assert(cast<Instruction>(BI.getCondition())->getOpcode() == + Instruction::Or && + "Only `and` and `or` conditions can result in a partial " + "unswitch!"); + if (SuccBB == BI.getSuccessor(0)) + continue; + } + } + // This successor's domtree will not need to be duplicated after // unswitching if the edge to the successor dominates it (and thus the // entire tree). This essentially means there is no other path into this @@ -2036,27 +2339,95 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, }; TerminatorInst *BestUnswitchTI = nullptr; int BestUnswitchCost; - for (TerminatorInst *CandidateTI : UnswitchCandidates) { - int CandidateCost = ComputeUnswitchedCost(CandidateTI); - DEBUG(dbgs() << " Computed cost of " << CandidateCost - << " for unswitch candidate: " << *CandidateTI << "\n"); + ArrayRef<Value *> BestUnswitchInvariants; + for (auto &TerminatorAndInvariants : UnswitchCandidates) { + TerminatorInst &TI = *TerminatorAndInvariants.first; + ArrayRef<Value *> Invariants = TerminatorAndInvariants.second; + BranchInst *BI = dyn_cast<BranchInst>(&TI); + int CandidateCost = ComputeUnswitchedCost( + TI, /*FullUnswitch*/ !BI || (Invariants.size() == 1 && + Invariants[0] == BI->getCondition())); + LLVM_DEBUG(dbgs() << " Computed cost of " << CandidateCost + << " for unswitch candidate: " << TI << "\n"); if (!BestUnswitchTI || CandidateCost < BestUnswitchCost) { - BestUnswitchTI = CandidateTI; + BestUnswitchTI = &TI; BestUnswitchCost = CandidateCost; + BestUnswitchInvariants = Invariants; } } - if (BestUnswitchCost < UnswitchThreshold) { - DEBUG(dbgs() << " Trying to unswitch non-trivial (cost = " - << BestUnswitchCost << ") branch: " << *BestUnswitchTI - << "\n"); - Changed |= unswitchInvariantBranch(L, cast<BranchInst>(*BestUnswitchTI), DT, - LI, AC, NonTrivialUnswitchCB); - } else { - DEBUG(dbgs() << "Cannot unswitch, lowest cost found: " << BestUnswitchCost - << "\n"); + if (BestUnswitchCost >= UnswitchThreshold) { + LLVM_DEBUG(dbgs() << "Cannot unswitch, lowest cost found: " + << BestUnswitchCost << "\n"); + return false; + } + + LLVM_DEBUG(dbgs() << " Trying to unswitch non-trivial (cost = " + << BestUnswitchCost << ") terminator: " << *BestUnswitchTI + << "\n"); + return unswitchNontrivialInvariants( + L, *BestUnswitchTI, BestUnswitchInvariants, DT, LI, AC, UnswitchCB, SE); +} + +/// Unswitch control flow predicated on loop invariant conditions. +/// +/// This first hoists all branches or switches which are trivial (IE, do not +/// require duplicating any part of the loop) out of the loop body. It then +/// looks at other loop invariant control flows and tries to unswitch those as +/// well by cloning the loop if the result is small enough. +/// +/// The `DT`, `LI`, `AC`, `TTI` parameters are required analyses that are also +/// updated based on the unswitch. +/// +/// If either `NonTrivial` is true or the flag `EnableNonTrivialUnswitch` is +/// true, we will attempt to do non-trivial unswitching as well as trivial +/// unswitching. +/// +/// The `UnswitchCB` callback provided will be run after unswitching is +/// complete, with the first parameter set to `true` if the provided loop +/// remains a loop, and a list of new sibling loops created. +/// +/// If `SE` is non-null, we will update that analysis based on the unswitching +/// done. +static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC, TargetTransformInfo &TTI, + bool NonTrivial, + function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB, + ScalarEvolution *SE) { + assert(L.isRecursivelyLCSSAForm(DT, LI) && + "Loops must be in LCSSA form before unswitching."); + bool Changed = false; + + // Must be in loop simplified form: we need a preheader and dedicated exits. + if (!L.isLoopSimplifyForm()) + return false; + + // Try trivial unswitch first before loop over other basic blocks in the loop. + if (unswitchAllTrivialConditions(L, DT, LI, SE)) { + // If we unswitched successfully we will want to clean up the loop before + // processing it further so just mark it as unswitched and return. + UnswitchCB(/*CurrentLoopValid*/ true, {}); + return true; } + // If we're not doing non-trivial unswitching, we're done. We both accept + // a parameter but also check a local flag that can be used for testing + // a debugging. + if (!NonTrivial && !EnableNonTrivialUnswitch) + return false; + + // For non-trivial unswitching, because it often creates new loops, we rely on + // the pass manager to iterate on the loops rather than trying to immediately + // reach a fixed point. There is no substantial advantage to iterating + // internally, and if any of the new loops are simplified enough to contain + // trivial unswitching we want to prefer those. + + // Try to unswitch the best invariant condition. We prefer this full unswitch to + // a partial unswitch when possible below the threshold. + if (unswitchBestCondition(L, DT, LI, AC, TTI, UnswitchCB, SE)) + return true; + + // No other opportunities to unswitch. return Changed; } @@ -2066,16 +2437,18 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, Function &F = *L.getHeader()->getParent(); (void)F; - DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << L << "\n"); + LLVM_DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << L + << "\n"); // Save the current loop name in a variable so that we can report it even // after it has been deleted. std::string LoopName = L.getName(); - auto NonTrivialUnswitchCB = [&L, &U, &LoopName](bool CurrentLoopValid, - ArrayRef<Loop *> NewLoops) { + auto UnswitchCB = [&L, &U, &LoopName](bool CurrentLoopValid, + ArrayRef<Loop *> NewLoops) { // If we did a non-trivial unswitch, we have added new (cloned) loops. - U.addSiblingLoops(NewLoops); + if (!NewLoops.empty()) + U.addSiblingLoops(NewLoops); // If the current loop remains valid, we should revisit it to catch any // other unswitch opportunities. Otherwise, we need to mark it as deleted. @@ -2085,15 +2458,13 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, U.markLoopAsDeleted(L, LoopName); }; - if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial, - NonTrivialUnswitchCB)) + if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial, UnswitchCB, + &AR.SE)) return PreservedAnalyses::all(); -#ifndef NDEBUG // Historically this pass has had issues with the dominator tree so verify it // in asserts builds. - AR.DT.verifyDomTree(); -#endif + assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast)); return getLoopPassPreservedAnalyses(); } @@ -2128,15 +2499,19 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { Function &F = *L->getHeader()->getParent(); - DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << *L << "\n"); + LLVM_DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << *L + << "\n"); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - auto NonTrivialUnswitchCB = [&L, &LPM](bool CurrentLoopValid, - ArrayRef<Loop *> NewLoops) { + auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); + auto *SE = SEWP ? &SEWP->getSE() : nullptr; + + auto UnswitchCB = [&L, &LPM](bool CurrentLoopValid, + ArrayRef<Loop *> NewLoops) { // If we did a non-trivial unswitch, we have added new (cloned) loops. for (auto *NewL : NewLoops) LPM.addLoop(*NewL); @@ -2150,18 +2525,16 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { LPM.markLoopAsDeleted(*L); }; - bool Changed = - unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, NonTrivialUnswitchCB); + bool Changed = unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB, SE); // If anything was unswitched, also clear any cached information about this // loop. LPM.deleteSimpleAnalysisLoop(L); -#ifndef NDEBUG // Historically this pass has had issues with the dominator tree so verify it // in asserts builds. - DT.verifyDomTree(); -#endif + assert(DT.verify(DominatorTree::VerificationLevel::Fast)); + return Changed; } diff --git a/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/lib/Transforms/Scalar/SimplifyCFGPass.cpp index 1522170dc3b9..b7b1db76b492 100644 --- a/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ b/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -28,6 +28,7 @@ #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" @@ -39,7 +40,6 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/SimplifyCFG.h" -#include "llvm/Transforms/Utils/Local.h" #include <utility> using namespace llvm; diff --git a/lib/Transforms/Scalar/Sink.cpp b/lib/Transforms/Scalar/Sink.cpp index cfb8a062299f..ca6b93e0b4a9 100644 --- a/lib/Transforms/Scalar/Sink.cpp +++ b/lib/Transforms/Scalar/Sink.cpp @@ -114,7 +114,7 @@ static bool IsAcceptableTarget(Instruction *Inst, BasicBlock *SuccToSinkTo, if (SuccToSinkTo->getUniquePredecessor() != Inst->getParent()) { // We cannot sink a load across a critical edge - there may be stores in // other code paths. - if (isa<LoadInst>(Inst)) + if (Inst->mayReadFromMemory()) return false; // We don't want to sink across a critical edge if we don't dominate the @@ -187,11 +187,9 @@ static bool SinkInstruction(Instruction *Inst, if (!SuccToSinkTo) return false; - DEBUG(dbgs() << "Sink" << *Inst << " ("; - Inst->getParent()->printAsOperand(dbgs(), false); - dbgs() << " -> "; - SuccToSinkTo->printAsOperand(dbgs(), false); - dbgs() << ")\n"); + LLVM_DEBUG(dbgs() << "Sink" << *Inst << " ("; + Inst->getParent()->printAsOperand(dbgs(), false); dbgs() << " -> "; + SuccToSinkTo->printAsOperand(dbgs(), false); dbgs() << ")\n"); // Move the instruction. Inst->moveBefore(&*SuccToSinkTo->getFirstInsertionPt()); @@ -244,7 +242,7 @@ static bool iterativelySinkInstructions(Function &F, DominatorTree &DT, do { MadeChange = false; - DEBUG(dbgs() << "Sinking iteration " << NumSinkIter << "\n"); + LLVM_DEBUG(dbgs() << "Sinking iteration " << NumSinkIter << "\n"); // Process all basic blocks. for (BasicBlock &I : F) MadeChange |= ProcessBlock(I, DT, LI, AA); diff --git a/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp b/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp index 23156d5a4d83..6743e19a7c92 100644 --- a/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp +++ b/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp @@ -64,7 +64,7 @@ isSafeToSpeculatePHIUsers(PHINode &PN, DominatorTree &DT, // block. We should consider using actual post-dominance here in the // future. if (UI->getParent() != PhiBB) { - DEBUG(dbgs() << " Unsafe: use in a different BB: " << *UI << "\n"); + LLVM_DEBUG(dbgs() << " Unsafe: use in a different BB: " << *UI << "\n"); return false; } @@ -75,7 +75,7 @@ isSafeToSpeculatePHIUsers(PHINode &PN, DominatorTree &DT, // probably change this to do at least a limited scan of the intervening // instructions and allow handling stores in easily proven safe cases. if (mayBeMemoryDependent(*UI)) { - DEBUG(dbgs() << " Unsafe: can't speculate use: " << *UI << "\n"); + LLVM_DEBUG(dbgs() << " Unsafe: can't speculate use: " << *UI << "\n"); return false; } @@ -126,8 +126,8 @@ isSafeToSpeculatePHIUsers(PHINode &PN, DominatorTree &DT, // If when we directly test whether this is safe it fails, bail. if (UnsafeSet.count(OpI) || ParentBB != PhiBB || mayBeMemoryDependent(*OpI)) { - DEBUG(dbgs() << " Unsafe: can't speculate transitive use: " << *OpI - << "\n"); + LLVM_DEBUG(dbgs() << " Unsafe: can't speculate transitive use: " + << *OpI << "\n"); // Record the stack of instructions which reach this node as unsafe // so we prune subsequent searches. UnsafeSet.insert(OpI); @@ -229,7 +229,7 @@ static bool isSafeAndProfitableToSpeculateAroundPHI( NonFreeMat |= MatCost != TTI.TCC_Free; } if (!NonFreeMat) { - DEBUG(dbgs() << " Free: " << PN << "\n"); + LLVM_DEBUG(dbgs() << " Free: " << PN << "\n"); // No profit in free materialization. return false; } @@ -237,7 +237,7 @@ static bool isSafeAndProfitableToSpeculateAroundPHI( // Now check that the uses of this PHI can actually be speculated, // otherwise we'll still have to materialize the PHI value. if (!isSafeToSpeculatePHIUsers(PN, DT, PotentialSpecSet, UnsafeSet)) { - DEBUG(dbgs() << " Unsafe PHI: " << PN << "\n"); + LLVM_DEBUG(dbgs() << " Unsafe PHI: " << PN << "\n"); return false; } @@ -266,7 +266,7 @@ static bool isSafeAndProfitableToSpeculateAroundPHI( // Assume we will commute the constant to the RHS to be canonical. Idx = 1; - // Get the intrinsic ID if this user is an instrinsic. + // Get the intrinsic ID if this user is an intrinsic. Intrinsic::ID IID = Intrinsic::not_intrinsic; if (auto *UserII = dyn_cast<IntrinsicInst>(UserI)) IID = UserII->getIntrinsicID(); @@ -288,9 +288,13 @@ static bool isSafeAndProfitableToSpeculateAroundPHI( // just bail. We're only interested in cases where folding the incoming // constants is at least break-even on all paths. if (FoldedCost > MatCost) { - DEBUG(dbgs() << " Not profitable to fold imm: " << *IncomingC << "\n" - " Materializing cost: " << MatCost << "\n" - " Accumulated folded cost: " << FoldedCost << "\n"); + LLVM_DEBUG(dbgs() << " Not profitable to fold imm: " << *IncomingC + << "\n" + " Materializing cost: " + << MatCost + << "\n" + " Accumulated folded cost: " + << FoldedCost << "\n"); return false; } } @@ -310,8 +314,8 @@ static bool isSafeAndProfitableToSpeculateAroundPHI( "less that its materialized cost, " "the sum must be as well."); - DEBUG(dbgs() << " Cost savings " << (TotalMatCost - TotalFoldedCost) - << ": " << PN << "\n"); + LLVM_DEBUG(dbgs() << " Cost savings " << (TotalMatCost - TotalFoldedCost) + << ": " << PN << "\n"); CostSavingsMap[&PN] = TotalMatCost - TotalFoldedCost; return true; } @@ -489,9 +493,13 @@ findProfitablePHIs(ArrayRef<PHINode *> PNs, // and zero out the cost of everything it depends on. int CostSavings = CostSavingsMap.find(PN)->second; if (SpecCost > CostSavings) { - DEBUG(dbgs() << " Not profitable, speculation cost: " << *PN << "\n" - " Cost savings: " << CostSavings << "\n" - " Speculation cost: " << SpecCost << "\n"); + LLVM_DEBUG(dbgs() << " Not profitable, speculation cost: " << *PN + << "\n" + " Cost savings: " + << CostSavings + << "\n" + " Speculation cost: " + << SpecCost << "\n"); continue; } @@ -545,7 +553,7 @@ static void speculatePHIs(ArrayRef<PHINode *> SpecPNs, SmallPtrSetImpl<Instruction *> &PotentialSpecSet, SmallSetVector<BasicBlock *, 16> &PredSet, DominatorTree &DT) { - DEBUG(dbgs() << " Speculating around " << SpecPNs.size() << " PHIs!\n"); + LLVM_DEBUG(dbgs() << " Speculating around " << SpecPNs.size() << " PHIs!\n"); NumPHIsSpeculated += SpecPNs.size(); // Split any critical edges so that we have a block to hoist into. @@ -558,8 +566,8 @@ static void speculatePHIs(ArrayRef<PHINode *> SpecPNs, CriticalEdgeSplittingOptions(&DT).setMergeIdenticalEdges()); if (NewPredBB) { ++NumEdgesSplit; - DEBUG(dbgs() << " Split critical edge from: " << PredBB->getName() - << "\n"); + LLVM_DEBUG(dbgs() << " Split critical edge from: " << PredBB->getName() + << "\n"); SpecPreds.push_back(NewPredBB); } else { assert(PredBB->getSingleSuccessor() == ParentBB && @@ -593,14 +601,15 @@ static void speculatePHIs(ArrayRef<PHINode *> SpecPNs, int NumSpecInsts = SpecList.size() * SpecPreds.size(); int NumRedundantInsts = NumSpecInsts - SpecList.size(); - DEBUG(dbgs() << " Inserting " << NumSpecInsts << " speculated instructions, " - << NumRedundantInsts << " redundancies\n"); + LLVM_DEBUG(dbgs() << " Inserting " << NumSpecInsts + << " speculated instructions, " << NumRedundantInsts + << " redundancies\n"); NumSpeculatedInstructions += NumSpecInsts; NumNewRedundantInstructions += NumRedundantInsts; // Each predecessor is numbered by its index in `SpecPreds`, so for each // instruction we speculate, the speculated instruction is stored in that - // index of the vector asosciated with the original instruction. We also + // index of the vector associated with the original instruction. We also // store the incoming values for each predecessor from any PHIs used. SmallDenseMap<Instruction *, SmallVector<Value *, 2>, 16> SpeculatedValueMap; @@ -716,7 +725,7 @@ static void speculatePHIs(ArrayRef<PHINode *> SpecPNs, /// true when at least some speculation occurs. static bool tryToSpeculatePHIs(SmallVectorImpl<PHINode *> &PNs, DominatorTree &DT, TargetTransformInfo &TTI) { - DEBUG(dbgs() << "Evaluating phi nodes for speculation:\n"); + LLVM_DEBUG(dbgs() << "Evaluating phi nodes for speculation:\n"); // Savings in cost from speculating around a PHI node. SmallDenseMap<PHINode *, int, 16> CostSavingsMap; @@ -745,7 +754,7 @@ static bool tryToSpeculatePHIs(SmallVectorImpl<PHINode *> &PNs, PNs.end()); // If no PHIs were profitable, skip. if (PNs.empty()) { - DEBUG(dbgs() << " No safe and profitable PHIs found!\n"); + LLVM_DEBUG(dbgs() << " No safe and profitable PHIs found!\n"); return false; } @@ -763,13 +772,13 @@ static bool tryToSpeculatePHIs(SmallVectorImpl<PHINode *> &PNs, // differently. if (isa<IndirectBrInst>(PredBB->getTerminator()) || isa<InvokeInst>(PredBB->getTerminator())) { - DEBUG(dbgs() << " Invalid: predecessor terminator: " << PredBB->getName() - << "\n"); + LLVM_DEBUG(dbgs() << " Invalid: predecessor terminator: " + << PredBB->getName() << "\n"); return false; } } if (PredSet.size() < 2) { - DEBUG(dbgs() << " Unimportant: phi with only one predecessor\n"); + LLVM_DEBUG(dbgs() << " Unimportant: phi with only one predecessor\n"); return false; } diff --git a/lib/Transforms/Scalar/SpeculativeExecution.cpp b/lib/Transforms/Scalar/SpeculativeExecution.cpp index a7c308b59877..f5e1dd6ed850 100644 --- a/lib/Transforms/Scalar/SpeculativeExecution.cpp +++ b/lib/Transforms/Scalar/SpeculativeExecution.cpp @@ -62,7 +62,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/SpeculativeExecution.h" -#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Instructions.h" @@ -137,6 +137,7 @@ INITIALIZE_PASS_END(SpeculativeExecutionLegacyPass, "speculative-execution", void SpeculativeExecutionLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<TargetTransformInfoWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); + AU.setPreservesCFG(); } bool SpeculativeExecutionLegacyPass::runOnFunction(Function &F) { @@ -151,8 +152,8 @@ namespace llvm { bool SpeculativeExecutionPass::runImpl(Function &F, TargetTransformInfo *TTI) { if (OnlyIfDivergentTarget && !TTI->hasBranchDivergence()) { - DEBUG(dbgs() << "Not running SpeculativeExecution because " - "TTI->hasBranchDivergence() is false.\n"); + LLVM_DEBUG(dbgs() << "Not running SpeculativeExecution because " + "TTI->hasBranchDivergence() is false.\n"); return false; } @@ -251,7 +252,7 @@ static unsigned ComputeSpeculationCost(const Instruction *I, bool SpeculativeExecutionPass::considerHoistingFromTo( BasicBlock &FromBlock, BasicBlock &ToBlock) { - SmallSet<const Instruction *, 8> NotHoisted; + SmallPtrSet<const Instruction *, 8> NotHoisted; const auto AllPrecedingUsesFromBlockHoisted = [&NotHoisted](User *U) { for (Value* V : U->operand_values()) { if (Instruction *I = dyn_cast<Instruction>(V)) { @@ -314,6 +315,7 @@ PreservedAnalyses SpeculativeExecutionPass::run(Function &F, return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserve<GlobalsAA>(); + PA.preserveSet<CFGAnalyses>(); return PA; } } // namespace llvm diff --git a/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp b/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp index ce40af1223f6..2061db13639a 100644 --- a/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp +++ b/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -61,6 +61,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -80,7 +81,6 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstdint> #include <limits> diff --git a/lib/Transforms/Scalar/StructurizeCFG.cpp b/lib/Transforms/Scalar/StructurizeCFG.cpp index 2972e1cff9a4..d650264176aa 100644 --- a/lib/Transforms/Scalar/StructurizeCFG.cpp +++ b/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -40,6 +40,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include <algorithm> #include <cassert> @@ -55,6 +56,12 @@ static const char *const FlowBlockName = "Flow"; namespace { +static cl::opt<bool> ForceSkipUniformRegions( + "structurizecfg-skip-uniform-regions", + cl::Hidden, + cl::desc("Force whether the StructurizeCFG pass skips uniform regions"), + cl::init(false)); + // Definition of the complex types used in this pass. using BBValuePair = std::pair<BasicBlock *, Value *>; @@ -120,7 +127,7 @@ public: bool resultIsRememberedBlock() { return ResultIsRemembered; } }; -/// @brief Transforms the control flow graph on one single entry/exit region +/// Transforms the control flow graph on one single entry/exit region /// at a time. /// /// After the transform all "If"/"Then"/"Else" style control flow looks like @@ -176,6 +183,7 @@ class StructurizeCFG : public RegionPass { Function *Func; Region *ParentRegion; + DivergenceAnalysis *DA; DominatorTree *DT; LoopInfo *LI; @@ -196,6 +204,9 @@ class StructurizeCFG : public RegionPass { void orderNodes(); + Loop *getAdjustedLoop(RegionNode *RN); + unsigned getAdjustedLoopDepth(RegionNode *RN); + void analyzeLoops(RegionNode *N); Value *invert(Value *Condition); @@ -242,8 +253,11 @@ class StructurizeCFG : public RegionPass { public: static char ID; - explicit StructurizeCFG(bool SkipUniformRegions = false) - : RegionPass(ID), SkipUniformRegions(SkipUniformRegions) { + explicit StructurizeCFG(bool SkipUniformRegions_ = false) + : RegionPass(ID), + SkipUniformRegions(SkipUniformRegions_) { + if (ForceSkipUniformRegions.getNumOccurrences()) + SkipUniformRegions = ForceSkipUniformRegions.getValue(); initializeStructurizeCFGPass(*PassRegistry::getPassRegistry()); } @@ -278,7 +292,7 @@ INITIALIZE_PASS_DEPENDENCY(RegionInfoPass) INITIALIZE_PASS_END(StructurizeCFG, "structurizecfg", "Structurize the CFG", false, false) -/// \brief Initialize the types and constants used in the pass +/// Initialize the types and constants used in the pass bool StructurizeCFG::doInitialization(Region *R, RGPassManager &RGM) { LLVMContext &Context = R->getEntry()->getContext(); @@ -290,7 +304,27 @@ bool StructurizeCFG::doInitialization(Region *R, RGPassManager &RGM) { return false; } -/// \brief Build up the general order of nodes +/// Use the exit block to determine the loop if RN is a SubRegion. +Loop *StructurizeCFG::getAdjustedLoop(RegionNode *RN) { + if (RN->isSubRegion()) { + Region *SubRegion = RN->getNodeAs<Region>(); + return LI->getLoopFor(SubRegion->getExit()); + } + + return LI->getLoopFor(RN->getEntry()); +} + +/// Use the exit block to determine the loop depth if RN is a SubRegion. +unsigned StructurizeCFG::getAdjustedLoopDepth(RegionNode *RN) { + if (RN->isSubRegion()) { + Region *SubR = RN->getNodeAs<Region>(); + return LI->getLoopDepth(SubR->getExit()); + } + + return LI->getLoopDepth(RN->getEntry()); +} + +/// Build up the general order of nodes void StructurizeCFG::orderNodes() { ReversePostOrderTraversal<Region*> RPOT(ParentRegion); SmallDenseMap<Loop*, unsigned, 8> LoopBlocks; @@ -299,16 +333,15 @@ void StructurizeCFG::orderNodes() { // to what we want. The only problem with it is that sometimes backedges // for outer loops will be visited before backedges for inner loops. for (RegionNode *RN : RPOT) { - BasicBlock *BB = RN->getEntry(); - Loop *Loop = LI->getLoopFor(BB); + Loop *Loop = getAdjustedLoop(RN); ++LoopBlocks[Loop]; } unsigned CurrentLoopDepth = 0; Loop *CurrentLoop = nullptr; for (auto I = RPOT.begin(), E = RPOT.end(); I != E; ++I) { - BasicBlock *BB = (*I)->getEntry(); - unsigned LoopDepth = LI->getLoopDepth(BB); + RegionNode *RN = cast<RegionNode>(*I); + unsigned LoopDepth = getAdjustedLoopDepth(RN); if (is_contained(Order, *I)) continue; @@ -320,15 +353,14 @@ void StructurizeCFG::orderNodes() { auto LoopI = I; while (unsigned &BlockCount = LoopBlocks[CurrentLoop]) { LoopI++; - BasicBlock *LoopBB = (*LoopI)->getEntry(); - if (LI->getLoopFor(LoopBB) == CurrentLoop) { + if (getAdjustedLoop(cast<RegionNode>(*LoopI)) == CurrentLoop) { --BlockCount; Order.push_back(*LoopI); } } } - CurrentLoop = LI->getLoopFor(BB); + CurrentLoop = getAdjustedLoop(RN); if (CurrentLoop) LoopBlocks[CurrentLoop]--; @@ -343,7 +375,7 @@ void StructurizeCFG::orderNodes() { std::reverse(Order.begin(), Order.end()); } -/// \brief Determine the end of the loops +/// Determine the end of the loops void StructurizeCFG::analyzeLoops(RegionNode *N) { if (N->isSubRegion()) { // Test for exit as back edge @@ -362,15 +394,16 @@ void StructurizeCFG::analyzeLoops(RegionNode *N) { } } -/// \brief Invert the given condition +/// Invert the given condition Value *StructurizeCFG::invert(Value *Condition) { // First: Check if it's a constant if (Constant *C = dyn_cast<Constant>(Condition)) return ConstantExpr::getNot(C); // Second: If the condition is already inverted, return the original value - if (match(Condition, m_Not(m_Value(Condition)))) - return Condition; + Value *NotCondition; + if (match(Condition, m_Not(m_Value(NotCondition)))) + return NotCondition; if (Instruction *Inst = dyn_cast<Instruction>(Condition)) { // Third: Check all the users for an invert @@ -394,7 +427,7 @@ Value *StructurizeCFG::invert(Value *Condition) { llvm_unreachable("Unhandled condition to invert"); } -/// \brief Build the condition for one edge +/// Build the condition for one edge Value *StructurizeCFG::buildCondition(BranchInst *Term, unsigned Idx, bool Invert) { Value *Cond = Invert ? BoolFalse : BoolTrue; @@ -407,7 +440,7 @@ Value *StructurizeCFG::buildCondition(BranchInst *Term, unsigned Idx, return Cond; } -/// \brief Analyze the predecessors of each block and build up predicates +/// Analyze the predecessors of each block and build up predicates void StructurizeCFG::gatherPredicates(RegionNode *N) { RegionInfo *RI = ParentRegion->getRegionInfo(); BasicBlock *BB = N->getEntry(); @@ -465,7 +498,7 @@ void StructurizeCFG::gatherPredicates(RegionNode *N) { } } -/// \brief Collect various loop and predicate infos +/// Collect various loop and predicate infos void StructurizeCFG::collectInfos() { // Reset predicate Predicates.clear(); @@ -478,10 +511,10 @@ void StructurizeCFG::collectInfos() { Visited.clear(); for (RegionNode *RN : reverse(Order)) { - DEBUG(dbgs() << "Visiting: " - << (RN->isSubRegion() ? "SubRegion with entry: " : "") - << RN->getEntry()->getName() << " Loop Depth: " - << LI->getLoopDepth(RN->getEntry()) << "\n"); + LLVM_DEBUG(dbgs() << "Visiting: " + << (RN->isSubRegion() ? "SubRegion with entry: " : "") + << RN->getEntry()->getName() << " Loop Depth: " + << LI->getLoopDepth(RN->getEntry()) << "\n"); // Analyze all the conditions leading to a node gatherPredicates(RN); @@ -494,7 +527,7 @@ void StructurizeCFG::collectInfos() { } } -/// \brief Insert the missing branch conditions +/// Insert the missing branch conditions void StructurizeCFG::insertConditions(bool Loops) { BranchVector &Conds = Loops ? LoopConds : Conditions; Value *Default = Loops ? BoolTrue : BoolFalse; @@ -540,14 +573,11 @@ void StructurizeCFG::insertConditions(bool Loops) { } } -/// \brief Remove all PHI values coming from "From" into "To" and remember +/// Remove all PHI values coming from "From" into "To" and remember /// them in DeletedPhis void StructurizeCFG::delPhiValues(BasicBlock *From, BasicBlock *To) { PhiMap &Map = DeletedPhis[To]; - for (Instruction &I : *To) { - if (!isa<PHINode>(I)) - break; - PHINode &Phi = cast<PHINode>(I); + for (PHINode &Phi : To->phis()) { while (Phi.getBasicBlockIndex(From) != -1) { Value *Deleted = Phi.removeIncomingValue(From, false); Map[&Phi].push_back(std::make_pair(From, Deleted)); @@ -555,19 +585,16 @@ void StructurizeCFG::delPhiValues(BasicBlock *From, BasicBlock *To) { } } -/// \brief Add a dummy PHI value as soon as we knew the new predecessor +/// Add a dummy PHI value as soon as we knew the new predecessor void StructurizeCFG::addPhiValues(BasicBlock *From, BasicBlock *To) { - for (Instruction &I : *To) { - if (!isa<PHINode>(I)) - break; - PHINode &Phi = cast<PHINode>(I); + for (PHINode &Phi : To->phis()) { Value *Undef = UndefValue::get(Phi.getType()); Phi.addIncoming(Undef, From); } AddedPhis[To].push_back(From); } -/// \brief Add the real PHI value as soon as everything is set up +/// Add the real PHI value as soon as everything is set up void StructurizeCFG::setPhiValues() { SSAUpdater Updater; for (const auto &AddedPhi : AddedPhis) { @@ -607,7 +634,7 @@ void StructurizeCFG::setPhiValues() { assert(DeletedPhis.empty()); } -/// \brief Remove phi values from all successors and then remove the terminator. +/// Remove phi values from all successors and then remove the terminator. void StructurizeCFG::killTerminator(BasicBlock *BB) { TerminatorInst *Term = BB->getTerminator(); if (!Term) @@ -617,10 +644,12 @@ void StructurizeCFG::killTerminator(BasicBlock *BB) { SI != SE; ++SI) delPhiValues(BB, *SI); + if (DA) + DA->removeValue(Term); Term->eraseFromParent(); } -/// \brief Let node exit(s) point to NewExit +/// Let node exit(s) point to NewExit void StructurizeCFG::changeExit(RegionNode *Node, BasicBlock *NewExit, bool IncludeDominator) { if (Node->isSubRegion()) { @@ -666,7 +695,7 @@ void StructurizeCFG::changeExit(RegionNode *Node, BasicBlock *NewExit, } } -/// \brief Create a new flow node and update dominator tree and region info +/// Create a new flow node and update dominator tree and region info BasicBlock *StructurizeCFG::getNextFlow(BasicBlock *Dominator) { LLVMContext &Context = Func->getContext(); BasicBlock *Insert = Order.empty() ? ParentRegion->getExit() : @@ -678,7 +707,7 @@ BasicBlock *StructurizeCFG::getNextFlow(BasicBlock *Dominator) { return Flow; } -/// \brief Create a new or reuse the previous node as flow node +/// Create a new or reuse the previous node as flow node BasicBlock *StructurizeCFG::needPrefix(bool NeedEmpty) { BasicBlock *Entry = PrevNode->getEntry(); @@ -697,7 +726,7 @@ BasicBlock *StructurizeCFG::needPrefix(bool NeedEmpty) { return Flow; } -/// \brief Returns the region exit if possible, otherwise just a new flow node +/// Returns the region exit if possible, otherwise just a new flow node BasicBlock *StructurizeCFG::needPostfix(BasicBlock *Flow, bool ExitUseAllowed) { if (!Order.empty() || !ExitUseAllowed) @@ -709,13 +738,13 @@ BasicBlock *StructurizeCFG::needPostfix(BasicBlock *Flow, return Exit; } -/// \brief Set the previous node +/// Set the previous node void StructurizeCFG::setPrevNode(BasicBlock *BB) { PrevNode = ParentRegion->contains(BB) ? ParentRegion->getBBNode(BB) : nullptr; } -/// \brief Does BB dominate all the predicates of Node? +/// Does BB dominate all the predicates of Node? bool StructurizeCFG::dominatesPredicates(BasicBlock *BB, RegionNode *Node) { BBPredicates &Preds = Predicates[Node->getEntry()]; return llvm::all_of(Preds, [&](std::pair<BasicBlock *, Value *> Pred) { @@ -723,7 +752,7 @@ bool StructurizeCFG::dominatesPredicates(BasicBlock *BB, RegionNode *Node) { }); } -/// \brief Can we predict that this node will always be called? +/// Can we predict that this node will always be called? bool StructurizeCFG::isPredictableTrue(RegionNode *Node) { BBPredicates &Preds = Predicates[Node->getEntry()]; bool Dominated = false; @@ -851,7 +880,7 @@ void StructurizeCFG::createFlow() { } /// Handle a rare case where the disintegrated nodes instructions -/// no longer dominate all their uses. Not sure if this is really nessasary +/// no longer dominate all their uses. Not sure if this is really necessary void StructurizeCFG::rebuildSSA() { SSAUpdater Updater; for (BasicBlock *BB : ParentRegion->blocks()) @@ -884,30 +913,60 @@ void StructurizeCFG::rebuildSSA() { } } -static bool hasOnlyUniformBranches(const Region *R, +static bool hasOnlyUniformBranches(Region *R, unsigned UniformMDKindID, const DivergenceAnalysis &DA) { - for (const BasicBlock *BB : R->blocks()) { - const BranchInst *Br = dyn_cast<BranchInst>(BB->getTerminator()); - if (!Br || !Br->isConditional()) - continue; + for (auto E : R->elements()) { + if (!E->isSubRegion()) { + auto Br = dyn_cast<BranchInst>(E->getEntry()->getTerminator()); + if (!Br || !Br->isConditional()) + continue; - if (!DA.isUniform(Br->getCondition())) - return false; - DEBUG(dbgs() << "BB: " << BB->getName() << " has uniform terminator\n"); + if (!DA.isUniform(Br)) + return false; + LLVM_DEBUG(dbgs() << "BB: " << Br->getParent()->getName() + << " has uniform terminator\n"); + } else { + // Explicitly refuse to treat regions as uniform if they have non-uniform + // subregions. We cannot rely on DivergenceAnalysis for branches in + // subregions because those branches may have been removed and re-created, + // so we look for our metadata instead. + // + // Warning: It would be nice to treat regions as uniform based only on + // their direct child basic blocks' terminators, regardless of whether + // subregions are uniform or not. However, this requires a very careful + // look at SIAnnotateControlFlow to make sure nothing breaks there. + for (auto BB : E->getNodeAs<Region>()->blocks()) { + auto Br = dyn_cast<BranchInst>(BB->getTerminator()); + if (!Br || !Br->isConditional()) + continue; + + if (!Br->getMetadata(UniformMDKindID)) + return false; + } + } } return true; } -/// \brief Run the transformation for each region found +/// Run the transformation for each region found bool StructurizeCFG::runOnRegion(Region *R, RGPassManager &RGM) { if (R->isTopLevelRegion()) return false; + DA = nullptr; + if (SkipUniformRegions) { // TODO: We could probably be smarter here with how we handle sub-regions. - auto &DA = getAnalysis<DivergenceAnalysis>(); - if (hasOnlyUniformBranches(R, DA)) { - DEBUG(dbgs() << "Skipping region with uniform control flow: " << *R << '\n'); + // We currently rely on the fact that metadata is set by earlier invocations + // of the pass on sub-regions, and that this metadata doesn't get lost -- + // but we shouldn't rely on metadata for correctness! + unsigned UniformMDKindID = + R->getEntry()->getContext().getMDKindID("structurizecfg.uniform"); + DA = &getAnalysis<DivergenceAnalysis>(); + + if (hasOnlyUniformBranches(R, UniformMDKindID, *DA)) { + LLVM_DEBUG(dbgs() << "Skipping region with uniform control flow: " << *R + << '\n'); // Mark all direct child block terminators as having been treated as // uniform. To account for a possible future in which non-uniform @@ -919,7 +978,7 @@ bool StructurizeCFG::runOnRegion(Region *R, RGPassManager &RGM) { continue; if (Instruction *Term = E->getEntry()->getTerminator()) - Term->setMetadata("structurizecfg.uniform", MD); + Term->setMetadata(UniformMDKindID, MD); } return false; diff --git a/lib/Transforms/Scalar/TailRecursionElimination.cpp b/lib/Transforms/Scalar/TailRecursionElimination.cpp index 2a1106b41de2..f8cd6c17a5a6 100644 --- a/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -87,7 +87,7 @@ STATISTIC(NumEliminated, "Number of tail calls removed"); STATISTIC(NumRetDuped, "Number of return duplicated"); STATISTIC(NumAccumAdded, "Number of accumulators introduced"); -/// \brief Scan the specified function for alloca instructions. +/// Scan the specified function for alloca instructions. /// If it contains any dynamic allocas, returns false. static bool canTRE(Function &F) { // Because of PR962, we don't TRE dynamic allocas. @@ -302,7 +302,7 @@ static bool markTails(Function &F, bool &AllCallsAreTailCalls, if (Visited[CI->getParent()] != ESCAPED) { // If the escape point was part way through the block, calls after the // escape point wouldn't have been put into DeferredTails. - DEBUG(dbgs() << "Marked as tail call candidate: " << *CI << "\n"); + LLVM_DEBUG(dbgs() << "Marked as tail call candidate: " << *CI << "\n"); CI->setTailCall(); Modified = true; } else { @@ -699,8 +699,8 @@ static bool foldReturnAndProcessPred( BranchInst *BI = UncondBranchPreds.pop_back_val(); BasicBlock *Pred = BI->getParent(); if (CallInst *CI = findTRECandidate(BI, CannotTailCallElimCallsMarkedTail, TTI)){ - DEBUG(dbgs() << "FOLDING: " << *BB - << "INTO UNCOND BRANCH PRED: " << *Pred); + LLVM_DEBUG(dbgs() << "FOLDING: " << *BB + << "INTO UNCOND BRANCH PRED: " << *Pred); ReturnInst *RI = FoldReturnIntoUncondBranch(Ret, BB, Pred); // Cleanup: if all predecessors of BB have been eliminated by diff --git a/lib/Transforms/Utils/AddDiscriminators.cpp b/lib/Transforms/Utils/AddDiscriminators.cpp index 0f0668f24db5..e3ef42362223 100644 --- a/lib/Transforms/Utils/AddDiscriminators.cpp +++ b/lib/Transforms/Utils/AddDiscriminators.cpp @@ -69,7 +69,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" #include <utility> using namespace llvm; @@ -114,7 +114,7 @@ static bool shouldHaveDiscriminator(const Instruction *I) { return !isa<IntrinsicInst>(I) || isa<MemIntrinsic>(I); } -/// \brief Assign DWARF discriminators. +/// Assign DWARF discriminators. /// /// To assign discriminators, we examine the boundaries of every /// basic block and its successors. Suppose there is a basic block B1 @@ -210,9 +210,9 @@ static bool addDiscriminators(Function &F) { // it in 1 byte ULEB128 representation. unsigned Discriminator = R.second ? ++LDM[L] : LDM[L]; I.setDebugLoc(DIL->setBaseDiscriminator(Discriminator)); - DEBUG(dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":" - << DIL->getColumn() << ":" << Discriminator << " " << I - << "\n"); + LLVM_DEBUG(dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":" + << DIL->getColumn() << ":" << Discriminator << " " << I + << "\n"); Changed = true; } } diff --git a/lib/Transforms/Utils/BasicBlockUtils.cpp b/lib/Transforms/Utils/BasicBlockUtils.cpp index 606bd8baccaa..516a785dce1e 100644 --- a/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -20,6 +20,7 @@ #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" @@ -36,7 +37,6 @@ #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Support/Casting.h" -#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstdint> #include <string> @@ -45,16 +45,22 @@ using namespace llvm; -void llvm::DeleteDeadBlock(BasicBlock *BB) { +void llvm::DeleteDeadBlock(BasicBlock *BB, DeferredDominance *DDT) { assert((pred_begin(BB) == pred_end(BB) || // Can delete self loop. BB->getSinglePredecessor() == BB) && "Block is not dead!"); TerminatorInst *BBTerm = BB->getTerminator(); + std::vector<DominatorTree::UpdateType> Updates; // Loop through all of our successors and make sure they know that one // of their predecessors is going away. - for (BasicBlock *Succ : BBTerm->successors()) + if (DDT) + Updates.reserve(BBTerm->getNumSuccessors()); + for (BasicBlock *Succ : BBTerm->successors()) { Succ->removePredecessor(BB); + if (DDT) + Updates.push_back({DominatorTree::Delete, BB, Succ}); + } // Zap all the instructions in the block. while (!BB->empty()) { @@ -69,8 +75,12 @@ void llvm::DeleteDeadBlock(BasicBlock *BB) { BB->getInstList().pop_back(); } - // Zap the block! - BB->eraseFromParent(); + if (DDT) { + DDT->applyUpdates(Updates); + DDT->deleteBB(BB); // Deferred deletion of BB. + } else { + BB->eraseFromParent(); // Zap the block! + } } void llvm::FoldSingleEntryPHINodes(BasicBlock *BB, @@ -94,9 +104,8 @@ bool llvm::DeleteDeadPHIs(BasicBlock *BB, const TargetLibraryInfo *TLI) { // Recursively deleting a PHI may cause multiple PHIs to be deleted // or RAUW'd undef, so use an array of WeakTrackingVH for the PHIs to delete. SmallVector<WeakTrackingVH, 8> PHIs; - for (BasicBlock::iterator I = BB->begin(); - PHINode *PN = dyn_cast<PHINode>(I); ++I) - PHIs.push_back(PN); + for (PHINode &PN : BB->phis()) + PHIs.push_back(&PN); bool Changed = false; for (unsigned i = 0, e = PHIs.size(); i != e; ++i) @@ -108,9 +117,12 @@ bool llvm::DeleteDeadPHIs(BasicBlock *BB, const TargetLibraryInfo *TLI) { bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DominatorTree *DT, LoopInfo *LI, - MemoryDependenceResults *MemDep) { - // Don't merge away blocks who have their address taken. - if (BB->hasAddressTaken()) return false; + MemoryDependenceResults *MemDep, + DeferredDominance *DDT) { + assert(!(DT && DDT) && "Cannot call with both DT and DDT."); + + if (BB->hasAddressTaken()) + return false; // Can't merge if there are multiple predecessors, or no predecessors. BasicBlock *PredBB = BB->getUniquePredecessor(); @@ -122,39 +134,38 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DominatorTree *DT, if (PredBB->getTerminator()->isExceptional()) return false; - succ_iterator SI(succ_begin(PredBB)), SE(succ_end(PredBB)); - BasicBlock *OnlySucc = BB; - for (; SI != SE; ++SI) - if (*SI != OnlySucc) { - OnlySucc = nullptr; // There are multiple distinct successors! - break; - } - - // Can't merge if there are multiple successors. - if (!OnlySucc) return false; + // Can't merge if there are multiple distinct successors. + if (PredBB->getUniqueSuccessor() != BB) + return false; // Can't merge if there is PHI loop. - for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE; ++BI) { - if (PHINode *PN = dyn_cast<PHINode>(BI)) { - for (Value *IncValue : PN->incoming_values()) - if (IncValue == PN) - return false; - } else - break; - } + for (PHINode &PN : BB->phis()) + for (Value *IncValue : PN.incoming_values()) + if (IncValue == &PN) + return false; // Begin by getting rid of unneeded PHIs. - SmallVector<Value *, 4> IncomingValues; + SmallVector<AssertingVH<Value>, 4> IncomingValues; if (isa<PHINode>(BB->front())) { - for (auto &I : *BB) - if (PHINode *PN = dyn_cast<PHINode>(&I)) { - if (PN->getIncomingValue(0) != PN) - IncomingValues.push_back(PN->getIncomingValue(0)); - } else - break; + for (PHINode &PN : BB->phis()) + if (!isa<PHINode>(PN.getIncomingValue(0)) || + cast<PHINode>(PN.getIncomingValue(0))->getParent() != BB) + IncomingValues.push_back(PN.getIncomingValue(0)); FoldSingleEntryPHINodes(BB, MemDep); } + // Deferred DT update: Collect all the edges that exit BB. These + // dominator edges will be redirected from Pred. + std::vector<DominatorTree::UpdateType> Updates; + if (DDT) { + Updates.reserve(1 + (2 * succ_size(BB))); + Updates.push_back({DominatorTree::Delete, PredBB, BB}); + for (auto I = succ_begin(BB), E = succ_end(BB); I != E; ++I) { + Updates.push_back({DominatorTree::Delete, BB, *I}); + Updates.push_back({DominatorTree::Insert, PredBB, *I}); + } + } + // Delete the unconditional branch from the predecessor... PredBB->getInstList().pop_back(); @@ -166,8 +177,8 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DominatorTree *DT, PredBB->getInstList().splice(PredBB->end(), BB->getInstList()); // Eliminate duplicate dbg.values describing the entry PHI node post-splice. - for (auto *Incoming : IncomingValues) { - if (isa<Instruction>(Incoming)) { + for (auto Incoming : IncomingValues) { + if (isa<Instruction>(*Incoming)) { SmallVector<DbgValueInst *, 2> DbgValues; SmallDenseSet<std::pair<DILocalVariable *, DIExpression *>, 2> DbgValueSet; @@ -201,7 +212,12 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DominatorTree *DT, if (MemDep) MemDep->invalidateCachedPredecessors(); - BB->eraseFromParent(); + if (DDT) { + DDT->deleteBB(BB); // Deferred deletion of BB. + DDT->applyUpdates(Updates); + } else { + BB->eraseFromParent(); // Nuke BB. + } return true; } @@ -317,13 +333,21 @@ static void UpdateAnalysisInformation(BasicBlock *OldBB, BasicBlock *NewBB, DominatorTree *DT, LoopInfo *LI, bool PreserveLCSSA, bool &HasLoopExit) { // Update dominator tree if available. - if (DT) - DT->splitBlock(NewBB); + if (DT) { + if (OldBB == DT->getRootNode()->getBlock()) { + assert(NewBB == &NewBB->getParent()->getEntryBlock()); + DT->setNewRoot(NewBB); + } else { + // Split block expects NewBB to have a non-empty set of predecessors. + DT->splitBlock(NewBB); + } + } // The rest of the logic is only relevant for updating the loop structures. if (!LI) return; + assert(DT && "DT should be available to update LoopInfo!"); Loop *L = LI->getLoopFor(OldBB); // If we need to preserve loop analyses, collect some information about how @@ -331,6 +355,12 @@ static void UpdateAnalysisInformation(BasicBlock *OldBB, BasicBlock *NewBB, bool IsLoopEntry = !!L; bool SplitMakesNewLoopHeader = false; for (BasicBlock *Pred : Preds) { + // Preds that are not reachable from entry should not be used to identify if + // OldBB is a loop entry or if SplitMakesNewLoopHeader. Unreachable blocks + // are not within any loops, so we incorrectly mark SplitMakesNewLoopHeader + // as true and make the NewBB the header of some loop. This breaks LI. + if (!DT->isReachableFromEntry(Pred)) + continue; // If we need to preserve LCSSA, determine if any of the preds is a loop // exit. if (PreserveLCSSA) @@ -495,7 +525,6 @@ BasicBlock *llvm::SplitBlockPredecessors(BasicBlock *BB, // Insert dummy values as the incoming value. for (BasicBlock::iterator I = BB->begin(); isa<PHINode>(I); ++I) cast<PHINode>(I)->addIncoming(UndefValue::get(I->getType()), NewBB); - return NewBB; } // Update DominatorTree, LoopInfo, and LCCSA analysis information. @@ -503,8 +532,11 @@ BasicBlock *llvm::SplitBlockPredecessors(BasicBlock *BB, UpdateAnalysisInformation(BB, NewBB, Preds, DT, LI, PreserveLCSSA, HasLoopExit); - // Update the PHI nodes in BB with the values coming from NewBB. - UpdatePHINodes(BB, NewBB, Preds, BI, HasLoopExit); + if (!Preds.empty()) { + // Update the PHI nodes in BB with the values coming from NewBB. + UpdatePHINodes(BB, NewBB, Preds, BI, HasLoopExit); + } + return NewBB; } diff --git a/lib/Transforms/Utils/BreakCriticalEdges.cpp b/lib/Transforms/Utils/BreakCriticalEdges.cpp index 3653c307619b..3e30c27a9f33 100644 --- a/lib/Transforms/Utils/BreakCriticalEdges.cpp +++ b/lib/Transforms/Utils/BreakCriticalEdges.cpp @@ -28,7 +28,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Type.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/ValueMapper.h" @@ -106,10 +106,9 @@ static void createPHIsForSplitLoopExit(ArrayRef<BasicBlock *> Preds, SplitBB->isLandingPad()) && "SplitBB has non-PHI nodes!"); // For each PHI in the destination block. - for (BasicBlock::iterator I = DestBB->begin(); - PHINode *PN = dyn_cast<PHINode>(I); ++I) { - unsigned Idx = PN->getBasicBlockIndex(SplitBB); - Value *V = PN->getIncomingValue(Idx); + for (PHINode &PN : DestBB->phis()) { + unsigned Idx = PN.getBasicBlockIndex(SplitBB); + Value *V = PN.getIncomingValue(Idx); // If the input is a PHI which already satisfies LCSSA, don't create // a new one. @@ -119,13 +118,13 @@ static void createPHIsForSplitLoopExit(ArrayRef<BasicBlock *> Preds, // Otherwise a new PHI is needed. Create one and populate it. PHINode *NewPN = PHINode::Create( - PN->getType(), Preds.size(), "split", + PN.getType(), Preds.size(), "split", SplitBB->isLandingPad() ? &SplitBB->front() : SplitBB->getTerminator()); for (unsigned i = 0, e = Preds.size(); i != e; ++i) NewPN->addIncoming(V, Preds[i]); // Update the original PHI. - PN->setIncomingValue(Idx, NewPN); + PN.setIncomingValue(Idx, NewPN); } } diff --git a/lib/Transforms/Utils/BuildLibCalls.cpp b/lib/Transforms/Utils/BuildLibCalls.cpp index b60dfb4f3541..5f5c4150d3bb 100644 --- a/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/lib/Transforms/Utils/BuildLibCalls.cpp @@ -105,12 +105,23 @@ static bool setRetNonNull(Function &F) { return true; } +static bool setNonLazyBind(Function &F) { + if (F.hasFnAttribute(Attribute::NonLazyBind)) + return false; + F.addFnAttr(Attribute::NonLazyBind); + return true; +} + bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { LibFunc TheLibFunc; if (!(TLI.getLibFunc(F, TheLibFunc) && TLI.has(TheLibFunc))) return false; bool Changed = false; + + if (F.getParent() != nullptr && F.getParent()->getRtLibUseGOT()) + Changed |= setNonLazyBind(F); + switch (TheLibFunc) { case LibFunc_strlen: case LibFunc_wcslen: @@ -375,6 +386,7 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { case LibFunc_fseek: case LibFunc_ftell: case LibFunc_fgetc: + case LibFunc_fgetc_unlocked: case LibFunc_fseeko: case LibFunc_ftello: case LibFunc_fileno: @@ -393,6 +405,7 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { Changed |= setOnlyReadsMemory(F); return Changed; case LibFunc_fputc: + case LibFunc_fputc_unlocked: case LibFunc_fstat: case LibFunc_frexp: case LibFunc_frexpf: @@ -402,21 +415,25 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { Changed |= setDoesNotCapture(F, 1); return Changed; case LibFunc_fgets: + case LibFunc_fgets_unlocked: Changed |= setDoesNotThrow(F); Changed |= setDoesNotCapture(F, 2); return Changed; case LibFunc_fread: + case LibFunc_fread_unlocked: Changed |= setDoesNotThrow(F); Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 3); return Changed; case LibFunc_fwrite: + case LibFunc_fwrite_unlocked: Changed |= setDoesNotThrow(F); Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 3); // FIXME: readonly #1? return Changed; case LibFunc_fputs: + case LibFunc_fputs_unlocked: Changed |= setDoesNotThrow(F); Changed |= setDoesNotCapture(F, 0); Changed |= setDoesNotCapture(F, 1); @@ -447,6 +464,7 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { return Changed; case LibFunc_gets: case LibFunc_getchar: + case LibFunc_getchar_unlocked: Changed |= setDoesNotThrow(F); return Changed; case LibFunc_getitimer: @@ -485,6 +503,7 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { Changed |= setOnlyReadsMemory(F, 1); return Changed; case LibFunc_putc: + case LibFunc_putc_unlocked: Changed |= setDoesNotThrow(F); Changed |= setDoesNotCapture(F, 1); return Changed; @@ -505,6 +524,7 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { Changed |= setOnlyReadsMemory(F, 1); return Changed; case LibFunc_putchar: + case LibFunc_putchar_unlocked: Changed |= setDoesNotThrow(F); return Changed; case LibFunc_popen: @@ -687,9 +707,9 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { Changed |= setRetNonNull(F); Changed |= setRetDoesNotAlias(F); return Changed; - //TODO: add LibFunc entries for: - //case LibFunc_memset_pattern4: - //case LibFunc_memset_pattern8: + // TODO: add LibFunc entries for: + // case LibFunc_memset_pattern4: + // case LibFunc_memset_pattern8: case LibFunc_memset_pattern16: Changed |= setOnlyAccessesArgMemory(F); Changed |= setDoesNotCapture(F, 0); @@ -709,6 +729,19 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { } } +bool llvm::hasUnaryFloatFn(const TargetLibraryInfo *TLI, Type *Ty, + LibFunc DoubleFn, LibFunc FloatFn, + LibFunc LongDoubleFn) { + switch (Ty->getTypeID()) { + case Type::FloatTyID: + return TLI->has(FloatFn); + case Type::DoubleTyID: + return TLI->has(DoubleFn); + default: + return TLI->has(LongDoubleFn); + } +} + //- Emit LibCalls ------------------------------------------------------------// Value *llvm::castToCStr(Value *V, IRBuilder<> &B) { @@ -973,6 +1006,24 @@ Value *llvm::emitFPutC(Value *Char, Value *File, IRBuilder<> &B, return CI; } +Value *llvm::emitFPutCUnlocked(Value *Char, Value *File, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { + if (!TLI->has(LibFunc_fputc_unlocked)) + return nullptr; + + Module *M = B.GetInsertBlock()->getModule(); + Constant *F = M->getOrInsertFunction("fputc_unlocked", B.getInt32Ty(), + B.getInt32Ty(), File->getType()); + if (File->getType()->isPointerTy()) + inferLibFuncAttributes(*M->getFunction("fputc_unlocked"), *TLI); + Char = B.CreateIntCast(Char, B.getInt32Ty(), /*isSigned*/ true, "chari"); + CallInst *CI = B.CreateCall(F, {Char, File}, "fputc_unlocked"); + + if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) + CI->setCallingConv(Fn->getCallingConv()); + return CI; +} + Value *llvm::emitFPutS(Value *Str, Value *File, IRBuilder<> &B, const TargetLibraryInfo *TLI) { if (!TLI->has(LibFunc_fputs)) @@ -991,6 +1042,24 @@ Value *llvm::emitFPutS(Value *Str, Value *File, IRBuilder<> &B, return CI; } +Value *llvm::emitFPutSUnlocked(Value *Str, Value *File, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { + if (!TLI->has(LibFunc_fputs_unlocked)) + return nullptr; + + Module *M = B.GetInsertBlock()->getModule(); + StringRef FPutsUnlockedName = TLI->getName(LibFunc_fputs_unlocked); + Constant *F = M->getOrInsertFunction(FPutsUnlockedName, B.getInt32Ty(), + B.getInt8PtrTy(), File->getType()); + if (File->getType()->isPointerTy()) + inferLibFuncAttributes(*M->getFunction(FPutsUnlockedName), *TLI); + CallInst *CI = B.CreateCall(F, {castToCStr(Str, B), File}, "fputs_unlocked"); + + if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) + CI->setCallingConv(Fn->getCallingConv()); + return CI; +} + Value *llvm::emitFWrite(Value *Ptr, Value *Size, Value *File, IRBuilder<> &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { if (!TLI->has(LibFunc_fwrite)) @@ -1013,3 +1082,119 @@ Value *llvm::emitFWrite(Value *Ptr, Value *Size, Value *File, IRBuilder<> &B, CI->setCallingConv(Fn->getCallingConv()); return CI; } + +Value *llvm::emitMalloc(Value *Num, IRBuilder<> &B, const DataLayout &DL, + const TargetLibraryInfo *TLI) { + if (!TLI->has(LibFunc_malloc)) + return nullptr; + + Module *M = B.GetInsertBlock()->getModule(); + LLVMContext &Context = B.GetInsertBlock()->getContext(); + Value *Malloc = M->getOrInsertFunction("malloc", B.getInt8PtrTy(), + DL.getIntPtrType(Context)); + inferLibFuncAttributes(*M->getFunction("malloc"), *TLI); + CallInst *CI = B.CreateCall(Malloc, Num, "malloc"); + + if (const Function *F = dyn_cast<Function>(Malloc->stripPointerCasts())) + CI->setCallingConv(F->getCallingConv()); + + return CI; +} + +Value *llvm::emitCalloc(Value *Num, Value *Size, const AttributeList &Attrs, + IRBuilder<> &B, const TargetLibraryInfo &TLI) { + if (!TLI.has(LibFunc_calloc)) + return nullptr; + + Module *M = B.GetInsertBlock()->getModule(); + const DataLayout &DL = M->getDataLayout(); + IntegerType *PtrType = DL.getIntPtrType((B.GetInsertBlock()->getContext())); + Value *Calloc = M->getOrInsertFunction("calloc", Attrs, B.getInt8PtrTy(), + PtrType, PtrType); + inferLibFuncAttributes(*M->getFunction("calloc"), TLI); + CallInst *CI = B.CreateCall(Calloc, {Num, Size}, "calloc"); + + if (const auto *F = dyn_cast<Function>(Calloc->stripPointerCasts())) + CI->setCallingConv(F->getCallingConv()); + + return CI; +} + +Value *llvm::emitFWriteUnlocked(Value *Ptr, Value *Size, Value *N, Value *File, + IRBuilder<> &B, const DataLayout &DL, + const TargetLibraryInfo *TLI) { + if (!TLI->has(LibFunc_fwrite_unlocked)) + return nullptr; + + Module *M = B.GetInsertBlock()->getModule(); + LLVMContext &Context = B.GetInsertBlock()->getContext(); + StringRef FWriteUnlockedName = TLI->getName(LibFunc_fwrite_unlocked); + Constant *F = M->getOrInsertFunction( + FWriteUnlockedName, DL.getIntPtrType(Context), B.getInt8PtrTy(), + DL.getIntPtrType(Context), DL.getIntPtrType(Context), File->getType()); + + if (File->getType()->isPointerTy()) + inferLibFuncAttributes(*M->getFunction(FWriteUnlockedName), *TLI); + CallInst *CI = B.CreateCall(F, {castToCStr(Ptr, B), Size, N, File}); + + if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) + CI->setCallingConv(Fn->getCallingConv()); + return CI; +} + +Value *llvm::emitFGetCUnlocked(Value *File, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { + if (!TLI->has(LibFunc_fgetc_unlocked)) + return nullptr; + + Module *M = B.GetInsertBlock()->getModule(); + Constant *F = + M->getOrInsertFunction("fgetc_unlocked", B.getInt32Ty(), File->getType()); + if (File->getType()->isPointerTy()) + inferLibFuncAttributes(*M->getFunction("fgetc_unlocked"), *TLI); + CallInst *CI = B.CreateCall(F, File, "fgetc_unlocked"); + + if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) + CI->setCallingConv(Fn->getCallingConv()); + return CI; +} + +Value *llvm::emitFGetSUnlocked(Value *Str, Value *Size, Value *File, + IRBuilder<> &B, const TargetLibraryInfo *TLI) { + if (!TLI->has(LibFunc_fgets_unlocked)) + return nullptr; + + Module *M = B.GetInsertBlock()->getModule(); + Constant *F = + M->getOrInsertFunction("fgets_unlocked", B.getInt8PtrTy(), + B.getInt8PtrTy(), B.getInt32Ty(), File->getType()); + inferLibFuncAttributes(*M->getFunction("fgets_unlocked"), *TLI); + CallInst *CI = + B.CreateCall(F, {castToCStr(Str, B), Size, File}, "fgets_unlocked"); + + if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) + CI->setCallingConv(Fn->getCallingConv()); + return CI; +} + +Value *llvm::emitFReadUnlocked(Value *Ptr, Value *Size, Value *N, Value *File, + IRBuilder<> &B, const DataLayout &DL, + const TargetLibraryInfo *TLI) { + if (!TLI->has(LibFunc_fread_unlocked)) + return nullptr; + + Module *M = B.GetInsertBlock()->getModule(); + LLVMContext &Context = B.GetInsertBlock()->getContext(); + StringRef FReadUnlockedName = TLI->getName(LibFunc_fread_unlocked); + Constant *F = M->getOrInsertFunction( + FReadUnlockedName, DL.getIntPtrType(Context), B.getInt8PtrTy(), + DL.getIntPtrType(Context), DL.getIntPtrType(Context), File->getType()); + + if (File->getType()->isPointerTy()) + inferLibFuncAttributes(*M->getFunction(FReadUnlockedName), *TLI); + CallInst *CI = B.CreateCall(F, {castToCStr(Ptr, B), Size, N, File}); + + if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) + CI->setCallingConv(Fn->getCallingConv()); + return CI; +} diff --git a/lib/Transforms/Utils/BypassSlowDivision.cpp b/lib/Transforms/Utils/BypassSlowDivision.cpp index f711b192f604..05512a6dff3e 100644 --- a/lib/Transforms/Utils/BypassSlowDivision.cpp +++ b/lib/Transforms/Utils/BypassSlowDivision.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" @@ -34,7 +35,6 @@ #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" #include "llvm/Support/KnownBits.h" -#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstdint> @@ -173,7 +173,7 @@ Value *FastDivInsertionTask::getReplacement(DivCacheTy &Cache) { return isDivisionOp() ? Value.Quotient : Value.Remainder; } -/// \brief Check if a value looks like a hash. +/// Check if a value looks like a hash. /// /// The routine is expected to detect values computed using the most common hash /// algorithms. Typically, hash computations end with one of the following diff --git a/lib/Transforms/Utils/CMakeLists.txt b/lib/Transforms/Utils/CMakeLists.txt index 972e47f9270a..c87b74f739f4 100644 --- a/lib/Transforms/Utils/CMakeLists.txt +++ b/lib/Transforms/Utils/CMakeLists.txt @@ -25,8 +25,10 @@ add_llvm_library(LLVMTransformUtils LCSSA.cpp LibCallsShrinkWrap.cpp Local.cpp + LoopRotationUtils.cpp LoopSimplify.cpp LoopUnroll.cpp + LoopUnrollAndJam.cpp LoopUnrollPeel.cpp LoopUnrollRuntime.cpp LoopUtils.cpp @@ -43,10 +45,10 @@ add_llvm_library(LLVMTransformUtils PromoteMemoryToRegister.cpp StripGCRelocates.cpp SSAUpdater.cpp + SSAUpdaterBulk.cpp SanitizerStats.cpp SimplifyCFG.cpp SimplifyIndVar.cpp - SimplifyInstructions.cpp SimplifyLibCalls.cpp SplitModule.cpp StripNonLineTableDebugInfo.cpp diff --git a/lib/Transforms/Utils/CallPromotionUtils.cpp b/lib/Transforms/Utils/CallPromotionUtils.cpp index 8825f77555e7..4d9c22e57a68 100644 --- a/lib/Transforms/Utils/CallPromotionUtils.cpp +++ b/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -47,14 +47,11 @@ using namespace llvm; /// static void fixupPHINodeForNormalDest(InvokeInst *Invoke, BasicBlock *OrigBlock, BasicBlock *MergeBlock) { - for (auto &I : *Invoke->getNormalDest()) { - auto *Phi = dyn_cast<PHINode>(&I); - if (!Phi) - break; - int Idx = Phi->getBasicBlockIndex(OrigBlock); + for (PHINode &Phi : Invoke->getNormalDest()->phis()) { + int Idx = Phi.getBasicBlockIndex(OrigBlock); if (Idx == -1) continue; - Phi->setIncomingBlock(Idx, MergeBlock); + Phi.setIncomingBlock(Idx, MergeBlock); } } @@ -82,16 +79,13 @@ static void fixupPHINodeForNormalDest(InvokeInst *Invoke, BasicBlock *OrigBlock, static void fixupPHINodeForUnwindDest(InvokeInst *Invoke, BasicBlock *OrigBlock, BasicBlock *ThenBlock, BasicBlock *ElseBlock) { - for (auto &I : *Invoke->getUnwindDest()) { - auto *Phi = dyn_cast<PHINode>(&I); - if (!Phi) - break; - int Idx = Phi->getBasicBlockIndex(OrigBlock); + for (PHINode &Phi : Invoke->getUnwindDest()->phis()) { + int Idx = Phi.getBasicBlockIndex(OrigBlock); if (Idx == -1) continue; - auto *V = Phi->getIncomingValue(Idx); - Phi->setIncomingBlock(Idx, ThenBlock); - Phi->addIncoming(V, ElseBlock); + auto *V = Phi.getIncomingValue(Idx); + Phi.setIncomingBlock(Idx, ThenBlock); + Phi.addIncoming(V, ElseBlock); } } @@ -395,12 +389,14 @@ Instruction *llvm::promoteCall(CallSite CS, Function *Callee, // Inspect the arguments of the call site. If an argument's type doesn't // match the corresponding formal argument's type in the callee, bitcast it // to the correct type. - for (Use &U : CS.args()) { - unsigned ArgNo = CS.getArgumentNo(&U); - Type *FormalTy = Callee->getFunctionType()->getParamType(ArgNo); - Type *ActualTy = U.get()->getType(); + auto CalleeType = Callee->getFunctionType(); + auto CalleeParamNum = CalleeType->getNumParams(); + for (unsigned ArgNo = 0; ArgNo < CalleeParamNum; ++ArgNo) { + auto *Arg = CS.getArgument(ArgNo); + Type *FormalTy = CalleeType->getParamType(ArgNo); + Type *ActualTy = Arg->getType(); if (FormalTy != ActualTy) { - auto *Cast = CastInst::Create(Instruction::BitCast, U.get(), FormalTy, "", + auto *Cast = CastInst::Create(Instruction::BitCast, Arg, FormalTy, "", CS.getInstruction()); CS.setArgument(ArgNo, Cast); } diff --git a/lib/Transforms/Utils/CloneFunction.cpp b/lib/Transforms/Utils/CloneFunction.cpp index 3b19ba1b50f2..61448e9acb57 100644 --- a/lib/Transforms/Utils/CloneFunction.cpp +++ b/lib/Transforms/Utils/CloneFunction.cpp @@ -18,6 +18,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfo.h" @@ -31,7 +32,6 @@ #include "llvm/IR/Module.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <map> using namespace llvm; @@ -43,44 +43,36 @@ BasicBlock *llvm::CloneBasicBlock(const BasicBlock *BB, ValueToValueMapTy &VMap, DebugInfoFinder *DIFinder) { DenseMap<const MDNode *, MDNode *> Cache; BasicBlock *NewBB = BasicBlock::Create(BB->getContext(), "", F); - if (BB->hasName()) NewBB->setName(BB->getName()+NameSuffix); + if (BB->hasName()) + NewBB->setName(BB->getName() + NameSuffix); bool hasCalls = false, hasDynamicAllocas = false, hasStaticAllocas = false; Module *TheModule = F ? F->getParent() : nullptr; // Loop over all instructions, and copy them over. - for (BasicBlock::const_iterator II = BB->begin(), IE = BB->end(); - II != IE; ++II) { - - if (DIFinder && TheModule) { - if (auto *DDI = dyn_cast<DbgDeclareInst>(II)) - DIFinder->processDeclare(*TheModule, DDI); - else if (auto *DVI = dyn_cast<DbgValueInst>(II)) - DIFinder->processValue(*TheModule, DVI); + for (const Instruction &I : *BB) { + if (DIFinder && TheModule) + DIFinder->processInstruction(*TheModule, I); - if (auto DbgLoc = II->getDebugLoc()) - DIFinder->processLocation(*TheModule, DbgLoc.get()); - } - - Instruction *NewInst = II->clone(); - if (II->hasName()) - NewInst->setName(II->getName()+NameSuffix); + Instruction *NewInst = I.clone(); + if (I.hasName()) + NewInst->setName(I.getName() + NameSuffix); NewBB->getInstList().push_back(NewInst); - VMap[&*II] = NewInst; // Add instruction map to value. + VMap[&I] = NewInst; // Add instruction map to value. - hasCalls |= (isa<CallInst>(II) && !isa<DbgInfoIntrinsic>(II)); - if (const AllocaInst *AI = dyn_cast<AllocaInst>(II)) { + hasCalls |= (isa<CallInst>(I) && !isa<DbgInfoIntrinsic>(I)); + if (const AllocaInst *AI = dyn_cast<AllocaInst>(&I)) { if (isa<ConstantInt>(AI->getArraySize())) hasStaticAllocas = true; else hasDynamicAllocas = true; } } - + if (CodeInfo) { CodeInfo->ContainsCalls |= hasCalls; CodeInfo->ContainsDynamicAllocas |= hasDynamicAllocas; - CodeInfo->ContainsDynamicAllocas |= hasStaticAllocas && + CodeInfo->ContainsDynamicAllocas |= hasStaticAllocas && BB != &BB->getParent()->getEntryBlock(); } return NewBB; @@ -175,7 +167,7 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc, // Create a new basic block and copy instructions into it! BasicBlock *CBB = CloneBasicBlock(&BB, VMap, NameSuffix, NewFunc, CodeInfo, - SP ? &DIFinder : nullptr); + ModuleLevelChanges ? &DIFinder : nullptr); // Add basic block mapping. VMap[&BB] = CBB; @@ -197,15 +189,15 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc, Returns.push_back(RI); } - for (DISubprogram *ISP : DIFinder.subprograms()) { - if (ISP != SP) { + for (DISubprogram *ISP : DIFinder.subprograms()) + if (ISP != SP) VMap.MD()[ISP].reset(ISP); - } - } - for (auto *Type : DIFinder.types()) { + for (DICompileUnit *CU : DIFinder.compile_units()) + VMap.MD()[CU].reset(CU); + + for (DIType *Type : DIFinder.types()) VMap.MD()[Type].reset(Type); - } // Loop over all of the instructions in the function, fixing up operand // references as we go. This uses VMap to do all the hard work. @@ -283,7 +275,7 @@ namespace { /// The specified block is found to be reachable, clone it and /// anything that it can reach. - void CloneBlock(const BasicBlock *BB, + void CloneBlock(const BasicBlock *BB, BasicBlock::const_iterator StartingInst, std::vector<const BasicBlock*> &ToClone); }; @@ -493,17 +485,13 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, // Handle PHI nodes specially, as we have to remove references to dead // blocks. - for (BasicBlock::const_iterator I = BI.begin(), E = BI.end(); I != E; ++I) { + for (const PHINode &PN : BI.phis()) { // PHI nodes may have been remapped to non-PHI nodes by the caller or // during the cloning process. - if (const PHINode *PN = dyn_cast<PHINode>(I)) { - if (isa<PHINode>(VMap[PN])) - PHIToResolve.push_back(PN); - else - break; - } else { + if (isa<PHINode>(VMap[&PN])) + PHIToResolve.push_back(&PN); + else break; - } } // Finally, remap the terminator instructions, as those can't be remapped @@ -550,7 +538,7 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, // phi nodes will have invalid entries. Update the PHI nodes in this // case. PHINode *PN = cast<PHINode>(NewBB->begin()); - NumPreds = std::distance(pred_begin(NewBB), pred_end(NewBB)); + NumPreds = pred_size(NewBB); if (NumPreds != PN->getNumIncomingValues()) { assert(NumPreds < PN->getNumIncomingValues()); // Count how many times each predecessor comes to this block. @@ -722,7 +710,7 @@ void llvm::CloneAndPruneFunctionInto(Function *NewFunc, const Function *OldFunc, ModuleLevelChanges, Returns, NameSuffix, CodeInfo); } -/// \brief Remaps instructions in \p Blocks using the mapping in \p VMap. +/// Remaps instructions in \p Blocks using the mapping in \p VMap. void llvm::remapInstructionsInBlocks( const SmallVectorImpl<BasicBlock *> &Blocks, ValueToValueMapTy &VMap) { // Rewrite the code to refer to itself. @@ -732,7 +720,7 @@ void llvm::remapInstructionsInBlocks( RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); } -/// \brief Clones a loop \p OrigLoop. Returns the loop and the blocks in \p +/// Clones a loop \p OrigLoop. Returns the loop and the blocks in \p /// Blocks. /// /// Updates LoopInfo and DominatorTree assuming the loop is dominated by block @@ -796,12 +784,13 @@ Loop *llvm::cloneLoopWithPreheader(BasicBlock *Before, BasicBlock *LoopDomBB, return NewLoop; } -/// \brief Duplicate non-Phi instructions from the beginning of block up to +/// Duplicate non-Phi instructions from the beginning of block up to /// StopAt instruction into a split block between BB and its predecessor. BasicBlock * llvm::DuplicateInstructionsInSplitBetween(BasicBlock *BB, BasicBlock *PredBB, Instruction *StopAt, - ValueToValueMapTy &ValueMapping) { + ValueToValueMapTy &ValueMapping, + DominatorTree *DT) { // We are going to have to map operands from the original BB block to the new // copy of the block 'NewBB'. If there are PHI nodes in BB, evaluate them to // account for entry from PredBB. @@ -809,13 +798,15 @@ llvm::DuplicateInstructionsInSplitBetween(BasicBlock *BB, BasicBlock *PredBB, for (; PHINode *PN = dyn_cast<PHINode>(BI); ++BI) ValueMapping[PN] = PN->getIncomingValueForBlock(PredBB); - BasicBlock *NewBB = SplitEdge(PredBB, BB); + BasicBlock *NewBB = SplitEdge(PredBB, BB, DT); NewBB->setName(PredBB->getName() + ".split"); Instruction *NewTerm = NewBB->getTerminator(); // Clone the non-phi instructions of BB into NewBB, keeping track of the // mapping and using it to remap operands in the cloned instructions. - for (; StopAt != &*BI; ++BI) { + // Stop once we see the terminator too. This covers the case where BB's + // terminator gets replaced and StopAt == BB's terminator. + for (; StopAt != &*BI && BB->getTerminator() != &*BI; ++BI) { Instruction *New = BI->clone(); New->setName(BI->getName()); New->insertBefore(NewTerm); diff --git a/lib/Transforms/Utils/CloneModule.cpp b/lib/Transforms/Utils/CloneModule.cpp index 8fee10854229..35c7511a24b9 100644 --- a/lib/Transforms/Utils/CloneModule.cpp +++ b/lib/Transforms/Utils/CloneModule.cpp @@ -32,33 +32,34 @@ static void copyComdat(GlobalObject *Dst, const GlobalObject *Src) { /// copies of global variables and functions, and making their (initializers and /// references, respectively) refer to the right globals. /// -std::unique_ptr<Module> llvm::CloneModule(const Module *M) { +std::unique_ptr<Module> llvm::CloneModule(const Module &M) { // Create the value map that maps things from the old module over to the new // module. ValueToValueMapTy VMap; return CloneModule(M, VMap); } -std::unique_ptr<Module> llvm::CloneModule(const Module *M, +std::unique_ptr<Module> llvm::CloneModule(const Module &M, ValueToValueMapTy &VMap) { return CloneModule(M, VMap, [](const GlobalValue *GV) { return true; }); } std::unique_ptr<Module> llvm::CloneModule( - const Module *M, ValueToValueMapTy &VMap, + const Module &M, ValueToValueMapTy &VMap, function_ref<bool(const GlobalValue *)> ShouldCloneDefinition) { // First off, we need to create the new module. std::unique_ptr<Module> New = - llvm::make_unique<Module>(M->getModuleIdentifier(), M->getContext()); - New->setDataLayout(M->getDataLayout()); - New->setTargetTriple(M->getTargetTriple()); - New->setModuleInlineAsm(M->getModuleInlineAsm()); - + llvm::make_unique<Module>(M.getModuleIdentifier(), M.getContext()); + New->setSourceFileName(M.getSourceFileName()); + New->setDataLayout(M.getDataLayout()); + New->setTargetTriple(M.getTargetTriple()); + New->setModuleInlineAsm(M.getModuleInlineAsm()); + // Loop over all of the global variables, making corresponding globals in the // new module. Here we add them to the VMap and to the new Module. We // don't worry about attributes or initializers, they will come later. // - for (Module::const_global_iterator I = M->global_begin(), E = M->global_end(); + for (Module::const_global_iterator I = M.global_begin(), E = M.global_end(); I != E; ++I) { GlobalVariable *GV = new GlobalVariable(*New, I->getValueType(), @@ -72,7 +73,7 @@ std::unique_ptr<Module> llvm::CloneModule( } // Loop over the functions in the module, making external functions as before - for (const Function &I : *M) { + for (const Function &I : M) { Function *NF = Function::Create(cast<FunctionType>(I.getValueType()), I.getLinkage(), I.getName(), New.get()); NF->copyAttributesFrom(&I); @@ -80,7 +81,7 @@ std::unique_ptr<Module> llvm::CloneModule( } // Loop over the aliases in the module - for (Module::const_alias_iterator I = M->alias_begin(), E = M->alias_end(); + for (Module::const_alias_iterator I = M.alias_begin(), E = M.alias_end(); I != E; ++I) { if (!ShouldCloneDefinition(&*I)) { // An alias cannot act as an external reference, so we need to create @@ -114,7 +115,7 @@ std::unique_ptr<Module> llvm::CloneModule( // have been created, loop through and copy the global variable referrers // over... We also set the attributes on the global now. // - for (Module::const_global_iterator I = M->global_begin(), E = M->global_end(); + for (Module::const_global_iterator I = M.global_begin(), E = M.global_end(); I != E; ++I) { if (I->isDeclaration()) continue; @@ -139,7 +140,7 @@ std::unique_ptr<Module> llvm::CloneModule( // Similarly, copy over function bodies now... // - for (const Function &I : *M) { + for (const Function &I : M) { if (I.isDeclaration()) continue; @@ -169,7 +170,7 @@ std::unique_ptr<Module> llvm::CloneModule( } // And aliases - for (Module::const_alias_iterator I = M->alias_begin(), E = M->alias_end(); + for (Module::const_alias_iterator I = M.alias_begin(), E = M.alias_end(); I != E; ++I) { // We already dealt with undefined aliases above. if (!ShouldCloneDefinition(&*I)) @@ -180,8 +181,9 @@ std::unique_ptr<Module> llvm::CloneModule( } // And named metadata.... - for (Module::const_named_metadata_iterator I = M->named_metadata_begin(), - E = M->named_metadata_end(); I != E; ++I) { + for (Module::const_named_metadata_iterator I = M.named_metadata_begin(), + E = M.named_metadata_end(); + I != E; ++I) { const NamedMDNode &NMD = *I; NamedMDNode *NewNMD = New->getOrInsertNamedMetadata(NMD.getName()); for (unsigned i = 0, e = NMD.getNumOperands(); i != e; ++i) @@ -194,7 +196,7 @@ std::unique_ptr<Module> llvm::CloneModule( extern "C" { LLVMModuleRef LLVMCloneModule(LLVMModuleRef M) { - return wrap(CloneModule(unwrap(M)).release()); + return wrap(CloneModule(*unwrap(M)).release()); } } diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp index 7a404241cb14..f31dab9f96af 100644 --- a/lib/Transforms/Utils/CodeExtractor.cpp +++ b/lib/Transforms/Utils/CodeExtractor.cpp @@ -66,6 +66,7 @@ #include <vector> using namespace llvm; +using ProfileCount = Function::ProfileCount; #define DEBUG_TYPE "code-extractor" @@ -77,12 +78,10 @@ static cl::opt<bool> AggregateArgsOpt("aggregate-extracted-args", cl::Hidden, cl::desc("Aggregate arguments to code-extracted functions")); -/// \brief Test whether a block is valid for extraction. -bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB, - bool AllowVarArgs) { - // Landing pads must be in the function where they were inserted for cleanup. - if (BB.isEHPad()) - return false; +/// Test whether a block is valid for extraction. +static bool isBlockValidForExtraction(const BasicBlock &BB, + const SetVector<BasicBlock *> &Result, + bool AllowVarArgs, bool AllowAlloca) { // taking the address of a basic block moved to another function is illegal if (BB.hasAddressTaken()) return false; @@ -111,11 +110,63 @@ bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB, } } - // Don't hoist code containing allocas or invokes. If explicitly requested, - // allow vastart. + // If explicitly requested, allow vastart and alloca. For invoke instructions + // verify that extraction is valid. for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) { - if (isa<AllocaInst>(I) || isa<InvokeInst>(I)) - return false; + if (isa<AllocaInst>(I)) { + if (!AllowAlloca) + return false; + continue; + } + + if (const auto *II = dyn_cast<InvokeInst>(I)) { + // Unwind destination (either a landingpad, catchswitch, or cleanuppad) + // must be a part of the subgraph which is being extracted. + if (auto *UBB = II->getUnwindDest()) + if (!Result.count(UBB)) + return false; + continue; + } + + // All catch handlers of a catchswitch instruction as well as the unwind + // destination must be in the subgraph. + if (const auto *CSI = dyn_cast<CatchSwitchInst>(I)) { + if (auto *UBB = CSI->getUnwindDest()) + if (!Result.count(UBB)) + return false; + for (auto *HBB : CSI->handlers()) + if (!Result.count(const_cast<BasicBlock*>(HBB))) + return false; + continue; + } + + // Make sure that entire catch handler is within subgraph. It is sufficient + // to check that catch return's block is in the list. + if (const auto *CPI = dyn_cast<CatchPadInst>(I)) { + for (const auto *U : CPI->users()) + if (const auto *CRI = dyn_cast<CatchReturnInst>(U)) + if (!Result.count(const_cast<BasicBlock*>(CRI->getParent()))) + return false; + continue; + } + + // And do similar checks for cleanup handler - the entire handler must be + // in subgraph which is going to be extracted. For cleanup return should + // additionally check that the unwind destination is also in the subgraph. + if (const auto *CPI = dyn_cast<CleanupPadInst>(I)) { + for (const auto *U : CPI->users()) + if (const auto *CRI = dyn_cast<CleanupReturnInst>(U)) + if (!Result.count(const_cast<BasicBlock*>(CRI->getParent()))) + return false; + continue; + } + if (const auto *CRI = dyn_cast<CleanupReturnInst>(I)) { + if (auto *UBB = CRI->getUnwindDest()) + if (!Result.count(UBB)) + return false; + continue; + } + if (const CallInst *CI = dyn_cast<CallInst>(I)) if (const Function *F = CI->getCalledFunction()) if (F->getIntrinsicID() == Intrinsic::vastart) { @@ -129,10 +180,10 @@ bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB, return true; } -/// \brief Build a set of blocks to extract if the input blocks are viable. +/// Build a set of blocks to extract if the input blocks are viable. static SetVector<BasicBlock *> buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, - bool AllowVarArgs) { + bool AllowVarArgs, bool AllowAlloca) { assert(!BBs.empty() && "The set of blocks to extract must be non-empty"); SetVector<BasicBlock *> Result; @@ -145,32 +196,42 @@ buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, if (!Result.insert(BB)) llvm_unreachable("Repeated basic blocks in extraction input"); - if (!CodeExtractor::isBlockValidForExtraction(*BB, AllowVarArgs)) { - Result.clear(); - return Result; - } } -#ifndef NDEBUG - for (SetVector<BasicBlock *>::iterator I = std::next(Result.begin()), - E = Result.end(); - I != E; ++I) - for (pred_iterator PI = pred_begin(*I), PE = pred_end(*I); - PI != PE; ++PI) - assert(Result.count(*PI) && - "No blocks in this region may have entries from outside the region" - " except for the first block!"); -#endif + for (auto *BB : Result) { + if (!isBlockValidForExtraction(*BB, Result, AllowVarArgs, AllowAlloca)) + return {}; + + // Make sure that the first block is not a landing pad. + if (BB == Result.front()) { + if (BB->isEHPad()) { + LLVM_DEBUG(dbgs() << "The first block cannot be an unwind block\n"); + return {}; + } + continue; + } + + // All blocks other than the first must not have predecessors outside of + // the subgraph which is being extracted. + for (auto *PBB : predecessors(BB)) + if (!Result.count(PBB)) { + LLVM_DEBUG( + dbgs() << "No blocks in this region may have entries from " + "outside the region except for the first block!\n"); + return {}; + } + } return Result; } CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, bool AggregateArgs, BlockFrequencyInfo *BFI, - BranchProbabilityInfo *BPI, bool AllowVarArgs) + BranchProbabilityInfo *BPI, bool AllowVarArgs, + bool AllowAlloca) : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), BPI(BPI), AllowVarArgs(AllowVarArgs), - Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs)) {} + Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)) {} CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, BlockFrequencyInfo *BFI, @@ -178,7 +239,8 @@ CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), BPI(BPI), AllowVarArgs(false), Blocks(buildExtractionBlockSet(L.getBlocks(), &DT, - /* AllowVarArgs */ false)) {} + /* AllowVarArgs */ false, + /* AllowAlloca */ false)) {} /// definedInRegion - Return true if the specified value is defined in the /// extracted region. @@ -562,8 +624,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, BasicBlock *newHeader, Function *oldFunction, Module *M) { - DEBUG(dbgs() << "inputs: " << inputs.size() << "\n"); - DEBUG(dbgs() << "outputs: " << outputs.size() << "\n"); + LLVM_DEBUG(dbgs() << "inputs: " << inputs.size() << "\n"); + LLVM_DEBUG(dbgs() << "outputs: " << outputs.size() << "\n"); // This function returns unsigned, outputs will go back by reference. switch (NumExitBlocks) { @@ -577,20 +639,20 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, // Add the types of the input values to the function's argument list for (Value *value : inputs) { - DEBUG(dbgs() << "value used in func: " << *value << "\n"); + LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n"); paramTy.push_back(value->getType()); } // Add the types of the output values to the function's argument list. for (Value *output : outputs) { - DEBUG(dbgs() << "instr used in func: " << *output << "\n"); + LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n"); if (AggregateArgs) paramTy.push_back(output->getType()); else paramTy.push_back(PointerType::getUnqual(output->getType())); } - DEBUG({ + LLVM_DEBUG({ dbgs() << "Function type: " << *RetTy << " f("; for (Type *i : paramTy) dbgs() << *i << ", "; @@ -620,16 +682,89 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, if (oldFunction->hasUWTable()) newFunction->setHasUWTable(); - // Inherit all of the target dependent attributes. + // Inherit all of the target dependent attributes and white-listed + // target independent attributes. // (e.g. If the extracted region contains a call to an x86.sse // instruction we need to make sure that the extracted region has the // "target-features" attribute allowing it to be lowered. // FIXME: This should be changed to check to see if a specific // attribute can not be inherited. - AttrBuilder AB(oldFunction->getAttributes().getFnAttributes()); - for (const auto &Attr : AB.td_attrs()) - newFunction->addFnAttr(Attr.first, Attr.second); + for (const auto &Attr : oldFunction->getAttributes().getFnAttributes()) { + if (Attr.isStringAttribute()) { + if (Attr.getKindAsString() == "thunk") + continue; + } else + switch (Attr.getKindAsEnum()) { + // Those attributes cannot be propagated safely. Explicitly list them + // here so we get a warning if new attributes are added. This list also + // includes non-function attributes. + case Attribute::Alignment: + case Attribute::AllocSize: + case Attribute::ArgMemOnly: + case Attribute::Builtin: + case Attribute::ByVal: + case Attribute::Convergent: + case Attribute::Dereferenceable: + case Attribute::DereferenceableOrNull: + case Attribute::InAlloca: + case Attribute::InReg: + case Attribute::InaccessibleMemOnly: + case Attribute::InaccessibleMemOrArgMemOnly: + case Attribute::JumpTable: + case Attribute::Naked: + case Attribute::Nest: + case Attribute::NoAlias: + case Attribute::NoBuiltin: + case Attribute::NoCapture: + case Attribute::NoReturn: + case Attribute::None: + case Attribute::NonNull: + case Attribute::ReadNone: + case Attribute::ReadOnly: + case Attribute::Returned: + case Attribute::ReturnsTwice: + case Attribute::SExt: + case Attribute::Speculatable: + case Attribute::StackAlignment: + case Attribute::StructRet: + case Attribute::SwiftError: + case Attribute::SwiftSelf: + case Attribute::WriteOnly: + case Attribute::ZExt: + case Attribute::EndAttrKinds: + continue; + // Those attributes should be safe to propagate to the extracted function. + case Attribute::AlwaysInline: + case Attribute::Cold: + case Attribute::NoRecurse: + case Attribute::InlineHint: + case Attribute::MinSize: + case Attribute::NoDuplicate: + case Attribute::NoImplicitFloat: + case Attribute::NoInline: + case Attribute::NonLazyBind: + case Attribute::NoRedZone: + case Attribute::NoUnwind: + case Attribute::OptForFuzzing: + case Attribute::OptimizeNone: + case Attribute::OptimizeForSize: + case Attribute::SafeStack: + case Attribute::ShadowCallStack: + case Attribute::SanitizeAddress: + case Attribute::SanitizeMemory: + case Attribute::SanitizeThread: + case Attribute::SanitizeHWAddress: + case Attribute::StackProtect: + case Attribute::StackProtectReq: + case Attribute::StackProtectStrong: + case Attribute::StrictFP: + case Attribute::UWTable: + case Attribute::NoCfCheck: + break; + } + newFunction->addFnAttr(Attr); + } newFunction->getBasicBlockList().push_back(newRootNode); // Create an iterator to name all of the arguments we inserted. @@ -1093,10 +1228,10 @@ Function *CodeExtractor::extractCodeRegion() { // Update the entry count of the function. if (BFI) { - Optional<uint64_t> EntryCount = - BFI->getProfileCountFromFreq(EntryFreq.getFrequency()); - if (EntryCount.hasValue()) - newFunction->setEntryCount(EntryCount.getValue()); + auto Count = BFI->getProfileCountFromFreq(EntryFreq.getFrequency()); + if (Count.hasValue()) + newFunction->setEntryCount( + ProfileCount(Count.getValue(), Function::PCT_Real)); // FIXME BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency()); } @@ -1104,6 +1239,10 @@ Function *CodeExtractor::extractCodeRegion() { moveCodeToFunction(newFunction); + // Propagate personality info to the new function if there is one. + if (oldFunction->hasPersonalityFn()) + newFunction->setPersonalityFn(oldFunction->getPersonalityFn()); + // Update the branch weights for the exit block. if (BFI && NumExitBlocks > 1) calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI); @@ -1139,7 +1278,7 @@ Function *CodeExtractor::extractCodeRegion() { } } - DEBUG(if (verifyFunction(*newFunction)) - report_fatal_error("verifyFunction failed!")); + LLVM_DEBUG(if (verifyFunction(*newFunction)) + report_fatal_error("verifyFunction failed!")); return newFunction; } diff --git a/lib/Transforms/Utils/CtorUtils.cpp b/lib/Transforms/Utils/CtorUtils.cpp index 82b67c293102..9a0240144d08 100644 --- a/lib/Transforms/Utils/CtorUtils.cpp +++ b/lib/Transforms/Utils/CtorUtils.cpp @@ -138,7 +138,7 @@ bool optimizeGlobalCtorsList(Module &M, if (!F) continue; - DEBUG(dbgs() << "Optimizing Global Constructor: " << *F << "\n"); + LLVM_DEBUG(dbgs() << "Optimizing Global Constructor: " << *F << "\n"); // We cannot simplify external ctor functions. if (F->empty()) diff --git a/lib/Transforms/Utils/DemoteRegToStack.cpp b/lib/Transforms/Utils/DemoteRegToStack.cpp index 6d3d287defdb..56ff03c7f5e1 100644 --- a/lib/Transforms/Utils/DemoteRegToStack.cpp +++ b/lib/Transforms/Utils/DemoteRegToStack.cpp @@ -9,11 +9,11 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/Analysis/CFG.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Type.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" using namespace llvm; /// DemoteRegToStack - This function takes a virtual register computed by an diff --git a/lib/Transforms/Utils/EntryExitInstrumenter.cpp b/lib/Transforms/Utils/EntryExitInstrumenter.cpp index 421663f82565..569ea58a3047 100644 --- a/lib/Transforms/Utils/EntryExitInstrumenter.cpp +++ b/lib/Transforms/Utils/EntryExitInstrumenter.cpp @@ -9,14 +9,13 @@ #include "llvm/Transforms/Utils/EntryExitInstrumenter.h" #include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/CodeGen/Passes.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/Pass.h" -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" using namespace llvm; static void insertCall(Function &CurFn, StringRef Func, @@ -92,17 +91,27 @@ static bool runOnFunction(Function &F, bool PostInlining) { if (!ExitFunc.empty()) { for (BasicBlock &BB : F) { - TerminatorInst *T = BB.getTerminator(); + Instruction *T = BB.getTerminator(); + if (!isa<ReturnInst>(T)) + continue; + + // If T is preceded by a musttail call, that's the real terminator. + Instruction *Prev = T->getPrevNode(); + if (BitCastInst *BCI = dyn_cast_or_null<BitCastInst>(Prev)) + Prev = BCI->getPrevNode(); + if (CallInst *CI = dyn_cast_or_null<CallInst>(Prev)) { + if (CI->isMustTailCall()) + T = CI; + } + DebugLoc DL; if (DebugLoc TerminatorDL = T->getDebugLoc()) DL = TerminatorDL; else if (auto SP = F.getSubprogram()) DL = DebugLoc::get(0, 0, SP); - if (isa<ReturnInst>(T)) { - insertCall(F, ExitFunc, T, DL); - Changed = true; - } + insertCall(F, ExitFunc, T, DL); + Changed = true; } F.removeAttribute(AttributeList::FunctionIndex, ExitAttr); } diff --git a/lib/Transforms/Utils/EscapeEnumerator.cpp b/lib/Transforms/Utils/EscapeEnumerator.cpp index 78d7474e5b95..c9c96fbe5da0 100644 --- a/lib/Transforms/Utils/EscapeEnumerator.cpp +++ b/lib/Transforms/Utils/EscapeEnumerator.cpp @@ -14,9 +14,9 @@ #include "llvm/Transforms/Utils/EscapeEnumerator.h" #include "llvm/Analysis/EHPersonalities.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Module.h" -#include "llvm/Transforms/Utils/Local.h" using namespace llvm; static Constant *getDefaultPersonalityFn(Module *M) { @@ -73,8 +73,8 @@ IRBuilder<> *EscapeEnumerator::Next() { F.setPersonalityFn(PersFn); } - if (isFuncletEHPersonality(classifyEHPersonality(F.getPersonalityFn()))) { - report_fatal_error("Funclet EH not supported"); + if (isScopedEHPersonality(classifyEHPersonality(F.getPersonalityFn()))) { + report_fatal_error("Scoped EH not supported"); } LandingPadInst *LPad = diff --git a/lib/Transforms/Utils/Evaluator.cpp b/lib/Transforms/Utils/Evaluator.cpp index 3c5e299fae98..7fd9425efed3 100644 --- a/lib/Transforms/Utils/Evaluator.cpp +++ b/lib/Transforms/Utils/Evaluator.cpp @@ -24,6 +24,7 @@ #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InstrTypes.h" @@ -174,6 +175,11 @@ static bool isSimpleEnoughPointerToCommit(Constant *C) { return false; } +static Constant *getInitializer(Constant *C) { + auto *GV = dyn_cast<GlobalVariable>(C); + return GV && GV->hasDefinitiveInitializer() ? GV->getInitializer() : nullptr; +} + /// Return the value that would be computed by a load from P after the stores /// reflected by 'memory' have been performed. If we can't decide, return null. Constant *Evaluator::ComputeLoadResult(Constant *P) { @@ -189,18 +195,96 @@ Constant *Evaluator::ComputeLoadResult(Constant *P) { return nullptr; } - // Handle a constantexpr getelementptr. - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(P)) - if (CE->getOpcode() == Instruction::GetElementPtr && - isa<GlobalVariable>(CE->getOperand(0))) { - GlobalVariable *GV = cast<GlobalVariable>(CE->getOperand(0)); - if (GV->hasDefinitiveInitializer()) - return ConstantFoldLoadThroughGEPConstantExpr(GV->getInitializer(), CE); + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(P)) { + switch (CE->getOpcode()) { + // Handle a constantexpr getelementptr. + case Instruction::GetElementPtr: + if (auto *I = getInitializer(CE->getOperand(0))) + return ConstantFoldLoadThroughGEPConstantExpr(I, CE); + break; + // Handle a constantexpr bitcast. + case Instruction::BitCast: + Constant *Val = getVal(CE->getOperand(0)); + auto MM = MutatedMemory.find(Val); + auto *I = (MM != MutatedMemory.end()) ? MM->second + : getInitializer(CE->getOperand(0)); + if (I) + return ConstantFoldLoadThroughBitcast( + I, P->getType()->getPointerElementType(), DL); + break; } + } return nullptr; // don't know how to evaluate. } +static Function *getFunction(Constant *C) { + if (auto *Fn = dyn_cast<Function>(C)) + return Fn; + + if (auto *Alias = dyn_cast<GlobalAlias>(C)) + if (auto *Fn = dyn_cast<Function>(Alias->getAliasee())) + return Fn; + return nullptr; +} + +Function * +Evaluator::getCalleeWithFormalArgs(CallSite &CS, + SmallVector<Constant *, 8> &Formals) { + auto *V = CS.getCalledValue(); + if (auto *Fn = getFunction(getVal(V))) + return getFormalParams(CS, Fn, Formals) ? Fn : nullptr; + + auto *CE = dyn_cast<ConstantExpr>(V); + if (!CE || CE->getOpcode() != Instruction::BitCast || + !getFormalParams(CS, getFunction(CE->getOperand(0)), Formals)) + return nullptr; + + return dyn_cast<Function>( + ConstantFoldLoadThroughBitcast(CE, CE->getOperand(0)->getType(), DL)); +} + +bool Evaluator::getFormalParams(CallSite &CS, Function *F, + SmallVector<Constant *, 8> &Formals) { + if (!F) + return false; + + auto *FTy = F->getFunctionType(); + if (FTy->getNumParams() > CS.getNumArgOperands()) { + LLVM_DEBUG(dbgs() << "Too few arguments for function.\n"); + return false; + } + + auto ArgI = CS.arg_begin(); + for (auto ParI = FTy->param_begin(), ParE = FTy->param_end(); ParI != ParE; + ++ParI) { + auto *ArgC = ConstantFoldLoadThroughBitcast(getVal(*ArgI), *ParI, DL); + if (!ArgC) { + LLVM_DEBUG(dbgs() << "Can not convert function argument.\n"); + return false; + } + Formals.push_back(ArgC); + ++ArgI; + } + return true; +} + +/// If call expression contains bitcast then we may need to cast +/// evaluated return value to a type of the call expression. +Constant *Evaluator::castCallResultIfNeeded(Value *CallExpr, Constant *RV) { + ConstantExpr *CE = dyn_cast<ConstantExpr>(CallExpr); + if (!RV || !CE || CE->getOpcode() != Instruction::BitCast) + return RV; + + if (auto *FT = + dyn_cast<FunctionType>(CE->getType()->getPointerElementType())) { + RV = ConstantFoldLoadThroughBitcast(RV, FT->getReturnType(), DL); + if (!RV) + LLVM_DEBUG(dbgs() << "Failed to fold bitcast call expr\n"); + } + return RV; +} + /// Evaluate all instructions in block BB, returning true if successful, false /// if we can't evaluate it. NewBB returns the next BB that control flows into, /// or null upon return. @@ -210,22 +294,23 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, while (true) { Constant *InstResult = nullptr; - DEBUG(dbgs() << "Evaluating Instruction: " << *CurInst << "\n"); + LLVM_DEBUG(dbgs() << "Evaluating Instruction: " << *CurInst << "\n"); if (StoreInst *SI = dyn_cast<StoreInst>(CurInst)) { if (!SI->isSimple()) { - DEBUG(dbgs() << "Store is not simple! Can not evaluate.\n"); + LLVM_DEBUG(dbgs() << "Store is not simple! Can not evaluate.\n"); return false; // no volatile/atomic accesses. } Constant *Ptr = getVal(SI->getOperand(1)); if (auto *FoldedPtr = ConstantFoldConstant(Ptr, DL, TLI)) { - DEBUG(dbgs() << "Folding constant ptr expression: " << *Ptr); + LLVM_DEBUG(dbgs() << "Folding constant ptr expression: " << *Ptr); Ptr = FoldedPtr; - DEBUG(dbgs() << "; To: " << *Ptr << "\n"); + LLVM_DEBUG(dbgs() << "; To: " << *Ptr << "\n"); } if (!isSimpleEnoughPointerToCommit(Ptr)) { // If this is too complex for us to commit, reject it. - DEBUG(dbgs() << "Pointer is too complex for us to evaluate store."); + LLVM_DEBUG( + dbgs() << "Pointer is too complex for us to evaluate store."); return false; } @@ -234,14 +319,15 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, // If this might be too difficult for the backend to handle (e.g. the addr // of one global variable divided by another) then we can't commit it. if (!isSimpleEnoughValueToCommit(Val, SimpleConstants, DL)) { - DEBUG(dbgs() << "Store value is too complex to evaluate store. " << *Val - << "\n"); + LLVM_DEBUG(dbgs() << "Store value is too complex to evaluate store. " + << *Val << "\n"); return false; } if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) { if (CE->getOpcode() == Instruction::BitCast) { - DEBUG(dbgs() << "Attempting to resolve bitcast on constant ptr.\n"); + LLVM_DEBUG(dbgs() + << "Attempting to resolve bitcast on constant ptr.\n"); // If we're evaluating a store through a bitcast, then we need // to pull the bitcast off the pointer type and push it onto the // stored value. @@ -252,7 +338,8 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, // In order to push the bitcast onto the stored value, a bitcast // from NewTy to Val's type must be legal. If it's not, we can try // introspecting NewTy to find a legal conversion. - while (!Val->getType()->canLosslesslyBitCastTo(NewTy)) { + Constant *NewVal; + while (!(NewVal = ConstantFoldLoadThroughBitcast(Val, NewTy, DL))) { // If NewTy is a struct, we can convert the pointer to the struct // into a pointer to its first member. // FIXME: This could be extended to support arrays as well. @@ -270,17 +357,14 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, // If we can't improve the situation by introspecting NewTy, // we have to give up. } else { - DEBUG(dbgs() << "Failed to bitcast constant ptr, can not " - "evaluate.\n"); + LLVM_DEBUG(dbgs() << "Failed to bitcast constant ptr, can not " + "evaluate.\n"); return false; } } - // If we found compatible types, go ahead and push the bitcast - // onto the stored value. - Val = ConstantExpr::getBitCast(Val, NewTy); - - DEBUG(dbgs() << "Evaluated bitcast: " << *Val << "\n"); + Val = NewVal; + LLVM_DEBUG(dbgs() << "Evaluated bitcast: " << *Val << "\n"); } } @@ -289,37 +373,37 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, InstResult = ConstantExpr::get(BO->getOpcode(), getVal(BO->getOperand(0)), getVal(BO->getOperand(1))); - DEBUG(dbgs() << "Found a BinaryOperator! Simplifying: " << *InstResult - << "\n"); + LLVM_DEBUG(dbgs() << "Found a BinaryOperator! Simplifying: " + << *InstResult << "\n"); } else if (CmpInst *CI = dyn_cast<CmpInst>(CurInst)) { InstResult = ConstantExpr::getCompare(CI->getPredicate(), getVal(CI->getOperand(0)), getVal(CI->getOperand(1))); - DEBUG(dbgs() << "Found a CmpInst! Simplifying: " << *InstResult - << "\n"); + LLVM_DEBUG(dbgs() << "Found a CmpInst! Simplifying: " << *InstResult + << "\n"); } else if (CastInst *CI = dyn_cast<CastInst>(CurInst)) { InstResult = ConstantExpr::getCast(CI->getOpcode(), getVal(CI->getOperand(0)), CI->getType()); - DEBUG(dbgs() << "Found a Cast! Simplifying: " << *InstResult - << "\n"); + LLVM_DEBUG(dbgs() << "Found a Cast! Simplifying: " << *InstResult + << "\n"); } else if (SelectInst *SI = dyn_cast<SelectInst>(CurInst)) { InstResult = ConstantExpr::getSelect(getVal(SI->getOperand(0)), getVal(SI->getOperand(1)), getVal(SI->getOperand(2))); - DEBUG(dbgs() << "Found a Select! Simplifying: " << *InstResult - << "\n"); + LLVM_DEBUG(dbgs() << "Found a Select! Simplifying: " << *InstResult + << "\n"); } else if (auto *EVI = dyn_cast<ExtractValueInst>(CurInst)) { InstResult = ConstantExpr::getExtractValue( getVal(EVI->getAggregateOperand()), EVI->getIndices()); - DEBUG(dbgs() << "Found an ExtractValueInst! Simplifying: " << *InstResult - << "\n"); + LLVM_DEBUG(dbgs() << "Found an ExtractValueInst! Simplifying: " + << *InstResult << "\n"); } else if (auto *IVI = dyn_cast<InsertValueInst>(CurInst)) { InstResult = ConstantExpr::getInsertValue( getVal(IVI->getAggregateOperand()), getVal(IVI->getInsertedValueOperand()), IVI->getIndices()); - DEBUG(dbgs() << "Found an InsertValueInst! Simplifying: " << *InstResult - << "\n"); + LLVM_DEBUG(dbgs() << "Found an InsertValueInst! Simplifying: " + << *InstResult << "\n"); } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(CurInst)) { Constant *P = getVal(GEP->getOperand(0)); SmallVector<Constant*, 8> GEPOps; @@ -329,60 +413,63 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, InstResult = ConstantExpr::getGetElementPtr(GEP->getSourceElementType(), P, GEPOps, cast<GEPOperator>(GEP)->isInBounds()); - DEBUG(dbgs() << "Found a GEP! Simplifying: " << *InstResult - << "\n"); + LLVM_DEBUG(dbgs() << "Found a GEP! Simplifying: " << *InstResult << "\n"); } else if (LoadInst *LI = dyn_cast<LoadInst>(CurInst)) { if (!LI->isSimple()) { - DEBUG(dbgs() << "Found a Load! Not a simple load, can not evaluate.\n"); + LLVM_DEBUG( + dbgs() << "Found a Load! Not a simple load, can not evaluate.\n"); return false; // no volatile/atomic accesses. } Constant *Ptr = getVal(LI->getOperand(0)); if (auto *FoldedPtr = ConstantFoldConstant(Ptr, DL, TLI)) { Ptr = FoldedPtr; - DEBUG(dbgs() << "Found a constant pointer expression, constant " - "folding: " << *Ptr << "\n"); + LLVM_DEBUG(dbgs() << "Found a constant pointer expression, constant " + "folding: " + << *Ptr << "\n"); } InstResult = ComputeLoadResult(Ptr); if (!InstResult) { - DEBUG(dbgs() << "Failed to compute load result. Can not evaluate load." - "\n"); + LLVM_DEBUG( + dbgs() << "Failed to compute load result. Can not evaluate load." + "\n"); return false; // Could not evaluate load. } - DEBUG(dbgs() << "Evaluated load: " << *InstResult << "\n"); + LLVM_DEBUG(dbgs() << "Evaluated load: " << *InstResult << "\n"); } else if (AllocaInst *AI = dyn_cast<AllocaInst>(CurInst)) { if (AI->isArrayAllocation()) { - DEBUG(dbgs() << "Found an array alloca. Can not evaluate.\n"); + LLVM_DEBUG(dbgs() << "Found an array alloca. Can not evaluate.\n"); return false; // Cannot handle array allocs. } Type *Ty = AI->getAllocatedType(); AllocaTmps.push_back(llvm::make_unique<GlobalVariable>( Ty, false, GlobalValue::InternalLinkage, UndefValue::get(Ty), - AI->getName())); + AI->getName(), /*TLMode=*/GlobalValue::NotThreadLocal, + AI->getType()->getPointerAddressSpace())); InstResult = AllocaTmps.back().get(); - DEBUG(dbgs() << "Found an alloca. Result: " << *InstResult << "\n"); + LLVM_DEBUG(dbgs() << "Found an alloca. Result: " << *InstResult << "\n"); } else if (isa<CallInst>(CurInst) || isa<InvokeInst>(CurInst)) { CallSite CS(&*CurInst); // Debug info can safely be ignored here. if (isa<DbgInfoIntrinsic>(CS.getInstruction())) { - DEBUG(dbgs() << "Ignoring debug info.\n"); + LLVM_DEBUG(dbgs() << "Ignoring debug info.\n"); ++CurInst; continue; } // Cannot handle inline asm. if (isa<InlineAsm>(CS.getCalledValue())) { - DEBUG(dbgs() << "Found inline asm, can not evaluate.\n"); + LLVM_DEBUG(dbgs() << "Found inline asm, can not evaluate.\n"); return false; } if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CS.getInstruction())) { if (MemSetInst *MSI = dyn_cast<MemSetInst>(II)) { if (MSI->isVolatile()) { - DEBUG(dbgs() << "Can not optimize a volatile memset " << - "intrinsic.\n"); + LLVM_DEBUG(dbgs() << "Can not optimize a volatile memset " + << "intrinsic.\n"); return false; } Constant *Ptr = getVal(MSI->getDest()); @@ -390,7 +477,7 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, Constant *DestVal = ComputeLoadResult(getVal(Ptr)); if (Val->isNullValue() && DestVal && DestVal->isNullValue()) { // This memset is a no-op. - DEBUG(dbgs() << "Ignoring no-op memset.\n"); + LLVM_DEBUG(dbgs() << "Ignoring no-op memset.\n"); ++CurInst; continue; } @@ -398,7 +485,7 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, if (II->getIntrinsicID() == Intrinsic::lifetime_start || II->getIntrinsicID() == Intrinsic::lifetime_end) { - DEBUG(dbgs() << "Ignoring lifetime intrinsic.\n"); + LLVM_DEBUG(dbgs() << "Ignoring lifetime intrinsic.\n"); ++CurInst; continue; } @@ -407,7 +494,8 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, // We don't insert an entry into Values, as it doesn't have a // meaningful return value. if (!II->use_empty()) { - DEBUG(dbgs() << "Found unused invariant_start. Can't evaluate.\n"); + LLVM_DEBUG(dbgs() + << "Found unused invariant_start. Can't evaluate.\n"); return false; } ConstantInt *Size = cast<ConstantInt>(II->getArgOperand(0)); @@ -419,54 +507,54 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, Size->getValue().getLimitedValue() >= DL.getTypeStoreSize(ElemTy)) { Invariants.insert(GV); - DEBUG(dbgs() << "Found a global var that is an invariant: " << *GV - << "\n"); + LLVM_DEBUG(dbgs() << "Found a global var that is an invariant: " + << *GV << "\n"); } else { - DEBUG(dbgs() << "Found a global var, but can not treat it as an " - "invariant.\n"); + LLVM_DEBUG(dbgs() + << "Found a global var, but can not treat it as an " + "invariant.\n"); } } // Continue even if we do nothing. ++CurInst; continue; } else if (II->getIntrinsicID() == Intrinsic::assume) { - DEBUG(dbgs() << "Skipping assume intrinsic.\n"); + LLVM_DEBUG(dbgs() << "Skipping assume intrinsic.\n"); ++CurInst; continue; } else if (II->getIntrinsicID() == Intrinsic::sideeffect) { - DEBUG(dbgs() << "Skipping sideeffect intrinsic.\n"); + LLVM_DEBUG(dbgs() << "Skipping sideeffect intrinsic.\n"); ++CurInst; continue; } - DEBUG(dbgs() << "Unknown intrinsic. Can not evaluate.\n"); + LLVM_DEBUG(dbgs() << "Unknown intrinsic. Can not evaluate.\n"); return false; } // Resolve function pointers. - Function *Callee = dyn_cast<Function>(getVal(CS.getCalledValue())); + SmallVector<Constant *, 8> Formals; + Function *Callee = getCalleeWithFormalArgs(CS, Formals); if (!Callee || Callee->isInterposable()) { - DEBUG(dbgs() << "Can not resolve function pointer.\n"); + LLVM_DEBUG(dbgs() << "Can not resolve function pointer.\n"); return false; // Cannot resolve. } - SmallVector<Constant*, 8> Formals; - for (User::op_iterator i = CS.arg_begin(), e = CS.arg_end(); i != e; ++i) - Formals.push_back(getVal(*i)); - if (Callee->isDeclaration()) { // If this is a function we can constant fold, do it. if (Constant *C = ConstantFoldCall(CS, Callee, Formals, TLI)) { - InstResult = C; - DEBUG(dbgs() << "Constant folded function call. Result: " << - *InstResult << "\n"); + InstResult = castCallResultIfNeeded(CS.getCalledValue(), C); + if (!InstResult) + return false; + LLVM_DEBUG(dbgs() << "Constant folded function call. Result: " + << *InstResult << "\n"); } else { - DEBUG(dbgs() << "Can not constant fold function call.\n"); + LLVM_DEBUG(dbgs() << "Can not constant fold function call.\n"); return false; } } else { if (Callee->getFunctionType()->isVarArg()) { - DEBUG(dbgs() << "Can not constant fold vararg function call.\n"); + LLVM_DEBUG(dbgs() << "Can not constant fold vararg function call.\n"); return false; } @@ -474,21 +562,24 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, // Execute the call, if successful, use the return value. ValueStack.emplace_back(); if (!EvaluateFunction(Callee, RetVal, Formals)) { - DEBUG(dbgs() << "Failed to evaluate function.\n"); + LLVM_DEBUG(dbgs() << "Failed to evaluate function.\n"); return false; } ValueStack.pop_back(); - InstResult = RetVal; + InstResult = castCallResultIfNeeded(CS.getCalledValue(), RetVal); + if (RetVal && !InstResult) + return false; if (InstResult) { - DEBUG(dbgs() << "Successfully evaluated function. Result: " - << *InstResult << "\n\n"); + LLVM_DEBUG(dbgs() << "Successfully evaluated function. Result: " + << *InstResult << "\n\n"); } else { - DEBUG(dbgs() << "Successfully evaluated function. Result: 0\n\n"); + LLVM_DEBUG(dbgs() + << "Successfully evaluated function. Result: 0\n\n"); } } } else if (isa<TerminatorInst>(CurInst)) { - DEBUG(dbgs() << "Found a terminator instruction.\n"); + LLVM_DEBUG(dbgs() << "Found a terminator instruction.\n"); if (BranchInst *BI = dyn_cast<BranchInst>(CurInst)) { if (BI->isUnconditional()) { @@ -515,17 +606,18 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, NextBB = nullptr; } else { // invoke, unwind, resume, unreachable. - DEBUG(dbgs() << "Can not handle terminator."); + LLVM_DEBUG(dbgs() << "Can not handle terminator."); return false; // Cannot handle this terminator. } // We succeeded at evaluating this block! - DEBUG(dbgs() << "Successfully evaluated block.\n"); + LLVM_DEBUG(dbgs() << "Successfully evaluated block.\n"); return true; } else { // Did not know how to evaluate this! - DEBUG(dbgs() << "Failed to evaluate block due to unhandled instruction." - "\n"); + LLVM_DEBUG( + dbgs() << "Failed to evaluate block due to unhandled instruction." + "\n"); return false; } @@ -539,7 +631,7 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, // If we just processed an invoke, we finished evaluating the block. if (InvokeInst *II = dyn_cast<InvokeInst>(CurInst)) { NextBB = II->getNormalDest(); - DEBUG(dbgs() << "Found an invoke instruction. Finished Block.\n\n"); + LLVM_DEBUG(dbgs() << "Found an invoke instruction. Finished Block.\n\n"); return true; } @@ -578,7 +670,7 @@ bool Evaluator::EvaluateFunction(Function *F, Constant *&RetVal, while (true) { BasicBlock *NextBB = nullptr; // Initialized to avoid compiler warnings. - DEBUG(dbgs() << "Trying to evaluate BB: " << *CurBB << "\n"); + LLVM_DEBUG(dbgs() << "Trying to evaluate BB: " << *CurBB << "\n"); if (!EvaluateBlock(CurInst, NextBB)) return false; diff --git a/lib/Transforms/Utils/FlattenCFG.cpp b/lib/Transforms/Utils/FlattenCFG.cpp index 5fdcc6d1d727..3c6c9c9a5df4 100644 --- a/lib/Transforms/Utils/FlattenCFG.cpp +++ b/lib/Transforms/Utils/FlattenCFG.cpp @@ -13,6 +13,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" @@ -24,7 +25,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include <cassert> using namespace llvm; @@ -36,16 +36,16 @@ namespace { class FlattenCFGOpt { AliasAnalysis *AA; - /// \brief Use parallel-and or parallel-or to generate conditions for + /// Use parallel-and or parallel-or to generate conditions for /// conditional branches. bool FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder); - /// \brief If \param BB is the merge block of an if-region, attempt to merge + /// If \param BB is the merge block of an if-region, attempt to merge /// the if-region with an adjacent if-region upstream if two if-regions /// contain identical instructions. bool MergeIfRegion(BasicBlock *BB, IRBuilder<> &Builder); - /// \brief Compare a pair of blocks: \p Block1 and \p Block2, which + /// Compare a pair of blocks: \p Block1 and \p Block2, which /// are from two if-regions whose entry blocks are \p Head1 and \p /// Head2. \returns true if \p Block1 and \p Block2 contain identical /// instructions, and have no memory reference alias with \p Head2. @@ -312,7 +312,7 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder) { new UnreachableInst(CB->getContext(), CB); } while (Iteration); - DEBUG(dbgs() << "Use parallel and/or in:\n" << *FirstCondBlock); + LLVM_DEBUG(dbgs() << "Use parallel and/or in:\n" << *FirstCondBlock); return true; } @@ -469,7 +469,7 @@ bool FlattenCFGOpt::MergeIfRegion(BasicBlock *BB, IRBuilder<> &Builder) { // Remove \param SecondEntryBlock SecondEntryBlock->dropAllReferences(); SecondEntryBlock->eraseFromParent(); - DEBUG(dbgs() << "If conditions merged into:\n" << *FirstEntryBlock); + LLVM_DEBUG(dbgs() << "If conditions merged into:\n" << *FirstEntryBlock); return true; } diff --git a/lib/Transforms/Utils/FunctionComparator.cpp b/lib/Transforms/Utils/FunctionComparator.cpp index bddcbd86e914..69203f9f2485 100644 --- a/lib/Transforms/Utils/FunctionComparator.cpp +++ b/lib/Transforms/Utils/FunctionComparator.cpp @@ -18,7 +18,6 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -377,7 +376,7 @@ int FunctionComparator::cmpConstants(const Constant *L, } } default: // Unknown constant, abort. - DEBUG(dbgs() << "Looking at valueID " << L->getValueID() << "\n"); + LLVM_DEBUG(dbgs() << "Looking at valueID " << L->getValueID() << "\n"); llvm_unreachable("Constant ValueID not recognized."); return -1; } @@ -710,7 +709,7 @@ int FunctionComparator::cmpInlineAsm(const InlineAsm *L, return Res; if (int Res = cmpNumbers(L->getDialect(), R->getDialect())) return Res; - llvm_unreachable("InlineAsm blocks were not uniqued."); + assert(L->getFunctionType() != R->getFunctionType()); return 0; } @@ -925,7 +924,7 @@ FunctionComparator::FunctionHash FunctionComparator::functionHash(Function &F) { H.add(F.arg_size()); SmallVector<const BasicBlock *, 8> BBs; - SmallSet<const BasicBlock *, 16> VisitedBBs; + SmallPtrSet<const BasicBlock *, 16> VisitedBBs; // Walk the blocks in the same order as FunctionComparator::cmpBasicBlocks(), // accumulating the hash of the function "structure." (BB and opcode sequence) diff --git a/lib/Transforms/Utils/FunctionImportUtils.cpp b/lib/Transforms/Utils/FunctionImportUtils.cpp index 6b5f593073b4..479816a339d0 100644 --- a/lib/Transforms/Utils/FunctionImportUtils.cpp +++ b/lib/Transforms/Utils/FunctionImportUtils.cpp @@ -206,15 +206,10 @@ void FunctionImportGlobalProcessing::processGlobalForThinLTO(GlobalValue &GV) { // definition. if (GV.hasName()) { ValueInfo VI = ImportIndex.getValueInfo(GV.getGUID()); - if (VI) { - // Need to check all summaries are local in case of hash collisions. - bool IsLocal = VI.getSummaryList().size() && - llvm::all_of(VI.getSummaryList(), - [](const std::unique_ptr<GlobalValueSummary> &Summary) { - return Summary->isDSOLocal(); - }); - if (IsLocal) - GV.setDSOLocal(true); + if (VI && VI.isDSOLocal()) { + GV.setDSOLocal(true); + if (GV.hasDLLImportStorageClass()) + GV.setDLLStorageClass(GlobalValue::DefaultStorageClass); } } diff --git a/lib/Transforms/Utils/GlobalStatus.cpp b/lib/Transforms/Utils/GlobalStatus.cpp index 245fefb38ee8..ff6970db47da 100644 --- a/lib/Transforms/Utils/GlobalStatus.cpp +++ b/lib/Transforms/Utils/GlobalStatus.cpp @@ -60,7 +60,7 @@ bool llvm::isSafeToDestroyConstant(const Constant *C) { } static bool analyzeGlobalAux(const Value *V, GlobalStatus &GS, - SmallPtrSetImpl<const PHINode *> &PhiUsers) { + SmallPtrSetImpl<const Value *> &VisitedUsers) { if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) if (GV->isExternallyInitialized()) GS.StoredType = GlobalStatus::StoredOnce; @@ -75,7 +75,8 @@ static bool analyzeGlobalAux(const Value *V, GlobalStatus &GS, if (!isa<PointerType>(CE->getType())) return true; - if (analyzeGlobalAux(CE, GS, PhiUsers)) + // FIXME: Do we need to add constexpr selects to VisitedUsers? + if (analyzeGlobalAux(CE, GS, VisitedUsers)) return true; } else if (const Instruction *I = dyn_cast<Instruction>(UR)) { if (!GS.HasMultipleAccessingFunctions) { @@ -137,20 +138,18 @@ static bool analyzeGlobalAux(const Value *V, GlobalStatus &GS, GS.StoredType = GlobalStatus::Stored; } } - } else if (isa<BitCastInst>(I)) { - if (analyzeGlobalAux(I, GS, PhiUsers)) + } else if (isa<BitCastInst>(I) || isa<GetElementPtrInst>(I)) { + // Skip over bitcasts and GEPs; we don't care about the type or offset + // of the pointer. + if (analyzeGlobalAux(I, GS, VisitedUsers)) return true; - } else if (isa<GetElementPtrInst>(I)) { - if (analyzeGlobalAux(I, GS, PhiUsers)) - return true; - } else if (isa<SelectInst>(I)) { - if (analyzeGlobalAux(I, GS, PhiUsers)) - return true; - } else if (const PHINode *PN = dyn_cast<PHINode>(I)) { - // PHI nodes we can check just like select or GEP instructions, but we - // have to be careful about infinite recursion. - if (PhiUsers.insert(PN).second) // Not already visited. - if (analyzeGlobalAux(I, GS, PhiUsers)) + } else if (isa<SelectInst>(I) || isa<PHINode>(I)) { + // Look through selects and PHIs to find if the pointer is + // conditionally accessed. Make sure we only visit an instruction + // once; otherwise, we can get infinite recursion or exponential + // compile time. + if (VisitedUsers.insert(I).second) + if (analyzeGlobalAux(I, GS, VisitedUsers)) return true; } else if (isa<CmpInst>(I)) { GS.IsCompared = true; @@ -191,6 +190,6 @@ static bool analyzeGlobalAux(const Value *V, GlobalStatus &GS, GlobalStatus::GlobalStatus() = default; bool GlobalStatus::analyzeGlobal(const Value *V, GlobalStatus &GS) { - SmallPtrSet<const PHINode *, 16> PhiUsers; - return analyzeGlobalAux(V, GS, PhiUsers); + SmallPtrSet<const Value *, 16> VisitedUsers; + return analyzeGlobalAux(V, GS, VisitedUsers); } diff --git a/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp b/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp index b8c12ad5ea84..8382220fc9e1 100644 --- a/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp +++ b/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp @@ -161,7 +161,7 @@ void ImportedFunctionsInliningStatistics::dump(const bool Verbose) { void ImportedFunctionsInliningStatistics::calculateRealInlines() { // Removing duplicated Callers. - std::sort(NonImportedCallers.begin(), NonImportedCallers.end()); + llvm::sort(NonImportedCallers.begin(), NonImportedCallers.end()); NonImportedCallers.erase( std::unique(NonImportedCallers.begin(), NonImportedCallers.end()), NonImportedCallers.end()); @@ -190,13 +190,14 @@ ImportedFunctionsInliningStatistics::getSortedNodes() { for (const NodesMapTy::value_type& Node : NodesMap) SortedNodes.push_back(&Node); - std::sort( + llvm::sort( SortedNodes.begin(), SortedNodes.end(), [&](const SortedNodesTy::value_type &Lhs, const SortedNodesTy::value_type &Rhs) { if (Lhs->second->NumberOfInlines != Rhs->second->NumberOfInlines) return Lhs->second->NumberOfInlines > Rhs->second->NumberOfInlines; - if (Lhs->second->NumberOfRealInlines != Rhs->second->NumberOfRealInlines) + if (Lhs->second->NumberOfRealInlines != + Rhs->second->NumberOfRealInlines) return Lhs->second->NumberOfRealInlines > Rhs->second->NumberOfRealInlines; return Lhs->first() < Rhs->first(); diff --git a/lib/Transforms/Utils/InlineFunction.cpp b/lib/Transforms/Utils/InlineFunction.cpp index fedf6e100d6c..0315aac1cf84 100644 --- a/lib/Transforms/Utils/InlineFunction.cpp +++ b/lib/Transforms/Utils/InlineFunction.cpp @@ -29,6 +29,7 @@ #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" @@ -60,7 +61,6 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Transforms/Utils/Cloning.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> #include <cassert> @@ -72,6 +72,7 @@ #include <vector> using namespace llvm; +using ProfileCount = Function::ProfileCount; static cl::opt<bool> EnableNoAliasConversion("enable-noalias-to-md-conversion", cl::init(true), @@ -1247,7 +1248,7 @@ static void HandleByValArgumentInit(Value *Dst, Value *Src, Module *M, // Always generate a memcpy of alignment 1 here because we don't know // the alignment of the src pointer. Other optimizations can infer // better alignment. - Builder.CreateMemCpy(Dst, Src, Size, /*Align=*/1); + Builder.CreateMemCpy(Dst, /*DstAlign*/1, Src, /*SrcAlign*/1, Size); } /// When inlining a call site that has a byval argument, @@ -1431,29 +1432,29 @@ static void updateCallerBFI(BasicBlock *CallSiteBlock, /// Update the branch metadata for cloned call instructions. static void updateCallProfile(Function *Callee, const ValueToValueMapTy &VMap, - const Optional<uint64_t> &CalleeEntryCount, + const ProfileCount &CalleeEntryCount, const Instruction *TheCall, ProfileSummaryInfo *PSI, BlockFrequencyInfo *CallerBFI) { - if (!CalleeEntryCount.hasValue() || CalleeEntryCount.getValue() < 1) + if (!CalleeEntryCount.hasValue() || CalleeEntryCount.isSynthetic() || + CalleeEntryCount.getCount() < 1) return; - Optional<uint64_t> CallSiteCount = - PSI ? PSI->getProfileCount(TheCall, CallerBFI) : None; + auto CallSiteCount = PSI ? PSI->getProfileCount(TheCall, CallerBFI) : None; uint64_t CallCount = std::min(CallSiteCount.hasValue() ? CallSiteCount.getValue() : 0, - CalleeEntryCount.getValue()); + CalleeEntryCount.getCount()); for (auto const &Entry : VMap) if (isa<CallInst>(Entry.first)) if (auto *CI = dyn_cast_or_null<CallInst>(Entry.second)) - CI->updateProfWeight(CallCount, CalleeEntryCount.getValue()); + CI->updateProfWeight(CallCount, CalleeEntryCount.getCount()); for (BasicBlock &BB : *Callee) // No need to update the callsite if it is pruned during inlining. if (VMap.count(&BB)) for (Instruction &I : BB) if (CallInst *CI = dyn_cast<CallInst>(&I)) - CI->updateProfWeight(CalleeEntryCount.getValue() - CallCount, - CalleeEntryCount.getValue()); + CI->updateProfWeight(CalleeEntryCount.getCount() - CallCount, + CalleeEntryCount.getCount()); } /// Update the entry count of callee after inlining. @@ -1467,18 +1468,19 @@ static void updateCalleeCount(BlockFrequencyInfo *CallerBFI, BasicBlock *CallBB, // callsite is M, the new callee count is set to N - M. M is estimated from // the caller's entry count, its entry block frequency and the block frequency // of the callsite. - Optional<uint64_t> CalleeCount = Callee->getEntryCount(); + auto CalleeCount = Callee->getEntryCount(); if (!CalleeCount.hasValue() || !PSI) return; - Optional<uint64_t> CallCount = PSI->getProfileCount(CallInst, CallerBFI); + auto CallCount = PSI->getProfileCount(CallInst, CallerBFI); if (!CallCount.hasValue()) return; // Since CallSiteCount is an estimate, it could exceed the original callee // count and has to be set to 0. - if (CallCount.getValue() > CalleeCount.getValue()) - Callee->setEntryCount(0); + if (CallCount.getValue() > CalleeCount.getCount()) + CalleeCount.setCount(0); else - Callee->setEntryCount(CalleeCount.getValue() - CallCount.getValue()); + CalleeCount.setCount(CalleeCount.getCount() - CallCount.getValue()); + Callee->setEntryCount(CalleeCount); } /// This function inlines the called function into the basic block of the @@ -1500,10 +1502,9 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, IFI.reset(); Function *CalledFunc = CS.getCalledFunction(); - if (!CalledFunc || // Can't inline external function or indirect - CalledFunc->isDeclaration() || - (!ForwardVarArgsTo && CalledFunc->isVarArg())) // call, or call to a vararg function! - return false; + if (!CalledFunc || // Can't inline external function or indirect + CalledFunc->isDeclaration()) // call! + return false; // The inliner does not know how to inline through calls with operand bundles // in general ... @@ -1568,7 +1569,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, Instruction *CallSiteEHPad = nullptr; if (CallerPersonality) { EHPersonality Personality = classifyEHPersonality(CallerPersonality); - if (isFuncletEHPersonality(Personality)) { + if (isScopedEHPersonality(Personality)) { Optional<OperandBundleUse> ParentFunclet = CS.getOperandBundle(LLVMContext::OB_funclet); if (ParentFunclet) @@ -1630,9 +1631,6 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, auto &DL = Caller->getParent()->getDataLayout(); - assert((CalledFunc->arg_size() == CS.arg_size() || ForwardVarArgsTo) && - "Varargs calls can only be inlined if the Varargs are forwarded!"); - // Calculate the vector of arguments to pass into the function cloner, which // matches up the formal to the actual argument values. CallSite::arg_iterator AI = CS.arg_begin(); @@ -1815,9 +1813,12 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, } SmallVector<Value*,4> VarArgsToForward; + SmallVector<AttributeSet, 4> VarArgsAttrs; for (unsigned i = CalledFunc->getFunctionType()->getNumParams(); - i < CS.getNumArgOperands(); i++) + i < CS.getNumArgOperands(); i++) { VarArgsToForward.push_back(CS.getArgOperand(i)); + VarArgsAttrs.push_back(CS.getAttributes().getParamAttributes(i)); + } bool InlinedMustTailCalls = false, InlinedDeoptimizeCalls = false; if (InlinedFunctionInfo.ContainsCalls) { @@ -1825,6 +1826,10 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, if (CallInst *CI = dyn_cast<CallInst>(TheCall)) CallSiteTailKind = CI->getTailCallKind(); + // For inlining purposes, the "notail" marker is the same as no marker. + if (CallSiteTailKind == CallInst::TCK_NoTail) + CallSiteTailKind = CallInst::TCK_None; + for (Function::iterator BB = FirstNewBlock, E = Caller->end(); BB != E; ++BB) { for (auto II = BB->begin(); II != BB->end();) { @@ -1833,6 +1838,40 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, if (!CI) continue; + // Forward varargs from inlined call site to calls to the + // ForwardVarArgsTo function, if requested, and to musttail calls. + if (!VarArgsToForward.empty() && + ((ForwardVarArgsTo && + CI->getCalledFunction() == ForwardVarArgsTo) || + CI->isMustTailCall())) { + // Collect attributes for non-vararg parameters. + AttributeList Attrs = CI->getAttributes(); + SmallVector<AttributeSet, 8> ArgAttrs; + if (!Attrs.isEmpty() || !VarArgsAttrs.empty()) { + for (unsigned ArgNo = 0; + ArgNo < CI->getFunctionType()->getNumParams(); ++ArgNo) + ArgAttrs.push_back(Attrs.getParamAttributes(ArgNo)); + } + + // Add VarArg attributes. + ArgAttrs.append(VarArgsAttrs.begin(), VarArgsAttrs.end()); + Attrs = AttributeList::get(CI->getContext(), Attrs.getFnAttributes(), + Attrs.getRetAttributes(), ArgAttrs); + // Add VarArgs to existing parameters. + SmallVector<Value *, 6> Params(CI->arg_operands()); + Params.append(VarArgsToForward.begin(), VarArgsToForward.end()); + CallInst *NewCI = + CallInst::Create(CI->getCalledFunction() ? CI->getCalledFunction() + : CI->getCalledValue(), + Params, "", CI); + NewCI->setDebugLoc(CI->getDebugLoc()); + NewCI->setAttributes(Attrs); + NewCI->setCallingConv(CI->getCallingConv()); + CI->replaceAllUsesWith(NewCI); + CI->eraseFromParent(); + CI = NewCI; + } + if (Function *F = CI->getCalledFunction()) InlinedDeoptimizeCalls |= F->getIntrinsicID() == Intrinsic::experimental_deoptimize; @@ -1850,6 +1889,8 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // f -> musttail g -> tail f ==> f -> tail f // f -> g -> musttail f ==> f -> f // f -> g -> tail f ==> f -> f + // + // Inlined notail calls should remain notail calls. CallInst::TailCallKind ChildTCK = CI->getTailCallKind(); if (ChildTCK != CallInst::TCK_NoTail) ChildTCK = std::min(CallSiteTailKind, ChildTCK); @@ -1860,16 +1901,6 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // 'nounwind'. if (MarkNoUnwind) CI->setDoesNotThrow(); - - if (ForwardVarArgsTo && !VarArgsToForward.empty() && - CI->getCalledFunction() == ForwardVarArgsTo) { - SmallVector<Value*, 6> Params(CI->arg_operands()); - Params.append(VarArgsToForward.begin(), VarArgsToForward.end()); - CallInst *Call = CallInst::Create(CI->getCalledFunction(), Params, "", CI); - Call->setDebugLoc(CI->getDebugLoc()); - CI->replaceAllUsesWith(Call); - CI->eraseFromParent(); - } } } } diff --git a/lib/Transforms/Utils/InstructionNamer.cpp b/lib/Transforms/Utils/InstructionNamer.cpp index 23ec45edb3ef..003721f2b939 100644 --- a/lib/Transforms/Utils/InstructionNamer.cpp +++ b/lib/Transforms/Utils/InstructionNamer.cpp @@ -17,7 +17,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/Type.h" #include "llvm/Pass.h" -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" using namespace llvm; namespace { diff --git a/lib/Transforms/Utils/IntegerDivision.cpp b/lib/Transforms/Utils/IntegerDivision.cpp index 5a90dcb033b2..3fbb3487884b 100644 --- a/lib/Transforms/Utils/IntegerDivision.cpp +++ b/lib/Transforms/Utils/IntegerDivision.cpp @@ -372,7 +372,7 @@ static Value *generateUnsignedDivisionCode(Value *Dividend, Value *Divisor, /// information about the operands are known. Implements both 32bit and 64bit /// scalar division. /// -/// @brief Replace Rem with generated code. +/// Replace Rem with generated code. bool llvm::expandRemainder(BinaryOperator *Rem) { assert((Rem->getOpcode() == Instruction::SRem || Rem->getOpcode() == Instruction::URem) && @@ -430,7 +430,7 @@ bool llvm::expandRemainder(BinaryOperator *Rem) { /// when more information about the operands are known. Implements both /// 32bit and 64bit scalar division. /// -/// @brief Replace Div with generated code. +/// Replace Div with generated code. bool llvm::expandDivision(BinaryOperator *Div) { assert((Div->getOpcode() == Instruction::SDiv || Div->getOpcode() == Instruction::UDiv) && @@ -482,7 +482,7 @@ bool llvm::expandDivision(BinaryOperator *Div) { /// that have no or very little suppport for smaller than 32 bit integer /// arithmetic. /// -/// @brief Replace Rem with emulation code. +/// Replace Rem with emulation code. bool llvm::expandRemainderUpTo32Bits(BinaryOperator *Rem) { assert((Rem->getOpcode() == Instruction::SRem || Rem->getOpcode() == Instruction::URem) && @@ -531,7 +531,7 @@ bool llvm::expandRemainderUpTo32Bits(BinaryOperator *Rem) { /// 64 bits. Uses the above routines and extends the inputs/truncates the /// outputs to operate in 64 bits. /// -/// @brief Replace Rem with emulation code. +/// Replace Rem with emulation code. bool llvm::expandRemainderUpTo64Bits(BinaryOperator *Rem) { assert((Rem->getOpcode() == Instruction::SRem || Rem->getOpcode() == Instruction::URem) && @@ -580,7 +580,7 @@ bool llvm::expandRemainderUpTo64Bits(BinaryOperator *Rem) { /// in 32 bits; that is, these routines are good for targets that have no /// or very little support for smaller than 32 bit integer arithmetic. /// -/// @brief Replace Div with emulation code. +/// Replace Div with emulation code. bool llvm::expandDivisionUpTo32Bits(BinaryOperator *Div) { assert((Div->getOpcode() == Instruction::SDiv || Div->getOpcode() == Instruction::UDiv) && @@ -628,7 +628,7 @@ bool llvm::expandDivisionUpTo32Bits(BinaryOperator *Div) { /// above routines and extends the inputs/truncates the outputs to operate /// in 64 bits. /// -/// @brief Replace Div with emulation code. +/// Replace Div with emulation code. bool llvm::expandDivisionUpTo64Bits(BinaryOperator *Div) { assert((Div->getOpcode() == Instruction::SDiv || Div->getOpcode() == Instruction::UDiv) && diff --git a/lib/Transforms/Utils/LCSSA.cpp b/lib/Transforms/Utils/LCSSA.cpp index ae0e2bb6c280..956d0387c7a8 100644 --- a/lib/Transforms/Utils/LCSSA.cpp +++ b/lib/Transforms/Utils/LCSSA.cpp @@ -36,13 +36,14 @@ #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/PredIteratorCache.h" #include "llvm/Pass.h" -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" using namespace llvm; @@ -214,18 +215,27 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, Worklist.push_back(PostProcessPN); // Keep track of PHI nodes that we want to remove because they did not have - // any uses rewritten. + // any uses rewritten. If the new PHI is used, store it so that we can + // try to propagate dbg.value intrinsics to it. + SmallVector<PHINode *, 2> NeedDbgValues; for (PHINode *PN : AddedPHIs) if (PN->use_empty()) PHIsToRemove.insert(PN); - + else + NeedDbgValues.push_back(PN); + insertDebugValuesForPHIs(InstBB, NeedDbgValues); Changed = true; } - // Remove PHI nodes that did not have any uses rewritten. - for (PHINode *PN : PHIsToRemove) { - assert (PN->use_empty() && "Trying to remove a phi with uses."); - PN->eraseFromParent(); - } + // Remove PHI nodes that did not have any uses rewritten. We need to redo the + // use_empty() check here, because even if the PHI node wasn't used when added + // to PHIsToRemove, later added PHI nodes can be using it. This cleanup is + // not guaranteed to handle trees/cycles of PHI nodes that only are used by + // each other. Such situations has only been noticed when the input IR + // contains unreachable code, and leaving some extra redundant PHI nodes in + // such situations is considered a minor problem. + for (PHINode *PN : PHIsToRemove) + if (PN->use_empty()) + PN->eraseFromParent(); return Changed; } diff --git a/lib/Transforms/Utils/LibCallsShrinkWrap.cpp b/lib/Transforms/Utils/LibCallsShrinkWrap.cpp index 42aca757c2af..9832a6f24e1f 100644 --- a/lib/Transforms/Utils/LibCallsShrinkWrap.cpp +++ b/lib/Transforms/Utils/LibCallsShrinkWrap.cpp @@ -79,11 +79,11 @@ public: bool perform() { bool Changed = false; for (auto &CI : WorkList) { - DEBUG(dbgs() << "CDCE calls: " << CI->getCalledFunction()->getName() - << "\n"); + LLVM_DEBUG(dbgs() << "CDCE calls: " << CI->getCalledFunction()->getName() + << "\n"); if (perform(CI)) { Changed = true; - DEBUG(dbgs() << "Transformed\n"); + LLVM_DEBUG(dbgs() << "Transformed\n"); } } return Changed; @@ -421,7 +421,7 @@ Value *LibCallsShrinkWrap::generateCondForPow(CallInst *CI, const LibFunc &Func) { // FIXME: LibFunc_powf and powl TBD. if (Func != LibFunc_pow) { - DEBUG(dbgs() << "Not handled powf() and powl()\n"); + LLVM_DEBUG(dbgs() << "Not handled powf() and powl()\n"); return nullptr; } @@ -433,7 +433,7 @@ Value *LibCallsShrinkWrap::generateCondForPow(CallInst *CI, if (ConstantFP *CF = dyn_cast<ConstantFP>(Base)) { double D = CF->getValueAPF().convertToDouble(); if (D < 1.0f || D > APInt::getMaxValue(8).getZExtValue()) { - DEBUG(dbgs() << "Not handled pow(): constant base out of range\n"); + LLVM_DEBUG(dbgs() << "Not handled pow(): constant base out of range\n"); return nullptr; } @@ -447,7 +447,7 @@ Value *LibCallsShrinkWrap::generateCondForPow(CallInst *CI, // If the Base value coming from an integer type. Instruction *I = dyn_cast<Instruction>(Base); if (!I) { - DEBUG(dbgs() << "Not handled pow(): FP type base\n"); + LLVM_DEBUG(dbgs() << "Not handled pow(): FP type base\n"); return nullptr; } unsigned Opcode = I->getOpcode(); @@ -461,7 +461,7 @@ Value *LibCallsShrinkWrap::generateCondForPow(CallInst *CI, else if (BW == 32) UpperV = 32.0f; else { - DEBUG(dbgs() << "Not handled pow(): type too wide\n"); + LLVM_DEBUG(dbgs() << "Not handled pow(): type too wide\n"); return nullptr; } @@ -477,7 +477,7 @@ Value *LibCallsShrinkWrap::generateCondForPow(CallInst *CI, Value *Cond0 = BBBuilder.CreateFCmp(CmpInst::FCMP_OLE, Base, V0); return BBBuilder.CreateOr(Cond0, Cond); } - DEBUG(dbgs() << "Not handled pow(): base not from integer convert\n"); + LLVM_DEBUG(dbgs() << "Not handled pow(): base not from integer convert\n"); return nullptr; } @@ -496,9 +496,9 @@ void LibCallsShrinkWrap::shrinkWrapCI(CallInst *CI, Value *Cond) { SuccBB->setName("cdce.end"); CI->removeFromParent(); CallBB->getInstList().insert(CallBB->getFirstInsertionPt(), CI); - DEBUG(dbgs() << "== Basic Block After =="); - DEBUG(dbgs() << *CallBB->getSinglePredecessor() << *CallBB - << *CallBB->getSingleSuccessor() << "\n"); + LLVM_DEBUG(dbgs() << "== Basic Block After =="); + LLVM_DEBUG(dbgs() << *CallBB->getSinglePredecessor() << *CallBB + << *CallBB->getSingleSuccessor() << "\n"); } // Perform the transformation to a single candidate. @@ -529,10 +529,7 @@ static bool runImpl(Function &F, const TargetLibraryInfo &TLI, bool Changed = CCDCE.perform(); // Verify the dominator after we've updated it locally. -#ifndef NDEBUG - if (DT) - DT->verifyDomTree(); -#endif + assert(!DT || DT->verify(DominatorTree::VerificationLevel::Fast)); return Changed; } diff --git a/lib/Transforms/Utils/Local.cpp b/lib/Transforms/Utils/Local.cpp index a1961eecb391..ae3cb077a3af 100644 --- a/lib/Transforms/Utils/Local.cpp +++ b/lib/Transforms/Utils/Local.cpp @@ -73,6 +73,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> #include <cassert> #include <climits> @@ -100,26 +101,23 @@ STATISTIC(NumRemoved, "Number of unreachable basic blocks removed"); /// conditions and indirectbr addresses this might make dead if /// DeleteDeadConditions is true. bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, - const TargetLibraryInfo *TLI) { + const TargetLibraryInfo *TLI, + DeferredDominance *DDT) { TerminatorInst *T = BB->getTerminator(); IRBuilder<> Builder(T); // Branch - See if we are conditional jumping on constant - if (BranchInst *BI = dyn_cast<BranchInst>(T)) { + if (auto *BI = dyn_cast<BranchInst>(T)) { if (BI->isUnconditional()) return false; // Can't optimize uncond branch BasicBlock *Dest1 = BI->getSuccessor(0); BasicBlock *Dest2 = BI->getSuccessor(1); - if (ConstantInt *Cond = dyn_cast<ConstantInt>(BI->getCondition())) { + if (auto *Cond = dyn_cast<ConstantInt>(BI->getCondition())) { // Are we branching on constant? // YES. Change to unconditional branch... BasicBlock *Destination = Cond->getZExtValue() ? Dest1 : Dest2; BasicBlock *OldDest = Cond->getZExtValue() ? Dest2 : Dest1; - //cerr << "Function: " << T->getParent()->getParent() - // << "\nRemoving branch from " << T->getParent() - // << "\n\nTo: " << OldDest << endl; - // Let the basic block know that we are letting go of it. Based on this, // it will adjust it's PHI nodes. OldDest->removePredecessor(BB); @@ -127,6 +125,8 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, // Replace the conditional branch with an unconditional one. Builder.CreateBr(Destination); BI->eraseFromParent(); + if (DDT) + DDT->deleteEdge(BB, OldDest); return true; } @@ -150,10 +150,10 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, return false; } - if (SwitchInst *SI = dyn_cast<SwitchInst>(T)) { + if (auto *SI = dyn_cast<SwitchInst>(T)) { // If we are switching on a constant, we can convert the switch to an // unconditional branch. - ConstantInt *CI = dyn_cast<ConstantInt>(SI->getCondition()); + auto *CI = dyn_cast<ConstantInt>(SI->getCondition()); BasicBlock *DefaultDest = SI->getDefaultDest(); BasicBlock *TheOnlyDest = DefaultDest; @@ -197,9 +197,12 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, createBranchWeights(Weights)); } // Remove this entry. - DefaultDest->removePredecessor(SI->getParent()); + BasicBlock *ParentBB = SI->getParent(); + DefaultDest->removePredecessor(ParentBB); i = SI->removeCase(i); e = SI->case_end(); + if (DDT) + DDT->deleteEdge(ParentBB, DefaultDest); continue; } @@ -225,14 +228,20 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, // Insert the new branch. Builder.CreateBr(TheOnlyDest); BasicBlock *BB = SI->getParent(); + std::vector <DominatorTree::UpdateType> Updates; + if (DDT) + Updates.reserve(SI->getNumSuccessors() - 1); // Remove entries from PHI nodes which we no longer branch to... for (BasicBlock *Succ : SI->successors()) { // Found case matching a constant operand? - if (Succ == TheOnlyDest) + if (Succ == TheOnlyDest) { TheOnlyDest = nullptr; // Don't modify the first branch to TheOnlyDest - else + } else { Succ->removePredecessor(BB); + if (DDT) + Updates.push_back({DominatorTree::Delete, BB, Succ}); + } } // Delete the old switch. @@ -240,6 +249,8 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, SI->eraseFromParent(); if (DeleteDeadConditions) RecursivelyDeleteTriviallyDeadInstructions(Cond, TLI); + if (DDT) + DDT->applyUpdates(Updates); return true; } @@ -280,19 +291,28 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, return false; } - if (IndirectBrInst *IBI = dyn_cast<IndirectBrInst>(T)) { + if (auto *IBI = dyn_cast<IndirectBrInst>(T)) { // indirectbr blockaddress(@F, @BB) -> br label @BB - if (BlockAddress *BA = + if (auto *BA = dyn_cast<BlockAddress>(IBI->getAddress()->stripPointerCasts())) { BasicBlock *TheOnlyDest = BA->getBasicBlock(); + std::vector <DominatorTree::UpdateType> Updates; + if (DDT) + Updates.reserve(IBI->getNumDestinations() - 1); + // Insert the new branch. Builder.CreateBr(TheOnlyDest); for (unsigned i = 0, e = IBI->getNumDestinations(); i != e; ++i) { - if (IBI->getDestination(i) == TheOnlyDest) + if (IBI->getDestination(i) == TheOnlyDest) { TheOnlyDest = nullptr; - else - IBI->getDestination(i)->removePredecessor(IBI->getParent()); + } else { + BasicBlock *ParentBB = IBI->getParent(); + BasicBlock *DestBB = IBI->getDestination(i); + DestBB->removePredecessor(ParentBB); + if (DDT) + Updates.push_back({DominatorTree::Delete, ParentBB, DestBB}); + } } Value *Address = IBI->getAddress(); IBI->eraseFromParent(); @@ -307,6 +327,8 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, new UnreachableInst(BB->getContext(), BB); } + if (DDT) + DDT->applyUpdates(Updates); return true; } } @@ -350,6 +372,11 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, return false; return true; } + if (DbgLabelInst *DLI = dyn_cast<DbgLabelInst>(I)) { + if (DLI->getLabel()) + return false; + return true; + } if (!I->mayHaveSideEffects()) return true; @@ -357,8 +384,9 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, // Special case intrinsics that "may have side effects" but can be deleted // when dead. if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { - // Safe to delete llvm.stacksave if dead. - if (II->getIntrinsicID() == Intrinsic::stacksave) + // Safe to delete llvm.stacksave and launder.invariant.group if dead. + if (II->getIntrinsicID() == Intrinsic::stacksave || + II->getIntrinsicID() == Intrinsic::launder_invariant_group) return true; // Lifetime intrinsics are dead when their right-hand is undef. @@ -406,17 +434,31 @@ llvm::RecursivelyDeleteTriviallyDeadInstructions(Value *V, SmallVector<Instruction*, 16> DeadInsts; DeadInsts.push_back(I); + RecursivelyDeleteTriviallyDeadInstructions(DeadInsts, TLI); - do { - I = DeadInsts.pop_back_val(); + return true; +} + +void llvm::RecursivelyDeleteTriviallyDeadInstructions( + SmallVectorImpl<Instruction *> &DeadInsts, const TargetLibraryInfo *TLI) { + // Process the dead instruction list until empty. + while (!DeadInsts.empty()) { + Instruction &I = *DeadInsts.pop_back_val(); + assert(I.use_empty() && "Instructions with uses are not dead."); + assert(isInstructionTriviallyDead(&I, TLI) && + "Live instruction found in dead worklist!"); + + // Don't lose the debug info while deleting the instructions. + salvageDebugInfo(I); // Null out all of the instruction's operands to see if any operand becomes // dead as we go. - for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { - Value *OpV = I->getOperand(i); - I->setOperand(i, nullptr); + for (Use &OpU : I.operands()) { + Value *OpV = OpU.get(); + OpU.set(nullptr); - if (!OpV->use_empty()) continue; + if (!OpV->use_empty()) + continue; // If the operand is an instruction that became dead as we nulled out the // operand, and if it is 'trivially' dead, delete it in a future loop @@ -426,10 +468,8 @@ llvm::RecursivelyDeleteTriviallyDeadInstructions(Value *V, DeadInsts.push_back(OpI); } - I->eraseFromParent(); - } while (!DeadInsts.empty()); - - return true; + I.eraseFromParent(); + } } /// areAllUsesEqual - Check whether the uses of a value are all the same. @@ -481,6 +521,8 @@ simplifyAndDCEInstruction(Instruction *I, const DataLayout &DL, const TargetLibraryInfo *TLI) { if (isInstructionTriviallyDead(I, TLI)) { + salvageDebugInfo(*I); + // Null out all of the instruction's operands to see if any operand becomes // dead as we go. for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { @@ -583,7 +625,8 @@ bool llvm::SimplifyInstructionsInBlock(BasicBlock *BB, /// /// .. and delete the predecessor corresponding to the '1', this will attempt to /// recursively fold the and to 0. -void llvm::RemovePredecessorAndSimplify(BasicBlock *BB, BasicBlock *Pred) { +void llvm::RemovePredecessorAndSimplify(BasicBlock *BB, BasicBlock *Pred, + DeferredDominance *DDT) { // This only adjusts blocks with PHI nodes. if (!isa<PHINode>(BB->begin())) return; @@ -606,13 +649,18 @@ void llvm::RemovePredecessorAndSimplify(BasicBlock *BB, BasicBlock *Pred) { // of the block. if (PhiIt != OldPhiIt) PhiIt = &BB->front(); } + if (DDT) + DDT->deleteEdge(Pred, BB); } /// MergeBasicBlockIntoOnlyPred - DestBB is a block with one predecessor and its /// predecessor is known to have one successor (DestBB!). Eliminate the edge /// between them, moving the instructions in the predecessor into DestBB and /// deleting the predecessor block. -void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB, DominatorTree *DT) { +void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB, DominatorTree *DT, + DeferredDominance *DDT) { + assert(!(DT && DDT) && "Cannot call with both DT and DDT."); + // If BB has single-entry PHI nodes, fold them. while (PHINode *PN = dyn_cast<PHINode>(DestBB->begin())) { Value *NewVal = PN->getIncomingValue(0); @@ -625,6 +673,24 @@ void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB, DominatorTree *DT) { BasicBlock *PredBB = DestBB->getSinglePredecessor(); assert(PredBB && "Block doesn't have a single predecessor!"); + bool ReplaceEntryBB = false; + if (PredBB == &DestBB->getParent()->getEntryBlock()) + ReplaceEntryBB = true; + + // Deferred DT update: Collect all the edges that enter PredBB. These + // dominator edges will be redirected to DestBB. + std::vector <DominatorTree::UpdateType> Updates; + if (DDT && !ReplaceEntryBB) { + Updates.reserve(1 + (2 * pred_size(PredBB))); + Updates.push_back({DominatorTree::Delete, PredBB, DestBB}); + for (auto I = pred_begin(PredBB), E = pred_end(PredBB); I != E; ++I) { + Updates.push_back({DominatorTree::Delete, *I, PredBB}); + // This predecessor of PredBB may already have DestBB as a successor. + if (llvm::find(successors(*I), DestBB) == succ_end(*I)) + Updates.push_back({DominatorTree::Insert, *I, DestBB}); + } + } + // Zap anything that took the address of DestBB. Not doing this will give the // address an invalid value. if (DestBB->hasAddressTaken()) { @@ -645,7 +711,7 @@ void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB, DominatorTree *DT) { // If the PredBB is the entry block of the function, move DestBB up to // become the entry block after we erase PredBB. - if (PredBB == &DestBB->getParent()->getEntryBlock()) + if (ReplaceEntryBB) DestBB->moveAfter(PredBB); if (DT) { @@ -657,8 +723,19 @@ void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB, DominatorTree *DT) { DT->eraseNode(PredBB); } } - // Nuke BB. - PredBB->eraseFromParent(); + + if (DDT) { + DDT->deleteBB(PredBB); // Deferred deletion of BB. + if (ReplaceEntryBB) + // The entry block was removed and there is no external interface for the + // dominator tree to be notified of this change. In this corner-case we + // recalculate the entire tree. + DDT->recalculate(*(DestBB->getParent())); + else + DDT->applyUpdates(Updates); + } else { + PredBB->eraseFromParent(); // Nuke BB. + } } /// CanMergeValues - Return true if we can choose one of these values to use @@ -675,8 +752,8 @@ static bool CanMergeValues(Value *First, Value *Second) { static bool CanPropagatePredecessorsForPHIs(BasicBlock *BB, BasicBlock *Succ) { assert(*succ_begin(BB) == Succ && "Succ is not successor of BB!"); - DEBUG(dbgs() << "Looking to fold " << BB->getName() << " into " - << Succ->getName() << "\n"); + LLVM_DEBUG(dbgs() << "Looking to fold " << BB->getName() << " into " + << Succ->getName() << "\n"); // Shortcut, if there is only a single predecessor it must be BB and merging // is always safe if (Succ->getSinglePredecessor()) return true; @@ -699,10 +776,11 @@ static bool CanPropagatePredecessorsForPHIs(BasicBlock *BB, BasicBlock *Succ) { if (BBPreds.count(IBB) && !CanMergeValues(BBPN->getIncomingValueForBlock(IBB), PN->getIncomingValue(PI))) { - DEBUG(dbgs() << "Can't fold, phi node " << PN->getName() << " in " - << Succ->getName() << " is conflicting with " - << BBPN->getName() << " with regard to common predecessor " - << IBB->getName() << "\n"); + LLVM_DEBUG(dbgs() + << "Can't fold, phi node " << PN->getName() << " in " + << Succ->getName() << " is conflicting with " + << BBPN->getName() << " with regard to common predecessor " + << IBB->getName() << "\n"); return false; } } @@ -715,9 +793,10 @@ static bool CanPropagatePredecessorsForPHIs(BasicBlock *BB, BasicBlock *Succ) { BasicBlock *IBB = PN->getIncomingBlock(PI); if (BBPreds.count(IBB) && !CanMergeValues(Val, PN->getIncomingValue(PI))) { - DEBUG(dbgs() << "Can't fold, phi node " << PN->getName() << " in " - << Succ->getName() << " is conflicting with regard to common " - << "predecessor " << IBB->getName() << "\n"); + LLVM_DEBUG(dbgs() << "Can't fold, phi node " << PN->getName() + << " in " << Succ->getName() + << " is conflicting with regard to common " + << "predecessor " << IBB->getName() << "\n"); return false; } } @@ -730,7 +809,7 @@ static bool CanPropagatePredecessorsForPHIs(BasicBlock *BB, BasicBlock *Succ) { using PredBlockVector = SmallVector<BasicBlock *, 16>; using IncomingValueMap = DenseMap<BasicBlock *, Value *>; -/// \brief Determines the value to use as the phi node input for a block. +/// Determines the value to use as the phi node input for a block. /// /// Select between \p OldVal any value that we know flows from \p BB /// to a particular phi on the basis of which one (if either) is not @@ -759,7 +838,7 @@ static Value *selectIncomingValueForBlock(Value *OldVal, BasicBlock *BB, return OldVal; } -/// \brief Create a map from block to value for the operands of a +/// Create a map from block to value for the operands of a /// given phi. /// /// Create a map from block to value for each non-undef value flowing @@ -778,7 +857,7 @@ static void gatherIncomingValuesToPhi(PHINode *PN, } } -/// \brief Replace the incoming undef values to a phi with the values +/// Replace the incoming undef values to a phi with the values /// from a block-to-value map. /// /// \param PN The phi we are replacing the undefs in. @@ -798,7 +877,7 @@ static void replaceUndefValuesInPhi(PHINode *PN, } } -/// \brief Replace a value flowing from a block to a phi with +/// Replace a value flowing from a block to a phi with /// potentially multiple instances of that value flowing from the /// block's predecessors to the phi. /// @@ -865,7 +944,8 @@ static void redirectValuesFromPredecessorsToPhi(BasicBlock *BB, /// potential side-effect free intrinsics and the branch. If possible, /// eliminate BB by rewriting all the predecessors to branch to the successor /// block and return true. If we can't transform, return false. -bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB) { +bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, + DeferredDominance *DDT) { assert(BB != &BB->getParent()->getEntryBlock() && "TryToSimplifyUncondBranchFromEmptyBlock called on entry block!"); @@ -904,7 +984,20 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB) { } } - DEBUG(dbgs() << "Killing Trivial BB: \n" << *BB); + LLVM_DEBUG(dbgs() << "Killing Trivial BB: \n" << *BB); + + std::vector<DominatorTree::UpdateType> Updates; + if (DDT) { + Updates.reserve(1 + (2 * pred_size(BB))); + Updates.push_back({DominatorTree::Delete, BB, Succ}); + // All predecessors of BB will be moved to Succ. + for (auto I = pred_begin(BB), E = pred_end(BB); I != E; ++I) { + Updates.push_back({DominatorTree::Delete, *I, BB}); + // This predecessor of BB may already have Succ as a successor. + if (llvm::find(successors(*I), Succ) == succ_end(*I)) + Updates.push_back({DominatorTree::Insert, *I, Succ}); + } + } if (isa<PHINode>(Succ->begin())) { // If there is more than one pred of succ, and there are PHI nodes in @@ -950,7 +1043,13 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB) { // Everything that jumped to BB now goes to Succ. BB->replaceAllUsesWith(Succ); if (!Succ->hasName()) Succ->takeName(BB); - BB->eraseFromParent(); // Delete the old basic block. + + if (DDT) { + DDT->deleteBB(BB); // Deferred deletion of the old basic block. + DDT->applyUpdates(Updates); + } else { + BB->eraseFromParent(); // Delete the old basic block. + } return true; } @@ -1129,6 +1228,31 @@ static bool PhiHasDebugValue(DILocalVariable *DIVar, return false; } +/// Check if the alloc size of \p ValTy is large enough to cover the variable +/// (or fragment of the variable) described by \p DII. +/// +/// This is primarily intended as a helper for the different +/// ConvertDebugDeclareToDebugValue functions. The dbg.declare/dbg.addr that is +/// converted describes an alloca'd variable, so we need to use the +/// alloc size of the value when doing the comparison. E.g. an i1 value will be +/// identified as covering an n-bit fragment, if the store size of i1 is at +/// least n bits. +static bool valueCoversEntireFragment(Type *ValTy, DbgInfoIntrinsic *DII) { + const DataLayout &DL = DII->getModule()->getDataLayout(); + uint64_t ValueSize = DL.getTypeAllocSizeInBits(ValTy); + if (auto FragmentSize = DII->getFragmentSizeInBits()) + return ValueSize >= *FragmentSize; + // We can't always calculate the size of the DI variable (e.g. if it is a + // VLA). Try to use the size of the alloca that the dbg intrinsic describes + // intead. + if (DII->isAddressOfVariable()) + if (auto *AI = dyn_cast_or_null<AllocaInst>(DII->getVariableLocation())) + if (auto FragmentSize = AI->getAllocationSizeInBits(DL)) + return ValueSize >= *FragmentSize; + // Could not determine size of variable. Conservatively return false. + return false; +} + /// Inserts a llvm.dbg.value intrinsic before a store to an alloca'd value /// that has an associated llvm.dbg.declare or llvm.dbg.addr intrinsic. void llvm::ConvertDebugDeclareToDebugValue(DbgInfoIntrinsic *DII, @@ -1139,6 +1263,21 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgInfoIntrinsic *DII, auto *DIExpr = DII->getExpression(); Value *DV = SI->getOperand(0); + if (!valueCoversEntireFragment(SI->getValueOperand()->getType(), DII)) { + // FIXME: If storing to a part of the variable described by the dbg.declare, + // then we want to insert a dbg.value for the corresponding fragment. + LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to dbg.value: " + << *DII << '\n'); + // For now, when there is a store to parts of the variable (but we do not + // know which part) we insert an dbg.value instrinsic to indicate that we + // know nothing about the variable's content. + DV = UndefValue::get(DV->getType()); + if (!LdStHasDebugValue(DIVar, DIExpr, SI)) + Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, DII->getDebugLoc(), + SI); + return; + } + // If an argument is zero extended then use argument directly. The ZExt // may be zapped by an optimization pass in future. Argument *ExtendedArg = nullptr; @@ -1182,6 +1321,15 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgInfoIntrinsic *DII, if (LdStHasDebugValue(DIVar, DIExpr, LI)) return; + if (!valueCoversEntireFragment(LI->getType(), DII)) { + // FIXME: If only referring to a part of the variable described by the + // dbg.declare, then we want to insert a dbg.value for the corresponding + // fragment. + LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to dbg.value: " + << *DII << '\n'); + return; + } + // We are now tracking the loaded value instead of the address. In the // future if multi-location support is added to the IR, it might be // preferable to keep tracking both the loaded value and the original @@ -1202,6 +1350,15 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgInfoIntrinsic *DII, if (PhiHasDebugValue(DIVar, DIExpr, APN)) return; + if (!valueCoversEntireFragment(APN->getType(), DII)) { + // FIXME: If only referring to a part of the variable described by the + // dbg.declare, then we want to insert a dbg.value for the corresponding + // fragment. + LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to dbg.value: " + << *DII << '\n'); + return; + } + BasicBlock *BB = APN->getParent(); auto InsertionPt = BB->getFirstInsertionPt(); @@ -1241,33 +1398,91 @@ bool llvm::LowerDbgDeclare(Function &F) { // stored on the stack, while the dbg.declare can only describe // the stack slot (and at a lexical-scope granularity). Later // passes will attempt to elide the stack slot. - if (AI && !isArray(AI)) { - for (auto &AIUse : AI->uses()) { - User *U = AIUse.getUser(); - if (StoreInst *SI = dyn_cast<StoreInst>(U)) { - if (AIUse.getOperandNo() == 1) - ConvertDebugDeclareToDebugValue(DDI, SI, DIB); - } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) { - ConvertDebugDeclareToDebugValue(DDI, LI, DIB); - } else if (CallInst *CI = dyn_cast<CallInst>(U)) { - // This is a call by-value or some other instruction that - // takes a pointer to the variable. Insert a *value* - // intrinsic that describes the alloca. - DIB.insertDbgValueIntrinsic(AI, DDI->getVariable(), - DDI->getExpression(), DDI->getDebugLoc(), - CI); - } + if (!AI || isArray(AI)) + continue; + + // A volatile load/store means that the alloca can't be elided anyway. + if (llvm::any_of(AI->users(), [](User *U) -> bool { + if (LoadInst *LI = dyn_cast<LoadInst>(U)) + return LI->isVolatile(); + if (StoreInst *SI = dyn_cast<StoreInst>(U)) + return SI->isVolatile(); + return false; + })) + continue; + + for (auto &AIUse : AI->uses()) { + User *U = AIUse.getUser(); + if (StoreInst *SI = dyn_cast<StoreInst>(U)) { + if (AIUse.getOperandNo() == 1) + ConvertDebugDeclareToDebugValue(DDI, SI, DIB); + } else if (LoadInst *LI = dyn_cast<LoadInst>(U)) { + ConvertDebugDeclareToDebugValue(DDI, LI, DIB); + } else if (CallInst *CI = dyn_cast<CallInst>(U)) { + // This is a call by-value or some other instruction that takes a + // pointer to the variable. Insert a *value* intrinsic that describes + // the variable by dereferencing the alloca. + auto *DerefExpr = + DIExpression::append(DDI->getExpression(), dwarf::DW_OP_deref); + DIB.insertDbgValueIntrinsic(AI, DDI->getVariable(), DerefExpr, + DDI->getDebugLoc(), CI); } - DDI->eraseFromParent(); } + DDI->eraseFromParent(); } return true; } +/// Propagate dbg.value intrinsics through the newly inserted PHIs. +void llvm::insertDebugValuesForPHIs(BasicBlock *BB, + SmallVectorImpl<PHINode *> &InsertedPHIs) { + assert(BB && "No BasicBlock to clone dbg.value(s) from."); + if (InsertedPHIs.size() == 0) + return; + + // Map existing PHI nodes to their dbg.values. + ValueToValueMapTy DbgValueMap; + for (auto &I : *BB) { + if (auto DbgII = dyn_cast<DbgInfoIntrinsic>(&I)) { + if (auto *Loc = dyn_cast_or_null<PHINode>(DbgII->getVariableLocation())) + DbgValueMap.insert({Loc, DbgII}); + } + } + if (DbgValueMap.size() == 0) + return; + + // Then iterate through the new PHIs and look to see if they use one of the + // previously mapped PHIs. If so, insert a new dbg.value intrinsic that will + // propagate the info through the new PHI. + LLVMContext &C = BB->getContext(); + for (auto PHI : InsertedPHIs) { + BasicBlock *Parent = PHI->getParent(); + // Avoid inserting an intrinsic into an EH block. + if (Parent->getFirstNonPHI()->isEHPad()) + continue; + auto PhiMAV = MetadataAsValue::get(C, ValueAsMetadata::get(PHI)); + for (auto VI : PHI->operand_values()) { + auto V = DbgValueMap.find(VI); + if (V != DbgValueMap.end()) { + auto *DbgII = cast<DbgInfoIntrinsic>(V->second); + Instruction *NewDbgII = DbgII->clone(); + NewDbgII->setOperand(0, PhiMAV); + auto InsertionPt = Parent->getFirstInsertionPt(); + assert(InsertionPt != Parent->end() && "Ill-formed basic block"); + NewDbgII->insertBefore(&*InsertionPt); + } + } + } +} + /// Finds all intrinsics declaring local variables as living in the memory that /// 'V' points to. This may include a mix of dbg.declare and /// dbg.addr intrinsics. TinyPtrVector<DbgInfoIntrinsic *> llvm::FindDbgAddrUses(Value *V) { + // This function is hot. Check whether the value has any metadata to avoid a + // DenseMap lookup. + if (!V->isUsedByMetadata()) + return {}; auto *L = LocalAsMetadata::getIfExists(V); if (!L) return {}; @@ -1286,6 +1501,10 @@ TinyPtrVector<DbgInfoIntrinsic *> llvm::FindDbgAddrUses(Value *V) { } void llvm::findDbgValues(SmallVectorImpl<DbgValueInst *> &DbgValues, Value *V) { + // This function is hot. Check whether the value has any metadata to avoid a + // DenseMap lookup. + if (!V->isUsedByMetadata()) + return; if (auto *L = LocalAsMetadata::getIfExists(V)) if (auto *MDV = MetadataAsValue::getIfExists(V->getContext(), L)) for (User *U : MDV->users()) @@ -1293,8 +1512,12 @@ void llvm::findDbgValues(SmallVectorImpl<DbgValueInst *> &DbgValues, Value *V) { DbgValues.push_back(DVI); } -static void findDbgUsers(SmallVectorImpl<DbgInfoIntrinsic *> &DbgUsers, - Value *V) { +void llvm::findDbgUsers(SmallVectorImpl<DbgInfoIntrinsic *> &DbgUsers, + Value *V) { + // This function is hot. Check whether the value has any metadata to avoid a + // DenseMap lookup. + if (!V->isUsedByMetadata()) + return; if (auto *L = LocalAsMetadata::getIfExists(V)) if (auto *MDV = MetadataAsValue::getIfExists(V->getContext(), L)) for (User *U : MDV->users()) @@ -1312,11 +1535,11 @@ bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress, auto *DIExpr = DII->getExpression(); assert(DIVar && "Missing variable"); DIExpr = DIExpression::prepend(DIExpr, DerefBefore, Offset, DerefAfter); - // Insert llvm.dbg.declare immediately after InsertBefore, and remove old + // Insert llvm.dbg.declare immediately before InsertBefore, and remove old // llvm.dbg.declare. Builder.insertDeclare(NewAddress, DIVar, DIExpr, Loc, InsertBefore); if (DII == InsertBefore) - InsertBefore = &*std::next(InsertBefore->getIterator()); + InsertBefore = InsertBefore->getNextNode(); DII->eraseFromParent(); } return !DbgAddrs.empty(); @@ -1368,66 +1591,293 @@ void llvm::replaceDbgValueForAlloca(AllocaInst *AI, Value *NewAllocaAddress, } } -void llvm::salvageDebugInfo(Instruction &I) { - SmallVector<DbgValueInst *, 1> DbgValues; +/// Wrap \p V in a ValueAsMetadata instance. +static MetadataAsValue *wrapValueInMetadata(LLVMContext &C, Value *V) { + return MetadataAsValue::get(C, ValueAsMetadata::get(V)); +} + +bool llvm::salvageDebugInfo(Instruction &I) { + SmallVector<DbgInfoIntrinsic *, 1> DbgUsers; + findDbgUsers(DbgUsers, &I); + if (DbgUsers.empty()) + return false; + auto &M = *I.getModule(); + auto &DL = M.getDataLayout(); + auto &Ctx = I.getContext(); + auto wrapMD = [&](Value *V) { return wrapValueInMetadata(Ctx, V); }; - auto wrapMD = [&](Value *V) { - return MetadataAsValue::get(I.getContext(), ValueAsMetadata::get(V)); + auto doSalvage = [&](DbgInfoIntrinsic *DII, SmallVectorImpl<uint64_t> &Ops) { + auto *DIExpr = DII->getExpression(); + if (!Ops.empty()) { + // Do not add DW_OP_stack_value for DbgDeclare and DbgAddr, because they + // are implicitly pointing out the value as a DWARF memory location + // description. + bool WithStackValue = isa<DbgValueInst>(DII); + DIExpr = DIExpression::prependOpcodes(DIExpr, Ops, WithStackValue); + } + DII->setOperand(0, wrapMD(I.getOperand(0))); + DII->setOperand(2, MetadataAsValue::get(Ctx, DIExpr)); + LLVM_DEBUG(dbgs() << "SALVAGE: " << *DII << '\n'); }; - auto applyOffset = [&](DbgValueInst *DVI, uint64_t Offset) { - auto *DIExpr = DVI->getExpression(); - DIExpr = DIExpression::prepend(DIExpr, DIExpression::NoDeref, Offset, - DIExpression::NoDeref, - DIExpression::WithStackValue); - DVI->setOperand(0, wrapMD(I.getOperand(0))); - DVI->setOperand(2, MetadataAsValue::get(I.getContext(), DIExpr)); - DEBUG(dbgs() << "SALVAGE: " << *DVI << '\n'); + auto applyOffset = [&](DbgInfoIntrinsic *DII, uint64_t Offset) { + SmallVector<uint64_t, 8> Ops; + DIExpression::appendOffset(Ops, Offset); + doSalvage(DII, Ops); }; - if (isa<BitCastInst>(&I) || isa<IntToPtrInst>(&I)) { - // Bitcasts are entirely irrelevant for debug info. Rewrite dbg.value, - // dbg.addr, and dbg.declare to use the cast's source. - SmallVector<DbgInfoIntrinsic *, 1> DbgUsers; - findDbgUsers(DbgUsers, &I); + auto applyOps = [&](DbgInfoIntrinsic *DII, + std::initializer_list<uint64_t> Opcodes) { + SmallVector<uint64_t, 8> Ops(Opcodes); + doSalvage(DII, Ops); + }; + + if (auto *CI = dyn_cast<CastInst>(&I)) { + if (!CI->isNoopCast(DL)) + return false; + + // No-op casts are irrelevant for debug info. + MetadataAsValue *CastSrc = wrapMD(I.getOperand(0)); for (auto *DII : DbgUsers) { - DII->setOperand(0, wrapMD(I.getOperand(0))); - DEBUG(dbgs() << "SALVAGE: " << *DII << '\n'); + DII->setOperand(0, CastSrc); + LLVM_DEBUG(dbgs() << "SALVAGE: " << *DII << '\n'); } + return true; } else if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { - findDbgValues(DbgValues, &I); - for (auto *DVI : DbgValues) { - unsigned BitWidth = - M.getDataLayout().getPointerSizeInBits(GEP->getPointerAddressSpace()); - APInt Offset(BitWidth, 0); - // Rewrite a constant GEP into a DIExpression. Since we are performing - // arithmetic to compute the variable's *value* in the DIExpression, we - // need to mark the expression with a DW_OP_stack_value. - if (GEP->accumulateConstantOffset(M.getDataLayout(), Offset)) - // GEP offsets are i32 and thus always fit into an int64_t. - applyOffset(DVI, Offset.getSExtValue()); - } + unsigned BitWidth = + M.getDataLayout().getIndexSizeInBits(GEP->getPointerAddressSpace()); + // Rewrite a constant GEP into a DIExpression. Since we are performing + // arithmetic to compute the variable's *value* in the DIExpression, we + // need to mark the expression with a DW_OP_stack_value. + APInt Offset(BitWidth, 0); + if (GEP->accumulateConstantOffset(M.getDataLayout(), Offset)) + for (auto *DII : DbgUsers) + applyOffset(DII, Offset.getSExtValue()); + return true; } else if (auto *BI = dyn_cast<BinaryOperator>(&I)) { - if (BI->getOpcode() == Instruction::Add) - if (auto *ConstInt = dyn_cast<ConstantInt>(I.getOperand(1))) - if (ConstInt->getBitWidth() <= 64) { - APInt Offset = ConstInt->getValue(); - findDbgValues(DbgValues, &I); - for (auto *DVI : DbgValues) - applyOffset(DVI, Offset.getSExtValue()); - } + // Rewrite binary operations with constant integer operands. + auto *ConstInt = dyn_cast<ConstantInt>(I.getOperand(1)); + if (!ConstInt || ConstInt->getBitWidth() > 64) + return false; + + uint64_t Val = ConstInt->getSExtValue(); + for (auto *DII : DbgUsers) { + switch (BI->getOpcode()) { + case Instruction::Add: + applyOffset(DII, Val); + break; + case Instruction::Sub: + applyOffset(DII, -int64_t(Val)); + break; + case Instruction::Mul: + applyOps(DII, {dwarf::DW_OP_constu, Val, dwarf::DW_OP_mul}); + break; + case Instruction::SDiv: + applyOps(DII, {dwarf::DW_OP_constu, Val, dwarf::DW_OP_div}); + break; + case Instruction::SRem: + applyOps(DII, {dwarf::DW_OP_constu, Val, dwarf::DW_OP_mod}); + break; + case Instruction::Or: + applyOps(DII, {dwarf::DW_OP_constu, Val, dwarf::DW_OP_or}); + break; + case Instruction::And: + applyOps(DII, {dwarf::DW_OP_constu, Val, dwarf::DW_OP_and}); + break; + case Instruction::Xor: + applyOps(DII, {dwarf::DW_OP_constu, Val, dwarf::DW_OP_xor}); + break; + case Instruction::Shl: + applyOps(DII, {dwarf::DW_OP_constu, Val, dwarf::DW_OP_shl}); + break; + case Instruction::LShr: + applyOps(DII, {dwarf::DW_OP_constu, Val, dwarf::DW_OP_shr}); + break; + case Instruction::AShr: + applyOps(DII, {dwarf::DW_OP_constu, Val, dwarf::DW_OP_shra}); + break; + default: + // TODO: Salvage constants from each kind of binop we know about. + return false; + } + } + return true; } else if (isa<LoadInst>(&I)) { - findDbgValues(DbgValues, &I); - for (auto *DVI : DbgValues) { + MetadataAsValue *AddrMD = wrapMD(I.getOperand(0)); + for (auto *DII : DbgUsers) { // Rewrite the load into DW_OP_deref. - auto *DIExpr = DVI->getExpression(); + auto *DIExpr = DII->getExpression(); DIExpr = DIExpression::prepend(DIExpr, DIExpression::WithDeref); - DVI->setOperand(0, wrapMD(I.getOperand(0))); - DVI->setOperand(2, MetadataAsValue::get(I.getContext(), DIExpr)); - DEBUG(dbgs() << "SALVAGE: " << *DVI << '\n'); + DII->setOperand(0, AddrMD); + DII->setOperand(2, MetadataAsValue::get(Ctx, DIExpr)); + LLVM_DEBUG(dbgs() << "SALVAGE: " << *DII << '\n'); + } + return true; + } + return false; +} + +/// A replacement for a dbg.value expression. +using DbgValReplacement = Optional<DIExpression *>; + +/// Point debug users of \p From to \p To using exprs given by \p RewriteExpr, +/// possibly moving/deleting users to prevent use-before-def. Returns true if +/// changes are made. +static bool rewriteDebugUsers( + Instruction &From, Value &To, Instruction &DomPoint, DominatorTree &DT, + function_ref<DbgValReplacement(DbgInfoIntrinsic &DII)> RewriteExpr) { + // Find debug users of From. + SmallVector<DbgInfoIntrinsic *, 1> Users; + findDbgUsers(Users, &From); + if (Users.empty()) + return false; + + // Prevent use-before-def of To. + bool Changed = false; + SmallPtrSet<DbgInfoIntrinsic *, 1> DeleteOrSalvage; + if (isa<Instruction>(&To)) { + bool DomPointAfterFrom = From.getNextNonDebugInstruction() == &DomPoint; + + for (auto *DII : Users) { + // It's common to see a debug user between From and DomPoint. Move it + // after DomPoint to preserve the variable update without any reordering. + if (DomPointAfterFrom && DII->getNextNonDebugInstruction() == &DomPoint) { + LLVM_DEBUG(dbgs() << "MOVE: " << *DII << '\n'); + DII->moveAfter(&DomPoint); + Changed = true; + + // Users which otherwise aren't dominated by the replacement value must + // be salvaged or deleted. + } else if (!DT.dominates(&DomPoint, DII)) { + DeleteOrSalvage.insert(DII); + } } } + + // Update debug users without use-before-def risk. + for (auto *DII : Users) { + if (DeleteOrSalvage.count(DII)) + continue; + + LLVMContext &Ctx = DII->getContext(); + DbgValReplacement DVR = RewriteExpr(*DII); + if (!DVR) + continue; + + DII->setOperand(0, wrapValueInMetadata(Ctx, &To)); + DII->setOperand(2, MetadataAsValue::get(Ctx, *DVR)); + LLVM_DEBUG(dbgs() << "REWRITE: " << *DII << '\n'); + Changed = true; + } + + if (!DeleteOrSalvage.empty()) { + // Try to salvage the remaining debug users. + Changed |= salvageDebugInfo(From); + + // Delete the debug users which weren't salvaged. + for (auto *DII : DeleteOrSalvage) { + if (DII->getVariableLocation() == &From) { + LLVM_DEBUG(dbgs() << "Erased UseBeforeDef: " << *DII << '\n'); + DII->eraseFromParent(); + Changed = true; + } + } + } + + return Changed; +} + +/// Check if a bitcast between a value of type \p FromTy to type \p ToTy would +/// losslessly preserve the bits and semantics of the value. This predicate is +/// symmetric, i.e swapping \p FromTy and \p ToTy should give the same result. +/// +/// Note that Type::canLosslesslyBitCastTo is not suitable here because it +/// allows semantically unequivalent bitcasts, such as <2 x i64> -> <4 x i32>, +/// and also does not allow lossless pointer <-> integer conversions. +static bool isBitCastSemanticsPreserving(const DataLayout &DL, Type *FromTy, + Type *ToTy) { + // Trivially compatible types. + if (FromTy == ToTy) + return true; + + // Handle compatible pointer <-> integer conversions. + if (FromTy->isIntOrPtrTy() && ToTy->isIntOrPtrTy()) { + bool SameSize = DL.getTypeSizeInBits(FromTy) == DL.getTypeSizeInBits(ToTy); + bool LosslessConversion = !DL.isNonIntegralPointerType(FromTy) && + !DL.isNonIntegralPointerType(ToTy); + return SameSize && LosslessConversion; + } + + // TODO: This is not exhaustive. + return false; +} + +bool llvm::replaceAllDbgUsesWith(Instruction &From, Value &To, + Instruction &DomPoint, DominatorTree &DT) { + // Exit early if From has no debug users. + if (!From.isUsedByMetadata()) + return false; + + assert(&From != &To && "Can't replace something with itself"); + + Type *FromTy = From.getType(); + Type *ToTy = To.getType(); + + auto Identity = [&](DbgInfoIntrinsic &DII) -> DbgValReplacement { + return DII.getExpression(); + }; + + // Handle no-op conversions. + Module &M = *From.getModule(); + const DataLayout &DL = M.getDataLayout(); + if (isBitCastSemanticsPreserving(DL, FromTy, ToTy)) + return rewriteDebugUsers(From, To, DomPoint, DT, Identity); + + // Handle integer-to-integer widening and narrowing. + // FIXME: Use DW_OP_convert when it's available everywhere. + if (FromTy->isIntegerTy() && ToTy->isIntegerTy()) { + uint64_t FromBits = FromTy->getPrimitiveSizeInBits(); + uint64_t ToBits = ToTy->getPrimitiveSizeInBits(); + assert(FromBits != ToBits && "Unexpected no-op conversion"); + + // When the width of the result grows, assume that a debugger will only + // access the low `FromBits` bits when inspecting the source variable. + if (FromBits < ToBits) + return rewriteDebugUsers(From, To, DomPoint, DT, Identity); + + // The width of the result has shrunk. Use sign/zero extension to describe + // the source variable's high bits. + auto SignOrZeroExt = [&](DbgInfoIntrinsic &DII) -> DbgValReplacement { + DILocalVariable *Var = DII.getVariable(); + + // Without knowing signedness, sign/zero extension isn't possible. + auto Signedness = Var->getSignedness(); + if (!Signedness) + return None; + + bool Signed = *Signedness == DIBasicType::Signedness::Signed; + + if (!Signed) { + // In the unsigned case, assume that a debugger will initialize the + // high bits to 0 and do a no-op conversion. + return Identity(DII); + } else { + // In the signed case, the high bits are given by sign extension, i.e: + // (To >> (ToBits - 1)) * ((2 ^ FromBits) - 1) + // Calculate the high bits and OR them together with the low bits. + SmallVector<uint64_t, 8> Ops({dwarf::DW_OP_dup, dwarf::DW_OP_constu, + (ToBits - 1), dwarf::DW_OP_shr, + dwarf::DW_OP_lit0, dwarf::DW_OP_not, + dwarf::DW_OP_mul, dwarf::DW_OP_or}); + return DIExpression::appendToStack(DII.getExpression(), Ops); + } + }; + return rewriteDebugUsers(From, To, DomPoint, DT, SignOrZeroExt); + } + + // TODO: Floating-point conversions, vectors. + return false; } unsigned llvm::removeAllNonTerminatorAndEHPadInstructions(BasicBlock *BB) { @@ -1452,13 +1902,19 @@ unsigned llvm::removeAllNonTerminatorAndEHPadInstructions(BasicBlock *BB) { } unsigned llvm::changeToUnreachable(Instruction *I, bool UseLLVMTrap, - bool PreserveLCSSA) { + bool PreserveLCSSA, DeferredDominance *DDT) { BasicBlock *BB = I->getParent(); + std::vector <DominatorTree::UpdateType> Updates; + // Loop over all of the successors, removing BB's entry from any PHI // nodes. - for (BasicBlock *Successor : successors(BB)) + if (DDT) + Updates.reserve(BB->getTerminator()->getNumSuccessors()); + for (BasicBlock *Successor : successors(BB)) { Successor->removePredecessor(BB, PreserveLCSSA); - + if (DDT) + Updates.push_back({DominatorTree::Delete, BB, Successor}); + } // Insert a call to llvm.trap right before this. This turns the undefined // behavior into a hard fail instead of falling through into random code. if (UseLLVMTrap) { @@ -1478,11 +1934,13 @@ unsigned llvm::changeToUnreachable(Instruction *I, bool UseLLVMTrap, BB->getInstList().erase(BBI++); ++NumInstrsRemoved; } + if (DDT) + DDT->applyUpdates(Updates); return NumInstrsRemoved; } /// changeToCall - Convert the specified invoke into a normal call. -static void changeToCall(InvokeInst *II) { +static void changeToCall(InvokeInst *II, DeferredDominance *DDT = nullptr) { SmallVector<Value*, 8> Args(II->arg_begin(), II->arg_end()); SmallVector<OperandBundleDef, 1> OpBundles; II->getOperandBundlesAsDefs(OpBundles); @@ -1495,11 +1953,16 @@ static void changeToCall(InvokeInst *II) { II->replaceAllUsesWith(NewCall); // Follow the call by a branch to the normal destination. - BranchInst::Create(II->getNormalDest(), II); + BasicBlock *NormalDestBB = II->getNormalDest(); + BranchInst::Create(NormalDestBB, II); // Update PHI nodes in the unwind destination - II->getUnwindDest()->removePredecessor(II->getParent()); + BasicBlock *BB = II->getParent(); + BasicBlock *UnwindDestBB = II->getUnwindDest(); + UnwindDestBB->removePredecessor(BB); II->eraseFromParent(); + if (DDT) + DDT->deleteEdge(BB, UnwindDestBB); } BasicBlock *llvm::changeToInvokeAndSplitBasicBlock(CallInst *CI, @@ -1540,7 +2003,8 @@ BasicBlock *llvm::changeToInvokeAndSplitBasicBlock(CallInst *CI, } static bool markAliveBlocks(Function &F, - SmallPtrSetImpl<BasicBlock*> &Reachable) { + SmallPtrSetImpl<BasicBlock*> &Reachable, + DeferredDominance *DDT = nullptr) { SmallVector<BasicBlock*, 128> Worklist; BasicBlock *BB = &F.front(); Worklist.push_back(BB); @@ -1553,41 +2017,44 @@ static bool markAliveBlocks(Function &F, // instructions into LLVM unreachable insts. The instruction combining pass // canonicalizes unreachable insts into stores to null or undef. for (Instruction &I : *BB) { - // Assumptions that are known to be false are equivalent to unreachable. - // Also, if the condition is undefined, then we make the choice most - // beneficial to the optimizer, and choose that to also be unreachable. - if (auto *II = dyn_cast<IntrinsicInst>(&I)) { - if (II->getIntrinsicID() == Intrinsic::assume) { - if (match(II->getArgOperand(0), m_CombineOr(m_Zero(), m_Undef()))) { - // Don't insert a call to llvm.trap right before the unreachable. - changeToUnreachable(II, false); - Changed = true; - break; - } - } - - if (II->getIntrinsicID() == Intrinsic::experimental_guard) { - // A call to the guard intrinsic bails out of the current compilation - // unit if the predicate passed to it is false. If the predicate is a - // constant false, then we know the guard will bail out of the current - // compile unconditionally, so all code following it is dead. - // - // Note: unlike in llvm.assume, it is not "obviously profitable" for - // guards to treat `undef` as `false` since a guard on `undef` can - // still be useful for widening. - if (match(II->getArgOperand(0), m_Zero())) - if (!isa<UnreachableInst>(II->getNextNode())) { - changeToUnreachable(II->getNextNode(), /*UseLLVMTrap=*/ false); + if (auto *CI = dyn_cast<CallInst>(&I)) { + Value *Callee = CI->getCalledValue(); + // Handle intrinsic calls. + if (Function *F = dyn_cast<Function>(Callee)) { + auto IntrinsicID = F->getIntrinsicID(); + // Assumptions that are known to be false are equivalent to + // unreachable. Also, if the condition is undefined, then we make the + // choice most beneficial to the optimizer, and choose that to also be + // unreachable. + if (IntrinsicID == Intrinsic::assume) { + if (match(CI->getArgOperand(0), m_CombineOr(m_Zero(), m_Undef()))) { + // Don't insert a call to llvm.trap right before the unreachable. + changeToUnreachable(CI, false, false, DDT); Changed = true; break; } - } - } - - if (auto *CI = dyn_cast<CallInst>(&I)) { - Value *Callee = CI->getCalledValue(); - if (isa<ConstantPointerNull>(Callee) || isa<UndefValue>(Callee)) { - changeToUnreachable(CI, /*UseLLVMTrap=*/false); + } else if (IntrinsicID == Intrinsic::experimental_guard) { + // A call to the guard intrinsic bails out of the current + // compilation unit if the predicate passed to it is false. If the + // predicate is a constant false, then we know the guard will bail + // out of the current compile unconditionally, so all code following + // it is dead. + // + // Note: unlike in llvm.assume, it is not "obviously profitable" for + // guards to treat `undef` as `false` since a guard on `undef` can + // still be useful for widening. + if (match(CI->getArgOperand(0), m_Zero())) + if (!isa<UnreachableInst>(CI->getNextNode())) { + changeToUnreachable(CI->getNextNode(), /*UseLLVMTrap=*/false, + false, DDT); + Changed = true; + break; + } + } + } else if ((isa<ConstantPointerNull>(Callee) && + !NullPointerIsDefined(CI->getFunction())) || + isa<UndefValue>(Callee)) { + changeToUnreachable(CI, /*UseLLVMTrap=*/false, false, DDT); Changed = true; break; } @@ -1597,17 +2064,16 @@ static bool markAliveBlocks(Function &F, // though. if (!isa<UnreachableInst>(CI->getNextNode())) { // Don't insert a call to llvm.trap right before the unreachable. - changeToUnreachable(CI->getNextNode(), false); + changeToUnreachable(CI->getNextNode(), false, false, DDT); Changed = true; } break; } - } + } else if (auto *SI = dyn_cast<StoreInst>(&I)) { + // Store to undef and store to null are undefined and used to signal + // that they should be changed to unreachable by passes that can't + // modify the CFG. - // Store to undef and store to null are undefined and used to signal that - // they should be changed to unreachable by passes that can't modify the - // CFG. - if (auto *SI = dyn_cast<StoreInst>(&I)) { // Don't touch volatile stores. if (SI->isVolatile()) continue; @@ -1615,8 +2081,9 @@ static bool markAliveBlocks(Function &F, if (isa<UndefValue>(Ptr) || (isa<ConstantPointerNull>(Ptr) && - SI->getPointerAddressSpace() == 0)) { - changeToUnreachable(SI, true); + !NullPointerIsDefined(SI->getFunction(), + SI->getPointerAddressSpace()))) { + changeToUnreachable(SI, true, false, DDT); Changed = true; break; } @@ -1627,17 +2094,23 @@ static bool markAliveBlocks(Function &F, if (auto *II = dyn_cast<InvokeInst>(Terminator)) { // Turn invokes that call 'nounwind' functions into ordinary calls. Value *Callee = II->getCalledValue(); - if (isa<ConstantPointerNull>(Callee) || isa<UndefValue>(Callee)) { - changeToUnreachable(II, true); + if ((isa<ConstantPointerNull>(Callee) && + !NullPointerIsDefined(BB->getParent())) || + isa<UndefValue>(Callee)) { + changeToUnreachable(II, true, false, DDT); Changed = true; } else if (II->doesNotThrow() && canSimplifyInvokeNoUnwind(&F)) { if (II->use_empty() && II->onlyReadsMemory()) { // jump to the normal destination branch. - BranchInst::Create(II->getNormalDest(), II); - II->getUnwindDest()->removePredecessor(II->getParent()); + BasicBlock *NormalDestBB = II->getNormalDest(); + BasicBlock *UnwindDestBB = II->getUnwindDest(); + BranchInst::Create(NormalDestBB, II); + UnwindDestBB->removePredecessor(II->getParent()); II->eraseFromParent(); + if (DDT) + DDT->deleteEdge(BB, UnwindDestBB); } else - changeToCall(II); + changeToCall(II, DDT); Changed = true; } } else if (auto *CatchSwitch = dyn_cast<CatchSwitchInst>(Terminator)) { @@ -1683,7 +2156,7 @@ static bool markAliveBlocks(Function &F, } } - Changed |= ConstantFoldTerminator(BB, true); + Changed |= ConstantFoldTerminator(BB, true, nullptr, DDT); for (BasicBlock *Successor : successors(BB)) if (Reachable.insert(Successor).second) Worklist.push_back(Successor); @@ -1691,11 +2164,11 @@ static bool markAliveBlocks(Function &F, return Changed; } -void llvm::removeUnwindEdge(BasicBlock *BB) { +void llvm::removeUnwindEdge(BasicBlock *BB, DeferredDominance *DDT) { TerminatorInst *TI = BB->getTerminator(); if (auto *II = dyn_cast<InvokeInst>(TI)) { - changeToCall(II); + changeToCall(II, DDT); return; } @@ -1723,15 +2196,18 @@ void llvm::removeUnwindEdge(BasicBlock *BB) { UnwindDest->removePredecessor(BB); TI->replaceAllUsesWith(NewTI); TI->eraseFromParent(); + if (DDT) + DDT->deleteEdge(BB, UnwindDest); } /// removeUnreachableBlocks - Remove blocks that are not reachable, even /// if they are in a dead cycle. Return true if a change was made, false /// otherwise. If `LVI` is passed, this function preserves LazyValueInfo /// after modifying the CFG. -bool llvm::removeUnreachableBlocks(Function &F, LazyValueInfo *LVI) { +bool llvm::removeUnreachableBlocks(Function &F, LazyValueInfo *LVI, + DeferredDominance *DDT) { SmallPtrSet<BasicBlock*, 16> Reachable; - bool Changed = markAliveBlocks(F, Reachable); + bool Changed = markAliveBlocks(F, Reachable, DDT); // If there are unreachable blocks in the CFG... if (Reachable.size() == F.size()) @@ -1741,25 +2217,39 @@ bool llvm::removeUnreachableBlocks(Function &F, LazyValueInfo *LVI) { NumRemoved += F.size()-Reachable.size(); // Loop over all of the basic blocks that are not reachable, dropping all of - // their internal references... - for (Function::iterator BB = ++F.begin(), E = F.end(); BB != E; ++BB) { - if (Reachable.count(&*BB)) + // their internal references. Update DDT and LVI if available. + std::vector <DominatorTree::UpdateType> Updates; + for (Function::iterator I = ++F.begin(), E = F.end(); I != E; ++I) { + auto *BB = &*I; + if (Reachable.count(BB)) continue; - - for (BasicBlock *Successor : successors(&*BB)) + for (BasicBlock *Successor : successors(BB)) { if (Reachable.count(Successor)) - Successor->removePredecessor(&*BB); + Successor->removePredecessor(BB); + if (DDT) + Updates.push_back({DominatorTree::Delete, BB, Successor}); + } if (LVI) - LVI->eraseBlock(&*BB); + LVI->eraseBlock(BB); BB->dropAllReferences(); } - for (Function::iterator I = ++F.begin(); I != F.end();) - if (!Reachable.count(&*I)) - I = F.getBasicBlockList().erase(I); - else + for (Function::iterator I = ++F.begin(); I != F.end();) { + auto *BB = &*I; + if (Reachable.count(BB)) { ++I; + continue; + } + if (DDT) { + DDT->deleteBB(BB); // deferred deletion of BB. + ++I; + } else { + I = F.getBasicBlockList().erase(I); + } + } + if (DDT) + DDT->applyUpdates(Updates); return true; } @@ -1852,8 +2342,8 @@ static unsigned replaceDominatedUsesWith(Value *From, Value *To, if (!Dominates(Root, U)) continue; U.set(To); - DEBUG(dbgs() << "Replace dominated use of '" << From->getName() << "' as " - << *To << " in " << *U << "\n"); + LLVM_DEBUG(dbgs() << "Replace dominated use of '" << From->getName() + << "' as " << *To << " in " << *U << "\n"); ++Count; } return Count; @@ -1957,7 +2447,7 @@ void llvm::copyRangeMetadata(const DataLayout &DL, const LoadInst &OldLI, if (!NewTy->isPointerTy()) return; - unsigned BitWidth = DL.getTypeSizeInBits(NewTy); + unsigned BitWidth = DL.getIndexTypeSizeInBits(NewTy); if (!getConstantRangeFromMetadata(*N).contains(APInt(BitWidth, 0))) { MDNode *NN = MDNode::get(OldLI.getContext(), None); NewLI.setMetadata(LLVMContext::MD_nonnull, NN); @@ -2269,7 +2759,7 @@ bool llvm::canReplaceOperandWithVariable(const Instruction *I, unsigned OpIdx) { // Static allocas (constant size in the entry block) are handled by // prologue/epilogue insertion so they're free anyway. We definitely don't // want to make them non-constant. - return !dyn_cast<AllocaInst>(I)->isStaticAlloca(); + return !cast<AllocaInst>(I)->isStaticAlloca(); case Instruction::GetElementPtr: if (OpIdx == 0) return true; diff --git a/lib/Transforms/Utils/LoopRotationUtils.cpp b/lib/Transforms/Utils/LoopRotationUtils.cpp new file mode 100644 index 000000000000..6e92e679f999 --- /dev/null +++ b/lib/Transforms/Utils/LoopRotationUtils.cpp @@ -0,0 +1,645 @@ +//===----------------- LoopRotationUtils.cpp -----------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file provides utilities to convert a loop into a loop with bottom test. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/LoopRotationUtils.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/SSAUpdater.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +using namespace llvm; + +#define DEBUG_TYPE "loop-rotate" + +STATISTIC(NumRotated, "Number of loops rotated"); + +namespace { +/// A simple loop rotation transformation. +class LoopRotate { + const unsigned MaxHeaderSize; + LoopInfo *LI; + const TargetTransformInfo *TTI; + AssumptionCache *AC; + DominatorTree *DT; + ScalarEvolution *SE; + const SimplifyQuery &SQ; + bool RotationOnly; + bool IsUtilMode; + +public: + LoopRotate(unsigned MaxHeaderSize, LoopInfo *LI, + const TargetTransformInfo *TTI, AssumptionCache *AC, + DominatorTree *DT, ScalarEvolution *SE, const SimplifyQuery &SQ, + bool RotationOnly, bool IsUtilMode) + : MaxHeaderSize(MaxHeaderSize), LI(LI), TTI(TTI), AC(AC), DT(DT), SE(SE), + SQ(SQ), RotationOnly(RotationOnly), IsUtilMode(IsUtilMode) {} + bool processLoop(Loop *L); + +private: + bool rotateLoop(Loop *L, bool SimplifiedLatch); + bool simplifyLoopLatch(Loop *L); +}; +} // end anonymous namespace + +/// RewriteUsesOfClonedInstructions - We just cloned the instructions from the +/// old header into the preheader. If there were uses of the values produced by +/// these instruction that were outside of the loop, we have to insert PHI nodes +/// to merge the two values. Do this now. +static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, + BasicBlock *OrigPreheader, + ValueToValueMapTy &ValueMap, + SmallVectorImpl<PHINode*> *InsertedPHIs) { + // Remove PHI node entries that are no longer live. + BasicBlock::iterator I, E = OrigHeader->end(); + for (I = OrigHeader->begin(); PHINode *PN = dyn_cast<PHINode>(I); ++I) + PN->removeIncomingValue(PN->getBasicBlockIndex(OrigPreheader)); + + // Now fix up users of the instructions in OrigHeader, inserting PHI nodes + // as necessary. + SSAUpdater SSA(InsertedPHIs); + for (I = OrigHeader->begin(); I != E; ++I) { + Value *OrigHeaderVal = &*I; + + // If there are no uses of the value (e.g. because it returns void), there + // is nothing to rewrite. + if (OrigHeaderVal->use_empty()) + continue; + + Value *OrigPreHeaderVal = ValueMap.lookup(OrigHeaderVal); + + // The value now exits in two versions: the initial value in the preheader + // and the loop "next" value in the original header. + SSA.Initialize(OrigHeaderVal->getType(), OrigHeaderVal->getName()); + SSA.AddAvailableValue(OrigHeader, OrigHeaderVal); + SSA.AddAvailableValue(OrigPreheader, OrigPreHeaderVal); + + // Visit each use of the OrigHeader instruction. + for (Value::use_iterator UI = OrigHeaderVal->use_begin(), + UE = OrigHeaderVal->use_end(); + UI != UE;) { + // Grab the use before incrementing the iterator. + Use &U = *UI; + + // Increment the iterator before removing the use from the list. + ++UI; + + // SSAUpdater can't handle a non-PHI use in the same block as an + // earlier def. We can easily handle those cases manually. + Instruction *UserInst = cast<Instruction>(U.getUser()); + if (!isa<PHINode>(UserInst)) { + BasicBlock *UserBB = UserInst->getParent(); + + // The original users in the OrigHeader are already using the + // original definitions. + if (UserBB == OrigHeader) + continue; + + // Users in the OrigPreHeader need to use the value to which the + // original definitions are mapped. + if (UserBB == OrigPreheader) { + U = OrigPreHeaderVal; + continue; + } + } + + // Anything else can be handled by SSAUpdater. + SSA.RewriteUse(U); + } + + // Replace MetadataAsValue(ValueAsMetadata(OrigHeaderVal)) uses in debug + // intrinsics. + SmallVector<DbgValueInst *, 1> DbgValues; + llvm::findDbgValues(DbgValues, OrigHeaderVal); + for (auto &DbgValue : DbgValues) { + // The original users in the OrigHeader are already using the original + // definitions. + BasicBlock *UserBB = DbgValue->getParent(); + if (UserBB == OrigHeader) + continue; + + // Users in the OrigPreHeader need to use the value to which the + // original definitions are mapped and anything else can be handled by + // the SSAUpdater. To avoid adding PHINodes, check if the value is + // available in UserBB, if not substitute undef. + Value *NewVal; + if (UserBB == OrigPreheader) + NewVal = OrigPreHeaderVal; + else if (SSA.HasValueForBlock(UserBB)) + NewVal = SSA.GetValueInMiddleOfBlock(UserBB); + else + NewVal = UndefValue::get(OrigHeaderVal->getType()); + DbgValue->setOperand(0, + MetadataAsValue::get(OrigHeaderVal->getContext(), + ValueAsMetadata::get(NewVal))); + } + } +} + +// Look for a phi which is only used outside the loop (via a LCSSA phi) +// in the exit from the header. This means that rotating the loop can +// remove the phi. +static bool shouldRotateLoopExitingLatch(Loop *L) { + BasicBlock *Header = L->getHeader(); + BasicBlock *HeaderExit = Header->getTerminator()->getSuccessor(0); + if (L->contains(HeaderExit)) + HeaderExit = Header->getTerminator()->getSuccessor(1); + + for (auto &Phi : Header->phis()) { + // Look for uses of this phi in the loop/via exits other than the header. + if (llvm::any_of(Phi.users(), [HeaderExit](const User *U) { + return cast<Instruction>(U)->getParent() != HeaderExit; + })) + continue; + return true; + } + + return false; +} + +/// Rotate loop LP. Return true if the loop is rotated. +/// +/// \param SimplifiedLatch is true if the latch was just folded into the final +/// loop exit. In this case we may want to rotate even though the new latch is +/// now an exiting branch. This rotation would have happened had the latch not +/// been simplified. However, if SimplifiedLatch is false, then we avoid +/// rotating loops in which the latch exits to avoid excessive or endless +/// rotation. LoopRotate should be repeatable and converge to a canonical +/// form. This property is satisfied because simplifying the loop latch can only +/// happen once across multiple invocations of the LoopRotate pass. +bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { + // If the loop has only one block then there is not much to rotate. + if (L->getBlocks().size() == 1) + return false; + + BasicBlock *OrigHeader = L->getHeader(); + BasicBlock *OrigLatch = L->getLoopLatch(); + + BranchInst *BI = dyn_cast<BranchInst>(OrigHeader->getTerminator()); + if (!BI || BI->isUnconditional()) + return false; + + // If the loop header is not one of the loop exiting blocks then + // either this loop is already rotated or it is not + // suitable for loop rotation transformations. + if (!L->isLoopExiting(OrigHeader)) + return false; + + // If the loop latch already contains a branch that leaves the loop then the + // loop is already rotated. + if (!OrigLatch) + return false; + + // Rotate if either the loop latch does *not* exit the loop, or if the loop + // latch was just simplified. Or if we think it will be profitable. + if (L->isLoopExiting(OrigLatch) && !SimplifiedLatch && IsUtilMode == false && + !shouldRotateLoopExitingLatch(L)) + return false; + + // Check size of original header and reject loop if it is very big or we can't + // duplicate blocks inside it. + { + SmallPtrSet<const Value *, 32> EphValues; + CodeMetrics::collectEphemeralValues(L, AC, EphValues); + + CodeMetrics Metrics; + Metrics.analyzeBasicBlock(OrigHeader, *TTI, EphValues); + if (Metrics.notDuplicatable) { + LLVM_DEBUG( + dbgs() << "LoopRotation: NOT rotating - contains non-duplicatable" + << " instructions: "; + L->dump()); + return false; + } + if (Metrics.convergent) { + LLVM_DEBUG(dbgs() << "LoopRotation: NOT rotating - contains convergent " + "instructions: "; + L->dump()); + return false; + } + if (Metrics.NumInsts > MaxHeaderSize) + return false; + } + + // Now, this loop is suitable for rotation. + BasicBlock *OrigPreheader = L->getLoopPreheader(); + + // If the loop could not be converted to canonical form, it must have an + // indirectbr in it, just give up. + if (!OrigPreheader || !L->hasDedicatedExits()) + return false; + + // Anything ScalarEvolution may know about this loop or the PHI nodes + // in its header will soon be invalidated. We should also invalidate + // all outer loops because insertion and deletion of blocks that happens + // during the rotation may violate invariants related to backedge taken + // infos in them. + if (SE) + SE->forgetTopmostLoop(L); + + LLVM_DEBUG(dbgs() << "LoopRotation: rotating "; L->dump()); + + // Find new Loop header. NewHeader is a Header's one and only successor + // that is inside loop. Header's other successor is outside the + // loop. Otherwise loop is not suitable for rotation. + BasicBlock *Exit = BI->getSuccessor(0); + BasicBlock *NewHeader = BI->getSuccessor(1); + if (L->contains(Exit)) + std::swap(Exit, NewHeader); + assert(NewHeader && "Unable to determine new loop header"); + assert(L->contains(NewHeader) && !L->contains(Exit) && + "Unable to determine loop header and exit blocks"); + + // This code assumes that the new header has exactly one predecessor. + // Remove any single-entry PHI nodes in it. + assert(NewHeader->getSinglePredecessor() && + "New header doesn't have one pred!"); + FoldSingleEntryPHINodes(NewHeader); + + // Begin by walking OrigHeader and populating ValueMap with an entry for + // each Instruction. + BasicBlock::iterator I = OrigHeader->begin(), E = OrigHeader->end(); + ValueToValueMapTy ValueMap; + + // For PHI nodes, the value available in OldPreHeader is just the + // incoming value from OldPreHeader. + for (; PHINode *PN = dyn_cast<PHINode>(I); ++I) + ValueMap[PN] = PN->getIncomingValueForBlock(OrigPreheader); + + // For the rest of the instructions, either hoist to the OrigPreheader if + // possible or create a clone in the OldPreHeader if not. + TerminatorInst *LoopEntryBranch = OrigPreheader->getTerminator(); + + // Record all debug intrinsics preceding LoopEntryBranch to avoid duplication. + using DbgIntrinsicHash = + std::pair<std::pair<Value *, DILocalVariable *>, DIExpression *>; + auto makeHash = [](DbgInfoIntrinsic *D) -> DbgIntrinsicHash { + return {{D->getVariableLocation(), D->getVariable()}, D->getExpression()}; + }; + SmallDenseSet<DbgIntrinsicHash, 8> DbgIntrinsics; + for (auto I = std::next(OrigPreheader->rbegin()), E = OrigPreheader->rend(); + I != E; ++I) { + if (auto *DII = dyn_cast<DbgInfoIntrinsic>(&*I)) + DbgIntrinsics.insert(makeHash(DII)); + else + break; + } + + while (I != E) { + Instruction *Inst = &*I++; + + // If the instruction's operands are invariant and it doesn't read or write + // memory, then it is safe to hoist. Doing this doesn't change the order of + // execution in the preheader, but does prevent the instruction from + // executing in each iteration of the loop. This means it is safe to hoist + // something that might trap, but isn't safe to hoist something that reads + // memory (without proving that the loop doesn't write). + if (L->hasLoopInvariantOperands(Inst) && !Inst->mayReadFromMemory() && + !Inst->mayWriteToMemory() && !isa<TerminatorInst>(Inst) && + !isa<DbgInfoIntrinsic>(Inst) && !isa<AllocaInst>(Inst)) { + Inst->moveBefore(LoopEntryBranch); + continue; + } + + // Otherwise, create a duplicate of the instruction. + Instruction *C = Inst->clone(); + + // Eagerly remap the operands of the instruction. + RemapInstruction(C, ValueMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + + // Avoid inserting the same intrinsic twice. + if (auto *DII = dyn_cast<DbgInfoIntrinsic>(C)) + if (DbgIntrinsics.count(makeHash(DII))) { + C->deleteValue(); + continue; + } + + // With the operands remapped, see if the instruction constant folds or is + // otherwise simplifyable. This commonly occurs because the entry from PHI + // nodes allows icmps and other instructions to fold. + Value *V = SimplifyInstruction(C, SQ); + if (V && LI->replacementPreservesLCSSAForm(C, V)) { + // If so, then delete the temporary instruction and stick the folded value + // in the map. + ValueMap[Inst] = V; + if (!C->mayHaveSideEffects()) { + C->deleteValue(); + C = nullptr; + } + } else { + ValueMap[Inst] = C; + } + if (C) { + // Otherwise, stick the new instruction into the new block! + C->setName(Inst->getName()); + C->insertBefore(LoopEntryBranch); + + if (auto *II = dyn_cast<IntrinsicInst>(C)) + if (II->getIntrinsicID() == Intrinsic::assume) + AC->registerAssumption(II); + } + } + + // Along with all the other instructions, we just cloned OrigHeader's + // terminator into OrigPreHeader. Fix up the PHI nodes in each of OrigHeader's + // successors by duplicating their incoming values for OrigHeader. + TerminatorInst *TI = OrigHeader->getTerminator(); + for (BasicBlock *SuccBB : TI->successors()) + for (BasicBlock::iterator BI = SuccBB->begin(); + PHINode *PN = dyn_cast<PHINode>(BI); ++BI) + PN->addIncoming(PN->getIncomingValueForBlock(OrigHeader), OrigPreheader); + + // Now that OrigPreHeader has a clone of OrigHeader's terminator, remove + // OrigPreHeader's old terminator (the original branch into the loop), and + // remove the corresponding incoming values from the PHI nodes in OrigHeader. + LoopEntryBranch->eraseFromParent(); + + + SmallVector<PHINode*, 2> InsertedPHIs; + // If there were any uses of instructions in the duplicated block outside the + // loop, update them, inserting PHI nodes as required + RewriteUsesOfClonedInstructions(OrigHeader, OrigPreheader, ValueMap, + &InsertedPHIs); + + // Attach dbg.value intrinsics to the new phis if that phi uses a value that + // previously had debug metadata attached. This keeps the debug info + // up-to-date in the loop body. + if (!InsertedPHIs.empty()) + insertDebugValuesForPHIs(OrigHeader, InsertedPHIs); + + // NewHeader is now the header of the loop. + L->moveToHeader(NewHeader); + assert(L->getHeader() == NewHeader && "Latch block is our new header"); + + // Inform DT about changes to the CFG. + if (DT) { + // The OrigPreheader branches to the NewHeader and Exit now. Then, inform + // the DT about the removed edge to the OrigHeader (that got removed). + SmallVector<DominatorTree::UpdateType, 3> Updates; + Updates.push_back({DominatorTree::Insert, OrigPreheader, Exit}); + Updates.push_back({DominatorTree::Insert, OrigPreheader, NewHeader}); + Updates.push_back({DominatorTree::Delete, OrigPreheader, OrigHeader}); + DT->applyUpdates(Updates); + } + + // At this point, we've finished our major CFG changes. As part of cloning + // the loop into the preheader we've simplified instructions and the + // duplicated conditional branch may now be branching on a constant. If it is + // branching on a constant and if that constant means that we enter the loop, + // then we fold away the cond branch to an uncond branch. This simplifies the + // loop in cases important for nested loops, and it also means we don't have + // to split as many edges. + BranchInst *PHBI = cast<BranchInst>(OrigPreheader->getTerminator()); + assert(PHBI->isConditional() && "Should be clone of BI condbr!"); + if (!isa<ConstantInt>(PHBI->getCondition()) || + PHBI->getSuccessor(cast<ConstantInt>(PHBI->getCondition())->isZero()) != + NewHeader) { + // The conditional branch can't be folded, handle the general case. + // Split edges as necessary to preserve LoopSimplify form. + + // Right now OrigPreHeader has two successors, NewHeader and ExitBlock, and + // thus is not a preheader anymore. + // Split the edge to form a real preheader. + BasicBlock *NewPH = SplitCriticalEdge( + OrigPreheader, NewHeader, + CriticalEdgeSplittingOptions(DT, LI).setPreserveLCSSA()); + NewPH->setName(NewHeader->getName() + ".lr.ph"); + + // Preserve canonical loop form, which means that 'Exit' should have only + // one predecessor. Note that Exit could be an exit block for multiple + // nested loops, causing both of the edges to now be critical and need to + // be split. + SmallVector<BasicBlock *, 4> ExitPreds(pred_begin(Exit), pred_end(Exit)); + bool SplitLatchEdge = false; + for (BasicBlock *ExitPred : ExitPreds) { + // We only need to split loop exit edges. + Loop *PredLoop = LI->getLoopFor(ExitPred); + if (!PredLoop || PredLoop->contains(Exit)) + continue; + if (isa<IndirectBrInst>(ExitPred->getTerminator())) + continue; + SplitLatchEdge |= L->getLoopLatch() == ExitPred; + BasicBlock *ExitSplit = SplitCriticalEdge( + ExitPred, Exit, + CriticalEdgeSplittingOptions(DT, LI).setPreserveLCSSA()); + ExitSplit->moveBefore(Exit); + } + assert(SplitLatchEdge && + "Despite splitting all preds, failed to split latch exit?"); + } else { + // We can fold the conditional branch in the preheader, this makes things + // simpler. The first step is to remove the extra edge to the Exit block. + Exit->removePredecessor(OrigPreheader, true /*preserve LCSSA*/); + BranchInst *NewBI = BranchInst::Create(NewHeader, PHBI); + NewBI->setDebugLoc(PHBI->getDebugLoc()); + PHBI->eraseFromParent(); + + // With our CFG finalized, update DomTree if it is available. + if (DT) DT->deleteEdge(OrigPreheader, Exit); + } + + assert(L->getLoopPreheader() && "Invalid loop preheader after loop rotation"); + assert(L->getLoopLatch() && "Invalid loop latch after loop rotation"); + + // Now that the CFG and DomTree are in a consistent state again, try to merge + // the OrigHeader block into OrigLatch. This will succeed if they are + // connected by an unconditional branch. This is just a cleanup so the + // emitted code isn't too gross in this common case. + MergeBlockIntoPredecessor(OrigHeader, DT, LI); + + LLVM_DEBUG(dbgs() << "LoopRotation: into "; L->dump()); + + ++NumRotated; + return true; +} + +/// Determine whether the instructions in this range may be safely and cheaply +/// speculated. This is not an important enough situation to develop complex +/// heuristics. We handle a single arithmetic instruction along with any type +/// conversions. +static bool shouldSpeculateInstrs(BasicBlock::iterator Begin, + BasicBlock::iterator End, Loop *L) { + bool seenIncrement = false; + bool MultiExitLoop = false; + + if (!L->getExitingBlock()) + MultiExitLoop = true; + + for (BasicBlock::iterator I = Begin; I != End; ++I) { + + if (!isSafeToSpeculativelyExecute(&*I)) + return false; + + if (isa<DbgInfoIntrinsic>(I)) + continue; + + switch (I->getOpcode()) { + default: + return false; + case Instruction::GetElementPtr: + // GEPs are cheap if all indices are constant. + if (!cast<GEPOperator>(I)->hasAllConstantIndices()) + return false; + // fall-thru to increment case + LLVM_FALLTHROUGH; + case Instruction::Add: + case Instruction::Sub: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: { + Value *IVOpnd = + !isa<Constant>(I->getOperand(0)) + ? I->getOperand(0) + : !isa<Constant>(I->getOperand(1)) ? I->getOperand(1) : nullptr; + if (!IVOpnd) + return false; + + // If increment operand is used outside of the loop, this speculation + // could cause extra live range interference. + if (MultiExitLoop) { + for (User *UseI : IVOpnd->users()) { + auto *UserInst = cast<Instruction>(UseI); + if (!L->contains(UserInst)) + return false; + } + } + + if (seenIncrement) + return false; + seenIncrement = true; + break; + } + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + // ignore type conversions + break; + } + } + return true; +} + +/// Fold the loop tail into the loop exit by speculating the loop tail +/// instructions. Typically, this is a single post-increment. In the case of a +/// simple 2-block loop, hoisting the increment can be much better than +/// duplicating the entire loop header. In the case of loops with early exits, +/// rotation will not work anyway, but simplifyLoopLatch will put the loop in +/// canonical form so downstream passes can handle it. +/// +/// I don't believe this invalidates SCEV. +bool LoopRotate::simplifyLoopLatch(Loop *L) { + BasicBlock *Latch = L->getLoopLatch(); + if (!Latch || Latch->hasAddressTaken()) + return false; + + BranchInst *Jmp = dyn_cast<BranchInst>(Latch->getTerminator()); + if (!Jmp || !Jmp->isUnconditional()) + return false; + + BasicBlock *LastExit = Latch->getSinglePredecessor(); + if (!LastExit || !L->isLoopExiting(LastExit)) + return false; + + BranchInst *BI = dyn_cast<BranchInst>(LastExit->getTerminator()); + if (!BI) + return false; + + if (!shouldSpeculateInstrs(Latch->begin(), Jmp->getIterator(), L)) + return false; + + LLVM_DEBUG(dbgs() << "Folding loop latch " << Latch->getName() << " into " + << LastExit->getName() << "\n"); + + // Hoist the instructions from Latch into LastExit. + LastExit->getInstList().splice(BI->getIterator(), Latch->getInstList(), + Latch->begin(), Jmp->getIterator()); + + unsigned FallThruPath = BI->getSuccessor(0) == Latch ? 0 : 1; + BasicBlock *Header = Jmp->getSuccessor(0); + assert(Header == L->getHeader() && "expected a backward branch"); + + // Remove Latch from the CFG so that LastExit becomes the new Latch. + BI->setSuccessor(FallThruPath, Header); + Latch->replaceSuccessorsPhiUsesWith(LastExit); + Jmp->eraseFromParent(); + + // Nuke the Latch block. + assert(Latch->empty() && "unable to evacuate Latch"); + LI->removeBlock(Latch); + if (DT) + DT->eraseNode(Latch); + Latch->eraseFromParent(); + return true; +} + +/// Rotate \c L, and return true if any modification was made. +bool LoopRotate::processLoop(Loop *L) { + // Save the loop metadata. + MDNode *LoopMD = L->getLoopID(); + + bool SimplifiedLatch = false; + + // Simplify the loop latch before attempting to rotate the header + // upward. Rotation may not be needed if the loop tail can be folded into the + // loop exit. + if (!RotationOnly) + SimplifiedLatch = simplifyLoopLatch(L); + + bool MadeChange = rotateLoop(L, SimplifiedLatch); + assert((!MadeChange || L->isLoopExiting(L->getLoopLatch())) && + "Loop latch should be exiting after loop-rotate."); + + // Restore the loop metadata. + // NB! We presume LoopRotation DOESN'T ADD its own metadata. + if ((MadeChange || SimplifiedLatch) && LoopMD) + L->setLoopID(LoopMD); + + return MadeChange || SimplifiedLatch; +} + + +/// The utility to convert a loop into a loop with bottom test. +bool llvm::LoopRotation(Loop *L, LoopInfo *LI, const TargetTransformInfo *TTI, + AssumptionCache *AC, DominatorTree *DT, + ScalarEvolution *SE, const SimplifyQuery &SQ, + bool RotationOnly = true, + unsigned Threshold = unsigned(-1), + bool IsUtilMode = true) { + LoopRotate LR(Threshold, LI, TTI, AC, DT, SE, SQ, RotationOnly, IsUtilMode); + + return LR.processLoop(L); +} diff --git a/lib/Transforms/Utils/LoopSimplify.cpp b/lib/Transforms/Utils/LoopSimplify.cpp index f43af9772771..970494eb4704 100644 --- a/lib/Transforms/Utils/LoopSimplify.cpp +++ b/lib/Transforms/Utils/LoopSimplify.cpp @@ -52,6 +52,7 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -64,9 +65,8 @@ #include "llvm/IR/Type.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; @@ -141,8 +141,8 @@ BasicBlock *llvm::InsertPreheaderForLoop(Loop *L, DominatorTree *DT, if (!PreheaderBB) return nullptr; - DEBUG(dbgs() << "LoopSimplify: Creating pre-header " - << PreheaderBB->getName() << "\n"); + LLVM_DEBUG(dbgs() << "LoopSimplify: Creating pre-header " + << PreheaderBB->getName() << "\n"); // Make sure that NewBB is put someplace intelligent, which doesn't mess up // code layout too horribly. @@ -170,7 +170,7 @@ static void addBlockAndPredsToSet(BasicBlock *InputBB, BasicBlock *StopBlock, } while (!Worklist.empty()); } -/// \brief The first part of loop-nestification is to find a PHI node that tells +/// The first part of loop-nestification is to find a PHI node that tells /// us how to partition the loops. static PHINode *findPHIToPartitionLoops(Loop *L, DominatorTree *DT, AssumptionCache *AC) { @@ -195,7 +195,7 @@ static PHINode *findPHIToPartitionLoops(Loop *L, DominatorTree *DT, return nullptr; } -/// \brief If this loop has multiple backedges, try to pull one of them out into +/// If this loop has multiple backedges, try to pull one of them out into /// a nested loop. /// /// This is important for code that looks like @@ -242,7 +242,7 @@ static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, OuterLoopPreds.push_back(PN->getIncomingBlock(i)); } } - DEBUG(dbgs() << "LoopSimplify: Splitting out a new outer loop\n"); + LLVM_DEBUG(dbgs() << "LoopSimplify: Splitting out a new outer loop\n"); // If ScalarEvolution is around and knows anything about values in // this loop, tell it to forget them, because we're about to @@ -332,7 +332,7 @@ static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, return NewOuter; } -/// \brief This method is called when the specified loop has more than one +/// This method is called when the specified loop has more than one /// backedge in it. /// /// If this occurs, revector all of these backedges to target a new basic block @@ -371,8 +371,8 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, BranchInst *BETerminator = BranchInst::Create(Header, BEBlock); BETerminator->setDebugLoc(Header->getFirstNonPHI()->getDebugLoc()); - DEBUG(dbgs() << "LoopSimplify: Inserting unique backedge block " - << BEBlock->getName() << "\n"); + LLVM_DEBUG(dbgs() << "LoopSimplify: Inserting unique backedge block " + << BEBlock->getName() << "\n"); // Move the new backedge block to right after the last backedge block. Function::iterator InsertPos = ++BackedgeBlocks.back()->getIterator(); @@ -457,7 +457,7 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, return BEBlock; } -/// \brief Simplify one loop and queue further loops for simplification. +/// Simplify one loop and queue further loops for simplification. static bool simplifyOneLoop(Loop *L, SmallVectorImpl<Loop *> &Worklist, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, @@ -484,8 +484,8 @@ ReprocessLoop: // Delete each unique out-of-loop (and thus dead) predecessor. for (BasicBlock *P : BadPreds) { - DEBUG(dbgs() << "LoopSimplify: Deleting edge from dead predecessor " - << P->getName() << "\n"); + LLVM_DEBUG(dbgs() << "LoopSimplify: Deleting edge from dead predecessor " + << P->getName() << "\n"); // Zap the dead pred's terminator and replace it with unreachable. TerminatorInst *TI = P->getTerminator(); @@ -504,16 +504,13 @@ ReprocessLoop: if (BI->isConditional()) { if (UndefValue *Cond = dyn_cast<UndefValue>(BI->getCondition())) { - DEBUG(dbgs() << "LoopSimplify: Resolving \"br i1 undef\" to exit in " - << ExitingBlock->getName() << "\n"); + LLVM_DEBUG(dbgs() + << "LoopSimplify: Resolving \"br i1 undef\" to exit in " + << ExitingBlock->getName() << "\n"); BI->setCondition(ConstantInt::get(Cond->getType(), !L->contains(BI->getSuccessor(0)))); - // This may make the loop analyzable, force SCEV recomputation. - if (SE) - SE->forgetLoop(L); - Changed = true; } } @@ -617,11 +614,8 @@ ReprocessLoop: // comparison and the branch. bool AllInvariant = true; bool AnyInvariant = false; - for (BasicBlock::iterator I = ExitingBlock->begin(); &*I != BI; ) { + for (auto I = ExitingBlock->instructionsWithoutDebug().begin(); &*I != BI; ) { Instruction *Inst = &*I++; - // Skip debug info intrinsics. - if (isa<DbgInfoIntrinsic>(Inst)) - continue; if (Inst == CI) continue; if (!L->makeLoopInvariant(Inst, AnyInvariant, @@ -648,15 +642,8 @@ ReprocessLoop: // Success. The block is now dead, so remove it from the loop, // update the dominator tree and delete it. - DEBUG(dbgs() << "LoopSimplify: Eliminating exiting block " - << ExitingBlock->getName() << "\n"); - - // Notify ScalarEvolution before deleting this block. Currently assume the - // parent loop doesn't change (spliting edges doesn't count). If blocks, - // CFG edges, or other values in the parent loop change, then we need call - // to forgetLoop() for the parent instead. - if (SE) - SE->forgetLoop(L); + LLVM_DEBUG(dbgs() << "LoopSimplify: Eliminating exiting block " + << ExitingBlock->getName() << "\n"); assert(pred_begin(ExitingBlock) == pred_end(ExitingBlock)); Changed = true; @@ -679,6 +666,12 @@ ReprocessLoop: } } + // Changing exit conditions for blocks may affect exit counts of this loop and + // any of its paretns, so we must invalidate the entire subtree if we've made + // any changes. + if (Changed && SE) + SE->forgetTopmostLoop(L); + return Changed; } diff --git a/lib/Transforms/Utils/LoopUnroll.cpp b/lib/Transforms/Utils/LoopUnroll.cpp index dc98a39adcc5..04b8c1417e0a 100644 --- a/lib/Transforms/Utils/LoopUnroll.cpp +++ b/lib/Transforms/Utils/LoopUnroll.cpp @@ -23,6 +23,7 @@ #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfoMetadata.h" @@ -33,7 +34,6 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SimplifyIndVar.h" @@ -63,8 +63,7 @@ UnrollVerifyDomtree("unroll-verify-domtree", cl::Hidden, /// Convert the instruction operands from referencing the current values into /// those specified by VMap. -static inline void remapInstruction(Instruction *I, - ValueToValueMapTy &VMap) { +void llvm::remapInstruction(Instruction *I, ValueToValueMapTy &VMap) { for (unsigned op = 0, E = I->getNumOperands(); op != E; ++op) { Value *Op = I->getOperand(op); @@ -97,16 +96,10 @@ static inline void remapInstruction(Instruction *I, /// Folds a basic block into its predecessor if it only has one predecessor, and /// that predecessor only has one successor. -/// The LoopInfo Analysis that is passed will be kept consistent. If folding is -/// successful references to the containing loop must be removed from -/// ScalarEvolution by calling ScalarEvolution::forgetLoop because SE may have -/// references to the eliminated BB. The argument ForgottenLoops contains a set -/// of loops that have already been forgotten to prevent redundant, expensive -/// calls to ScalarEvolution::forgetLoop. Returns the new combined block. -static BasicBlock * -foldBlockIntoPredecessor(BasicBlock *BB, LoopInfo *LI, ScalarEvolution *SE, - SmallPtrSetImpl<Loop *> &ForgottenLoops, - DominatorTree *DT) { +/// The LoopInfo Analysis that is passed will be kept consistent. +BasicBlock *llvm::foldBlockIntoPredecessor(BasicBlock *BB, LoopInfo *LI, + ScalarEvolution *SE, + DominatorTree *DT) { // Merge basic blocks into their predecessor if there is only one distinct // pred, and if there is only one distinct successor of the predecessor, and // if there are no PHI nodes. @@ -116,7 +109,8 @@ foldBlockIntoPredecessor(BasicBlock *BB, LoopInfo *LI, ScalarEvolution *SE, if (OnlyPred->getTerminator()->getNumSuccessors() != 1) return nullptr; - DEBUG(dbgs() << "Merging: " << *BB << "into: " << *OnlyPred); + LLVM_DEBUG(dbgs() << "Merging: " << BB->getName() << " into " + << OnlyPred->getName() << "\n"); // Resolve any PHI nodes at the start of the block. They are all // guaranteed to have exactly one entry if they exist, unless there are @@ -149,13 +143,6 @@ foldBlockIntoPredecessor(BasicBlock *BB, LoopInfo *LI, ScalarEvolution *SE, DT->eraseNode(BB); } - // ScalarEvolution holds references to loop exit blocks. - if (SE) { - if (Loop *L = LI->getLoopFor(BB)) { - if (ForgottenLoops.insert(L).second) - SE->forgetLoop(L); - } - } LI->removeBlock(BB); // Inherit predecessor's name if it exists... @@ -258,16 +245,55 @@ static bool isEpilogProfitable(Loop *L) { BasicBlock *PreHeader = L->getLoopPreheader(); BasicBlock *Header = L->getHeader(); assert(PreHeader && Header); - for (Instruction &BBI : *Header) { - PHINode *PN = dyn_cast<PHINode>(&BBI); - if (!PN) - break; - if (isa<ConstantInt>(PN->getIncomingValueForBlock(PreHeader))) + for (const PHINode &PN : Header->phis()) { + if (isa<ConstantInt>(PN.getIncomingValueForBlock(PreHeader))) return true; } return false; } +/// Perform some cleanup and simplifications on loops after unrolling. It is +/// useful to simplify the IV's in the new loop, as well as do a quick +/// simplify/dce pass of the instructions. +void llvm::simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI, + ScalarEvolution *SE, DominatorTree *DT, + AssumptionCache *AC) { + // Simplify any new induction variables in the partially unrolled loop. + if (SE && SimplifyIVs) { + SmallVector<WeakTrackingVH, 16> DeadInsts; + simplifyLoopIVs(L, SE, DT, LI, DeadInsts); + + // Aggressively clean up dead instructions that simplifyLoopIVs already + // identified. Any remaining should be cleaned up below. + while (!DeadInsts.empty()) + if (Instruction *Inst = + dyn_cast_or_null<Instruction>(&*DeadInsts.pop_back_val())) + RecursivelyDeleteTriviallyDeadInstructions(Inst); + } + + // At this point, the code is well formed. We now do a quick sweep over the + // inserted code, doing constant propagation and dead code elimination as we + // go. + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + const std::vector<BasicBlock *> &NewLoopBlocks = L->getBlocks(); + for (BasicBlock *BB : NewLoopBlocks) { + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { + Instruction *Inst = &*I++; + + if (Value *V = SimplifyInstruction(Inst, {DL, nullptr, DT, AC})) + if (LI->replacementPreservesLCSSAForm(Inst, V)) + Inst->replaceAllUsesWith(V); + if (isInstructionTriviallyDead(Inst)) + BB->getInstList().erase(Inst); + } + } + + // TODO: after peeling or unrolling, previously loop variant conditions are + // likely to fold to constants, eagerly propagating those here will require + // fewer cleanup passes to be run. Alternatively, a LoopEarlyCSE might be + // appropriate. +} + /// Unroll the given loop by Count. The loop must be in LCSSA form. Unrolling /// can only fail when the loop's latch block is not terminated by a conditional /// branch instruction. However, if the trip count (and multiple) are not known, @@ -313,19 +339,19 @@ LoopUnrollResult llvm::UnrollLoop( BasicBlock *Preheader = L->getLoopPreheader(); if (!Preheader) { - DEBUG(dbgs() << " Can't unroll; loop preheader-insertion failed.\n"); + LLVM_DEBUG(dbgs() << " Can't unroll; loop preheader-insertion failed.\n"); return LoopUnrollResult::Unmodified; } BasicBlock *LatchBlock = L->getLoopLatch(); if (!LatchBlock) { - DEBUG(dbgs() << " Can't unroll; loop exit-block-insertion failed.\n"); + LLVM_DEBUG(dbgs() << " Can't unroll; loop exit-block-insertion failed.\n"); return LoopUnrollResult::Unmodified; } // Loops with indirectbr cannot be cloned. if (!L->isSafeToClone()) { - DEBUG(dbgs() << " Can't unroll; Loop body cannot be cloned.\n"); + LLVM_DEBUG(dbgs() << " Can't unroll; Loop body cannot be cloned.\n"); return LoopUnrollResult::Unmodified; } @@ -338,8 +364,9 @@ LoopUnrollResult llvm::UnrollLoop( if (!BI || BI->isUnconditional()) { // The loop-rotate pass can be helpful to avoid this in many cases. - DEBUG(dbgs() << - " Can't unroll; loop not terminated by a conditional branch.\n"); + LLVM_DEBUG( + dbgs() + << " Can't unroll; loop not terminated by a conditional branch.\n"); return LoopUnrollResult::Unmodified; } @@ -348,22 +375,22 @@ LoopUnrollResult llvm::UnrollLoop( }; if (!CheckSuccessors(0, 1) && !CheckSuccessors(1, 0)) { - DEBUG(dbgs() << "Can't unroll; only loops with one conditional latch" - " exiting the loop can be unrolled\n"); + LLVM_DEBUG(dbgs() << "Can't unroll; only loops with one conditional latch" + " exiting the loop can be unrolled\n"); return LoopUnrollResult::Unmodified; } if (Header->hasAddressTaken()) { // The loop-rotate pass can be helpful to avoid this in many cases. - DEBUG(dbgs() << - " Won't unroll loop: address of header block is taken.\n"); + LLVM_DEBUG( + dbgs() << " Won't unroll loop: address of header block is taken.\n"); return LoopUnrollResult::Unmodified; } if (TripCount != 0) - DEBUG(dbgs() << " Trip Count = " << TripCount << "\n"); + LLVM_DEBUG(dbgs() << " Trip Count = " << TripCount << "\n"); if (TripMultiple != 1) - DEBUG(dbgs() << " Trip Multiple = " << TripMultiple << "\n"); + LLVM_DEBUG(dbgs() << " Trip Multiple = " << TripMultiple << "\n"); // Effectively "DCE" unrolled iterations that are beyond the tripcount // and will never be executed. @@ -372,7 +399,7 @@ LoopUnrollResult llvm::UnrollLoop( // Don't enter the unroll code if there is nothing to do. if (TripCount == 0 && Count < 2 && PeelCount == 0) { - DEBUG(dbgs() << "Won't unroll; almost nothing to do\n"); + LLVM_DEBUG(dbgs() << "Won't unroll; almost nothing to do\n"); return LoopUnrollResult::Unmodified; } @@ -406,8 +433,9 @@ LoopUnrollResult llvm::UnrollLoop( "Did not expect runtime trip-count unrolling " "and peeling for the same loop"); + bool Peeled = false; if (PeelCount) { - bool Peeled = peelLoop(L, PeelCount, LI, SE, DT, AC, PreserveLCSSA); + Peeled = peelLoop(L, PeelCount, LI, SE, DT, AC, PreserveLCSSA); // Successful peeling may result in a change in the loop preheader/trip // counts. If we later unroll the loop, we want these to be updated. @@ -422,7 +450,7 @@ LoopUnrollResult llvm::UnrollLoop( // Loops containing convergent instructions must have a count that divides // their TripMultiple. - DEBUG( + LLVM_DEBUG( { bool HasConvergent = false; for (auto &BB : L->blocks()) @@ -445,18 +473,12 @@ LoopUnrollResult llvm::UnrollLoop( if (Force) RuntimeTripCount = false; else { - DEBUG( - dbgs() << "Wont unroll; remainder loop could not be generated" - "when assuming runtime trip count\n"); + LLVM_DEBUG(dbgs() << "Won't unroll; remainder loop could not be " + "generated when assuming runtime trip count\n"); return LoopUnrollResult::Unmodified; } } - // Notify ScalarEvolution that the loop will be substantially changed, - // if not outright eliminated. - if (SE) - SE->forgetLoop(L); - // If we know the trip count, we know the multiple... unsigned BreakoutTrip = 0; if (TripCount != 0) { @@ -471,8 +493,8 @@ LoopUnrollResult llvm::UnrollLoop( using namespace ore; // Report the unrolling decision. if (CompletelyUnroll) { - DEBUG(dbgs() << "COMPLETELY UNROLLING loop %" << Header->getName() - << " with trip count " << TripCount << "!\n"); + LLVM_DEBUG(dbgs() << "COMPLETELY UNROLLING loop %" << Header->getName() + << " with trip count " << TripCount << "!\n"); if (ORE) ORE->emit([&]() { return OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(), @@ -481,8 +503,8 @@ LoopUnrollResult llvm::UnrollLoop( << NV("UnrollCount", TripCount) << " iterations"; }); } else if (PeelCount) { - DEBUG(dbgs() << "PEELING loop %" << Header->getName() - << " with iteration count " << PeelCount << "!\n"); + LLVM_DEBUG(dbgs() << "PEELING loop %" << Header->getName() + << " with iteration count " << PeelCount << "!\n"); if (ORE) ORE->emit([&]() { return OptimizationRemark(DEBUG_TYPE, "Peeled", L->getStartLoc(), @@ -498,31 +520,42 @@ LoopUnrollResult llvm::UnrollLoop( << NV("UnrollCount", Count); }; - DEBUG(dbgs() << "UNROLLING loop %" << Header->getName() - << " by " << Count); + LLVM_DEBUG(dbgs() << "UNROLLING loop %" << Header->getName() << " by " + << Count); if (TripMultiple == 0 || BreakoutTrip != TripMultiple) { - DEBUG(dbgs() << " with a breakout at trip " << BreakoutTrip); + LLVM_DEBUG(dbgs() << " with a breakout at trip " << BreakoutTrip); if (ORE) ORE->emit([&]() { return DiagBuilder() << " with a breakout at trip " << NV("BreakoutTrip", BreakoutTrip); }); } else if (TripMultiple != 1) { - DEBUG(dbgs() << " with " << TripMultiple << " trips per branch"); + LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch"); if (ORE) ORE->emit([&]() { return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple) << " trips per branch"; }); } else if (RuntimeTripCount) { - DEBUG(dbgs() << " with run-time trip count"); + LLVM_DEBUG(dbgs() << " with run-time trip count"); if (ORE) ORE->emit( [&]() { return DiagBuilder() << " with run-time trip count"; }); } - DEBUG(dbgs() << "!\n"); + LLVM_DEBUG(dbgs() << "!\n"); } + // We are going to make changes to this loop. SCEV may be keeping cached info + // about it, in particular about backedge taken count. The changes we make + // are guaranteed to invalidate this information for our loop. It is tempting + // to only invalidate the loop being unrolled, but it is incorrect as long as + // all exiting branches from all inner loops have impact on the outer loops, + // and if something changes inside them then any of outer loops may also + // change. When we forget outermost loop, we also forget all contained loops + // and this is what we need here. + if (SE) + SE->forgetTopmostLoop(L); + bool ContinueOnTrue = L->contains(BI->getSuccessor(0)); BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue); @@ -580,14 +613,9 @@ LoopUnrollResult llvm::UnrollLoop( "Header should not be in a sub-loop"); // Tell LI about New. const Loop *OldLoop = addClonedBlockToLoopInfo(*BB, New, LI, NewLoops); - if (OldLoop) { + if (OldLoop) LoopsToSimplify.insert(NewLoops[OldLoop]); - // Forget the old loop, since its inputs may have changed. - if (SE) - SE->forgetLoop(OldLoop); - } - if (*BB == Header) // Loop over all of the PHI nodes in the block, changing them to use // the incoming values from the previous block. @@ -611,13 +639,12 @@ LoopUnrollResult llvm::UnrollLoop( for (BasicBlock *Succ : successors(*BB)) { if (L->contains(Succ)) continue; - for (BasicBlock::iterator BBI = Succ->begin(); - PHINode *phi = dyn_cast<PHINode>(BBI); ++BBI) { - Value *Incoming = phi->getIncomingValueForBlock(*BB); + for (PHINode &PHI : Succ->phis()) { + Value *Incoming = PHI.getIncomingValueForBlock(*BB); ValueToValueMapTy::iterator It = LastValueMap.find(Incoming); if (It != LastValueMap.end()) Incoming = It->second; - phi->addIncoming(Incoming, New); + PHI.addIncoming(Incoming, New); } } // Keep track of new headers and latches as we create them, so that @@ -721,10 +748,8 @@ LoopUnrollResult llvm::UnrollLoop( for (BasicBlock *Succ: successors(BB)) { if (Succ == Headers[i]) continue; - for (BasicBlock::iterator BBI = Succ->begin(); - PHINode *Phi = dyn_cast<PHINode>(BBI); ++BBI) { - Phi->removeIncomingValue(BB, false); - } + for (PHINode &Phi : Succ->phis()) + Phi.removeIncomingValue(BB, false); } } // Replace the conditional branch with an unconditional one. @@ -775,17 +800,15 @@ LoopUnrollResult llvm::UnrollLoop( } } - if (DT && UnrollVerifyDomtree) - DT->verifyDomTree(); + assert(!DT || !UnrollVerifyDomtree || + DT->verify(DominatorTree::VerificationLevel::Fast)); // Merge adjacent basic blocks, if possible. - SmallPtrSet<Loop *, 4> ForgottenLoops; for (BasicBlock *Latch : Latches) { BranchInst *Term = cast<BranchInst>(Latch->getTerminator()); if (Term->isUnconditional()) { BasicBlock *Dest = Term->getSuccessor(0); - if (BasicBlock *Fold = - foldBlockIntoPredecessor(Dest, LI, SE, ForgottenLoops, DT)) { + if (BasicBlock *Fold = foldBlockIntoPredecessor(Dest, LI, SE, DT)) { // Dest has been folded into Fold. Update our worklists accordingly. std::replace(Latches.begin(), Latches.end(), Dest, Fold); UnrolledLoopBlocks.erase(std::remove(UnrolledLoopBlocks.begin(), @@ -795,40 +818,10 @@ LoopUnrollResult llvm::UnrollLoop( } } - // Simplify any new induction variables in the partially unrolled loop. - if (SE && !CompletelyUnroll && Count > 1) { - SmallVector<WeakTrackingVH, 16> DeadInsts; - simplifyLoopIVs(L, SE, DT, LI, DeadInsts); - - // Aggressively clean up dead instructions that simplifyLoopIVs already - // identified. Any remaining should be cleaned up below. - while (!DeadInsts.empty()) - if (Instruction *Inst = - dyn_cast_or_null<Instruction>(&*DeadInsts.pop_back_val())) - RecursivelyDeleteTriviallyDeadInstructions(Inst); - } - - // At this point, the code is well formed. We now do a quick sweep over the - // inserted code, doing constant propagation and dead code elimination as we - // go. - const DataLayout &DL = Header->getModule()->getDataLayout(); - const std::vector<BasicBlock*> &NewLoopBlocks = L->getBlocks(); - for (BasicBlock *BB : NewLoopBlocks) { - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ) { - Instruction *Inst = &*I++; - - if (Value *V = SimplifyInstruction(Inst, {DL, nullptr, DT, AC})) - if (LI->replacementPreservesLCSSAForm(Inst, V)) - Inst->replaceAllUsesWith(V); - if (isInstructionTriviallyDead(Inst)) - BB->getInstList().erase(Inst); - } - } - - // TODO: after peeling or unrolling, previously loop variant conditions are - // likely to fold to constants, eagerly propagating those here will require - // fewer cleanup passes to be run. Alternatively, a LoopEarlyCSE might be - // appropriate. + // At this point, the code is well formed. We now simplify the unrolled loop, + // doing constant propagation and dead code elimination as we go. + simplifyLoopAfterUnroll(L, !CompletelyUnroll && (Count > 1 || Peeled), LI, SE, + DT, AC); NumCompletelyUnrolled += CompletelyUnroll; ++NumUnrolled; diff --git a/lib/Transforms/Utils/LoopUnrollAndJam.cpp b/lib/Transforms/Utils/LoopUnrollAndJam.cpp new file mode 100644 index 000000000000..b919f73c3817 --- /dev/null +++ b/lib/Transforms/Utils/LoopUnrollAndJam.cpp @@ -0,0 +1,785 @@ +//===-- LoopUnrollAndJam.cpp - Loop unrolling utilities -------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements loop unroll and jam as a routine, much like +// LoopUnroll.cpp implements loop unroll. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/DependenceAnalysis.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/LoopIterator.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/Utils/Local.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/LoopSimplify.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/SimplifyIndVar.h" +#include "llvm/Transforms/Utils/UnrollLoop.h" +using namespace llvm; + +#define DEBUG_TYPE "loop-unroll-and-jam" + +STATISTIC(NumUnrolledAndJammed, "Number of loops unroll and jammed"); +STATISTIC(NumCompletelyUnrolledAndJammed, "Number of loops unroll and jammed"); + +typedef SmallPtrSet<BasicBlock *, 4> BasicBlockSet; + +// Partition blocks in an outer/inner loop pair into blocks before and after +// the loop +static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop, + BasicBlockSet &ForeBlocks, + BasicBlockSet &SubLoopBlocks, + BasicBlockSet &AftBlocks, + DominatorTree *DT) { + BasicBlock *SubLoopLatch = SubLoop->getLoopLatch(); + SubLoopBlocks.insert(SubLoop->block_begin(), SubLoop->block_end()); + + for (BasicBlock *BB : L->blocks()) { + if (!SubLoop->contains(BB)) { + if (DT->dominates(SubLoopLatch, BB)) + AftBlocks.insert(BB); + else + ForeBlocks.insert(BB); + } + } + + // Check that all blocks in ForeBlocks together dominate the subloop + // TODO: This might ideally be done better with a dominator/postdominators. + BasicBlock *SubLoopPreHeader = SubLoop->getLoopPreheader(); + for (BasicBlock *BB : ForeBlocks) { + if (BB == SubLoopPreHeader) + continue; + TerminatorInst *TI = BB->getTerminator(); + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + if (!ForeBlocks.count(TI->getSuccessor(i))) + return false; + } + + return true; +} + +// Looks at the phi nodes in Header for values coming from Latch. For these +// instructions and all their operands calls Visit on them, keeping going for +// all the operands in AftBlocks. Returns false if Visit returns false, +// otherwise returns true. This is used to process the instructions in the +// Aft blocks that need to be moved before the subloop. It is used in two +// places. One to check that the required set of instructions can be moved +// before the loop. Then to collect the instructions to actually move in +// moveHeaderPhiOperandsToForeBlocks. +template <typename T> +static bool processHeaderPhiOperands(BasicBlock *Header, BasicBlock *Latch, + BasicBlockSet &AftBlocks, T Visit) { + SmallVector<Instruction *, 8> Worklist; + for (auto &Phi : Header->phis()) { + Value *V = Phi.getIncomingValueForBlock(Latch); + if (Instruction *I = dyn_cast<Instruction>(V)) + Worklist.push_back(I); + } + + while (!Worklist.empty()) { + Instruction *I = Worklist.back(); + Worklist.pop_back(); + if (!Visit(I)) + return false; + + if (AftBlocks.count(I->getParent())) + for (auto &U : I->operands()) + if (Instruction *II = dyn_cast<Instruction>(U)) + Worklist.push_back(II); + } + + return true; +} + +// Move the phi operands of Header from Latch out of AftBlocks to InsertLoc. +static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header, + BasicBlock *Latch, + Instruction *InsertLoc, + BasicBlockSet &AftBlocks) { + // We need to ensure we move the instructions in the correct order, + // starting with the earliest required instruction and moving forward. + std::vector<Instruction *> Visited; + processHeaderPhiOperands(Header, Latch, AftBlocks, + [&Visited, &AftBlocks](Instruction *I) { + if (AftBlocks.count(I->getParent())) + Visited.push_back(I); + return true; + }); + + // Move all instructions in program order to before the InsertLoc + BasicBlock *InsertLocBB = InsertLoc->getParent(); + for (Instruction *I : reverse(Visited)) { + if (I->getParent() != InsertLocBB) + I->moveBefore(InsertLoc); + } +} + +/* + This method performs Unroll and Jam. For a simple loop like: + for (i = ..) + Fore(i) + for (j = ..) + SubLoop(i, j) + Aft(i) + + Instead of doing normal inner or outer unrolling, we do: + for (i = .., i+=2) + Fore(i) + Fore(i+1) + for (j = ..) + SubLoop(i, j) + SubLoop(i+1, j) + Aft(i) + Aft(i+1) + + So the outer loop is essetially unrolled and then the inner loops are fused + ("jammed") together into a single loop. This can increase speed when there + are loads in SubLoop that are invariant to i, as they become shared between + the now jammed inner loops. + + We do this by spliting the blocks in the loop into Fore, Subloop and Aft. + Fore blocks are those before the inner loop, Aft are those after. Normal + Unroll code is used to copy each of these sets of blocks and the results are + combined together into the final form above. + + isSafeToUnrollAndJam should be used prior to calling this to make sure the + unrolling will be valid. Checking profitablility is also advisable. +*/ +LoopUnrollResult +llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount, + unsigned TripMultiple, bool UnrollRemainder, + LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, + AssumptionCache *AC, OptimizationRemarkEmitter *ORE) { + + // When we enter here we should have already checked that it is safe + BasicBlock *Header = L->getHeader(); + assert(L->getSubLoops().size() == 1); + Loop *SubLoop = *L->begin(); + + // Don't enter the unroll code if there is nothing to do. + if (TripCount == 0 && Count < 2) { + LLVM_DEBUG(dbgs() << "Won't unroll; almost nothing to do\n"); + return LoopUnrollResult::Unmodified; + } + + assert(Count > 0); + assert(TripMultiple > 0); + assert(TripCount == 0 || TripCount % TripMultiple == 0); + + // Are we eliminating the loop control altogether? + bool CompletelyUnroll = (Count == TripCount); + + // We use the runtime remainder in cases where we don't know trip multiple + if (TripMultiple == 1 || TripMultiple % Count != 0) { + if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false, + /*UseEpilogRemainder*/ true, + UnrollRemainder, LI, SE, DT, AC, true)) { + LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be " + "generated when assuming runtime trip count\n"); + return LoopUnrollResult::Unmodified; + } + } + + // Notify ScalarEvolution that the loop will be substantially changed, + // if not outright eliminated. + if (SE) { + SE->forgetLoop(L); + SE->forgetLoop(SubLoop); + } + + using namespace ore; + // Report the unrolling decision. + if (CompletelyUnroll) { + LLVM_DEBUG(dbgs() << "COMPLETELY UNROLL AND JAMMING loop %" + << Header->getName() << " with trip count " << TripCount + << "!\n"); + ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(), + L->getHeader()) + << "completely unroll and jammed loop with " + << NV("UnrollCount", TripCount) << " iterations"); + } else { + auto DiagBuilder = [&]() { + OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(), + L->getHeader()); + return Diag << "unroll and jammed loop by a factor of " + << NV("UnrollCount", Count); + }; + + LLVM_DEBUG(dbgs() << "UNROLL AND JAMMING loop %" << Header->getName() + << " by " << Count); + if (TripMultiple != 1) { + LLVM_DEBUG(dbgs() << " with " << TripMultiple << " trips per branch"); + ORE->emit([&]() { + return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple) + << " trips per branch"; + }); + } else { + LLVM_DEBUG(dbgs() << " with run-time trip count"); + ORE->emit([&]() { return DiagBuilder() << " with run-time trip count"; }); + } + LLVM_DEBUG(dbgs() << "!\n"); + } + + BasicBlock *Preheader = L->getLoopPreheader(); + BasicBlock *LatchBlock = L->getLoopLatch(); + BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator()); + assert(Preheader && LatchBlock && Header); + assert(BI && !BI->isUnconditional()); + bool ContinueOnTrue = L->contains(BI->getSuccessor(0)); + BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue); + bool SubLoopContinueOnTrue = SubLoop->contains( + SubLoop->getLoopLatch()->getTerminator()->getSuccessor(0)); + + // Partition blocks in an outer/inner loop pair into blocks before and after + // the loop + BasicBlockSet SubLoopBlocks; + BasicBlockSet ForeBlocks; + BasicBlockSet AftBlocks; + partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, AftBlocks, + DT); + + // We keep track of the entering/first and exiting/last block of each of + // Fore/SubLoop/Aft in each iteration. This helps make the stapling up of + // blocks easier. + std::vector<BasicBlock *> ForeBlocksFirst; + std::vector<BasicBlock *> ForeBlocksLast; + std::vector<BasicBlock *> SubLoopBlocksFirst; + std::vector<BasicBlock *> SubLoopBlocksLast; + std::vector<BasicBlock *> AftBlocksFirst; + std::vector<BasicBlock *> AftBlocksLast; + ForeBlocksFirst.push_back(Header); + ForeBlocksLast.push_back(SubLoop->getLoopPreheader()); + SubLoopBlocksFirst.push_back(SubLoop->getHeader()); + SubLoopBlocksLast.push_back(SubLoop->getExitingBlock()); + AftBlocksFirst.push_back(SubLoop->getExitBlock()); + AftBlocksLast.push_back(L->getExitingBlock()); + // Maps Blocks[0] -> Blocks[It] + ValueToValueMapTy LastValueMap; + + // Move any instructions from fore phi operands from AftBlocks into Fore. + moveHeaderPhiOperandsToForeBlocks( + Header, LatchBlock, SubLoop->getLoopPreheader()->getTerminator(), + AftBlocks); + + // The current on-the-fly SSA update requires blocks to be processed in + // reverse postorder so that LastValueMap contains the correct value at each + // exit. + LoopBlocksDFS DFS(L); + DFS.perform(LI); + // Stash the DFS iterators before adding blocks to the loop. + LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO(); + LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO(); + + if (Header->getParent()->isDebugInfoForProfiling()) + for (BasicBlock *BB : L->getBlocks()) + for (Instruction &I : *BB) + if (!isa<DbgInfoIntrinsic>(&I)) + if (const DILocation *DIL = I.getDebugLoc()) + I.setDebugLoc(DIL->cloneWithDuplicationFactor(Count)); + + // Copy all blocks + for (unsigned It = 1; It != Count; ++It) { + std::vector<BasicBlock *> NewBlocks; + // Maps Blocks[It] -> Blocks[It-1] + DenseMap<Value *, Value *> PrevItValueMap; + + for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) { + ValueToValueMapTy VMap; + BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It)); + Header->getParent()->getBasicBlockList().push_back(New); + + if (ForeBlocks.count(*BB)) { + L->addBasicBlockToLoop(New, *LI); + + if (*BB == ForeBlocksFirst[0]) + ForeBlocksFirst.push_back(New); + if (*BB == ForeBlocksLast[0]) + ForeBlocksLast.push_back(New); + } else if (SubLoopBlocks.count(*BB)) { + SubLoop->addBasicBlockToLoop(New, *LI); + + if (*BB == SubLoopBlocksFirst[0]) + SubLoopBlocksFirst.push_back(New); + if (*BB == SubLoopBlocksLast[0]) + SubLoopBlocksLast.push_back(New); + } else if (AftBlocks.count(*BB)) { + L->addBasicBlockToLoop(New, *LI); + + if (*BB == AftBlocksFirst[0]) + AftBlocksFirst.push_back(New); + if (*BB == AftBlocksLast[0]) + AftBlocksLast.push_back(New); + } else { + llvm_unreachable("BB being cloned should be in Fore/Sub/Aft"); + } + + // Update our running maps of newest clones + PrevItValueMap[New] = (It == 1 ? *BB : LastValueMap[*BB]); + LastValueMap[*BB] = New; + for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end(); + VI != VE; ++VI) { + PrevItValueMap[VI->second] = + const_cast<Value *>(It == 1 ? VI->first : LastValueMap[VI->first]); + LastValueMap[VI->first] = VI->second; + } + + NewBlocks.push_back(New); + + // Update DomTree: + if (*BB == ForeBlocksFirst[0]) + DT->addNewBlock(New, ForeBlocksLast[It - 1]); + else if (*BB == SubLoopBlocksFirst[0]) + DT->addNewBlock(New, SubLoopBlocksLast[It - 1]); + else if (*BB == AftBlocksFirst[0]) + DT->addNewBlock(New, AftBlocksLast[It - 1]); + else { + // Each set of blocks (Fore/Sub/Aft) will have the same internal domtree + // structure. + auto BBDomNode = DT->getNode(*BB); + auto BBIDom = BBDomNode->getIDom(); + BasicBlock *OriginalBBIDom = BBIDom->getBlock(); + assert(OriginalBBIDom); + assert(LastValueMap[cast<Value>(OriginalBBIDom)]); + DT->addNewBlock( + New, cast<BasicBlock>(LastValueMap[cast<Value>(OriginalBBIDom)])); + } + } + + // Remap all instructions in the most recent iteration + for (BasicBlock *NewBlock : NewBlocks) { + for (Instruction &I : *NewBlock) { + ::remapInstruction(&I, LastValueMap); + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + if (II->getIntrinsicID() == Intrinsic::assume) + AC->registerAssumption(II); + } + } + + // Alter the ForeBlocks phi's, pointing them at the latest version of the + // value from the previous iteration's phis + for (PHINode &Phi : ForeBlocksFirst[It]->phis()) { + Value *OldValue = Phi.getIncomingValueForBlock(AftBlocksLast[It]); + assert(OldValue && "should have incoming edge from Aft[It]"); + Value *NewValue = OldValue; + if (Value *PrevValue = PrevItValueMap[OldValue]) + NewValue = PrevValue; + + assert(Phi.getNumOperands() == 2); + Phi.setIncomingBlock(0, ForeBlocksLast[It - 1]); + Phi.setIncomingValue(0, NewValue); + Phi.removeIncomingValue(1); + } + } + + // Now that all the basic blocks for the unrolled iterations are in place, + // finish up connecting the blocks and phi nodes. At this point LastValueMap + // is the last unrolled iterations values. + + // Update Phis in BB from OldBB to point to NewBB + auto updatePHIBlocks = [](BasicBlock *BB, BasicBlock *OldBB, + BasicBlock *NewBB) { + for (PHINode &Phi : BB->phis()) { + int I = Phi.getBasicBlockIndex(OldBB); + Phi.setIncomingBlock(I, NewBB); + } + }; + // Update Phis in BB from OldBB to point to NewBB and use the latest value + // from LastValueMap + auto updatePHIBlocksAndValues = [](BasicBlock *BB, BasicBlock *OldBB, + BasicBlock *NewBB, + ValueToValueMapTy &LastValueMap) { + for (PHINode &Phi : BB->phis()) { + for (unsigned b = 0; b < Phi.getNumIncomingValues(); ++b) { + if (Phi.getIncomingBlock(b) == OldBB) { + Value *OldValue = Phi.getIncomingValue(b); + if (Value *LastValue = LastValueMap[OldValue]) + Phi.setIncomingValue(b, LastValue); + Phi.setIncomingBlock(b, NewBB); + break; + } + } + } + }; + // Move all the phis from Src into Dest + auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) { + Instruction *insertPoint = Dest->getFirstNonPHI(); + while (PHINode *Phi = dyn_cast<PHINode>(Src->begin())) + Phi->moveBefore(insertPoint); + }; + + // Update the PHI values outside the loop to point to the last block + updatePHIBlocksAndValues(LoopExit, AftBlocksLast[0], AftBlocksLast.back(), + LastValueMap); + + // Update ForeBlocks successors and phi nodes + BranchInst *ForeTerm = + cast<BranchInst>(ForeBlocksLast.back()->getTerminator()); + BasicBlock *Dest = SubLoopBlocksFirst[0]; + ForeTerm->setSuccessor(0, Dest); + + if (CompletelyUnroll) { + while (PHINode *Phi = dyn_cast<PHINode>(ForeBlocksFirst[0]->begin())) { + Phi->replaceAllUsesWith(Phi->getIncomingValueForBlock(Preheader)); + Phi->getParent()->getInstList().erase(Phi); + } + } else { + // Update the PHI values to point to the last aft block + updatePHIBlocksAndValues(ForeBlocksFirst[0], AftBlocksLast[0], + AftBlocksLast.back(), LastValueMap); + } + + for (unsigned It = 1; It != Count; It++) { + // Remap ForeBlock successors from previous iteration to this + BranchInst *ForeTerm = + cast<BranchInst>(ForeBlocksLast[It - 1]->getTerminator()); + BasicBlock *Dest = ForeBlocksFirst[It]; + ForeTerm->setSuccessor(0, Dest); + } + + // Subloop successors and phis + BranchInst *SubTerm = + cast<BranchInst>(SubLoopBlocksLast.back()->getTerminator()); + SubTerm->setSuccessor(!SubLoopContinueOnTrue, SubLoopBlocksFirst[0]); + SubTerm->setSuccessor(SubLoopContinueOnTrue, AftBlocksFirst[0]); + updatePHIBlocks(SubLoopBlocksFirst[0], ForeBlocksLast[0], + ForeBlocksLast.back()); + updatePHIBlocks(SubLoopBlocksFirst[0], SubLoopBlocksLast[0], + SubLoopBlocksLast.back()); + + for (unsigned It = 1; It != Count; It++) { + // Replace the conditional branch of the previous iteration subloop with an + // unconditional one to this one + BranchInst *SubTerm = + cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator()); + BranchInst::Create(SubLoopBlocksFirst[It], SubTerm); + SubTerm->eraseFromParent(); + + updatePHIBlocks(SubLoopBlocksFirst[It], ForeBlocksLast[It], + ForeBlocksLast.back()); + updatePHIBlocks(SubLoopBlocksFirst[It], SubLoopBlocksLast[It], + SubLoopBlocksLast.back()); + movePHIs(SubLoopBlocksFirst[It], SubLoopBlocksFirst[0]); + } + + // Aft blocks successors and phis + BranchInst *Term = cast<BranchInst>(AftBlocksLast.back()->getTerminator()); + if (CompletelyUnroll) { + BranchInst::Create(LoopExit, Term); + Term->eraseFromParent(); + } else { + Term->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]); + } + updatePHIBlocks(AftBlocksFirst[0], SubLoopBlocksLast[0], + SubLoopBlocksLast.back()); + + for (unsigned It = 1; It != Count; It++) { + // Replace the conditional branch of the previous iteration subloop with an + // unconditional one to this one + BranchInst *AftTerm = + cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator()); + BranchInst::Create(AftBlocksFirst[It], AftTerm); + AftTerm->eraseFromParent(); + + updatePHIBlocks(AftBlocksFirst[It], SubLoopBlocksLast[It], + SubLoopBlocksLast.back()); + movePHIs(AftBlocksFirst[It], AftBlocksFirst[0]); + } + + // Dominator Tree. Remove the old links between Fore, Sub and Aft, adding the + // new ones required. + if (Count != 1) { + SmallVector<DominatorTree::UpdateType, 4> DTUpdates; + DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, ForeBlocksLast[0], + SubLoopBlocksFirst[0]); + DTUpdates.emplace_back(DominatorTree::UpdateKind::Delete, + SubLoopBlocksLast[0], AftBlocksFirst[0]); + + DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert, + ForeBlocksLast.back(), SubLoopBlocksFirst[0]); + DTUpdates.emplace_back(DominatorTree::UpdateKind::Insert, + SubLoopBlocksLast.back(), AftBlocksFirst[0]); + DT->applyUpdates(DTUpdates); + } + + // Merge adjacent basic blocks, if possible. + SmallPtrSet<BasicBlock *, 16> MergeBlocks; + MergeBlocks.insert(ForeBlocksLast.begin(), ForeBlocksLast.end()); + MergeBlocks.insert(SubLoopBlocksLast.begin(), SubLoopBlocksLast.end()); + MergeBlocks.insert(AftBlocksLast.begin(), AftBlocksLast.end()); + while (!MergeBlocks.empty()) { + BasicBlock *BB = *MergeBlocks.begin(); + BranchInst *Term = dyn_cast<BranchInst>(BB->getTerminator()); + if (Term && Term->isUnconditional() && L->contains(Term->getSuccessor(0))) { + BasicBlock *Dest = Term->getSuccessor(0); + if (BasicBlock *Fold = foldBlockIntoPredecessor(Dest, LI, SE, DT)) { + // Don't remove BB and add Fold as they are the same BB + assert(Fold == BB); + (void)Fold; + MergeBlocks.erase(Dest); + } else + MergeBlocks.erase(BB); + } else + MergeBlocks.erase(BB); + } + + // At this point, the code is well formed. We now do a quick sweep over the + // inserted code, doing constant propagation and dead code elimination as we + // go. + simplifyLoopAfterUnroll(SubLoop, true, LI, SE, DT, AC); + simplifyLoopAfterUnroll(L, !CompletelyUnroll && Count > 1, LI, SE, DT, AC); + + NumCompletelyUnrolledAndJammed += CompletelyUnroll; + ++NumUnrolledAndJammed; + +#ifndef NDEBUG + // We shouldn't have done anything to break loop simplify form or LCSSA. + Loop *OuterL = L->getParentLoop(); + Loop *OutestLoop = OuterL ? OuterL : (!CompletelyUnroll ? L : SubLoop); + assert(OutestLoop->isRecursivelyLCSSAForm(*DT, *LI)); + if (!CompletelyUnroll) + assert(L->isLoopSimplifyForm()); + assert(SubLoop->isLoopSimplifyForm()); + assert(DT->verify()); +#endif + + // Update LoopInfo if the loop is completely removed. + if (CompletelyUnroll) + LI->erase(L); + + return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled + : LoopUnrollResult::PartiallyUnrolled; +} + +static bool getLoadsAndStores(BasicBlockSet &Blocks, + SmallVector<Value *, 4> &MemInstr) { + // Scan the BBs and collect legal loads and stores. + // Returns false if non-simple loads/stores are found. + for (BasicBlock *BB : Blocks) { + for (Instruction &I : *BB) { + if (auto *Ld = dyn_cast<LoadInst>(&I)) { + if (!Ld->isSimple()) + return false; + MemInstr.push_back(&I); + } else if (auto *St = dyn_cast<StoreInst>(&I)) { + if (!St->isSimple()) + return false; + MemInstr.push_back(&I); + } else if (I.mayReadOrWriteMemory()) { + return false; + } + } + } + return true; +} + +static bool checkDependencies(SmallVector<Value *, 4> &Earlier, + SmallVector<Value *, 4> &Later, + unsigned LoopDepth, bool InnerLoop, + DependenceInfo &DI) { + // Use DA to check for dependencies between loads and stores that make unroll + // and jam invalid + for (Value *I : Earlier) { + for (Value *J : Later) { + Instruction *Src = cast<Instruction>(I); + Instruction *Dst = cast<Instruction>(J); + if (Src == Dst) + continue; + // Ignore Input dependencies. + if (isa<LoadInst>(Src) && isa<LoadInst>(Dst)) + continue; + + // Track dependencies, and if we find them take a conservative approach + // by allowing only = or < (not >), altough some > would be safe + // (depending upon unroll width). + // For the inner loop, we need to disallow any (> <) dependencies + // FIXME: Allow > so long as distance is less than unroll width + if (auto D = DI.depends(Src, Dst, true)) { + assert(D->isOrdered() && "Expected an output, flow or anti dep."); + + if (D->isConfused()) + return false; + if (!InnerLoop) { + if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT) + return false; + } else { + assert(LoopDepth + 1 <= D->getLevels()); + if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT && + D->getDirection(LoopDepth + 1) & Dependence::DVEntry::LT) + return false; + } + } + } + } + return true; +} + +static bool checkDependencies(Loop *L, BasicBlockSet &ForeBlocks, + BasicBlockSet &SubLoopBlocks, + BasicBlockSet &AftBlocks, DependenceInfo &DI) { + // Get all loads/store pairs for each blocks + SmallVector<Value *, 4> ForeMemInstr; + SmallVector<Value *, 4> SubLoopMemInstr; + SmallVector<Value *, 4> AftMemInstr; + if (!getLoadsAndStores(ForeBlocks, ForeMemInstr) || + !getLoadsAndStores(SubLoopBlocks, SubLoopMemInstr) || + !getLoadsAndStores(AftBlocks, AftMemInstr)) + return false; + + // Check for dependencies between any blocks that may change order + unsigned LoopDepth = L->getLoopDepth(); + return checkDependencies(ForeMemInstr, SubLoopMemInstr, LoopDepth, false, + DI) && + checkDependencies(ForeMemInstr, AftMemInstr, LoopDepth, false, DI) && + checkDependencies(SubLoopMemInstr, AftMemInstr, LoopDepth, false, + DI) && + checkDependencies(SubLoopMemInstr, SubLoopMemInstr, LoopDepth, true, + DI); +} + +bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT, + DependenceInfo &DI) { + /* We currently handle outer loops like this: + | + ForeFirst <----\ } + Blocks | } ForeBlocks + ForeLast | } + | | + SubLoopFirst <\ | } + Blocks | | } SubLoopBlocks + SubLoopLast -/ | } + | | + AftFirst | } + Blocks | } AftBlocks + AftLast ------/ } + | + + There are (theoretically) any number of blocks in ForeBlocks, SubLoopBlocks + and AftBlocks, providing that there is one edge from Fores to SubLoops, + one edge from SubLoops to Afts and a single outer loop exit (from Afts). + In practice we currently limit Aft blocks to a single block, and limit + things further in the profitablility checks of the unroll and jam pass. + + Because of the way we rearrange basic blocks, we also require that + the Fore blocks on all unrolled iterations are safe to move before the + SubLoop blocks of all iterations. So we require that the phi node looping + operands of ForeHeader can be moved to at least the end of ForeEnd, so that + we can arrange cloned Fore Blocks before the subloop and match up Phi's + correctly. + + i.e. The old order of blocks used to be F1 S1_1 S1_2 A1 F2 S2_1 S2_2 A2. + It needs to be safe to tranform this to F1 F2 S1_1 S2_1 S1_2 S2_2 A1 A2. + + There are then a number of checks along the lines of no calls, no + exceptions, inner loop IV is consistent, etc. Note that for loops requiring + runtime unrolling, UnrollRuntimeLoopRemainder can also fail in + UnrollAndJamLoop if the trip count cannot be easily calculated. + */ + + if (!L->isLoopSimplifyForm() || L->getSubLoops().size() != 1) + return false; + Loop *SubLoop = L->getSubLoops()[0]; + if (!SubLoop->isLoopSimplifyForm()) + return false; + + BasicBlock *Header = L->getHeader(); + BasicBlock *Latch = L->getLoopLatch(); + BasicBlock *Exit = L->getExitingBlock(); + BasicBlock *SubLoopHeader = SubLoop->getHeader(); + BasicBlock *SubLoopLatch = SubLoop->getLoopLatch(); + BasicBlock *SubLoopExit = SubLoop->getExitingBlock(); + + if (Latch != Exit) + return false; + if (SubLoopLatch != SubLoopExit) + return false; + + if (Header->hasAddressTaken() || SubLoopHeader->hasAddressTaken()) + return false; + + // Split blocks into Fore/SubLoop/Aft based on dominators + BasicBlockSet SubLoopBlocks; + BasicBlockSet ForeBlocks; + BasicBlockSet AftBlocks; + if (!partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, + AftBlocks, &DT)) + return false; + + // Aft blocks may need to move instructions to fore blocks, which becomes more + // difficult if there are multiple (potentially conditionally executed) + // blocks. For now we just exclude loops with multiple aft blocks. + if (AftBlocks.size() != 1) + return false; + + // Check inner loop IV is consistent between all iterations + const SCEV *SubLoopBECountSC = SE.getExitCount(SubLoop, SubLoopLatch); + if (isa<SCEVCouldNotCompute>(SubLoopBECountSC) || + !SubLoopBECountSC->getType()->isIntegerTy()) + return false; + ScalarEvolution::LoopDisposition LD = + SE.getLoopDisposition(SubLoopBECountSC, L); + if (LD != ScalarEvolution::LoopInvariant) + return false; + + // Check the loop safety info for exceptions. + LoopSafetyInfo LSI; + computeLoopSafetyInfo(&LSI, L); + if (LSI.MayThrow) + return false; + + // We've ruled out the easy stuff and now need to check that there are no + // interdependencies which may prevent us from moving the: + // ForeBlocks before Subloop and AftBlocks. + // Subloop before AftBlocks. + // ForeBlock phi operands before the subloop + + // Make sure we can move all instructions we need to before the subloop + if (!processHeaderPhiOperands( + Header, Latch, AftBlocks, [&AftBlocks, &SubLoop](Instruction *I) { + if (SubLoop->contains(I->getParent())) + return false; + if (AftBlocks.count(I->getParent())) { + // If we hit a phi node in afts we know we are done (probably + // LCSSA) + if (isa<PHINode>(I)) + return false; + // Can't move instructions with side effects or memory + // reads/writes + if (I->mayHaveSideEffects() || I->mayReadOrWriteMemory()) + return false; + } + // Keep going + return true; + })) + return false; + + // Check for memory dependencies which prohibit the unrolling we are doing. + // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check + // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub. + if (!checkDependencies(L, ForeBlocks, SubLoopBlocks, AftBlocks, DI)) + return false; + + return true; +} diff --git a/lib/Transforms/Utils/LoopUnrollPeel.cpp b/lib/Transforms/Utils/LoopUnrollPeel.cpp index c84ae7d693d7..13794c53f24b 100644 --- a/lib/Transforms/Utils/LoopUnrollPeel.cpp +++ b/lib/Transforms/Utils/LoopUnrollPeel.cpp @@ -20,6 +20,7 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Dominators.h" @@ -30,6 +31,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -46,6 +48,7 @@ #include <limits> using namespace llvm; +using namespace llvm::PatternMatch; #define DEBUG_TYPE "loop-unroll" @@ -66,7 +69,7 @@ static const unsigned InfiniteIterationsToInvariance = std::numeric_limits<unsigned>::max(); // Check whether we are capable of peeling this loop. -static bool canPeel(Loop *L) { +bool llvm::canPeel(Loop *L) { // Make sure the loop is in simplified form if (!L->isLoopSimplifyForm()) return false; @@ -136,11 +139,109 @@ static unsigned calculateIterationsToInvariance( return ToInvariance; } +// Return the number of iterations to peel off that make conditions in the +// body true/false. For example, if we peel 2 iterations off the loop below, +// the condition i < 2 can be evaluated at compile time. +// for (i = 0; i < n; i++) +// if (i < 2) +// .. +// else +// .. +// } +static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount, + ScalarEvolution &SE) { + assert(L.isLoopSimplifyForm() && "Loop needs to be in loop simplify form"); + unsigned DesiredPeelCount = 0; + + for (auto *BB : L.blocks()) { + auto *BI = dyn_cast<BranchInst>(BB->getTerminator()); + if (!BI || BI->isUnconditional()) + continue; + + // Ignore loop exit condition. + if (L.getLoopLatch() == BB) + continue; + + Value *Condition = BI->getCondition(); + Value *LeftVal, *RightVal; + CmpInst::Predicate Pred; + if (!match(Condition, m_ICmp(Pred, m_Value(LeftVal), m_Value(RightVal)))) + continue; + + const SCEV *LeftSCEV = SE.getSCEV(LeftVal); + const SCEV *RightSCEV = SE.getSCEV(RightVal); + + // Do not consider predicates that are known to be true or false + // independently of the loop iteration. + if (SE.isKnownPredicate(Pred, LeftSCEV, RightSCEV) || + SE.isKnownPredicate(ICmpInst::getInversePredicate(Pred), LeftSCEV, + RightSCEV)) + continue; + + // Check if we have a condition with one AddRec and one non AddRec + // expression. Normalize LeftSCEV to be the AddRec. + if (!isa<SCEVAddRecExpr>(LeftSCEV)) { + if (isa<SCEVAddRecExpr>(RightSCEV)) { + std::swap(LeftSCEV, RightSCEV); + Pred = ICmpInst::getSwappedPredicate(Pred); + } else + continue; + } + + const SCEVAddRecExpr *LeftAR = cast<SCEVAddRecExpr>(LeftSCEV); + + // Avoid huge SCEV computations in the loop below, make sure we only + // consider AddRecs of the loop we are trying to peel and avoid + // non-monotonic predicates, as we will not be able to simplify the loop + // body. + // FIXME: For the non-monotonic predicates ICMP_EQ and ICMP_NE we can + // simplify the loop, if we peel 1 additional iteration, if there + // is no wrapping. + bool Increasing; + if (!LeftAR->isAffine() || LeftAR->getLoop() != &L || + !SE.isMonotonicPredicate(LeftAR, Pred, Increasing)) + continue; + (void)Increasing; + + // Check if extending the current DesiredPeelCount lets us evaluate Pred + // or !Pred in the loop body statically. + unsigned NewPeelCount = DesiredPeelCount; + + const SCEV *IterVal = LeftAR->evaluateAtIteration( + SE.getConstant(LeftSCEV->getType(), NewPeelCount), SE); + + // If the original condition is not known, get the negated predicate + // (which holds on the else branch) and check if it is known. This allows + // us to peel of iterations that make the original condition false. + if (!SE.isKnownPredicate(Pred, IterVal, RightSCEV)) + Pred = ICmpInst::getInversePredicate(Pred); + + const SCEV *Step = LeftAR->getStepRecurrence(SE); + while (NewPeelCount < MaxPeelCount && + SE.isKnownPredicate(Pred, IterVal, RightSCEV)) { + IterVal = SE.getAddExpr(IterVal, Step); + NewPeelCount++; + } + + // Only peel the loop if the monotonic predicate !Pred becomes known in the + // first iteration of the loop body after peeling. + if (NewPeelCount > DesiredPeelCount && + SE.isKnownPredicate(ICmpInst::getInversePredicate(Pred), IterVal, + RightSCEV)) + DesiredPeelCount = NewPeelCount; + } + + return DesiredPeelCount; +} + // Return the number of iterations we want to peel off. void llvm::computePeelCount(Loop *L, unsigned LoopSize, TargetTransformInfo::UnrollingPreferences &UP, - unsigned &TripCount) { + unsigned &TripCount, ScalarEvolution &SE) { assert(LoopSize > 0 && "Zero loop size is not allowed!"); + // Save the UP.PeelCount value set by the target in + // TTI.getUnrollingPreferences or by the flag -unroll-peel-count. + unsigned TargetPeelCount = UP.PeelCount; UP.PeelCount = 0; if (!canPeel(L)) return; @@ -149,6 +250,19 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, if (!L->empty()) return; + // If the user provided a peel count, use that. + bool UserPeelCount = UnrollForcePeelCount.getNumOccurrences() > 0; + if (UserPeelCount) { + LLVM_DEBUG(dbgs() << "Force-peeling first " << UnrollForcePeelCount + << " iterations.\n"); + UP.PeelCount = UnrollForcePeelCount; + return; + } + + // Skip peeling if it's disabled. + if (!UP.AllowPeeling) + return; + // Here we try to get rid of Phis which become invariants after 1, 2, ..., N // iterations of the loop. For this we compute the number for iterations after // which every Phi is guaranteed to become an invariant, and try to peel the @@ -160,7 +274,9 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, SmallDenseMap<PHINode *, unsigned> IterationsToInvariance; // Now go through all Phis to calculate their the number of iterations they // need to become invariants. - unsigned DesiredPeelCount = 0; + // Start the max computation with the UP.PeelCount value set by the target + // in TTI.getUnrollingPreferences or by the flag -unroll-peel-count. + unsigned DesiredPeelCount = TargetPeelCount; BasicBlock *BackEdge = L->getLoopLatch(); assert(BackEdge && "Loop is not in simplified form?"); for (auto BI = L->getHeader()->begin(); isa<PHINode>(&*BI); ++BI) { @@ -170,15 +286,21 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, if (ToInvariance != InfiniteIterationsToInvariance) DesiredPeelCount = std::max(DesiredPeelCount, ToInvariance); } + + // Pay respect to limitations implied by loop size and the max peel count. + unsigned MaxPeelCount = UnrollPeelMaxCount; + MaxPeelCount = std::min(MaxPeelCount, UP.Threshold / LoopSize - 1); + + DesiredPeelCount = std::max(DesiredPeelCount, + countToEliminateCompares(*L, MaxPeelCount, SE)); + if (DesiredPeelCount > 0) { - // Pay respect to limitations implied by loop size and the max peel count. - unsigned MaxPeelCount = UnrollPeelMaxCount; - MaxPeelCount = std::min(MaxPeelCount, UP.Threshold / LoopSize - 1); DesiredPeelCount = std::min(DesiredPeelCount, MaxPeelCount); // Consider max peel count limitation. assert(DesiredPeelCount > 0 && "Wrong loop size estimation?"); - DEBUG(dbgs() << "Peel " << DesiredPeelCount << " iteration(s) to turn" - << " some Phis into invariants.\n"); + LLVM_DEBUG(dbgs() << "Peel " << DesiredPeelCount + << " iteration(s) to turn" + << " some Phis into invariants.\n"); UP.PeelCount = DesiredPeelCount; return; } @@ -189,44 +311,37 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, if (TripCount) return; - // If the user provided a peel count, use that. - bool UserPeelCount = UnrollForcePeelCount.getNumOccurrences() > 0; - if (UserPeelCount) { - DEBUG(dbgs() << "Force-peeling first " << UnrollForcePeelCount - << " iterations.\n"); - UP.PeelCount = UnrollForcePeelCount; - return; - } - // If we don't know the trip count, but have reason to believe the average // trip count is low, peeling should be beneficial, since we will usually // hit the peeled section. // We only do this in the presence of profile information, since otherwise // our estimates of the trip count are not reliable enough. - if (UP.AllowPeeling && L->getHeader()->getParent()->hasProfileData()) { + if (L->getHeader()->getParent()->hasProfileData()) { Optional<unsigned> PeelCount = getLoopEstimatedTripCount(L); if (!PeelCount) return; - DEBUG(dbgs() << "Profile-based estimated trip count is " << *PeelCount - << "\n"); + LLVM_DEBUG(dbgs() << "Profile-based estimated trip count is " << *PeelCount + << "\n"); if (*PeelCount) { if ((*PeelCount <= UnrollPeelMaxCount) && (LoopSize * (*PeelCount + 1) <= UP.Threshold)) { - DEBUG(dbgs() << "Peeling first " << *PeelCount << " iterations.\n"); + LLVM_DEBUG(dbgs() << "Peeling first " << *PeelCount + << " iterations.\n"); UP.PeelCount = *PeelCount; return; } - DEBUG(dbgs() << "Requested peel count: " << *PeelCount << "\n"); - DEBUG(dbgs() << "Max peel count: " << UnrollPeelMaxCount << "\n"); - DEBUG(dbgs() << "Peel cost: " << LoopSize * (*PeelCount + 1) << "\n"); - DEBUG(dbgs() << "Max peel cost: " << UP.Threshold << "\n"); + LLVM_DEBUG(dbgs() << "Requested peel count: " << *PeelCount << "\n"); + LLVM_DEBUG(dbgs() << "Max peel count: " << UnrollPeelMaxCount << "\n"); + LLVM_DEBUG(dbgs() << "Peel cost: " << LoopSize * (*PeelCount + 1) + << "\n"); + LLVM_DEBUG(dbgs() << "Max peel cost: " << UP.Threshold << "\n"); } } } -/// \brief Update the branch weights of the latch of a peeled-off loop +/// Update the branch weights of the latch of a peeled-off loop /// iteration. /// This sets the branch weights for the latch of the recently peeled off loop /// iteration correctly. @@ -267,12 +382,12 @@ static void updateBranchWeights(BasicBlock *Header, BranchInst *LatchBR, } } -/// \brief Clones the body of the loop L, putting it between \p InsertTop and \p +/// Clones the body of the loop L, putting it between \p InsertTop and \p /// InsertBot. /// \param IterNumber The serial number of the iteration currently being /// peeled off. /// \param Exit The exit block of the original loop. -/// \param[out] NewBlocks A list of the the blocks in the newly created clone +/// \param[out] NewBlocks A list of the blocks in the newly created clone /// \param[out] VMap The value map between the loop and the new clone. /// \param LoopBlocks A helper for DFS-traversal of the loop. /// \param LVMap A value-map that maps instructions from the original loop to @@ -376,7 +491,7 @@ static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop, LVMap[KV.first] = KV.second; } -/// \brief Peel off the first \p PeelCount iterations of loop \p L. +/// Peel off the first \p PeelCount iterations of loop \p L. /// /// Note that this does not peel them off as a single straight-line block. /// Rather, each iteration is peeled off separately, and needs to check the @@ -388,8 +503,8 @@ static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop, bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC, bool PreserveLCSSA) { - if (!canPeel(L)) - return false; + assert(PeelCount > 0 && "Attempt to peel out zero iterations?"); + assert(canPeel(L) && "Attempt to peel a loop which is not peelable?"); LoopBlocksDFS LoopBlocks(L); LoopBlocks.perform(LI); @@ -500,10 +615,7 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, // the original loop body. if (Iter == 0) DT->changeImmediateDominator(Exit, cast<BasicBlock>(LVMap[Latch])); -#ifndef NDEBUG - if (VerifyDomInfo) - DT->verifyDomTree(); -#endif + assert(DT->verify(DominatorTree::VerificationLevel::Fast)); } updateBranchWeights(InsertBot, cast<BranchInst>(VMap[LatchBR]), Iter, diff --git a/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/lib/Transforms/Utils/LoopUnrollRuntime.cpp index e00541d3c812..0057b4ba7ce1 100644 --- a/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -21,8 +21,8 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -33,7 +33,7 @@ #include "llvm/IR/Module.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -80,25 +80,21 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, // The new PHI node value is added as an operand of a PHI node in either // the loop header or the loop exit block. for (BasicBlock *Succ : successors(Latch)) { - for (Instruction &BBI : *Succ) { - PHINode *PN = dyn_cast<PHINode>(&BBI); - // Exit when we passed all PHI nodes. - if (!PN) - break; + for (PHINode &PN : Succ->phis()) { // Add a new PHI node to the prolog end block and add the // appropriate incoming values. - PHINode *NewPN = PHINode::Create(PN->getType(), 2, PN->getName() + ".unr", + PHINode *NewPN = PHINode::Create(PN.getType(), 2, PN.getName() + ".unr", PrologExit->getFirstNonPHI()); // Adding a value to the new PHI node from the original loop preheader. // This is the value that skips all the prolog code. - if (L->contains(PN)) { - NewPN->addIncoming(PN->getIncomingValueForBlock(NewPreHeader), + if (L->contains(&PN)) { + NewPN->addIncoming(PN.getIncomingValueForBlock(NewPreHeader), PreHeader); } else { - NewPN->addIncoming(UndefValue::get(PN->getType()), PreHeader); + NewPN->addIncoming(UndefValue::get(PN.getType()), PreHeader); } - Value *V = PN->getIncomingValueForBlock(Latch); + Value *V = PN.getIncomingValueForBlock(Latch); if (Instruction *I = dyn_cast<Instruction>(V)) { if (L->contains(I)) { V = VMap.lookup(I); @@ -111,10 +107,10 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, // Update the existing PHI node operand with the value from the // new PHI node. How this is done depends on if the existing // PHI node is in the original loop block, or the exit block. - if (L->contains(PN)) { - PN->setIncomingValue(PN->getBasicBlockIndex(NewPreHeader), NewPN); + if (L->contains(&PN)) { + PN.setIncomingValue(PN.getBasicBlockIndex(NewPreHeader), NewPN); } else { - PN->addIncoming(NewPN, PrologExit); + PN.addIncoming(NewPN, PrologExit); } } } @@ -191,11 +187,7 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, // Exit (EpilogPN) // Update PHI nodes at NewExit and Exit. - for (Instruction &BBI : *NewExit) { - PHINode *PN = dyn_cast<PHINode>(&BBI); - // Exit when we passed all PHI nodes. - if (!PN) - break; + for (PHINode &PN : NewExit->phis()) { // PN should be used in another PHI located in Exit block as // Exit was split by SplitBlockPredecessors into Exit and NewExit // Basicaly it should look like: @@ -207,14 +199,14 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, // // There is EpilogPreHeader incoming block instead of NewExit as // NewExit was spilt 1 more time to get EpilogPreHeader. - assert(PN->hasOneUse() && "The phi should have 1 use"); - PHINode *EpilogPN = cast<PHINode> (PN->use_begin()->getUser()); + assert(PN.hasOneUse() && "The phi should have 1 use"); + PHINode *EpilogPN = cast<PHINode>(PN.use_begin()->getUser()); assert(EpilogPN->getParent() == Exit && "EpilogPN should be in Exit block"); // Add incoming PreHeader from branch around the Loop - PN->addIncoming(UndefValue::get(PN->getType()), PreHeader); + PN.addIncoming(UndefValue::get(PN.getType()), PreHeader); - Value *V = PN->getIncomingValueForBlock(Latch); + Value *V = PN.getIncomingValueForBlock(Latch); Instruction *I = dyn_cast<Instruction>(V); if (I && L->contains(I)) // If value comes from an instruction in the loop add VMap value. @@ -242,23 +234,19 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, // Skip this as we already updated phis in exit blocks. if (!L->contains(Succ)) continue; - for (Instruction &BBI : *Succ) { - PHINode *PN = dyn_cast<PHINode>(&BBI); - // Exit when we passed all PHI nodes. - if (!PN) - break; + for (PHINode &PN : Succ->phis()) { // Add new PHI nodes to the loop exit block and update epilog // PHIs with the new PHI values. - PHINode *NewPN = PHINode::Create(PN->getType(), 2, PN->getName() + ".unr", + PHINode *NewPN = PHINode::Create(PN.getType(), 2, PN.getName() + ".unr", NewExit->getFirstNonPHI()); // Adding a value to the new PHI node from the unrolling loop preheader. - NewPN->addIncoming(PN->getIncomingValueForBlock(NewPreHeader), PreHeader); + NewPN->addIncoming(PN.getIncomingValueForBlock(NewPreHeader), PreHeader); // Adding a value to the new PHI node from the unrolling loop latch. - NewPN->addIncoming(PN->getIncomingValueForBlock(Latch), Latch); + NewPN->addIncoming(PN.getIncomingValueForBlock(Latch), Latch); // Update the existing PHI node operand with the value from the new PHI // node. Corresponding instruction in epilog loop should be PHI. - PHINode *VPN = cast<PHINode>(VMap[&BBI]); + PHINode *VPN = cast<PHINode>(VMap[&PN]); VPN->setIncomingValue(VPN->getBasicBlockIndex(EpilogPreHeader), NewPN); } } @@ -430,8 +418,9 @@ canSafelyUnrollMultiExitLoop(Loop *L, SmallVectorImpl<BasicBlock *> &OtherExits, // UnrollRuntimeMultiExit is true. This will need updating the logic in // connectEpilog/connectProlog. if (!LatchExit->getSinglePredecessor()) { - DEBUG(dbgs() << "Bailout for multi-exit handling when latch exit has >1 " - "predecessor.\n"); + LLVM_DEBUG( + dbgs() << "Bailout for multi-exit handling when latch exit has >1 " + "predecessor.\n"); return false; } // FIXME: We bail out of multi-exit unrolling when epilog loop is generated @@ -540,14 +529,14 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC, bool PreserveLCSSA) { - DEBUG(dbgs() << "Trying runtime unrolling on Loop: \n"); - DEBUG(L->dump()); - DEBUG(UseEpilogRemainder ? dbgs() << "Using epilog remainder.\n" : - dbgs() << "Using prolog remainder.\n"); + LLVM_DEBUG(dbgs() << "Trying runtime unrolling on Loop: \n"); + LLVM_DEBUG(L->dump()); + LLVM_DEBUG(UseEpilogRemainder ? dbgs() << "Using epilog remainder.\n" + : dbgs() << "Using prolog remainder.\n"); // Make sure the loop is in canonical form. if (!L->isLoopSimplifyForm()) { - DEBUG(dbgs() << "Not in simplify form!\n"); + LLVM_DEBUG(dbgs() << "Not in simplify form!\n"); return false; } @@ -573,7 +562,7 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, // Support only single exit and exiting block unless multi-exit loop unrolling is enabled. if (!isMultiExitUnrollingEnabled && (!L->getExitingBlock() || OtherExits.size())) { - DEBUG( + LLVM_DEBUG( dbgs() << "Multiple exit/exiting blocks in loop and multi-exit unrolling not " "enabled!\n"); @@ -593,7 +582,7 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, const SCEV *BECountSC = SE->getExitCount(L, Latch); if (isa<SCEVCouldNotCompute>(BECountSC) || !BECountSC->getType()->isIntegerTy()) { - DEBUG(dbgs() << "Could not compute exit block SCEV\n"); + LLVM_DEBUG(dbgs() << "Could not compute exit block SCEV\n"); return false; } @@ -603,7 +592,7 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, const SCEV *TripCountSC = SE->getAddExpr(BECountSC, SE->getConstant(BECountSC->getType(), 1)); if (isa<SCEVCouldNotCompute>(TripCountSC)) { - DEBUG(dbgs() << "Could not compute trip count SCEV.\n"); + LLVM_DEBUG(dbgs() << "Could not compute trip count SCEV.\n"); return false; } @@ -613,15 +602,16 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, SCEVExpander Expander(*SE, DL, "loop-unroll"); if (!AllowExpensiveTripCount && Expander.isHighCostExpansion(TripCountSC, L, PreHeaderBR)) { - DEBUG(dbgs() << "High cost for expanding trip count scev!\n"); + LLVM_DEBUG(dbgs() << "High cost for expanding trip count scev!\n"); return false; } // This constraint lets us deal with an overflowing trip count easily; see the // comment on ModVal below. if (Log2_32(Count) > BEWidth) { - DEBUG(dbgs() - << "Count failed constraint on overflow trip count calculation.\n"); + LLVM_DEBUG( + dbgs() + << "Count failed constraint on overflow trip count calculation.\n"); return false; } @@ -775,7 +765,7 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, // values from the cloned region. Also update the dominator info for // OtherExits and their immediate successors, since we have new edges into // OtherExits. - SmallSet<BasicBlock*, 8> ImmediateSuccessorsOfExitBlocks; + SmallPtrSet<BasicBlock*, 8> ImmediateSuccessorsOfExitBlocks; for (auto *BB : OtherExits) { for (auto &II : *BB) { @@ -890,10 +880,9 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, NewPreHeader, VMap, DT, LI, PreserveLCSSA); } - // If this loop is nested, then the loop unroller changes the code in the - // parent loop, so the Scalar Evolution pass needs to be run again. - if (Loop *ParentLoop = L->getParentLoop()) - SE->forgetLoop(ParentLoop); + // If this loop is nested, then the loop unroller changes the code in the any + // of its parent loops, so the Scalar Evolution pass needs to be run again. + SE->forgetTopmostLoop(L); // Canonicalize to LoopSimplifyForm both original and remainder loops. We // cannot rely on the LoopUnrollPass to do this because it only does @@ -909,7 +898,7 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, } if (remainderLoop && UnrollRemainder) { - DEBUG(dbgs() << "Unrolling remainder loop\n"); + LLVM_DEBUG(dbgs() << "Unrolling remainder loop\n"); UnrollLoop(remainderLoop, /*Count*/ Count - 1, /*TripCount*/ Count - 1, /*Force*/ false, /*AllowRuntime*/ false, /*AllowExpensiveTripCount*/ false, /*PreserveCondBr*/ true, diff --git a/lib/Transforms/Utils/LoopUtils.cpp b/lib/Transforms/Utils/LoopUtils.cpp index fe106e33bca1..46af120a428b 100644 --- a/lib/Transforms/Utils/LoopUtils.cpp +++ b/lib/Transforms/Utils/LoopUtils.cpp @@ -16,13 +16,16 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MustExecute.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -30,6 +33,7 @@ #include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; @@ -77,10 +81,13 @@ bool RecurrenceDescriptor::isArithmeticRecurrenceKind(RecurrenceKind Kind) { return false; } -Instruction * -RecurrenceDescriptor::lookThroughAnd(PHINode *Phi, Type *&RT, - SmallPtrSetImpl<Instruction *> &Visited, - SmallPtrSetImpl<Instruction *> &CI) { +/// Determines if Phi may have been type-promoted. If Phi has a single user +/// that ANDs the Phi with a type mask, return the user. RT is updated to +/// account for the narrower bit width represented by the mask, and the AND +/// instruction is added to CI. +static Instruction *lookThroughAnd(PHINode *Phi, Type *&RT, + SmallPtrSetImpl<Instruction *> &Visited, + SmallPtrSetImpl<Instruction *> &CI) { if (!Phi->hasOneUse()) return Phi; @@ -101,70 +108,92 @@ RecurrenceDescriptor::lookThroughAnd(PHINode *Phi, Type *&RT, return Phi; } -bool RecurrenceDescriptor::getSourceExtensionKind( - Instruction *Start, Instruction *Exit, Type *RT, bool &IsSigned, - SmallPtrSetImpl<Instruction *> &Visited, - SmallPtrSetImpl<Instruction *> &CI) { +/// Compute the minimal bit width needed to represent a reduction whose exit +/// instruction is given by Exit. +static std::pair<Type *, bool> computeRecurrenceType(Instruction *Exit, + DemandedBits *DB, + AssumptionCache *AC, + DominatorTree *DT) { + bool IsSigned = false; + const DataLayout &DL = Exit->getModule()->getDataLayout(); + uint64_t MaxBitWidth = DL.getTypeSizeInBits(Exit->getType()); + + if (DB) { + // Use the demanded bits analysis to determine the bits that are live out + // of the exit instruction, rounding up to the nearest power of two. If the + // use of demanded bits results in a smaller bit width, we know the value + // must be positive (i.e., IsSigned = false), because if this were not the + // case, the sign bit would have been demanded. + auto Mask = DB->getDemandedBits(Exit); + MaxBitWidth = Mask.getBitWidth() - Mask.countLeadingZeros(); + } + + if (MaxBitWidth == DL.getTypeSizeInBits(Exit->getType()) && AC && DT) { + // If demanded bits wasn't able to limit the bit width, we can try to use + // value tracking instead. This can be the case, for example, if the value + // may be negative. + auto NumSignBits = ComputeNumSignBits(Exit, DL, 0, AC, nullptr, DT); + auto NumTypeBits = DL.getTypeSizeInBits(Exit->getType()); + MaxBitWidth = NumTypeBits - NumSignBits; + KnownBits Bits = computeKnownBits(Exit, DL); + if (!Bits.isNonNegative()) { + // If the value is not known to be non-negative, we set IsSigned to true, + // meaning that we will use sext instructions instead of zext + // instructions to restore the original type. + IsSigned = true; + if (!Bits.isNegative()) + // If the value is not known to be negative, we don't known what the + // upper bit is, and therefore, we don't know what kind of extend we + // will need. In this case, just increase the bit width by one bit and + // use sext. + ++MaxBitWidth; + } + } + if (!isPowerOf2_64(MaxBitWidth)) + MaxBitWidth = NextPowerOf2(MaxBitWidth); + + return std::make_pair(Type::getIntNTy(Exit->getContext(), MaxBitWidth), + IsSigned); +} + +/// Collect cast instructions that can be ignored in the vectorizer's cost +/// model, given a reduction exit value and the minimal type in which the +/// reduction can be represented. +static void collectCastsToIgnore(Loop *TheLoop, Instruction *Exit, + Type *RecurrenceType, + SmallPtrSetImpl<Instruction *> &Casts) { SmallVector<Instruction *, 8> Worklist; - bool FoundOneOperand = false; - unsigned DstSize = RT->getPrimitiveSizeInBits(); + SmallPtrSet<Instruction *, 8> Visited; Worklist.push_back(Exit); - // Traverse the instructions in the reduction expression, beginning with the - // exit value. while (!Worklist.empty()) { - Instruction *I = Worklist.pop_back_val(); - for (Use &U : I->operands()) { - - // Terminate the traversal if the operand is not an instruction, or we - // reach the starting value. - Instruction *J = dyn_cast<Instruction>(U.get()); - if (!J || J == Start) - continue; - - // Otherwise, investigate the operation if it is also in the expression. - if (Visited.count(J)) { - Worklist.push_back(J); + Instruction *Val = Worklist.pop_back_val(); + Visited.insert(Val); + if (auto *Cast = dyn_cast<CastInst>(Val)) + if (Cast->getSrcTy() == RecurrenceType) { + // If the source type of a cast instruction is equal to the recurrence + // type, it will be eliminated, and should be ignored in the vectorizer + // cost model. + Casts.insert(Cast); continue; } - // If the operand is not in Visited, it is not a reduction operation, but - // it does feed into one. Make sure it is either a single-use sign- or - // zero-extend instruction. - CastInst *Cast = dyn_cast<CastInst>(J); - bool IsSExtInst = isa<SExtInst>(J); - if (!Cast || !Cast->hasOneUse() || !(isa<ZExtInst>(J) || IsSExtInst)) - return false; - - // Ensure the source type of the extend is no larger than the reduction - // type. It is not necessary for the types to be identical. - unsigned SrcSize = Cast->getSrcTy()->getPrimitiveSizeInBits(); - if (SrcSize > DstSize) - return false; - - // Furthermore, ensure that all such extends are of the same kind. - if (FoundOneOperand) { - if (IsSigned != IsSExtInst) - return false; - } else { - FoundOneOperand = true; - IsSigned = IsSExtInst; - } - - // Lastly, if the source type of the extend matches the reduction type, - // add the extend to CI so that we can avoid accounting for it in the - // cost model. - if (SrcSize == DstSize) - CI.insert(Cast); - } + // Add all operands to the work list if they are loop-varying values that + // we haven't yet visited. + for (Value *O : cast<User>(Val)->operands()) + if (auto *I = dyn_cast<Instruction>(O)) + if (TheLoop->contains(I) && !Visited.count(I)) + Worklist.push_back(I); } - return true; } bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind, Loop *TheLoop, bool HasFunNoNaNAttr, - RecurrenceDescriptor &RedDes) { + RecurrenceDescriptor &RedDes, + DemandedBits *DB, + AssumptionCache *AC, + DominatorTree *DT) { if (Phi->getNumIncomingValues() != 2) return false; @@ -353,14 +382,49 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind, if (!FoundStartPHI || !FoundReduxOp || !ExitInstruction) return false; - // If we think Phi may have been type-promoted, we also need to ensure that - // all source operands of the reduction are either SExtInsts or ZEstInsts. If - // so, we will be able to evaluate the reduction in the narrower bit width. - if (Start != Phi) - if (!getSourceExtensionKind(Start, ExitInstruction, RecurrenceType, - IsSigned, VisitedInsts, CastInsts)) + if (Start != Phi) { + // If the starting value is not the same as the phi node, we speculatively + // looked through an 'and' instruction when evaluating a potential + // arithmetic reduction to determine if it may have been type-promoted. + // + // We now compute the minimal bit width that is required to represent the + // reduction. If this is the same width that was indicated by the 'and', we + // can represent the reduction in the smaller type. The 'and' instruction + // will be eliminated since it will essentially be a cast instruction that + // can be ignore in the cost model. If we compute a different type than we + // did when evaluating the 'and', the 'and' will not be eliminated, and we + // will end up with different kinds of operations in the recurrence + // expression (e.g., RK_IntegerAND, RK_IntegerADD). We give up if this is + // the case. + // + // The vectorizer relies on InstCombine to perform the actual + // type-shrinking. It does this by inserting instructions to truncate the + // exit value of the reduction to the width indicated by RecurrenceType and + // then extend this value back to the original width. If IsSigned is false, + // a 'zext' instruction will be generated; otherwise, a 'sext' will be + // used. + // + // TODO: We should not rely on InstCombine to rewrite the reduction in the + // smaller type. We should just generate a correctly typed expression + // to begin with. + Type *ComputedType; + std::tie(ComputedType, IsSigned) = + computeRecurrenceType(ExitInstruction, DB, AC, DT); + if (ComputedType != RecurrenceType) return false; + // The recurrence expression will be represented in a narrower type. If + // there are any cast instructions that will be unnecessary, collect them + // in CastInsts. Note that the 'and' instruction was already included in + // this list. + // + // TODO: A better way to represent this may be to tag in some way all the + // instructions that are a part of the reduction. The vectorizer cost + // model could then apply the recurrence type to these instructions, + // without needing a white list of instructions to ignore. + collectCastsToIgnore(TheLoop, ExitInstruction, RecurrenceType, CastInsts); + } + // We found a reduction var if we have reached the original phi node and we // only have a single instruction with out-of-loop users. @@ -480,48 +544,59 @@ bool RecurrenceDescriptor::hasMultipleUsesOf( return false; } bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop, - RecurrenceDescriptor &RedDes) { + RecurrenceDescriptor &RedDes, + DemandedBits *DB, AssumptionCache *AC, + DominatorTree *DT) { BasicBlock *Header = TheLoop->getHeader(); Function &F = *Header->getParent(); bool HasFunNoNaNAttr = F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true"; - if (AddReductionVar(Phi, RK_IntegerAdd, TheLoop, HasFunNoNaNAttr, RedDes)) { - DEBUG(dbgs() << "Found an ADD reduction PHI." << *Phi << "\n"); + if (AddReductionVar(Phi, RK_IntegerAdd, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an ADD reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_IntegerMult, TheLoop, HasFunNoNaNAttr, RedDes)) { - DEBUG(dbgs() << "Found a MUL reduction PHI." << *Phi << "\n"); + if (AddReductionVar(Phi, RK_IntegerMult, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found a MUL reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_IntegerOr, TheLoop, HasFunNoNaNAttr, RedDes)) { - DEBUG(dbgs() << "Found an OR reduction PHI." << *Phi << "\n"); + if (AddReductionVar(Phi, RK_IntegerOr, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an OR reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_IntegerAnd, TheLoop, HasFunNoNaNAttr, RedDes)) { - DEBUG(dbgs() << "Found an AND reduction PHI." << *Phi << "\n"); + if (AddReductionVar(Phi, RK_IntegerAnd, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an AND reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_IntegerXor, TheLoop, HasFunNoNaNAttr, RedDes)) { - DEBUG(dbgs() << "Found a XOR reduction PHI." << *Phi << "\n"); + if (AddReductionVar(Phi, RK_IntegerXor, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found a XOR reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_IntegerMinMax, TheLoop, HasFunNoNaNAttr, - RedDes)) { - DEBUG(dbgs() << "Found a MINMAX reduction PHI." << *Phi << "\n"); + if (AddReductionVar(Phi, RK_IntegerMinMax, TheLoop, HasFunNoNaNAttr, RedDes, + DB, AC, DT)) { + LLVM_DEBUG(dbgs() << "Found a MINMAX reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_FloatMult, TheLoop, HasFunNoNaNAttr, RedDes)) { - DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n"); + if (AddReductionVar(Phi, RK_FloatMult, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_FloatAdd, TheLoop, HasFunNoNaNAttr, RedDes)) { - DEBUG(dbgs() << "Found an FAdd reduction PHI." << *Phi << "\n"); + if (AddReductionVar(Phi, RK_FloatAdd, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an FAdd reduction PHI." << *Phi << "\n"); return true; } - if (AddReductionVar(Phi, RK_FloatMinMax, TheLoop, HasFunNoNaNAttr, RedDes)) { - DEBUG(dbgs() << "Found an float MINMAX reduction PHI." << *Phi << "\n"); + if (AddReductionVar(Phi, RK_FloatMinMax, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an float MINMAX reduction PHI." << *Phi + << "\n"); return true; } // Not a reduction of known type. @@ -849,13 +924,13 @@ bool InductionDescriptor::isFPInductionPHI(PHINode *Phi, const Loop *TheLoop, } /// This function is called when we suspect that the update-chain of a phi node -/// (whose symbolic SCEV expression sin \p PhiScev) contains redundant casts, -/// that can be ignored. (This can happen when the PSCEV rewriter adds a runtime -/// predicate P under which the SCEV expression for the phi can be the -/// AddRecurrence \p AR; See createAddRecFromPHIWithCast). We want to find the -/// cast instructions that are involved in the update-chain of this induction. -/// A caller that adds the required runtime predicate can be free to drop these -/// cast instructions, and compute the phi using \p AR (instead of some scev +/// (whose symbolic SCEV expression sin \p PhiScev) contains redundant casts, +/// that can be ignored. (This can happen when the PSCEV rewriter adds a runtime +/// predicate P under which the SCEV expression for the phi can be the +/// AddRecurrence \p AR; See createAddRecFromPHIWithCast). We want to find the +/// cast instructions that are involved in the update-chain of this induction. +/// A caller that adds the required runtime predicate can be free to drop these +/// cast instructions, and compute the phi using \p AR (instead of some scev /// expression with casts). /// /// For example, without a predicate the scev expression can take the following @@ -890,7 +965,7 @@ static bool getCastsForInductionPHI(PredicatedScalarEvolution &PSE, assert(PSE.getSCEV(PN) == AR && "Unexpected phi node SCEV expression"); const Loop *L = AR->getLoop(); - // Find any cast instructions that participate in the def-use chain of + // Find any cast instructions that participate in the def-use chain of // PhiScev in the loop. // FORNOW/TODO: We currently expect the def-use chain to include only // two-operand instructions, where one of the operands is an invariant. @@ -978,7 +1053,7 @@ bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop, AR = PSE.getAsAddRec(Phi); if (!AR) { - DEBUG(dbgs() << "LV: PHI is not a poly recurrence.\n"); + LLVM_DEBUG(dbgs() << "LV: PHI is not a poly recurrence.\n"); return false; } @@ -1012,14 +1087,15 @@ bool InductionDescriptor::isInductionPHI( const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PhiScev); if (!AR) { - DEBUG(dbgs() << "LV: PHI is not a poly recurrence.\n"); + LLVM_DEBUG(dbgs() << "LV: PHI is not a poly recurrence.\n"); return false; } if (AR->getLoop() != TheLoop) { // FIXME: We should treat this as a uniform. Unfortunately, we // don't currently know how to handled uniform PHIs. - DEBUG(dbgs() << "LV: PHI is a recurrence with respect to an outer loop.\n"); + LLVM_DEBUG( + dbgs() << "LV: PHI is a recurrence with respect to an outer loop.\n"); return false; } @@ -1100,11 +1176,12 @@ bool llvm::formDedicatedExitBlocks(Loop *L, DominatorTree *DT, LoopInfo *LI, BB, InLoopPredecessors, ".loopexit", DT, LI, PreserveLCSSA); if (!NewExitBB) - DEBUG(dbgs() << "WARNING: Can't create a dedicated exit block for loop: " - << *L << "\n"); + LLVM_DEBUG( + dbgs() << "WARNING: Can't create a dedicated exit block for loop: " + << *L << "\n"); else - DEBUG(dbgs() << "LoopSimplify: Creating dedicated exit block " - << NewExitBB->getName() << "\n"); + LLVM_DEBUG(dbgs() << "LoopSimplify: Creating dedicated exit block " + << NewExitBB->getName() << "\n"); return true; }; @@ -1127,7 +1204,7 @@ bool llvm::formDedicatedExitBlocks(Loop *L, DominatorTree *DT, LoopInfo *LI, return Changed; } -/// \brief Returns the instructions that use values defined in the loop. +/// Returns the instructions that use values defined in the loop. SmallVector<Instruction *, 8> llvm::findDefsUsedOutsideOfLoop(Loop *L) { SmallVector<Instruction *, 8> UsedOutside; @@ -1204,7 +1281,7 @@ void llvm::initializeLoopPassPass(PassRegistry &Registry) { INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) } -/// \brief Find string metadata for loop +/// Find string metadata for loop /// /// If it has a value (e.g. {"llvm.distribute", 1} return the value as an /// operand or null otherwise. If the string metadata is not found return @@ -1321,13 +1398,12 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT = nullptr, // Rewrite phis in the exit block to get their inputs from the Preheader // instead of the exiting block. - BasicBlock::iterator BI = ExitBlock->begin(); - while (PHINode *P = dyn_cast<PHINode>(BI)) { + for (PHINode &P : ExitBlock->phis()) { // Set the zero'th element of Phi to be from the preheader and remove all // other incoming values. Given the loop has dedicated exits, all other // incoming values must be from the exiting blocks. int PredIndex = 0; - P->setIncomingBlock(PredIndex, Preheader); + P.setIncomingBlock(PredIndex, Preheader); // Removes all incoming values from all other exiting blocks (including // duplicate values from an exiting block). // Nuke all entries except the zero'th entry which is the preheader entry. @@ -1335,13 +1411,12 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT = nullptr, // below, to keep the indices valid for deletion (removeIncomingValues // updates getNumIncomingValues and shifts all values down into the operand // being deleted). - for (unsigned i = 0, e = P->getNumIncomingValues() - 1; i != e; ++i) - P->removeIncomingValue(e - i, false); + for (unsigned i = 0, e = P.getNumIncomingValues() - 1; i != e; ++i) + P.removeIncomingValue(e - i, false); - assert((P->getNumIncomingValues() == 1 && - P->getIncomingBlock(PredIndex) == Preheader) && + assert((P.getNumIncomingValues() == 1 && + P.getIncomingBlock(PredIndex) == Preheader) && "Should have exactly one value and that's from the preheader!"); - ++BI; } // Disconnect the loop body by branching directly to its exit. @@ -1358,6 +1433,32 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT = nullptr, DT->deleteEdge(Preheader, L->getHeader()); } + // Given LCSSA form is satisfied, we should not have users of instructions + // within the dead loop outside of the loop. However, LCSSA doesn't take + // unreachable uses into account. We handle them here. + // We could do it after drop all references (in this case all users in the + // loop will be already eliminated and we have less work to do but according + // to API doc of User::dropAllReferences only valid operation after dropping + // references, is deletion. So let's substitute all usages of + // instruction from the loop with undef value of corresponding type first. + for (auto *Block : L->blocks()) + for (Instruction &I : *Block) { + auto *Undef = UndefValue::get(I.getType()); + for (Value::use_iterator UI = I.use_begin(), E = I.use_end(); UI != E;) { + Use &U = *UI; + ++UI; + if (auto *Usr = dyn_cast<Instruction>(U.getUser())) + if (L->contains(Usr->getParent())) + continue; + // If we have a DT then we can check that uses outside a loop only in + // unreachable block. + if (DT) + assert(!DT->isReachableFromEntry(U) && + "Unexpected user in reachable block"); + U.set(Undef); + } + } + // Remove the block from the reference counting scheme, so that we can // delete it freely later. for (auto *Block : L->blocks()) @@ -1385,54 +1486,12 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT = nullptr, } } -/// Returns true if the instruction in a loop is guaranteed to execute at least -/// once. -bool llvm::isGuaranteedToExecute(const Instruction &Inst, - const DominatorTree *DT, const Loop *CurLoop, - const LoopSafetyInfo *SafetyInfo) { - // We have to check to make sure that the instruction dominates all - // of the exit blocks. If it doesn't, then there is a path out of the loop - // which does not execute this instruction, so we can't hoist it. - - // If the instruction is in the header block for the loop (which is very - // common), it is always guaranteed to dominate the exit blocks. Since this - // is a common case, and can save some work, check it now. - if (Inst.getParent() == CurLoop->getHeader()) - // If there's a throw in the header block, we can't guarantee we'll reach - // Inst. - return !SafetyInfo->HeaderMayThrow; - - // Somewhere in this loop there is an instruction which may throw and make us - // exit the loop. - if (SafetyInfo->MayThrow) - return false; - - // Get the exit blocks for the current loop. - SmallVector<BasicBlock *, 8> ExitBlocks; - CurLoop->getExitBlocks(ExitBlocks); - - // Verify that the block dominates each of the exit blocks of the loop. - for (BasicBlock *ExitBlock : ExitBlocks) - if (!DT->dominates(Inst.getParent(), ExitBlock)) - return false; - - // As a degenerate case, if the loop is statically infinite then we haven't - // proven anything since there are no exit blocks. - if (ExitBlocks.empty()) - return false; - - // FIXME: In general, we have to prove that the loop isn't an infinite loop. - // See http::llvm.org/PR24078 . (The "ExitBlocks.empty()" check above is - // just a special case of this.) - return true; -} - Optional<unsigned> llvm::getLoopEstimatedTripCount(Loop *L) { // Only support loops with a unique exiting block, and a latch. if (!L->getExitingBlock()) return None; - // Get the branch weights for the the loop's backedge. + // Get the branch weights for the loop's backedge. BranchInst *LatchBR = dyn_cast<BranchInst>(L->getLoopLatch()->getTerminator()); if (!LatchBR || LatchBR->getNumSuccessors() != 2) @@ -1460,7 +1519,7 @@ Optional<unsigned> llvm::getLoopEstimatedTripCount(Loop *L) { return (FalseVal + (TrueVal / 2)) / TrueVal; } -/// \brief Adds a 'fast' flag to floating point operations. +/// Adds a 'fast' flag to floating point operations. static Value *addFastMathFlag(Value *V) { if (isa<FPMathOperator>(V)) { FastMathFlags Flags; @@ -1470,6 +1529,38 @@ static Value *addFastMathFlag(Value *V) { return V; } +// Helper to generate an ordered reduction. +Value * +llvm::getOrderedReduction(IRBuilder<> &Builder, Value *Acc, Value *Src, + unsigned Op, + RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind, + ArrayRef<Value *> RedOps) { + unsigned VF = Src->getType()->getVectorNumElements(); + + // Extract and apply reduction ops in ascending order: + // e.g. ((((Acc + Scl[0]) + Scl[1]) + Scl[2]) + ) ... + Scl[VF-1] + Value *Result = Acc; + for (unsigned ExtractIdx = 0; ExtractIdx != VF; ++ExtractIdx) { + Value *Ext = + Builder.CreateExtractElement(Src, Builder.getInt32(ExtractIdx)); + + if (Op != Instruction::ICmp && Op != Instruction::FCmp) { + Result = Builder.CreateBinOp((Instruction::BinaryOps)Op, Result, Ext, + "bin.rdx"); + } else { + assert(MinMaxKind != RecurrenceDescriptor::MRK_Invalid && + "Invalid min/max"); + Result = RecurrenceDescriptor::createMinMaxOp(Builder, MinMaxKind, Result, + Ext); + } + + if (!RedOps.empty()) + propagateIRFlags(Result, RedOps); + } + + return Result; +} + // Helper to generate a log2 shuffle reduction. Value * llvm::getShuffleReduction(IRBuilder<> &Builder, Value *Src, unsigned Op, diff --git a/lib/Transforms/Utils/LoopVersioning.cpp b/lib/Transforms/Utils/LoopVersioning.cpp index 29756d9dab7f..abbcd5f9e3b8 100644 --- a/lib/Transforms/Utils/LoopVersioning.cpp +++ b/lib/Transforms/Utils/LoopVersioning.cpp @@ -140,9 +140,12 @@ void LoopVersioning::addPHINodes( if (!PN) { PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver", &PHIBlock->front()); - for (auto *User : Inst->users()) - if (!VersionedLoop->contains(cast<Instruction>(User)->getParent())) - User->replaceUsesOfWith(Inst, PN); + SmallVector<User*, 8> UsersToUpdate; + for (User *U : Inst->users()) + if (!VersionedLoop->contains(cast<Instruction>(U)->getParent())) + UsersToUpdate.push_back(U); + for (User *U : UsersToUpdate) + U->replaceUsesOfWith(Inst, PN); PN->addIncoming(Inst, VersionedLoop->getExitingBlock()); } } @@ -248,7 +251,7 @@ void LoopVersioning::annotateInstWithNoAlias(Instruction *VersionedInst, } namespace { -/// \brief Also expose this is a pass. Currently this is only used for +/// Also expose this is a pass. Currently this is only used for /// unit-testing. It adds all memchecks necessary to remove all may-aliasing /// array accesses from the loop. class LoopVersioningPass : public FunctionPass { diff --git a/lib/Transforms/Utils/LowerInvoke.cpp b/lib/Transforms/Utils/LowerInvoke.cpp index ee84541e526d..c852d538b0d1 100644 --- a/lib/Transforms/Utils/LowerInvoke.cpp +++ b/lib/Transforms/Utils/LowerInvoke.cpp @@ -21,7 +21,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" using namespace llvm; #define DEBUG_TYPE "lowerinvoke" @@ -48,10 +48,12 @@ static bool runImpl(Function &F) { bool Changed = false; for (BasicBlock &BB : F) if (InvokeInst *II = dyn_cast<InvokeInst>(BB.getTerminator())) { - SmallVector<Value *, 16> CallArgs(II->op_begin(), II->op_end() - 3); + SmallVector<Value *, 16> CallArgs(II->arg_begin(), II->arg_end()); + SmallVector<OperandBundleDef, 1> OpBundles; + II->getOperandBundlesAsDefs(OpBundles); // Insert a normal call instruction... CallInst *NewCall = - CallInst::Create(II->getCalledValue(), CallArgs, "", II); + CallInst::Create(II->getCalledValue(), CallArgs, OpBundles, "", II); NewCall->takeName(II); NewCall->setCallingConv(II->getCallingConv()); NewCall->setAttributes(II->getAttributes()); diff --git a/lib/Transforms/Utils/LowerMemIntrinsics.cpp b/lib/Transforms/Utils/LowerMemIntrinsics.cpp index 57dc225e9dab..03006ef3a2d3 100644 --- a/lib/Transforms/Utils/LowerMemIntrinsics.cpp +++ b/lib/Transforms/Utils/LowerMemIntrinsics.cpp @@ -409,8 +409,8 @@ void llvm::expandMemCpyAsLoop(MemCpyInst *Memcpy, /* SrcAddr */ Memcpy->getRawSource(), /* DstAddr */ Memcpy->getRawDest(), /* CopyLen */ CI, - /* SrcAlign */ Memcpy->getAlignment(), - /* DestAlign */ Memcpy->getAlignment(), + /* SrcAlign */ Memcpy->getSourceAlignment(), + /* DestAlign */ Memcpy->getDestAlignment(), /* SrcIsVolatile */ Memcpy->isVolatile(), /* DstIsVolatile */ Memcpy->isVolatile(), /* TargetTransformInfo */ TTI); @@ -419,8 +419,8 @@ void llvm::expandMemCpyAsLoop(MemCpyInst *Memcpy, /* SrcAddr */ Memcpy->getRawSource(), /* DstAddr */ Memcpy->getRawDest(), /* CopyLen */ Memcpy->getLength(), - /* SrcAlign */ Memcpy->getAlignment(), - /* DestAlign */ Memcpy->getAlignment(), + /* SrcAlign */ Memcpy->getSourceAlignment(), + /* DestAlign */ Memcpy->getDestAlignment(), /* SrcIsVolatile */ Memcpy->isVolatile(), /* DstIsVolatile */ Memcpy->isVolatile(), /* TargetTransfomrInfo */ TTI); @@ -432,8 +432,8 @@ void llvm::expandMemMoveAsLoop(MemMoveInst *Memmove) { /* SrcAddr */ Memmove->getRawSource(), /* DstAddr */ Memmove->getRawDest(), /* CopyLen */ Memmove->getLength(), - /* SrcAlign */ Memmove->getAlignment(), - /* DestAlign */ Memmove->getAlignment(), + /* SrcAlign */ Memmove->getSourceAlignment(), + /* DestAlign */ Memmove->getDestAlignment(), /* SrcIsVolatile */ Memmove->isVolatile(), /* DstIsVolatile */ Memmove->isVolatile()); } @@ -443,6 +443,6 @@ void llvm::expandMemSetAsLoop(MemSetInst *Memset) { /* DstAddr */ Memset->getRawDest(), /* CopyLen */ Memset->getLength(), /* SetValue */ Memset->getValue(), - /* Alignment */ Memset->getAlignment(), + /* Alignment */ Memset->getDestAlignment(), Memset->isVolatile()); } diff --git a/lib/Transforms/Utils/LowerSwitch.cpp b/lib/Transforms/Utils/LowerSwitch.cpp index 344cb35df986..e99ecfef19cd 100644 --- a/lib/Transforms/Utils/LowerSwitch.cpp +++ b/lib/Transforms/Utils/LowerSwitch.cpp @@ -29,7 +29,7 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include <algorithm> #include <cassert> @@ -74,7 +74,7 @@ namespace { LowerSwitch() : FunctionPass(ID) { initializeLowerSwitchPass(*PassRegistry::getPassRegistry()); - } + } bool runOnFunction(Function &F) override; @@ -155,11 +155,8 @@ bool LowerSwitch::runOnFunction(Function &F) { } /// Used for debugging purposes. -static raw_ostream& operator<<(raw_ostream &O, - const LowerSwitch::CaseVector &C) - LLVM_ATTRIBUTE_USED; - -static raw_ostream& operator<<(raw_ostream &O, +LLVM_ATTRIBUTE_USED +static raw_ostream &operator<<(raw_ostream &O, const LowerSwitch::CaseVector &C) { O << "["; @@ -172,7 +169,7 @@ static raw_ostream& operator<<(raw_ostream &O, return O << "]"; } -/// \brief Update the first occurrence of the "switch statement" BB in the PHI +/// Update the first occurrence of the "switch statement" BB in the PHI /// node with the "new" BB. The other occurrences will: /// /// 1) Be updated by subsequent calls to this function. Switch statements may @@ -245,14 +242,13 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, unsigned Mid = Size / 2; std::vector<CaseRange> LHS(Begin, Begin + Mid); - DEBUG(dbgs() << "LHS: " << LHS << "\n"); + LLVM_DEBUG(dbgs() << "LHS: " << LHS << "\n"); std::vector<CaseRange> RHS(Begin + Mid, End); - DEBUG(dbgs() << "RHS: " << RHS << "\n"); + LLVM_DEBUG(dbgs() << "RHS: " << RHS << "\n"); CaseRange &Pivot = *(Begin + Mid); - DEBUG(dbgs() << "Pivot ==> " - << Pivot.Low->getValue() - << " -" << Pivot.High->getValue() << "\n"); + LLVM_DEBUG(dbgs() << "Pivot ==> " << Pivot.Low->getValue() << " -" + << Pivot.High->getValue() << "\n"); // NewLowerBound here should never be the integer minimal value. // This is because it is computed from a case range that is never @@ -274,20 +270,14 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, NewUpperBound = LHS.back().High; } - DEBUG(dbgs() << "LHS Bounds ==> "; - if (LowerBound) { - dbgs() << LowerBound->getSExtValue(); - } else { - dbgs() << "NONE"; - } - dbgs() << " - " << NewUpperBound->getSExtValue() << "\n"; - dbgs() << "RHS Bounds ==> "; - dbgs() << NewLowerBound->getSExtValue() << " - "; - if (UpperBound) { - dbgs() << UpperBound->getSExtValue() << "\n"; - } else { - dbgs() << "NONE\n"; - }); + LLVM_DEBUG(dbgs() << "LHS Bounds ==> "; if (LowerBound) { + dbgs() << LowerBound->getSExtValue(); + } else { dbgs() << "NONE"; } dbgs() << " - " + << NewUpperBound->getSExtValue() << "\n"; + dbgs() << "RHS Bounds ==> "; + dbgs() << NewLowerBound->getSExtValue() << " - "; if (UpperBound) { + dbgs() << UpperBound->getSExtValue() << "\n"; + } else { dbgs() << "NONE\n"; }); // Create a new node that checks if the value is < pivot. Go to the // left branch if it is and right branch if not. @@ -337,7 +327,7 @@ BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, } else if (Leaf.Low->isZero()) { // Val >= 0 && Val <= Hi --> Val <=u Hi Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, - "SwitchLeaf"); + "SwitchLeaf"); } else { // Emit V-Lo <=u Hi-Lo Constant* NegLo = ConstantExpr::getNeg(Leaf.Low); @@ -364,7 +354,7 @@ BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, for (uint64_t j = 0; j < Range; ++j) { PN->removeIncomingValue(OrigBlock); } - + int BlockIdx = PN->getBasicBlockIndex(OrigBlock); assert(BlockIdx != -1 && "Switch didn't go to this successor??"); PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf); @@ -382,7 +372,7 @@ unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { Cases.push_back(CaseRange(Case.getCaseValue(), Case.getCaseValue(), Case.getCaseSuccessor())); - std::sort(Cases.begin(), Cases.end(), CaseCmp()); + llvm::sort(Cases.begin(), Cases.end(), CaseCmp()); // Merge case into clusters if (Cases.size() >= 2) { @@ -443,9 +433,9 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, // Prepare cases vector. CaseVector Cases; unsigned numCmps = Clusterify(Cases, SI); - DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() - << ". Total compares: " << numCmps << "\n"); - DEBUG(dbgs() << "Cases: " << Cases << "\n"); + LLVM_DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() + << ". Total compares: " << numCmps << "\n"); + LLVM_DEBUG(dbgs() << "Cases: " << Cases << "\n"); (void)numCmps; ConstantInt *LowerBound = nullptr; @@ -505,6 +495,10 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, } #endif + // As the default block in the switch is unreachable, update the PHI nodes + // (remove the entry to the default block) to reflect this. + Default->removePredecessor(OrigBlock); + // Use the most popular block as the new default, reducing the number of // cases. assert(MaxPop > 0 && PopSucc); @@ -518,29 +512,33 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, if (Cases.empty()) { BranchInst::Create(Default, CurBlock); SI->eraseFromParent(); + // As all the cases have been replaced with a single branch, only keep + // one entry in the PHI nodes. + for (unsigned I = 0 ; I < (MaxPop - 1) ; ++I) + PopSucc->removePredecessor(OrigBlock); return; } } + unsigned NrOfDefaults = (SI->getDefaultDest() == Default) ? 1 : 0; + for (const auto &Case : SI->cases()) + if (Case.getCaseSuccessor() == Default) + NrOfDefaults++; + // Create a new, empty default block so that the new hierarchy of // if-then statements go to this and the PHI nodes are happy. BasicBlock *NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); F->getBasicBlockList().insert(Default->getIterator(), NewDefault); BranchInst::Create(Default, NewDefault); - // If there is an entry in any PHI nodes for the default edge, make sure - // to update them as well. - for (BasicBlock::iterator I = Default->begin(); isa<PHINode>(I); ++I) { - PHINode *PN = cast<PHINode>(I); - int BlockIdx = PN->getBasicBlockIndex(OrigBlock); - assert(BlockIdx != -1 && "Switch didn't go to this successor??"); - PN->setIncomingBlock((unsigned)BlockIdx, NewDefault); - } - BasicBlock *SwitchBlock = switchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val, OrigBlock, OrigBlock, NewDefault, UnreachableRanges); + // If there are entries in any PHI nodes for the default edge, make sure + // to update them as well. + fixPhis(Default, OrigBlock, NewDefault, NrOfDefaults); + // Branch to our shiny new if-then stuff... BranchInst::Create(SwitchBlock, OrigBlock); diff --git a/lib/Transforms/Utils/Mem2Reg.cpp b/lib/Transforms/Utils/Mem2Reg.cpp index 29f289b62da0..23145e584751 100644 --- a/lib/Transforms/Utils/Mem2Reg.cpp +++ b/lib/Transforms/Utils/Mem2Reg.cpp @@ -22,7 +22,7 @@ #include "llvm/IR/PassManager.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" #include <vector> diff --git a/lib/Transforms/Utils/MetaRenamer.cpp b/lib/Transforms/Utils/MetaRenamer.cpp index 0f7bd76c03ca..323f2552ca80 100644 --- a/lib/Transforms/Utils/MetaRenamer.cpp +++ b/lib/Transforms/Utils/MetaRenamer.cpp @@ -29,7 +29,7 @@ #include "llvm/IR/Type.h" #include "llvm/IR/TypeFinder.h" #include "llvm/Pass.h" -#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/Utils.h" using namespace llvm; diff --git a/lib/Transforms/Utils/OrderedInstructions.cpp b/lib/Transforms/Utils/OrderedInstructions.cpp index dc780542ce68..6d0b96f6aa8a 100644 --- a/lib/Transforms/Utils/OrderedInstructions.cpp +++ b/lib/Transforms/Utils/OrderedInstructions.cpp @@ -14,19 +14,38 @@ #include "llvm/Transforms/Utils/OrderedInstructions.h" using namespace llvm; +bool OrderedInstructions::localDominates(const Instruction *InstA, + const Instruction *InstB) const { + assert(InstA->getParent() == InstB->getParent() && + "Instructions must be in the same basic block"); + + const BasicBlock *IBB = InstA->getParent(); + auto OBB = OBBMap.find(IBB); + if (OBB == OBBMap.end()) + OBB = OBBMap.insert({IBB, make_unique<OrderedBasicBlock>(IBB)}).first; + return OBB->second->dominates(InstA, InstB); +} + /// Given 2 instructions, use OrderedBasicBlock to check for dominance relation /// if the instructions are in the same basic block, Otherwise, use dominator /// tree. bool OrderedInstructions::dominates(const Instruction *InstA, const Instruction *InstB) const { - const BasicBlock *IBB = InstA->getParent(); // Use ordered basic block to do dominance check in case the 2 instructions // are in the same basic block. - if (IBB == InstB->getParent()) { - auto OBB = OBBMap.find(IBB); - if (OBB == OBBMap.end()) - OBB = OBBMap.insert({IBB, make_unique<OrderedBasicBlock>(IBB)}).first; - return OBB->second->dominates(InstA, InstB); - } + if (InstA->getParent() == InstB->getParent()) + return localDominates(InstA, InstB); return DT->dominates(InstA->getParent(), InstB->getParent()); } + +bool OrderedInstructions::dfsBefore(const Instruction *InstA, + const Instruction *InstB) const { + // Use ordered basic block in case the 2 instructions are in the same basic + // block. + if (InstA->getParent() == InstB->getParent()) + return localDominates(InstA, InstB); + + DomTreeNode *DA = DT->getNode(InstA->getParent()); + DomTreeNode *DB = DT->getNode(InstB->getParent()); + return DA->getDFSNumIn() < DB->getDFSNumIn(); +} diff --git a/lib/Transforms/Utils/PredicateInfo.cpp b/lib/Transforms/Utils/PredicateInfo.cpp index d47be6ea566b..2923977b791a 100644 --- a/lib/Transforms/Utils/PredicateInfo.cpp +++ b/lib/Transforms/Utils/PredicateInfo.cpp @@ -17,6 +17,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CFG.h" #include "llvm/IR/AssemblyAnnotationWriter.h" @@ -24,6 +25,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" @@ -32,7 +34,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/DebugCounter.h" #include "llvm/Support/FormattedStream.h" -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/OrderedInstructions.h" #include <algorithm> #define DEBUG_TYPE "predicateinfo" @@ -118,7 +120,7 @@ static bool valueComesBefore(OrderedInstructions &OI, const Value *A, return false; if (ArgA && ArgB) return ArgA->getArgNo() < ArgB->getArgNo(); - return OI.dominates(cast<Instruction>(A), cast<Instruction>(B)); + return OI.dfsBefore(cast<Instruction>(A), cast<Instruction>(B)); } // This compares ValueDFS structures, creating OrderedBasicBlocks where @@ -479,6 +481,19 @@ void PredicateInfo::buildPredicateInfo() { renameUses(OpsToRename); } +// Create a ssa_copy declaration with custom mangling, because +// Intrinsic::getDeclaration does not handle overloaded unnamed types properly: +// all unnamed types get mangled to the same string. We use the pointer +// to the type as name here, as it guarantees unique names for different +// types and we remove the declarations when destroying PredicateInfo. +// It is a workaround for PR38117, because solving it in a fully general way is +// tricky (FIXME). +static Function *getCopyDeclaration(Module *M, Type *Ty) { + std::string Name = "llvm.ssa.copy." + utostr((uintptr_t) Ty); + return cast<Function>(M->getOrInsertFunction( + Name, getType(M->getContext(), Intrinsic::ssa_copy, Ty))); +} + // Given the renaming stack, make all the operands currently on the stack real // by inserting them into the IR. Return the last operation's value. Value *PredicateInfo::materializeStack(unsigned int &Counter, @@ -507,8 +522,9 @@ Value *PredicateInfo::materializeStack(unsigned int &Counter, // order in the case of multiple predicateinfo in the same block. if (isa<PredicateWithEdge>(ValInfo)) { IRBuilder<> B(getBranchTerminator(ValInfo)); - Function *IF = Intrinsic::getDeclaration( - F.getParent(), Intrinsic::ssa_copy, Op->getType()); + Function *IF = getCopyDeclaration(F.getParent(), Op->getType()); + if (IF->user_begin() == IF->user_end()) + CreatedDeclarations.insert(IF); CallInst *PIC = B.CreateCall(IF, Op, Op->getName() + "." + Twine(Counter++)); PredicateMap.insert({PIC, ValInfo}); @@ -518,8 +534,9 @@ Value *PredicateInfo::materializeStack(unsigned int &Counter, assert(PAssume && "Should not have gotten here without it being an assume"); IRBuilder<> B(PAssume->AssumeInst); - Function *IF = Intrinsic::getDeclaration( - F.getParent(), Intrinsic::ssa_copy, Op->getType()); + Function *IF = getCopyDeclaration(F.getParent(), Op->getType()); + if (IF->user_begin() == IF->user_end()) + CreatedDeclarations.insert(IF); CallInst *PIC = B.CreateCall(IF, Op); PredicateMap.insert({PIC, ValInfo}); Result.Def = PIC; @@ -553,10 +570,11 @@ void PredicateInfo::renameUses(SmallPtrSetImpl<Value *> &OpSet) { auto Comparator = [&](const Value *A, const Value *B) { return valueComesBefore(OI, A, B); }; - std::sort(OpsToRename.begin(), OpsToRename.end(), Comparator); + llvm::sort(OpsToRename.begin(), OpsToRename.end(), Comparator); ValueDFS_Compare Compare(OI); // Compute liveness, and rename in O(uses) per Op. for (auto *Op : OpsToRename) { + LLVM_DEBUG(dbgs() << "Visiting " << *Op << "\n"); unsigned Counter = 0; SmallVector<ValueDFS, 16> OrderedUses; const auto &ValueInfo = getValueInfo(Op); @@ -625,15 +643,15 @@ void PredicateInfo::renameUses(SmallPtrSetImpl<Value *> &OpSet) { // we want to. bool PossibleCopy = VD.PInfo != nullptr; if (RenameStack.empty()) { - DEBUG(dbgs() << "Rename Stack is empty\n"); + LLVM_DEBUG(dbgs() << "Rename Stack is empty\n"); } else { - DEBUG(dbgs() << "Rename Stack Top DFS numbers are (" - << RenameStack.back().DFSIn << "," - << RenameStack.back().DFSOut << ")\n"); + LLVM_DEBUG(dbgs() << "Rename Stack Top DFS numbers are (" + << RenameStack.back().DFSIn << "," + << RenameStack.back().DFSOut << ")\n"); } - DEBUG(dbgs() << "Current DFS numbers are (" << VD.DFSIn << "," - << VD.DFSOut << ")\n"); + LLVM_DEBUG(dbgs() << "Current DFS numbers are (" << VD.DFSIn << "," + << VD.DFSOut << ")\n"); bool ShouldPush = (VD.Def || PossibleCopy); bool OutOfScope = !stackIsInScope(RenameStack, VD); @@ -652,7 +670,7 @@ void PredicateInfo::renameUses(SmallPtrSetImpl<Value *> &OpSet) { if (VD.Def || PossibleCopy) continue; if (!DebugCounter::shouldExecute(RenameCounter)) { - DEBUG(dbgs() << "Skipping execution due to debug counter\n"); + LLVM_DEBUG(dbgs() << "Skipping execution due to debug counter\n"); continue; } ValueDFS &Result = RenameStack.back(); @@ -663,8 +681,9 @@ void PredicateInfo::renameUses(SmallPtrSetImpl<Value *> &OpSet) { if (!Result.Def) Result.Def = materializeStack(Counter, RenameStack, Op); - DEBUG(dbgs() << "Found replacement " << *Result.Def << " for " - << *VD.U->get() << " in " << *(VD.U->getUser()) << "\n"); + LLVM_DEBUG(dbgs() << "Found replacement " << *Result.Def << " for " + << *VD.U->get() << " in " << *(VD.U->getUser()) + << "\n"); assert(DT.dominates(cast<Instruction>(Result.Def), *VD.U) && "Predicateinfo def should have dominated this use"); VD.U->set(Result.Def); @@ -702,7 +721,22 @@ PredicateInfo::PredicateInfo(Function &F, DominatorTree &DT, buildPredicateInfo(); } -PredicateInfo::~PredicateInfo() {} +// Remove all declarations we created . The PredicateInfo consumers are +// responsible for remove the ssa_copy calls created. +PredicateInfo::~PredicateInfo() { + // Collect function pointers in set first, as SmallSet uses a SmallVector + // internally and we have to remove the asserting value handles first. + SmallPtrSet<Function *, 20> FunctionPtrs; + for (auto &F : CreatedDeclarations) + FunctionPtrs.insert(&*F); + CreatedDeclarations.clear(); + + for (Function *F : FunctionPtrs) { + assert(F->user_begin() == F->user_end() && + "PredicateInfo consumer did not remove all SSA copies."); + F->eraseFromParent(); + } +} void PredicateInfo::verifyPredicateInfo() const {} @@ -720,6 +754,20 @@ void PredicateInfoPrinterLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<AssumptionCacheTracker>(); } +// Replace ssa_copy calls created by PredicateInfo with their operand. +static void replaceCreatedSSACopys(PredicateInfo &PredInfo, Function &F) { + for (auto I = inst_begin(F), E = inst_end(F); I != E;) { + Instruction *Inst = &*I++; + const auto *PI = PredInfo.getPredicateInfoFor(Inst); + auto *II = dyn_cast<IntrinsicInst>(Inst); + if (!PI || !II || II->getIntrinsicID() != Intrinsic::ssa_copy) + continue; + + Inst->replaceAllUsesWith(II->getOperand(0)); + Inst->eraseFromParent(); + } +} + bool PredicateInfoPrinterLegacyPass::runOnFunction(Function &F) { auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); @@ -727,6 +775,8 @@ bool PredicateInfoPrinterLegacyPass::runOnFunction(Function &F) { PredInfo->print(dbgs()); if (VerifyPredicateInfo) PredInfo->verifyPredicateInfo(); + + replaceCreatedSSACopys(*PredInfo, F); return false; } @@ -735,12 +785,14 @@ PreservedAnalyses PredicateInfoPrinterPass::run(Function &F, auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); OS << "PredicateInfo for function: " << F.getName() << "\n"; - make_unique<PredicateInfo>(F, DT, AC)->print(OS); + auto PredInfo = make_unique<PredicateInfo>(F, DT, AC); + PredInfo->print(OS); + replaceCreatedSSACopys(*PredInfo, F); return PreservedAnalyses::all(); } -/// \brief An assembly annotator class to print PredicateInfo information in +/// An assembly annotator class to print PredicateInfo information in /// comments. class PredicateInfoAnnotatedWriter : public AssemblyAnnotationWriter { friend class PredicateInfo; diff --git a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp index fcd3bd08482a..86e15bbd7f22 100644 --- a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -26,6 +26,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/IteratedDominanceFrontier.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -45,7 +46,6 @@ #include "llvm/IR/Type.h" #include "llvm/IR/User.h" #include "llvm/Support/Casting.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" #include <algorithm> #include <cassert> @@ -164,26 +164,27 @@ struct AllocaInfo { } }; -// Data package used by RenamePass() -class RenamePassData { -public: +/// Data package used by RenamePass(). +struct RenamePassData { using ValVector = std::vector<Value *>; + using LocationVector = std::vector<DebugLoc>; - RenamePassData(BasicBlock *B, BasicBlock *P, ValVector V) - : BB(B), Pred(P), Values(std::move(V)) {} + RenamePassData(BasicBlock *B, BasicBlock *P, ValVector V, LocationVector L) + : BB(B), Pred(P), Values(std::move(V)), Locations(std::move(L)) {} BasicBlock *BB; BasicBlock *Pred; ValVector Values; + LocationVector Locations; }; -/// \brief This assigns and keeps a per-bb relative ordering of load/store +/// This assigns and keeps a per-bb relative ordering of load/store /// instructions in the block that directly load or store an alloca. /// /// This functionality is important because it avoids scanning large basic /// blocks multiple times when promoting many allocas in the same block. class LargeBlockInfo { - /// \brief For each instruction that we track, keep the index of the + /// For each instruction that we track, keep the index of the /// instruction. /// /// The index starts out as the number of the instruction from the start of @@ -242,7 +243,7 @@ struct PromoteMem2Reg { /// Reverse mapping of Allocas. DenseMap<AllocaInst *, unsigned> AllocaLookup; - /// \brief The PhiNodes we're adding. + /// The PhiNodes we're adding. /// /// That map is used to simplify some Phi nodes as we iterate over it, so /// it should have deterministic iterators. We could use a MapVector, but @@ -294,7 +295,7 @@ private: unsigned getNumPreds(const BasicBlock *BB) { unsigned &NP = BBNumPreds[BB]; if (NP == 0) - NP = std::distance(pred_begin(BB), pred_end(BB)) + 1; + NP = pred_size(BB) + 1; return NP - 1; } @@ -303,6 +304,7 @@ private: SmallPtrSetImpl<BasicBlock *> &LiveInBlocks); void RenamePass(BasicBlock *BB, BasicBlock *Pred, RenamePassData::ValVector &IncVals, + RenamePassData::LocationVector &IncLocs, std::vector<RenamePassData> &Worklist); bool QueuePhiNode(BasicBlock *BB, unsigned AllocaIdx, unsigned &Version); }; @@ -345,7 +347,7 @@ static void removeLifetimeIntrinsicUsers(AllocaInst *AI) { } } -/// \brief Rewrite as many loads as possible given a single store. +/// Rewrite as many loads as possible given a single store. /// /// When there is only a single store, we can use the domtree to trivially /// replace all of the dominated loads with the stored value. Do so, and return @@ -475,7 +477,7 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, // Sort the stores by their index, making it efficient to do a lookup with a // binary search. - std::sort(StoresByIndex.begin(), StoresByIndex.end(), less_first()); + llvm::sort(StoresByIndex.begin(), StoresByIndex.end(), less_first()); // Walk all of the loads from this alloca, replacing them with the nearest // store above them, if any. @@ -509,6 +511,11 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, !isKnownNonZero(ReplVal, DL, 0, AC, LI, &DT)) addAssumeNonNull(AC, LI); + // If the replacement value is the load, this must occur in unreachable + // code. + if (ReplVal == LI) + ReplVal = UndefValue::get(LI->getType()); + LI->replaceAllUsesWith(ReplVal); } @@ -631,10 +638,10 @@ void PromoteMem2Reg::run() { SmallVector<BasicBlock *, 32> PHIBlocks; IDF.calculate(PHIBlocks); if (PHIBlocks.size() > 1) - std::sort(PHIBlocks.begin(), PHIBlocks.end(), - [this](BasicBlock *A, BasicBlock *B) { - return BBNumbers.lookup(A) < BBNumbers.lookup(B); - }); + llvm::sort(PHIBlocks.begin(), PHIBlocks.end(), + [this](BasicBlock *A, BasicBlock *B) { + return BBNumbers.lookup(A) < BBNumbers.lookup(B); + }); unsigned CurrentVersion = 0; for (BasicBlock *BB : PHIBlocks) @@ -653,15 +660,20 @@ void PromoteMem2Reg::run() { for (unsigned i = 0, e = Allocas.size(); i != e; ++i) Values[i] = UndefValue::get(Allocas[i]->getAllocatedType()); + // When handling debug info, treat all incoming values as if they have unknown + // locations until proven otherwise. + RenamePassData::LocationVector Locations(Allocas.size()); + // Walks all basic blocks in the function performing the SSA rename algorithm // and inserting the phi nodes we marked as necessary std::vector<RenamePassData> RenamePassWorkList; - RenamePassWorkList.emplace_back(&F.front(), nullptr, std::move(Values)); + RenamePassWorkList.emplace_back(&F.front(), nullptr, std::move(Values), + std::move(Locations)); do { RenamePassData RPD = std::move(RenamePassWorkList.back()); RenamePassWorkList.pop_back(); // RenamePass may add new worklist entries. - RenamePass(RPD.BB, RPD.Pred, RPD.Values, RenamePassWorkList); + RenamePass(RPD.BB, RPD.Pred, RPD.Values, RPD.Locations, RenamePassWorkList); } while (!RenamePassWorkList.empty()); // The renamer uses the Visited set to avoid infinite loops. Clear it now. @@ -740,7 +752,7 @@ void PromoteMem2Reg::run() { // Ok, now we know that all of the PHI nodes are missing entries for some // basic blocks. Start by sorting the incoming predecessors for efficient // access. - std::sort(Preds.begin(), Preds.end()); + llvm::sort(Preds.begin(), Preds.end()); // Now we loop through all BB's which have entries in SomePHI and remove // them from the Preds list. @@ -772,7 +784,7 @@ void PromoteMem2Reg::run() { NewPhiNodes.clear(); } -/// \brief Determine which blocks the value is live in. +/// Determine which blocks the value is live in. /// /// These are blocks which lead to uses. Knowing this allows us to avoid /// inserting PHI nodes into blocks which don't lead to uses (thus, the @@ -846,7 +858,7 @@ void PromoteMem2Reg::ComputeLiveInBlocks( } } -/// \brief Queue a phi-node to be added to a basic-block for a specific Alloca. +/// Queue a phi-node to be added to a basic-block for a specific Alloca. /// /// Returns true if there wasn't already a phi-node for that variable bool PromoteMem2Reg::QueuePhiNode(BasicBlock *BB, unsigned AllocaNo, @@ -868,13 +880,24 @@ bool PromoteMem2Reg::QueuePhiNode(BasicBlock *BB, unsigned AllocaNo, return true; } -/// \brief Recursively traverse the CFG of the function, renaming loads and +/// Update the debug location of a phi. \p ApplyMergedLoc indicates whether to +/// create a merged location incorporating \p DL, or to set \p DL directly. +static void updateForIncomingValueLocation(PHINode *PN, DebugLoc DL, + bool ApplyMergedLoc) { + if (ApplyMergedLoc) + PN->applyMergedLocation(PN->getDebugLoc(), DL); + else + PN->setDebugLoc(DL); +} + +/// Recursively traverse the CFG of the function, renaming loads and /// stores to the allocas which we are promoting. /// /// IncomingVals indicates what value each Alloca contains on exit from the /// predecessor block Pred. void PromoteMem2Reg::RenamePass(BasicBlock *BB, BasicBlock *Pred, RenamePassData::ValVector &IncomingVals, + RenamePassData::LocationVector &IncomingLocs, std::vector<RenamePassData> &Worklist) { NextIteration: // If we are inserting any phi nodes into this BB, they will already be in the @@ -899,6 +922,10 @@ NextIteration: do { unsigned AllocaNo = PhiToAllocaMap[APN]; + // Update the location of the phi node. + updateForIncomingValueLocation(APN, IncomingLocs[AllocaNo], + APN->getNumIncomingValues() > 0); + // Add N incoming values to the PHI node. for (unsigned i = 0; i != NumEdges; ++i) APN->addIncoming(IncomingVals[AllocaNo], Pred); @@ -960,8 +987,11 @@ NextIteration: continue; // what value were we writing? - IncomingVals[ai->second] = SI->getOperand(0); + unsigned AllocaNo = ai->second; + IncomingVals[AllocaNo] = SI->getOperand(0); + // Record debuginfo for the store before removing it. + IncomingLocs[AllocaNo] = SI->getDebugLoc(); for (DbgInfoIntrinsic *DII : AllocaDbgDeclares[ai->second]) ConvertDebugDeclareToDebugValue(DII, SI, DIB); BB->getInstList().erase(SI); @@ -984,7 +1014,7 @@ NextIteration: for (; I != E; ++I) if (VisitedSuccs.insert(*I).second) - Worklist.emplace_back(*I, Pred, IncomingVals); + Worklist.emplace_back(*I, Pred, IncomingVals, IncomingLocs); goto NextIteration; } diff --git a/lib/Transforms/Utils/SSAUpdater.cpp b/lib/Transforms/Utils/SSAUpdater.cpp index e4b20b0faa15..ca184ed7c4e3 100644 --- a/lib/Transforms/Utils/SSAUpdater.cpp +++ b/lib/Transforms/Utils/SSAUpdater.cpp @@ -147,11 +147,9 @@ Value *SSAUpdater::GetValueInMiddleOfBlock(BasicBlock *BB) { if (isa<PHINode>(BB->begin())) { SmallDenseMap<BasicBlock *, Value *, 8> ValueMapping(PredValues.begin(), PredValues.end()); - PHINode *SomePHI; - for (BasicBlock::iterator It = BB->begin(); - (SomePHI = dyn_cast<PHINode>(It)); ++It) { - if (IsEquivalentPHI(SomePHI, ValueMapping)) - return SomePHI; + for (PHINode &SomePHI : BB->phis()) { + if (IsEquivalentPHI(&SomePHI, ValueMapping)) + return &SomePHI; } } @@ -180,7 +178,7 @@ Value *SSAUpdater::GetValueInMiddleOfBlock(BasicBlock *BB) { // If the client wants to know about all new instructions, tell it. if (InsertedPHIs) InsertedPHIs->push_back(InsertedPHI); - DEBUG(dbgs() << " Inserted PHI: " << *InsertedPHI << "\n"); + LLVM_DEBUG(dbgs() << " Inserted PHI: " << *InsertedPHI << "\n"); return InsertedPHI; } diff --git a/lib/Transforms/Utils/SSAUpdaterBulk.cpp b/lib/Transforms/Utils/SSAUpdaterBulk.cpp new file mode 100644 index 000000000000..397bac2940a4 --- /dev/null +++ b/lib/Transforms/Utils/SSAUpdaterBulk.cpp @@ -0,0 +1,191 @@ +//===- SSAUpdaterBulk.cpp - Unstructured SSA Update Tool ------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the SSAUpdaterBulk class. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/SSAUpdaterBulk.h" +#include "llvm/Analysis/IteratedDominanceFrontier.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/Value.h" + +using namespace llvm; + +#define DEBUG_TYPE "ssaupdaterbulk" + +/// Helper function for finding a block which should have a value for the given +/// user. For PHI-nodes this block is the corresponding predecessor, for other +/// instructions it's their parent block. +static BasicBlock *getUserBB(Use *U) { + auto *User = cast<Instruction>(U->getUser()); + + if (auto *UserPN = dyn_cast<PHINode>(User)) + return UserPN->getIncomingBlock(*U); + else + return User->getParent(); +} + +/// Add a new variable to the SSA rewriter. This needs to be called before +/// AddAvailableValue or AddUse calls. +unsigned SSAUpdaterBulk::AddVariable(StringRef Name, Type *Ty) { + unsigned Var = Rewrites.size(); + LLVM_DEBUG(dbgs() << "SSAUpdater: Var=" << Var << ": initialized with Ty = " + << *Ty << ", Name = " << Name << "\n"); + RewriteInfo RI(Name, Ty); + Rewrites.push_back(RI); + return Var; +} + +/// Indicate that a rewritten value is available in the specified block with the +/// specified value. +void SSAUpdaterBulk::AddAvailableValue(unsigned Var, BasicBlock *BB, Value *V) { + assert(Var < Rewrites.size() && "Variable not found!"); + LLVM_DEBUG(dbgs() << "SSAUpdater: Var=" << Var + << ": added new available value" << *V << " in " + << BB->getName() << "\n"); + Rewrites[Var].Defines[BB] = V; +} + +/// Record a use of the symbolic value. This use will be updated with a +/// rewritten value when RewriteAllUses is called. +void SSAUpdaterBulk::AddUse(unsigned Var, Use *U) { + assert(Var < Rewrites.size() && "Variable not found!"); + LLVM_DEBUG(dbgs() << "SSAUpdater: Var=" << Var << ": added a use" << *U->get() + << " in " << getUserBB(U)->getName() << "\n"); + Rewrites[Var].Uses.push_back(U); +} + +/// Return true if the SSAUpdater already has a value for the specified variable +/// in the specified block. +bool SSAUpdaterBulk::HasValueForBlock(unsigned Var, BasicBlock *BB) { + return (Var < Rewrites.size()) ? Rewrites[Var].Defines.count(BB) : false; +} + +// Compute value at the given block BB. We either should already know it, or we +// should be able to recursively reach it going up dominator tree. +Value *SSAUpdaterBulk::computeValueAt(BasicBlock *BB, RewriteInfo &R, + DominatorTree *DT) { + if (!R.Defines.count(BB)) { + if (DT->isReachableFromEntry(BB) && PredCache.get(BB).size()) { + BasicBlock *IDom = DT->getNode(BB)->getIDom()->getBlock(); + Value *V = computeValueAt(IDom, R, DT); + R.Defines[BB] = V; + } else + R.Defines[BB] = UndefValue::get(R.Ty); + } + return R.Defines[BB]; +} + +/// Given sets of UsingBlocks and DefBlocks, compute the set of LiveInBlocks. +/// This is basically a subgraph limited by DefBlocks and UsingBlocks. +static void +ComputeLiveInBlocks(const SmallPtrSetImpl<BasicBlock *> &UsingBlocks, + const SmallPtrSetImpl<BasicBlock *> &DefBlocks, + SmallPtrSetImpl<BasicBlock *> &LiveInBlocks, + PredIteratorCache &PredCache) { + // To determine liveness, we must iterate through the predecessors of blocks + // where the def is live. Blocks are added to the worklist if we need to + // check their predecessors. Start with all the using blocks. + SmallVector<BasicBlock *, 64> LiveInBlockWorklist(UsingBlocks.begin(), + UsingBlocks.end()); + + // Now that we have a set of blocks where the phi is live-in, recursively add + // their predecessors until we find the full region the value is live. + while (!LiveInBlockWorklist.empty()) { + BasicBlock *BB = LiveInBlockWorklist.pop_back_val(); + + // The block really is live in here, insert it into the set. If already in + // the set, then it has already been processed. + if (!LiveInBlocks.insert(BB).second) + continue; + + // Since the value is live into BB, it is either defined in a predecessor or + // live into it to. Add the preds to the worklist unless they are a + // defining block. + for (BasicBlock *P : PredCache.get(BB)) { + // The value is not live into a predecessor if it defines the value. + if (DefBlocks.count(P)) + continue; + + // Otherwise it is, add to the worklist. + LiveInBlockWorklist.push_back(P); + } + } +} + +/// Perform all the necessary updates, including new PHI-nodes insertion and the +/// requested uses update. +void SSAUpdaterBulk::RewriteAllUses(DominatorTree *DT, + SmallVectorImpl<PHINode *> *InsertedPHIs) { + for (auto &R : Rewrites) { + // Compute locations for new phi-nodes. + // For that we need to initialize DefBlocks from definitions in R.Defines, + // UsingBlocks from uses in R.Uses, then compute LiveInBlocks, and then use + // this set for computing iterated dominance frontier (IDF). + // The IDF blocks are the blocks where we need to insert new phi-nodes. + ForwardIDFCalculator IDF(*DT); + LLVM_DEBUG(dbgs() << "SSAUpdater: rewriting " << R.Uses.size() + << " use(s)\n"); + + SmallPtrSet<BasicBlock *, 2> DefBlocks; + for (auto &Def : R.Defines) + DefBlocks.insert(Def.first); + IDF.setDefiningBlocks(DefBlocks); + + SmallPtrSet<BasicBlock *, 2> UsingBlocks; + for (Use *U : R.Uses) + UsingBlocks.insert(getUserBB(U)); + + SmallVector<BasicBlock *, 32> IDFBlocks; + SmallPtrSet<BasicBlock *, 32> LiveInBlocks; + ComputeLiveInBlocks(UsingBlocks, DefBlocks, LiveInBlocks, PredCache); + IDF.resetLiveInBlocks(); + IDF.setLiveInBlocks(LiveInBlocks); + IDF.calculate(IDFBlocks); + + // We've computed IDF, now insert new phi-nodes there. + SmallVector<PHINode *, 4> InsertedPHIsForVar; + for (auto *FrontierBB : IDFBlocks) { + IRBuilder<> B(FrontierBB, FrontierBB->begin()); + PHINode *PN = B.CreatePHI(R.Ty, 0, R.Name); + R.Defines[FrontierBB] = PN; + InsertedPHIsForVar.push_back(PN); + if (InsertedPHIs) + InsertedPHIs->push_back(PN); + } + + // Fill in arguments of the inserted PHIs. + for (auto *PN : InsertedPHIsForVar) { + BasicBlock *PBB = PN->getParent(); + for (BasicBlock *Pred : PredCache.get(PBB)) + PN->addIncoming(computeValueAt(Pred, R, DT), Pred); + } + + // Rewrite actual uses with the inserted definitions. + SmallPtrSet<Use *, 4> ProcessedUses; + for (Use *U : R.Uses) { + if (!ProcessedUses.insert(U).second) + continue; + Value *V = computeValueAt(getUserBB(U), R, DT); + Value *OldVal = U->get(); + assert(OldVal && "Invalid use!"); + // Notify that users of the existing value that it is being replaced. + if (OldVal != V && OldVal->hasValueHandle()) + ValueHandleBase::ValueIsRAUWd(OldVal, V); + LLVM_DEBUG(dbgs() << "SSAUpdater: replacing " << *OldVal << " with " << *V + << "\n"); + U->set(V); + } + } +} diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp index e7358dbcb624..c87b5c16ffce 100644 --- a/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/lib/Transforms/Utils/SimplifyCFG.cpp @@ -19,7 +19,6 @@ #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" @@ -28,6 +27,7 @@ #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -66,7 +66,6 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> #include <cassert> @@ -283,12 +282,8 @@ isProfitableToFoldUnconditional(BranchInst *SI1, BranchInst *SI2, /// of Succ. static void AddPredecessorToBlock(BasicBlock *Succ, BasicBlock *NewPred, BasicBlock *ExistPred) { - if (!isa<PHINode>(Succ->begin())) - return; // Quick exit if nothing to do - - PHINode *PN; - for (BasicBlock::iterator I = Succ->begin(); (PN = dyn_cast<PHINode>(I)); ++I) - PN->addIncoming(PN->getIncomingValueForBlock(ExistPred), NewPred); + for (PHINode &PN : Succ->phis()) + PN.addIncoming(PN.getIncomingValueForBlock(ExistPred), NewPred); } /// Compute an abstract "cost" of speculating the given instruction, @@ -692,9 +687,7 @@ Value *SimplifyCFGOpt::isValueEqualityComparison(TerminatorInst *TI) { if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { // Do not permit merging of large switch instructions into their // predecessors unless there is only one predecessor. - if (SI->getNumSuccessors() * std::distance(pred_begin(SI->getParent()), - pred_end(SI->getParent())) <= - 128) + if (SI->getNumSuccessors() * pred_size(SI->getParent()) <= 128) CV = SI->getCondition(); } else if (BranchInst *BI = dyn_cast<BranchInst>(TI)) if (BI->isConditional() && BI->getCondition()->hasOneUse()) @@ -851,9 +844,9 @@ bool SimplifyCFGOpt::SimplifyEqualityComparisonWithOnlyPredecessor( // Remove PHI node entries for the dead edge. ThisCases[0].Dest->removePredecessor(TI->getParent()); - DEBUG(dbgs() << "Threading pred instr: " << *Pred->getTerminator() - << "Through successor TI: " << *TI << "Leaving: " << *NI - << "\n"); + LLVM_DEBUG(dbgs() << "Threading pred instr: " << *Pred->getTerminator() + << "Through successor TI: " << *TI << "Leaving: " << *NI + << "\n"); EraseTerminatorInstAndDCECond(TI); return true; @@ -865,8 +858,8 @@ bool SimplifyCFGOpt::SimplifyEqualityComparisonWithOnlyPredecessor( for (unsigned i = 0, e = PredCases.size(); i != e; ++i) DeadCases.insert(PredCases[i].Value); - DEBUG(dbgs() << "Threading pred instr: " << *Pred->getTerminator() - << "Through successor TI: " << *TI); + LLVM_DEBUG(dbgs() << "Threading pred instr: " << *Pred->getTerminator() + << "Through successor TI: " << *TI); // Collect branch weights into a vector. SmallVector<uint32_t, 8> Weights; @@ -892,7 +885,7 @@ bool SimplifyCFGOpt::SimplifyEqualityComparisonWithOnlyPredecessor( if (HasWeight && Weights.size() >= 2) setBranchWeights(SI, Weights); - DEBUG(dbgs() << "Leaving: " << *TI << "\n"); + LLVM_DEBUG(dbgs() << "Leaving: " << *TI << "\n"); return true; } @@ -933,9 +926,9 @@ bool SimplifyCFGOpt::SimplifyEqualityComparisonWithOnlyPredecessor( Instruction *NI = Builder.CreateBr(TheRealDest); (void)NI; - DEBUG(dbgs() << "Threading pred instr: " << *Pred->getTerminator() - << "Through successor TI: " << *TI << "Leaving: " << *NI - << "\n"); + LLVM_DEBUG(dbgs() << "Threading pred instr: " << *Pred->getTerminator() + << "Through successor TI: " << *TI << "Leaving: " << *NI + << "\n"); EraseTerminatorInstAndDCECond(TI); return true; @@ -1228,11 +1221,9 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, static bool isSafeToHoistInvoke(BasicBlock *BB1, BasicBlock *BB2, Instruction *I1, Instruction *I2) { for (BasicBlock *Succ : successors(BB1)) { - PHINode *PN; - for (BasicBlock::iterator BBI = Succ->begin(); - (PN = dyn_cast<PHINode>(BBI)); ++BBI) { - Value *BB1V = PN->getIncomingValueForBlock(BB1); - Value *BB2V = PN->getIncomingValueForBlock(BB2); + for (const PHINode &PN : Succ->phis()) { + Value *BB1V = PN.getIncomingValueForBlock(BB1); + Value *BB2V = PN.getIncomingValueForBlock(BB2); if (BB1V != BB2V && (BB1V == I1 || BB2V == I2)) { return false; } @@ -1282,34 +1273,58 @@ static bool HoistThenElseCodeToIf(BranchInst *BI, if (isa<TerminatorInst>(I1)) goto HoistTerminator; + // If we're going to hoist a call, make sure that the two instructions we're + // commoning/hoisting are both marked with musttail, or neither of them is + // marked as such. Otherwise, we might end up in a situation where we hoist + // from a block where the terminator is a `ret` to a block where the terminator + // is a `br`, and `musttail` calls expect to be followed by a return. + auto *C1 = dyn_cast<CallInst>(I1); + auto *C2 = dyn_cast<CallInst>(I2); + if (C1 && C2) + if (C1->isMustTailCall() != C2->isMustTailCall()) + return Changed; + if (!TTI.isProfitableToHoist(I1) || !TTI.isProfitableToHoist(I2)) return Changed; - // For a normal instruction, we just move one to right before the branch, - // then replace all uses of the other with the first. Finally, we remove - // the now redundant second instruction. - BIParent->getInstList().splice(BI->getIterator(), BB1->getInstList(), I1); - if (!I2->use_empty()) - I2->replaceAllUsesWith(I1); - I1->andIRFlags(I2); - unsigned KnownIDs[] = {LLVMContext::MD_tbaa, - LLVMContext::MD_range, - LLVMContext::MD_fpmath, - LLVMContext::MD_invariant_load, - LLVMContext::MD_nonnull, - LLVMContext::MD_invariant_group, - LLVMContext::MD_align, - LLVMContext::MD_dereferenceable, - LLVMContext::MD_dereferenceable_or_null, - LLVMContext::MD_mem_parallel_loop_access}; - combineMetadata(I1, I2, KnownIDs); - - // I1 and I2 are being combined into a single instruction. Its debug - // location is the merged locations of the original instructions. - I1->applyMergedLocation(I1->getDebugLoc(), I2->getDebugLoc()); - - I2->eraseFromParent(); - Changed = true; + if (isa<DbgInfoIntrinsic>(I1) || isa<DbgInfoIntrinsic>(I2)) { + assert (isa<DbgInfoIntrinsic>(I1) && isa<DbgInfoIntrinsic>(I2)); + // The debug location is an integral part of a debug info intrinsic + // and can't be separated from it or replaced. Instead of attempting + // to merge locations, simply hoist both copies of the intrinsic. + BIParent->getInstList().splice(BI->getIterator(), + BB1->getInstList(), I1); + BIParent->getInstList().splice(BI->getIterator(), + BB2->getInstList(), I2); + Changed = true; + } else { + // For a normal instruction, we just move one to right before the branch, + // then replace all uses of the other with the first. Finally, we remove + // the now redundant second instruction. + BIParent->getInstList().splice(BI->getIterator(), + BB1->getInstList(), I1); + if (!I2->use_empty()) + I2->replaceAllUsesWith(I1); + I1->andIRFlags(I2); + unsigned KnownIDs[] = {LLVMContext::MD_tbaa, + LLVMContext::MD_range, + LLVMContext::MD_fpmath, + LLVMContext::MD_invariant_load, + LLVMContext::MD_nonnull, + LLVMContext::MD_invariant_group, + LLVMContext::MD_align, + LLVMContext::MD_dereferenceable, + LLVMContext::MD_dereferenceable_or_null, + LLVMContext::MD_mem_parallel_loop_access}; + combineMetadata(I1, I2, KnownIDs); + + // I1 and I2 are being combined into a single instruction. Its debug + // location is the merged locations of the original instructions. + I1->applyMergedLocation(I1->getDebugLoc(), I2->getDebugLoc()); + + I2->eraseFromParent(); + Changed = true; + } I1 = &*BB1_Itr++; I2 = &*BB2_Itr++; @@ -1332,18 +1347,16 @@ HoistTerminator: return Changed; for (BasicBlock *Succ : successors(BB1)) { - PHINode *PN; - for (BasicBlock::iterator BBI = Succ->begin(); - (PN = dyn_cast<PHINode>(BBI)); ++BBI) { - Value *BB1V = PN->getIncomingValueForBlock(BB1); - Value *BB2V = PN->getIncomingValueForBlock(BB2); + for (PHINode &PN : Succ->phis()) { + Value *BB1V = PN.getIncomingValueForBlock(BB1); + Value *BB2V = PN.getIncomingValueForBlock(BB2); if (BB1V == BB2V) continue; // Check for passingValueIsAlwaysUndefined here because we would rather // eliminate undefined control flow then converting it to a select. - if (passingValueIsAlwaysUndefined(BB1V, PN) || - passingValueIsAlwaysUndefined(BB2V, PN)) + if (passingValueIsAlwaysUndefined(BB1V, &PN) || + passingValueIsAlwaysUndefined(BB2V, &PN)) return Changed; if (isa<ConstantExpr>(BB1V) && !isSafeToSpeculativelyExecute(BB1V)) @@ -1369,11 +1382,9 @@ HoistTerminator: // nodes, so we insert select instruction to compute the final result. std::map<std::pair<Value *, Value *>, SelectInst *> InsertedSelects; for (BasicBlock *Succ : successors(BB1)) { - PHINode *PN; - for (BasicBlock::iterator BBI = Succ->begin(); - (PN = dyn_cast<PHINode>(BBI)); ++BBI) { - Value *BB1V = PN->getIncomingValueForBlock(BB1); - Value *BB2V = PN->getIncomingValueForBlock(BB2); + for (PHINode &PN : Succ->phis()) { + Value *BB1V = PN.getIncomingValueForBlock(BB1); + Value *BB2V = PN.getIncomingValueForBlock(BB2); if (BB1V == BB2V) continue; @@ -1386,9 +1397,9 @@ HoistTerminator: BB1V->getName() + "." + BB2V->getName(), BI)); // Make the PHI node use the select for all incoming values for BB1/BB2 - for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (PN->getIncomingBlock(i) == BB1 || PN->getIncomingBlock(i) == BB2) - PN->setIncomingValue(i, SI); + for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) + if (PN.getIncomingBlock(i) == BB1 || PN.getIncomingBlock(i) == BB2) + PN.setIncomingValue(i, SI); } } @@ -1727,7 +1738,8 @@ static bool SinkCommonCodeFromPredecessors(BasicBlock *BB) { LockstepReverseIterator LRI(UnconditionalPreds); while (LRI.isValid() && canSinkInstructions(*LRI, PHIOperands)) { - DEBUG(dbgs() << "SINK: instruction can be sunk: " << *(*LRI)[0] << "\n"); + LLVM_DEBUG(dbgs() << "SINK: instruction can be sunk: " << *(*LRI)[0] + << "\n"); InstructionsToSink.insert((*LRI).begin(), (*LRI).end()); ++ScanIdx; --LRI; @@ -1739,7 +1751,7 @@ static bool SinkCommonCodeFromPredecessors(BasicBlock *BB) { for (auto *V : PHIOperands[I]) if (InstructionsToSink.count(V) == 0) ++NumPHIdValues; - DEBUG(dbgs() << "SINK: #phid values: " << NumPHIdValues << "\n"); + LLVM_DEBUG(dbgs() << "SINK: #phid values: " << NumPHIdValues << "\n"); unsigned NumPHIInsts = NumPHIdValues / UnconditionalPreds.size(); if ((NumPHIdValues % UnconditionalPreds.size()) != 0) NumPHIInsts++; @@ -1767,7 +1779,7 @@ static bool SinkCommonCodeFromPredecessors(BasicBlock *BB) { if (!Profitable) return false; - DEBUG(dbgs() << "SINK: Splitting edge\n"); + LLVM_DEBUG(dbgs() << "SINK: Splitting edge\n"); // We have a conditional edge and we're going to sink some instructions. // Insert a new block postdominating all blocks we're going to sink from. if (!SplitBlockPredecessors(BB, UnconditionalPreds, ".sink.split")) @@ -1789,16 +1801,17 @@ static bool SinkCommonCodeFromPredecessors(BasicBlock *BB) { // and never actually sink it which means we produce more PHIs than intended. // This is unlikely in practice though. for (unsigned SinkIdx = 0; SinkIdx != ScanIdx; ++SinkIdx) { - DEBUG(dbgs() << "SINK: Sink: " - << *UnconditionalPreds[0]->getTerminator()->getPrevNode() - << "\n"); + LLVM_DEBUG(dbgs() << "SINK: Sink: " + << *UnconditionalPreds[0]->getTerminator()->getPrevNode() + << "\n"); // Because we've sunk every instruction in turn, the current instruction to // sink is always at index 0. LRI.reset(); if (!ProfitableToSinkInstruction(LRI)) { // Too many PHIs would be created. - DEBUG(dbgs() << "SINK: stopping here, too many PHIs would be created!\n"); + LLVM_DEBUG( + dbgs() << "SINK: stopping here, too many PHIs would be created!\n"); break; } @@ -1810,7 +1823,7 @@ static bool SinkCommonCodeFromPredecessors(BasicBlock *BB) { return Changed; } -/// \brief Determine if we can hoist sink a sole store instruction out of a +/// Determine if we can hoist sink a sole store instruction out of a /// conditional block. /// /// We are looking for code like the following: @@ -1850,12 +1863,9 @@ static Value *isSafeToSpeculateStore(Instruction *I, BasicBlock *BrBB, // Look for a store to the same pointer in BrBB. unsigned MaxNumInstToLookAt = 9; - for (Instruction &CurI : reverse(*BrBB)) { + for (Instruction &CurI : reverse(BrBB->instructionsWithoutDebug())) { if (!MaxNumInstToLookAt) break; - // Skip debug info. - if (isa<DbgInfoIntrinsic>(CurI)) - continue; --MaxNumInstToLookAt; // Could be calling an instruction that affects memory like free(). @@ -1874,7 +1884,7 @@ static Value *isSafeToSpeculateStore(Instruction *I, BasicBlock *BrBB, return nullptr; } -/// \brief Speculate a conditional basic block flattening the CFG. +/// Speculate a conditional basic block flattening the CFG. /// /// Note that this is a very risky transform currently. Speculating /// instructions like this is most often not desirable. Instead, there is an MI @@ -1999,10 +2009,9 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, // Check that the PHI nodes can be converted to selects. bool HaveRewritablePHIs = false; - for (BasicBlock::iterator I = EndBB->begin(); - PHINode *PN = dyn_cast<PHINode>(I); ++I) { - Value *OrigV = PN->getIncomingValueForBlock(BB); - Value *ThenV = PN->getIncomingValueForBlock(ThenBB); + for (PHINode &PN : EndBB->phis()) { + Value *OrigV = PN.getIncomingValueForBlock(BB); + Value *ThenV = PN.getIncomingValueForBlock(ThenBB); // FIXME: Try to remove some of the duplication with HoistThenElseCodeToIf. // Skip PHIs which are trivial. @@ -2010,8 +2019,8 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, continue; // Don't convert to selects if we could remove undefined behavior instead. - if (passingValueIsAlwaysUndefined(OrigV, PN) || - passingValueIsAlwaysUndefined(ThenV, PN)) + if (passingValueIsAlwaysUndefined(OrigV, &PN) || + passingValueIsAlwaysUndefined(ThenV, &PN)) return false; HaveRewritablePHIs = true; @@ -2045,7 +2054,7 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, return false; // If we get here, we can hoist the instruction and if-convert. - DEBUG(dbgs() << "SPECULATIVELY EXECUTING BB" << *ThenBB << "\n";); + LLVM_DEBUG(dbgs() << "SPECULATIVELY EXECUTING BB" << *ThenBB << "\n";); // Insert a select of the value of the speculated store. if (SpeculatedStoreValue) { @@ -2072,12 +2081,11 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, // Insert selects and rewrite the PHI operands. IRBuilder<NoFolder> Builder(BI); - for (BasicBlock::iterator I = EndBB->begin(); - PHINode *PN = dyn_cast<PHINode>(I); ++I) { - unsigned OrigI = PN->getBasicBlockIndex(BB); - unsigned ThenI = PN->getBasicBlockIndex(ThenBB); - Value *OrigV = PN->getIncomingValue(OrigI); - Value *ThenV = PN->getIncomingValue(ThenI); + for (PHINode &PN : EndBB->phis()) { + unsigned OrigI = PN.getBasicBlockIndex(BB); + unsigned ThenI = PN.getBasicBlockIndex(ThenBB); + Value *OrigV = PN.getIncomingValue(OrigI); + Value *ThenV = PN.getIncomingValue(ThenI); // Skip PHIs which are trivial. if (OrigV == ThenV) @@ -2091,8 +2099,8 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, std::swap(TrueV, FalseV); Value *V = Builder.CreateSelect( BrCond, TrueV, FalseV, "spec.select", BI); - PN->setIncomingValue(OrigI, V); - PN->setIncomingValue(ThenI, V); + PN.setIncomingValue(OrigI, V); + PN.setIncomingValue(ThenI, V); } // Remove speculated dbg intrinsics. @@ -2107,19 +2115,16 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, /// Return true if we can thread a branch across this block. static bool BlockIsSimpleEnoughToThreadThrough(BasicBlock *BB) { - BranchInst *BI = cast<BranchInst>(BB->getTerminator()); unsigned Size = 0; - for (BasicBlock::iterator BBI = BB->begin(); &*BBI != BI; ++BBI) { - if (isa<DbgInfoIntrinsic>(BBI)) - continue; + for (Instruction &I : BB->instructionsWithoutDebug()) { if (Size > 10) return false; // Don't clone large BB's. ++Size; // We can only support instructions that do not define values that are // live outside of the current basic block. - for (User *U : BBI->users()) { + for (User *U : I.users()) { Instruction *UI = cast<Instruction>(U); if (UI->getParent() != BB || isa<PHINode>(UI)) return false; @@ -2261,6 +2266,10 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, // dependence information for this check, but simplifycfg can't keep it up // to date, and this catches most of the cases we care about anyway. BasicBlock *BB = PN->getParent(); + const Function *Fn = BB->getParent(); + if (Fn && Fn->hasFnAttribute(Attribute::OptForFuzzing)) + return false; + BasicBlock *IfTrue, *IfFalse; Value *IfCond = GetIfCondition(BB, IfTrue, IfFalse); if (!IfCond || @@ -2351,8 +2360,9 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, } } - DEBUG(dbgs() << "FOUND IF CONDITION! " << *IfCond << " T: " - << IfTrue->getName() << " F: " << IfFalse->getName() << "\n"); + LLVM_DEBUG(dbgs() << "FOUND IF CONDITION! " << *IfCond + << " T: " << IfTrue->getName() + << " F: " << IfFalse->getName() << "\n"); // If we can still promote the PHI nodes after this gauntlet of tests, // do all of the PHI's now. @@ -2476,9 +2486,9 @@ static bool SimplifyCondBranchToTwoReturns(BranchInst *BI, (void)RI; - DEBUG(dbgs() << "\nCHANGING BRANCH TO TWO RETURNS INTO SELECT:" - << "\n " << *BI << "NewRet = " << *RI - << "TRUEBLOCK: " << *TrueSucc << "FALSEBLOCK: " << *FalseSucc); + LLVM_DEBUG(dbgs() << "\nCHANGING BRANCH TO TWO RETURNS INTO SELECT:" + << "\n " << *BI << "NewRet = " << *RI << "TRUEBLOCK: " + << *TrueSucc << "FALSEBLOCK: " << *FalseSucc); EraseTerminatorInstAndDCECond(BI); @@ -2487,7 +2497,7 @@ static bool SimplifyCondBranchToTwoReturns(BranchInst *BI, /// Return true if the given instruction is available /// in its predecessor block. If yes, the instruction will be removed. -static bool checkCSEInPredecessor(Instruction *Inst, BasicBlock *PB) { +static bool tryCSEWithPredecessor(Instruction *Inst, BasicBlock *PB) { if (!isa<BinaryOperator>(Inst) && !isa<CmpInst>(Inst)) return false; for (Instruction &I : *PB) { @@ -2544,14 +2554,16 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { if (PBI->isConditional() && (BI->getSuccessor(0) == PBI->getSuccessor(0) || BI->getSuccessor(0) == PBI->getSuccessor(1))) { - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { + for (auto I = BB->instructionsWithoutDebug().begin(), + E = BB->instructionsWithoutDebug().end(); + I != E;) { Instruction *Curr = &*I++; if (isa<CmpInst>(Curr)) { Cond = Curr; break; } // Quit if we can't remove this instruction. - if (!checkCSEInPredecessor(Curr, PB)) + if (!tryCSEWithPredecessor(Curr, PB)) return false; } } @@ -2651,7 +2663,7 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { continue; } - DEBUG(dbgs() << "FOLDING BRANCH TO COMMON DEST:\n" << *PBI << *BB); + LLVM_DEBUG(dbgs() << "FOLDING BRANCH TO COMMON DEST:\n" << *PBI << *BB); IRBuilder<> Builder(PBI); // If we need to invert the condition in the pred block to match, do so now. @@ -2861,7 +2873,7 @@ static Value *ensureValueAvailableInSuccessor(Value *V, BasicBlock *BB, if (!AlternativeV) break; - assert(std::distance(pred_begin(Succ), pred_end(Succ)) == 2); + assert(pred_size(Succ) == 2); auto PredI = pred_begin(Succ); BasicBlock *OtherPredBB = *PredI == BB ? *++PredI : *PredI; if (PHI->getIncomingValueForBlock(OtherPredBB) == AlternativeV) @@ -2904,14 +2916,13 @@ static bool mergeConditionalStoreToAddress(BasicBlock *PTB, BasicBlock *PFB, // instructions inside are all cheap (arithmetic/GEPs), it's worthwhile to // thread this store. unsigned N = 0; - for (auto &I : *BB) { + for (auto &I : BB->instructionsWithoutDebug()) { // Cheap instructions viable for folding. if (isa<BinaryOperator>(I) || isa<GetElementPtrInst>(I) || isa<StoreInst>(I)) ++N; // Free instructions. - else if (isa<TerminatorInst>(I) || isa<DbgInfoIntrinsic>(I) || - IsaBitcastOfPointerType(I)) + else if (isa<TerminatorInst>(I) || IsaBitcastOfPointerType(I)) continue; else return false; @@ -2966,6 +2977,21 @@ static bool mergeConditionalStoreToAddress(BasicBlock *PTB, BasicBlock *PFB, if (&*I != PStore && I->mayReadOrWriteMemory()) return false; + // If PostBB has more than two predecessors, we need to split it so we can + // sink the store. + if (std::next(pred_begin(PostBB), 2) != pred_end(PostBB)) { + // We know that QFB's only successor is PostBB. And QFB has a single + // predecessor. If QTB exists, then its only successor is also PostBB. + // If QTB does not exist, then QFB's only predecessor has a conditional + // branch to QFB and PostBB. + BasicBlock *TruePred = QTB ? QTB : QFB->getSinglePredecessor(); + BasicBlock *NewBB = SplitBlockPredecessors(PostBB, { QFB, TruePred}, + "condstore.split"); + if (!NewBB) + return false; + PostBB = NewBB; + } + // OK, we're going to sink the stores to PostBB. The store has to be // conditional though, so first create the predicate. Value *PCond = cast<BranchInst>(PFB->getSinglePredecessor()->getTerminator()) @@ -3101,7 +3127,7 @@ static bool mergeConditionalStores(BranchInst *PBI, BranchInst *QBI, if ((PTB && !HasOnePredAndOneSucc(PTB, PBI->getParent(), QBI->getParent())) || (QTB && !HasOnePredAndOneSucc(QTB, QBI->getParent(), PostBB))) return false; - if (!PostBB->hasNUses(2) || !QBI->getParent()->hasNUses(2)) + if (!QBI->getParent()->hasNUses(2)) return false; // OK, this is a sequence of two diamonds or triangles. @@ -3201,11 +3227,9 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, // If this is a conditional branch in an empty block, and if any // predecessors are a conditional branch to one of our destinations, // fold the conditions into logical ops and one cond br. - BasicBlock::iterator BBI = BB->begin(); + // Ignore dbg intrinsics. - while (isa<DbgInfoIntrinsic>(BBI)) - ++BBI; - if (&*BBI != BI) + if (&*BB->instructionsWithoutDebug().begin() != BI) return false; int PBIOp, BIOp; @@ -3262,8 +3286,8 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, // Finally, if everything is ok, fold the branches to logical ops. BasicBlock *OtherDest = BI->getSuccessor(BIOp ^ 1); - DEBUG(dbgs() << "FOLDING BRs:" << *PBI->getParent() - << "AND: " << *BI->getParent()); + LLVM_DEBUG(dbgs() << "FOLDING BRs:" << *PBI->getParent() + << "AND: " << *BI->getParent()); // If OtherDest *is* BB, then BB is a basic block with a single conditional // branch in it, where one edge (OtherDest) goes back to itself but the other @@ -3281,7 +3305,7 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, OtherDest = InfLoopBlock; } - DEBUG(dbgs() << *PBI->getParent()->getParent()); + LLVM_DEBUG(dbgs() << *PBI->getParent()->getParent()); // BI may have other predecessors. Because of this, we leave // it alone, but modify PBI. @@ -3335,17 +3359,15 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, // it. If it has PHIs though, the PHIs may have different // entries for BB and PBI's BB. If so, insert a select to make // them agree. - PHINode *PN; - for (BasicBlock::iterator II = CommonDest->begin(); - (PN = dyn_cast<PHINode>(II)); ++II) { - Value *BIV = PN->getIncomingValueForBlock(BB); - unsigned PBBIdx = PN->getBasicBlockIndex(PBI->getParent()); - Value *PBIV = PN->getIncomingValue(PBBIdx); + for (PHINode &PN : CommonDest->phis()) { + Value *BIV = PN.getIncomingValueForBlock(BB); + unsigned PBBIdx = PN.getBasicBlockIndex(PBI->getParent()); + Value *PBIV = PN.getIncomingValue(PBBIdx); if (BIV != PBIV) { // Insert a select in PBI to pick the right value. SelectInst *NV = cast<SelectInst>( Builder.CreateSelect(PBICond, PBIV, BIV, PBIV->getName() + ".mux")); - PN->setIncomingValue(PBBIdx, NV); + PN.setIncomingValue(PBBIdx, NV); // Although the select has the same condition as PBI, the original branch // weights for PBI do not apply to the new select because the select's // 'logical' edges are incoming edges of the phi that is eliminated, not @@ -3367,8 +3389,8 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, } } - DEBUG(dbgs() << "INTO: " << *PBI->getParent()); - DEBUG(dbgs() << *PBI->getParent()->getParent()); + LLVM_DEBUG(dbgs() << "INTO: " << *PBI->getParent()); + LLVM_DEBUG(dbgs() << *PBI->getParent()->getParent()); // This basic block is probably dead. We know it has at least // one fewer predecessor. @@ -3668,9 +3690,9 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, IRBuilder<> &Builder, BasicBlock *BB = BI->getParent(); - DEBUG(dbgs() << "Converting 'icmp' chain with " << Values.size() - << " cases into SWITCH. BB is:\n" - << *BB); + LLVM_DEBUG(dbgs() << "Converting 'icmp' chain with " << Values.size() + << " cases into SWITCH. BB is:\n" + << *BB); // If there are any extra values that couldn't be folded into the switch // then we evaluate them with an explicit branch first. Split the block @@ -3693,8 +3715,8 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, IRBuilder<> &Builder, // for the edge we just added. AddPredecessorToBlock(EdgeBB, BB, NewBB); - DEBUG(dbgs() << " ** 'icmp' chain unhandled condition: " << *ExtraCase - << "\nEXTRABB = " << *BB); + LLVM_DEBUG(dbgs() << " ** 'icmp' chain unhandled condition: " << *ExtraCase + << "\nEXTRABB = " << *BB); BB = NewBB; } @@ -3725,7 +3747,7 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, IRBuilder<> &Builder, // Erase the old branch instruction. EraseTerminatorInstAndDCECond(BI); - DEBUG(dbgs() << " ** 'icmp' chain result is:\n" << *BB << '\n'); + LLVM_DEBUG(dbgs() << " ** 'icmp' chain result is:\n" << *BB << '\n'); return true; } @@ -3876,6 +3898,7 @@ static bool removeEmptyCleanup(CleanupReturnInst *RI) { switch (IntrinsicID) { case Intrinsic::dbg_declare: case Intrinsic::dbg_value: + case Intrinsic::dbg_label: case Intrinsic::lifetime_end: break; default: @@ -4052,8 +4075,8 @@ bool SimplifyCFGOpt::SimplifyReturn(ReturnInst *RI, IRBuilder<> &Builder) { if (!UncondBranchPreds.empty() && DupRet) { while (!UncondBranchPreds.empty()) { BasicBlock *Pred = UncondBranchPreds.pop_back_val(); - DEBUG(dbgs() << "FOLDING: " << *BB - << "INTO UNCOND BRANCH PRED: " << *Pred); + LLVM_DEBUG(dbgs() << "FOLDING: " << *BB + << "INTO UNCOND BRANCH PRED: " << *Pred); (void)FoldReturnIntoUncondBranch(RI, BB, Pred); } @@ -4377,7 +4400,8 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC, if (Known.Zero.intersects(CaseVal) || !Known.One.isSubsetOf(CaseVal) || (CaseVal.getMinSignedBits() > MaxSignificantBitsInCond)) { DeadCases.push_back(Case.getCaseValue()); - DEBUG(dbgs() << "SimplifyCFG: switch case " << CaseVal << " is dead.\n"); + LLVM_DEBUG(dbgs() << "SimplifyCFG: switch case " << CaseVal + << " is dead.\n"); } } @@ -4393,7 +4417,7 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC, if (HasDefault && DeadCases.empty() && NumUnknownBits < 64 /* avoid overflow */ && SI->getNumCases() == (1ULL << NumUnknownBits)) { - DEBUG(dbgs() << "SimplifyCFG: switch default is dead.\n"); + LLVM_DEBUG(dbgs() << "SimplifyCFG: switch default is dead.\n"); BasicBlock *NewDefault = SplitBlockPredecessors(SI->getDefaultDest(), SI->getParent(), ""); SI->setDefaultDest(&*NewDefault); @@ -4451,17 +4475,16 @@ static PHINode *FindPHIForConditionForwarding(ConstantInt *CaseValue, BasicBlock *Succ = Branch->getSuccessor(0); - BasicBlock::iterator I = Succ->begin(); - while (PHINode *PHI = dyn_cast<PHINode>(I++)) { - int Idx = PHI->getBasicBlockIndex(BB); + for (PHINode &PHI : Succ->phis()) { + int Idx = PHI.getBasicBlockIndex(BB); assert(Idx >= 0 && "PHI has no entry for predecessor?"); - Value *InValue = PHI->getIncomingValue(Idx); + Value *InValue = PHI.getIncomingValue(Idx); if (InValue != CaseValue) continue; *PhiIndex = Idx; - return PHI; + return &PHI; } return nullptr; @@ -4491,19 +4514,16 @@ static bool ForwardSwitchConditionToPHI(SwitchInst *SI) { // --> // %r = phi i32 ... [ %x, %switchbb ] ... - for (Instruction &InstInCaseDest : *CaseDest) { - auto *Phi = dyn_cast<PHINode>(&InstInCaseDest); - if (!Phi) break; - + for (PHINode &Phi : CaseDest->phis()) { // This only works if there is exactly 1 incoming edge from the switch to // a phi. If there is >1, that means multiple cases of the switch map to 1 // value in the phi, and that phi value is not the switch condition. Thus, // this transform would not make sense (the phi would be invalid because // a phi can't have different incoming values from the same block). - int SwitchBBIdx = Phi->getBasicBlockIndex(SwitchBlock); - if (Phi->getIncomingValue(SwitchBBIdx) == CaseValue && - count(Phi->blocks(), SwitchBlock) == 1) { - Phi->setIncomingValue(SwitchBBIdx, SI->getCondition()); + int SwitchBBIdx = Phi.getBasicBlockIndex(SwitchBlock); + if (Phi.getIncomingValue(SwitchBBIdx) == CaseValue && + count(Phi.blocks(), SwitchBlock) == 1) { + Phi.setIncomingValue(SwitchBBIdx, SI->getCondition()); Changed = true; } } @@ -4614,24 +4634,20 @@ GetCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest, // which we can constant-propagate the CaseVal, continue to its successor. SmallDenseMap<Value *, Constant *> ConstantPool; ConstantPool.insert(std::make_pair(SI->getCondition(), CaseVal)); - for (BasicBlock::iterator I = CaseDest->begin(), E = CaseDest->end(); I != E; - ++I) { - if (TerminatorInst *T = dyn_cast<TerminatorInst>(I)) { + for (Instruction &I :CaseDest->instructionsWithoutDebug()) { + if (TerminatorInst *T = dyn_cast<TerminatorInst>(&I)) { // If the terminator is a simple branch, continue to the next block. if (T->getNumSuccessors() != 1 || T->isExceptional()) return false; Pred = CaseDest; CaseDest = T->getSuccessor(0); - } else if (isa<DbgInfoIntrinsic>(I)) { - // Skip debug intrinsic. - continue; - } else if (Constant *C = ConstantFold(&*I, DL, ConstantPool)) { + } else if (Constant *C = ConstantFold(&I, DL, ConstantPool)) { // Instruction is side-effect free and constant. // If the instruction has uses outside this block or a phi node slot for // the block, it is not safe to bypass the instruction since it would then // no longer dominate all its uses. - for (auto &Use : I->uses()) { + for (auto &Use : I.uses()) { User *User = Use.getUser(); if (Instruction *I = dyn_cast<Instruction>(User)) if (I->getParent() == CaseDest) @@ -4642,7 +4658,7 @@ GetCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest, return false; } - ConstantPool.insert(std::make_pair(&*I, C)); + ConstantPool.insert(std::make_pair(&I, C)); } else { break; } @@ -4656,14 +4672,13 @@ GetCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest, return false; // Get the values for this case from phi nodes in the destination block. - BasicBlock::iterator I = (*CommonDest)->begin(); - while (PHINode *PHI = dyn_cast<PHINode>(I++)) { - int Idx = PHI->getBasicBlockIndex(Pred); + for (PHINode &PHI : (*CommonDest)->phis()) { + int Idx = PHI.getBasicBlockIndex(Pred); if (Idx == -1) continue; Constant *ConstVal = - LookupConstant(PHI->getIncomingValue(Idx), ConstantPool); + LookupConstant(PHI.getIncomingValue(Idx), ConstantPool); if (!ConstVal) return false; @@ -4671,37 +4686,38 @@ GetCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest, if (!ValidLookupTableConstant(ConstVal, TTI)) return false; - Res.push_back(std::make_pair(PHI, ConstVal)); + Res.push_back(std::make_pair(&PHI, ConstVal)); } return Res.size() > 0; } // Helper function used to add CaseVal to the list of cases that generate -// Result. -static void MapCaseToResult(ConstantInt *CaseVal, - SwitchCaseResultVectorTy &UniqueResults, - Constant *Result) { +// Result. Returns the updated number of cases that generate this result. +static uintptr_t MapCaseToResult(ConstantInt *CaseVal, + SwitchCaseResultVectorTy &UniqueResults, + Constant *Result) { for (auto &I : UniqueResults) { if (I.first == Result) { I.second.push_back(CaseVal); - return; + return I.second.size(); } } UniqueResults.push_back( std::make_pair(Result, SmallVector<ConstantInt *, 4>(1, CaseVal))); + return 1; } // Helper function that initializes a map containing // results for the PHI node of the common destination block for a switch // instruction. Returns false if multiple PHI nodes have been found or if // there is not a common destination block for the switch. -static bool InitializeUniqueCases(SwitchInst *SI, PHINode *&PHI, - BasicBlock *&CommonDest, - SwitchCaseResultVectorTy &UniqueResults, - Constant *&DefaultResult, - const DataLayout &DL, - const TargetTransformInfo &TTI) { +static bool +InitializeUniqueCases(SwitchInst *SI, PHINode *&PHI, BasicBlock *&CommonDest, + SwitchCaseResultVectorTy &UniqueResults, + Constant *&DefaultResult, const DataLayout &DL, + const TargetTransformInfo &TTI, + uintptr_t MaxUniqueResults, uintptr_t MaxCasesPerResult) { for (auto &I : SI->cases()) { ConstantInt *CaseVal = I.getCaseValue(); @@ -4711,10 +4727,21 @@ static bool InitializeUniqueCases(SwitchInst *SI, PHINode *&PHI, DL, TTI)) return false; - // Only one value per case is permitted + // Only one value per case is permitted. if (Results.size() > 1) return false; - MapCaseToResult(CaseVal, UniqueResults, Results.begin()->second); + + // Add the case->result mapping to UniqueResults. + const uintptr_t NumCasesForResult = + MapCaseToResult(CaseVal, UniqueResults, Results.begin()->second); + + // Early out if there are too many cases for this result. + if (NumCasesForResult > MaxCasesPerResult) + return false; + + // Early out if there are too many unique results. + if (UniqueResults.size() > MaxUniqueResults) + return false; // Check the PHI consistency. if (!PHI) @@ -4814,7 +4841,7 @@ static bool switchToSelect(SwitchInst *SI, IRBuilder<> &Builder, SwitchCaseResultVectorTy UniqueResults; // Collect all the cases that will deliver the same value from the switch. if (!InitializeUniqueCases(SI, PHI, CommonDest, UniqueResults, DefaultResult, - DL, TTI)) + DL, TTI, 2, 1)) return false; // Selects choose between maximum two values. if (UniqueResults.size() != 2) @@ -5392,8 +5419,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, } bool ReturnedEarly = false; - for (size_t I = 0, E = PHIs.size(); I != E; ++I) { - PHINode *PHI = PHIs[I]; + for (PHINode *PHI : PHIs) { const ResultListTy &ResultList = ResultLists[PHI]; // If using a bitmask, use any value to fill the lookup table holes. @@ -5483,7 +5509,7 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, SmallVector<int64_t,4> Values; for (auto &C : SI->cases()) Values.push_back(C.getCaseValue()->getValue().getSExtValue()); - std::sort(Values.begin(), Values.end()); + llvm::sort(Values.begin(), Values.end()); // If the switch is already dense, there's nothing useful to do here. if (isSwitchDense(Values)) @@ -5566,11 +5592,7 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { // If the block only contains the switch, see if we can fold the block // away into any preds. - BasicBlock::iterator BBI = BB->begin(); - // Ignore dbg intrinsics. - while (isa<DbgInfoIntrinsic>(BBI)) - ++BBI; - if (SI == &*BBI) + if (SI == &*BB->instructionsWithoutDebug().begin()) if (FoldValueComparisonIntoPredecessors(SI, Builder)) return simplifyCFG(BB, TTI, Options) | true; } @@ -5657,7 +5679,7 @@ bool SimplifyCFGOpt::SimplifyIndirectBr(IndirectBrInst *IBI) { /// any transform which might inhibit optimization (such as our ability to /// specialize a particular handler via tail commoning). We do this by not /// merging any blocks which require us to introduce a phi. Since the same -/// values are flowing through both blocks, we don't loose any ability to +/// values are flowing through both blocks, we don't lose any ability to /// specialize. If anything, we make such specialization more likely. /// /// TODO - This transformation could remove entries from a phi in the target @@ -5687,7 +5709,7 @@ static bool TryToMergeLandingPad(LandingPadInst *LPad, BranchInst *BI, // We've found an identical block. Update our predecessors to take that // path instead and make ourselves dead. - SmallSet<BasicBlock *, 16> Preds; + SmallPtrSet<BasicBlock *, 16> Preds; Preds.insert(pred_begin(BB), pred_end(BB)); for (BasicBlock *Pred : Preds) { InvokeInst *II = cast<InvokeInst>(Pred->getTerminator()); @@ -5705,7 +5727,7 @@ static bool TryToMergeLandingPad(LandingPadInst *LPad, BranchInst *BI, Inst.eraseFromParent(); } - SmallSet<BasicBlock *, 16> Succs; + SmallPtrSet<BasicBlock *, 16> Succs; Succs.insert(succ_begin(BB), succ_end(BB)); for (BasicBlock *Succ : Succs) { Succ->removePredecessor(BB); @@ -5729,9 +5751,12 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, // header. (This is for early invocations before loop simplify and // vectorization to keep canonical loop forms for nested loops. These blocks // can be eliminated when the pass is invoked later in the back-end.) + // Note that if BB has only one predecessor then we do not introduce new + // backedge, so we can eliminate BB. bool NeedCanonicalLoop = Options.NeedCanonicalLoop && - (LoopHeaders && (LoopHeaders->count(BB) || LoopHeaders->count(Succ))); + (LoopHeaders && pred_size(BB) > 1 && + (LoopHeaders->count(BB) || LoopHeaders->count(Succ))); BasicBlock::iterator I = BB->getFirstNonPHIOrDbg()->getIterator(); if (I->isTerminator() && BB != &BB->getParent()->getEntryBlock() && !NeedCanonicalLoop && TryToSimplifyUncondBranchFromEmptyBlock(BB)) @@ -5779,6 +5804,9 @@ static BasicBlock *allPredecessorsComeFromSameSource(BasicBlock *BB) { bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { BasicBlock *BB = BI->getParent(); + const Function *Fn = BB->getParent(); + if (Fn && Fn->hasFnAttribute(Attribute::OptForFuzzing)) + return false; // Conditional branch if (isValueEqualityComparison(BI)) { @@ -5791,18 +5819,12 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // This block must be empty, except for the setcond inst, if it exists. // Ignore dbg intrinsics. - BasicBlock::iterator I = BB->begin(); - // Ignore dbg intrinsics. - while (isa<DbgInfoIntrinsic>(I)) - ++I; + auto I = BB->instructionsWithoutDebug().begin(); if (&*I == BI) { if (FoldValueComparisonIntoPredecessors(BI, Builder)) return simplifyCFG(BB, TTI, Options) | true; } else if (&*I == cast<Instruction>(BI->getCondition())) { ++I; - // Ignore dbg intrinsics. - while (isa<DbgInfoIntrinsic>(I)) - ++I; if (&*I == BI && FoldValueComparisonIntoPredecessors(BI, Builder)) return simplifyCFG(BB, TTI, Options) | true; } @@ -5928,17 +5950,20 @@ static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I) { // Load from null is undefined. if (LoadInst *LI = dyn_cast<LoadInst>(Use)) if (!LI->isVolatile()) - return LI->getPointerAddressSpace() == 0; + return !NullPointerIsDefined(LI->getFunction(), + LI->getPointerAddressSpace()); // Store to null is undefined. if (StoreInst *SI = dyn_cast<StoreInst>(Use)) if (!SI->isVolatile()) - return SI->getPointerAddressSpace() == 0 && + return (!NullPointerIsDefined(SI->getFunction(), + SI->getPointerAddressSpace())) && SI->getPointerOperand() == I; // A call to null is undefined. if (auto CS = CallSite(Use)) - return CS.getCalledValue() == I; + return !NullPointerIsDefined(CS->getFunction()) && + CS.getCalledValue() == I; } return false; } @@ -5946,14 +5971,13 @@ static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I) { /// If BB has an incoming value that will always trigger undefined behavior /// (eg. null pointer dereference), remove the branch leading here. static bool removeUndefIntroducingPredecessor(BasicBlock *BB) { - for (BasicBlock::iterator i = BB->begin(); - PHINode *PHI = dyn_cast<PHINode>(i); ++i) - for (unsigned i = 0, e = PHI->getNumIncomingValues(); i != e; ++i) - if (passingValueIsAlwaysUndefined(PHI->getIncomingValue(i), PHI)) { - TerminatorInst *T = PHI->getIncomingBlock(i)->getTerminator(); + for (PHINode &PHI : BB->phis()) + for (unsigned i = 0, e = PHI.getNumIncomingValues(); i != e; ++i) + if (passingValueIsAlwaysUndefined(PHI.getIncomingValue(i), &PHI)) { + TerminatorInst *T = PHI.getIncomingBlock(i)->getTerminator(); IRBuilder<> Builder(T); if (BranchInst *BI = dyn_cast<BranchInst>(T)) { - BB->removePredecessor(PHI->getIncomingBlock(i)); + BB->removePredecessor(PHI.getIncomingBlock(i)); // Turn uncoditional branches into unreachables and remove the dead // destination from conditional branches. if (BI->isUnconditional()) @@ -5980,7 +6004,7 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) { // or that just have themself as a predecessor. These are unreachable. if ((pred_empty(BB) && BB != &BB->getParent()->getEntryBlock()) || BB->getSinglePredecessor() == BB) { - DEBUG(dbgs() << "Removing BB: \n" << *BB); + LLVM_DEBUG(dbgs() << "Removing BB: \n" << *BB); DeleteDeadBlock(BB); return true; } diff --git a/lib/Transforms/Utils/SimplifyIndVar.cpp b/lib/Transforms/Utils/SimplifyIndVar.cpp index ad1faea0a7ae..e381fbc34ab4 100644 --- a/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -26,6 +26,7 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -80,6 +81,7 @@ namespace { bool replaceIVUserWithLoopInvariant(Instruction *UseInst); bool eliminateOverflowIntrinsic(CallInst *CI); + bool eliminateTrunc(TruncInst *TI); bool eliminateIVUser(Instruction *UseInst, Instruction *IVOperand); bool makeIVComparisonInvariant(ICmpInst *ICmp, Value *IVOperand); void eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand); @@ -147,8 +149,8 @@ Value *SimplifyIndvar::foldIVUser(Instruction *UseInst, Instruction *IVOperand) if (SE->getSCEV(UseInst) != FoldedExpr) return nullptr; - DEBUG(dbgs() << "INDVARS: Eliminated IV operand: " << *IVOperand - << " -> " << *UseInst << '\n'); + LLVM_DEBUG(dbgs() << "INDVARS: Eliminated IV operand: " << *IVOperand + << " -> " << *UseInst << '\n'); UseInst->setOperand(OperIdx, IVSrc); assert(SE->getSCEV(UseInst) == FoldedExpr && "bad SCEV with folded oper"); @@ -221,7 +223,7 @@ bool SimplifyIndvar::makeIVComparisonInvariant(ICmpInst *ICmp, // for now. return false; - DEBUG(dbgs() << "INDVARS: Simplified comparison: " << *ICmp << '\n'); + LLVM_DEBUG(dbgs() << "INDVARS: Simplified comparison: " << *ICmp << '\n'); ICmp->setPredicate(InvariantPredicate); ICmp->setOperand(0, NewLHS); ICmp->setOperand(1, NewRHS); @@ -252,11 +254,11 @@ void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) { if (SE->isKnownPredicate(Pred, S, X)) { ICmp->replaceAllUsesWith(ConstantInt::getTrue(ICmp->getContext())); DeadInsts.emplace_back(ICmp); - DEBUG(dbgs() << "INDVARS: Eliminated comparison: " << *ICmp << '\n'); + LLVM_DEBUG(dbgs() << "INDVARS: Eliminated comparison: " << *ICmp << '\n'); } else if (SE->isKnownPredicate(ICmpInst::getInversePredicate(Pred), S, X)) { ICmp->replaceAllUsesWith(ConstantInt::getFalse(ICmp->getContext())); DeadInsts.emplace_back(ICmp); - DEBUG(dbgs() << "INDVARS: Eliminated comparison: " << *ICmp << '\n'); + LLVM_DEBUG(dbgs() << "INDVARS: Eliminated comparison: " << *ICmp << '\n'); } else if (makeIVComparisonInvariant(ICmp, IVOperand)) { // fallthrough to end of function } else if (ICmpInst::isSigned(OriginalPred) && @@ -267,7 +269,8 @@ void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) { // we turn the instruction's predicate to its unsigned version. Note that // we cannot rely on Pred here unless we check if we have swapped it. assert(ICmp->getPredicate() == OriginalPred && "Predicate changed?"); - DEBUG(dbgs() << "INDVARS: Turn to unsigned comparison: " << *ICmp << '\n'); + LLVM_DEBUG(dbgs() << "INDVARS: Turn to unsigned comparison: " << *ICmp + << '\n'); ICmp->setPredicate(ICmpInst::getUnsignedPredicate(OriginalPred)); } else return; @@ -293,7 +296,7 @@ bool SimplifyIndvar::eliminateSDiv(BinaryOperator *SDiv) { SDiv->getName() + ".udiv", SDiv); UDiv->setIsExact(SDiv->isExact()); SDiv->replaceAllUsesWith(UDiv); - DEBUG(dbgs() << "INDVARS: Simplified sdiv: " << *SDiv << '\n'); + LLVM_DEBUG(dbgs() << "INDVARS: Simplified sdiv: " << *SDiv << '\n'); ++NumSimplifiedSDiv; Changed = true; DeadInsts.push_back(SDiv); @@ -309,7 +312,7 @@ void SimplifyIndvar::replaceSRemWithURem(BinaryOperator *Rem) { auto *URem = BinaryOperator::Create(BinaryOperator::URem, N, D, Rem->getName() + ".urem", Rem); Rem->replaceAllUsesWith(URem); - DEBUG(dbgs() << "INDVARS: Simplified srem: " << *Rem << '\n'); + LLVM_DEBUG(dbgs() << "INDVARS: Simplified srem: " << *Rem << '\n'); ++NumSimplifiedSRem; Changed = true; DeadInsts.emplace_back(Rem); @@ -318,7 +321,7 @@ void SimplifyIndvar::replaceSRemWithURem(BinaryOperator *Rem) { // i % n --> i if i is in [0,n). void SimplifyIndvar::replaceRemWithNumerator(BinaryOperator *Rem) { Rem->replaceAllUsesWith(Rem->getOperand(0)); - DEBUG(dbgs() << "INDVARS: Simplified rem: " << *Rem << '\n'); + LLVM_DEBUG(dbgs() << "INDVARS: Simplified rem: " << *Rem << '\n'); ++NumElimRem; Changed = true; DeadInsts.emplace_back(Rem); @@ -332,7 +335,7 @@ void SimplifyIndvar::replaceRemWithNumeratorOrZero(BinaryOperator *Rem) { SelectInst *Sel = SelectInst::Create(ICmp, ConstantInt::get(T, 0), N, "iv.rem", Rem); Rem->replaceAllUsesWith(Sel); - DEBUG(dbgs() << "INDVARS: Simplified rem: " << *Rem << '\n'); + LLVM_DEBUG(dbgs() << "INDVARS: Simplified rem: " << *Rem << '\n'); ++NumElimRem; Changed = true; DeadInsts.emplace_back(Rem); @@ -492,6 +495,118 @@ bool SimplifyIndvar::eliminateOverflowIntrinsic(CallInst *CI) { return true; } +bool SimplifyIndvar::eliminateTrunc(TruncInst *TI) { + // It is always legal to replace + // icmp <pred> i32 trunc(iv), n + // with + // icmp <pred> i64 sext(trunc(iv)), sext(n), if pred is signed predicate. + // Or with + // icmp <pred> i64 zext(trunc(iv)), zext(n), if pred is unsigned predicate. + // Or with either of these if pred is an equality predicate. + // + // If we can prove that iv == sext(trunc(iv)) or iv == zext(trunc(iv)) for + // every comparison which uses trunc, it means that we can replace each of + // them with comparison of iv against sext/zext(n). We no longer need trunc + // after that. + // + // TODO: Should we do this if we can widen *some* comparisons, but not all + // of them? Sometimes it is enough to enable other optimizations, but the + // trunc instruction will stay in the loop. + Value *IV = TI->getOperand(0); + Type *IVTy = IV->getType(); + const SCEV *IVSCEV = SE->getSCEV(IV); + const SCEV *TISCEV = SE->getSCEV(TI); + + // Check if iv == zext(trunc(iv)) and if iv == sext(trunc(iv)). If so, we can + // get rid of trunc + bool DoesSExtCollapse = false; + bool DoesZExtCollapse = false; + if (IVSCEV == SE->getSignExtendExpr(TISCEV, IVTy)) + DoesSExtCollapse = true; + if (IVSCEV == SE->getZeroExtendExpr(TISCEV, IVTy)) + DoesZExtCollapse = true; + + // If neither sext nor zext does collapse, it is not profitable to do any + // transform. Bail. + if (!DoesSExtCollapse && !DoesZExtCollapse) + return false; + + // Collect users of the trunc that look like comparisons against invariants. + // Bail if we find something different. + SmallVector<ICmpInst *, 4> ICmpUsers; + for (auto *U : TI->users()) { + // We don't care about users in unreachable blocks. + if (isa<Instruction>(U) && + !DT->isReachableFromEntry(cast<Instruction>(U)->getParent())) + continue; + if (ICmpInst *ICI = dyn_cast<ICmpInst>(U)) { + if (ICI->getOperand(0) == TI && L->isLoopInvariant(ICI->getOperand(1))) { + assert(L->contains(ICI->getParent()) && "LCSSA form broken?"); + // If we cannot get rid of trunc, bail. + if (ICI->isSigned() && !DoesSExtCollapse) + return false; + if (ICI->isUnsigned() && !DoesZExtCollapse) + return false; + // For equality, either signed or unsigned works. + ICmpUsers.push_back(ICI); + } else + return false; + } else + return false; + } + + auto CanUseZExt = [&](ICmpInst *ICI) { + // Unsigned comparison can be widened as unsigned. + if (ICI->isUnsigned()) + return true; + // Is it profitable to do zext? + if (!DoesZExtCollapse) + return false; + // For equality, we can safely zext both parts. + if (ICI->isEquality()) + return true; + // Otherwise we can only use zext when comparing two non-negative or two + // negative values. But in practice, we will never pass DoesZExtCollapse + // check for a negative value, because zext(trunc(x)) is non-negative. So + // it only make sense to check for non-negativity here. + const SCEV *SCEVOP1 = SE->getSCEV(ICI->getOperand(0)); + const SCEV *SCEVOP2 = SE->getSCEV(ICI->getOperand(1)); + return SE->isKnownNonNegative(SCEVOP1) && SE->isKnownNonNegative(SCEVOP2); + }; + // Replace all comparisons against trunc with comparisons against IV. + for (auto *ICI : ICmpUsers) { + auto *Op1 = ICI->getOperand(1); + Instruction *Ext = nullptr; + // For signed/unsigned predicate, replace the old comparison with comparison + // of immediate IV against sext/zext of the invariant argument. If we can + // use either sext or zext (i.e. we are dealing with equality predicate), + // then prefer zext as a more canonical form. + // TODO: If we see a signed comparison which can be turned into unsigned, + // we can do it here for canonicalization purposes. + ICmpInst::Predicate Pred = ICI->getPredicate(); + if (CanUseZExt(ICI)) { + assert(DoesZExtCollapse && "Unprofitable zext?"); + Ext = new ZExtInst(Op1, IVTy, "zext", ICI); + Pred = ICmpInst::getUnsignedPredicate(Pred); + } else { + assert(DoesSExtCollapse && "Unprofitable sext?"); + Ext = new SExtInst(Op1, IVTy, "sext", ICI); + assert(Pred == ICmpInst::getSignedPredicate(Pred) && "Must be signed!"); + } + bool Changed; + L->makeLoopInvariant(Ext, Changed); + (void)Changed; + ICmpInst *NewICI = new ICmpInst(ICI, Pred, IV, Ext); + ICI->replaceAllUsesWith(NewICI); + DeadInsts.emplace_back(ICI); + } + + // Trunc no longer needed. + TI->replaceAllUsesWith(UndefValue::get(TI->getType())); + DeadInsts.emplace_back(TI); + return true; +} + /// Eliminate an operation that consumes a simple IV and has no observable /// side-effect given the range of IV values. IVOperand is guaranteed SCEVable, /// but UseInst may not be. @@ -516,6 +631,10 @@ bool SimplifyIndvar::eliminateIVUser(Instruction *UseInst, if (eliminateOverflowIntrinsic(CI)) return true; + if (auto *TI = dyn_cast<TruncInst>(UseInst)) + if (eliminateTrunc(TI)) + return true; + if (eliminateIdentitySCEV(UseInst, IVOperand)) return true; @@ -548,8 +667,8 @@ bool SimplifyIndvar::replaceIVUserWithLoopInvariant(Instruction *I) { auto *Invariant = Rewriter.expandCodeFor(S, I->getType(), IP); I->replaceAllUsesWith(Invariant); - DEBUG(dbgs() << "INDVARS: Replace IV user: " << *I - << " with loop invariant: " << *S << '\n'); + LLVM_DEBUG(dbgs() << "INDVARS: Replace IV user: " << *I + << " with loop invariant: " << *S << '\n'); ++NumFoldedUser; Changed = true; DeadInsts.emplace_back(I); @@ -589,7 +708,7 @@ bool SimplifyIndvar::eliminateIdentitySCEV(Instruction *UseInst, if (!LI->replacementPreservesLCSSAForm(UseInst, IVOperand)) return false; - DEBUG(dbgs() << "INDVARS: Eliminated identity: " << *UseInst << '\n'); + LLVM_DEBUG(dbgs() << "INDVARS: Eliminated identity: " << *UseInst << '\n'); UseInst->replaceAllUsesWith(IVOperand); ++NumElimIdentity; @@ -771,6 +890,15 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { SimpleIVUsers.pop_back_val(); Instruction *UseInst = UseOper.first; + // If a user of the IndVar is trivially dead, we prefer just to mark it dead + // rather than try to do some complex analysis or transformation (such as + // widening) basing on it. + // TODO: Propagate TLI and pass it here to handle more cases. + if (isInstructionTriviallyDead(UseInst, /* TLI */ nullptr)) { + DeadInsts.emplace_back(UseInst); + continue; + } + // Bypass back edges to avoid extra work. if (UseInst == CurrIV) continue; @@ -783,7 +911,7 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { for (unsigned N = 0; IVOperand; ++N) { assert(N <= Simplified.size() && "runaway iteration"); - Value *NewOper = foldIVUser(UseOper.first, IVOperand); + Value *NewOper = foldIVUser(UseInst, IVOperand); if (!NewOper) break; // done folding IVOperand = dyn_cast<Instruction>(NewOper); @@ -791,12 +919,12 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { if (!IVOperand) continue; - if (eliminateIVUser(UseOper.first, IVOperand)) { + if (eliminateIVUser(UseInst, IVOperand)) { pushIVUsers(IVOperand, L, Simplified, SimpleIVUsers); continue; } - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(UseOper.first)) { + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(UseInst)) { if ((isa<OverflowingBinaryOperator>(BO) && strengthenOverflowingOperation(BO, IVOperand)) || (isa<ShlOperator>(BO) && strengthenRightShift(BO, IVOperand))) { @@ -806,13 +934,13 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { } } - CastInst *Cast = dyn_cast<CastInst>(UseOper.first); + CastInst *Cast = dyn_cast<CastInst>(UseInst); if (V && Cast) { V->visitCast(Cast); continue; } - if (isSimpleIVUser(UseOper.first, L, SE)) { - pushIVUsers(UseOper.first, L, Simplified, SimpleIVUsers); + if (isSimpleIVUser(UseInst, L, SE)) { + pushIVUsers(UseInst, L, Simplified, SimpleIVUsers); } } } diff --git a/lib/Transforms/Utils/SimplifyLibCalls.cpp b/lib/Transforms/Utils/SimplifyLibCalls.cpp index 03a1d55ddc30..8c48597fc2e4 100644 --- a/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -7,10 +7,8 @@ // //===----------------------------------------------------------------------===// // -// This is a utility pass used for testing the InstructionSimplify analysis. -// The analysis is applied to every instruction, and if it simplifies then the -// instruction is replaced by the simplification. If you are looking for a pass -// that performs serious instruction folding, use the instcombine pass instead. +// This file implements the library calls simplifier. It does not implement +// any pass, but can't be used by other passes to do simplifications. // //===----------------------------------------------------------------------===// @@ -21,7 +19,9 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/CaptureTracking.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -33,7 +33,6 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" -#include "llvm/Transforms/Utils/Local.h" using namespace llvm; using namespace PatternMatch; @@ -104,19 +103,51 @@ static bool callHasFloatingPointArgument(const CallInst *CI) { }); } -/// \brief Check whether the overloaded unary floating point function -/// corresponding to \a Ty is available. -static bool hasUnaryFloatFn(const TargetLibraryInfo *TLI, Type *Ty, - LibFunc DoubleFn, LibFunc FloatFn, - LibFunc LongDoubleFn) { - switch (Ty->getTypeID()) { - case Type::FloatTyID: - return TLI->has(FloatFn); - case Type::DoubleTyID: - return TLI->has(DoubleFn); - default: - return TLI->has(LongDoubleFn); - } +static Value *convertStrToNumber(CallInst *CI, StringRef &Str, int64_t Base) { + if (Base < 2 || Base > 36) + // handle special zero base + if (Base != 0) + return nullptr; + + char *End; + std::string nptr = Str.str(); + errno = 0; + long long int Result = strtoll(nptr.c_str(), &End, Base); + if (errno) + return nullptr; + + // if we assume all possible target locales are ASCII supersets, + // then if strtoll successfully parses a number on the host, + // it will also successfully parse the same way on the target + if (*End != '\0') + return nullptr; + + if (!isIntN(CI->getType()->getPrimitiveSizeInBits(), Result)) + return nullptr; + + return ConstantInt::get(CI->getType(), Result); +} + +static bool isLocallyOpenedFile(Value *File, CallInst *CI, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { + CallInst *FOpen = dyn_cast<CallInst>(File); + if (!FOpen) + return false; + + Function *InnerCallee = FOpen->getCalledFunction(); + if (!InnerCallee) + return false; + + LibFunc Func; + if (!TLI->getLibFunc(*InnerCallee, Func) || !TLI->has(Func) || + Func != LibFunc_fopen) + return false; + + inferLibFuncAttributes(*CI->getCalledFunction(), *TLI); + if (PointerMayBeCaptured(File, true, true)) + return false; + + return true; } //===----------------------------------------------------------------------===// @@ -156,9 +187,8 @@ Value *LibCallSimplifier::emitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len, // We have enough information to now generate the memcpy call to do the // concatenation for us. Make a memcpy to copy the nul byte with align = 1. - B.CreateMemCpy(CpyDst, Src, - ConstantInt::get(DL.getIntPtrType(Src->getContext()), Len + 1), - 1); + B.CreateMemCpy(CpyDst, 1, Src, 1, + ConstantInt::get(DL.getIntPtrType(Src->getContext()), Len + 1)); return Dst; } @@ -346,8 +376,8 @@ Value *LibCallSimplifier::optimizeStrCpy(CallInst *CI, IRBuilder<> &B) { // We have enough information to now generate the memcpy call to do the // copy for us. Make a memcpy to copy the nul byte with align = 1. - B.CreateMemCpy(Dst, Src, - ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len), 1); + B.CreateMemCpy(Dst, 1, Src, 1, + ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len)); return Dst; } @@ -371,7 +401,7 @@ Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilder<> &B) { // We have enough information to now generate the memcpy call to do the // copy for us. Make a memcpy to copy the nul byte with align = 1. - B.CreateMemCpy(Dst, Src, LenV, 1); + B.CreateMemCpy(Dst, 1, Src, 1, LenV); return DstEnd; } @@ -388,7 +418,7 @@ Value *LibCallSimplifier::optimizeStrNCpy(CallInst *CI, IRBuilder<> &B) { --SrcLen; if (SrcLen == 0) { - // strncpy(x, "", y) -> memset(x, '\0', y, 1) + // strncpy(x, "", y) -> memset(align 1 x, '\0', y) B.CreateMemSet(Dst, B.getInt8('\0'), LenOp, 1); return Dst; } @@ -407,8 +437,8 @@ Value *LibCallSimplifier::optimizeStrNCpy(CallInst *CI, IRBuilder<> &B) { return nullptr; Type *PT = Callee->getFunctionType()->getParamType(0); - // strncpy(x, s, c) -> memcpy(x, s, c, 1) [s and c are constant] - B.CreateMemCpy(Dst, Src, ConstantInt::get(DL.getIntPtrType(PT), Len), 1); + // strncpy(x, s, c) -> memcpy(align 1 x, align 1 s, c) [s and c are constant] + B.CreateMemCpy(Dst, 1, Src, 1, ConstantInt::get(DL.getIntPtrType(PT), Len)); return Dst; } @@ -508,7 +538,7 @@ Value *LibCallSimplifier::optimizeStrLen(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeWcslen(CallInst *CI, IRBuilder<> &B) { - Module &M = *CI->getParent()->getParent()->getParent(); + Module &M = *CI->getModule(); unsigned WCharSize = TLI->getWCharSize(M) * 8; // We cannot perform this optimization without wchar_size metadata. if (WCharSize == 0) @@ -816,40 +846,19 @@ Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeMemCpy(CallInst *CI, IRBuilder<> &B) { - // memcpy(x, y, n) -> llvm.memcpy(x, y, n, 1) - B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); + // memcpy(x, y, n) -> llvm.memcpy(align 1 x, align 1 y, n) + B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, + CI->getArgOperand(2)); return CI->getArgOperand(0); } Value *LibCallSimplifier::optimizeMemMove(CallInst *CI, IRBuilder<> &B) { - // memmove(x, y, n) -> llvm.memmove(x, y, n, 1) - B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); + // memmove(x, y, n) -> llvm.memmove(align 1 x, align 1 y, n) + B.CreateMemMove(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, + CI->getArgOperand(2)); return CI->getArgOperand(0); } -// TODO: Does this belong in BuildLibCalls or should all of those similar -// functions be moved here? -static Value *emitCalloc(Value *Num, Value *Size, const AttributeList &Attrs, - IRBuilder<> &B, const TargetLibraryInfo &TLI) { - LibFunc Func; - if (!TLI.getLibFunc("calloc", Func) || !TLI.has(Func)) - return nullptr; - - Module *M = B.GetInsertBlock()->getModule(); - const DataLayout &DL = M->getDataLayout(); - IntegerType *PtrType = DL.getIntPtrType((B.GetInsertBlock()->getContext())); - Value *Calloc = M->getOrInsertFunction("calloc", Attrs, B.getInt8PtrTy(), - PtrType, PtrType); - CallInst *CI = B.CreateCall(Calloc, { Num, Size }, "calloc"); - - if (const auto *F = dyn_cast<Function>(Calloc->stripPointerCasts())) - CI->setCallingConv(F->getCallingConv()); - - return CI; -} - /// Fold memset[_chk](malloc(n), 0, n) --> calloc(1, n). static Value *foldMallocMemset(CallInst *Memset, IRBuilder<> &B, const TargetLibraryInfo &TLI) { @@ -901,12 +910,19 @@ Value *LibCallSimplifier::optimizeMemSet(CallInst *CI, IRBuilder<> &B) { if (auto *Calloc = foldMallocMemset(CI, B, *TLI)) return Calloc; - // memset(p, v, n) -> llvm.memset(p, v, n, 1) + // memset(p, v, n) -> llvm.memset(align 1 p, v, n) Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false); B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1); return CI->getArgOperand(0); } +Value *LibCallSimplifier::optimizeRealloc(CallInst *CI, IRBuilder<> &B) { + if (isa<ConstantPointerNull>(CI->getArgOperand(0))) + return emitMalloc(CI->getArgOperand(1), B, DL, TLI); + + return nullptr; +} + //===----------------------------------------------------------------------===// // Math Library Optimizations //===----------------------------------------------------------------------===// @@ -1666,12 +1682,12 @@ Value *LibCallSimplifier::optimizeFls(CallInst *CI, IRBuilder<> &B) { } Value *LibCallSimplifier::optimizeAbs(CallInst *CI, IRBuilder<> &B) { - // abs(x) -> x >s -1 ? x : -x - Value *Op = CI->getArgOperand(0); - Value *Pos = - B.CreateICmpSGT(Op, Constant::getAllOnesValue(Op->getType()), "ispos"); - Value *Neg = B.CreateNeg(Op, "neg"); - return B.CreateSelect(Pos, Op, Neg); + // abs(x) -> x <s 0 ? -x : x + // The negation has 'nsw' because abs of INT_MIN is undefined. + Value *X = CI->getArgOperand(0); + Value *IsNeg = B.CreateICmpSLT(X, Constant::getNullValue(X->getType())); + Value *NegX = B.CreateNSWNeg(X, "neg"); + return B.CreateSelect(IsNeg, NegX, X); } Value *LibCallSimplifier::optimizeIsDigit(CallInst *CI, IRBuilder<> &B) { @@ -1695,6 +1711,29 @@ Value *LibCallSimplifier::optimizeToAscii(CallInst *CI, IRBuilder<> &B) { ConstantInt::get(CI->getType(), 0x7F)); } +Value *LibCallSimplifier::optimizeAtoi(CallInst *CI, IRBuilder<> &B) { + StringRef Str; + if (!getConstantStringInfo(CI->getArgOperand(0), Str)) + return nullptr; + + return convertStrToNumber(CI, Str, 10); +} + +Value *LibCallSimplifier::optimizeStrtol(CallInst *CI, IRBuilder<> &B) { + StringRef Str; + if (!getConstantStringInfo(CI->getArgOperand(0), Str)) + return nullptr; + + if (!isa<ConstantPointerNull>(CI->getArgOperand(1))) + return nullptr; + + if (ConstantInt *CInt = dyn_cast<ConstantInt>(CI->getArgOperand(2))) { + return convertStrToNumber(CI, Str, CInt->getSExtValue()); + } + + return nullptr; +} + //===----------------------------------------------------------------------===// // Formatting and IO Library Call Optimizations //===----------------------------------------------------------------------===// @@ -1826,15 +1865,13 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI, IRBuilder<> &B) { if (CI->getNumArgOperands() == 2) { // Make sure there's no % in the constant array. We could try to handle // %% -> % in the future if we cared. - for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) - if (FormatStr[i] == '%') - return nullptr; // we found a format specifier, bail out. + if (FormatStr.find('%') != StringRef::npos) + return nullptr; // we found a format specifier, bail out. - // sprintf(str, fmt) -> llvm.memcpy(str, fmt, strlen(fmt)+1, 1) - B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), + // sprintf(str, fmt) -> llvm.memcpy(align 1 str, align 1 fmt, strlen(fmt)+1) + B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, ConstantInt::get(DL.getIntPtrType(CI->getContext()), - FormatStr.size() + 1), - 1); // Copy the null byte. + FormatStr.size() + 1)); // Copy the null byte. return ConstantInt::get(CI->getType(), FormatStr.size()); } @@ -1868,7 +1905,7 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI, IRBuilder<> &B) { return nullptr; Value *IncLen = B.CreateAdd(Len, ConstantInt::get(Len->getType(), 1), "leninc"); - B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(2), IncLen, 1); + B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(2), 1, IncLen); // The sprintf result is the unincremented number of bytes in the string. return B.CreateIntCast(Len, CI->getType(), false); @@ -1897,6 +1934,93 @@ Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilder<> &B) { return nullptr; } +Value *LibCallSimplifier::optimizeSnPrintFString(CallInst *CI, IRBuilder<> &B) { + // Check for a fixed format string. + StringRef FormatStr; + if (!getConstantStringInfo(CI->getArgOperand(2), FormatStr)) + return nullptr; + + // Check for size + ConstantInt *Size = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + if (!Size) + return nullptr; + + uint64_t N = Size->getZExtValue(); + + // If we just have a format string (nothing else crazy) transform it. + if (CI->getNumArgOperands() == 3) { + // Make sure there's no % in the constant array. We could try to handle + // %% -> % in the future if we cared. + if (FormatStr.find('%') != StringRef::npos) + return nullptr; // we found a format specifier, bail out. + + if (N == 0) + return ConstantInt::get(CI->getType(), FormatStr.size()); + else if (N < FormatStr.size() + 1) + return nullptr; + + // sprintf(str, size, fmt) -> llvm.memcpy(align 1 str, align 1 fmt, + // strlen(fmt)+1) + B.CreateMemCpy( + CI->getArgOperand(0), 1, CI->getArgOperand(2), 1, + ConstantInt::get(DL.getIntPtrType(CI->getContext()), + FormatStr.size() + 1)); // Copy the null byte. + return ConstantInt::get(CI->getType(), FormatStr.size()); + } + + // The remaining optimizations require the format string to be "%s" or "%c" + // and have an extra operand. + if (FormatStr.size() == 2 && FormatStr[0] == '%' && + CI->getNumArgOperands() == 4) { + + // Decode the second character of the format string. + if (FormatStr[1] == 'c') { + if (N == 0) + return ConstantInt::get(CI->getType(), 1); + else if (N == 1) + return nullptr; + + // snprintf(dst, size, "%c", chr) --> *(i8*)dst = chr; *((i8*)dst+1) = 0 + if (!CI->getArgOperand(3)->getType()->isIntegerTy()) + return nullptr; + Value *V = B.CreateTrunc(CI->getArgOperand(3), B.getInt8Ty(), "char"); + Value *Ptr = castToCStr(CI->getArgOperand(0), B); + B.CreateStore(V, Ptr); + Ptr = B.CreateGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); + B.CreateStore(B.getInt8(0), Ptr); + + return ConstantInt::get(CI->getType(), 1); + } + + if (FormatStr[1] == 's') { + // snprintf(dest, size, "%s", str) to llvm.memcpy(dest, str, len+1, 1) + StringRef Str; + if (!getConstantStringInfo(CI->getArgOperand(3), Str)) + return nullptr; + + if (N == 0) + return ConstantInt::get(CI->getType(), Str.size()); + else if (N < Str.size() + 1) + return nullptr; + + B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(3), 1, + ConstantInt::get(CI->getType(), Str.size() + 1)); + + // The snprintf result is the unincremented number of bytes in the string. + return ConstantInt::get(CI->getType(), Str.size()); + } + } + return nullptr; +} + +Value *LibCallSimplifier::optimizeSnPrintF(CallInst *CI, IRBuilder<> &B) { + if (Value *V = optimizeSnPrintFString(CI, B)) { + return V; + } + + return nullptr; +} + Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI, IRBuilder<> &B) { optimizeErrorReporting(CI, B, 0); @@ -1913,9 +2037,9 @@ Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI, IRBuilder<> &B) { // fprintf(F, "foo") --> fwrite("foo", 3, 1, F) if (CI->getNumArgOperands() == 2) { - for (unsigned i = 0, e = FormatStr.size(); i != e; ++i) - if (FormatStr[i] == '%') // Could handle %% -> % if we cared. - return nullptr; // We found a format specifier. + // Could handle %% -> % if we cared. + if (FormatStr.find('%') != StringRef::npos) + return nullptr; // We found a format specifier. return emitFWrite( CI->getArgOperand(1), @@ -1973,22 +2097,27 @@ Value *LibCallSimplifier::optimizeFWrite(CallInst *CI, IRBuilder<> &B) { // Get the element size and count. ConstantInt *SizeC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); ConstantInt *CountC = dyn_cast<ConstantInt>(CI->getArgOperand(2)); - if (!SizeC || !CountC) - return nullptr; - uint64_t Bytes = SizeC->getZExtValue() * CountC->getZExtValue(); - - // If this is writing zero records, remove the call (it's a noop). - if (Bytes == 0) - return ConstantInt::get(CI->getType(), 0); - - // If this is writing one byte, turn it into fputc. - // This optimisation is only valid, if the return value is unused. - if (Bytes == 1 && CI->use_empty()) { // fwrite(S,1,1,F) -> fputc(S[0],F) - Value *Char = B.CreateLoad(castToCStr(CI->getArgOperand(0), B), "char"); - Value *NewCI = emitFPutC(Char, CI->getArgOperand(3), B, TLI); - return NewCI ? ConstantInt::get(CI->getType(), 1) : nullptr; + if (SizeC && CountC) { + uint64_t Bytes = SizeC->getZExtValue() * CountC->getZExtValue(); + + // If this is writing zero records, remove the call (it's a noop). + if (Bytes == 0) + return ConstantInt::get(CI->getType(), 0); + + // If this is writing one byte, turn it into fputc. + // This optimisation is only valid, if the return value is unused. + if (Bytes == 1 && CI->use_empty()) { // fwrite(S,1,1,F) -> fputc(S[0],F) + Value *Char = B.CreateLoad(castToCStr(CI->getArgOperand(0), B), "char"); + Value *NewCI = emitFPutC(Char, CI->getArgOperand(3), B, TLI); + return NewCI ? ConstantInt::get(CI->getType(), 1) : nullptr; + } } + if (isLocallyOpenedFile(CI->getArgOperand(3), CI, B, TLI)) + return emitFWriteUnlocked(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), CI->getArgOperand(3), B, DL, + TLI); + return nullptr; } @@ -1997,12 +2126,18 @@ Value *LibCallSimplifier::optimizeFPuts(CallInst *CI, IRBuilder<> &B) { // Don't rewrite fputs to fwrite when optimising for size because fwrite // requires more arguments and thus extra MOVs are required. - if (CI->getParent()->getParent()->optForSize()) + if (CI->getFunction()->optForSize()) return nullptr; - // We can't optimize if return value is used. - if (!CI->use_empty()) - return nullptr; + // Check if has any use + if (!CI->use_empty()) { + if (isLocallyOpenedFile(CI->getArgOperand(1), CI, B, TLI)) + return emitFPutSUnlocked(CI->getArgOperand(0), CI->getArgOperand(1), B, + TLI); + else + // We can't optimize if return value is used. + return nullptr; + } // fputs(s,F) --> fwrite(s,1,strlen(s),F) uint64_t Len = GetStringLength(CI->getArgOperand(0)); @@ -2016,6 +2151,40 @@ Value *LibCallSimplifier::optimizeFPuts(CallInst *CI, IRBuilder<> &B) { CI->getArgOperand(1), B, DL, TLI); } +Value *LibCallSimplifier::optimizeFPutc(CallInst *CI, IRBuilder<> &B) { + optimizeErrorReporting(CI, B, 1); + + if (isLocallyOpenedFile(CI->getArgOperand(1), CI, B, TLI)) + return emitFPutCUnlocked(CI->getArgOperand(0), CI->getArgOperand(1), B, + TLI); + + return nullptr; +} + +Value *LibCallSimplifier::optimizeFGetc(CallInst *CI, IRBuilder<> &B) { + if (isLocallyOpenedFile(CI->getArgOperand(0), CI, B, TLI)) + return emitFGetCUnlocked(CI->getArgOperand(0), B, TLI); + + return nullptr; +} + +Value *LibCallSimplifier::optimizeFGets(CallInst *CI, IRBuilder<> &B) { + if (isLocallyOpenedFile(CI->getArgOperand(2), CI, B, TLI)) + return emitFGetSUnlocked(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), B, TLI); + + return nullptr; +} + +Value *LibCallSimplifier::optimizeFRead(CallInst *CI, IRBuilder<> &B) { + if (isLocallyOpenedFile(CI->getArgOperand(3), CI, B, TLI)) + return emitFReadUnlocked(CI->getArgOperand(0), CI->getArgOperand(1), + CI->getArgOperand(2), CI->getArgOperand(3), B, DL, + TLI); + + return nullptr; +} + Value *LibCallSimplifier::optimizePuts(CallInst *CI, IRBuilder<> &B) { // Check for a constant string. StringRef Str; @@ -2099,6 +2268,8 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, return optimizeMemMove(CI, Builder); case LibFunc_memset: return optimizeMemSet(CI, Builder); + case LibFunc_realloc: + return optimizeRealloc(CI, Builder); case LibFunc_wcslen: return optimizeWcslen(CI, Builder); default: @@ -2290,16 +2461,33 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { return optimizeIsAscii(CI, Builder); case LibFunc_toascii: return optimizeToAscii(CI, Builder); + case LibFunc_atoi: + case LibFunc_atol: + case LibFunc_atoll: + return optimizeAtoi(CI, Builder); + case LibFunc_strtol: + case LibFunc_strtoll: + return optimizeStrtol(CI, Builder); case LibFunc_printf: return optimizePrintF(CI, Builder); case LibFunc_sprintf: return optimizeSPrintF(CI, Builder); + case LibFunc_snprintf: + return optimizeSnPrintF(CI, Builder); case LibFunc_fprintf: return optimizeFPrintF(CI, Builder); case LibFunc_fwrite: return optimizeFWrite(CI, Builder); + case LibFunc_fread: + return optimizeFRead(CI, Builder); case LibFunc_fputs: return optimizeFPuts(CI, Builder); + case LibFunc_fgets: + return optimizeFGets(CI, Builder); + case LibFunc_fputc: + return optimizeFPutc(CI, Builder); + case LibFunc_fgetc: + return optimizeFGetc(CI, Builder); case LibFunc_puts: return optimizePuts(CI, Builder); case LibFunc_perror: @@ -2307,8 +2495,6 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { case LibFunc_vfprintf: case LibFunc_fiprintf: return optimizeErrorReporting(CI, Builder, 0); - case LibFunc_fputc: - return optimizeErrorReporting(CI, Builder, 1); default: return nullptr; } @@ -2393,8 +2579,8 @@ bool FortifiedLibCallSimplifier::isFortifiedCallFoldable(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeMemCpyChk(CallInst *CI, IRBuilder<> &B) { if (isFortifiedCallFoldable(CI, 3, 2, false)) { - B.CreateMemCpy(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); + B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, + CI->getArgOperand(2)); return CI->getArgOperand(0); } return nullptr; @@ -2403,8 +2589,8 @@ Value *FortifiedLibCallSimplifier::optimizeMemCpyChk(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeMemMoveChk(CallInst *CI, IRBuilder<> &B) { if (isFortifiedCallFoldable(CI, 3, 2, false)) { - B.CreateMemMove(CI->getArgOperand(0), CI->getArgOperand(1), - CI->getArgOperand(2), 1); + B.CreateMemMove(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, + CI->getArgOperand(2)); return CI->getArgOperand(0); } return nullptr; diff --git a/lib/Transforms/Utils/SplitModule.cpp b/lib/Transforms/Utils/SplitModule.cpp index 968eb0208f43..f8d758c54983 100644 --- a/lib/Transforms/Utils/SplitModule.cpp +++ b/lib/Transforms/Utils/SplitModule.cpp @@ -101,7 +101,8 @@ static void findPartitions(Module *M, ClusterIDMapType &ClusterIDMap, // At this point module should have the proper mix of globals and locals. // As we attempt to partition this module, we must not change any // locals to globals. - DEBUG(dbgs() << "Partition module with (" << M->size() << ")functions\n"); + LLVM_DEBUG(dbgs() << "Partition module with (" << M->size() + << ")functions\n"); ClusterMapType GVtoClusterMap; ComdatMembersType ComdatMembers; @@ -180,28 +181,31 @@ static void findPartitions(Module *M, ClusterIDMapType &ClusterIDMap, std::make_pair(std::distance(GVtoClusterMap.member_begin(I), GVtoClusterMap.member_end()), I)); - std::sort(Sets.begin(), Sets.end(), [](const SortType &a, const SortType &b) { - if (a.first == b.first) - return a.second->getData()->getName() > b.second->getData()->getName(); - else - return a.first > b.first; - }); + llvm::sort(Sets.begin(), Sets.end(), + [](const SortType &a, const SortType &b) { + if (a.first == b.first) + return a.second->getData()->getName() > + b.second->getData()->getName(); + else + return a.first > b.first; + }); for (auto &I : Sets) { unsigned CurrentClusterID = BalancinQueue.top().first; unsigned CurrentClusterSize = BalancinQueue.top().second; BalancinQueue.pop(); - DEBUG(dbgs() << "Root[" << CurrentClusterID << "] cluster_size(" << I.first - << ") ----> " << I.second->getData()->getName() << "\n"); + LLVM_DEBUG(dbgs() << "Root[" << CurrentClusterID << "] cluster_size(" + << I.first << ") ----> " << I.second->getData()->getName() + << "\n"); for (ClusterMapType::member_iterator MI = GVtoClusterMap.findLeader(I.second); MI != GVtoClusterMap.member_end(); ++MI) { if (!Visited.insert(*MI).second) continue; - DEBUG(dbgs() << "----> " << (*MI)->getName() - << ((*MI)->hasLocalLinkage() ? " l " : " e ") << "\n"); + LLVM_DEBUG(dbgs() << "----> " << (*MI)->getName() + << ((*MI)->hasLocalLinkage() ? " l " : " e ") << "\n"); Visited.insert(*MI); ClusterIDMap[*MI] = CurrentClusterID; CurrentClusterSize++; @@ -270,7 +274,7 @@ void llvm::SplitModule( for (unsigned I = 0; I < N; ++I) { ValueToValueMapTy VMap; std::unique_ptr<Module> MPart( - CloneModule(M.get(), VMap, [&](const GlobalValue *GV) { + CloneModule(*M, VMap, [&](const GlobalValue *GV) { if (ClusterIDMap.count(GV)) return (ClusterIDMap[GV] == I); else diff --git a/lib/Transforms/Utils/StripGCRelocates.cpp b/lib/Transforms/Utils/StripGCRelocates.cpp index 49dc15cf5e7c..ac0b519f4a77 100644 --- a/lib/Transforms/Utils/StripGCRelocates.cpp +++ b/lib/Transforms/Utils/StripGCRelocates.cpp @@ -21,7 +21,6 @@ #include "llvm/IR/Type.h" #include "llvm/Pass.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" using namespace llvm; @@ -75,6 +74,3 @@ bool StripGCRelocates::runOnFunction(Function &F) { INITIALIZE_PASS(StripGCRelocates, "strip-gc-relocates", "Strip gc.relocates inserted through RewriteStatepointsForGC", true, false) -FunctionPass *llvm::createStripGCRelocatesPass() { - return new StripGCRelocates(); -} diff --git a/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp b/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp index cd0378e0140c..8956a089a99c 100644 --- a/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp +++ b/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp @@ -9,7 +9,7 @@ #include "llvm/IR/DebugInfo.h" #include "llvm/Pass.h" -#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/Utils.h" using namespace llvm; namespace { diff --git a/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp b/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp index ed444e4cf43c..e633ac0c874d 100644 --- a/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp +++ b/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp @@ -19,7 +19,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Type.h" -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" using namespace llvm; char UnifyFunctionExitNodes::ID = 0; diff --git a/lib/Transforms/Utils/Utils.cpp b/lib/Transforms/Utils/Utils.cpp index f6c7d1c4989e..afd842f59911 100644 --- a/lib/Transforms/Utils/Utils.cpp +++ b/lib/Transforms/Utils/Utils.cpp @@ -12,7 +12,10 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Utils.h" #include "llvm-c/Initialization.h" +#include "llvm-c/Transforms/Utils.h" +#include "llvm/IR/LegacyPassManager.h" #include "llvm/InitializePasses.h" #include "llvm/PassRegistry.h" @@ -33,7 +36,6 @@ void llvm::initializeTransformUtils(PassRegistry &Registry) { initializePromoteLegacyPassPass(Registry); initializeStripNonLineTableDebugInfoPass(Registry); initializeUnifyFunctionExitNodesPass(Registry); - initializeInstSimplifierPass(Registry); initializeMetaRenamerPass(Registry); initializeStripGCRelocatesPass(Registry); initializePredicateInfoPrinterLegacyPassPass(Registry); @@ -43,3 +45,12 @@ void llvm::initializeTransformUtils(PassRegistry &Registry) { void LLVMInitializeTransformUtils(LLVMPassRegistryRef R) { initializeTransformUtils(*unwrap(R)); } + +void LLVMAddLowerSwitchPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLowerSwitchPass()); +} + +void LLVMAddPromoteMemoryToRegisterPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createPromoteMemoryToRegisterPass()); +} + diff --git a/lib/Transforms/Utils/VNCoercion.cpp b/lib/Transforms/Utils/VNCoercion.cpp index c3feea6a0a41..948d9bd5baad 100644 --- a/lib/Transforms/Utils/VNCoercion.cpp +++ b/lib/Transforms/Utils/VNCoercion.cpp @@ -20,8 +20,14 @@ bool canCoerceMustAliasedValueToLoad(Value *StoredVal, Type *LoadTy, StoredVal->getType()->isStructTy() || StoredVal->getType()->isArrayTy()) return false; + uint64_t StoreSize = DL.getTypeSizeInBits(StoredVal->getType()); + + // The store size must be byte-aligned to support future type casts. + if (llvm::alignTo(StoreSize, 8) != StoreSize) + return false; + // The store has to be at least as big as the load. - if (DL.getTypeSizeInBits(StoredVal->getType()) < DL.getTypeSizeInBits(LoadTy)) + if (StoreSize < DL.getTypeSizeInBits(LoadTy)) return false; // Don't coerce non-integral pointers to integers or vice versa. @@ -389,8 +395,8 @@ Value *getLoadValueForLoad(LoadInst *SrcVal, unsigned Offset, Type *LoadTy, NewLoad->takeName(SrcVal); NewLoad->setAlignment(SrcVal->getAlignment()); - DEBUG(dbgs() << "GVN WIDENED LOAD: " << *SrcVal << "\n"); - DEBUG(dbgs() << "TO: " << *NewLoad << "\n"); + LLVM_DEBUG(dbgs() << "GVN WIDENED LOAD: " << *SrcVal << "\n"); + LLVM_DEBUG(dbgs() << "TO: " << *NewLoad << "\n"); // Replace uses of the original load with the wider load. On a big endian // system, we need to shift down to get the relevant bits. diff --git a/lib/Transforms/Utils/ValueMapper.cpp b/lib/Transforms/Utils/ValueMapper.cpp index 8c9ecbc3503e..55fff3f3872a 100644 --- a/lib/Transforms/Utils/ValueMapper.cpp +++ b/lib/Transforms/Utils/ValueMapper.cpp @@ -25,6 +25,7 @@ #include "llvm/IR/CallSite.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalAlias.h" @@ -536,13 +537,23 @@ Optional<Metadata *> MDNodeMapper::tryToMapOperand(const Metadata *Op) { return None; } +static Metadata *cloneOrBuildODR(const MDNode &N) { + auto *CT = dyn_cast<DICompositeType>(&N); + // If ODR type uniquing is enabled, we would have uniqued composite types + // with identifiers during bitcode reading, so we can just use CT. + if (CT && CT->getContext().isODRUniquingDebugTypes() && + CT->getIdentifier() != "") + return const_cast<DICompositeType *>(CT); + return MDNode::replaceWithDistinct(N.clone()); +} + MDNode *MDNodeMapper::mapDistinctNode(const MDNode &N) { assert(N.isDistinct() && "Expected a distinct node"); assert(!M.getVM().getMappedMD(&N) && "Expected an unmapped node"); - DistinctWorklist.push_back(cast<MDNode>( - (M.Flags & RF_MoveDistinctMDs) - ? M.mapToSelf(&N) - : M.mapToMetadata(&N, MDNode::replaceWithDistinct(N.clone())))); + DistinctWorklist.push_back( + cast<MDNode>((M.Flags & RF_MoveDistinctMDs) + ? M.mapToSelf(&N) + : M.mapToMetadata(&N, cloneOrBuildODR(N)))); return DistinctWorklist.back(); } diff --git a/lib/Transforms/Vectorize/CMakeLists.txt b/lib/Transforms/Vectorize/CMakeLists.txt index 7622ed6d194f..27a4d241b320 100644 --- a/lib/Transforms/Vectorize/CMakeLists.txt +++ b/lib/Transforms/Vectorize/CMakeLists.txt @@ -1,9 +1,13 @@ add_llvm_library(LLVMVectorize LoadStoreVectorizer.cpp + LoopVectorizationLegality.cpp LoopVectorize.cpp SLPVectorizer.cpp Vectorize.cpp VPlan.cpp + VPlanHCFGBuilder.cpp + VPlanHCFGTransforms.cpp + VPlanVerifier.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms diff --git a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index dc83b6d4d292..5f3d127202ad 100644 --- a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -6,6 +6,38 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// +// +// This pass merges loads/stores to/from sequential memory addresses into vector +// loads/stores. Although there's nothing GPU-specific in here, this pass is +// motivated by the microarchitectural quirks of nVidia and AMD GPUs. +// +// (For simplicity below we talk about loads only, but everything also applies +// to stores.) +// +// This pass is intended to be run late in the pipeline, after other +// vectorization opportunities have been exploited. So the assumption here is +// that immediately following our new vector load we'll need to extract out the +// individual elements of the load, so we can operate on them individually. +// +// On CPUs this transformation is usually not beneficial, because extracting the +// elements of a vector register is expensive on most architectures. It's +// usually better just to load each element individually into its own scalar +// register. +// +// However, nVidia and AMD GPUs don't have proper vector registers. Instead, a +// "vector load" loads directly into a series of scalar registers. In effect, +// extracting the elements of the vector is free. It's therefore always +// beneficial to vectorize a sequence of loads on these architectures. +// +// Vectorizing (perhaps a better name might be "coalescing") loads can have +// large performance impacts on GPU kernels, and opportunities for vectorizing +// are common in GPU code. This pass tries very hard to find such +// opportunities; its runtime is quadratic in the number of loads in a BB. +// +// Some CPU architectures, such as ARM, have instructions that load into +// multiple scalar registers, similar to a GPU vectorized load. In theory ARM +// could use this pass (with some modifications), but currently it implements +// its own pass to do something similar to what we do here. #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" @@ -21,6 +53,7 @@ #include "llvm/Analysis/OrderedBasicBlock.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/Attributes.h" @@ -45,7 +78,6 @@ #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Vectorize.h" #include <algorithm> #include <cassert> @@ -65,8 +97,16 @@ static const unsigned StackAdjustedAlignment = 4; namespace { +/// ChainID is an arbitrary token that is allowed to be different only for the +/// accesses that are guaranteed to be considered non-consecutive by +/// Vectorizer::isConsecutiveAccess. It's used for grouping instructions +/// together and reducing the number of instructions the main search operates on +/// at a time, i.e. this is to reduce compile time and nothing else as the main +/// search has O(n^2) time complexity. The underlying type of ChainID should not +/// be relied upon. +using ChainID = const Value *; using InstrList = SmallVector<Instruction *, 8>; -using InstrListMap = MapVector<Value *, InstrList>; +using InstrListMap = MapVector<ChainID, InstrList>; class Vectorizer { Function &F; @@ -86,10 +126,6 @@ public: bool run(); private: - Value *getPointerOperand(Value *I) const; - - GetElementPtrInst *getSourceGEP(Value *Src) const; - unsigned getPointerAddressSpace(Value *I); unsigned getAlignment(LoadInst *LI) const { @@ -108,7 +144,15 @@ private: return DL.getABITypeAlignment(SI->getValueOperand()->getType()); } + static const unsigned MaxDepth = 3; + bool isConsecutiveAccess(Value *A, Value *B); + bool areConsecutivePointers(Value *PtrA, Value *PtrB, const APInt &PtrDelta, + unsigned Depth = 0) const; + bool lookThroughComplexAddresses(Value *PtrA, Value *PtrB, APInt PtrDelta, + unsigned Depth) const; + bool lookThroughSelects(Value *PtrA, Value *PtrB, const APInt &PtrDelta, + unsigned Depth) const; /// After vectorization, reorder the instructions that I depends on /// (the instructions defining its operands), to ensure they dominate I. @@ -239,14 +283,6 @@ bool Vectorizer::run() { return Changed; } -Value *Vectorizer::getPointerOperand(Value *I) const { - if (LoadInst *LI = dyn_cast<LoadInst>(I)) - return LI->getPointerOperand(); - if (StoreInst *SI = dyn_cast<StoreInst>(I)) - return SI->getPointerOperand(); - return nullptr; -} - unsigned Vectorizer::getPointerAddressSpace(Value *I) { if (LoadInst *L = dyn_cast<LoadInst>(I)) return L->getPointerAddressSpace(); @@ -255,23 +291,10 @@ unsigned Vectorizer::getPointerAddressSpace(Value *I) { return -1; } -GetElementPtrInst *Vectorizer::getSourceGEP(Value *Src) const { - // First strip pointer bitcasts. Make sure pointee size is the same with - // and without casts. - // TODO: a stride set by the add instruction below can match the difference - // in pointee type size here. Currently it will not be vectorized. - Value *SrcPtr = getPointerOperand(Src); - Value *SrcBase = SrcPtr->stripPointerCasts(); - if (DL.getTypeStoreSize(SrcPtr->getType()->getPointerElementType()) == - DL.getTypeStoreSize(SrcBase->getType()->getPointerElementType())) - SrcPtr = SrcBase; - return dyn_cast<GetElementPtrInst>(SrcPtr); -} - // FIXME: Merge with llvm::isConsecutiveAccess bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) { - Value *PtrA = getPointerOperand(A); - Value *PtrB = getPointerOperand(B); + Value *PtrA = getLoadStorePointerOperand(A); + Value *PtrB = getLoadStorePointerOperand(B); unsigned ASA = getPointerAddressSpace(A); unsigned ASB = getPointerAddressSpace(B); @@ -280,18 +303,27 @@ bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) { return false; // Make sure that A and B are different pointers of the same size type. - unsigned PtrBitWidth = DL.getPointerSizeInBits(ASA); Type *PtrATy = PtrA->getType()->getPointerElementType(); Type *PtrBTy = PtrB->getType()->getPointerElementType(); if (PtrA == PtrB || + PtrATy->isVectorTy() != PtrBTy->isVectorTy() || DL.getTypeStoreSize(PtrATy) != DL.getTypeStoreSize(PtrBTy) || DL.getTypeStoreSize(PtrATy->getScalarType()) != DL.getTypeStoreSize(PtrBTy->getScalarType())) return false; + unsigned PtrBitWidth = DL.getPointerSizeInBits(ASA); APInt Size(PtrBitWidth, DL.getTypeStoreSize(PtrATy)); - APInt OffsetA(PtrBitWidth, 0), OffsetB(PtrBitWidth, 0); + return areConsecutivePointers(PtrA, PtrB, Size); +} + +bool Vectorizer::areConsecutivePointers(Value *PtrA, Value *PtrB, + const APInt &PtrDelta, + unsigned Depth) const { + unsigned PtrBitWidth = DL.getPointerTypeSizeInBits(PtrA->getType()); + APInt OffsetA(PtrBitWidth, 0); + APInt OffsetB(PtrBitWidth, 0); PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA); PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB); @@ -300,11 +332,11 @@ bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) { // Check if they are based on the same pointer. That makes the offsets // sufficient. if (PtrA == PtrB) - return OffsetDelta == Size; + return OffsetDelta == PtrDelta; // Compute the necessary base pointer delta to have the necessary final delta - // equal to the size. - APInt BaseDelta = Size - OffsetDelta; + // equal to the pointer delta requested. + APInt BaseDelta = PtrDelta - OffsetDelta; // Compute the distance with SCEV between the base pointers. const SCEV *PtrSCEVA = SE.getSCEV(PtrA); @@ -314,71 +346,127 @@ bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) { if (X == PtrSCEVB) return true; + // The above check will not catch the cases where one of the pointers is + // factorized but the other one is not, such as (C + (S * (A + B))) vs + // (AS + BS). Get the minus scev. That will allow re-combining the expresions + // and getting the simplified difference. + const SCEV *Dist = SE.getMinusSCEV(PtrSCEVB, PtrSCEVA); + if (C == Dist) + return true; + // Sometimes even this doesn't work, because SCEV can't always see through // patterns that look like (gep (ext (add (shl X, C1), C2))). Try checking // things the hard way. + return lookThroughComplexAddresses(PtrA, PtrB, BaseDelta, Depth); +} + +bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB, + APInt PtrDelta, + unsigned Depth) const { + auto *GEPA = dyn_cast<GetElementPtrInst>(PtrA); + auto *GEPB = dyn_cast<GetElementPtrInst>(PtrB); + if (!GEPA || !GEPB) + return lookThroughSelects(PtrA, PtrB, PtrDelta, Depth); // Look through GEPs after checking they're the same except for the last // index. - GetElementPtrInst *GEPA = getSourceGEP(A); - GetElementPtrInst *GEPB = getSourceGEP(B); - if (!GEPA || !GEPB || GEPA->getNumOperands() != GEPB->getNumOperands()) + if (GEPA->getNumOperands() != GEPB->getNumOperands() || + GEPA->getPointerOperand() != GEPB->getPointerOperand()) return false; - unsigned FinalIndex = GEPA->getNumOperands() - 1; - for (unsigned i = 0; i < FinalIndex; i++) - if (GEPA->getOperand(i) != GEPB->getOperand(i)) + gep_type_iterator GTIA = gep_type_begin(GEPA); + gep_type_iterator GTIB = gep_type_begin(GEPB); + for (unsigned I = 0, E = GEPA->getNumIndices() - 1; I < E; ++I) { + if (GTIA.getOperand() != GTIB.getOperand()) return false; + ++GTIA; + ++GTIB; + } - Instruction *OpA = dyn_cast<Instruction>(GEPA->getOperand(FinalIndex)); - Instruction *OpB = dyn_cast<Instruction>(GEPB->getOperand(FinalIndex)); + Instruction *OpA = dyn_cast<Instruction>(GTIA.getOperand()); + Instruction *OpB = dyn_cast<Instruction>(GTIB.getOperand()); if (!OpA || !OpB || OpA->getOpcode() != OpB->getOpcode() || OpA->getType() != OpB->getType()) return false; + if (PtrDelta.isNegative()) { + if (PtrDelta.isMinSignedValue()) + return false; + PtrDelta.negate(); + std::swap(OpA, OpB); + } + uint64_t Stride = DL.getTypeAllocSize(GTIA.getIndexedType()); + if (PtrDelta.urem(Stride) != 0) + return false; + unsigned IdxBitWidth = OpA->getType()->getScalarSizeInBits(); + APInt IdxDiff = PtrDelta.udiv(Stride).zextOrSelf(IdxBitWidth); + // Only look through a ZExt/SExt. if (!isa<SExtInst>(OpA) && !isa<ZExtInst>(OpA)) return false; bool Signed = isa<SExtInst>(OpA); - OpA = dyn_cast<Instruction>(OpA->getOperand(0)); + // At this point A could be a function parameter, i.e. not an instruction + Value *ValA = OpA->getOperand(0); OpB = dyn_cast<Instruction>(OpB->getOperand(0)); - if (!OpA || !OpB || OpA->getType() != OpB->getType()) + if (!OpB || ValA->getType() != OpB->getType()) return false; - // Now we need to prove that adding 1 to OpA won't overflow. + // Now we need to prove that adding IdxDiff to ValA won't overflow. bool Safe = false; - // First attempt: if OpB is an add with NSW/NUW, and OpB is 1 added to OpA, - // we're okay. + // First attempt: if OpB is an add with NSW/NUW, and OpB is IdxDiff added to + // ValA, we're okay. if (OpB->getOpcode() == Instruction::Add && isa<ConstantInt>(OpB->getOperand(1)) && - cast<ConstantInt>(OpB->getOperand(1))->getSExtValue() > 0) { + IdxDiff.sle(cast<ConstantInt>(OpB->getOperand(1))->getSExtValue())) { if (Signed) Safe = cast<BinaryOperator>(OpB)->hasNoSignedWrap(); else Safe = cast<BinaryOperator>(OpB)->hasNoUnsignedWrap(); } - unsigned BitWidth = OpA->getType()->getScalarSizeInBits(); + unsigned BitWidth = ValA->getType()->getScalarSizeInBits(); // Second attempt: - // If any bits are known to be zero other than the sign bit in OpA, we can - // add 1 to it while guaranteeing no overflow of any sort. + // If all set bits of IdxDiff or any higher order bit other than the sign bit + // are known to be zero in ValA, we can add Diff to it while guaranteeing no + // overflow of any sort. if (!Safe) { + OpA = dyn_cast<Instruction>(ValA); + if (!OpA) + return false; KnownBits Known(BitWidth); computeKnownBits(OpA, Known, DL, 0, nullptr, OpA, &DT); - if (Known.countMaxTrailingOnes() < (BitWidth - 1)) - Safe = true; + APInt BitsAllowedToBeSet = Known.Zero.zext(IdxDiff.getBitWidth()); + if (Signed) + BitsAllowedToBeSet.clearBit(BitWidth - 1); + if (BitsAllowedToBeSet.ult(IdxDiff)) + return false; } - if (!Safe) + const SCEV *OffsetSCEVA = SE.getSCEV(ValA); + const SCEV *OffsetSCEVB = SE.getSCEV(OpB); + const SCEV *C = SE.getConstant(IdxDiff.trunc(BitWidth)); + const SCEV *X = SE.getAddExpr(OffsetSCEVA, C); + return X == OffsetSCEVB; +} + +bool Vectorizer::lookThroughSelects(Value *PtrA, Value *PtrB, + const APInt &PtrDelta, + unsigned Depth) const { + if (Depth++ == MaxDepth) return false; - const SCEV *OffsetSCEVA = SE.getSCEV(OpA); - const SCEV *OffsetSCEVB = SE.getSCEV(OpB); - const SCEV *One = SE.getConstant(APInt(BitWidth, 1)); - const SCEV *X2 = SE.getAddExpr(OffsetSCEVA, One); - return X2 == OffsetSCEVB; + if (auto *SelectA = dyn_cast<SelectInst>(PtrA)) { + if (auto *SelectB = dyn_cast<SelectInst>(PtrB)) { + return SelectA->getCondition() == SelectB->getCondition() && + areConsecutivePointers(SelectA->getTrueValue(), + SelectB->getTrueValue(), PtrDelta, Depth) && + areConsecutivePointers(SelectA->getFalseValue(), + SelectB->getFalseValue(), PtrDelta, Depth); + } + } + return false; } void Vectorizer::reorder(Instruction *I) { @@ -448,7 +536,7 @@ Vectorizer::getBoundaryInstrs(ArrayRef<Instruction *> Chain) { void Vectorizer::eraseInstructions(ArrayRef<Instruction *> Chain) { SmallVector<Instruction *, 16> Instrs; for (Instruction *I : Chain) { - Value *PtrOperand = getPointerOperand(I); + Value *PtrOperand = getLoadStorePointerOperand(I); assert(PtrOperand && "Instruction must have a pointer operand."); Instrs.push_back(I); if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(PtrOperand)) @@ -484,7 +572,7 @@ Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) { SmallVector<Instruction *, 16> ChainInstrs; bool IsLoadChain = isa<LoadInst>(Chain[0]); - DEBUG({ + LLVM_DEBUG({ for (Instruction *I : Chain) { if (IsLoadChain) assert(isa<LoadInst>(I) && @@ -506,11 +594,12 @@ Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) { Intrinsic::sideeffect) { // Ignore llvm.sideeffect calls. } else if (IsLoadChain && (I.mayWriteToMemory() || I.mayThrow())) { - DEBUG(dbgs() << "LSV: Found may-write/throw operation: " << I << '\n'); + LLVM_DEBUG(dbgs() << "LSV: Found may-write/throw operation: " << I + << '\n'); break; } else if (!IsLoadChain && (I.mayReadOrWriteMemory() || I.mayThrow())) { - DEBUG(dbgs() << "LSV: Found may-read/write/throw operation: " << I - << '\n'); + LLVM_DEBUG(dbgs() << "LSV: Found may-read/write/throw operation: " << I + << '\n'); break; } } @@ -536,32 +625,40 @@ Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) { if (BarrierMemoryInstr && OBB.dominates(BarrierMemoryInstr, MemInstr)) break; - if (isa<LoadInst>(MemInstr) && isa<LoadInst>(ChainInstr)) + auto *MemLoad = dyn_cast<LoadInst>(MemInstr); + auto *ChainLoad = dyn_cast<LoadInst>(ChainInstr); + if (MemLoad && ChainLoad) continue; + // We can ignore the alias if the we have a load store pair and the load + // is known to be invariant. The load cannot be clobbered by the store. + auto IsInvariantLoad = [](const LoadInst *LI) -> bool { + return LI->getMetadata(LLVMContext::MD_invariant_load); + }; + // We can ignore the alias as long as the load comes before the store, // because that means we won't be moving the load past the store to // vectorize it (the vectorized load is inserted at the location of the // first load in the chain). - if (isa<StoreInst>(MemInstr) && isa<LoadInst>(ChainInstr) && - OBB.dominates(ChainInstr, MemInstr)) + if (isa<StoreInst>(MemInstr) && ChainLoad && + (IsInvariantLoad(ChainLoad) || OBB.dominates(ChainLoad, MemInstr))) continue; // Same case, but in reverse. - if (isa<LoadInst>(MemInstr) && isa<StoreInst>(ChainInstr) && - OBB.dominates(MemInstr, ChainInstr)) + if (MemLoad && isa<StoreInst>(ChainInstr) && + (IsInvariantLoad(MemLoad) || OBB.dominates(MemLoad, ChainInstr))) continue; if (!AA.isNoAlias(MemoryLocation::get(MemInstr), MemoryLocation::get(ChainInstr))) { - DEBUG({ + LLVM_DEBUG({ dbgs() << "LSV: Found alias:\n" " Aliasing instruction and pointer:\n" << " " << *MemInstr << '\n' - << " " << *getPointerOperand(MemInstr) << '\n' + << " " << *getLoadStorePointerOperand(MemInstr) << '\n' << " Aliased instruction and pointer:\n" << " " << *ChainInstr << '\n' - << " " << *getPointerOperand(ChainInstr) << '\n'; + << " " << *getLoadStorePointerOperand(ChainInstr) << '\n'; }); // Save this aliasing memory instruction as a barrier, but allow other // instructions that precede the barrier to be vectorized with this one. @@ -594,6 +691,20 @@ Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) { return Chain.slice(0, ChainIdx); } +static ChainID getChainID(const Value *Ptr, const DataLayout &DL) { + const Value *ObjPtr = GetUnderlyingObject(Ptr, DL); + if (const auto *Sel = dyn_cast<SelectInst>(ObjPtr)) { + // The select's themselves are distinct instructions even if they share the + // same condition and evaluate to consecutive pointers for true and false + // values of the condition. Therefore using the select's themselves for + // grouping instructions would put consecutive accesses into different lists + // and they won't be even checked for being consecutive, and won't be + // vectorized. + return Sel->getCondition(); + } + return ObjPtr; +} + std::pair<InstrListMap, InstrListMap> Vectorizer::collectInstructions(BasicBlock *BB) { InstrListMap LoadRefs; @@ -632,8 +743,12 @@ Vectorizer::collectInstructions(BasicBlock *BB) { unsigned AS = Ptr->getType()->getPointerAddressSpace(); unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS); + unsigned VF = VecRegSize / TySize; + VectorType *VecTy = dyn_cast<VectorType>(Ty); + // No point in looking at these if they're too big to vectorize. - if (TySize > VecRegSize / 2) + if (TySize > VecRegSize / 2 || + (VecTy && TTI.getLoadVectorFactor(VF, TySize, TySize / 8, VecTy) == 0)) continue; // Make sure all the users of a vector are constant-index extracts. @@ -644,8 +759,8 @@ Vectorizer::collectInstructions(BasicBlock *BB) { continue; // Save the load locations. - Value *ObjPtr = GetUnderlyingObject(Ptr, DL); - LoadRefs[ObjPtr].push_back(LI); + const ChainID ID = getChainID(Ptr, DL); + LoadRefs[ID].push_back(LI); } else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) { if (!SI->isSimple()) continue; @@ -675,8 +790,12 @@ Vectorizer::collectInstructions(BasicBlock *BB) { unsigned AS = Ptr->getType()->getPointerAddressSpace(); unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS); + unsigned VF = VecRegSize / TySize; + VectorType *VecTy = dyn_cast<VectorType>(Ty); + // No point in looking at these if they're too big to vectorize. - if (TySize > VecRegSize / 2) + if (TySize > VecRegSize / 2 || + (VecTy && TTI.getStoreVectorFactor(VF, TySize, TySize / 8, VecTy) == 0)) continue; if (isa<VectorType>(Ty) && !llvm::all_of(SI->users(), [](const User *U) { @@ -686,8 +805,8 @@ Vectorizer::collectInstructions(BasicBlock *BB) { continue; // Save store location. - Value *ObjPtr = GetUnderlyingObject(Ptr, DL); - StoreRefs[ObjPtr].push_back(SI); + const ChainID ID = getChainID(Ptr, DL); + StoreRefs[ID].push_back(SI); } } @@ -697,12 +816,12 @@ Vectorizer::collectInstructions(BasicBlock *BB) { bool Vectorizer::vectorizeChains(InstrListMap &Map) { bool Changed = false; - for (const std::pair<Value *, InstrList> &Chain : Map) { + for (const std::pair<ChainID, InstrList> &Chain : Map) { unsigned Size = Chain.second.size(); if (Size < 2) continue; - DEBUG(dbgs() << "LSV: Analyzing a chain of length " << Size << ".\n"); + LLVM_DEBUG(dbgs() << "LSV: Analyzing a chain of length " << Size << ".\n"); // Process the stores in chunks of 64. for (unsigned CI = 0, CE = Size; CI < CE; CI += 64) { @@ -716,7 +835,8 @@ bool Vectorizer::vectorizeChains(InstrListMap &Map) { } bool Vectorizer::vectorizeInstructions(ArrayRef<Instruction *> Instrs) { - DEBUG(dbgs() << "LSV: Vectorizing " << Instrs.size() << " instructions.\n"); + LLVM_DEBUG(dbgs() << "LSV: Vectorizing " << Instrs.size() + << " instructions.\n"); SmallVector<int, 16> Heads, Tails; int ConsecutiveChain[64]; @@ -852,14 +972,14 @@ bool Vectorizer::vectorizeStoreChain( // vector factor, break it into two pieces. unsigned TargetVF = TTI.getStoreVectorFactor(VF, Sz, SzInBytes, VecTy); if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) { - DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor." - " Creating two separate arrays.\n"); + LLVM_DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor." + " Creating two separate arrays.\n"); return vectorizeStoreChain(Chain.slice(0, TargetVF), InstructionsProcessed) | vectorizeStoreChain(Chain.slice(TargetVF), InstructionsProcessed); } - DEBUG({ + LLVM_DEBUG({ dbgs() << "LSV: Stores to vectorize:\n"; for (Instruction *I : Chain) dbgs() << " " << *I << "\n"; @@ -1000,8 +1120,8 @@ bool Vectorizer::vectorizeLoadChain( // vector factor, break it into two pieces. unsigned TargetVF = TTI.getLoadVectorFactor(VF, Sz, SzInBytes, VecTy); if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) { - DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor." - " Creating two separate arrays.\n"); + LLVM_DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor." + " Creating two separate arrays.\n"); return vectorizeLoadChain(Chain.slice(0, TargetVF), InstructionsProcessed) | vectorizeLoadChain(Chain.slice(TargetVF), InstructionsProcessed); } @@ -1024,7 +1144,7 @@ bool Vectorizer::vectorizeLoadChain( Alignment = NewAlign; } - DEBUG({ + LLVM_DEBUG({ dbgs() << "LSV: Loads to vectorize:\n"; for (Instruction *I : Chain) I->dump(); @@ -1107,7 +1227,7 @@ bool Vectorizer::accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace, bool Allows = TTI.allowsMisalignedMemoryAccesses(F.getParent()->getContext(), SzInBytes * 8, AddressSpace, Alignment, &Fast); - DEBUG(dbgs() << "LSV: Target said misaligned is allowed? " << Allows - << " and fast? " << Fast << "\n";); + LLVM_DEBUG(dbgs() << "LSV: Target said misaligned is allowed? " << Allows + << " and fast? " << Fast << "\n";); return !Allows || !Fast; } diff --git a/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp new file mode 100644 index 000000000000..697bc1b448d7 --- /dev/null +++ b/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -0,0 +1,1072 @@ +//===- LoopVectorizationLegality.cpp --------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file provides loop vectorization legality analysis. Original code +// resided in LoopVectorize.cpp for a long time. +// +// At this point, it is implemented as a utility class, not as an analysis +// pass. It should be easy to create an analysis pass around it if there +// is a need (but D45420 needs to happen first). +// +#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/IntrinsicInst.h" + +using namespace llvm; + +#define LV_NAME "loop-vectorize" +#define DEBUG_TYPE LV_NAME + +static cl::opt<bool> + EnableIfConversion("enable-if-conversion", cl::init(true), cl::Hidden, + cl::desc("Enable if-conversion during vectorization.")); + +static cl::opt<unsigned> PragmaVectorizeMemoryCheckThreshold( + "pragma-vectorize-memory-check-threshold", cl::init(128), cl::Hidden, + cl::desc("The maximum allowed number of runtime memory checks with a " + "vectorize(enable) pragma.")); + +static cl::opt<unsigned> VectorizeSCEVCheckThreshold( + "vectorize-scev-check-threshold", cl::init(16), cl::Hidden, + cl::desc("The maximum number of SCEV checks allowed.")); + +static cl::opt<unsigned> PragmaVectorizeSCEVCheckThreshold( + "pragma-vectorize-scev-check-threshold", cl::init(128), cl::Hidden, + cl::desc("The maximum number of SCEV checks allowed with a " + "vectorize(enable) pragma")); + +/// Maximum vectorization interleave count. +static const unsigned MaxInterleaveFactor = 16; + +namespace llvm { + +OptimizationRemarkAnalysis createLVMissedAnalysis(const char *PassName, + StringRef RemarkName, + Loop *TheLoop, + Instruction *I) { + Value *CodeRegion = TheLoop->getHeader(); + DebugLoc DL = TheLoop->getStartLoc(); + + if (I) { + CodeRegion = I->getParent(); + // If there is no debug location attached to the instruction, revert back to + // using the loop's. + if (I->getDebugLoc()) + DL = I->getDebugLoc(); + } + + OptimizationRemarkAnalysis R(PassName, RemarkName, DL, CodeRegion); + R << "loop not vectorized: "; + return R; +} + +bool LoopVectorizeHints::Hint::validate(unsigned Val) { + switch (Kind) { + case HK_WIDTH: + return isPowerOf2_32(Val) && Val <= VectorizerParams::MaxVectorWidth; + case HK_UNROLL: + return isPowerOf2_32(Val) && Val <= MaxInterleaveFactor; + case HK_FORCE: + return (Val <= 1); + case HK_ISVECTORIZED: + return (Val == 0 || Val == 1); + } + return false; +} + +LoopVectorizeHints::LoopVectorizeHints(const Loop *L, bool DisableInterleaving, + OptimizationRemarkEmitter &ORE) + : Width("vectorize.width", VectorizerParams::VectorizationFactor, HK_WIDTH), + Interleave("interleave.count", DisableInterleaving, HK_UNROLL), + Force("vectorize.enable", FK_Undefined, HK_FORCE), + IsVectorized("isvectorized", 0, HK_ISVECTORIZED), TheLoop(L), ORE(ORE) { + // Populate values with existing loop metadata. + getHintsFromMetadata(); + + // force-vector-interleave overrides DisableInterleaving. + if (VectorizerParams::isInterleaveForced()) + Interleave.Value = VectorizerParams::VectorizationInterleave; + + if (IsVectorized.Value != 1) + // If the vectorization width and interleaving count are both 1 then + // consider the loop to have been already vectorized because there's + // nothing more that we can do. + IsVectorized.Value = Width.Value == 1 && Interleave.Value == 1; + LLVM_DEBUG(if (DisableInterleaving && Interleave.Value == 1) dbgs() + << "LV: Interleaving disabled by the pass manager\n"); +} + +bool LoopVectorizeHints::allowVectorization(Function *F, Loop *L, + bool AlwaysVectorize) const { + if (getForce() == LoopVectorizeHints::FK_Disabled) { + LLVM_DEBUG(dbgs() << "LV: Not vectorizing: #pragma vectorize disable.\n"); + emitRemarkWithHints(); + return false; + } + + if (!AlwaysVectorize && getForce() != LoopVectorizeHints::FK_Enabled) { + LLVM_DEBUG(dbgs() << "LV: Not vectorizing: No #pragma vectorize enable.\n"); + emitRemarkWithHints(); + return false; + } + + if (getIsVectorized() == 1) { + LLVM_DEBUG(dbgs() << "LV: Not vectorizing: Disabled/already vectorized.\n"); + // FIXME: Add interleave.disable metadata. This will allow + // vectorize.disable to be used without disabling the pass and errors + // to differentiate between disabled vectorization and a width of 1. + ORE.emit([&]() { + return OptimizationRemarkAnalysis(vectorizeAnalysisPassName(), + "AllDisabled", L->getStartLoc(), + L->getHeader()) + << "loop not vectorized: vectorization and interleaving are " + "explicitly disabled, or the loop has already been " + "vectorized"; + }); + return false; + } + + return true; +} + +void LoopVectorizeHints::emitRemarkWithHints() const { + using namespace ore; + + ORE.emit([&]() { + if (Force.Value == LoopVectorizeHints::FK_Disabled) + return OptimizationRemarkMissed(LV_NAME, "MissedExplicitlyDisabled", + TheLoop->getStartLoc(), + TheLoop->getHeader()) + << "loop not vectorized: vectorization is explicitly disabled"; + else { + OptimizationRemarkMissed R(LV_NAME, "MissedDetails", + TheLoop->getStartLoc(), TheLoop->getHeader()); + R << "loop not vectorized"; + if (Force.Value == LoopVectorizeHints::FK_Enabled) { + R << " (Force=" << NV("Force", true); + if (Width.Value != 0) + R << ", Vector Width=" << NV("VectorWidth", Width.Value); + if (Interleave.Value != 0) + R << ", Interleave Count=" << NV("InterleaveCount", Interleave.Value); + R << ")"; + } + return R; + } + }); +} + +const char *LoopVectorizeHints::vectorizeAnalysisPassName() const { + if (getWidth() == 1) + return LV_NAME; + if (getForce() == LoopVectorizeHints::FK_Disabled) + return LV_NAME; + if (getForce() == LoopVectorizeHints::FK_Undefined && getWidth() == 0) + return LV_NAME; + return OptimizationRemarkAnalysis::AlwaysPrint; +} + +void LoopVectorizeHints::getHintsFromMetadata() { + MDNode *LoopID = TheLoop->getLoopID(); + if (!LoopID) + return; + + // First operand should refer to the loop id itself. + assert(LoopID->getNumOperands() > 0 && "requires at least one operand"); + assert(LoopID->getOperand(0) == LoopID && "invalid loop id"); + + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + const MDString *S = nullptr; + SmallVector<Metadata *, 4> Args; + + // The expected hint is either a MDString or a MDNode with the first + // operand a MDString. + if (const MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i))) { + if (!MD || MD->getNumOperands() == 0) + continue; + S = dyn_cast<MDString>(MD->getOperand(0)); + for (unsigned i = 1, ie = MD->getNumOperands(); i < ie; ++i) + Args.push_back(MD->getOperand(i)); + } else { + S = dyn_cast<MDString>(LoopID->getOperand(i)); + assert(Args.size() == 0 && "too many arguments for MDString"); + } + + if (!S) + continue; + + // Check if the hint starts with the loop metadata prefix. + StringRef Name = S->getString(); + if (Args.size() == 1) + setHint(Name, Args[0]); + } +} + +void LoopVectorizeHints::setHint(StringRef Name, Metadata *Arg) { + if (!Name.startswith(Prefix())) + return; + Name = Name.substr(Prefix().size(), StringRef::npos); + + const ConstantInt *C = mdconst::dyn_extract<ConstantInt>(Arg); + if (!C) + return; + unsigned Val = C->getZExtValue(); + + Hint *Hints[] = {&Width, &Interleave, &Force, &IsVectorized}; + for (auto H : Hints) { + if (Name == H->Name) { + if (H->validate(Val)) + H->Value = Val; + else + LLVM_DEBUG(dbgs() << "LV: ignoring invalid hint '" << Name << "'\n"); + break; + } + } +} + +MDNode *LoopVectorizeHints::createHintMetadata(StringRef Name, + unsigned V) const { + LLVMContext &Context = TheLoop->getHeader()->getContext(); + Metadata *MDs[] = { + MDString::get(Context, Name), + ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Context), V))}; + return MDNode::get(Context, MDs); +} + +bool LoopVectorizeHints::matchesHintMetadataName(MDNode *Node, + ArrayRef<Hint> HintTypes) { + MDString *Name = dyn_cast<MDString>(Node->getOperand(0)); + if (!Name) + return false; + + for (auto H : HintTypes) + if (Name->getString().endswith(H.Name)) + return true; + return false; +} + +void LoopVectorizeHints::writeHintsToMetadata(ArrayRef<Hint> HintTypes) { + if (HintTypes.empty()) + return; + + // Reserve the first element to LoopID (see below). + SmallVector<Metadata *, 4> MDs(1); + // If the loop already has metadata, then ignore the existing operands. + MDNode *LoopID = TheLoop->getLoopID(); + if (LoopID) { + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + MDNode *Node = cast<MDNode>(LoopID->getOperand(i)); + // If node in update list, ignore old value. + if (!matchesHintMetadataName(Node, HintTypes)) + MDs.push_back(Node); + } + } + + // Now, add the missing hints. + for (auto H : HintTypes) + MDs.push_back(createHintMetadata(Twine(Prefix(), H.Name).str(), H.Value)); + + // Replace current metadata node with new one. + LLVMContext &Context = TheLoop->getHeader()->getContext(); + MDNode *NewLoopID = MDNode::get(Context, MDs); + // Set operand 0 to refer to the loop id itself. + NewLoopID->replaceOperandWith(0, NewLoopID); + + TheLoop->setLoopID(NewLoopID); +} + +bool LoopVectorizationRequirements::doesNotMeet( + Function *F, Loop *L, const LoopVectorizeHints &Hints) { + const char *PassName = Hints.vectorizeAnalysisPassName(); + bool Failed = false; + if (UnsafeAlgebraInst && !Hints.allowReordering()) { + ORE.emit([&]() { + return OptimizationRemarkAnalysisFPCommute( + PassName, "CantReorderFPOps", UnsafeAlgebraInst->getDebugLoc(), + UnsafeAlgebraInst->getParent()) + << "loop not vectorized: cannot prove it is safe to reorder " + "floating-point operations"; + }); + Failed = true; + } + + // Test if runtime memcheck thresholds are exceeded. + bool PragmaThresholdReached = + NumRuntimePointerChecks > PragmaVectorizeMemoryCheckThreshold; + bool ThresholdReached = + NumRuntimePointerChecks > VectorizerParams::RuntimeMemoryCheckThreshold; + if ((ThresholdReached && !Hints.allowReordering()) || + PragmaThresholdReached) { + ORE.emit([&]() { + return OptimizationRemarkAnalysisAliasing(PassName, "CantReorderMemOps", + L->getStartLoc(), + L->getHeader()) + << "loop not vectorized: cannot prove it is safe to reorder " + "memory operations"; + }); + LLVM_DEBUG(dbgs() << "LV: Too many memory checks needed.\n"); + Failed = true; + } + + return Failed; +} + +// Return true if the inner loop \p Lp is uniform with regard to the outer loop +// \p OuterLp (i.e., if the outer loop is vectorized, all the vector lanes +// executing the inner loop will execute the same iterations). This check is +// very constrained for now but it will be relaxed in the future. \p Lp is +// considered uniform if it meets all the following conditions: +// 1) it has a canonical IV (starting from 0 and with stride 1), +// 2) its latch terminator is a conditional branch and, +// 3) its latch condition is a compare instruction whose operands are the +// canonical IV and an OuterLp invariant. +// This check doesn't take into account the uniformity of other conditions not +// related to the loop latch because they don't affect the loop uniformity. +// +// NOTE: We decided to keep all these checks and its associated documentation +// together so that we can easily have a picture of the current supported loop +// nests. However, some of the current checks don't depend on \p OuterLp and +// would be redundantly executed for each \p Lp if we invoked this function for +// different candidate outer loops. This is not the case for now because we +// don't currently have the infrastructure to evaluate multiple candidate outer +// loops and \p OuterLp will be a fixed parameter while we only support explicit +// outer loop vectorization. It's also very likely that these checks go away +// before introducing the aforementioned infrastructure. However, if this is not +// the case, we should move the \p OuterLp independent checks to a separate +// function that is only executed once for each \p Lp. +static bool isUniformLoop(Loop *Lp, Loop *OuterLp) { + assert(Lp->getLoopLatch() && "Expected loop with a single latch."); + + // If Lp is the outer loop, it's uniform by definition. + if (Lp == OuterLp) + return true; + assert(OuterLp->contains(Lp) && "OuterLp must contain Lp."); + + // 1. + PHINode *IV = Lp->getCanonicalInductionVariable(); + if (!IV) { + LLVM_DEBUG(dbgs() << "LV: Canonical IV not found.\n"); + return false; + } + + // 2. + BasicBlock *Latch = Lp->getLoopLatch(); + auto *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator()); + if (!LatchBr || LatchBr->isUnconditional()) { + LLVM_DEBUG(dbgs() << "LV: Unsupported loop latch branch.\n"); + return false; + } + + // 3. + auto *LatchCmp = dyn_cast<CmpInst>(LatchBr->getCondition()); + if (!LatchCmp) { + LLVM_DEBUG( + dbgs() << "LV: Loop latch condition is not a compare instruction.\n"); + return false; + } + + Value *CondOp0 = LatchCmp->getOperand(0); + Value *CondOp1 = LatchCmp->getOperand(1); + Value *IVUpdate = IV->getIncomingValueForBlock(Latch); + if (!(CondOp0 == IVUpdate && OuterLp->isLoopInvariant(CondOp1)) && + !(CondOp1 == IVUpdate && OuterLp->isLoopInvariant(CondOp0))) { + LLVM_DEBUG(dbgs() << "LV: Loop latch condition is not uniform.\n"); + return false; + } + + return true; +} + +// Return true if \p Lp and all its nested loops are uniform with regard to \p +// OuterLp. +static bool isUniformLoopNest(Loop *Lp, Loop *OuterLp) { + if (!isUniformLoop(Lp, OuterLp)) + return false; + + // Check if nested loops are uniform. + for (Loop *SubLp : *Lp) + if (!isUniformLoopNest(SubLp, OuterLp)) + return false; + + return true; +} + +/// Check whether it is safe to if-convert this phi node. +/// +/// Phi nodes with constant expressions that can trap are not safe to if +/// convert. +static bool canIfConvertPHINodes(BasicBlock *BB) { + for (PHINode &Phi : BB->phis()) { + for (Value *V : Phi.incoming_values()) + if (auto *C = dyn_cast<Constant>(V)) + if (C->canTrap()) + return false; + } + return true; +} + +static Type *convertPointerToIntegerType(const DataLayout &DL, Type *Ty) { + if (Ty->isPointerTy()) + return DL.getIntPtrType(Ty); + + // It is possible that char's or short's overflow when we ask for the loop's + // trip count, work around this by changing the type size. + if (Ty->getScalarSizeInBits() < 32) + return Type::getInt32Ty(Ty->getContext()); + + return Ty; +} + +static Type *getWiderType(const DataLayout &DL, Type *Ty0, Type *Ty1) { + Ty0 = convertPointerToIntegerType(DL, Ty0); + Ty1 = convertPointerToIntegerType(DL, Ty1); + if (Ty0->getScalarSizeInBits() > Ty1->getScalarSizeInBits()) + return Ty0; + return Ty1; +} + +/// Check that the instruction has outside loop users and is not an +/// identified reduction variable. +static bool hasOutsideLoopUser(const Loop *TheLoop, Instruction *Inst, + SmallPtrSetImpl<Value *> &AllowedExit) { + // Reduction and Induction instructions are allowed to have exit users. All + // other instructions must not have external users. + if (!AllowedExit.count(Inst)) + // Check that all of the users of the loop are inside the BB. + for (User *U : Inst->users()) { + Instruction *UI = cast<Instruction>(U); + // This user may be a reduction exit value. + if (!TheLoop->contains(UI)) { + LLVM_DEBUG(dbgs() << "LV: Found an outside user for : " << *UI << '\n'); + return true; + } + } + return false; +} + +int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { + const ValueToValueMap &Strides = + getSymbolicStrides() ? *getSymbolicStrides() : ValueToValueMap(); + + int Stride = getPtrStride(PSE, Ptr, TheLoop, Strides, true, false); + if (Stride == 1 || Stride == -1) + return Stride; + return 0; +} + +bool LoopVectorizationLegality::isUniform(Value *V) { + return LAI->isUniform(V); +} + +bool LoopVectorizationLegality::canVectorizeOuterLoop() { + assert(!TheLoop->empty() && "We are not vectorizing an outer loop."); + // Store the result and return it at the end instead of exiting early, in case + // allowExtraAnalysis is used to report multiple reasons for not vectorizing. + bool Result = true; + bool DoExtraAnalysis = ORE->allowExtraAnalysis(DEBUG_TYPE); + + for (BasicBlock *BB : TheLoop->blocks()) { + // Check whether the BB terminator is a BranchInst. Any other terminator is + // not supported yet. + auto *Br = dyn_cast<BranchInst>(BB->getTerminator()); + if (!Br) { + LLVM_DEBUG(dbgs() << "LV: Unsupported basic block terminator.\n"); + ORE->emit(createMissedAnalysis("CFGNotUnderstood") + << "loop control flow is not understood by vectorizer"); + if (DoExtraAnalysis) + Result = false; + else + return false; + } + + // Check whether the BranchInst is a supported one. Only unconditional + // branches, conditional branches with an outer loop invariant condition or + // backedges are supported. + if (Br && Br->isConditional() && + !TheLoop->isLoopInvariant(Br->getCondition()) && + !LI->isLoopHeader(Br->getSuccessor(0)) && + !LI->isLoopHeader(Br->getSuccessor(1))) { + LLVM_DEBUG(dbgs() << "LV: Unsupported conditional branch.\n"); + ORE->emit(createMissedAnalysis("CFGNotUnderstood") + << "loop control flow is not understood by vectorizer"); + if (DoExtraAnalysis) + Result = false; + else + return false; + } + } + + // Check whether inner loops are uniform. At this point, we only support + // simple outer loops scenarios with uniform nested loops. + if (!isUniformLoopNest(TheLoop /*loop nest*/, + TheLoop /*context outer loop*/)) { + LLVM_DEBUG( + dbgs() + << "LV: Not vectorizing: Outer loop contains divergent loops.\n"); + ORE->emit(createMissedAnalysis("CFGNotUnderstood") + << "loop control flow is not understood by vectorizer"); + if (DoExtraAnalysis) + Result = false; + else + return false; + } + + return Result; +} + +void LoopVectorizationLegality::addInductionPhi( + PHINode *Phi, const InductionDescriptor &ID, + SmallPtrSetImpl<Value *> &AllowedExit) { + Inductions[Phi] = ID; + + // In case this induction also comes with casts that we know we can ignore + // in the vectorized loop body, record them here. All casts could be recorded + // here for ignoring, but suffices to record only the first (as it is the + // only one that may bw used outside the cast sequence). + const SmallVectorImpl<Instruction *> &Casts = ID.getCastInsts(); + if (!Casts.empty()) + InductionCastsToIgnore.insert(*Casts.begin()); + + Type *PhiTy = Phi->getType(); + const DataLayout &DL = Phi->getModule()->getDataLayout(); + + // Get the widest type. + if (!PhiTy->isFloatingPointTy()) { + if (!WidestIndTy) + WidestIndTy = convertPointerToIntegerType(DL, PhiTy); + else + WidestIndTy = getWiderType(DL, PhiTy, WidestIndTy); + } + + // Int inductions are special because we only allow one IV. + if (ID.getKind() == InductionDescriptor::IK_IntInduction && + ID.getConstIntStepValue() && ID.getConstIntStepValue()->isOne() && + isa<Constant>(ID.getStartValue()) && + cast<Constant>(ID.getStartValue())->isNullValue()) { + + // Use the phi node with the widest type as induction. Use the last + // one if there are multiple (no good reason for doing this other + // than it is expedient). We've checked that it begins at zero and + // steps by one, so this is a canonical induction variable. + if (!PrimaryInduction || PhiTy == WidestIndTy) + PrimaryInduction = Phi; + } + + // Both the PHI node itself, and the "post-increment" value feeding + // back into the PHI node may have external users. + // We can allow those uses, except if the SCEVs we have for them rely + // on predicates that only hold within the loop, since allowing the exit + // currently means re-using this SCEV outside the loop. + if (PSE.getUnionPredicate().isAlwaysTrue()) { + AllowedExit.insert(Phi); + AllowedExit.insert(Phi->getIncomingValueForBlock(TheLoop->getLoopLatch())); + } + + LLVM_DEBUG(dbgs() << "LV: Found an induction variable.\n"); +} + +bool LoopVectorizationLegality::canVectorizeInstrs() { + BasicBlock *Header = TheLoop->getHeader(); + + // Look for the attribute signaling the absence of NaNs. + Function &F = *Header->getParent(); + HasFunNoNaNAttr = + F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true"; + + // For each block in the loop. + for (BasicBlock *BB : TheLoop->blocks()) { + // Scan the instructions in the block and look for hazards. + for (Instruction &I : *BB) { + if (auto *Phi = dyn_cast<PHINode>(&I)) { + Type *PhiTy = Phi->getType(); + // Check that this PHI type is allowed. + if (!PhiTy->isIntegerTy() && !PhiTy->isFloatingPointTy() && + !PhiTy->isPointerTy()) { + ORE->emit(createMissedAnalysis("CFGNotUnderstood", Phi) + << "loop control flow is not understood by vectorizer"); + LLVM_DEBUG(dbgs() << "LV: Found an non-int non-pointer PHI.\n"); + return false; + } + + // If this PHINode is not in the header block, then we know that we + // can convert it to select during if-conversion. No need to check if + // the PHIs in this block are induction or reduction variables. + if (BB != Header) { + // Check that this instruction has no outside users or is an + // identified reduction value with an outside user. + if (!hasOutsideLoopUser(TheLoop, Phi, AllowedExit)) + continue; + ORE->emit(createMissedAnalysis("NeitherInductionNorReduction", Phi) + << "value could not be identified as " + "an induction or reduction variable"); + return false; + } + + // We only allow if-converted PHIs with exactly two incoming values. + if (Phi->getNumIncomingValues() != 2) { + ORE->emit(createMissedAnalysis("CFGNotUnderstood", Phi) + << "control flow not understood by vectorizer"); + LLVM_DEBUG(dbgs() << "LV: Found an invalid PHI.\n"); + return false; + } + + RecurrenceDescriptor RedDes; + if (RecurrenceDescriptor::isReductionPHI(Phi, TheLoop, RedDes, DB, AC, + DT)) { + if (RedDes.hasUnsafeAlgebra()) + Requirements->addUnsafeAlgebraInst(RedDes.getUnsafeAlgebraInst()); + AllowedExit.insert(RedDes.getLoopExitInstr()); + Reductions[Phi] = RedDes; + continue; + } + + InductionDescriptor ID; + if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID)) { + addInductionPhi(Phi, ID, AllowedExit); + if (ID.hasUnsafeAlgebra() && !HasFunNoNaNAttr) + Requirements->addUnsafeAlgebraInst(ID.getUnsafeAlgebraInst()); + continue; + } + + if (RecurrenceDescriptor::isFirstOrderRecurrence(Phi, TheLoop, + SinkAfter, DT)) { + FirstOrderRecurrences.insert(Phi); + continue; + } + + // As a last resort, coerce the PHI to a AddRec expression + // and re-try classifying it a an induction PHI. + if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID, true)) { + addInductionPhi(Phi, ID, AllowedExit); + continue; + } + + ORE->emit(createMissedAnalysis("NonReductionValueUsedOutsideLoop", Phi) + << "value that could not be identified as " + "reduction is used outside the loop"); + LLVM_DEBUG(dbgs() << "LV: Found an unidentified PHI." << *Phi << "\n"); + return false; + } // end of PHI handling + + // We handle calls that: + // * Are debug info intrinsics. + // * Have a mapping to an IR intrinsic. + // * Have a vector version available. + auto *CI = dyn_cast<CallInst>(&I); + if (CI && !getVectorIntrinsicIDForCall(CI, TLI) && + !isa<DbgInfoIntrinsic>(CI) && + !(CI->getCalledFunction() && TLI && + TLI->isFunctionVectorizable(CI->getCalledFunction()->getName()))) { + ORE->emit(createMissedAnalysis("CantVectorizeCall", CI) + << "call instruction cannot be vectorized"); + LLVM_DEBUG( + dbgs() << "LV: Found a non-intrinsic, non-libfunc callsite.\n"); + return false; + } + + // Intrinsics such as powi,cttz and ctlz are legal to vectorize if the + // second argument is the same (i.e. loop invariant) + if (CI && hasVectorInstrinsicScalarOpd( + getVectorIntrinsicIDForCall(CI, TLI), 1)) { + auto *SE = PSE.getSE(); + if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(1)), TheLoop)) { + ORE->emit(createMissedAnalysis("CantVectorizeIntrinsic", CI) + << "intrinsic instruction cannot be vectorized"); + LLVM_DEBUG(dbgs() + << "LV: Found unvectorizable intrinsic " << *CI << "\n"); + return false; + } + } + + // Check that the instruction return type is vectorizable. + // Also, we can't vectorize extractelement instructions. + if ((!VectorType::isValidElementType(I.getType()) && + !I.getType()->isVoidTy()) || + isa<ExtractElementInst>(I)) { + ORE->emit(createMissedAnalysis("CantVectorizeInstructionReturnType", &I) + << "instruction return type cannot be vectorized"); + LLVM_DEBUG(dbgs() << "LV: Found unvectorizable type.\n"); + return false; + } + + // Check that the stored type is vectorizable. + if (auto *ST = dyn_cast<StoreInst>(&I)) { + Type *T = ST->getValueOperand()->getType(); + if (!VectorType::isValidElementType(T)) { + ORE->emit(createMissedAnalysis("CantVectorizeStore", ST) + << "store instruction cannot be vectorized"); + return false; + } + + // FP instructions can allow unsafe algebra, thus vectorizable by + // non-IEEE-754 compliant SIMD units. + // This applies to floating-point math operations and calls, not memory + // operations, shuffles, or casts, as they don't change precision or + // semantics. + } else if (I.getType()->isFloatingPointTy() && (CI || I.isBinaryOp()) && + !I.isFast()) { + LLVM_DEBUG(dbgs() << "LV: Found FP op with unsafe algebra.\n"); + Hints->setPotentiallyUnsafe(); + } + + // Reduction instructions are allowed to have exit users. + // All other instructions must not have external users. + if (hasOutsideLoopUser(TheLoop, &I, AllowedExit)) { + ORE->emit(createMissedAnalysis("ValueUsedOutsideLoop", &I) + << "value cannot be used outside the loop"); + return false; + } + } // next instr. + } + + if (!PrimaryInduction) { + LLVM_DEBUG(dbgs() << "LV: Did not find one integer induction var.\n"); + if (Inductions.empty()) { + ORE->emit(createMissedAnalysis("NoInductionVariable") + << "loop induction variable could not be identified"); + return false; + } + } + + // Now we know the widest induction type, check if our found induction + // is the same size. If it's not, unset it here and InnerLoopVectorizer + // will create another. + if (PrimaryInduction && WidestIndTy != PrimaryInduction->getType()) + PrimaryInduction = nullptr; + + return true; +} + +bool LoopVectorizationLegality::canVectorizeMemory() { + LAI = &(*GetLAA)(*TheLoop); + const OptimizationRemarkAnalysis *LAR = LAI->getReport(); + if (LAR) { + ORE->emit([&]() { + return OptimizationRemarkAnalysis(Hints->vectorizeAnalysisPassName(), + "loop not vectorized: ", *LAR); + }); + } + if (!LAI->canVectorizeMemory()) + return false; + + if (LAI->hasStoreToLoopInvariantAddress()) { + ORE->emit(createMissedAnalysis("CantVectorizeStoreToLoopInvariantAddress") + << "write to a loop invariant address could not be vectorized"); + LLVM_DEBUG(dbgs() << "LV: We don't allow storing to uniform addresses\n"); + return false; + } + + Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks()); + PSE.addPredicate(LAI->getPSE().getUnionPredicate()); + + return true; +} + +bool LoopVectorizationLegality::isInductionPhi(const Value *V) { + Value *In0 = const_cast<Value *>(V); + PHINode *PN = dyn_cast_or_null<PHINode>(In0); + if (!PN) + return false; + + return Inductions.count(PN); +} + +bool LoopVectorizationLegality::isCastedInductionVariable(const Value *V) { + auto *Inst = dyn_cast<Instruction>(V); + return (Inst && InductionCastsToIgnore.count(Inst)); +} + +bool LoopVectorizationLegality::isInductionVariable(const Value *V) { + return isInductionPhi(V) || isCastedInductionVariable(V); +} + +bool LoopVectorizationLegality::isFirstOrderRecurrence(const PHINode *Phi) { + return FirstOrderRecurrences.count(Phi); +} + +bool LoopVectorizationLegality::blockNeedsPredication(BasicBlock *BB) { + return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); +} + +bool LoopVectorizationLegality::blockCanBePredicated( + BasicBlock *BB, SmallPtrSetImpl<Value *> &SafePtrs) { + const bool IsAnnotatedParallel = TheLoop->isAnnotatedParallel(); + + for (Instruction &I : *BB) { + // Check that we don't have a constant expression that can trap as operand. + for (Value *Operand : I.operands()) { + if (auto *C = dyn_cast<Constant>(Operand)) + if (C->canTrap()) + return false; + } + // We might be able to hoist the load. + if (I.mayReadFromMemory()) { + auto *LI = dyn_cast<LoadInst>(&I); + if (!LI) + return false; + if (!SafePtrs.count(LI->getPointerOperand())) { + // !llvm.mem.parallel_loop_access implies if-conversion safety. + // Otherwise, record that the load needs (real or emulated) masking + // and let the cost model decide. + if (!IsAnnotatedParallel) + MaskedOp.insert(LI); + continue; + } + } + + if (I.mayWriteToMemory()) { + auto *SI = dyn_cast<StoreInst>(&I); + if (!SI) + return false; + // Predicated store requires some form of masking: + // 1) masked store HW instruction, + // 2) emulation via load-blend-store (only if safe and legal to do so, + // be aware on the race conditions), or + // 3) element-by-element predicate check and scalar store. + MaskedOp.insert(SI); + continue; + } + if (I.mayThrow()) + return false; + } + + return true; +} + +bool LoopVectorizationLegality::canVectorizeWithIfConvert() { + if (!EnableIfConversion) { + ORE->emit(createMissedAnalysis("IfConversionDisabled") + << "if-conversion is disabled"); + return false; + } + + assert(TheLoop->getNumBlocks() > 1 && "Single block loops are vectorizable"); + + // A list of pointers that we can safely read and write to. + SmallPtrSet<Value *, 8> SafePointes; + + // Collect safe addresses. + for (BasicBlock *BB : TheLoop->blocks()) { + if (blockNeedsPredication(BB)) + continue; + + for (Instruction &I : *BB) + if (auto *Ptr = getLoadStorePointerOperand(&I)) + SafePointes.insert(Ptr); + } + + // Collect the blocks that need predication. + BasicBlock *Header = TheLoop->getHeader(); + for (BasicBlock *BB : TheLoop->blocks()) { + // We don't support switch statements inside loops. + if (!isa<BranchInst>(BB->getTerminator())) { + ORE->emit(createMissedAnalysis("LoopContainsSwitch", BB->getTerminator()) + << "loop contains a switch statement"); + return false; + } + + // We must be able to predicate all blocks that need to be predicated. + if (blockNeedsPredication(BB)) { + if (!blockCanBePredicated(BB, SafePointes)) { + ORE->emit(createMissedAnalysis("NoCFGForSelect", BB->getTerminator()) + << "control flow cannot be substituted for a select"); + return false; + } + } else if (BB != Header && !canIfConvertPHINodes(BB)) { + ORE->emit(createMissedAnalysis("NoCFGForSelect", BB->getTerminator()) + << "control flow cannot be substituted for a select"); + return false; + } + } + + // We can if-convert this loop. + return true; +} + +// Helper function to canVectorizeLoopNestCFG. +bool LoopVectorizationLegality::canVectorizeLoopCFG(Loop *Lp, + bool UseVPlanNativePath) { + assert((UseVPlanNativePath || Lp->empty()) && + "VPlan-native path is not enabled."); + + // TODO: ORE should be improved to show more accurate information when an + // outer loop can't be vectorized because a nested loop is not understood or + // legal. Something like: "outer_loop_location: loop not vectorized: + // (inner_loop_location) loop control flow is not understood by vectorizer". + + // Store the result and return it at the end instead of exiting early, in case + // allowExtraAnalysis is used to report multiple reasons for not vectorizing. + bool Result = true; + bool DoExtraAnalysis = ORE->allowExtraAnalysis(DEBUG_TYPE); + + // We must have a loop in canonical form. Loops with indirectbr in them cannot + // be canonicalized. + if (!Lp->getLoopPreheader()) { + LLVM_DEBUG(dbgs() << "LV: Loop doesn't have a legal pre-header.\n"); + ORE->emit(createMissedAnalysis("CFGNotUnderstood") + << "loop control flow is not understood by vectorizer"); + if (DoExtraAnalysis) + Result = false; + else + return false; + } + + // We must have a single backedge. + if (Lp->getNumBackEdges() != 1) { + ORE->emit(createMissedAnalysis("CFGNotUnderstood") + << "loop control flow is not understood by vectorizer"); + if (DoExtraAnalysis) + Result = false; + else + return false; + } + + // We must have a single exiting block. + if (!Lp->getExitingBlock()) { + ORE->emit(createMissedAnalysis("CFGNotUnderstood") + << "loop control flow is not understood by vectorizer"); + if (DoExtraAnalysis) + Result = false; + else + return false; + } + + // We only handle bottom-tested loops, i.e. loop in which the condition is + // checked at the end of each iteration. With that we can assume that all + // instructions in the loop are executed the same number of times. + if (Lp->getExitingBlock() != Lp->getLoopLatch()) { + ORE->emit(createMissedAnalysis("CFGNotUnderstood") + << "loop control flow is not understood by vectorizer"); + if (DoExtraAnalysis) + Result = false; + else + return false; + } + + return Result; +} + +bool LoopVectorizationLegality::canVectorizeLoopNestCFG( + Loop *Lp, bool UseVPlanNativePath) { + // Store the result and return it at the end instead of exiting early, in case + // allowExtraAnalysis is used to report multiple reasons for not vectorizing. + bool Result = true; + bool DoExtraAnalysis = ORE->allowExtraAnalysis(DEBUG_TYPE); + if (!canVectorizeLoopCFG(Lp, UseVPlanNativePath)) { + if (DoExtraAnalysis) + Result = false; + else + return false; + } + + // Recursively check whether the loop control flow of nested loops is + // understood. + for (Loop *SubLp : *Lp) + if (!canVectorizeLoopNestCFG(SubLp, UseVPlanNativePath)) { + if (DoExtraAnalysis) + Result = false; + else + return false; + } + + return Result; +} + +bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) { + // Store the result and return it at the end instead of exiting early, in case + // allowExtraAnalysis is used to report multiple reasons for not vectorizing. + bool Result = true; + + bool DoExtraAnalysis = ORE->allowExtraAnalysis(DEBUG_TYPE); + // Check whether the loop-related control flow in the loop nest is expected by + // vectorizer. + if (!canVectorizeLoopNestCFG(TheLoop, UseVPlanNativePath)) { + if (DoExtraAnalysis) + Result = false; + else + return false; + } + + // We need to have a loop header. + LLVM_DEBUG(dbgs() << "LV: Found a loop: " << TheLoop->getHeader()->getName() + << '\n'); + + // Specific checks for outer loops. We skip the remaining legal checks at this + // point because they don't support outer loops. + if (!TheLoop->empty()) { + assert(UseVPlanNativePath && "VPlan-native path is not enabled."); + + if (!canVectorizeOuterLoop()) { + LLVM_DEBUG(dbgs() << "LV: Not vectorizing: Unsupported outer loop.\n"); + // TODO: Implement DoExtraAnalysis when subsequent legal checks support + // outer loops. + return false; + } + + LLVM_DEBUG(dbgs() << "LV: We can vectorize this outer loop!\n"); + return Result; + } + + assert(TheLoop->empty() && "Inner loop expected."); + // Check if we can if-convert non-single-bb loops. + unsigned NumBlocks = TheLoop->getNumBlocks(); + if (NumBlocks != 1 && !canVectorizeWithIfConvert()) { + LLVM_DEBUG(dbgs() << "LV: Can't if-convert the loop.\n"); + if (DoExtraAnalysis) + Result = false; + else + return false; + } + + // Check if we can vectorize the instructions and CFG in this loop. + if (!canVectorizeInstrs()) { + LLVM_DEBUG(dbgs() << "LV: Can't vectorize the instructions or CFG\n"); + if (DoExtraAnalysis) + Result = false; + else + return false; + } + + // Go over each instruction and look at memory deps. + if (!canVectorizeMemory()) { + LLVM_DEBUG(dbgs() << "LV: Can't vectorize due to memory conflicts\n"); + if (DoExtraAnalysis) + Result = false; + else + return false; + } + + LLVM_DEBUG(dbgs() << "LV: We can vectorize this loop" + << (LAI->getRuntimePointerChecking()->Need + ? " (with a runtime bound check)" + : "") + << "!\n"); + + unsigned SCEVThreshold = VectorizeSCEVCheckThreshold; + if (Hints->getForce() == LoopVectorizeHints::FK_Enabled) + SCEVThreshold = PragmaVectorizeSCEVCheckThreshold; + + if (PSE.getUnionPredicate().getComplexity() > SCEVThreshold) { + ORE->emit(createMissedAnalysis("TooManySCEVRunTimeChecks") + << "Too many SCEV assumptions need to be made and checked " + << "at runtime"); + LLVM_DEBUG(dbgs() << "LV: Too many SCEV checks needed.\n"); + if (DoExtraAnalysis) + Result = false; + else + return false; + } + + // Okay! We've done all the tests. If any have failed, return false. Otherwise + // we can vectorize, and at this point we don't have any other mem analysis + // which may limit our maximum vectorization factor, so just return true with + // no restrictions. + return Result; +} + +} // namespace llvm diff --git a/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/lib/Transforms/Vectorize/LoopVectorizationPlanner.h new file mode 100644 index 000000000000..2aa219064299 --- /dev/null +++ b/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -0,0 +1,282 @@ +//===- LoopVectorizationPlanner.h - Planner for LoopVectorization ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file provides a LoopVectorizationPlanner class. +/// InnerLoopVectorizer vectorizes loops which contain only one basic +/// LoopVectorizationPlanner - drives the vectorization process after having +/// passed Legality checks. +/// The planner builds and optimizes the Vectorization Plans which record the +/// decisions how to vectorize the given loop. In particular, represent the +/// control-flow of the vectorized version, the replication of instructions that +/// are to be scalarized, and interleave access groups. +/// +/// Also provides a VPlan-based builder utility analogous to IRBuilder. +/// It provides an instruction-level API for generating VPInstructions while +/// abstracting away the Recipe manipulation details. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_VECTORIZE_LOOPVECTORIZATIONPLANNER_H +#define LLVM_TRANSFORMS_VECTORIZE_LOOPVECTORIZATIONPLANNER_H + +#include "VPlan.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" + +namespace llvm { + +/// VPlan-based builder utility analogous to IRBuilder. +class VPBuilder { +private: + VPBasicBlock *BB = nullptr; + VPBasicBlock::iterator InsertPt = VPBasicBlock::iterator(); + + VPInstruction *createInstruction(unsigned Opcode, + ArrayRef<VPValue *> Operands) { + VPInstruction *Instr = new VPInstruction(Opcode, Operands); + if (BB) + BB->insert(Instr, InsertPt); + return Instr; + } + + VPInstruction *createInstruction(unsigned Opcode, + std::initializer_list<VPValue *> Operands) { + return createInstruction(Opcode, ArrayRef<VPValue *>(Operands)); + } + +public: + VPBuilder() {} + + /// Clear the insertion point: created instructions will not be inserted into + /// a block. + void clearInsertionPoint() { + BB = nullptr; + InsertPt = VPBasicBlock::iterator(); + } + + VPBasicBlock *getInsertBlock() const { return BB; } + VPBasicBlock::iterator getInsertPoint() const { return InsertPt; } + + /// InsertPoint - A saved insertion point. + class VPInsertPoint { + VPBasicBlock *Block = nullptr; + VPBasicBlock::iterator Point; + + public: + /// Creates a new insertion point which doesn't point to anything. + VPInsertPoint() = default; + + /// Creates a new insertion point at the given location. + VPInsertPoint(VPBasicBlock *InsertBlock, VPBasicBlock::iterator InsertPoint) + : Block(InsertBlock), Point(InsertPoint) {} + + /// Returns true if this insert point is set. + bool isSet() const { return Block != nullptr; } + + VPBasicBlock *getBlock() const { return Block; } + VPBasicBlock::iterator getPoint() const { return Point; } + }; + + /// Sets the current insert point to a previously-saved location. + void restoreIP(VPInsertPoint IP) { + if (IP.isSet()) + setInsertPoint(IP.getBlock(), IP.getPoint()); + else + clearInsertionPoint(); + } + + /// This specifies that created VPInstructions should be appended to the end + /// of the specified block. + void setInsertPoint(VPBasicBlock *TheBB) { + assert(TheBB && "Attempting to set a null insert point"); + BB = TheBB; + InsertPt = BB->end(); + } + + /// This specifies that created instructions should be inserted at the + /// specified point. + void setInsertPoint(VPBasicBlock *TheBB, VPBasicBlock::iterator IP) { + BB = TheBB; + InsertPt = IP; + } + + /// Insert and return the specified instruction. + VPInstruction *insert(VPInstruction *I) const { + BB->insert(I, InsertPt); + return I; + } + + /// Create an N-ary operation with \p Opcode, \p Operands and set \p Inst as + /// its underlying Instruction. + VPValue *createNaryOp(unsigned Opcode, ArrayRef<VPValue *> Operands, + Instruction *Inst = nullptr) { + VPInstruction *NewVPInst = createInstruction(Opcode, Operands); + NewVPInst->setUnderlyingValue(Inst); + return NewVPInst; + } + VPValue *createNaryOp(unsigned Opcode, + std::initializer_list<VPValue *> Operands, + Instruction *Inst = nullptr) { + return createNaryOp(Opcode, ArrayRef<VPValue *>(Operands), Inst); + } + + VPValue *createNot(VPValue *Operand) { + return createInstruction(VPInstruction::Not, {Operand}); + } + + VPValue *createAnd(VPValue *LHS, VPValue *RHS) { + return createInstruction(Instruction::BinaryOps::And, {LHS, RHS}); + } + + VPValue *createOr(VPValue *LHS, VPValue *RHS) { + return createInstruction(Instruction::BinaryOps::Or, {LHS, RHS}); + } + + //===--------------------------------------------------------------------===// + // RAII helpers. + //===--------------------------------------------------------------------===// + + /// RAII object that stores the current insertion point and restores it when + /// the object is destroyed. + class InsertPointGuard { + VPBuilder &Builder; + VPBasicBlock *Block; + VPBasicBlock::iterator Point; + + public: + InsertPointGuard(VPBuilder &B) + : Builder(B), Block(B.getInsertBlock()), Point(B.getInsertPoint()) {} + + InsertPointGuard(const InsertPointGuard &) = delete; + InsertPointGuard &operator=(const InsertPointGuard &) = delete; + + ~InsertPointGuard() { Builder.restoreIP(VPInsertPoint(Block, Point)); } + }; +}; + +/// TODO: The following VectorizationFactor was pulled out of +/// LoopVectorizationCostModel class. LV also deals with +/// VectorizerParams::VectorizationFactor and VectorizationCostTy. +/// We need to streamline them. + +/// Information about vectorization costs +struct VectorizationFactor { + // Vector width with best cost + unsigned Width; + // Cost of the loop with that width + unsigned Cost; +}; + +/// Planner drives the vectorization process after having passed +/// Legality checks. +class LoopVectorizationPlanner { + /// The loop that we evaluate. + Loop *OrigLoop; + + /// Loop Info analysis. + LoopInfo *LI; + + /// Target Library Info. + const TargetLibraryInfo *TLI; + + /// Target Transform Info. + const TargetTransformInfo *TTI; + + /// The legality analysis. + LoopVectorizationLegality *Legal; + + /// The profitablity analysis. + LoopVectorizationCostModel &CM; + + using VPlanPtr = std::unique_ptr<VPlan>; + + SmallVector<VPlanPtr, 4> VPlans; + + /// This class is used to enable the VPlan to invoke a method of ILV. This is + /// needed until the method is refactored out of ILV and becomes reusable. + struct VPCallbackILV : public VPCallback { + InnerLoopVectorizer &ILV; + + VPCallbackILV(InnerLoopVectorizer &ILV) : ILV(ILV) {} + + Value *getOrCreateVectorValues(Value *V, unsigned Part) override; + }; + + /// A builder used to construct the current plan. + VPBuilder Builder; + + unsigned BestVF = 0; + unsigned BestUF = 0; + +public: + LoopVectorizationPlanner(Loop *L, LoopInfo *LI, const TargetLibraryInfo *TLI, + const TargetTransformInfo *TTI, + LoopVectorizationLegality *Legal, + LoopVectorizationCostModel &CM) + : OrigLoop(L), LI(LI), TLI(TLI), TTI(TTI), Legal(Legal), CM(CM) {} + + /// Plan how to best vectorize, return the best VF and its cost. + VectorizationFactor plan(bool OptForSize, unsigned UserVF); + + /// Use the VPlan-native path to plan how to best vectorize, return the best + /// VF and its cost. + VectorizationFactor planInVPlanNativePath(bool OptForSize, unsigned UserVF); + + /// Finalize the best decision and dispose of all other VPlans. + void setBestPlan(unsigned VF, unsigned UF); + + /// Generate the IR code for the body of the vectorized loop according to the + /// best selected VPlan. + void executePlan(InnerLoopVectorizer &LB, DominatorTree *DT); + + void printPlans(raw_ostream &O) { + for (const auto &Plan : VPlans) + O << *Plan; + } + + /// Test a \p Predicate on a \p Range of VF's. Return the value of applying + /// \p Predicate on Range.Start, possibly decreasing Range.End such that the + /// returned value holds for the entire \p Range. + static bool + getDecisionAndClampRange(const std::function<bool(unsigned)> &Predicate, + VFRange &Range); + +protected: + /// Collect the instructions from the original loop that would be trivially + /// dead in the vectorized loop if generated. + void collectTriviallyDeadInstructions( + SmallPtrSetImpl<Instruction *> &DeadInstructions); + + /// Build VPlans for power-of-2 VF's between \p MinVF and \p MaxVF inclusive, + /// according to the information gathered by Legal when it checked if it is + /// legal to vectorize the loop. + void buildVPlans(unsigned MinVF, unsigned MaxVF); + +private: + /// Build a VPlan according to the information gathered by Legal. \return a + /// VPlan for vectorization factors \p Range.Start and up to \p Range.End + /// exclusive, possibly decreasing \p Range.End. + VPlanPtr buildVPlan(VFRange &Range); + + /// Build a VPlan using VPRecipes according to the information gather by + /// Legal. This method is only used for the legacy inner loop vectorizer. + VPlanPtr + buildVPlanWithVPRecipes(VFRange &Range, SmallPtrSetImpl<Value *> &NeedDef, + SmallPtrSetImpl<Instruction *> &DeadInstructions); + + /// Build VPlans for power-of-2 VF's between \p MinVF and \p MaxVF inclusive, + /// according to the information gathered by Legal when it checked if it is + /// legal to vectorize the loop. This method creates VPlans using VPRecipes. + void buildVPlansWithVPRecipes(unsigned MinVF, unsigned MaxVF); +}; + +} // namespace llvm + +#endif // LLVM_TRANSFORMS_VECTORIZE_LOOPVECTORIZATIONPLANNER_H diff --git a/lib/Transforms/Vectorize/LoopVectorize.cpp b/lib/Transforms/Vectorize/LoopVectorize.cpp index 52f32cda2609..3c693f5d5ee0 100644 --- a/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -26,6 +26,14 @@ // of vectorization. It decides on the optimal vector width, which // can be one, if vectorization is not profitable. // +// There is a development effort going on to migrate loop vectorizer to the +// VPlan infrastructure and to introduce outer loop vectorization support (see +// docs/Proposal/VectorizationPlan.rst and +// http://lists.llvm.org/pipermail/llvm-dev/2017-December/119523.html). For this +// purpose, we temporarily introduced the VPlan-native vectorization path: an +// alternative vectorization path that is natively implemented on top of the +// VPlan infrastructure. See EnableVPlanNativePath for enabling. +// //===----------------------------------------------------------------------===// // // The reduction-variable vectorization is based on the paper: @@ -47,8 +55,9 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Vectorize/LoopVectorize.h" -#include "VPlan.h" -#include "VPlanBuilder.h" +#include "LoopVectorizationPlanner.h" +#include "VPRecipeBuilder.h" +#include "VPlanHCFGBuilder.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -57,11 +66,9 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" -#include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" @@ -70,6 +77,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/DemandedBits.h" #include "llvm/Analysis/GlobalsModRef.h" @@ -124,6 +132,7 @@ #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/LoopVersioning.h" +#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h" #include <algorithm> #include <cassert> #include <cstdint> @@ -145,10 +154,6 @@ using namespace llvm; STATISTIC(LoopsVectorized, "Number of loops vectorized"); STATISTIC(LoopsAnalyzed, "Number of loops analyzed for vectorization"); -static cl::opt<bool> - EnableIfConversion("enable-if-conversion", cl::init(true), cl::Hidden, - cl::desc("Enable if-conversion during vectorization.")); - /// Loops with a known constant trip count below this number are vectorized only /// if no scalar iteration overheads are incurred. static cl::opt<unsigned> TinyTripCountVectorThreshold( @@ -184,9 +189,6 @@ static cl::opt<unsigned> ForceTargetNumVectorRegs( "force-target-num-vector-regs", cl::init(0), cl::Hidden, cl::desc("A flag that overrides the target's number of vector registers.")); -/// Maximum vectorization interleave count. -static const unsigned MaxInterleaveFactor = 16; - static cl::opt<unsigned> ForceTargetMaxScalarInterleaveFactor( "force-target-max-scalar-interleave", cl::init(0), cl::Hidden, cl::desc("A flag that overrides the target's max interleave factor for " @@ -209,7 +211,7 @@ static cl::opt<unsigned> SmallLoopCost( "The cost of a loop that is considered 'small' by the interleaver.")); static cl::opt<bool> LoopVectorizeWithBlockFrequency( - "loop-vectorize-with-block-frequency", cl::init(false), cl::Hidden, + "loop-vectorize-with-block-frequency", cl::init(true), cl::Hidden, cl::desc("Enable the use of the block frequency analysis to access PGO " "heuristics minimizing code growth in cold regions and being more " "aggressive in hot regions.")); @@ -238,71 +240,21 @@ static cl::opt<unsigned> MaxNestedScalarReductionIC( cl::desc("The maximum interleave count to use when interleaving a scalar " "reduction in a nested loop.")); -static cl::opt<unsigned> PragmaVectorizeMemoryCheckThreshold( - "pragma-vectorize-memory-check-threshold", cl::init(128), cl::Hidden, - cl::desc("The maximum allowed number of runtime memory checks with a " - "vectorize(enable) pragma.")); - -static cl::opt<unsigned> VectorizeSCEVCheckThreshold( - "vectorize-scev-check-threshold", cl::init(16), cl::Hidden, - cl::desc("The maximum number of SCEV checks allowed.")); - -static cl::opt<unsigned> PragmaVectorizeSCEVCheckThreshold( - "pragma-vectorize-scev-check-threshold", cl::init(128), cl::Hidden, - cl::desc("The maximum number of SCEV checks allowed with a " - "vectorize(enable) pragma")); - -/// Create an analysis remark that explains why vectorization failed -/// -/// \p PassName is the name of the pass (e.g. can be AlwaysPrint). \p -/// RemarkName is the identifier for the remark. If \p I is passed it is an -/// instruction that prevents vectorization. Otherwise \p TheLoop is used for -/// the location of the remark. \return the remark object that can be -/// streamed to. -static OptimizationRemarkAnalysis -createMissedAnalysis(const char *PassName, StringRef RemarkName, Loop *TheLoop, - Instruction *I = nullptr) { - Value *CodeRegion = TheLoop->getHeader(); - DebugLoc DL = TheLoop->getStartLoc(); - - if (I) { - CodeRegion = I->getParent(); - // If there is no debug location attached to the instruction, revert back to - // using the loop's. - if (I->getDebugLoc()) - DL = I->getDebugLoc(); - } - - OptimizationRemarkAnalysis R(PassName, RemarkName, DL, CodeRegion); - R << "loop not vectorized: "; - return R; -} - -namespace { - -class LoopVectorizationLegality; -class LoopVectorizationCostModel; -class LoopVectorizationRequirements; - -} // end anonymous namespace - -/// Returns true if the given loop body has a cycle, excluding the loop -/// itself. -static bool hasCyclesInLoopBody(const Loop &L) { - if (!L.empty()) - return true; - - for (const auto &SCC : - make_range(scc_iterator<Loop, LoopBodyTraits>::begin(L), - scc_iterator<Loop, LoopBodyTraits>::end(L))) { - if (SCC.size() > 1) { - DEBUG(dbgs() << "LVL: Detected a cycle in the loop body:\n"); - DEBUG(L.dump()); - return true; - } - } - return false; -} +static cl::opt<bool> EnableVPlanNativePath( + "enable-vplan-native-path", cl::init(false), cl::Hidden, + cl::desc("Enable VPlan-native vectorization path with " + "support for outer loop vectorization.")); + +// This flag enables the stress testing of the VPlan H-CFG construction in the +// VPlan-native vectorization path. It must be used in conjuction with +// -enable-vplan-native-path. -vplan-verify-hcfg can also be used to enable the +// verification of the H-CFGs built. +static cl::opt<bool> VPlanBuildStressTest( + "vplan-build-stress-test", cl::init(false), cl::Hidden, + cl::desc( + "Build VPlan for every supported loop nest in the function and bail " + "out right after the build (stress test the VPlan H-CFG construction " + "in the VPlan-native vectorization path).")); /// A helper function for converting Scalar types to vector types. /// If the incoming type is void, we return void. If the VF is 1, we return @@ -317,16 +269,6 @@ static Type *ToVectorTy(Type *Scalar, unsigned VF) { // in the project. They can be effectively organized in a common Load/Store // utilities unit. -/// A helper function that returns the pointer operand of a load or store -/// instruction. -static Value *getPointerOperand(Value *I) { - if (auto *LI = dyn_cast<LoadInst>(I)) - return LI->getPointerOperand(); - if (auto *SI = dyn_cast<StoreInst>(I)) - return SI->getPointerOperand(); - return nullptr; -} - /// A helper function that returns the type of loaded or stored value. static Type *getMemInstValueType(Value *I) { assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && @@ -373,7 +315,7 @@ static bool hasIrregularType(Type *Ty, const DataLayout &DL, unsigned VF) { /// A helper function that returns the reciprocal of the block probability of /// predicated blocks. If we return X, we are assuming the predicated block -/// will execute once for for every X iterations of the loop header. +/// will execute once for every X iterations of the loop header. /// /// TODO: We should use actual block probability here, if available. Currently, /// we always assume predicated blocks have a 50% chance of executing. @@ -502,7 +444,7 @@ public: void vectorizeMemoryInstruction(Instruction *Instr, VectorParts *BlockInMask = nullptr); - /// \brief Set the debug location in the builder using the debug location in + /// Set the debug location in the builder using the debug location in /// the instruction. void setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr); @@ -538,7 +480,7 @@ protected: /// vectorizing this phi node. void fixReduction(PHINode *Phi); - /// \brief The Loop exit block may have single value PHI nodes with some + /// The Loop exit block may have single value PHI nodes with some /// incoming value. While vectorizing we only handled real values /// that were defined inside the loop and we should have one value for /// each predecessor of its parent basic block. See PR14725. @@ -573,9 +515,9 @@ protected: /// Compute scalar induction steps. \p ScalarIV is the scalar induction /// variable on which to base the steps, \p Step is the size of the step, and /// \p EntryVal is the value from the original loop that maps to the steps. - /// Note that \p EntryVal doesn't have to be an induction variable (e.g., it - /// can be a truncate instruction). - void buildScalarSteps(Value *ScalarIV, Value *Step, Value *EntryVal, + /// Note that \p EntryVal doesn't have to be an induction variable - it + /// can also be a truncate instruction. + void buildScalarSteps(Value *ScalarIV, Value *Step, Instruction *EntryVal, const InductionDescriptor &ID); /// Create a vector induction phi node based on an existing scalar one. \p @@ -602,10 +544,20 @@ protected: /// vector loop for both the Phi and the cast. /// If \p VectorLoopValue is a scalarized value, \p Lane is also specified, /// Otherwise, \p VectorLoopValue is a widened/vectorized value. - void recordVectorLoopValueForInductionCast (const InductionDescriptor &ID, - Value *VectorLoopValue, - unsigned Part, - unsigned Lane = UINT_MAX); + /// + /// \p EntryVal is the value from the original loop that maps to the vector + /// phi node and is used to distinguish what is the IV currently being + /// processed - original one (if \p EntryVal is a phi corresponding to the + /// original IV) or the "newly-created" one based on the proof mentioned above + /// (see also buildScalarSteps() and createVectorIntOrFPInductionPHI()). In the + /// latter case \p EntryVal is a TruncInst and we must not record anything for + /// that IV, but it's error-prone to expect callers of this routine to care + /// about that, hence this explicit parameter. + void recordVectorLoopValueForInductionCast(const InductionDescriptor &ID, + const Instruction *EntryVal, + Value *VectorLoopValue, + unsigned Part, + unsigned Lane = UINT_MAX); /// Generate a shuffle sequence that will reverse the vector Vec. virtual Value *reverseVector(Value *Vec); @@ -646,7 +598,7 @@ protected: /// loop. void addMetadata(Instruction *To, Instruction *From); - /// \brief Similar to the previous function but it adds the metadata to a + /// Similar to the previous function but it adds the metadata to a /// vector of instructions. void addMetadata(ArrayRef<Value *> To, Instruction *From); @@ -679,7 +631,7 @@ protected: /// Interface to emit optimization remarks. OptimizationRemarkEmitter *ORE; - /// \brief LoopVersioning. It's only set up (non-null) if memchecks were + /// LoopVersioning. It's only set up (non-null) if memchecks were /// used. /// /// This is currently only used to add no-alias metadata based on the @@ -777,7 +729,7 @@ private: } // end namespace llvm -/// \brief Look for a meaningful debug location on the instruction or it's +/// Look for a meaningful debug location on the instruction or it's /// operands. static Instruction *getDebugLocFromInstOrOperands(Instruction *I) { if (!I) @@ -849,7 +801,7 @@ void InnerLoopVectorizer::addMetadata(ArrayRef<Value *> To, namespace llvm { -/// \brief The group of interleaved loads/stores sharing the same stride and +/// The group of interleaved loads/stores sharing the same stride and /// close to each other. /// /// Each member in this group has an index starting from 0, and the largest @@ -893,7 +845,7 @@ public: unsigned getAlignment() const { return Align; } unsigned getNumMembers() const { return Members.size(); } - /// \brief Try to insert a new member \p Instr with index \p Index and + /// Try to insert a new member \p Instr with index \p Index and /// alignment \p NewAlign. The index is related to the leader and it could be /// negative if it is the new leader. /// @@ -927,7 +879,7 @@ public: return true; } - /// \brief Get the member with the given index \p Index + /// Get the member with the given index \p Index /// /// \returns nullptr if contains no such member. Instruction *getMember(unsigned Index) const { @@ -938,7 +890,7 @@ public: return Members.find(Key)->second; } - /// \brief Get the index for the given member. Unlike the key in the member + /// Get the index for the given member. Unlike the key in the member /// map, the index starts from 0. unsigned getIndex(Instruction *Instr) const { for (auto I : Members) @@ -989,7 +941,7 @@ private: namespace { -/// \brief Drive the analysis of interleaved memory accesses in the loop. +/// Drive the analysis of interleaved memory accesses in the loop. /// /// Use this class to analyze interleaved accesses only when we can vectorize /// a loop. Otherwise it's meaningless to do analysis as the vectorization @@ -1000,11 +952,12 @@ namespace { class InterleavedAccessInfo { public: InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L, - DominatorTree *DT, LoopInfo *LI) - : PSE(PSE), TheLoop(L), DT(DT), LI(LI) {} + DominatorTree *DT, LoopInfo *LI, + const LoopAccessInfo *LAI) + : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {} ~InterleavedAccessInfo() { - SmallSet<InterleaveGroup *, 4> DelSet; + SmallPtrSet<InterleaveGroup *, 4> DelSet; // Avoid releasing a pointer twice. for (auto &I : InterleaveGroupMap) DelSet.insert(I.second); @@ -1012,16 +965,16 @@ public: delete Ptr; } - /// \brief Analyze the interleaved accesses and collect them in interleave + /// Analyze the interleaved accesses and collect them in interleave /// groups. Substitute symbolic strides using \p Strides. - void analyzeInterleaving(const ValueToValueMap &Strides); + void analyzeInterleaving(); - /// \brief Check if \p Instr belongs to any interleave group. + /// Check if \p Instr belongs to any interleave group. bool isInterleaved(Instruction *Instr) const { return InterleaveGroupMap.count(Instr); } - /// \brief Get the interleave group that \p Instr belongs to. + /// Get the interleave group that \p Instr belongs to. /// /// \returns nullptr if doesn't have such group. InterleaveGroup *getInterleaveGroup(Instruction *Instr) const { @@ -1030,13 +983,10 @@ public: return nullptr; } - /// \brief Returns true if an interleaved group that may access memory + /// Returns true if an interleaved group that may access memory /// out-of-bounds requires a scalar epilogue iteration for correctness. bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; } - /// \brief Initialize the LoopAccessInfo used for dependence checking. - void setLAI(const LoopAccessInfo *Info) { LAI = Info; } - private: /// A wrapper around ScalarEvolution, used to add runtime SCEV checks. /// Simplifies SCEV expressions in the context of existing SCEV assumptions. @@ -1047,7 +997,7 @@ private: Loop *TheLoop; DominatorTree *DT; LoopInfo *LI; - const LoopAccessInfo *LAI = nullptr; + const LoopAccessInfo *LAI; /// True if the loop may contain non-reversed interleaved groups with /// out-of-bounds accesses. We ensure we don't speculatively access memory @@ -1061,7 +1011,7 @@ private: /// access to a set of dependent sink accesses. DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences; - /// \brief The descriptor for a strided memory access. + /// The descriptor for a strided memory access. struct StrideDescriptor { StrideDescriptor() = default; StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size, @@ -1081,10 +1031,10 @@ private: unsigned Align = 0; }; - /// \brief A type for holding instructions and their stride descriptors. + /// A type for holding instructions and their stride descriptors. using StrideEntry = std::pair<Instruction *, StrideDescriptor>; - /// \brief Create a new interleave group with the given instruction \p Instr, + /// Create a new interleave group with the given instruction \p Instr, /// stride \p Stride and alignment \p Align. /// /// \returns the newly created interleave group. @@ -1096,7 +1046,7 @@ private: return InterleaveGroupMap[Instr]; } - /// \brief Release the group and remove all the relationships. + /// Release the group and remove all the relationships. void releaseGroup(InterleaveGroup *Group) { for (unsigned i = 0; i < Group->getFactor(); i++) if (Instruction *Member = Group->getMember(i)) @@ -1105,28 +1055,28 @@ private: delete Group; } - /// \brief Collect all the accesses with a constant stride in program order. + /// Collect all the accesses with a constant stride in program order. void collectConstStrideAccesses( MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, const ValueToValueMap &Strides); - /// \brief Returns true if \p Stride is allowed in an interleaved group. + /// Returns true if \p Stride is allowed in an interleaved group. static bool isStrided(int Stride) { unsigned Factor = std::abs(Stride); return Factor >= 2 && Factor <= MaxInterleaveGroupFactor; } - /// \brief Returns true if \p BB is a predicated block. + /// Returns true if \p BB is a predicated block. bool isPredicated(BasicBlock *BB) const { return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); } - /// \brief Returns true if LoopAccessInfo can be used for dependence queries. + /// Returns true if LoopAccessInfo can be used for dependence queries. bool areDependencesValid() const { return LAI && LAI->getDepChecker().getDependences(); } - /// \brief Returns true if memory accesses \p A and \p B can be reordered, if + /// Returns true if memory accesses \p A and \p B can be reordered, if /// necessary, when constructing interleaved groups. /// /// \p A must precede \p B in program order. We return false if reordering is @@ -1174,7 +1124,7 @@ private: return !Dependences.count(Src) || !Dependences.lookup(Src).count(Sink); } - /// \brief Collect the dependences from LoopAccessInfo. + /// Collect the dependences from LoopAccessInfo. /// /// We process the dependences once during the interleaved access analysis to /// enable constant-time dependence queries. @@ -1187,315 +1137,6 @@ private: } }; -/// Utility class for getting and setting loop vectorizer hints in the form -/// of loop metadata. -/// This class keeps a number of loop annotations locally (as member variables) -/// and can, upon request, write them back as metadata on the loop. It will -/// initially scan the loop for existing metadata, and will update the local -/// values based on information in the loop. -/// We cannot write all values to metadata, as the mere presence of some info, -/// for example 'force', means a decision has been made. So, we need to be -/// careful NOT to add them if the user hasn't specifically asked so. -class LoopVectorizeHints { - enum HintKind { HK_WIDTH, HK_UNROLL, HK_FORCE, HK_ISVECTORIZED }; - - /// Hint - associates name and validation with the hint value. - struct Hint { - const char *Name; - unsigned Value; // This may have to change for non-numeric values. - HintKind Kind; - - Hint(const char *Name, unsigned Value, HintKind Kind) - : Name(Name), Value(Value), Kind(Kind) {} - - bool validate(unsigned Val) { - switch (Kind) { - case HK_WIDTH: - return isPowerOf2_32(Val) && Val <= VectorizerParams::MaxVectorWidth; - case HK_UNROLL: - return isPowerOf2_32(Val) && Val <= MaxInterleaveFactor; - case HK_FORCE: - return (Val <= 1); - case HK_ISVECTORIZED: - return (Val==0 || Val==1); - } - return false; - } - }; - - /// Vectorization width. - Hint Width; - - /// Vectorization interleave factor. - Hint Interleave; - - /// Vectorization forced - Hint Force; - - /// Already Vectorized - Hint IsVectorized; - - /// Return the loop metadata prefix. - static StringRef Prefix() { return "llvm.loop."; } - - /// True if there is any unsafe math in the loop. - bool PotentiallyUnsafe = false; - -public: - enum ForceKind { - FK_Undefined = -1, ///< Not selected. - FK_Disabled = 0, ///< Forcing disabled. - FK_Enabled = 1, ///< Forcing enabled. - }; - - LoopVectorizeHints(const Loop *L, bool DisableInterleaving, - OptimizationRemarkEmitter &ORE) - : Width("vectorize.width", VectorizerParams::VectorizationFactor, - HK_WIDTH), - Interleave("interleave.count", DisableInterleaving, HK_UNROLL), - Force("vectorize.enable", FK_Undefined, HK_FORCE), - IsVectorized("isvectorized", 0, HK_ISVECTORIZED), TheLoop(L), ORE(ORE) { - // Populate values with existing loop metadata. - getHintsFromMetadata(); - - // force-vector-interleave overrides DisableInterleaving. - if (VectorizerParams::isInterleaveForced()) - Interleave.Value = VectorizerParams::VectorizationInterleave; - - if (IsVectorized.Value != 1) - // If the vectorization width and interleaving count are both 1 then - // consider the loop to have been already vectorized because there's - // nothing more that we can do. - IsVectorized.Value = Width.Value == 1 && Interleave.Value == 1; - DEBUG(if (DisableInterleaving && Interleave.Value == 1) dbgs() - << "LV: Interleaving disabled by the pass manager\n"); - } - - /// Mark the loop L as already vectorized by setting the width to 1. - void setAlreadyVectorized() { - IsVectorized.Value = 1; - Hint Hints[] = {IsVectorized}; - writeHintsToMetadata(Hints); - } - - bool allowVectorization(Function *F, Loop *L, bool AlwaysVectorize) const { - if (getForce() == LoopVectorizeHints::FK_Disabled) { - DEBUG(dbgs() << "LV: Not vectorizing: #pragma vectorize disable.\n"); - emitRemarkWithHints(); - return false; - } - - if (!AlwaysVectorize && getForce() != LoopVectorizeHints::FK_Enabled) { - DEBUG(dbgs() << "LV: Not vectorizing: No #pragma vectorize enable.\n"); - emitRemarkWithHints(); - return false; - } - - if (getIsVectorized() == 1) { - DEBUG(dbgs() << "LV: Not vectorizing: Disabled/already vectorized.\n"); - // FIXME: Add interleave.disable metadata. This will allow - // vectorize.disable to be used without disabling the pass and errors - // to differentiate between disabled vectorization and a width of 1. - ORE.emit([&]() { - return OptimizationRemarkAnalysis(vectorizeAnalysisPassName(), - "AllDisabled", L->getStartLoc(), - L->getHeader()) - << "loop not vectorized: vectorization and interleaving are " - "explicitly disabled, or the loop has already been " - "vectorized"; - }); - return false; - } - - return true; - } - - /// Dumps all the hint information. - void emitRemarkWithHints() const { - using namespace ore; - - ORE.emit([&]() { - if (Force.Value == LoopVectorizeHints::FK_Disabled) - return OptimizationRemarkMissed(LV_NAME, "MissedExplicitlyDisabled", - TheLoop->getStartLoc(), - TheLoop->getHeader()) - << "loop not vectorized: vectorization is explicitly disabled"; - else { - OptimizationRemarkMissed R(LV_NAME, "MissedDetails", - TheLoop->getStartLoc(), - TheLoop->getHeader()); - R << "loop not vectorized"; - if (Force.Value == LoopVectorizeHints::FK_Enabled) { - R << " (Force=" << NV("Force", true); - if (Width.Value != 0) - R << ", Vector Width=" << NV("VectorWidth", Width.Value); - if (Interleave.Value != 0) - R << ", Interleave Count=" - << NV("InterleaveCount", Interleave.Value); - R << ")"; - } - return R; - } - }); - } - - unsigned getWidth() const { return Width.Value; } - unsigned getInterleave() const { return Interleave.Value; } - unsigned getIsVectorized() const { return IsVectorized.Value; } - enum ForceKind getForce() const { return (ForceKind)Force.Value; } - - /// \brief If hints are provided that force vectorization, use the AlwaysPrint - /// pass name to force the frontend to print the diagnostic. - const char *vectorizeAnalysisPassName() const { - if (getWidth() == 1) - return LV_NAME; - if (getForce() == LoopVectorizeHints::FK_Disabled) - return LV_NAME; - if (getForce() == LoopVectorizeHints::FK_Undefined && getWidth() == 0) - return LV_NAME; - return OptimizationRemarkAnalysis::AlwaysPrint; - } - - bool allowReordering() const { - // When enabling loop hints are provided we allow the vectorizer to change - // the order of operations that is given by the scalar loop. This is not - // enabled by default because can be unsafe or inefficient. For example, - // reordering floating-point operations will change the way round-off - // error accumulates in the loop. - return getForce() == LoopVectorizeHints::FK_Enabled || getWidth() > 1; - } - - bool isPotentiallyUnsafe() const { - // Avoid FP vectorization if the target is unsure about proper support. - // This may be related to the SIMD unit in the target not handling - // IEEE 754 FP ops properly, or bad single-to-double promotions. - // Otherwise, a sequence of vectorized loops, even without reduction, - // could lead to different end results on the destination vectors. - return getForce() != LoopVectorizeHints::FK_Enabled && PotentiallyUnsafe; - } - - void setPotentiallyUnsafe() { PotentiallyUnsafe = true; } - -private: - /// Find hints specified in the loop metadata and update local values. - void getHintsFromMetadata() { - MDNode *LoopID = TheLoop->getLoopID(); - if (!LoopID) - return; - - // First operand should refer to the loop id itself. - assert(LoopID->getNumOperands() > 0 && "requires at least one operand"); - assert(LoopID->getOperand(0) == LoopID && "invalid loop id"); - - for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { - const MDString *S = nullptr; - SmallVector<Metadata *, 4> Args; - - // The expected hint is either a MDString or a MDNode with the first - // operand a MDString. - if (const MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i))) { - if (!MD || MD->getNumOperands() == 0) - continue; - S = dyn_cast<MDString>(MD->getOperand(0)); - for (unsigned i = 1, ie = MD->getNumOperands(); i < ie; ++i) - Args.push_back(MD->getOperand(i)); - } else { - S = dyn_cast<MDString>(LoopID->getOperand(i)); - assert(Args.size() == 0 && "too many arguments for MDString"); - } - - if (!S) - continue; - - // Check if the hint starts with the loop metadata prefix. - StringRef Name = S->getString(); - if (Args.size() == 1) - setHint(Name, Args[0]); - } - } - - /// Checks string hint with one operand and set value if valid. - void setHint(StringRef Name, Metadata *Arg) { - if (!Name.startswith(Prefix())) - return; - Name = Name.substr(Prefix().size(), StringRef::npos); - - const ConstantInt *C = mdconst::dyn_extract<ConstantInt>(Arg); - if (!C) - return; - unsigned Val = C->getZExtValue(); - - Hint *Hints[] = {&Width, &Interleave, &Force, &IsVectorized}; - for (auto H : Hints) { - if (Name == H->Name) { - if (H->validate(Val)) - H->Value = Val; - else - DEBUG(dbgs() << "LV: ignoring invalid hint '" << Name << "'\n"); - break; - } - } - } - - /// Create a new hint from name / value pair. - MDNode *createHintMetadata(StringRef Name, unsigned V) const { - LLVMContext &Context = TheLoop->getHeader()->getContext(); - Metadata *MDs[] = {MDString::get(Context, Name), - ConstantAsMetadata::get( - ConstantInt::get(Type::getInt32Ty(Context), V))}; - return MDNode::get(Context, MDs); - } - - /// Matches metadata with hint name. - bool matchesHintMetadataName(MDNode *Node, ArrayRef<Hint> HintTypes) { - MDString *Name = dyn_cast<MDString>(Node->getOperand(0)); - if (!Name) - return false; - - for (auto H : HintTypes) - if (Name->getString().endswith(H.Name)) - return true; - return false; - } - - /// Sets current hints into loop metadata, keeping other values intact. - void writeHintsToMetadata(ArrayRef<Hint> HintTypes) { - if (HintTypes.empty()) - return; - - // Reserve the first element to LoopID (see below). - SmallVector<Metadata *, 4> MDs(1); - // If the loop already has metadata, then ignore the existing operands. - MDNode *LoopID = TheLoop->getLoopID(); - if (LoopID) { - for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { - MDNode *Node = cast<MDNode>(LoopID->getOperand(i)); - // If node in update list, ignore old value. - if (!matchesHintMetadataName(Node, HintTypes)) - MDs.push_back(Node); - } - } - - // Now, add the missing hints. - for (auto H : HintTypes) - MDs.push_back(createHintMetadata(Twine(Prefix(), H.Name).str(), H.Value)); - - // Replace current metadata node with new one. - LLVMContext &Context = TheLoop->getHeader()->getContext(); - MDNode *NewLoopID = MDNode::get(Context, MDs); - // Set operand 0 to refer to the loop id itself. - NewLoopID->replaceOperandWith(0, NewLoopID); - - TheLoop->setLoopID(NewLoopID); - } - - /// The loop these hints belong to. - const Loop *TheLoop; - - /// Interface to emit optimization remarks. - OptimizationRemarkEmitter &ORE; -}; - } // end anonymous namespace static void emitMissedWarning(Function *F, Loop *L, @@ -1519,324 +1160,7 @@ static void emitMissedWarning(Function *F, Loop *L, } } -namespace { - -/// LoopVectorizationLegality checks if it is legal to vectorize a loop, and -/// to what vectorization factor. -/// This class does not look at the profitability of vectorization, only the -/// legality. This class has two main kinds of checks: -/// * Memory checks - The code in canVectorizeMemory checks if vectorization -/// will change the order of memory accesses in a way that will change the -/// correctness of the program. -/// * Scalars checks - The code in canVectorizeInstrs and canVectorizeMemory -/// checks for a number of different conditions, such as the availability of a -/// single induction variable, that all types are supported and vectorize-able, -/// etc. This code reflects the capabilities of InnerLoopVectorizer. -/// This class is also used by InnerLoopVectorizer for identifying -/// induction variable and the different reduction variables. -class LoopVectorizationLegality { -public: - LoopVectorizationLegality( - Loop *L, PredicatedScalarEvolution &PSE, DominatorTree *DT, - TargetLibraryInfo *TLI, AliasAnalysis *AA, Function *F, - const TargetTransformInfo *TTI, - std::function<const LoopAccessInfo &(Loop &)> *GetLAA, LoopInfo *LI, - OptimizationRemarkEmitter *ORE, LoopVectorizationRequirements *R, - LoopVectorizeHints *H) - : TheLoop(L), PSE(PSE), TLI(TLI), TTI(TTI), DT(DT), GetLAA(GetLAA), - ORE(ORE), InterleaveInfo(PSE, L, DT, LI), Requirements(R), Hints(H) {} - - /// ReductionList contains the reduction descriptors for all - /// of the reductions that were found in the loop. - using ReductionList = DenseMap<PHINode *, RecurrenceDescriptor>; - - /// InductionList saves induction variables and maps them to the - /// induction descriptor. - using InductionList = MapVector<PHINode *, InductionDescriptor>; - - /// RecurrenceSet contains the phi nodes that are recurrences other than - /// inductions and reductions. - using RecurrenceSet = SmallPtrSet<const PHINode *, 8>; - - /// Returns true if it is legal to vectorize this loop. - /// This does not mean that it is profitable to vectorize this - /// loop, only that it is legal to do so. - bool canVectorize(); - - /// Returns the primary induction variable. - PHINode *getPrimaryInduction() { return PrimaryInduction; } - - /// Returns the reduction variables found in the loop. - ReductionList *getReductionVars() { return &Reductions; } - - /// Returns the induction variables found in the loop. - InductionList *getInductionVars() { return &Inductions; } - - /// Return the first-order recurrences found in the loop. - RecurrenceSet *getFirstOrderRecurrences() { return &FirstOrderRecurrences; } - - /// Return the set of instructions to sink to handle first-order recurrences. - DenseMap<Instruction *, Instruction *> &getSinkAfter() { return SinkAfter; } - - /// Returns the widest induction type. - Type *getWidestInductionType() { return WidestIndTy; } - - /// Returns True if V is a Phi node of an induction variable in this loop. - bool isInductionPhi(const Value *V); - - /// Returns True if V is a cast that is part of an induction def-use chain, - /// and had been proven to be redundant under a runtime guard (in other - /// words, the cast has the same SCEV expression as the induction phi). - bool isCastedInductionVariable(const Value *V); - - /// Returns True if V can be considered as an induction variable in this - /// loop. V can be the induction phi, or some redundant cast in the def-use - /// chain of the inducion phi. - bool isInductionVariable(const Value *V); - - /// Returns True if PN is a reduction variable in this loop. - bool isReductionVariable(PHINode *PN) { return Reductions.count(PN); } - - /// Returns True if Phi is a first-order recurrence in this loop. - bool isFirstOrderRecurrence(const PHINode *Phi); - - /// Return true if the block BB needs to be predicated in order for the loop - /// to be vectorized. - bool blockNeedsPredication(BasicBlock *BB); - - /// Check if this pointer is consecutive when vectorizing. This happens - /// when the last index of the GEP is the induction variable, or that the - /// pointer itself is an induction variable. - /// This check allows us to vectorize A[idx] into a wide load/store. - /// Returns: - /// 0 - Stride is unknown or non-consecutive. - /// 1 - Address is consecutive. - /// -1 - Address is consecutive, and decreasing. - /// NOTE: This method must only be used before modifying the original scalar - /// loop. Do not use after invoking 'createVectorizedLoopSkeleton' (PR34965). - int isConsecutivePtr(Value *Ptr); - - /// Returns true if the value V is uniform within the loop. - bool isUniform(Value *V); - - /// Returns the information that we collected about runtime memory check. - const RuntimePointerChecking *getRuntimePointerChecking() const { - return LAI->getRuntimePointerChecking(); - } - - const LoopAccessInfo *getLAI() const { return LAI; } - - /// \brief Check if \p Instr belongs to any interleaved access group. - bool isAccessInterleaved(Instruction *Instr) { - return InterleaveInfo.isInterleaved(Instr); - } - - /// \brief Get the interleaved access group that \p Instr belongs to. - const InterleaveGroup *getInterleavedAccessGroup(Instruction *Instr) { - return InterleaveInfo.getInterleaveGroup(Instr); - } - - /// \brief Returns true if an interleaved group requires a scalar iteration - /// to handle accesses with gaps. - bool requiresScalarEpilogue() const { - return InterleaveInfo.requiresScalarEpilogue(); - } - - unsigned getMaxSafeDepDistBytes() { return LAI->getMaxSafeDepDistBytes(); } - - uint64_t getMaxSafeRegisterWidth() const { - return LAI->getDepChecker().getMaxSafeRegisterWidth(); - } - - bool hasStride(Value *V) { return LAI->hasStride(V); } - - /// Returns true if the target machine supports masked store operation - /// for the given \p DataType and kind of access to \p Ptr. - bool isLegalMaskedStore(Type *DataType, Value *Ptr) { - return isConsecutivePtr(Ptr) && TTI->isLegalMaskedStore(DataType); - } - - /// Returns true if the target machine supports masked load operation - /// for the given \p DataType and kind of access to \p Ptr. - bool isLegalMaskedLoad(Type *DataType, Value *Ptr) { - return isConsecutivePtr(Ptr) && TTI->isLegalMaskedLoad(DataType); - } - - /// Returns true if the target machine supports masked scatter operation - /// for the given \p DataType. - bool isLegalMaskedScatter(Type *DataType) { - return TTI->isLegalMaskedScatter(DataType); - } - - /// Returns true if the target machine supports masked gather operation - /// for the given \p DataType. - bool isLegalMaskedGather(Type *DataType) { - return TTI->isLegalMaskedGather(DataType); - } - - /// Returns true if the target machine can represent \p V as a masked gather - /// or scatter operation. - bool isLegalGatherOrScatter(Value *V) { - auto *LI = dyn_cast<LoadInst>(V); - auto *SI = dyn_cast<StoreInst>(V); - if (!LI && !SI) - return false; - auto *Ptr = getPointerOperand(V); - auto *Ty = cast<PointerType>(Ptr->getType())->getElementType(); - return (LI && isLegalMaskedGather(Ty)) || (SI && isLegalMaskedScatter(Ty)); - } - - /// Returns true if vector representation of the instruction \p I - /// requires mask. - bool isMaskRequired(const Instruction *I) { return (MaskedOp.count(I) != 0); } - - unsigned getNumStores() const { return LAI->getNumStores(); } - unsigned getNumLoads() const { return LAI->getNumLoads(); } - unsigned getNumPredStores() const { return NumPredStores; } - - /// Returns true if \p I is an instruction that will be scalarized with - /// predication. Such instructions include conditional stores and - /// instructions that may divide by zero. - bool isScalarWithPredication(Instruction *I); - - /// Returns true if \p I is a memory instruction with consecutive memory - /// access that can be widened. - bool memoryInstructionCanBeWidened(Instruction *I, unsigned VF = 1); - - // Returns true if the NoNaN attribute is set on the function. - bool hasFunNoNaNAttr() const { return HasFunNoNaNAttr; } - -private: - /// Check if a single basic block loop is vectorizable. - /// At this point we know that this is a loop with a constant trip count - /// and we only need to check individual instructions. - bool canVectorizeInstrs(); - - /// When we vectorize loops we may change the order in which - /// we read and write from memory. This method checks if it is - /// legal to vectorize the code, considering only memory constrains. - /// Returns true if the loop is vectorizable - bool canVectorizeMemory(); - - /// Return true if we can vectorize this loop using the IF-conversion - /// transformation. - bool canVectorizeWithIfConvert(); - - /// Return true if all of the instructions in the block can be speculatively - /// executed. \p SafePtrs is a list of addresses that are known to be legal - /// and we know that we can read from them without segfault. - bool blockCanBePredicated(BasicBlock *BB, SmallPtrSetImpl<Value *> &SafePtrs); - - /// Updates the vectorization state by adding \p Phi to the inductions list. - /// This can set \p Phi as the main induction of the loop if \p Phi is a - /// better choice for the main induction than the existing one. - void addInductionPhi(PHINode *Phi, const InductionDescriptor &ID, - SmallPtrSetImpl<Value *> &AllowedExit); - - /// Create an analysis remark that explains why vectorization failed - /// - /// \p RemarkName is the identifier for the remark. If \p I is passed it is - /// an instruction that prevents vectorization. Otherwise the loop is used - /// for the location of the remark. \return the remark object that can be - /// streamed to. - OptimizationRemarkAnalysis - createMissedAnalysis(StringRef RemarkName, Instruction *I = nullptr) const { - return ::createMissedAnalysis(Hints->vectorizeAnalysisPassName(), - RemarkName, TheLoop, I); - } - - /// \brief If an access has a symbolic strides, this maps the pointer value to - /// the stride symbol. - const ValueToValueMap *getSymbolicStrides() { - // FIXME: Currently, the set of symbolic strides is sometimes queried before - // it's collected. This happens from canVectorizeWithIfConvert, when the - // pointer is checked to reference consecutive elements suitable for a - // masked access. - return LAI ? &LAI->getSymbolicStrides() : nullptr; - } - - unsigned NumPredStores = 0; - - /// The loop that we evaluate. - Loop *TheLoop; - - /// A wrapper around ScalarEvolution used to add runtime SCEV checks. - /// Applies dynamic knowledge to simplify SCEV expressions in the context - /// of existing SCEV assumptions. The analysis will also add a minimal set - /// of new predicates if this is required to enable vectorization and - /// unrolling. - PredicatedScalarEvolution &PSE; - - /// Target Library Info. - TargetLibraryInfo *TLI; - - /// Target Transform Info - const TargetTransformInfo *TTI; - - /// Dominator Tree. - DominatorTree *DT; - - // LoopAccess analysis. - std::function<const LoopAccessInfo &(Loop &)> *GetLAA; - - // And the loop-accesses info corresponding to this loop. This pointer is - // null until canVectorizeMemory sets it up. - const LoopAccessInfo *LAI = nullptr; - - /// Interface to emit optimization remarks. - OptimizationRemarkEmitter *ORE; - - /// The interleave access information contains groups of interleaved accesses - /// with the same stride and close to each other. - InterleavedAccessInfo InterleaveInfo; - - // --- vectorization state --- // - - /// Holds the primary induction variable. This is the counter of the - /// loop. - PHINode *PrimaryInduction = nullptr; - - /// Holds the reduction variables. - ReductionList Reductions; - - /// Holds all of the induction variables that we found in the loop. - /// Notice that inductions don't need to start at zero and that induction - /// variables can be pointers. - InductionList Inductions; - - /// Holds all the casts that participate in the update chain of the induction - /// variables, and that have been proven to be redundant (possibly under a - /// runtime guard). These casts can be ignored when creating the vectorized - /// loop body. - SmallPtrSet<Instruction *, 4> InductionCastsToIgnore; - - /// Holds the phi nodes that are first-order recurrences. - RecurrenceSet FirstOrderRecurrences; - - /// Holds instructions that need to sink past other instructions to handle - /// first-order recurrences. - DenseMap<Instruction *, Instruction *> SinkAfter; - - /// Holds the widest induction type encountered. - Type *WidestIndTy = nullptr; - - /// Allowed outside users. This holds the induction and reduction - /// vars which can be accessed from outside the loop. - SmallPtrSet<Value *, 4> AllowedExit; - - /// Can we assume the absence of NaNs. - bool HasFunNoNaNAttr = false; - - /// Vectorization requirements that will go through late-evaluation. - LoopVectorizationRequirements *Requirements; - - /// Used to emit an analysis of any legality issues. - LoopVectorizeHints *Hints; - - /// While vectorizing these instructions we have to generate a - /// call to the appropriate masked intrinsic - SmallPtrSet<const Instruction *, 8> MaskedOp; -}; +namespace llvm { /// LoopVectorizationCostModel - estimates the expected speedups due to /// vectorization. @@ -1853,23 +1177,15 @@ public: const TargetLibraryInfo *TLI, DemandedBits *DB, AssumptionCache *AC, OptimizationRemarkEmitter *ORE, const Function *F, - const LoopVectorizeHints *Hints) + const LoopVectorizeHints *Hints, + InterleavedAccessInfo &IAI) : TheLoop(L), PSE(PSE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB), - AC(AC), ORE(ORE), TheFunction(F), Hints(Hints) {} + AC(AC), ORE(ORE), TheFunction(F), Hints(Hints), InterleaveInfo(IAI) {} /// \return An upper bound for the vectorization factor, or None if /// vectorization should be avoided up front. Optional<unsigned> computeMaxVF(bool OptForSize); - /// Information about vectorization costs - struct VectorizationFactor { - // Vector width with best cost - unsigned Width; - - // Cost of the loop with that width - unsigned Cost; - }; - /// \return The most profitable vectorization factor and the cost of that VF. /// This method checks every power of two up to MaxVF. If UserVF is not ZERO /// then this vectorization factor will be selected if vectorization is @@ -1903,7 +1219,7 @@ public: /// avoid redundant calculations. void setCostBasedWideningDecision(unsigned VF); - /// \brief A struct that represents some properties of the register usage + /// A struct that represents some properties of the register usage /// of a loop. struct RegisterUsage { /// Holds the number of loop invariant values that are used in the loop. @@ -1911,9 +1227,6 @@ public: /// Holds the maximum number of concurrent live intervals in the loop. unsigned MaxLocalUsers; - - /// Holds the number of instructions in the loop. - unsigned NumInstructions; }; /// \return Returns information about the register usages of the loop for the @@ -2063,7 +1376,69 @@ public: collectLoopScalars(VF); } + /// Returns true if the target machine supports masked store operation + /// for the given \p DataType and kind of access to \p Ptr. + bool isLegalMaskedStore(Type *DataType, Value *Ptr) { + return Legal->isConsecutivePtr(Ptr) && TTI.isLegalMaskedStore(DataType); + } + + /// Returns true if the target machine supports masked load operation + /// for the given \p DataType and kind of access to \p Ptr. + bool isLegalMaskedLoad(Type *DataType, Value *Ptr) { + return Legal->isConsecutivePtr(Ptr) && TTI.isLegalMaskedLoad(DataType); + } + + /// Returns true if the target machine supports masked scatter operation + /// for the given \p DataType. + bool isLegalMaskedScatter(Type *DataType) { + return TTI.isLegalMaskedScatter(DataType); + } + + /// Returns true if the target machine supports masked gather operation + /// for the given \p DataType. + bool isLegalMaskedGather(Type *DataType) { + return TTI.isLegalMaskedGather(DataType); + } + + /// Returns true if the target machine can represent \p V as a masked gather + /// or scatter operation. + bool isLegalGatherOrScatter(Value *V) { + bool LI = isa<LoadInst>(V); + bool SI = isa<StoreInst>(V); + if (!LI && !SI) + return false; + auto *Ty = getMemInstValueType(V); + return (LI && isLegalMaskedGather(Ty)) || (SI && isLegalMaskedScatter(Ty)); + } + + /// Returns true if \p I is an instruction that will be scalarized with + /// predication. Such instructions include conditional stores and + /// instructions that may divide by zero. + bool isScalarWithPredication(Instruction *I); + + /// Returns true if \p I is a memory instruction with consecutive memory + /// access that can be widened. + bool memoryInstructionCanBeWidened(Instruction *I, unsigned VF = 1); + + /// Check if \p Instr belongs to any interleaved access group. + bool isAccessInterleaved(Instruction *Instr) { + return InterleaveInfo.isInterleaved(Instr); + } + + /// Get the interleaved access group that \p Instr belongs to. + const InterleaveGroup *getInterleavedAccessGroup(Instruction *Instr) { + return InterleaveInfo.getInterleaveGroup(Instr); + } + + /// Returns true if an interleaved group requires a scalar iteration + /// to handle accesses with gaps. + bool requiresScalarEpilogue() const { + return InterleaveInfo.requiresScalarEpilogue(); + } + private: + unsigned NumPredStores = 0; + /// \return An upper bound for the vectorization factor, larger than zero. /// One is returned if vectorization should best be avoided due to cost. unsigned computeFeasibleMaxVF(bool OptForSize, unsigned ConstTripCount); @@ -2115,12 +1490,16 @@ private: /// as a vector operation. bool isConsecutiveLoadOrStore(Instruction *I); + /// Returns true if an artificially high cost for emulated masked memrefs + /// should be used. + bool useEmulatedMaskMemRefHack(Instruction *I); + /// Create an analysis remark that explains why vectorization failed /// /// \p RemarkName is the identifier for the remark. \return the remark object /// that can be streamed to. OptimizationRemarkAnalysis createMissedAnalysis(StringRef RemarkName) { - return ::createMissedAnalysis(Hints->vectorizeAnalysisPassName(), + return createLVMissedAnalysis(Hints->vectorizeAnalysisPassName(), RemarkName, TheLoop); } @@ -2222,6 +1601,10 @@ public: /// Loop Vectorize Hint. const LoopVectorizeHints *Hints; + /// The interleave access information contains groups of interleaved accesses + /// with the same stride and close to each other. + InterleavedAccessInfo &InterleaveInfo; + /// Values to ignore in the cost model. SmallPtrSet<const Value *, 16> ValuesToIgnore; @@ -2229,271 +1612,78 @@ public: SmallPtrSet<const Value *, 16> VecValuesToIgnore; }; -} // end anonymous namespace - -namespace llvm { - -/// InnerLoopVectorizer vectorizes loops which contain only one basic -/// LoopVectorizationPlanner - drives the vectorization process after having -/// passed Legality checks. -/// The planner builds and optimizes the Vectorization Plans which record the -/// decisions how to vectorize the given loop. In particular, represent the -/// control-flow of the vectorized version, the replication of instructions that -/// are to be scalarized, and interleave access groups. -class LoopVectorizationPlanner { - /// The loop that we evaluate. - Loop *OrigLoop; - - /// Loop Info analysis. - LoopInfo *LI; - - /// Target Library Info. - const TargetLibraryInfo *TLI; - - /// Target Transform Info. - const TargetTransformInfo *TTI; - - /// The legality analysis. - LoopVectorizationLegality *Legal; - - /// The profitablity analysis. - LoopVectorizationCostModel &CM; - - using VPlanPtr = std::unique_ptr<VPlan>; - - SmallVector<VPlanPtr, 4> VPlans; - - /// This class is used to enable the VPlan to invoke a method of ILV. This is - /// needed until the method is refactored out of ILV and becomes reusable. - struct VPCallbackILV : public VPCallback { - InnerLoopVectorizer &ILV; - - VPCallbackILV(InnerLoopVectorizer &ILV) : ILV(ILV) {} - - Value *getOrCreateVectorValues(Value *V, unsigned Part) override { - return ILV.getOrCreateVectorValue(V, Part); - } - }; - - /// A builder used to construct the current plan. - VPBuilder Builder; - - /// When we if-convert we need to create edge masks. We have to cache values - /// so that we don't end up with exponential recursion/IR. Note that - /// if-conversion currently takes place during VPlan-construction, so these - /// caches are only used at that stage. - using EdgeMaskCacheTy = - DenseMap<std::pair<BasicBlock *, BasicBlock *>, VPValue *>; - using BlockMaskCacheTy = DenseMap<BasicBlock *, VPValue *>; - EdgeMaskCacheTy EdgeMaskCache; - BlockMaskCacheTy BlockMaskCache; - - unsigned BestVF = 0; - unsigned BestUF = 0; - -public: - LoopVectorizationPlanner(Loop *L, LoopInfo *LI, const TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI, - LoopVectorizationLegality *Legal, - LoopVectorizationCostModel &CM) - : OrigLoop(L), LI(LI), TLI(TLI), TTI(TTI), Legal(Legal), CM(CM) {} - - /// Plan how to best vectorize, return the best VF and its cost. - LoopVectorizationCostModel::VectorizationFactor plan(bool OptForSize, - unsigned UserVF); - - /// Finalize the best decision and dispose of all other VPlans. - void setBestPlan(unsigned VF, unsigned UF); - - /// Generate the IR code for the body of the vectorized loop according to the - /// best selected VPlan. - void executePlan(InnerLoopVectorizer &LB, DominatorTree *DT); - - void printPlans(raw_ostream &O) { - for (const auto &Plan : VPlans) - O << *Plan; - } - -protected: - /// Collect the instructions from the original loop that would be trivially - /// dead in the vectorized loop if generated. - void collectTriviallyDeadInstructions( - SmallPtrSetImpl<Instruction *> &DeadInstructions); - - /// A range of powers-of-2 vectorization factors with fixed start and - /// adjustable end. The range includes start and excludes end, e.g.,: - /// [1, 9) = {1, 2, 4, 8} - struct VFRange { - // A power of 2. - const unsigned Start; - - // Need not be a power of 2. If End <= Start range is empty. - unsigned End; - }; - - /// Test a \p Predicate on a \p Range of VF's. Return the value of applying - /// \p Predicate on Range.Start, possibly decreasing Range.End such that the - /// returned value holds for the entire \p Range. - bool getDecisionAndClampRange(const std::function<bool(unsigned)> &Predicate, - VFRange &Range); - - /// Build VPlans for power-of-2 VF's between \p MinVF and \p MaxVF inclusive, - /// according to the information gathered by Legal when it checked if it is - /// legal to vectorize the loop. - void buildVPlans(unsigned MinVF, unsigned MaxVF); - -private: - /// A helper function that computes the predicate of the block BB, assuming - /// that the header block of the loop is set to True. It returns the *entry* - /// mask for the block BB. - VPValue *createBlockInMask(BasicBlock *BB, VPlanPtr &Plan); - - /// A helper function that computes the predicate of the edge between SRC - /// and DST. - VPValue *createEdgeMask(BasicBlock *Src, BasicBlock *Dst, VPlanPtr &Plan); - - /// Check if \I belongs to an Interleave Group within the given VF \p Range, - /// \return true in the first returned value if so and false otherwise. - /// Build a new VPInterleaveGroup Recipe if \I is the primary member of an IG - /// for \p Range.Start, and provide it as the second returned value. - /// Note that if \I is an adjunct member of an IG for \p Range.Start, the - /// \return value is <true, nullptr>, as it is handled by another recipe. - /// \p Range.End may be decreased to ensure same decision from \p Range.Start - /// to \p Range.End. - VPInterleaveRecipe *tryToInterleaveMemory(Instruction *I, VFRange &Range); - - // Check if \I is a memory instruction to be widened for \p Range.Start and - // potentially masked. Such instructions are handled by a recipe that takes an - // additional VPInstruction for the mask. - VPWidenMemoryInstructionRecipe *tryToWidenMemory(Instruction *I, - VFRange &Range, - VPlanPtr &Plan); - - /// Check if an induction recipe should be constructed for \I within the given - /// VF \p Range. If so build and return it. If not, return null. \p Range.End - /// may be decreased to ensure same decision from \p Range.Start to - /// \p Range.End. - VPWidenIntOrFpInductionRecipe *tryToOptimizeInduction(Instruction *I, - VFRange &Range); - - /// Handle non-loop phi nodes. Currently all such phi nodes are turned into - /// a sequence of select instructions as the vectorizer currently performs - /// full if-conversion. - VPBlendRecipe *tryToBlend(Instruction *I, VPlanPtr &Plan); - - /// Check if \p I can be widened within the given VF \p Range. If \p I can be - /// widened for \p Range.Start, check if the last recipe of \p VPBB can be - /// extended to include \p I or else build a new VPWidenRecipe for it and - /// append it to \p VPBB. Return true if \p I can be widened for Range.Start, - /// false otherwise. Range.End may be decreased to ensure same decision from - /// \p Range.Start to \p Range.End. - bool tryToWiden(Instruction *I, VPBasicBlock *VPBB, VFRange &Range); - - /// Build a VPReplicationRecipe for \p I and enclose it within a Region if it - /// is predicated. \return \p VPBB augmented with this new recipe if \p I is - /// not predicated, otherwise \return a new VPBasicBlock that succeeds the new - /// Region. Update the packing decision of predicated instructions if they - /// feed \p I. Range.End may be decreased to ensure same recipe behavior from - /// \p Range.Start to \p Range.End. - VPBasicBlock *handleReplication( - Instruction *I, VFRange &Range, VPBasicBlock *VPBB, - DenseMap<Instruction *, VPReplicateRecipe *> &PredInst2Recipe, - VPlanPtr &Plan); - - /// Create a replicating region for instruction \p I that requires - /// predication. \p PredRecipe is a VPReplicateRecipe holding \p I. - VPRegionBlock *createReplicateRegion(Instruction *I, VPRecipeBase *PredRecipe, - VPlanPtr &Plan); - - /// Build a VPlan according to the information gathered by Legal. \return a - /// VPlan for vectorization factors \p Range.Start and up to \p Range.End - /// exclusive, possibly decreasing \p Range.End. - VPlanPtr buildVPlan(VFRange &Range, - const SmallPtrSetImpl<Value *> &NeedDef); -}; - } // end namespace llvm -namespace { - -/// \brief This holds vectorization requirements that must be verified late in -/// the process. The requirements are set by legalize and costmodel. Once -/// vectorization has been determined to be possible and profitable the -/// requirements can be verified by looking for metadata or compiler options. -/// For example, some loops require FP commutativity which is only allowed if -/// vectorization is explicitly specified or if the fast-math compiler option -/// has been provided. -/// Late evaluation of these requirements allows helpful diagnostics to be -/// composed that tells the user what need to be done to vectorize the loop. For -/// example, by specifying #pragma clang loop vectorize or -ffast-math. Late -/// evaluation should be used only when diagnostics can generated that can be -/// followed by a non-expert user. -class LoopVectorizationRequirements { -public: - LoopVectorizationRequirements(OptimizationRemarkEmitter &ORE) : ORE(ORE) {} - - void addUnsafeAlgebraInst(Instruction *I) { - // First unsafe algebra instruction. - if (!UnsafeAlgebraInst) - UnsafeAlgebraInst = I; - } - - void addRuntimePointerChecks(unsigned Num) { NumRuntimePointerChecks = Num; } - - bool doesNotMeet(Function *F, Loop *L, const LoopVectorizeHints &Hints) { - const char *PassName = Hints.vectorizeAnalysisPassName(); - bool Failed = false; - if (UnsafeAlgebraInst && !Hints.allowReordering()) { - ORE.emit([&]() { - return OptimizationRemarkAnalysisFPCommute( - PassName, "CantReorderFPOps", - UnsafeAlgebraInst->getDebugLoc(), - UnsafeAlgebraInst->getParent()) - << "loop not vectorized: cannot prove it is safe to reorder " - "floating-point operations"; - }); - Failed = true; - } - - // Test if runtime memcheck thresholds are exceeded. - bool PragmaThresholdReached = - NumRuntimePointerChecks > PragmaVectorizeMemoryCheckThreshold; - bool ThresholdReached = - NumRuntimePointerChecks > VectorizerParams::RuntimeMemoryCheckThreshold; - if ((ThresholdReached && !Hints.allowReordering()) || - PragmaThresholdReached) { - ORE.emit([&]() { - return OptimizationRemarkAnalysisAliasing(PassName, "CantReorderMemOps", - L->getStartLoc(), - L->getHeader()) - << "loop not vectorized: cannot prove it is safe to reorder " - "memory operations"; - }); - DEBUG(dbgs() << "LV: Too many memory checks needed.\n"); - Failed = true; - } +// Return true if \p OuterLp is an outer loop annotated with hints for explicit +// vectorization. The loop needs to be annotated with #pragma omp simd +// simdlen(#) or #pragma clang vectorize(enable) vectorize_width(#). If the +// vector length information is not provided, vectorization is not considered +// explicit. Interleave hints are not allowed either. These limitations will be +// relaxed in the future. +// Please, note that we are currently forced to abuse the pragma 'clang +// vectorize' semantics. This pragma provides *auto-vectorization hints* +// (i.e., LV must check that vectorization is legal) whereas pragma 'omp simd' +// provides *explicit vectorization hints* (LV can bypass legal checks and +// assume that vectorization is legal). However, both hints are implemented +// using the same metadata (llvm.loop.vectorize, processed by +// LoopVectorizeHints). This will be fixed in the future when the native IR +// representation for pragma 'omp simd' is introduced. +static bool isExplicitVecOuterLoop(Loop *OuterLp, + OptimizationRemarkEmitter *ORE) { + assert(!OuterLp->empty() && "This is not an outer loop"); + LoopVectorizeHints Hints(OuterLp, true /*DisableInterleaving*/, *ORE); + + // Only outer loops with an explicit vectorization hint are supported. + // Unannotated outer loops are ignored. + if (Hints.getForce() == LoopVectorizeHints::FK_Undefined) + return false; - return Failed; + Function *Fn = OuterLp->getHeader()->getParent(); + if (!Hints.allowVectorization(Fn, OuterLp, false /*AlwaysVectorize*/)) { + LLVM_DEBUG(dbgs() << "LV: Loop hints prevent outer loop vectorization.\n"); + return false; } -private: - unsigned NumRuntimePointerChecks = 0; - Instruction *UnsafeAlgebraInst = nullptr; + if (!Hints.getWidth()) { + LLVM_DEBUG(dbgs() << "LV: Not vectorizing: No user vector width.\n"); + emitMissedWarning(Fn, OuterLp, Hints, ORE); + return false; + } - /// Interface to emit optimization remarks. - OptimizationRemarkEmitter &ORE; -}; + if (Hints.getInterleave() > 1) { + // TODO: Interleave support is future work. + LLVM_DEBUG(dbgs() << "LV: Not vectorizing: Interleave is not supported for " + "outer loops.\n"); + emitMissedWarning(Fn, OuterLp, Hints, ORE); + return false; + } -} // end anonymous namespace + return true; +} -static void addAcyclicInnerLoop(Loop &L, SmallVectorImpl<Loop *> &V) { - if (L.empty()) { - if (!hasCyclesInLoopBody(L)) +static void collectSupportedLoops(Loop &L, LoopInfo *LI, + OptimizationRemarkEmitter *ORE, + SmallVectorImpl<Loop *> &V) { + // Collect inner loops and outer loops without irreducible control flow. For + // now, only collect outer loops that have explicit vectorization hints. If we + // are stress testing the VPlan H-CFG construction, we collect the outermost + // loop of every loop nest. + if (L.empty() || VPlanBuildStressTest || + (EnableVPlanNativePath && isExplicitVecOuterLoop(&L, ORE))) { + LoopBlocksRPO RPOT(&L); + RPOT.perform(LI); + if (!containsIrreducibleCFG<const BasicBlock *>(RPOT, *LI)) { V.push_back(&L); - return; + // TODO: Collect inner loops inside marked outer loops in case + // vectorization fails for the outer loop. Do not invoke + // 'containsIrreducibleCFG' again for inner loops when the outer loop is + // already known to be reducible. We can use an inherited attribute for + // that. + return; + } } for (Loop *InnerL : L) - addAcyclicInnerLoop(*InnerL, V); + collectSupportedLoops(*InnerL, LI, ORE, V); } namespace { @@ -2562,14 +1752,16 @@ struct LoopVectorize : public FunctionPass { //===----------------------------------------------------------------------===// Value *InnerLoopVectorizer::getBroadcastInstrs(Value *V) { - // We need to place the broadcast of invariant variables outside the loop. + // We need to place the broadcast of invariant variables outside the loop, + // but only if it's proven safe to do so. Else, broadcast will be inside + // vector loop body. Instruction *Instr = dyn_cast<Instruction>(V); - bool NewInstr = (Instr && Instr->getParent() == LoopVectorBody); - bool Invariant = OrigLoop->isLoopInvariant(V) && !NewInstr; - + bool SafeToHoist = OrigLoop->isLoopInvariant(V) && + (!Instr || + DT->dominates(Instr->getParent(), LoopVectorPreHeader)); // Place the code for broadcasting invariant variables in the new preheader. IRBuilder<>::InsertPointGuard Guard(Builder); - if (Invariant) + if (SafeToHoist) Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); // Broadcast the scalar into all locations in the vector. @@ -2580,6 +1772,8 @@ Value *InnerLoopVectorizer::getBroadcastInstrs(Value *V) { void InnerLoopVectorizer::createVectorIntOrFpInductionPHI( const InductionDescriptor &II, Value *Step, Instruction *EntryVal) { + assert((isa<PHINode>(EntryVal) || isa<TruncInst>(EntryVal)) && + "Expected either an induction phi-node or a truncate of it!"); Value *Start = II.getStartValue(); // Construct the initial value of the vector IV in the vector loop preheader @@ -2627,14 +1821,18 @@ void InnerLoopVectorizer::createVectorIntOrFpInductionPHI( // factor. The last of those goes into the PHI. PHINode *VecInd = PHINode::Create(SteppedStart->getType(), 2, "vec.ind", &*LoopVectorBody->getFirstInsertionPt()); + VecInd->setDebugLoc(EntryVal->getDebugLoc()); Instruction *LastInduction = VecInd; for (unsigned Part = 0; Part < UF; ++Part) { VectorLoopValueMap.setVectorValue(EntryVal, Part, LastInduction); - recordVectorLoopValueForInductionCast(II, LastInduction, Part); + if (isa<TruncInst>(EntryVal)) addMetadata(LastInduction, EntryVal); + recordVectorLoopValueForInductionCast(II, EntryVal, LastInduction, Part); + LastInduction = cast<Instruction>(addFastMathFlag( Builder.CreateBinOp(AddOp, LastInduction, SplatVF, "step.add"))); + LastInduction->setDebugLoc(EntryVal->getDebugLoc()); } // Move the last step to the end of the latch block. This ensures consistent @@ -2665,8 +1863,20 @@ bool InnerLoopVectorizer::needsScalarInduction(Instruction *IV) const { } void InnerLoopVectorizer::recordVectorLoopValueForInductionCast( - const InductionDescriptor &ID, Value *VectorLoopVal, unsigned Part, - unsigned Lane) { + const InductionDescriptor &ID, const Instruction *EntryVal, + Value *VectorLoopVal, unsigned Part, unsigned Lane) { + assert((isa<PHINode>(EntryVal) || isa<TruncInst>(EntryVal)) && + "Expected either an induction phi-node or a truncate of it!"); + + // This induction variable is not the phi from the original loop but the + // newly-created IV based on the proof that casted Phi is equal to the + // uncasted Phi in the vectorized loop (under a runtime guard possibly). It + // re-uses the same InductionDescriptor that original IV uses but we don't + // have to do any recording in this case - that is done when original IV is + // processed. + if (isa<TruncInst>(EntryVal)) + return; + const SmallVectorImpl<Instruction *> &Casts = ID.getCastInsts(); if (Casts.empty()) return; @@ -2754,15 +1964,16 @@ void InnerLoopVectorizer::widenIntOrFpInduction(PHINode *IV, TruncInst *Trunc) { // If we haven't yet vectorized the induction variable, splat the scalar // induction variable, and build the necessary step vectors. + // TODO: Don't do it unless the vectorized IV is really required. if (!VectorizedIV) { Value *Broadcasted = getBroadcastInstrs(ScalarIV); for (unsigned Part = 0; Part < UF; ++Part) { Value *EntryPart = getStepVector(Broadcasted, VF * Part, Step, ID.getInductionOpcode()); VectorLoopValueMap.setVectorValue(EntryVal, Part, EntryPart); - recordVectorLoopValueForInductionCast(ID, EntryPart, Part); if (Trunc) addMetadata(EntryPart, Trunc); + recordVectorLoopValueForInductionCast(ID, EntryVal, EntryPart, Part); } } @@ -2833,7 +2044,7 @@ Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, Value *Step, } void InnerLoopVectorizer::buildScalarSteps(Value *ScalarIV, Value *Step, - Value *EntryVal, + Instruction *EntryVal, const InductionDescriptor &ID) { // We shouldn't have to build scalar steps if we aren't vectorizing. assert(VF > 1 && "VF should be greater than one"); @@ -2868,25 +2079,11 @@ void InnerLoopVectorizer::buildScalarSteps(Value *ScalarIV, Value *Step, auto *Mul = addFastMathFlag(Builder.CreateBinOp(MulOp, StartIdx, Step)); auto *Add = addFastMathFlag(Builder.CreateBinOp(AddOp, ScalarIV, Mul)); VectorLoopValueMap.setScalarValue(EntryVal, {Part, Lane}, Add); - recordVectorLoopValueForInductionCast(ID, Add, Part, Lane); + recordVectorLoopValueForInductionCast(ID, EntryVal, Add, Part, Lane); } } } -int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { - const ValueToValueMap &Strides = getSymbolicStrides() ? *getSymbolicStrides() : - ValueToValueMap(); - - int Stride = getPtrStride(PSE, Ptr, TheLoop, Strides, true, false); - if (Stride == 1 || Stride == -1) - return Stride; - return 0; -} - -bool LoopVectorizationLegality::isUniform(Value *V) { - return LAI->isUniform(V); -} - Value *InnerLoopVectorizer::getOrCreateVectorValue(Value *V, unsigned Part) { assert(V != Induction && "The new induction variable should not be used."); assert(!V->getType()->isVectorTy() && "Can't widen a vector"); @@ -3046,7 +2243,7 @@ Value *InnerLoopVectorizer::reverseVector(Value *Vec) { // <0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11> ; Interleave R,G,B elements // store <12 x i32> %interleaved.vec ; Write 4 tuples of R,G,B void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { - const InterleaveGroup *Group = Legal->getInterleavedAccessGroup(Instr); + const InterleaveGroup *Group = Cost->getInterleavedAccessGroup(Instr); assert(Group && "Fail to get an interleaved access group."); // Skip if current instruction is not the insert position. @@ -3054,7 +2251,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { return; const DataLayout &DL = Instr->getModule()->getDataLayout(); - Value *Ptr = getPointerOperand(Instr); + Value *Ptr = getLoadStorePointerOperand(Instr); // Prepare for the vector type of the interleaved load/store. Type *ScalarTy = getMemInstValueType(Instr); @@ -3076,6 +2273,10 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { if (Group->isReverse()) Index += (VF - 1) * Group->getFactor(); + bool InBounds = false; + if (auto *gep = dyn_cast<GetElementPtrInst>(Ptr->stripPointerCasts())) + InBounds = gep->isInBounds(); + for (unsigned Part = 0; Part < UF; Part++) { Value *NewPtr = getOrCreateScalarValue(Ptr, {Part, 0}); @@ -3091,6 +2292,8 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { // A[i+2] = c; // Member of index 2 (Current instruction) // Current pointer is pointed to A[i+2], adjust it to A[i]. NewPtr = Builder.CreateGEP(NewPtr, Builder.getInt32(-Index)); + if (InBounds) + cast<GetElementPtrInst>(NewPtr)->setIsInBounds(true); // Cast to the vector pointer type. NewPtrs.push_back(Builder.CreateBitCast(NewPtr, PtrTy)); @@ -3196,7 +2399,7 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr, Type *ScalarDataTy = getMemInstValueType(Instr); Type *DataTy = VectorType::get(ScalarDataTy, VF); - Value *Ptr = getPointerOperand(Instr); + Value *Ptr = getLoadStorePointerOperand(Instr); unsigned Alignment = getMemInstAlignment(Instr); // An alignment of 0 means target abi alignment. We need to use the scalar's // target abi alignment in such a case. @@ -3227,10 +2430,37 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr, if (isMaskRequired) Mask = *BlockInMask; + bool InBounds = false; + if (auto *gep = dyn_cast<GetElementPtrInst>( + getLoadStorePointerOperand(Instr)->stripPointerCasts())) + InBounds = gep->isInBounds(); + + const auto CreateVecPtr = [&](unsigned Part, Value *Ptr) -> Value * { + // Calculate the pointer for the specific unroll-part. + GetElementPtrInst *PartPtr = nullptr; + + if (Reverse) { + // If the address is consecutive but reversed, then the + // wide store needs to start at the last vector element. + PartPtr = cast<GetElementPtrInst>( + Builder.CreateGEP(Ptr, Builder.getInt32(-Part * VF))); + PartPtr->setIsInBounds(InBounds); + PartPtr = cast<GetElementPtrInst>( + Builder.CreateGEP(PartPtr, Builder.getInt32(1 - VF))); + PartPtr->setIsInBounds(InBounds); + if (isMaskRequired) // Reverse of a null all-one mask is a null mask. + Mask[Part] = reverseVector(Mask[Part]); + } else { + PartPtr = cast<GetElementPtrInst>( + Builder.CreateGEP(Ptr, Builder.getInt32(Part * VF))); + PartPtr->setIsInBounds(InBounds); + } + + return Builder.CreateBitCast(PartPtr, DataTy->getPointerTo(AddressSpace)); + }; + // Handle Stores: if (SI) { - assert(!Legal->isUniform(SI->getPointerOperand()) && - "We do not allow storing to uniform addresses"); setDebugLocFromInst(Builder, SI); for (unsigned Part = 0; Part < UF; ++Part) { @@ -3242,30 +2472,14 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr, NewSI = Builder.CreateMaskedScatter(StoredVal, VectorGep, Alignment, MaskPart); } else { - // Calculate the pointer for the specific unroll-part. - Value *PartPtr = - Builder.CreateGEP(nullptr, Ptr, Builder.getInt32(Part * VF)); - if (Reverse) { // If we store to reverse consecutive memory locations, then we need // to reverse the order of elements in the stored value. StoredVal = reverseVector(StoredVal); // We don't want to update the value in the map as it might be used in // another expression. So don't call resetVectorValue(StoredVal). - - // If the address is consecutive but reversed, then the - // wide store needs to start at the last vector element. - PartPtr = - Builder.CreateGEP(nullptr, Ptr, Builder.getInt32(-Part * VF)); - PartPtr = - Builder.CreateGEP(nullptr, PartPtr, Builder.getInt32(1 - VF)); - if (isMaskRequired) // Reverse of a null all-one mask is a null mask. - Mask[Part] = reverseVector(Mask[Part]); } - - Value *VecPtr = - Builder.CreateBitCast(PartPtr, DataTy->getPointerTo(AddressSpace)); - + auto *VecPtr = CreateVecPtr(Part, Ptr); if (isMaskRequired) NewSI = Builder.CreateMaskedStore(StoredVal, VecPtr, Alignment, Mask[Part]); @@ -3289,21 +2503,7 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr, nullptr, "wide.masked.gather"); addMetadata(NewLI, LI); } else { - // Calculate the pointer for the specific unroll-part. - Value *PartPtr = - Builder.CreateGEP(nullptr, Ptr, Builder.getInt32(Part * VF)); - - if (Reverse) { - // If the address is consecutive but reversed, then the - // wide load needs to start at the last vector element. - PartPtr = Builder.CreateGEP(nullptr, Ptr, Builder.getInt32(-Part * VF)); - PartPtr = Builder.CreateGEP(nullptr, PartPtr, Builder.getInt32(1 - VF)); - if (isMaskRequired) // Reverse of a null all-one mask is a null mask. - Mask[Part] = reverseVector(Mask[Part]); - } - - Value *VecPtr = - Builder.CreateBitCast(PartPtr, DataTy->getPointerTo(AddressSpace)); + auto *VecPtr = CreateVecPtr(Part, Ptr); if (isMaskRequired) NewLI = Builder.CreateMaskedLoad(VecPtr, Alignment, Mask[Part], UndefValue::get(DataTy), @@ -3457,7 +2657,7 @@ Value *InnerLoopVectorizer::getOrCreateVectorTripCount(Loop *L) { // does not evenly divide the trip count, no adjustment is necessary since // there will already be scalar iterations. Note that the minimum iterations // check ensures that N >= Step. - if (VF > 1 && Legal->requiresScalarEpilogue()) { + if (VF > 1 && Cost->requiresScalarEpilogue()) { auto *IsZero = Builder.CreateICmpEQ(R, ConstantInt::get(R->getType(), 0)); R = Builder.CreateSelect(IsZero, Step, R); } @@ -3508,8 +2708,8 @@ void InnerLoopVectorizer::emitMinimumIterationCountCheck(Loop *L, // vector trip count is zero. This check also covers the case where adding one // to the backedge-taken count overflowed leading to an incorrect trip count // of zero. In this case we will also jump to the scalar loop. - auto P = Legal->requiresScalarEpilogue() ? ICmpInst::ICMP_ULE - : ICmpInst::ICMP_ULT; + auto P = Cost->requiresScalarEpilogue() ? ICmpInst::ICMP_ULE + : ICmpInst::ICMP_ULT; Value *CheckMinIters = Builder.CreateICmp( P, Count, ConstantInt::get(Count->getType(), VF * UF), "min.iters.check"); @@ -3714,6 +2914,8 @@ BasicBlock *InnerLoopVectorizer::createVectorizedLoopSkeleton() { // Create phi nodes to merge from the backedge-taken check block. PHINode *BCResumeVal = PHINode::Create( OrigPhi->getType(), 3, "bc.resume.val", ScalarPH->getTerminator()); + // Copy original phi DL over to the new one. + BCResumeVal->setDebugLoc(OrigPhi->getDebugLoc()); Value *&EndValue = IVEndValues[OrigPhi]; if (OrigPhi == OldInduction) { // We know what the end value is. @@ -3871,7 +3073,7 @@ struct CSEDenseMapInfo { } // end anonymous namespace -///\brief Perform cse of induction variable instructions. +///Perform cse of induction variable instructions. static void cse(BasicBlock *BB) { // Perform simple cse. SmallDenseMap<Instruction *, Instruction *, 4, CSEDenseMapInfo> CSEMap; @@ -3893,7 +3095,7 @@ static void cse(BasicBlock *BB) { } } -/// \brief Estimate the overhead of scalarizing an instruction. This is a +/// Estimate the overhead of scalarizing an instruction. This is a /// convenience wrapper for the type-based getScalarizationOverhead API. static unsigned getScalarizationOverhead(Instruction *I, unsigned VF, const TargetTransformInfo &TTI) { @@ -4074,7 +3276,7 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths() { SI->getOperand(1), VectorType::get(ScalarTruncatedTy, Elements1)); NewI = B.CreateShuffleVector(O0, O1, SI->getMask()); - } else if (isa<LoadInst>(I)) { + } else if (isa<LoadInst>(I) || isa<PHINode>(I)) { // Don't do anything with the operands, just extend the result. continue; } else if (auto *IE = dyn_cast<InsertElementInst>(I)) { @@ -4089,7 +3291,8 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths() { EE->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements)); NewI = B.CreateExtractElement(O0, EE->getOperand(2)); } else { - llvm_unreachable("Unhandled instruction type!"); + // If we don't know what to do, be conservative and don't do anything. + continue; } // Lastly, extend the result. @@ -4164,15 +3367,12 @@ void InnerLoopVectorizer::fixCrossIterationPHIs() { // the currently empty PHI nodes. At this point every instruction in the // original loop is widened to a vector form so we can use them to construct // the incoming edges. - for (Instruction &I : *OrigLoop->getHeader()) { - PHINode *Phi = dyn_cast<PHINode>(&I); - if (!Phi) - break; + for (PHINode &Phi : OrigLoop->getHeader()->phis()) { // Handle first-order recurrences and reductions that need to be fixed. - if (Legal->isFirstOrderRecurrence(Phi)) - fixFirstOrderRecurrence(Phi); - else if (Legal->isReductionVariable(Phi)) - fixReduction(Phi); + if (Legal->isFirstOrderRecurrence(&Phi)) + fixFirstOrderRecurrence(&Phi); + else if (Legal->isReductionVariable(&Phi)) + fixReduction(&Phi); } } @@ -4335,15 +3535,11 @@ void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) { // Finally, fix users of the recurrence outside the loop. The users will need // either the last value of the scalar recurrence or the last value of the // vector recurrence we extracted in the middle block. Since the loop is in - // LCSSA form, we just need to find the phi node for the original scalar + // LCSSA form, we just need to find all the phi nodes for the original scalar // recurrence in the exit block, and then add an edge for the middle block. - for (auto &I : *LoopExitBlock) { - auto *LCSSAPhi = dyn_cast<PHINode>(&I); - if (!LCSSAPhi) - break; - if (LCSSAPhi->getIncomingValue(0) == Phi) { - LCSSAPhi->addIncoming(ExtractForPhiUsedOutsideLoop, LoopMiddleBlock); - break; + for (PHINode &LCSSAPhi : LoopExitBlock->phis()) { + if (LCSSAPhi.getIncomingValue(0) == Phi) { + LCSSAPhi.addIncoming(ExtractForPhiUsedOutsideLoop, LoopMiddleBlock); } } } @@ -4499,21 +3695,15 @@ void InnerLoopVectorizer::fixReduction(PHINode *Phi) { // inside and outside of the scalar remainder loop. // We know that the loop is in LCSSA form. We need to update the // PHI nodes in the exit blocks. - for (BasicBlock::iterator LEI = LoopExitBlock->begin(), - LEE = LoopExitBlock->end(); - LEI != LEE; ++LEI) { - PHINode *LCSSAPhi = dyn_cast<PHINode>(LEI); - if (!LCSSAPhi) - break; - + for (PHINode &LCSSAPhi : LoopExitBlock->phis()) { // All PHINodes need to have a single entry edge, or two if // we already fixed them. - assert(LCSSAPhi->getNumIncomingValues() < 3 && "Invalid LCSSA PHI"); + assert(LCSSAPhi.getNumIncomingValues() < 3 && "Invalid LCSSA PHI"); // We found a reduction value exit-PHI. Update it with the // incoming bypass edge. - if (LCSSAPhi->getIncomingValue(0) == LoopExitInst) - LCSSAPhi->addIncoming(ReducedPartRdx, LoopMiddleBlock); + if (LCSSAPhi.getIncomingValue(0) == LoopExitInst) + LCSSAPhi.addIncoming(ReducedPartRdx, LoopMiddleBlock); } // end of the LCSSA phi scan. // Fix the scalar loop reduction variable with the incoming reduction sum @@ -4528,14 +3718,11 @@ void InnerLoopVectorizer::fixReduction(PHINode *Phi) { } void InnerLoopVectorizer::fixLCSSAPHIs() { - for (Instruction &LEI : *LoopExitBlock) { - auto *LCSSAPhi = dyn_cast<PHINode>(&LEI); - if (!LCSSAPhi) - break; - if (LCSSAPhi->getNumIncomingValues() == 1) { - assert(OrigLoop->isLoopInvariant(LCSSAPhi->getIncomingValue(0)) && + for (PHINode &LCSSAPhi : LoopExitBlock->phis()) { + if (LCSSAPhi.getNumIncomingValues() == 1) { + assert(OrigLoop->isLoopInvariant(LCSSAPhi.getIncomingValue(0)) && "Incoming value isn't loop invariant"); - LCSSAPhi->addIncoming(LCSSAPhi->getIncomingValue(0), LoopMiddleBlock); + LCSSAPhi.addIncoming(LCSSAPhi.getIncomingValue(0), LoopMiddleBlock); } } } @@ -4955,7 +4142,7 @@ void InnerLoopVectorizer::widenInstruction(Instruction &I) { default: // This instruction is not vectorized by simple widening. - DEBUG(dbgs() << "LV: Found an unhandled instruction: " << I); + LLVM_DEBUG(dbgs() << "LV: Found an unhandled instruction: " << I); llvm_unreachable("Unhandled instruction!"); } // end of switch. } @@ -4973,467 +4160,7 @@ void InnerLoopVectorizer::updateAnalysis() { DT->addNewBlock(LoopScalarPreHeader, LoopBypassBlocks[0]); DT->changeImmediateDominator(LoopScalarBody, LoopScalarPreHeader); DT->changeImmediateDominator(LoopExitBlock, LoopBypassBlocks[0]); - DEBUG(DT->verifyDomTree()); -} - -/// \brief Check whether it is safe to if-convert this phi node. -/// -/// Phi nodes with constant expressions that can trap are not safe to if -/// convert. -static bool canIfConvertPHINodes(BasicBlock *BB) { - for (Instruction &I : *BB) { - auto *Phi = dyn_cast<PHINode>(&I); - if (!Phi) - return true; - for (Value *V : Phi->incoming_values()) - if (auto *C = dyn_cast<Constant>(V)) - if (C->canTrap()) - return false; - } - return true; -} - -bool LoopVectorizationLegality::canVectorizeWithIfConvert() { - if (!EnableIfConversion) { - ORE->emit(createMissedAnalysis("IfConversionDisabled") - << "if-conversion is disabled"); - return false; - } - - assert(TheLoop->getNumBlocks() > 1 && "Single block loops are vectorizable"); - - // A list of pointers that we can safely read and write to. - SmallPtrSet<Value *, 8> SafePointes; - - // Collect safe addresses. - for (BasicBlock *BB : TheLoop->blocks()) { - if (blockNeedsPredication(BB)) - continue; - - for (Instruction &I : *BB) - if (auto *Ptr = getPointerOperand(&I)) - SafePointes.insert(Ptr); - } - - // Collect the blocks that need predication. - BasicBlock *Header = TheLoop->getHeader(); - for (BasicBlock *BB : TheLoop->blocks()) { - // We don't support switch statements inside loops. - if (!isa<BranchInst>(BB->getTerminator())) { - ORE->emit(createMissedAnalysis("LoopContainsSwitch", BB->getTerminator()) - << "loop contains a switch statement"); - return false; - } - - // We must be able to predicate all blocks that need to be predicated. - if (blockNeedsPredication(BB)) { - if (!blockCanBePredicated(BB, SafePointes)) { - ORE->emit(createMissedAnalysis("NoCFGForSelect", BB->getTerminator()) - << "control flow cannot be substituted for a select"); - return false; - } - } else if (BB != Header && !canIfConvertPHINodes(BB)) { - ORE->emit(createMissedAnalysis("NoCFGForSelect", BB->getTerminator()) - << "control flow cannot be substituted for a select"); - return false; - } - } - - // We can if-convert this loop. - return true; -} - -bool LoopVectorizationLegality::canVectorize() { - // Store the result and return it at the end instead of exiting early, in case - // allowExtraAnalysis is used to report multiple reasons for not vectorizing. - bool Result = true; - - bool DoExtraAnalysis = ORE->allowExtraAnalysis(DEBUG_TYPE); - // We must have a loop in canonical form. Loops with indirectbr in them cannot - // be canonicalized. - if (!TheLoop->getLoopPreheader()) { - DEBUG(dbgs() << "LV: Loop doesn't have a legal pre-header.\n"); - ORE->emit(createMissedAnalysis("CFGNotUnderstood") - << "loop control flow is not understood by vectorizer"); - if (DoExtraAnalysis) - Result = false; - else - return false; - } - - // FIXME: The code is currently dead, since the loop gets sent to - // LoopVectorizationLegality is already an innermost loop. - // - // We can only vectorize innermost loops. - if (!TheLoop->empty()) { - ORE->emit(createMissedAnalysis("NotInnermostLoop") - << "loop is not the innermost loop"); - if (DoExtraAnalysis) - Result = false; - else - return false; - } - - // We must have a single backedge. - if (TheLoop->getNumBackEdges() != 1) { - ORE->emit(createMissedAnalysis("CFGNotUnderstood") - << "loop control flow is not understood by vectorizer"); - if (DoExtraAnalysis) - Result = false; - else - return false; - } - - // We must have a single exiting block. - if (!TheLoop->getExitingBlock()) { - ORE->emit(createMissedAnalysis("CFGNotUnderstood") - << "loop control flow is not understood by vectorizer"); - if (DoExtraAnalysis) - Result = false; - else - return false; - } - - // We only handle bottom-tested loops, i.e. loop in which the condition is - // checked at the end of each iteration. With that we can assume that all - // instructions in the loop are executed the same number of times. - if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) { - ORE->emit(createMissedAnalysis("CFGNotUnderstood") - << "loop control flow is not understood by vectorizer"); - if (DoExtraAnalysis) - Result = false; - else - return false; - } - - // We need to have a loop header. - DEBUG(dbgs() << "LV: Found a loop: " << TheLoop->getHeader()->getName() - << '\n'); - - // Check if we can if-convert non-single-bb loops. - unsigned NumBlocks = TheLoop->getNumBlocks(); - if (NumBlocks != 1 && !canVectorizeWithIfConvert()) { - DEBUG(dbgs() << "LV: Can't if-convert the loop.\n"); - if (DoExtraAnalysis) - Result = false; - else - return false; - } - - // Check if we can vectorize the instructions and CFG in this loop. - if (!canVectorizeInstrs()) { - DEBUG(dbgs() << "LV: Can't vectorize the instructions or CFG\n"); - if (DoExtraAnalysis) - Result = false; - else - return false; - } - - // Go over each instruction and look at memory deps. - if (!canVectorizeMemory()) { - DEBUG(dbgs() << "LV: Can't vectorize due to memory conflicts\n"); - if (DoExtraAnalysis) - Result = false; - else - return false; - } - - DEBUG(dbgs() << "LV: We can vectorize this loop" - << (LAI->getRuntimePointerChecking()->Need - ? " (with a runtime bound check)" - : "") - << "!\n"); - - bool UseInterleaved = TTI->enableInterleavedAccessVectorization(); - - // If an override option has been passed in for interleaved accesses, use it. - if (EnableInterleavedMemAccesses.getNumOccurrences() > 0) - UseInterleaved = EnableInterleavedMemAccesses; - - // Analyze interleaved memory accesses. - if (UseInterleaved) - InterleaveInfo.analyzeInterleaving(*getSymbolicStrides()); - - unsigned SCEVThreshold = VectorizeSCEVCheckThreshold; - if (Hints->getForce() == LoopVectorizeHints::FK_Enabled) - SCEVThreshold = PragmaVectorizeSCEVCheckThreshold; - - if (PSE.getUnionPredicate().getComplexity() > SCEVThreshold) { - ORE->emit(createMissedAnalysis("TooManySCEVRunTimeChecks") - << "Too many SCEV assumptions need to be made and checked " - << "at runtime"); - DEBUG(dbgs() << "LV: Too many SCEV checks needed.\n"); - if (DoExtraAnalysis) - Result = false; - else - return false; - } - - // Okay! We've done all the tests. If any have failed, return false. Otherwise - // we can vectorize, and at this point we don't have any other mem analysis - // which may limit our maximum vectorization factor, so just return true with - // no restrictions. - return Result; -} - -static Type *convertPointerToIntegerType(const DataLayout &DL, Type *Ty) { - if (Ty->isPointerTy()) - return DL.getIntPtrType(Ty); - - // It is possible that char's or short's overflow when we ask for the loop's - // trip count, work around this by changing the type size. - if (Ty->getScalarSizeInBits() < 32) - return Type::getInt32Ty(Ty->getContext()); - - return Ty; -} - -static Type *getWiderType(const DataLayout &DL, Type *Ty0, Type *Ty1) { - Ty0 = convertPointerToIntegerType(DL, Ty0); - Ty1 = convertPointerToIntegerType(DL, Ty1); - if (Ty0->getScalarSizeInBits() > Ty1->getScalarSizeInBits()) - return Ty0; - return Ty1; -} - -/// \brief Check that the instruction has outside loop users and is not an -/// identified reduction variable. -static bool hasOutsideLoopUser(const Loop *TheLoop, Instruction *Inst, - SmallPtrSetImpl<Value *> &AllowedExit) { - // Reduction and Induction instructions are allowed to have exit users. All - // other instructions must not have external users. - if (!AllowedExit.count(Inst)) - // Check that all of the users of the loop are inside the BB. - for (User *U : Inst->users()) { - Instruction *UI = cast<Instruction>(U); - // This user may be a reduction exit value. - if (!TheLoop->contains(UI)) { - DEBUG(dbgs() << "LV: Found an outside user for : " << *UI << '\n'); - return true; - } - } - return false; -} - -void LoopVectorizationLegality::addInductionPhi( - PHINode *Phi, const InductionDescriptor &ID, - SmallPtrSetImpl<Value *> &AllowedExit) { - Inductions[Phi] = ID; - - // In case this induction also comes with casts that we know we can ignore - // in the vectorized loop body, record them here. All casts could be recorded - // here for ignoring, but suffices to record only the first (as it is the - // only one that may bw used outside the cast sequence). - const SmallVectorImpl<Instruction *> &Casts = ID.getCastInsts(); - if (!Casts.empty()) - InductionCastsToIgnore.insert(*Casts.begin()); - - Type *PhiTy = Phi->getType(); - const DataLayout &DL = Phi->getModule()->getDataLayout(); - - // Get the widest type. - if (!PhiTy->isFloatingPointTy()) { - if (!WidestIndTy) - WidestIndTy = convertPointerToIntegerType(DL, PhiTy); - else - WidestIndTy = getWiderType(DL, PhiTy, WidestIndTy); - } - - // Int inductions are special because we only allow one IV. - if (ID.getKind() == InductionDescriptor::IK_IntInduction && - ID.getConstIntStepValue() && - ID.getConstIntStepValue()->isOne() && - isa<Constant>(ID.getStartValue()) && - cast<Constant>(ID.getStartValue())->isNullValue()) { - - // Use the phi node with the widest type as induction. Use the last - // one if there are multiple (no good reason for doing this other - // than it is expedient). We've checked that it begins at zero and - // steps by one, so this is a canonical induction variable. - if (!PrimaryInduction || PhiTy == WidestIndTy) - PrimaryInduction = Phi; - } - - // Both the PHI node itself, and the "post-increment" value feeding - // back into the PHI node may have external users. - // We can allow those uses, except if the SCEVs we have for them rely - // on predicates that only hold within the loop, since allowing the exit - // currently means re-using this SCEV outside the loop. - if (PSE.getUnionPredicate().isAlwaysTrue()) { - AllowedExit.insert(Phi); - AllowedExit.insert(Phi->getIncomingValueForBlock(TheLoop->getLoopLatch())); - } - - DEBUG(dbgs() << "LV: Found an induction variable.\n"); -} - -bool LoopVectorizationLegality::canVectorizeInstrs() { - BasicBlock *Header = TheLoop->getHeader(); - - // Look for the attribute signaling the absence of NaNs. - Function &F = *Header->getParent(); - HasFunNoNaNAttr = - F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true"; - - // For each block in the loop. - for (BasicBlock *BB : TheLoop->blocks()) { - // Scan the instructions in the block and look for hazards. - for (Instruction &I : *BB) { - if (auto *Phi = dyn_cast<PHINode>(&I)) { - Type *PhiTy = Phi->getType(); - // Check that this PHI type is allowed. - if (!PhiTy->isIntegerTy() && !PhiTy->isFloatingPointTy() && - !PhiTy->isPointerTy()) { - ORE->emit(createMissedAnalysis("CFGNotUnderstood", Phi) - << "loop control flow is not understood by vectorizer"); - DEBUG(dbgs() << "LV: Found an non-int non-pointer PHI.\n"); - return false; - } - - // If this PHINode is not in the header block, then we know that we - // can convert it to select during if-conversion. No need to check if - // the PHIs in this block are induction or reduction variables. - if (BB != Header) { - // Check that this instruction has no outside users or is an - // identified reduction value with an outside user. - if (!hasOutsideLoopUser(TheLoop, Phi, AllowedExit)) - continue; - ORE->emit(createMissedAnalysis("NeitherInductionNorReduction", Phi) - << "value could not be identified as " - "an induction or reduction variable"); - return false; - } - - // We only allow if-converted PHIs with exactly two incoming values. - if (Phi->getNumIncomingValues() != 2) { - ORE->emit(createMissedAnalysis("CFGNotUnderstood", Phi) - << "control flow not understood by vectorizer"); - DEBUG(dbgs() << "LV: Found an invalid PHI.\n"); - return false; - } - - RecurrenceDescriptor RedDes; - if (RecurrenceDescriptor::isReductionPHI(Phi, TheLoop, RedDes)) { - if (RedDes.hasUnsafeAlgebra()) - Requirements->addUnsafeAlgebraInst(RedDes.getUnsafeAlgebraInst()); - AllowedExit.insert(RedDes.getLoopExitInstr()); - Reductions[Phi] = RedDes; - continue; - } - - InductionDescriptor ID; - if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID)) { - addInductionPhi(Phi, ID, AllowedExit); - if (ID.hasUnsafeAlgebra() && !HasFunNoNaNAttr) - Requirements->addUnsafeAlgebraInst(ID.getUnsafeAlgebraInst()); - continue; - } - - if (RecurrenceDescriptor::isFirstOrderRecurrence(Phi, TheLoop, - SinkAfter, DT)) { - FirstOrderRecurrences.insert(Phi); - continue; - } - - // As a last resort, coerce the PHI to a AddRec expression - // and re-try classifying it a an induction PHI. - if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID, true)) { - addInductionPhi(Phi, ID, AllowedExit); - continue; - } - - ORE->emit(createMissedAnalysis("NonReductionValueUsedOutsideLoop", Phi) - << "value that could not be identified as " - "reduction is used outside the loop"); - DEBUG(dbgs() << "LV: Found an unidentified PHI." << *Phi << "\n"); - return false; - } // end of PHI handling - - // We handle calls that: - // * Are debug info intrinsics. - // * Have a mapping to an IR intrinsic. - // * Have a vector version available. - auto *CI = dyn_cast<CallInst>(&I); - if (CI && !getVectorIntrinsicIDForCall(CI, TLI) && - !isa<DbgInfoIntrinsic>(CI) && - !(CI->getCalledFunction() && TLI && - TLI->isFunctionVectorizable(CI->getCalledFunction()->getName()))) { - ORE->emit(createMissedAnalysis("CantVectorizeCall", CI) - << "call instruction cannot be vectorized"); - DEBUG(dbgs() << "LV: Found a non-intrinsic, non-libfunc callsite.\n"); - return false; - } - - // Intrinsics such as powi,cttz and ctlz are legal to vectorize if the - // second argument is the same (i.e. loop invariant) - if (CI && hasVectorInstrinsicScalarOpd( - getVectorIntrinsicIDForCall(CI, TLI), 1)) { - auto *SE = PSE.getSE(); - if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(1)), TheLoop)) { - ORE->emit(createMissedAnalysis("CantVectorizeIntrinsic", CI) - << "intrinsic instruction cannot be vectorized"); - DEBUG(dbgs() << "LV: Found unvectorizable intrinsic " << *CI << "\n"); - return false; - } - } - - // Check that the instruction return type is vectorizable. - // Also, we can't vectorize extractelement instructions. - if ((!VectorType::isValidElementType(I.getType()) && - !I.getType()->isVoidTy()) || - isa<ExtractElementInst>(I)) { - ORE->emit(createMissedAnalysis("CantVectorizeInstructionReturnType", &I) - << "instruction return type cannot be vectorized"); - DEBUG(dbgs() << "LV: Found unvectorizable type.\n"); - return false; - } - - // Check that the stored type is vectorizable. - if (auto *ST = dyn_cast<StoreInst>(&I)) { - Type *T = ST->getValueOperand()->getType(); - if (!VectorType::isValidElementType(T)) { - ORE->emit(createMissedAnalysis("CantVectorizeStore", ST) - << "store instruction cannot be vectorized"); - return false; - } - - // FP instructions can allow unsafe algebra, thus vectorizable by - // non-IEEE-754 compliant SIMD units. - // This applies to floating-point math operations and calls, not memory - // operations, shuffles, or casts, as they don't change precision or - // semantics. - } else if (I.getType()->isFloatingPointTy() && (CI || I.isBinaryOp()) && - !I.isFast()) { - DEBUG(dbgs() << "LV: Found FP op with unsafe algebra.\n"); - Hints->setPotentiallyUnsafe(); - } - - // Reduction instructions are allowed to have exit users. - // All other instructions must not have external users. - if (hasOutsideLoopUser(TheLoop, &I, AllowedExit)) { - ORE->emit(createMissedAnalysis("ValueUsedOutsideLoop", &I) - << "value cannot be used outside the loop"); - return false; - } - } // next instr. - } - - if (!PrimaryInduction) { - DEBUG(dbgs() << "LV: Did not find one integer induction var.\n"); - if (Inductions.empty()) { - ORE->emit(createMissedAnalysis("NoInductionVariable") - << "loop induction variable could not be identified"); - return false; - } - } - - // Now we know the widest induction type, check if our found induction - // is the same size. If it's not, unset it here and InnerLoopVectorizer - // will create another. - if (PrimaryInduction && WidestIndTy != PrimaryInduction->getType()) - PrimaryInduction = nullptr; - - return true; + assert(DT->verify(DominatorTree::VerificationLevel::Fast)); } void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) { @@ -5461,7 +4188,7 @@ void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) { if (auto *Store = dyn_cast<StoreInst>(MemAccess)) if (Ptr == Store->getValueOperand()) return WideningDecision == CM_Scalarize; - assert(Ptr == getPointerOperand(MemAccess) && + assert(Ptr == getLoadStorePointerOperand(MemAccess) && "Ptr is neither a value or pointer operand"); return WideningDecision != CM_GatherScatter; }; @@ -5527,7 +4254,7 @@ void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) { } for (auto *I : ScalarPtrs) if (!PossibleNonScalarPtrs.count(I)) { - DEBUG(dbgs() << "LV: Found scalar instruction: " << *I << "\n"); + LLVM_DEBUG(dbgs() << "LV: Found scalar instruction: " << *I << "\n"); Worklist.insert(I); } @@ -5544,8 +4271,9 @@ void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) { continue; Worklist.insert(Ind); Worklist.insert(IndUpdate); - DEBUG(dbgs() << "LV: Found scalar instruction: " << *Ind << "\n"); - DEBUG(dbgs() << "LV: Found scalar instruction: " << *IndUpdate << "\n"); + LLVM_DEBUG(dbgs() << "LV: Found scalar instruction: " << *Ind << "\n"); + LLVM_DEBUG(dbgs() << "LV: Found scalar instruction: " << *IndUpdate + << "\n"); } // Insert the forced scalars. @@ -5572,7 +4300,7 @@ void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) { isScalarUse(J, Src)); })) { Worklist.insert(Src); - DEBUG(dbgs() << "LV: Found scalar instruction: " << *Src << "\n"); + LLVM_DEBUG(dbgs() << "LV: Found scalar instruction: " << *Src << "\n"); } } @@ -5612,21 +4340,30 @@ void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) { // The induction variable and its update instruction will remain scalar. Worklist.insert(Ind); Worklist.insert(IndUpdate); - DEBUG(dbgs() << "LV: Found scalar instruction: " << *Ind << "\n"); - DEBUG(dbgs() << "LV: Found scalar instruction: " << *IndUpdate << "\n"); + LLVM_DEBUG(dbgs() << "LV: Found scalar instruction: " << *Ind << "\n"); + LLVM_DEBUG(dbgs() << "LV: Found scalar instruction: " << *IndUpdate + << "\n"); } Scalars[VF].insert(Worklist.begin(), Worklist.end()); } -bool LoopVectorizationLegality::isScalarWithPredication(Instruction *I) { - if (!blockNeedsPredication(I->getParent())) +bool LoopVectorizationCostModel::isScalarWithPredication(Instruction *I) { + if (!Legal->blockNeedsPredication(I->getParent())) return false; switch(I->getOpcode()) { default: break; - case Instruction::Store: - return !isMaskRequired(I); + case Instruction::Load: + case Instruction::Store: { + if (!Legal->isMaskRequired(I)) + return false; + auto *Ptr = getLoadStorePointerOperand(I); + auto *Ty = getMemInstValueType(I); + return isa<LoadInst>(I) ? + !(isLegalMaskedLoad(Ty, Ptr) || isLegalMaskedGather(Ty)) + : !(isLegalMaskedStore(Ty, Ptr) || isLegalMaskedScatter(Ty)); + } case Instruction::UDiv: case Instruction::SDiv: case Instruction::SRem: @@ -5636,17 +4373,17 @@ bool LoopVectorizationLegality::isScalarWithPredication(Instruction *I) { return false; } -bool LoopVectorizationLegality::memoryInstructionCanBeWidened(Instruction *I, - unsigned VF) { +bool LoopVectorizationCostModel::memoryInstructionCanBeWidened(Instruction *I, + unsigned VF) { // Get and ensure we have a valid memory instruction. LoadInst *LI = dyn_cast<LoadInst>(I); StoreInst *SI = dyn_cast<StoreInst>(I); assert((LI || SI) && "Invalid memory instruction"); - auto *Ptr = getPointerOperand(I); + auto *Ptr = getLoadStorePointerOperand(I); // In order to be widened, the pointer should be consecutive, first of all. - if (!isConsecutivePtr(Ptr)) + if (!Legal->isConsecutivePtr(Ptr)) return false; // If the instruction is a store located in a predicated block, it will be @@ -5697,7 +4434,7 @@ void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { auto *Cmp = dyn_cast<Instruction>(Latch->getTerminator()->getOperand(0)); if (Cmp && TheLoop->contains(Cmp) && Cmp->hasOneUse()) { Worklist.insert(Cmp); - DEBUG(dbgs() << "LV: Found uniform instruction: " << *Cmp << "\n"); + LLVM_DEBUG(dbgs() << "LV: Found uniform instruction: " << *Cmp << "\n"); } // Holds consecutive and consecutive-like pointers. Consecutive-like pointers @@ -5729,7 +4466,7 @@ void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { for (auto *BB : TheLoop->blocks()) for (auto &I : *BB) { // If there's no pointer operand, there's nothing to do. - auto *Ptr = dyn_cast_or_null<Instruction>(getPointerOperand(&I)); + auto *Ptr = dyn_cast_or_null<Instruction>(getLoadStorePointerOperand(&I)); if (!Ptr) continue; @@ -5737,7 +4474,7 @@ void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { // pointer operand. auto UsersAreMemAccesses = llvm::all_of(Ptr->users(), [&](User *U) -> bool { - return getPointerOperand(U) == Ptr; + return getLoadStorePointerOperand(U) == Ptr; }); // Ensure the memory instruction will not be scalarized or used by @@ -5758,7 +4495,7 @@ void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { // aren't also identified as possibly non-uniform. for (auto *V : ConsecutiveLikePtrs) if (!PossibleNonUniformPtrs.count(V)) { - DEBUG(dbgs() << "LV: Found uniform instruction: " << *V << "\n"); + LLVM_DEBUG(dbgs() << "LV: Found uniform instruction: " << *V << "\n"); Worklist.insert(V); } @@ -5777,10 +4514,11 @@ void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { if (llvm::all_of(OI->users(), [&](User *U) -> bool { auto *J = cast<Instruction>(U); return !TheLoop->contains(J) || Worklist.count(J) || - (OI == getPointerOperand(J) && isUniformDecision(J, VF)); + (OI == getLoadStorePointerOperand(J) && + isUniformDecision(J, VF)); })) { Worklist.insert(OI); - DEBUG(dbgs() << "LV: Found uniform instruction: " << *OI << "\n"); + LLVM_DEBUG(dbgs() << "LV: Found uniform instruction: " << *OI << "\n"); } } } @@ -5788,7 +4526,7 @@ void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { // Returns true if Ptr is the pointer operand of a memory access instruction // I, and I is known to not require scalarization. auto isVectorizedMemAccessUse = [&](Instruction *I, Value *Ptr) -> bool { - return getPointerOperand(I) == Ptr && isUniformDecision(I, VF); + return getLoadStorePointerOperand(I) == Ptr && isUniformDecision(I, VF); }; // For an instruction to be added into Worklist above, all its users inside @@ -5825,123 +4563,14 @@ void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { // The induction variable and its update instruction will remain uniform. Worklist.insert(Ind); Worklist.insert(IndUpdate); - DEBUG(dbgs() << "LV: Found uniform instruction: " << *Ind << "\n"); - DEBUG(dbgs() << "LV: Found uniform instruction: " << *IndUpdate << "\n"); + LLVM_DEBUG(dbgs() << "LV: Found uniform instruction: " << *Ind << "\n"); + LLVM_DEBUG(dbgs() << "LV: Found uniform instruction: " << *IndUpdate + << "\n"); } Uniforms[VF].insert(Worklist.begin(), Worklist.end()); } -bool LoopVectorizationLegality::canVectorizeMemory() { - LAI = &(*GetLAA)(*TheLoop); - InterleaveInfo.setLAI(LAI); - const OptimizationRemarkAnalysis *LAR = LAI->getReport(); - if (LAR) { - ORE->emit([&]() { - return OptimizationRemarkAnalysis(Hints->vectorizeAnalysisPassName(), - "loop not vectorized: ", *LAR); - }); - } - if (!LAI->canVectorizeMemory()) - return false; - - if (LAI->hasStoreToLoopInvariantAddress()) { - ORE->emit(createMissedAnalysis("CantVectorizeStoreToLoopInvariantAddress") - << "write to a loop invariant address could not be vectorized"); - DEBUG(dbgs() << "LV: We don't allow storing to uniform addresses\n"); - return false; - } - - Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks()); - PSE.addPredicate(LAI->getPSE().getUnionPredicate()); - - return true; -} - -bool LoopVectorizationLegality::isInductionPhi(const Value *V) { - Value *In0 = const_cast<Value *>(V); - PHINode *PN = dyn_cast_or_null<PHINode>(In0); - if (!PN) - return false; - - return Inductions.count(PN); -} - -bool LoopVectorizationLegality::isCastedInductionVariable(const Value *V) { - auto *Inst = dyn_cast<Instruction>(V); - return (Inst && InductionCastsToIgnore.count(Inst)); -} - -bool LoopVectorizationLegality::isInductionVariable(const Value *V) { - return isInductionPhi(V) || isCastedInductionVariable(V); -} - -bool LoopVectorizationLegality::isFirstOrderRecurrence(const PHINode *Phi) { - return FirstOrderRecurrences.count(Phi); -} - -bool LoopVectorizationLegality::blockNeedsPredication(BasicBlock *BB) { - return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); -} - -bool LoopVectorizationLegality::blockCanBePredicated( - BasicBlock *BB, SmallPtrSetImpl<Value *> &SafePtrs) { - const bool IsAnnotatedParallel = TheLoop->isAnnotatedParallel(); - - for (Instruction &I : *BB) { - // Check that we don't have a constant expression that can trap as operand. - for (Value *Operand : I.operands()) { - if (auto *C = dyn_cast<Constant>(Operand)) - if (C->canTrap()) - return false; - } - // We might be able to hoist the load. - if (I.mayReadFromMemory()) { - auto *LI = dyn_cast<LoadInst>(&I); - if (!LI) - return false; - if (!SafePtrs.count(LI->getPointerOperand())) { - if (isLegalMaskedLoad(LI->getType(), LI->getPointerOperand()) || - isLegalMaskedGather(LI->getType())) { - MaskedOp.insert(LI); - continue; - } - // !llvm.mem.parallel_loop_access implies if-conversion safety. - if (IsAnnotatedParallel) - continue; - return false; - } - } - - if (I.mayWriteToMemory()) { - auto *SI = dyn_cast<StoreInst>(&I); - // We only support predication of stores in basic blocks with one - // predecessor. - if (!SI) - return false; - - // Build a masked store if it is legal for the target. - if (isLegalMaskedStore(SI->getValueOperand()->getType(), - SI->getPointerOperand()) || - isLegalMaskedScatter(SI->getValueOperand()->getType())) { - MaskedOp.insert(SI); - continue; - } - - bool isSafePtr = (SafePtrs.count(SI->getPointerOperand()) != 0); - bool isSinglePredecessor = SI->getParent()->getSinglePredecessor(); - - if (++NumPredStores > NumberOfStoresToPredicate || !isSafePtr || - !isSinglePredecessor) - return false; - } - if (I.mayThrow()) - return false; - } - - return true; -} - void InterleavedAccessInfo::collectConstStrideAccesses( MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, const ValueToValueMap &Strides) { @@ -5962,7 +4591,7 @@ void InterleavedAccessInfo::collectConstStrideAccesses( if (!LI && !SI) continue; - Value *Ptr = getPointerOperand(&I); + Value *Ptr = getLoadStorePointerOperand(&I); // We don't check wrapping here because we don't know yet if Ptr will be // part of a full group or a group with gaps. Checking wrapping for all // pointers (even those that end up in groups with no gaps) will be overly @@ -6022,9 +4651,9 @@ void InterleavedAccessInfo::collectConstStrideAccesses( // this group because it and (2) are dependent. However, (1) can be grouped // with other accesses that may precede it in program order. Note that a // bottom-up order does not imply that WAW dependences should not be checked. -void InterleavedAccessInfo::analyzeInterleaving( - const ValueToValueMap &Strides) { - DEBUG(dbgs() << "LV: Analyzing interleaved accesses...\n"); +void InterleavedAccessInfo::analyzeInterleaving() { + LLVM_DEBUG(dbgs() << "LV: Analyzing interleaved accesses...\n"); + const ValueToValueMap &Strides = LAI->getSymbolicStrides(); // Holds all accesses with a constant stride. MapVector<Instruction *, StrideDescriptor> AccessStrideInfo; @@ -6065,7 +4694,8 @@ void InterleavedAccessInfo::analyzeInterleaving( if (isStrided(DesB.Stride)) { Group = getInterleaveGroup(B); if (!Group) { - DEBUG(dbgs() << "LV: Creating an interleave group with:" << *B << '\n'); + LLVM_DEBUG(dbgs() << "LV: Creating an interleave group with:" << *B + << '\n'); Group = createInterleaveGroup(B, DesB.Stride, DesB.Align); } if (B->mayWriteToMemory()) @@ -6124,7 +4754,12 @@ void InterleavedAccessInfo::analyzeInterleaving( // Ignore A if it's already in a group or isn't the same kind of memory // operation as B. - if (isInterleaved(A) || A->mayReadFromMemory() != B->mayReadFromMemory()) + // Note that mayReadFromMemory() isn't mutually exclusive to mayWriteToMemory + // in the case of atomic loads. We shouldn't see those here, canVectorizeMemory() + // should have returned false - except for the case we asked for optimization + // remarks. + if (isInterleaved(A) || (A->mayReadFromMemory() != B->mayReadFromMemory()) + || (A->mayWriteToMemory() != B->mayWriteToMemory())) continue; // Check rules 1 and 2. Ignore A if its stride or size is different from @@ -6163,8 +4798,9 @@ void InterleavedAccessInfo::analyzeInterleaving( // Try to insert A into B's group. if (Group->insertMember(A, IndexA, DesA.Align)) { - DEBUG(dbgs() << "LV: Inserted:" << *A << '\n' - << " into the interleave group with" << *B << '\n'); + LLVM_DEBUG(dbgs() << "LV: Inserted:" << *A << '\n' + << " into the interleave group with" << *B + << '\n'); InterleaveGroupMap[A] = Group; // Set the first load in program order as the insert position. @@ -6177,8 +4813,9 @@ void InterleavedAccessInfo::analyzeInterleaving( // Remove interleaved store groups with gaps. for (InterleaveGroup *Group : StoreGroups) if (Group->getNumMembers() != Group->getFactor()) { - DEBUG(dbgs() << "LV: Invalidate candidate interleaved store group due " - "to gaps.\n"); + LLVM_DEBUG( + dbgs() << "LV: Invalidate candidate interleaved store group due " + "to gaps.\n"); releaseGroup(Group); } // Remove interleaved groups with gaps (currently only loads) whose memory @@ -6207,21 +4844,23 @@ void InterleavedAccessInfo::analyzeInterleaving( // So we check only group member 0 (which is always guaranteed to exist), // and group member Factor - 1; If the latter doesn't exist we rely on // peeling (if it is a non-reveresed accsess -- see Case 3). - Value *FirstMemberPtr = getPointerOperand(Group->getMember(0)); + Value *FirstMemberPtr = getLoadStorePointerOperand(Group->getMember(0)); if (!getPtrStride(PSE, FirstMemberPtr, TheLoop, Strides, /*Assume=*/false, /*ShouldCheckWrap=*/true)) { - DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to " - "first group member potentially pointer-wrapping.\n"); + LLVM_DEBUG( + dbgs() << "LV: Invalidate candidate interleaved group due to " + "first group member potentially pointer-wrapping.\n"); releaseGroup(Group); continue; } Instruction *LastMember = Group->getMember(Group->getFactor() - 1); if (LastMember) { - Value *LastMemberPtr = getPointerOperand(LastMember); + Value *LastMemberPtr = getLoadStorePointerOperand(LastMember); if (!getPtrStride(PSE, LastMemberPtr, TheLoop, Strides, /*Assume=*/false, /*ShouldCheckWrap=*/true)) { - DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to " - "last group member potentially pointer-wrapping.\n"); + LLVM_DEBUG( + dbgs() << "LV: Invalidate candidate interleaved group due to " + "last group member potentially pointer-wrapping.\n"); releaseGroup(Group); } } else { @@ -6231,29 +4870,25 @@ void InterleavedAccessInfo::analyzeInterleaving( // to look for a member at index factor - 1, since every group must have // a member at index zero. if (Group->isReverse()) { - DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to " - "a reverse access with gaps.\n"); + LLVM_DEBUG( + dbgs() << "LV: Invalidate candidate interleaved group due to " + "a reverse access with gaps.\n"); releaseGroup(Group); continue; } - DEBUG(dbgs() << "LV: Interleaved group requires epilogue iteration.\n"); + LLVM_DEBUG( + dbgs() << "LV: Interleaved group requires epilogue iteration.\n"); RequiresScalarEpilogue = true; } } } Optional<unsigned> LoopVectorizationCostModel::computeMaxVF(bool OptForSize) { - if (!EnableCondStoresVectorization && Legal->getNumPredStores()) { - ORE->emit(createMissedAnalysis("ConditionalStore") - << "store that is conditionally executed prevents vectorization"); - DEBUG(dbgs() << "LV: No vectorization. There are conditional stores.\n"); - return None; - } - if (Legal->getRuntimePointerChecking()->Need && TTI.hasBranchDivergence()) { // TODO: It may by useful to do since it's still likely to be dynamically // uniform if the target can skip. - DEBUG(dbgs() << "LV: Not inserting runtime ptr check for divergent target"); + LLVM_DEBUG( + dbgs() << "LV: Not inserting runtime ptr check for divergent target"); ORE->emit( createMissedAnalysis("CantVersionLoopWithDivergentTarget") @@ -6271,20 +4906,22 @@ Optional<unsigned> LoopVectorizationCostModel::computeMaxVF(bool OptForSize) { << "runtime pointer checks needed. Enable vectorization of this " "loop with '#pragma clang loop vectorize(enable)' when " "compiling with -Os/-Oz"); - DEBUG(dbgs() - << "LV: Aborting. Runtime ptr check is required with -Os/-Oz.\n"); + LLVM_DEBUG( + dbgs() + << "LV: Aborting. Runtime ptr check is required with -Os/-Oz.\n"); return None; } // If we optimize the program for size, avoid creating the tail loop. - DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n'); + LLVM_DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n'); // If we don't know the precise trip count, don't try to vectorize. if (TC < 2) { ORE->emit( createMissedAnalysis("UnknownLoopCountComplexCFG") << "unable to calculate the loop count due to complex control flow"); - DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); + LLVM_DEBUG( + dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); return None; } @@ -6302,7 +4939,8 @@ Optional<unsigned> LoopVectorizationCostModel::computeMaxVF(bool OptForSize) { "same time. Enable vectorization of this loop " "with '#pragma clang loop vectorize(enable)' " "when compiling with -Os/-Oz"); - DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); + LLVM_DEBUG( + dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); return None; } @@ -6327,29 +4965,30 @@ LoopVectorizationCostModel::computeFeasibleMaxVF(bool OptForSize, unsigned MaxVectorSize = WidestRegister / WidestType; - DEBUG(dbgs() << "LV: The Smallest and Widest types: " << SmallestType << " / " - << WidestType << " bits.\n"); - DEBUG(dbgs() << "LV: The Widest register safe to use is: " << WidestRegister - << " bits.\n"); + LLVM_DEBUG(dbgs() << "LV: The Smallest and Widest types: " << SmallestType + << " / " << WidestType << " bits.\n"); + LLVM_DEBUG(dbgs() << "LV: The Widest register safe to use is: " + << WidestRegister << " bits.\n"); - assert(MaxVectorSize <= 64 && "Did not expect to pack so many elements" - " into one vector!"); + assert(MaxVectorSize <= 256 && "Did not expect to pack so many elements" + " into one vector!"); if (MaxVectorSize == 0) { - DEBUG(dbgs() << "LV: The target has no vector registers.\n"); + LLVM_DEBUG(dbgs() << "LV: The target has no vector registers.\n"); MaxVectorSize = 1; return MaxVectorSize; } else if (ConstTripCount && ConstTripCount < MaxVectorSize && isPowerOf2_32(ConstTripCount)) { // We need to clamp the VF to be the ConstTripCount. There is no point in // choosing a higher viable VF as done in the loop below. - DEBUG(dbgs() << "LV: Clamping the MaxVF to the constant trip count: " - << ConstTripCount << "\n"); + LLVM_DEBUG(dbgs() << "LV: Clamping the MaxVF to the constant trip count: " + << ConstTripCount << "\n"); MaxVectorSize = ConstTripCount; return MaxVectorSize; } unsigned MaxVF = MaxVectorSize; - if (MaximizeBandwidth && !OptForSize) { + if (TTI.shouldMaximizeVectorBandwidth(OptForSize) || + (MaximizeBandwidth && !OptForSize)) { // Collect all viable vectorization factors larger than the default MaxVF // (i.e. MaxVectorSize). SmallVector<unsigned, 8> VFs; @@ -6369,24 +5008,30 @@ LoopVectorizationCostModel::computeFeasibleMaxVF(bool OptForSize, break; } } + if (unsigned MinVF = TTI.getMinimumVF(SmallestType)) { + if (MaxVF < MinVF) { + LLVM_DEBUG(dbgs() << "LV: Overriding calculated MaxVF(" << MaxVF + << ") with target's minimum: " << MinVF << '\n'); + MaxVF = MinVF; + } + } } return MaxVF; } -LoopVectorizationCostModel::VectorizationFactor +VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor(unsigned MaxVF) { float Cost = expectedCost(1).first; -#ifndef NDEBUG const float ScalarCost = Cost; -#endif /* NDEBUG */ unsigned Width = 1; - DEBUG(dbgs() << "LV: Scalar loop costs: " << (int)ScalarCost << ".\n"); + LLVM_DEBUG(dbgs() << "LV: Scalar loop costs: " << (int)ScalarCost << ".\n"); bool ForceVectorization = Hints->getForce() == LoopVectorizeHints::FK_Enabled; - // Ignore scalar width, because the user explicitly wants vectorization. if (ForceVectorization && MaxVF > 1) { - Width = 2; - Cost = expectedCost(Width).first / (float)Width; + // Ignore scalar width, because the user explicitly wants vectorization. + // Initialize cost to max so that VF = 2 is, at least, chosen during cost + // evaluation. + Cost = std::numeric_limits<float>::max(); } for (unsigned i = 2; i <= MaxVF; i *= 2) { @@ -6395,10 +5040,10 @@ LoopVectorizationCostModel::selectVectorizationFactor(unsigned MaxVF) { // the vector elements. VectorizationCostTy C = expectedCost(i); float VectorCost = C.first / (float)i; - DEBUG(dbgs() << "LV: Vector loop of width " << i - << " costs: " << (int)VectorCost << ".\n"); + LLVM_DEBUG(dbgs() << "LV: Vector loop of width " << i + << " costs: " << (int)VectorCost << ".\n"); if (!C.second && !ForceVectorization) { - DEBUG( + LLVM_DEBUG( dbgs() << "LV: Not considering vector loop of width " << i << " because it will not generate any vector instructions.\n"); continue; @@ -6409,10 +5054,19 @@ LoopVectorizationCostModel::selectVectorizationFactor(unsigned MaxVF) { } } - DEBUG(if (ForceVectorization && Width > 1 && Cost >= ScalarCost) dbgs() - << "LV: Vectorization seems to be not beneficial, " - << "but was forced by a user.\n"); - DEBUG(dbgs() << "LV: Selecting VF: " << Width << ".\n"); + if (!EnableCondStoresVectorization && NumPredStores) { + ORE->emit(createMissedAnalysis("ConditionalStore") + << "store that is conditionally executed prevents vectorization"); + LLVM_DEBUG( + dbgs() << "LV: No vectorization. There are conditional stores.\n"); + Width = 1; + Cost = ScalarCost; + } + + LLVM_DEBUG(if (ForceVectorization && Width > 1 && Cost >= ScalarCost) dbgs() + << "LV: Vectorization seems to be not beneficial, " + << "but was forced by a user.\n"); + LLVM_DEBUG(dbgs() << "LV: Selecting VF: " << Width << ".\n"); VectorizationFactor Factor = {Width, (unsigned)(Width * Cost)}; return Factor; } @@ -6460,7 +5114,7 @@ LoopVectorizationCostModel::getSmallestAndWidestTypes() { // optimization to non-pointer types. // if (T->isPointerTy() && !isConsecutiveLoadOrStore(&I) && - !Legal->isAccessInterleaved(&I) && !Legal->isLegalGatherOrScatter(&I)) + !isAccessInterleaved(&I) && !isLegalGatherOrScatter(&I)) continue; MinWidth = std::min(MinWidth, @@ -6504,8 +5158,8 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, return 1; unsigned TargetNumRegisters = TTI.getNumberOfRegisters(VF > 1); - DEBUG(dbgs() << "LV: The target has " << TargetNumRegisters - << " registers\n"); + LLVM_DEBUG(dbgs() << "LV: The target has " << TargetNumRegisters + << " registers\n"); if (VF == 1) { if (ForceTargetNumScalarRegs.getNumOccurrences() > 0) @@ -6519,7 +5173,6 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, // We divide by these constants so assume that we have at least one // instruction that uses at least one register. R.MaxLocalUsers = std::max(R.MaxLocalUsers, 1U); - R.NumInstructions = std::max(R.NumInstructions, 1U); // We calculate the interleave count using the following formula. // Subtract the number of loop invariants from the number of available @@ -6564,7 +5217,7 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, // Interleave if we vectorized this loop and there is a reduction that could // benefit from interleaving. if (VF > 1 && !Legal->getReductionVars()->empty()) { - DEBUG(dbgs() << "LV: Interleaving because of reductions.\n"); + LLVM_DEBUG(dbgs() << "LV: Interleaving because of reductions.\n"); return IC; } @@ -6575,7 +5228,7 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, // We want to interleave small loops in order to reduce the loop overhead and // potentially expose ILP opportunities. - DEBUG(dbgs() << "LV: Loop cost is " << LoopCost << '\n'); + LLVM_DEBUG(dbgs() << "LV: Loop cost is " << LoopCost << '\n'); if (!InterleavingRequiresRuntimePointerCheck && LoopCost < SmallLoopCost) { // We assume that the cost overhead is 1 and we use the cost model // to estimate the cost of the loop and interleave until the cost of the @@ -6603,11 +5256,12 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, if (EnableLoadStoreRuntimeInterleave && std::max(StoresIC, LoadsIC) > SmallIC) { - DEBUG(dbgs() << "LV: Interleaving to saturate store or load ports.\n"); + LLVM_DEBUG( + dbgs() << "LV: Interleaving to saturate store or load ports.\n"); return std::max(StoresIC, LoadsIC); } - DEBUG(dbgs() << "LV: Interleaving to reduce branch cost.\n"); + LLVM_DEBUG(dbgs() << "LV: Interleaving to reduce branch cost.\n"); return SmallIC; } @@ -6615,11 +5269,11 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, // this point) that could benefit from interleaving. bool HasReductions = !Legal->getReductionVars()->empty(); if (TTI.enableAggressiveInterleaving(HasReductions)) { - DEBUG(dbgs() << "LV: Interleaving to expose ILP.\n"); + LLVM_DEBUG(dbgs() << "LV: Interleaving to expose ILP.\n"); return IC; } - DEBUG(dbgs() << "LV: Not Interleaving.\n"); + LLVM_DEBUG(dbgs() << "LV: Not Interleaving.\n"); return 1; } @@ -6646,7 +5300,6 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { DFS.perform(LI); RegisterUsage RU; - RU.NumInstructions = 0; // Each 'key' in the map opens a new interval. The values // of the map are the index of the 'last seen' usage of the @@ -6658,14 +5311,13 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { // Marks the end of each interval. IntervalMap EndPoint; // Saves the list of instruction indices that are used in the loop. - SmallSet<Instruction *, 8> Ends; + SmallPtrSet<Instruction *, 8> Ends; // Saves the list of values that are used in the loop but are // defined outside the loop, such as arguments and constants. SmallPtrSet<Value *, 8> LoopInvariants; unsigned Index = 0; for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) { - RU.NumInstructions += BB->size(); for (Instruction &I : *BB) { IdxToInstr[Index++] = &I; @@ -6698,7 +5350,7 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { for (auto &Interval : EndPoint) TransposeEnds[Interval.second].push_back(Interval.first); - SmallSet<Instruction *, 8> OpenIntervals; + SmallPtrSet<Instruction *, 8> OpenIntervals; // Get the size of the widest register. unsigned MaxSafeDepDist = -1U; @@ -6711,7 +5363,7 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { SmallVector<RegisterUsage, 8> RUs(VFs.size()); SmallVector<unsigned, 8> MaxUsages(VFs.size(), 0); - DEBUG(dbgs() << "LV(REG): Calculating max register usage:\n"); + LLVM_DEBUG(dbgs() << "LV(REG): Calculating max register usage:\n"); // A lambda that gets the register usage for the given type and VF. auto GetRegUsage = [&DL, WidestRegister](Type *Ty, unsigned VF) { @@ -6756,8 +5408,8 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { MaxUsages[j] = std::max(MaxUsages[j], RegUsage); } - DEBUG(dbgs() << "LV(REG): At #" << i << " Interval # " - << OpenIntervals.size() << '\n'); + LLVM_DEBUG(dbgs() << "LV(REG): At #" << i << " Interval # " + << OpenIntervals.size() << '\n'); // Add the current instruction to the list of open intervals. OpenIntervals.insert(I); @@ -6772,10 +5424,10 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { Invariant += GetRegUsage(Inst->getType(), VFs[i]); } - DEBUG(dbgs() << "LV(REG): VF = " << VFs[i] << '\n'); - DEBUG(dbgs() << "LV(REG): Found max usage: " << MaxUsages[i] << '\n'); - DEBUG(dbgs() << "LV(REG): Found invariant usage: " << Invariant << '\n'); - DEBUG(dbgs() << "LV(REG): LoopSize: " << RU.NumInstructions << '\n'); + LLVM_DEBUG(dbgs() << "LV(REG): VF = " << VFs[i] << '\n'); + LLVM_DEBUG(dbgs() << "LV(REG): Found max usage: " << MaxUsages[i] << '\n'); + LLVM_DEBUG(dbgs() << "LV(REG): Found invariant usage: " << Invariant + << '\n'); RU.LoopInvariantRegs = Invariant; RU.MaxLocalUsers = MaxUsages[i]; @@ -6785,6 +5437,22 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { return RUs; } +bool LoopVectorizationCostModel::useEmulatedMaskMemRefHack(Instruction *I){ + // TODO: Cost model for emulated masked load/store is completely + // broken. This hack guides the cost model to use an artificially + // high enough value to practically disable vectorization with such + // operations, except where previously deployed legality hack allowed + // using very low cost values. This is to avoid regressions coming simply + // from moving "masked load/store" check from legality to cost model. + // Masked Load/Gather emulation was previously never allowed. + // Limited number of Masked Store/Scatter emulation was allowed. + assert(isScalarWithPredication(I) && + "Expecting a scalar emulated instruction"); + return isa<LoadInst>(I) || + (isa<StoreInst>(I) && + NumPredStores > NumberOfStoresToPredicate); +} + void LoopVectorizationCostModel::collectInstsToScalarize(unsigned VF) { // If we aren't vectorizing the loop, or if we've already collected the // instructions to scalarize, there's nothing to do. Collection may already @@ -6805,11 +5473,13 @@ void LoopVectorizationCostModel::collectInstsToScalarize(unsigned VF) { if (!Legal->blockNeedsPredication(BB)) continue; for (Instruction &I : *BB) - if (Legal->isScalarWithPredication(&I)) { + if (isScalarWithPredication(&I)) { ScalarCostsTy ScalarCosts; - if (computePredInstDiscount(&I, ScalarCosts, VF) >= 0) + // Do not apply discount logic if hacked cost is needed + // for emulated masked memrefs. + if (!useEmulatedMaskMemRefHack(&I) && + computePredInstDiscount(&I, ScalarCosts, VF) >= 0) ScalarCostsVF.insert(ScalarCosts.begin(), ScalarCosts.end()); - // Remember that BB will remain after vectorization. PredicatedBBsAfterVectorization.insert(BB); } @@ -6844,7 +5514,7 @@ int LoopVectorizationCostModel::computePredInstDiscount( // If the instruction is scalar with predication, it will be analyzed // separately. We ignore it within the context of PredInst. - if (Legal->isScalarWithPredication(I)) + if (isScalarWithPredication(I)) return false; // If any of the instruction's operands are uniform after vectorization, @@ -6898,7 +5568,7 @@ int LoopVectorizationCostModel::computePredInstDiscount( // Compute the scalarization overhead of needed insertelement instructions // and phi nodes. - if (Legal->isScalarWithPredication(I) && !I->getType()->isVoidTy()) { + if (isScalarWithPredication(I) && !I->getType()->isVoidTy()) { ScalarCost += TTI.getScalarizationOverhead(ToVectorTy(I->getType(), VF), true, false); ScalarCost += VF * TTI.getCFInstrCost(Instruction::PHI); @@ -6940,11 +5610,7 @@ LoopVectorizationCostModel::expectedCost(unsigned VF) { VectorizationCostTy BlockCost; // For each instruction in the old loop. - for (Instruction &I : *BB) { - // Skip dbg intrinsics. - if (isa<DbgInfoIntrinsic>(I)) - continue; - + for (Instruction &I : BB->instructionsWithoutDebug()) { // Skip ignored values. if (ValuesToIgnore.count(&I) || (VF > 1 && VecValuesToIgnore.count(&I))) @@ -6958,8 +5624,9 @@ LoopVectorizationCostModel::expectedCost(unsigned VF) { BlockCost.first += C.first; BlockCost.second |= C.second; - DEBUG(dbgs() << "LV: Found an estimated cost of " << C.first << " for VF " - << VF << " For instruction: " << I << '\n'); + LLVM_DEBUG(dbgs() << "LV: Found an estimated cost of " << C.first + << " for VF " << VF << " For instruction: " << I + << '\n'); } // If we are vectorizing a predicated block, it will have been @@ -6978,7 +5645,7 @@ LoopVectorizationCostModel::expectedCost(unsigned VF) { return Cost; } -/// \brief Gets Address Access SCEV after verifying that the access pattern +/// Gets Address Access SCEV after verifying that the access pattern /// is loop invariant except the induction variable dependence. /// /// This SCEV can be sent to the Target in order to estimate the address @@ -7020,7 +5687,7 @@ unsigned LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I, unsigned Alignment = getMemInstAlignment(I); unsigned AS = getMemInstAddressSpace(I); - Value *Ptr = getPointerOperand(I); + Value *Ptr = getLoadStorePointerOperand(I); Type *PtrTy = ToVectorTy(Ptr->getType(), VF); // Figure out whether the access is strided and get the stride value @@ -7041,9 +5708,15 @@ unsigned LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I, // If we have a predicated store, it may not be executed for each vector // lane. Scale the cost by the probability of executing the predicated // block. - if (Legal->isScalarWithPredication(I)) + if (isScalarWithPredication(I)) { Cost /= getReciprocalPredBlockProb(); + if (useEmulatedMaskMemRefHack(I)) + // Artificially setting to a high enough value to practically disable + // vectorization with such operations. + Cost = 3000000; + } + return Cost; } @@ -7052,7 +5725,7 @@ unsigned LoopVectorizationCostModel::getConsecutiveMemOpCost(Instruction *I, Type *ValTy = getMemInstValueType(I); Type *VectorTy = ToVectorTy(ValTy, VF); unsigned Alignment = getMemInstAlignment(I); - Value *Ptr = getPointerOperand(I); + Value *Ptr = getLoadStorePointerOperand(I); unsigned AS = getMemInstAddressSpace(I); int ConsecutiveStride = Legal->isConsecutivePtr(Ptr); @@ -7088,7 +5761,7 @@ unsigned LoopVectorizationCostModel::getGatherScatterCost(Instruction *I, Type *ValTy = getMemInstValueType(I); Type *VectorTy = ToVectorTy(ValTy, VF); unsigned Alignment = getMemInstAlignment(I); - Value *Ptr = getPointerOperand(I); + Value *Ptr = getLoadStorePointerOperand(I); return TTI.getAddressComputationCost(VectorTy) + TTI.getGatherScatterOpCost(I->getOpcode(), VectorTy, Ptr, @@ -7101,7 +5774,7 @@ unsigned LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I, Type *VectorTy = ToVectorTy(ValTy, VF); unsigned AS = getMemInstAddressSpace(I); - auto Group = Legal->getInterleavedAccessGroup(I); + auto Group = getInterleavedAccessGroup(I); assert(Group && "Fail to get an interleaved access group."); unsigned InterleaveFactor = Group->getFactor(); @@ -7168,13 +5841,16 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { void LoopVectorizationCostModel::setCostBasedWideningDecision(unsigned VF) { if (VF == 1) return; + NumPredStores = 0; for (BasicBlock *BB : TheLoop->blocks()) { // For each instruction in the old loop. for (Instruction &I : *BB) { - Value *Ptr = getPointerOperand(&I); + Value *Ptr = getLoadStorePointerOperand(&I); if (!Ptr) continue; + if (isa<StoreInst>(&I) && isScalarWithPredication(&I)) + NumPredStores++; if (isa<LoadInst>(&I) && Legal->isUniform(Ptr)) { // Scalar load + broadcast unsigned Cost = getUniformMemOpCost(&I, VF); @@ -7183,9 +5859,10 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(unsigned VF) { } // We assume that widening is the best solution when possible. - if (Legal->memoryInstructionCanBeWidened(&I, VF)) { + if (memoryInstructionCanBeWidened(&I, VF)) { unsigned Cost = getConsecutiveMemOpCost(&I, VF); - int ConsecutiveStride = Legal->isConsecutivePtr(getPointerOperand(&I)); + int ConsecutiveStride = + Legal->isConsecutivePtr(getLoadStorePointerOperand(&I)); assert((ConsecutiveStride == 1 || ConsecutiveStride == -1) && "Expected consecutive stride."); InstWidening Decision = @@ -7197,8 +5874,8 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(unsigned VF) { // Choose between Interleaving, Gather/Scatter or Scalarization. unsigned InterleaveCost = std::numeric_limits<unsigned>::max(); unsigned NumAccesses = 1; - if (Legal->isAccessInterleaved(&I)) { - auto Group = Legal->getInterleavedAccessGroup(&I); + if (isAccessInterleaved(&I)) { + auto Group = getInterleavedAccessGroup(&I); assert(Group && "Fail to get an interleaved access group."); // Make one decision for the whole group. @@ -7210,7 +5887,7 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(unsigned VF) { } unsigned GatherScatterCost = - Legal->isLegalGatherOrScatter(&I) + isLegalGatherOrScatter(&I) ? getGatherScatterCost(&I, VF) * NumAccesses : std::numeric_limits<unsigned>::max(); @@ -7235,7 +5912,7 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(unsigned VF) { // If the instructions belongs to an interleave group, the whole group // receives the same decision. The whole group receives the cost, but // the cost will actually be assigned to one instruction. - if (auto Group = Legal->getInterleavedAccessGroup(&I)) + if (auto Group = getInterleavedAccessGroup(&I)) setWideningDecision(Group, VF, Decision, Cost); else setWideningDecision(&I, VF, Decision, Cost); @@ -7255,7 +5932,7 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(unsigned VF) { for (BasicBlock *BB : TheLoop->blocks()) for (Instruction &I : *BB) { Instruction *PtrDef = - dyn_cast_or_null<Instruction>(getPointerOperand(&I)); + dyn_cast_or_null<Instruction>(getLoadStorePointerOperand(&I)); if (PtrDef && TheLoop->contains(PtrDef) && getWideningDecision(&I, VF) != CM_GatherScatter) AddrDefs.insert(PtrDef); @@ -7285,7 +5962,7 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(unsigned VF) { // Scalarize a widened load of address. setWideningDecision(I, VF, CM_Scalarize, (VF * getMemoryInstructionCost(I, 1))); - else if (auto Group = Legal->getInterleavedAccessGroup(I)) { + else if (auto Group = getInterleavedAccessGroup(I)) { // Scalarize an interleave group of address loads. for (unsigned I = 0; I < Group->getFactor(); ++I) { if (Instruction *Member = Group->getMember(I)) @@ -7371,7 +6048,7 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, // vector lane. Get the scalarization cost and scale this amount by the // probability of executing the predicated block. If the instruction is not // predicated, we fall through to the next case. - if (VF > 1 && Legal->isScalarWithPredication(I)) { + if (VF > 1 && isScalarWithPredication(I)) { unsigned Cost = 0; // These instructions have a non-void type, so account for the phi nodes @@ -7569,7 +6246,7 @@ Pass *createLoopVectorizePass(bool NoUnrolling, bool AlwaysVectorize) { bool LoopVectorizationCostModel::isConsecutiveLoadOrStore(Instruction *Inst) { // Check if the pointer operand of a load or store instruction is // consecutive. - if (auto *Ptr = getPointerOperand(Inst)) + if (auto *Ptr = getLoadStorePointerOperand(Inst)) return Legal->isConsecutivePtr(Ptr); return false; } @@ -7594,23 +6271,59 @@ void LoopVectorizationCostModel::collectValuesToIgnore() { } } -LoopVectorizationCostModel::VectorizationFactor +VectorizationFactor +LoopVectorizationPlanner::planInVPlanNativePath(bool OptForSize, + unsigned UserVF) { + // Width 1 means no vectorization, cost 0 means uncomputed cost. + const VectorizationFactor NoVectorization = {1U, 0U}; + + // Outer loop handling: They may require CFG and instruction level + // transformations before even evaluating whether vectorization is profitable. + // Since we cannot modify the incoming IR, we need to build VPlan upfront in + // the vectorization pipeline. + if (!OrigLoop->empty()) { + // TODO: If UserVF is not provided, we set UserVF to 4 for stress testing. + // This won't be necessary when UserVF is not required in the VPlan-native + // path. + if (VPlanBuildStressTest && !UserVF) + UserVF = 4; + + assert(EnableVPlanNativePath && "VPlan-native path is not enabled."); + assert(UserVF && "Expected UserVF for outer loop vectorization."); + assert(isPowerOf2_32(UserVF) && "VF needs to be a power of two"); + LLVM_DEBUG(dbgs() << "LV: Using user VF " << UserVF << ".\n"); + buildVPlans(UserVF, UserVF); + + // For VPlan build stress testing, we bail out after VPlan construction. + if (VPlanBuildStressTest) + return NoVectorization; + + return {UserVF, 0}; + } + + LLVM_DEBUG( + dbgs() << "LV: Not vectorizing. Inner loops aren't supported in the " + "VPlan-native path.\n"); + return NoVectorization; +} + +VectorizationFactor LoopVectorizationPlanner::plan(bool OptForSize, unsigned UserVF) { - // Width 1 means no vectorize, cost 0 means uncomputed cost. - const LoopVectorizationCostModel::VectorizationFactor NoVectorization = {1U, - 0U}; + assert(OrigLoop->empty() && "Inner loop expected."); + // Width 1 means no vectorization, cost 0 means uncomputed cost. + const VectorizationFactor NoVectorization = {1U, 0U}; Optional<unsigned> MaybeMaxVF = CM.computeMaxVF(OptForSize); if (!MaybeMaxVF.hasValue()) // Cases considered too costly to vectorize. return NoVectorization; if (UserVF) { - DEBUG(dbgs() << "LV: Using user VF " << UserVF << ".\n"); + LLVM_DEBUG(dbgs() << "LV: Using user VF " << UserVF << ".\n"); assert(isPowerOf2_32(UserVF) && "VF needs to be a power of two"); // Collect the instructions (and their associated costs) that will be more // profitable to scalarize. CM.selectUserVectorizationFactor(UserVF); - buildVPlans(UserVF, UserVF); - DEBUG(printPlans(dbgs())); + buildVPlansWithVPRecipes(UserVF, UserVF); + LLVM_DEBUG(printPlans(dbgs())); return {UserVF, 0}; } @@ -7627,8 +6340,8 @@ LoopVectorizationPlanner::plan(bool OptForSize, unsigned UserVF) { CM.collectInstsToScalarize(VF); } - buildVPlans(1, MaxVF); - DEBUG(printPlans(dbgs())); + buildVPlansWithVPRecipes(1, MaxVF); + LLVM_DEBUG(printPlans(dbgs())); if (MaxVF == 1) return NoVectorization; @@ -7637,7 +6350,8 @@ LoopVectorizationPlanner::plan(bool OptForSize, unsigned UserVF) { } void LoopVectorizationPlanner::setBestPlan(unsigned VF, unsigned UF) { - DEBUG(dbgs() << "Setting best plan to VF=" << VF << ", UF=" << UF << '\n'); + LLVM_DEBUG(dbgs() << "Setting best plan to VF=" << VF << ", UF=" << UF + << '\n'); BestVF = VF; BestUF = UF; @@ -7787,30 +6501,15 @@ bool LoopVectorizationPlanner::getDecisionAndClampRange( /// vectorization decision can potentially shorten this sub-range during /// buildVPlan(). void LoopVectorizationPlanner::buildVPlans(unsigned MinVF, unsigned MaxVF) { - - // Collect conditions feeding internal conditional branches; they need to be - // represented in VPlan for it to model masking. - SmallPtrSet<Value *, 1> NeedDef; - - auto *Latch = OrigLoop->getLoopLatch(); - for (BasicBlock *BB : OrigLoop->blocks()) { - if (BB == Latch) - continue; - BranchInst *Branch = dyn_cast<BranchInst>(BB->getTerminator()); - if (Branch && Branch->isConditional()) - NeedDef.insert(Branch->getCondition()); - } - for (unsigned VF = MinVF; VF < MaxVF + 1;) { VFRange SubRange = {VF, MaxVF + 1}; - VPlans.push_back(buildVPlan(SubRange, NeedDef)); + VPlans.push_back(buildVPlan(SubRange)); VF = SubRange.End; } } -VPValue *LoopVectorizationPlanner::createEdgeMask(BasicBlock *Src, - BasicBlock *Dst, - VPlanPtr &Plan) { +VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst, + VPlanPtr &Plan) { assert(is_contained(predecessors(Dst), Src) && "Invalid edge"); // Look for cached value. @@ -7840,8 +6539,7 @@ VPValue *LoopVectorizationPlanner::createEdgeMask(BasicBlock *Src, return EdgeMaskCache[Edge] = EdgeMask; } -VPValue *LoopVectorizationPlanner::createBlockInMask(BasicBlock *BB, - VPlanPtr &Plan) { +VPValue *VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlanPtr &Plan) { assert(OrigLoop->contains(BB) && "Block is not a part of a loop"); // Look for cached value. @@ -7874,10 +6572,9 @@ VPValue *LoopVectorizationPlanner::createBlockInMask(BasicBlock *BB, return BlockMaskCache[BB] = BlockMask; } -VPInterleaveRecipe * -LoopVectorizationPlanner::tryToInterleaveMemory(Instruction *I, - VFRange &Range) { - const InterleaveGroup *IG = Legal->getInterleavedAccessGroup(I); +VPInterleaveRecipe *VPRecipeBuilder::tryToInterleaveMemory(Instruction *I, + VFRange &Range) { + const InterleaveGroup *IG = CM.getInterleavedAccessGroup(I); if (!IG) return nullptr; @@ -7889,7 +6586,7 @@ LoopVectorizationPlanner::tryToInterleaveMemory(Instruction *I, LoopVectorizationCostModel::CM_Interleave); }; }; - if (!getDecisionAndClampRange(isIGMember(I), Range)) + if (!LoopVectorizationPlanner::getDecisionAndClampRange(isIGMember(I), Range)) return nullptr; // I is a member of an InterleaveGroup for VF's in the (possibly trimmed) @@ -7902,8 +6599,8 @@ LoopVectorizationPlanner::tryToInterleaveMemory(Instruction *I, } VPWidenMemoryInstructionRecipe * -LoopVectorizationPlanner::tryToWidenMemory(Instruction *I, VFRange &Range, - VPlanPtr &Plan) { +VPRecipeBuilder::tryToWidenMemory(Instruction *I, VFRange &Range, + VPlanPtr &Plan) { if (!isa<LoadInst>(I) && !isa<StoreInst>(I)) return nullptr; @@ -7922,7 +6619,7 @@ LoopVectorizationPlanner::tryToWidenMemory(Instruction *I, VFRange &Range, return Decision != LoopVectorizationCostModel::CM_Scalarize; }; - if (!getDecisionAndClampRange(willWiden, Range)) + if (!LoopVectorizationPlanner::getDecisionAndClampRange(willWiden, Range)) return nullptr; VPValue *Mask = nullptr; @@ -7933,8 +6630,7 @@ LoopVectorizationPlanner::tryToWidenMemory(Instruction *I, VFRange &Range, } VPWidenIntOrFpInductionRecipe * -LoopVectorizationPlanner::tryToOptimizeInduction(Instruction *I, - VFRange &Range) { +VPRecipeBuilder::tryToOptimizeInduction(Instruction *I, VFRange &Range) { if (PHINode *Phi = dyn_cast<PHINode>(I)) { // Check if this is an integer or fp induction. If so, build the recipe that // produces its scalar and vector values. @@ -7959,15 +6655,14 @@ LoopVectorizationPlanner::tryToOptimizeInduction(Instruction *I, [=](unsigned VF) -> bool { return CM.isOptimizableIVTruncate(K, VF); }; }; - if (isa<TruncInst>(I) && - getDecisionAndClampRange(isOptimizableIVTruncate(I), Range)) + if (isa<TruncInst>(I) && LoopVectorizationPlanner::getDecisionAndClampRange( + isOptimizableIVTruncate(I), Range)) return new VPWidenIntOrFpInductionRecipe(cast<PHINode>(I->getOperand(0)), cast<TruncInst>(I)); return nullptr; } -VPBlendRecipe * -LoopVectorizationPlanner::tryToBlend(Instruction *I, VPlanPtr &Plan) { +VPBlendRecipe *VPRecipeBuilder::tryToBlend(Instruction *I, VPlanPtr &Plan) { PHINode *Phi = dyn_cast<PHINode>(I); if (!Phi || Phi->getParent() == OrigLoop->getHeader()) return nullptr; @@ -7991,9 +6686,9 @@ LoopVectorizationPlanner::tryToBlend(Instruction *I, VPlanPtr &Plan) { return new VPBlendRecipe(Phi, Masks); } -bool LoopVectorizationPlanner::tryToWiden(Instruction *I, VPBasicBlock *VPBB, - VFRange &Range) { - if (Legal->isScalarWithPredication(I)) +bool VPRecipeBuilder::tryToWiden(Instruction *I, VPBasicBlock *VPBB, + VFRange &Range) { + if (CM.isScalarWithPredication(I)) return false; auto IsVectorizableOpcode = [](unsigned Opcode) { @@ -8077,7 +6772,7 @@ bool LoopVectorizationPlanner::tryToWiden(Instruction *I, VPBasicBlock *VPBB, return true; }; - if (!getDecisionAndClampRange(willWiden, Range)) + if (!LoopVectorizationPlanner::getDecisionAndClampRange(willWiden, Range)) return false; // Success: widen this instruction. We optimize the common case where @@ -8092,15 +6787,15 @@ bool LoopVectorizationPlanner::tryToWiden(Instruction *I, VPBasicBlock *VPBB, return true; } -VPBasicBlock *LoopVectorizationPlanner::handleReplication( +VPBasicBlock *VPRecipeBuilder::handleReplication( Instruction *I, VFRange &Range, VPBasicBlock *VPBB, DenseMap<Instruction *, VPReplicateRecipe *> &PredInst2Recipe, VPlanPtr &Plan) { - bool IsUniform = getDecisionAndClampRange( + bool IsUniform = LoopVectorizationPlanner::getDecisionAndClampRange( [&](unsigned VF) { return CM.isUniformAfterVectorization(I, VF); }, Range); - bool IsPredicated = Legal->isScalarWithPredication(I); + bool IsPredicated = CM.isScalarWithPredication(I); auto *Recipe = new VPReplicateRecipe(I, IsUniform, IsPredicated); // Find if I uses a predicated instruction. If so, it will use its scalar @@ -8113,24 +6808,25 @@ VPBasicBlock *LoopVectorizationPlanner::handleReplication( // Finalize the recipe for Instr, first if it is not predicated. if (!IsPredicated) { - DEBUG(dbgs() << "LV: Scalarizing:" << *I << "\n"); + LLVM_DEBUG(dbgs() << "LV: Scalarizing:" << *I << "\n"); VPBB->appendRecipe(Recipe); return VPBB; } - DEBUG(dbgs() << "LV: Scalarizing and predicating:" << *I << "\n"); + LLVM_DEBUG(dbgs() << "LV: Scalarizing and predicating:" << *I << "\n"); assert(VPBB->getSuccessors().empty() && "VPBB has successors when handling predicated replication."); // Record predicated instructions for above packing optimizations. PredInst2Recipe[I] = Recipe; - VPBlockBase *Region = - VPBB->setOneSuccessor(createReplicateRegion(I, Recipe, Plan)); - return cast<VPBasicBlock>(Region->setOneSuccessor(new VPBasicBlock())); + VPBlockBase *Region = createReplicateRegion(I, Recipe, Plan); + VPBlockUtils::insertBlockAfter(Region, VPBB); + auto *RegSucc = new VPBasicBlock(); + VPBlockUtils::insertBlockAfter(RegSucc, Region); + return RegSucc; } -VPRegionBlock * -LoopVectorizationPlanner::createReplicateRegion(Instruction *Instr, - VPRecipeBase *PredRecipe, - VPlanPtr &Plan) { +VPRegionBlock *VPRecipeBuilder::createReplicateRegion(Instruction *Instr, + VPRecipeBase *PredRecipe, + VPlanPtr &Plan) { // Instructions marked for predication are replicated and placed under an // if-then construct to prevent side-effects. @@ -8150,19 +6846,67 @@ LoopVectorizationPlanner::createReplicateRegion(Instruction *Instr, // Note: first set Entry as region entry and then connect successors starting // from it in order, to propagate the "parent" of each VPBasicBlock. - Entry->setTwoSuccessors(Pred, Exit); - Pred->setOneSuccessor(Exit); + VPBlockUtils::insertTwoBlocksAfter(Pred, Exit, BlockInMask, Entry); + VPBlockUtils::connectBlocks(Pred, Exit); return Region; } -LoopVectorizationPlanner::VPlanPtr -LoopVectorizationPlanner::buildVPlan(VFRange &Range, - const SmallPtrSetImpl<Value *> &NeedDef) { - EdgeMaskCache.clear(); - BlockMaskCache.clear(); - DenseMap<Instruction *, Instruction *> &SinkAfter = Legal->getSinkAfter(); - DenseMap<Instruction *, Instruction *> SinkAfterInverse; +bool VPRecipeBuilder::tryToCreateRecipe(Instruction *Instr, VFRange &Range, + VPlanPtr &Plan, VPBasicBlock *VPBB) { + VPRecipeBase *Recipe = nullptr; + // Check if Instr should belong to an interleave memory recipe, or already + // does. In the latter case Instr is irrelevant. + if ((Recipe = tryToInterleaveMemory(Instr, Range))) { + VPBB->appendRecipe(Recipe); + return true; + } + + // Check if Instr is a memory operation that should be widened. + if ((Recipe = tryToWidenMemory(Instr, Range, Plan))) { + VPBB->appendRecipe(Recipe); + return true; + } + + // Check if Instr should form some PHI recipe. + if ((Recipe = tryToOptimizeInduction(Instr, Range))) { + VPBB->appendRecipe(Recipe); + return true; + } + if ((Recipe = tryToBlend(Instr, Plan))) { + VPBB->appendRecipe(Recipe); + return true; + } + if (PHINode *Phi = dyn_cast<PHINode>(Instr)) { + VPBB->appendRecipe(new VPWidenPHIRecipe(Phi)); + return true; + } + + // Check if Instr is to be widened by a general VPWidenRecipe, after + // having first checked for specific widening recipes that deal with + // Interleave Groups, Inductions and Phi nodes. + if (tryToWiden(Instr, VPBB, Range)) + return true; + + return false; +} + +void LoopVectorizationPlanner::buildVPlansWithVPRecipes(unsigned MinVF, + unsigned MaxVF) { + assert(OrigLoop->empty() && "Inner loop expected."); + + // Collect conditions feeding internal conditional branches; they need to be + // represented in VPlan for it to model masking. + SmallPtrSet<Value *, 1> NeedDef; + + auto *Latch = OrigLoop->getLoopLatch(); + for (BasicBlock *BB : OrigLoop->blocks()) { + if (BB == Latch) + continue; + BranchInst *Branch = dyn_cast<BranchInst>(BB->getTerminator()); + if (Branch && Branch->isConditional()) + NeedDef.insert(Branch->getCondition()); + } // Collect instructions from the original loop that will become trivially dead // in the vectorized loop. We don't need to vectorize these instructions. For @@ -8173,15 +6917,31 @@ LoopVectorizationPlanner::buildVPlan(VFRange &Range, SmallPtrSet<Instruction *, 4> DeadInstructions; collectTriviallyDeadInstructions(DeadInstructions); + for (unsigned VF = MinVF; VF < MaxVF + 1;) { + VFRange SubRange = {VF, MaxVF + 1}; + VPlans.push_back( + buildVPlanWithVPRecipes(SubRange, NeedDef, DeadInstructions)); + VF = SubRange.End; + } +} + +LoopVectorizationPlanner::VPlanPtr +LoopVectorizationPlanner::buildVPlanWithVPRecipes( + VFRange &Range, SmallPtrSetImpl<Value *> &NeedDef, + SmallPtrSetImpl<Instruction *> &DeadInstructions) { // Hold a mapping from predicated instructions to their recipes, in order to // fix their AlsoPack behavior if a user is determined to replicate and use a // scalar instead of vector value. DenseMap<Instruction *, VPReplicateRecipe *> PredInst2Recipe; + DenseMap<Instruction *, Instruction *> &SinkAfter = Legal->getSinkAfter(); + DenseMap<Instruction *, Instruction *> SinkAfterInverse; + // Create a dummy pre-entry VPBasicBlock to start building the VPlan. VPBasicBlock *VPBB = new VPBasicBlock("Pre-Entry"); auto Plan = llvm::make_unique<VPlan>(VPBB); + VPRecipeBuilder RecipeBuilder(OrigLoop, TLI, TTI, Legal, CM, Builder); // Represent values that will have defs inside VPlan. for (Value *V : NeedDef) Plan->addVPValue(V); @@ -8196,7 +6956,7 @@ LoopVectorizationPlanner::buildVPlan(VFRange &Range, // ingredients and fill a new VPBasicBlock. unsigned VPBBsForBB = 0; auto *FirstVPBBForBB = new VPBasicBlock(BB->getName()); - VPBB->setOneSuccessor(FirstVPBBForBB); + VPBlockUtils::insertBlockAfter(FirstVPBBForBB, VPBB); VPBB = FirstVPBBForBB; Builder.setInsertPoint(VPBB); @@ -8204,18 +6964,17 @@ LoopVectorizationPlanner::buildVPlan(VFRange &Range, // Organize the ingredients to vectorize from current basic block in the // right order. - for (Instruction &I : *BB) { + for (Instruction &I : BB->instructionsWithoutDebug()) { Instruction *Instr = &I; // First filter out irrelevant instructions, to ensure no recipes are // built for them. - if (isa<BranchInst>(Instr) || isa<DbgInfoIntrinsic>(Instr) || - DeadInstructions.count(Instr)) + if (isa<BranchInst>(Instr) || DeadInstructions.count(Instr)) continue; // I is a member of an InterleaveGroup for Range.Start. If it's an adjunct // member of the IG, do not construct any Recipe for it. - const InterleaveGroup *IG = Legal->getInterleavedAccessGroup(Instr); + const InterleaveGroup *IG = CM.getInterleavedAccessGroup(Instr); if (IG && Instr != IG->getInsertPos() && Range.Start >= 2 && // Query is illegal for VF == 1 CM.getWideningDecision(Instr, Range.Start) == @@ -8230,8 +6989,9 @@ LoopVectorizationPlanner::buildVPlan(VFRange &Range, // should follow. auto SAIt = SinkAfter.find(Instr); if (SAIt != SinkAfter.end()) { - DEBUG(dbgs() << "Sinking" << *SAIt->first << " after" << *SAIt->second - << " to vectorize a 1st order recurrence.\n"); + LLVM_DEBUG(dbgs() << "Sinking" << *SAIt->first << " after" + << *SAIt->second + << " to vectorize a 1st order recurrence.\n"); SinkAfterInverse[SAIt->second] = Instr; continue; } @@ -8247,45 +7007,13 @@ LoopVectorizationPlanner::buildVPlan(VFRange &Range, // Introduce each ingredient into VPlan. for (Instruction *Instr : Ingredients) { - VPRecipeBase *Recipe = nullptr; - - // Check if Instr should belong to an interleave memory recipe, or already - // does. In the latter case Instr is irrelevant. - if ((Recipe = tryToInterleaveMemory(Instr, Range))) { - VPBB->appendRecipe(Recipe); - continue; - } - - // Check if Instr is a memory operation that should be widened. - if ((Recipe = tryToWidenMemory(Instr, Range, Plan))) { - VPBB->appendRecipe(Recipe); - continue; - } - - // Check if Instr should form some PHI recipe. - if ((Recipe = tryToOptimizeInduction(Instr, Range))) { - VPBB->appendRecipe(Recipe); - continue; - } - if ((Recipe = tryToBlend(Instr, Plan))) { - VPBB->appendRecipe(Recipe); - continue; - } - if (PHINode *Phi = dyn_cast<PHINode>(Instr)) { - VPBB->appendRecipe(new VPWidenPHIRecipe(Phi)); - continue; - } - - // Check if Instr is to be widened by a general VPWidenRecipe, after - // having first checked for specific widening recipes that deal with - // Interleave Groups, Inductions and Phi nodes. - if (tryToWiden(Instr, VPBB, Range)) + if (RecipeBuilder.tryToCreateRecipe(Instr, Range, Plan, VPBB)) continue; // Otherwise, if all widening options failed, Instruction is to be // replicated. This may create a successor for VPBB. - VPBasicBlock *NextVPBB = - handleReplication(Instr, Range, VPBB, PredInst2Recipe, Plan); + VPBasicBlock *NextVPBB = RecipeBuilder.handleReplication( + Instr, Range, VPBB, PredInst2Recipe, Plan); if (NextVPBB != VPBB) { VPBB = NextVPBB; VPBB->setName(BB->hasName() ? BB->getName() + "." + Twine(VPBBsForBB++) @@ -8300,7 +7028,7 @@ LoopVectorizationPlanner::buildVPlan(VFRange &Range, VPBasicBlock *PreEntry = cast<VPBasicBlock>(Plan->getEntry()); assert(PreEntry->empty() && "Expecting empty pre-entry block."); VPBlockBase *Entry = Plan->setEntry(PreEntry->getSingleSuccessor()); - PreEntry->disconnectSuccessor(Entry); + VPBlockUtils::disconnectBlocks(PreEntry, Entry); delete PreEntry; std::string PlanName; @@ -8319,6 +7047,30 @@ LoopVectorizationPlanner::buildVPlan(VFRange &Range, return Plan; } +LoopVectorizationPlanner::VPlanPtr +LoopVectorizationPlanner::buildVPlan(VFRange &Range) { + // Outer loop handling: They may require CFG and instruction level + // transformations before even evaluating whether vectorization is profitable. + // Since we cannot modify the incoming IR, we need to build VPlan upfront in + // the vectorization pipeline. + assert(!OrigLoop->empty()); + assert(EnableVPlanNativePath && "VPlan-native path is not enabled."); + + // Create new empty VPlan + auto Plan = llvm::make_unique<VPlan>(); + + // Build hierarchical CFG + VPlanHCFGBuilder HCFGBuilder(OrigLoop, LI); + HCFGBuilder.buildHierarchicalCFG(*Plan.get()); + + return Plan; +} + +Value* LoopVectorizationPlanner::VPCallbackILV:: +getOrCreateVectorValues(Value *V, unsigned Part) { + return ILV.getOrCreateVectorValue(V, Part); +} + void VPInterleaveRecipe::print(raw_ostream &O, const Twine &Indent) const { O << " +\n" << Indent << "\"INTERLEAVE-GROUP with factor " << IG->getFactor() << " at "; @@ -8483,28 +7235,66 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { State.ILV->vectorizeMemoryInstruction(&Instr, &MaskValues); } +// Process the loop in the VPlan-native vectorization path. This path builds +// VPlan upfront in the vectorization pipeline, which allows to apply +// VPlan-to-VPlan transformations from the very beginning without modifying the +// input LLVM IR. +static bool processLoopInVPlanNativePath( + Loop *L, PredicatedScalarEvolution &PSE, LoopInfo *LI, DominatorTree *DT, + LoopVectorizationLegality *LVL, TargetTransformInfo *TTI, + TargetLibraryInfo *TLI, DemandedBits *DB, AssumptionCache *AC, + OptimizationRemarkEmitter *ORE, LoopVectorizeHints &Hints) { + + assert(EnableVPlanNativePath && "VPlan-native path is disabled."); + Function *F = L->getHeader()->getParent(); + InterleavedAccessInfo IAI(PSE, L, DT, LI, LVL->getLAI()); + LoopVectorizationCostModel CM(L, PSE, LI, LVL, *TTI, TLI, DB, AC, ORE, F, + &Hints, IAI); + // Use the planner for outer loop vectorization. + // TODO: CM is not used at this point inside the planner. Turn CM into an + // optional argument if we don't need it in the future. + LoopVectorizationPlanner LVP(L, LI, TLI, TTI, LVL, CM); + + // Get user vectorization factor. + unsigned UserVF = Hints.getWidth(); + + // Check the function attributes to find out if this function should be + // optimized for size. + bool OptForSize = + Hints.getForce() != LoopVectorizeHints::FK_Enabled && F->optForSize(); + + // Plan how to best vectorize, return the best VF and its cost. + LVP.planInVPlanNativePath(OptForSize, UserVF); + + // Returning false. We are currently not generating vector code in the VPlan + // native path. + return false; +} + bool LoopVectorizePass::processLoop(Loop *L) { - assert(L->empty() && "Only process inner loops."); + assert((EnableVPlanNativePath || L->empty()) && + "VPlan-native path is not enabled. Only process inner loops."); #ifndef NDEBUG const std::string DebugLocStr = getDebugLocString(L); #endif /* NDEBUG */ - DEBUG(dbgs() << "\nLV: Checking a loop in \"" - << L->getHeader()->getParent()->getName() << "\" from " - << DebugLocStr << "\n"); + LLVM_DEBUG(dbgs() << "\nLV: Checking a loop in \"" + << L->getHeader()->getParent()->getName() << "\" from " + << DebugLocStr << "\n"); LoopVectorizeHints Hints(L, DisableUnrolling, *ORE); - DEBUG(dbgs() << "LV: Loop hints:" - << " force=" - << (Hints.getForce() == LoopVectorizeHints::FK_Disabled - ? "disabled" - : (Hints.getForce() == LoopVectorizeHints::FK_Enabled - ? "enabled" - : "?")) - << " width=" << Hints.getWidth() - << " unroll=" << Hints.getInterleave() << "\n"); + LLVM_DEBUG( + dbgs() << "LV: Loop hints:" + << " force=" + << (Hints.getForce() == LoopVectorizeHints::FK_Disabled + ? "disabled" + : (Hints.getForce() == LoopVectorizeHints::FK_Enabled + ? "enabled" + : "?")) + << " width=" << Hints.getWidth() + << " unroll=" << Hints.getInterleave() << "\n"); // Function containing loop Function *F = L->getHeader()->getParent(); @@ -8518,7 +7308,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { // benefit from vectorization, respectively. if (!Hints.allowVectorization(F, L, AlwaysVectorize)) { - DEBUG(dbgs() << "LV: Loop hints prevent vectorization.\n"); + LLVM_DEBUG(dbgs() << "LV: Loop hints prevent vectorization.\n"); return false; } @@ -8526,10 +7316,10 @@ bool LoopVectorizePass::processLoop(Loop *L) { // Check if it is legal to vectorize the loop. LoopVectorizationRequirements Requirements(*ORE); - LoopVectorizationLegality LVL(L, PSE, DT, TLI, AA, F, TTI, GetLAA, LI, ORE, - &Requirements, &Hints); - if (!LVL.canVectorize()) { - DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); + LoopVectorizationLegality LVL(L, PSE, DT, TLI, AA, F, GetLAA, LI, ORE, + &Requirements, &Hints, DB, AC); + if (!LVL.canVectorize(EnableVPlanNativePath)) { + LLVM_DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); emitMissedWarning(F, L, Hints, ORE); return false; } @@ -8539,11 +7329,33 @@ bool LoopVectorizePass::processLoop(Loop *L) { bool OptForSize = Hints.getForce() != LoopVectorizeHints::FK_Enabled && F->optForSize(); + // Entrance to the VPlan-native vectorization path. Outer loops are processed + // here. They may require CFG and instruction level transformations before + // even evaluating whether vectorization is profitable. Since we cannot modify + // the incoming IR, we need to build VPlan upfront in the vectorization + // pipeline. + if (!L->empty()) + return processLoopInVPlanNativePath(L, PSE, LI, DT, &LVL, TTI, TLI, DB, AC, + ORE, Hints); + + assert(L->empty() && "Inner loop expected."); // Check the loop for a trip count threshold: vectorize loops with a tiny trip // count by optimizing for size, to minimize overheads. - unsigned ExpectedTC = SE->getSmallConstantMaxTripCount(L); - bool HasExpectedTC = (ExpectedTC > 0); - + // Prefer constant trip counts over profile data, over upper bound estimate. + unsigned ExpectedTC = 0; + bool HasExpectedTC = false; + if (const SCEVConstant *ConstExits = + dyn_cast<SCEVConstant>(SE->getBackedgeTakenCount(L))) { + const APInt &ExitsCount = ConstExits->getAPInt(); + // We are interested in small values for ExpectedTC. Skip over those that + // can't fit an unsigned. + if (ExitsCount.ult(std::numeric_limits<unsigned>::max())) { + ExpectedTC = static_cast<unsigned>(ExitsCount.getZExtValue()) + 1; + HasExpectedTC = true; + } + } + // ExpectedTC may be large because it's bound by a variable. Check + // profiling information to validate we should vectorize. if (!HasExpectedTC && LoopVectorizeWithBlockFrequency) { auto EstimatedTC = getLoopEstimatedTripCount(L); if (EstimatedTC) { @@ -8551,15 +7363,19 @@ bool LoopVectorizePass::processLoop(Loop *L) { HasExpectedTC = true; } } + if (!HasExpectedTC) { + ExpectedTC = SE->getSmallConstantMaxTripCount(L); + HasExpectedTC = (ExpectedTC > 0); + } if (HasExpectedTC && ExpectedTC < TinyTripCountVectorThreshold) { - DEBUG(dbgs() << "LV: Found a loop with a very small trip count. " - << "This loop is worth vectorizing only if no scalar " - << "iteration overheads are incurred."); + LLVM_DEBUG(dbgs() << "LV: Found a loop with a very small trip count. " + << "This loop is worth vectorizing only if no scalar " + << "iteration overheads are incurred."); if (Hints.getForce() == LoopVectorizeHints::FK_Enabled) - DEBUG(dbgs() << " But vectorizing was explicitly forced.\n"); + LLVM_DEBUG(dbgs() << " But vectorizing was explicitly forced.\n"); else { - DEBUG(dbgs() << "\n"); + LLVM_DEBUG(dbgs() << "\n"); // Loops with a very small trip count are considered for vectorization // under OptForSize, thereby making sure the cost of their loop body is // dominant, free of runtime guards and scalar iteration overheads. @@ -8572,10 +7388,10 @@ bool LoopVectorizePass::processLoop(Loop *L) { // an integer loop and the vector instructions selected are purely integer // vector instructions? if (F->hasFnAttribute(Attribute::NoImplicitFloat)) { - DEBUG(dbgs() << "LV: Can't vectorize when the NoImplicitFloat" - "attribute is used.\n"); - ORE->emit(createMissedAnalysis(Hints.vectorizeAnalysisPassName(), - "NoImplicitFloat", L) + LLVM_DEBUG(dbgs() << "LV: Can't vectorize when the NoImplicitFloat" + "attribute is used.\n"); + ORE->emit(createLVMissedAnalysis(Hints.vectorizeAnalysisPassName(), + "NoImplicitFloat", L) << "loop not vectorized due to NoImplicitFloat attribute"); emitMissedWarning(F, L, Hints, ORE); return false; @@ -8587,17 +7403,30 @@ bool LoopVectorizePass::processLoop(Loop *L) { // additional fp-math flags can help. if (Hints.isPotentiallyUnsafe() && TTI->isFPVectorizationPotentiallyUnsafe()) { - DEBUG(dbgs() << "LV: Potentially unsafe FP op prevents vectorization.\n"); + LLVM_DEBUG( + dbgs() << "LV: Potentially unsafe FP op prevents vectorization.\n"); ORE->emit( - createMissedAnalysis(Hints.vectorizeAnalysisPassName(), "UnsafeFP", L) + createLVMissedAnalysis(Hints.vectorizeAnalysisPassName(), "UnsafeFP", L) << "loop not vectorized due to unsafe FP support."); emitMissedWarning(F, L, Hints, ORE); return false; } + bool UseInterleaved = TTI->enableInterleavedAccessVectorization(); + InterleavedAccessInfo IAI(PSE, L, DT, LI, LVL.getLAI()); + + // If an override option has been passed in for interleaved accesses, use it. + if (EnableInterleavedMemAccesses.getNumOccurrences() > 0) + UseInterleaved = EnableInterleavedMemAccesses; + + // Analyze interleaved memory accesses. + if (UseInterleaved) { + IAI.analyzeInterleaving(); + } + // Use the cost model. LoopVectorizationCostModel CM(L, PSE, LI, &LVL, *TTI, TLI, DB, AC, ORE, F, - &Hints); + &Hints, IAI); CM.collectValuesToIgnore(); // Use the planner for vectorization. @@ -8607,8 +7436,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { unsigned UserVF = Hints.getWidth(); // Plan how to best vectorize, return the best VF and its cost. - LoopVectorizationCostModel::VectorizationFactor VF = - LVP.plan(OptForSize, UserVF); + VectorizationFactor VF = LVP.plan(OptForSize, UserVF); // Select the interleave count. unsigned IC = CM.selectInterleaveCount(OptForSize, VF.Width, VF.Cost); @@ -8620,14 +7448,14 @@ bool LoopVectorizePass::processLoop(Loop *L) { std::pair<StringRef, std::string> VecDiagMsg, IntDiagMsg; bool VectorizeLoop = true, InterleaveLoop = true; if (Requirements.doesNotMeet(F, L, Hints)) { - DEBUG(dbgs() << "LV: Not vectorizing: loop did not meet vectorization " - "requirements.\n"); + LLVM_DEBUG(dbgs() << "LV: Not vectorizing: loop did not meet vectorization " + "requirements.\n"); emitMissedWarning(F, L, Hints, ORE); return false; } if (VF.Width == 1) { - DEBUG(dbgs() << "LV: Vectorization is possible but not beneficial.\n"); + LLVM_DEBUG(dbgs() << "LV: Vectorization is possible but not beneficial.\n"); VecDiagMsg = std::make_pair( "VectorizationNotBeneficial", "the cost-model indicates that vectorization is not beneficial"); @@ -8636,7 +7464,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { if (IC == 1 && UserIC <= 1) { // Tell the user interleaving is not beneficial. - DEBUG(dbgs() << "LV: Interleaving is not beneficial.\n"); + LLVM_DEBUG(dbgs() << "LV: Interleaving is not beneficial.\n"); IntDiagMsg = std::make_pair( "InterleavingNotBeneficial", "the cost-model indicates that interleaving is not beneficial"); @@ -8648,8 +7476,8 @@ bool LoopVectorizePass::processLoop(Loop *L) { } } else if (IC > 1 && UserIC == 1) { // Tell the user interleaving is beneficial, but it explicitly disabled. - DEBUG(dbgs() - << "LV: Interleaving is beneficial but is explicitly disabled."); + LLVM_DEBUG( + dbgs() << "LV: Interleaving is beneficial but is explicitly disabled."); IntDiagMsg = std::make_pair( "InterleavingBeneficialButDisabled", "the cost-model indicates that interleaving is beneficial " @@ -8676,24 +7504,24 @@ bool LoopVectorizePass::processLoop(Loop *L) { }); return false; } else if (!VectorizeLoop && InterleaveLoop) { - DEBUG(dbgs() << "LV: Interleave Count is " << IC << '\n'); + LLVM_DEBUG(dbgs() << "LV: Interleave Count is " << IC << '\n'); ORE->emit([&]() { return OptimizationRemarkAnalysis(VAPassName, VecDiagMsg.first, L->getStartLoc(), L->getHeader()) << VecDiagMsg.second; }); } else if (VectorizeLoop && !InterleaveLoop) { - DEBUG(dbgs() << "LV: Found a vectorizable loop (" << VF.Width << ") in " - << DebugLocStr << '\n'); + LLVM_DEBUG(dbgs() << "LV: Found a vectorizable loop (" << VF.Width + << ") in " << DebugLocStr << '\n'); ORE->emit([&]() { return OptimizationRemarkAnalysis(LV_NAME, IntDiagMsg.first, L->getStartLoc(), L->getHeader()) << IntDiagMsg.second; }); } else if (VectorizeLoop && InterleaveLoop) { - DEBUG(dbgs() << "LV: Found a vectorizable loop (" << VF.Width << ") in " - << DebugLocStr << '\n'); - DEBUG(dbgs() << "LV: Interleave Count is " << IC << '\n'); + LLVM_DEBUG(dbgs() << "LV: Found a vectorizable loop (" << VF.Width + << ") in " << DebugLocStr << '\n'); + LLVM_DEBUG(dbgs() << "LV: Interleave Count is " << IC << '\n'); } LVP.setBestPlan(VF.Width, IC); @@ -8740,7 +7568,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { // Mark the loop as already vectorized to avoid vectorizing again. Hints.setAlreadyVectorized(); - DEBUG(verifyFunction(*L->getHeader()->getParent())); + LLVM_DEBUG(verifyFunction(*L->getHeader()->getParent())); return true; } @@ -8788,7 +7616,7 @@ bool LoopVectorizePass::runImpl( SmallVector<Loop *, 8> Worklist; for (Loop *L : *LI) - addAcyclicInnerLoop(*L, Worklist); + collectSupportedLoops(*L, LI, ORE, Worklist); LoopsAnalyzed += Worklist.size(); diff --git a/lib/Transforms/Vectorize/SLPVectorizer.cpp b/lib/Transforms/Vectorize/SLPVectorizer.cpp index a7ccd3faec44..ac8c4f046c6f 100644 --- a/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -161,7 +161,7 @@ static const unsigned MaxMemDepDistance = 160; /// regions to be handled. static const int MinScheduleRegionSize = 16; -/// \brief Predicate for the element types that the SLP vectorizer supports. +/// Predicate for the element types that the SLP vectorizer supports. /// /// The most important thing to filter here are types which are invalid in LLVM /// vectors. We also filter target specific types which have absolutely no @@ -246,13 +246,15 @@ static bool isSplat(ArrayRef<Value *> VL) { /// %ins4 = insertelement <4 x i8> %ins3, i8 %9, i32 3 /// ret <4 x i8> %ins4 /// InstCombiner transforms this into a shuffle and vector mul +/// TODO: Can we split off and reuse the shuffle mask detection from +/// TargetTransformInfo::getInstructionThroughput? static Optional<TargetTransformInfo::ShuffleKind> isShuffle(ArrayRef<Value *> VL) { auto *EI0 = cast<ExtractElementInst>(VL[0]); unsigned Size = EI0->getVectorOperandType()->getVectorNumElements(); Value *Vec1 = nullptr; Value *Vec2 = nullptr; - enum ShuffleMode {Unknown, FirstAlternate, SecondAlternate, Permute}; + enum ShuffleMode { Unknown, Select, Permute }; ShuffleMode CommonShuffleMode = Unknown; for (unsigned I = 0, E = VL.size(); I < E; ++I) { auto *EI = cast<ExtractElementInst>(VL[I]); @@ -272,7 +274,11 @@ isShuffle(ArrayRef<Value *> VL) { continue; // For correct shuffling we have to have at most 2 different vector operands // in all extractelement instructions. - if (Vec1 && Vec2 && Vec != Vec1 && Vec != Vec2) + if (!Vec1 || Vec1 == Vec) + Vec1 = Vec; + else if (!Vec2 || Vec2 == Vec) + Vec2 = Vec; + else return None; if (CommonShuffleMode == Permute) continue; @@ -282,119 +288,17 @@ isShuffle(ArrayRef<Value *> VL) { CommonShuffleMode = Permute; continue; } - // Check the shuffle mode for the current operation. - if (!Vec1) - Vec1 = Vec; - else if (Vec != Vec1) - Vec2 = Vec; - // Example: shufflevector A, B, <0,5,2,7> - // I is odd and IntIdx for A == I - FirstAlternate shuffle. - // I is even and IntIdx for B == I - FirstAlternate shuffle. - // Example: shufflevector A, B, <4,1,6,3> - // I is even and IntIdx for A == I - SecondAlternate shuffle. - // I is odd and IntIdx for B == I - SecondAlternate shuffle. - const bool IIsEven = I & 1; - const bool CurrVecIsA = Vec == Vec1; - const bool IIsOdd = !IIsEven; - const bool CurrVecIsB = !CurrVecIsA; - ShuffleMode CurrentShuffleMode = - ((IIsOdd && CurrVecIsA) || (IIsEven && CurrVecIsB)) ? FirstAlternate - : SecondAlternate; - // Common mode is not set or the same as the shuffle mode of the current - // operation - alternate. - if (CommonShuffleMode == Unknown) - CommonShuffleMode = CurrentShuffleMode; - // Common shuffle mode is not the same as the shuffle mode of the current - // operation - permutation. - if (CommonShuffleMode != CurrentShuffleMode) - CommonShuffleMode = Permute; + CommonShuffleMode = Select; } // If we're not crossing lanes in different vectors, consider it as blending. - if ((CommonShuffleMode == FirstAlternate || - CommonShuffleMode == SecondAlternate) && - Vec2) - return TargetTransformInfo::SK_Alternate; + if (CommonShuffleMode == Select && Vec2) + return TargetTransformInfo::SK_Select; // If Vec2 was never used, we have a permutation of a single vector, otherwise // we have permutation of 2 vectors. return Vec2 ? TargetTransformInfo::SK_PermuteTwoSrc : TargetTransformInfo::SK_PermuteSingleSrc; } -///\returns Opcode that can be clubbed with \p Op to create an alternate -/// sequence which can later be merged as a ShuffleVector instruction. -static unsigned getAltOpcode(unsigned Op) { - switch (Op) { - case Instruction::FAdd: - return Instruction::FSub; - case Instruction::FSub: - return Instruction::FAdd; - case Instruction::Add: - return Instruction::Sub; - case Instruction::Sub: - return Instruction::Add; - default: - return 0; - } -} - -static bool isOdd(unsigned Value) { - return Value & 1; -} - -static bool sameOpcodeOrAlt(unsigned Opcode, unsigned AltOpcode, - unsigned CheckedOpcode) { - return Opcode == CheckedOpcode || AltOpcode == CheckedOpcode; -} - -/// Chooses the correct key for scheduling data. If \p Op has the same (or -/// alternate) opcode as \p OpValue, the key is \p Op. Otherwise the key is \p -/// OpValue. -static Value *isOneOf(Value *OpValue, Value *Op) { - auto *I = dyn_cast<Instruction>(Op); - if (!I) - return OpValue; - auto *OpInst = cast<Instruction>(OpValue); - unsigned OpInstOpcode = OpInst->getOpcode(); - unsigned IOpcode = I->getOpcode(); - if (sameOpcodeOrAlt(OpInstOpcode, getAltOpcode(OpInstOpcode), IOpcode)) - return Op; - return OpValue; -} - -namespace { - -/// Contains data for the instructions going to be vectorized. -struct RawInstructionsData { - /// Main Opcode of the instructions going to be vectorized. - unsigned Opcode = 0; - - /// The list of instructions have some instructions with alternate opcodes. - bool HasAltOpcodes = false; -}; - -} // end anonymous namespace - -/// Checks the list of the vectorized instructions \p VL and returns info about -/// this list. -static RawInstructionsData getMainOpcode(ArrayRef<Value *> VL) { - auto *I0 = dyn_cast<Instruction>(VL[0]); - if (!I0) - return {}; - RawInstructionsData Res; - unsigned Opcode = I0->getOpcode(); - // Walk through the list of the vectorized instructions - // in order to check its structure described by RawInstructionsData. - for (unsigned Cnt = 0, E = VL.size(); Cnt != E; ++Cnt) { - auto *I = dyn_cast<Instruction>(VL[Cnt]); - if (!I) - return {}; - if (Opcode != I->getOpcode()) - Res.HasAltOpcodes = true; - } - Res.Opcode = Opcode; - return Res; -} - namespace { /// Main data required for vectorization of instructions. @@ -402,42 +306,90 @@ struct InstructionsState { /// The very first instruction in the list with the main opcode. Value *OpValue = nullptr; - /// The main opcode for the list of instructions. - unsigned Opcode = 0; + /// The main/alternate instruction. + Instruction *MainOp = nullptr; + Instruction *AltOp = nullptr; + + /// The main/alternate opcodes for the list of instructions. + unsigned getOpcode() const { + return MainOp ? MainOp->getOpcode() : 0; + } + + unsigned getAltOpcode() const { + return AltOp ? AltOp->getOpcode() : 0; + } /// Some of the instructions in the list have alternate opcodes. - bool IsAltShuffle = false; + bool isAltShuffle() const { return getOpcode() != getAltOpcode(); } + + bool isOpcodeOrAlt(Instruction *I) const { + unsigned CheckedOpcode = I->getOpcode(); + return getOpcode() == CheckedOpcode || getAltOpcode() == CheckedOpcode; + } - InstructionsState() = default; - InstructionsState(Value *OpValue, unsigned Opcode, bool IsAltShuffle) - : OpValue(OpValue), Opcode(Opcode), IsAltShuffle(IsAltShuffle) {} + InstructionsState() = delete; + InstructionsState(Value *OpValue, Instruction *MainOp, Instruction *AltOp) + : OpValue(OpValue), MainOp(MainOp), AltOp(AltOp) {} }; } // end anonymous namespace +/// Chooses the correct key for scheduling data. If \p Op has the same (or +/// alternate) opcode as \p OpValue, the key is \p Op. Otherwise the key is \p +/// OpValue. +static Value *isOneOf(const InstructionsState &S, Value *Op) { + auto *I = dyn_cast<Instruction>(Op); + if (I && S.isOpcodeOrAlt(I)) + return Op; + return S.OpValue; +} + /// \returns analysis of the Instructions in \p VL described in /// InstructionsState, the Opcode that we suppose the whole list /// could be vectorized even if its structure is diverse. -static InstructionsState getSameOpcode(ArrayRef<Value *> VL) { - auto Res = getMainOpcode(VL); - unsigned Opcode = Res.Opcode; - if (!Res.HasAltOpcodes) - return InstructionsState(VL[0], Opcode, false); - auto *OpInst = cast<Instruction>(VL[0]); - unsigned AltOpcode = getAltOpcode(Opcode); - // Examine each element in the list instructions VL to determine - // if some operations there could be considered as an alternative - // (for example as subtraction relates to addition operation). +static InstructionsState getSameOpcode(ArrayRef<Value *> VL, + unsigned BaseIndex = 0) { + // Make sure these are all Instructions. + if (llvm::any_of(VL, [](Value *V) { return !isa<Instruction>(V); })) + return InstructionsState(VL[BaseIndex], nullptr, nullptr); + + bool IsCastOp = isa<CastInst>(VL[BaseIndex]); + bool IsBinOp = isa<BinaryOperator>(VL[BaseIndex]); + unsigned Opcode = cast<Instruction>(VL[BaseIndex])->getOpcode(); + unsigned AltOpcode = Opcode; + unsigned AltIndex = BaseIndex; + + // Check for one alternate opcode from another BinaryOperator. + // TODO - generalize to support all operators (types, calls etc.). for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) { - auto *I = cast<Instruction>(VL[Cnt]); - unsigned InstOpcode = I->getOpcode(); - if ((Res.HasAltOpcodes && - InstOpcode != (isOdd(Cnt) ? AltOpcode : Opcode)) || - (!Res.HasAltOpcodes && InstOpcode != Opcode)) { - return InstructionsState(OpInst, 0, false); - } + unsigned InstOpcode = cast<Instruction>(VL[Cnt])->getOpcode(); + if (IsBinOp && isa<BinaryOperator>(VL[Cnt])) { + if (InstOpcode == Opcode || InstOpcode == AltOpcode) + continue; + if (Opcode == AltOpcode) { + AltOpcode = InstOpcode; + AltIndex = Cnt; + continue; + } + } else if (IsCastOp && isa<CastInst>(VL[Cnt])) { + Type *Ty0 = cast<Instruction>(VL[BaseIndex])->getOperand(0)->getType(); + Type *Ty1 = cast<Instruction>(VL[Cnt])->getOperand(0)->getType(); + if (Ty0 == Ty1) { + if (InstOpcode == Opcode || InstOpcode == AltOpcode) + continue; + if (Opcode == AltOpcode) { + AltOpcode = InstOpcode; + AltIndex = Cnt; + continue; + } + } + } else if (InstOpcode == Opcode || InstOpcode == AltOpcode) + continue; + return InstructionsState(VL[BaseIndex], nullptr, nullptr); } - return InstructionsState(OpInst, Opcode, Res.HasAltOpcodes); + + return InstructionsState(VL[BaseIndex], cast<Instruction>(VL[BaseIndex]), + cast<Instruction>(VL[AltIndex])); } /// \returns true if all of the values in \p VL have the same type or false @@ -452,16 +404,21 @@ static bool allSameType(ArrayRef<Value *> VL) { } /// \returns True if Extract{Value,Element} instruction extracts element Idx. -static bool matchExtractIndex(Instruction *E, unsigned Idx, unsigned Opcode) { - assert(Opcode == Instruction::ExtractElement || - Opcode == Instruction::ExtractValue); +static Optional<unsigned> getExtractIndex(Instruction *E) { + unsigned Opcode = E->getOpcode(); + assert((Opcode == Instruction::ExtractElement || + Opcode == Instruction::ExtractValue) && + "Expected extractelement or extractvalue instruction."); if (Opcode == Instruction::ExtractElement) { - ConstantInt *CI = dyn_cast<ConstantInt>(E->getOperand(1)); - return CI && CI->getZExtValue() == Idx; - } else { - ExtractValueInst *EI = cast<ExtractValueInst>(E); - return EI->getNumIndices() == 1 && *EI->idx_begin() == Idx; + auto *CI = dyn_cast<ConstantInt>(E->getOperand(1)); + if (!CI) + return None; + return CI->getZExtValue(); } + ExtractValueInst *EI = cast<ExtractValueInst>(E); + if (EI->getNumIndices() != 1) + return None; + return *EI->idx_begin(); } /// \returns True if in-tree use also needs extract. This refers to @@ -549,7 +506,7 @@ public: MinVecRegSize = TTI->getMinVectorRegisterBitWidth(); } - /// \brief Vectorize the tree that starts with the elements in \p VL. + /// Vectorize the tree that starts with the elements in \p VL. /// Returns the vectorized root. Value *vectorizeTree(); @@ -585,8 +542,8 @@ public: ScalarToTreeEntry.clear(); MustGather.clear(); ExternalUses.clear(); - NumLoadsWantToKeepOrder = 0; - NumLoadsWantToChangeOrder = 0; + NumOpsWantToKeepOrder.clear(); + NumOpsWantToKeepOriginalOrder = 0; for (auto &Iter : BlocksSchedules) { BlockScheduling *BS = Iter.second.get(); BS->clear(); @@ -596,12 +553,22 @@ public: unsigned getTreeSize() const { return VectorizableTree.size(); } - /// \brief Perform LICM and CSE on the newly generated gather sequences. - void optimizeGatherSequence(Function &F); + /// Perform LICM and CSE on the newly generated gather sequences. + void optimizeGatherSequence(); + + /// \returns The best order of instructions for vectorization. + Optional<ArrayRef<unsigned>> bestOrder() const { + auto I = std::max_element( + NumOpsWantToKeepOrder.begin(), NumOpsWantToKeepOrder.end(), + [](const decltype(NumOpsWantToKeepOrder)::value_type &D1, + const decltype(NumOpsWantToKeepOrder)::value_type &D2) { + return D1.second < D2.second; + }); + if (I == NumOpsWantToKeepOrder.end() || + I->getSecond() <= NumOpsWantToKeepOriginalOrder) + return None; - /// \returns true if it is beneficial to reverse the vector order. - bool shouldReorder() const { - return NumLoadsWantToChangeOrder > NumLoadsWantToKeepOrder; + return makeArrayRef(I->getFirst()); } /// \return The vector element size in bits to use when vectorizing the @@ -625,7 +592,7 @@ public: return MinVecRegSize; } - /// \brief Check if ArrayType or StructType is isomorphic to some VectorType. + /// Check if ArrayType or StructType is isomorphic to some VectorType. /// /// \returns number of elements in vector if isomorphism exists, 0 otherwise. unsigned canMapToVector(Type *T, const DataLayout &DL) const; @@ -648,9 +615,13 @@ private: /// This is the recursive part of buildTree. void buildTree_rec(ArrayRef<Value *> Roots, unsigned Depth, int); - /// \returns True if the ExtractElement/ExtractValue instructions in VL can - /// be vectorized to use the original vector (or aggregate "bitcast" to a vector). - bool canReuseExtract(ArrayRef<Value *> VL, Value *OpValue) const; + /// \returns true if the ExtractElement/ExtractValue instructions in \p VL can + /// be vectorized to use the original vector (or aggregate "bitcast" to a + /// vector) and sets \p CurrentOrder to the identity permutation; otherwise + /// returns false, setting \p CurrentOrder to either an empty vector or a + /// non-identity permutation that allows to reuse extract instructions. + bool canReuseExtract(ArrayRef<Value *> VL, Value *OpValue, + SmallVectorImpl<unsigned> &CurrentOrder) const; /// Vectorize a single entry in the tree. Value *vectorizeTree(TreeEntry *E); @@ -658,22 +629,19 @@ private: /// Vectorize a single entry in the tree, starting in \p VL. Value *vectorizeTree(ArrayRef<Value *> VL); - /// \returns the pointer to the vectorized value if \p VL is already - /// vectorized, or NULL. They may happen in cycles. - Value *alreadyVectorized(ArrayRef<Value *> VL, Value *OpValue) const; - /// \returns the scalarization cost for this type. Scalarization in this /// context means the creation of vectors from a group of scalars. - int getGatherCost(Type *Ty); + int getGatherCost(Type *Ty, const DenseSet<unsigned> &ShuffledIndices); /// \returns the scalarization cost for this list of values. Assuming that /// this subtree gets vectorized, we may need to extract the values from the /// roots. This method calculates the cost of extracting the values. int getGatherCost(ArrayRef<Value *> VL); - /// \brief Set the Builder insert point to one after the last instruction in + /// Set the Builder insert point to one after the last instruction in /// the bundle - void setInsertPointAfterBundle(ArrayRef<Value *> VL, Value *OpValue); + void setInsertPointAfterBundle(ArrayRef<Value *> VL, + const InstructionsState &S); /// \returns a vector from a collection of scalars in \p VL. Value *Gather(ArrayRef<Value *> VL, VectorType *Ty); @@ -684,7 +652,8 @@ private: /// \reorder commutative operands in alt shuffle if they result in /// vectorized code. - void reorderAltShuffleOperands(unsigned Opcode, ArrayRef<Value *> VL, + void reorderAltShuffleOperands(const InstructionsState &S, + ArrayRef<Value *> VL, SmallVectorImpl<Value *> &Left, SmallVectorImpl<Value *> &Right); @@ -698,8 +667,12 @@ private: /// \returns true if the scalars in VL are equal to this entry. bool isSame(ArrayRef<Value *> VL) const { - assert(VL.size() == Scalars.size() && "Invalid size"); - return std::equal(VL.begin(), VL.end(), Scalars.begin()); + if (VL.size() == Scalars.size()) + return std::equal(VL.begin(), VL.end(), Scalars.begin()); + return VL.size() == ReuseShuffleIndices.size() && + std::equal( + VL.begin(), VL.end(), ReuseShuffleIndices.begin(), + [this](Value *V, unsigned Idx) { return V == Scalars[Idx]; }); } /// A vector of scalars. @@ -711,6 +684,12 @@ private: /// Do we need to gather this sequence ? bool NeedToGather = false; + /// Does this sequence require some shuffling? + SmallVector<unsigned, 4> ReuseShuffleIndices; + + /// Does this entry require reordering? + ArrayRef<unsigned> ReorderIndices; + /// Points back to the VectorizableTree. /// /// Only used for Graphviz right now. Unfortunately GraphTrait::NodeRef has @@ -725,13 +704,17 @@ private: }; /// Create a new VectorizableTree entry. - TreeEntry *newTreeEntry(ArrayRef<Value *> VL, bool Vectorized, - int &UserTreeIdx) { + void newTreeEntry(ArrayRef<Value *> VL, bool Vectorized, int &UserTreeIdx, + ArrayRef<unsigned> ReuseShuffleIndices = None, + ArrayRef<unsigned> ReorderIndices = None) { VectorizableTree.emplace_back(VectorizableTree); int idx = VectorizableTree.size() - 1; TreeEntry *Last = &VectorizableTree[idx]; Last->Scalars.insert(Last->Scalars.begin(), VL.begin(), VL.end()); Last->NeedToGather = !Vectorized; + Last->ReuseShuffleIndices.append(ReuseShuffleIndices.begin(), + ReuseShuffleIndices.end()); + Last->ReorderIndices = ReorderIndices; if (Vectorized) { for (int i = 0, e = VL.size(); i != e; ++i) { assert(!getTreeEntry(VL[i]) && "Scalar already in tree!"); @@ -744,7 +727,6 @@ private: if (UserTreeIdx >= 0) Last->UserTreeIndices.push_back(UserTreeIdx); UserTreeIdx = idx; - return Last; } /// -- Vectorization State -- @@ -758,13 +740,6 @@ private: return nullptr; } - const TreeEntry *getTreeEntry(Value *V) const { - auto I = ScalarToTreeEntry.find(V); - if (I != ScalarToTreeEntry.end()) - return &VectorizableTree[I->second]; - return nullptr; - } - /// Maps a specific scalar to its tree entry. SmallDenseMap<Value*, int> ScalarToTreeEntry; @@ -1038,7 +1013,7 @@ private: template <typename ReadyListType> void schedule(ScheduleData *SD, ReadyListType &ReadyList) { SD->IsScheduled = true; - DEBUG(dbgs() << "SLP: schedule " << *SD << "\n"); + LLVM_DEBUG(dbgs() << "SLP: schedule " << *SD << "\n"); ScheduleData *BundleMember = SD; while (BundleMember) { @@ -1061,8 +1036,8 @@ private: assert(!DepBundle->IsScheduled && "already scheduled bundle gets ready"); ReadyList.insert(DepBundle); - DEBUG(dbgs() - << "SLP: gets ready (def): " << *DepBundle << "\n"); + LLVM_DEBUG(dbgs() + << "SLP: gets ready (def): " << *DepBundle << "\n"); } }); } @@ -1075,8 +1050,8 @@ private: assert(!DepBundle->IsScheduled && "already scheduled bundle gets ready"); ReadyList.insert(DepBundle); - DEBUG(dbgs() << "SLP: gets ready (mem): " << *DepBundle - << "\n"); + LLVM_DEBUG(dbgs() + << "SLP: gets ready (mem): " << *DepBundle << "\n"); } } BundleMember = BundleMember->NextInBundle; @@ -1101,7 +1076,8 @@ private: doForAllOpcodes(I, [&](ScheduleData *SD) { if (SD->isSchedulingEntity() && SD->isReady()) { ReadyList.insert(SD); - DEBUG(dbgs() << "SLP: initially in ready list: " << *I << "\n"); + LLVM_DEBUG(dbgs() + << "SLP: initially in ready list: " << *I << "\n"); } }); } @@ -1110,7 +1086,8 @@ private: /// Checks if a bundle of instructions can be scheduled, i.e. has no /// cyclic dependencies. This is only a dry-run, no instructions are /// actually moved at this stage. - bool tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP, Value *OpValue); + bool tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP, + const InstructionsState &S); /// Un-bundles a group of instructions. void cancelScheduling(ArrayRef<Value *> VL, Value *OpValue); @@ -1120,7 +1097,7 @@ private: /// Extends the scheduling region so that V is inside the region. /// \returns true if the region size is within the limit. - bool extendSchedulingRegion(Value *V, Value *OpValue); + bool extendSchedulingRegion(Value *V, const InstructionsState &S); /// Initialize the ScheduleData structures for new instructions in the /// scheduling region. @@ -1201,11 +1178,38 @@ private: /// List of users to ignore during scheduling and that don't need extracting. ArrayRef<Value *> UserIgnoreList; - // Number of load bundles that contain consecutive loads. - int NumLoadsWantToKeepOrder = 0; + using OrdersType = SmallVector<unsigned, 4>; + /// A DenseMapInfo implementation for holding DenseMaps and DenseSets of + /// sorted SmallVectors of unsigned. + struct OrdersTypeDenseMapInfo { + static OrdersType getEmptyKey() { + OrdersType V; + V.push_back(~1U); + return V; + } + + static OrdersType getTombstoneKey() { + OrdersType V; + V.push_back(~2U); + return V; + } + + static unsigned getHashValue(const OrdersType &V) { + return static_cast<unsigned>(hash_combine_range(V.begin(), V.end())); + } + + static bool isEqual(const OrdersType &LHS, const OrdersType &RHS) { + return LHS == RHS; + } + }; - // Number of load bundles that contain consecutive loads in reversed order. - int NumLoadsWantToChangeOrder = 0; + /// Contains orders of operations along with the number of bundles that have + /// operations in this order. It stores only those orders that require + /// reordering, if reordering is not required it is counted using \a + /// NumOpsWantToKeepOriginalOrder. + DenseMap<OrdersType, unsigned, OrdersTypeDenseMapInfo> NumOpsWantToKeepOrder; + /// Number of bundles that do not require reordering. + unsigned NumOpsWantToKeepOriginalOrder = 0; // Analysis and block reference. Function *F; @@ -1242,7 +1246,7 @@ template <> struct GraphTraits<BoUpSLP *> { /// NodeRef has to be a pointer per the GraphWriter. using NodeRef = TreeEntry *; - /// \brief Add the VectorizableTree to the index iterator to be able to return + /// Add the VectorizableTree to the index iterator to be able to return /// TreeEntry pointers. struct ChildIteratorType : public iterator_adaptor_base<ChildIteratorType, @@ -1340,17 +1344,22 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, // For each lane: for (int Lane = 0, LE = Entry->Scalars.size(); Lane != LE; ++Lane) { Value *Scalar = Entry->Scalars[Lane]; + int FoundLane = Lane; + if (!Entry->ReuseShuffleIndices.empty()) { + FoundLane = + std::distance(Entry->ReuseShuffleIndices.begin(), + llvm::find(Entry->ReuseShuffleIndices, FoundLane)); + } // Check if the scalar is externally used as an extra arg. auto ExtI = ExternallyUsedValues.find(Scalar); if (ExtI != ExternallyUsedValues.end()) { - DEBUG(dbgs() << "SLP: Need to extract: Extra arg from lane " << - Lane << " from " << *Scalar << ".\n"); - ExternalUses.emplace_back(Scalar, nullptr, Lane); - continue; + LLVM_DEBUG(dbgs() << "SLP: Need to extract: Extra arg from lane " + << Lane << " from " << *Scalar << ".\n"); + ExternalUses.emplace_back(Scalar, nullptr, FoundLane); } for (User *U : Scalar->users()) { - DEBUG(dbgs() << "SLP: Checking user:" << *U << ".\n"); + LLVM_DEBUG(dbgs() << "SLP: Checking user:" << *U << ".\n"); Instruction *UserInst = dyn_cast<Instruction>(U); if (!UserInst) @@ -1364,8 +1373,8 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, // be used. if (UseScalar != U || !InTreeUserNeedToExtract(Scalar, UserInst, TLI)) { - DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << *U - << ".\n"); + LLVM_DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << *U + << ".\n"); assert(!UseEntry->NeedToGather && "Bad state"); continue; } @@ -1375,9 +1384,9 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, if (is_contained(UserIgnoreList, UserInst)) continue; - DEBUG(dbgs() << "SLP: Need to extract:" << *U << " from lane " << - Lane << " from " << *Scalar << ".\n"); - ExternalUses.push_back(ExternalUser(Scalar, U, Lane)); + LLVM_DEBUG(dbgs() << "SLP: Need to extract:" << *U << " from lane " + << Lane << " from " << *Scalar << ".\n"); + ExternalUses.push_back(ExternalUser(Scalar, U, FoundLane)); } } } @@ -1389,28 +1398,28 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, InstructionsState S = getSameOpcode(VL); if (Depth == RecursionMaxDepth) { - DEBUG(dbgs() << "SLP: Gathering due to max recursion depth.\n"); + LLVM_DEBUG(dbgs() << "SLP: Gathering due to max recursion depth.\n"); newTreeEntry(VL, false, UserTreeIdx); return; } // Don't handle vectors. if (S.OpValue->getType()->isVectorTy()) { - DEBUG(dbgs() << "SLP: Gathering due to vector type.\n"); + LLVM_DEBUG(dbgs() << "SLP: Gathering due to vector type.\n"); newTreeEntry(VL, false, UserTreeIdx); return; } if (StoreInst *SI = dyn_cast<StoreInst>(S.OpValue)) if (SI->getValueOperand()->getType()->isVectorTy()) { - DEBUG(dbgs() << "SLP: Gathering due to store vector type.\n"); + LLVM_DEBUG(dbgs() << "SLP: Gathering due to store vector type.\n"); newTreeEntry(VL, false, UserTreeIdx); return; } // If all of the operands are identical or constant we have a simple solution. - if (allConstant(VL) || isSplat(VL) || !allSameBlock(VL) || !S.Opcode) { - DEBUG(dbgs() << "SLP: Gathering due to C,S,B,O. \n"); + if (allConstant(VL) || isSplat(VL) || !allSameBlock(VL) || !S.getOpcode()) { + LLVM_DEBUG(dbgs() << "SLP: Gathering due to C,S,B,O. \n"); newTreeEntry(VL, false, UserTreeIdx); return; } @@ -1421,8 +1430,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // Don't vectorize ephemeral values. for (unsigned i = 0, e = VL.size(); i != e; ++i) { if (EphValues.count(VL[i])) { - DEBUG(dbgs() << "SLP: The instruction (" << *VL[i] << - ") is ephemeral.\n"); + LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *VL[i] + << ") is ephemeral.\n"); newTreeEntry(VL, false, UserTreeIdx); return; } @@ -1430,18 +1439,17 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // Check if this is a duplicate of another entry. if (TreeEntry *E = getTreeEntry(S.OpValue)) { - for (unsigned i = 0, e = VL.size(); i != e; ++i) { - DEBUG(dbgs() << "SLP: \tChecking bundle: " << *VL[i] << ".\n"); - if (E->Scalars[i] != VL[i]) { - DEBUG(dbgs() << "SLP: Gathering due to partial overlap.\n"); - newTreeEntry(VL, false, UserTreeIdx); - return; - } + LLVM_DEBUG(dbgs() << "SLP: \tChecking bundle: " << *S.OpValue << ".\n"); + if (!E->isSame(VL)) { + LLVM_DEBUG(dbgs() << "SLP: Gathering due to partial overlap.\n"); + newTreeEntry(VL, false, UserTreeIdx); + return; } // Record the reuse of the tree node. FIXME, currently this is only used to // properly draw the graph rather than for the actual vectorization. E->UserTreeIndices.push_back(UserTreeIdx); - DEBUG(dbgs() << "SLP: Perfect diamond merge at " << *S.OpValue << ".\n"); + LLVM_DEBUG(dbgs() << "SLP: Perfect diamond merge at " << *S.OpValue + << ".\n"); return; } @@ -1451,8 +1459,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (!I) continue; if (getTreeEntry(I)) { - DEBUG(dbgs() << "SLP: The instruction (" << *VL[i] << - ") is already in tree.\n"); + LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *VL[i] + << ") is already in tree.\n"); newTreeEntry(VL, false, UserTreeIdx); return; } @@ -1462,7 +1470,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // we need to gather the scalars. for (unsigned i = 0, e = VL.size(); i != e; ++i) { if (MustGather.count(VL[i])) { - DEBUG(dbgs() << "SLP: Gathering due to gathered scalar.\n"); + LLVM_DEBUG(dbgs() << "SLP: Gathering due to gathered scalar.\n"); newTreeEntry(VL, false, UserTreeIdx); return; } @@ -1476,19 +1484,32 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (!DT->isReachableFromEntry(BB)) { // Don't go into unreachable blocks. They may contain instructions with // dependency cycles which confuse the final scheduling. - DEBUG(dbgs() << "SLP: bundle in unreachable block.\n"); + LLVM_DEBUG(dbgs() << "SLP: bundle in unreachable block.\n"); newTreeEntry(VL, false, UserTreeIdx); return; } // Check that every instruction appears once in this bundle. - for (unsigned i = 0, e = VL.size(); i < e; ++i) - for (unsigned j = i + 1; j < e; ++j) - if (VL[i] == VL[j]) { - DEBUG(dbgs() << "SLP: Scalar used twice in bundle.\n"); - newTreeEntry(VL, false, UserTreeIdx); - return; - } + SmallVector<unsigned, 4> ReuseShuffleIndicies; + SmallVector<Value *, 4> UniqueValues; + DenseMap<Value *, unsigned> UniquePositions; + for (Value *V : VL) { + auto Res = UniquePositions.try_emplace(V, UniqueValues.size()); + ReuseShuffleIndicies.emplace_back(Res.first->second); + if (Res.second) + UniqueValues.emplace_back(V); + } + if (UniqueValues.size() == VL.size()) { + ReuseShuffleIndicies.clear(); + } else { + LLVM_DEBUG(dbgs() << "SLP: Shuffle for reused scalars.\n"); + if (UniqueValues.size() <= 1 || !llvm::isPowerOf2_32(UniqueValues.size())) { + LLVM_DEBUG(dbgs() << "SLP: Scalar used twice in bundle.\n"); + newTreeEntry(VL, false, UserTreeIdx); + return; + } + VL = UniqueValues; + } auto &BSRef = BlocksSchedules[BB]; if (!BSRef) @@ -1496,18 +1517,18 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, BlockScheduling &BS = *BSRef.get(); - if (!BS.tryScheduleBundle(VL, this, S.OpValue)) { - DEBUG(dbgs() << "SLP: We are not able to schedule this bundle!\n"); + if (!BS.tryScheduleBundle(VL, this, S)) { + LLVM_DEBUG(dbgs() << "SLP: We are not able to schedule this bundle!\n"); assert((!BS.getScheduleData(VL0) || !BS.getScheduleData(VL0)->isPartOfBundle()) && "tryScheduleBundle should cancelScheduling on failure"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); return; } - DEBUG(dbgs() << "SLP: We are able to schedule this bundle.\n"); + LLVM_DEBUG(dbgs() << "SLP: We are able to schedule this bundle.\n"); - unsigned ShuffleOrOp = S.IsAltShuffle ? - (unsigned) Instruction::ShuffleVector : S.Opcode; + unsigned ShuffleOrOp = S.isAltShuffle() ? + (unsigned) Instruction::ShuffleVector : S.getOpcode(); switch (ShuffleOrOp) { case Instruction::PHI: { PHINode *PH = dyn_cast<PHINode>(VL0); @@ -1518,15 +1539,17 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, TerminatorInst *Term = dyn_cast<TerminatorInst>( cast<PHINode>(VL[j])->getIncomingValueForBlock(PH->getIncomingBlock(i))); if (Term) { - DEBUG(dbgs() << "SLP: Need to swizzle PHINodes (TerminatorInst use).\n"); + LLVM_DEBUG( + dbgs() + << "SLP: Need to swizzle PHINodes (TerminatorInst use).\n"); BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); return; } } - newTreeEntry(VL, true, UserTreeIdx); - DEBUG(dbgs() << "SLP: added a vector of PHINodes.\n"); + newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: added a vector of PHINodes.\n"); for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) { ValueList Operands; @@ -1541,13 +1564,35 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } case Instruction::ExtractValue: case Instruction::ExtractElement: { - bool Reuse = canReuseExtract(VL, VL0); + OrdersType CurrentOrder; + bool Reuse = canReuseExtract(VL, VL0, CurrentOrder); if (Reuse) { - DEBUG(dbgs() << "SLP: Reusing extract sequence.\n"); - } else { - BS.cancelScheduling(VL, VL0); + LLVM_DEBUG(dbgs() << "SLP: Reusing or shuffling extract sequence.\n"); + ++NumOpsWantToKeepOriginalOrder; + newTreeEntry(VL, /*Vectorized=*/true, UserTreeIdx, + ReuseShuffleIndicies); + return; } - newTreeEntry(VL, Reuse, UserTreeIdx); + if (!CurrentOrder.empty()) { + LLVM_DEBUG({ + dbgs() << "SLP: Reusing or shuffling of reordered extract sequence " + "with order"; + for (unsigned Idx : CurrentOrder) + dbgs() << " " << Idx; + dbgs() << "\n"; + }); + // Insert new order with initial value 0, if it does not exist, + // otherwise return the iterator to the existing one. + auto StoredCurrentOrderAndNum = + NumOpsWantToKeepOrder.try_emplace(CurrentOrder).first; + ++StoredCurrentOrderAndNum->getSecond(); + newTreeEntry(VL, /*Vectorized=*/true, UserTreeIdx, ReuseShuffleIndicies, + StoredCurrentOrderAndNum->getFirst()); + return; + } + LLVM_DEBUG(dbgs() << "SLP: Gather extract sequence.\n"); + newTreeEntry(VL, /*Vectorized=*/false, UserTreeIdx, ReuseShuffleIndicies); + BS.cancelScheduling(VL, VL0); return; } case Instruction::Load: { @@ -1562,62 +1607,67 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (DL->getTypeSizeInBits(ScalarTy) != DL->getTypeAllocSizeInBits(ScalarTy)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); - DEBUG(dbgs() << "SLP: Gathering loads of non-packed type.\n"); + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: Gathering loads of non-packed type.\n"); return; } // Make sure all loads in the bundle are simple - we can't vectorize // atomic or volatile loads. - for (unsigned i = 0, e = VL.size() - 1; i < e; ++i) { - LoadInst *L = cast<LoadInst>(VL[i]); + SmallVector<Value *, 4> PointerOps(VL.size()); + auto POIter = PointerOps.begin(); + for (Value *V : VL) { + auto *L = cast<LoadInst>(V); if (!L->isSimple()) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); - DEBUG(dbgs() << "SLP: Gathering non-simple loads.\n"); + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: Gathering non-simple loads.\n"); return; } + *POIter = L->getPointerOperand(); + ++POIter; } - // Check if the loads are consecutive, reversed, or neither. - // TODO: What we really want is to sort the loads, but for now, check - // the two likely directions. - bool Consecutive = true; - bool ReverseConsecutive = true; - for (unsigned i = 0, e = VL.size() - 1; i < e; ++i) { - if (!isConsecutiveAccess(VL[i], VL[i + 1], *DL, *SE)) { - Consecutive = false; - break; + OrdersType CurrentOrder; + // Check the order of pointer operands. + if (llvm::sortPtrAccesses(PointerOps, *DL, *SE, CurrentOrder)) { + Value *Ptr0; + Value *PtrN; + if (CurrentOrder.empty()) { + Ptr0 = PointerOps.front(); + PtrN = PointerOps.back(); } else { - ReverseConsecutive = false; + Ptr0 = PointerOps[CurrentOrder.front()]; + PtrN = PointerOps[CurrentOrder.back()]; } - } - - if (Consecutive) { - ++NumLoadsWantToKeepOrder; - newTreeEntry(VL, true, UserTreeIdx); - DEBUG(dbgs() << "SLP: added a vector of loads.\n"); - return; - } - - // If none of the load pairs were consecutive when checked in order, - // check the reverse order. - if (ReverseConsecutive) - for (unsigned i = VL.size() - 1; i > 0; --i) - if (!isConsecutiveAccess(VL[i], VL[i - 1], *DL, *SE)) { - ReverseConsecutive = false; - break; + const SCEV *Scev0 = SE->getSCEV(Ptr0); + const SCEV *ScevN = SE->getSCEV(PtrN); + const auto *Diff = + dyn_cast<SCEVConstant>(SE->getMinusSCEV(ScevN, Scev0)); + uint64_t Size = DL->getTypeAllocSize(ScalarTy); + // Check that the sorted loads are consecutive. + if (Diff && Diff->getAPInt().getZExtValue() == (VL.size() - 1) * Size) { + if (CurrentOrder.empty()) { + // Original loads are consecutive and does not require reordering. + ++NumOpsWantToKeepOriginalOrder; + newTreeEntry(VL, /*Vectorized=*/true, UserTreeIdx, + ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: added a vector of loads.\n"); + } else { + // Need to reorder. + auto I = NumOpsWantToKeepOrder.try_emplace(CurrentOrder).first; + ++I->getSecond(); + newTreeEntry(VL, /*Vectorized=*/true, UserTreeIdx, + ReuseShuffleIndicies, I->getFirst()); + LLVM_DEBUG(dbgs() << "SLP: added a vector of jumbled loads.\n"); } + return; + } + } + LLVM_DEBUG(dbgs() << "SLP: Gathering non-consecutive loads.\n"); BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); - - if (ReverseConsecutive) { - ++NumLoadsWantToChangeOrder; - DEBUG(dbgs() << "SLP: Gathering reversed loads.\n"); - } else { - DEBUG(dbgs() << "SLP: Gathering non-consecutive loads.\n"); - } + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); return; } case Instruction::ZExt: @@ -1637,13 +1687,14 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, Type *Ty = cast<Instruction>(VL[i])->getOperand(0)->getType(); if (Ty != SrcTy || !isValidElementType(Ty)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); - DEBUG(dbgs() << "SLP: Gathering casts with different src types.\n"); + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() + << "SLP: Gathering casts with different src types.\n"); return; } } - newTreeEntry(VL, true, UserTreeIdx); - DEBUG(dbgs() << "SLP: added a vector of casts.\n"); + newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: added a vector of casts.\n"); for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { ValueList Operands; @@ -1665,14 +1716,15 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (Cmp->getPredicate() != P0 || Cmp->getOperand(0)->getType() != ComparedTy) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); - DEBUG(dbgs() << "SLP: Gathering cmp with different predicate.\n"); + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() + << "SLP: Gathering cmp with different predicate.\n"); return; } } - newTreeEntry(VL, true, UserTreeIdx); - DEBUG(dbgs() << "SLP: added a vector of compares.\n"); + newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: added a vector of compares.\n"); for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { ValueList Operands; @@ -1703,14 +1755,14 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, case Instruction::And: case Instruction::Or: case Instruction::Xor: - newTreeEntry(VL, true, UserTreeIdx); - DEBUG(dbgs() << "SLP: added a vector of bin op.\n"); + newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: added a vector of bin op.\n"); // Sort operands of the instructions so that each side is more likely to // have the same opcode. if (isa<BinaryOperator>(VL0) && VL0->isCommutative()) { ValueList Left, Right; - reorderInputsAccordingToOpcode(S.Opcode, VL, Left, Right); + reorderInputsAccordingToOpcode(S.getOpcode(), VL, Left, Right); buildTree_rec(Left, Depth + 1, UserTreeIdx); buildTree_rec(Right, Depth + 1, UserTreeIdx); return; @@ -1730,9 +1782,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // We don't combine GEPs with complicated (nested) indexing. for (unsigned j = 0; j < VL.size(); ++j) { if (cast<Instruction>(VL[j])->getNumOperands() != 2) { - DEBUG(dbgs() << "SLP: not-vectorizable GEP (nested indexes).\n"); + LLVM_DEBUG(dbgs() << "SLP: not-vectorizable GEP (nested indexes).\n"); BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); return; } } @@ -1743,9 +1795,10 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, for (unsigned j = 0; j < VL.size(); ++j) { Type *CurTy = cast<Instruction>(VL[j])->getOperand(0)->getType(); if (Ty0 != CurTy) { - DEBUG(dbgs() << "SLP: not-vectorizable GEP (different types).\n"); + LLVM_DEBUG(dbgs() + << "SLP: not-vectorizable GEP (different types).\n"); BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); return; } } @@ -1754,16 +1807,16 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, for (unsigned j = 0; j < VL.size(); ++j) { auto Op = cast<Instruction>(VL[j])->getOperand(1); if (!isa<ConstantInt>(Op)) { - DEBUG( - dbgs() << "SLP: not-vectorizable GEP (non-constant indexes).\n"); + LLVM_DEBUG(dbgs() + << "SLP: not-vectorizable GEP (non-constant indexes).\n"); BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); return; } } - newTreeEntry(VL, true, UserTreeIdx); - DEBUG(dbgs() << "SLP: added a vector of GEPs.\n"); + newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: added a vector of GEPs.\n"); for (unsigned i = 0, e = 2; i < e; ++i) { ValueList Operands; // Prepare the operand vector. @@ -1779,13 +1832,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, for (unsigned i = 0, e = VL.size() - 1; i < e; ++i) if (!isConsecutiveAccess(VL[i], VL[i + 1], *DL, *SE)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); - DEBUG(dbgs() << "SLP: Non-consecutive store.\n"); + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: Non-consecutive store.\n"); return; } - newTreeEntry(VL, true, UserTreeIdx); - DEBUG(dbgs() << "SLP: added a vector of stores.\n"); + newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: added a vector of stores.\n"); ValueList Operands; for (Value *j : VL) @@ -1802,8 +1855,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); if (!isTriviallyVectorizable(ID)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); - DEBUG(dbgs() << "SLP: Non-vectorizable call.\n"); + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: Non-vectorizable call.\n"); return; } Function *Int = CI->getCalledFunction(); @@ -1816,9 +1869,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, getVectorIntrinsicIDForCall(CI2, TLI) != ID || !CI->hasIdenticalOperandBundleSchema(*CI2)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); - DEBUG(dbgs() << "SLP: mismatched calls:" << *CI << "!=" << *VL[i] - << "\n"); + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: mismatched calls:" << *CI << "!=" << *VL[i] + << "\n"); return; } // ctlz,cttz and powi are special intrinsics whose second argument @@ -1827,10 +1880,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, Value *A1J = CI2->getArgOperand(1); if (A1I != A1J) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); - DEBUG(dbgs() << "SLP: mismatched arguments in call:" << *CI - << " argument "<< A1I<<"!=" << A1J - << "\n"); + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: mismatched arguments in call:" << *CI + << " argument " << A1I << "!=" << A1J << "\n"); return; } } @@ -1840,14 +1892,14 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, CI->op_begin() + CI->getBundleOperandsEndIndex(), CI2->op_begin() + CI2->getBundleOperandsStartIndex())) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); - DEBUG(dbgs() << "SLP: mismatched bundle operands in calls:" << *CI << "!=" - << *VL[i] << '\n'); + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: mismatched bundle operands in calls:" + << *CI << "!=" << *VL[i] << '\n'); return; } } - newTreeEntry(VL, true, UserTreeIdx); + newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) { ValueList Operands; // Prepare the operand vector. @@ -1862,19 +1914,19 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, case Instruction::ShuffleVector: // If this is not an alternate sequence of opcode like add-sub // then do not vectorize this instruction. - if (!S.IsAltShuffle) { + if (!S.isAltShuffle()) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); - DEBUG(dbgs() << "SLP: ShuffleVector are not vectorized.\n"); + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: ShuffleVector are not vectorized.\n"); return; } - newTreeEntry(VL, true, UserTreeIdx); - DEBUG(dbgs() << "SLP: added a ShuffleVector op.\n"); + newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: added a ShuffleVector op.\n"); // Reorder operands if reordering would enable vectorization. if (isa<BinaryOperator>(VL0)) { ValueList Left, Right; - reorderAltShuffleOperands(S.Opcode, VL, Left, Right); + reorderAltShuffleOperands(S, VL, Left, Right); buildTree_rec(Left, Depth + 1, UserTreeIdx); buildTree_rec(Right, Depth + 1, UserTreeIdx); return; @@ -1892,8 +1944,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, default: BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); - DEBUG(dbgs() << "SLP: Gathering unknown instruction.\n"); + newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); + LLVM_DEBUG(dbgs() << "SLP: Gathering unknown instruction.\n"); return; } } @@ -1923,15 +1975,18 @@ unsigned BoUpSLP::canMapToVector(Type *T, const DataLayout &DL) const { return N; } -bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, Value *OpValue) const { +bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, Value *OpValue, + SmallVectorImpl<unsigned> &CurrentOrder) const { Instruction *E0 = cast<Instruction>(OpValue); assert(E0->getOpcode() == Instruction::ExtractElement || E0->getOpcode() == Instruction::ExtractValue); - assert(E0->getOpcode() == getSameOpcode(VL).Opcode && "Invalid opcode"); + assert(E0->getOpcode() == getSameOpcode(VL).getOpcode() && "Invalid opcode"); // Check if all of the extracts come from the same vector and from the // correct offset. Value *Vec = E0->getOperand(0); + CurrentOrder.clear(); + // We have to extract from a vector/aggregate with the same number of elements. unsigned NElts; if (E0->getOpcode() == Instruction::ExtractValue) { @@ -1951,15 +2006,40 @@ bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, Value *OpValue) const { return false; // Check that all of the indices extract from the correct offset. - for (unsigned I = 0, E = VL.size(); I < E; ++I) { - Instruction *Inst = cast<Instruction>(VL[I]); - if (!matchExtractIndex(Inst, I, Inst->getOpcode())) - return false; + bool ShouldKeepOrder = true; + unsigned E = VL.size(); + // Assign to all items the initial value E + 1 so we can check if the extract + // instruction index was used already. + // Also, later we can check that all the indices are used and we have a + // consecutive access in the extract instructions, by checking that no + // element of CurrentOrder still has value E + 1. + CurrentOrder.assign(E, E + 1); + unsigned I = 0; + for (; I < E; ++I) { + auto *Inst = cast<Instruction>(VL[I]); if (Inst->getOperand(0) != Vec) - return false; + break; + Optional<unsigned> Idx = getExtractIndex(Inst); + if (!Idx) + break; + const unsigned ExtIdx = *Idx; + if (ExtIdx != I) { + if (ExtIdx >= E || CurrentOrder[ExtIdx] != E + 1) + break; + ShouldKeepOrder = false; + CurrentOrder[ExtIdx] = I; + } else { + if (CurrentOrder[I] != E + 1) + break; + CurrentOrder[I] = I; + } + } + if (I < E) { + CurrentOrder.clear(); + return false; } - return true; + return ShouldKeepOrder; } bool BoUpSLP::areAllUsersVectorized(Instruction *I) const { @@ -1985,13 +2065,22 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { VecTy = VectorType::get( IntegerType::get(F->getContext(), MinBWs[VL[0]].first), VL.size()); + unsigned ReuseShuffleNumbers = E->ReuseShuffleIndices.size(); + bool NeedToShuffleReuses = !E->ReuseShuffleIndices.empty(); + int ReuseShuffleCost = 0; + if (NeedToShuffleReuses) { + ReuseShuffleCost = + TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy); + } if (E->NeedToGather) { if (allConstant(VL)) return 0; if (isSplat(VL)) { - return TTI->getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, 0); + return ReuseShuffleCost + + TTI->getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, 0); } - if (getSameOpcode(VL).Opcode == Instruction::ExtractElement) { + if (getSameOpcode(VL).getOpcode() == Instruction::ExtractElement && + allSameType(VL) && allSameBlock(VL)) { Optional<TargetTransformInfo::ShuffleKind> ShuffleKind = isShuffle(VL); if (ShuffleKind.hasValue()) { int Cost = TTI->getShuffleCost(ShuffleKind.getValue(), VecTy); @@ -2008,37 +2097,86 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { IO->getZExtValue()); } } - return Cost; + return ReuseShuffleCost + Cost; } } - return getGatherCost(E->Scalars); + return ReuseShuffleCost + getGatherCost(VL); } InstructionsState S = getSameOpcode(VL); - assert(S.Opcode && allSameType(VL) && allSameBlock(VL) && "Invalid VL"); + assert(S.getOpcode() && allSameType(VL) && allSameBlock(VL) && "Invalid VL"); Instruction *VL0 = cast<Instruction>(S.OpValue); - unsigned ShuffleOrOp = S.IsAltShuffle ? - (unsigned) Instruction::ShuffleVector : S.Opcode; + unsigned ShuffleOrOp = S.isAltShuffle() ? + (unsigned) Instruction::ShuffleVector : S.getOpcode(); switch (ShuffleOrOp) { case Instruction::PHI: return 0; case Instruction::ExtractValue: case Instruction::ExtractElement: - if (canReuseExtract(VL, S.OpValue)) { - int DeadCost = 0; + if (NeedToShuffleReuses) { + unsigned Idx = 0; + for (unsigned I : E->ReuseShuffleIndices) { + if (ShuffleOrOp == Instruction::ExtractElement) { + auto *IO = cast<ConstantInt>( + cast<ExtractElementInst>(VL[I])->getIndexOperand()); + Idx = IO->getZExtValue(); + ReuseShuffleCost -= TTI->getVectorInstrCost( + Instruction::ExtractElement, VecTy, Idx); + } else { + ReuseShuffleCost -= TTI->getVectorInstrCost( + Instruction::ExtractElement, VecTy, Idx); + ++Idx; + } + } + Idx = ReuseShuffleNumbers; + for (Value *V : VL) { + if (ShuffleOrOp == Instruction::ExtractElement) { + auto *IO = cast<ConstantInt>( + cast<ExtractElementInst>(V)->getIndexOperand()); + Idx = IO->getZExtValue(); + } else { + --Idx; + } + ReuseShuffleCost += + TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, Idx); + } + } + if (!E->NeedToGather) { + int DeadCost = ReuseShuffleCost; + if (!E->ReorderIndices.empty()) { + // TODO: Merge this shuffle with the ReuseShuffleCost. + DeadCost += TTI->getShuffleCost( + TargetTransformInfo::SK_PermuteSingleSrc, VecTy); + } for (unsigned i = 0, e = VL.size(); i < e; ++i) { Instruction *E = cast<Instruction>(VL[i]); // If all users are going to be vectorized, instruction can be // considered as dead. // The same, if have only one user, it will be vectorized for sure. - if (areAllUsersVectorized(E)) + if (areAllUsersVectorized(E)) { // Take credit for instruction that will become dead. - DeadCost += + if (E->hasOneUse()) { + Instruction *Ext = E->user_back(); + if ((isa<SExtInst>(Ext) || isa<ZExtInst>(Ext)) && + all_of(Ext->users(), + [](User *U) { return isa<GetElementPtrInst>(U); })) { + // Use getExtractWithExtendCost() to calculate the cost of + // extractelement/ext pair. + DeadCost -= TTI->getExtractWithExtendCost( + Ext->getOpcode(), Ext->getType(), VecTy, i); + // Add back the cost of s|zext which is subtracted seperately. + DeadCost += TTI->getCastInstrCost( + Ext->getOpcode(), Ext->getType(), E->getType(), Ext); + continue; + } + } + DeadCost -= TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, i); + } } - return -DeadCost; + return DeadCost; } - return getGatherCost(VecTy); + return ReuseShuffleCost + getGatherCost(VL); case Instruction::ZExt: case Instruction::SExt: @@ -2053,24 +2191,37 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { case Instruction::FPTrunc: case Instruction::BitCast: { Type *SrcTy = VL0->getOperand(0)->getType(); + int ScalarEltCost = + TTI->getCastInstrCost(S.getOpcode(), ScalarTy, SrcTy, VL0); + if (NeedToShuffleReuses) { + ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost; + } // Calculate the cost of this instruction. - int ScalarCost = VL.size() * TTI->getCastInstrCost(VL0->getOpcode(), - VL0->getType(), SrcTy, VL0); + int ScalarCost = VL.size() * ScalarEltCost; VectorType *SrcVecTy = VectorType::get(SrcTy, VL.size()); - int VecCost = TTI->getCastInstrCost(VL0->getOpcode(), VecTy, SrcVecTy, VL0); + int VecCost = 0; + // Check if the values are candidates to demote. + if (!MinBWs.count(VL0) || VecTy != SrcVecTy) { + VecCost = ReuseShuffleCost + + TTI->getCastInstrCost(S.getOpcode(), VecTy, SrcVecTy, VL0); + } return VecCost - ScalarCost; } case Instruction::FCmp: case Instruction::ICmp: case Instruction::Select: { // Calculate the cost of this instruction. + int ScalarEltCost = TTI->getCmpSelInstrCost(S.getOpcode(), ScalarTy, + Builder.getInt1Ty(), VL0); + if (NeedToShuffleReuses) { + ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost; + } VectorType *MaskTy = VectorType::get(Builder.getInt1Ty(), VL.size()); - int ScalarCost = VecTy->getNumElements() * - TTI->getCmpSelInstrCost(S.Opcode, ScalarTy, Builder.getInt1Ty(), VL0); - int VecCost = TTI->getCmpSelInstrCost(S.Opcode, VecTy, MaskTy, VL0); - return VecCost - ScalarCost; + int ScalarCost = VecTy->getNumElements() * ScalarEltCost; + int VecCost = TTI->getCmpSelInstrCost(S.getOpcode(), VecTy, MaskTy, VL0); + return ReuseShuffleCost + VecCost - ScalarCost; } case Instruction::Add: case Instruction::FAdd: @@ -2099,42 +2250,43 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { TargetTransformInfo::OperandValueProperties Op1VP = TargetTransformInfo::OP_None; TargetTransformInfo::OperandValueProperties Op2VP = - TargetTransformInfo::OP_None; + TargetTransformInfo::OP_PowerOf2; // If all operands are exactly the same ConstantInt then set the // operand kind to OK_UniformConstantValue. // If instead not all operands are constants, then set the operand kind // to OK_AnyValue. If all operands are constants but not the same, // then set the operand kind to OK_NonUniformConstantValue. - ConstantInt *CInt = nullptr; - for (unsigned i = 0; i < VL.size(); ++i) { + ConstantInt *CInt0 = nullptr; + for (unsigned i = 0, e = VL.size(); i < e; ++i) { const Instruction *I = cast<Instruction>(VL[i]); - if (!isa<ConstantInt>(I->getOperand(1))) { + ConstantInt *CInt = dyn_cast<ConstantInt>(I->getOperand(1)); + if (!CInt) { Op2VK = TargetTransformInfo::OK_AnyValue; + Op2VP = TargetTransformInfo::OP_None; break; } + if (Op2VP == TargetTransformInfo::OP_PowerOf2 && + !CInt->getValue().isPowerOf2()) + Op2VP = TargetTransformInfo::OP_None; if (i == 0) { - CInt = cast<ConstantInt>(I->getOperand(1)); + CInt0 = CInt; continue; } - if (Op2VK == TargetTransformInfo::OK_UniformConstantValue && - CInt != cast<ConstantInt>(I->getOperand(1))) + if (CInt0 != CInt) Op2VK = TargetTransformInfo::OK_NonUniformConstantValue; } - // FIXME: Currently cost of model modification for division by power of - // 2 is handled for X86 and AArch64. Add support for other targets. - if (Op2VK == TargetTransformInfo::OK_UniformConstantValue && CInt && - CInt->getValue().isPowerOf2()) - Op2VP = TargetTransformInfo::OP_PowerOf2; SmallVector<const Value *, 4> Operands(VL0->operand_values()); - int ScalarCost = - VecTy->getNumElements() * - TTI->getArithmeticInstrCost(S.Opcode, ScalarTy, Op1VK, Op2VK, Op1VP, - Op2VP, Operands); - int VecCost = TTI->getArithmeticInstrCost(S.Opcode, VecTy, Op1VK, Op2VK, - Op1VP, Op2VP, Operands); - return VecCost - ScalarCost; + int ScalarEltCost = TTI->getArithmeticInstrCost( + S.getOpcode(), ScalarTy, Op1VK, Op2VK, Op1VP, Op2VP, Operands); + if (NeedToShuffleReuses) { + ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost; + } + int ScalarCost = VecTy->getNumElements() * ScalarEltCost; + int VecCost = TTI->getArithmeticInstrCost(S.getOpcode(), VecTy, Op1VK, + Op2VK, Op1VP, Op2VP, Operands); + return ReuseShuffleCost + VecCost - ScalarCost; } case Instruction::GetElementPtr: { TargetTransformInfo::OperandValueKind Op1VK = @@ -2142,83 +2294,119 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { TargetTransformInfo::OperandValueKind Op2VK = TargetTransformInfo::OK_UniformConstantValue; - int ScalarCost = - VecTy->getNumElements() * + int ScalarEltCost = TTI->getArithmeticInstrCost(Instruction::Add, ScalarTy, Op1VK, Op2VK); + if (NeedToShuffleReuses) { + ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost; + } + int ScalarCost = VecTy->getNumElements() * ScalarEltCost; int VecCost = TTI->getArithmeticInstrCost(Instruction::Add, VecTy, Op1VK, Op2VK); - - return VecCost - ScalarCost; + return ReuseShuffleCost + VecCost - ScalarCost; } case Instruction::Load: { // Cost of wide load - cost of scalar loads. - unsigned alignment = dyn_cast<LoadInst>(VL0)->getAlignment(); - int ScalarLdCost = VecTy->getNumElements() * + unsigned alignment = cast<LoadInst>(VL0)->getAlignment(); + int ScalarEltCost = TTI->getMemoryOpCost(Instruction::Load, ScalarTy, alignment, 0, VL0); - int VecLdCost = TTI->getMemoryOpCost(Instruction::Load, - VecTy, alignment, 0, VL0); - return VecLdCost - ScalarLdCost; + if (NeedToShuffleReuses) { + ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost; + } + int ScalarLdCost = VecTy->getNumElements() * ScalarEltCost; + int VecLdCost = + TTI->getMemoryOpCost(Instruction::Load, VecTy, alignment, 0, VL0); + if (!E->ReorderIndices.empty()) { + // TODO: Merge this shuffle with the ReuseShuffleCost. + VecLdCost += TTI->getShuffleCost( + TargetTransformInfo::SK_PermuteSingleSrc, VecTy); + } + return ReuseShuffleCost + VecLdCost - ScalarLdCost; } case Instruction::Store: { // We know that we can merge the stores. Calculate the cost. - unsigned alignment = dyn_cast<StoreInst>(VL0)->getAlignment(); - int ScalarStCost = VecTy->getNumElements() * + unsigned alignment = cast<StoreInst>(VL0)->getAlignment(); + int ScalarEltCost = TTI->getMemoryOpCost(Instruction::Store, ScalarTy, alignment, 0, VL0); - int VecStCost = TTI->getMemoryOpCost(Instruction::Store, - VecTy, alignment, 0, VL0); - return VecStCost - ScalarStCost; + if (NeedToShuffleReuses) { + ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost; + } + int ScalarStCost = VecTy->getNumElements() * ScalarEltCost; + int VecStCost = + TTI->getMemoryOpCost(Instruction::Store, VecTy, alignment, 0, VL0); + return ReuseShuffleCost + VecStCost - ScalarStCost; } case Instruction::Call: { CallInst *CI = cast<CallInst>(VL0); Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); // Calculate the cost of the scalar and vector calls. - SmallVector<Type*, 4> ScalarTys; - for (unsigned op = 0, opc = CI->getNumArgOperands(); op!= opc; ++op) + SmallVector<Type *, 4> ScalarTys; + for (unsigned op = 0, opc = CI->getNumArgOperands(); op != opc; ++op) ScalarTys.push_back(CI->getArgOperand(op)->getType()); FastMathFlags FMF; if (auto *FPMO = dyn_cast<FPMathOperator>(CI)) FMF = FPMO->getFastMathFlags(); - int ScalarCallCost = VecTy->getNumElements() * + int ScalarEltCost = TTI->getIntrinsicInstrCost(ID, ScalarTy, ScalarTys, FMF); + if (NeedToShuffleReuses) { + ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * ScalarEltCost; + } + int ScalarCallCost = VecTy->getNumElements() * ScalarEltCost; SmallVector<Value *, 4> Args(CI->arg_operands()); int VecCallCost = TTI->getIntrinsicInstrCost(ID, CI->getType(), Args, FMF, VecTy->getNumElements()); - DEBUG(dbgs() << "SLP: Call cost "<< VecCallCost - ScalarCallCost - << " (" << VecCallCost << "-" << ScalarCallCost << ")" - << " for " << *CI << "\n"); + LLVM_DEBUG(dbgs() << "SLP: Call cost " << VecCallCost - ScalarCallCost + << " (" << VecCallCost << "-" << ScalarCallCost << ")" + << " for " << *CI << "\n"); - return VecCallCost - ScalarCallCost; + return ReuseShuffleCost + VecCallCost - ScalarCallCost; } case Instruction::ShuffleVector: { - TargetTransformInfo::OperandValueKind Op1VK = - TargetTransformInfo::OK_AnyValue; - TargetTransformInfo::OperandValueKind Op2VK = - TargetTransformInfo::OK_AnyValue; + assert(S.isAltShuffle() && + ((Instruction::isBinaryOp(S.getOpcode()) && + Instruction::isBinaryOp(S.getAltOpcode())) || + (Instruction::isCast(S.getOpcode()) && + Instruction::isCast(S.getAltOpcode()))) && + "Invalid Shuffle Vector Operand"); int ScalarCost = 0; - int VecCost = 0; + if (NeedToShuffleReuses) { + for (unsigned Idx : E->ReuseShuffleIndices) { + Instruction *I = cast<Instruction>(VL[Idx]); + ReuseShuffleCost -= TTI->getInstructionCost( + I, TargetTransformInfo::TCK_RecipThroughput); + } + for (Value *V : VL) { + Instruction *I = cast<Instruction>(V); + ReuseShuffleCost += TTI->getInstructionCost( + I, TargetTransformInfo::TCK_RecipThroughput); + } + } for (Value *i : VL) { Instruction *I = cast<Instruction>(i); - if (!I) - break; - ScalarCost += - TTI->getArithmeticInstrCost(I->getOpcode(), ScalarTy, Op1VK, Op2VK); + assert(S.isOpcodeOrAlt(I) && "Unexpected main/alternate opcode"); + ScalarCost += TTI->getInstructionCost( + I, TargetTransformInfo::TCK_RecipThroughput); } // VecCost is equal to sum of the cost of creating 2 vectors // and the cost of creating shuffle. - Instruction *I0 = cast<Instruction>(VL[0]); - VecCost = - TTI->getArithmeticInstrCost(I0->getOpcode(), VecTy, Op1VK, Op2VK); - Instruction *I1 = cast<Instruction>(VL[1]); - VecCost += - TTI->getArithmeticInstrCost(I1->getOpcode(), VecTy, Op1VK, Op2VK); - VecCost += - TTI->getShuffleCost(TargetTransformInfo::SK_Alternate, VecTy, 0); - return VecCost - ScalarCost; + int VecCost = 0; + if (Instruction::isBinaryOp(S.getOpcode())) { + VecCost = TTI->getArithmeticInstrCost(S.getOpcode(), VecTy); + VecCost += TTI->getArithmeticInstrCost(S.getAltOpcode(), VecTy); + } else { + Type *Src0SclTy = S.MainOp->getOperand(0)->getType(); + Type *Src1SclTy = S.AltOp->getOperand(0)->getType(); + VectorType *Src0Ty = VectorType::get(Src0SclTy, VL.size()); + VectorType *Src1Ty = VectorType::get(Src1SclTy, VL.size()); + VecCost = TTI->getCastInstrCost(S.getOpcode(), VecTy, Src0Ty); + VecCost += TTI->getCastInstrCost(S.getAltOpcode(), VecTy, Src1Ty); + } + VecCost += TTI->getShuffleCost(TargetTransformInfo::SK_Select, VecTy, 0); + return ReuseShuffleCost + VecCost - ScalarCost; } default: llvm_unreachable("Unknown instruction"); @@ -2226,8 +2414,8 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { } bool BoUpSLP::isFullyVectorizableTinyTree() { - DEBUG(dbgs() << "SLP: Check whether the tree with height " << - VectorizableTree.size() << " is fully vectorizable .\n"); + LLVM_DEBUG(dbgs() << "SLP: Check whether the tree with height " + << VectorizableTree.size() << " is fully vectorizable .\n"); // We only handle trees of heights 1 and 2. if (VectorizableTree.size() == 1 && !VectorizableTree[0].NeedToGather) @@ -2297,13 +2485,13 @@ int BoUpSLP::getSpillCost() { LiveValues.insert(cast<Instruction>(&*J)); } - DEBUG( + LLVM_DEBUG({ dbgs() << "SLP: #LV: " << LiveValues.size(); for (auto *X : LiveValues) dbgs() << " " << X->getName(); dbgs() << ", Looking at "; Inst->dump(); - ); + }); // Now find the sequence of instructions between PrevInst and Inst. BasicBlock::reverse_iterator InstIt = ++Inst->getIterator().getReverse(), @@ -2315,7 +2503,10 @@ int BoUpSLP::getSpillCost() { continue; } - if (isa<CallInst>(&*PrevInstIt) && &*PrevInstIt != PrevInst) { + // Debug informations don't impact spill cost. + if ((isa<CallInst>(&*PrevInstIt) && + !isa<DbgInfoIntrinsic>(&*PrevInstIt)) && + &*PrevInstIt != PrevInst) { SmallVector<Type*, 4> V; for (auto *II : LiveValues) V.push_back(VectorType::get(II->getType(), BundleWidth)); @@ -2333,19 +2524,41 @@ int BoUpSLP::getSpillCost() { int BoUpSLP::getTreeCost() { int Cost = 0; - DEBUG(dbgs() << "SLP: Calculating cost for tree of size " << - VectorizableTree.size() << ".\n"); + LLVM_DEBUG(dbgs() << "SLP: Calculating cost for tree of size " + << VectorizableTree.size() << ".\n"); unsigned BundleWidth = VectorizableTree[0].Scalars.size(); - for (TreeEntry &TE : VectorizableTree) { + for (unsigned I = 0, E = VectorizableTree.size(); I < E; ++I) { + TreeEntry &TE = VectorizableTree[I]; + + // We create duplicate tree entries for gather sequences that have multiple + // uses. However, we should not compute the cost of duplicate sequences. + // For example, if we have a build vector (i.e., insertelement sequence) + // that is used by more than one vector instruction, we only need to + // compute the cost of the insertelement instructions once. The redundent + // instructions will be eliminated by CSE. + // + // We should consider not creating duplicate tree entries for gather + // sequences, and instead add additional edges to the tree representing + // their uses. Since such an approach results in fewer total entries, + // existing heuristics based on tree size may yeild different results. + // + if (TE.NeedToGather && + std::any_of(std::next(VectorizableTree.begin(), I + 1), + VectorizableTree.end(), [TE](TreeEntry &Entry) { + return Entry.NeedToGather && Entry.isSame(TE.Scalars); + })) + continue; + int C = getEntryCost(&TE); - DEBUG(dbgs() << "SLP: Adding cost " << C << " for bundle that starts with " - << *TE.Scalars[0] << ".\n"); + LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C + << " for bundle that starts with " << *TE.Scalars[0] + << ".\n"); Cost += C; } - SmallSet<Value *, 16> ExtractCostCalculated; + SmallPtrSet<Value *, 16> ExtractCostCalculated; int ExtractCost = 0; for (ExternalUser &EU : ExternalUses) { // We only add extract cost once for the same scalar. @@ -2386,7 +2599,7 @@ int BoUpSLP::getTreeCost() { << "SLP: Extract Cost = " << ExtractCost << ".\n" << "SLP: Total Cost = " << Cost << ".\n"; } - DEBUG(dbgs() << Str); + LLVM_DEBUG(dbgs() << Str); if (ViewSLPTree) ViewGraph(this, "SLP" + F->getName(), false, Str); @@ -2394,10 +2607,14 @@ int BoUpSLP::getTreeCost() { return Cost; } -int BoUpSLP::getGatherCost(Type *Ty) { +int BoUpSLP::getGatherCost(Type *Ty, + const DenseSet<unsigned> &ShuffledIndices) { int Cost = 0; for (unsigned i = 0, e = cast<VectorType>(Ty)->getNumElements(); i < e; ++i) - Cost += TTI->getVectorInstrCost(Instruction::InsertElement, Ty, i); + if (!ShuffledIndices.count(i)) + Cost += TTI->getVectorInstrCost(Instruction::InsertElement, Ty, i); + if (!ShuffledIndices.empty()) + Cost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, Ty); return Cost; } @@ -2408,7 +2625,17 @@ int BoUpSLP::getGatherCost(ArrayRef<Value *> VL) { ScalarTy = SI->getValueOperand()->getType(); VectorType *VecTy = VectorType::get(ScalarTy, VL.size()); // Find the cost of inserting/extracting values from the vector. - return getGatherCost(VecTy); + // Check if the same elements are inserted several times and count them as + // shuffle candidates. + DenseSet<unsigned> ShuffledElements; + DenseSet<Value *> UniqueElements; + // Iterate in reverse order to consider insert elements with the high cost. + for (unsigned I = VL.size(); I > 0; --I) { + unsigned Idx = I - 1; + if (!UniqueElements.insert(VL[Idx]).second) + ShuffledElements.insert(Idx); + } + return getGatherCost(VecTy, ShuffledElements); } // Reorder commutative operations in alternate shuffle if the resulting vectors @@ -2420,16 +2647,14 @@ int BoUpSLP::getGatherCost(ArrayRef<Value *> VL) { // load a[3] + load b[3] // Reordering the second load b[1] load a[1] would allow us to vectorize this // code. -void BoUpSLP::reorderAltShuffleOperands(unsigned Opcode, ArrayRef<Value *> VL, +void BoUpSLP::reorderAltShuffleOperands(const InstructionsState &S, + ArrayRef<Value *> VL, SmallVectorImpl<Value *> &Left, SmallVectorImpl<Value *> &Right) { // Push left and right operands of binary operation into Left and Right - unsigned AltOpcode = getAltOpcode(Opcode); - (void)AltOpcode; for (Value *V : VL) { auto *I = cast<Instruction>(V); - assert(sameOpcodeOrAlt(Opcode, AltOpcode, I->getOpcode()) && - "Incorrect instruction in vector"); + assert(S.isOpcodeOrAlt(I) && "Incorrect instruction in vector"); Left.push_back(I->getOperand(0)); Right.push_back(I->getOperand(1)); } @@ -2609,7 +2834,7 @@ void BoUpSLP::reorderInputsAccordingToOpcode(unsigned Opcode, // add a[1],c[2] load b[1] // b[2] load b[2] // add a[3],c[3] load b[3] - for (unsigned j = 0; j < VL.size() - 1; ++j) { + for (unsigned j = 0, e = VL.size() - 1; j < e; ++j) { if (LoadInst *L = dyn_cast<LoadInst>(Left[j])) { if (LoadInst *L1 = dyn_cast<LoadInst>(Right[j + 1])) { if (isConsecutiveAccess(L, L1, *DL, *SE)) { @@ -2630,17 +2855,15 @@ void BoUpSLP::reorderInputsAccordingToOpcode(unsigned Opcode, } } -void BoUpSLP::setInsertPointAfterBundle(ArrayRef<Value *> VL, Value *OpValue) { +void BoUpSLP::setInsertPointAfterBundle(ArrayRef<Value *> VL, + const InstructionsState &S) { // Get the basic block this bundle is in. All instructions in the bundle // should be in this block. - auto *Front = cast<Instruction>(OpValue); + auto *Front = cast<Instruction>(S.OpValue); auto *BB = Front->getParent(); - const unsigned Opcode = cast<Instruction>(OpValue)->getOpcode(); - const unsigned AltOpcode = getAltOpcode(Opcode); assert(llvm::all_of(make_range(VL.begin(), VL.end()), [=](Value *V) -> bool { - return !sameOpcodeOrAlt(Opcode, AltOpcode, - cast<Instruction>(V)->getOpcode()) || - cast<Instruction>(V)->getParent() == BB; + auto *I = cast<Instruction>(V); + return !S.isOpcodeOrAlt(I) || I->getParent() == BB; })); // The last instruction in the bundle in program order. @@ -2652,7 +2875,7 @@ void BoUpSLP::setInsertPointAfterBundle(ArrayRef<Value *> VL, Value *OpValue) { // bundle. The end of the bundle is marked by null ScheduleData. if (BlocksSchedules.count(BB)) { auto *Bundle = - BlocksSchedules[BB]->getScheduleData(isOneOf(OpValue, VL.back())); + BlocksSchedules[BB]->getScheduleData(isOneOf(S, VL.back())); if (Bundle && Bundle->isPartOfBundle()) for (; Bundle; Bundle = Bundle->NextInBundle) if (Bundle->OpValue == Bundle->Inst) @@ -2680,7 +2903,7 @@ void BoUpSLP::setInsertPointAfterBundle(ArrayRef<Value *> VL, Value *OpValue) { if (!LastInst) { SmallPtrSet<Value *, 16> Bundle(VL.begin(), VL.end()); for (auto &I : make_range(BasicBlock::iterator(Front), BB->end())) { - if (Bundle.erase(&I) && sameOpcodeOrAlt(Opcode, AltOpcode, I.getOpcode())) + if (Bundle.erase(&I) && S.isOpcodeOrAlt(&I)) LastInst = &I; if (Bundle.empty()) break; @@ -2706,7 +2929,7 @@ Value *BoUpSLP::Gather(ArrayRef<Value *> VL, VectorType *Ty) { if (TreeEntry *E = getTreeEntry(VL[i])) { // Find which lane we need to extract. int FoundLane = -1; - for (unsigned Lane = 0, LE = VL.size(); Lane != LE; ++Lane) { + for (unsigned Lane = 0, LE = E->Scalars.size(); Lane != LE; ++Lane) { // Is this the lane of the scalar that we are looking for ? if (E->Scalars[Lane] == VL[i]) { FoundLane = Lane; @@ -2714,6 +2937,11 @@ Value *BoUpSLP::Gather(ArrayRef<Value *> VL, VectorType *Ty) { } } assert(FoundLane >= 0 && "Could not find the correct lane"); + if (!E->ReuseShuffleIndices.empty()) { + FoundLane = + std::distance(E->ReuseShuffleIndices.begin(), + llvm::find(E->ReuseShuffleIndices, FoundLane)); + } ExternalUses.push_back(ExternalUser(VL[i], Insrt, FoundLane)); } } @@ -2722,66 +2950,128 @@ Value *BoUpSLP::Gather(ArrayRef<Value *> VL, VectorType *Ty) { return Vec; } -Value *BoUpSLP::alreadyVectorized(ArrayRef<Value *> VL, Value *OpValue) const { - if (const TreeEntry *En = getTreeEntry(OpValue)) { - if (En->isSame(VL) && En->VectorizedValue) - return En->VectorizedValue; - } - return nullptr; -} - Value *BoUpSLP::vectorizeTree(ArrayRef<Value *> VL) { InstructionsState S = getSameOpcode(VL); - if (S.Opcode) { + if (S.getOpcode()) { if (TreeEntry *E = getTreeEntry(S.OpValue)) { - if (E->isSame(VL)) - return vectorizeTree(E); + if (E->isSame(VL)) { + Value *V = vectorizeTree(E); + if (VL.size() == E->Scalars.size() && !E->ReuseShuffleIndices.empty()) { + // We need to get the vectorized value but without shuffle. + if (auto *SV = dyn_cast<ShuffleVectorInst>(V)) { + V = SV->getOperand(0); + } else { + // Reshuffle to get only unique values. + SmallVector<unsigned, 4> UniqueIdxs; + SmallSet<unsigned, 4> UsedIdxs; + for(unsigned Idx : E->ReuseShuffleIndices) + if (UsedIdxs.insert(Idx).second) + UniqueIdxs.emplace_back(Idx); + V = Builder.CreateShuffleVector(V, UndefValue::get(V->getType()), + UniqueIdxs); + } + } + return V; + } } } Type *ScalarTy = S.OpValue->getType(); if (StoreInst *SI = dyn_cast<StoreInst>(S.OpValue)) ScalarTy = SI->getValueOperand()->getType(); + + // Check that every instruction appears once in this bundle. + SmallVector<unsigned, 4> ReuseShuffleIndicies; + SmallVector<Value *, 4> UniqueValues; + if (VL.size() > 2) { + DenseMap<Value *, unsigned> UniquePositions; + for (Value *V : VL) { + auto Res = UniquePositions.try_emplace(V, UniqueValues.size()); + ReuseShuffleIndicies.emplace_back(Res.first->second); + if (Res.second || isa<Constant>(V)) + UniqueValues.emplace_back(V); + } + // Do not shuffle single element or if number of unique values is not power + // of 2. + if (UniqueValues.size() == VL.size() || UniqueValues.size() <= 1 || + !llvm::isPowerOf2_32(UniqueValues.size())) + ReuseShuffleIndicies.clear(); + else + VL = UniqueValues; + } VectorType *VecTy = VectorType::get(ScalarTy, VL.size()); - return Gather(VL, VecTy); + Value *V = Gather(VL, VecTy); + if (!ReuseShuffleIndicies.empty()) { + V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), + ReuseShuffleIndicies, "shuffle"); + if (auto *I = dyn_cast<Instruction>(V)) { + GatherSeq.insert(I); + CSEBlocks.insert(I->getParent()); + } + } + return V; +} + +static void inversePermutation(ArrayRef<unsigned> Indices, + SmallVectorImpl<unsigned> &Mask) { + Mask.clear(); + const unsigned E = Indices.size(); + Mask.resize(E); + for (unsigned I = 0; I < E; ++I) + Mask[Indices[I]] = I; } Value *BoUpSLP::vectorizeTree(TreeEntry *E) { IRBuilder<>::InsertPointGuard Guard(Builder); if (E->VectorizedValue) { - DEBUG(dbgs() << "SLP: Diamond merged for " << *E->Scalars[0] << ".\n"); + LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *E->Scalars[0] << ".\n"); return E->VectorizedValue; } InstructionsState S = getSameOpcode(E->Scalars); - Instruction *VL0 = cast<Instruction>(E->Scalars[0]); + Instruction *VL0 = cast<Instruction>(S.OpValue); Type *ScalarTy = VL0->getType(); if (StoreInst *SI = dyn_cast<StoreInst>(VL0)) ScalarTy = SI->getValueOperand()->getType(); VectorType *VecTy = VectorType::get(ScalarTy, E->Scalars.size()); + bool NeedToShuffleReuses = !E->ReuseShuffleIndices.empty(); + if (E->NeedToGather) { - setInsertPointAfterBundle(E->Scalars, VL0); + setInsertPointAfterBundle(E->Scalars, S); auto *V = Gather(E->Scalars, VecTy); + if (NeedToShuffleReuses) { + V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), + E->ReuseShuffleIndices, "shuffle"); + if (auto *I = dyn_cast<Instruction>(V)) { + GatherSeq.insert(I); + CSEBlocks.insert(I->getParent()); + } + } E->VectorizedValue = V; return V; } - unsigned ShuffleOrOp = S.IsAltShuffle ? - (unsigned) Instruction::ShuffleVector : S.Opcode; + unsigned ShuffleOrOp = S.isAltShuffle() ? + (unsigned) Instruction::ShuffleVector : S.getOpcode(); switch (ShuffleOrOp) { case Instruction::PHI: { PHINode *PH = dyn_cast<PHINode>(VL0); Builder.SetInsertPoint(PH->getParent()->getFirstNonPHI()); Builder.SetCurrentDebugLocation(PH->getDebugLoc()); PHINode *NewPhi = Builder.CreatePHI(VecTy, PH->getNumIncomingValues()); - E->VectorizedValue = NewPhi; + Value *V = NewPhi; + if (NeedToShuffleReuses) { + V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), + E->ReuseShuffleIndices, "shuffle"); + } + E->VectorizedValue = V; // PHINodes may have multiple entries from the same block. We want to // visit every block once. - SmallSet<BasicBlock*, 4> VisitedBBs; + SmallPtrSet<BasicBlock*, 4> VisitedBBs; for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) { ValueList Operands; @@ -2804,32 +3094,74 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { assert(NewPhi->getNumIncomingValues() == PH->getNumIncomingValues() && "Invalid number of incoming values"); - return NewPhi; + return V; } case Instruction::ExtractElement: { - if (canReuseExtract(E->Scalars, VL0)) { + if (!E->NeedToGather) { Value *V = VL0->getOperand(0); + if (!E->ReorderIndices.empty()) { + OrdersType Mask; + inversePermutation(E->ReorderIndices, Mask); + Builder.SetInsertPoint(VL0); + V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), Mask, + "reorder_shuffle"); + } + if (NeedToShuffleReuses) { + // TODO: Merge this shuffle with the ReorderShuffleMask. + if (!E->ReorderIndices.empty()) + Builder.SetInsertPoint(VL0); + V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), + E->ReuseShuffleIndices, "shuffle"); + } E->VectorizedValue = V; return V; } - setInsertPointAfterBundle(E->Scalars, VL0); + setInsertPointAfterBundle(E->Scalars, S); auto *V = Gather(E->Scalars, VecTy); + if (NeedToShuffleReuses) { + V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), + E->ReuseShuffleIndices, "shuffle"); + if (auto *I = dyn_cast<Instruction>(V)) { + GatherSeq.insert(I); + CSEBlocks.insert(I->getParent()); + } + } E->VectorizedValue = V; return V; } case Instruction::ExtractValue: { - if (canReuseExtract(E->Scalars, VL0)) { + if (!E->NeedToGather) { LoadInst *LI = cast<LoadInst>(VL0->getOperand(0)); Builder.SetInsertPoint(LI); PointerType *PtrTy = PointerType::get(VecTy, LI->getPointerAddressSpace()); Value *Ptr = Builder.CreateBitCast(LI->getOperand(0), PtrTy); LoadInst *V = Builder.CreateAlignedLoad(Ptr, LI->getAlignment()); - E->VectorizedValue = V; - return propagateMetadata(V, E->Scalars); + Value *NewV = propagateMetadata(V, E->Scalars); + if (!E->ReorderIndices.empty()) { + OrdersType Mask; + inversePermutation(E->ReorderIndices, Mask); + NewV = Builder.CreateShuffleVector(NewV, UndefValue::get(VecTy), Mask, + "reorder_shuffle"); + } + if (NeedToShuffleReuses) { + // TODO: Merge this shuffle with the ReorderShuffleMask. + NewV = Builder.CreateShuffleVector( + NewV, UndefValue::get(VecTy), E->ReuseShuffleIndices, "shuffle"); + } + E->VectorizedValue = NewV; + return NewV; } - setInsertPointAfterBundle(E->Scalars, VL0); + setInsertPointAfterBundle(E->Scalars, S); auto *V = Gather(E->Scalars, VecTy); + if (NeedToShuffleReuses) { + V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), + E->ReuseShuffleIndices, "shuffle"); + if (auto *I = dyn_cast<Instruction>(V)) { + GatherSeq.insert(I); + CSEBlocks.insert(I->getParent()); + } + } E->VectorizedValue = V; return V; } @@ -2849,15 +3181,21 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { for (Value *V : E->Scalars) INVL.push_back(cast<Instruction>(V)->getOperand(0)); - setInsertPointAfterBundle(E->Scalars, VL0); + setInsertPointAfterBundle(E->Scalars, S); Value *InVec = vectorizeTree(INVL); - if (Value *V = alreadyVectorized(E->Scalars, VL0)) - return V; + if (E->VectorizedValue) { + LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); + return E->VectorizedValue; + } CastInst *CI = dyn_cast<CastInst>(VL0); Value *V = Builder.CreateCast(CI->getOpcode(), InVec, VecTy); + if (NeedToShuffleReuses) { + V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), + E->ReuseShuffleIndices, "shuffle"); + } E->VectorizedValue = V; ++NumVectorInstructions; return V; @@ -2870,23 +3208,29 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { RHSV.push_back(cast<Instruction>(V)->getOperand(1)); } - setInsertPointAfterBundle(E->Scalars, VL0); + setInsertPointAfterBundle(E->Scalars, S); Value *L = vectorizeTree(LHSV); Value *R = vectorizeTree(RHSV); - if (Value *V = alreadyVectorized(E->Scalars, VL0)) - return V; + if (E->VectorizedValue) { + LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); + return E->VectorizedValue; + } CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate(); Value *V; - if (S.Opcode == Instruction::FCmp) + if (S.getOpcode() == Instruction::FCmp) V = Builder.CreateFCmp(P0, L, R); else V = Builder.CreateICmp(P0, L, R); + propagateIRFlags(V, E->Scalars, VL0); + if (NeedToShuffleReuses) { + V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), + E->ReuseShuffleIndices, "shuffle"); + } E->VectorizedValue = V; - propagateIRFlags(E->VectorizedValue, E->Scalars, VL0); ++NumVectorInstructions; return V; } @@ -2898,16 +3242,22 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { FalseVec.push_back(cast<Instruction>(V)->getOperand(2)); } - setInsertPointAfterBundle(E->Scalars, VL0); + setInsertPointAfterBundle(E->Scalars, S); Value *Cond = vectorizeTree(CondVec); Value *True = vectorizeTree(TrueVec); Value *False = vectorizeTree(FalseVec); - if (Value *V = alreadyVectorized(E->Scalars, VL0)) - return V; + if (E->VectorizedValue) { + LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); + return E->VectorizedValue; + } Value *V = Builder.CreateSelect(Cond, True, False); + if (NeedToShuffleReuses) { + V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), + E->ReuseShuffleIndices, "shuffle"); + } E->VectorizedValue = V; ++NumVectorInstructions; return V; @@ -2932,7 +3282,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::Xor: { ValueList LHSVL, RHSVL; if (isa<BinaryOperator>(VL0) && VL0->isCommutative()) - reorderInputsAccordingToOpcode(S.Opcode, E->Scalars, LHSVL, + reorderInputsAccordingToOpcode(S.getOpcode(), E->Scalars, LHSVL, RHSVL); else for (Value *V : E->Scalars) { @@ -2941,29 +3291,40 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { RHSVL.push_back(I->getOperand(1)); } - setInsertPointAfterBundle(E->Scalars, VL0); + setInsertPointAfterBundle(E->Scalars, S); Value *LHS = vectorizeTree(LHSVL); Value *RHS = vectorizeTree(RHSVL); - if (Value *V = alreadyVectorized(E->Scalars, VL0)) - return V; + if (E->VectorizedValue) { + LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); + return E->VectorizedValue; + } Value *V = Builder.CreateBinOp( - static_cast<Instruction::BinaryOps>(S.Opcode), LHS, RHS); + static_cast<Instruction::BinaryOps>(S.getOpcode()), LHS, RHS); + propagateIRFlags(V, E->Scalars, VL0); + if (auto *I = dyn_cast<Instruction>(V)) + V = propagateMetadata(I, E->Scalars); + + if (NeedToShuffleReuses) { + V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), + E->ReuseShuffleIndices, "shuffle"); + } E->VectorizedValue = V; - propagateIRFlags(E->VectorizedValue, E->Scalars, VL0); ++NumVectorInstructions; - if (Instruction *I = dyn_cast<Instruction>(V)) - return propagateMetadata(I, E->Scalars); - return V; } case Instruction::Load: { // Loads are inserted at the head of the tree because we don't want to // sink them all the way down past store instructions. - setInsertPointAfterBundle(E->Scalars, VL0); + bool IsReorder = !E->ReorderIndices.empty(); + if (IsReorder) { + S = getSameOpcode(E->Scalars, E->ReorderIndices.front()); + VL0 = cast<Instruction>(S.OpValue); + } + setInsertPointAfterBundle(E->Scalars, S); LoadInst *LI = cast<LoadInst>(VL0); Type *ScalarLoadTy = LI->getType(); @@ -2985,9 +3346,21 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Alignment = DL->getABITypeAlignment(ScalarLoadTy); } LI->setAlignment(Alignment); - E->VectorizedValue = LI; + Value *V = propagateMetadata(LI, E->Scalars); + if (IsReorder) { + OrdersType Mask; + inversePermutation(E->ReorderIndices, Mask); + V = Builder.CreateShuffleVector(V, UndefValue::get(V->getType()), + Mask, "reorder_shuffle"); + } + if (NeedToShuffleReuses) { + // TODO: Merge this shuffle with the ReorderShuffleMask. + V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), + E->ReuseShuffleIndices, "shuffle"); + } + E->VectorizedValue = V; ++NumVectorInstructions; - return propagateMetadata(LI, E->Scalars); + return V; } case Instruction::Store: { StoreInst *SI = cast<StoreInst>(VL0); @@ -2998,12 +3371,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { for (Value *V : E->Scalars) ScalarStoreValues.push_back(cast<StoreInst>(V)->getValueOperand()); - setInsertPointAfterBundle(E->Scalars, VL0); + setInsertPointAfterBundle(E->Scalars, S); Value *VecValue = vectorizeTree(ScalarStoreValues); Value *ScalarPtr = SI->getPointerOperand(); Value *VecPtr = Builder.CreateBitCast(ScalarPtr, VecTy->getPointerTo(AS)); - StoreInst *S = Builder.CreateStore(VecValue, VecPtr); + StoreInst *ST = Builder.CreateStore(VecValue, VecPtr); // The pointer operand uses an in-tree scalar, so add the new BitCast to // ExternalUses to make sure that an extract will be generated in the @@ -3014,13 +3387,18 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { if (!Alignment) Alignment = DL->getABITypeAlignment(SI->getValueOperand()->getType()); - S->setAlignment(Alignment); - E->VectorizedValue = S; + ST->setAlignment(Alignment); + Value *V = propagateMetadata(ST, E->Scalars); + if (NeedToShuffleReuses) { + V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), + E->ReuseShuffleIndices, "shuffle"); + } + E->VectorizedValue = V; ++NumVectorInstructions; - return propagateMetadata(S, E->Scalars); + return V; } case Instruction::GetElementPtr: { - setInsertPointAfterBundle(E->Scalars, VL0); + setInsertPointAfterBundle(E->Scalars, S); ValueList Op0VL; for (Value *V : E->Scalars) @@ -3041,17 +3419,21 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *V = Builder.CreateGEP( cast<GetElementPtrInst>(VL0)->getSourceElementType(), Op0, OpVecs); + if (Instruction *I = dyn_cast<Instruction>(V)) + V = propagateMetadata(I, E->Scalars); + + if (NeedToShuffleReuses) { + V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), + E->ReuseShuffleIndices, "shuffle"); + } E->VectorizedValue = V; ++NumVectorInstructions; - if (Instruction *I = dyn_cast<Instruction>(V)) - return propagateMetadata(I, E->Scalars); - return V; } case Instruction::Call: { CallInst *CI = cast<CallInst>(VL0); - setInsertPointAfterBundle(E->Scalars, VL0); + setInsertPointAfterBundle(E->Scalars, S); Function *FI; Intrinsic::ID IID = Intrinsic::not_intrinsic; Value *ScalarArg = nullptr; @@ -3075,7 +3457,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } Value *OpVec = vectorizeTree(OpVL); - DEBUG(dbgs() << "SLP: OpVec[" << j << "]: " << *OpVec << "\n"); + LLVM_DEBUG(dbgs() << "SLP: OpVec[" << j << "]: " << *OpVec << "\n"); OpVecs.push_back(OpVec); } @@ -3093,58 +3475,87 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { if (ScalarArg && getTreeEntry(ScalarArg)) ExternalUses.push_back(ExternalUser(ScalarArg, cast<User>(V), 0)); + propagateIRFlags(V, E->Scalars, VL0); + if (NeedToShuffleReuses) { + V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), + E->ReuseShuffleIndices, "shuffle"); + } E->VectorizedValue = V; - propagateIRFlags(E->VectorizedValue, E->Scalars, VL0); ++NumVectorInstructions; return V; } case Instruction::ShuffleVector: { ValueList LHSVL, RHSVL; - assert(Instruction::isBinaryOp(S.Opcode) && + assert(S.isAltShuffle() && + ((Instruction::isBinaryOp(S.getOpcode()) && + Instruction::isBinaryOp(S.getAltOpcode())) || + (Instruction::isCast(S.getOpcode()) && + Instruction::isCast(S.getAltOpcode()))) && "Invalid Shuffle Vector Operand"); - reorderAltShuffleOperands(S.Opcode, E->Scalars, LHSVL, RHSVL); - setInsertPointAfterBundle(E->Scalars, VL0); - Value *LHS = vectorizeTree(LHSVL); - Value *RHS = vectorizeTree(RHSVL); - - if (Value *V = alreadyVectorized(E->Scalars, VL0)) - return V; + Value *LHS, *RHS; + if (Instruction::isBinaryOp(S.getOpcode())) { + reorderAltShuffleOperands(S, E->Scalars, LHSVL, RHSVL); + setInsertPointAfterBundle(E->Scalars, S); + LHS = vectorizeTree(LHSVL); + RHS = vectorizeTree(RHSVL); + } else { + ValueList INVL; + for (Value *V : E->Scalars) + INVL.push_back(cast<Instruction>(V)->getOperand(0)); + setInsertPointAfterBundle(E->Scalars, S); + LHS = vectorizeTree(INVL); + } - // Create a vector of LHS op1 RHS - Value *V0 = Builder.CreateBinOp( - static_cast<Instruction::BinaryOps>(S.Opcode), LHS, RHS); + if (E->VectorizedValue) { + LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); + return E->VectorizedValue; + } - unsigned AltOpcode = getAltOpcode(S.Opcode); - // Create a vector of LHS op2 RHS - Value *V1 = Builder.CreateBinOp( - static_cast<Instruction::BinaryOps>(AltOpcode), LHS, RHS); + Value *V0, *V1; + if (Instruction::isBinaryOp(S.getOpcode())) { + V0 = Builder.CreateBinOp( + static_cast<Instruction::BinaryOps>(S.getOpcode()), LHS, RHS); + V1 = Builder.CreateBinOp( + static_cast<Instruction::BinaryOps>(S.getAltOpcode()), LHS, RHS); + } else { + V0 = Builder.CreateCast( + static_cast<Instruction::CastOps>(S.getOpcode()), LHS, VecTy); + V1 = Builder.CreateCast( + static_cast<Instruction::CastOps>(S.getAltOpcode()), LHS, VecTy); + } // Create shuffle to take alternate operations from the vector. - // Also, gather up odd and even scalar ops to propagate IR flags to + // Also, gather up main and alt scalar ops to propagate IR flags to // each vector operation. - ValueList OddScalars, EvenScalars; + ValueList OpScalars, AltScalars; unsigned e = E->Scalars.size(); SmallVector<Constant *, 8> Mask(e); for (unsigned i = 0; i < e; ++i) { - if (isOdd(i)) { + auto *OpInst = cast<Instruction>(E->Scalars[i]); + assert(S.isOpcodeOrAlt(OpInst) && "Unexpected main/alternate opcode"); + if (OpInst->getOpcode() == S.getAltOpcode()) { Mask[i] = Builder.getInt32(e + i); - OddScalars.push_back(E->Scalars[i]); + AltScalars.push_back(E->Scalars[i]); } else { Mask[i] = Builder.getInt32(i); - EvenScalars.push_back(E->Scalars[i]); + OpScalars.push_back(E->Scalars[i]); } } Value *ShuffleMask = ConstantVector::get(Mask); - propagateIRFlags(V0, EvenScalars); - propagateIRFlags(V1, OddScalars); + propagateIRFlags(V0, OpScalars); + propagateIRFlags(V1, AltScalars); Value *V = Builder.CreateShuffleVector(V0, V1, ShuffleMask); + if (Instruction *I = dyn_cast<Instruction>(V)) + V = propagateMetadata(I, E->Scalars); + if (NeedToShuffleReuses) { + V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), + E->ReuseShuffleIndices, "shuffle"); + } E->VectorizedValue = V; ++NumVectorInstructions; - if (Instruction *I = dyn_cast<Instruction>(V)) - return propagateMetadata(I, E->Scalars); return V; } @@ -3183,7 +3594,8 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { VectorizableTree[0].VectorizedValue = Trunc; } - DEBUG(dbgs() << "SLP: Extracting " << ExternalUses.size() << " values .\n"); + LLVM_DEBUG(dbgs() << "SLP: Extracting " << ExternalUses.size() + << " values .\n"); // If necessary, sign-extend or zero-extend ScalarRoot to the larger type // specified by ScalarType. @@ -3260,7 +3672,7 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { Ex = extend(ScalarRoot, Ex, Scalar->getType()); CSEBlocks.insert(cast<Instruction>(User)->getParent()); User->replaceUsesOfWith(Scalar, Ex); - } + } } else { Builder.SetInsertPoint(&F->getEntryBlock().front()); Value *Ex = Builder.CreateExtractElement(Vec, Lane); @@ -3269,7 +3681,7 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { User->replaceUsesOfWith(Scalar, Ex); } - DEBUG(dbgs() << "SLP: Replaced:" << *User << ".\n"); + LLVM_DEBUG(dbgs() << "SLP: Replaced:" << *User << ".\n"); } // For each vectorized value: @@ -3290,7 +3702,7 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { if (!Ty->isVoidTy()) { #ifndef NDEBUG for (User *U : Scalar->users()) { - DEBUG(dbgs() << "SLP: \tvalidating user:" << *U << ".\n"); + LLVM_DEBUG(dbgs() << "SLP: \tvalidating user:" << *U << ".\n"); // It is legal to replace users in the ignorelist by undef. assert((getTreeEntry(U) || is_contained(UserIgnoreList, U)) && @@ -3300,7 +3712,7 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { Value *Undef = UndefValue::get(Ty); Scalar->replaceAllUsesWith(Undef); } - DEBUG(dbgs() << "SLP: \tErasing scalar:" << *Scalar << ".\n"); + LLVM_DEBUG(dbgs() << "SLP: \tErasing scalar:" << *Scalar << ".\n"); eraseInstruction(cast<Instruction>(Scalar)); } } @@ -3310,18 +3722,16 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { return VectorizableTree[0].VectorizedValue; } -void BoUpSLP::optimizeGatherSequence(Function &F) { - DEBUG(dbgs() << "SLP: Optimizing " << GatherSeq.size() - << " gather sequences instructions.\n"); +void BoUpSLP::optimizeGatherSequence() { + LLVM_DEBUG(dbgs() << "SLP: Optimizing " << GatherSeq.size() + << " gather sequences instructions.\n"); // LICM InsertElementInst sequences. - for (Instruction *it : GatherSeq) { - InsertElementInst *Insert = dyn_cast<InsertElementInst>(it); - - if (!Insert) + for (Instruction *I : GatherSeq) { + if (!isa<InsertElementInst>(I) && !isa<ShuffleVectorInst>(I)) continue; // Check if this block is inside a loop. - Loop *L = LI->getLoopFor(Insert->getParent()); + Loop *L = LI->getLoopFor(I->getParent()); if (!L) continue; @@ -3333,27 +3743,41 @@ void BoUpSLP::optimizeGatherSequence(Function &F) { // If the vector or the element that we insert into it are // instructions that are defined in this basic block then we can't // hoist this instruction. - Instruction *CurrVec = dyn_cast<Instruction>(Insert->getOperand(0)); - Instruction *NewElem = dyn_cast<Instruction>(Insert->getOperand(1)); - if (CurrVec && L->contains(CurrVec)) + auto *Op0 = dyn_cast<Instruction>(I->getOperand(0)); + auto *Op1 = dyn_cast<Instruction>(I->getOperand(1)); + if (Op0 && L->contains(Op0)) continue; - if (NewElem && L->contains(NewElem)) + if (Op1 && L->contains(Op1)) continue; // We can hoist this instruction. Move it to the pre-header. - Insert->moveBefore(PreHeader->getTerminator()); + I->moveBefore(PreHeader->getTerminator()); } + // Make a list of all reachable blocks in our CSE queue. + SmallVector<const DomTreeNode *, 8> CSEWorkList; + CSEWorkList.reserve(CSEBlocks.size()); + for (BasicBlock *BB : CSEBlocks) + if (DomTreeNode *N = DT->getNode(BB)) { + assert(DT->isReachableFromEntry(N)); + CSEWorkList.push_back(N); + } + + // Sort blocks by domination. This ensures we visit a block after all blocks + // dominating it are visited. + std::stable_sort(CSEWorkList.begin(), CSEWorkList.end(), + [this](const DomTreeNode *A, const DomTreeNode *B) { + return DT->properlyDominates(A, B); + }); + // Perform O(N^2) search over the gather sequences and merge identical // instructions. TODO: We can further optimize this scan if we split the // instructions into different buckets based on the insert lane. SmallVector<Instruction *, 16> Visited; - ReversePostOrderTraversal<Function *> RPOT(&F); - for (auto BB : RPOT) { - // Traverse CSEBlocks by RPOT order. - if (!CSEBlocks.count(BB)) - continue; - + for (auto I = CSEWorkList.begin(), E = CSEWorkList.end(); I != E; ++I) { + assert((I == CSEWorkList.begin() || !DT->dominates(*I, *std::prev(I))) && + "Worklist not sorted properly!"); + BasicBlock *BB = (*I)->getBlock(); // For all instructions in blocks containing gather sequences: for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e;) { Instruction *In = &*it++; @@ -3384,8 +3808,9 @@ void BoUpSLP::optimizeGatherSequence(Function &F) { // Groups the instructions to a bundle (which is then a single scheduling entity) // and schedules instructions until the bundle gets ready. bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, - BoUpSLP *SLP, Value *OpValue) { - if (isa<PHINode>(OpValue)) + BoUpSLP *SLP, + const InstructionsState &S) { + if (isa<PHINode>(S.OpValue)) return true; // Initialize the instruction bundle. @@ -3393,12 +3818,12 @@ bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, ScheduleData *PrevInBundle = nullptr; ScheduleData *Bundle = nullptr; bool ReSchedule = false; - DEBUG(dbgs() << "SLP: bundle: " << *OpValue << "\n"); + LLVM_DEBUG(dbgs() << "SLP: bundle: " << *S.OpValue << "\n"); // Make sure that the scheduling region contains all // instructions of the bundle. for (Value *V : VL) { - if (!extendSchedulingRegion(V, OpValue)) + if (!extendSchedulingRegion(V, S)) return false; } @@ -3410,8 +3835,8 @@ bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, // A bundle member was scheduled as single instruction before and now // needs to be scheduled as part of the bundle. We just get rid of the // existing schedule. - DEBUG(dbgs() << "SLP: reset schedule because " << *BundleMember - << " was already scheduled\n"); + LLVM_DEBUG(dbgs() << "SLP: reset schedule because " << *BundleMember + << " was already scheduled\n"); ReSchedule = true; } assert(BundleMember->isSchedulingEntity() && @@ -3446,8 +3871,8 @@ bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, initialFillReadyList(ReadyInsts); } - DEBUG(dbgs() << "SLP: try schedule bundle " << *Bundle << " in block " - << BB->getName() << "\n"); + LLVM_DEBUG(dbgs() << "SLP: try schedule bundle " << *Bundle << " in block " + << BB->getName() << "\n"); calculateDependencies(Bundle, true, SLP); @@ -3465,7 +3890,7 @@ bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, } } if (!Bundle->isReady()) { - cancelScheduling(VL, OpValue); + cancelScheduling(VL, S.OpValue); return false; } return true; @@ -3477,7 +3902,7 @@ void BoUpSLP::BlockScheduling::cancelScheduling(ArrayRef<Value *> VL, return; ScheduleData *Bundle = getScheduleData(OpValue); - DEBUG(dbgs() << "SLP: cancel scheduling of " << *Bundle << "\n"); + LLVM_DEBUG(dbgs() << "SLP: cancel scheduling of " << *Bundle << "\n"); assert(!Bundle->IsScheduled && "Can't cancel bundle which is already scheduled"); assert(Bundle->isSchedulingEntity() && Bundle->isPartOfBundle() && @@ -3508,13 +3933,13 @@ BoUpSLP::ScheduleData *BoUpSLP::BlockScheduling::allocateScheduleDataChunks() { } bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, - Value *OpValue) { - if (getScheduleData(V, isOneOf(OpValue, V))) + const InstructionsState &S) { + if (getScheduleData(V, isOneOf(S, V))) return true; Instruction *I = dyn_cast<Instruction>(V); assert(I && "bundle member must be an instruction"); assert(!isa<PHINode>(I) && "phi nodes don't need to be scheduled"); - auto &&CheckSheduleForI = [this, OpValue](Instruction *I) -> bool { + auto &&CheckSheduleForI = [this, &S](Instruction *I) -> bool { ScheduleData *ISD = getScheduleData(I); if (!ISD) return false; @@ -3522,8 +3947,8 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, "ScheduleData not in scheduling region"); ScheduleData *SD = allocateScheduleDataChunks(); SD->Inst = I; - SD->init(SchedulingRegionID, OpValue); - ExtraScheduleDataMap[I][OpValue] = SD; + SD->init(SchedulingRegionID, S.OpValue); + ExtraScheduleDataMap[I][S.OpValue] = SD; return true; }; if (CheckSheduleForI(I)) @@ -3533,10 +3958,10 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, initScheduleData(I, I->getNextNode(), nullptr, nullptr); ScheduleStart = I; ScheduleEnd = I->getNextNode(); - if (isOneOf(OpValue, I) != I) + if (isOneOf(S, I) != I) CheckSheduleForI(I); assert(ScheduleEnd && "tried to vectorize a TerminatorInst?"); - DEBUG(dbgs() << "SLP: initialize schedule region to " << *I << "\n"); + LLVM_DEBUG(dbgs() << "SLP: initialize schedule region to " << *I << "\n"); return true; } // Search up and down at the same time, because we don't know if the new @@ -3548,7 +3973,7 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, BasicBlock::iterator LowerEnd = BB->end(); while (true) { if (++ScheduleRegionSize > ScheduleRegionSizeLimit) { - DEBUG(dbgs() << "SLP: exceeded schedule region size limit\n"); + LLVM_DEBUG(dbgs() << "SLP: exceeded schedule region size limit\n"); return false; } @@ -3556,9 +3981,10 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, if (&*UpIter == I) { initScheduleData(I, ScheduleStart, nullptr, FirstLoadStoreInRegion); ScheduleStart = I; - if (isOneOf(OpValue, I) != I) + if (isOneOf(S, I) != I) CheckSheduleForI(I); - DEBUG(dbgs() << "SLP: extend schedule region start to " << *I << "\n"); + LLVM_DEBUG(dbgs() << "SLP: extend schedule region start to " << *I + << "\n"); return true; } UpIter++; @@ -3568,10 +3994,11 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, initScheduleData(ScheduleEnd, I->getNextNode(), LastLoadStoreInRegion, nullptr); ScheduleEnd = I->getNextNode(); - if (isOneOf(OpValue, I) != I) + if (isOneOf(S, I) != I) CheckSheduleForI(I); assert(ScheduleEnd && "tried to vectorize a TerminatorInst?"); - DEBUG(dbgs() << "SLP: extend schedule region end to " << *I << "\n"); + LLVM_DEBUG(dbgs() << "SLP: extend schedule region end to " << *I + << "\n"); return true; } DownIter++; @@ -3635,7 +4062,8 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, assert(isInSchedulingRegion(BundleMember)); if (!BundleMember->hasValidDependencies()) { - DEBUG(dbgs() << "SLP: update deps of " << *BundleMember << "\n"); + LLVM_DEBUG(dbgs() << "SLP: update deps of " << *BundleMember + << "\n"); BundleMember->Dependencies = 0; BundleMember->resetUnscheduledDeps(); @@ -3727,7 +4155,7 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, // i0 to i3, we have transitive dependencies from i0 to i6,i7,i8 // and we can abort this loop at i6. if (DistToSrc >= 2 * MaxMemDepDistance) - break; + break; DistToSrc++; } } @@ -3736,7 +4164,8 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, } if (InsertInReadyList && SD->isReady()) { ReadyInsts.push_back(SD); - DEBUG(dbgs() << "SLP: gets ready on update: " << *SD->Inst << "\n"); + LLVM_DEBUG(dbgs() << "SLP: gets ready on update: " << *SD->Inst + << "\n"); } } } @@ -3759,7 +4188,7 @@ void BoUpSLP::scheduleBlock(BlockScheduling *BS) { if (!BS->ScheduleStart) return; - DEBUG(dbgs() << "SLP: schedule block " << BS->BB->getName() << "\n"); + LLVM_DEBUG(dbgs() << "SLP: schedule block " << BS->BB->getName() << "\n"); BS->resetSchedule(); @@ -4025,7 +4454,11 @@ void BoUpSLP::computeMinimumValueSizes() { // We start by looking at each entry that can be demoted. We compute the // maximum bit width required to store the scalar by using ValueTracking to // compute the number of high-order bits we can truncate. - if (MaxBitWidth == DL->getTypeSizeInBits(TreeRoot[0]->getType())) { + if (MaxBitWidth == DL->getTypeSizeInBits(TreeRoot[0]->getType()) && + llvm::all_of(TreeRoot, [](Value *R) { + assert(R->hasOneUse() && "Root should have only one use!"); + return isa<GetElementPtrInst>(R->user_back()); + })) { MaxBitWidth = 8u; // Determine if the sign bit of all the roots is known to be zero. If not, @@ -4188,7 +4621,7 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_, if (F.hasFnAttribute(Attribute::NoImplicitFloat)) return false; - DEBUG(dbgs() << "SLP: Analyzing blocks in " << F.getName() << ".\n"); + LLVM_DEBUG(dbgs() << "SLP: Analyzing blocks in " << F.getName() << ".\n"); // Use the bottom up slp vectorizer to construct chains that start with // store instructions. @@ -4203,8 +4636,8 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_, // Vectorize trees that end at stores. if (!Stores.empty()) { - DEBUG(dbgs() << "SLP: Found stores for " << Stores.size() - << " underlying objects.\n"); + LLVM_DEBUG(dbgs() << "SLP: Found stores for " << Stores.size() + << " underlying objects.\n"); Changed |= vectorizeStoreChains(R); } @@ -4215,21 +4648,21 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_, // is primarily intended to catch gather-like idioms ending at // non-consecutive loads. if (!GEPs.empty()) { - DEBUG(dbgs() << "SLP: Found GEPs for " << GEPs.size() - << " underlying objects.\n"); + LLVM_DEBUG(dbgs() << "SLP: Found GEPs for " << GEPs.size() + << " underlying objects.\n"); Changed |= vectorizeGEPIndices(BB, R); } } if (Changed) { - R.optimizeGatherSequence(F); - DEBUG(dbgs() << "SLP: vectorized \"" << F.getName() << "\"\n"); - DEBUG(verifyFunction(F)); + R.optimizeGatherSequence(); + LLVM_DEBUG(dbgs() << "SLP: vectorized \"" << F.getName() << "\"\n"); + LLVM_DEBUG(verifyFunction(F)); } return Changed; } -/// \brief Check that the Values in the slice in VL array are still existent in +/// Check that the Values in the slice in VL array are still existent in /// the WeakTrackingVH array. /// Vectorization of part of the VL array may cause later values in the VL array /// to become invalid. We track when this has happened in the WeakTrackingVH @@ -4244,30 +4677,28 @@ static bool hasValueBeenRAUWed(ArrayRef<Value *> VL, bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, unsigned VecRegSize) { - unsigned ChainLen = Chain.size(); - DEBUG(dbgs() << "SLP: Analyzing a store chain of length " << ChainLen - << "\n"); - unsigned Sz = R.getVectorElementSize(Chain[0]); - unsigned VF = VecRegSize / Sz; + const unsigned ChainLen = Chain.size(); + LLVM_DEBUG(dbgs() << "SLP: Analyzing a store chain of length " << ChainLen + << "\n"); + const unsigned Sz = R.getVectorElementSize(Chain[0]); + const unsigned VF = VecRegSize / Sz; if (!isPowerOf2_32(Sz) || VF < 2) return false; // Keep track of values that were deleted by vectorizing in the loop below. - SmallVector<WeakTrackingVH, 8> TrackValues(Chain.begin(), Chain.end()); + const SmallVector<WeakTrackingVH, 8> TrackValues(Chain.begin(), Chain.end()); bool Changed = false; // Look for profitable vectorizable trees at all offsets, starting at zero. - for (unsigned i = 0, e = ChainLen; i < e; ++i) { - if (i + VF > e) - break; + for (unsigned i = 0, e = ChainLen; i + VF <= e; ++i) { // Check that a previous iteration of this loop did not delete the Value. if (hasValueBeenRAUWed(Chain, TrackValues, i, VF)) continue; - DEBUG(dbgs() << "SLP: Analyzing " << VF << " stores at offset " << i - << "\n"); + LLVM_DEBUG(dbgs() << "SLP: Analyzing " << VF << " stores at offset " << i + << "\n"); ArrayRef<Value *> Operands = Chain.slice(i, VF); R.buildTree(Operands); @@ -4278,9 +4709,10 @@ bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, int Cost = R.getTreeCost(); - DEBUG(dbgs() << "SLP: Found cost=" << Cost << " for VF=" << VF << "\n"); + LLVM_DEBUG(dbgs() << "SLP: Found cost=" << Cost << " for VF=" << VF + << "\n"); if (Cost < -SLPCostThreshold) { - DEBUG(dbgs() << "SLP: Decided to vectorize cost=" << Cost << "\n"); + LLVM_DEBUG(dbgs() << "SLP: Decided to vectorize cost=" << Cost << "\n"); using namespace ore; @@ -4417,66 +4849,48 @@ bool SLPVectorizerPass::tryToVectorizePair(Value *A, Value *B, BoUpSLP &R) { if (!A || !B) return false; Value *VL[] = { A, B }; - return tryToVectorizeList(VL, R, None, true); + return tryToVectorizeList(VL, R, /*UserCost=*/0, true); } bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, - ArrayRef<Value *> BuildVector, - bool AllowReorder, - bool NeedExtraction) { + int UserCost, bool AllowReorder) { if (VL.size() < 2) return false; - DEBUG(dbgs() << "SLP: Trying to vectorize a list of length = " << VL.size() - << ".\n"); + LLVM_DEBUG(dbgs() << "SLP: Trying to vectorize a list of length = " + << VL.size() << ".\n"); - // Check that all of the parts are scalar instructions of the same type. - Instruction *I0 = dyn_cast<Instruction>(VL[0]); - if (!I0) + // Check that all of the parts are scalar instructions of the same type, + // we permit an alternate opcode via InstructionsState. + InstructionsState S = getSameOpcode(VL); + if (!S.getOpcode()) return false; - unsigned Opcode0 = I0->getOpcode(); - + Instruction *I0 = cast<Instruction>(S.OpValue); unsigned Sz = R.getVectorElementSize(I0); unsigned MinVF = std::max(2U, R.getMinVecRegSize() / Sz); unsigned MaxVF = std::max<unsigned>(PowerOf2Floor(VL.size()), MinVF); if (MaxVF < 2) { - R.getORE()->emit([&]() { - return OptimizationRemarkMissed( - SV_NAME, "SmallVF", I0) - << "Cannot SLP vectorize list: vectorization factor " - << "less than 2 is not supported"; - }); - return false; + R.getORE()->emit([&]() { + return OptimizationRemarkMissed(SV_NAME, "SmallVF", I0) + << "Cannot SLP vectorize list: vectorization factor " + << "less than 2 is not supported"; + }); + return false; } for (Value *V : VL) { Type *Ty = V->getType(); if (!isValidElementType(Ty)) { - // NOTE: the following will give user internal llvm type name, which may not be useful + // NOTE: the following will give user internal llvm type name, which may + // not be useful. R.getORE()->emit([&]() { - std::string type_str; - llvm::raw_string_ostream rso(type_str); - Ty->print(rso); - return OptimizationRemarkMissed( - SV_NAME, "UnsupportedType", I0) - << "Cannot SLP vectorize list: type " - << rso.str() + " is unsupported by vectorizer"; - }); - return false; - } - Instruction *Inst = dyn_cast<Instruction>(V); - - if (!Inst) - return false; - if (Inst->getOpcode() != Opcode0) { - R.getORE()->emit([&]() { - return OptimizationRemarkMissed( - SV_NAME, "InequableTypes", I0) - << "Cannot SLP vectorize list: not all of the " - << "parts of scalar instructions are of the same type: " - << ore::NV("Instruction1Opcode", I0) << " and " - << ore::NV("Instruction2Opcode", Inst); + std::string type_str; + llvm::raw_string_ostream rso(type_str); + Ty->print(rso); + return OptimizationRemarkMissed(SV_NAME, "UnsupportedType", I0) + << "Cannot SLP vectorize list: type " + << rso.str() + " is unsupported by vectorizer"; }); return false; } @@ -4513,24 +4927,20 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, if (hasValueBeenRAUWed(VL, TrackValues, I, OpsWidth)) continue; - DEBUG(dbgs() << "SLP: Analyzing " << OpsWidth << " operations " - << "\n"); + LLVM_DEBUG(dbgs() << "SLP: Analyzing " << OpsWidth << " operations " + << "\n"); ArrayRef<Value *> Ops = VL.slice(I, OpsWidth); - ArrayRef<Value *> EmptyArray; - ArrayRef<Value *> BuildVectorSlice; - if (!BuildVector.empty()) - BuildVectorSlice = BuildVector.slice(I, OpsWidth); - - R.buildTree(Ops, NeedExtraction ? EmptyArray : BuildVectorSlice); + R.buildTree(Ops); + Optional<ArrayRef<unsigned>> Order = R.bestOrder(); // TODO: check if we can allow reordering for more cases. - if (AllowReorder && R.shouldReorder()) { + if (AllowReorder && Order) { + // TODO: reorder tree nodes without tree rebuilding. // Conceptually, there is nothing actually preventing us from trying to // reorder a larger list. In fact, we do exactly this when vectorizing // reductions. However, at this point, we only expect to get here when // there are exactly two operations. assert(Ops.size() == 2); - assert(BuildVectorSlice.empty()); Value *ReorderedOps[] = {Ops[1], Ops[0]}; R.buildTree(ReorderedOps, None); } @@ -4538,43 +4948,19 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, continue; R.computeMinimumValueSizes(); - int Cost = R.getTreeCost(); + int Cost = R.getTreeCost() - UserCost; CandidateFound = true; MinCost = std::min(MinCost, Cost); if (Cost < -SLPCostThreshold) { - DEBUG(dbgs() << "SLP: Vectorizing list at cost:" << Cost << ".\n"); + LLVM_DEBUG(dbgs() << "SLP: Vectorizing list at cost:" << Cost << ".\n"); R.getORE()->emit(OptimizationRemark(SV_NAME, "VectorizedList", cast<Instruction>(Ops[0])) << "SLP vectorized with cost " << ore::NV("Cost", Cost) << " and with tree size " << ore::NV("TreeSize", R.getTreeSize())); - Value *VectorizedRoot = R.vectorizeTree(); - - // Reconstruct the build vector by extracting the vectorized root. This - // way we handle the case where some elements of the vector are - // undefined. - // (return (inserelt <4 xi32> (insertelt undef (opd0) 0) (opd1) 2)) - if (!BuildVectorSlice.empty()) { - // The insert point is the last build vector instruction. The - // vectorized root will precede it. This guarantees that we get an - // instruction. The vectorized tree could have been constant folded. - Instruction *InsertAfter = cast<Instruction>(BuildVectorSlice.back()); - unsigned VecIdx = 0; - for (auto &V : BuildVectorSlice) { - IRBuilder<NoFolder> Builder(InsertAfter->getParent(), - ++BasicBlock::iterator(InsertAfter)); - Instruction *I = cast<Instruction>(V); - assert(isa<InsertElementInst>(I) || isa<InsertValueInst>(I)); - Instruction *Extract = - cast<Instruction>(Builder.CreateExtractElement( - VectorizedRoot, Builder.getInt32(VecIdx++))); - I->setOperand(1, Extract); - I->moveAfter(Extract); - InsertAfter = I; - } - } + R.vectorizeTree(); // Move to the next bundle. I += VF - 1; NextInst = I + 1; @@ -4585,18 +4971,16 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, if (!Changed && CandidateFound) { R.getORE()->emit([&]() { - return OptimizationRemarkMissed( - SV_NAME, "NotBeneficial", I0) - << "List vectorization was possible but not beneficial with cost " - << ore::NV("Cost", MinCost) << " >= " - << ore::NV("Treshold", -SLPCostThreshold); + return OptimizationRemarkMissed(SV_NAME, "NotBeneficial", I0) + << "List vectorization was possible but not beneficial with cost " + << ore::NV("Cost", MinCost) << " >= " + << ore::NV("Treshold", -SLPCostThreshold); }); } else if (!Changed) { R.getORE()->emit([&]() { - return OptimizationRemarkMissed( - SV_NAME, "NotPossible", I0) - << "Cannot SLP vectorize list: vectorization was impossible" - << " with available vectorization factors"; + return OptimizationRemarkMissed(SV_NAME, "NotPossible", I0) + << "Cannot SLP vectorize list: vectorization was impossible" + << " with available vectorization factors"; }); } return Changed; @@ -4645,7 +5029,7 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) { return false; } -/// \brief Generate a shuffle mask to be used in a reduction tree. +/// Generate a shuffle mask to be used in a reduction tree. /// /// \param VecLen The length of the vector to be reduced. /// \param NumEltsToRdx The number of elements that should be reduced in the @@ -5128,6 +5512,77 @@ class HorizontalReduction { return OperationData( Instruction::FCmp, LHS, RHS, RK_Max, cast<Instruction>(Select->getCondition())->hasNoNaNs()); + } else { + // Try harder: look for min/max pattern based on instructions producing + // same values such as: select ((cmp Inst1, Inst2), Inst1, Inst2). + // During the intermediate stages of SLP, it's very common to have + // pattern like this (since optimizeGatherSequence is run only once + // at the end): + // %1 = extractelement <2 x i32> %a, i32 0 + // %2 = extractelement <2 x i32> %a, i32 1 + // %cond = icmp sgt i32 %1, %2 + // %3 = extractelement <2 x i32> %a, i32 0 + // %4 = extractelement <2 x i32> %a, i32 1 + // %select = select i1 %cond, i32 %3, i32 %4 + CmpInst::Predicate Pred; + Instruction *L1; + Instruction *L2; + + LHS = Select->getTrueValue(); + RHS = Select->getFalseValue(); + Value *Cond = Select->getCondition(); + + // TODO: Support inverse predicates. + if (match(Cond, m_Cmp(Pred, m_Specific(LHS), m_Instruction(L2)))) { + if (!isa<ExtractElementInst>(RHS) || + !L2->isIdenticalTo(cast<Instruction>(RHS))) + return OperationData(V); + } else if (match(Cond, m_Cmp(Pred, m_Instruction(L1), m_Specific(RHS)))) { + if (!isa<ExtractElementInst>(LHS) || + !L1->isIdenticalTo(cast<Instruction>(LHS))) + return OperationData(V); + } else { + if (!isa<ExtractElementInst>(LHS) || !isa<ExtractElementInst>(RHS)) + return OperationData(V); + if (!match(Cond, m_Cmp(Pred, m_Instruction(L1), m_Instruction(L2))) || + !L1->isIdenticalTo(cast<Instruction>(LHS)) || + !L2->isIdenticalTo(cast<Instruction>(RHS))) + return OperationData(V); + } + switch (Pred) { + default: + return OperationData(V); + + case CmpInst::ICMP_ULT: + case CmpInst::ICMP_ULE: + return OperationData(Instruction::ICmp, LHS, RHS, RK_UMin); + + case CmpInst::ICMP_SLT: + case CmpInst::ICMP_SLE: + return OperationData(Instruction::ICmp, LHS, RHS, RK_Min); + + case CmpInst::FCMP_OLT: + case CmpInst::FCMP_OLE: + case CmpInst::FCMP_ULT: + case CmpInst::FCMP_ULE: + return OperationData(Instruction::FCmp, LHS, RHS, RK_Min, + cast<Instruction>(Cond)->hasNoNaNs()); + + case CmpInst::ICMP_UGT: + case CmpInst::ICMP_UGE: + return OperationData(Instruction::ICmp, LHS, RHS, RK_UMax); + + case CmpInst::ICMP_SGT: + case CmpInst::ICMP_SGE: + return OperationData(Instruction::ICmp, LHS, RHS, RK_Max); + + case CmpInst::FCMP_OGT: + case CmpInst::FCMP_OGE: + case CmpInst::FCMP_UGT: + case CmpInst::FCMP_UGE: + return OperationData(Instruction::FCmp, LHS, RHS, RK_Max, + cast<Instruction>(Cond)->hasNoNaNs()); + } } } return OperationData(V); @@ -5136,7 +5591,7 @@ class HorizontalReduction { public: HorizontalReduction() = default; - /// \brief Try to find a reduction tree. + /// Try to find a reduction tree. bool matchAssociativeReduction(PHINode *Phi, Instruction *B) { assert((!Phi || is_contained(Phi->operands(), B)) && "Thi phi needs to use the binary operator"); @@ -5164,6 +5619,8 @@ public: Type *Ty = B->getType(); if (!isValidElementType(Ty)) return false; + if (!Ty->isIntOrIntVectorTy() && !Ty->isFPOrFPVectorTy()) + return false; ReducedValueData.clear(); ReductionRoot = B; @@ -5262,7 +5719,7 @@ public: return true; } - /// \brief Attempt to vectorize the tree found by + /// Attempt to vectorize the tree found by /// matchAssociativeReduction. bool tryToReduce(BoUpSLP &V, TargetTransformInfo *TTI) { if (ReducedVals.empty()) @@ -5295,9 +5752,14 @@ public: while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) { auto VL = makeArrayRef(&ReducedVals[i], ReduxWidth); V.buildTree(VL, ExternallyUsedValues, IgnoreList); - if (V.shouldReorder()) { - SmallVector<Value *, 8> Reversed(VL.rbegin(), VL.rend()); - V.buildTree(Reversed, ExternallyUsedValues, IgnoreList); + Optional<ArrayRef<unsigned>> Order = V.bestOrder(); + // TODO: Handle orders of size less than number of elements in the vector. + if (Order && Order->size() == VL.size()) { + // TODO: reorder tree nodes without tree rebuilding. + SmallVector<Value *, 4> ReorderedOps(VL.size()); + llvm::transform(*Order, ReorderedOps.begin(), + [VL](const unsigned Idx) { return VL[Idx]; }); + V.buildTree(ReorderedOps, ExternallyUsedValues, IgnoreList); } if (V.isTreeTinyAndNotFullyVectorizable()) break; @@ -5305,8 +5767,9 @@ public: V.computeMinimumValueSizes(); // Estimate cost. - int Cost = - V.getTreeCost() + getReductionCost(TTI, ReducedVals[i], ReduxWidth); + int TreeCost = V.getTreeCost(); + int ReductionCost = getReductionCost(TTI, ReducedVals[i], ReduxWidth); + int Cost = TreeCost + ReductionCost; if (Cost >= -SLPCostThreshold) { V.getORE()->emit([&]() { return OptimizationRemarkMissed( @@ -5319,8 +5782,8 @@ public: break; } - DEBUG(dbgs() << "SLP: Vectorizing horizontal reduction at cost:" << Cost - << ". (HorRdx)\n"); + LLVM_DEBUG(dbgs() << "SLP: Vectorizing horizontal reduction at cost:" + << Cost << ". (HorRdx)\n"); V.getORE()->emit([&]() { return OptimizationRemark( SV_NAME, "VectorizedHorizontalReduction", cast<Instruction>(VL[0])) @@ -5382,7 +5845,7 @@ public: } private: - /// \brief Calculate the cost of a reduction. + /// Calculate the cost of a reduction. int getReductionCost(TargetTransformInfo *TTI, Value *FirstReducedVal, unsigned ReduxWidth) { Type *ScalarTy = FirstReducedVal->getType(); @@ -5441,16 +5904,16 @@ private: } ScalarReduxCost *= (ReduxWidth - 1); - DEBUG(dbgs() << "SLP: Adding cost " << VecReduxCost - ScalarReduxCost - << " for reduction that starts with " << *FirstReducedVal - << " (It is a " - << (IsPairwiseReduction ? "pairwise" : "splitting") - << " reduction)\n"); + LLVM_DEBUG(dbgs() << "SLP: Adding cost " << VecReduxCost - ScalarReduxCost + << " for reduction that starts with " << *FirstReducedVal + << " (It is a " + << (IsPairwiseReduction ? "pairwise" : "splitting") + << " reduction)\n"); return VecReduxCost - ScalarReduxCost; } - /// \brief Emit a horizontal reduction of the vectorized value. + /// Emit a horizontal reduction of the vectorized value. Value *emitReduction(Value *VectorizedValue, IRBuilder<> &Builder, unsigned ReduxWidth, const TargetTransformInfo *TTI) { assert(VectorizedValue && "Need to have a vectorized tree node"); @@ -5486,7 +5949,7 @@ private: } // end anonymous namespace -/// \brief Recognize construction of vectors like +/// Recognize construction of vectors like /// %ra = insertelement <4 x float> undef, float %s0, i32 0 /// %rb = insertelement <4 x float> %ra, float %s1, i32 1 /// %rc = insertelement <4 x float> %rb, float %s2, i32 2 @@ -5495,11 +5958,17 @@ private: /// /// Returns true if it matches static bool findBuildVector(InsertElementInst *LastInsertElem, - SmallVectorImpl<Value *> &BuildVector, - SmallVectorImpl<Value *> &BuildVectorOpds) { + TargetTransformInfo *TTI, + SmallVectorImpl<Value *> &BuildVectorOpds, + int &UserCost) { + UserCost = 0; Value *V = nullptr; do { - BuildVector.push_back(LastInsertElem); + if (auto *CI = dyn_cast<ConstantInt>(LastInsertElem->getOperand(2))) { + UserCost += TTI->getVectorInstrCost(Instruction::InsertElement, + LastInsertElem->getType(), + CI->getZExtValue()); + } BuildVectorOpds.push_back(LastInsertElem->getOperand(1)); V = LastInsertElem->getOperand(0); if (isa<UndefValue>(V)) @@ -5508,20 +5977,17 @@ static bool findBuildVector(InsertElementInst *LastInsertElem, if (!LastInsertElem || !LastInsertElem->hasOneUse()) return false; } while (true); - std::reverse(BuildVector.begin(), BuildVector.end()); std::reverse(BuildVectorOpds.begin(), BuildVectorOpds.end()); return true; } -/// \brief Like findBuildVector, but looks for construction of aggregate. +/// Like findBuildVector, but looks for construction of aggregate. /// /// \return true if it matches. static bool findBuildAggregate(InsertValueInst *IV, - SmallVectorImpl<Value *> &BuildVector, SmallVectorImpl<Value *> &BuildVectorOpds) { Value *V; do { - BuildVector.push_back(IV); BuildVectorOpds.push_back(IV->getInsertedValueOperand()); V = IV->getAggregateOperand(); if (isa<UndefValue>(V)) @@ -5530,7 +5996,6 @@ static bool findBuildAggregate(InsertValueInst *IV, if (!IV || !IV->hasOneUse()) return false; } while (true); - std::reverse(BuildVector.begin(), BuildVector.end()); std::reverse(BuildVectorOpds.begin(), BuildVectorOpds.end()); return true; } @@ -5539,7 +6004,7 @@ static bool PhiTypeSorterFunc(Value *V, Value *V2) { return V->getType() < V2->getType(); } -/// \brief Try and get a reduction value from a phi node. +/// Try and get a reduction value from a phi node. /// /// Given a phi node \p P in a block \p ParentBB, consider possible reductions /// if they come from either \p ParentBB or a containing loop latch. @@ -5552,9 +6017,8 @@ static Value *getReductionValue(const DominatorTree *DT, PHINode *P, // reduction phi. Vectorizing such cases has been reported to cause // miscompiles. See PR25787. auto DominatedReduxValue = [&](Value *R) { - return ( - dyn_cast<Instruction>(R) && - DT->dominates(P->getParent(), dyn_cast<Instruction>(R)->getParent())); + return isa<Instruction>(R) && + DT->dominates(P->getParent(), cast<Instruction>(R)->getParent()); }; Value *Rdx = nullptr; @@ -5624,7 +6088,7 @@ static bool tryToVectorizeHorReductionOrInstOperands( // Interrupt the process if the Root instruction itself was vectorized or all // sub-trees not higher that RecursionMaxDepth were analyzed/vectorized. SmallVector<std::pair<WeakTrackingVH, unsigned>, 8> Stack(1, {Root, 0}); - SmallSet<Value *, 8> VisitedInstrs; + SmallPtrSet<Value *, 8> VisitedInstrs; bool Res = false; while (!Stack.empty()) { Value *V; @@ -5706,27 +6170,29 @@ bool SLPVectorizerPass::vectorizeInsertValueInst(InsertValueInst *IVI, if (!R.canMapToVector(IVI->getType(), DL)) return false; - SmallVector<Value *, 16> BuildVector; SmallVector<Value *, 16> BuildVectorOpds; - if (!findBuildAggregate(IVI, BuildVector, BuildVectorOpds)) + if (!findBuildAggregate(IVI, BuildVectorOpds)) return false; - DEBUG(dbgs() << "SLP: array mappable to vector: " << *IVI << "\n"); + LLVM_DEBUG(dbgs() << "SLP: array mappable to vector: " << *IVI << "\n"); // Aggregate value is unlikely to be processed in vector register, we need to // extract scalars into scalar registers, so NeedExtraction is set true. - return tryToVectorizeList(BuildVectorOpds, R, BuildVector, false, true); + return tryToVectorizeList(BuildVectorOpds, R); } bool SLPVectorizerPass::vectorizeInsertElementInst(InsertElementInst *IEI, BasicBlock *BB, BoUpSLP &R) { - SmallVector<Value *, 16> BuildVector; + int UserCost; SmallVector<Value *, 16> BuildVectorOpds; - if (!findBuildVector(IEI, BuildVector, BuildVectorOpds)) + if (!findBuildVector(IEI, TTI, BuildVectorOpds, UserCost) || + (llvm::all_of(BuildVectorOpds, + [](Value *V) { return isa<ExtractElementInst>(V); }) && + isShuffle(BuildVectorOpds))) return false; // Vectorize starting with the build vector operands ignoring the BuildVector // instructions for the purpose of scheduling and user extraction. - return tryToVectorizeList(BuildVectorOpds, R, BuildVector); + return tryToVectorizeList(BuildVectorOpds, R, UserCost); } bool SLPVectorizerPass::vectorizeCmpInst(CmpInst *CI, BasicBlock *BB, @@ -5763,7 +6229,7 @@ bool SLPVectorizerPass::vectorizeSimpleInstructions( bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { bool Changed = false; SmallVector<Value *, 4> Incoming; - SmallSet<Value *, 16> VisitedInstrs; + SmallPtrSet<Value *, 16> VisitedInstrs; bool HaveVectorizedPhiNodes = true; while (HaveVectorizedPhiNodes) { @@ -5798,14 +6264,15 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { // Try to vectorize them. unsigned NumElts = (SameTypeIt - IncIt); - DEBUG(errs() << "SLP: Trying to vectorize starting at PHIs (" << NumElts << ")\n"); + LLVM_DEBUG(dbgs() << "SLP: Trying to vectorize starting at PHIs (" + << NumElts << ")\n"); // The order in which the phi nodes appear in the program does not matter. // So allow tryToVectorizeList to reorder them if it is beneficial. This // is done when there are exactly two elements since tryToVectorizeList // asserts that there are only two values when AllowReorder is true. bool AllowReorder = NumElts == 2; if (NumElts > 1 && tryToVectorizeList(makeArrayRef(IncIt, NumElts), R, - None, AllowReorder)) { + /*UserCost=*/0, AllowReorder)) { // Success start over because instructions might have been changed. HaveVectorizedPhiNodes = true; Changed = true; @@ -5885,7 +6352,6 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { if (isa<InsertElementInst>(it) || isa<CmpInst>(it) || isa<InsertValueInst>(it)) PostProcessInstructions.push_back(&*it); - } return Changed; @@ -5899,8 +6365,8 @@ bool SLPVectorizerPass::vectorizeGEPIndices(BasicBlock *BB, BoUpSLP &R) { if (Entry.second.size() < 2) continue; - DEBUG(dbgs() << "SLP: Analyzing a getelementptr list of length " - << Entry.second.size() << ".\n"); + LLVM_DEBUG(dbgs() << "SLP: Analyzing a getelementptr list of length " + << Entry.second.size() << ".\n"); // We process the getelementptr list in chunks of 16 (like we do for // stores) to minimize compile-time. @@ -5982,14 +6448,14 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { if (it->second.size() < 2) continue; - DEBUG(dbgs() << "SLP: Analyzing a store chain of length " - << it->second.size() << ".\n"); + LLVM_DEBUG(dbgs() << "SLP: Analyzing a store chain of length " + << it->second.size() << ".\n"); // Process the stores in chunks of 16. // TODO: The limit of 16 inhibits greater vectorization factors. // For example, AVX2 supports v32i8. Increasing this limit, however, // may cause a significant compile-time increase. - for (unsigned CI = 0, CE = it->second.size(); CI < CE; CI+=16) { + for (unsigned CI = 0, CE = it->second.size(); CI < CE; CI += 16) { unsigned Len = std::min<unsigned>(CE - CI, 16); Changed |= vectorizeStores(makeArrayRef(&it->second[CI], Len), R); } diff --git a/lib/Transforms/Vectorize/VPRecipeBuilder.h b/lib/Transforms/Vectorize/VPRecipeBuilder.h new file mode 100644 index 000000000000..f43a8bb123b1 --- /dev/null +++ b/lib/Transforms/Vectorize/VPRecipeBuilder.h @@ -0,0 +1,131 @@ +//===- VPRecipeBuilder.h - Helper class to build recipes --------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_VECTORIZE_VPRECIPEBUILDER_H +#define LLVM_TRANSFORMS_VECTORIZE_VPRECIPEBUILDER_H + +#include "LoopVectorizationPlanner.h" +#include "VPlan.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/IR/IRBuilder.h" + +namespace llvm { + +class LoopVectorizationLegality; +class LoopVectorizationCostModel; +class TargetTransformInfo; +class TargetLibraryInfo; + +/// Helper class to create VPRecipies from IR instructions. +class VPRecipeBuilder { + /// The loop that we evaluate. + Loop *OrigLoop; + + /// Target Library Info. + const TargetLibraryInfo *TLI; + + /// Target Transform Info. + const TargetTransformInfo *TTI; + + /// The legality analysis. + LoopVectorizationLegality *Legal; + + /// The profitablity analysis. + LoopVectorizationCostModel &CM; + + VPBuilder &Builder; + + /// When we if-convert we need to create edge masks. We have to cache values + /// so that we don't end up with exponential recursion/IR. Note that + /// if-conversion currently takes place during VPlan-construction, so these + /// caches are only used at that stage. + using EdgeMaskCacheTy = + DenseMap<std::pair<BasicBlock *, BasicBlock *>, VPValue *>; + using BlockMaskCacheTy = DenseMap<BasicBlock *, VPValue *>; + EdgeMaskCacheTy EdgeMaskCache; + BlockMaskCacheTy BlockMaskCache; + +public: + /// A helper function that computes the predicate of the block BB, assuming + /// that the header block of the loop is set to True. It returns the *entry* + /// mask for the block BB. + VPValue *createBlockInMask(BasicBlock *BB, VPlanPtr &Plan); + + /// A helper function that computes the predicate of the edge between SRC + /// and DST. + VPValue *createEdgeMask(BasicBlock *Src, BasicBlock *Dst, VPlanPtr &Plan); + + /// Check if \I belongs to an Interleave Group within the given VF \p Range, + /// \return true in the first returned value if so and false otherwise. + /// Build a new VPInterleaveGroup Recipe if \I is the primary member of an IG + /// for \p Range.Start, and provide it as the second returned value. + /// Note that if \I is an adjunct member of an IG for \p Range.Start, the + /// \return value is <true, nullptr>, as it is handled by another recipe. + /// \p Range.End may be decreased to ensure same decision from \p Range.Start + /// to \p Range.End. + VPInterleaveRecipe *tryToInterleaveMemory(Instruction *I, VFRange &Range); + + /// Check if \I is a memory instruction to be widened for \p Range.Start and + /// potentially masked. Such instructions are handled by a recipe that takes + /// an additional VPInstruction for the mask. + VPWidenMemoryInstructionRecipe * + tryToWidenMemory(Instruction *I, VFRange &Range, VPlanPtr &Plan); + + /// Check if an induction recipe should be constructed for \I within the given + /// VF \p Range. If so build and return it. If not, return null. \p Range.End + /// may be decreased to ensure same decision from \p Range.Start to + /// \p Range.End. + VPWidenIntOrFpInductionRecipe *tryToOptimizeInduction(Instruction *I, + VFRange &Range); + + /// Handle non-loop phi nodes. Currently all such phi nodes are turned into + /// a sequence of select instructions as the vectorizer currently performs + /// full if-conversion. + VPBlendRecipe *tryToBlend(Instruction *I, VPlanPtr &Plan); + + /// Check if \p I can be widened within the given VF \p Range. If \p I can be + /// widened for \p Range.Start, check if the last recipe of \p VPBB can be + /// extended to include \p I or else build a new VPWidenRecipe for it and + /// append it to \p VPBB. Return true if \p I can be widened for Range.Start, + /// false otherwise. Range.End may be decreased to ensure same decision from + /// \p Range.Start to \p Range.End. + bool tryToWiden(Instruction *I, VPBasicBlock *VPBB, VFRange &Range); + + /// Create a replicating region for instruction \p I that requires + /// predication. \p PredRecipe is a VPReplicateRecipe holding \p I. + VPRegionBlock *createReplicateRegion(Instruction *I, VPRecipeBase *PredRecipe, + VPlanPtr &Plan); + +public: + VPRecipeBuilder(Loop *OrigLoop, const TargetLibraryInfo *TLI, + const TargetTransformInfo *TTI, + LoopVectorizationLegality *Legal, + LoopVectorizationCostModel &CM, VPBuilder &Builder) + : OrigLoop(OrigLoop), TLI(TLI), TTI(TTI), Legal(Legal), CM(CM), + Builder(Builder) {} + + /// Check if a recipe can be create for \p I withing the given VF \p Range. + /// If a recipe can be created, it adds it to \p VPBB. + bool tryToCreateRecipe(Instruction *Instr, VFRange &Range, VPlanPtr &Plan, + VPBasicBlock *VPBB); + + /// Build a VPReplicationRecipe for \p I and enclose it within a Region if it + /// is predicated. \return \p VPBB augmented with this new recipe if \p I is + /// not predicated, otherwise \return a new VPBasicBlock that succeeds the new + /// Region. Update the packing decision of predicated instructions if they + /// feed \p I. Range.End may be decreased to ensure same recipe behavior from + /// \p Range.Start to \p Range.End. + VPBasicBlock *handleReplication( + Instruction *I, VFRange &Range, VPBasicBlock *VPBB, + DenseMap<Instruction *, VPReplicateRecipe *> &PredInst2Recipe, + VPlanPtr &Plan); +}; +} // end namespace llvm + +#endif // LLVM_TRANSFORMS_VECTORIZE_VPRECIPEBUILDER_H diff --git a/lib/Transforms/Vectorize/VPlan.cpp b/lib/Transforms/Vectorize/VPlan.cpp index 4e54fc6db2a5..f7b07b722bb1 100644 --- a/lib/Transforms/Vectorize/VPlan.cpp +++ b/lib/Transforms/Vectorize/VPlan.cpp @@ -116,7 +116,7 @@ VPBasicBlock::createEmptyBasicBlock(VPTransformState::CFGState &CFG) { BasicBlock *PrevBB = CFG.PrevBB; BasicBlock *NewBB = BasicBlock::Create(PrevBB->getContext(), getName(), PrevBB->getParent(), CFG.LastBB); - DEBUG(dbgs() << "LV: created " << NewBB->getName() << '\n'); + LLVM_DEBUG(dbgs() << "LV: created " << NewBB->getName() << '\n'); // Hook up the new basic block to its predecessors. for (VPBlockBase *PredVPBlock : getHierarchicalPredecessors()) { @@ -125,7 +125,7 @@ VPBasicBlock::createEmptyBasicBlock(VPTransformState::CFGState &CFG) { BasicBlock *PredBB = CFG.VPBB2IRBB[PredVPBB]; assert(PredBB && "Predecessor basic-block not found building successor."); auto *PredBBTerminator = PredBB->getTerminator(); - DEBUG(dbgs() << "LV: draw edge from" << PredBB->getName() << '\n'); + LLVM_DEBUG(dbgs() << "LV: draw edge from" << PredBB->getName() << '\n'); if (isa<UnreachableInst>(PredBBTerminator)) { assert(PredVPSuccessors.size() == 1 && "Predecessor ending w/o branch must have single successor."); @@ -175,8 +175,8 @@ void VPBasicBlock::execute(VPTransformState *State) { } // 2. Fill the IR basic block with IR instructions. - DEBUG(dbgs() << "LV: vectorizing VPBB:" << getName() - << " in BB:" << NewBB->getName() << '\n'); + LLVM_DEBUG(dbgs() << "LV: vectorizing VPBB:" << getName() + << " in BB:" << NewBB->getName() << '\n'); State->CFG.VPBB2IRBB[this] = NewBB; State->CFG.PrevVPBB = this; @@ -184,7 +184,7 @@ void VPBasicBlock::execute(VPTransformState *State) { for (VPRecipeBase &Recipe : Recipes) Recipe.execute(*State); - DEBUG(dbgs() << "LV: filled BB:" << *NewBB); + LLVM_DEBUG(dbgs() << "LV: filled BB:" << *NewBB); } void VPRegionBlock::execute(VPTransformState *State) { @@ -193,7 +193,7 @@ void VPRegionBlock::execute(VPTransformState *State) { if (!isReplicator()) { // Visit the VPBlocks connected to "this", starting from it. for (VPBlockBase *Block : RPOT) { - DEBUG(dbgs() << "LV: VPBlock in RPO " << Block->getName() << '\n'); + LLVM_DEBUG(dbgs() << "LV: VPBlock in RPO " << Block->getName() << '\n'); Block->execute(State); } return; @@ -210,7 +210,7 @@ void VPRegionBlock::execute(VPTransformState *State) { State->Instance->Lane = Lane; // Visit the VPBlocks connected to \p this, starting from it. for (VPBlockBase *Block : RPOT) { - DEBUG(dbgs() << "LV: VPBlock in RPO " << Block->getName() << '\n'); + LLVM_DEBUG(dbgs() << "LV: VPBlock in RPO " << Block->getName() << '\n'); Block->execute(State); } } @@ -220,6 +220,15 @@ void VPRegionBlock::execute(VPTransformState *State) { State->Instance.reset(); } +void VPRecipeBase::insertBefore(VPRecipeBase *InsertPos) { + Parent = InsertPos->getParent(); + Parent->getRecipeList().insert(InsertPos->getIterator(), this); +} + +iplist<VPRecipeBase>::iterator VPRecipeBase::eraseFromParent() { + return getParent()->getRecipeList().erase(getIterator()); +} + void VPInstruction::generateInstruction(VPTransformState &State, unsigned Part) { IRBuilder<> &Builder = State.Builder; @@ -356,7 +365,7 @@ void VPlan::updateDominatorTree(DominatorTree *DT, BasicBlock *LoopPreHeaderBB, "One successor of a basic block does not lead to the other."); assert(InterimSucc->getSinglePredecessor() && "Interim successor has more than one predecessor."); - assert(std::distance(pred_begin(PostDomSucc), pred_end(PostDomSucc)) == 2 && + assert(pred_size(PostDomSucc) == 2 && "PostDom successor has more than two predecessors."); DT->addNewBlock(InterimSucc, BB); DT->addNewBlock(PostDomSucc, BB); @@ -448,6 +457,18 @@ void VPlanPrinter::dumpBasicBlock(const VPBasicBlock *BasicBlock) { bumpIndent(1); for (const VPRecipeBase &Recipe : *BasicBlock) Recipe.print(OS, Indent); + + // Dump the condition bit. + const VPValue *CBV = BasicBlock->getCondBit(); + if (CBV) { + OS << " +\n" << Indent << " \"CondBit: "; + if (const VPInstruction *CBI = dyn_cast<VPInstruction>(CBV)) { + CBI->printAsOperand(OS); + OS << " (" << DOT::EscapeString(CBI->getParent()->getName()) << ")\\l\""; + } else + CBV->printAsOperand(OS); + } + bumpIndent(-2); OS << "\n" << Indent << "]\n"; dumpEdges(BasicBlock); diff --git a/lib/Transforms/Vectorize/VPlan.h b/lib/Transforms/Vectorize/VPlan.h index 2ccabfd6af25..866951cb79a4 100644 --- a/lib/Transforms/Vectorize/VPlan.h +++ b/lib/Transforms/Vectorize/VPlan.h @@ -30,6 +30,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/GraphTraits.h" #include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" @@ -42,15 +43,10 @@ #include <map> #include <string> -// The (re)use of existing LoopVectorize classes is subject to future VPlan -// refactoring. -namespace { -class LoopVectorizationLegality; -class LoopVectorizationCostModel; -} // namespace - namespace llvm { +class LoopVectorizationLegality; +class LoopVectorizationCostModel; class BasicBlock; class DominatorTree; class InnerLoopVectorizer; @@ -60,6 +56,20 @@ class raw_ostream; class Value; class VPBasicBlock; class VPRegionBlock; +class VPlan; + +/// A range of powers-of-2 vectorization factors with fixed start and +/// adjustable end. The range includes start and excludes end, e.g.,: +/// [1, 9) = {1, 2, 4, 8} +struct VFRange { + // A power of 2. + const unsigned Start; + + // Need not be a power of 2. If End <= Start range is empty. + unsigned End; +}; + +using VPlanPtr = std::unique_ptr<VPlan>; /// In what follows, the term "input IR" refers to code that is fed into the /// vectorizer whereas the term "output IR" refers to code that is generated by @@ -311,6 +321,8 @@ struct VPTransformState { /// VPBlockBase is the building block of the Hierarchical Control-Flow Graph. /// A VPBlockBase can be either a VPBasicBlock or a VPRegionBlock. class VPBlockBase { + friend class VPBlockUtils; + private: const unsigned char SubclassID; ///< Subclass identifier (for isa/dyn_cast). @@ -327,6 +339,9 @@ private: /// List of successor blocks. SmallVector<VPBlockBase *, 1> Successors; + /// Successor selector, null for zero or single successor blocks. + VPValue *CondBit = nullptr; + /// Add \p Successor as the last successor to this block. void appendSuccessor(VPBlockBase *Successor) { assert(Successor && "Cannot add nullptr successor!"); @@ -377,6 +392,7 @@ public: /// for any other purpose, as the values may change as LLVM evolves. unsigned getVPBlockID() const { return SubclassID; } + VPRegionBlock *getParent() { return Parent; } const VPRegionBlock *getParent() const { return Parent; } void setParent(VPRegionBlock *P) { Parent = P; } @@ -411,6 +427,9 @@ public: return (Predecessors.size() == 1 ? *Predecessors.begin() : nullptr); } + size_t getNumSuccessors() const { return Successors.size(); } + size_t getNumPredecessors() const { return Predecessors.size(); } + /// An Enclosing Block of a block B is any block containing B, including B /// itself. \return the closest enclosing block starting from "this", which /// has successors. \return the root enclosing block if all enclosing blocks @@ -454,34 +473,41 @@ public: return getEnclosingBlockWithPredecessors()->getSinglePredecessor(); } - /// Sets a given VPBlockBase \p Successor as the single successor and \return - /// \p Successor. The parent of this Block is copied to be the parent of - /// \p Successor. - VPBlockBase *setOneSuccessor(VPBlockBase *Successor) { + /// \return the condition bit selecting the successor. + VPValue *getCondBit() { return CondBit; } + + const VPValue *getCondBit() const { return CondBit; } + + void setCondBit(VPValue *CV) { CondBit = CV; } + + /// Set a given VPBlockBase \p Successor as the single successor of this + /// VPBlockBase. This VPBlockBase is not added as predecessor of \p Successor. + /// This VPBlockBase must have no successors. + void setOneSuccessor(VPBlockBase *Successor) { assert(Successors.empty() && "Setting one successor when others exist."); appendSuccessor(Successor); - Successor->appendPredecessor(this); - Successor->Parent = Parent; - return Successor; } - /// Sets two given VPBlockBases \p IfTrue and \p IfFalse to be the two - /// successors. The parent of this Block is copied to be the parent of both - /// \p IfTrue and \p IfFalse. - void setTwoSuccessors(VPBlockBase *IfTrue, VPBlockBase *IfFalse) { + /// Set two given VPBlockBases \p IfTrue and \p IfFalse to be the two + /// successors of this VPBlockBase. \p Condition is set as the successor + /// selector. This VPBlockBase is not added as predecessor of \p IfTrue or \p + /// IfFalse. This VPBlockBase must have no successors. + void setTwoSuccessors(VPBlockBase *IfTrue, VPBlockBase *IfFalse, + VPValue *Condition) { assert(Successors.empty() && "Setting two successors when others exist."); + assert(Condition && "Setting two successors without condition!"); + CondBit = Condition; appendSuccessor(IfTrue); appendSuccessor(IfFalse); - IfTrue->appendPredecessor(this); - IfFalse->appendPredecessor(this); - IfTrue->Parent = Parent; - IfFalse->Parent = Parent; } - void disconnectSuccessor(VPBlockBase *Successor) { - assert(Successor && "Successor to disconnect is null."); - removeSuccessor(Successor); - Successor->removePredecessor(this); + /// Set each VPBasicBlock in \p NewPreds as predecessor of this VPBlockBase. + /// This VPBlockBase must have no predecessors. This VPBlockBase is not added + /// as successor of any VPBasicBlock in \p NewPreds. + void setPredecessors(ArrayRef<VPBlockBase *> NewPreds) { + assert(Predecessors.empty() && "Block predecessors already set."); + for (auto *Pred : NewPreds) + appendPredecessor(Pred); } /// The method which generates the output IR that correspond to this @@ -539,6 +565,15 @@ public: /// Each recipe prints itself. virtual void print(raw_ostream &O, const Twine &Indent) const = 0; + + /// Insert an unlinked recipe into a basic block immediately before + /// the specified recipe. + void insertBefore(VPRecipeBase *InsertPos); + + /// This method unlinks 'this' from the containing basic block and deletes it. + /// + /// \returns an iterator pointing to the element after the erased one + iplist<VPRecipeBase>::iterator eraseFromParent(); }; /// This is a concrete Recipe that models a single VPlan-level instruction. @@ -546,6 +581,8 @@ public: /// executed, these instructions would always form a single-def expression as /// the VPInstruction is also a single def-use vertex. class VPInstruction : public VPUser, public VPRecipeBase { + friend class VPlanHCFGTransforms; + public: /// VPlan opcodes, extending LLVM IR with idiomatics instructions. enum { Not = Instruction::OtherOpsEnd + 1 }; @@ -559,10 +596,13 @@ private: void generateInstruction(VPTransformState &State, unsigned Part); public: - VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands) + VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands) : VPUser(VPValue::VPInstructionSC, Operands), VPRecipeBase(VPRecipeBase::VPInstructionSC), Opcode(Opcode) {} + VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands) + : VPInstruction(Opcode, ArrayRef<VPValue *>(Operands)) {} + /// Method to support type inquiry through isa, cast, and dyn_cast. static inline bool classof(const VPValue *V) { return V->getVPValueID() == VPValue::VPInstructionSC; @@ -907,7 +947,10 @@ public: inline const VPRecipeBase &back() const { return Recipes.back(); } inline VPRecipeBase &back() { return Recipes.back(); } - /// \brief Returns a pointer to a member of the recipe list. + /// Returns a reference to the list of recipes. + RecipeListTy &getRecipeList() { return Recipes; } + + /// Returns a pointer to a member of the recipe list. static RecipeListTy VPBasicBlock::*getSublistAccess(VPRecipeBase *) { return &VPBasicBlock::Recipes; } @@ -968,6 +1011,9 @@ public: Entry->setParent(this); Exit->setParent(this); } + VPRegionBlock(const std::string &Name = "", bool IsReplicator = false) + : VPBlockBase(VPRegionBlockSC, Name), Entry(nullptr), Exit(nullptr), + IsReplicator(IsReplicator) {} ~VPRegionBlock() override { if (Entry) @@ -982,9 +1028,27 @@ public: const VPBlockBase *getEntry() const { return Entry; } VPBlockBase *getEntry() { return Entry; } + /// Set \p EntryBlock as the entry VPBlockBase of this VPRegionBlock. \p + /// EntryBlock must have no predecessors. + void setEntry(VPBlockBase *EntryBlock) { + assert(EntryBlock->getPredecessors().empty() && + "Entry block cannot have predecessors."); + Entry = EntryBlock; + EntryBlock->setParent(this); + } + const VPBlockBase *getExit() const { return Exit; } VPBlockBase *getExit() { return Exit; } + /// Set \p ExitBlock as the exit VPBlockBase of this VPRegionBlock. \p + /// ExitBlock must have no successors. + void setExit(VPBlockBase *ExitBlock) { + assert(ExitBlock->getSuccessors().empty() && + "Exit block cannot have successors."); + Exit = ExitBlock; + ExitBlock->setParent(this); + } + /// An indicator whether this region is to generate multiple replicated /// instances of output IR corresponding to its VPBlockBases. bool isReplicator() const { return IsReplicator; } @@ -1012,6 +1076,13 @@ private: /// Holds the name of the VPlan, for printing. std::string Name; + /// Holds all the external definitions created for this VPlan. + // TODO: Introduce a specific representation for external definitions in + // VPlan. External definitions must be immutable and hold a pointer to its + // underlying IR that will be used to implement its structural comparison + // (operators '==' and '<'). + SmallPtrSet<VPValue *, 16> VPExternalDefs; + /// Holds a mapping between Values and their corresponding VPValue inside /// VPlan. Value2VPValueTy Value2VPValue; @@ -1024,6 +1095,8 @@ public: VPBlockBase::deleteCFG(Entry); for (auto &MapEntry : Value2VPValue) delete MapEntry.second; + for (VPValue *Def : VPExternalDefs) + delete Def; } /// Generate the IR code for this VPlan. @@ -1042,6 +1115,12 @@ public: void setName(const Twine &newName) { Name = newName.str(); } + /// Add \p VPVal to the pool of external definitions if it's not already + /// in the pool. + void addExternalDef(VPValue *VPVal) { + VPExternalDefs.insert(VPVal); + } + void addVPValue(Value *V) { assert(V && "Trying to add a null Value to VPlan"); assert(!Value2VPValue.count(V) && "Value already exists in VPlan"); @@ -1189,6 +1268,72 @@ template <> struct GraphTraits<Inverse<VPBlockBase *>> { } }; +//===----------------------------------------------------------------------===// +// VPlan Utilities +//===----------------------------------------------------------------------===// + +/// Class that provides utilities for VPBlockBases in VPlan. +class VPBlockUtils { +public: + VPBlockUtils() = delete; + + /// Insert disconnected VPBlockBase \p NewBlock after \p BlockPtr. Add \p + /// NewBlock as successor of \p BlockPtr and \p BlockPtr as predecessor of \p + /// NewBlock, and propagate \p BlockPtr parent to \p NewBlock. If \p BlockPtr + /// has more than one successor, its conditional bit is propagated to \p + /// NewBlock. \p NewBlock must have neither successors nor predecessors. + static void insertBlockAfter(VPBlockBase *NewBlock, VPBlockBase *BlockPtr) { + assert(NewBlock->getSuccessors().empty() && + "Can't insert new block with successors."); + // TODO: move successors from BlockPtr to NewBlock when this functionality + // is necessary. For now, setBlockSingleSuccessor will assert if BlockPtr + // already has successors. + BlockPtr->setOneSuccessor(NewBlock); + NewBlock->setPredecessors({BlockPtr}); + NewBlock->setParent(BlockPtr->getParent()); + } + + /// Insert disconnected VPBlockBases \p IfTrue and \p IfFalse after \p + /// BlockPtr. Add \p IfTrue and \p IfFalse as succesors of \p BlockPtr and \p + /// BlockPtr as predecessor of \p IfTrue and \p IfFalse. Propagate \p BlockPtr + /// parent to \p IfTrue and \p IfFalse. \p Condition is set as the successor + /// selector. \p BlockPtr must have no successors and \p IfTrue and \p IfFalse + /// must have neither successors nor predecessors. + static void insertTwoBlocksAfter(VPBlockBase *IfTrue, VPBlockBase *IfFalse, + VPValue *Condition, VPBlockBase *BlockPtr) { + assert(IfTrue->getSuccessors().empty() && + "Can't insert IfTrue with successors."); + assert(IfFalse->getSuccessors().empty() && + "Can't insert IfFalse with successors."); + BlockPtr->setTwoSuccessors(IfTrue, IfFalse, Condition); + IfTrue->setPredecessors({BlockPtr}); + IfFalse->setPredecessors({BlockPtr}); + IfTrue->setParent(BlockPtr->getParent()); + IfFalse->setParent(BlockPtr->getParent()); + } + + /// Connect VPBlockBases \p From and \p To bi-directionally. Append \p To to + /// the successors of \p From and \p From to the predecessors of \p To. Both + /// VPBlockBases must have the same parent, which can be null. Both + /// VPBlockBases can be already connected to other VPBlockBases. + static void connectBlocks(VPBlockBase *From, VPBlockBase *To) { + assert((From->getParent() == To->getParent()) && + "Can't connect two block with different parents"); + assert(From->getNumSuccessors() < 2 && + "Blocks can't have more than two successors."); + From->appendSuccessor(To); + To->appendPredecessor(From); + } + + /// Disconnect VPBlockBases \p From and \p To bi-directionally. Remove \p To + /// from the successors of \p From and \p From from the predecessors of \p To. + static void disconnectBlocks(VPBlockBase *From, VPBlockBase *To) { + assert(To && "Successor to disconnect is null."); + From->removeSuccessor(To); + To->removePredecessor(From); + } +}; + } // end namespace llvm #endif // LLVM_TRANSFORMS_VECTORIZE_VPLAN_H diff --git a/lib/Transforms/Vectorize/VPlanBuilder.h b/lib/Transforms/Vectorize/VPlanBuilder.h deleted file mode 100644 index d6eb3397d044..000000000000 --- a/lib/Transforms/Vectorize/VPlanBuilder.h +++ /dev/null @@ -1,61 +0,0 @@ -//===- VPlanBuilder.h - A VPlan utility for constructing VPInstructions ---===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file provides a VPlan-based builder utility analogous to IRBuilder. -/// It provides an instruction-level API for generating VPInstructions while -/// abstracting away the Recipe manipulation details. -//===----------------------------------------------------------------------===// - -#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLAN_BUILDER_H -#define LLVM_TRANSFORMS_VECTORIZE_VPLAN_BUILDER_H - -#include "VPlan.h" - -namespace llvm { - -class VPBuilder { -private: - VPBasicBlock *BB = nullptr; - VPBasicBlock::iterator InsertPt = VPBasicBlock::iterator(); - - VPInstruction *createInstruction(unsigned Opcode, - std::initializer_list<VPValue *> Operands) { - VPInstruction *Instr = new VPInstruction(Opcode, Operands); - BB->insert(Instr, InsertPt); - return Instr; - } - -public: - VPBuilder() {} - - /// \brief This specifies that created VPInstructions should be appended to - /// the end of the specified block. - void setInsertPoint(VPBasicBlock *TheBB) { - assert(TheBB && "Attempting to set a null insert point"); - BB = TheBB; - InsertPt = BB->end(); - } - - VPValue *createNot(VPValue *Operand) { - return createInstruction(VPInstruction::Not, {Operand}); - } - - VPValue *createAnd(VPValue *LHS, VPValue *RHS) { - return createInstruction(Instruction::BinaryOps::And, {LHS, RHS}); - } - - VPValue *createOr(VPValue *LHS, VPValue *RHS) { - return createInstruction(Instruction::BinaryOps::Or, {LHS, RHS}); - } -}; - -} // namespace llvm - -#endif // LLVM_TRANSFORMS_VECTORIZE_VPLAN_BUILDER_H diff --git a/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp b/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp new file mode 100644 index 000000000000..08129b74cddf --- /dev/null +++ b/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp @@ -0,0 +1,336 @@ +//===-- VPlanHCFGBuilder.cpp ----------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements the construction of a VPlan-based Hierarchical CFG +/// (H-CFG) for an incoming IR. This construction comprises the following +/// components and steps: +// +/// 1. PlainCFGBuilder class: builds a plain VPBasicBlock-based CFG that +/// faithfully represents the CFG in the incoming IR. A VPRegionBlock (Top +/// Region) is created to enclose and serve as parent of all the VPBasicBlocks +/// in the plain CFG. +/// NOTE: At this point, there is a direct correspondence between all the +/// VPBasicBlocks created for the initial plain CFG and the incoming +/// BasicBlocks. However, this might change in the future. +/// +//===----------------------------------------------------------------------===// + +#include "VPlanHCFGBuilder.h" +#include "LoopVectorizationPlanner.h" +#include "llvm/Analysis/LoopIterator.h" + +#define DEBUG_TYPE "loop-vectorize" + +using namespace llvm; + +namespace { +// Class that is used to build the plain CFG for the incoming IR. +class PlainCFGBuilder { +private: + // The outermost loop of the input loop nest considered for vectorization. + Loop *TheLoop; + + // Loop Info analysis. + LoopInfo *LI; + + // Vectorization plan that we are working on. + VPlan &Plan; + + // Output Top Region. + VPRegionBlock *TopRegion = nullptr; + + // Builder of the VPlan instruction-level representation. + VPBuilder VPIRBuilder; + + // NOTE: The following maps are intentionally destroyed after the plain CFG + // construction because subsequent VPlan-to-VPlan transformation may + // invalidate them. + // Map incoming BasicBlocks to their newly-created VPBasicBlocks. + DenseMap<BasicBlock *, VPBasicBlock *> BB2VPBB; + // Map incoming Value definitions to their newly-created VPValues. + DenseMap<Value *, VPValue *> IRDef2VPValue; + + // Hold phi node's that need to be fixed once the plain CFG has been built. + SmallVector<PHINode *, 8> PhisToFix; + + // Utility functions. + void setVPBBPredsFromBB(VPBasicBlock *VPBB, BasicBlock *BB); + void fixPhiNodes(); + VPBasicBlock *getOrCreateVPBB(BasicBlock *BB); + bool isExternalDef(Value *Val); + VPValue *getOrCreateVPOperand(Value *IRVal); + void createVPInstructionsForVPBB(VPBasicBlock *VPBB, BasicBlock *BB); + +public: + PlainCFGBuilder(Loop *Lp, LoopInfo *LI, VPlan &P) + : TheLoop(Lp), LI(LI), Plan(P) {} + + // Build the plain CFG and return its Top Region. + VPRegionBlock *buildPlainCFG(); +}; +} // anonymous namespace + +// Set predecessors of \p VPBB in the same order as they are in \p BB. \p VPBB +// must have no predecessors. +void PlainCFGBuilder::setVPBBPredsFromBB(VPBasicBlock *VPBB, BasicBlock *BB) { + SmallVector<VPBlockBase *, 8> VPBBPreds; + // Collect VPBB predecessors. + for (BasicBlock *Pred : predecessors(BB)) + VPBBPreds.push_back(getOrCreateVPBB(Pred)); + + VPBB->setPredecessors(VPBBPreds); +} + +// Add operands to VPInstructions representing phi nodes from the input IR. +void PlainCFGBuilder::fixPhiNodes() { + for (auto *Phi : PhisToFix) { + assert(IRDef2VPValue.count(Phi) && "Missing VPInstruction for PHINode."); + VPValue *VPVal = IRDef2VPValue[Phi]; + assert(isa<VPInstruction>(VPVal) && "Expected VPInstruction for phi node."); + auto *VPPhi = cast<VPInstruction>(VPVal); + assert(VPPhi->getNumOperands() == 0 && + "Expected VPInstruction with no operands."); + + for (Value *Op : Phi->operands()) + VPPhi->addOperand(getOrCreateVPOperand(Op)); + } +} + +// Create a new empty VPBasicBlock for an incoming BasicBlock or retrieve an +// existing one if it was already created. +VPBasicBlock *PlainCFGBuilder::getOrCreateVPBB(BasicBlock *BB) { + auto BlockIt = BB2VPBB.find(BB); + if (BlockIt != BB2VPBB.end()) + // Retrieve existing VPBB. + return BlockIt->second; + + // Create new VPBB. + LLVM_DEBUG(dbgs() << "Creating VPBasicBlock for " << BB->getName() << "\n"); + VPBasicBlock *VPBB = new VPBasicBlock(BB->getName()); + BB2VPBB[BB] = VPBB; + VPBB->setParent(TopRegion); + return VPBB; +} + +// Return true if \p Val is considered an external definition. An external +// definition is either: +// 1. A Value that is not an Instruction. This will be refined in the future. +// 2. An Instruction that is outside of the CFG snippet represented in VPlan, +// i.e., is not part of: a) the loop nest, b) outermost loop PH and, c) +// outermost loop exits. +bool PlainCFGBuilder::isExternalDef(Value *Val) { + // All the Values that are not Instructions are considered external + // definitions for now. + Instruction *Inst = dyn_cast<Instruction>(Val); + if (!Inst) + return true; + + BasicBlock *InstParent = Inst->getParent(); + assert(InstParent && "Expected instruction parent."); + + // Check whether Instruction definition is in loop PH. + BasicBlock *PH = TheLoop->getLoopPreheader(); + assert(PH && "Expected loop pre-header."); + + if (InstParent == PH) + // Instruction definition is in outermost loop PH. + return false; + + // Check whether Instruction definition is in the loop exit. + BasicBlock *Exit = TheLoop->getUniqueExitBlock(); + assert(Exit && "Expected loop with single exit."); + if (InstParent == Exit) { + // Instruction definition is in outermost loop exit. + return false; + } + + // Check whether Instruction definition is in loop body. + return !TheLoop->contains(Inst); +} + +// Create a new VPValue or retrieve an existing one for the Instruction's +// operand \p IRVal. This function must only be used to create/retrieve VPValues +// for *Instruction's operands* and not to create regular VPInstruction's. For +// the latter, please, look at 'createVPInstructionsForVPBB'. +VPValue *PlainCFGBuilder::getOrCreateVPOperand(Value *IRVal) { + auto VPValIt = IRDef2VPValue.find(IRVal); + if (VPValIt != IRDef2VPValue.end()) + // Operand has an associated VPInstruction or VPValue that was previously + // created. + return VPValIt->second; + + // Operand doesn't have a previously created VPInstruction/VPValue. This + // means that operand is: + // A) a definition external to VPlan, + // B) any other Value without specific representation in VPlan. + // For now, we use VPValue to represent A and B and classify both as external + // definitions. We may introduce specific VPValue subclasses for them in the + // future. + assert(isExternalDef(IRVal) && "Expected external definition as operand."); + + // A and B: Create VPValue and add it to the pool of external definitions and + // to the Value->VPValue map. + VPValue *NewVPVal = new VPValue(IRVal); + Plan.addExternalDef(NewVPVal); + IRDef2VPValue[IRVal] = NewVPVal; + return NewVPVal; +} + +// Create new VPInstructions in a VPBasicBlock, given its BasicBlock +// counterpart. This function must be invoked in RPO so that the operands of a +// VPInstruction in \p BB have been visited before (except for Phi nodes). +void PlainCFGBuilder::createVPInstructionsForVPBB(VPBasicBlock *VPBB, + BasicBlock *BB) { + VPIRBuilder.setInsertPoint(VPBB); + for (Instruction &InstRef : *BB) { + Instruction *Inst = &InstRef; + + // There shouldn't be any VPValue for Inst at this point. Otherwise, we + // visited Inst when we shouldn't, breaking the RPO traversal order. + assert(!IRDef2VPValue.count(Inst) && + "Instruction shouldn't have been visited."); + + if (auto *Br = dyn_cast<BranchInst>(Inst)) { + // Branch instruction is not explicitly represented in VPlan but we need + // to represent its condition bit when it's conditional. + if (Br->isConditional()) + getOrCreateVPOperand(Br->getCondition()); + + // Skip the rest of the Instruction processing for Branch instructions. + continue; + } + + VPInstruction *NewVPInst; + if (auto *Phi = dyn_cast<PHINode>(Inst)) { + // Phi node's operands may have not been visited at this point. We create + // an empty VPInstruction that we will fix once the whole plain CFG has + // been built. + NewVPInst = cast<VPInstruction>(VPIRBuilder.createNaryOp( + Inst->getOpcode(), {} /*No operands*/, Inst)); + PhisToFix.push_back(Phi); + } else { + // Translate LLVM-IR operands into VPValue operands and set them in the + // new VPInstruction. + SmallVector<VPValue *, 4> VPOperands; + for (Value *Op : Inst->operands()) + VPOperands.push_back(getOrCreateVPOperand(Op)); + + // Build VPInstruction for any arbitraty Instruction without specific + // representation in VPlan. + NewVPInst = cast<VPInstruction>( + VPIRBuilder.createNaryOp(Inst->getOpcode(), VPOperands, Inst)); + } + + IRDef2VPValue[Inst] = NewVPInst; + } +} + +// Main interface to build the plain CFG. +VPRegionBlock *PlainCFGBuilder::buildPlainCFG() { + // 1. Create the Top Region. It will be the parent of all VPBBs. + TopRegion = new VPRegionBlock("TopRegion", false /*isReplicator*/); + + // 2. Scan the body of the loop in a topological order to visit each basic + // block after having visited its predecessor basic blocks. Create a VPBB for + // each BB and link it to its successor and predecessor VPBBs. Note that + // predecessors must be set in the same order as they are in the incomming IR. + // Otherwise, there might be problems with existing phi nodes and algorithm + // based on predecessors traversal. + + // Loop PH needs to be explicitly visited since it's not taken into account by + // LoopBlocksDFS. + BasicBlock *PreheaderBB = TheLoop->getLoopPreheader(); + assert((PreheaderBB->getTerminator()->getNumSuccessors() == 1) && + "Unexpected loop preheader"); + VPBasicBlock *PreheaderVPBB = getOrCreateVPBB(PreheaderBB); + createVPInstructionsForVPBB(PreheaderVPBB, PreheaderBB); + // Create empty VPBB for Loop H so that we can link PH->H. + VPBlockBase *HeaderVPBB = getOrCreateVPBB(TheLoop->getHeader()); + // Preheader's predecessors will be set during the loop RPO traversal below. + PreheaderVPBB->setOneSuccessor(HeaderVPBB); + + LoopBlocksRPO RPO(TheLoop); + RPO.perform(LI); + + for (BasicBlock *BB : RPO) { + // Create or retrieve the VPBasicBlock for this BB and create its + // VPInstructions. + VPBasicBlock *VPBB = getOrCreateVPBB(BB); + createVPInstructionsForVPBB(VPBB, BB); + + // Set VPBB successors. We create empty VPBBs for successors if they don't + // exist already. Recipes will be created when the successor is visited + // during the RPO traversal. + TerminatorInst *TI = BB->getTerminator(); + assert(TI && "Terminator expected."); + unsigned NumSuccs = TI->getNumSuccessors(); + + if (NumSuccs == 1) { + VPBasicBlock *SuccVPBB = getOrCreateVPBB(TI->getSuccessor(0)); + assert(SuccVPBB && "VPBB Successor not found."); + VPBB->setOneSuccessor(SuccVPBB); + } else if (NumSuccs == 2) { + VPBasicBlock *SuccVPBB0 = getOrCreateVPBB(TI->getSuccessor(0)); + assert(SuccVPBB0 && "Successor 0 not found."); + VPBasicBlock *SuccVPBB1 = getOrCreateVPBB(TI->getSuccessor(1)); + assert(SuccVPBB1 && "Successor 1 not found."); + + // Get VPBB's condition bit. + assert(isa<BranchInst>(TI) && "Unsupported terminator!"); + auto *Br = cast<BranchInst>(TI); + Value *BrCond = Br->getCondition(); + // Look up the branch condition to get the corresponding VPValue + // representing the condition bit in VPlan (which may be in another VPBB). + assert(IRDef2VPValue.count(BrCond) && + "Missing condition bit in IRDef2VPValue!"); + VPValue *VPCondBit = IRDef2VPValue[BrCond]; + + // Link successors using condition bit. + VPBB->setTwoSuccessors(SuccVPBB0, SuccVPBB1, VPCondBit); + } else + llvm_unreachable("Number of successors not supported."); + + // Set VPBB predecessors in the same order as they are in the incoming BB. + setVPBBPredsFromBB(VPBB, BB); + } + + // 3. Process outermost loop exit. We created an empty VPBB for the loop + // single exit BB during the RPO traversal of the loop body but Instructions + // weren't visited because it's not part of the the loop. + BasicBlock *LoopExitBB = TheLoop->getUniqueExitBlock(); + assert(LoopExitBB && "Loops with multiple exits are not supported."); + VPBasicBlock *LoopExitVPBB = BB2VPBB[LoopExitBB]; + createVPInstructionsForVPBB(LoopExitVPBB, LoopExitBB); + // Loop exit was already set as successor of the loop exiting BB. + // We only set its predecessor VPBB now. + setVPBBPredsFromBB(LoopExitVPBB, LoopExitBB); + + // 4. The whole CFG has been built at this point so all the input Values must + // have a VPlan couterpart. Fix VPlan phi nodes by adding their corresponding + // VPlan operands. + fixPhiNodes(); + + // 5. Final Top Region setup. Set outermost loop pre-header and single exit as + // Top Region entry and exit. + TopRegion->setEntry(PreheaderVPBB); + TopRegion->setExit(LoopExitVPBB); + return TopRegion; +} + +// Public interface to build a H-CFG. +void VPlanHCFGBuilder::buildHierarchicalCFG(VPlan &Plan) { + // Build Top Region enclosing the plain CFG and set it as VPlan entry. + PlainCFGBuilder PCFGBuilder(TheLoop, LI, Plan); + VPRegionBlock *TopRegion = PCFGBuilder.buildPlainCFG(); + Plan.setEntry(TopRegion); + LLVM_DEBUG(Plan.setName("HCFGBuilder: Plain CFG\n"); dbgs() << Plan); + + Verifier.verifyHierarchicalCFG(TopRegion); +} diff --git a/lib/Transforms/Vectorize/VPlanHCFGBuilder.h b/lib/Transforms/Vectorize/VPlanHCFGBuilder.h new file mode 100644 index 000000000000..c4e69843615a --- /dev/null +++ b/lib/Transforms/Vectorize/VPlanHCFGBuilder.h @@ -0,0 +1,55 @@ +//===-- VPlanHCFGBuilder.h --------------------------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines the VPlanHCFGBuilder class which contains the public +/// interface (buildHierarchicalCFG) to build a VPlan-based Hierarchical CFG +/// (H-CFG) for an incoming IR. +/// +/// A H-CFG in VPlan is a control-flow graph whose nodes are VPBasicBlocks +/// and/or VPRegionBlocks (i.e., other H-CFGs). The outermost H-CFG of a VPlan +/// consists of a VPRegionBlock, denoted Top Region, which encloses any other +/// VPBlockBase in the H-CFG. This guarantees that any VPBlockBase in the H-CFG +/// other than the Top Region will have a parent VPRegionBlock and allows us +/// to easily add more nodes before/after the main vector loop (such as the +/// reduction epilogue). +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLAN_VPLANHCFGBUILDER_H +#define LLVM_TRANSFORMS_VECTORIZE_VPLAN_VPLANHCFGBUILDER_H + +#include "VPlan.h" +#include "VPlanVerifier.h" + +namespace llvm { + +class Loop; + +/// Main class to build the VPlan H-CFG for an incoming IR. +class VPlanHCFGBuilder { +private: + // The outermost loop of the input loop nest considered for vectorization. + Loop *TheLoop; + + // Loop Info analysis. + LoopInfo *LI; + + // VPlan verifier utility. + VPlanVerifier Verifier; + +public: + VPlanHCFGBuilder(Loop *Lp, LoopInfo *LI) : TheLoop(Lp), LI(LI) {} + + /// Build H-CFG for TheLoop and update \p Plan accordingly. + void buildHierarchicalCFG(VPlan &Plan); +}; +} // namespace llvm + +#endif // LLVM_TRANSFORMS_VECTORIZE_VPLAN_VPLANHCFGBUILDER_H diff --git a/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp b/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp new file mode 100644 index 000000000000..e3cbab077e61 --- /dev/null +++ b/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp @@ -0,0 +1,73 @@ +//===-- VPlanHCFGTransforms.cpp - Utility VPlan to VPlan transforms -------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements a set of utility VPlan to VPlan transformations. +/// +//===----------------------------------------------------------------------===// + +#include "VPlanHCFGTransforms.h" +#include "llvm/ADT/PostOrderIterator.h" + +using namespace llvm; + +void VPlanHCFGTransforms::VPInstructionsToVPRecipes( + VPlanPtr &Plan, + LoopVectorizationLegality::InductionList *Inductions, + SmallPtrSetImpl<Instruction *> &DeadInstructions) { + + VPRegionBlock *TopRegion = dyn_cast<VPRegionBlock>(Plan->getEntry()); + ReversePostOrderTraversal<VPBlockBase *> RPOT(TopRegion->getEntry()); + for (VPBlockBase *Base : RPOT) { + // Do not widen instructions in pre-header and exit blocks. + if (Base->getNumPredecessors() == 0 || Base->getNumSuccessors() == 0) + continue; + + VPBasicBlock *VPBB = Base->getEntryBasicBlock(); + VPRecipeBase *LastRecipe = nullptr; + // Introduce each ingredient into VPlan. + for (auto I = VPBB->begin(), E = VPBB->end(); I != E;) { + VPRecipeBase *Ingredient = &*I++; + // Can only handle VPInstructions. + VPInstruction *VPInst = cast<VPInstruction>(Ingredient); + Instruction *Inst = cast<Instruction>(VPInst->getUnderlyingValue()); + if (DeadInstructions.count(Inst)) { + Ingredient->eraseFromParent(); + continue; + } + + VPRecipeBase *NewRecipe = nullptr; + // Create VPWidenMemoryInstructionRecipe for loads and stores. + if (isa<LoadInst>(Inst) || isa<StoreInst>(Inst)) + NewRecipe = new VPWidenMemoryInstructionRecipe(*Inst, nullptr /*Mask*/); + else if (PHINode *Phi = dyn_cast<PHINode>(Inst)) { + InductionDescriptor II = Inductions->lookup(Phi); + if (II.getKind() == InductionDescriptor::IK_IntInduction || + II.getKind() == InductionDescriptor::IK_FpInduction) { + NewRecipe = new VPWidenIntOrFpInductionRecipe(Phi); + } else + NewRecipe = new VPWidenPHIRecipe(Phi); + } else { + // If the last recipe is a VPWidenRecipe, add Inst to it instead of + // creating a new recipe. + if (VPWidenRecipe *WidenRecipe = + dyn_cast_or_null<VPWidenRecipe>(LastRecipe)) { + WidenRecipe->appendInstruction(Inst); + Ingredient->eraseFromParent(); + continue; + } + NewRecipe = new VPWidenRecipe(Inst); + } + + NewRecipe->insertBefore(Ingredient); + LastRecipe = NewRecipe; + Ingredient->eraseFromParent(); + } + } +} diff --git a/lib/Transforms/Vectorize/VPlanHCFGTransforms.h b/lib/Transforms/Vectorize/VPlanHCFGTransforms.h new file mode 100644 index 000000000000..ae549c6871b3 --- /dev/null +++ b/lib/Transforms/Vectorize/VPlanHCFGTransforms.h @@ -0,0 +1,36 @@ +//===- VPlanHCFGTransforms.h - Utility VPlan to VPlan transforms ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file provides utility VPlan to VPlan transformations. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLANHCFGTRANSFORMS_H +#define LLVM_TRANSFORMS_VECTORIZE_VPLANHCFGTRANSFORMS_H + +#include "VPlan.h" +#include "llvm/IR/Instruction.h" +#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h" + +namespace llvm { + +class VPlanHCFGTransforms { + +public: + /// Replaces the VPInstructions in \p Plan with corresponding + /// widen recipes. + static void VPInstructionsToVPRecipes( + VPlanPtr &Plan, + LoopVectorizationLegality::InductionList *Inductions, + SmallPtrSetImpl<Instruction *> &DeadInstructions); +}; + +} // namespace llvm + +#endif // LLVM_TRANSFORMS_VECTORIZE_VPLANHCFGTRANSFORMS_H diff --git a/lib/Transforms/Vectorize/VPlanValue.h b/lib/Transforms/Vectorize/VPlanValue.h index 50966891e0eb..08f142915b49 100644 --- a/lib/Transforms/Vectorize/VPlanValue.h +++ b/lib/Transforms/Vectorize/VPlanValue.h @@ -37,13 +37,34 @@ class VPUser; // coming from the input IR, instructions which VPlan will generate if executed // and live-outs which the VPlan will need to fix accordingly. class VPValue { + friend class VPBuilder; private: const unsigned char SubclassID; ///< Subclass identifier (for isa/dyn_cast). SmallVector<VPUser *, 1> Users; protected: - VPValue(const unsigned char SC) : SubclassID(SC) {} + // Hold the underlying Value, if any, attached to this VPValue. + Value *UnderlyingVal; + + VPValue(const unsigned char SC, Value *UV = nullptr) + : SubclassID(SC), UnderlyingVal(UV) {} + + // DESIGN PRINCIPLE: Access to the underlying IR must be strictly limited to + // the front-end and back-end of VPlan so that the middle-end is as + // independent as possible of the underlying IR. We grant access to the + // underlying IR using friendship. In that way, we should be able to use VPlan + // for multiple underlying IRs (Polly?) by providing a new VPlan front-end, + // back-end and analysis information for the new IR. + + /// Return the underlying Value attached to this VPValue. + Value *getUnderlyingValue() { return UnderlyingVal; } + + // Set \p Val as the underlying Value of this VPValue. + void setUnderlyingValue(Value *Val) { + assert(!UnderlyingVal && "Underlying Value is already set."); + UnderlyingVal = Val; + } public: /// An enumeration for keeping track of the concrete subclass of VPValue that @@ -52,7 +73,7 @@ public: /// type identification. enum { VPValueSC, VPUserSC, VPInstructionSC }; - VPValue() : SubclassID(VPValueSC) {} + VPValue(Value *UV = nullptr) : VPValue(VPValueSC, UV) {} VPValue(const VPValue &) = delete; VPValue &operator=(const VPValue &) = delete; @@ -94,11 +115,6 @@ class VPUser : public VPValue { private: SmallVector<VPValue *, 2> Operands; - void addOperand(VPValue *Operand) { - Operands.push_back(Operand); - Operand->addUser(*this); - } - protected: VPUser(const unsigned char SC) : VPValue(SC) {} VPUser(const unsigned char SC, ArrayRef<VPValue *> Operands) : VPValue(SC) { @@ -120,6 +136,11 @@ public: V->getVPValueID() <= VPInstructionSC; } + void addOperand(VPValue *Operand) { + Operands.push_back(Operand); + Operand->addUser(*this); + } + unsigned getNumOperands() const { return Operands.size(); } inline VPValue *getOperand(unsigned N) const { assert(N < Operands.size() && "Operand index out of bounds"); diff --git a/lib/Transforms/Vectorize/VPlanVerifier.cpp b/lib/Transforms/Vectorize/VPlanVerifier.cpp new file mode 100644 index 000000000000..054bed4e177f --- /dev/null +++ b/lib/Transforms/Vectorize/VPlanVerifier.cpp @@ -0,0 +1,133 @@ +//===-- VPlanVerifier.cpp -------------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines the class VPlanVerifier, which contains utility functions +/// to check the consistency and invariants of a VPlan. +/// +//===----------------------------------------------------------------------===// + +#include "VPlanVerifier.h" +#include "llvm/ADT/DepthFirstIterator.h" + +#define DEBUG_TYPE "loop-vectorize" + +using namespace llvm; + +static cl::opt<bool> EnableHCFGVerifier("vplan-verify-hcfg", cl::init(false), + cl::Hidden, + cl::desc("Verify VPlan H-CFG.")); + +#ifndef NDEBUG +/// Utility function that checks whether \p VPBlockVec has duplicate +/// VPBlockBases. +static bool hasDuplicates(const SmallVectorImpl<VPBlockBase *> &VPBlockVec) { + SmallDenseSet<const VPBlockBase *, 8> VPBlockSet; + for (const auto *Block : VPBlockVec) { + if (VPBlockSet.count(Block)) + return true; + VPBlockSet.insert(Block); + } + return false; +} +#endif + +/// Helper function that verifies the CFG invariants of the VPBlockBases within +/// \p Region. Checks in this function are generic for VPBlockBases. They are +/// not specific for VPBasicBlocks or VPRegionBlocks. +static void verifyBlocksInRegion(const VPRegionBlock *Region) { + for (const VPBlockBase *VPB : + make_range(df_iterator<const VPBlockBase *>::begin(Region->getEntry()), + df_iterator<const VPBlockBase *>::end(Region->getExit()))) { + // Check block's parent. + assert(VPB->getParent() == Region && "VPBlockBase has wrong parent"); + + // Check block's condition bit. + if (VPB->getNumSuccessors() > 1) + assert(VPB->getCondBit() && "Missing condition bit!"); + else + assert(!VPB->getCondBit() && "Unexpected condition bit!"); + + // Check block's successors. + const auto &Successors = VPB->getSuccessors(); + // There must be only one instance of a successor in block's successor list. + // TODO: This won't work for switch statements. + assert(!hasDuplicates(Successors) && + "Multiple instances of the same successor."); + + for (const VPBlockBase *Succ : Successors) { + // There must be a bi-directional link between block and successor. + const auto &SuccPreds = Succ->getPredecessors(); + assert(std::find(SuccPreds.begin(), SuccPreds.end(), VPB) != + SuccPreds.end() && + "Missing predecessor link."); + (void)SuccPreds; + } + + // Check block's predecessors. + const auto &Predecessors = VPB->getPredecessors(); + // There must be only one instance of a predecessor in block's predecessor + // list. + // TODO: This won't work for switch statements. + assert(!hasDuplicates(Predecessors) && + "Multiple instances of the same predecessor."); + + for (const VPBlockBase *Pred : Predecessors) { + // Block and predecessor must be inside the same region. + assert(Pred->getParent() == VPB->getParent() && + "Predecessor is not in the same region."); + + // There must be a bi-directional link between block and predecessor. + const auto &PredSuccs = Pred->getSuccessors(); + assert(std::find(PredSuccs.begin(), PredSuccs.end(), VPB) != + PredSuccs.end() && + "Missing successor link."); + (void)PredSuccs; + } + } +} + +/// Verify the CFG invariants of VPRegionBlock \p Region and its nested +/// VPBlockBases. Do not recurse inside nested VPRegionBlocks. +static void verifyRegion(const VPRegionBlock *Region) { + const VPBlockBase *Entry = Region->getEntry(); + const VPBlockBase *Exit = Region->getExit(); + + // Entry and Exit shouldn't have any predecessor/successor, respectively. + assert(!Entry->getNumPredecessors() && "Region entry has predecessors."); + assert(!Exit->getNumSuccessors() && "Region exit has successors."); + (void)Entry; + (void)Exit; + + verifyBlocksInRegion(Region); +} + +/// Verify the CFG invariants of VPRegionBlock \p Region and its nested +/// VPBlockBases. Recurse inside nested VPRegionBlocks. +static void verifyRegionRec(const VPRegionBlock *Region) { + verifyRegion(Region); + + // Recurse inside nested regions. + for (const VPBlockBase *VPB : + make_range(df_iterator<const VPBlockBase *>::begin(Region->getEntry()), + df_iterator<const VPBlockBase *>::end(Region->getExit()))) { + if (const auto *SubRegion = dyn_cast<VPRegionBlock>(VPB)) + verifyRegionRec(SubRegion); + } +} + +void VPlanVerifier::verifyHierarchicalCFG( + const VPRegionBlock *TopRegion) const { + if (!EnableHCFGVerifier) + return; + + LLVM_DEBUG(dbgs() << "Verifying VPlan H-CFG.\n"); + assert(!TopRegion->getParent() && "VPlan Top Region should have no parent."); + verifyRegionRec(TopRegion); +} diff --git a/lib/Transforms/Vectorize/VPlanVerifier.h b/lib/Transforms/Vectorize/VPlanVerifier.h new file mode 100644 index 000000000000..d2f99d006a66 --- /dev/null +++ b/lib/Transforms/Vectorize/VPlanVerifier.h @@ -0,0 +1,44 @@ +//===-- VPlanVerifier.h -----------------------------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file declares the class VPlanVerifier, which contains utility functions +/// to check the consistency of a VPlan. This includes the following kinds of +/// invariants: +/// +/// 1. Region/Block invariants: +/// - Region's entry/exit block must have no predecessors/successors, +/// respectively. +/// - Block's parent must be the region immediately containing the block. +/// - Linked blocks must have a bi-directional link (successor/predecessor). +/// - All predecessors/successors of a block must belong to the same region. +/// - Blocks must have no duplicated successor/predecessor. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLANVERIFIER_H +#define LLVM_TRANSFORMS_VECTORIZE_VPLANVERIFIER_H + +#include "VPlan.h" + +namespace llvm { + +/// Class with utility functions that can be used to check the consistency and +/// invariants of a VPlan, including the components of its H-CFG. +class VPlanVerifier { +public: + /// Verify the invariants of the H-CFG starting from \p TopRegion. The + /// verification process comprises the following steps: + /// 1. Region/Block verification: Check the Region/Block verification + /// invariants for every region in the H-CFG. + void verifyHierarchicalCFG(const VPRegionBlock *TopRegion) const; +}; +} // namespace llvm + +#endif //LLVM_TRANSFORMS_VECTORIZE_VPLANVERIFIER_H diff --git a/lib/Transforms/Vectorize/Vectorize.cpp b/lib/Transforms/Vectorize/Vectorize.cpp index b04905bfc6fa..f62a88558328 100644 --- a/lib/Transforms/Vectorize/Vectorize.cpp +++ b/lib/Transforms/Vectorize/Vectorize.cpp @@ -34,10 +34,6 @@ void LLVMInitializeVectorization(LLVMPassRegistryRef R) { initializeVectorization(*unwrap(R)); } -// DEPRECATED: Remove after the LLVM 5 release. -void LLVMAddBBVectorizePass(LLVMPassManagerRef PM) { -} - void LLVMAddLoopVectorizePass(LLVMPassManagerRef PM) { unwrap(PM)->add(createLoopVectorizePass()); } |