diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2022-07-04 19:20:19 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2023-02-08 19:02:26 +0000 |
commit | 81ad626541db97eb356e2c1d4a20eb2a26a766ab (patch) | |
tree | 311b6a8987c32b1e1dcbab65c54cfac3fdb56175 /contrib/llvm-project/llvm/lib/Transforms | |
parent | 5fff09660e06a66bed6482da9c70df328e16bbb6 (diff) | |
parent | 145449b1e420787bb99721a429341fa6be3adfb6 (diff) |
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms')
252 files changed, 20837 insertions, 16793 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index 7243e39c9029..1fd8b88dd776 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -22,8 +22,8 @@ #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -36,6 +36,10 @@ using namespace llvm; using namespace PatternMatch; +namespace llvm { +class DataLayout; +} + #define DEBUG_TYPE "aggressive-instcombine" STATISTIC(NumAnyOrAllBitsSet, "Number of any/all-bits-set patterns folded"); @@ -200,14 +204,13 @@ static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) { /// of 'and' ops, then we also need to capture the fact that we saw an /// "and X, 1", so that's an extra return value for that case. struct MaskOps { - Value *Root; + Value *Root = nullptr; APInt Mask; bool MatchAndChain; - bool FoundAnd1; + bool FoundAnd1 = false; MaskOps(unsigned BitWidth, bool MatchAnds) - : Root(nullptr), Mask(APInt::getZero(BitWidth)), MatchAndChain(MatchAnds), - FoundAnd1(false) {} + : Mask(APInt::getZero(BitWidth)), MatchAndChain(MatchAnds) {} }; /// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a @@ -363,10 +366,72 @@ static bool tryToRecognizePopCount(Instruction &I) { return false; } +/// Fold smin(smax(fptosi(x), C1), C2) to llvm.fptosi.sat(x), providing C1 and +/// C2 saturate the value of the fp conversion. The transform is not reversable +/// as the fptosi.sat is more defined than the input - all values produce a +/// valid value for the fptosi.sat, where as some produce poison for original +/// that were out of range of the integer conversion. The reversed pattern may +/// use fmax and fmin instead. As we cannot directly reverse the transform, and +/// it is not always profitable, we make it conditional on the cost being +/// reported as lower by TTI. +static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) { + // Look for min(max(fptosi, converting to fptosi_sat. + Value *In; + const APInt *MinC, *MaxC; + if (!match(&I, m_SMax(m_OneUse(m_SMin(m_OneUse(m_FPToSI(m_Value(In))), + m_APInt(MinC))), + m_APInt(MaxC))) && + !match(&I, m_SMin(m_OneUse(m_SMax(m_OneUse(m_FPToSI(m_Value(In))), + m_APInt(MaxC))), + m_APInt(MinC)))) + return false; + + // Check that the constants clamp a saturate. + if (!(*MinC + 1).isPowerOf2() || -*MaxC != *MinC + 1) + return false; + + Type *IntTy = I.getType(); + Type *FpTy = In->getType(); + Type *SatTy = + IntegerType::get(IntTy->getContext(), (*MinC + 1).exactLogBase2() + 1); + if (auto *VecTy = dyn_cast<VectorType>(IntTy)) + SatTy = VectorType::get(SatTy, VecTy->getElementCount()); + + // Get the cost of the intrinsic, and check that against the cost of + // fptosi+smin+smax + InstructionCost SatCost = TTI.getIntrinsicInstrCost( + IntrinsicCostAttributes(Intrinsic::fptosi_sat, SatTy, {In}, {FpTy}), + TTI::TCK_RecipThroughput); + SatCost += TTI.getCastInstrCost(Instruction::SExt, SatTy, IntTy, + TTI::CastContextHint::None, + TTI::TCK_RecipThroughput); + + InstructionCost MinMaxCost = TTI.getCastInstrCost( + Instruction::FPToSI, IntTy, FpTy, TTI::CastContextHint::None, + TTI::TCK_RecipThroughput); + MinMaxCost += TTI.getIntrinsicInstrCost( + IntrinsicCostAttributes(Intrinsic::smin, IntTy, {IntTy}), + TTI::TCK_RecipThroughput); + MinMaxCost += TTI.getIntrinsicInstrCost( + IntrinsicCostAttributes(Intrinsic::smax, IntTy, {IntTy}), + TTI::TCK_RecipThroughput); + + if (SatCost >= MinMaxCost) + return false; + + IRBuilder<> Builder(&I); + Function *Fn = Intrinsic::getDeclaration(I.getModule(), Intrinsic::fptosi_sat, + {SatTy, FpTy}); + Value *Sat = Builder.CreateCall(Fn, In); + I.replaceAllUsesWith(Builder.CreateSExt(Sat, IntTy)); + return true; +} + /// This is the entry point for folds that could be implemented in regular /// InstCombine, but they are separated because they are not expected to /// occur frequently and/or have more than a constant-length pattern match. -static bool foldUnusualPatterns(Function &F, DominatorTree &DT) { +static bool foldUnusualPatterns(Function &F, DominatorTree &DT, + TargetTransformInfo &TTI) { bool MadeChange = false; for (BasicBlock &BB : F) { // Ignore unreachable basic blocks. @@ -382,6 +447,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT) { MadeChange |= foldAnyOrAllBitsSet(I); MadeChange |= foldGuardedFunnelShift(I, DT); MadeChange |= tryToRecognizePopCount(I); + MadeChange |= tryToFPToSat(I, TTI); } } @@ -395,13 +461,13 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT) { /// This is the entry point for all transforms. Pass manager differences are /// handled in the callers of this function. -static bool runImpl(Function &F, AssumptionCache &AC, TargetLibraryInfo &TLI, - DominatorTree &DT) { +static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI, + TargetLibraryInfo &TLI, DominatorTree &DT) { bool MadeChange = false; const DataLayout &DL = F.getParent()->getDataLayout(); TruncInstCombine TIC(AC, TLI, DL, DT); MadeChange |= TIC.run(F); - MadeChange |= foldUnusualPatterns(F, DT); + MadeChange |= foldUnusualPatterns(F, DT, TTI); return MadeChange; } @@ -411,6 +477,7 @@ void AggressiveInstCombinerLegacyPass::getAnalysisUsage( AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); AU.addPreserved<AAResultsWrapperPass>(); AU.addPreserved<BasicAAWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); @@ -421,7 +488,8 @@ bool AggressiveInstCombinerLegacyPass::runOnFunction(Function &F) { auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - return runImpl(F, AC, TLI, DT); + auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + return runImpl(F, AC, TTI, TLI, DT); } PreservedAnalyses AggressiveInstCombinePass::run(Function &F, @@ -429,7 +497,8 @@ PreservedAnalyses AggressiveInstCombinePass::run(Function &F, auto &AC = AM.getResult<AssumptionAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); - if (!runImpl(F, AC, TLI, DT)) { + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + if (!runImpl(F, AC, TTI, TLI, DT)) { // No changes, all analyses are preserved. return PreservedAnalyses::all(); } @@ -446,6 +515,7 @@ INITIALIZE_PASS_BEGIN(AggressiveInstCombinerLegacyPass, INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(AggressiveInstCombinerLegacyPass, "aggressive-instcombine", "Combine pattern based expressions", false, false) diff --git a/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h b/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h index 5d69e26d6ecc..9fc103d45d98 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h +++ b/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h @@ -23,14 +23,14 @@ using namespace llvm; //===----------------------------------------------------------------------===// -// TruncInstCombine - looks for expression dags dominated by trunc instructions -// and for each eligible dag, it will create a reduced bit-width expression and -// replace the old expression with this new one and remove the old one. -// Eligible expression dag is such that: +// TruncInstCombine - looks for expression graphs dominated by trunc +// instructions and for each eligible graph, it will create a reduced bit-width +// expression and replace the old expression with this new one and remove the +// old one. Eligible expression graph is such that: // 1. Contains only supported instructions. // 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value. // 3. Can be evaluated into type with reduced legal bit-width (or Trunc type). -// 4. All instructions in the dag must not have users outside the dag. +// 4. All instructions in the graph must not have users outside the graph. // Only exception is for {ZExt, SExt}Inst with operand type equal to the // new reduced type chosen in (3). // @@ -61,9 +61,9 @@ class TruncInstCombine { SmallVector<TruncInst *, 4> Worklist; /// Current processed TruncInst instruction. - TruncInst *CurrentTruncInst; + TruncInst *CurrentTruncInst = nullptr; - /// Information per each instruction in the expression dag. + /// Information per each instruction in the expression graph. struct Info { /// Number of LSBs that are needed to generate a valid expression. unsigned ValidBitWidth = 0; @@ -72,26 +72,26 @@ class TruncInstCombine { /// The reduced value generated to replace the old instruction. Value *NewValue = nullptr; }; - /// An ordered map representing expression dag post-dominated by current - /// processed TruncInst. It maps each instruction in the dag to its Info + /// An ordered map representing expression graph post-dominated by current + /// processed TruncInst. It maps each instruction in the graph to its Info /// structure. The map is ordered such that each instruction appears before - /// all other instructions in the dag that uses it. + /// all other instructions in the graph that uses it. MapVector<Instruction *, Info> InstInfoMap; public: TruncInstCombine(AssumptionCache &AC, TargetLibraryInfo &TLI, const DataLayout &DL, const DominatorTree &DT) - : AC(AC), TLI(TLI), DL(DL), DT(DT), CurrentTruncInst(nullptr) {} + : AC(AC), TLI(TLI), DL(DL), DT(DT) {} /// Perform TruncInst pattern optimization on given function. bool run(Function &F); private: - /// Build expression dag dominated by the /p CurrentTruncInst and append it to - /// the InstInfoMap container. + /// Build expression graph dominated by the /p CurrentTruncInst and append it + /// to the InstInfoMap container. /// - /// \return true only if succeed to generate an eligible sub expression dag. - bool buildTruncExpressionDag(); + /// \return true only if succeed to generate an eligible sub expression graph. + bool buildTruncExpressionGraph(); /// Calculate the minimal allowed bit-width of the chain ending with the /// currently visited truncate's operand. @@ -100,12 +100,12 @@ private: /// truncate's operand can be shrunk to. unsigned getMinBitWidth(); - /// Build an expression dag dominated by the current processed TruncInst and + /// Build an expression graph dominated by the current processed TruncInst and /// Check if it is eligible to be reduced to a smaller type. /// /// \return the scalar version of the new type to be used for the reduced - /// expression dag, or nullptr if the expression dag is not eligible - /// to be reduced. + /// expression graph, or nullptr if the expression graph is not + /// eligible to be reduced. Type *getBestTruncatedType(); KnownBits computeKnownBits(const Value *V) const { @@ -128,12 +128,12 @@ private: /// \return the new reduced value. Value *getReducedOperand(Value *V, Type *SclTy); - /// Create a new expression dag using the reduced /p SclTy type and replace - /// the old expression dag with it. Also erase all instructions in the old - /// dag, except those that are still needed outside the dag. + /// Create a new expression graph using the reduced /p SclTy type and replace + /// the old expression graph with it. Also erase all instructions in the old + /// graph, except those that are still needed outside the graph. /// - /// \param SclTy scalar version of new type to reduce expression dag into. - void ReduceExpressionDag(Type *SclTy); + /// \param SclTy scalar version of new type to reduce expression graph into. + void ReduceExpressionGraph(Type *SclTy); }; } // end namespace llvm. diff --git a/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp index 4624b735bef8..70ea68587b8e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp @@ -6,14 +6,14 @@ // //===----------------------------------------------------------------------===// // -// TruncInstCombine - looks for expression dags post-dominated by TruncInst and -// for each eligible dag, it will create a reduced bit-width expression, replace -// the old expression with this new one and remove the old expression. -// Eligible expression dag is such that: +// TruncInstCombine - looks for expression graphs post-dominated by TruncInst +// and for each eligible graph, it will create a reduced bit-width expression, +// replace the old expression with this new one and remove the old expression. +// Eligible expression graph is such that: // 1. Contains only supported instructions. // 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value. // 3. Can be evaluated into type with reduced legal bit-width. -// 4. All instructions in the dag must not have users outside the dag. +// 4. All instructions in the graph must not have users outside the graph. // The only exception is for {ZExt, SExt}Inst with operand type equal to // the new reduced type evaluated in (3). // @@ -28,7 +28,6 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstantFolding.h" -#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" @@ -39,14 +38,13 @@ using namespace llvm; #define DEBUG_TYPE "aggressive-instcombine" -STATISTIC( - NumDAGsReduced, - "Number of truncations eliminated by reducing bit width of expression DAG"); +STATISTIC(NumExprsReduced, "Number of truncations eliminated by reducing bit " + "width of expression graph"); STATISTIC(NumInstrsReduced, "Number of instructions whose bit width was reduced"); /// Given an instruction and a container, it fills all the relevant operands of -/// that instruction, with respect to the Trunc expression dag optimizaton. +/// that instruction, with respect to the Trunc expression graph optimizaton. static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) { unsigned Opc = I->getOpcode(); switch (Opc) { @@ -78,15 +76,19 @@ static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) { Ops.push_back(I->getOperand(1)); Ops.push_back(I->getOperand(2)); break; + case Instruction::PHI: + for (Value *V : cast<PHINode>(I)->incoming_values()) + Ops.push_back(V); + break; default: llvm_unreachable("Unreachable!"); } } -bool TruncInstCombine::buildTruncExpressionDag() { +bool TruncInstCombine::buildTruncExpressionGraph() { SmallVector<Value *, 8> Worklist; SmallVector<Instruction *, 8> Stack; - // Clear old expression dag. + // Clear old instructions info. InstInfoMap.clear(); Worklist.push_back(CurrentTruncInst->getOperand(0)); @@ -150,11 +152,19 @@ bool TruncInstCombine::buildTruncExpressionDag() { append_range(Worklist, Operands); break; } + case Instruction::PHI: { + SmallVector<Value *, 2> Operands; + getRelevantOperands(I, Operands); + // Add only operands not in Stack to prevent cycle + for (auto *Op : Operands) + if (all_of(Stack, [Op](Value *V) { return Op != V; })) + Worklist.push_back(Op); + break; + } default: // TODO: Can handle more cases here: // 1. shufflevector // 2. sdiv, srem - // 3. phi node(and loop handling) // ... return false; } @@ -254,7 +264,7 @@ unsigned TruncInstCombine::getMinBitWidth() { } Type *TruncInstCombine::getBestTruncatedType() { - if (!buildTruncExpressionDag()) + if (!buildTruncExpressionGraph()) return nullptr; // We don't want to duplicate instructions, which isn't profitable. Thus, we @@ -367,8 +377,10 @@ Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) { return Entry.NewValue; } -void TruncInstCombine::ReduceExpressionDag(Type *SclTy) { +void TruncInstCombine::ReduceExpressionGraph(Type *SclTy) { NumInstrsReduced += InstInfoMap.size(); + // Pairs of old and new phi-nodes + SmallVector<std::pair<PHINode *, PHINode *>, 2> OldNewPHINodes; for (auto &Itr : InstInfoMap) { // Forward Instruction *I = Itr.first; TruncInstCombine::Info &NodeInfo = Itr.second; @@ -451,6 +463,12 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) { Res = Builder.CreateSelect(Op0, LHS, RHS); break; } + case Instruction::PHI: { + Res = Builder.CreatePHI(getReducedType(I, SclTy), I->getNumOperands()); + OldNewPHINodes.push_back( + std::make_pair(cast<PHINode>(I), cast<PHINode>(Res))); + break; + } default: llvm_unreachable("Unhandled instruction"); } @@ -460,6 +478,14 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) { ResI->takeName(I); } + for (auto &Node : OldNewPHINodes) { + PHINode *OldPN = Node.first; + PHINode *NewPN = Node.second; + for (auto Incoming : zip(OldPN->incoming_values(), OldPN->blocks())) + NewPN->addIncoming(getReducedOperand(std::get<0>(Incoming), SclTy), + std::get<1>(Incoming)); + } + Value *Res = getReducedOperand(CurrentTruncInst->getOperand(0), SclTy); Type *DstTy = CurrentTruncInst->getType(); if (Res->getType() != DstTy) { @@ -470,17 +496,29 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) { } CurrentTruncInst->replaceAllUsesWith(Res); - // Erase old expression dag, which was replaced by the reduced expression dag. - // We iterate backward, which means we visit the instruction before we visit - // any of its operands, this way, when we get to the operand, we already - // removed the instructions (from the expression dag) that uses it. + // Erase old expression graph, which was replaced by the reduced expression + // graph. CurrentTruncInst->eraseFromParent(); + // First, erase old phi-nodes and its uses + for (auto &Node : OldNewPHINodes) { + PHINode *OldPN = Node.first; + OldPN->replaceAllUsesWith(PoisonValue::get(OldPN->getType())); + InstInfoMap.erase(OldPN); + OldPN->eraseFromParent(); + } + // Now we have expression graph turned into dag. + // We iterate backward, which means we visit the instruction before we + // visit any of its operands, this way, when we get to the operand, we already + // removed the instructions (from the expression dag) that uses it. for (auto &I : llvm::reverse(InstInfoMap)) { // We still need to check that the instruction has no users before we erase // it, because {SExt, ZExt}Inst Instruction might have other users that was // not reduced, in such case, we need to keep that instruction. if (I.first->use_empty()) I.first->eraseFromParent(); + else + assert((isa<SExtInst>(I.first) || isa<ZExtInst>(I.first)) && + "Only {SExt, ZExt}Inst might have unreduced users"); } } @@ -498,18 +536,18 @@ bool TruncInstCombine::run(Function &F) { } // Process all TruncInst in the Worklist, for each instruction: - // 1. Check if it dominates an eligible expression dag to be reduced. - // 2. Create a reduced expression dag and replace the old one with it. + // 1. Check if it dominates an eligible expression graph to be reduced. + // 2. Create a reduced expression graph and replace the old one with it. while (!Worklist.empty()) { CurrentTruncInst = Worklist.pop_back_val(); if (Type *NewDstSclTy = getBestTruncatedType()) { LLVM_DEBUG( - dbgs() << "ICE: TruncInstCombine reducing type of expression dag " + dbgs() << "ICE: TruncInstCombine reducing type of expression graph " "dominated by: " << CurrentTruncInst << '\n'); - ReduceExpressionDag(NewDstSclTy); - ++NumDAGsReduced; + ReduceExpressionGraph(NewDstSclTy); + ++NumExprsReduced; MadeIRChange = true; } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp index 67f8828e4c75..f7bbdcffd2ec 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp @@ -10,9 +10,9 @@ #include "CoroInternal.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" -#include "llvm/IR/LegacyPassManager.h" -#include "llvm/Pass.h" -#include "llvm/Transforms/Scalar.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Function.h" +#include "llvm/Transforms/Scalar/SimplifyCFG.h" using namespace llvm; @@ -23,19 +23,10 @@ namespace { struct Lowerer : coro::LowererBase { IRBuilder<> Builder; Lowerer(Module &M) : LowererBase(M), Builder(Context) {} - bool lowerRemainingCoroIntrinsics(Function &F); + bool lower(Function &F); }; } -static void simplifyCFG(Function &F) { - llvm::legacy::FunctionPassManager FPM(F.getParent()); - FPM.add(createCFGSimplificationPass()); - - FPM.doInitialization(); - FPM.run(F); - FPM.doFinalization(); -} - static void lowerSubFn(IRBuilder<> &Builder, CoroSubFnInst *SubFn) { Builder.SetInsertPoint(SubFn); Value *FrameRaw = SubFn->getFrame(); @@ -53,12 +44,10 @@ static void lowerSubFn(IRBuilder<> &Builder, CoroSubFnInst *SubFn) { SubFn->replaceAllUsesWith(Load); } -bool Lowerer::lowerRemainingCoroIntrinsics(Function &F) { +bool Lowerer::lower(Function &F) { + bool IsPrivateAndUnprocessed = F.isPresplitCoroutine() && F.hasLocalLinkage(); bool Changed = false; - bool IsPrivateAndUnprocessed = - F.hasFnAttribute(CORO_PRESPLIT_ATTR) && F.hasLocalLinkage(); - for (Instruction &I : llvm::make_early_inc_range(instructions(F))) { if (auto *II = dyn_cast<IntrinsicInst>(&I)) { switch (II->getIntrinsicID()) { @@ -116,11 +105,6 @@ bool Lowerer::lowerRemainingCoroIntrinsics(Function &F) { } } - if (Changed) { - // After replacement were made we can cleanup the function body a little. - simplifyCFG(F); - } - return Changed; } @@ -132,50 +116,21 @@ static bool declaresCoroCleanupIntrinsics(const Module &M) { "llvm.coro.async.resume"}); } -PreservedAnalyses CoroCleanupPass::run(Function &F, - FunctionAnalysisManager &AM) { - auto &M = *F.getParent(); - if (!declaresCoroCleanupIntrinsics(M) || - !Lowerer(M).lowerRemainingCoroIntrinsics(F)) +PreservedAnalyses CoroCleanupPass::run(Module &M, + ModuleAnalysisManager &MAM) { + if (!declaresCoroCleanupIntrinsics(M)) return PreservedAnalyses::all(); - return PreservedAnalyses::none(); -} - -namespace { - -struct CoroCleanupLegacy : FunctionPass { - static char ID; // Pass identification, replacement for typeid + FunctionAnalysisManager &FAM = + MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); - CoroCleanupLegacy() : FunctionPass(ID) { - initializeCoroCleanupLegacyPass(*PassRegistry::getPassRegistry()); - } + FunctionPassManager FPM; + FPM.addPass(SimplifyCFGPass()); - std::unique_ptr<Lowerer> L; + Lowerer L(M); + for (auto &F : M) + if (L.lower(F)) + FPM.run(F, FAM); - // This pass has work to do only if we find intrinsics we are going to lower - // in the module. - bool doInitialization(Module &M) override { - if (declaresCoroCleanupIntrinsics(M)) - L = std::make_unique<Lowerer>(M); - return false; - } - - bool runOnFunction(Function &F) override { - if (L) - return L->lowerRemainingCoroIntrinsics(F); - return false; - } - void getAnalysisUsage(AnalysisUsage &AU) const override { - if (!L) - AU.setPreservesAll(); - } - StringRef getPassName() const override { return "Coroutine Cleanup"; } -}; + return PreservedAnalyses::none(); } - -char CoroCleanupLegacy::ID = 0; -INITIALIZE_PASS(CoroCleanupLegacy, "coro-cleanup", - "Lower all coroutine related intrinsics", false, false) - -Pass *llvm::createCoroCleanupLegacyPass() { return new CoroCleanupLegacy(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroConditionalWrapper.cpp b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroConditionalWrapper.cpp new file mode 100644 index 000000000000..3d26a43ceba7 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroConditionalWrapper.cpp @@ -0,0 +1,24 @@ +//===- CoroConditionalWrapper.cpp -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Coroutines/CoroConditionalWrapper.h" +#include "CoroInternal.h" +#include "llvm/IR/Module.h" + +using namespace llvm; + +CoroConditionalWrapper::CoroConditionalWrapper(ModulePassManager &&PM) + : PM(std::move(PM)) {} + +PreservedAnalyses CoroConditionalWrapper::run(Module &M, + ModuleAnalysisManager &AM) { + if (!coro::declaresAnyIntrinsic(M)) + return PreservedAnalyses::all(); + + return PM.run(M, AM); +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroEarly.cpp b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroEarly.cpp index 1533e1805f17..dd7cb23f3f3d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroEarly.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroEarly.cpp @@ -8,10 +8,10 @@ #include "llvm/Transforms/Coroutines/CoroEarly.h" #include "CoroInternal.h" +#include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Module.h" -#include "llvm/Pass.h" using namespace llvm; @@ -35,7 +35,7 @@ public: AnyResumeFnPtrTy(FunctionType::get(Type::getVoidTy(Context), Int8Ptr, /*isVarArg=*/false) ->getPointerTo()) {} - bool lowerEarlyIntrinsics(Function &F); + void lowerEarlyIntrinsics(Function &F); }; } @@ -145,14 +145,16 @@ static void setCannotDuplicate(CoroIdInst *CoroId) { CB->setCannotDuplicate(); } -bool Lowerer::lowerEarlyIntrinsics(Function &F) { - bool Changed = false; +void Lowerer::lowerEarlyIntrinsics(Function &F) { CoroIdInst *CoroId = nullptr; SmallVector<CoroFreeInst *, 4> CoroFrees; bool HasCoroSuspend = false; for (Instruction &I : llvm::make_early_inc_range(instructions(F))) { - if (auto *CB = dyn_cast<CallBase>(&I)) { - switch (CB->getIntrinsicID()) { + auto *CB = dyn_cast<CallBase>(&I); + if (!CB) + continue; + + switch (CB->getIntrinsicID()) { default: continue; case Intrinsic::coro_free: @@ -178,12 +180,9 @@ bool Lowerer::lowerEarlyIntrinsics(Function &F) { case Intrinsic::coro_id: if (auto *CII = cast<CoroIdInst>(&I)) { if (CII->getInfo().isPreSplit()) { - assert(F.hasFnAttribute(CORO_PRESPLIT_ATTR) && - F.getFnAttribute(CORO_PRESPLIT_ATTR).getValueAsString() == - UNPREPARED_FOR_SPLIT && + assert(F.isPresplitCoroutine() && "The frontend uses Swtich-Resumed ABI should emit " - "\"coroutine.presplit\" attribute with value \"0\" for the " - "coroutine."); + "\"coroutine.presplit\" attribute for the coroutine."); setCannotDuplicate(CII); CII->setCoroutineSelf(); CoroId = cast<CoroIdInst>(&I); @@ -193,9 +192,7 @@ bool Lowerer::lowerEarlyIntrinsics(Function &F) { case Intrinsic::coro_id_retcon: case Intrinsic::coro_id_retcon_once: case Intrinsic::coro_id_async: - // TODO: Remove the line once we support it in the corresponding - // frontend. - F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT); + F.setPresplitCoroutine(); break; case Intrinsic::coro_resume: lowerResumeOrDestroy(*CB, CoroSubFnInst::ResumeIndex); @@ -209,16 +206,16 @@ bool Lowerer::lowerEarlyIntrinsics(Function &F) { case Intrinsic::coro_done: lowerCoroDone(cast<IntrinsicInst>(&I)); break; - } - Changed = true; } } + // Make sure that all CoroFree reference the coro.id intrinsic. // Token type is not exposed through coroutine C/C++ builtins to plain C, so // we allow specifying none and fixing it up here. if (CoroId) for (CoroFreeInst *CF : CoroFrees) CF->setArgOperand(0, CoroId); + // Coroutine suspention could potentially lead to any argument modified // outside of the function, hence arguments should not have noalias // attributes. @@ -226,7 +223,6 @@ bool Lowerer::lowerEarlyIntrinsics(Function &F) { for (Argument &A : F.args()) if (A.hasNoAliasAttr()) A.removeAttr(Attribute::NoAlias); - return Changed; } static bool declaresCoroEarlyIntrinsics(const Module &M) { @@ -238,52 +234,15 @@ static bool declaresCoroEarlyIntrinsics(const Module &M) { "llvm.coro.suspend"}); } -PreservedAnalyses CoroEarlyPass::run(Function &F, FunctionAnalysisManager &) { - Module &M = *F.getParent(); - if (!declaresCoroEarlyIntrinsics(M) || !Lowerer(M).lowerEarlyIntrinsics(F)) +PreservedAnalyses CoroEarlyPass::run(Module &M, ModuleAnalysisManager &) { + if (!declaresCoroEarlyIntrinsics(M)) return PreservedAnalyses::all(); + Lowerer L(M); + for (auto &F : M) + L.lowerEarlyIntrinsics(F); + PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); return PA; } - -namespace { - -struct CoroEarlyLegacy : public FunctionPass { - static char ID; // Pass identification, replacement for typeid. - CoroEarlyLegacy() : FunctionPass(ID) { - initializeCoroEarlyLegacyPass(*PassRegistry::getPassRegistry()); - } - - std::unique_ptr<Lowerer> L; - - // This pass has work to do only if we find intrinsics we are going to lower - // in the module. - bool doInitialization(Module &M) override { - if (declaresCoroEarlyIntrinsics(M)) - L = std::make_unique<Lowerer>(M); - return false; - } - - bool runOnFunction(Function &F) override { - if (!L) - return false; - - return L->lowerEarlyIntrinsics(F); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - } - StringRef getPassName() const override { - return "Lower early coroutine intrinsics"; - } -}; -} - -char CoroEarlyLegacy::ID = 0; -INITIALIZE_PASS(CoroEarlyLegacy, "coro-early", - "Lower early coroutine intrinsics", false, false) - -Pass *llvm::createCoroEarlyLegacyPass() { return new CoroEarlyLegacy(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroElide.cpp index 84bebb7bf42d..6f78fc8db311 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroElide.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroElide.cpp @@ -14,8 +14,6 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/InstIterator.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FileSystem.h" @@ -103,21 +101,12 @@ static void removeTailCallAttribute(AllocaInst *Frame, AAResults &AA) { // Given a resume function @f.resume(%f.frame* %frame), returns the size // and expected alignment of %f.frame type. -static std::pair<uint64_t, Align> getFrameLayout(Function *Resume) { - // Prefer to pull information from the function attributes. +static Optional<std::pair<uint64_t, Align>> getFrameLayout(Function *Resume) { + // Pull information from the function attributes. auto Size = Resume->getParamDereferenceableBytes(0); - auto Align = Resume->getParamAlign(0); - - // If those aren't given, extract them from the type. - if (Size == 0 || !Align) { - auto *FrameTy = Resume->arg_begin()->getType()->getPointerElementType(); - - const DataLayout &DL = Resume->getParent()->getDataLayout(); - if (!Size) Size = DL.getTypeAllocSize(FrameTy); - if (!Align) Align = DL.getABITypeAlign(FrameTy); - } - - return std::make_pair(Size, *Align); + if (!Size) + return None; + return std::make_pair(Size, Resume->getParamAlign(0).valueOrOne()); } // Finds first non alloca instruction in the entry block of a function. @@ -347,56 +336,37 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA, assert(Resumers && "PostSplit coro.id Info argument must refer to an array" "of coroutine subfunctions"); auto *ResumeAddrConstant = - ConstantExpr::getExtractValue(Resumers, CoroSubFnInst::ResumeIndex); + Resumers->getAggregateElement(CoroSubFnInst::ResumeIndex); replaceWithConstant(ResumeAddrConstant, ResumeAddr); bool ShouldElide = shouldElide(CoroId->getFunction(), DT); - auto *DestroyAddrConstant = ConstantExpr::getExtractValue( - Resumers, + auto *DestroyAddrConstant = Resumers->getAggregateElement( ShouldElide ? CoroSubFnInst::CleanupIndex : CoroSubFnInst::DestroyIndex); for (auto &It : DestroyAddr) replaceWithConstant(DestroyAddrConstant, It.second); if (ShouldElide) { - auto FrameSizeAndAlign = getFrameLayout(cast<Function>(ResumeAddrConstant)); - elideHeapAllocations(CoroId->getFunction(), FrameSizeAndAlign.first, - FrameSizeAndAlign.second, AA); - coro::replaceCoroFree(CoroId, /*Elide=*/true); - NumOfCoroElided++; + if (auto FrameSizeAndAlign = + getFrameLayout(cast<Function>(ResumeAddrConstant))) { + elideHeapAllocations(CoroId->getFunction(), FrameSizeAndAlign->first, + FrameSizeAndAlign->second, AA); + coro::replaceCoroFree(CoroId, /*Elide=*/true); + NumOfCoroElided++; #ifndef NDEBUG - if (!CoroElideInfoOutputFilename.empty()) - *getOrCreateLogFile() - << "Elide " << CoroId->getCoroutine()->getName() << " in " - << CoroId->getFunction()->getName() << "\n"; + if (!CoroElideInfoOutputFilename.empty()) + *getOrCreateLogFile() + << "Elide " << CoroId->getCoroutine()->getName() << " in " + << CoroId->getFunction()->getName() << "\n"; #endif + } } return true; } -// See if there are any coro.subfn.addr instructions referring to coro.devirt -// trigger, if so, replace them with a direct call to devirt trigger function. -static bool replaceDevirtTrigger(Function &F) { - SmallVector<CoroSubFnInst *, 1> DevirtAddr; - for (auto &I : instructions(F)) - if (auto *SubFn = dyn_cast<CoroSubFnInst>(&I)) - if (SubFn->getIndex() == CoroSubFnInst::RestartTrigger) - DevirtAddr.push_back(SubFn); - - if (DevirtAddr.empty()) - return false; - - Module &M = *F.getParent(); - Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN); - assert(DevirtFn && "coro.devirt.fn not found"); - replaceWithConstant(DevirtFn, DevirtAddr); - - return true; -} - static bool declaresCoroElideIntrinsics(Module &M) { return coro::declaresIntrinsics(M, {"llvm.coro.id", "llvm.coro.id.async"}); } @@ -422,62 +392,3 @@ PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) { return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); } - -namespace { -struct CoroElideLegacy : FunctionPass { - static char ID; - CoroElideLegacy() : FunctionPass(ID) { - initializeCoroElideLegacyPass(*PassRegistry::getPassRegistry()); - } - - std::unique_ptr<Lowerer> L; - - bool doInitialization(Module &M) override { - if (declaresCoroElideIntrinsics(M)) - L = std::make_unique<Lowerer>(M); - return false; - } - - bool runOnFunction(Function &F) override { - if (!L) - return false; - - bool Changed = false; - - if (F.hasFnAttribute(CORO_PRESPLIT_ATTR)) - Changed = replaceDevirtTrigger(F); - - L->CoroIds.clear(); - L->collectPostSplitCoroIds(&F); - // If we did not find any coro.id, there is nothing to do. - if (L->CoroIds.empty()) - return Changed; - - AAResults &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); - DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - - for (auto *CII : L->CoroIds) - Changed |= L->processCoroId(CII, AA, DT); - - return Changed; - } - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AAResultsWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - } - StringRef getPassName() const override { return "Coroutine Elision"; } -}; -} - -char CoroElideLegacy::ID = 0; -INITIALIZE_PASS_BEGIN( - CoroElideLegacy, "coro-elide", - "Coroutine frame allocation elision and indirect calls replacement", false, - false) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_END( - CoroElideLegacy, "coro-elide", - "Coroutine frame allocation elision and indirect calls replacement", false, - false) - -Pass *llvm::createCoroElideLegacyPass() { return new CoroElideLegacy(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroFrame.cpp index 9c16d3750998..d09607bb1c4c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -27,7 +27,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" -#include "llvm/Support/CommandLine.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/OptimizedStructLayout.h" @@ -44,13 +44,6 @@ using namespace llvm; // "coro-frame", which results in leaner debug spew. #define DEBUG_TYPE "coro-suspend-crossing" -static cl::opt<bool> EnableReuseStorageInFrame( - "reuse-storage-in-coroutine-frame", cl::Hidden, - cl::desc( - "Enable the optimization which would reuse the storage in the coroutine \ - frame for allocas whose liferanges are not overlapped, for testing purposes"), - llvm::cl::init(false)); - enum { SmallVectorThreshold = 32 }; // Provides two way mapping between the blocks and numbers. @@ -347,15 +340,26 @@ struct FrameDataInfo { FieldIndexMap[V] = Index; } - uint64_t getAlign(Value *V) const { + Align getAlign(Value *V) const { auto Iter = FieldAlignMap.find(V); assert(Iter != FieldAlignMap.end()); return Iter->second; } - void setAlign(Value *V, uint64_t Align) { + void setAlign(Value *V, Align AL) { assert(FieldAlignMap.count(V) == 0); - FieldAlignMap.insert({V, Align}); + FieldAlignMap.insert({V, AL}); + } + + uint64_t getDynamicAlign(Value *V) const { + auto Iter = FieldDynamicAlignMap.find(V); + assert(Iter != FieldDynamicAlignMap.end()); + return Iter->second; + } + + void setDynamicAlign(Value *V, uint64_t Align) { + assert(FieldDynamicAlignMap.count(V) == 0); + FieldDynamicAlignMap.insert({V, Align}); } uint64_t getOffset(Value *V) const { @@ -382,7 +386,8 @@ private: DenseMap<Value *, uint32_t> FieldIndexMap; // Map from values to their alignment on the frame. They would be set after // the frame is built. - DenseMap<Value *, uint64_t> FieldAlignMap; + DenseMap<Value *, Align> FieldAlignMap; + DenseMap<Value *, uint64_t> FieldDynamicAlignMap; // Map from values to their offset on the frame. They would be set after // the frame is built. DenseMap<Value *, uint64_t> FieldOffsetMap; @@ -423,6 +428,7 @@ private: FieldIDType LayoutFieldIndex; Align Alignment; Align TyAlignment; + uint64_t DynamicAlignBuffer; }; const DataLayout &DL; @@ -489,7 +495,7 @@ public: coro::Shape &Shape); /// Add a field to this structure. - LLVM_NODISCARD FieldIDType addField(Type *Ty, MaybeAlign FieldAlignment, + LLVM_NODISCARD FieldIDType addField(Type *Ty, MaybeAlign MaybeFieldAlignment, bool IsHeader = false, bool IsSpillOfValue = false) { assert(!IsFinished && "adding fields to a finished builder"); @@ -508,13 +514,21 @@ public: // to remember the type alignment anyway to build the type. // If we are spilling values we don't need to worry about ABI alignment // concerns. - auto ABIAlign = DL.getABITypeAlign(Ty); - Align TyAlignment = - (IsSpillOfValue && MaxFrameAlignment) - ? (*MaxFrameAlignment < ABIAlign ? *MaxFrameAlignment : ABIAlign) - : ABIAlign; - if (!FieldAlignment) { - FieldAlignment = TyAlignment; + Align ABIAlign = DL.getABITypeAlign(Ty); + Align TyAlignment = ABIAlign; + if (IsSpillOfValue && MaxFrameAlignment && *MaxFrameAlignment < ABIAlign) + TyAlignment = *MaxFrameAlignment; + Align FieldAlignment = MaybeFieldAlignment.value_or(TyAlignment); + + // The field alignment could be bigger than the max frame case, in that case + // we request additional storage to be able to dynamically align the + // pointer. + uint64_t DynamicAlignBuffer = 0; + if (MaxFrameAlignment && (FieldAlignment > *MaxFrameAlignment)) { + DynamicAlignBuffer = + offsetToAlignment(MaxFrameAlignment->value(), FieldAlignment); + FieldAlignment = *MaxFrameAlignment; + FieldSize = FieldSize + DynamicAlignBuffer; } // Lay out header fields immediately. @@ -523,12 +537,13 @@ public: Offset = alignTo(StructSize, FieldAlignment); StructSize = Offset + FieldSize; - // Everything else has a flexible offset. + // Everything else has a flexible offset. } else { Offset = OptimizedStructLayoutField::FlexibleOffset; } - Fields.push_back({FieldSize, Offset, Ty, 0, *FieldAlignment, TyAlignment}); + Fields.push_back({FieldSize, Offset, Ty, 0, FieldAlignment, TyAlignment, + DynamicAlignBuffer}); return Fields.size() - 1; } @@ -561,7 +576,12 @@ void FrameDataInfo::updateLayoutIndex(FrameTypeBuilder &B) { auto Updater = [&](Value *I) { auto Field = B.getLayoutField(getFieldIndex(I)); setFieldIndex(I, Field.LayoutFieldIndex); - setAlign(I, Field.Alignment.value()); + setAlign(I, Field.Alignment); + uint64_t dynamicAlign = + Field.DynamicAlignBuffer + ? Field.DynamicAlignBuffer + Field.Alignment.value() + : 0; + setDynamicAlign(I, dynamicAlign); setOffset(I, Field.Offset); }; LayoutIndexUpdateStarted = true; @@ -588,7 +608,7 @@ void FrameTypeBuilder::addFieldForAllocas(const Function &F, } }); - if (!Shape.OptimizeFrame && !EnableReuseStorageInFrame) { + if (!Shape.OptimizeFrame) { for (const auto &A : FrameData.Allocas) { AllocaInst *Alloca = A.Alloca; NonOverlapedAllocas.emplace_back(AllocaSetType(1, Alloca)); @@ -755,6 +775,10 @@ void FrameTypeBuilder::finish(StructType *Ty) { F.LayoutFieldIndex = FieldTypes.size(); FieldTypes.push_back(F.Ty); + if (F.DynamicAlignBuffer) { + FieldTypes.push_back( + ArrayType::get(Type::getInt8Ty(Context), F.DynamicAlignBuffer)); + } LastOffset = Offset + F.Size; } @@ -807,9 +831,10 @@ static StringRef solveTypeName(Type *Ty) { return "__floating_type_"; } - if (Ty->isPointerTy()) { - auto *PtrTy = cast<PointerType>(Ty); - Type *PointeeTy = PtrTy->getPointerElementType(); + if (auto *PtrTy = dyn_cast<PointerType>(Ty)) { + if (PtrTy->isOpaque()) + return "PointerType"; + Type *PointeeTy = PtrTy->getNonOpaquePointerElementType(); auto Name = solveTypeName(PointeeTy); if (Name == "UnknownType") return "PointerType"; @@ -826,10 +851,9 @@ static StringRef solveTypeName(Type *Ty) { auto Name = Ty->getStructName(); SmallString<16> Buffer(Name); - for_each(Buffer, [](auto &Iter) { + for (auto &Iter : Buffer) if (Iter == '.' || Iter == ':') Iter = '_'; - }); auto *MDName = MDString::get(Ty->getContext(), Buffer.str()); return MDName->getString(); } @@ -1012,7 +1036,7 @@ static void buildFrameDebugInfo(Function &F, coro::Shape &Shape, auto Index = FrameData.getFieldIndex(V); OffsetCache.insert( - {Index, {FrameData.getAlign(V), FrameData.getOffset(V)}}); + {Index, {FrameData.getAlign(V).value(), FrameData.getOffset(V)}}); } DenseMap<Type *, DIType *> DITypeCache; @@ -1078,7 +1102,7 @@ static void buildFrameDebugInfo(Function &F, coro::Shape &Shape, DBuilder.insertDeclare(Shape.FramePtr, FrameDIVar, DBuilder.createExpression(), DILoc, - Shape.FramePtr->getNextNode()); + Shape.getInsertPtAfterFramePtr()); } // Build a struct that will keep state for an active coroutine. @@ -1367,7 +1391,7 @@ struct AllocaUseVisitor : PtrUseVisitor<AllocaUseVisitor> { bool getShouldLiveOnFrame() const { if (!ShouldLiveOnFrame) ShouldLiveOnFrame = computeShouldLiveOnFrame(); - return ShouldLiveOnFrame.getValue(); + return *ShouldLiveOnFrame; } bool getMayWriteBeforeCoroBegin() const { return MayWriteBeforeCoroBegin; } @@ -1455,7 +1479,7 @@ private: auto Itr = AliasOffetMap.find(&I); if (Itr == AliasOffetMap.end()) { AliasOffetMap[&I] = Offset; - } else if (Itr->second.hasValue() && Itr->second.getValue() != Offset) { + } else if (Itr->second && *Itr->second != Offset) { // If we have seen two different possible values for this alias, we set // it to empty. AliasOffetMap[&I].reset(); @@ -1517,13 +1541,12 @@ static void createFramePtr(coro::Shape &Shape) { // whatever // // -static Instruction *insertSpills(const FrameDataInfo &FrameData, - coro::Shape &Shape) { +static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { auto *CB = Shape.CoroBegin; LLVMContext &C = CB->getContext(); IRBuilder<> Builder(C); StructType *FrameTy = Shape.FrameTy; - Instruction *FramePtr = Shape.FramePtr; + Value *FramePtr = Shape.FramePtr; DominatorTree DT(*CB->getFunction()); SmallDenseMap<llvm::Value *, llvm::AllocaInst *, 4> DbgPtrAllocaCache; @@ -1550,7 +1573,18 @@ static Instruction *insertSpills(const FrameDataInfo &FrameData, auto GEP = cast<GetElementPtrInst>( Builder.CreateInBoundsGEP(FrameTy, FramePtr, Indices)); - if (isa<AllocaInst>(Orig)) { + if (auto *AI = dyn_cast<AllocaInst>(Orig)) { + if (FrameData.getDynamicAlign(Orig) != 0) { + assert(FrameData.getDynamicAlign(Orig) == AI->getAlign().value()); + auto *M = AI->getModule(); + auto *IntPtrTy = M->getDataLayout().getIntPtrType(AI->getType()); + auto *PtrValue = Builder.CreatePtrToInt(GEP, IntPtrTy); + auto *AlignMask = + ConstantInt::get(IntPtrTy, AI->getAlign().value() - 1); + PtrValue = Builder.CreateAdd(PtrValue, AlignMask); + PtrValue = Builder.CreateAnd(PtrValue, Builder.CreateNot(AlignMask)); + return Builder.CreateIntToPtr(PtrValue, AI->getType()); + } // If the type of GEP is not equal to the type of AllocaInst, it implies // that the AllocaInst may be reused in the Frame slot of other // AllocaInst. So We cast GEP to the AllocaInst here to re-use @@ -1571,20 +1605,19 @@ static Instruction *insertSpills(const FrameDataInfo &FrameData, // Create a store instruction storing the value into the // coroutine frame. Instruction *InsertPt = nullptr; - bool NeedToCopyArgPtrValue = false; + Type *ByValTy = nullptr; if (auto *Arg = dyn_cast<Argument>(Def)) { // For arguments, we will place the store instruction right after // the coroutine frame pointer instruction, i.e. bitcast of // coro.begin from i8* to %f.frame*. - InsertPt = FramePtr->getNextNode(); + InsertPt = Shape.getInsertPtAfterFramePtr(); // If we're spilling an Argument, make sure we clear 'nocapture' // from the coroutine function. Arg->getParent()->removeParamAttr(Arg->getArgNo(), Attribute::NoCapture); if (Arg->hasByValAttr()) - NeedToCopyArgPtrValue = true; - + ByValTy = Arg->getParamByValType(); } else if (auto *CSI = dyn_cast<AnyCoroSuspendInst>(Def)) { // Don't spill immediately after a suspend; splitting assumes // that the suspend will be followed by a branch. @@ -1594,7 +1627,7 @@ static Instruction *insertSpills(const FrameDataInfo &FrameData, if (!DT.dominates(CB, I)) { // If it is not dominated by CoroBegin, then spill should be // inserted immediately after CoroFrame is computed. - InsertPt = FramePtr->getNextNode(); + InsertPt = Shape.getInsertPtAfterFramePtr(); } else if (auto *II = dyn_cast<InvokeInst>(I)) { // If we are spilling the result of the invoke instruction, split // the normal edge and insert the spill in the new block. @@ -1619,11 +1652,10 @@ static Instruction *insertSpills(const FrameDataInfo &FrameData, Builder.SetInsertPoint(InsertPt); auto *G = Builder.CreateConstInBoundsGEP2_32( FrameTy, FramePtr, 0, Index, Def->getName() + Twine(".spill.addr")); - if (NeedToCopyArgPtrValue) { + if (ByValTy) { // For byval arguments, we need to store the pointed value in the frame, // instead of the pointer itself. - auto *Value = - Builder.CreateLoad(Def->getType()->getPointerElementType(), Def); + auto *Value = Builder.CreateLoad(ByValTy, Def); Builder.CreateAlignedStore(Value, G, SpillAlignment); } else { Builder.CreateAlignedStore(Def, G, SpillAlignment); @@ -1641,7 +1673,7 @@ static Instruction *insertSpills(const FrameDataInfo &FrameData, auto *GEP = GetFramePointer(E.first); GEP->setName(E.first->getName() + Twine(".reload.addr")); - if (NeedToCopyArgPtrValue) + if (ByValTy) CurrentReload = GEP; else CurrentReload = Builder.CreateAlignedLoad( @@ -1664,6 +1696,12 @@ static Instruction *insertSpills(const FrameDataInfo &FrameData, } } + // Salvage debug info on any dbg.addr that we see. We do not insert them + // into each block where we have a use though. + if (auto *DI = dyn_cast<DbgAddrIntrinsic>(U)) { + coro::salvageDebugInfo(DbgPtrAllocaCache, DI, Shape.OptimizeFrame); + } + // If we have a single edge PHINode, remove it and replace it with a // reload from the coroutine frame. (We already took care of multi edge // PHINodes by rewriting them in the rewritePHIs function). @@ -1682,10 +1720,10 @@ static Instruction *insertSpills(const FrameDataInfo &FrameData, } } - BasicBlock *FramePtrBB = FramePtr->getParent(); + BasicBlock *FramePtrBB = Shape.getInsertPtAfterFramePtr()->getParent(); - auto SpillBlock = - FramePtrBB->splitBasicBlock(FramePtr->getNextNode(), "AllocaSpillBB"); + auto SpillBlock = FramePtrBB->splitBasicBlock( + Shape.getInsertPtAfterFramePtr(), "AllocaSpillBB"); SpillBlock->splitBasicBlock(&SpillBlock->front(), "PostSpill"); Shape.AllocaSpillBlock = SpillBlock; @@ -1704,7 +1742,7 @@ static Instruction *insertSpills(const FrameDataInfo &FrameData, Alloca->replaceAllUsesWith(G); Alloca->eraseFromParent(); } - return FramePtr; + return; } // If we found any alloca, replace all of their remaining uses with GEP @@ -1735,7 +1773,7 @@ static Instruction *insertSpills(const FrameDataInfo &FrameData, for (Instruction *I : UsersToUpdate) I->replaceUsesOfWith(Alloca, G); } - Builder.SetInsertPoint(FramePtr->getNextNode()); + Builder.SetInsertPoint(Shape.getInsertPtAfterFramePtr()); for (const auto &A : FrameData.Allocas) { AllocaInst *Alloca = A.Alloca; if (A.MayWriteBeforeCoroBegin) { @@ -1755,16 +1793,16 @@ static Instruction *insertSpills(const FrameDataInfo &FrameData, auto *FramePtr = GetFramePointer(Alloca); auto *FramePtrRaw = Builder.CreateBitCast(FramePtr, Type::getInt8PtrTy(C)); - auto *AliasPtr = Builder.CreateGEP( - Type::getInt8Ty(C), FramePtrRaw, - ConstantInt::get(Type::getInt64Ty(C), Alias.second.getValue())); + auto &Value = *Alias.second; + auto ITy = IntegerType::get(C, Value.getBitWidth()); + auto *AliasPtr = Builder.CreateGEP(Type::getInt8Ty(C), FramePtrRaw, + ConstantInt::get(ITy, Value)); auto *AliasPtrTyped = Builder.CreateBitCast(AliasPtr, Alias.first->getType()); Alias.first->replaceUsesWithIf( AliasPtrTyped, [&](Use &U) { return DT.dominates(CB, U); }); } } - return FramePtr; } // Moves the values in the PHIs in SuccBB that correspong to PredBB into a new @@ -2130,7 +2168,7 @@ static void lowerLocalAllocas(ArrayRef<CoroAllocaAllocInst*> LocalAllocas, // Allocate memory. auto Alloca = Builder.CreateAlloca(Builder.getInt8Ty(), AI->getSize()); - Alloca->setAlignment(Align(AI->getAlignment())); + Alloca->setAlignment(AI->getAlignment()); for (auto U : AI->users()) { // Replace gets with the allocation. @@ -2279,7 +2317,10 @@ static void eliminateSwiftErrorArgument(Function &F, Argument &Arg, IRBuilder<> Builder(F.getEntryBlock().getFirstNonPHIOrDbg()); auto ArgTy = cast<PointerType>(Arg.getType()); - auto ValueTy = ArgTy->getPointerElementType(); + // swifterror arguments are required to have pointer-to-pointer type, + // so create a pointer-typed alloca with opaque pointers. + auto ValueTy = ArgTy->isOpaque() ? PointerType::getUnqual(F.getContext()) + : ArgTy->getNonOpaquePointerElementType(); // Reduce to the alloca case: @@ -2520,6 +2561,7 @@ void coro::salvageDebugInfo( bool SkipOutermostLoad = !isa<DbgValueInst>(DVI); Value *Storage = DVI->getVariableLocationOp(0); Value *OriginalStorage = Storage; + while (auto *Inst = dyn_cast_or_null<Instruction>(Storage)) { if (auto *LdInst = dyn_cast<LoadInst>(Inst)) { Storage = LdInst->getOperand(0); @@ -2559,7 +2601,7 @@ void coro::salvageDebugInfo( // // Avoid to create the alloca would be eliminated by optimization // passes and the corresponding dbg.declares would be invalid. - if (!OptimizeFrame && !EnableReuseStorageInFrame) + if (!OptimizeFrame) if (auto *Arg = dyn_cast<llvm::Argument>(Storage)) { auto &Cached = DbgPtrAllocaCache[Storage]; if (!Cached) { @@ -2575,14 +2617,15 @@ void coro::salvageDebugInfo( // expression, we need to add a DW_OP_deref at the *start* of the // expression to first load the contents of the alloca before // adjusting it with the expression. - if (Expr && Expr->isComplex()) - Expr = DIExpression::prepend(Expr, DIExpression::DerefBefore); + Expr = DIExpression::prepend(Expr, DIExpression::DerefBefore); } DVI->replaceVariableLocationOp(OriginalStorage, Storage); DVI->setExpression(Expr); - /// It makes no sense to move the dbg.value intrinsic. - if (!isa<DbgValueInst>(DVI)) { + // We only hoist dbg.declare today since it doesn't make sense to hoist + // dbg.value or dbg.addr since they do not have the same function wide + // guarantees that dbg.declare does. + if (!isa<DbgValueInst>(DVI) && !isa<DbgAddrIntrinsic>(DVI)) { if (auto *II = dyn_cast<InvokeInst>(Storage)) DVI->moveBefore(II->getNormalDest()->getFirstNonPHI()); else if (auto *CBI = dyn_cast<CallBrInst>(Storage)) @@ -2661,13 +2704,6 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) { for (User *U : I.users()) if (Checker.isDefinitionAcrossSuspend(I, U)) Spills[&I].push_back(cast<Instruction>(U)); - - // Manually add dbg.value metadata uses of I. - SmallVector<DbgValueInst *, 16> DVIs; - findDbgValues(DVIs, &I); - for (auto *DVI : DVIs) - if (Checker.isDefinitionAcrossSuspend(I, DVI)) - Spills[&I].push_back(DVI); } if (Spills.empty()) @@ -2754,10 +2790,9 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) { auto *V = Iter.first; SmallVector<DbgValueInst *, 16> DVIs; findDbgValues(DVIs, V); - llvm::for_each(DVIs, [&](DbgValueInst *DVI) { + for (DbgValueInst *DVI : DVIs) if (Checker.isDefinitionAcrossSuspend(*V, DVI)) FrameData.Spills[V].push_back(DVI); - }); } LLVM_DEBUG(dumpSpills("Spills", FrameData.Spills)); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroInternal.h b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroInternal.h index 9a17068df3a9..5557370c82ba 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroInternal.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroInternal.h @@ -13,7 +13,6 @@ #include "CoroInstr.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/Transforms/Coroutines.h" namespace llvm { @@ -21,40 +20,13 @@ class CallGraph; class CallGraphSCC; class PassRegistry; -void initializeCoroEarlyLegacyPass(PassRegistry &); -void initializeCoroSplitLegacyPass(PassRegistry &); -void initializeCoroElideLegacyPass(PassRegistry &); -void initializeCoroCleanupLegacyPass(PassRegistry &); - -// CoroEarly pass marks every function that has coro.begin with a string -// attribute "coroutine.presplit"="0". CoroSplit pass processes the coroutine -// twice. First, it lets it go through complete IPO optimization pipeline as a -// single function. It forces restart of the pipeline by inserting an indirect -// call to an empty function "coro.devirt.trigger" which is devirtualized by -// CoroElide pass that triggers a restart of the pipeline by CGPassManager. -// When CoroSplit pass sees the same coroutine the second time, it splits it up, -// adds coroutine subfunctions to the SCC to be processed by IPO pipeline. -// Async lowering similarily triggers a restart of the pipeline after it has -// split the coroutine. -// -// FIXME: Refactor these attributes as LLVM attributes instead of string -// attributes since these attributes are already used outside LLVM's -// coroutine module. -// FIXME: Remove these values once we remove the Legacy PM. -#define CORO_PRESPLIT_ATTR "coroutine.presplit" -#define UNPREPARED_FOR_SPLIT "0" -#define PREPARED_FOR_SPLIT "1" -#define ASYNC_RESTART_AFTER_SPLIT "2" - -#define CORO_DEVIRT_TRIGGER_FN "coro.devirt.trigger" - namespace coro { +bool declaresAnyIntrinsic(const Module &M); bool declaresIntrinsics(const Module &M, const std::initializer_list<StringRef>); void replaceCoroFree(CoroIdInst *CoroId, bool Elide); -void updateCallGraph(Function &Caller, ArrayRef<Function *> Funcs, - CallGraph &CG, CallGraphSCC &SCC); + /// Recover a dbg.declare prepared by the frontend and emit an alloca /// holding a pointer to the coroutine frame. void salvageDebugInfo( @@ -128,7 +100,7 @@ struct LLVM_LIBRARY_VISIBILITY Shape { StructType *FrameTy; Align FrameAlign; uint64_t FrameSize; - Instruction *FramePtr; + Value *FramePtr; BasicBlock *AllocaSpillBlock; /// This would only be true if optimization are enabled. @@ -210,10 +182,9 @@ struct LLVM_LIBRARY_VISIBILITY Shape { FunctionType *getResumeFunctionType() const { switch (ABI) { - case coro::ABI::Switch: { - auto *FnPtrTy = getSwitchResumePointerType(); - return cast<FunctionType>(FnPtrTy->getPointerElementType()); - } + case coro::ABI::Switch: + return FunctionType::get(Type::getVoidTy(FrameTy->getContext()), + FrameTy->getPointerTo(), /*IsVarArg*/false); case coro::ABI::Retcon: case coro::ABI::RetconOnce: return RetconLowering.ResumePrototype->getFunctionType(); @@ -267,6 +238,12 @@ struct LLVM_LIBRARY_VISIBILITY Shape { return nullptr; } + Instruction *getInsertPtAfterFramePtr() const { + if (auto *I = dyn_cast<Instruction>(FramePtr)) + return I->getNextNode(); + return &cast<Argument>(FramePtr)->getParent()->getEntryBlock().front(); + } + /// Allocate memory according to the rules of the active lowering. /// /// \param CG - if non-null, will be updated for the new call diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroSplit.cpp index b5129809c6a6..ead552d9be4e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -22,15 +22,17 @@ #include "CoroInstr.h" #include "CoroInternal.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/PriorityWorklist.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CallGraph.h" -#include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/BinaryFormat/Dwarf.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -50,13 +52,10 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" -#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/IR/Verifier.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/PrettyStackTrace.h" @@ -869,11 +868,16 @@ void CoroCloner::create() { OrigF.getParent()->end(), ActiveSuspend); } - // Replace all args with undefs. The buildCoroutineFrame algorithm already - // rewritten access to the args that occurs after suspend points with loads - // and stores to/from the coroutine frame. - for (Argument &A : OrigF.args()) - VMap[&A] = UndefValue::get(A.getType()); + // Replace all args with dummy instructions. If an argument is the old frame + // pointer, the dummy will be replaced by the new frame pointer once it is + // computed below. Uses of all other arguments should have already been + // rewritten by buildCoroutineFrame() to use loads/stores on the coroutine + // frame. + SmallVector<Instruction *> DummyArgs; + for (Argument &A : OrigF.args()) { + DummyArgs.push_back(new FreezeInst(UndefValue::get(A.getType()))); + VMap[&A] = DummyArgs.back(); + } SmallVector<ReturnInst *, 4> Returns; @@ -923,6 +927,12 @@ void CoroCloner::create() { NewF->setVisibility(savedVisibility); NewF->setUnnamedAddr(savedUnnamedAddr); NewF->setDLLStorageClass(savedDLLStorageClass); + // The function sanitizer metadata needs to match the signature of the + // function it is being attached to. However this does not hold for split + // functions here. Thus remove the metadata for split functions. + if (Shape.ABI == coro::ABI::Switch && + NewF->hasMetadata(LLVMContext::MD_func_sanitize)) + NewF->eraseMetadata(LLVMContext::MD_func_sanitize); // Replace the attributes of the new function: auto OrigAttrs = NewF->getAttributes(); @@ -932,7 +942,8 @@ void CoroCloner::create() { case coro::ABI::Switch: // Bootstrap attributes by copying function attributes from the // original function. This should include optimization settings and so on. - NewAttrs = NewAttrs.addFnAttributes(Context, AttrBuilder(Context, OrigAttrs.getFnAttrs())); + NewAttrs = NewAttrs.addFnAttributes( + Context, AttrBuilder(Context, OrigAttrs.getFnAttrs())); addFramePointerAttrs(NewAttrs, Context, 0, Shape.FrameSize, Shape.FrameAlign); @@ -1013,7 +1024,15 @@ void CoroCloner::create() { auto *NewVFrame = Builder.CreateBitCast( NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame"); Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]); - OldVFrame->replaceAllUsesWith(NewVFrame); + if (OldVFrame != NewVFrame) + OldVFrame->replaceAllUsesWith(NewVFrame); + + // All uses of the arguments should have been resolved by this point, + // so we can safely remove the dummy values. + for (Instruction *DummyArg : DummyArgs) { + DummyArg->replaceAllUsesWith(UndefValue::get(DummyArg->getType())); + DummyArg->deleteValue(); + } switch (Shape.ABI) { case coro::ABI::Switch: @@ -1063,13 +1082,6 @@ static Function *createClone(Function &F, const Twine &Suffix, return Cloner.getFunction(); } -/// Remove calls to llvm.coro.end in the original function. -static void removeCoroEnds(const coro::Shape &Shape, CallGraph *CG) { - for (auto End : Shape.CoroEnds) { - replaceCoroEnd(End, Shape, Shape.FramePtr, /*in resume*/ false, CG); - } -} - static void updateAsyncFuncPointerContextSize(coro::Shape &Shape) { assert(Shape.ABI == coro::ABI::Async); @@ -1150,7 +1162,8 @@ static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn, Function *DestroyFn, Function *CleanupFn) { assert(Shape.ABI == coro::ABI::Switch); - IRBuilder<> Builder(Shape.FramePtr->getNextNode()); + IRBuilder<> Builder(Shape.getInsertPtAfterFramePtr()); + auto *ResumeAddr = Builder.CreateStructGEP( Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Resume, "resume.addr"); @@ -1559,7 +1572,8 @@ static void simplifySuspendPoints(coro::Shape &Shape) { } static void splitSwitchCoroutine(Function &F, coro::Shape &Shape, - SmallVectorImpl<Function *> &Clones) { + SmallVectorImpl<Function *> &Clones, + TargetTransformInfo &TTI) { assert(Shape.ABI == coro::ABI::Switch); createResumeEntryBlock(F, Shape); @@ -1574,7 +1588,13 @@ static void splitSwitchCoroutine(Function &F, coro::Shape &Shape, postSplitCleanup(*DestroyClone); postSplitCleanup(*CleanupClone); - addMustTailToCoroResumes(*ResumeClone); + // Adding musttail call to support symmetric transfer. + // Skip targets which don't support tail call. + // + // FIXME: Could we support symmetric transfer effectively without musttail + // call? + if (TTI.supportsTailCalls()) + addMustTailToCoroResumes(*ResumeClone); // Store addresses resume/destroy/cleanup functions in the coroutine frame. updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone); @@ -1661,7 +1681,7 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape, // Map all uses of llvm.coro.begin to the allocated frame pointer. { // Make sure we don't invalidate Shape.FramePtr. - TrackingVH<Instruction> Handle(Shape.FramePtr); + TrackingVH<Value> Handle(Shape.FramePtr); Shape.CoroBegin->replaceAllUsesWith(FramePtr); Shape.FramePtr = Handle.getValPtr(); } @@ -1773,7 +1793,7 @@ static void splitRetconCoroutine(Function &F, coro::Shape &Shape, // Map all uses of llvm.coro.begin to the allocated frame pointer. { // Make sure we don't invalidate Shape.FramePtr. - TrackingVH<Instruction> Handle(Shape.FramePtr); + TrackingVH<Value> Handle(Shape.FramePtr); Shape.CoroBegin->replaceAllUsesWith(RawFramePtr); Shape.FramePtr = Handle.getValPtr(); } @@ -1879,6 +1899,7 @@ namespace { static coro::Shape splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones, + TargetTransformInfo &TTI, bool OptimizeFrame) { PrettyStackTraceFunction prettyStackTrace(F); @@ -1901,7 +1922,7 @@ static coro::Shape splitCoroutine(Function &F, } else { switch (Shape.ABI) { case coro::ABI::Switch: - splitSwitchCoroutine(F, Shape, Clones); + splitSwitchCoroutine(F, Shape, Clones, TTI); break; case coro::ABI::Async: splitAsyncCoroutine(F, Shape, Clones); @@ -1917,21 +1938,27 @@ static coro::Shape splitCoroutine(Function &F, // This invalidates SwiftErrorOps in the Shape. replaceSwiftErrorOps(F, Shape, nullptr); - return Shape; -} - -static void -updateCallGraphAfterCoroutineSplit(Function &F, const coro::Shape &Shape, - const SmallVectorImpl<Function *> &Clones, - CallGraph &CG, CallGraphSCC &SCC) { - if (!Shape.CoroBegin) - return; - - removeCoroEnds(Shape, &CG); - postSplitCleanup(F); + // Finally, salvage the llvm.dbg.{declare,addr} in our original function that + // point into the coroutine frame. We only do this for the current function + // since the Cloner salvaged debug info for us in the new coroutine funclets. + SmallVector<DbgVariableIntrinsic *, 8> Worklist; + SmallDenseMap<llvm::Value *, llvm::AllocaInst *, 4> DbgPtrAllocaCache; + for (auto &BB : F) { + for (auto &I : BB) { + if (auto *DDI = dyn_cast<DbgDeclareInst>(&I)) { + Worklist.push_back(DDI); + continue; + } + if (auto *DDI = dyn_cast<DbgAddrIntrinsic>(&I)) { + Worklist.push_back(DDI); + continue; + } + } + } + for (auto *DDI : Worklist) + coro::salvageDebugInfo(DbgPtrAllocaCache, DDI, Shape.OptimizeFrame); - // Update call graph and add the functions we created to the SCC. - coro::updateCallGraph(F, Clones, CG, SCC); + return Shape; } static void updateCallGraphAfterCoroutineSplit( @@ -1976,70 +2003,6 @@ static void updateCallGraphAfterCoroutineSplit( updateCGAndAnalysisManagerForFunctionPass(CG, C, N, AM, UR, FAM); } -// When we see the coroutine the first time, we insert an indirect call to a -// devirt trigger function and mark the coroutine that it is now ready for -// split. -// Async lowering uses this after it has split the function to restart the -// pipeline. -static void prepareForSplit(Function &F, CallGraph &CG, - bool MarkForAsyncRestart = false) { - Module &M = *F.getParent(); - LLVMContext &Context = F.getContext(); -#ifndef NDEBUG - Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN); - assert(DevirtFn && "coro.devirt.trigger function not found"); -#endif - - F.addFnAttr(CORO_PRESPLIT_ATTR, MarkForAsyncRestart - ? ASYNC_RESTART_AFTER_SPLIT - : PREPARED_FOR_SPLIT); - - // Insert an indirect call sequence that will be devirtualized by CoroElide - // pass: - // %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1) - // %1 = bitcast i8* %0 to void(i8*)* - // call void %1(i8* null) - coro::LowererBase Lowerer(M); - Instruction *InsertPt = - MarkForAsyncRestart ? F.getEntryBlock().getFirstNonPHIOrDbgOrLifetime() - : F.getEntryBlock().getTerminator(); - auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(Context)); - auto *DevirtFnAddr = - Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt); - FunctionType *FnTy = FunctionType::get(Type::getVoidTy(Context), - {Type::getInt8PtrTy(Context)}, false); - auto *IndirectCall = CallInst::Create(FnTy, DevirtFnAddr, Null, "", InsertPt); - - // Update CG graph with an indirect call we just added. - CG[&F]->addCalledFunction(IndirectCall, CG.getCallsExternalNode()); -} - -// Make sure that there is a devirtualization trigger function that the -// coro-split pass uses to force a restart of the CGSCC pipeline. If the devirt -// trigger function is not found, we will create one and add it to the current -// SCC. -static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) { - Module &M = CG.getModule(); - if (M.getFunction(CORO_DEVIRT_TRIGGER_FN)) - return; - - LLVMContext &C = M.getContext(); - auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C), - /*isVarArg=*/false); - Function *DevirtFn = - Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage, - CORO_DEVIRT_TRIGGER_FN, &M); - DevirtFn->addFnAttr(Attribute::AlwaysInline); - auto *Entry = BasicBlock::Create(C, "entry", DevirtFn); - ReturnInst::Create(C, Entry); - - auto *Node = CG.getOrInsertFunction(DevirtFn); - - SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end()); - Nodes.push_back(Node); - SCC.initialize(Nodes); -} - /// Replace a call to llvm.coro.prepare.retcon. static void replacePrepare(CallInst *Prepare, LazyCallGraph &CG, LazyCallGraph::SCC &C) { @@ -2076,59 +2039,6 @@ static void replacePrepare(CallInst *Prepare, LazyCallGraph &CG, Cast->eraseFromParent(); } } -/// Replace a call to llvm.coro.prepare.retcon. -static void replacePrepare(CallInst *Prepare, CallGraph &CG) { - auto CastFn = Prepare->getArgOperand(0); // as an i8* - auto Fn = CastFn->stripPointerCasts(); // as its original type - - // Find call graph nodes for the preparation. - CallGraphNode *PrepareUserNode = nullptr, *FnNode = nullptr; - if (auto ConcreteFn = dyn_cast<Function>(Fn)) { - PrepareUserNode = CG[Prepare->getFunction()]; - FnNode = CG[ConcreteFn]; - } - - // Attempt to peephole this pattern: - // %0 = bitcast [[TYPE]] @some_function to i8* - // %1 = call @llvm.coro.prepare.retcon(i8* %0) - // %2 = bitcast %1 to [[TYPE]] - // ==> - // %2 = @some_function - for (Use &U : llvm::make_early_inc_range(Prepare->uses())) { - // Look for bitcasts back to the original function type. - auto *Cast = dyn_cast<BitCastInst>(U.getUser()); - if (!Cast || Cast->getType() != Fn->getType()) continue; - - // Check whether the replacement will introduce new direct calls. - // If so, we'll need to update the call graph. - if (PrepareUserNode) { - for (auto &Use : Cast->uses()) { - if (auto *CB = dyn_cast<CallBase>(Use.getUser())) { - if (!CB->isCallee(&Use)) - continue; - PrepareUserNode->removeCallEdgeFor(*CB); - PrepareUserNode->addCalledFunction(CB, FnNode); - } - } - } - - // Replace and remove the cast. - Cast->replaceAllUsesWith(Fn); - Cast->eraseFromParent(); - } - - // Replace any remaining uses with the function as an i8*. - // This can never directly be a callee, so we don't need to update CG. - Prepare->replaceAllUsesWith(CastFn); - Prepare->eraseFromParent(); - - // Kill dead bitcasts. - while (auto *Cast = dyn_cast<BitCastInst>(CastFn)) { - if (!Cast->use_empty()) break; - CastFn = Cast->getOperand(0); - Cast->eraseFromParent(); - } -} static bool replaceAllPrepares(Function *PrepareFn, LazyCallGraph &CG, LazyCallGraph::SCC &C) { @@ -2143,30 +2053,6 @@ static bool replaceAllPrepares(Function *PrepareFn, LazyCallGraph &CG, return Changed; } -/// Remove calls to llvm.coro.prepare.retcon, a barrier meant to prevent -/// IPO from operating on calls to a retcon coroutine before it's been -/// split. This is only safe to do after we've split all retcon -/// coroutines in the module. We can do that this in this pass because -/// this pass does promise to split all retcon coroutines (as opposed to -/// switch coroutines, which are lowered in multiple stages). -static bool replaceAllPrepares(Function *PrepareFn, CallGraph &CG) { - bool Changed = false; - for (Use &P : llvm::make_early_inc_range(PrepareFn->uses())) { - // Intrinsics can only be used in calls. - auto *Prepare = cast<CallInst>(P.getUser()); - replacePrepare(Prepare, CG); - Changed = true; - } - - return Changed; -} - -static bool declaresCoroSplitIntrinsics(const Module &M) { - return coro::declaresIntrinsics(M, {"llvm.coro.begin", - "llvm.coro.prepare.retcon", - "llvm.coro.prepare.async"}); -} - static void addPrepareFunction(const Module &M, SmallVectorImpl<Function *> &Fns, StringRef Name) { @@ -2185,18 +2071,15 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C, auto &FAM = AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); - if (!declaresCoroSplitIntrinsics(M)) - return PreservedAnalyses::all(); - // Check for uses of llvm.coro.prepare.retcon/async. SmallVector<Function *, 2> PrepareFns; addPrepareFunction(M, PrepareFns, "llvm.coro.prepare.retcon"); addPrepareFunction(M, PrepareFns, "llvm.coro.prepare.async"); // Find coroutines for processing. - SmallVector<LazyCallGraph::Node *, 4> Coroutines; + SmallVector<LazyCallGraph::Node *> Coroutines; for (LazyCallGraph::Node &N : C) - if (N.getFunction().hasFnAttribute(CORO_PRESPLIT_ATTR)) + if (N.getFunction().isPresplitCoroutine()) Coroutines.push_back(&N); if (Coroutines.empty() && PrepareFns.empty()) @@ -2212,13 +2095,12 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C, for (LazyCallGraph::Node *N : Coroutines) { Function &F = N->getFunction(); LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F.getName() - << "' state: " - << F.getFnAttribute(CORO_PRESPLIT_ATTR).getValueAsString() << "\n"); - F.removeFnAttr(CORO_PRESPLIT_ATTR); + F.setSplittedCoroutine(); SmallVector<Function *, 4> Clones; - const coro::Shape Shape = splitCoroutine(F, Clones, OptimizeFrame); + const coro::Shape Shape = splitCoroutine( + F, Clones, FAM.getResult<TargetIRAnalysis>(F), OptimizeFrame); updateCallGraphAfterCoroutineSplit(*N, Shape, Clones, C, CG, AM, UR, FAM); if (!Shape.CoroSuspends.empty()) { @@ -2237,122 +2119,3 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C, return PreservedAnalyses::none(); } - -namespace { - -// We present a coroutine to LLVM as an ordinary function with suspension -// points marked up with intrinsics. We let the optimizer party on the coroutine -// as a single function for as long as possible. Shortly before the coroutine is -// eligible to be inlined into its callers, we split up the coroutine into parts -// corresponding to initial, resume and destroy invocations of the coroutine, -// add them to the current SCC and restart the IPO pipeline to optimize the -// coroutine subfunctions we extracted before proceeding to the caller of the -// coroutine. -struct CoroSplitLegacy : public CallGraphSCCPass { - static char ID; // Pass identification, replacement for typeid - - CoroSplitLegacy(bool OptimizeFrame = false) - : CallGraphSCCPass(ID), OptimizeFrame(OptimizeFrame) { - initializeCoroSplitLegacyPass(*PassRegistry::getPassRegistry()); - } - - bool Run = false; - bool OptimizeFrame; - - // A coroutine is identified by the presence of coro.begin intrinsic, if - // we don't have any, this pass has nothing to do. - bool doInitialization(CallGraph &CG) override { - Run = declaresCoroSplitIntrinsics(CG.getModule()); - return CallGraphSCCPass::doInitialization(CG); - } - - bool runOnSCC(CallGraphSCC &SCC) override { - if (!Run) - return false; - - // Check for uses of llvm.coro.prepare.retcon. - SmallVector<Function *, 2> PrepareFns; - auto &M = SCC.getCallGraph().getModule(); - addPrepareFunction(M, PrepareFns, "llvm.coro.prepare.retcon"); - addPrepareFunction(M, PrepareFns, "llvm.coro.prepare.async"); - - // Find coroutines for processing. - SmallVector<Function *, 4> Coroutines; - for (CallGraphNode *CGN : SCC) - if (auto *F = CGN->getFunction()) - if (F->hasFnAttribute(CORO_PRESPLIT_ATTR)) - Coroutines.push_back(F); - - if (Coroutines.empty() && PrepareFns.empty()) - return false; - - CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); - - if (Coroutines.empty()) { - bool Changed = false; - for (auto *PrepareFn : PrepareFns) - Changed |= replaceAllPrepares(PrepareFn, CG); - return Changed; - } - - createDevirtTriggerFunc(CG, SCC); - - // Split all the coroutines. - for (Function *F : Coroutines) { - Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR); - StringRef Value = Attr.getValueAsString(); - LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName() - << "' state: " << Value << "\n"); - // Async lowering marks coroutines to trigger a restart of the pipeline - // after it has split them. - if (Value == ASYNC_RESTART_AFTER_SPLIT) { - F->removeFnAttr(CORO_PRESPLIT_ATTR); - continue; - } - if (Value == UNPREPARED_FOR_SPLIT) { - prepareForSplit(*F, CG); - continue; - } - F->removeFnAttr(CORO_PRESPLIT_ATTR); - - SmallVector<Function *, 4> Clones; - const coro::Shape Shape = splitCoroutine(*F, Clones, OptimizeFrame); - updateCallGraphAfterCoroutineSplit(*F, Shape, Clones, CG, SCC); - if (Shape.ABI == coro::ABI::Async) { - // Restart SCC passes. - // Mark function for CoroElide pass. It will devirtualize causing a - // restart of the SCC pipeline. - prepareForSplit(*F, CG, true /*MarkForAsyncRestart*/); - } - } - - for (auto *PrepareFn : PrepareFns) - replaceAllPrepares(PrepareFn, CG); - - return true; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - CallGraphSCCPass::getAnalysisUsage(AU); - } - - StringRef getPassName() const override { return "Coroutine Splitting"; } -}; - -} // end anonymous namespace - -char CoroSplitLegacy::ID = 0; - -INITIALIZE_PASS_BEGIN( - CoroSplitLegacy, "coro-split", - "Split coroutine into a set of functions driving its state machine", false, - false) -INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) -INITIALIZE_PASS_END( - CoroSplitLegacy, "coro-split", - "Split coroutine into a set of functions driving its state machine", false, - false) - -Pass *llvm::createCoroSplitLegacyPass(bool OptimizeFrame) { - return new CoroSplitLegacy(OptimizeFrame); -} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/Coroutines.cpp index 965a146c143f..1742e9319c3b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/Coroutines.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/Coroutines.cpp @@ -10,14 +10,11 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Coroutines.h" #include "CoroInstr.h" #include "CoroInternal.h" -#include "llvm-c/Transforms/Coroutines.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/CallGraph.h" -#include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" @@ -26,14 +23,10 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" -#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" -#include "llvm/InitializePasses.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Transforms/IPO.h" -#include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstddef> @@ -41,55 +34,6 @@ using namespace llvm; -void llvm::initializeCoroutines(PassRegistry &Registry) { - initializeCoroEarlyLegacyPass(Registry); - initializeCoroSplitLegacyPass(Registry); - initializeCoroElideLegacyPass(Registry); - initializeCoroCleanupLegacyPass(Registry); -} - -static void addCoroutineOpt0Passes(const PassManagerBuilder &Builder, - legacy::PassManagerBase &PM) { - PM.add(createCoroSplitLegacyPass()); - PM.add(createCoroElideLegacyPass()); - - PM.add(createBarrierNoopPass()); - PM.add(createCoroCleanupLegacyPass()); -} - -static void addCoroutineEarlyPasses(const PassManagerBuilder &Builder, - legacy::PassManagerBase &PM) { - PM.add(createCoroEarlyLegacyPass()); -} - -static void addCoroutineScalarOptimizerPasses(const PassManagerBuilder &Builder, - legacy::PassManagerBase &PM) { - PM.add(createCoroElideLegacyPass()); -} - -static void addCoroutineSCCPasses(const PassManagerBuilder &Builder, - legacy::PassManagerBase &PM) { - PM.add(createCoroSplitLegacyPass(Builder.OptLevel != 0)); -} - -static void addCoroutineOptimizerLastPasses(const PassManagerBuilder &Builder, - legacy::PassManagerBase &PM) { - PM.add(createCoroCleanupLegacyPass()); -} - -void llvm::addCoroutinePassesToExtensionPoints(PassManagerBuilder &Builder) { - Builder.addExtension(PassManagerBuilder::EP_EarlyAsPossible, - addCoroutineEarlyPasses); - Builder.addExtension(PassManagerBuilder::EP_EnabledOnOptLevel0, - addCoroutineOpt0Passes); - Builder.addExtension(PassManagerBuilder::EP_CGSCCOptimizerLate, - addCoroutineSCCPasses); - Builder.addExtension(PassManagerBuilder::EP_ScalarOptimizerLate, - addCoroutineScalarOptimizerPasses); - Builder.addExtension(PassManagerBuilder::EP_OptimizerLast, - addCoroutineOptimizerLastPasses); -} - // Construct the lowerer base class and initialize its members. coro::LowererBase::LowererBase(Module &M) : TheModule(M), Context(M.getContext()), @@ -119,44 +63,55 @@ Value *coro::LowererBase::makeSubFnCall(Value *Arg, int Index, return Bitcast; } +// NOTE: Must be sorted! +static const char *const CoroIntrinsics[] = { + "llvm.coro.align", + "llvm.coro.alloc", + "llvm.coro.async.context.alloc", + "llvm.coro.async.context.dealloc", + "llvm.coro.async.resume", + "llvm.coro.async.size.replace", + "llvm.coro.async.store_resume", + "llvm.coro.begin", + "llvm.coro.destroy", + "llvm.coro.done", + "llvm.coro.end", + "llvm.coro.end.async", + "llvm.coro.frame", + "llvm.coro.free", + "llvm.coro.id", + "llvm.coro.id.async", + "llvm.coro.id.retcon", + "llvm.coro.id.retcon.once", + "llvm.coro.noop", + "llvm.coro.prepare.async", + "llvm.coro.prepare.retcon", + "llvm.coro.promise", + "llvm.coro.resume", + "llvm.coro.save", + "llvm.coro.size", + "llvm.coro.subfn.addr", + "llvm.coro.suspend", + "llvm.coro.suspend.async", + "llvm.coro.suspend.retcon", +}; + #ifndef NDEBUG static bool isCoroutineIntrinsicName(StringRef Name) { - // NOTE: Must be sorted! - static const char *const CoroIntrinsics[] = { - "llvm.coro.align", - "llvm.coro.alloc", - "llvm.coro.async.context.alloc", - "llvm.coro.async.context.dealloc", - "llvm.coro.async.resume", - "llvm.coro.async.size.replace", - "llvm.coro.async.store_resume", - "llvm.coro.begin", - "llvm.coro.destroy", - "llvm.coro.done", - "llvm.coro.end", - "llvm.coro.end.async", - "llvm.coro.frame", - "llvm.coro.free", - "llvm.coro.id", - "llvm.coro.id.async", - "llvm.coro.id.retcon", - "llvm.coro.id.retcon.once", - "llvm.coro.noop", - "llvm.coro.prepare.async", - "llvm.coro.prepare.retcon", - "llvm.coro.promise", - "llvm.coro.resume", - "llvm.coro.save", - "llvm.coro.size", - "llvm.coro.subfn.addr", - "llvm.coro.suspend", - "llvm.coro.suspend.async", - "llvm.coro.suspend.retcon", - }; return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics, Name) != -1; } #endif +bool coro::declaresAnyIntrinsic(const Module &M) { + for (StringRef Name : CoroIntrinsics) { + assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic"); + if (M.getNamedValue(Name)) + return true; + } + + return false; +} + // Verifies if a module has named values listed. Also, in debug mode verifies // that names are intrinsic names. bool coro::declaresIntrinsics(const Module &M, @@ -191,46 +146,6 @@ void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) { } } -// FIXME: This code is stolen from CallGraph::addToCallGraph(Function *F), which -// happens to be private. It is better for this functionality exposed by the -// CallGraph. -static void buildCGN(CallGraph &CG, CallGraphNode *Node) { - Function *F = Node->getFunction(); - - // Look for calls by this function. - for (Instruction &I : instructions(F)) - if (auto *Call = dyn_cast<CallBase>(&I)) { - const Function *Callee = Call->getCalledFunction(); - if (!Callee || !Intrinsic::isLeaf(Callee->getIntrinsicID())) - // Indirect calls of intrinsics are not allowed so no need to check. - // We can be more precise here by using TargetArg returned by - // Intrinsic::isLeaf. - Node->addCalledFunction(Call, CG.getCallsExternalNode()); - else if (!Callee->isIntrinsic()) - Node->addCalledFunction(Call, CG.getOrInsertFunction(Callee)); - } -} - -// Rebuild CGN after we extracted parts of the code from ParentFunc into -// NewFuncs. Builds CGNs for the NewFuncs and adds them to the current SCC. -void coro::updateCallGraph(Function &ParentFunc, ArrayRef<Function *> NewFuncs, - CallGraph &CG, CallGraphSCC &SCC) { - // Rebuild CGN from scratch for the ParentFunc - auto *ParentNode = CG[&ParentFunc]; - ParentNode->removeAllCalledFunctions(); - buildCGN(CG, ParentNode); - - SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end()); - - for (Function *F : NewFuncs) { - CallGraphNode *Callee = CG.getOrInsertFunction(F); - Nodes.push_back(Callee); - buildCGN(CG, Callee); - } - - SCC.initialize(Nodes); -} - static void clear(coro::Shape &Shape) { Shape.CoroBegin = nullptr; Shape.CoroEnds.clear(); @@ -735,25 +650,3 @@ void CoroAsyncEndInst::checkWellFormed() const { "match the tail arguments", MustTailCallFunc); } - -void LLVMAddCoroEarlyPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createCoroEarlyLegacyPass()); -} - -void LLVMAddCoroSplitPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createCoroSplitLegacyPass()); -} - -void LLVMAddCoroElidePass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createCoroElideLegacyPass()); -} - -void LLVMAddCoroCleanupPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createCoroCleanupLegacyPass()); -} - -void -LLVMPassManagerBuilderAddCoroutinePassesToExtensionPoints(LLVMPassManagerBuilderRef PMB) { - PassManagerBuilder *Builder = unwrap(PMB); - addCoroutinePassesToExtensionPoints(*Builder); -} diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/AlwaysInliner.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/AlwaysInliner.cpp index a6d9ce1033f3..58cea7ebb749 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/AlwaysInliner.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/AlwaysInliner.cpp @@ -1,4 +1,4 @@ -//===- InlineAlways.cpp - Code to inline always_inline functions ----------===// +//===- AlwaysInliner.cpp - Code to inline always_inline functions ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -16,15 +16,10 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InlineCost.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" -#include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/IR/CallingConv.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" #include "llvm/InitializePasses.h" -#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/Inliner.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -60,31 +55,38 @@ PreservedAnalyses AlwaysInlinerPass::run(Module &M, for (User *U : F.users()) if (auto *CB = dyn_cast<CallBase>(U)) if (CB->getCalledFunction() == &F && - CB->hasFnAttr(Attribute::AlwaysInline)) - Calls.insert(CB); + CB->hasFnAttr(Attribute::AlwaysInline) && + !CB->getAttributes().hasFnAttr(Attribute::NoInline)) + Calls.insert(CB); for (CallBase *CB : Calls) { Function *Caller = CB->getCaller(); OptimizationRemarkEmitter ORE(Caller); - auto OIC = shouldInline( - *CB, - [&](CallBase &CB) { - return InlineCost::getAlways("always inline attribute"); - }, - ORE); - assert(OIC); - emitInlinedIntoBasedOnCost(ORE, CB->getDebugLoc(), CB->getParent(), F, - *Caller, *OIC, false, DEBUG_TYPE); + DebugLoc DLoc = CB->getDebugLoc(); + BasicBlock *Block = CB->getParent(); InlineFunctionInfo IFI( /*cg=*/nullptr, GetAssumptionCache, &PSI, - &FAM.getResult<BlockFrequencyAnalysis>(*(CB->getCaller())), + &FAM.getResult<BlockFrequencyAnalysis>(*Caller), &FAM.getResult<BlockFrequencyAnalysis>(F)); InlineResult Res = InlineFunction( *CB, IFI, &FAM.getResult<AAManager>(F), InsertLifetime); - assert(Res.isSuccess() && "unexpected failure to inline"); - (void)Res; + if (!Res.isSuccess()) { + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc, + Block) + << "'" << ore::NV("Callee", &F) << "' is not inlined into '" + << ore::NV("Caller", Caller) + << "': " << ore::NV("Reason", Res.getFailureReason()); + }); + continue; + } + + emitInlinedIntoBasedOnCost( + ORE, DLoc, Block, F, *Caller, + InlineCost::getAlways("always inline attribute"), + /*ForProfileContext=*/false, DEBUG_TYPE); // Merge the attributes based on the inlining. AttributeFuncs::mergeAttributesForInlining(*Caller, F); @@ -210,6 +212,9 @@ InlineCost AlwaysInlinerLegacyPass::getInlineCost(CallBase &CB) { if (!CB.hasFnAttr(Attribute::AlwaysInline)) return InlineCost::getNever("no alwaysinline attribute"); + if (Callee->hasFnAttribute(Attribute::AlwaysInline) && CB.isNoInline()) + return InlineCost::getNever("noinline call site attribute"); + auto IsViable = isInlineViable(*Callee); if (!IsViable.isSuccess()) return InlineCost::getNever(IsViable.getFailureReason()); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp index e6a542385662..62cfc3294968 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -29,9 +29,8 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/ArgumentPromotion.h" + #include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" @@ -40,15 +39,11 @@ #include "llvm/ADT/Twine.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" -#include "llvm/Analysis/CGSCCPassManager.h" #include "llvm/Analysis/CallGraph.h" -#include "llvm/Analysis/CallGraphSCCPass.h" -#include "llvm/Analysis/LazyCallGraph.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/MemoryLocation.h" -#include "llvm/Analysis/ValueTracking.h" -#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -56,33 +51,26 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Metadata.h" -#include "llvm/IR/Module.h" #include "llvm/IR/NoFolder.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/Utils/PromoteMemToReg.h" #include <algorithm> #include <cassert> #include <cstdint> -#include <functional> -#include <iterator> -#include <map> -#include <set> #include <utility> #include <vector> @@ -91,43 +79,81 @@ using namespace llvm; #define DEBUG_TYPE "argpromotion" STATISTIC(NumArgumentsPromoted, "Number of pointer arguments promoted"); -STATISTIC(NumAggregatesPromoted, "Number of aggregate arguments promoted"); -STATISTIC(NumByValArgsPromoted, "Number of byval arguments promoted"); STATISTIC(NumArgumentsDead, "Number of dead pointer args eliminated"); -/// A vector used to hold the indices of a single GEP instruction -using IndicesVector = std::vector<uint64_t>; +namespace { + +struct ArgPart { + Type *Ty; + Align Alignment; + /// A representative guaranteed-executed load or store instruction for use by + /// metadata transfer. + Instruction *MustExecInstr; +}; + +using OffsetAndArgPart = std::pair<int64_t, ArgPart>; + +} // end anonymous namespace + +static Value *createByteGEP(IRBuilderBase &IRB, const DataLayout &DL, + Value *Ptr, Type *ResElemTy, int64_t Offset) { + // For non-opaque pointers, try to create a "nice" GEP if possible, otherwise + // fall back to an i8 GEP to a specific offset. + unsigned AddrSpace = Ptr->getType()->getPointerAddressSpace(); + APInt OrigOffset(DL.getIndexTypeSizeInBits(Ptr->getType()), Offset); + if (!Ptr->getType()->isOpaquePointerTy()) { + Type *OrigElemTy = Ptr->getType()->getNonOpaquePointerElementType(); + if (OrigOffset == 0 && OrigElemTy == ResElemTy) + return Ptr; + + if (OrigElemTy->isSized()) { + APInt TmpOffset = OrigOffset; + Type *TmpTy = OrigElemTy; + SmallVector<APInt> IntIndices = + DL.getGEPIndicesForOffset(TmpTy, TmpOffset); + if (TmpOffset == 0) { + // Try to add trailing zero indices to reach the right type. + while (TmpTy != ResElemTy) { + Type *NextTy = GetElementPtrInst::getTypeAtIndex(TmpTy, (uint64_t)0); + if (!NextTy) + break; + + IntIndices.push_back(APInt::getZero( + isa<StructType>(TmpTy) ? 32 : OrigOffset.getBitWidth())); + TmpTy = NextTy; + } + + SmallVector<Value *> Indices; + for (const APInt &Index : IntIndices) + Indices.push_back(IRB.getInt(Index)); + + if (OrigOffset != 0 || TmpTy == ResElemTy) { + Ptr = IRB.CreateGEP(OrigElemTy, Ptr, Indices); + return IRB.CreateBitCast(Ptr, ResElemTy->getPointerTo(AddrSpace)); + } + } + } + } + + if (OrigOffset != 0) { + Ptr = IRB.CreateBitCast(Ptr, IRB.getInt8PtrTy(AddrSpace)); + Ptr = IRB.CreateGEP(IRB.getInt8Ty(), Ptr, IRB.getInt(OrigOffset)); + } + return IRB.CreateBitCast(Ptr, ResElemTy->getPointerTo(AddrSpace)); +} /// DoPromotion - This method actually performs the promotion of the specified /// arguments, and returns the new function. At this point, we know that it's /// safe to do so. static Function * -doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, - SmallPtrSetImpl<Argument *> &ByValArgsToTransform, - Optional<function_ref<void(CallBase &OldCS, CallBase &NewCS)>> - ReplaceCallSite) { +doPromotion(Function *F, FunctionAnalysisManager &FAM, + const DenseMap<Argument *, SmallVector<OffsetAndArgPart, 4>> + &ArgsToPromote) { // Start by computing a new prototype for the function, which is the same as // the old function, but has modified arguments. FunctionType *FTy = F->getFunctionType(); std::vector<Type *> Params; - using ScalarizeTable = std::set<std::pair<Type *, IndicesVector>>; - - // ScalarizedElements - If we are promoting a pointer that has elements - // accessed out of it, keep track of which elements are accessed so that we - // can add one argument for each. - // - // Arguments that are directly loaded will have a zero element value here, to - // handle cases where there are both a direct load and GEP accesses. - std::map<Argument *, ScalarizeTable> ScalarizedElements; - - // OriginalLoads - Keep track of a representative load instruction from the - // original function so that we can tell the alias analysis implementation - // what the new GEP/Load instructions we are inserting look like. - // We need to keep the original loads for each argument and the elements - // of the argument that are accessed. - std::map<std::pair<Argument *, IndicesVector>, LoadInst *> OriginalLoads; - // Attribute - Keep track of the parameter attributes for the arguments // that we are *not* promoting. For the ones that we do promote, the parameter // attributes are lost @@ -138,15 +164,7 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, unsigned ArgNo = 0; for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; ++I, ++ArgNo) { - if (ByValArgsToTransform.count(&*I)) { - // Simple byval argument? Just add all the struct element types. - Type *AgTy = I->getParamByValType(); - StructType *STy = cast<StructType>(AgTy); - llvm::append_range(Params, STy->elements()); - ArgAttrVec.insert(ArgAttrVec.end(), STy->getNumElements(), - AttributeSet()); - ++NumByValArgsPromoted; - } else if (!ArgsToPromote.count(&*I)) { + if (!ArgsToPromote.count(&*I)) { // Unchanged argument Params.push_back(I->getType()); ArgAttrVec.push_back(PAL.getParamAttrs(ArgNo)); @@ -154,58 +172,12 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, // Dead argument (which are always marked as promotable) ++NumArgumentsDead; } else { - // Okay, this is being promoted. This means that the only uses are loads - // or GEPs which are only used by loads - - // In this table, we will track which indices are loaded from the argument - // (where direct loads are tracked as no indices). - ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; - for (User *U : make_early_inc_range(I->users())) { - Instruction *UI = cast<Instruction>(U); - Type *SrcTy; - if (LoadInst *L = dyn_cast<LoadInst>(UI)) - SrcTy = L->getType(); - else - SrcTy = cast<GetElementPtrInst>(UI)->getSourceElementType(); - // Skip dead GEPs and remove them. - if (isa<GetElementPtrInst>(UI) && UI->use_empty()) { - UI->eraseFromParent(); - continue; - } - - IndicesVector Indices; - Indices.reserve(UI->getNumOperands() - 1); - // Since loads will only have a single operand, and GEPs only a single - // non-index operand, this will record direct loads without any indices, - // and gep+loads with the GEP indices. - for (const Use &I : llvm::drop_begin(UI->operands())) - Indices.push_back(cast<ConstantInt>(I)->getSExtValue()); - // GEPs with a single 0 index can be merged with direct loads - if (Indices.size() == 1 && Indices.front() == 0) - Indices.clear(); - ArgIndices.insert(std::make_pair(SrcTy, Indices)); - LoadInst *OrigLoad; - if (LoadInst *L = dyn_cast<LoadInst>(UI)) - OrigLoad = L; - else - // Take any load, we will use it only to update Alias Analysis - OrigLoad = cast<LoadInst>(UI->user_back()); - OriginalLoads[std::make_pair(&*I, Indices)] = OrigLoad; - } - - // Add a parameter to the function for each element passed in. - for (const auto &ArgIndex : ArgIndices) { - // not allowed to dereference ->begin() if size() is 0 - Params.push_back(GetElementPtrInst::getIndexedType( - I->getType()->getPointerElementType(), ArgIndex.second)); + const auto &ArgParts = ArgsToPromote.find(&*I)->second; + for (const auto &Pair : ArgParts) { + Params.push_back(Pair.second.Ty); ArgAttrVec.push_back(AttributeSet()); - assert(Params.back()); } - - if (ArgIndices.size() == 1 && ArgIndices.begin()->second.empty()) - ++NumArgumentsPromoted; - else - ++NumAggregatesPromoted; + ++NumArgumentsPromoted; } } @@ -222,24 +194,30 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, // The new function will have the !dbg metadata copied from the original // function. The original function may not be deleted, and dbg metadata need - // to be unique so we need to drop it. + // to be unique, so we need to drop it. F->setSubprogram(nullptr); LLVM_DEBUG(dbgs() << "ARG PROMOTION: Promoting to:" << *NF << "\n" << "From: " << *F); + uint64_t LargestVectorWidth = 0; + for (auto *I : Params) + if (auto *VT = dyn_cast<llvm::VectorType>(I)) + LargestVectorWidth = std::max( + LargestVectorWidth, VT->getPrimitiveSizeInBits().getKnownMinSize()); + // Recompute the parameter attributes list based on the new arguments for // the function. NF->setAttributes(AttributeList::get(F->getContext(), PAL.getFnAttrs(), PAL.getRetAttrs(), ArgAttrVec)); + AttributeFuncs::updateMinLegalVectorWidthAttr(*NF, LargestVectorWidth); ArgAttrVec.clear(); F->getParent()->getFunctionList().insert(F->getIterator(), NF); NF->takeName(F); - // Loop over all of the callers of the function, transforming the call sites - // to pass in the loaded pointers. - // + // Loop over all the callers of the function, transforming the call sites to + // pass in the loaded pointers. SmallVector<Value *, 16> Args; const DataLayout &DL = F->getParent()->getDataLayout(); while (!F->use_empty()) { @@ -250,74 +228,34 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, // Loop over the operands, inserting GEP and loads in the caller as // appropriate. - auto AI = CB.arg_begin(); + auto *AI = CB.arg_begin(); ArgNo = 0; for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; - ++I, ++AI, ++ArgNo) - if (!ArgsToPromote.count(&*I) && !ByValArgsToTransform.count(&*I)) { + ++I, ++AI, ++ArgNo) { + if (!ArgsToPromote.count(&*I)) { Args.push_back(*AI); // Unmodified argument ArgAttrVec.push_back(CallPAL.getParamAttrs(ArgNo)); - } else if (ByValArgsToTransform.count(&*I)) { - // Emit a GEP and load for each element of the struct. - Type *AgTy = I->getParamByValType(); - StructType *STy = cast<StructType>(AgTy); - Value *Idxs[2] = { - ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), nullptr}; - const StructLayout *SL = DL.getStructLayout(STy); - Align StructAlign = *I->getParamAlign(); - for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { - Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i); - auto *Idx = - IRB.CreateGEP(STy, *AI, Idxs, (*AI)->getName() + "." + Twine(i)); - // TODO: Tell AA about the new values? - Align Alignment = - commonAlignment(StructAlign, SL->getElementOffset(i)); - Args.push_back(IRB.CreateAlignedLoad( - STy->getElementType(i), Idx, Alignment, Idx->getName() + ".val")); - ArgAttrVec.push_back(AttributeSet()); - } } else if (!I->use_empty()) { - // Non-dead argument: insert GEPs and loads as appropriate. - ScalarizeTable &ArgIndices = ScalarizedElements[&*I]; - // Store the Value* version of the indices in here, but declare it now - // for reuse. - std::vector<Value *> Ops; - for (const auto &ArgIndex : ArgIndices) { - Value *V = *AI; - LoadInst *OrigLoad = - OriginalLoads[std::make_pair(&*I, ArgIndex.second)]; - if (!ArgIndex.second.empty()) { - Ops.reserve(ArgIndex.second.size()); - Type *ElTy = V->getType(); - for (auto II : ArgIndex.second) { - // Use i32 to index structs, and i64 for others (pointers/arrays). - // This satisfies GEP constraints. - Type *IdxTy = - (ElTy->isStructTy() ? Type::getInt32Ty(F->getContext()) - : Type::getInt64Ty(F->getContext())); - Ops.push_back(ConstantInt::get(IdxTy, II)); - // Keep track of the type we're currently indexing. - if (auto *ElPTy = dyn_cast<PointerType>(ElTy)) - ElTy = ElPTy->getPointerElementType(); - else - ElTy = GetElementPtrInst::getTypeAtIndex(ElTy, II); - } - // And create a GEP to extract those indices. - V = IRB.CreateGEP(ArgIndex.first, V, Ops, V->getName() + ".idx"); - Ops.clear(); + Value *V = *AI; + const auto &ArgParts = ArgsToPromote.find(&*I)->second; + for (const auto &Pair : ArgParts) { + LoadInst *LI = IRB.CreateAlignedLoad( + Pair.second.Ty, + createByteGEP(IRB, DL, V, Pair.second.Ty, Pair.first), + Pair.second.Alignment, V->getName() + ".val"); + if (Pair.second.MustExecInstr) { + LI->setAAMetadata(Pair.second.MustExecInstr->getAAMetadata()); + LI->copyMetadata(*Pair.second.MustExecInstr, + {LLVMContext::MD_range, LLVMContext::MD_nonnull, + LLVMContext::MD_dereferenceable, + LLVMContext::MD_dereferenceable_or_null, + LLVMContext::MD_align, LLVMContext::MD_noundef}); } - // Since we're replacing a load make sure we take the alignment - // of the previous load. - LoadInst *newLoad = - IRB.CreateLoad(OrigLoad->getType(), V, V->getName() + ".val"); - newLoad->setAlignment(OrigLoad->getAlign()); - // Transfer the AA info too. - newLoad->setAAMetadata(OrigLoad->getAAMetadata()); - - Args.push_back(newLoad); + Args.push_back(LI); ArgAttrVec.push_back(AttributeSet()); } } + } // Push any varargs arguments on the list. for (; AI != CB.arg_end(); ++AI, ++ArgNo) { @@ -345,9 +283,8 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, Args.clear(); ArgAttrVec.clear(); - // Update the callgraph to know that the callsite has been transformed. - if (ReplaceCallSite) - (*ReplaceCallSite)(CB, *NewCS); + AttributeFuncs::updateMinLegalVectorWidthAttr(*CB.getCaller(), + LargestVectorWidth); if (!CB.use_empty()) { CB.replaceAllUsesWith(NewCS); @@ -364,11 +301,15 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, // function empty. NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList()); + // We will collect all the new created allocas to promote them into registers + // after the following loop + SmallVector<AllocaInst *, 4> Allocas; + // Loop over the argument list, transferring uses of the old arguments over to // the new arguments, also transferring over the names as well. Function::arg_iterator I2 = NF->arg_begin(); for (Argument &Arg : F->args()) { - if (!ArgsToPromote.count(&Arg) && !ByValArgsToTransform.count(&Arg)) { + if (!ArgsToPromote.count(&Arg)) { // If this is an unmodified argument, move the name and users over to the // new version. Arg.replaceAllUsesWith(&*I2); @@ -377,37 +318,6 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, continue; } - if (ByValArgsToTransform.count(&Arg)) { - // In the callee, we create an alloca, and store each of the new incoming - // arguments into the alloca. - Instruction *InsertPt = &NF->begin()->front(); - - // Just add all the struct element types. - Type *AgTy = Arg.getParamByValType(); - Align StructAlign = *Arg.getParamAlign(); - Value *TheAlloca = new AllocaInst(AgTy, DL.getAllocaAddrSpace(), nullptr, - StructAlign, "", InsertPt); - StructType *STy = cast<StructType>(AgTy); - Value *Idxs[2] = {ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), - nullptr}; - const StructLayout *SL = DL.getStructLayout(STy); - - for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { - Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i); - Value *Idx = GetElementPtrInst::Create( - AgTy, TheAlloca, Idxs, TheAlloca->getName() + "." + Twine(i), - InsertPt); - I2->setName(Arg.getName() + "." + Twine(i)); - Align Alignment = commonAlignment(StructAlign, SL->getElementOffset(i)); - new StoreInst(&*I2++, Idx, false, Alignment, InsertPt); - } - - // Anything that used the arg should now use the alloca. - Arg.replaceAllUsesWith(TheAlloca); - TheAlloca->takeName(&Arg); - continue; - } - // There potentially are metadata uses for things like llvm.dbg.value. // Replace them with undef, after handling the other regular uses. auto RauwUndefMetadata = make_scope_exit( @@ -416,57 +326,95 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, if (Arg.use_empty()) continue; - // Otherwise, if we promoted this argument, then all users are load - // instructions (or GEPs with only load users), and all loads should be - // using the new argument that we added. - ScalarizeTable &ArgIndices = ScalarizedElements[&Arg]; - - while (!Arg.use_empty()) { - if (LoadInst *LI = dyn_cast<LoadInst>(Arg.user_back())) { - assert(ArgIndices.begin()->second.empty() && - "Load element should sort to front!"); - I2->setName(Arg.getName() + ".val"); - LI->replaceAllUsesWith(&*I2); - LI->eraseFromParent(); - LLVM_DEBUG(dbgs() << "*** Promoted load of argument '" << Arg.getName() - << "' in function '" << F->getName() << "'\n"); - } else { - GetElementPtrInst *GEP = cast<GetElementPtrInst>(Arg.user_back()); - assert(!GEP->use_empty() && - "GEPs without uses should be cleaned up already"); - IndicesVector Operands; - Operands.reserve(GEP->getNumIndices()); - for (const Use &Idx : GEP->indices()) - Operands.push_back(cast<ConstantInt>(Idx)->getSExtValue()); - - // GEPs with a single 0 index can be merged with direct loads - if (Operands.size() == 1 && Operands.front() == 0) - Operands.clear(); - - Function::arg_iterator TheArg = I2; - for (ScalarizeTable::iterator It = ArgIndices.begin(); - It->second != Operands; ++It, ++TheArg) { - assert(It != ArgIndices.end() && "GEP not handled??"); - } + // Otherwise, if we promoted this argument, we have to create an alloca in + // the callee for every promotable part and store each of the new incoming + // arguments into the corresponding alloca, what lets the old code (the + // store instructions if they are allowed especially) a chance to work as + // before. + assert(Arg.getType()->isPointerTy() && + "Only arguments with a pointer type are promotable"); - TheArg->setName(formatv("{0}.{1:$[.]}.val", Arg.getName(), - make_range(Operands.begin(), Operands.end()))); + IRBuilder<NoFolder> IRB(&NF->begin()->front()); - LLVM_DEBUG(dbgs() << "*** Promoted agg argument '" << TheArg->getName() - << "' of function '" << NF->getName() << "'\n"); + // Add only the promoted elements, so parts from ArgsToPromote + SmallDenseMap<int64_t, AllocaInst *> OffsetToAlloca; + for (const auto &Pair : ArgsToPromote.find(&Arg)->second) { + int64_t Offset = Pair.first; + const ArgPart &Part = Pair.second; - // All of the uses must be load instructions. Replace them all with - // the argument specified by ArgNo. - while (!GEP->use_empty()) { - LoadInst *L = cast<LoadInst>(GEP->user_back()); - L->replaceAllUsesWith(&*TheArg); - L->eraseFromParent(); - } - GEP->eraseFromParent(); + Argument *NewArg = I2++; + NewArg->setName(Arg.getName() + "." + Twine(Offset) + ".val"); + + AllocaInst *NewAlloca = IRB.CreateAlloca( + Part.Ty, nullptr, Arg.getName() + "." + Twine(Offset) + ".allc"); + NewAlloca->setAlignment(Pair.second.Alignment); + IRB.CreateAlignedStore(NewArg, NewAlloca, Pair.second.Alignment); + + // Collect the alloca to retarget the users to + OffsetToAlloca.insert({Offset, NewAlloca}); + } + + auto GetAlloca = [&](Value *Ptr) { + APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0); + Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset, + /* AllowNonInbounds */ true); + assert(Ptr == &Arg && "Not constant offset from arg?"); + return OffsetToAlloca.lookup(Offset.getSExtValue()); + }; + + // Cleanup the code from the dead instructions: GEPs and BitCasts in between + // the original argument and its users: loads and stores. Retarget every + // user to the new created alloca. + SmallVector<Value *, 16> Worklist; + SmallVector<Instruction *, 16> DeadInsts; + append_range(Worklist, Arg.users()); + while (!Worklist.empty()) { + Value *V = Worklist.pop_back_val(); + if (isa<BitCastInst>(V) || isa<GetElementPtrInst>(V)) { + DeadInsts.push_back(cast<Instruction>(V)); + append_range(Worklist, V->users()); + continue; + } + + if (auto *LI = dyn_cast<LoadInst>(V)) { + Value *Ptr = LI->getPointerOperand(); + LI->setOperand(LoadInst::getPointerOperandIndex(), GetAlloca(Ptr)); + continue; } + + if (auto *SI = dyn_cast<StoreInst>(V)) { + assert(!SI->isVolatile() && "Volatile operations can't be promoted."); + Value *Ptr = SI->getPointerOperand(); + SI->setOperand(StoreInst::getPointerOperandIndex(), GetAlloca(Ptr)); + continue; + } + + llvm_unreachable("Unexpected user"); + } + + for (Instruction *I : DeadInsts) { + I->replaceAllUsesWith(PoisonValue::get(I->getType())); + I->eraseFromParent(); } - // Increment I2 past all of the arguments added for this promoted pointer. - std::advance(I2, ArgIndices.size()); + + // Collect the allocas for promotion + for (const auto &Pair : OffsetToAlloca) { + assert(isAllocaPromotable(Pair.second) && + "By design, only promotable allocas should be produced."); + Allocas.push_back(Pair.second); + } + } + + LLVM_DEBUG(dbgs() << "ARG PROMOTION: " << Allocas.size() + << " alloca(s) are promotable by Mem2Reg\n"); + + if (!Allocas.empty()) { + // And we are able to call the `promoteMemoryToRegister()` function. + // Our earlier checks have ensured that PromoteMemToReg() will + // succeed. + auto &DT = FAM.getResult<DominatorTreeAnalysis>(*NF); + auto &AC = FAM.getResult<AssumptionAnalysis>(*NF); + PromoteMemToReg(Allocas, DT, &AC); } return NF; @@ -474,100 +422,37 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, /// Return true if we can prove that all callees pass in a valid pointer for the /// specified function argument. -static bool allCallersPassValidPointerForArgument(Argument *Arg, Type *Ty) { +static bool allCallersPassValidPointerForArgument(Argument *Arg, + Align NeededAlign, + uint64_t NeededDerefBytes) { Function *Callee = Arg->getParent(); const DataLayout &DL = Callee->getParent()->getDataLayout(); + APInt Bytes(64, NeededDerefBytes); - unsigned ArgNo = Arg->getArgNo(); + // Check if the argument itself is marked dereferenceable and aligned. + if (isDereferenceableAndAlignedPointer(Arg, NeededAlign, Bytes, DL)) + return true; // Look at all call sites of the function. At this point we know we only have // direct callees. - for (User *U : Callee->users()) { + return all_of(Callee->users(), [&](User *U) { CallBase &CB = cast<CallBase>(*U); - - if (!isDereferenceablePointer(CB.getArgOperand(ArgNo), Ty, DL)) - return false; - } - return true; + return isDereferenceableAndAlignedPointer(CB.getArgOperand(Arg->getArgNo()), + NeededAlign, Bytes, DL); + }); } -/// Returns true if Prefix is a prefix of longer. That means, Longer has a size -/// that is greater than or equal to the size of prefix, and each of the -/// elements in Prefix is the same as the corresponding elements in Longer. -/// -/// This means it also returns true when Prefix and Longer are equal! -static bool isPrefix(const IndicesVector &Prefix, const IndicesVector &Longer) { - if (Prefix.size() > Longer.size()) - return false; - return std::equal(Prefix.begin(), Prefix.end(), Longer.begin()); -} - -/// Checks if Indices, or a prefix of Indices, is in Set. -static bool prefixIn(const IndicesVector &Indices, - std::set<IndicesVector> &Set) { - std::set<IndicesVector>::iterator Low; - Low = Set.upper_bound(Indices); - if (Low != Set.begin()) - Low--; - // Low is now the last element smaller than or equal to Indices. This means - // it points to a prefix of Indices (possibly Indices itself), if such - // prefix exists. - // - // This load is safe if any prefix of its operands is safe to load. - return Low != Set.end() && isPrefix(*Low, Indices); -} - -/// Mark the given indices (ToMark) as safe in the given set of indices -/// (Safe). Marking safe usually means adding ToMark to Safe. However, if there -/// is already a prefix of Indices in Safe, Indices are implicitely marked safe -/// already. Furthermore, any indices that Indices is itself a prefix of, are -/// removed from Safe (since they are implicitely safe because of Indices now). -static void markIndicesSafe(const IndicesVector &ToMark, - std::set<IndicesVector> &Safe) { - std::set<IndicesVector>::iterator Low; - Low = Safe.upper_bound(ToMark); - // Guard against the case where Safe is empty - if (Low != Safe.begin()) - Low--; - // Low is now the last element smaller than or equal to Indices. This - // means it points to a prefix of Indices (possibly Indices itself), if - // such prefix exists. - if (Low != Safe.end()) { - if (isPrefix(*Low, ToMark)) - // If there is already a prefix of these indices (or exactly these - // indices) marked a safe, don't bother adding these indices - return; - - // Increment Low, so we can use it as a "insert before" hint - ++Low; - } - // Insert - Low = Safe.insert(Low, ToMark); - ++Low; - // If there we're a prefix of longer index list(s), remove those - std::set<IndicesVector>::iterator End = Safe.end(); - while (Low != End && isPrefix(ToMark, *Low)) { - std::set<IndicesVector>::iterator Remove = Low; - ++Low; - Safe.erase(Remove); - } -} - -/// isSafeToPromoteArgument - As you might guess from the name of this method, -/// it checks to see if it is both safe and useful to promote the argument. -/// This method limits promotion of aggregates to only promote up to three -/// elements of the aggregate in order to avoid exploding the number of -/// arguments passed in. -static bool isSafeToPromoteArgument(Argument *Arg, Type *ByValTy, AAResults &AAR, - unsigned MaxElements) { - using GEPIndicesSet = std::set<IndicesVector>; - +/// Determine that this argument is safe to promote, and find the argument +/// parts it can be promoted into. +static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR, + unsigned MaxElements, bool IsRecursive, + SmallVectorImpl<OffsetAndArgPart> &ArgPartsVec) { // Quick exit for unused arguments if (Arg->use_empty()) return true; - // We can only promote this argument if all of the uses are loads, or are GEP - // instructions (with constant indices) that are subsequently loaded. + // We can only promote this argument if all the uses are loads at known + // offsets. // // Promoting the argument causes it to be loaded in the caller // unconditionally. This is only safe if we can prove that either the load @@ -578,157 +463,193 @@ static bool isSafeToPromoteArgument(Argument *Arg, Type *ByValTy, AAResults &AAR // anyway, in the latter case, invalid loads won't happen. This prevents us // from introducing an invalid load that wouldn't have happened in the // original code. - // - // This set will contain all sets of indices that are loaded in the entry - // block, and thus are safe to unconditionally load in the caller. - GEPIndicesSet SafeToUnconditionallyLoad; - - // This set contains all the sets of indices that we are planning to promote. - // This makes it possible to limit the number of arguments added. - GEPIndicesSet ToPromote; - - // If the pointer is always valid, any load with first index 0 is valid. - - if (ByValTy) - SafeToUnconditionallyLoad.insert(IndicesVector(1, 0)); - - // Whenever a new underlying type for the operand is found, make sure it's - // consistent with the GEPs and loads we've already seen and, if necessary, - // use it to see if all incoming pointers are valid (which implies the 0-index - // is safe). - Type *BaseTy = ByValTy; - auto UpdateBaseTy = [&](Type *NewBaseTy) { - if (BaseTy) - return BaseTy == NewBaseTy; - - BaseTy = NewBaseTy; - if (allCallersPassValidPointerForArgument(Arg, BaseTy)) { - assert(SafeToUnconditionallyLoad.empty()); - SafeToUnconditionallyLoad.insert(IndicesVector(1, 0)); + + SmallDenseMap<int64_t, ArgPart, 4> ArgParts; + Align NeededAlign(1); + uint64_t NeededDerefBytes = 0; + + // And if this is a byval argument we also allow to have store instructions. + // Only handle in such way arguments with specified alignment; + // if it's unspecified, the actual alignment of the argument is + // target-specific. + bool AreStoresAllowed = Arg->getParamByValType() && Arg->getParamAlign(); + + // An end user of a pointer argument is a load or store instruction. + // Returns None if this load or store is not based on the argument. Return + // true if we can promote the instruction, false otherwise. + auto HandleEndUser = [&](auto *I, Type *Ty, + bool GuaranteedToExecute) -> Optional<bool> { + // Don't promote volatile or atomic instructions. + if (!I->isSimple()) + return false; + + Value *Ptr = I->getPointerOperand(); + APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0); + Ptr = Ptr->stripAndAccumulateConstantOffsets(DL, Offset, + /* AllowNonInbounds */ true); + if (Ptr != Arg) + return None; + + if (Offset.getSignificantBits() >= 64) + return false; + + TypeSize Size = DL.getTypeStoreSize(Ty); + // Don't try to promote scalable types. + if (Size.isScalable()) + return false; + + // If this is a recursive function and one of the types is a pointer, + // then promoting it might lead to recursive promotion. + if (IsRecursive && Ty->isPointerTy()) + return false; + + int64_t Off = Offset.getSExtValue(); + auto Pair = ArgParts.try_emplace( + Off, ArgPart{Ty, I->getAlign(), GuaranteedToExecute ? I : nullptr}); + ArgPart &Part = Pair.first->second; + bool OffsetNotSeenBefore = Pair.second; + + // We limit promotion to only promoting up to a fixed number of elements of + // the aggregate. + if (MaxElements > 0 && ArgParts.size() > MaxElements) { + LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: " + << "more than " << MaxElements << " parts\n"); + return false; } - return true; - }; + // For now, we only support loading/storing one specific type at a given + // offset. + if (Part.Ty != Ty) { + LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: " + << "accessed as both " << *Part.Ty << " and " << *Ty + << " at offset " << Off << "\n"); + return false; + } - // First, iterate functions that are guaranteed to execution on function - // entry and mark loads of (geps of) arguments as safe. - BasicBlock &EntryBlock = Arg->getParent()->front(); - // Declare this here so we can reuse it - IndicesVector Indices; - for (Instruction &I : EntryBlock) { - if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { - Value *V = LI->getPointerOperand(); - if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) { - V = GEP->getPointerOperand(); - if (V == Arg) { - // This load actually loads (part of) Arg? Check the indices then. - Indices.reserve(GEP->getNumIndices()); - for (Use &Idx : GEP->indices()) - if (ConstantInt *CI = dyn_cast<ConstantInt>(Idx)) - Indices.push_back(CI->getSExtValue()); - else - // We found a non-constant GEP index for this argument? Bail out - // right away, can't promote this argument at all. - return false; - - if (!UpdateBaseTy(GEP->getSourceElementType())) - return false; - - // Indices checked out, mark them as safe - markIndicesSafe(Indices, SafeToUnconditionallyLoad); - Indices.clear(); - } - } else if (V == Arg) { - // Direct loads are equivalent to a GEP with a single 0 index. - markIndicesSafe(IndicesVector(1, 0), SafeToUnconditionallyLoad); + // If this instruction is not guaranteed to execute, and we haven't seen a + // load or store at this offset before (or it had lower alignment), then we + // need to remember that requirement. + // Note that skipping instructions of previously seen offsets is only + // correct because we only allow a single type for a given offset, which + // also means that the number of accessed bytes will be the same. + if (!GuaranteedToExecute && + (OffsetNotSeenBefore || Part.Alignment < I->getAlign())) { + // We won't be able to prove dereferenceability for negative offsets. + if (Off < 0) + return false; - if (BaseTy && LI->getType() != BaseTy) - return false; + // If the offset is not aligned, an aligned base pointer won't help. + if (!isAligned(I->getAlign(), Off)) + return false; - BaseTy = LI->getType(); - } + NeededDerefBytes = std::max(NeededDerefBytes, Off + Size.getFixedValue()); + NeededAlign = std::max(NeededAlign, I->getAlign()); } + Part.Alignment = std::max(Part.Alignment, I->getAlign()); + return true; + }; + + // Look for loads and stores that are guaranteed to execute on entry. + for (Instruction &I : Arg->getParent()->getEntryBlock()) { + Optional<bool> Res{}; + if (LoadInst *LI = dyn_cast<LoadInst>(&I)) + Res = HandleEndUser(LI, LI->getType(), /* GuaranteedToExecute */ true); + else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) + Res = HandleEndUser(SI, SI->getValueOperand()->getType(), + /* GuaranteedToExecute */ true); + if (Res && !*Res) + return false; + if (!isGuaranteedToTransferExecutionToSuccessor(&I)) break; } - // Now, iterate all uses of the argument to see if there are any uses that are - // not (GEP+)loads, or any (GEP+)loads that are not safe to promote. + // Now look at all loads of the argument. Remember the load instructions + // for the aliasing check below. + SmallVector<const Use *, 16> Worklist; + SmallPtrSet<const Use *, 16> Visited; SmallVector<LoadInst *, 16> Loads; - IndicesVector Operands; - for (Use &U : Arg->uses()) { - User *UR = U.getUser(); - Operands.clear(); - if (LoadInst *LI = dyn_cast<LoadInst>(UR)) { - // Don't hack volatile/atomic loads - if (!LI->isSimple()) - return false; - Loads.push_back(LI); - // Direct loads are equivalent to a GEP with a zero index and then a load. - Operands.push_back(0); + auto AppendUses = [&](const Value *V) { + for (const Use &U : V->uses()) + if (Visited.insert(&U).second) + Worklist.push_back(&U); + }; + AppendUses(Arg); + while (!Worklist.empty()) { + const Use *U = Worklist.pop_back_val(); + Value *V = U->getUser(); + if (isa<BitCastInst>(V)) { + AppendUses(V); + continue; + } - if (!UpdateBaseTy(LI->getType())) + if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) { + if (!GEP->hasAllConstantIndices()) return false; - } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(UR)) { - if (GEP->use_empty()) { - // Dead GEP's cause trouble later. Just remove them if we run into - // them. - continue; - } + AppendUses(V); + continue; + } - if (!UpdateBaseTy(GEP->getSourceElementType())) + if (auto *LI = dyn_cast<LoadInst>(V)) { + if (!*HandleEndUser(LI, LI->getType(), /* GuaranteedToExecute */ false)) return false; + Loads.push_back(LI); + continue; + } - // Ensure that all of the indices are constants. - for (Use &Idx : GEP->indices()) - if (ConstantInt *C = dyn_cast<ConstantInt>(Idx)) - Operands.push_back(C->getSExtValue()); - else - return false; // Not a constant operand GEP! - - // Ensure that the only users of the GEP are load instructions. - for (User *GEPU : GEP->users()) - if (LoadInst *LI = dyn_cast<LoadInst>(GEPU)) { - // Don't hack volatile/atomic loads - if (!LI->isSimple()) - return false; - Loads.push_back(LI); - } else { - // Other uses than load? - return false; - } - } else { - return false; // Not a load or a GEP. + // Stores are allowed for byval arguments + auto *SI = dyn_cast<StoreInst>(V); + if (AreStoresAllowed && SI && + U->getOperandNo() == StoreInst::getPointerOperandIndex()) { + if (!*HandleEndUser(SI, SI->getValueOperand()->getType(), + /* GuaranteedToExecute */ false)) + return false; + continue; + // Only stores TO the argument is allowed, all the other stores are + // unknown users } - // Now, see if it is safe to promote this load / loads of this GEP. Loading - // is safe if Operands, or a prefix of Operands, is marked as safe. - if (!prefixIn(Operands, SafeToUnconditionallyLoad)) - return false; + // Unknown user. + LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: " + << "unknown user " << *V << "\n"); + return false; + } - // See if we are already promoting a load with these indices. If not, check - // to make sure that we aren't promoting too many elements. If so, nothing - // to do. - if (ToPromote.find(Operands) == ToPromote.end()) { - if (MaxElements > 0 && ToPromote.size() == MaxElements) { - LLVM_DEBUG(dbgs() << "argpromotion not promoting argument '" - << Arg->getName() - << "' because it would require adding more " - << "than " << MaxElements - << " arguments to the function.\n"); - // We limit aggregate promotion to only promoting up to a fixed number - // of elements of the aggregate. - return false; - } - ToPromote.insert(std::move(Operands)); + if (NeededDerefBytes || NeededAlign > 1) { + // Try to prove a required deref / aligned requirement. + if (!allCallersPassValidPointerForArgument(Arg, NeededAlign, + NeededDerefBytes)) { + LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: " + << "not dereferenceable or aligned\n"); + return false; } } - if (Loads.empty()) + if (ArgParts.empty()) return true; // No users, this is a dead argument. - // Okay, now we know that the argument is only used by load instructions and + // Sort parts by offset. + append_range(ArgPartsVec, ArgParts); + sort(ArgPartsVec, + [](const auto &A, const auto &B) { return A.first < B.first; }); + + // Make sure the parts are non-overlapping. + int64_t Offset = ArgPartsVec[0].first; + for (const auto &Pair : ArgPartsVec) { + if (Pair.first < Offset) + return false; // Overlap with previous part. + + Offset = Pair.first + DL.getTypeStoreSize(Pair.second.Ty); + } + + // If store instructions are allowed, the path from the entry of the function + // to each load may be not free of instructions that potentially invalidate + // the load, and this is an admissible situation. + if (AreStoresAllowed) + return true; + + // Okay, now we know that the argument is only used by load instructions, and // it is safe to unconditionally perform all of them. Use alias analysis to // check to see if the pointer is guaranteed to not be modified from entry of // the function to each of the load instructions. @@ -762,118 +683,31 @@ static bool isSafeToPromoteArgument(Argument *Arg, Type *ByValTy, AAResults &AAR return true; } -bool ArgumentPromotionPass::isDenselyPacked(Type *type, const DataLayout &DL) { - // There is no size information, so be conservative. - if (!type->isSized()) - return false; - - // If the alloc size is not equal to the storage size, then there are padding - // bytes. For x86_fp80 on x86-64, size: 80 alloc size: 128. - if (DL.getTypeSizeInBits(type) != DL.getTypeAllocSizeInBits(type)) - return false; - - // FIXME: This isn't the right way to check for padding in vectors with - // non-byte-size elements. - if (VectorType *seqTy = dyn_cast<VectorType>(type)) - return isDenselyPacked(seqTy->getElementType(), DL); - - // For array types, check for padding within members. - if (ArrayType *seqTy = dyn_cast<ArrayType>(type)) - return isDenselyPacked(seqTy->getElementType(), DL); - - if (!isa<StructType>(type)) - return true; - - // Check for padding within and between elements of a struct. - StructType *StructTy = cast<StructType>(type); - const StructLayout *Layout = DL.getStructLayout(StructTy); - uint64_t StartPos = 0; - for (unsigned i = 0, E = StructTy->getNumElements(); i < E; ++i) { - Type *ElTy = StructTy->getElementType(i); - if (!isDenselyPacked(ElTy, DL)) - return false; - if (StartPos != Layout->getElementOffsetInBits(i)) - return false; - StartPos += DL.getTypeAllocSizeInBits(ElTy); - } - - return true; -} - -/// Checks if the padding bytes of an argument could be accessed. -static bool canPaddingBeAccessed(Argument *arg) { - assert(arg->hasByValAttr()); - - // Track all the pointers to the argument to make sure they are not captured. - SmallPtrSet<Value *, 16> PtrValues; - PtrValues.insert(arg); - - // Track all of the stores. - SmallVector<StoreInst *, 16> Stores; - - // Scan through the uses recursively to make sure the pointer is always used - // sanely. - SmallVector<Value *, 16> WorkList(arg->users()); - while (!WorkList.empty()) { - Value *V = WorkList.pop_back_val(); - if (isa<GetElementPtrInst>(V) || isa<PHINode>(V)) { - if (PtrValues.insert(V).second) - llvm::append_range(WorkList, V->users()); - } else if (StoreInst *Store = dyn_cast<StoreInst>(V)) { - Stores.push_back(Store); - } else if (!isa<LoadInst>(V)) { - return true; - } - } - - // Check to make sure the pointers aren't captured - for (StoreInst *Store : Stores) - if (PtrValues.count(Store->getValueOperand())) - return true; - - return false; -} - -/// Check if callers and the callee \p F agree how promoted arguments would be -/// passed. The ones that they do not agree on are eliminated from the sets but -/// the return value has to be observed as well. -static bool areFunctionArgsABICompatible( - const Function &F, const TargetTransformInfo &TTI, - SmallPtrSetImpl<Argument *> &ArgsToPromote, - SmallPtrSetImpl<Argument *> &ByValArgsToTransform) { - // TODO: Check individual arguments so we can promote a subset? - SmallVector<Type *, 32> Types; - for (Argument *Arg : ArgsToPromote) - Types.push_back(Arg->getType()->getPointerElementType()); - for (Argument *Arg : ByValArgsToTransform) - Types.push_back(Arg->getParamByValType()); - - for (const Use &U : F.uses()) { +/// Check if callers and callee agree on how promoted arguments would be +/// passed. +static bool areTypesABICompatible(ArrayRef<Type *> Types, const Function &F, + const TargetTransformInfo &TTI) { + return all_of(F.uses(), [&](const Use &U) { CallBase *CB = dyn_cast<CallBase>(U.getUser()); if (!CB) return false; + const Function *Caller = CB->getCaller(); const Function *Callee = CB->getCalledFunction(); - if (!TTI.areTypesABICompatible(Caller, Callee, Types)) - return false; - } - return true; + return TTI.areTypesABICompatible(Caller, Callee, Types); + }); } /// PromoteArguments - This method checks the specified function to see if there /// are any promotable arguments and if it is safe to promote the function (for /// example, all callers are direct). If safe to promote some arguments, it /// calls the DoPromotion method. -static Function * -promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter, - unsigned MaxElements, - Optional<function_ref<void(CallBase &OldCS, CallBase &NewCS)>> - ReplaceCallSite, - const TargetTransformInfo &TTI) { +static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM, + unsigned MaxElements, bool IsRecursive) { // Don't perform argument promotion for naked functions; otherwise we can end // up removing parameters that are seemingly 'not used' as they are referred // to in the assembly. - if(F->hasFnAttribute(Attribute::Naked)) + if (F->hasFnAttribute(Attribute::Naked)) return nullptr; // Make sure that it is local to this module. @@ -903,20 +737,20 @@ promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter, // Second check: make sure that all callers are direct callers. We can't // transform functions that have indirect callers. Also see if the function - // is self-recursive and check that target features are compatible. - bool isSelfRecursive = false; + // is self-recursive. for (Use &U : F->uses()) { CallBase *CB = dyn_cast<CallBase>(U.getUser()); // Must be a direct call. - if (CB == nullptr || !CB->isCallee(&U)) + if (CB == nullptr || !CB->isCallee(&U) || + CB->getFunctionType() != F->getFunctionType()) return nullptr; // Can't change signature of musttail callee if (CB->isMustTailCall()) return nullptr; - if (CB->getParent()->getParent() == F) - isSelfRecursive = true; + if (CB->getFunction() == F) + IsRecursive = true; } // Can't change signature of musttail caller @@ -926,16 +760,13 @@ promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter, return nullptr; const DataLayout &DL = F->getParent()->getDataLayout(); - - AAResults &AAR = AARGetter(*F); + auto &AAR = FAM.getResult<AAManager>(*F); + const auto &TTI = FAM.getResult<TargetIRAnalysis>(*F); // Check to see which arguments are promotable. If an argument is promotable, // add it to ArgsToPromote. - SmallPtrSet<Argument *, 8> ArgsToPromote; - SmallPtrSet<Argument *, 8> ByValArgsToTransform; + DenseMap<Argument *, SmallVector<OffsetAndArgPart, 4>> ArgsToPromote; for (Argument *PtrArg : PointerArgs) { - Type *AgTy = PtrArg->getType()->getPointerElementType(); - // Replace sret attribute with noalias. This reduces register pressure by // avoiding a register copy. if (PtrArg->hasStructRetAttr()) { @@ -949,72 +780,25 @@ promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter, } } - // If this is a byval argument, and if the aggregate type is small, just - // pass the elements, which is always safe, if the passed value is densely - // packed or if we can prove the padding bytes are never accessed. - // - // Only handle arguments with specified alignment; if it's unspecified, the - // actual alignment of the argument is target-specific. - bool isSafeToPromote = PtrArg->hasByValAttr() && PtrArg->getParamAlign() && - (ArgumentPromotionPass::isDenselyPacked(AgTy, DL) || - !canPaddingBeAccessed(PtrArg)); - if (isSafeToPromote) { - if (StructType *STy = dyn_cast<StructType>(AgTy)) { - if (MaxElements > 0 && STy->getNumElements() > MaxElements) { - LLVM_DEBUG(dbgs() << "argpromotion disable promoting argument '" - << PtrArg->getName() - << "' because it would require adding more" - << " than " << MaxElements - << " arguments to the function.\n"); - continue; - } - - // If all the elements are single-value types, we can promote it. - bool AllSimple = true; - for (const auto *EltTy : STy->elements()) { - if (!EltTy->isSingleValueType()) { - AllSimple = false; - break; - } - } + // If we can promote the pointer to its value. + SmallVector<OffsetAndArgPart, 4> ArgParts; - // Safe to transform, don't even bother trying to "promote" it. - // Passing the elements as a scalar will allow sroa to hack on - // the new alloca we introduce. - if (AllSimple) { - ByValArgsToTransform.insert(PtrArg); - continue; - } - } - } + if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, ArgParts)) { + SmallVector<Type *, 4> Types; + for (const auto &Pair : ArgParts) + Types.push_back(Pair.second.Ty); - // If the argument is a recursive type and we're in a recursive - // function, we could end up infinitely peeling the function argument. - if (isSelfRecursive) { - if (StructType *STy = dyn_cast<StructType>(AgTy)) { - bool RecursiveType = - llvm::is_contained(STy->elements(), PtrArg->getType()); - if (RecursiveType) - continue; + if (areTypesABICompatible(Types, *F, TTI)) { + ArgsToPromote.insert({PtrArg, std::move(ArgParts)}); } } - - // Otherwise, see if we can promote the pointer to its value. - Type *ByValTy = - PtrArg->hasByValAttr() ? PtrArg->getParamByValType() : nullptr; - if (isSafeToPromoteArgument(PtrArg, ByValTy, AAR, MaxElements)) - ArgsToPromote.insert(PtrArg); } // No promotable pointer arguments. - if (ArgsToPromote.empty() && ByValArgsToTransform.empty()) + if (ArgsToPromote.empty()) return nullptr; - if (!areFunctionArgsABICompatible( - *F, TTI, ArgsToPromote, ByValArgsToTransform)) - return nullptr; - - return doPromotion(F, ArgsToPromote, ByValArgsToTransform, ReplaceCallSite); + return doPromotion(F, FAM, ArgsToPromote); } PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C, @@ -1030,19 +814,10 @@ PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C, FunctionAnalysisManager &FAM = AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); + bool IsRecursive = C.size() > 1; for (LazyCallGraph::Node &N : C) { Function &OldF = N.getFunction(); - - // FIXME: This lambda must only be used with this function. We should - // skip the lambda and just get the AA results directly. - auto AARGetter = [&](Function &F) -> AAResults & { - assert(&F == &OldF && "Called with an unexpected function!"); - return FAM.getResult<AAManager>(F); - }; - - const TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(OldF); - Function *NewF = - promoteArguments(&OldF, AARGetter, MaxElements, None, TTI); + Function *NewF = promoteArguments(&OldF, FAM, MaxElements, IsRecursive); if (!NewF) continue; LocalChange = true; @@ -1077,111 +852,3 @@ PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C, PA.preserveSet<AllAnalysesOn<Function>>(); return PA; } - -namespace { - -/// ArgPromotion - The 'by reference' to 'by value' argument promotion pass. -struct ArgPromotion : public CallGraphSCCPass { - // Pass identification, replacement for typeid - static char ID; - - explicit ArgPromotion(unsigned MaxElements = 3) - : CallGraphSCCPass(ID), MaxElements(MaxElements) { - initializeArgPromotionPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - getAAResultsAnalysisUsage(AU); - CallGraphSCCPass::getAnalysisUsage(AU); - } - - bool runOnSCC(CallGraphSCC &SCC) override; - -private: - using llvm::Pass::doInitialization; - - bool doInitialization(CallGraph &CG) override; - - /// The maximum number of elements to expand, or 0 for unlimited. - unsigned MaxElements; -}; - -} // end anonymous namespace - -char ArgPromotion::ID = 0; - -INITIALIZE_PASS_BEGIN(ArgPromotion, "argpromotion", - "Promote 'by reference' arguments to scalars", false, - false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_END(ArgPromotion, "argpromotion", - "Promote 'by reference' arguments to scalars", false, false) - -Pass *llvm::createArgumentPromotionPass(unsigned MaxElements) { - return new ArgPromotion(MaxElements); -} - -bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) { - if (skipSCC(SCC)) - return false; - - // Get the callgraph information that we need to update to reflect our - // changes. - CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); - - LegacyAARGetter AARGetter(*this); - - bool Changed = false, LocalChange; - - // Iterate until we stop promoting from this SCC. - do { - LocalChange = false; - // Attempt to promote arguments from all functions in this SCC. - for (CallGraphNode *OldNode : SCC) { - Function *OldF = OldNode->getFunction(); - if (!OldF) - continue; - - auto ReplaceCallSite = [&](CallBase &OldCS, CallBase &NewCS) { - Function *Caller = OldCS.getParent()->getParent(); - CallGraphNode *NewCalleeNode = - CG.getOrInsertFunction(NewCS.getCalledFunction()); - CallGraphNode *CallerNode = CG[Caller]; - CallerNode->replaceCallEdge(cast<CallBase>(OldCS), - cast<CallBase>(NewCS), NewCalleeNode); - }; - - const TargetTransformInfo &TTI = - getAnalysis<TargetTransformInfoWrapperPass>().getTTI(*OldF); - if (Function *NewF = promoteArguments(OldF, AARGetter, MaxElements, - {ReplaceCallSite}, TTI)) { - LocalChange = true; - - // Update the call graph for the newly promoted function. - CallGraphNode *NewNode = CG.getOrInsertFunction(NewF); - NewNode->stealCalledFunctionsFrom(OldNode); - if (OldNode->getNumReferences() == 0) - delete CG.removeFunctionFromModule(OldNode); - else - OldF->setLinkage(Function::ExternalLinkage); - - // And updat ethe SCC we're iterating as well. - SCC.ReplaceNode(OldNode, NewNode); - } - } - // Remember that we changed something. - Changed |= LocalChange; - } while (LocalChange); - - return Changed; -} - -bool ArgPromotion::doInitialization(CallGraph &CG) { - return CallGraphSCCPass::doInitialization(CG); -} diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp index 7bca2084c448..b05b7990e3f0 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp @@ -15,29 +15,25 @@ #include "llvm/Transforms/IPO/Attributor.h" -#include "llvm/ADT/GraphTraits.h" #include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/TinyPtrVector.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/Analysis/InlineCost.h" -#include "llvm/Analysis/LazyValueInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" -#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/MustExecute.h" -#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" -#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/NoFolder.h" #include "llvm/IR/ValueHandle.h" -#include "llvm/IR/Verifier.h" #include "llvm/InitializePasses.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" @@ -50,6 +46,10 @@ #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" +#ifdef EXPENSIVE_CHECKS +#include "llvm/IR/Verifier.h" +#endif + #include <cassert> #include <string> @@ -123,13 +123,13 @@ static cl::list<std::string> SeedAllowList("attributor-seed-allow-list", cl::Hidden, cl::desc("Comma seperated list of attribute names that are " "allowed to be seeded."), - cl::ZeroOrMore, cl::CommaSeparated); + cl::CommaSeparated); static cl::list<std::string> FunctionSeedAllowList( "attributor-function-seed-allow-list", cl::Hidden, cl::desc("Comma seperated list of function names that are " "allowed to be seeded."), - cl::ZeroOrMore, cl::CommaSeparated); + cl::CommaSeparated); #endif static cl::opt<bool> @@ -209,33 +209,25 @@ bool AA::isNoSyncInst(Attributor &A, const Instruction &I, } bool AA::isDynamicallyUnique(Attributor &A, const AbstractAttribute &QueryingAA, - const Value &V) { - if (auto *C = dyn_cast<Constant>(&V)) - return !C->isThreadDependent(); - // TODO: Inspect and cache more complex instructions. - if (auto *CB = dyn_cast<CallBase>(&V)) - return CB->getNumOperands() == 0 && !CB->mayHaveSideEffects() && - !CB->mayReadFromMemory(); - const Function *Scope = nullptr; - if (auto *I = dyn_cast<Instruction>(&V)) - Scope = I->getFunction(); - if (auto *A = dyn_cast<Argument>(&V)) - Scope = A->getParent(); - if (!Scope) + const Value &V, bool ForAnalysisOnly) { + // TODO: See the AAInstanceInfo class comment. + if (!ForAnalysisOnly) return false; - auto &NoRecurseAA = A.getAAFor<AANoRecurse>( - QueryingAA, IRPosition::function(*Scope), DepClassTy::OPTIONAL); - return NoRecurseAA.isAssumedNoRecurse(); + auto &InstanceInfoAA = A.getAAFor<AAInstanceInfo>( + QueryingAA, IRPosition::value(V), DepClassTy::OPTIONAL); + return InstanceInfoAA.isAssumedUniqueForAnalysis(); } Constant *AA::getInitialValueForObj(Value &Obj, Type &Ty, const TargetLibraryInfo *TLI) { if (isa<AllocaInst>(Obj)) return UndefValue::get(&Ty); - if (isAllocationFn(&Obj, TLI)) - return getInitialValueOfAllocation(&cast<CallBase>(Obj), TLI, &Ty); + if (Constant *Init = getInitialValueOfAllocation(&Obj, TLI, &Ty)) + return Init; auto *GV = dyn_cast<GlobalVariable>(&Obj); - if (!GV || !GV->hasLocalLinkage()) + if (!GV) + return nullptr; + if (!GV->hasLocalLinkage() && !(GV->isConstant() && GV->hasInitializer())) return nullptr; if (!GV->hasInitializer()) return UndefValue::get(&Ty); @@ -252,19 +244,29 @@ bool AA::isValidInScope(const Value &V, const Function *Scope) { return false; } -bool AA::isValidAtPosition(const Value &V, const Instruction &CtxI, +bool AA::isValidAtPosition(const AA::ValueAndContext &VAC, InformationCache &InfoCache) { - if (isa<Constant>(V)) + if (isa<Constant>(VAC.getValue()) || VAC.getValue() == VAC.getCtxI()) return true; - const Function *Scope = CtxI.getFunction(); - if (auto *A = dyn_cast<Argument>(&V)) + const Function *Scope = nullptr; + const Instruction *CtxI = VAC.getCtxI(); + if (CtxI) + Scope = CtxI->getFunction(); + if (auto *A = dyn_cast<Argument>(VAC.getValue())) return A->getParent() == Scope; - if (auto *I = dyn_cast<Instruction>(&V)) + if (auto *I = dyn_cast<Instruction>(VAC.getValue())) { if (I->getFunction() == Scope) { - const DominatorTree *DT = - InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(*Scope); - return DT && DT->dominates(I, &CtxI); + if (const DominatorTree *DT = + InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>( + *Scope)) + return DT->dominates(I, CtxI); + // Local dominance check mostly for the old PM passes. + if (CtxI && I->getParent() == CtxI->getParent()) + return llvm::any_of( + make_range(I->getIterator(), I->getParent()->end()), + [&](const Instruction &AfterI) { return &AfterI == CtxI; }); } + } return false; } @@ -295,11 +297,11 @@ AA::combineOptionalValuesInAAValueLatice(const Optional<Value *> &A, const Optional<Value *> &B, Type *Ty) { if (A == B) return A; - if (!B.hasValue()) + if (!B) return A; if (*B == nullptr) return nullptr; - if (!A.hasValue()) + if (!A) return Ty ? getWithType(**B, *Ty) : nullptr; if (*A == nullptr) return nullptr; @@ -314,22 +316,33 @@ AA::combineOptionalValuesInAAValueLatice(const Optional<Value *> &A, return nullptr; } -bool AA::getPotentialCopiesOfStoredValue( - Attributor &A, StoreInst &SI, SmallSetVector<Value *, 4> &PotentialCopies, - const AbstractAttribute &QueryingAA, bool &UsedAssumedInformation) { +template <bool IsLoad, typename Ty> +static bool getPotentialCopiesOfMemoryValue( + Attributor &A, Ty &I, SmallSetVector<Value *, 4> &PotentialCopies, + SmallSetVector<Instruction *, 4> &PotentialValueOrigins, + const AbstractAttribute &QueryingAA, bool &UsedAssumedInformation, + bool OnlyExact) { + LLVM_DEBUG(dbgs() << "Trying to determine the potential copies of " << I + << " (only exact: " << OnlyExact << ")\n";); - Value &Ptr = *SI.getPointerOperand(); + Value &Ptr = *I.getPointerOperand(); SmallVector<Value *, 8> Objects; - if (!AA::getAssumedUnderlyingObjects(A, Ptr, Objects, QueryingAA, &SI, + if (!AA::getAssumedUnderlyingObjects(A, Ptr, Objects, QueryingAA, &I, UsedAssumedInformation)) { LLVM_DEBUG( dbgs() << "Underlying objects stored into could not be determined\n";); return false; } + // Containers to remember the pointer infos and new copies while we are not + // sure that we can find all of them. If we abort we want to avoid spurious + // dependences and potential copies in the provided container. SmallVector<const AAPointerInfo *> PIs; SmallVector<Value *> NewCopies; + SmallVector<Instruction *> NewCopyOrigins; + const auto *TLI = + A.getInfoCache().getTargetLibraryInfoForFunction(*I.getFunction()); for (Value *Obj : Objects) { LLVM_DEBUG(dbgs() << "Visit underlying object " << *Obj << "\n"); if (isa<UndefValue>(Obj)) @@ -337,7 +350,7 @@ bool AA::getPotentialCopiesOfStoredValue( if (isa<ConstantPointerNull>(Obj)) { // A null pointer access can be undefined but any offset from null may // be OK. We do not try to optimize the latter. - if (!NullPointerIsDefined(SI.getFunction(), + if (!NullPointerIsDefined(I.getFunction(), Ptr.getType()->getPointerAddressSpace()) && A.getAssumedSimplified(Ptr, QueryingAA, UsedAssumedInformation) == Obj) @@ -346,37 +359,74 @@ bool AA::getPotentialCopiesOfStoredValue( dbgs() << "Underlying object is a valid nullptr, giving up.\n";); return false; } + // TODO: Use assumed noalias return. if (!isa<AllocaInst>(Obj) && !isa<GlobalVariable>(Obj) && - !isNoAliasCall(Obj)) { + !(IsLoad ? isAllocationFn(Obj, TLI) : isNoAliasCall(Obj))) { LLVM_DEBUG(dbgs() << "Underlying object is not supported yet: " << *Obj << "\n";); return false; } if (auto *GV = dyn_cast<GlobalVariable>(Obj)) - if (!GV->hasLocalLinkage()) { + if (!GV->hasLocalLinkage() && + !(GV->isConstant() && GV->hasInitializer())) { LLVM_DEBUG(dbgs() << "Underlying object is global with external " "linkage, not supported yet: " << *Obj << "\n";); return false; } + if (IsLoad) { + Value *InitialValue = AA::getInitialValueForObj(*Obj, *I.getType(), TLI); + if (!InitialValue) + return false; + NewCopies.push_back(InitialValue); + NewCopyOrigins.push_back(nullptr); + } + auto CheckAccess = [&](const AAPointerInfo::Access &Acc, bool IsExact) { - if (!Acc.isRead()) + if ((IsLoad && !Acc.isWrite()) || (!IsLoad && !Acc.isRead())) + return true; + if (IsLoad && Acc.isWrittenValueYetUndetermined()) return true; - auto *LI = dyn_cast<LoadInst>(Acc.getRemoteInst()); - if (!LI) { - LLVM_DEBUG(dbgs() << "Underlying object read through a non-load " - "instruction not supported yet: " - << *Acc.getRemoteInst() << "\n";); + if (OnlyExact && !IsExact && + !isa_and_nonnull<UndefValue>(Acc.getWrittenValue())) { + LLVM_DEBUG(dbgs() << "Non exact access " << *Acc.getRemoteInst() + << ", abort!\n"); return false; } - NewCopies.push_back(LI); + if (IsLoad) { + assert(isa<LoadInst>(I) && "Expected load or store instruction only!"); + if (!Acc.isWrittenValueUnknown()) { + NewCopies.push_back(Acc.getWrittenValue()); + NewCopyOrigins.push_back(Acc.getRemoteInst()); + return true; + } + auto *SI = dyn_cast<StoreInst>(Acc.getRemoteInst()); + if (!SI) { + LLVM_DEBUG(dbgs() << "Underlying object written through a non-store " + "instruction not supported yet: " + << *Acc.getRemoteInst() << "\n";); + return false; + } + NewCopies.push_back(SI->getValueOperand()); + NewCopyOrigins.push_back(SI); + } else { + assert(isa<StoreInst>(I) && "Expected load or store instruction only!"); + auto *LI = dyn_cast<LoadInst>(Acc.getRemoteInst()); + if (!LI && OnlyExact) { + LLVM_DEBUG(dbgs() << "Underlying object read through a non-load " + "instruction not supported yet: " + << *Acc.getRemoteInst() << "\n";); + return false; + } + NewCopies.push_back(Acc.getRemoteInst()); + } return true; }; auto &PI = A.getAAFor<AAPointerInfo>(QueryingAA, IRPosition::value(*Obj), DepClassTy::NONE); - if (!PI.forallInterferingAccesses(SI, CheckAccess)) { + if (!PI.forallInterferingAccesses(A, QueryingAA, I, CheckAccess)) { LLVM_DEBUG( dbgs() << "Failed to verify all interfering accesses for underlying object: " @@ -386,16 +436,40 @@ bool AA::getPotentialCopiesOfStoredValue( PIs.push_back(&PI); } + // Only if we were successful collection all potential copies we record + // dependences (on non-fix AAPointerInfo AAs). We also only then modify the + // given PotentialCopies container. for (auto *PI : PIs) { if (!PI->getState().isAtFixpoint()) UsedAssumedInformation = true; A.recordDependence(*PI, QueryingAA, DepClassTy::OPTIONAL); } PotentialCopies.insert(NewCopies.begin(), NewCopies.end()); + PotentialValueOrigins.insert(NewCopyOrigins.begin(), NewCopyOrigins.end()); return true; } +bool AA::getPotentiallyLoadedValues( + Attributor &A, LoadInst &LI, SmallSetVector<Value *, 4> &PotentialValues, + SmallSetVector<Instruction *, 4> &PotentialValueOrigins, + const AbstractAttribute &QueryingAA, bool &UsedAssumedInformation, + bool OnlyExact) { + return getPotentialCopiesOfMemoryValue</* IsLoad */ true>( + A, LI, PotentialValues, PotentialValueOrigins, QueryingAA, + UsedAssumedInformation, OnlyExact); +} + +bool AA::getPotentialCopiesOfStoredValue( + Attributor &A, StoreInst &SI, SmallSetVector<Value *, 4> &PotentialCopies, + const AbstractAttribute &QueryingAA, bool &UsedAssumedInformation, + bool OnlyExact) { + SmallSetVector<Instruction *, 4> PotentialValueOrigins; + return getPotentialCopiesOfMemoryValue</* IsLoad */ false>( + A, SI, PotentialCopies, PotentialValueOrigins, QueryingAA, + UsedAssumedInformation, OnlyExact); +} + static bool isAssumedReadOnlyOrReadNone(Attributor &A, const IRPosition &IRP, const AbstractAttribute &QueryingAA, bool RequireReadNone, bool &IsKnown) { @@ -450,6 +524,8 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI, SmallVector<const Instruction *> Worklist; Worklist.push_back(&FromI); + const auto &NoRecurseAA = A.getAAFor<AANoRecurse>( + QueryingAA, IRPosition::function(ToFn), DepClassTy::OPTIONAL); while (!Worklist.empty()) { const Instruction *CurFromI = Worklist.pop_back_val(); if (!Visited.insert(CurFromI).second) @@ -469,7 +545,8 @@ isPotentiallyReachable(Attributor &A, const Instruction &FromI, << *ToI << " [Intra]\n"); if (Result) return true; - continue; + if (NoRecurseAA.isAssumedNoRecurse()) + continue; } // TODO: If we can go arbitrarily backwards we will eventually reach an @@ -632,7 +709,7 @@ Argument *IRPosition::getAssociatedArgument() const { assert(ACS.getCalledFunction()->arg_size() > u && "ACS mapped into var-args arguments!"); - if (CBCandidateArg.hasValue()) { + if (CBCandidateArg) { CBCandidateArg = nullptr; break; } @@ -641,7 +718,7 @@ Argument *IRPosition::getAssociatedArgument() const { } // If we found a unique callback candidate argument, return it. - if (CBCandidateArg.hasValue() && CBCandidateArg.getValue()) + if (CBCandidateArg && CBCandidateArg.getValue()) return CBCandidateArg.getValue(); // If no callbacks were found, or none used the underlying call site operand @@ -950,22 +1027,24 @@ Attributor::getAssumedConstant(const IRPosition &IRP, bool &UsedAssumedInformation) { // First check all callbacks provided by outside AAs. If any of them returns // a non-null value that is different from the associated value, or None, we - // assume it's simpliied. + // assume it's simplified. for (auto &CB : SimplificationCallbacks.lookup(IRP)) { Optional<Value *> SimplifiedV = CB(IRP, &AA, UsedAssumedInformation); - if (!SimplifiedV.hasValue()) + if (!SimplifiedV) return llvm::None; if (isa_and_nonnull<Constant>(*SimplifiedV)) return cast<Constant>(*SimplifiedV); return nullptr; } + if (auto *C = dyn_cast<Constant>(&IRP.getAssociatedValue())) + return C; const auto &ValueSimplifyAA = getAAFor<AAValueSimplify>(AA, IRP, DepClassTy::NONE); Optional<Value *> SimplifiedV = ValueSimplifyAA.getAssumedSimplifiedValue(*this); bool IsKnown = ValueSimplifyAA.isAtFixpoint(); UsedAssumedInformation |= !IsKnown; - if (!SimplifiedV.hasValue()) { + if (!SimplifiedV) { recordDependence(ValueSimplifyAA, AA, DepClassTy::OPTIONAL); return llvm::None; } @@ -988,18 +1067,18 @@ Attributor::getAssumedSimplified(const IRPosition &IRP, bool &UsedAssumedInformation) { // First check all callbacks provided by outside AAs. If any of them returns // a non-null value that is different from the associated value, or None, we - // assume it's simpliied. + // assume it's simplified. for (auto &CB : SimplificationCallbacks.lookup(IRP)) return CB(IRP, AA, UsedAssumedInformation); - // If no high-level/outside simplification occured, use AAValueSimplify. + // If no high-level/outside simplification occurred, use AAValueSimplify. const auto &ValueSimplifyAA = getOrCreateAAFor<AAValueSimplify>(IRP, AA, DepClassTy::NONE); Optional<Value *> SimplifiedV = ValueSimplifyAA.getAssumedSimplifiedValue(*this); bool IsKnown = ValueSimplifyAA.isAtFixpoint(); UsedAssumedInformation |= !IsKnown; - if (!SimplifiedV.hasValue()) { + if (!SimplifiedV) { if (AA) recordDependence(ValueSimplifyAA, *AA, DepClassTy::OPTIONAL); return llvm::None; @@ -1018,7 +1097,7 @@ Attributor::getAssumedSimplified(const IRPosition &IRP, Optional<Value *> Attributor::translateArgumentToCallSiteContent( Optional<Value *> V, CallBase &CB, const AbstractAttribute &AA, bool &UsedAssumedInformation) { - if (!V.hasValue()) + if (!V) return V; if (*V == nullptr || isa<Constant>(*V)) return V; @@ -1079,6 +1158,19 @@ bool Attributor::isAssumedDead(const Use &U, BasicBlock *IncomingBB = PHI->getIncomingBlock(U); return isAssumedDead(*IncomingBB->getTerminator(), QueryingAA, FnLivenessAA, UsedAssumedInformation, CheckBBLivenessOnly, DepClass); + } else if (StoreInst *SI = dyn_cast<StoreInst>(UserI)) { + if (!CheckBBLivenessOnly && SI->getPointerOperand() != U.get()) { + const IRPosition IRP = IRPosition::inst(*SI); + const AAIsDead &IsDeadAA = + getOrCreateAAFor<AAIsDead>(IRP, QueryingAA, DepClassTy::NONE); + if (IsDeadAA.isRemovableStore()) { + if (QueryingAA) + recordDependence(IsDeadAA, *QueryingAA, DepClass); + if (!IsDeadAA.isKnown(AAIsDead::IS_REMOVABLE)) + UsedAssumedInformation = true; + return true; + } + } } return isAssumedDead(IRPosition::inst(*UserI), QueryingAA, FnLivenessAA, @@ -1192,6 +1284,7 @@ bool Attributor::checkForAllUses( function_ref<bool(const Use &, bool &)> Pred, const AbstractAttribute &QueryingAA, const Value &V, bool CheckBBLivenessOnly, DepClassTy LivenessDepClass, + bool IgnoreDroppableUses, function_ref<bool(const Use &OldU, const Use &NewU)> EquivalentUseCB) { // Check the trivial case first as it catches void values. @@ -1232,7 +1325,7 @@ bool Attributor::checkForAllUses( LLVM_DEBUG(dbgs() << "[Attributor] Dead use, skip!\n"); continue; } - if (U->getUser()->isDroppable()) { + if (IgnoreDroppableUses && U->getUser()->isDroppable()) { LLVM_DEBUG(dbgs() << "[Attributor] Droppable user, skip!\n"); continue; } @@ -1242,9 +1335,9 @@ bool Attributor::checkForAllUses( if (!Visited.insert(U).second) continue; SmallSetVector<Value *, 4> PotentialCopies; - if (AA::getPotentialCopiesOfStoredValue(*this, *SI, PotentialCopies, - QueryingAA, - UsedAssumedInformation)) { + if (AA::getPotentialCopiesOfStoredValue( + *this, *SI, PotentialCopies, QueryingAA, UsedAssumedInformation, + /* OnlyExact */ true)) { LLVM_DEBUG(dbgs() << "[Attributor] Value is stored, continue with " << PotentialCopies.size() << " potential copies instead!\n"); @@ -1324,8 +1417,7 @@ bool Attributor::checkForAllCallSites(function_ref<bool(AbstractCallSite)> Pred, continue; } if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U.getUser())) { - if (CE->isCast() && CE->getType()->isPointerTy() && - CE->getType()->getPointerElementType()->isFunctionTy()) { + if (CE->isCast() && CE->getType()->isPointerTy()) { LLVM_DEBUG( dbgs() << "[Attributor] Use, is constant cast expression, add " << CE->getNumUses() @@ -1472,30 +1564,24 @@ static bool checkForAllInstructionsImpl( } bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred, + const Function *Fn, const AbstractAttribute &QueryingAA, const ArrayRef<unsigned> &Opcodes, bool &UsedAssumedInformation, bool CheckBBLivenessOnly, bool CheckPotentiallyDead) { - - const IRPosition &IRP = QueryingAA.getIRPosition(); // Since we need to provide instructions we have to have an exact definition. - const Function *AssociatedFunction = IRP.getAssociatedFunction(); - if (!AssociatedFunction) - return false; - - if (AssociatedFunction->isDeclaration()) + if (!Fn || Fn->isDeclaration()) return false; // TODO: use the function scope once we have call site AAReturnedValues. - const IRPosition &QueryIRP = IRPosition::function(*AssociatedFunction); + const IRPosition &QueryIRP = IRPosition::function(*Fn); const auto *LivenessAA = (CheckBBLivenessOnly || CheckPotentiallyDead) ? nullptr : &(getAAFor<AAIsDead>(QueryingAA, QueryIRP, DepClassTy::NONE)); - auto &OpcodeInstMap = - InfoCache.getOpcodeInstMapForFunction(*AssociatedFunction); + auto &OpcodeInstMap = InfoCache.getOpcodeInstMapForFunction(*Fn); if (!checkForAllInstructionsImpl(this, OpcodeInstMap, Pred, &QueryingAA, LivenessAA, Opcodes, UsedAssumedInformation, CheckBBLivenessOnly, CheckPotentiallyDead)) @@ -1504,6 +1590,19 @@ bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred, return true; } +bool Attributor::checkForAllInstructions(function_ref<bool(Instruction &)> Pred, + const AbstractAttribute &QueryingAA, + const ArrayRef<unsigned> &Opcodes, + bool &UsedAssumedInformation, + bool CheckBBLivenessOnly, + bool CheckPotentiallyDead) { + const IRPosition &IRP = QueryingAA.getIRPosition(); + const Function *AssociatedFunction = IRP.getAssociatedFunction(); + return checkForAllInstructions(Pred, AssociatedFunction, QueryingAA, Opcodes, + UsedAssumedInformation, CheckBBLivenessOnly, + CheckPotentiallyDead); +} + bool Attributor::checkForAllReadWriteInstructions( function_ref<bool(Instruction &)> Pred, AbstractAttribute &QueryingAA, bool &UsedAssumedInformation) { @@ -1542,11 +1641,8 @@ void Attributor::runTillFixpoint() { // the abstract analysis. unsigned IterationCounter = 1; - unsigned MaxFixedPointIterations; - if (MaxFixpointIterations) - MaxFixedPointIterations = MaxFixpointIterations.getValue(); - else - MaxFixedPointIterations = SetFixpointIterations; + unsigned MaxIterations = + Configuration.MaxFixpointIterations.value_or(SetFixpointIterations); SmallVector<AbstractAttribute *, 32> ChangedAAs; SetVector<AbstractAttribute *> Worklist, InvalidAAs; @@ -1631,21 +1727,20 @@ void Attributor::runTillFixpoint() { QueryAAsAwaitingUpdate.end()); QueryAAsAwaitingUpdate.clear(); - } while (!Worklist.empty() && (IterationCounter++ < MaxFixedPointIterations || - VerifyMaxFixpointIterations)); + } while (!Worklist.empty() && + (IterationCounter++ < MaxIterations || VerifyMaxFixpointIterations)); - if (IterationCounter > MaxFixedPointIterations && !Worklist.empty()) { + if (IterationCounter > MaxIterations && !Functions.empty()) { auto Remark = [&](OptimizationRemarkMissed ORM) { return ORM << "Attributor did not reach a fixpoint after " - << ore::NV("Iterations", MaxFixedPointIterations) - << " iterations."; + << ore::NV("Iterations", MaxIterations) << " iterations."; }; - Function *F = Worklist.front()->getIRPosition().getAssociatedFunction(); + Function *F = Functions.front(); emitRemark<OptimizationRemarkMissed>(F, "FixedPoint", Remark); } LLVM_DEBUG(dbgs() << "\n[Attributor] Fixpoint iteration done after: " - << IterationCounter << "/" << MaxFixpointIterations + << IterationCounter << "/" << MaxIterations << " iterations\n"); // Reset abstract arguments not settled in a sound fixpoint by now. This @@ -1679,11 +1774,9 @@ void Attributor::runTillFixpoint() { << " abstract attributes.\n"; }); - if (VerifyMaxFixpointIterations && - IterationCounter != MaxFixedPointIterations) { + if (VerifyMaxFixpointIterations && IterationCounter != MaxIterations) { errs() << "\n[Attributor] Fixpoint iteration done after: " - << IterationCounter << "/" << MaxFixedPointIterations - << " iterations\n"; + << IterationCounter << "/" << MaxIterations << " iterations\n"; llvm_unreachable("The fixpoint was not reached with exactly the number of " "specified iterations!"); } @@ -1720,6 +1813,9 @@ ChangeStatus Attributor::manifestAttributes() { if (!State.isValidState()) continue; + if (AA->getCtxI() && !isRunOn(*AA->getAnchorScope())) + continue; + // Skip dead code. bool UsedAssumedInformation = false; if (isAssumedDead(*AA, nullptr, UsedAssumedInformation, @@ -1769,7 +1865,7 @@ ChangeStatus Attributor::manifestAttributes() { void Attributor::identifyDeadInternalFunctions() { // Early exit if we don't intend to delete functions. - if (!DeleteFns) + if (!Configuration.DeleteFns) return; // Identify dead internal functions and delete them. This happens outside @@ -1821,7 +1917,8 @@ ChangeStatus Attributor::cleanupIR() { << ToBeDeletedBlocks.size() << " blocks and " << ToBeDeletedInsts.size() << " instructions and " << ToBeChangedValues.size() << " values and " - << ToBeChangedUses.size() << " uses. " + << ToBeChangedUses.size() << " uses. To insert " + << ToBeChangedToUnreachableInsts.size() << " unreachables." << "Preserve manifest added " << ManifestAddedBlocks.size() << " blocks\n"); @@ -1839,12 +1936,15 @@ ChangeStatus Attributor::cleanupIR() { NewV = Entry.first; } while (true); + Instruction *I = dyn_cast<Instruction>(U->getUser()); + assert((!I || isRunOn(*I->getFunction())) && + "Cannot replace an instruction outside the current SCC!"); + // Do not replace uses in returns if the value is a must-tail call we will // not delete. - if (auto *RI = dyn_cast<ReturnInst>(U->getUser())) { + if (auto *RI = dyn_cast_or_null<ReturnInst>(I)) { if (auto *CI = dyn_cast<CallInst>(OldV->stripPointerCasts())) - if (CI->isMustTailCall() && - (!ToBeDeletedInsts.count(CI) || !isRunOn(*CI->getCaller()))) + if (CI->isMustTailCall() && !ToBeDeletedInsts.count(CI)) return; // If we rewrite a return and the new value is not an argument, strip the // `returned` attribute as it is wrong now. @@ -1854,8 +1954,8 @@ ChangeStatus Attributor::cleanupIR() { } // Do not perform call graph altering changes outside the SCC. - if (auto *CB = dyn_cast<CallBase>(U->getUser())) - if (CB->isCallee(U) && !isRunOn(*CB->getCaller())) + if (auto *CB = dyn_cast_or_null<CallBase>(I)) + if (CB->isCallee(U)) return; LLVM_DEBUG(dbgs() << "Use " << *NewV << " in " << *U->getUser() @@ -1903,8 +2003,12 @@ ChangeStatus Attributor::cleanupIR() { for (auto &U : OldV->uses()) if (Entry.second || !U.getUser()->isDroppable()) Uses.push_back(&U); - for (Use *U : Uses) + for (Use *U : Uses) { + if (auto *I = dyn_cast<Instruction>(U->getUser())) + if (!isRunOn(*I->getFunction())) + continue; ReplaceUse(U, NewV); + } } for (auto &V : InvokeWithDeadSuccessor) @@ -1935,15 +2039,15 @@ ChangeStatus Attributor::cleanupIR() { } } for (Instruction *I : TerminatorsToFold) { - if (!isRunOn(*I->getFunction())) - continue; + assert(isRunOn(*I->getFunction()) && + "Cannot replace a terminator outside the current SCC!"); CGModifiedFunctions.insert(I->getFunction()); ConstantFoldTerminator(I->getParent()); } for (auto &V : ToBeChangedToUnreachableInsts) if (Instruction *I = dyn_cast_or_null<Instruction>(V)) { - if (!isRunOn(*I->getFunction())) - continue; + assert(isRunOn(*I->getFunction()) && + "Cannot replace an instruction outside the current SCC!"); CGModifiedFunctions.insert(I->getFunction()); changeToUnreachable(I); } @@ -1951,10 +2055,10 @@ ChangeStatus Attributor::cleanupIR() { for (auto &V : ToBeDeletedInsts) { if (Instruction *I = dyn_cast_or_null<Instruction>(V)) { if (auto *CB = dyn_cast<CallBase>(I)) { - if (!isRunOn(*I->getFunction())) - continue; + assert(isRunOn(*I->getFunction()) && + "Cannot delete an instruction outside the current SCC!"); if (!isa<IntrinsicInst>(CB)) - CGUpdater.removeCallSite(*CB); + Configuration.CGUpdater.removeCallSite(*CB); } I->dropDroppableUses(); CGModifiedFunctions.insert(I->getFunction()); @@ -1967,9 +2071,7 @@ ChangeStatus Attributor::cleanupIR() { } } - llvm::erase_if(DeadInsts, [&](WeakTrackingVH I) { - return !I || !isRunOn(*cast<Instruction>(I)->getFunction()); - }); + llvm::erase_if(DeadInsts, [&](WeakTrackingVH I) { return !I; }); LLVM_DEBUG({ dbgs() << "[Attributor] DeadInsts size: " << DeadInsts.size() << "\n"; @@ -2005,12 +2107,12 @@ ChangeStatus Attributor::cleanupIR() { for (Function *Fn : CGModifiedFunctions) if (!ToBeDeletedFunctions.count(Fn) && Functions.count(Fn)) - CGUpdater.reanalyzeFunction(*Fn); + Configuration.CGUpdater.reanalyzeFunction(*Fn); for (Function *Fn : ToBeDeletedFunctions) { if (!Functions.count(Fn)) continue; - CGUpdater.removeFunction(*Fn); + Configuration.CGUpdater.removeFunction(*Fn); } if (!ToBeChangedUses.empty()) @@ -2249,7 +2351,7 @@ bool Attributor::internalizeFunctions(SmallPtrSetImpl<Function *> &FnSet, bool Attributor::isValidFunctionSignatureRewrite( Argument &Arg, ArrayRef<Type *> ReplacementTypes) { - if (!RewriteSignatures) + if (!Configuration.RewriteSignatures) return false; Function *Fn = Arg.getParent(); @@ -2364,7 +2466,7 @@ bool Attributor::shouldSeedAttribute(AbstractAttribute &AA) { } ChangeStatus Attributor::rewriteFunctionSignatures( - SmallPtrSetImpl<Function *> &ModifiedFns) { + SmallSetVector<Function *, 8> &ModifiedFns) { ChangeStatus Changed = ChangeStatus::UNCHANGED; for (auto &It : ArgumentReplacementMap) { @@ -2397,6 +2499,12 @@ ChangeStatus Attributor::rewriteFunctionSignatures( } } + uint64_t LargestVectorWidth = 0; + for (auto *I : NewArgumentTypes) + if (auto *VT = dyn_cast<llvm::VectorType>(I)) + LargestVectorWidth = std::max( + LargestVectorWidth, VT->getPrimitiveSizeInBits().getKnownMinSize()); + FunctionType *OldFnTy = OldFn->getFunctionType(); Type *RetTy = OldFnTy->getReturnType(); @@ -2426,6 +2534,7 @@ ChangeStatus Attributor::rewriteFunctionSignatures( NewFn->setAttributes(AttributeList::get( Ctx, OldFnAttributeList.getFnAttrs(), OldFnAttributeList.getRetAttrs(), NewArgumentAttributes)); + AttributeFuncs::updateMinLegalVectorWidthAttr(*NewFn, LargestVectorWidth); // Since we have now created the new function, splice the body of the old // function right into the new function, leaving the old rotting hulk of the @@ -2503,6 +2612,9 @@ ChangeStatus Attributor::rewriteFunctionSignatures( Ctx, OldCallAttributeList.getFnAttrs(), OldCallAttributeList.getRetAttrs(), NewArgOperandAttributes)); + AttributeFuncs::updateMinLegalVectorWidthAttr(*NewCB->getCaller(), + LargestVectorWidth); + CallSitePairs.push_back({OldCB, NewCB}); return true; }; @@ -2523,6 +2635,9 @@ ChangeStatus Attributor::rewriteFunctionSignatures( ARIs[OldArgNum]) { if (ARI->CalleeRepairCB) ARI->CalleeRepairCB(*ARI, *NewFn, NewFnArgIt); + if (ARI->ReplacementTypes.empty()) + OldFnArgIt->replaceAllUsesWith( + PoisonValue::get(OldFnArgIt->getType())); NewFnArgIt += ARI->ReplacementTypes.size(); } else { NewFnArgIt->takeName(&*OldFnArgIt); @@ -2538,17 +2653,17 @@ ChangeStatus Attributor::rewriteFunctionSignatures( assert(OldCB.getType() == NewCB.getType() && "Cannot handle call sites with different types!"); ModifiedFns.insert(OldCB.getFunction()); - CGUpdater.replaceCallSite(OldCB, NewCB); + Configuration.CGUpdater.replaceCallSite(OldCB, NewCB); OldCB.replaceAllUsesWith(&NewCB); OldCB.eraseFromParent(); } // Replace the function in the call graph (if any). - CGUpdater.replaceFunctionWith(*OldFn, *NewFn); + Configuration.CGUpdater.replaceFunctionWith(*OldFn, *NewFn); // If the old function was modified and needed to be reanalyzed, the new one // does now. - if (ModifiedFns.erase(OldFn)) + if (ModifiedFns.remove(OldFn)) ModifiedFns.insert(NewFn); Changed = ChangeStatus::CHANGED; @@ -2568,6 +2683,30 @@ void InformationCache::initializeInformationCache(const Function &CF, // queried by abstract attributes during their initialization or update. // This has to happen before we create attributes. + DenseMap<const Value *, Optional<short>> AssumeUsesMap; + + // Add \p V to the assume uses map which track the number of uses outside of + // "visited" assumes. If no outside uses are left the value is added to the + // assume only use vector. + auto AddToAssumeUsesMap = [&](const Value &V) -> void { + SmallVector<const Instruction *> Worklist; + if (auto *I = dyn_cast<Instruction>(&V)) + Worklist.push_back(I); + while (!Worklist.empty()) { + const Instruction *I = Worklist.pop_back_val(); + Optional<short> &NumUses = AssumeUsesMap[I]; + if (!NumUses) + NumUses = I->getNumUses(); + NumUses = NumUses.getValue() - /* this assume */ 1; + if (NumUses.getValue() != 0) + continue; + AssumeOnlyValues.insert(I); + for (const Value *Op : I->operands()) + if (auto *OpI = dyn_cast<Instruction>(Op)) + Worklist.push_back(OpI); + } + }; + for (Instruction &I : instructions(&F)) { bool IsInterestingOpcode = false; @@ -2588,6 +2727,7 @@ void InformationCache::initializeInformationCache(const Function &CF, // For `must-tail` calls we remember the caller and callee. if (auto *Assume = dyn_cast<AssumeInst>(&I)) { fillMapFromAssume(*Assume, KnowledgeMap); + AddToAssumeUsesMap(*Assume->getArgOperand(0)); } else if (cast<CallInst>(I).isMustTailCall()) { FI.ContainsMustTailCall = true; if (const Function *Callee = cast<CallInst>(I).getCalledFunction()) @@ -2736,7 +2876,8 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { getOrCreateAAFor<AAIsDead>(RetPos); // Every function might be simplified. - getOrCreateAAFor<AAValueSimplify>(RetPos); + bool UsedAssumedInformation = false; + getAssumedSimplified(RetPos, nullptr, UsedAssumedInformation); // Every returned value might be marked noundef. getOrCreateAAFor<AANoUndef>(RetPos); @@ -2828,7 +2969,8 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { if (!Callee->getReturnType()->isVoidTy() && !CB.use_empty()) { IRPosition CBRetPos = IRPosition::callsite_returned(CB); - getOrCreateAAFor<AAValueSimplify>(CBRetPos); + bool UsedAssumedInformation = false; + getAssumedSimplified(CBRetPos, nullptr, UsedAssumedInformation); } for (int I = 0, E = CB.arg_size(); I < E; ++I) { @@ -2891,10 +3033,15 @@ void Attributor::identifyDefaultAbstractAttributes(Function &F) { getOrCreateAAFor<AAAlign>( IRPosition::value(*cast<LoadInst>(I).getPointerOperand())); if (SimplifyAllLoads) - getOrCreateAAFor<AAValueSimplify>(IRPosition::value(I)); - } else - getOrCreateAAFor<AAAlign>( - IRPosition::value(*cast<StoreInst>(I).getPointerOperand())); + getAssumedSimplified(IRPosition::value(I), nullptr, + UsedAssumedInformation); + } else { + auto &SI = cast<StoreInst>(I); + getOrCreateAAFor<AAIsDead>(IRPosition::inst(I)); + getAssumedSimplified(IRPosition::value(*SI.getValueOperand()), nullptr, + UsedAssumedInformation); + getOrCreateAAFor<AAAlign>(IRPosition::value(*SI.getPointerOperand())); + } return true; }; Success = checkForAllInstructionsImpl( @@ -2969,8 +3116,8 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, if (!S.isValidState()) OS << "full-set"; else { - for (auto &it : S.getAssumedSet()) - OS << it << ", "; + for (auto &It : S.getAssumedSet()) + OS << It << ", "; if (S.undefIsContained()) OS << "undef "; } @@ -3012,8 +3159,12 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, OS << " [" << Acc.getKind() << "] " << *Acc.getRemoteInst(); if (Acc.getLocalInst() != Acc.getRemoteInst()) OS << " via " << *Acc.getLocalInst(); - if (Acc.getContent().hasValue()) - OS << " [" << *Acc.getContent() << "]"; + if (Acc.getContent()) { + if (*Acc.getContent()) + OS << " [" << **Acc.getContent() << "]"; + else + OS << " [ <unknown> ]"; + } return OS; } ///} @@ -3026,7 +3177,7 @@ static bool runAttributorOnFunctions(InformationCache &InfoCache, SetVector<Function *> &Functions, AnalysisGetter &AG, CallGraphUpdater &CGUpdater, - bool DeleteFns) { + bool DeleteFns, bool IsModulePass) { if (Functions.empty()) return false; @@ -3039,8 +3190,10 @@ static bool runAttributorOnFunctions(InformationCache &InfoCache, // Create an Attributor and initially empty information cache that is filled // while we identify default attribute opportunities. - Attributor A(Functions, InfoCache, CGUpdater, /* Allowed */ nullptr, - DeleteFns); + AttributorConfig AC(CGUpdater); + AC.IsModulePass = IsModulePass; + AC.DeleteFns = DeleteFns; + Attributor A(Functions, InfoCache, AC); // Create shallow wrappers for all functions that are not IPO amendable if (AllowShallowWrappers) @@ -3145,7 +3298,7 @@ PreservedAnalyses AttributorPass::run(Module &M, ModuleAnalysisManager &AM) { BumpPtrAllocator Allocator; InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ nullptr); if (runAttributorOnFunctions(InfoCache, Functions, AG, CGUpdater, - /* DeleteFns */ true)) { + /* DeleteFns */ true, /* IsModulePass */ true)) { // FIXME: Think about passes we will preserve and add them here. return PreservedAnalyses::none(); } @@ -3173,7 +3326,8 @@ PreservedAnalyses AttributorCGSCCPass::run(LazyCallGraph::SCC &C, BumpPtrAllocator Allocator; InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ &Functions); if (runAttributorOnFunctions(InfoCache, Functions, AG, CGUpdater, - /* DeleteFns */ false)) { + /* DeleteFns */ false, + /* IsModulePass */ false)) { // FIXME: Think about passes we will preserve and add them here. PreservedAnalyses PA; PA.preserve<FunctionAnalysisManagerCGSCCProxy>(); @@ -3249,7 +3403,8 @@ struct AttributorLegacyPass : public ModulePass { BumpPtrAllocator Allocator; InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ nullptr); return runAttributorOnFunctions(InfoCache, Functions, AG, CGUpdater, - /* DeleteFns*/ true); + /* DeleteFns*/ true, + /* IsModulePass */ true); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -3286,7 +3441,8 @@ struct AttributorCGSCCLegacyPass : public CallGraphSCCPass { BumpPtrAllocator Allocator; InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ &Functions); return runAttributorOnFunctions(InfoCache, Functions, AG, CGUpdater, - /* DeleteFns */ false); + /* DeleteFns */ false, + /* IsModulePass */ false); } void getAnalysisUsage(AnalysisUsage &AU) const override { diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp index 61a973f869d4..4d99ce7e3175 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -14,9 +14,11 @@ #include "llvm/Transforms/IPO/Attributor.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -30,22 +32,29 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" #include "llvm/IR/Assumptions.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/NoFolder.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" #include "llvm/Support/Alignment.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/FileSystem.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/IPO/ArgumentPromotion.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/ValueMapper.h" #include <cassert> using namespace llvm; @@ -70,11 +79,11 @@ static cl::opt<unsigned, true> MaxPotentialValues( cl::location(llvm::PotentialConstantIntValuesState::MaxPotentialValues), cl::init(7)); -static cl::opt<unsigned> - MaxInterferingWrites("attributor-max-interfering-writes", cl::Hidden, - cl::desc("Maximum number of interfering writes to " - "check before assuming all might interfere."), - cl::init(6)); +static cl::opt<unsigned> MaxInterferingAccesses( + "attributor-max-interfering-accesses", cl::Hidden, + cl::desc("Maximum number of interfering accesses to " + "check before assuming all might interfere."), + cl::init(6)); STATISTIC(NumAAs, "Number of abstract attributes created"); @@ -141,6 +150,7 @@ PIPE_OPERATOR(AANonNull) PIPE_OPERATOR(AANoAlias) PIPE_OPERATOR(AADereferenceable) PIPE_OPERATOR(AAAlign) +PIPE_OPERATOR(AAInstanceInfo) PIPE_OPERATOR(AANoCapture) PIPE_OPERATOR(AAValueSimplify) PIPE_OPERATOR(AANoFree) @@ -151,7 +161,7 @@ PIPE_OPERATOR(AAMemoryLocation) PIPE_OPERATOR(AAValueConstantRange) PIPE_OPERATOR(AAPrivatizablePtr) PIPE_OPERATOR(AAUndefinedBehavior) -PIPE_OPERATOR(AAPotentialValues) +PIPE_OPERATOR(AAPotentialConstantValues) PIPE_OPERATOR(AANoUndef) PIPE_OPERATOR(AACallEdges) PIPE_OPERATOR(AAFunctionReachability) @@ -171,6 +181,45 @@ ChangeStatus clampStateAndIndicateChange<DerefState>(DerefState &S, } // namespace llvm +/// Checks if a type could have padding bytes. +static bool isDenselyPacked(Type *Ty, const DataLayout &DL) { + // There is no size information, so be conservative. + if (!Ty->isSized()) + return false; + + // If the alloc size is not equal to the storage size, then there are padding + // bytes. For x86_fp80 on x86-64, size: 80 alloc size: 128. + if (DL.getTypeSizeInBits(Ty) != DL.getTypeAllocSizeInBits(Ty)) + return false; + + // FIXME: This isn't the right way to check for padding in vectors with + // non-byte-size elements. + if (VectorType *SeqTy = dyn_cast<VectorType>(Ty)) + return isDenselyPacked(SeqTy->getElementType(), DL); + + // For array types, check for padding within members. + if (ArrayType *SeqTy = dyn_cast<ArrayType>(Ty)) + return isDenselyPacked(SeqTy->getElementType(), DL); + + if (!isa<StructType>(Ty)) + return true; + + // Check for padding within and between elements of a struct. + StructType *StructTy = cast<StructType>(Ty); + const StructLayout *Layout = DL.getStructLayout(StructTy); + uint64_t StartPos = 0; + for (unsigned I = 0, E = StructTy->getNumElements(); I < E; ++I) { + Type *ElTy = StructTy->getElementType(I); + if (!isDenselyPacked(ElTy, DL)) + return false; + if (StartPos != Layout->getElementOffsetInBits(I)) + return false; + StartPos += DL.getTypeAllocSizeInBits(ElTy); + } + + return true; +} + /// Get pointer operand of memory accessing instruction. If \p I is /// not a memory accessing instruction, return nullptr. If \p AllowVolatile, /// is set to false and the instruction is volatile, return nullptr. @@ -253,8 +302,9 @@ static Value *constructPointer(Type *ResTy, Type *PtrElemTy, Value *Ptr, /// once. Note that the value used for the callback may still be the value /// associated with \p IRP (due to PHIs). To limit how much effort is invested, /// we will never visit more values than specified by \p MaxValues. -/// If \p Intraprocedural is set to true only values valid in the scope of -/// \p CtxI will be visited and simplification into other scopes is prevented. +/// If \p VS does not contain the Interprocedural bit, only values valid in the +/// scope of \p CtxI will be visited and simplification into other scopes is +/// prevented. template <typename StateTy> static bool genericValueTraversal( Attributor &A, IRPosition IRP, const AbstractAttribute &QueryingAA, @@ -264,13 +314,13 @@ static bool genericValueTraversal( const Instruction *CtxI, bool &UsedAssumedInformation, bool UseValueSimplify = true, int MaxValues = 16, function_ref<Value *(Value *)> StripCB = nullptr, - bool Intraprocedural = false) { + AA::ValueScope VS = AA::Interprocedural) { struct LivenessInfo { const AAIsDead *LivenessAA = nullptr; bool AnyDead = false; }; - DenseMap<const Function *, LivenessInfo> LivenessAAs; + SmallMapVector<const Function *, LivenessInfo, 4> LivenessAAs; auto GetLivenessInfo = [&](const Function &F) -> LivenessInfo & { LivenessInfo &LI = LivenessAAs[&F]; if (!LI.LivenessAA) @@ -329,7 +379,7 @@ static bool genericValueTraversal( if (auto *SI = dyn_cast<SelectInst>(V)) { Optional<Constant *> C = A.getAssumedConstant( *SI->getCondition(), QueryingAA, UsedAssumedInformation); - bool NoValueYet = !C.hasValue(); + bool NoValueYet = !C; if (NoValueYet || isa_and_nonnull<UndefValue>(*C)) continue; if (auto *CI = dyn_cast_or_null<ConstantInt>(*C)) { @@ -362,7 +412,7 @@ static bool genericValueTraversal( } if (auto *Arg = dyn_cast<Argument>(V)) { - if (!Intraprocedural && !Arg->hasPassPointeeByValueCopyAttr()) { + if ((VS & AA::Interprocedural) && !Arg->hasPassPointeeByValueCopyAttr()) { SmallVector<Item> CallSiteValues; bool UsedAssumedInformation = false; if (A.checkForAllCallSites( @@ -385,11 +435,11 @@ static bool genericValueTraversal( if (UseValueSimplify && !isa<Constant>(V)) { Optional<Value *> SimpleV = A.getAssumedSimplified(*V, QueryingAA, UsedAssumedInformation); - if (!SimpleV.hasValue()) + if (!SimpleV) continue; Value *NewV = SimpleV.getValue(); if (NewV && NewV != V) { - if (!Intraprocedural || !CtxI || + if ((VS & AA::Interprocedural) || !CtxI || AA::isValidInScope(*NewV, CtxI->getFunction())) { Worklist.push_back({NewV, CtxI}); continue; @@ -397,6 +447,37 @@ static bool genericValueTraversal( } } + if (auto *LI = dyn_cast<LoadInst>(V)) { + bool UsedAssumedInformation = false; + // If we ask for the potentially loaded values from the initial pointer we + // will simply end up here again. The load is as far as we can make it. + if (LI->getPointerOperand() != InitialV) { + SmallSetVector<Value *, 4> PotentialCopies; + SmallSetVector<Instruction *, 4> PotentialValueOrigins; + if (AA::getPotentiallyLoadedValues(A, *LI, PotentialCopies, + PotentialValueOrigins, QueryingAA, + UsedAssumedInformation, + /* OnlyExact */ true)) { + // Values have to be dynamically unique or we loose the fact that a + // single llvm::Value might represent two runtime values (e.g., stack + // locations in different recursive calls). + bool DynamicallyUnique = + llvm::all_of(PotentialCopies, [&A, &QueryingAA](Value *PC) { + return AA::isDynamicallyUnique(A, QueryingAA, *PC); + }); + if (DynamicallyUnique && + ((VS & AA::Interprocedural) || !CtxI || + llvm::all_of(PotentialCopies, [CtxI](Value *PC) { + return AA::isValidInScope(*PC, CtxI->getFunction()); + }))) { + for (auto *PotentialCopy : PotentialCopies) + Worklist.push_back({PotentialCopy, CtxI}); + continue; + } + } + } + } + // Once a leaf is reached we inform the user through the callback. if (!VisitValueCB(*V, CtxI, State, Iteration > 1)) { LLVM_DEBUG(dbgs() << "Generic value traversal visit callback failed for: " @@ -420,7 +501,7 @@ bool AA::getAssumedUnderlyingObjects(Attributor &A, const Value &Ptr, const AbstractAttribute &QueryingAA, const Instruction *CtxI, bool &UsedAssumedInformation, - bool Intraprocedural) { + AA::ValueScope VS) { auto StripCB = [&](Value *V) { return getUnderlyingObject(V); }; SmallPtrSet<Value *, 8> SeenObjects; auto VisitValueCB = [&SeenObjects](Value &Val, const Instruction *, @@ -432,15 +513,16 @@ bool AA::getAssumedUnderlyingObjects(Attributor &A, const Value &Ptr, }; if (!genericValueTraversal<decltype(Objects)>( A, IRPosition::value(Ptr), QueryingAA, Objects, VisitValueCB, CtxI, - UsedAssumedInformation, true, 32, StripCB, Intraprocedural)) + UsedAssumedInformation, true, 32, StripCB, VS)) return false; return true; } -const Value *stripAndAccumulateMinimalOffsets( - Attributor &A, const AbstractAttribute &QueryingAA, const Value *Val, - const DataLayout &DL, APInt &Offset, bool AllowNonInbounds, - bool UseAssumed = false) { +static const Value * +stripAndAccumulateOffsets(Attributor &A, const AbstractAttribute &QueryingAA, + const Value *Val, const DataLayout &DL, APInt &Offset, + bool GetMinOffset, bool AllowNonInbounds, + bool UseAssumed = false) { auto AttributorAnalysis = [&](Value &V, APInt &ROffset) -> bool { const IRPosition &Pos = IRPosition::value(V); @@ -451,14 +533,20 @@ const Value *stripAndAccumulateMinimalOffsets( : DepClassTy::NONE); ConstantRange Range = UseAssumed ? ValueConstantRangeAA.getAssumed() : ValueConstantRangeAA.getKnown(); + if (Range.isFullSet()) + return false; + // We can only use the lower part of the range because the upper part can // be higher than what the value can really be. - ROffset = Range.getSignedMin(); + if (GetMinOffset) + ROffset = Range.getSignedMin(); + else + ROffset = Range.getSignedMax(); return true; }; return Val->stripAndAccumulateConstantOffsets(DL, Offset, AllowNonInbounds, - /* AllowInvariant */ false, + /* AllowInvariant */ true, AttributorAnalysis); } @@ -467,8 +555,9 @@ getMinimalBaseOfPointer(Attributor &A, const AbstractAttribute &QueryingAA, const Value *Ptr, int64_t &BytesOffset, const DataLayout &DL, bool AllowNonInbounds = false) { APInt OffsetAPInt(DL.getIndexTypeSizeInBits(Ptr->getType()), 0); - const Value *Base = stripAndAccumulateMinimalOffsets( - A, QueryingAA, Ptr, DL, OffsetAPInt, AllowNonInbounds); + const Value *Base = + stripAndAccumulateOffsets(A, QueryingAA, Ptr, DL, OffsetAPInt, + /* GetMinOffset */ true, AllowNonInbounds); BytesOffset = OffsetAPInt.getSExtValue(); return Base; @@ -502,10 +591,9 @@ static void clampReturnedValueStates( LLVM_DEBUG(dbgs() << "[Attributor] RV: " << RV << " AA: " << AA.getAsStr() << " @ " << RVPos << "\n"); const StateType &AAS = AA.getState(); - if (T.hasValue()) - *T &= AAS; - else - T = AAS; + if (!T) + T = StateType::getBestState(AAS); + *T &= AAS; LLVM_DEBUG(dbgs() << "[Attributor] AA State: " << AAS << " RV State: " << T << "\n"); return T->isValidState(); @@ -513,7 +601,7 @@ static void clampReturnedValueStates( if (!A.checkForAllReturnedValues(CheckReturnValue, QueryingAA)) S.indicatePessimisticFixpoint(); - else if (T.hasValue()) + else if (T) S ^= *T; } @@ -569,10 +657,9 @@ static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA, LLVM_DEBUG(dbgs() << "[Attributor] ACS: " << *ACS.getInstruction() << " AA: " << AA.getAsStr() << " @" << ACSArgPos << "\n"); const StateType &AAS = AA.getState(); - if (T.hasValue()) - *T &= AAS; - else - T = AAS; + if (!T) + T = StateType::getBestState(AAS); + *T &= AAS; LLVM_DEBUG(dbgs() << "[Attributor] AA State: " << AAS << " CSA State: " << T << "\n"); return T->isValidState(); @@ -582,7 +669,7 @@ static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA, if (!A.checkForAllCallSites(CallSiteCheck, QueryingAA, true, UsedAssumedInformation)) S.indicatePessimisticFixpoint(); - else if (T.hasValue()) + else if (T) S ^= *T; } @@ -676,7 +763,6 @@ struct AACallSiteReturnedFromReturned : public BaseType { return clampStateAndIndicateChange(S, AA.getState()); } }; -} // namespace /// Helper function to accumulate uses. template <class AAType, typename StateType = typename AAType::StateType> @@ -788,6 +874,7 @@ static void followUsesInMBEC(AAType &AA, Attributor &A, StateType &S, S += ParentState; } } +} // namespace /// ------------------------ PointerInfo --------------------------------------- @@ -795,9 +882,6 @@ namespace llvm { namespace AA { namespace PointerInfo { -/// An access kind description as used by AAPointerInfo. -struct OffsetAndSize; - struct State; } // namespace PointerInfo @@ -815,7 +899,7 @@ struct DenseMapInfo<AAPointerInfo::Access> : DenseMapInfo<Instruction *> { /// Helper that allows OffsetAndSize as a key in a DenseMap. template <> -struct DenseMapInfo<AA::PointerInfo ::OffsetAndSize> +struct DenseMapInfo<AAPointerInfo ::OffsetAndSize> : DenseMapInfo<std::pair<int64_t, int64_t>> {}; /// Helper for AA::PointerInfo::Acccess DenseMap/Set usage ignoring everythign @@ -831,90 +915,15 @@ struct AccessAsInstructionInfo : DenseMapInfo<Instruction *> { } // namespace llvm -/// Helper to represent an access offset and size, with logic to deal with -/// uncertainty and check for overlapping accesses. -struct AA::PointerInfo::OffsetAndSize : public std::pair<int64_t, int64_t> { - using BaseTy = std::pair<int64_t, int64_t>; - OffsetAndSize(int64_t Offset, int64_t Size) : BaseTy(Offset, Size) {} - OffsetAndSize(const BaseTy &P) : BaseTy(P) {} - int64_t getOffset() const { return first; } - int64_t getSize() const { return second; } - static OffsetAndSize getUnknown() { return OffsetAndSize(Unknown, Unknown); } - - /// Return true if offset or size are unknown. - bool offsetOrSizeAreUnknown() const { - return getOffset() == OffsetAndSize::Unknown || - getSize() == OffsetAndSize::Unknown; - } - - /// Return true if this offset and size pair might describe an address that - /// overlaps with \p OAS. - bool mayOverlap(const OffsetAndSize &OAS) const { - // Any unknown value and we are giving up -> overlap. - if (offsetOrSizeAreUnknown() || OAS.offsetOrSizeAreUnknown()) - return true; - - // Check if one offset point is in the other interval [offset, offset+size]. - return OAS.getOffset() + OAS.getSize() > getOffset() && - OAS.getOffset() < getOffset() + getSize(); - } - - /// Constant used to represent unknown offset or sizes. - static constexpr int64_t Unknown = 1 << 31; -}; - -/// Implementation of the DenseMapInfo. -/// -///{ -inline llvm::AccessAsInstructionInfo::Access -llvm::AccessAsInstructionInfo::getEmptyKey() { - return Access(Base::getEmptyKey(), nullptr, AAPointerInfo::AK_READ, nullptr); -} -inline llvm::AccessAsInstructionInfo::Access -llvm::AccessAsInstructionInfo::getTombstoneKey() { - return Access(Base::getTombstoneKey(), nullptr, AAPointerInfo::AK_READ, - nullptr); -} -unsigned llvm::AccessAsInstructionInfo::getHashValue( - const llvm::AccessAsInstructionInfo::Access &A) { - return Base::getHashValue(A.getRemoteInst()); -} -bool llvm::AccessAsInstructionInfo::isEqual( - const llvm::AccessAsInstructionInfo::Access &LHS, - const llvm::AccessAsInstructionInfo::Access &RHS) { - return LHS.getRemoteInst() == RHS.getRemoteInst(); -} -inline llvm::DenseMapInfo<AAPointerInfo::Access>::Access -llvm::DenseMapInfo<AAPointerInfo::Access>::getEmptyKey() { - return AAPointerInfo::Access(nullptr, nullptr, AAPointerInfo::AK_READ, - nullptr); -} -inline llvm::DenseMapInfo<AAPointerInfo::Access>::Access -llvm::DenseMapInfo<AAPointerInfo::Access>::getTombstoneKey() { - return AAPointerInfo::Access(nullptr, nullptr, AAPointerInfo::AK_WRITE, - nullptr); -} - -unsigned llvm::DenseMapInfo<AAPointerInfo::Access>::getHashValue( - const llvm::DenseMapInfo<AAPointerInfo::Access>::Access &A) { - return detail::combineHashValue( - DenseMapInfo<Instruction *>::getHashValue(A.getRemoteInst()), - (A.isWrittenValueYetUndetermined() - ? ~0 - : DenseMapInfo<Value *>::getHashValue(A.getWrittenValue()))) + - A.getKind(); -} - -bool llvm::DenseMapInfo<AAPointerInfo::Access>::isEqual( - const llvm::DenseMapInfo<AAPointerInfo::Access>::Access &LHS, - const llvm::DenseMapInfo<AAPointerInfo::Access>::Access &RHS) { - return LHS == RHS; -} -///} - /// A type to track pointer/struct usage and accesses for AAPointerInfo. struct AA::PointerInfo::State : public AbstractState { + ~State() { + // We do not delete the Accesses objects but need to destroy them still. + for (auto &It : AccessBins) + It.second->~Accesses(); + } + /// Return the best possible representable state. static State getBestState(const State &SIS) { return State(); } @@ -925,9 +934,10 @@ struct AA::PointerInfo::State : public AbstractState { return R; } - State() {} - State(const State &SIS) : AccessBins(SIS.AccessBins) {} - State(State &&SIS) : AccessBins(std::move(SIS.AccessBins)) {} + State() = default; + State(State &&SIS) : AccessBins(std::move(SIS.AccessBins)) { + SIS.AccessBins.clear(); + } const State &getAssumed() const { return *this; } @@ -976,15 +986,11 @@ struct AA::PointerInfo::State : public AbstractState { return false; auto &Accs = It->getSecond(); auto &RAccs = RIt->getSecond(); - if (Accs.size() != RAccs.size()) + if (Accs->size() != RAccs->size()) return false; - auto AccIt = Accs.begin(), RAccIt = RAccs.begin(), AccE = Accs.end(); - while (AccIt != AccE) { - if (*AccIt != *RAccIt) + for (const auto &ZipIt : llvm::zip(*Accs, *RAccs)) + if (std::get<0>(ZipIt) != std::get<1>(ZipIt)) return false; - ++AccIt; - ++RAccIt; - } ++It; ++RIt; } @@ -993,42 +999,88 @@ struct AA::PointerInfo::State : public AbstractState { bool operator!=(const State &R) const { return !(*this == R); } /// We store accesses in a set with the instruction as key. - using Accesses = DenseSet<AAPointerInfo::Access, AccessAsInstructionInfo>; + struct Accesses { + SmallVector<AAPointerInfo::Access, 4> Accesses; + DenseMap<const Instruction *, unsigned> Map; + + unsigned size() const { return Accesses.size(); } + + using vec_iterator = decltype(Accesses)::iterator; + vec_iterator begin() { return Accesses.begin(); } + vec_iterator end() { return Accesses.end(); } + + using iterator = decltype(Map)::const_iterator; + iterator find(AAPointerInfo::Access &Acc) { + return Map.find(Acc.getRemoteInst()); + } + iterator find_end() { return Map.end(); } + + AAPointerInfo::Access &get(iterator &It) { + return Accesses[It->getSecond()]; + } + + void insert(AAPointerInfo::Access &Acc) { + Map[Acc.getRemoteInst()] = Accesses.size(); + Accesses.push_back(Acc); + } + }; /// We store all accesses in bins denoted by their offset and size. - using AccessBinsTy = DenseMap<OffsetAndSize, Accesses>; + using AccessBinsTy = DenseMap<AAPointerInfo::OffsetAndSize, Accesses *>; AccessBinsTy::const_iterator begin() const { return AccessBins.begin(); } AccessBinsTy::const_iterator end() const { return AccessBins.end(); } protected: /// The bins with all the accesses for the associated pointer. - DenseMap<OffsetAndSize, Accesses> AccessBins; + AccessBinsTy AccessBins; /// Add a new access to the state at offset \p Offset and with size \p Size. /// The access is associated with \p I, writes \p Content (if anything), and /// is of kind \p Kind. /// \Returns CHANGED, if the state changed, UNCHANGED otherwise. - ChangeStatus addAccess(int64_t Offset, int64_t Size, Instruction &I, - Optional<Value *> Content, + ChangeStatus addAccess(Attributor &A, int64_t Offset, int64_t Size, + Instruction &I, Optional<Value *> Content, AAPointerInfo::AccessKind Kind, Type *Ty, Instruction *RemoteI = nullptr, Accesses *BinPtr = nullptr) { - OffsetAndSize Key{Offset, Size}; - Accesses &Bin = BinPtr ? *BinPtr : AccessBins[Key]; + AAPointerInfo::OffsetAndSize Key{Offset, Size}; + Accesses *&Bin = BinPtr ? BinPtr : AccessBins[Key]; + if (!Bin) + Bin = new (A.Allocator) Accesses; AAPointerInfo::Access Acc(&I, RemoteI ? RemoteI : &I, Content, Kind, Ty); // Check if we have an access for this instruction in this bin, if not, // simply add it. - auto It = Bin.find(Acc); - if (It == Bin.end()) { - Bin.insert(Acc); + auto It = Bin->find(Acc); + if (It == Bin->find_end()) { + Bin->insert(Acc); return ChangeStatus::CHANGED; } // If the existing access is the same as then new one, nothing changed. - AAPointerInfo::Access Before = *It; + AAPointerInfo::Access &Current = Bin->get(It); + AAPointerInfo::Access Before = Current; // The new one will be combined with the existing one. - *It &= Acc; - return *It == Before ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED; + Current &= Acc; + return Current == Before ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED; + } + + /// See AAPointerInfo::forallInterferingAccesses. + bool forallInterferingAccesses( + AAPointerInfo::OffsetAndSize OAS, + function_ref<bool(const AAPointerInfo::Access &, bool)> CB) const { + if (!isValidState()) + return false; + + for (auto &It : AccessBins) { + AAPointerInfo::OffsetAndSize ItOAS = It.getFirst(); + if (!OAS.mayOverlap(ItOAS)) + continue; + bool IsExact = OAS == ItOAS && !OAS.offsetOrSizeAreUnknown(); + for (auto &Access : *It.getSecond()) + if (!CB(Access, IsExact)) + return false; + } + return true; } /// See AAPointerInfo::forallInterferingAccesses. @@ -1037,10 +1089,11 @@ protected: function_ref<bool(const AAPointerInfo::Access &, bool)> CB) const { if (!isValidState()) return false; + // First find the offset and size of I. - OffsetAndSize OAS(-1, -1); + AAPointerInfo::OffsetAndSize OAS(-1, -1); for (auto &It : AccessBins) { - for (auto &Access : It.getSecond()) { + for (auto &Access : *It.getSecond()) { if (Access.getRemoteInst() == &I) { OAS = It.getFirst(); break; @@ -1049,21 +1102,13 @@ protected: if (OAS.getSize() != -1) break; } + // No access for I was found, we are done. if (OAS.getSize() == -1) return true; // Now that we have an offset and size, find all overlapping ones and use // the callback on the accesses. - for (auto &It : AccessBins) { - OffsetAndSize ItOAS = It.getFirst(); - if (!OAS.mayOverlap(ItOAS)) - continue; - bool IsExact = OAS == ItOAS && !OAS.offsetOrSizeAreUnknown(); - for (auto &Access : It.getSecond()) - if (!CB(Access, IsExact)) - return false; - } - return true; + return forallInterferingAccesses(OAS, CB); } private: @@ -1071,6 +1116,7 @@ private: BooleanState BS; }; +namespace { struct AAPointerInfoImpl : public StateWrapper<AA::PointerInfo::State, AAPointerInfo> { using BaseTy = StateWrapper<AA::PointerInfo::State, AAPointerInfo>; @@ -1093,22 +1139,18 @@ struct AAPointerInfoImpl } bool forallInterferingAccesses( - LoadInst &LI, function_ref<bool(const AAPointerInfo::Access &, bool)> CB) + OffsetAndSize OAS, + function_ref<bool(const AAPointerInfo::Access &, bool)> CB) const override { - return State::forallInterferingAccesses(LI, CB); + return State::forallInterferingAccesses(OAS, CB); } bool forallInterferingAccesses( - StoreInst &SI, function_ref<bool(const AAPointerInfo::Access &, bool)> CB) - const override { - return State::forallInterferingAccesses(SI, CB); - } - bool forallInterferingWrites( - Attributor &A, const AbstractAttribute &QueryingAA, LoadInst &LI, + Attributor &A, const AbstractAttribute &QueryingAA, Instruction &I, function_ref<bool(const Access &, bool)> UserCB) const override { SmallPtrSet<const Access *, 8> DominatingWrites; - SmallVector<std::pair<const Access *, bool>, 8> InterferingWrites; + SmallVector<std::pair<const Access *, bool>, 8> InterferingAccesses; - Function &Scope = *LI.getFunction(); + Function &Scope = *I.getFunction(); const auto &NoSyncAA = A.getAAFor<AANoSync>( QueryingAA, IRPosition::function(Scope), DepClassTy::OPTIONAL); const auto *ExecDomainAA = A.lookupAAFor<AAExecutionDomain>( @@ -1136,13 +1178,15 @@ struct AAPointerInfoImpl // TODO: Use inter-procedural reachability and dominance. const auto &NoRecurseAA = A.getAAFor<AANoRecurse>( - QueryingAA, IRPosition::function(*LI.getFunction()), - DepClassTy::OPTIONAL); + QueryingAA, IRPosition::function(Scope), DepClassTy::OPTIONAL); - const bool CanUseCFGResoning = CanIgnoreThreading(LI); + const bool FindInterferingWrites = I.mayReadFromMemory(); + const bool FindInterferingReads = I.mayWriteToMemory(); + const bool UseDominanceReasoning = FindInterferingWrites; + const bool CanUseCFGResoning = CanIgnoreThreading(I); InformationCache &InfoCache = A.getInfoCache(); const DominatorTree *DT = - NoRecurseAA.isKnownNoRecurse() + NoRecurseAA.isKnownNoRecurse() && UseDominanceReasoning ? InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>( Scope) : nullptr; @@ -1198,33 +1242,37 @@ struct AAPointerInfoImpl } auto AccessCB = [&](const Access &Acc, bool Exact) { - if (!Acc.isWrite()) + if ((!FindInterferingWrites || !Acc.isWrite()) && + (!FindInterferingReads || !Acc.isRead())) return true; // For now we only filter accesses based on CFG reasoning which does not // work yet if we have threading effects, or the access is complicated. if (CanUseCFGResoning) { - if (!AA::isPotentiallyReachable(A, *Acc.getLocalInst(), LI, QueryingAA, - IsLiveInCalleeCB)) + if ((!Acc.isWrite() || + !AA::isPotentiallyReachable(A, *Acc.getLocalInst(), I, QueryingAA, + IsLiveInCalleeCB)) && + (!Acc.isRead() || + !AA::isPotentiallyReachable(A, I, *Acc.getLocalInst(), QueryingAA, + IsLiveInCalleeCB))) return true; - if (DT && Exact && - (Acc.getLocalInst()->getFunction() == LI.getFunction()) && + if (DT && Exact && (Acc.getLocalInst()->getFunction() == &Scope) && IsSameThreadAsLoad(Acc)) { - if (DT->dominates(Acc.getLocalInst(), &LI)) + if (DT->dominates(Acc.getLocalInst(), &I)) DominatingWrites.insert(&Acc); } } - InterferingWrites.push_back({&Acc, Exact}); + InterferingAccesses.push_back({&Acc, Exact}); return true; }; - if (!State::forallInterferingAccesses(LI, AccessCB)) + if (!State::forallInterferingAccesses(I, AccessCB)) return false; // If we cannot use CFG reasoning we only filter the non-write accesses // and are done here. if (!CanUseCFGResoning) { - for (auto &It : InterferingWrites) + for (auto &It : InterferingAccesses) if (!UserCB(*It.first, It.second)) return false; return true; @@ -1251,11 +1299,11 @@ struct AAPointerInfoImpl return false; }; - // Run the user callback on all writes we cannot skip and return if that + // Run the user callback on all accesses we cannot skip and return if that // succeeded for all or not. - unsigned NumInterferingWrites = InterferingWrites.size(); - for (auto &It : InterferingWrites) { - if (!DT || NumInterferingWrites > MaxInterferingWrites || + unsigned NumInterferingAccesses = InterferingAccesses.size(); + for (auto &It : InterferingAccesses) { + if (!DT || NumInterferingAccesses > MaxInterferingAccesses || !CanSkipAccess(*It.first, It.second)) { if (!UserCB(*It.first, It.second)) return false; @@ -1264,36 +1312,39 @@ struct AAPointerInfoImpl return true; } - ChangeStatus translateAndAddCalleeState(Attributor &A, - const AAPointerInfo &CalleeAA, - int64_t CallArgOffset, CallBase &CB) { + ChangeStatus translateAndAddState(Attributor &A, const AAPointerInfo &OtherAA, + int64_t Offset, CallBase &CB, + bool FromCallee = false) { using namespace AA::PointerInfo; - if (!CalleeAA.getState().isValidState() || !isValidState()) + if (!OtherAA.getState().isValidState() || !isValidState()) return indicatePessimisticFixpoint(); - const auto &CalleeImplAA = static_cast<const AAPointerInfoImpl &>(CalleeAA); - bool IsByval = CalleeImplAA.getAssociatedArgument()->hasByValAttr(); + const auto &OtherAAImpl = static_cast<const AAPointerInfoImpl &>(OtherAA); + bool IsByval = + FromCallee && OtherAAImpl.getAssociatedArgument()->hasByValAttr(); // Combine the accesses bin by bin. ChangeStatus Changed = ChangeStatus::UNCHANGED; - for (auto &It : CalleeImplAA.getState()) { + for (auto &It : OtherAAImpl.getState()) { OffsetAndSize OAS = OffsetAndSize::getUnknown(); - if (CallArgOffset != OffsetAndSize::Unknown) - OAS = OffsetAndSize(It.first.getOffset() + CallArgOffset, - It.first.getSize()); - Accesses &Bin = AccessBins[OAS]; - for (const AAPointerInfo::Access &RAcc : It.second) { + if (Offset != OffsetAndSize::Unknown) + OAS = OffsetAndSize(It.first.getOffset() + Offset, It.first.getSize()); + Accesses *Bin = AccessBins.lookup(OAS); + for (const AAPointerInfo::Access &RAcc : *It.second) { if (IsByval && !RAcc.isRead()) continue; bool UsedAssumedInformation = false; - Optional<Value *> Content = A.translateArgumentToCallSiteContent( - RAcc.getContent(), CB, *this, UsedAssumedInformation); - AccessKind AK = - AccessKind(RAcc.getKind() & (IsByval ? AccessKind::AK_READ - : AccessKind::AK_READ_WRITE)); + AccessKind AK = RAcc.getKind(); + Optional<Value *> Content = RAcc.getContent(); + if (FromCallee) { + Content = A.translateArgumentToCallSiteContent( + RAcc.getContent(), CB, *this, UsedAssumedInformation); + AK = AccessKind( + AK & (IsByval ? AccessKind::AK_READ : AccessKind::AK_READ_WRITE)); + } Changed = - Changed | addAccess(OAS.getOffset(), OAS.getSize(), CB, Content, AK, - RAcc.getType(), RAcc.getRemoteInst(), &Bin); + Changed | addAccess(A, OAS.getOffset(), OAS.getSize(), CB, Content, + AK, RAcc.getType(), RAcc.getRemoteInst(), Bin); } } return Changed; @@ -1316,7 +1367,7 @@ struct AAPointerInfoFloating : public AAPointerInfoImpl { bool handleAccess(Attributor &A, Instruction &I, Value &Ptr, Optional<Value *> Content, AccessKind Kind, int64_t Offset, ChangeStatus &Changed, Type *Ty, - int64_t Size = AA::PointerInfo::OffsetAndSize::Unknown) { + int64_t Size = OffsetAndSize::Unknown) { using namespace AA::PointerInfo; // No need to find a size if one is given or the offset is unknown. if (Offset != OffsetAndSize::Unknown && Size == OffsetAndSize::Unknown && @@ -1326,13 +1377,13 @@ struct AAPointerInfoFloating : public AAPointerInfoImpl { if (!AccessSize.isScalable()) Size = AccessSize.getFixedSize(); } - Changed = Changed | addAccess(Offset, Size, I, Content, Kind, Ty); + Changed = Changed | addAccess(A, Offset, Size, I, Content, Kind, Ty); return true; }; /// Helper struct, will support ranges eventually. struct OffsetInfo { - int64_t Offset = AA::PointerInfo::OffsetAndSize::Unknown; + int64_t Offset = OffsetAndSize::Unknown; bool operator==(const OffsetInfo &OI) const { return Offset == OI.Offset; } }; @@ -1340,7 +1391,6 @@ struct AAPointerInfoFloating : public AAPointerInfoImpl { /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { using namespace AA::PointerInfo; - State S = getState(); ChangeStatus Changed = ChangeStatus::UNCHANGED; Value &AssociatedValue = getAssociatedValue(); @@ -1348,7 +1398,7 @@ struct AAPointerInfoFloating : public AAPointerInfoImpl { DenseMap<Value *, OffsetInfo> OffsetInfoMap; OffsetInfoMap[&AssociatedValue] = OffsetInfo{0}; - auto HandlePassthroughUser = [&](Value *Usr, OffsetInfo &PtrOI, + auto HandlePassthroughUser = [&](Value *Usr, OffsetInfo PtrOI, bool &Follow) { OffsetInfo &UsrOI = OffsetInfoMap[Usr]; UsrOI = PtrOI; @@ -1486,8 +1536,8 @@ struct AAPointerInfoFloating : public AAPointerInfoImpl { const auto &CSArgPI = A.getAAFor<AAPointerInfo>( *this, IRPosition::callsite_argument(*CB, ArgNo), DepClassTy::REQUIRED); - Changed = translateAndAddCalleeState( - A, CSArgPI, OffsetInfoMap[CurPtr].Offset, *CB) | + Changed = translateAndAddState(A, CSArgPI, + OffsetInfoMap[CurPtr].Offset, *CB) | Changed; return true; } @@ -1508,7 +1558,7 @@ struct AAPointerInfoFloating : public AAPointerInfoImpl { }; if (!A.checkForAllUses(UsePred, *this, AssociatedValue, /* CheckBBLivenessOnly */ true, DepClassTy::OPTIONAL, - EquivalentUseCB)) + /* IgnoreDroppableUses */ true, EquivalentUseCB)) return indicatePessimisticFixpoint(); LLVM_DEBUG({ @@ -1516,15 +1566,19 @@ struct AAPointerInfoFloating : public AAPointerInfoImpl { for (auto &It : AccessBins) { dbgs() << "[" << It.first.getOffset() << "-" << It.first.getOffset() + It.first.getSize() - << "] : " << It.getSecond().size() << "\n"; - for (auto &Acc : It.getSecond()) { + << "] : " << It.getSecond()->size() << "\n"; + for (auto &Acc : *It.getSecond()) { dbgs() << " - " << Acc.getKind() << " - " << *Acc.getLocalInst() << "\n"; if (Acc.getLocalInst() != Acc.getRemoteInst()) dbgs() << " --> " << *Acc.getRemoteInst() << "\n"; - if (!Acc.isWrittenValueYetUndetermined()) - dbgs() << " - " << Acc.getWrittenValue() << "\n"; + if (!Acc.isWrittenValueYetUndetermined()) { + if (Acc.getWrittenValue()) + dbgs() << " - c: " << *Acc.getWrittenValue() << "\n"; + else + dbgs() << " - c: <unknown>\n"; + } } } }); @@ -1587,7 +1641,7 @@ struct AAPointerInfoCallSiteArgument final : AAPointerInfoFloating { LengthVal = Length->getSExtValue(); Value &Ptr = getAssociatedValue(); unsigned ArgNo = getIRPosition().getCallSiteArgNo(); - ChangeStatus Changed; + ChangeStatus Changed = ChangeStatus::UNCHANGED; if (ArgNo == 0) { handleAccess(A, *MI, Ptr, nullptr, AccessKind::AK_WRITE, 0, Changed, nullptr, LengthVal); @@ -1612,7 +1666,8 @@ struct AAPointerInfoCallSiteArgument final : AAPointerInfoFloating { const IRPosition &ArgPos = IRPosition::argument(*Arg); auto &ArgAA = A.getAAFor<AAPointerInfo>(*this, ArgPos, DepClassTy::REQUIRED); - return translateAndAddCalleeState(A, ArgAA, 0, *cast<CallBase>(getCtxI())); + return translateAndAddState(A, ArgAA, 0, *cast<CallBase>(getCtxI()), + /* FromCallee */ true); } /// See AbstractAttribute::trackStatistics() @@ -1630,9 +1685,11 @@ struct AAPointerInfoCallSiteReturned final : AAPointerInfoFloating { AAPointerInfoImpl::trackPointerInfoStatistics(getIRPosition()); } }; +} // namespace /// -----------------------NoUnwind Function Attribute-------------------------- +namespace { struct AANoUnwindImpl : AANoUnwind { AANoUnwindImpl(const IRPosition &IRP, Attributor &A) : AANoUnwind(IRP, A) {} @@ -1704,9 +1761,11 @@ struct AANoUnwindCallSite final : AANoUnwindImpl { /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nounwind); } }; +} // namespace /// --------------------- Function Return Values ------------------------------- +namespace { /// "Attribute" that collects all potential returned values and the return /// instructions that they arise from. /// @@ -1832,7 +1891,7 @@ ChangeStatus AAReturnedValuesImpl::manifest(Attributor &A) { // Check if we have an assumed unique return value that we could manifest. Optional<Value *> UniqueRV = getAssumedUniqueReturnValue(A); - if (!UniqueRV.hasValue() || !UniqueRV.getValue()) + if (!UniqueRV || !UniqueRV.getValue()) return Changed; // Bookkeeping. @@ -1911,7 +1970,7 @@ ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) { A, IRPosition::value(*Ret.getReturnValue()), *this, Ret, ReturnValueCB, &I, UsedAssumedInformation, /* UseValueSimplify */ true, /* MaxValues */ 16, - /* StripCB */ nullptr, /* Intraprocedural */ true); + /* StripCB */ nullptr, AA::Intraprocedural); }; // Discover returned values from all live returned instructions in the @@ -1953,20 +2012,10 @@ struct AAReturnedValuesCallSite final : AAReturnedValuesImpl { /// See AbstractAttribute::trackStatistics() void trackStatistics() const override {} }; +} // namespace /// ------------------------ NoSync Function Attribute ------------------------- -struct AANoSyncImpl : AANoSync { - AANoSyncImpl(const IRPosition &IRP, Attributor &A) : AANoSync(IRP, A) {} - - const std::string getAsStr() const override { - return getAssumed() ? "nosync" : "may-sync"; - } - - /// See AbstractAttribute::updateImpl(...). - ChangeStatus updateImpl(Attributor &A) override; -}; - bool AANoSync::isNonRelaxedAtomic(const Instruction *I) { if (!I->isAtomic()) return false; @@ -2009,6 +2058,18 @@ bool AANoSync::isNoSyncIntrinsic(const Instruction *I) { return false; } +namespace { +struct AANoSyncImpl : AANoSync { + AANoSyncImpl(const IRPosition &IRP, Attributor &A) : AANoSync(IRP, A) {} + + const std::string getAsStr() const override { + return getAssumed() ? "nosync" : "may-sync"; + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override; +}; + ChangeStatus AANoSyncImpl::updateImpl(Attributor &A) { auto CheckRWInstForNoSync = [&](Instruction &I) { @@ -2071,9 +2132,11 @@ struct AANoSyncCallSite final : AANoSyncImpl { /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(nosync); } }; +} // namespace /// ------------------------ No-Free Attributes ---------------------------- +namespace { struct AANoFreeImpl : public AANoFree { AANoFreeImpl(const IRPosition &IRP, Attributor &A) : AANoFree(IRP, A) {} @@ -2255,8 +2318,10 @@ struct AANoFreeCallSiteReturned final : AANoFreeFloating { /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(nofree) } }; +} // namespace /// ------------------------ NonNull Argument Attribute ------------------------ +namespace { static int64_t getKnownNonNullAndDerefBytesForUse( Attributor &A, const AbstractAttribute &QueryingAA, Value &AssociatedValue, const Use *U, const Instruction *I, bool &IsNonNull, bool &TrackUse) { @@ -2344,7 +2409,7 @@ struct AANonNullImpl : AANonNull { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { - Value &V = getAssociatedValue(); + Value &V = *getAssociatedValue().stripPointerCasts(); if (!NullIsDefined && hasAttr({Attribute::NonNull, Attribute::Dereferenceable}, /* IgnoreSubsumingPositions */ false, &A)) { @@ -2368,7 +2433,7 @@ struct AANonNullImpl : AANonNull { } } - if (isa<GlobalValue>(&getAssociatedValue())) { + if (isa<GlobalValue>(V)) { indicatePessimisticFixpoint(); return; } @@ -2486,9 +2551,11 @@ struct AANonNullCallSiteReturned final /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(nonnull) } }; +} // namespace /// ------------------------ No-Recurse Attributes ---------------------------- +namespace { struct AANoRecurseImpl : public AANoRecurse { AANoRecurseImpl(const IRPosition &IRP, Attributor &A) : AANoRecurse(IRP, A) {} @@ -2564,9 +2631,11 @@ struct AANoRecurseCallSite final : AANoRecurseImpl { /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(norecurse); } }; +} // namespace /// -------------------- Undefined-Behavior Attributes ------------------------ +namespace { struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { AAUndefinedBehaviorImpl(const IRPosition &IRP, Attributor &A) : AAUndefinedBehavior(IRP, A) {} @@ -2597,7 +2666,7 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { // Either we stopped and the appropriate action was taken, // or we got back a simplified value to continue. Optional<Value *> SimplifiedPtrOp = stopOnUndefOrAssumed(A, PtrOp, &I); - if (!SimplifiedPtrOp.hasValue() || !SimplifiedPtrOp.getValue()) + if (!SimplifiedPtrOp || !SimplifiedPtrOp.getValue()) return true; const Value *PtrOpVal = SimplifiedPtrOp.getValue(); @@ -2642,7 +2711,7 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { // or we got back a simplified value to continue. Optional<Value *> SimplifiedCond = stopOnUndefOrAssumed(A, BrInst->getCondition(), BrInst); - if (!SimplifiedCond.hasValue() || !SimplifiedCond.getValue()) + if (!SimplifiedCond || !*SimplifiedCond) return true; AssumedNoUBInsts.insert(&I); return true; @@ -2688,10 +2757,9 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { IRPosition::value(*ArgVal), *this, UsedAssumedInformation); if (UsedAssumedInformation) continue; - if (SimplifiedVal.hasValue() && !SimplifiedVal.getValue()) + if (SimplifiedVal && !SimplifiedVal.getValue()) return true; - if (!SimplifiedVal.hasValue() || - isa<UndefValue>(*SimplifiedVal.getValue())) { + if (!SimplifiedVal || isa<UndefValue>(*SimplifiedVal.getValue())) { KnownUBInsts.insert(&I); continue; } @@ -2712,7 +2780,7 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { // or we got back a simplified return value to continue. Optional<Value *> SimplifiedRetValue = stopOnUndefOrAssumed(A, RI.getReturnValue(), &I); - if (!SimplifiedRetValue.hasValue() || !SimplifiedRetValue.getValue()) + if (!SimplifiedRetValue || !*SimplifiedRetValue) return true; // Check if a return instruction always cause UB or not @@ -2790,7 +2858,7 @@ struct AAUndefinedBehaviorImpl : public AAUndefinedBehavior { case Instruction::AtomicRMW: return !AssumedNoUBInsts.count(I); case Instruction::Br: { - auto BrInst = cast<BranchInst>(I); + auto *BrInst = cast<BranchInst>(I); if (BrInst->isUnconditional()) return false; return !AssumedNoUBInsts.count(I); @@ -2861,13 +2929,13 @@ private: IRPosition::value(*V), *this, UsedAssumedInformation); if (!UsedAssumedInformation) { // Don't depend on assumed values. - if (!SimplifiedV.hasValue()) { + if (!SimplifiedV) { // If it is known (which we tested above) but it doesn't have a value, // then we can assume `undef` and hence the instruction is UB. KnownUBInsts.insert(I); return llvm::None; } - if (!SimplifiedV.getValue()) + if (!*SimplifiedV) return nullptr; V = *SimplifiedV; } @@ -2891,9 +2959,11 @@ struct AAUndefinedBehaviorFunction final : AAUndefinedBehaviorImpl { KnownUBInsts.size(); } }; +} // namespace /// ------------------------ Will-Return Attributes ---------------------------- +namespace { // Helper function that checks whether a function has any cycle which we don't // know if it is bounded or not. // Loops with maximum trip count are considered bounded, any other cycle not. @@ -3032,9 +3102,11 @@ struct AAWillReturnCallSite final : AAWillReturnImpl { /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(willreturn); } }; +} // namespace /// -------------------AAReachability Attribute-------------------------- +namespace { struct AAReachabilityImpl : AAReachability { AAReachabilityImpl(const IRPosition &IRP, Attributor &A) : AAReachability(IRP, A) {} @@ -3046,10 +3118,6 @@ struct AAReachabilityImpl : AAReachability { /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { - const auto &NoRecurseAA = A.getAAFor<AANoRecurse>( - *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); - if (!NoRecurseAA.isAssumedNoRecurse()) - return indicatePessimisticFixpoint(); return ChangeStatus::UNCHANGED; } }; @@ -3061,9 +3129,11 @@ struct AAReachabilityFunction final : public AAReachabilityImpl { /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_FN_ATTR(reachable); } }; +} // namespace /// ------------------------ NoAlias Argument Attribute ------------------------ +namespace { struct AANoAliasImpl : AANoAlias { AANoAliasImpl(const IRPosition &IRP, Attributor &A) : AANoAlias(IRP, A) { assert(getAssociatedType()->isPointerTy() && @@ -3260,14 +3330,20 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl { return false; } + auto IsDereferenceableOrNull = [&](Value *O, const DataLayout &DL) { + const auto &DerefAA = A.getAAFor<AADereferenceable>( + *this, IRPosition::value(*O), DepClassTy::OPTIONAL); + return DerefAA.getAssumedDereferenceableBytes(); + }; + A.recordDependence(NoAliasAA, *this, DepClassTy::OPTIONAL); const IRPosition &VIRP = IRPosition::value(getAssociatedValue()); const Function *ScopeFn = VIRP.getAnchorScope(); auto &NoCaptureAA = A.getAAFor<AANoCapture>(*this, VIRP, DepClassTy::NONE); // Check whether the value is captured in the scope using AANoCapture. - // Look at CFG and check only uses possibly executed before this - // callsite. + // Look at CFG and check only uses possibly executed before this + // callsite. auto UsePred = [&](const Use &U, bool &Follow) -> bool { Instruction *UserI = cast<Instruction>(U.getUser()); @@ -3279,12 +3355,6 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl { return true; if (ScopeFn) { - const auto &ReachabilityAA = A.getAAFor<AAReachability>( - *this, IRPosition::function(*ScopeFn), DepClassTy::OPTIONAL); - - if (!ReachabilityAA.isAssumedReachable(A, *UserI, *getCtxI())) - return true; - if (auto *CB = dyn_cast<CallBase>(UserI)) { if (CB->isArgOperand(&U)) { @@ -3298,17 +3368,26 @@ struct AANoAliasCallSiteArgument final : AANoAliasImpl { return true; } } + + if (!AA::isPotentiallyReachable(A, *UserI, *getCtxI(), *this)) + return true; } - // For cases which can potentially have more users - if (isa<GetElementPtrInst>(U) || isa<BitCastInst>(U) || isa<PHINode>(U) || - isa<SelectInst>(U)) { + // TODO: We should track the capturing uses in AANoCapture but the problem + // is CGSCC runs. For those we would need to "allow" AANoCapture for + // a value in the module slice. + switch (DetermineUseCaptureKind(U, IsDereferenceableOrNull)) { + case UseCaptureKind::NO_CAPTURE: + return true; + case UseCaptureKind::MAY_CAPTURE: + LLVM_DEBUG(dbgs() << "[AANoAliasCSArg] Unknown user: " << *UserI + << "\n"); + return false; + case UseCaptureKind::PASSTHROUGH: Follow = true; return true; } - - LLVM_DEBUG(dbgs() << "[AANoAliasCSArg] Unknown user: " << *U << "\n"); - return false; + llvm_unreachable("unknown UseCaptureKind"); }; if (!NoCaptureAA.isAssumedNoCaptureMaybeReturned()) { @@ -3437,12 +3516,21 @@ struct AANoAliasCallSiteReturned final : AANoAliasImpl { /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CSRET_ATTR(noalias); } }; +} // namespace /// -------------------AAIsDead Function Attribute----------------------- +namespace { struct AAIsDeadValueImpl : public AAIsDead { AAIsDeadValueImpl(const IRPosition &IRP, Attributor &A) : AAIsDead(IRP, A) {} + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + if (auto *Scope = getAnchorScope()) + if (!A.isRunOn(*Scope)) + indicatePessimisticFixpoint(); + } + /// See AAIsDead::isAssumedDead(). bool isAssumedDead() const override { return isAssumed(IS_DEAD); } @@ -3466,22 +3554,25 @@ struct AAIsDeadValueImpl : public AAIsDead { } /// See AbstractAttribute::getAsStr(). - const std::string getAsStr() const override { + virtual const std::string getAsStr() const override { return isAssumedDead() ? "assumed-dead" : "assumed-live"; } /// Check if all uses are assumed dead. bool areAllUsesAssumedDead(Attributor &A, Value &V) { // Callers might not check the type, void has no uses. - if (V.getType()->isVoidTy()) + if (V.getType()->isVoidTy() || V.use_empty()) return true; // If we replace a value with a constant there are no uses left afterwards. if (!isa<Constant>(V)) { + if (auto *I = dyn_cast<Instruction>(&V)) + if (!A.isRunOn(*I->getFunction())) + return false; bool UsedAssumedInformation = false; Optional<Constant *> C = A.getAssumedConstant(V, *this, UsedAssumedInformation); - if (!C.hasValue() || *C) + if (!C || *C) return true; } @@ -3491,7 +3582,8 @@ struct AAIsDeadValueImpl : public AAIsDead { // without going through N update cycles. This is not required for // correctness. return A.checkForAllUses(UsePred, *this, V, /* CheckBBLivenessOnly */ false, - DepClassTy::REQUIRED); + DepClassTy::REQUIRED, + /* IgnoreDroppableUses */ false); } /// Determine if \p I is assumed to be side-effect free. @@ -3522,6 +3614,8 @@ struct AAIsDeadFloating : public AAIsDeadValueImpl { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { + AAIsDeadValueImpl::initialize(A); + if (isa<UndefValue>(getAssociatedValue())) { indicatePessimisticFixpoint(); return; @@ -3552,6 +3646,15 @@ struct AAIsDeadFloating : public AAIsDeadValueImpl { }); } + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr() const override { + Instruction *I = dyn_cast<Instruction>(&getAssociatedValue()); + if (isa_and_nonnull<StoreInst>(I)) + if (isValidState()) + return "assumed-dead-store"; + return AAIsDeadValueImpl::getAsStr(); + } + /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { Instruction *I = dyn_cast<Instruction>(&getAssociatedValue()); @@ -3567,6 +3670,10 @@ struct AAIsDeadFloating : public AAIsDeadValueImpl { return ChangeStatus::UNCHANGED; } + bool isRemovableStore() const override { + return isAssumed(IS_REMOVABLE) && isa<StoreInst>(&getAssociatedValue()); + } + /// See AbstractAttribute::manifest(...). ChangeStatus manifest(Attributor &A) override { Value &V = getAssociatedValue(); @@ -3581,21 +3688,7 @@ struct AAIsDeadFloating : public AAIsDeadValueImpl { return ChangeStatus::CHANGED; } } - if (V.use_empty()) - return ChangeStatus::UNCHANGED; - - bool UsedAssumedInformation = false; - Optional<Constant *> C = - A.getAssumedConstant(V, *this, UsedAssumedInformation); - if (C.hasValue() && C.getValue()) - return ChangeStatus::UNCHANGED; - - // Replace the value with undef as it is dead but keep droppable uses around - // as they provide information we don't want to give up on just yet. - UndefValue &UV = *UndefValue::get(V.getType()); - bool AnyChange = - A.changeValueAfterManifest(V, UV, /* ChangeDropppable */ false); - return AnyChange ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED; + return ChangeStatus::UNCHANGED; } /// See AbstractAttribute::trackStatistics() @@ -3610,23 +3703,22 @@ struct AAIsDeadArgument : public AAIsDeadFloating { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { + AAIsDeadFloating::initialize(A); if (!A.isFunctionIPOAmendable(*getAnchorScope())) indicatePessimisticFixpoint(); } /// See AbstractAttribute::manifest(...). ChangeStatus manifest(Attributor &A) override { - ChangeStatus Changed = AAIsDeadFloating::manifest(A); Argument &Arg = *getAssociatedArgument(); if (A.isValidFunctionSignatureRewrite(Arg, /* ReplacementTypes */ {})) if (A.registerFunctionSignatureRewrite( Arg, /* ReplacementTypes */ {}, Attributor::ArgumentReplacementInfo::CalleeRepairCBTy{}, Attributor::ArgumentReplacementInfo::ACSRepairCBTy{})) { - Arg.dropDroppableUses(); return ChangeStatus::CHANGED; } - return Changed; + return ChangeStatus::UNCHANGED; } /// See AbstractAttribute::trackStatistics() @@ -3639,6 +3731,7 @@ struct AAIsDeadCallSiteArgument : public AAIsDeadValueImpl { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { + AAIsDeadValueImpl::initialize(A); if (isa<UndefValue>(getAssociatedValue())) indicatePessimisticFixpoint(); } @@ -3675,7 +3768,7 @@ struct AAIsDeadCallSiteArgument : public AAIsDeadValueImpl { struct AAIsDeadCallSiteReturned : public AAIsDeadFloating { AAIsDeadCallSiteReturned(const IRPosition &IRP, Attributor &A) - : AAIsDeadFloating(IRP, A), IsAssumedSideEffectFree(true) {} + : AAIsDeadFloating(IRP, A) {} /// See AAIsDead::isAssumedDead(). bool isAssumedDead() const override { @@ -3684,6 +3777,7 @@ struct AAIsDeadCallSiteReturned : public AAIsDeadFloating { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { + AAIsDeadFloating::initialize(A); if (isa<UndefValue>(getAssociatedValue())) { indicatePessimisticFixpoint(); return; @@ -3721,7 +3815,7 @@ struct AAIsDeadCallSiteReturned : public AAIsDeadFloating { } private: - bool IsAssumedSideEffectFree; + bool IsAssumedSideEffectFree = true; }; struct AAIsDeadReturned : public AAIsDeadValueImpl { @@ -3774,17 +3868,13 @@ struct AAIsDeadFunction : public AAIsDead { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { - const Function *F = getAnchorScope(); - if (F && !F->isDeclaration()) { - // We only want to compute liveness once. If the function is not part of - // the SCC, skip it. - if (A.isRunOn(*const_cast<Function *>(F))) { - ToBeExploredFrom.insert(&F->getEntryBlock().front()); - assumeLive(A, F->getEntryBlock()); - } else { - indicatePessimisticFixpoint(); - } + Function *F = getAnchorScope(); + if (!F || F->isDeclaration() || !A.isRunOn(*F)) { + indicatePessimisticFixpoint(); + return; } + ToBeExploredFrom.insert(&F->getEntryBlock().front()); + assumeLive(A, F->getEntryBlock()); } /// See AbstractAttribute::getAsStr(). @@ -3989,7 +4079,7 @@ identifyAliveSuccessors(Attributor &A, const BranchInst &BI, } else { Optional<Constant *> C = A.getAssumedConstant(*BI.getCondition(), AA, UsedAssumedInformation); - if (!C.hasValue() || isa_and_nonnull<UndefValue>(C.getValue())) { + if (!C || isa_and_nonnull<UndefValue>(*C)) { // No value yet, assume both edges are dead. } else if (isa_and_nonnull<ConstantInt>(*C)) { const BasicBlock *SuccBB = @@ -4011,7 +4101,7 @@ identifyAliveSuccessors(Attributor &A, const SwitchInst &SI, bool UsedAssumedInformation = false; Optional<Constant *> C = A.getAssumedConstant(*SI.getCondition(), AA, UsedAssumedInformation); - if (!C.hasValue() || isa_and_nonnull<UndefValue>(C.getValue())) { + if (!C || isa_and_nonnull<UndefValue>(C.getValue())) { // No value yet, assume all edges are dead. } else if (isa_and_nonnull<ConstantInt>(C.getValue())) { for (auto &CaseIt : SI.cases()) { @@ -4158,9 +4248,11 @@ struct AAIsDeadCallSite final : AAIsDeadFunction { /// See AbstractAttribute::trackStatistics() void trackStatistics() const override {} }; +} // namespace /// -------------------- Dereferenceable Argument Attribute -------------------- +namespace { struct AADereferenceableImpl : AADereferenceable { AADereferenceableImpl(const IRPosition &IRP, Attributor &A) : AADereferenceable(IRP, A) {} @@ -4168,6 +4260,7 @@ struct AADereferenceableImpl : AADereferenceable { /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { + Value &V = *getAssociatedValue().stripPointerCasts(); SmallVector<Attribute, 4> Attrs; getAttrs({Attribute::Dereferenceable, Attribute::DereferenceableOrNull}, Attrs, /* IgnoreSubsumingPositions */ false, &A); @@ -4178,9 +4271,8 @@ struct AADereferenceableImpl : AADereferenceable { NonNullAA = &A.getAAFor<AANonNull>(*this, IRP, DepClassTy::NONE); bool CanBeNull, CanBeFreed; - takeKnownDerefBytesMaximum( - IRP.getAssociatedValue().getPointerDereferenceableBytes( - A.getDataLayout(), CanBeNull, CanBeFreed)); + takeKnownDerefBytesMaximum(V.getPointerDereferenceableBytes( + A.getDataLayout(), CanBeNull, CanBeFreed)); bool IsFnInterface = IRP.isFnInterfaceKind(); Function *FnScope = IRP.getAnchorScope(); @@ -4279,8 +4371,9 @@ struct AADereferenceableFloating : AADereferenceableImpl { unsigned IdxWidth = DL.getIndexSizeInBits(V.getType()->getPointerAddressSpace()); APInt Offset(IdxWidth, 0); - const Value *Base = - stripAndAccumulateMinimalOffsets(A, *this, &V, DL, Offset, false); + const Value *Base = stripAndAccumulateOffsets( + A, *this, &V, DL, Offset, /* GetMinOffset */ false, + /* AllowNonInbounds */ true); const auto &AA = A.getAAFor<AADereferenceable>( *this, IRPosition::value(*Base), DepClassTy::REQUIRED); @@ -4395,9 +4488,11 @@ struct AADereferenceableCallSiteReturned final STATS_DECLTRACK_CS_ATTR(dereferenceable); } }; +} // namespace // ------------------------ Align Argument Attribute ------------------------ +namespace { static unsigned getKnownAlignForUse(Attributor &A, AAAlign &QueryingAA, Value &AssociatedValue, const Use *U, const Instruction *I, bool &TrackUse) { @@ -4468,14 +4563,8 @@ struct AAAlignImpl : AAAlign { for (const Attribute &Attr : Attrs) takeKnownMaximum(Attr.getValueAsInt()); - Value &V = getAssociatedValue(); - // TODO: This is a HACK to avoid getPointerAlignment to introduce a ptr2int - // use of the function pointer. This was caused by D73131. We want to - // avoid this for function pointers especially because we iterate - // their uses and int2ptr is not handled. It is not a correctness - // problem though! - if (!V.getType()->getPointerElementType()->isFunctionTy()) - takeKnownMaximum(V.getPointerAlignment(A.getDataLayout()).value()); + Value &V = *getAssociatedValue().stripPointerCasts(); + takeKnownMaximum(V.getPointerAlignment(A.getDataLayout()).value()); if (getIRPosition().isFnInterfaceKind() && (!getAnchorScope() || @@ -4497,16 +4586,16 @@ struct AAAlignImpl : AAAlign { for (const Use &U : AssociatedValue.uses()) { if (auto *SI = dyn_cast<StoreInst>(U.getUser())) { if (SI->getPointerOperand() == &AssociatedValue) - if (SI->getAlignment() < getAssumedAlign()) { + if (SI->getAlign() < getAssumedAlign()) { STATS_DECLTRACK(AAAlign, Store, "Number of times alignment added to a store"); - SI->setAlignment(Align(getAssumedAlign())); + SI->setAlignment(getAssumedAlign()); LoadStoreChanged = ChangeStatus::CHANGED; } } else if (auto *LI = dyn_cast<LoadInst>(U.getUser())) { if (LI->getPointerOperand() == &AssociatedValue) - if (LI->getAlignment() < getAssumedAlign()) { - LI->setAlignment(Align(getAssumedAlign())); + if (LI->getAlign() < getAssumedAlign()) { + LI->setAlignment(getAssumedAlign()); STATS_DECLTRACK(AAAlign, Load, "Number of times alignment added to a load"); LoadStoreChanged = ChangeStatus::CHANGED; @@ -4550,9 +4639,8 @@ struct AAAlignImpl : AAAlign { /// See AbstractAttribute::getAsStr(). const std::string getAsStr() const override { - return getAssumedAlign() ? ("align<" + std::to_string(getKnownAlign()) + - "-" + std::to_string(getAssumedAlign()) + ">") - : "unknown-align"; + return "align<" + std::to_string(getKnownAlign().value()) + "-" + + std::to_string(getAssumedAlign().value()) + ">"; } }; @@ -4566,6 +4654,8 @@ struct AAAlignFloating : AAAlignImpl { auto VisitValueCB = [&](Value &V, const Instruction *, AAAlign::StateType &T, bool Stripped) -> bool { + if (isa<UndefValue>(V) || isa<ConstantPointerNull>(V)) + return true; const auto &AA = A.getAAFor<AAAlign>(*this, IRPosition::value(V), DepClassTy::REQUIRED); if (!Stripped && this == &AA) { @@ -4573,6 +4663,7 @@ struct AAAlignFloating : AAAlignImpl { unsigned Alignment = 1; if (const Value *Base = GetPointerBaseWithConstantOffset(&V, Offset, DL)) { + // TODO: Use AAAlign for the base too. Align PA = Base->getPointerAlignment(DL); // BasePointerAddr + Offset = Alignment * Q for some integer Q. // So we can say that the maximum power of two which is a divisor of @@ -4677,7 +4768,7 @@ struct AAAlignCallSiteArgument final : AAAlignFloating { // so we do not need to track a dependence. const auto &ArgAlignAA = A.getAAFor<AAAlign>( *this, IRPosition::argument(*Arg), DepClassTy::NONE); - takeKnownMaximum(ArgAlignAA.getKnownAlign()); + takeKnownMaximum(ArgAlignAA.getKnownAlign().value()); } return Changed; } @@ -4704,8 +4795,10 @@ struct AAAlignCallSiteReturned final /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(align); } }; +} // namespace /// ------------------ Function No-Return Attribute ---------------------------- +namespace { struct AANoReturnImpl : public AANoReturn { AANoReturnImpl(const IRPosition &IRP, Attributor &A) : AANoReturn(IRP, A) {} @@ -4773,9 +4866,179 @@ struct AANoReturnCallSite final : AANoReturnImpl { /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_CS_ATTR(noreturn); } }; +} // namespace + +/// ----------------------- Instance Info --------------------------------- + +namespace { +/// A class to hold the state of for no-capture attributes. +struct AAInstanceInfoImpl : public AAInstanceInfo { + AAInstanceInfoImpl(const IRPosition &IRP, Attributor &A) + : AAInstanceInfo(IRP, A) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + Value &V = getAssociatedValue(); + if (auto *C = dyn_cast<Constant>(&V)) { + if (C->isThreadDependent()) + indicatePessimisticFixpoint(); + else + indicateOptimisticFixpoint(); + return; + } + if (auto *CB = dyn_cast<CallBase>(&V)) + if (CB->arg_size() == 0 && !CB->mayHaveSideEffects() && + !CB->mayReadFromMemory()) { + indicateOptimisticFixpoint(); + return; + } + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + ChangeStatus Changed = ChangeStatus::UNCHANGED; + + Value &V = getAssociatedValue(); + const Function *Scope = nullptr; + if (auto *I = dyn_cast<Instruction>(&V)) + Scope = I->getFunction(); + if (auto *A = dyn_cast<Argument>(&V)) { + Scope = A->getParent(); + if (!Scope->hasLocalLinkage()) + return Changed; + } + if (!Scope) + return indicateOptimisticFixpoint(); + + auto &NoRecurseAA = A.getAAFor<AANoRecurse>( + *this, IRPosition::function(*Scope), DepClassTy::OPTIONAL); + if (NoRecurseAA.isAssumedNoRecurse()) + return Changed; + + auto UsePred = [&](const Use &U, bool &Follow) { + const Instruction *UserI = dyn_cast<Instruction>(U.getUser()); + if (!UserI || isa<GetElementPtrInst>(UserI) || isa<CastInst>(UserI) || + isa<PHINode>(UserI) || isa<SelectInst>(UserI)) { + Follow = true; + return true; + } + if (isa<LoadInst>(UserI) || isa<CmpInst>(UserI) || + (isa<StoreInst>(UserI) && + cast<StoreInst>(UserI)->getValueOperand() != U.get())) + return true; + if (auto *CB = dyn_cast<CallBase>(UserI)) { + // This check is not guaranteeing uniqueness but for now that we cannot + // end up with two versions of \p U thinking it was one. + if (!CB->getCalledFunction() || + !CB->getCalledFunction()->hasLocalLinkage()) + return true; + if (!CB->isArgOperand(&U)) + return false; + const auto &ArgInstanceInfoAA = A.getAAFor<AAInstanceInfo>( + *this, IRPosition::callsite_argument(*CB, CB->getArgOperandNo(&U)), + DepClassTy::OPTIONAL); + if (!ArgInstanceInfoAA.isAssumedUniqueForAnalysis()) + return false; + // If this call base might reach the scope again we might forward the + // argument back here. This is very conservative. + if (AA::isPotentiallyReachable(A, *CB, *Scope, *this, nullptr)) + return false; + return true; + } + return false; + }; + + auto EquivalentUseCB = [&](const Use &OldU, const Use &NewU) { + if (auto *SI = dyn_cast<StoreInst>(OldU.getUser())) { + auto *Ptr = SI->getPointerOperand()->stripPointerCasts(); + if (isa<AllocaInst>(Ptr) && AA::isDynamicallyUnique(A, *this, *Ptr)) + return true; + auto *TLI = A.getInfoCache().getTargetLibraryInfoForFunction( + *SI->getFunction()); + if (isAllocationFn(Ptr, TLI) && AA::isDynamicallyUnique(A, *this, *Ptr)) + return true; + } + return false; + }; + + if (!A.checkForAllUses(UsePred, *this, V, /* CheckBBLivenessOnly */ true, + DepClassTy::OPTIONAL, + /* IgnoreDroppableUses */ true, EquivalentUseCB)) + return indicatePessimisticFixpoint(); + + return Changed; + } + + /// See AbstractState::getAsStr(). + const std::string getAsStr() const override { + return isAssumedUniqueForAnalysis() ? "<unique [fAa]>" : "<unknown>"; + } + + /// See AbstractAttribute::trackStatistics() + void trackStatistics() const override {} +}; + +/// InstanceInfo attribute for floating values. +struct AAInstanceInfoFloating : AAInstanceInfoImpl { + AAInstanceInfoFloating(const IRPosition &IRP, Attributor &A) + : AAInstanceInfoImpl(IRP, A) {} +}; + +/// NoCapture attribute for function arguments. +struct AAInstanceInfoArgument final : AAInstanceInfoFloating { + AAInstanceInfoArgument(const IRPosition &IRP, Attributor &A) + : AAInstanceInfoFloating(IRP, A) {} +}; + +/// InstanceInfo attribute for call site arguments. +struct AAInstanceInfoCallSiteArgument final : AAInstanceInfoImpl { + AAInstanceInfoCallSiteArgument(const IRPosition &IRP, Attributor &A) + : AAInstanceInfoImpl(IRP, A) {} + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + // TODO: Once we have call site specific value information we can provide + // call site specific liveness information and then it makes + // sense to specialize attributes for call sites arguments instead of + // redirecting requests to the callee argument. + Argument *Arg = getAssociatedArgument(); + if (!Arg) + return indicatePessimisticFixpoint(); + const IRPosition &ArgPos = IRPosition::argument(*Arg); + auto &ArgAA = + A.getAAFor<AAInstanceInfo>(*this, ArgPos, DepClassTy::REQUIRED); + return clampStateAndIndicateChange(getState(), ArgAA.getState()); + } +}; + +/// InstanceInfo attribute for function return value. +struct AAInstanceInfoReturned final : AAInstanceInfoImpl { + AAInstanceInfoReturned(const IRPosition &IRP, Attributor &A) + : AAInstanceInfoImpl(IRP, A) { + llvm_unreachable("InstanceInfo is not applicable to function returns!"); + } + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + llvm_unreachable("InstanceInfo is not applicable to function returns!"); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override { + llvm_unreachable("InstanceInfo is not applicable to function returns!"); + } +}; + +/// InstanceInfo attribute deduction for a call site return value. +struct AAInstanceInfoCallSiteReturned final : AAInstanceInfoFloating { + AAInstanceInfoCallSiteReturned(const IRPosition &IRP, Attributor &A) + : AAInstanceInfoFloating(IRP, A) {} +}; +} // namespace /// ----------------------- Variable Capturing --------------------------------- +namespace { /// A class to hold the state of for no-capture attributes. struct AANoCaptureImpl : public AANoCapture { AANoCaptureImpl(const IRPosition &IRP, Attributor &A) : AANoCapture(IRP, A) {} @@ -4883,143 +5146,69 @@ struct AANoCaptureImpl : public AANoCapture { return "assumed not-captured-maybe-returned"; return "assumed-captured"; } -}; - -/// Attributor-aware capture tracker. -struct AACaptureUseTracker final : public CaptureTracker { - - /// Create a capture tracker that can lookup in-flight abstract attributes - /// through the Attributor \p A. - /// - /// If a use leads to a potential capture, \p CapturedInMemory is set and the - /// search is stopped. If a use leads to a return instruction, - /// \p CommunicatedBack is set to true and \p CapturedInMemory is not changed. - /// If a use leads to a ptr2int which may capture the value, - /// \p CapturedInInteger is set. If a use is found that is currently assumed - /// "no-capture-maybe-returned", the user is added to the \p PotentialCopies - /// set. All values in \p PotentialCopies are later tracked as well. For every - /// explored use we decrement \p RemainingUsesToExplore. Once it reaches 0, - /// the search is stopped with \p CapturedInMemory and \p CapturedInInteger - /// conservatively set to true. - AACaptureUseTracker(Attributor &A, AANoCapture &NoCaptureAA, - const AAIsDead &IsDeadAA, AANoCapture::StateType &State, - SmallSetVector<Value *, 4> &PotentialCopies, - unsigned &RemainingUsesToExplore) - : A(A), NoCaptureAA(NoCaptureAA), IsDeadAA(IsDeadAA), State(State), - PotentialCopies(PotentialCopies), - RemainingUsesToExplore(RemainingUsesToExplore) {} - - /// Determine if \p V maybe captured. *Also updates the state!* - bool valueMayBeCaptured(const Value *V) { - if (V->getType()->isPointerTy()) { - PointerMayBeCaptured(V, this); - } else { - State.indicatePessimisticFixpoint(); - } - return State.isAssumed(AANoCapture::NO_CAPTURE_MAYBE_RETURNED); - } - - /// See CaptureTracker::tooManyUses(). - void tooManyUses() override { - State.removeAssumedBits(AANoCapture::NO_CAPTURE); - } - - bool isDereferenceableOrNull(Value *O, const DataLayout &DL) override { - if (CaptureTracker::isDereferenceableOrNull(O, DL)) - return true; - const auto &DerefAA = A.getAAFor<AADereferenceable>( - NoCaptureAA, IRPosition::value(*O), DepClassTy::OPTIONAL); - return DerefAA.getAssumedDereferenceableBytes(); - } - - /// See CaptureTracker::captured(...). - bool captured(const Use *U) override { - Instruction *UInst = cast<Instruction>(U->getUser()); - LLVM_DEBUG(dbgs() << "Check use: " << *U->get() << " in " << *UInst - << "\n"); - // Because we may reuse the tracker multiple times we keep track of the - // number of explored uses ourselves as well. - if (RemainingUsesToExplore-- == 0) { - LLVM_DEBUG(dbgs() << " - too many uses to explore!\n"); - return isCapturedIn(/* Memory */ true, /* Integer */ true, - /* Return */ true); - } + /// Check the use \p U and update \p State accordingly. Return true if we + /// should continue to update the state. + bool checkUse(Attributor &A, AANoCapture::StateType &State, const Use &U, + bool &Follow) { + Instruction *UInst = cast<Instruction>(U.getUser()); + LLVM_DEBUG(dbgs() << "[AANoCapture] Check use: " << *U.get() << " in " + << *UInst << "\n"); // Deal with ptr2int by following uses. if (isa<PtrToIntInst>(UInst)) { LLVM_DEBUG(dbgs() << " - ptr2int assume the worst!\n"); - return valueMayBeCaptured(UInst); + return isCapturedIn(State, /* Memory */ true, /* Integer */ true, + /* Return */ true); } - // For stores we check if we can follow the value through memory or not. - if (auto *SI = dyn_cast<StoreInst>(UInst)) { - if (SI->isVolatile()) - return isCapturedIn(/* Memory */ true, /* Integer */ false, - /* Return */ false); - bool UsedAssumedInformation = false; - if (!AA::getPotentialCopiesOfStoredValue( - A, *SI, PotentialCopies, NoCaptureAA, UsedAssumedInformation)) - return isCapturedIn(/* Memory */ true, /* Integer */ false, - /* Return */ false); - // Not captured directly, potential copies will be checked. - return isCapturedIn(/* Memory */ false, /* Integer */ false, + // For stores we already checked if we can follow them, if they make it + // here we give up. + if (isa<StoreInst>(UInst)) + return isCapturedIn(State, /* Memory */ true, /* Integer */ false, /* Return */ false); - } // Explicitly catch return instructions. if (isa<ReturnInst>(UInst)) { - if (UInst->getFunction() == NoCaptureAA.getAnchorScope()) - return isCapturedIn(/* Memory */ false, /* Integer */ false, + if (UInst->getFunction() == getAnchorScope()) + return isCapturedIn(State, /* Memory */ false, /* Integer */ false, /* Return */ true); - return isCapturedIn(/* Memory */ true, /* Integer */ true, + return isCapturedIn(State, /* Memory */ true, /* Integer */ true, /* Return */ true); } // For now we only use special logic for call sites. However, the tracker // itself knows about a lot of other non-capturing cases already. auto *CB = dyn_cast<CallBase>(UInst); - if (!CB || !CB->isArgOperand(U)) - return isCapturedIn(/* Memory */ true, /* Integer */ true, + if (!CB || !CB->isArgOperand(&U)) + return isCapturedIn(State, /* Memory */ true, /* Integer */ true, /* Return */ true); - unsigned ArgNo = CB->getArgOperandNo(U); + unsigned ArgNo = CB->getArgOperandNo(&U); const IRPosition &CSArgPos = IRPosition::callsite_argument(*CB, ArgNo); // If we have a abstract no-capture attribute for the argument we can use // it to justify a non-capture attribute here. This allows recursion! auto &ArgNoCaptureAA = - A.getAAFor<AANoCapture>(NoCaptureAA, CSArgPos, DepClassTy::REQUIRED); + A.getAAFor<AANoCapture>(*this, CSArgPos, DepClassTy::REQUIRED); if (ArgNoCaptureAA.isAssumedNoCapture()) - return isCapturedIn(/* Memory */ false, /* Integer */ false, + return isCapturedIn(State, /* Memory */ false, /* Integer */ false, /* Return */ false); if (ArgNoCaptureAA.isAssumedNoCaptureMaybeReturned()) { - addPotentialCopy(*CB); - return isCapturedIn(/* Memory */ false, /* Integer */ false, + Follow = true; + return isCapturedIn(State, /* Memory */ false, /* Integer */ false, /* Return */ false); } // Lastly, we could not find a reason no-capture can be assumed so we don't. - return isCapturedIn(/* Memory */ true, /* Integer */ true, + return isCapturedIn(State, /* Memory */ true, /* Integer */ true, /* Return */ true); } - /// Register \p CS as potential copy of the value we are checking. - void addPotentialCopy(CallBase &CB) { PotentialCopies.insert(&CB); } - - /// See CaptureTracker::shouldExplore(...). - bool shouldExplore(const Use *U) override { - // Check liveness and ignore droppable users. - bool UsedAssumedInformation = false; - return !U->getUser()->isDroppable() && - !A.isAssumedDead(*U, &NoCaptureAA, &IsDeadAA, - UsedAssumedInformation); - } - - /// Update the state according to \p CapturedInMem, \p CapturedInInt, and - /// \p CapturedInRet, then return the appropriate value for use in the - /// CaptureTracker::captured() interface. - bool isCapturedIn(bool CapturedInMem, bool CapturedInInt, - bool CapturedInRet) { + /// Update \p State according to \p CapturedInMem, \p CapturedInInt, and + /// \p CapturedInRet, then return true if we should continue updating the + /// state. + static bool isCapturedIn(AANoCapture::StateType &State, bool CapturedInMem, + bool CapturedInInt, bool CapturedInRet) { LLVM_DEBUG(dbgs() << " - captures [Mem " << CapturedInMem << "|Int " << CapturedInInt << "|Ret " << CapturedInRet << "]\n"); if (CapturedInMem) @@ -5028,27 +5217,8 @@ struct AACaptureUseTracker final : public CaptureTracker { State.removeAssumedBits(AANoCapture::NOT_CAPTURED_IN_INT); if (CapturedInRet) State.removeAssumedBits(AANoCapture::NOT_CAPTURED_IN_RET); - return !State.isAssumed(AANoCapture::NO_CAPTURE_MAYBE_RETURNED); + return State.isAssumed(AANoCapture::NO_CAPTURE_MAYBE_RETURNED); } - -private: - /// The attributor providing in-flight abstract attributes. - Attributor &A; - - /// The abstract attribute currently updated. - AANoCapture &NoCaptureAA; - - /// The abstract liveness state. - const AAIsDead &IsDeadAA; - - /// The state currently updated. - AANoCapture::StateType &State; - - /// Set of potential copies of the tracked value. - SmallSetVector<Value *, 4> &PotentialCopies; - - /// Global counter to limit the number of explored uses. - unsigned &RemainingUsesToExplore; }; ChangeStatus AANoCaptureImpl::updateImpl(Attributor &A) { @@ -5062,7 +5232,6 @@ ChangeStatus AANoCaptureImpl::updateImpl(Attributor &A) { isArgumentPosition() ? IRP.getAssociatedFunction() : IRP.getAnchorScope(); assert(F && "Expected a function!"); const IRPosition &FnPos = IRPosition::function(*F); - const auto &IsDeadAA = A.getAAFor<AAIsDead>(*this, FnPos, DepClassTy::NONE); AANoCapture::StateType T; @@ -5079,6 +5248,8 @@ ChangeStatus AANoCaptureImpl::updateImpl(Attributor &A) { // AAReturnedValues, e.g., track all values that escape through returns // directly somehow. auto CheckReturnedArgs = [&](const AAReturnedValues &RVAA) { + if (!RVAA.getState().isValidState()) + return false; bool SeenConstant = false; for (auto &It : RVAA.returned_values()) { if (isa<Constant>(It.first)) { @@ -5114,21 +5285,27 @@ ChangeStatus AANoCaptureImpl::updateImpl(Attributor &A) { } } - // Use the CaptureTracker interface and logic with the specialized tracker, - // defined in AACaptureUseTracker, that can look at in-flight abstract - // attributes and directly updates the assumed state. - SmallSetVector<Value *, 4> PotentialCopies; - unsigned RemainingUsesToExplore = - getDefaultMaxUsesToExploreForCaptureTracking(); - AACaptureUseTracker Tracker(A, *this, IsDeadAA, T, PotentialCopies, - RemainingUsesToExplore); + auto IsDereferenceableOrNull = [&](Value *O, const DataLayout &DL) { + const auto &DerefAA = A.getAAFor<AADereferenceable>( + *this, IRPosition::value(*O), DepClassTy::OPTIONAL); + return DerefAA.getAssumedDereferenceableBytes(); + }; - // Check all potential copies of the associated value until we can assume - // none will be captured or we have to assume at least one might be. - unsigned Idx = 0; - PotentialCopies.insert(V); - while (T.isAssumed(NO_CAPTURE_MAYBE_RETURNED) && Idx < PotentialCopies.size()) - Tracker.valueMayBeCaptured(PotentialCopies[Idx++]); + auto UseCheck = [&](const Use &U, bool &Follow) -> bool { + switch (DetermineUseCaptureKind(U, IsDereferenceableOrNull)) { + case UseCaptureKind::NO_CAPTURE: + return true; + case UseCaptureKind::MAY_CAPTURE: + return checkUse(A, T, U, Follow); + case UseCaptureKind::PASSTHROUGH: + Follow = true; + return true; + } + llvm_unreachable("Unexpected use capture kind!"); + }; + + if (!A.checkForAllUses(UseCheck, *this, *V)) + return indicatePessimisticFixpoint(); AANoCapture::StateType &S = getState(); auto Assumed = S.getAssumed(); @@ -5228,6 +5405,7 @@ struct AANoCaptureCallSiteReturned final : AANoCaptureImpl { STATS_DECLTRACK_CSRET_ATTR(nocapture) } }; +} // namespace /// ------------------ Value Simplify Attribute ---------------------------- @@ -5239,7 +5417,7 @@ bool ValueSimplifyStateType::unionAssumed(Optional<Value *> Other) { return false; LLVM_DEBUG({ - if (SimplifiedAssociatedValue.hasValue()) + if (SimplifiedAssociatedValue) dbgs() << "[ValueSimplify] is assumed to be " << **SimplifiedAssociatedValue << "\n"; else @@ -5248,6 +5426,7 @@ bool ValueSimplifyStateType::unionAssumed(Optional<Value *> Other) { return true; } +namespace { struct AAValueSimplifyImpl : AAValueSimplify { AAValueSimplifyImpl(const IRPosition &IRP, Attributor &A) : AAValueSimplify(IRP, A) {} @@ -5263,9 +5442,9 @@ struct AAValueSimplifyImpl : AAValueSimplify { /// See AbstractAttribute::getAsStr(). const std::string getAsStr() const override { LLVM_DEBUG({ - errs() << "SAV: " << SimplifiedAssociatedValue << " "; + dbgs() << "SAV: " << (bool)SimplifiedAssociatedValue << " "; if (SimplifiedAssociatedValue && *SimplifiedAssociatedValue) - errs() << "SAV: " << **SimplifiedAssociatedValue << " "; + dbgs() << "SAV: " << **SimplifiedAssociatedValue << " "; }); return isValidState() ? (isAtFixpoint() ? "simplified" : "maybe-simple") : "not-simple"; @@ -5279,24 +5458,101 @@ struct AAValueSimplifyImpl : AAValueSimplify { return SimplifiedAssociatedValue; } + /// Ensure the return value is \p V with type \p Ty, if not possible return + /// nullptr. If \p Check is true we will only verify such an operation would + /// suceed and return a non-nullptr value if that is the case. No IR is + /// generated or modified. + static Value *ensureType(Attributor &A, Value &V, Type &Ty, Instruction *CtxI, + bool Check) { + if (auto *TypedV = AA::getWithType(V, Ty)) + return TypedV; + if (CtxI && V.getType()->canLosslesslyBitCastTo(&Ty)) + return Check ? &V + : BitCastInst::CreatePointerBitCastOrAddrSpaceCast(&V, &Ty, + "", CtxI); + return nullptr; + } + + /// Reproduce \p I with type \p Ty or return nullptr if that is not posisble. + /// If \p Check is true we will only verify such an operation would suceed and + /// return a non-nullptr value if that is the case. No IR is generated or + /// modified. + static Value *reproduceInst(Attributor &A, + const AbstractAttribute &QueryingAA, + Instruction &I, Type &Ty, Instruction *CtxI, + bool Check, ValueToValueMapTy &VMap) { + assert(CtxI && "Cannot reproduce an instruction without context!"); + if (Check && (I.mayReadFromMemory() || + !isSafeToSpeculativelyExecute(&I, CtxI, /* DT */ nullptr, + /* TLI */ nullptr))) + return nullptr; + for (Value *Op : I.operands()) { + Value *NewOp = reproduceValue(A, QueryingAA, *Op, Ty, CtxI, Check, VMap); + if (!NewOp) { + assert(Check && "Manifest of new value unexpectedly failed!"); + return nullptr; + } + if (!Check) + VMap[Op] = NewOp; + } + if (Check) + return &I; + + Instruction *CloneI = I.clone(); + // TODO: Try to salvage debug information here. + CloneI->setDebugLoc(DebugLoc()); + VMap[&I] = CloneI; + CloneI->insertBefore(CtxI); + RemapInstruction(CloneI, VMap); + return CloneI; + } + + /// Reproduce \p V with type \p Ty or return nullptr if that is not posisble. + /// If \p Check is true we will only verify such an operation would suceed and + /// return a non-nullptr value if that is the case. No IR is generated or + /// modified. + static Value *reproduceValue(Attributor &A, + const AbstractAttribute &QueryingAA, Value &V, + Type &Ty, Instruction *CtxI, bool Check, + ValueToValueMapTy &VMap) { + if (const auto &NewV = VMap.lookup(&V)) + return NewV; + bool UsedAssumedInformation = false; + Optional<Value *> SimpleV = + A.getAssumedSimplified(V, QueryingAA, UsedAssumedInformation); + if (!SimpleV) + return PoisonValue::get(&Ty); + Value *EffectiveV = &V; + if (SimpleV.getValue()) + EffectiveV = SimpleV.getValue(); + if (auto *C = dyn_cast<Constant>(EffectiveV)) + if (!C->canTrap()) + return C; + if (CtxI && AA::isValidAtPosition(AA::ValueAndContext(*EffectiveV, *CtxI), + A.getInfoCache())) + return ensureType(A, *EffectiveV, Ty, CtxI, Check); + if (auto *I = dyn_cast<Instruction>(EffectiveV)) + if (Value *NewV = reproduceInst(A, QueryingAA, *I, Ty, CtxI, Check, VMap)) + return ensureType(A, *NewV, Ty, CtxI, Check); + return nullptr; + } + /// Return a value we can use as replacement for the associated one, or /// nullptr if we don't have one that makes sense. - Value *getReplacementValue(Attributor &A) const { - Value *NewV; - NewV = SimplifiedAssociatedValue.hasValue() - ? SimplifiedAssociatedValue.getValue() - : UndefValue::get(getAssociatedType()); - if (!NewV) - return nullptr; - NewV = AA::getWithType(*NewV, *getAssociatedType()); - if (!NewV || NewV == &getAssociatedValue()) - return nullptr; - const Instruction *CtxI = getCtxI(); - if (CtxI && !AA::isValidAtPosition(*NewV, *CtxI, A.getInfoCache())) - return nullptr; - if (!CtxI && !AA::isValidInScope(*NewV, getAnchorScope())) - return nullptr; - return NewV; + Value *manifestReplacementValue(Attributor &A, Instruction *CtxI) const { + Value *NewV = SimplifiedAssociatedValue + ? SimplifiedAssociatedValue.getValue() + : UndefValue::get(getAssociatedType()); + if (NewV && NewV != &getAssociatedValue()) { + ValueToValueMapTy VMap; + // First verify we can reprduce the value with the required type at the + // context location before we actually start modifying the IR. + if (reproduceValue(A, *this, *NewV, *getAssociatedType(), CtxI, + /* CheckOnly */ true, VMap)) + return reproduceValue(A, *this, *NewV, *getAssociatedType(), CtxI, + /* CheckOnly */ false, VMap); + } + return nullptr; } /// Helper function for querying AAValueSimplify and updating candicate. @@ -5320,14 +5576,14 @@ struct AAValueSimplifyImpl : AAValueSimplify { const auto &AA = A.getAAFor<AAType>(*this, getIRPosition(), DepClassTy::NONE); - Optional<ConstantInt *> COpt = AA.getAssumedConstantInt(A); + Optional<Constant *> COpt = AA.getAssumedConstant(A); - if (!COpt.hasValue()) { + if (!COpt) { SimplifiedAssociatedValue = llvm::None; A.recordDependence(AA, *this, DepClassTy::OPTIONAL); return true; } - if (auto *C = COpt.getValue()) { + if (auto *C = *COpt) { SimplifiedAssociatedValue = C; A.recordDependence(AA, *this, DepClassTy::OPTIONAL); return true; @@ -5338,7 +5594,7 @@ struct AAValueSimplifyImpl : AAValueSimplify { bool askSimplifiedValueForOtherAAs(Attributor &A) { if (askSimplifiedValueFor<AAValueConstantRange>(A)) return true; - if (askSimplifiedValueFor<AAPotentialValues>(A)) + if (askSimplifiedValueFor<AAPotentialConstantValues>(A)) return true; return false; } @@ -5346,14 +5602,18 @@ struct AAValueSimplifyImpl : AAValueSimplify { /// See AbstractAttribute::manifest(...). ChangeStatus manifest(Attributor &A) override { ChangeStatus Changed = ChangeStatus::UNCHANGED; - if (getAssociatedValue().user_empty()) - return Changed; - - if (auto *NewV = getReplacementValue(A)) { - LLVM_DEBUG(dbgs() << "[ValueSimplify] " << getAssociatedValue() << " -> " - << *NewV << " :: " << *this << "\n"); - if (A.changeValueAfterManifest(getAssociatedValue(), *NewV)) - Changed = ChangeStatus::CHANGED; + for (auto &U : getAssociatedValue().uses()) { + // Check if we need to adjust the insertion point to make sure the IR is + // valid. + Instruction *IP = dyn_cast<Instruction>(U.getUser()); + if (auto *PHI = dyn_cast_or_null<PHINode>(IP)) + IP = PHI->getIncomingBlock(U)->getTerminator(); + if (auto *NewV = manifestReplacementValue(A, IP)) { + LLVM_DEBUG(dbgs() << "[ValueSimplify] " << getAssociatedValue() + << " -> " << *NewV << " :: " << *this << "\n"); + if (A.changeUseAfterManifest(U, *NewV)) + Changed = ChangeStatus::CHANGED; + } } return Changed | AAValueSimplify::manifest(A); @@ -5364,74 +5624,6 @@ struct AAValueSimplifyImpl : AAValueSimplify { SimplifiedAssociatedValue = &getAssociatedValue(); return AAValueSimplify::indicatePessimisticFixpoint(); } - - static bool handleLoad(Attributor &A, const AbstractAttribute &AA, - LoadInst &L, function_ref<bool(Value &)> Union) { - auto UnionWrapper = [&](Value &V, Value &Obj) { - if (isa<AllocaInst>(Obj)) - return Union(V); - if (!AA::isDynamicallyUnique(A, AA, V)) - return false; - if (!AA::isValidAtPosition(V, L, A.getInfoCache())) - return false; - return Union(V); - }; - - Value &Ptr = *L.getPointerOperand(); - SmallVector<Value *, 8> Objects; - bool UsedAssumedInformation = false; - if (!AA::getAssumedUnderlyingObjects(A, Ptr, Objects, AA, &L, - UsedAssumedInformation)) - return false; - - const auto *TLI = - A.getInfoCache().getTargetLibraryInfoForFunction(*L.getFunction()); - for (Value *Obj : Objects) { - LLVM_DEBUG(dbgs() << "Visit underlying object " << *Obj << "\n"); - if (isa<UndefValue>(Obj)) - continue; - if (isa<ConstantPointerNull>(Obj)) { - // A null pointer access can be undefined but any offset from null may - // be OK. We do not try to optimize the latter. - if (!NullPointerIsDefined(L.getFunction(), - Ptr.getType()->getPointerAddressSpace()) && - A.getAssumedSimplified(Ptr, AA, UsedAssumedInformation) == Obj) - continue; - return false; - } - Constant *InitialVal = AA::getInitialValueForObj(*Obj, *L.getType(), TLI); - if (!InitialVal || !Union(*InitialVal)) - return false; - - LLVM_DEBUG(dbgs() << "Underlying object amenable to load-store " - "propagation, checking accesses next.\n"); - - auto CheckAccess = [&](const AAPointerInfo::Access &Acc, bool IsExact) { - LLVM_DEBUG(dbgs() << " - visit access " << Acc << "\n"); - if (Acc.isWrittenValueYetUndetermined()) - return true; - Value *Content = Acc.getWrittenValue(); - if (!Content) - return false; - Value *CastedContent = - AA::getWithType(*Content, *AA.getAssociatedType()); - if (!CastedContent) - return false; - if (IsExact) - return UnionWrapper(*CastedContent, *Obj); - if (auto *C = dyn_cast<Constant>(CastedContent)) - if (C->isNullValue() || C->isAllOnesValue() || isa<UndefValue>(C)) - return UnionWrapper(*CastedContent, *Obj); - return false; - }; - - auto &PI = A.getAAFor<AAPointerInfo>(AA, IRPosition::value(*Obj), - DepClassTy::REQUIRED); - if (!PI.forallInterferingWrites(A, AA, L, CheckAccess)) - return false; - } - return true; - } }; struct AAValueSimplifyArgument final : AAValueSimplifyImpl { @@ -5446,15 +5638,6 @@ struct AAValueSimplifyArgument final : AAValueSimplifyImpl { Attribute::StructRet, Attribute::Nest, Attribute::ByVal}, /* IgnoreSubsumingPositions */ true)) indicatePessimisticFixpoint(); - - // FIXME: This is a hack to prevent us from propagating function poiner in - // the new pass manager CGSCC pass as it creates call edges the - // CallGraphUpdater cannot handle yet. - Value &V = getAssociatedValue(); - if (V.getType()->isPointerTy() && - V.getType()->getPointerElementType()->isFunctionTy() && - !A.isModulePass()) - indicatePessimisticFixpoint(); } /// See AbstractAttribute::updateImpl(...). @@ -5487,7 +5670,7 @@ struct AAValueSimplifyArgument final : AAValueSimplifyImpl { bool UsedAssumedInformation = false; Optional<Constant *> SimpleArgOp = A.getAssumedConstant(ACSArgPos, *this, UsedAssumedInformation); - if (!SimpleArgOp.hasValue()) + if (!SimpleArgOp) return true; if (!SimpleArgOp.getValue()) return false; @@ -5537,12 +5720,16 @@ struct AAValueSimplifyReturned : AAValueSimplifyImpl { ChangeStatus updateImpl(Attributor &A) override { auto Before = SimplifiedAssociatedValue; - auto PredForReturned = [&](Value &V) { - return checkAndUpdate(A, *this, - IRPosition::value(V, getCallBaseContext())); + auto ReturnInstCB = [&](Instruction &I) { + auto &RI = cast<ReturnInst>(I); + return checkAndUpdate( + A, *this, + IRPosition::value(*RI.getReturnValue(), getCallBaseContext())); }; - if (!A.checkForAllReturnedValues(PredForReturned, *this)) + bool UsedAssumedInformation = false; + if (!A.checkForAllInstructions(ReturnInstCB, *this, {Instruction::Ret}, + UsedAssumedInformation)) if (!askSimplifiedValueForOtherAAs(A)) return indicatePessimisticFixpoint(); @@ -5552,29 +5739,9 @@ struct AAValueSimplifyReturned : AAValueSimplifyImpl { } ChangeStatus manifest(Attributor &A) override { - ChangeStatus Changed = ChangeStatus::UNCHANGED; - - if (auto *NewV = getReplacementValue(A)) { - auto PredForReturned = - [&](Value &, const SmallSetVector<ReturnInst *, 4> &RetInsts) { - for (ReturnInst *RI : RetInsts) { - Value *ReturnedVal = RI->getReturnValue(); - if (ReturnedVal == NewV || isa<UndefValue>(ReturnedVal)) - return true; - assert(RI->getFunction() == getAnchorScope() && - "ReturnInst in wrong function!"); - LLVM_DEBUG(dbgs() - << "[ValueSimplify] " << *ReturnedVal << " -> " - << *NewV << " in " << *RI << " :: " << *this << "\n"); - if (A.changeUseAfterManifest(RI->getOperandUse(0), *NewV)) - Changed = ChangeStatus::CHANGED; - } - return true; - }; - A.checkForAllReturnedValuesAndReturnInsts(PredForReturned, *this); - } - - return Changed | AAValueSimplify::manifest(A); + // We queried AAValueSimplify for the returned values so they will be + // replaced if a simplified form was found. Nothing to do here. + return ChangeStatus::UNCHANGED; } /// See AbstractAttribute::trackStatistics() @@ -5618,7 +5785,7 @@ struct AAValueSimplifyFloating : AAValueSimplifyImpl { const auto &SimplifiedLHS = A.getAssumedSimplified(IRPosition::value(*LHS, getCallBaseContext()), *this, UsedAssumedInformation); - if (!SimplifiedLHS.hasValue()) + if (!SimplifiedLHS) return true; if (!SimplifiedLHS.getValue()) return false; @@ -5627,7 +5794,7 @@ struct AAValueSimplifyFloating : AAValueSimplifyImpl { const auto &SimplifiedRHS = A.getAssumedSimplified(IRPosition::value(*RHS, getCallBaseContext()), *this, UsedAssumedInformation); - if (!SimplifiedRHS.hasValue()) + if (!SimplifiedRHS) return true; if (!SimplifiedRHS.getValue()) return false; @@ -5683,15 +5850,6 @@ struct AAValueSimplifyFloating : AAValueSimplifyImpl { return true; } - bool updateWithLoad(Attributor &A, LoadInst &L) { - auto Union = [&](Value &V) { - SimplifiedAssociatedValue = AA::combineOptionalValuesInAAValueLatice( - SimplifiedAssociatedValue, &V, L.getType()); - return SimplifiedAssociatedValue != Optional<Value *>(nullptr); - }; - return handleLoad(A, *this, L, Union); - } - /// Use the generic, non-optimistic InstSimplfy functionality if we managed to /// simplify any operand of the instruction \p I. Return true if successful, /// in that case SimplifiedAssociatedValue will be updated. @@ -5707,7 +5865,7 @@ struct AAValueSimplifyFloating : AAValueSimplifyImpl { *this, UsedAssumedInformation); // If we are not sure about any operand we are not sure about the entire // instruction, we'll wait. - if (!SimplifiedOp.hasValue()) + if (!SimplifiedOp) return true; if (SimplifiedOp.getValue()) @@ -5735,7 +5893,7 @@ struct AAValueSimplifyFloating : AAValueSimplifyImpl { const DataLayout &DL = I.getModule()->getDataLayout(); SimplifyQuery Q(DL, TLI, DT, AC, &I); if (Value *SimplifiedI = - SimplifyInstructionWithOperands(&I, NewOps, Q, ORE)) { + simplifyInstructionWithOperands(&I, NewOps, Q, ORE)) { SimplifiedAssociatedValue = AA::combineOptionalValuesInAAValueLatice( SimplifiedAssociatedValue, SimplifiedI, I.getType()); return SimplifiedAssociatedValue != Optional<Value *>(nullptr); @@ -5747,6 +5905,36 @@ struct AAValueSimplifyFloating : AAValueSimplifyImpl { ChangeStatus updateImpl(Attributor &A) override { auto Before = SimplifiedAssociatedValue; + // Do not simplify loads that are only used in llvm.assume if we cannot also + // remove all stores that may feed into the load. The reason is that the + // assume is probably worth something as long as the stores are around. + if (auto *LI = dyn_cast<LoadInst>(&getAssociatedValue())) { + InformationCache &InfoCache = A.getInfoCache(); + if (InfoCache.isOnlyUsedByAssume(*LI)) { + SmallSetVector<Value *, 4> PotentialCopies; + SmallSetVector<Instruction *, 4> PotentialValueOrigins; + bool UsedAssumedInformation = false; + if (AA::getPotentiallyLoadedValues(A, *LI, PotentialCopies, + PotentialValueOrigins, *this, + UsedAssumedInformation, + /* OnlyExact */ true)) { + if (!llvm::all_of(PotentialValueOrigins, [&](Instruction *I) { + if (!I) + return true; + if (auto *SI = dyn_cast<StoreInst>(I)) + return A.isAssumedDead(SI->getOperandUse(0), this, + /* LivenessAA */ nullptr, + UsedAssumedInformation, + /* CheckBBLivenessOnly */ false); + return A.isAssumedDead(*I, this, /* LivenessAA */ nullptr, + UsedAssumedInformation, + /* CheckBBLivenessOnly */ false); + })) + return indicatePessimisticFixpoint(); + } + } + } + auto VisitValueCB = [&](Value &V, const Instruction *CtxI, bool &, bool Stripped) -> bool { auto &AA = A.getAAFor<AAValueSimplify>( @@ -5755,9 +5943,6 @@ struct AAValueSimplifyFloating : AAValueSimplifyImpl { if (!Stripped && this == &AA) { if (auto *I = dyn_cast<Instruction>(&V)) { - if (auto *LI = dyn_cast<LoadInst>(&V)) - if (updateWithLoad(A, *LI)) - return true; if (auto *Cmp = dyn_cast<CmpInst>(&V)) if (handleCmp(A, *Cmp)) return true; @@ -5829,8 +6014,23 @@ struct AAValueSimplifyCallSiteReturned : AAValueSimplifyImpl { void initialize(Attributor &A) override { AAValueSimplifyImpl::initialize(A); - if (!getAssociatedFunction()) + Function *Fn = getAssociatedFunction(); + if (!Fn) { indicatePessimisticFixpoint(); + return; + } + for (Argument &Arg : Fn->args()) { + if (Arg.hasReturnedAttr()) { + auto IRP = IRPosition::callsite_argument(*cast<CallBase>(getCtxI()), + Arg.getArgNo()); + if (IRP.getPositionKind() == IRPosition::IRP_CALL_SITE_ARGUMENT && + checkAndUpdate(A, *this, IRP)) + indicateOptimisticFixpoint(); + else + indicatePessimisticFixpoint(); + return; + } + } } /// See AbstractAttribute::updateImpl(...). @@ -5868,8 +6068,13 @@ struct AAValueSimplifyCallSiteArgument : AAValueSimplifyFloating { /// See AbstractAttribute::manifest(...). ChangeStatus manifest(Attributor &A) override { ChangeStatus Changed = ChangeStatus::UNCHANGED; + // TODO: We should avoid simplification duplication to begin with. + auto *FloatAA = A.lookupAAFor<AAValueSimplify>( + IRPosition::value(getAssociatedValue()), this, DepClassTy::NONE); + if (FloatAA && FloatAA->getState().isValidState()) + return Changed; - if (auto *NewV = getReplacementValue(A)) { + if (auto *NewV = manifestReplacementValue(A, getCtxI())) { Use &U = cast<CallBase>(&getAnchorValue()) ->getArgOperandUse(getCallSiteArgNo()); if (A.changeUseAfterManifest(U, *NewV)) @@ -5883,8 +6088,10 @@ struct AAValueSimplifyCallSiteArgument : AAValueSimplifyFloating { STATS_DECLTRACK_CSARG_ATTR(value_simplify) } }; +} // namespace /// ----------------------- Heap-To-Stack Conversion --------------------------- +namespace { struct AAHeapToStackFunction final : public AAHeapToStack { struct AllocationInfo { @@ -5906,7 +6113,7 @@ struct AAHeapToStackFunction final : public AAHeapToStack { bool HasPotentiallyFreeingUnknownUses = false; /// The set of free calls that use this allocation. - SmallPtrSet<CallBase *, 1> PotentialFreeCalls{}; + SmallSetVector<CallBase *, 1> PotentialFreeCalls{}; }; struct DeallocationInfo { @@ -5918,7 +6125,7 @@ struct AAHeapToStackFunction final : public AAHeapToStack { bool MightFreeUnknownObjects = false; /// The set of allocation calls that are potentially freed. - SmallPtrSet<CallBase *, 1> PotentialAllocationCalls{}; + SmallSetVector<CallBase *, 1> PotentialAllocationCalls{}; }; AAHeapToStackFunction(const IRPosition &IRP, Attributor &A) @@ -5928,9 +6135,9 @@ struct AAHeapToStackFunction final : public AAHeapToStack { // Ensure we call the destructor so we release any memory allocated in the // sets. for (auto &It : AllocationInfos) - It.getSecond()->~AllocationInfo(); + It.second->~AllocationInfo(); for (auto &It : DeallocationInfos) - It.getSecond()->~DeallocationInfo(); + It.second->~DeallocationInfo(); } void initialize(Attributor &A) override { @@ -5955,7 +6162,8 @@ struct AAHeapToStackFunction final : public AAHeapToStack { if (nullptr != getInitialValueOfAllocation(CB, TLI, I8Ty)) { AllocationInfo *AI = new (A.Allocator) AllocationInfo{CB}; AllocationInfos[CB] = AI; - TLI->getLibFunc(*CB, AI->LibraryFunctionId); + if (TLI) + TLI->getLibFunc(*CB, AI->LibraryFunctionId); } } return true; @@ -5968,6 +6176,16 @@ struct AAHeapToStackFunction final : public AAHeapToStack { /* CheckPotentiallyDead */ true); (void)Success; assert(Success && "Did not expect the call base visit callback to fail!"); + + Attributor::SimplifictionCallbackTy SCB = + [](const IRPosition &, const AbstractAttribute *, + bool &) -> Optional<Value *> { return nullptr; }; + for (const auto &It : AllocationInfos) + A.registerSimplificationCallback(IRPosition::callsite_returned(*It.first), + SCB); + for (const auto &It : DeallocationInfos) + A.registerSimplificationCallback(IRPosition::callsite_returned(*It.first), + SCB); } const std::string getAsStr() const override { @@ -5994,7 +6212,8 @@ struct AAHeapToStackFunction final : public AAHeapToStack { bool isAssumedHeapToStack(const CallBase &CB) const override { if (isValidState()) - if (AllocationInfo *AI = AllocationInfos.lookup(&CB)) + if (AllocationInfo *AI = + AllocationInfos.lookup(const_cast<CallBase *>(&CB))) return AI->Status != AllocationInfo::INVALID; return false; } @@ -6023,6 +6242,17 @@ struct AAHeapToStackFunction final : public AAHeapToStack { Function *F = getAnchorScope(); const auto *TLI = A.getInfoCache().getTargetLibraryInfoForFunction(*F); + LoopInfo *LI = + A.getInfoCache().getAnalysisResultForFunction<LoopAnalysis>(*F); + Optional<bool> MayContainIrreducibleControl; + auto IsInLoop = [&](BasicBlock &BB) { + if (!MayContainIrreducibleControl.has_value()) + MayContainIrreducibleControl = mayContainIrreducibleControl(*F, LI); + if (MayContainIrreducibleControl.value()) + return true; + return LI->getLoopFor(&BB) != nullptr; + }; + for (auto &It : AllocationInfos) { AllocationInfo &AI = *It.second; if (AI.Status == AllocationInfo::INVALID) @@ -6052,7 +6282,7 @@ struct AAHeapToStackFunction final : public AAHeapToStack { const DataLayout &DL = A.getInfoCache().getDL(); Value *Size; Optional<APInt> SizeAPI = getSize(A, *this, AI); - if (SizeAPI.hasValue()) { + if (SizeAPI) { Size = ConstantInt::get(AI.CB->getContext(), *SizeAPI); } else { LLVMContext &Ctx = AI.CB->getContext(); @@ -6064,21 +6294,25 @@ struct AAHeapToStackFunction final : public AAHeapToStack { Size = SizeOffsetPair.first; } + Instruction *IP = (!SizeAPI.has_value() || IsInLoop(*AI.CB->getParent())) + ? AI.CB + : &F->getEntryBlock().front(); + Align Alignment(1); if (MaybeAlign RetAlign = AI.CB->getRetAlign()) - Alignment = max(Alignment, RetAlign); + Alignment = std::max(Alignment, *RetAlign); if (Value *Align = getAllocAlignment(AI.CB, TLI)) { Optional<APInt> AlignmentAPI = getAPInt(A, *this, *Align); - assert(AlignmentAPI.hasValue() && + assert(AlignmentAPI && AlignmentAPI.getValue().getZExtValue() > 0 && "Expected an alignment during manifest!"); - Alignment = - max(Alignment, MaybeAlign(AlignmentAPI.getValue().getZExtValue())); + Alignment = std::max( + Alignment, assumeAligned(AlignmentAPI.getValue().getZExtValue())); } // TODO: Hoist the alloca towards the function entry. unsigned AS = DL.getAllocaAddrSpace(); Instruction *Alloca = new AllocaInst(Type::getInt8Ty(F->getContext()), AS, - Size, Alignment, "", AI.CB); + Size, Alignment, "", IP); if (Alloca->getType() != AI.CB->getType()) Alloca = BitCastInst::CreatePointerBitCastOrAddrSpaceCast( @@ -6089,7 +6323,7 @@ struct AAHeapToStackFunction final : public AAHeapToStack { assert(InitVal && "Must be able to materialize initial memory state of allocation"); - A.changeValueAfterManifest(*AI.CB, *Alloca); + A.changeAfterManifest(IRPosition::inst(*AI.CB), *Alloca); if (auto *II = dyn_cast<InvokeInst>(AI.CB)) { auto *NBB = II->getNormalDest(); @@ -6118,7 +6352,7 @@ struct AAHeapToStackFunction final : public AAHeapToStack { bool UsedAssumedInformation = false; Optional<Constant *> SimpleV = A.getAssumedConstant(V, AA, UsedAssumedInformation); - if (!SimpleV.hasValue()) + if (!SimpleV) return APInt(64, 0); if (auto *CI = dyn_cast_or_null<ConstantInt>(SimpleV.getValue())) return CI->getValue(); @@ -6143,11 +6377,11 @@ struct AAHeapToStackFunction final : public AAHeapToStack { /// Collection of all malloc-like calls in a function with associated /// information. - DenseMap<CallBase *, AllocationInfo *> AllocationInfos; + MapVector<CallBase *, AllocationInfo *> AllocationInfos; /// Collection of all free-like calls in a function with associated /// information. - DenseMap<CallBase *, DeallocationInfo *> DeallocationInfos; + MapVector<CallBase *, DeallocationInfo *> DeallocationInfos; ChangeStatus updateImpl(Attributor &A) override; }; @@ -6263,6 +6497,8 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { dbgs() << "[H2S] unique free call might free unknown allocations\n"); return false; } + if (DI->PotentialAllocationCalls.empty()) + return true; if (DI->PotentialAllocationCalls.size() > 1) { LLVM_DEBUG(dbgs() << "[H2S] unique free call might free " << DI->PotentialAllocationCalls.size() @@ -6340,7 +6576,7 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { if (ValidUsesOnly && AI.LibraryFunctionId == LibFunc___kmpc_alloc_shared) - A.emitRemark<OptimizationRemarkMissed>(AI.CB, "OMP113", Remark); + A.emitRemark<OptimizationRemarkMissed>(CB, "OMP113", Remark); LLVM_DEBUG(dbgs() << "[H2S] Bad user: " << *UserI << "\n"); ValidUsesOnly = false; @@ -6372,7 +6608,8 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { continue; if (Value *Align = getAllocAlignment(AI.CB, TLI)) { - if (!getAPInt(A, *this, *Align)) { + Optional<APInt> APAlign = getAPInt(A, *this, *Align); + if (!APAlign) { // Can't generate an alloca which respects the required alignment // on the allocation. LLVM_DEBUG(dbgs() << "[H2S] Unknown allocation alignment: " << *AI.CB @@ -6380,14 +6617,23 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { AI.Status = AllocationInfo::INVALID; Changed = ChangeStatus::CHANGED; continue; + } else { + if (APAlign->ugt(llvm::Value::MaximumAlignment) || + !APAlign->isPowerOf2()) { + LLVM_DEBUG(dbgs() << "[H2S] Invalid allocation alignment: " << APAlign + << "\n"); + AI.Status = AllocationInfo::INVALID; + Changed = ChangeStatus::CHANGED; + continue; + } } } if (MaxHeapToStackSize != -1) { Optional<APInt> Size = getSize(A, *this, AI); - if (!Size.hasValue() || Size.getValue().ugt(MaxHeapToStackSize)) { + if (!Size || Size.getValue().ugt(MaxHeapToStackSize)) { LLVM_DEBUG({ - if (!Size.hasValue()) + if (!Size) dbgs() << "[H2S] Unknown allocation size: " << *AI.CB << "\n"; else dbgs() << "[H2S] Allocation size too large: " << *AI.CB << " vs. " @@ -6419,8 +6665,10 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { return Changed; } +} // namespace /// ----------------------- Privatizable Pointers ------------------------------ +namespace { struct AAPrivatizablePtrImpl : public AAPrivatizablePtr { AAPrivatizablePtrImpl(const IRPosition &IRP, Attributor &A) : AAPrivatizablePtr(IRP, A), PrivatizableType(llvm::None) {} @@ -6438,9 +6686,9 @@ struct AAPrivatizablePtrImpl : public AAPrivatizablePtr { /// Return a privatizable type that encloses both T0 and T1. /// TODO: This is merely a stub for now as we should manage a mapping as well. Optional<Type *> combineTypes(Optional<Type *> T0, Optional<Type *> T1) { - if (!T0.hasValue()) + if (!T0) return T1; - if (!T1.hasValue()) + if (!T1) return T0; if (T0 == T1) return T0; @@ -6470,10 +6718,12 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { // If this is a byval argument and we know all the call sites (so we can // rewrite them), there is no need to check them explicitly. bool UsedAssumedInformation = false; - if (getIRPosition().hasAttr(Attribute::ByVal) && + SmallVector<Attribute, 1> Attrs; + getAttrs({Attribute::ByVal}, Attrs, /* IgnoreSubsumingPositions */ true); + if (!Attrs.empty() && A.checkForAllCallSites([](AbstractCallSite ACS) { return true; }, *this, true, UsedAssumedInformation)) - return getAssociatedValue().getType()->getPointerElementType(); + return Attrs[0].getValueAsType(); Optional<Type *> Ty; unsigned ArgNo = getIRPosition().getCallSiteArgNo(); @@ -6498,9 +6748,9 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { LLVM_DEBUG({ dbgs() << "[AAPrivatizablePtr] ACSPos: " << ACSArgPos << ", CSTy: "; - if (CSTy.hasValue() && CSTy.getValue()) + if (CSTy && CSTy.getValue()) CSTy.getValue()->print(dbgs()); - else if (CSTy.hasValue()) + else if (CSTy) dbgs() << "<nullptr>"; else dbgs() << "<none>"; @@ -6510,16 +6760,16 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { LLVM_DEBUG({ dbgs() << " : New Type: "; - if (Ty.hasValue() && Ty.getValue()) + if (Ty && Ty.getValue()) Ty.getValue()->print(dbgs()); - else if (Ty.hasValue()) + else if (Ty) dbgs() << "<nullptr>"; else dbgs() << "<none>"; dbgs() << "\n"; }); - return !Ty.hasValue() || Ty.getValue(); + return !Ty || Ty.getValue(); }; if (!A.checkForAllCallSites(CallSiteCheck, *this, true, @@ -6531,7 +6781,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { PrivatizableType = identifyPrivatizableType(A); - if (!PrivatizableType.hasValue()) + if (!PrivatizableType) return ChangeStatus::UNCHANGED; if (!PrivatizableType.getValue()) return indicatePessimisticFixpoint(); @@ -6543,8 +6793,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { // Avoid arguments with padding for now. if (!getIRPosition().hasAttr(Attribute::ByVal) && - !ArgumentPromotionPass::isDenselyPacked(PrivatizableType.getValue(), - A.getInfoCache().getDL())) { + !isDenselyPacked(*PrivatizableType, A.getInfoCache().getDL())) { LLVM_DEBUG(dbgs() << "[AAPrivatizablePtr] Padding detected\n"); return indicatePessimisticFixpoint(); } @@ -6552,7 +6801,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { // Collect the types that will replace the privatizable type in the function // signature. SmallVector<Type *, 16> ReplacementTypes; - identifyReplacementTypes(PrivatizableType.getValue(), ReplacementTypes); + identifyReplacementTypes(*PrivatizableType, ReplacementTypes); // Verify callee and caller agree on how the promoted argument would be // passed. @@ -6620,7 +6869,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { *this, IRPosition::argument(CBArg), DepClassTy::REQUIRED); if (CBArgPrivAA.isValidState()) { auto CBArgPrivTy = CBArgPrivAA.getPrivatizableType(); - if (!CBArgPrivTy.hasValue()) + if (!CBArgPrivTy) continue; if (CBArgPrivTy.getValue() == PrivatizableType) continue; @@ -6667,7 +6916,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { DepClassTy::REQUIRED); if (DCArgPrivAA.isValidState()) { auto DCArgPrivTy = DCArgPrivAA.getPrivatizableType(); - if (!DCArgPrivTy.hasValue()) + if (!DCArgPrivTy) return true; if (DCArgPrivTy.getValue() == PrivatizableType) return true; @@ -6809,7 +7058,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { /// See AbstractAttribute::manifest(...) ChangeStatus manifest(Attributor &A) override { - if (!PrivatizableType.hasValue()) + if (!PrivatizableType) return ChangeStatus::UNCHANGED; assert(PrivatizableType.getValue() && "Expected privatizable type!"); @@ -6868,8 +7117,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { // When no alignment is specified for the load instruction, // natural alignment is assumed. createReplacementValues( - assumeAligned(AlignAA.getAssumedAlign()), - PrivatizableType.getValue(), ACS, + AlignAA.getAssumedAlign(), *PrivatizableType, ACS, ACS.getCallArgOperand(ARI.getReplacedArg().getArgNo()), NewArgOperands); }; @@ -6877,7 +7125,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { // Collect the types that will replace the privatizable type in the function // signature. SmallVector<Type *, 16> ReplacementTypes; - identifyReplacementTypes(PrivatizableType.getValue(), ReplacementTypes); + identifyReplacementTypes(*PrivatizableType, ReplacementTypes); // Register a rewrite of the argument. if (A.registerFunctionSignatureRewrite(*Arg, ReplacementTypes, @@ -6924,7 +7172,7 @@ struct AAPrivatizablePtrFloating : public AAPrivatizablePtrImpl { auto &PrivArgAA = A.getAAFor<AAPrivatizablePtr>( *this, IRPosition::argument(*Arg), DepClassTy::REQUIRED); if (PrivArgAA.isAssumedPrivatizablePtr()) - return Obj->getType()->getPointerElementType(); + return PrivArgAA.getPrivatizableType(); } LLVM_DEBUG(dbgs() << "[AAPrivatizablePtr] Underlying object neither valid " @@ -6953,7 +7201,7 @@ struct AAPrivatizablePtrCallSiteArgument final /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { PrivatizableType = identifyPrivatizableType(A); - if (!PrivatizableType.hasValue()) + if (!PrivatizableType) return ChangeStatus::UNCHANGED; if (!PrivatizableType.getValue()) return indicatePessimisticFixpoint(); @@ -7019,10 +7267,12 @@ struct AAPrivatizablePtrReturned final : public AAPrivatizablePtrFloating { STATS_DECLTRACK_FNRET_ATTR(privatizable_ptr); } }; +} // namespace /// -------------------- Memory Behavior Attributes ---------------------------- /// Includes read-none, read-only, and write-only. /// ---------------------------------------------------------------------------- +namespace { struct AAMemoryBehaviorImpl : public AAMemoryBehavior { AAMemoryBehaviorImpl(const IRPosition &IRP, Attributor &A) : AAMemoryBehavior(IRP, A) {} @@ -7522,6 +7772,7 @@ void AAMemoryBehaviorFloating::analyzeUseIn(Attributor &A, const Use &U, if (UserI->mayWriteToMemory()) removeAssumedBits(NO_WRITES); } +} // namespace /// -------------------- Memory Locations Attributes --------------------------- /// Includes read-none, argmemonly, inaccessiblememonly, @@ -7555,6 +7806,7 @@ std::string AAMemoryLocation::getMemoryLocationsAsStr( return S; } +namespace { struct AAMemoryLocationImpl : public AAMemoryLocation { AAMemoryLocationImpl(const IRPosition &IRP, Attributor &A) @@ -7802,7 +8054,7 @@ void AAMemoryLocationImpl::categorizePtrValue( bool UsedAssumedInformation = false; if (!AA::getAssumedUnderlyingObjects(A, Ptr, Objects, *this, &I, UsedAssumedInformation, - /* Intraprocedural */ true)) { + AA::Intraprocedural)) { LLVM_DEBUG( dbgs() << "[AAMemoryLocation] Pointer locations not categorized\n"); updateStateAndAccessesMap(State, NO_UNKOWN_MEM, &I, nullptr, Changed, @@ -8071,9 +8323,11 @@ struct AAMemoryLocationCallSite final : AAMemoryLocationImpl { STATS_DECLTRACK_CS_ATTR(readnone) } }; +} // namespace /// ------------------ Value Constant Range Attribute ------------------------- +namespace { struct AAValueConstantRangeImpl : AAValueConstantRange { using StateType = IntegerRangeState; AAValueConstantRangeImpl(const IRPosition &IRP, Attributor &A) @@ -8408,7 +8662,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { const auto &SimplifiedLHS = A.getAssumedSimplified(IRPosition::value(*LHS, getCallBaseContext()), *this, UsedAssumedInformation); - if (!SimplifiedLHS.hasValue()) + if (!SimplifiedLHS) return true; if (!SimplifiedLHS.getValue()) return false; @@ -8417,7 +8671,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { const auto &SimplifiedRHS = A.getAssumedSimplified(IRPosition::value(*RHS, getCallBaseContext()), *this, UsedAssumedInformation); - if (!SimplifiedRHS.hasValue()) + if (!SimplifiedRHS) return true; if (!SimplifiedRHS.getValue()) return false; @@ -8461,7 +8715,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { const auto &SimplifiedOpV = A.getAssumedSimplified(IRPosition::value(*OpV, getCallBaseContext()), *this, UsedAssumedInformation); - if (!SimplifiedOpV.hasValue()) + if (!SimplifiedOpV) return true; if (!SimplifiedOpV.getValue()) return false; @@ -8491,7 +8745,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { const auto &SimplifiedLHS = A.getAssumedSimplified(IRPosition::value(*LHS, getCallBaseContext()), *this, UsedAssumedInformation); - if (!SimplifiedLHS.hasValue()) + if (!SimplifiedLHS) return true; if (!SimplifiedLHS.getValue()) return false; @@ -8500,7 +8754,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { const auto &SimplifiedRHS = A.getAssumedSimplified(IRPosition::value(*RHS, getCallBaseContext()), *this, UsedAssumedInformation); - if (!SimplifiedRHS.hasValue()) + if (!SimplifiedRHS) return true; if (!SimplifiedRHS.getValue()) return false; @@ -8565,7 +8819,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl { const auto &SimplifiedOpV = A.getAssumedSimplified(IRPosition::value(V, getCallBaseContext()), *this, UsedAssumedInformation); - if (!SimplifiedOpV.hasValue()) + if (!SimplifiedOpV) return true; if (!SimplifiedOpV.getValue()) return false; @@ -8714,21 +8968,23 @@ struct AAValueConstantRangeCallSiteArgument : AAValueConstantRangeFloating { STATS_DECLTRACK_CSARG_ATTR(value_range) } }; +} // namespace /// ------------------ Potential Values Attribute ------------------------- -struct AAPotentialValuesImpl : AAPotentialValues { +namespace { +struct AAPotentialConstantValuesImpl : AAPotentialConstantValues { using StateType = PotentialConstantIntValuesState; - AAPotentialValuesImpl(const IRPosition &IRP, Attributor &A) - : AAPotentialValues(IRP, A) {} + AAPotentialConstantValuesImpl(const IRPosition &IRP, Attributor &A) + : AAPotentialConstantValues(IRP, A) {} /// See AbstractAttribute::initialize(..). void initialize(Attributor &A) override { if (A.hasSimplificationCallback(getIRPosition())) indicatePessimisticFixpoint(); else - AAPotentialValues::initialize(A); + AAPotentialConstantValues::initialize(A); } /// See AbstractAttribute::getAsStr(). @@ -8745,13 +9001,14 @@ struct AAPotentialValuesImpl : AAPotentialValues { } }; -struct AAPotentialValuesArgument final - : AAArgumentFromCallSiteArguments<AAPotentialValues, AAPotentialValuesImpl, +struct AAPotentialConstantValuesArgument final + : AAArgumentFromCallSiteArguments<AAPotentialConstantValues, + AAPotentialConstantValuesImpl, PotentialConstantIntValuesState> { - using Base = - AAArgumentFromCallSiteArguments<AAPotentialValues, AAPotentialValuesImpl, - PotentialConstantIntValuesState>; - AAPotentialValuesArgument(const IRPosition &IRP, Attributor &A) + using Base = AAArgumentFromCallSiteArguments<AAPotentialConstantValues, + AAPotentialConstantValuesImpl, + PotentialConstantIntValuesState>; + AAPotentialConstantValuesArgument(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} /// See AbstractAttribute::initialize(..). @@ -8769,11 +9026,12 @@ struct AAPotentialValuesArgument final } }; -struct AAPotentialValuesReturned - : AAReturnedFromReturnedValues<AAPotentialValues, AAPotentialValuesImpl> { - using Base = - AAReturnedFromReturnedValues<AAPotentialValues, AAPotentialValuesImpl>; - AAPotentialValuesReturned(const IRPosition &IRP, Attributor &A) +struct AAPotentialConstantValuesReturned + : AAReturnedFromReturnedValues<AAPotentialConstantValues, + AAPotentialConstantValuesImpl> { + using Base = AAReturnedFromReturnedValues<AAPotentialConstantValues, + AAPotentialConstantValuesImpl>; + AAPotentialConstantValuesReturned(const IRPosition &IRP, Attributor &A) : Base(IRP, A) {} /// See AbstractAttribute::trackStatistics() @@ -8782,13 +9040,13 @@ struct AAPotentialValuesReturned } }; -struct AAPotentialValuesFloating : AAPotentialValuesImpl { - AAPotentialValuesFloating(const IRPosition &IRP, Attributor &A) - : AAPotentialValuesImpl(IRP, A) {} +struct AAPotentialConstantValuesFloating : AAPotentialConstantValuesImpl { + AAPotentialConstantValuesFloating(const IRPosition &IRP, Attributor &A) + : AAPotentialConstantValuesImpl(IRP, A) {} /// See AbstractAttribute::initialize(..). void initialize(Attributor &A) override { - AAPotentialValuesImpl::initialize(A); + AAPotentialConstantValuesImpl::initialize(A); if (isAtFixpoint()) return; @@ -8814,7 +9072,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { indicatePessimisticFixpoint(); - LLVM_DEBUG(dbgs() << "[AAPotentialValues] We give up: " + LLVM_DEBUG(dbgs() << "[AAPotentialConstantValues] We give up: " << getAssociatedValue() << "\n"); } @@ -8922,7 +9180,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { const auto &SimplifiedLHS = A.getAssumedSimplified(IRPosition::value(*LHS, getCallBaseContext()), *this, UsedAssumedInformation); - if (!SimplifiedLHS.hasValue()) + if (!SimplifiedLHS) return ChangeStatus::UNCHANGED; if (!SimplifiedLHS.getValue()) return indicatePessimisticFixpoint(); @@ -8931,7 +9189,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { const auto &SimplifiedRHS = A.getAssumedSimplified(IRPosition::value(*RHS, getCallBaseContext()), *this, UsedAssumedInformation); - if (!SimplifiedRHS.hasValue()) + if (!SimplifiedRHS) return ChangeStatus::UNCHANGED; if (!SimplifiedRHS.getValue()) return indicatePessimisticFixpoint(); @@ -8940,18 +9198,18 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) return indicatePessimisticFixpoint(); - auto &LHSAA = A.getAAFor<AAPotentialValues>(*this, IRPosition::value(*LHS), - DepClassTy::REQUIRED); + auto &LHSAA = A.getAAFor<AAPotentialConstantValues>( + *this, IRPosition::value(*LHS), DepClassTy::REQUIRED); if (!LHSAA.isValidState()) return indicatePessimisticFixpoint(); - auto &RHSAA = A.getAAFor<AAPotentialValues>(*this, IRPosition::value(*RHS), - DepClassTy::REQUIRED); + auto &RHSAA = A.getAAFor<AAPotentialConstantValues>( + *this, IRPosition::value(*RHS), DepClassTy::REQUIRED); if (!RHSAA.isValidState()) return indicatePessimisticFixpoint(); - const DenseSet<APInt> &LHSAAPVS = LHSAA.getAssumedSet(); - const DenseSet<APInt> &RHSAAPVS = RHSAA.getAssumedSet(); + const SetTy &LHSAAPVS = LHSAA.getAssumedSet(); + const SetTy &RHSAAPVS = RHSAA.getAssumedSet(); // TODO: make use of undef flag to limit potential values aggressively. bool MaybeTrue = false, MaybeFalse = false; @@ -9005,7 +9263,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { const auto &SimplifiedLHS = A.getAssumedSimplified(IRPosition::value(*LHS, getCallBaseContext()), *this, UsedAssumedInformation); - if (!SimplifiedLHS.hasValue()) + if (!SimplifiedLHS) return ChangeStatus::UNCHANGED; if (!SimplifiedLHS.getValue()) return indicatePessimisticFixpoint(); @@ -9014,7 +9272,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { const auto &SimplifiedRHS = A.getAssumedSimplified(IRPosition::value(*RHS, getCallBaseContext()), *this, UsedAssumedInformation); - if (!SimplifiedRHS.hasValue()) + if (!SimplifiedRHS) return ChangeStatus::UNCHANGED; if (!SimplifiedRHS.getValue()) return indicatePessimisticFixpoint(); @@ -9028,21 +9286,21 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { // Check if we only need one operand. bool OnlyLeft = false, OnlyRight = false; - if (C.hasValue() && *C && (*C)->isOneValue()) + if (C && *C && (*C)->isOneValue()) OnlyLeft = true; - else if (C.hasValue() && *C && (*C)->isZeroValue()) + else if (C && *C && (*C)->isZeroValue()) OnlyRight = true; - const AAPotentialValues *LHSAA = nullptr, *RHSAA = nullptr; + const AAPotentialConstantValues *LHSAA = nullptr, *RHSAA = nullptr; if (!OnlyRight) { - LHSAA = &A.getAAFor<AAPotentialValues>(*this, IRPosition::value(*LHS), - DepClassTy::REQUIRED); + LHSAA = &A.getAAFor<AAPotentialConstantValues>( + *this, IRPosition::value(*LHS), DepClassTy::REQUIRED); if (!LHSAA->isValidState()) return indicatePessimisticFixpoint(); } if (!OnlyLeft) { - RHSAA = &A.getAAFor<AAPotentialValues>(*this, IRPosition::value(*RHS), - DepClassTy::REQUIRED); + RHSAA = &A.getAAFor<AAPotentialConstantValues>( + *this, IRPosition::value(*RHS), DepClassTy::REQUIRED); if (!RHSAA->isValidState()) return indicatePessimisticFixpoint(); } @@ -9080,17 +9338,17 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { const auto &SimplifiedSrc = A.getAssumedSimplified(IRPosition::value(*Src, getCallBaseContext()), *this, UsedAssumedInformation); - if (!SimplifiedSrc.hasValue()) + if (!SimplifiedSrc) return ChangeStatus::UNCHANGED; if (!SimplifiedSrc.getValue()) return indicatePessimisticFixpoint(); Src = *SimplifiedSrc; - auto &SrcAA = A.getAAFor<AAPotentialValues>(*this, IRPosition::value(*Src), - DepClassTy::REQUIRED); + auto &SrcAA = A.getAAFor<AAPotentialConstantValues>( + *this, IRPosition::value(*Src), DepClassTy::REQUIRED); if (!SrcAA.isValidState()) return indicatePessimisticFixpoint(); - const DenseSet<APInt> &SrcAAPVS = SrcAA.getAssumedSet(); + const SetTy &SrcAAPVS = SrcAA.getAssumedSet(); if (SrcAA.undefIsContained()) unionAssumedWithUndef(); else { @@ -9113,7 +9371,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { const auto &SimplifiedLHS = A.getAssumedSimplified(IRPosition::value(*LHS, getCallBaseContext()), *this, UsedAssumedInformation); - if (!SimplifiedLHS.hasValue()) + if (!SimplifiedLHS) return ChangeStatus::UNCHANGED; if (!SimplifiedLHS.getValue()) return indicatePessimisticFixpoint(); @@ -9122,7 +9380,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { const auto &SimplifiedRHS = A.getAssumedSimplified(IRPosition::value(*RHS, getCallBaseContext()), *this, UsedAssumedInformation); - if (!SimplifiedRHS.hasValue()) + if (!SimplifiedRHS) return ChangeStatus::UNCHANGED; if (!SimplifiedRHS.getValue()) return indicatePessimisticFixpoint(); @@ -9131,18 +9389,18 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) return indicatePessimisticFixpoint(); - auto &LHSAA = A.getAAFor<AAPotentialValues>(*this, IRPosition::value(*LHS), - DepClassTy::REQUIRED); + auto &LHSAA = A.getAAFor<AAPotentialConstantValues>( + *this, IRPosition::value(*LHS), DepClassTy::REQUIRED); if (!LHSAA.isValidState()) return indicatePessimisticFixpoint(); - auto &RHSAA = A.getAAFor<AAPotentialValues>(*this, IRPosition::value(*RHS), - DepClassTy::REQUIRED); + auto &RHSAA = A.getAAFor<AAPotentialConstantValues>( + *this, IRPosition::value(*RHS), DepClassTy::REQUIRED); if (!RHSAA.isValidState()) return indicatePessimisticFixpoint(); - const DenseSet<APInt> &LHSAAPVS = LHSAA.getAssumedSet(); - const DenseSet<APInt> &RHSAAPVS = RHSAA.getAssumedSet(); + const SetTy &LHSAAPVS = LHSAA.getAssumedSet(); + const SetTy &RHSAAPVS = RHSAA.getAssumedSet(); const APInt Zero = APInt(LHS->getType()->getIntegerBitWidth(), 0); // TODO: make use of undef flag to limit potential values aggressively. @@ -9181,13 +9439,13 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { const auto &SimplifiedIncomingValue = A.getAssumedSimplified( IRPosition::value(*IncomingValue, getCallBaseContext()), *this, UsedAssumedInformation); - if (!SimplifiedIncomingValue.hasValue()) + if (!SimplifiedIncomingValue) continue; if (!SimplifiedIncomingValue.getValue()) return indicatePessimisticFixpoint(); IncomingValue = *SimplifiedIncomingValue; - auto &PotentialValuesAA = A.getAAFor<AAPotentialValues>( + auto &PotentialValuesAA = A.getAAFor<AAPotentialConstantValues>( *this, IRPosition::value(*IncomingValue), DepClassTy::REQUIRED); if (!PotentialValuesAA.isValidState()) return indicatePessimisticFixpoint(); @@ -9200,30 +9458,6 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { : ChangeStatus::CHANGED; } - ChangeStatus updateWithLoad(Attributor &A, LoadInst &L) { - if (!L.getType()->isIntegerTy()) - return indicatePessimisticFixpoint(); - - auto Union = [&](Value &V) { - if (isa<UndefValue>(V)) { - unionAssumedWithUndef(); - return true; - } - if (ConstantInt *CI = dyn_cast<ConstantInt>(&V)) { - unionAssumed(CI->getValue()); - return true; - } - return false; - }; - auto AssumedBefore = getAssumed(); - - if (!AAValueSimplifyImpl::handleLoad(A, *this, L, Union)) - return indicatePessimisticFixpoint(); - - return AssumedBefore == getAssumed() ? ChangeStatus::UNCHANGED - : ChangeStatus::CHANGED; - } - /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { Value &V = getAssociatedValue(); @@ -9244,9 +9478,6 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { if (auto *PHI = dyn_cast<PHINode>(I)) return updateWithPHINode(A, PHI); - if (auto *L = dyn_cast<LoadInst>(I)) - return updateWithLoad(A, *L); - return indicatePessimisticFixpoint(); } @@ -9256,14 +9487,15 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { } }; -struct AAPotentialValuesFunction : AAPotentialValuesImpl { - AAPotentialValuesFunction(const IRPosition &IRP, Attributor &A) - : AAPotentialValuesImpl(IRP, A) {} +struct AAPotentialConstantValuesFunction : AAPotentialConstantValuesImpl { + AAPotentialConstantValuesFunction(const IRPosition &IRP, Attributor &A) + : AAPotentialConstantValuesImpl(IRP, A) {} /// See AbstractAttribute::initialize(...). ChangeStatus updateImpl(Attributor &A) override { - llvm_unreachable("AAPotentialValues(Function|CallSite)::updateImpl will " - "not be called"); + llvm_unreachable( + "AAPotentialConstantValues(Function|CallSite)::updateImpl will " + "not be called"); } /// See AbstractAttribute::trackStatistics() @@ -9272,9 +9504,9 @@ struct AAPotentialValuesFunction : AAPotentialValuesImpl { } }; -struct AAPotentialValuesCallSite : AAPotentialValuesFunction { - AAPotentialValuesCallSite(const IRPosition &IRP, Attributor &A) - : AAPotentialValuesFunction(IRP, A) {} +struct AAPotentialConstantValuesCallSite : AAPotentialConstantValuesFunction { + AAPotentialConstantValuesCallSite(const IRPosition &IRP, Attributor &A) + : AAPotentialConstantValuesFunction(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { @@ -9282,11 +9514,13 @@ struct AAPotentialValuesCallSite : AAPotentialValuesFunction { } }; -struct AAPotentialValuesCallSiteReturned - : AACallSiteReturnedFromReturned<AAPotentialValues, AAPotentialValuesImpl> { - AAPotentialValuesCallSiteReturned(const IRPosition &IRP, Attributor &A) - : AACallSiteReturnedFromReturned<AAPotentialValues, - AAPotentialValuesImpl>(IRP, A) {} +struct AAPotentialConstantValuesCallSiteReturned + : AACallSiteReturnedFromReturned<AAPotentialConstantValues, + AAPotentialConstantValuesImpl> { + AAPotentialConstantValuesCallSiteReturned(const IRPosition &IRP, + Attributor &A) + : AACallSiteReturnedFromReturned<AAPotentialConstantValues, + AAPotentialConstantValuesImpl>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { @@ -9294,13 +9528,15 @@ struct AAPotentialValuesCallSiteReturned } }; -struct AAPotentialValuesCallSiteArgument : AAPotentialValuesFloating { - AAPotentialValuesCallSiteArgument(const IRPosition &IRP, Attributor &A) - : AAPotentialValuesFloating(IRP, A) {} +struct AAPotentialConstantValuesCallSiteArgument + : AAPotentialConstantValuesFloating { + AAPotentialConstantValuesCallSiteArgument(const IRPosition &IRP, + Attributor &A) + : AAPotentialConstantValuesFloating(IRP, A) {} /// See AbstractAttribute::initialize(..). void initialize(Attributor &A) override { - AAPotentialValuesImpl::initialize(A); + AAPotentialConstantValuesImpl::initialize(A); if (isAtFixpoint()) return; @@ -9323,8 +9559,8 @@ struct AAPotentialValuesCallSiteArgument : AAPotentialValuesFloating { ChangeStatus updateImpl(Attributor &A) override { Value &V = getAssociatedValue(); auto AssumedBefore = getAssumed(); - auto &AA = A.getAAFor<AAPotentialValues>(*this, IRPosition::value(V), - DepClassTy::REQUIRED); + auto &AA = A.getAAFor<AAPotentialConstantValues>( + *this, IRPosition::value(V), DepClassTy::REQUIRED); const auto &S = AA.getAssumed(); unionAssumed(S); return AssumedBefore == getAssumed() ? ChangeStatus::UNCHANGED @@ -9396,7 +9632,7 @@ struct AANoUndefImpl : AANoUndef { // considered to be dead. We don't manifest noundef in such positions for // the same reason above. if (!A.getAssumedSimplified(getIRPosition(), *this, UsedAssumedInformation) - .hasValue()) + .has_value()) return ChangeStatus::UNCHANGED; return AANoUndef::manifest(A); } @@ -9564,7 +9800,9 @@ struct AACallEdgesCallSite : public AACallEdgesImpl { CallBase *CB = cast<CallBase>(getCtxI()); if (CB->isInlineAsm()) { - setHasUnknownCallee(false, Change); + if (!hasAssumption(*CB->getCaller(), "ompx_no_call_asm") && + !hasAssumption(*CB, "ompx_no_call_asm")) + setHasUnknownCallee(false, Change); return Change; } @@ -9691,7 +9929,7 @@ private: ArrayRef<const AACallEdges *> AAEdgesList, const Function &Fn) { Optional<bool> Cached = isCachedReachable(Fn); - if (Cached.hasValue()) + if (Cached) return Cached.getValue(); // The query was not cached, thus it is new. We need to request an update @@ -9726,6 +9964,10 @@ private: const SetVector<Function *> &Edges = AAEdges->getOptimisticEdges(); for (Function *Edge : Edges) { + // Functions that do not call back into the module can be ignored. + if (Edge->hasFnAttribute(Attribute::NoCallback)) + continue; + // We don't need a dependency if the result is reachable. const AAFunctionReachability &EdgeReachability = A.getAAFor<AAFunctionReachability>( @@ -9855,22 +10097,21 @@ public: } // Update the Instruction queries. - const AAReachability *Reachability; if (!InstQueries.empty()) { - Reachability = &A.getAAFor<AAReachability>( + const AAReachability *Reachability = &A.getAAFor<AAReachability>( *this, IRPosition::function(*getAssociatedFunction()), DepClassTy::REQUIRED); - } - // Check for local callbases first. - for (auto &InstPair : InstQueries) { - SmallVector<const AACallEdges *> CallEdges; - bool AllKnown = - getReachableCallEdges(A, *Reachability, *InstPair.first, CallEdges); - // Update will return change if we this effects any queries. - if (!AllKnown) - InstPair.second.CanReachUnknownCallee = true; - Change |= InstPair.second.update(A, *this, CallEdges); + // Check for local callbases first. + for (auto &InstPair : InstQueries) { + SmallVector<const AACallEdges *> CallEdges; + bool AllKnown = + getReachableCallEdges(A, *Reachability, *InstPair.first, CallEdges); + // Update will return change if we this effects any queries. + if (!AllKnown) + InstPair.second.CanReachUnknownCallee = true; + Change |= InstPair.second.update(A, *this, CallEdges); + } } return Change; @@ -9897,13 +10138,15 @@ private: /// Used to answer if a call base inside this function can reach a specific /// function. - DenseMap<const CallBase *, QueryResolver> CBQueries; + MapVector<const CallBase *, QueryResolver> CBQueries; /// This is for instruction queries than scan "forward". - DenseMap<const Instruction *, QueryResolver> InstQueries; + MapVector<const Instruction *, QueryResolver> InstQueries; }; +} // namespace /// ---------------------- Assumption Propagation ------------------------------ +namespace { struct AAAssumptionInfoImpl : public AAAssumptionInfo { AAAssumptionInfoImpl(const IRPosition &IRP, Attributor &A, const DenseSet<StringRef> &Known) @@ -10037,6 +10280,7 @@ private: return Assumptions; } }; +} // namespace AACallGraphNode *AACallEdgeIterator::operator*() const { return static_cast<AACallGraphNode *>(const_cast<AACallEdges *>( @@ -10059,6 +10303,7 @@ const char AANoReturn::ID = 0; const char AAIsDead::ID = 0; const char AADereferenceable::ID = 0; const char AAAlign::ID = 0; +const char AAInstanceInfo::ID = 0; const char AANoCapture::ID = 0; const char AAValueSimplify::ID = 0; const char AAHeapToStack::ID = 0; @@ -10066,7 +10311,7 @@ const char AAPrivatizablePtr::ID = 0; const char AAMemoryBehavior::ID = 0; const char AAMemoryLocation::ID = 0; const char AAValueConstantRange::ID = 0; -const char AAPotentialValues::ID = 0; +const char AAPotentialConstantValues::ID = 0; const char AANoUndef::ID = 0; const char AACallEdges::ID = 0; const char AAFunctionReachability::ID = 0; @@ -10181,9 +10426,10 @@ CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoAlias) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAPrivatizablePtr) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AADereferenceable) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAAlign) +CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAInstanceInfo) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoCapture) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAValueConstantRange) -CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAPotentialValues) +CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAPotentialConstantValues) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AANoUndef) CREATE_VALUE_ABSTRACT_ATTRIBUTE_FOR_POSITION(AAPointerInfo) diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/BlockExtractor.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/BlockExtractor.cpp index 7c178f9a9834..9e27ae49a901 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/BlockExtractor.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/BlockExtractor.cpp @@ -135,7 +135,8 @@ void BlockExtractor::loadFile() { if (LineSplit.empty()) continue; if (LineSplit.size()!=2) - report_fatal_error("Invalid line format, expecting lines like: 'funcname bb1[;bb2..]'"); + report_fatal_error("Invalid line format, expecting lines like: 'funcname bb1[;bb2..]'", + /*GenCrashDiag=*/false); SmallVector<StringRef, 4> BBNames; LineSplit[1].split(BBNames, ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false); @@ -194,13 +195,15 @@ bool BlockExtractor::runOnModule(Module &M) { for (const auto &BInfo : BlocksByName) { Function *F = M.getFunction(BInfo.first); if (!F) - report_fatal_error("Invalid function name specified in the input file"); + report_fatal_error("Invalid function name specified in the input file", + /*GenCrashDiag=*/false); for (const auto &BBInfo : BInfo.second) { auto Res = llvm::find_if(*F, [&](const BasicBlock &BB) { return BB.getName().equals(BBInfo); }); if (Res == F->end()) - report_fatal_error("Invalid block name specified in the input file"); + report_fatal_error("Invalid block name specified in the input file", + /*GenCrashDiag=*/false); GroupsOfBlocks[NextGroupIdx].push_back(&*Res); } ++NextGroupIdx; @@ -212,7 +215,7 @@ bool BlockExtractor::runOnModule(Module &M) { for (BasicBlock *BB : BBs) { // Check if the module contains BB. if (BB->getParent()->getParent() != &M) - report_fatal_error("Invalid basic block"); + report_fatal_error("Invalid basic block", /*GenCrashDiag=*/false); LLVM_DEBUG(dbgs() << "BlockExtractor: Extracting " << BB->getParent()->getName() << ":" << BB->getName() << "\n"); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/CalledValuePropagation.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/CalledValuePropagation.cpp index 927dceec8865..64bfcb2a9a9f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/CalledValuePropagation.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/CalledValuePropagation.cpp @@ -19,11 +19,13 @@ #include "llvm/Transforms/IPO/CalledValuePropagation.h" #include "llvm/Analysis/SparsePropagation.h" #include "llvm/Analysis/ValueLatticeUtils.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/MDBuilder.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" + using namespace llvm; #define DEBUG_TYPE "called-value-propagation" @@ -68,7 +70,7 @@ public: } }; - CVPLatticeVal() : LatticeState(Undefined) {} + CVPLatticeVal() = default; CVPLatticeVal(CVPLatticeStateTy LatticeState) : LatticeState(LatticeState) {} CVPLatticeVal(std::vector<Function *> &&Functions) : LatticeState(FunctionSet), Functions(std::move(Functions)) { @@ -94,7 +96,7 @@ public: private: /// Holds the state this lattice value is in. - CVPLatticeStateTy LatticeState; + CVPLatticeStateTy LatticeState = Undefined; /// Holds functions indicating the possible targets of call sites. This set /// is empty for lattice values in the undefined, overdefined, and untracked diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ConstantMerge.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ConstantMerge.cpp index 178d3f41963e..73af30ece47c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ConstantMerge.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ConstantMerge.cpp @@ -85,7 +85,7 @@ static void copyDebugLocMetadata(const GlobalVariable *From, } static Align getAlign(GlobalVariable *GV) { - return GV->getAlign().getValueOr( + return GV->getAlign().value_or( GV->getParent()->getDataLayout().getPreferredAlign(GV)); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp index 2fe9a59ad210..dfe33ac9da0d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp @@ -15,21 +15,16 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Triple.h" -#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalObject.h" -#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" -#include "llvm/IR/Operator.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" using namespace llvm; diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp index 2a6e38b0437f..99fa4baf355d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -16,18 +16,17 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/IPO/DeadArgumentElimination.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DIBuilder.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" @@ -44,9 +43,9 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/DeadArgumentElimination.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include <cassert> -#include <cstdint> #include <utility> #include <vector> @@ -55,36 +54,36 @@ using namespace llvm; #define DEBUG_TYPE "deadargelim" STATISTIC(NumArgumentsEliminated, "Number of unread args removed"); -STATISTIC(NumRetValsEliminated , "Number of unused return values removed"); -STATISTIC(NumArgumentsReplacedWithUndef, - "Number of unread args replaced with undef"); +STATISTIC(NumRetValsEliminated, "Number of unused return values removed"); +STATISTIC(NumArgumentsReplacedWithPoison, + "Number of unread args replaced with poison"); namespace { - /// DAE - The dead argument elimination pass. - class DAE : public ModulePass { - protected: - // DAH uses this to specify a different ID. - explicit DAE(char &ID) : ModulePass(ID) {} +/// The dead argument elimination pass. +class DAE : public ModulePass { +protected: + // DAH uses this to specify a different ID. + explicit DAE(char &ID) : ModulePass(ID) {} - public: - static char ID; // Pass identification, replacement for typeid +public: + static char ID; // Pass identification, replacement for typeid - DAE() : ModulePass(ID) { - initializeDAEPass(*PassRegistry::getPassRegistry()); - } + DAE() : ModulePass(ID) { + initializeDAEPass(*PassRegistry::getPassRegistry()); + } - bool runOnModule(Module &M) override { - if (skipModule(M)) - return false; - DeadArgumentEliminationPass DAEP(ShouldHackArguments()); - ModuleAnalysisManager DummyMAM; - PreservedAnalyses PA = DAEP.run(M, DummyMAM); - return !PA.areAllPreserved(); - } + bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + DeadArgumentEliminationPass DAEP(shouldHackArguments()); + ModuleAnalysisManager DummyMAM; + PreservedAnalyses PA = DAEP.run(M, DummyMAM); + return !PA.areAllPreserved(); + } - virtual bool ShouldHackArguments() const { return false; } - }; + virtual bool shouldHackArguments() const { return false; } +}; } // end anonymous namespace @@ -94,51 +93,51 @@ INITIALIZE_PASS(DAE, "deadargelim", "Dead Argument Elimination", false, false) namespace { - /// DAH - DeadArgumentHacking pass - Same as dead argument elimination, but - /// deletes arguments to functions which are external. This is only for use - /// by bugpoint. - struct DAH : public DAE { - static char ID; +/// The DeadArgumentHacking pass, same as dead argument elimination, but deletes +/// arguments to functions which are external. This is only for use by bugpoint. +struct DAH : public DAE { + static char ID; - DAH() : DAE(ID) {} + DAH() : DAE(ID) {} - bool ShouldHackArguments() const override { return true; } - }; + bool shouldHackArguments() const override { return true; } +}; } // end anonymous namespace char DAH::ID = 0; INITIALIZE_PASS(DAH, "deadarghaX0r", - "Dead Argument Hacking (BUGPOINT USE ONLY; DO NOT USE)", - false, false) + "Dead Argument Hacking (BUGPOINT USE ONLY; DO NOT USE)", false, + false) -/// createDeadArgEliminationPass - This pass removes arguments from functions -/// which are not used by the body of the function. +/// This pass removes arguments from functions which are not used by the body of +/// the function. ModulePass *llvm::createDeadArgEliminationPass() { return new DAE(); } ModulePass *llvm::createDeadArgHackingPass() { return new DAH(); } -/// DeleteDeadVarargs - If this is an function that takes a ... list, and if -/// llvm.vastart is never called, the varargs list is dead for the function. -bool DeadArgumentEliminationPass::DeleteDeadVarargs(Function &Fn) { - assert(Fn.getFunctionType()->isVarArg() && "Function isn't varargs!"); - if (Fn.isDeclaration() || !Fn.hasLocalLinkage()) return false; +/// If this is an function that takes a ... list, and if llvm.vastart is never +/// called, the varargs list is dead for the function. +bool DeadArgumentEliminationPass::deleteDeadVarargs(Function &F) { + assert(F.getFunctionType()->isVarArg() && "Function isn't varargs!"); + if (F.isDeclaration() || !F.hasLocalLinkage()) + return false; // Ensure that the function is only directly called. - if (Fn.hasAddressTaken()) + if (F.hasAddressTaken()) return false; // Don't touch naked functions. The assembly might be using an argument, or // otherwise rely on the frame layout in a way that this analysis will not // see. - if (Fn.hasFnAttribute(Attribute::Naked)) { + if (F.hasFnAttribute(Attribute::Naked)) { return false; } // Okay, we know we can transform this function if safe. Scan its body // looking for calls marked musttail or calls to llvm.vastart. - for (BasicBlock &BB : Fn) { + for (BasicBlock &BB : F) { for (Instruction &I : BB) { CallInst *CI = dyn_cast<CallInst>(&I); if (!CI) @@ -157,25 +156,24 @@ bool DeadArgumentEliminationPass::DeleteDeadVarargs(Function &Fn) { // Start by computing a new prototype for the function, which is the same as // the old function, but doesn't have isVarArg set. - FunctionType *FTy = Fn.getFunctionType(); + FunctionType *FTy = F.getFunctionType(); std::vector<Type *> Params(FTy->param_begin(), FTy->param_end()); - FunctionType *NFTy = FunctionType::get(FTy->getReturnType(), - Params, false); + FunctionType *NFTy = FunctionType::get(FTy->getReturnType(), Params, false); unsigned NumArgs = Params.size(); // Create the new function body and insert it into the module... - Function *NF = Function::Create(NFTy, Fn.getLinkage(), Fn.getAddressSpace()); - NF->copyAttributesFrom(&Fn); - NF->setComdat(Fn.getComdat()); - Fn.getParent()->getFunctionList().insert(Fn.getIterator(), NF); - NF->takeName(&Fn); + Function *NF = Function::Create(NFTy, F.getLinkage(), F.getAddressSpace()); + NF->copyAttributesFrom(&F); + NF->setComdat(F.getComdat()); + F.getParent()->getFunctionList().insert(F.getIterator(), NF); + NF->takeName(&F); - // Loop over all of the callers of the function, transforming the call sites + // Loop over all the callers of the function, transforming the call sites // to pass in a smaller number of arguments into the new function. // std::vector<Value *> Args; - for (User *U : llvm::make_early_inc_range(Fn.users())) { + for (User *U : llvm::make_early_inc_range(F.users())) { CallBase *CB = dyn_cast<CallBase>(U); if (!CB) continue; @@ -189,7 +187,7 @@ bool DeadArgumentEliminationPass::DeleteDeadVarargs(Function &Fn) { SmallVector<AttributeSet, 8> ArgAttrs; for (unsigned ArgNo = 0; ArgNo < NumArgs; ++ArgNo) ArgAttrs.push_back(PAL.getParamAttrs(ArgNo)); - PAL = AttributeList::get(Fn.getContext(), PAL.getFnAttrs(), + PAL = AttributeList::get(F.getContext(), PAL.getFnAttrs(), PAL.getRetAttrs(), ArgAttrs); } @@ -224,64 +222,67 @@ bool DeadArgumentEliminationPass::DeleteDeadVarargs(Function &Fn) { // Since we have now created the new function, splice the body of the old // function right into the new function, leaving the old rotting hulk of the // function empty. - NF->getBasicBlockList().splice(NF->begin(), Fn.getBasicBlockList()); + NF->getBasicBlockList().splice(NF->begin(), F.getBasicBlockList()); // Loop over the argument list, transferring uses of the old arguments over to - // the new arguments, also transferring over the names as well. While we're at - // it, remove the dead arguments from the DeadArguments list. - for (Function::arg_iterator I = Fn.arg_begin(), E = Fn.arg_end(), - I2 = NF->arg_begin(); I != E; ++I, ++I2) { + // the new arguments, also transferring over the names as well. While we're + // at it, remove the dead arguments from the DeadArguments list. + for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(), + I2 = NF->arg_begin(); + I != E; ++I, ++I2) { // Move the name and users over to the new version. I->replaceAllUsesWith(&*I2); I2->takeName(&*I); } - // Clone metadatas from the old function, including debug info descriptor. + // Clone metadata from the old function, including debug info descriptor. SmallVector<std::pair<unsigned, MDNode *>, 1> MDs; - Fn.getAllMetadata(MDs); + F.getAllMetadata(MDs); for (auto MD : MDs) NF->addMetadata(MD.first, *MD.second); // Fix up any BlockAddresses that refer to the function. - Fn.replaceAllUsesWith(ConstantExpr::getBitCast(NF, Fn.getType())); + F.replaceAllUsesWith(ConstantExpr::getBitCast(NF, F.getType())); // Delete the bitcast that we just created, so that NF does not // appear to be address-taken. NF->removeDeadConstantUsers(); // Finally, nuke the old function. - Fn.eraseFromParent(); + F.eraseFromParent(); return true; } -/// RemoveDeadArgumentsFromCallers - Checks if the given function has any -/// arguments that are unused, and changes the caller parameters to be undefined -/// instead. -bool DeadArgumentEliminationPass::RemoveDeadArgumentsFromCallers(Function &Fn) { +/// Checks if the given function has any arguments that are unused, and changes +/// the caller parameters to be poison instead. +bool DeadArgumentEliminationPass::removeDeadArgumentsFromCallers(Function &F) { // We cannot change the arguments if this TU does not define the function or // if the linker may choose a function body from another TU, even if the // nominal linkage indicates that other copies of the function have the same // semantics. In the below example, the dead load from %p may not have been - // eliminated from the linker-chosen copy of f, so replacing %p with undef + // eliminated from the linker-chosen copy of f, so replacing %p with poison // in callers may introduce undefined behavior. // // define linkonce_odr void @f(i32* %p) { // %v = load i32 %p // ret void // } - if (!Fn.hasExactDefinition()) + if (!F.hasExactDefinition()) return false; - // Functions with local linkage should already have been handled, except the - // fragile (variadic) ones which we can improve here. - if (Fn.hasLocalLinkage() && !Fn.getFunctionType()->isVarArg()) + // Functions with local linkage should already have been handled, except if + // they are fully alive (e.g., called indirectly) and except for the fragile + // (variadic) ones. In these cases, we may still be able to improve their + // statically known call sites. + if ((F.hasLocalLinkage() && !LiveFunctions.count(&F)) && + !F.getFunctionType()->isVarArg()) return false; // Don't touch naked functions. The assembly might be using an argument, or // otherwise rely on the frame layout in a way that this analysis will not // see. - if (Fn.hasFnAttribute(Attribute::Naked)) + if (F.hasFnAttribute(Attribute::Naked)) return false; - if (Fn.use_empty()) + if (F.use_empty()) return false; SmallVector<unsigned, 8> UnusedArgs; @@ -289,35 +290,36 @@ bool DeadArgumentEliminationPass::RemoveDeadArgumentsFromCallers(Function &Fn) { AttributeMask UBImplyingAttributes = AttributeFuncs::getUBImplyingAttributes(); - for (Argument &Arg : Fn.args()) { + for (Argument &Arg : F.args()) { if (!Arg.hasSwiftErrorAttr() && Arg.use_empty() && !Arg.hasPassPointeeByValueCopyAttr()) { if (Arg.isUsedByMetadata()) { - Arg.replaceAllUsesWith(UndefValue::get(Arg.getType())); + Arg.replaceAllUsesWith(PoisonValue::get(Arg.getType())); Changed = true; } UnusedArgs.push_back(Arg.getArgNo()); - Fn.removeParamAttrs(Arg.getArgNo(), UBImplyingAttributes); + F.removeParamAttrs(Arg.getArgNo(), UBImplyingAttributes); } } if (UnusedArgs.empty()) return false; - for (Use &U : Fn.uses()) { + for (Use &U : F.uses()) { CallBase *CB = dyn_cast<CallBase>(U.getUser()); - if (!CB || !CB->isCallee(&U)) + if (!CB || !CB->isCallee(&U) || + CB->getFunctionType() != F.getFunctionType()) continue; - // Now go through all unused args and replace them with "undef". + // Now go through all unused args and replace them with poison. for (unsigned I = 0, E = UnusedArgs.size(); I != E; ++I) { unsigned ArgNo = UnusedArgs[I]; Value *Arg = CB->getArgOperand(ArgNo); - CB->setArgOperand(ArgNo, UndefValue::get(Arg->getType())); + CB->setArgOperand(ArgNo, PoisonValue::get(Arg->getType())); CB->removeParamAttrs(ArgNo, UBImplyingAttributes); - ++NumArgumentsReplacedWithUndef; + ++NumArgumentsReplacedWithPoison; Changed = true; } } @@ -328,16 +330,15 @@ bool DeadArgumentEliminationPass::RemoveDeadArgumentsFromCallers(Function &Fn) { /// Convenience function that returns the number of return values. It returns 0 /// for void functions and 1 for functions not returning a struct. It returns /// the number of struct elements for functions returning a struct. -static unsigned NumRetVals(const Function *F) { +static unsigned numRetVals(const Function *F) { Type *RetTy = F->getReturnType(); if (RetTy->isVoidTy()) return 0; - else if (StructType *STy = dyn_cast<StructType>(RetTy)) + if (StructType *STy = dyn_cast<StructType>(RetTy)) return STy->getNumElements(); - else if (ArrayType *ATy = dyn_cast<ArrayType>(RetTy)) + if (ArrayType *ATy = dyn_cast<ArrayType>(RetTy)) return ATy->getNumElements(); - else - return 1; + return 1; } /// Returns the sub-type a function will return at a given Idx. Should @@ -349,20 +350,18 @@ static Type *getRetComponentType(const Function *F, unsigned Idx) { if (StructType *STy = dyn_cast<StructType>(RetTy)) return STy->getElementType(Idx); - else if (ArrayType *ATy = dyn_cast<ArrayType>(RetTy)) + if (ArrayType *ATy = dyn_cast<ArrayType>(RetTy)) return ATy->getElementType(); - else - return RetTy; + return RetTy; } -/// MarkIfNotLive - This checks Use for liveness in LiveValues. If Use is not -/// live, it adds Use to the MaybeLiveUses argument. Returns the determined -/// liveness of Use. +/// Checks Use for liveness in LiveValues. If Use is not live, it adds Use to +/// the MaybeLiveUses argument. Returns the determined liveness of Use. DeadArgumentEliminationPass::Liveness -DeadArgumentEliminationPass::MarkIfNotLive(RetOrArg Use, +DeadArgumentEliminationPass::markIfNotLive(RetOrArg Use, UseVector &MaybeLiveUses) { // We're live if our use or its Function is already marked as live. - if (IsLive(Use)) + if (isLive(Use)) return Live; // We're maybe live otherwise, but remember that we must become live if @@ -371,127 +370,127 @@ DeadArgumentEliminationPass::MarkIfNotLive(RetOrArg Use, return MaybeLive; } -/// SurveyUse - This looks at a single use of an argument or return value -/// and determines if it should be alive or not. Adds this use to MaybeLiveUses -/// if it causes the used value to become MaybeLive. +/// Looks at a single use of an argument or return value and determines if it +/// should be alive or not. Adds this use to MaybeLiveUses if it causes the +/// used value to become MaybeLive. /// /// RetValNum is the return value number to use when this use is used in a /// return instruction. This is used in the recursion, you should always leave /// it at 0. DeadArgumentEliminationPass::Liveness -DeadArgumentEliminationPass::SurveyUse(const Use *U, UseVector &MaybeLiveUses, +DeadArgumentEliminationPass::surveyUse(const Use *U, UseVector &MaybeLiveUses, unsigned RetValNum) { - const User *V = U->getUser(); - if (const ReturnInst *RI = dyn_cast<ReturnInst>(V)) { - // The value is returned from a function. It's only live when the - // function's return value is live. We use RetValNum here, for the case - // that U is really a use of an insertvalue instruction that uses the - // original Use. - const Function *F = RI->getParent()->getParent(); - if (RetValNum != -1U) { - RetOrArg Use = CreateRet(F, RetValNum); - // We might be live, depending on the liveness of Use. - return MarkIfNotLive(Use, MaybeLiveUses); - } else { - DeadArgumentEliminationPass::Liveness Result = MaybeLive; - for (unsigned Ri = 0; Ri < NumRetVals(F); ++Ri) { - RetOrArg Use = CreateRet(F, Ri); - // We might be live, depending on the liveness of Use. If any - // sub-value is live, then the entire value is considered live. This - // is a conservative choice, and better tracking is possible. - DeadArgumentEliminationPass::Liveness SubResult = - MarkIfNotLive(Use, MaybeLiveUses); - if (Result != Live) - Result = SubResult; - } - return Result; - } + const User *V = U->getUser(); + if (const ReturnInst *RI = dyn_cast<ReturnInst>(V)) { + // The value is returned from a function. It's only live when the + // function's return value is live. We use RetValNum here, for the case + // that U is really a use of an insertvalue instruction that uses the + // original Use. + const Function *F = RI->getParent()->getParent(); + if (RetValNum != -1U) { + RetOrArg Use = createRet(F, RetValNum); + // We might be live, depending on the liveness of Use. + return markIfNotLive(Use, MaybeLiveUses); } - if (const InsertValueInst *IV = dyn_cast<InsertValueInst>(V)) { - if (U->getOperandNo() != InsertValueInst::getAggregateOperandIndex() - && IV->hasIndices()) - // The use we are examining is inserted into an aggregate. Our liveness - // depends on all uses of that aggregate, but if it is used as a return - // value, only index at which we were inserted counts. - RetValNum = *IV->idx_begin(); - - // Note that if we are used as the aggregate operand to the insertvalue, - // we don't change RetValNum, but do survey all our uses. - - Liveness Result = MaybeLive; - for (const Use &UU : IV->uses()) { - Result = SurveyUse(&UU, MaybeLiveUses, RetValNum); - if (Result == Live) - break; - } - return Result; + + DeadArgumentEliminationPass::Liveness Result = MaybeLive; + for (unsigned Ri = 0; Ri < numRetVals(F); ++Ri) { + RetOrArg Use = createRet(F, Ri); + // We might be live, depending on the liveness of Use. If any + // sub-value is live, then the entire value is considered live. This + // is a conservative choice, and better tracking is possible. + DeadArgumentEliminationPass::Liveness SubResult = + markIfNotLive(Use, MaybeLiveUses); + if (Result != Live) + Result = SubResult; + } + return Result; + } + + if (const InsertValueInst *IV = dyn_cast<InsertValueInst>(V)) { + if (U->getOperandNo() != InsertValueInst::getAggregateOperandIndex() && + IV->hasIndices()) + // The use we are examining is inserted into an aggregate. Our liveness + // depends on all uses of that aggregate, but if it is used as a return + // value, only index at which we were inserted counts. + RetValNum = *IV->idx_begin(); + + // Note that if we are used as the aggregate operand to the insertvalue, + // we don't change RetValNum, but do survey all our uses. + + Liveness Result = MaybeLive; + for (const Use &UU : IV->uses()) { + Result = surveyUse(&UU, MaybeLiveUses, RetValNum); + if (Result == Live) + break; } + return Result; + } - if (const auto *CB = dyn_cast<CallBase>(V)) { - const Function *F = CB->getCalledFunction(); - if (F) { - // Used in a direct call. + if (const auto *CB = dyn_cast<CallBase>(V)) { + const Function *F = CB->getCalledFunction(); + if (F) { + // Used in a direct call. - // The function argument is live if it is used as a bundle operand. - if (CB->isBundleOperand(U)) - return Live; + // The function argument is live if it is used as a bundle operand. + if (CB->isBundleOperand(U)) + return Live; - // Find the argument number. We know for sure that this use is an - // argument, since if it was the function argument this would be an - // indirect call and the we know can't be looking at a value of the - // label type (for the invoke instruction). - unsigned ArgNo = CB->getArgOperandNo(U); + // Find the argument number. We know for sure that this use is an + // argument, since if it was the function argument this would be an + // indirect call and that we know can't be looking at a value of the + // label type (for the invoke instruction). + unsigned ArgNo = CB->getArgOperandNo(U); - if (ArgNo >= F->getFunctionType()->getNumParams()) - // The value is passed in through a vararg! Must be live. - return Live; + if (ArgNo >= F->getFunctionType()->getNumParams()) + // The value is passed in through a vararg! Must be live. + return Live; - assert(CB->getArgOperand(ArgNo) == CB->getOperand(U->getOperandNo()) && - "Argument is not where we expected it"); + assert(CB->getArgOperand(ArgNo) == CB->getOperand(U->getOperandNo()) && + "Argument is not where we expected it"); - // Value passed to a normal call. It's only live when the corresponding - // argument to the called function turns out live. - RetOrArg Use = CreateArg(F, ArgNo); - return MarkIfNotLive(Use, MaybeLiveUses); - } + // Value passed to a normal call. It's only live when the corresponding + // argument to the called function turns out live. + RetOrArg Use = createArg(F, ArgNo); + return markIfNotLive(Use, MaybeLiveUses); } - // Used in any other way? Value must be live. - return Live; + } + // Used in any other way? Value must be live. + return Live; } -/// SurveyUses - This looks at all the uses of the given value +/// Looks at all the uses of the given value /// Returns the Liveness deduced from the uses of this value. /// /// Adds all uses that cause the result to be MaybeLive to MaybeLiveRetUses. If /// the result is Live, MaybeLiveUses might be modified but its content should /// be ignored (since it might not be complete). DeadArgumentEliminationPass::Liveness -DeadArgumentEliminationPass::SurveyUses(const Value *V, +DeadArgumentEliminationPass::surveyUses(const Value *V, UseVector &MaybeLiveUses) { // Assume it's dead (which will only hold if there are no uses at all..). Liveness Result = MaybeLive; // Check each use. for (const Use &U : V->uses()) { - Result = SurveyUse(&U, MaybeLiveUses); + Result = surveyUse(&U, MaybeLiveUses); if (Result == Live) break; } return Result; } -// SurveyFunction - This performs the initial survey of the specified function, -// checking out whether or not it uses any of its incoming arguments or whether -// any callers use the return value. This fills in the LiveValues set and Uses -// map. -// -// We consider arguments of non-internal functions to be intrinsically alive as -// well as arguments to functions which have their "address taken". -void DeadArgumentEliminationPass::SurveyFunction(const Function &F) { +/// Performs the initial survey of the specified function, checking out whether +/// it uses any of its incoming arguments or whether any callers use the return +/// value. This fills in the LiveValues set and Uses map. +/// +/// We consider arguments of non-internal functions to be intrinsically alive as +/// well as arguments to functions which have their "address taken". +void DeadArgumentEliminationPass::surveyFunction(const Function &F) { // Functions with inalloca/preallocated parameters are expecting args in a // particular register and memory layout. if (F.getAttributes().hasAttrSomewhere(Attribute::InAlloca) || F.getAttributes().hasAttrSomewhere(Attribute::Preallocated)) { - MarkLive(F); + markLive(F); return; } @@ -499,11 +498,11 @@ void DeadArgumentEliminationPass::SurveyFunction(const Function &F) { // otherwise rely on the frame layout in a way that this analysis will not // see. if (F.hasFnAttribute(Attribute::Naked)) { - MarkLive(F); + markLive(F); return; } - unsigned RetCount = NumRetVals(&F); + unsigned RetCount = numRetVals(&F); // Assume all return values are dead using RetVals = SmallVector<Liveness, 5>; @@ -518,20 +517,10 @@ void DeadArgumentEliminationPass::SurveyFunction(const Function &F) { RetUses MaybeLiveRetUses(RetCount); bool HasMustTailCalls = false; - - for (Function::const_iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { - if (const ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator())) { - if (RI->getNumOperands() != 0 && RI->getOperand(0)->getType() - != F.getFunctionType()->getReturnType()) { - // We don't support old style multiple return values. - MarkLive(F); - return; - } - } - + for (const BasicBlock &BB : F) { // If we have any returns of `musttail` results - the signature can't // change - if (BB->getTerminatingMustTailCall() != nullptr) + if (BB.getTerminatingMustTailCall() != nullptr) HasMustTailCalls = true; } @@ -541,7 +530,7 @@ void DeadArgumentEliminationPass::SurveyFunction(const Function &F) { } if (!F.hasLocalLinkage() && (!ShouldHackArguments || F.isIntrinsic())) { - MarkLive(F); + markLive(F); return; } @@ -559,8 +548,9 @@ void DeadArgumentEliminationPass::SurveyFunction(const Function &F) { // If the function is PASSED IN as an argument, its address has been // taken. const auto *CB = dyn_cast<CallBase>(U.getUser()); - if (!CB || !CB->isCallee(&U)) { - MarkLive(F); + if (!CB || !CB->isCallee(&U) || + CB->getFunctionType() != F.getFunctionType()) { + markLive(F); return; } @@ -577,13 +567,13 @@ void DeadArgumentEliminationPass::SurveyFunction(const Function &F) { continue; // Check all uses of the return value. - for (const Use &U : CB->uses()) { - if (ExtractValueInst *Ext = dyn_cast<ExtractValueInst>(U.getUser())) { + for (const Use &UU : CB->uses()) { + if (ExtractValueInst *Ext = dyn_cast<ExtractValueInst>(UU.getUser())) { // This use uses a part of our return value, survey the uses of // that part and store the results for this index only. unsigned Idx = *Ext->idx_begin(); if (RetValLiveness[Idx] != Live) { - RetValLiveness[Idx] = SurveyUses(Ext, MaybeLiveRetUses[Idx]); + RetValLiveness[Idx] = surveyUses(Ext, MaybeLiveRetUses[Idx]); if (RetValLiveness[Idx] == Live) NumLiveRetVals++; } @@ -591,16 +581,16 @@ void DeadArgumentEliminationPass::SurveyFunction(const Function &F) { // Used by something else than extractvalue. Survey, but assume that the // result applies to all sub-values. UseVector MaybeLiveAggregateUses; - if (SurveyUse(&U, MaybeLiveAggregateUses) == Live) { + if (surveyUse(&UU, MaybeLiveAggregateUses) == Live) { NumLiveRetVals = RetCount; RetValLiveness.assign(RetCount, Live); break; - } else { - for (unsigned Ri = 0; Ri != RetCount; ++Ri) { - if (RetValLiveness[Ri] != Live) - MaybeLiveRetUses[Ri].append(MaybeLiveAggregateUses.begin(), - MaybeLiveAggregateUses.end()); - } + } + + for (unsigned Ri = 0; Ri != RetCount; ++Ri) { + if (RetValLiveness[Ri] != Live) + MaybeLiveRetUses[Ri].append(MaybeLiveAggregateUses.begin(), + MaybeLiveAggregateUses.end()); } } } @@ -613,7 +603,7 @@ void DeadArgumentEliminationPass::SurveyFunction(const Function &F) { // Now we've inspected all callers, record the liveness of our return values. for (unsigned Ri = 0; Ri != RetCount; ++Ri) - MarkValue(CreateRet(&F, Ri), RetValLiveness[Ri], MaybeLiveRetUses[Ri]); + markValue(createRet(&F, Ri), RetValLiveness[Ri], MaybeLiveRetUses[Ri]); LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - Inspecting args for fn: " << F.getName() << "\n"); @@ -641,81 +631,77 @@ void DeadArgumentEliminationPass::SurveyFunction(const Function &F) { } else { // See what the effect of this use is (recording any uses that cause // MaybeLive in MaybeLiveArgUses). - Result = SurveyUses(&*AI, MaybeLiveArgUses); + Result = surveyUses(&*AI, MaybeLiveArgUses); } // Mark the result. - MarkValue(CreateArg(&F, ArgI), Result, MaybeLiveArgUses); + markValue(createArg(&F, ArgI), Result, MaybeLiveArgUses); // Clear the vector again for the next iteration. MaybeLiveArgUses.clear(); } } -/// MarkValue - This function marks the liveness of RA depending on L. If L is -/// MaybeLive, it also takes all uses in MaybeLiveUses and records them in Uses, -/// such that RA will be marked live if any use in MaybeLiveUses gets marked -/// live later on. -void DeadArgumentEliminationPass::MarkValue(const RetOrArg &RA, Liveness L, +/// Marks the liveness of RA depending on L. If L is MaybeLive, it also takes +/// all uses in MaybeLiveUses and records them in Uses, such that RA will be +/// marked live if any use in MaybeLiveUses gets marked live later on. +void DeadArgumentEliminationPass::markValue(const RetOrArg &RA, Liveness L, const UseVector &MaybeLiveUses) { switch (L) { - case Live: - MarkLive(RA); - break; - case MaybeLive: - assert(!IsLive(RA) && "Use is already live!"); - for (const auto &MaybeLiveUse : MaybeLiveUses) { - if (IsLive(MaybeLiveUse)) { - // A use is live, so this value is live. - MarkLive(RA); - break; - } else { - // Note any uses of this value, so this value can be - // marked live whenever one of the uses becomes live. - Uses.insert(std::make_pair(MaybeLiveUse, RA)); - } + case Live: + markLive(RA); + break; + case MaybeLive: + assert(!isLive(RA) && "Use is already live!"); + for (const auto &MaybeLiveUse : MaybeLiveUses) { + if (isLive(MaybeLiveUse)) { + // A use is live, so this value is live. + markLive(RA); + break; } - break; + // Note any uses of this value, so this value can be + // marked live whenever one of the uses becomes live. + Uses.emplace(MaybeLiveUse, RA); + } + break; } } -/// MarkLive - Mark the given Function as alive, meaning that it cannot be -/// changed in any way. Additionally, -/// mark any values that are used as this function's parameters or by its return -/// values (according to Uses) live as well. -void DeadArgumentEliminationPass::MarkLive(const Function &F) { +/// Mark the given Function as alive, meaning that it cannot be changed in any +/// way. Additionally, mark any values that are used as this function's +/// parameters or by its return values (according to Uses) live as well. +void DeadArgumentEliminationPass::markLive(const Function &F) { LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - Intrinsically live fn: " << F.getName() << "\n"); // Mark the function as live. LiveFunctions.insert(&F); // Mark all arguments as live. for (unsigned ArgI = 0, E = F.arg_size(); ArgI != E; ++ArgI) - PropagateLiveness(CreateArg(&F, ArgI)); + propagateLiveness(createArg(&F, ArgI)); // Mark all return values as live. - for (unsigned Ri = 0, E = NumRetVals(&F); Ri != E; ++Ri) - PropagateLiveness(CreateRet(&F, Ri)); + for (unsigned Ri = 0, E = numRetVals(&F); Ri != E; ++Ri) + propagateLiveness(createRet(&F, Ri)); } -/// MarkLive - Mark the given return value or argument as live. Additionally, -/// mark any values that are used by this value (according to Uses) live as -/// well. -void DeadArgumentEliminationPass::MarkLive(const RetOrArg &RA) { - if (IsLive(RA)) +/// Mark the given return value or argument as live. Additionally, mark any +/// values that are used by this value (according to Uses) live as well. +void DeadArgumentEliminationPass::markLive(const RetOrArg &RA) { + if (isLive(RA)) return; // Already marked Live. LiveValues.insert(RA); LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - Marking " << RA.getDescription() << " live\n"); - PropagateLiveness(RA); + propagateLiveness(RA); } -bool DeadArgumentEliminationPass::IsLive(const RetOrArg &RA) { +bool DeadArgumentEliminationPass::isLive(const RetOrArg &RA) { return LiveFunctions.count(RA.F) || LiveValues.count(RA); } -/// PropagateLiveness - Given that RA is a live value, propagate it's liveness -/// to any other values it uses (according to Uses). -void DeadArgumentEliminationPass::PropagateLiveness(const RetOrArg &RA) { +/// Given that RA is a live value, propagate it's liveness to any other values +/// it uses (according to Uses). +void DeadArgumentEliminationPass::propagateLiveness(const RetOrArg &RA) { // We don't use upper_bound (or equal_range) here, because our recursive call // to ourselves is likely to cause the upper_bound (which is the first value // not belonging to RA) to become erased and the iterator invalidated. @@ -723,18 +709,17 @@ void DeadArgumentEliminationPass::PropagateLiveness(const RetOrArg &RA) { UseMap::iterator E = Uses.end(); UseMap::iterator I; for (I = Begin; I != E && I->first == RA; ++I) - MarkLive(I->second); + markLive(I->second); // Erase RA from the Uses map (from the lower bound to wherever we ended up // after the loop). Uses.erase(Begin, I); } -// RemoveDeadStuffFromFunction - Remove any arguments and return values from F -// that are not in LiveValues. Transform the function and all of the callees of -// the function to not have these arguments and return values. -// -bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { +/// Remove any arguments and return values from F that are not in LiveValues. +/// Transform the function and all the callees of the function to not have these +/// arguments and return values. +bool DeadArgumentEliminationPass::removeDeadStuffFromFunction(Function *F) { // Don't modify fully live functions if (LiveFunctions.count(F)) return false; @@ -742,7 +727,7 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { // Start by computing a new prototype for the function, which is the same as // the old function, but has fewer arguments and a different return type. FunctionType *FTy = F->getFunctionType(); - std::vector<Type*> Params; + std::vector<Type *> Params; // Keep track of if we have a live 'returned' argument bool HasLiveReturnedArg = false; @@ -759,7 +744,7 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { unsigned ArgI = 0; for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; ++I, ++ArgI) { - RetOrArg Arg = CreateArg(F, ArgI); + RetOrArg Arg = createArg(F, ArgI); if (LiveValues.erase(Arg)) { Params.push_back(I->getType()); ArgAlive[ArgI] = true; @@ -776,11 +761,11 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { // Find out the new return value. Type *RetTy = FTy->getReturnType(); Type *NRetTy = nullptr; - unsigned RetCount = NumRetVals(F); + unsigned RetCount = numRetVals(F); // -1 means unused, other numbers are the new index SmallVector<int, 5> NewRetIdxs(RetCount, -1); - std::vector<Type*> RetTypes; + std::vector<Type *> RetTypes; // If there is a function with a live 'returned' argument but a dead return // value, then there are two possible actions: @@ -792,9 +777,9 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { // It's not clear in the general case which option is more profitable because, // even in the absence of explicit uses of the return value, code generation // is free to use the 'returned' attribute to do things like eliding - // save/restores of registers across calls. Whether or not this happens is - // target and ABI-specific as well as depending on the amount of register - // pressure, so there's no good way for an IR-level pass to figure this out. + // save/restores of registers across calls. Whether this happens is target and + // ABI-specific as well as depending on the amount of register pressure, so + // there's no good way for an IR-level pass to figure this out. // // Fortunately, the only places where 'returned' is currently generated by // the FE are places where 'returned' is basically free and almost always a @@ -806,7 +791,7 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { } else { // Look at each of the original return values individually. for (unsigned Ri = 0; Ri != RetCount; ++Ri) { - RetOrArg Ret = CreateRet(F, Ri); + RetOrArg Ret = createRet(F, Ri); if (LiveValues.erase(Ret)) { RetTypes.push_back(getRetComponentType(F, Ri)); NewRetIdxs[Ri] = RetTypes.size() - 1; @@ -879,9 +864,9 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { F->getParent()->getFunctionList().insert(F->getIterator(), NF); NF->takeName(F); - // Loop over all of the callers of the function, transforming the call sites - // to pass in a smaller number of arguments into the new function. - std::vector<Value*> Args; + // Loop over all the callers of the function, transforming the call sites to + // pass in a smaller number of arguments into the new function. + std::vector<Value *> Args; while (!F->use_empty()) { CallBase &CB = cast<CallBase>(*F->user_back()); @@ -896,7 +881,7 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { // Declare these outside of the loops, so we can reuse them for the second // loop, which loops the varargs. - auto I = CB.arg_begin(); + auto *I = CB.arg_begin(); unsigned Pi = 0; // Loop over those operands, corresponding to the normal arguments to the // original function, and add those that are still alive. @@ -909,11 +894,11 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { // If the return type has changed, then get rid of 'returned' on the // call site. The alternative is to make all 'returned' attributes on // call sites keep the return value alive just like 'returned' - // attributes on function declaration but it's less clearly a win and + // attributes on function declaration, but it's less clearly a win and // this is not an expected case anyway ArgAttrVec.push_back(AttributeSet::get( - F->getContext(), - AttrBuilder(F->getContext(), Attrs).removeAttribute(Attribute::Returned))); + F->getContext(), AttrBuilder(F->getContext(), Attrs) + .removeAttribute(Attribute::Returned))); } else { // Otherwise, use the original attributes. ArgAttrVec.push_back(Attrs); @@ -921,7 +906,7 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { } // Push any varargs arguments on the list. Don't forget their attributes. - for (auto E = CB.arg_end(); I != E; ++I, ++Pi) { + for (auto *E = CB.arg_end(); I != E; ++I, ++Pi) { Args.push_back(*I); ArgAttrVec.push_back(CallPAL.getParamAttrs(Pi)); } @@ -934,8 +919,8 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { AttributeSet FnAttrs = CallPAL.getFnAttrs().removeAttribute( F->getContext(), Attribute::AllocSize); - AttributeList NewCallPAL = AttributeList::get( - F->getContext(), FnAttrs, RetAttrs, ArgAttrVec); + AttributeList NewCallPAL = + AttributeList::get(F->getContext(), FnAttrs, RetAttrs, ArgAttrVec); SmallVector<OperandBundleDef, 1> OpBundles; CB.getOperandBundlesAsDefs(OpBundles); @@ -961,10 +946,10 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { CB.replaceAllUsesWith(NewCB); NewCB->takeName(&CB); } else if (NewCB->getType()->isVoidTy()) { - // If the return value is dead, replace any uses of it with undef + // If the return value is dead, replace any uses of it with poison // (any non-debug value uses will get removed later on). if (!CB.getType()->isX86_MMXTy()) - CB.replaceAllUsesWith(UndefValue::get(CB.getType())); + CB.replaceAllUsesWith(PoisonValue::get(CB.getType())); } else { assert((RetTy->isStructTy() || RetTy->isArrayTy()) && "Return type changed, but not into a void. The old return type" @@ -980,8 +965,8 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { // with all the uses, we will just rebuild it using extract/insertvalue // chaining and let instcombine clean that up. // - // Start out building up our return value from undef - Value *RetVal = UndefValue::get(RetTy); + // Start out building up our return value from poison + Value *RetVal = PoisonValue::get(RetTy); for (unsigned Ri = 0; Ri != RetCount; ++Ri) if (NewRetIdxs[Ri] != -1) { Value *V; @@ -1026,10 +1011,10 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { I2->takeName(&*I); ++I2; } else { - // If this argument is dead, replace any uses of it with undef + // If this argument is dead, replace any uses of it with poison // (any non-debug value uses will get removed later on). if (!I->getType()->isX86_MMXTy()) - I->replaceAllUsesWith(UndefValue::get(I->getType())); + I->replaceAllUsesWith(PoisonValue::get(I->getType())); } // If we change the return value of the function we must rewrite any return @@ -1048,8 +1033,8 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { // This does generate messy code, but we'll let it to instcombine to // clean that up. Value *OldRet = RI->getOperand(0); - // Start out building up our return value from undef - RetVal = UndefValue::get(NRetTy); + // Start out building up our return value from poison + RetVal = PoisonValue::get(NRetTy); for (unsigned RetI = 0; RetI != RetCount; ++RetI) if (NewRetIdxs[RetI] != -1) { Value *EV = IRB.CreateExtractValue(OldRet, RetI, "oldret"); @@ -1074,12 +1059,22 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { BB.getInstList().erase(RI); } - // Clone metadatas from the old function, including debug info descriptor. + // Clone metadata from the old function, including debug info descriptor. SmallVector<std::pair<unsigned, MDNode *>, 1> MDs; F->getAllMetadata(MDs); for (auto MD : MDs) NF->addMetadata(MD.first, *MD.second); + // If either the return value(s) or argument(s) are removed, then probably the + // function does not follow standard calling conventions anymore. Hence, add + // DW_CC_nocall to DISubroutineType to inform debugger that it may not be safe + // to call this function or try to interpret the return value. + if (NFTy != FTy && NF->getSubprogram()) { + DISubprogram *SP = NF->getSubprogram(); + auto Temp = SP->getType()->cloneWithCC(llvm::dwarf::DW_CC_nocall); + SP->replaceType(MDNode::replaceWithPermanent(std::move(Temp))); + } + // Now that the old function is dead, delete it. F->eraseFromParent(); @@ -1097,26 +1092,25 @@ PreservedAnalyses DeadArgumentEliminationPass::run(Module &M, LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - Deleting dead varargs\n"); for (Function &F : llvm::make_early_inc_range(M)) if (F.getFunctionType()->isVarArg()) - Changed |= DeleteDeadVarargs(F); + Changed |= deleteDeadVarargs(F); - // Second phase:loop through the module, determining which arguments are live. - // We assume all arguments are dead unless proven otherwise (allowing us to - // determine that dead arguments passed into recursive functions are dead). - // + // Second phase: Loop through the module, determining which arguments are + // live. We assume all arguments are dead unless proven otherwise (allowing us + // to determine that dead arguments passed into recursive functions are dead). LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - Determining liveness\n"); for (auto &F : M) - SurveyFunction(F); + surveyFunction(F); // Now, remove all dead arguments and return values from each function in // turn. We use make_early_inc_range here because functions will probably get // removed (i.e. replaced by new ones). for (Function &F : llvm::make_early_inc_range(M)) - Changed |= RemoveDeadStuffFromFunction(&F); + Changed |= removeDeadStuffFromFunction(&F); // Finally, look for any unused parameters in functions with non-local - // linkage and replace the passed in parameters with undef. + // linkage and replace the passed in parameters with poison. for (auto &F : M) - Changed |= RemoveDeadArgumentsFromCallers(F); + Changed |= removeDeadArgumentsFromCallers(F); if (!Changed) return PreservedAnalyses::all(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ExtractGV.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ExtractGV.cpp index 387f114f6ffa..84280781ee70 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ExtractGV.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ExtractGV.cpp @@ -11,7 +11,6 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/SetVector.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" #include "llvm/Transforms/IPO.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp index 16d00a0c89e1..b10c2ea13469 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ForceFunctionAttrs.cpp @@ -8,9 +8,9 @@ #include "llvm/Transforms/IPO/ForceFunctionAttrs.h" #include "llvm/IR/Function.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" +#include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp index e2f1944cee63..49077f92884f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp @@ -30,7 +30,6 @@ #include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/LazyCallGraph.h" -#include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" @@ -45,6 +44,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Metadata.h" +#include "llvm/IR/ModuleSummaryIndex.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" @@ -69,6 +69,7 @@ using namespace llvm; #define DEBUG_TYPE "function-attrs" +STATISTIC(NumArgMemOnly, "Number of functions marked argmemonly"); STATISTIC(NumReadNone, "Number of functions marked readnone"); STATISTIC(NumReadOnly, "Number of functions marked readonly"); STATISTIC(NumWriteOnly, "Number of functions marked writeonly"); @@ -121,28 +122,28 @@ using SCCNodeSet = SmallSetVector<Function *, 8>; /// result will be based only on AA results for the function declaration; it /// will be assumed that some other (perhaps less optimized) version of the /// function may be selected at link time. -static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, - AAResults &AAR, - const SCCNodeSet &SCCNodes) { +static FunctionModRefBehavior +checkFunctionMemoryAccess(Function &F, bool ThisBody, AAResults &AAR, + const SCCNodeSet &SCCNodes) { FunctionModRefBehavior MRB = AAR.getModRefBehavior(&F); if (MRB == FMRB_DoesNotAccessMemory) // Already perfect! - return MAK_ReadNone; + return MRB; - if (!ThisBody) { - if (AliasAnalysis::onlyReadsMemory(MRB)) - return MAK_ReadOnly; - - if (AliasAnalysis::onlyWritesMemory(MRB)) - return MAK_WriteOnly; - - // Conservatively assume it reads and writes to memory. - return MAK_MayWrite; - } + if (!ThisBody) + return MRB; // Scan the function body for instructions that may read or write memory. bool ReadsMemory = false; bool WritesMemory = false; + // Track if the function accesses memory not based on pointer arguments or + // allocas. + bool AccessesNonArgsOrAlloca = false; + // Returns true if Ptr is not based on a function argument. + auto IsArgumentOrAlloca = [](const Value *Ptr) { + const Value *UO = getUnderlyingObject(Ptr); + return isa<Argument>(UO) || isa<AllocaInst>(UO); + }; for (Instruction &I : instructions(F)) { // Some instructions can be ignored even if they read or write memory. // Detect these now, skipping to the next instruction if one is found. @@ -175,6 +176,7 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, // If it reads, note it. if (isRefSet(MRI)) ReadsMemory = true; + AccessesNonArgsOrAlloca = true; continue; } @@ -187,12 +189,13 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, MemoryLocation Loc = MemoryLocation::getBeforeOrAfter(Arg, I.getAAMetadata()); - // Skip accesses to local or constant memory as they don't impact the // externally visible mod/ref behavior. if (AAR.pointsToConstantMemory(Loc, /*OrLocal=*/true)) continue; + AccessesNonArgsOrAlloca |= !IsArgumentOrAlloca(Loc.Ptr); + if (isModSet(MRI)) // Writes non-local memory. WritesMemory = true; @@ -202,24 +205,29 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, } continue; } else if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { + MemoryLocation Loc = MemoryLocation::get(LI); // Ignore non-volatile loads from local memory. (Atomic is okay here.) - if (!LI->isVolatile()) { - MemoryLocation Loc = MemoryLocation::get(LI); - if (AAR.pointsToConstantMemory(Loc, /*OrLocal=*/true)) - continue; - } + if (!LI->isVolatile() && + AAR.pointsToConstantMemory(Loc, /*OrLocal=*/true)) + continue; + AccessesNonArgsOrAlloca |= !IsArgumentOrAlloca(Loc.Ptr); } else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) { + MemoryLocation Loc = MemoryLocation::get(SI); // Ignore non-volatile stores to local memory. (Atomic is okay here.) - if (!SI->isVolatile()) { - MemoryLocation Loc = MemoryLocation::get(SI); - if (AAR.pointsToConstantMemory(Loc, /*OrLocal=*/true)) - continue; - } + if (!SI->isVolatile() && + AAR.pointsToConstantMemory(Loc, /*OrLocal=*/true)) + continue; + AccessesNonArgsOrAlloca |= !IsArgumentOrAlloca(Loc.Ptr); } else if (VAArgInst *VI = dyn_cast<VAArgInst>(&I)) { // Ignore vaargs on local memory. MemoryLocation Loc = MemoryLocation::get(VI); if (AAR.pointsToConstantMemory(Loc, /*OrLocal=*/true)) continue; + AccessesNonArgsOrAlloca |= !IsArgumentOrAlloca(Loc.Ptr); + } else { + // If AccessesNonArgsOrAlloca has not been updated above, set it + // conservatively. + AccessesNonArgsOrAlloca |= I.mayReadOrWriteMemory(); } // Any remaining instructions need to be taken seriously! Check if they @@ -232,61 +240,74 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, ReadsMemory |= I.mayReadFromMemory(); } - if (WritesMemory) { - if (!ReadsMemory) - return MAK_WriteOnly; - else - return MAK_MayWrite; - } - - return ReadsMemory ? MAK_ReadOnly : MAK_ReadNone; + if (!WritesMemory && !ReadsMemory) + return FMRB_DoesNotAccessMemory; + + FunctionModRefBehavior Result = FunctionModRefBehavior(FMRL_Anywhere); + if (!AccessesNonArgsOrAlloca) + Result = FunctionModRefBehavior(FMRL_ArgumentPointees); + if (WritesMemory) + Result = FunctionModRefBehavior(Result | static_cast<int>(ModRefInfo::Mod)); + if (ReadsMemory) + Result = FunctionModRefBehavior(Result | static_cast<int>(ModRefInfo::Ref)); + return Result; } -MemoryAccessKind llvm::computeFunctionBodyMemoryAccess(Function &F, - AAResults &AAR) { +FunctionModRefBehavior llvm::computeFunctionBodyMemoryAccess(Function &F, + AAResults &AAR) { return checkFunctionMemoryAccess(F, /*ThisBody=*/true, AAR, {}); } -/// Deduce readonly/readnone attributes for the SCC. +/// Deduce readonly/readnone/writeonly attributes for the SCC. template <typename AARGetterT> -static void addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter, - SmallSet<Function *, 8> &Changed) { +static void addMemoryAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter, + SmallSet<Function *, 8> &Changed) { // Check if any of the functions in the SCC read or write memory. If they // write memory then they can't be marked readnone or readonly. bool ReadsMemory = false; bool WritesMemory = false; + // Check if all functions only access memory through their arguments. + bool ArgMemOnly = true; for (Function *F : SCCNodes) { // Call the callable parameter to look up AA results for this function. AAResults &AAR = AARGetter(*F); - // Non-exact function definitions may not be selected at link time, and an // alternative version that writes to memory may be selected. See the // comment on GlobalValue::isDefinitionExact for more details. - switch (checkFunctionMemoryAccess(*F, F->hasExactDefinition(), - AAR, SCCNodes)) { - case MAK_MayWrite: + FunctionModRefBehavior FMRB = + checkFunctionMemoryAccess(*F, F->hasExactDefinition(), AAR, SCCNodes); + if (FMRB == FMRB_DoesNotAccessMemory) + continue; + ModRefInfo MR = createModRefInfo(FMRB); + ReadsMemory |= isRefSet(MR); + WritesMemory |= isModSet(MR); + ArgMemOnly &= AliasAnalysis::onlyAccessesArgPointees(FMRB); + // Reached neither readnone, readonly, writeonly nor argmemonly can be + // inferred. Exit. + if (ReadsMemory && WritesMemory && !ArgMemOnly) return; - case MAK_ReadOnly: - ReadsMemory = true; - break; - case MAK_WriteOnly: - WritesMemory = true; - break; - case MAK_ReadNone: - // Nothing to do! - break; - } } - // If the SCC contains both functions that read and functions that write, then - // we cannot add readonly attributes. - if (ReadsMemory && WritesMemory) - return; - - // Success! Functions in this SCC do not access memory, or only read memory. - // Give them the appropriate attribute. + assert((!ReadsMemory || !WritesMemory || ArgMemOnly) && + "no memory attributes can be added for this SCC, should have exited " + "earlier"); + // Success! Functions in this SCC do not access memory, only read memory, + // only write memory, or only access memory through its arguments. Give them + // the appropriate attribute. for (Function *F : SCCNodes) { + // If possible add argmemonly attribute to F, if it accesses memory. + if (ArgMemOnly && !F->onlyAccessesArgMemory() && + (ReadsMemory || WritesMemory)) { + NumArgMemOnly++; + F->addFnAttr(Attribute::ArgMemOnly); + Changed.insert(F); + } + + // The SCC contains functions both writing and reading from memory. We + // cannot add readonly or writeonline attributes. + if (ReadsMemory && WritesMemory) + continue; if (F->doesNotAccessMemory()) // Already perfect! continue; @@ -1810,7 +1831,7 @@ deriveAttrsInPostOrder(ArrayRef<Function *> Functions, AARGetterT &&AARGetter) { SmallSet<Function *, 8> Changed; addArgumentReturnedAttrs(Nodes.SCCNodes, Changed); - addReadAttrs(Nodes.SCCNodes, AARGetter, Changed); + addMemoryAttrs(Nodes.SCCNodes, AARGetter, Changed); addArgumentAttrs(Nodes.SCCNodes, Changed); inferConvergent(Nodes.SCCNodes, Changed); addNoReturnAttrs(Nodes.SCCNodes, Changed); @@ -1914,6 +1935,7 @@ struct PostOrderFunctionAttrsLegacyPass : public CallGraphSCCPass { char PostOrderFunctionAttrsLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(PostOrderFunctionAttrsLegacyPass, "function-attrs", "Deduce function attributes", false, false) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) INITIALIZE_PASS_END(PostOrderFunctionAttrsLegacyPass, "function-attrs", @@ -1993,12 +2015,13 @@ static bool addNoRecurseAttrsTopDown(Function &F) { // this function could be recursively (indirectly) called. Note that this // also detects if F is directly recursive as F is not yet marked as // a norecurse function. - for (auto *U : F.users()) { - auto *I = dyn_cast<Instruction>(U); + for (auto &U : F.uses()) { + auto *I = dyn_cast<Instruction>(U.getUser()); if (!I) return false; CallBase *CB = dyn_cast<CallBase>(I); - if (!CB || !CB->getParent()->getParent()->doesNotRecurse()) + if (!CB || !CB->isCallee(&U) || + !CB->getParent()->getParent()->doesNotRecurse()) return false; } F.setDoesNotRecurse(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp index d9b43109f629..56e2df14ff38 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp @@ -18,7 +18,6 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/StringSet.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/IR/AutoUpgrade.h" #include "llvm/IR/Constants.h" @@ -33,8 +32,6 @@ #include "llvm/IRReader/IRReader.h" #include "llvm/InitializePasses.h" #include "llvm/Linker/IRMover.h" -#include "llvm/Object/ModuleSymbolTable.h" -#include "llvm/Object/SymbolicFile.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" @@ -1112,12 +1109,13 @@ void llvm::thinLTOFinalizeInModule(Module &TheModule, llvm_unreachable("Expected GV to be converted"); } else { // If all copies of the original symbol had global unnamed addr and - // linkonce_odr linkage, it should be an auto hide symbol. In that case - // the thin link would have marked it as CanAutoHide. Add hidden visibility - // to the symbol to preserve the property. + // linkonce_odr linkage, or if all of them had local unnamed addr linkage + // and are constants, then it should be an auto hide symbol. In that case + // the thin link would have marked it as CanAutoHide. Add hidden + // visibility to the symbol to preserve the property. if (NewLinkage == GlobalValue::WeakODRLinkage && GS->second->canAutoHide()) { - assert(GV.hasLinkOnceODRLinkage() && GV.hasGlobalUnnamedAddr()); + assert(GV.canBeOmittedFromSymbolTable()); GV.setVisibility(GlobalValue::HiddenVisibility); } @@ -1330,10 +1328,9 @@ Expected<bool> FunctionImporter::importFunctions( << " from " << SrcModule->getSourceFileName() << "\n"; } - if (Error Err = Mover.move( - std::move(SrcModule), GlobalsToImport.getArrayRef(), - [](GlobalValue &, IRMover::ValueAdder) {}, - /*IsPerformingImport=*/true)) + if (Error Err = Mover.move(std::move(SrcModule), + GlobalsToImport.getArrayRef(), nullptr, + /*IsPerformingImport=*/true)) report_fatal_error(Twine("Function Import: link error: ") + toString(std::move(Err))); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp index 6c3cc3914337..dafd0dc865a2 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -19,11 +19,8 @@ // Current limitations: // - It does not yet handle integer ranges. We do support "literal constants", // but that's off by default under an option. -// - Only 1 argument per function is specialised, // - The cost-model could be further looked into (it mainly focuses on inlining // benefits), -// - We are not yet caching analysis results, but profiling and checking where -// extra compile time is spent didn't suggest this to be a problem. // // Ideas: // - With a function specialization attribute for arguments, we could have @@ -49,15 +46,16 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" -#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueLattice.h" +#include "llvm/Analysis/ValueLatticeUtils.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/Transforms/Scalar/SCCP.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/SCCPSolver.h" #include "llvm/Transforms/Utils/SizeOpts.h" #include <cmath> @@ -98,8 +96,13 @@ static cl::opt<bool> SpecializeOnAddresses( "func-specialization-on-address", cl::init(false), cl::Hidden, cl::desc("Enable function specialization on the address of global values")); -// TODO: This needs checking to see the impact on compile-times, which is why -// this is off by default for now. +// Disabled by default as it can significantly increase compilation times. +// Running nikic's compile time tracker on x86 with instruction count as the +// metric shows 3-4% regression for SPASS while being neutral for all other +// benchmarks of the llvm test suite. +// +// https://llvm-compile-time-tracker.com +// https://github.com/nikic/llvm-compile-time-tracker static cl::opt<bool> EnableSpecializationForLiteralConstant( "function-specialization-for-literal-constant", cl::init(false), cl::Hidden, cl::desc("Enable specialization of functions that take a literal constant " @@ -108,24 +111,18 @@ static cl::opt<bool> EnableSpecializationForLiteralConstant( namespace { // Bookkeeping struct to pass data from the analysis and profitability phase // to the actual transform helper functions. -struct ArgInfo { - Function *Fn; // The function to perform specialisation on. - Argument *Arg; // The Formal argument being analysed. - Constant *Const; // A corresponding actual constant argument. - InstructionCost Gain; // Profitability: Gain = Bonus - Cost. - - // Flag if this will be a partial specialization, in which case we will need - // to keep the original function around in addition to the added - // specializations. - bool Partial = false; - - ArgInfo(Function *F, Argument *A, Constant *C, InstructionCost G) - : Fn(F), Arg(A), Const(C), Gain(G){}; +struct SpecializationInfo { + SmallVector<ArgInfo, 8> Args; // Stores the {formal,actual} argument pairs. + InstructionCost Gain; // Profitability: Gain = Bonus - Cost. }; } // Anonymous namespace using FuncList = SmallVectorImpl<Function *>; -using ConstList = SmallVectorImpl<Constant *>; +using CallArgBinding = std::pair<CallBase *, Constant *>; +using CallSpecBinding = std::pair<CallBase *, SpecializationInfo>; +// We are using MapVector because it guarantees deterministic iteration +// order across executions. +using SpecializationMap = SmallMapVector<CallBase *, SpecializationInfo, 8>; // Helper to check if \p LV is either a constant or a constant // range with a single element. This should cover exactly the same cases as the @@ -204,41 +201,45 @@ static Constant *getConstantStackValue(CallInst *Call, Value *Val, // ret void // } // -static void constantArgPropagation(FuncList &WorkList, - Module &M, SCCPSolver &Solver) { +static void constantArgPropagation(FuncList &WorkList, Module &M, + SCCPSolver &Solver) { // Iterate over the argument tracked functions see if there // are any new constant values for the call instruction via // stack variables. for (auto *F : WorkList) { - // TODO: Generalize for any read only arguments. - if (F->arg_size() != 1) - continue; - - auto &Arg = *F->arg_begin(); - if (!Arg.onlyReadsMemory() || !Arg.getType()->isPointerTy()) - continue; for (auto *User : F->users()) { + auto *Call = dyn_cast<CallInst>(User); if (!Call) - break; - auto *ArgOp = Call->getArgOperand(0); - auto *ArgOpType = ArgOp->getType(); - auto *ConstVal = getConstantStackValue(Call, ArgOp, Solver); - if (!ConstVal) - break; + continue; - Value *GV = new GlobalVariable(M, ConstVal->getType(), true, - GlobalValue::InternalLinkage, ConstVal, - "funcspec.arg"); + bool Changed = false; + for (const Use &U : Call->args()) { + unsigned Idx = Call->getArgOperandNo(&U); + Value *ArgOp = Call->getArgOperand(Idx); + Type *ArgOpType = ArgOp->getType(); - if (ArgOpType != ConstVal->getType()) - GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOp->getType()); + if (!Call->onlyReadsMemory(Idx) || !ArgOpType->isPointerTy()) + continue; - Call->setArgOperand(0, GV); + auto *ConstVal = getConstantStackValue(Call, ArgOp, Solver); + if (!ConstVal) + continue; + + Value *GV = new GlobalVariable(M, ConstVal->getType(), true, + GlobalValue::InternalLinkage, ConstVal, + "funcspec.arg"); + if (ArgOpType != ConstVal->getType()) + GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOpType); + + Call->setArgOperand(Idx, GV); + Changed = true; + } // Add the changed CallInst to Solver Worklist - Solver.visitCall(*Call); + if (Changed) + Solver.visitCall(*Call); } } } @@ -275,7 +276,10 @@ class FunctionSpecializer { std::function<TargetTransformInfo &(Function &)> GetTTI; std::function<TargetLibraryInfo &(Function &)> GetTLI; - SmallPtrSet<Function *, 2> SpecializedFuncs; + SmallPtrSet<Function *, 4> SpecializedFuncs; + SmallPtrSet<Function *, 4> FullySpecialized; + SmallVector<Instruction *> ReplacedWithConstant; + DenseMap<Function *, CodeMetrics> FunctionMetrics; public: FunctionSpecializer(SCCPSolver &Solver, @@ -284,42 +288,66 @@ public: std::function<TargetLibraryInfo &(Function &)> GetTLI) : Solver(Solver), GetAC(GetAC), GetTTI(GetTTI), GetTLI(GetTLI) {} + ~FunctionSpecializer() { + // Eliminate dead code. + removeDeadInstructions(); + removeDeadFunctions(); + } + /// Attempt to specialize functions in the module to enable constant /// propagation across function boundaries. /// /// \returns true if at least one function is specialized. - bool - specializeFunctions(FuncList &FuncDecls, - FuncList &CurrentSpecializations) { + bool specializeFunctions(FuncList &Candidates, FuncList &WorkList) { bool Changed = false; - for (auto *F : FuncDecls) { - if (!isCandidateFunction(F, CurrentSpecializations)) + for (auto *F : Candidates) { + if (!isCandidateFunction(F)) continue; auto Cost = getSpecializationCost(F); if (!Cost.isValid()) { LLVM_DEBUG( - dbgs() << "FnSpecialization: Invalid specialisation cost.\n"); + dbgs() << "FnSpecialization: Invalid specialization cost.\n"); continue; } - auto ConstArgs = calculateGains(F, Cost); - if (ConstArgs.empty()) { - LLVM_DEBUG(dbgs() << "FnSpecialization: no possible constants found\n"); + LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization cost for " + << F->getName() << " is " << Cost << "\n"); + + SmallVector<CallSpecBinding, 8> Specializations; + if (!calculateGains(F, Cost, Specializations)) { + LLVM_DEBUG(dbgs() << "FnSpecialization: No possible constants found\n"); continue; } - for (auto &CA : ConstArgs) { - specializeFunction(CA, CurrentSpecializations); - Changed = true; - } + Changed = true; + for (auto &Entry : Specializations) + specializeFunction(F, Entry.second, WorkList); } - updateSpecializedFuncs(FuncDecls, CurrentSpecializations); + updateSpecializedFuncs(Candidates, WorkList); NumFuncSpecialized += NbFunctionsSpecialized; return Changed; } + void removeDeadInstructions() { + for (auto *I : ReplacedWithConstant) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Removing dead instruction " << *I + << "\n"); + I->eraseFromParent(); + } + ReplacedWithConstant.clear(); + } + + void removeDeadFunctions() { + for (auto *F : FullySpecialized) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Removing dead function " + << F->getName() << "\n"); + F->eraseFromParent(); + } + FullySpecialized.clear(); + } + bool tryToReplaceWithConstant(Value *V) { if (!V->getType()->isSingleValueType() || isa<CallBase>(V) || V->user_empty()) @@ -330,17 +358,26 @@ public: return false; auto *Const = isConstant(IV) ? Solver.getConstant(IV) : UndefValue::get(V->getType()); - V->replaceAllUsesWith(Const); - for (auto *U : Const->users()) + LLVM_DEBUG(dbgs() << "FnSpecialization: Replacing " << *V + << "\nFnSpecialization: with " << *Const << "\n"); + + // Record uses of V to avoid visiting irrelevant uses of const later. + SmallVector<Instruction *> UseInsts; + for (auto *U : V->users()) if (auto *I = dyn_cast<Instruction>(U)) if (Solver.isBlockExecutable(I->getParent())) - Solver.visit(I); + UseInsts.push_back(I); + + V->replaceAllUsesWith(Const); + + for (auto *I : UseInsts) + Solver.visit(I); // Remove the instruction from Block and Solver. if (auto *I = dyn_cast<Instruction>(V)) { if (I->isSafeToRemove()) { - I->eraseFromParent(); + ReplacedWithConstant.push_back(I); Solver.removeLatticeValueFor(I); } } @@ -352,92 +389,108 @@ private: // also in the cost model. unsigned NbFunctionsSpecialized = 0; + // Compute the code metrics for function \p F. + CodeMetrics &analyzeFunction(Function *F) { + auto I = FunctionMetrics.insert({F, CodeMetrics()}); + CodeMetrics &Metrics = I.first->second; + if (I.second) { + // The code metrics were not cached. + SmallPtrSet<const Value *, 32> EphValues; + CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues); + for (BasicBlock &BB : *F) + Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues); + + LLVM_DEBUG(dbgs() << "FnSpecialization: Code size of function " + << F->getName() << " is " << Metrics.NumInsts + << " instructions\n"); + } + return Metrics; + } + /// Clone the function \p F and remove the ssa_copy intrinsics added by /// the SCCPSolver in the cloned version. - Function *cloneCandidateFunction(Function *F) { - ValueToValueMapTy EmptyMap; - Function *Clone = CloneFunction(F, EmptyMap); + Function *cloneCandidateFunction(Function *F, ValueToValueMapTy &Mappings) { + Function *Clone = CloneFunction(F, Mappings); removeSSACopy(*Clone); return Clone; } - /// This function decides whether it's worthwhile to specialize function \p F - /// based on the known constant values its arguments can take on, i.e. it - /// calculates a gain and returns a list of actual arguments that are deemed - /// profitable to specialize. Specialization is performed on the first - /// interesting argument. Specializations based on additional arguments will - /// be evaluated on following iterations of the main IPSCCP solve loop. - SmallVector<ArgInfo> calculateGains(Function *F, InstructionCost Cost) { - SmallVector<ArgInfo> Worklist; + /// This function decides whether it's worthwhile to specialize function + /// \p F based on the known constant values its arguments can take on. It + /// only discovers potential specialization opportunities without actually + /// applying them. + /// + /// \returns true if any specializations have been found. + bool calculateGains(Function *F, InstructionCost Cost, + SmallVectorImpl<CallSpecBinding> &WorkList) { + SpecializationMap Specializations; // Determine if we should specialize the function based on the values the // argument can take on. If specialization is not profitable, we continue // on to the next argument. for (Argument &FormalArg : F->args()) { - LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing arg: " - << FormalArg.getName() << "\n"); // Determine if this argument is interesting. If we know the argument can - // take on any constant values, they are collected in Constants. If the - // argument can only ever equal a constant value in Constants, the - // function will be completely specialized, and the IsPartial flag will - // be set to false by isArgumentInteresting (that function only adds - // values to the Constants list that are deemed profitable). - bool IsPartial = true; - SmallVector<Constant *> ActualConstArg; - if (!isArgumentInteresting(&FormalArg, ActualConstArg, IsPartial)) { - LLVM_DEBUG(dbgs() << "FnSpecialization: Argument is not interesting\n"); + // take on any constant values, they are collected in Constants. + SmallVector<CallArgBinding, 8> ActualArgs; + if (!isArgumentInteresting(&FormalArg, ActualArgs)) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Argument " + << FormalArg.getNameOrAsOperand() + << " is not interesting\n"); continue; } - for (auto *ActualArg : ActualConstArg) { - InstructionCost Gain = - ForceFunctionSpecialization - ? 1 - : getSpecializationBonus(&FormalArg, ActualArg) - Cost; - - if (Gain <= 0) - continue; - Worklist.push_back({F, &FormalArg, ActualArg, Gain}); - } + for (const auto &Entry : ActualArgs) { + CallBase *Call = Entry.first; + Constant *ActualArg = Entry.second; - if (Worklist.empty()) - continue; + auto I = Specializations.insert({Call, SpecializationInfo()}); + SpecializationInfo &S = I.first->second; - // Sort the candidates in descending order. - llvm::stable_sort(Worklist, [](const ArgInfo &L, const ArgInfo &R) { - return L.Gain > R.Gain; - }); - - // Truncate the worklist to 'MaxClonesThreshold' candidates if - // necessary. - if (Worklist.size() > MaxClonesThreshold) { - LLVM_DEBUG(dbgs() << "FnSpecialization: number of candidates exceed " - << "the maximum number of clones threshold.\n" - << "Truncating worklist to " << MaxClonesThreshold - << " candidates.\n"); - Worklist.erase(Worklist.begin() + MaxClonesThreshold, - Worklist.end()); + if (I.second) + S.Gain = ForceFunctionSpecialization ? 1 : 0 - Cost; + if (!ForceFunctionSpecialization) + S.Gain += getSpecializationBonus(&FormalArg, ActualArg); + S.Args.push_back({&FormalArg, ActualArg}); } + } - if (IsPartial || Worklist.size() < ActualConstArg.size()) - for (auto &ActualArg : Worklist) - ActualArg.Partial = true; - - LLVM_DEBUG(dbgs() << "Sorted list of candidates by gain:\n"; - for (auto &C - : Worklist) { - dbgs() << "- Function = " << C.Fn->getName() << ", "; - dbgs() << "FormalArg = " << C.Arg->getName() << ", "; - dbgs() << "ActualArg = " << C.Const->getName() << ", "; - dbgs() << "Gain = " << C.Gain << "\n"; - }); - - // FIXME: Only one argument per function. - break; + // Remove unprofitable specializations. + Specializations.remove_if( + [](const auto &Entry) { return Entry.second.Gain <= 0; }); + + // Clear the MapVector and return the underlying vector. + WorkList = Specializations.takeVector(); + + // Sort the candidates in descending order. + llvm::stable_sort(WorkList, [](const auto &L, const auto &R) { + return L.second.Gain > R.second.Gain; + }); + + // Truncate the worklist to 'MaxClonesThreshold' candidates if necessary. + if (WorkList.size() > MaxClonesThreshold) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Number of candidates exceed " + << "the maximum number of clones threshold.\n" + << "FnSpecialization: Truncating worklist to " + << MaxClonesThreshold << " candidates.\n"); + WorkList.erase(WorkList.begin() + MaxClonesThreshold, WorkList.end()); } - return Worklist; + + LLVM_DEBUG(dbgs() << "FnSpecialization: Specializations for function " + << F->getName() << "\n"; + for (const auto &Entry + : WorkList) { + dbgs() << "FnSpecialization: Gain = " << Entry.second.Gain + << "\n"; + for (const ArgInfo &Arg : Entry.second.Args) + dbgs() << "FnSpecialization: FormalArg = " + << Arg.Formal->getNameOrAsOperand() + << ", ActualArg = " + << Arg.Actual->getNameOrAsOperand() << "\n"; + }); + + return !WorkList.empty(); } - bool isCandidateFunction(Function *F, FuncList &Specializations) { + bool isCandidateFunction(Function *F) { // Do not specialize the cloned function again. if (SpecializedFuncs.contains(F)) return false; @@ -461,44 +514,45 @@ private: return true; } - void specializeFunction(ArgInfo &AI, FuncList &Specializations) { - Function *Clone = cloneCandidateFunction(AI.Fn); - Argument *ClonedArg = Clone->getArg(AI.Arg->getArgNo()); + void specializeFunction(Function *F, SpecializationInfo &S, + FuncList &WorkList) { + ValueToValueMapTy Mappings; + Function *Clone = cloneCandidateFunction(F, Mappings); // Rewrite calls to the function so that they call the clone instead. - rewriteCallSites(AI.Fn, Clone, *ClonedArg, AI.Const); + rewriteCallSites(Clone, S.Args, Mappings); // Initialize the lattice state of the arguments of the function clone, // marking the argument on which we specialized the function constant // with the given value. - Solver.markArgInFuncSpecialization(AI.Fn, ClonedArg, AI.Const); + Solver.markArgInFuncSpecialization(Clone, S.Args); // Mark all the specialized functions - Specializations.push_back(Clone); + WorkList.push_back(Clone); NbFunctionsSpecialized++; // If the function has been completely specialized, the original function // is no longer needed. Mark it unreachable. - if (!AI.Partial) - Solver.markFunctionUnreachable(AI.Fn); + if (F->getNumUses() == 0 || all_of(F->users(), [F](User *U) { + if (auto *CS = dyn_cast<CallBase>(U)) + return CS->getFunction() == F; + return false; + })) { + Solver.markFunctionUnreachable(F); + FullySpecialized.insert(F); + } } /// Compute and return the cost of specializing function \p F. InstructionCost getSpecializationCost(Function *F) { - // Compute the code metrics for the function. - SmallPtrSet<const Value *, 32> EphValues; - CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues); - CodeMetrics Metrics; - for (BasicBlock &BB : *F) - Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues); - + CodeMetrics &Metrics = analyzeFunction(F); // If the code metrics reveal that we shouldn't duplicate the function, we // shouldn't specialize it. Set the specialization cost to Invalid. // Or if the lines of codes implies that this function is easy to get // inlined so that we shouldn't specialize it. - if (Metrics.notDuplicatable || + if (Metrics.notDuplicatable || !Metrics.NumInsts.isValid() || (!ForceFunctionSpecialization && - Metrics.NumInsts < SmallFunctionThreshold)) { + *Metrics.NumInsts.getValue() < SmallFunctionThreshold)) { InstructionCost C{}; C.setInvalid(); return C; @@ -539,31 +593,20 @@ private: DominatorTree DT(*F); LoopInfo LI(DT); auto &TTI = (GetTTI)(*F); - LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for: " << *A - << "\n"); + LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: " + << C->getNameOrAsOperand() << "\n"); InstructionCost TotalCost = 0; for (auto *U : A->users()) { TotalCost += getUserBonus(U, TTI, LI); - LLVM_DEBUG(dbgs() << "FnSpecialization: User cost "; + LLVM_DEBUG(dbgs() << "FnSpecialization: User cost "; TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n"); } // The below heuristic is only concerned with exposing inlining // opportunities via indirect call promotion. If the argument is not a - // function pointer, give up. - if (!isa<PointerType>(A->getType()) || - !isa<FunctionType>(A->getType()->getPointerElementType())) - return TotalCost; - - // Since the argument is a function pointer, its incoming constant values - // should be functions or constant expressions. The code below attempts to - // look through cast expressions to find the function that will be called. - Value *CalledValue = C; - while (isa<ConstantExpr>(CalledValue) && - cast<ConstantExpr>(CalledValue)->isCast()) - CalledValue = cast<User>(CalledValue)->getOperand(0); - Function *CalledFunction = dyn_cast<Function>(CalledValue); + // (potentially casted) function pointer, give up. + Function *CalledFunction = dyn_cast<Function>(C->stripPointerCasts()); if (!CalledFunction) return TotalCost; @@ -603,6 +646,9 @@ private: Bonus += Params.DefaultThreshold; else if (IC.isVariable() && IC.getCostDelta() > 0) Bonus += IC.getCostDelta(); + + LLVM_DEBUG(dbgs() << "FnSpecialization: Inlining bonus " << Bonus + << " for user " << *U << "\n"); } return TotalCost + Bonus; @@ -615,15 +661,12 @@ private: /// specializing the function based on the incoming values of argument \p A /// would result in any significant optimization opportunities. If /// optimization opportunities exist, the constant values of \p A on which to - /// specialize the function are collected in \p Constants. If the values in - /// \p Constants represent the complete set of values that \p A can take on, - /// the function will be completely specialized, and the \p IsPartial flag is - /// set to false. + /// specialize the function are collected in \p Constants. /// /// \returns true if the function should be specialized on the given /// argument. - bool isArgumentInteresting(Argument *A, ConstList &Constants, - bool &IsPartial) { + bool isArgumentInteresting(Argument *A, + SmallVectorImpl<CallArgBinding> &Constants) { // For now, don't attempt to specialize functions based on the values of // composite types. if (!A->getType()->isSingleValueType() || A->user_empty()) @@ -632,8 +675,9 @@ private: // If the argument isn't overdefined, there's nothing to do. It should // already be constant. if (!Solver.getLatticeValueFor(A).isOverdefined()) { - LLVM_DEBUG(dbgs() << "FnSpecialization: nothing to do, arg is already " - << "constant?\n"); + LLVM_DEBUG(dbgs() << "FnSpecialization: Nothing to do, argument " + << A->getNameOrAsOperand() + << " is already constant?\n"); return false; } @@ -650,20 +694,26 @@ private: // // TODO 2: this currently does not support constants, i.e. integer ranges. // - IsPartial = !getPossibleConstants(A, Constants); - LLVM_DEBUG(dbgs() << "FnSpecialization: interesting arg: " << *A << "\n"); + getPossibleConstants(A, Constants); + + if (Constants.empty()) + return false; + + LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting argument " + << A->getNameOrAsOperand() << "\n"); return true; } /// Collect in \p Constants all the constant values that argument \p A can /// take on. - /// - /// \returns true if all of the values the argument can take on are constant - /// (e.g., the argument's parent function cannot be called with an - /// overdefined value). - bool getPossibleConstants(Argument *A, ConstList &Constants) { + void getPossibleConstants(Argument *A, + SmallVectorImpl<CallArgBinding> &Constants) { Function *F = A->getParent(); - bool AllConstant = true; + + // SCCP solver does not record an argument that will be constructed on + // stack. + if (A->hasByValAttr() && !F->onlyReadsMemory()) + return; // Iterate over all the call sites of the argument's parent function. for (User *U : F->users()) { @@ -672,10 +722,8 @@ private: auto &CS = *cast<CallBase>(U); // If the call site has attribute minsize set, that callsite won't be // specialized. - if (CS.hasFnAttr(Attribute::MinSize)) { - AllConstant = false; + if (CS.hasFnAttr(Attribute::MinSize)) continue; - } // If the parent of the call site will never be executed, we don't need // to worry about the passed value. @@ -684,13 +732,7 @@ private: auto *V = CS.getArgOperand(A->getArgNo()); if (isa<PoisonValue>(V)) - return false; - - // For now, constant expressions are fine but only if they are function - // calls. - if (auto *CE = dyn_cast<ConstantExpr>(V)) - if (!isa<Function>(CE->getOperand(0))) - return false; + return; // TrackValueOfGlobalVariable only tracks scalar global variables. if (auto *GV = dyn_cast<GlobalVariable>(V)) { @@ -698,36 +740,32 @@ private: // global values. if (!GV->isConstant()) if (!SpecializeOnAddresses) - return false; + return; if (!GV->getValueType()->isSingleValueType()) - return false; + return; } if (isa<Constant>(V) && (Solver.getLatticeValueFor(V).isConstant() || EnableSpecializationForLiteralConstant)) - Constants.push_back(cast<Constant>(V)); - else - AllConstant = false; + Constants.push_back({&CS, cast<Constant>(V)}); } - - // If the argument can only take on constant values, AllConstant will be - // true. - return AllConstant; } /// Rewrite calls to function \p F to call function \p Clone instead. /// - /// This function modifies calls to function \p F whose argument at index \p - /// ArgNo is equal to constant \p C. The calls are rewritten to call function - /// \p Clone instead. + /// This function modifies calls to function \p F as long as the actual + /// arguments match those in \p Args. Note that for recursive calls we + /// need to compare against the cloned formal arguments. /// /// Callsites that have been marked with the MinSize function attribute won't /// be specialized and rewritten. - void rewriteCallSites(Function *F, Function *Clone, Argument &Arg, - Constant *C) { - unsigned ArgNo = Arg.getArgNo(); - SmallVector<CallBase *, 4> CallSitesToRewrite; + void rewriteCallSites(Function *Clone, const SmallVectorImpl<ArgInfo> &Args, + ValueToValueMapTy &Mappings) { + assert(!Args.empty() && "Specialization without arguments"); + Function *F = Args[0].Formal->getParent(); + + SmallVector<CallBase *, 8> CallSitesToRewrite; for (auto *U : F->users()) { if (!isa<CallInst>(U) && !isa<InvokeInst>(U)) continue; @@ -736,35 +774,50 @@ private: continue; CallSitesToRewrite.push_back(&CS); } + + LLVM_DEBUG(dbgs() << "FnSpecialization: Replacing call sites of " + << F->getName() << " with " << Clone->getName() << "\n"); + for (auto *CS : CallSitesToRewrite) { - if ((CS->getFunction() == Clone && CS->getArgOperand(ArgNo) == &Arg) || - CS->getArgOperand(ArgNo) == C) { + LLVM_DEBUG(dbgs() << "FnSpecialization: " + << CS->getFunction()->getName() << " ->" << *CS + << "\n"); + if (/* recursive call */ + (CS->getFunction() == Clone && + all_of(Args, + [CS, &Mappings](const ArgInfo &Arg) { + unsigned ArgNo = Arg.Formal->getArgNo(); + return CS->getArgOperand(ArgNo) == Mappings[Arg.Formal]; + })) || + /* normal call */ + all_of(Args, [CS](const ArgInfo &Arg) { + unsigned ArgNo = Arg.Formal->getArgNo(); + return CS->getArgOperand(ArgNo) == Arg.Actual; + })) { CS->setCalledFunction(Clone); Solver.markOverdefined(CS); } } } - void updateSpecializedFuncs(FuncList &FuncDecls, - FuncList &CurrentSpecializations) { - for (auto *SpecializedFunc : CurrentSpecializations) { - SpecializedFuncs.insert(SpecializedFunc); + void updateSpecializedFuncs(FuncList &Candidates, FuncList &WorkList) { + for (auto *F : WorkList) { + SpecializedFuncs.insert(F); // Initialize the state of the newly created functions, marking them // argument-tracked and executable. - if (SpecializedFunc->hasExactDefinition() && - !SpecializedFunc->hasFnAttribute(Attribute::Naked)) - Solver.addTrackedFunction(SpecializedFunc); + if (F->hasExactDefinition() && !F->hasFnAttribute(Attribute::Naked)) + Solver.addTrackedFunction(F); - Solver.addArgumentTrackedFunction(SpecializedFunc); - FuncDecls.push_back(SpecializedFunc); - Solver.markBlockExecutable(&SpecializedFunc->front()); + Solver.addArgumentTrackedFunction(F); + Candidates.push_back(F); + Solver.markBlockExecutable(&F->front()); // Replace the function arguments for the specialized functions. - for (Argument &Arg : SpecializedFunc->args()) + for (Argument &Arg : F->args()) if (!Arg.use_empty() && tryToReplaceWithConstant(&Arg)) LLVM_DEBUG(dbgs() << "FnSpecialization: Replaced constant argument: " - << Arg.getName() << "\n"); + << Arg.getNameOrAsOperand() << "\n"); } } }; @@ -871,22 +924,26 @@ bool llvm::runFunctionSpecialization( // Initially resolve the constants in all the argument tracked functions. RunSCCPSolver(FuncDecls); - SmallVector<Function *, 2> CurrentSpecializations; + SmallVector<Function *, 8> WorkList; unsigned I = 0; while (FuncSpecializationMaxIters != I++ && - FS.specializeFunctions(FuncDecls, CurrentSpecializations)) { + FS.specializeFunctions(FuncDecls, WorkList)) { + LLVM_DEBUG(dbgs() << "FnSpecialization: Finished iteration " << I << "\n"); // Run the solver for the specialized functions. - RunSCCPSolver(CurrentSpecializations); + RunSCCPSolver(WorkList); // Replace some unresolved constant arguments. constantArgPropagation(FuncDecls, M, Solver); - CurrentSpecializations.clear(); + WorkList.clear(); Changed = true; } - // Clean up the IR by removing ssa_copy intrinsics. + LLVM_DEBUG(dbgs() << "FnSpecialization: Number of specializations = " + << NumFuncSpecialized << "\n"); + + // Remove any ssa_copy intrinsics that may have been introduced. removeSSACopy(M); return Changed; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalDCE.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalDCE.cpp index 5e5d2086adc2..f35827220bb6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalDCE.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalDCE.cpp @@ -21,7 +21,6 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" -#include "llvm/IR/Operator.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" @@ -34,7 +33,7 @@ using namespace llvm; #define DEBUG_TYPE "globaldce" static cl::opt<bool> - ClEnableVFE("enable-vfe", cl::Hidden, cl::init(true), cl::ZeroOrMore, + ClEnableVFE("enable-vfe", cl::Hidden, cl::init(true), cl::desc("Enable virtual function elimination")); STATISTIC(NumAliases , "Number of global aliases removed"); @@ -86,6 +85,9 @@ ModulePass *llvm::createGlobalDCEPass() { /// Returns true if F is effectively empty. static bool isEmptyFunction(Function *F) { + // Skip external functions. + if (F->isDeclaration()) + return false; BasicBlock &Entry = F->getEntryBlock(); for (auto &I : Entry) { if (I.isDebugOrPseudoInst()) @@ -214,14 +216,14 @@ void GlobalDCEPass::ScanVTableLoad(Function *Caller, Metadata *TypeId, if (!Ptr) { LLVM_DEBUG(dbgs() << "can't find pointer in vtable!\n"); VFESafeVTables.erase(VTable); - return; + continue; } auto Callee = dyn_cast<Function>(Ptr->stripPointerCasts()); if (!Callee) { LLVM_DEBUG(dbgs() << "vtable entry is not function pointer!\n"); VFESafeVTables.erase(VTable); - return; + continue; } LLVM_DEBUG(dbgs() << "vfunc dep " << Caller->getName() << " -> " @@ -298,7 +300,8 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) { // marked as alive are discarded. // Remove empty functions from the global ctors list. - Changed |= optimizeGlobalCtorsList(M, isEmptyFunction); + Changed |= optimizeGlobalCtorsList( + M, [](uint32_t, Function *F) { return isEmptyFunction(F); }); // Collect the set of members for each comdat. for (Function &F : M) @@ -317,7 +320,7 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) { // Loop over the module, adding globals which are obviously necessary. for (GlobalObject &GO : M.global_objects()) { - Changed |= RemoveUnusedGlobalValue(GO); + GO.removeDeadConstantUsers(); // Functions with external linkage are needed if they have a body. // Externally visible & appending globals are needed, if they have an // initializer. @@ -330,7 +333,7 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) { // Compute direct dependencies of aliases. for (GlobalAlias &GA : M.aliases()) { - Changed |= RemoveUnusedGlobalValue(GA); + GA.removeDeadConstantUsers(); // Externally visible aliases are needed. if (!GA.isDiscardableIfUnused()) MarkLive(GA); @@ -340,7 +343,7 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) { // Compute direct dependencies of ifuncs. for (GlobalIFunc &GIF : M.ifuncs()) { - Changed |= RemoveUnusedGlobalValue(GIF); + GIF.removeDeadConstantUsers(); // Externally visible ifuncs are needed. if (!GIF.isDiscardableIfUnused()) MarkLive(GIF); @@ -403,7 +406,7 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) { // Now that all interferences have been dropped, delete the actual objects // themselves. auto EraseUnusedGlobalValue = [&](GlobalValue *GV) { - RemoveUnusedGlobalValue(*GV); + GV->removeDeadConstantUsers(); GV->eraseFromParent(); Changed = true; }; @@ -455,16 +458,3 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) { return PreservedAnalyses::none(); return PreservedAnalyses::all(); } - -// RemoveUnusedGlobalValue - Loop over all of the uses of the specified -// GlobalValue, looking for the constant pointer ref that may be pointing to it. -// If found, check to see if the constant pointer ref is safe to destroy, and if -// so, nuke it. This will reduce the reference count on the global value, which -// might make it deader. -// -bool GlobalDCEPass::RemoveUnusedGlobalValue(GlobalValue &GV) { - if (GV.use_empty()) - return false; - GV.removeDeadConstantUsers(); - return GV.use_empty(); -} diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp index 1cb32e32c895..1a1bde4f0668 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp @@ -17,6 +17,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/iterator_range.h" @@ -37,7 +38,6 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" -#include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" @@ -60,7 +60,6 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/CtorUtils.h" @@ -100,7 +99,7 @@ static cl::opt<bool> cl::init(false), cl::Hidden); static cl::opt<int> ColdCCRelFreq( - "coldcc-rel-freq", cl::Hidden, cl::init(2), cl::ZeroOrMore, + "coldcc-rel-freq", cl::Hidden, cl::init(2), cl::desc( "Maximum block frequency, expressed as a percentage of caller's " "entry frequency, for a call site to be considered cold for enabling" @@ -232,7 +231,7 @@ CleanupPointerRootUsers(GlobalVariable *GV, if (MemSrc && MemSrc->isConstant()) { Changed = true; MTI->eraseFromParent(); - } else if (Instruction *I = dyn_cast<Instruction>(MemSrc)) { + } else if (Instruction *I = dyn_cast<Instruction>(MTI->getSource())) { if (I->hasOneUse()) Dead.push_back(std::make_pair(I, MTI)); } @@ -405,9 +404,37 @@ static void transferSRADebugInfo(GlobalVariable *GV, GlobalVariable *NGV, for (auto *GVE : GVs) { DIVariable *Var = GVE->getVariable(); DIExpression *Expr = GVE->getExpression(); + int64_t CurVarOffsetInBytes = 0; + uint64_t CurVarOffsetInBits = 0; + + // Calculate the offset (Bytes), Continue if unknown. + if (!Expr->extractIfOffset(CurVarOffsetInBytes)) + continue; + + // Ignore negative offset. + if (CurVarOffsetInBytes < 0) + continue; + + // Convert offset to bits. + CurVarOffsetInBits = CHAR_BIT * (uint64_t)CurVarOffsetInBytes; + + // Current var starts after the fragment, ignore. + if (CurVarOffsetInBits >= (FragmentOffsetInBits + FragmentSizeInBits)) + continue; + + uint64_t CurVarSize = Var->getType()->getSizeInBits(); + // Current variable ends before start of fragment, ignore. + if (CurVarSize != 0 && + (CurVarOffsetInBits + CurVarSize) <= FragmentOffsetInBits) + continue; + + // Current variable fits in the fragment. + if (CurVarOffsetInBits == FragmentOffsetInBits && + CurVarSize == FragmentSizeInBits) + Expr = DIExpression::get(Expr->getContext(), {}); // If the FragmentSize is smaller than the variable, // emit a fragment expression. - if (FragmentSizeInBits < VarSize) { + else if (FragmentSizeInBits < VarSize) { if (auto E = DIExpression::createFragmentExpression( Expr, FragmentOffsetInBits, FragmentSizeInBits)) Expr = *E; @@ -581,17 +608,14 @@ static bool AllUsesOfValueWillTrapIfNull(const Value *V, // Will trap. } else if (const StoreInst *SI = dyn_cast<StoreInst>(U)) { if (SI->getOperand(0) == V) { - //cerr << "NONTRAPPING USE: " << *U; return false; // Storing the value. } } else if (const CallInst *CI = dyn_cast<CallInst>(U)) { if (CI->getCalledOperand() != V) { - //cerr << "NONTRAPPING USE: " << *U; return false; // Not calling the ptr } } else if (const InvokeInst *II = dyn_cast<InvokeInst>(U)) { if (II->getCalledOperand() != V) { - //cerr << "NONTRAPPING USE: " << *U; return false; // Not calling the ptr } } else if (const BitCastInst *CI = dyn_cast<BitCastInst>(U)) { @@ -615,7 +639,6 @@ static bool AllUsesOfValueWillTrapIfNull(const Value *V, // the comparing of the value of the created global init bool later in // optimizeGlobalAddressOfAllocation for the global variable. } else { - //cerr << "NONTRAPPING USE: " << *U; return false; } } @@ -878,7 +901,7 @@ OptimizeGlobalAddressOfAllocation(GlobalVariable *GV, CallInst *CI, } } - SmallPtrSet<Constant *, 1> RepValues; + SmallSetVector<Constant *, 1> RepValues; RepValues.insert(NewGV); // If there is a comparison against null, we will insert a global bool to @@ -1015,7 +1038,6 @@ valueIsOnlyUsedLocallyOrStoredToOneGlobal(const CallInst *CI, /// accessing the data, and exposes the resultant global to further GlobalOpt. static bool tryToOptimizeStoreOfAllocationToGlobal(GlobalVariable *GV, CallInst *CI, - AtomicOrdering Ordering, const DataLayout &DL, TargetLibraryInfo *TLI) { if (!isAllocRemovable(CI, TLI)) @@ -1062,7 +1084,7 @@ static bool tryToOptimizeStoreOfAllocationToGlobal(GlobalVariable *GV, // its initializer) is ever stored to the global. static bool optimizeOnceStoredGlobal(GlobalVariable *GV, Value *StoredOnceVal, - AtomicOrdering Ordering, const DataLayout &DL, + const DataLayout &DL, function_ref<TargetLibraryInfo &(Function &)> GetTLI) { // Ignore no-op GEPs and bitcasts. StoredOnceVal = StoredOnceVal->stripPointerCasts(); @@ -1087,7 +1109,7 @@ optimizeOnceStoredGlobal(GlobalVariable *GV, Value *StoredOnceVal, } else if (isAllocationFn(StoredOnceVal, GetTLI)) { if (auto *CI = dyn_cast<CallInst>(StoredOnceVal)) { auto *TLI = &GetTLI(*CI->getFunction()); - if (tryToOptimizeStoreOfAllocationToGlobal(GV, CI, Ordering, DL, TLI)) + if (tryToOptimizeStoreOfAllocationToGlobal(GV, CI, DL, TLI)) return true; } } @@ -1257,8 +1279,10 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { return true; } -static bool deleteIfDead( - GlobalValue &GV, SmallPtrSetImpl<const Comdat *> &NotDiscardableComdats) { +static bool +deleteIfDead(GlobalValue &GV, + SmallPtrSetImpl<const Comdat *> &NotDiscardableComdats, + function_ref<void(Function &)> DeleteFnCallback = nullptr) { GV.removeDeadConstantUsers(); if (!GV.isDiscardableIfUnused() && !GV.isDeclaration()) @@ -1277,6 +1301,10 @@ static bool deleteIfDead( return false; LLVM_DEBUG(dbgs() << "GLOBAL DEAD: " << GV << "\n"); + if (auto *F = dyn_cast<Function>(&GV)) { + if (DeleteFnCallback) + DeleteFnCallback(*F); + } GV.eraseFromParent(); ++NumDeleted; return true; @@ -1416,6 +1444,42 @@ static void makeAllConstantUsesInstructions(Constant *C) { } } +// For a global variable with one store, if the store dominates any loads, +// those loads will always load the stored value (as opposed to the +// initializer), even in the presence of recursion. +static bool forwardStoredOnceStore( + GlobalVariable *GV, const StoreInst *StoredOnceStore, + function_ref<DominatorTree &(Function &)> LookupDomTree) { + const Value *StoredOnceValue = StoredOnceStore->getValueOperand(); + // We can do this optimization for non-constants in nosync + norecurse + // functions, but globals used in exactly one norecurse functions are already + // promoted to an alloca. + if (!isa<Constant>(StoredOnceValue)) + return false; + const Function *F = StoredOnceStore->getFunction(); + SmallVector<LoadInst *> Loads; + for (User *U : GV->users()) { + if (auto *LI = dyn_cast<LoadInst>(U)) { + if (LI->getFunction() == F && + LI->getType() == StoredOnceValue->getType() && LI->isSimple()) + Loads.push_back(LI); + } + } + // Only compute DT if we have any loads to examine. + bool MadeChange = false; + if (!Loads.empty()) { + auto &DT = LookupDomTree(*const_cast<Function *>(F)); + for (auto *LI : Loads) { + if (DT.dominates(StoredOnceStore, LI)) { + LI->replaceAllUsesWith(const_cast<Value *>(StoredOnceValue)); + LI->eraseFromParent(); + MadeChange = true; + } + } + } + return MadeChange; +} + /// Analyze the specified global variable and optimize /// it if possible. If we make a change, return true. static bool @@ -1572,9 +1636,15 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS, // Try to optimize globals based on the knowledge that only one value // (besides its initializer) is ever stored to the global. - if (optimizeOnceStoredGlobal(GV, StoredOnceValue, GS.Ordering, DL, GetTLI)) + if (optimizeOnceStoredGlobal(GV, StoredOnceValue, DL, GetTLI)) return true; + // Try to forward the store to any loads. If we have more than one store, we + // may have a store of the initializer between StoredOnceStore and a load. + if (GS.NumStores == 1) + if (forwardStoredOnceStore(GV, GS.StoredOnceStore, LookupDomTree)) + return true; + // Otherwise, if the global was not a boolean, we can shrink it to be a // boolean. Skip this optimization for AS that doesn't allow an initializer. if (SOVConstant && GS.Ordering == AtomicOrdering::NotAtomic && @@ -1755,7 +1825,7 @@ hasOnlyColdCalls(Function &F, return false; if (!CalledFn->hasLocalLinkage()) return false; - // Skip over instrinsics since they won't remain as function calls. + // Skip over intrinsics since they won't remain as function calls. if (CalledFn->getIntrinsicID() != Intrinsic::not_intrinsic) continue; // Check if it's valid to use coldcc calling convention. @@ -1884,7 +1954,9 @@ OptimizeFunctions(Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI, function_ref<BlockFrequencyInfo &(Function &)> GetBFI, function_ref<DominatorTree &(Function &)> LookupDomTree, - SmallPtrSetImpl<const Comdat *> &NotDiscardableComdats) { + SmallPtrSetImpl<const Comdat *> &NotDiscardableComdats, + function_ref<void(Function &F)> ChangedCFGCallback, + function_ref<void(Function &F)> DeleteFnCallback) { bool Changed = false; @@ -1904,7 +1976,7 @@ OptimizeFunctions(Module &M, if (!F.hasName() && !F.isDeclaration() && !F.hasLocalLinkage()) F.setLinkage(GlobalValue::InternalLinkage); - if (deleteIfDead(F, NotDiscardableComdats)) { + if (deleteIfDead(F, NotDiscardableComdats, DeleteFnCallback)) { Changed = true; continue; } @@ -1917,13 +1989,11 @@ OptimizeFunctions(Module &M, // So, remove unreachable blocks from the function, because a) there's // no point in analyzing them and b) GlobalOpt should otherwise grow // some more complicated logic to break these cycles. - // Removing unreachable blocks might invalidate the dominator so we - // recalculate it. + // Notify the analysis manager that we've modified the function's CFG. if (!F.isDeclaration()) { if (removeUnreachableBlocks(F)) { - auto &DT = LookupDomTree(F); - DT.recalculate(F); Changed = true; + ChangedCFGCallback(F); } } @@ -2031,6 +2101,9 @@ OptimizeGlobalVars(Module &M, /// can, false otherwise. static bool EvaluateStaticConstructor(Function *F, const DataLayout &DL, TargetLibraryInfo *TLI) { + // Skip external functions. + if (F->isDeclaration()) + return false; // Call the function. Evaluator Eval(DL, TLI); Constant *RetValDummy; @@ -2383,15 +2456,19 @@ static bool OptimizeEmptyGlobalCXXDtors(Function *CXAAtExitFn) { return Changed; } -static bool optimizeGlobalsInModule( - Module &M, const DataLayout &DL, - function_ref<TargetLibraryInfo &(Function &)> GetTLI, - function_ref<TargetTransformInfo &(Function &)> GetTTI, - function_ref<BlockFrequencyInfo &(Function &)> GetBFI, - function_ref<DominatorTree &(Function &)> LookupDomTree) { +static bool +optimizeGlobalsInModule(Module &M, const DataLayout &DL, + function_ref<TargetLibraryInfo &(Function &)> GetTLI, + function_ref<TargetTransformInfo &(Function &)> GetTTI, + function_ref<BlockFrequencyInfo &(Function &)> GetBFI, + function_ref<DominatorTree &(Function &)> LookupDomTree, + function_ref<void(Function &F)> ChangedCFGCallback, + function_ref<void(Function &F)> DeleteFnCallback) { SmallPtrSet<const Comdat *, 8> NotDiscardableComdats; bool Changed = false; bool LocalChange = true; + Optional<uint32_t> FirstNotFullyEvaluatedPriority; + while (LocalChange) { LocalChange = false; @@ -2411,12 +2488,20 @@ static bool optimizeGlobalsInModule( // Delete functions that are trivially dead, ccc -> fastcc LocalChange |= OptimizeFunctions(M, GetTLI, GetTTI, GetBFI, LookupDomTree, - NotDiscardableComdats); + NotDiscardableComdats, ChangedCFGCallback, + DeleteFnCallback); // Optimize global_ctors list. - LocalChange |= optimizeGlobalCtorsList(M, [&](Function *F) { - return EvaluateStaticConstructor(F, DL, &GetTLI(*F)); - }); + LocalChange |= + optimizeGlobalCtorsList(M, [&](uint32_t Priority, Function *F) { + if (FirstNotFullyEvaluatedPriority && + *FirstNotFullyEvaluatedPriority != Priority) + return false; + bool Evaluated = EvaluateStaticConstructor(F, DL, &GetTLI(*F)); + if (!Evaluated) + FirstNotFullyEvaluatedPriority = Priority; + return Evaluated; + }); // Optimize non-address-taken globals. LocalChange |= OptimizeGlobalVars(M, GetTTI, GetTLI, LookupDomTree, @@ -2457,10 +2542,23 @@ PreservedAnalyses GlobalOptPass::run(Module &M, ModuleAnalysisManager &AM) { auto GetBFI = [&FAM](Function &F) -> BlockFrequencyInfo & { return FAM.getResult<BlockFrequencyAnalysis>(F); }; + auto ChangedCFGCallback = [&FAM](Function &F) { + FAM.invalidate(F, PreservedAnalyses::none()); + }; + auto DeleteFnCallback = [&FAM](Function &F) { FAM.clear(F, F.getName()); }; - if (!optimizeGlobalsInModule(M, DL, GetTLI, GetTTI, GetBFI, LookupDomTree)) + if (!optimizeGlobalsInModule(M, DL, GetTLI, GetTTI, GetBFI, LookupDomTree, + ChangedCFGCallback, DeleteFnCallback)) return PreservedAnalyses::all(); - return PreservedAnalyses::none(); + + PreservedAnalyses PA = PreservedAnalyses::none(); + // We made sure to clear analyses for deleted functions. + PA.preserve<FunctionAnalysisManagerModuleProxy>(); + // The only place we modify the CFG is when calling + // removeUnreachableBlocks(), but there we make sure to invalidate analyses + // for modified functions. + PA.preserveSet<CFGAnalyses>(); + return PA; } namespace { @@ -2491,8 +2589,13 @@ struct GlobalOptLegacyPass : public ModulePass { return this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI(); }; - return optimizeGlobalsInModule(M, DL, GetTLI, GetTTI, GetBFI, - LookupDomTree); + auto ChangedCFGCallback = [&LookupDomTree](Function &F) { + auto &DT = LookupDomTree(F); + DT.recalculate(F); + }; + + return optimizeGlobalsInModule(M, DL, GetTLI, GetTTI, GetBFI, LookupDomTree, + ChangedCFGCallback, nullptr); } void getAnalysisUsage(AnalysisUsage &AU) const override { diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalSplit.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalSplit.cpp index e7d698c42fcf..7d9e6135b2eb 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalSplit.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalSplit.cpp @@ -134,9 +134,9 @@ static bool splitGlobal(GlobalVariable &GV) { } // Finally, remove the original global. Any remaining uses refer to invalid - // elements of the global, so replace with undef. + // elements of the global, so replace with poison. if (!GV.use_empty()) - GV.replaceAllUsesWith(UndefValue::get(GV.getType())); + GV.replaceAllUsesWith(PoisonValue::get(GV.getType())); GV.eraseFromParent(); return true; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/HotColdSplitting.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/HotColdSplitting.cpp index a964fcde0396..95e8ae0fd22f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/HotColdSplitting.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/HotColdSplitting.cpp @@ -29,46 +29,33 @@ #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BlockFrequencyInfo.h" -#include "llvm/Analysis/BranchProbabilityInfo.h" -#include "llvm/Analysis/CFG.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Use.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" -#include "llvm/Support/BlockFrequency.h" -#include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" -#include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/CodeExtractor.h" -#include "llvm/Transforms/Utils/Local.h" -#include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> -#include <limits> #include <cassert> +#include <limits> #include <string> #define DEBUG_TYPE "hotcoldsplit" @@ -126,7 +113,8 @@ bool unlikelyExecuted(BasicBlock &BB) { // mark sanitizer traps as cold. for (Instruction &I : BB) if (auto *CB = dyn_cast<CallBase>(&I)) - if (CB->hasFnAttr(Attribute::Cold) && !CB->getMetadata("nosanitize")) + if (CB->hasFnAttr(Attribute::Cold) && + !CB->getMetadata(LLVMContext::MD_nosanitize)) return true; // The block is cold if it has an unreachable terminator, unless it's @@ -352,7 +340,7 @@ Function *HotColdSplitting::extractColdRegion( // TODO: Pass BFI and BPI to update profile information. CodeExtractor CE(Region, &DT, /* AggregateArgs */ false, /* BFI */ nullptr, /* BPI */ nullptr, AC, /* AllowVarArgs */ false, - /* AllowAlloca */ false, + /* AllowAlloca */ false, /* AllocaBlock */ nullptr, /* Suffix */ "cold." + std::to_string(Count)); // Perform a simple cost/benefit analysis to decide whether or not to permit @@ -740,7 +728,7 @@ bool HotColdSplittingLegacyPass::runOnModule(Module &M) { std::function<OptimizationRemarkEmitter &(Function &)> GetORE = [&ORE](Function &F) -> OptimizationRemarkEmitter & { ORE.reset(new OptimizationRemarkEmitter(&F)); - return *ORE.get(); + return *ORE; }; auto LookupAC = [this](Function &F) -> AssumptionCache * { if (auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>()) @@ -772,7 +760,7 @@ HotColdSplittingPass::run(Module &M, ModuleAnalysisManager &AM) { std::function<OptimizationRemarkEmitter &(Function &)> GetORE = [&ORE](Function &F) -> OptimizationRemarkEmitter & { ORE.reset(new OptimizationRemarkEmitter(&F)); - return *ORE.get(); + return *ORE; }; ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/IPO.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/IPO.cpp index de1c1d379502..ec2b80012ed6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/IPO.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/IPO.cpp @@ -24,7 +24,6 @@ using namespace llvm; void llvm::initializeIPO(PassRegistry &Registry) { initializeOpenMPOptCGSCCLegacyPassPass(Registry); - initializeArgPromotionPass(Registry); initializeAnnotation2MetadataLegacyPass(Registry); initializeCalledValuePropagationLegacyPassPass(Registry); initializeConstantMergeLegacyPassPass(Registry); @@ -70,10 +69,6 @@ void LLVMInitializeIPO(LLVMPassRegistryRef R) { initializeIPO(*unwrap(R)); } -void LLVMAddArgumentPromotionPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createArgumentPromotionPass()); -} - void LLVMAddCalledValuePropagationPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createCalledValuePropagationPass()); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp index faf7cb7d566a..d75d99e307fd 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp @@ -16,8 +16,9 @@ #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Attributes.h" -#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DIBuilder.h" +#include "llvm/IR/DebugInfo.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Mangler.h" #include "llvm/IR/PassManager.h" @@ -25,8 +26,6 @@ #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" -#include <map> -#include <set> #include <vector> #define DEBUG_TYPE "iroutliner" @@ -183,11 +182,24 @@ static void getSortedConstantKeys(std::vector<Value *> &SortedKeys, Value *OutlinableRegion::findCorrespondingValueIn(const OutlinableRegion &Other, Value *V) { Optional<unsigned> GVN = Candidate->getGVN(V); - assert(GVN.hasValue() && "No GVN for incoming value"); + assert(GVN && "No GVN for incoming value"); Optional<unsigned> CanonNum = Candidate->getCanonicalNum(*GVN); Optional<unsigned> FirstGVN = Other.Candidate->fromCanonicalNum(*CanonNum); Optional<Value *> FoundValueOpt = Other.Candidate->fromGVN(*FirstGVN); - return FoundValueOpt.getValueOr(nullptr); + return FoundValueOpt.value_or(nullptr); +} + +BasicBlock * +OutlinableRegion::findCorrespondingBlockIn(const OutlinableRegion &Other, + BasicBlock *BB) { + Instruction *FirstNonPHI = BB->getFirstNonPHI(); + assert(FirstNonPHI && "block is empty?"); + Value *CorrespondingVal = findCorrespondingValueIn(Other, FirstNonPHI); + if (!CorrespondingVal) + return nullptr; + BasicBlock *CorrespondingBlock = + cast<Instruction>(CorrespondingVal)->getParent(); + return CorrespondingBlock; } /// Rewrite the BranchInsts in the incoming blocks to \p PHIBlock that are found @@ -264,13 +276,33 @@ void OutlinableRegion::splitCandidate() { // We iterate over the instructions in the region, if we find a PHINode, we // check if there are predecessors outside of the region, if there are, // we ignore this region since we are unable to handle the severing of the - // phi node right now. + // phi node right now. + + // TODO: Handle extraneous inputs for PHINodes through variable number of + // inputs, similar to how outputs are handled. BasicBlock::iterator It = StartInst->getIterator(); + EndBB = BackInst->getParent(); + BasicBlock *IBlock; + BasicBlock *PHIPredBlock = nullptr; + bool EndBBTermAndBackInstDifferent = EndBB->getTerminator() != BackInst; while (PHINode *PN = dyn_cast<PHINode>(&*It)) { unsigned NumPredsOutsideRegion = 0; - for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (!BBSet.contains(PN->getIncomingBlock(i))) + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + if (!BBSet.contains(PN->getIncomingBlock(i))) { + PHIPredBlock = PN->getIncomingBlock(i); + ++NumPredsOutsideRegion; + continue; + } + + // We must consider the case there the incoming block to the PHINode is + // the same as the final block of the OutlinableRegion. If this is the + // case, the branch from this block must also be outlined to be valid. + IBlock = PN->getIncomingBlock(i); + if (IBlock == EndBB && EndBBTermAndBackInstDifferent) { + PHIPredBlock = PN->getIncomingBlock(i); ++NumPredsOutsideRegion; + } + } if (NumPredsOutsideRegion > 1) return; @@ -285,11 +317,9 @@ void OutlinableRegion::splitCandidate() { // If the region ends with a PHINode, but does not contain all of the phi node // instructions of the region, we ignore it for now. - if (isa<PHINode>(BackInst)) { - EndBB = BackInst->getParent(); - if (BackInst != &*std::prev(EndBB->getFirstInsertionPt())) - return; - } + if (isa<PHINode>(BackInst) && + BackInst != &*std::prev(EndBB->getFirstInsertionPt())) + return; // The basic block gets split like so: // block: block: @@ -310,6 +340,10 @@ void OutlinableRegion::splitCandidate() { StartBB = PrevBB->splitBasicBlock(StartInst, OriginalName + "_to_outline"); PrevBB->replaceSuccessorsPhiUsesWith(PrevBB, StartBB); + // If there was a PHINode with an incoming block outside the region, + // make sure is correctly updated in the newly split block. + if (PHIPredBlock) + PrevBB->replaceSuccessorsPhiUsesWith(PHIPredBlock, PrevBB); CandidateSplit = true; if (!BackInst->isTerminator()) { @@ -353,6 +387,25 @@ void OutlinableRegion::reattachCandidate() { assert(StartBB != nullptr && "StartBB for Candidate is not defined!"); assert(PrevBB->getTerminator() && "Terminator removed from PrevBB!"); + // Make sure PHINode references to the block we are merging into are + // updated to be incoming blocks from the predecessor to the current block. + + // NOTE: If this is updated such that the outlined block can have more than + // one incoming block to a PHINode, this logic will have to updated + // to handle multiple precessors instead. + + // We only need to update this if the outlined section contains a PHINode, if + // it does not, then the incoming block was never changed in the first place. + // On the other hand, if PrevBB has no predecessors, it means that all + // incoming blocks to the first block are contained in the region, and there + // will be nothing to update. + Instruction *StartInst = (*Candidate->begin()).Inst; + if (isa<PHINode>(StartInst) && !PrevBB->hasNPredecessors(0)) { + assert(!PrevBB->hasNPredecessorsOrMore(2) && + "PrevBB has more than one predecessor. Should be 0 or 1."); + BasicBlock *BeforePrevBB = PrevBB->getSinglePredecessor(); + PrevBB->replaceSuccessorsPhiUsesWith(PrevBB, BeforePrevBB); + } PrevBB->getTerminator()->eraseFromParent(); // If we reattaching after outlining, we iterate over the phi nodes to @@ -501,7 +554,7 @@ collectRegionsConstants(OutlinableRegion &Region, // the the number has been found to be not the same value in each instance. for (Value *V : ID.OperVals) { Optional<unsigned> GVNOpt = C.getGVN(V); - assert(GVNOpt.hasValue() && "Expected a GVN for operand?"); + assert(GVNOpt && "Expected a GVN for operand?"); unsigned GVN = GVNOpt.getValue(); // Check if this global value has been found to not be the same already. @@ -516,7 +569,7 @@ collectRegionsConstants(OutlinableRegion &Region, // global value number. If the global value does not map to a Constant, // it is considered to not be the same value. Optional<bool> ConstantMatches = constantMatches(V, GVN, GVNToConstant); - if (ConstantMatches.hasValue()) { + if (ConstantMatches) { if (ConstantMatches.getValue()) continue; else @@ -597,7 +650,7 @@ Function *IROutliner::createFunction(Module &M, OutlinableGroup &Group, "outlined_ir_func_" + std::to_string(FunctionNameSuffix), M); // Transfer the swifterr attribute to the correct function parameter. - if (Group.SwiftErrorArgument.hasValue()) + if (Group.SwiftErrorArgument) Group.OutlinedFunction->addParamAttr(Group.SwiftErrorArgument.getValue(), Attribute::SwiftError); @@ -666,6 +719,18 @@ static void moveFunctionData(Function &Old, Function &New, if (!isa<CallInst>(&Val)) { // Remove the debug information for outlined functions. Val.setDebugLoc(DebugLoc()); + + // Loop info metadata may contain line locations. Update them to have no + // value in the new subprogram since the outlined code could be from + // several locations. + auto updateLoopInfoLoc = [&New](Metadata *MD) -> Metadata * { + if (DISubprogram *SP = New.getSubprogram()) + if (auto *Loc = dyn_cast_or_null<DILocation>(MD)) + return DILocation::get(New.getContext(), Loc->getLine(), + Loc->getColumn(), SP, nullptr); + return MD; + }; + updateLoopMetadataDebugLocations(Val, updateLoopInfoLoc); continue; } @@ -691,8 +756,6 @@ static void moveFunctionData(Function &Old, Function &New, for (Instruction *I : DebugInsts) I->eraseFromParent(); } - - assert(NewEnds.size() > 0 && "No return instruction for new function?"); } /// Find the the constants that will need to be lifted into arguments @@ -714,7 +777,7 @@ static void findConstants(IRSimilarityCandidate &C, DenseSet<unsigned> &NotSame, for (Value *V : (*IDIt).OperVals) { // Since these are stored before any outlining, they will be in the // global value numbering. - unsigned GVN = C.getGVN(V).getValue(); + unsigned GVN = *C.getGVN(V); if (isa<Constant>(V)) if (NotSame.contains(GVN) && !Seen.contains(GVN)) { Inputs.push_back(GVN); @@ -745,8 +808,7 @@ static void mapInputsToGVNs(IRSimilarityCandidate &C, assert(Input && "Have a nullptr as an input"); if (OutputMappings.find(Input) != OutputMappings.end()) Input = OutputMappings.find(Input)->second; - assert(C.getGVN(Input).hasValue() && - "Could not find a numbering for the given input"); + assert(C.getGVN(Input) && "Could not find a numbering for the given input"); EndInputNumbers.push_back(C.getGVN(Input).getValue()); } } @@ -885,11 +947,11 @@ findExtractedInputToOverallInputMapping(OutlinableRegion &Region, // numbering overrides any discovered location for the extracted code. for (unsigned InputVal : InputGVNs) { Optional<unsigned> CanonicalNumberOpt = C.getCanonicalNum(InputVal); - assert(CanonicalNumberOpt.hasValue() && "Canonical number not found?"); + assert(CanonicalNumberOpt && "Canonical number not found?"); unsigned CanonicalNumber = CanonicalNumberOpt.getValue(); Optional<Value *> InputOpt = C.fromGVN(InputVal); - assert(InputOpt.hasValue() && "Global value number not found?"); + assert(InputOpt && "Global value number not found?"); Value *Input = InputOpt.getValue(); DenseMap<unsigned, unsigned>::iterator AggArgIt = @@ -901,7 +963,7 @@ findExtractedInputToOverallInputMapping(OutlinableRegion &Region, // argument in the overall function. if (Input->isSwiftError()) { assert( - !Group.SwiftErrorArgument.hasValue() && + !Group.SwiftErrorArgument && "Argument already marked with swifterr for this OutlinableGroup!"); Group.SwiftErrorArgument = TypeIndex; } @@ -969,12 +1031,11 @@ static bool outputHasNonPHI(Value *V, unsigned PHILoc, PHINode &PN, // We check to see if the value is used by the PHINode from some other // predecessor not included in the region. If it is, we make sure // to keep it as an output. - SmallVector<unsigned, 2> IncomingNumbers(PN.getNumIncomingValues()); - std::iota(IncomingNumbers.begin(), IncomingNumbers.end(), 0); - if (any_of(IncomingNumbers, [PHILoc, &PN, V, &BlocksInRegion](unsigned Idx) { - return (Idx != PHILoc && V == PN.getIncomingValue(Idx) && - !BlocksInRegion.contains(PN.getIncomingBlock(Idx))); - })) + if (any_of(llvm::seq<unsigned>(0, PN.getNumIncomingValues()), + [PHILoc, &PN, V, &BlocksInRegion](unsigned Idx) { + return (Idx != PHILoc && V == PN.getIncomingValue(Idx) && + !BlocksInRegion.contains(PN.getIncomingBlock(Idx))); + })) return true; // Check if the value is used by any other instructions outside the region. @@ -1098,30 +1159,72 @@ static hash_code encodePHINodeData(PHINodeData &PND) { /// /// \param Region - The region that \p PN is an output for. /// \param PN - The PHINode we are analyzing. +/// \param Blocks - The blocks for the region we are analyzing. /// \param AggArgIdx - The argument \p PN will be stored into. /// \returns An optional holding the assigned canonical number, or None if /// there is some attribute of the PHINode blocking it from being used. static Optional<unsigned> getGVNForPHINode(OutlinableRegion &Region, - PHINode *PN, unsigned AggArgIdx) { + PHINode *PN, + DenseSet<BasicBlock *> &Blocks, + unsigned AggArgIdx) { OutlinableGroup &Group = *Region.Parent; IRSimilarityCandidate &Cand = *Region.Candidate; BasicBlock *PHIBB = PN->getParent(); CanonList PHIGVNs; - for (Value *Incoming : PN->incoming_values()) { - // If we cannot find a GVN, this means that the input to the PHINode is - // not included in the region we are trying to analyze, meaning, that if - // it was outlined, we would be adding an extra input. We ignore this - // case for now, and so ignore the region. + Value *Incoming; + BasicBlock *IncomingBlock; + for (unsigned Idx = 0, EIdx = PN->getNumIncomingValues(); Idx < EIdx; Idx++) { + Incoming = PN->getIncomingValue(Idx); + IncomingBlock = PN->getIncomingBlock(Idx); + // If we cannot find a GVN, and the incoming block is included in the region + // this means that the input to the PHINode is not included in the region we + // are trying to analyze, meaning, that if it was outlined, we would be + // adding an extra input. We ignore this case for now, and so ignore the + // region. Optional<unsigned> OGVN = Cand.getGVN(Incoming); - if (!OGVN.hasValue()) { + if (!OGVN && Blocks.contains(IncomingBlock)) { Region.IgnoreRegion = true; return None; } + // If the incoming block isn't in the region, we don't have to worry about + // this incoming value. + if (!Blocks.contains(IncomingBlock)) + continue; + // Collect the canonical numbers of the values in the PHINode. - unsigned GVN = OGVN.getValue(); + unsigned GVN = *OGVN; OGVN = Cand.getCanonicalNum(GVN); - assert(OGVN.hasValue() && "No GVN found for incoming value?"); + assert(OGVN && "No GVN found for incoming value?"); + PHIGVNs.push_back(*OGVN); + + // Find the incoming block and use the canonical numbering as well to define + // the hash for the PHINode. + OGVN = Cand.getGVN(IncomingBlock); + + // If there is no number for the incoming block, it is becaause we have + // split the candidate basic blocks. So we use the previous block that it + // was split from to find the valid global value numbering for the PHINode. + if (!OGVN) { + assert(Cand.getStartBB() == IncomingBlock && + "Unknown basic block used in exit path PHINode."); + + BasicBlock *PrevBlock = nullptr; + // Iterate over the predecessors to the incoming block of the + // PHINode, when we find a block that is not contained in the region + // we know that this is the first block that we split from, and should + // have a valid global value numbering. + for (BasicBlock *Pred : predecessors(IncomingBlock)) + if (!Blocks.contains(Pred)) { + PrevBlock = Pred; + break; + } + assert(PrevBlock && "Expected a predecessor not in the reigon!"); + OGVN = Cand.getGVN(PrevBlock); + } + GVN = *OGVN; + OGVN = Cand.getCanonicalNum(GVN); + assert(OGVN && "No GVN found for incoming block?"); PHIGVNs.push_back(*OGVN); } @@ -1131,11 +1234,10 @@ static Optional<unsigned> getGVNForPHINode(OutlinableRegion &Region, DenseMap<hash_code, unsigned>::iterator GVNToPHIIt; DenseMap<unsigned, PHINodeData>::iterator PHIToGVNIt; Optional<unsigned> BBGVN = Cand.getGVN(PHIBB); - assert(BBGVN.hasValue() && "Could not find GVN for the incoming block!"); + assert(BBGVN && "Could not find GVN for the incoming block!"); BBGVN = Cand.getCanonicalNum(BBGVN.getValue()); - assert(BBGVN.hasValue() && - "Could not find canonical number for the incoming block!"); + assert(BBGVN && "Could not find canonical number for the incoming block!"); // Create a pair of the exit block canonical value, and the aggregate // argument location, connected to the canonical numbers stored in the // PHINode. @@ -1262,9 +1364,9 @@ findExtractedOutputToOverallOutputMapping(OutlinableRegion &Region, // If two PHINodes have the same canonical values, but different aggregate // argument locations, then they will have distinct Canonical Values. - GVN = getGVNForPHINode(Region, PN, AggArgIdx); - if (!GVN.hasValue()) - return; + GVN = getGVNForPHINode(Region, PN, BlocksInRegion, AggArgIdx); + if (!GVN) + return; } else { // If we do not have a PHINode we use the global value numbering for the // output value, to find the canonical number to add to the set of stored @@ -1413,7 +1515,7 @@ CallInst *replaceCalledFunction(Module &M, OutlinableRegion &Region) { // Make sure that the argument in the new function has the SwiftError // argument. - if (Group.SwiftErrorArgument.hasValue()) + if (Group.SwiftErrorArgument) Call->addParamAttr(Group.SwiftErrorArgument.getValue(), Attribute::SwiftError); @@ -1520,17 +1622,18 @@ getPassedArgumentAndAdjustArgumentLocation(const Argument *A, /// \param OutputMappings [in] - The mapping of output values from outlined /// region to their original values. /// \param CanonNums [out] - The canonical numbering for the incoming values to -/// \p PN. +/// \p PN paired with their incoming block. /// \param ReplacedWithOutlinedCall - A flag to use the extracted function call /// of \p Region rather than the overall function's call. -static void -findCanonNumsForPHI(PHINode *PN, OutlinableRegion &Region, - const DenseMap<Value *, Value *> &OutputMappings, - DenseSet<unsigned> &CanonNums, - bool ReplacedWithOutlinedCall = true) { +static void findCanonNumsForPHI( + PHINode *PN, OutlinableRegion &Region, + const DenseMap<Value *, Value *> &OutputMappings, + SmallVector<std::pair<unsigned, BasicBlock *>> &CanonNums, + bool ReplacedWithOutlinedCall = true) { // Iterate over the incoming values. for (unsigned Idx = 0, EIdx = PN->getNumIncomingValues(); Idx < EIdx; Idx++) { Value *IVal = PN->getIncomingValue(Idx); + BasicBlock *IBlock = PN->getIncomingBlock(Idx); // If we have an argument as incoming value, we need to grab the passed // value from the call itself. if (Argument *A = dyn_cast<Argument>(IVal)) { @@ -1545,10 +1648,10 @@ findCanonNumsForPHI(PHINode *PN, OutlinableRegion &Region, // Find and add the canonical number for the incoming value. Optional<unsigned> GVN = Region.Candidate->getGVN(IVal); - assert(GVN.hasValue() && "No GVN for incoming value"); + assert(GVN && "No GVN for incoming value"); Optional<unsigned> CanonNum = Region.Candidate->getCanonicalNum(*GVN); - assert(CanonNum.hasValue() && "No Canonical Number for GVN"); - CanonNums.insert(*CanonNum); + assert(CanonNum && "No Canonical Number for GVN"); + CanonNums.push_back(std::make_pair(*CanonNum, IBlock)); } } @@ -1557,19 +1660,26 @@ findCanonNumsForPHI(PHINode *PN, OutlinableRegion &Region, /// function. /// /// \param PN [in] - The PHINode that we are finding the canonical numbers for. -/// \param Region [in] - The OutlinableRegion containing \p PN. +/// \param Region [in] - The OutlinableRegion containing \p PN. /// \param OverallPhiBlock [in] - The overall PHIBlock we are trying to find /// \p PN in. /// \param OutputMappings [in] - The mapping of output values from outlined /// region to their original values. +/// \param UsedPHIs [in, out] - The PHINodes in the block that have already been +/// matched. /// \return the newly found or created PHINode in \p OverallPhiBlock. static PHINode* findOrCreatePHIInBlock(PHINode &PN, OutlinableRegion &Region, BasicBlock *OverallPhiBlock, - const DenseMap<Value *, Value *> &OutputMappings) { + const DenseMap<Value *, Value *> &OutputMappings, + DenseSet<PHINode *> &UsedPHIs) { OutlinableGroup &Group = *Region.Parent; - DenseSet<unsigned> PNCanonNums; + + // A list of the canonical numbering assigned to each incoming value, paired + // with the incoming block for the PHINode passed into this function. + SmallVector<std::pair<unsigned, BasicBlock *>> PNCanonNums; + // We have to use the extracted function since we have merged this region into // the overall function yet. We make sure to reassign the argument numbering // since it is possible that the argument ordering is different between the @@ -1578,18 +1688,61 @@ findOrCreatePHIInBlock(PHINode &PN, OutlinableRegion &Region, /* ReplacedWithOutlinedCall = */ false); OutlinableRegion *FirstRegion = Group.Regions[0]; - DenseSet<unsigned> CurrentCanonNums; + + // A list of the canonical numbering assigned to each incoming value, paired + // with the incoming block for the PHINode that we are currently comparing + // the passed PHINode to. + SmallVector<std::pair<unsigned, BasicBlock *>> CurrentCanonNums; + // Find the Canonical Numbering for each PHINode, if it matches, we replace // the uses of the PHINode we are searching for, with the found PHINode. for (PHINode &CurrPN : OverallPhiBlock->phis()) { + // If this PHINode has already been matched to another PHINode to be merged, + // we skip it. + if (UsedPHIs.contains(&CurrPN)) + continue; + CurrentCanonNums.clear(); findCanonNumsForPHI(&CurrPN, *FirstRegion, OutputMappings, CurrentCanonNums, /* ReplacedWithOutlinedCall = */ true); - if (all_of(PNCanonNums, [&CurrentCanonNums](unsigned CanonNum) { - return CurrentCanonNums.contains(CanonNum); - })) + // If the list of incoming values is not the same length, then they cannot + // match since there is not an analogue for each incoming value. + if (PNCanonNums.size() != CurrentCanonNums.size()) + continue; + + bool FoundMatch = true; + + // We compare the canonical value for each incoming value in the passed + // in PHINode to one already present in the outlined region. If the + // incoming values do not match, then the PHINodes do not match. + + // We also check to make sure that the incoming block matches as well by + // finding the corresponding incoming block in the combined outlined region + // for the current outlined region. + for (unsigned Idx = 0, Edx = PNCanonNums.size(); Idx < Edx; ++Idx) { + std::pair<unsigned, BasicBlock *> ToCompareTo = CurrentCanonNums[Idx]; + std::pair<unsigned, BasicBlock *> ToAdd = PNCanonNums[Idx]; + if (ToCompareTo.first != ToAdd.first) { + FoundMatch = false; + break; + } + + BasicBlock *CorrespondingBlock = + Region.findCorrespondingBlockIn(*FirstRegion, ToAdd.second); + assert(CorrespondingBlock && "Found block is nullptr"); + if (CorrespondingBlock != ToCompareTo.second) { + FoundMatch = false; + break; + } + } + + // If all incoming values and branches matched, then we can merge + // into the found PHINode. + if (FoundMatch) { + UsedPHIs.insert(&CurrPN); return &CurrPN; + } } // If we've made it here, it means we weren't able to replace the PHINode, so @@ -1603,12 +1756,8 @@ findOrCreatePHIInBlock(PHINode &PN, OutlinableRegion &Region, // Find corresponding basic block in the overall function for the incoming // block. - Instruction *FirstNonPHI = IncomingBlock->getFirstNonPHI(); - assert(FirstNonPHI && "Incoming block is empty?"); - Value *CorrespondingVal = - Region.findCorrespondingValueIn(*FirstRegion, FirstNonPHI); - assert(CorrespondingVal && "Value is nullptr?"); - BasicBlock *BlockToUse = cast<Instruction>(CorrespondingVal)->getParent(); + BasicBlock *BlockToUse = + Region.findCorrespondingBlockIn(*FirstRegion, IncomingBlock); NewPN->setIncomingBlock(Idx, BlockToUse); // If we have an argument we make sure we replace using the argument from @@ -1623,6 +1772,10 @@ findOrCreatePHIInBlock(PHINode &PN, OutlinableRegion &Region, IncomingVal = findOutputMapping(OutputMappings, IncomingVal); Value *Val = Region.findCorrespondingValueIn(*FirstRegion, IncomingVal); assert(Val && "Value is nullptr?"); + DenseMap<Value *, Value *>::iterator RemappedIt = + FirstRegion->RemappedArguments.find(Val); + if (RemappedIt != FirstRegion->RemappedArguments.end()) + Val = RemappedIt->second; NewPN->setIncomingValue(Idx, Val); } return NewPN; @@ -1649,6 +1802,7 @@ replaceArgumentUses(OutlinableRegion &Region, if (FirstFunction) DominatingFunction = Group.OutlinedFunction; DominatorTree DT(*DominatingFunction); + DenseSet<PHINode *> UsedPHIs; for (unsigned ArgIdx = 0; ArgIdx < Region.ExtractedFunction->arg_size(); ArgIdx++) { @@ -1665,6 +1819,8 @@ replaceArgumentUses(OutlinableRegion &Region, << *Region.ExtractedFunction << " with " << *AggArg << " in function " << *Group.OutlinedFunction << "\n"); Arg->replaceAllUsesWith(AggArg); + Value *V = Region.Call->getArgOperand(ArgIdx); + Region.RemappedArguments.insert(std::make_pair(V, AggArg)); continue; } @@ -1713,7 +1869,7 @@ replaceArgumentUses(OutlinableRegion &Region, // If this is storing a PHINode, we must make sure it is included in the // overall function. if (!isa<PHINode>(ValueOperand) || - Region.Candidate->getGVN(ValueOperand).hasValue()) { + Region.Candidate->getGVN(ValueOperand).has_value()) { if (FirstFunction) continue; Value *CorrVal = @@ -1725,7 +1881,7 @@ replaceArgumentUses(OutlinableRegion &Region, PHINode *PN = cast<PHINode>(SI->getValueOperand()); // If it has a value, it was not split by the code extractor, which // is what we are looking for. - if (Region.Candidate->getGVN(PN).hasValue()) + if (Region.Candidate->getGVN(PN)) continue; // We record the parent block for the PHINode in the Region so that @@ -1748,8 +1904,8 @@ replaceArgumentUses(OutlinableRegion &Region, // For our PHINode, we find the combined canonical numbering, and // attempt to find a matching PHINode in the overall PHIBlock. If we // cannot, we copy the PHINode and move it into this new block. - PHINode *NewPN = - findOrCreatePHIInBlock(*PN, Region, OverallPhiBlock, OutputMappings); + PHINode *NewPN = findOrCreatePHIInBlock(*PN, Region, OverallPhiBlock, + OutputMappings, UsedPHIs); NewI->setOperand(0, NewPN); } @@ -1923,7 +2079,7 @@ static void alignOutputBlockWithAggFunc( // If there is, we remove the new output blocks. If it does not, // we add it to our list of sets of output blocks. - if (MatchingBB.hasValue()) { + if (MatchingBB) { LLVM_DEBUG(dbgs() << "Set output block for region in function" << Region.ExtractedFunction << " to " << MatchingBB.getValue()); @@ -2279,6 +2435,9 @@ void IROutliner::pruneIncompatibleRegions( if (BBHasAddressTaken) continue; + if (IRSC.getFunction()->hasOptNone()) + continue; + if (IRSC.front()->Inst->getFunction()->hasLinkOnceODRLinkage() && !OutlineFromLinkODRs) continue; @@ -2343,9 +2502,9 @@ static Value *findOutputValueInRegion(OutlinableRegion &Region, OutputCanon = *It->second.second.begin(); } Optional<unsigned> OGVN = Region.Candidate->fromCanonicalNum(OutputCanon); - assert(OGVN.hasValue() && "Could not find GVN for Canonical Number?"); + assert(OGVN && "Could not find GVN for Canonical Number?"); Optional<Value *> OV = Region.Candidate->fromGVN(*OGVN); - assert(OV.hasValue() && "Could not find value for GVN?"); + assert(OV && "Could not find value for GVN?"); return *OV; } @@ -2400,11 +2559,8 @@ static InstructionCost findCostForOutputBlocks(Module &M, for (Value *V : ID.OperVals) { BasicBlock *BB = static_cast<BasicBlock *>(V); - DenseSet<BasicBlock *>::iterator CBIt = CandidateBlocks.find(BB); - if (CBIt != CandidateBlocks.end() || FoundBlocks.contains(BB)) - continue; - FoundBlocks.insert(BB); - NumOutputBranches++; + if (!CandidateBlocks.contains(BB) && FoundBlocks.insert(BB).second) + NumOutputBranches++; } } @@ -2520,7 +2676,7 @@ void IROutliner::updateOutputMapping(OutlinableRegion &Region, // If we found an output register, place a mapping of the new value // to the original in the mapping. - if (!OutputIdx.hasValue()) + if (!OutputIdx) return; if (OutputMappings.find(Outputs[OutputIdx.getValue()]) == @@ -2680,7 +2836,7 @@ unsigned IROutliner::doOutline(Module &M) { OS->Candidate->getBasicBlocks(BlocksInRegion, BE); OS->CE = new (ExtractorAllocator.Allocate()) CodeExtractor(BE, nullptr, false, nullptr, nullptr, nullptr, false, - false, "outlined"); + false, nullptr, "outlined"); findAddInputsOutputs(M, *OS, NotSame); if (!OS->IgnoreRegion) OutlinedRegions.push_back(OS); @@ -2791,7 +2947,7 @@ unsigned IROutliner::doOutline(Module &M) { OS->Candidate->getBasicBlocks(BlocksInRegion, BE); OS->CE = new (ExtractorAllocator.Allocate()) CodeExtractor(BE, nullptr, false, nullptr, nullptr, nullptr, false, - false, "outlined"); + false, nullptr, "outlined"); bool FunctionOutlined = extractSection(*OS); if (FunctionOutlined) { unsigned StartIdx = OS->Candidate->getStartIdx(); @@ -2874,7 +3030,7 @@ bool IROutlinerLegacyPass::runOnModule(Module &M) { std::unique_ptr<OptimizationRemarkEmitter> ORE; auto GORE = [&ORE](Function &F) -> OptimizationRemarkEmitter & { ORE.reset(new OptimizationRemarkEmitter(&F)); - return *ORE.get(); + return *ORE; }; auto GTTI = [this](Function &F) -> TargetTransformInfo & { @@ -2905,7 +3061,7 @@ PreservedAnalyses IROutlinerPass::run(Module &M, ModuleAnalysisManager &AM) { std::function<OptimizationRemarkEmitter &(Function &)> GORE = [&ORE](Function &F) -> OptimizationRemarkEmitter & { ORE.reset(new OptimizationRemarkEmitter(&F)); - return *ORE.get(); + return *ORE; }; if (IROutliner(GTTI, GIRSI, GORE).run(M)) diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp index c32e09875a12..76f8f1a7a482 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp @@ -9,11 +9,8 @@ #include "llvm/Transforms/IPO/InferFunctionAttrs.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/Function.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -32,7 +29,7 @@ static bool inferAllPrototypeAttributes( // explicitly visited by CGSCC passes in the new pass manager.) if (F.isDeclaration() && !F.hasOptNone()) { if (!F.hasFnAttribute(Attribute::NoBuiltin)) - Changed |= inferLibFuncAttributes(F, GetTLI(F)); + Changed |= inferNonMandatoryLibFuncAttrs(F, GetTLI(F)); Changed |= inferAttributesFromOthers(F); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/InlineSimple.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/InlineSimple.cpp index 76f1d0c54d08..2143e39d488d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/InlineSimple.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/InlineSimple.cpp @@ -12,14 +12,8 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InlineCost.h" -#include "llvm/Analysis/ProfileSummaryInfo.h" -#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/IR/CallingConv.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" #include "llvm/InitializePasses.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/Inliner.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp index 10abea7ebd32..4d32266eb9ea 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp @@ -14,8 +14,8 @@ #include "llvm/Transforms/IPO/Inliner.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" +#include "llvm/ADT/PriorityWorklist.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SetVector.h" @@ -29,7 +29,6 @@ #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CGSCCPassManager.h" #include "llvm/Analysis/CallGraph.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InlineAdvisor.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/InlineOrder.h" @@ -38,11 +37,9 @@ #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ReplayInlineAdvisor.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/Utils/ImportedFunctionsInliningStatistics.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/DiagnosticInfo.h" @@ -68,8 +65,6 @@ #include <algorithm> #include <cassert> #include <functional> -#include <sstream> -#include <tuple> #include <utility> #include <vector> @@ -110,6 +105,11 @@ static cl::opt<int> IntraSCCCostMultiplier( static cl::opt<bool> KeepAdvisorForPrinting("keep-inline-advisor-for-printing", cl::init(false), cl::Hidden); +/// Allows printing the contents of the advisor after each SCC inliner pass. +static cl::opt<bool> + EnablePostSCCAdvisorPrinting("enable-scc-inline-advisor-printing", + cl::init(false), cl::Hidden); + extern cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats; static cl::opt<std::string> CGSCCInlineReplayFile( @@ -163,10 +163,6 @@ static cl::opt<CallSiteFormat::Format> CGSCCInlineReplayFormat( "<Line Number>:<Column Number>.<Discriminator> (default)")), cl::desc("How cgscc inline replay file is formatted"), cl::Hidden); -static cl::opt<bool> InlineEnablePriorityOrder( - "inline-enable-priority-order", cl::Hidden, cl::init(false), - cl::desc("Enable the priority inline order for the inliner")); - LegacyInlinerBase::LegacyInlinerBase(char &ID) : CallGraphSCCPass(ID) {} LegacyInlinerBase::LegacyInlinerBase(char &ID, bool InsertLifetime) @@ -721,8 +717,9 @@ InlinerPass::getAdvisor(const ModuleAnalysisManagerCGSCCProxy::Result &MAM, // duration of the inliner pass, and thus the lifetime of the owned advisor. // The one we would get from the MAM can be invalidated as a result of the // inliner's activity. - OwnedAdvisor = - std::make_unique<DefaultInlineAdvisor>(M, FAM, getInlineParams()); + OwnedAdvisor = std::make_unique<DefaultInlineAdvisor>( + M, FAM, getInlineParams(), + InlineContext{LTOPhase, InlinePass::CGSCCInliner}); if (!CGSCCInlineReplayFile.empty()) OwnedAdvisor = getReplayInlineAdvisor( @@ -731,7 +728,9 @@ InlinerPass::getAdvisor(const ModuleAnalysisManagerCGSCCProxy::Result &MAM, CGSCCInlineReplayScope, CGSCCInlineReplayFallback, {CGSCCInlineReplayFormat}}, - /*EmitRemarks=*/true); + /*EmitRemarks=*/true, + InlineContext{LTOPhase, + InlinePass::ReplayCGSCCInliner}); return *OwnedAdvisor; } @@ -757,7 +756,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, .getManager(); InlineAdvisor &Advisor = getAdvisor(MAMProxy, FAM, M); - Advisor.onPassEntry(); + Advisor.onPassEntry(&InitialC); auto AdvisorOnExit = make_scope_exit([&] { Advisor.onPassExit(&InitialC); }); @@ -786,12 +785,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // this model, but it is uniformly spread across all the functions in the SCC // and eventually they all become too large to inline, rather than // incrementally maknig a single function grow in a super linear fashion. - std::unique_ptr<InlineOrder<std::pair<CallBase *, int>>> Calls; - if (InlineEnablePriorityOrder) - Calls = std::make_unique<PriorityInlineOrder<InlineSizePriority>>(); - else - Calls = std::make_unique<DefaultInlineOrder<std::pair<CallBase *, int>>>(); - assert(Calls != nullptr && "Expected an initialized InlineOrder"); + DefaultInlineOrder<std::pair<CallBase *, int>> Calls; // Populate the initial list of calls in this SCC. for (auto &N : InitialC) { @@ -806,7 +800,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, if (auto *CB = dyn_cast<CallBase>(&I)) if (Function *Callee = CB->getCalledFunction()) { if (!Callee->isDeclaration()) - Calls->push({CB, -1}); + Calls.push({CB, -1}); else if (!isa<IntrinsicInst>(I)) { using namespace ore; setInlineRemark(*CB, "unavailable definition"); @@ -820,7 +814,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, } } } - if (Calls->empty()) + if (Calls.empty()) return PreservedAnalyses::all(); // Capture updatable variable for the current SCC. @@ -846,15 +840,15 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, SmallVector<Function *, 4> DeadFunctionsInComdats; // Loop forward over all of the calls. - while (!Calls->empty()) { + while (!Calls.empty()) { // We expect the calls to typically be batched with sequences of calls that // have the same caller, so we first set up some shared infrastructure for // this caller. We also do any pruning we can at this layer on the caller // alone. - Function &F = *Calls->front().first->getCaller(); + Function &F = *Calls.front().first->getCaller(); LazyCallGraph::Node &N = *CG.lookup(F); if (CG.lookupSCC(N) != C) { - Calls->pop(); + Calls.pop(); continue; } @@ -870,8 +864,8 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // We bail out as soon as the caller has to change so we can update the // call graph and prepare the context of that new caller. bool DidInline = false; - while (!Calls->empty() && Calls->front().first->getCaller() == &F) { - auto P = Calls->pop(); + while (!Calls.empty() && Calls.front().first->getCaller() == &F) { + auto P = Calls.pop(); CallBase *CB = P.first; const int InlineHistoryID = P.second; Function &Callee = *CB->getCalledFunction(); @@ -913,7 +907,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, int CBCostMult = getStringFnAttrAsInt( *CB, InlineConstants::FunctionInlineCostMultiplierAttributeName) - .getValueOr(1); + .value_or(1); // Setup the data structure used to plumb customization into the // `InlineFunction` routine. @@ -955,7 +949,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, } if (NewCallee) { if (!NewCallee->isDeclaration()) { - Calls->push({ICB, NewHistoryID}); + Calls.push({ICB, NewHistoryID}); // Continually inlining through an SCC can result in huge compile // times and bloated code since we arbitrarily stop at some point // when the inliner decides it's not profitable to inline anymore. @@ -990,7 +984,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, if (Callee.isDiscardableIfUnused() && Callee.hasZeroLiveUses() && !CG.isLibFunction(Callee)) { if (Callee.hasLocalLinkage() || !Callee.hasComdat()) { - Calls->erase_if([&](const std::pair<CallBase *, int> &Call) { + Calls.erase_if([&](const std::pair<CallBase *, int> &Call) { return Call.first->getCaller() == &Callee; }); // Clear the body and queue the function itself for deletion when we @@ -1120,17 +1114,24 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, ModuleInlinerWrapperPass::ModuleInlinerWrapperPass(InlineParams Params, bool MandatoryFirst, + InlineContext IC, InliningAdvisorMode Mode, unsigned MaxDevirtIterations) - : Params(Params), Mode(Mode), MaxDevirtIterations(MaxDevirtIterations) { + : Params(Params), IC(IC), Mode(Mode), + MaxDevirtIterations(MaxDevirtIterations) { // Run the inliner first. The theory is that we are walking bottom-up and so // the callees have already been fully optimized, and we want to inline them // into the callers so that our optimizations can reflect that. // For PreLinkThinLTO pass, we disable hot-caller heuristic for sample PGO // because it makes profile annotation in the backend inaccurate. - if (MandatoryFirst) + if (MandatoryFirst) { PM.addPass(InlinerPass(/*OnlyMandatory*/ true)); + if (EnablePostSCCAdvisorPrinting) + PM.addPass(InlineAdvisorAnalysisPrinterPass(dbgs())); + } PM.addPass(InlinerPass()); + if (EnablePostSCCAdvisorPrinting) + PM.addPass(InlineAdvisorAnalysisPrinterPass(dbgs())); } PreservedAnalyses ModuleInlinerWrapperPass::run(Module &M, @@ -1140,7 +1141,8 @@ PreservedAnalyses ModuleInlinerWrapperPass::run(Module &M, {CGSCCInlineReplayFile, CGSCCInlineReplayScope, CGSCCInlineReplayFallback, - {CGSCCInlineReplayFormat}})) { + {CGSCCInlineReplayFormat}}, + IC)) { M.getContext().emitError( "Could not setup Inlining Advisor for the requested " "mode and/or options"); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/Internalize.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/Internalize.cpp index 692e445cb7cb..5aa5b905f06c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/Internalize.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/Internalize.cpp @@ -19,7 +19,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/Internalize.h" -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/Triple.h" @@ -33,8 +32,6 @@ #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" -#include "llvm/Transforms/Utils/GlobalStatus.h" -#include "llvm/Transforms/Utils/ModuleUtils.h" using namespace llvm; #define DEBUG_TYPE "internalize" diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/LoopExtractor.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/LoopExtractor.cpp index d9a59dd35fde..ad1927c09803 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/LoopExtractor.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/LoopExtractor.cpp @@ -23,14 +23,9 @@ #include "llvm/IR/PassManager.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" -#include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/CodeExtractor.h" -#include <fstream> -#include <set> using namespace llvm; #define DEBUG_TYPE "loop-extract" diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp index 8e83d7bcb6c2..d5f1d291f41f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp @@ -1223,6 +1223,7 @@ void LowerTypeTestsModule::verifyTypeMDNode(GlobalObject *GO, MDNode *Type) { static const unsigned kX86JumpTableEntrySize = 8; static const unsigned kARMJumpTableEntrySize = 4; static const unsigned kARMBTIJumpTableEntrySize = 8; +static const unsigned kRISCVJumpTableEntrySize = 8; unsigned LowerTypeTestsModule::getJumpTableEntrySize() { switch (Arch) { @@ -1238,6 +1239,9 @@ unsigned LowerTypeTestsModule::getJumpTableEntrySize() { if (BTE->getZExtValue()) return kARMBTIJumpTableEntrySize; return kARMJumpTableEntrySize; + case Triple::riscv32: + case Triple::riscv64: + return kRISCVJumpTableEntrySize; default: report_fatal_error("Unsupported architecture for jump tables"); } @@ -1265,6 +1269,9 @@ void LowerTypeTestsModule::createJumpTableEntry( AsmOS << "b $" << ArgIndex << "\n"; } else if (JumpTableArch == Triple::thumb) { AsmOS << "b.w $" << ArgIndex << "\n"; + } else if (JumpTableArch == Triple::riscv32 || + JumpTableArch == Triple::riscv64) { + AsmOS << "tail $" << ArgIndex << "@plt\n"; } else { report_fatal_error("Unsupported architecture for jump tables"); } @@ -1282,7 +1289,8 @@ Type *LowerTypeTestsModule::getJumpTableEntryType() { void LowerTypeTestsModule::buildBitSetsFromFunctions( ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Functions) { if (Arch == Triple::x86 || Arch == Triple::x86_64 || Arch == Triple::arm || - Arch == Triple::thumb || Arch == Triple::aarch64) + Arch == Triple::thumb || Arch == Triple::aarch64 || + Arch == Triple::riscv32 || Arch == Triple::riscv64) buildBitSetsFromFunctionsNative(TypeIds, Functions); else if (Arch == Triple::wasm32 || Arch == Triple::wasm64) buildBitSetsFromFunctionsWASM(TypeIds, Functions); @@ -1427,6 +1435,11 @@ void LowerTypeTestsModule::createJumpTable( F->addFnAttr("branch-target-enforcement", "false"); F->addFnAttr("sign-return-address", "none"); } + if (JumpTableArch == Triple::riscv32 || JumpTableArch == Triple::riscv64) { + // Make sure the jump table assembly is not modified by the assembler or + // the linker. + F->addFnAttr("target-features", "-c,-relax"); + } // Make sure we don't emit .eh_frame for this function. F->addFnAttr(Attribute::NoUnwind); @@ -2187,11 +2200,7 @@ bool LowerTypeTestsModule::lower() { } Sets.emplace_back(I, MaxUniqueId); } - llvm::sort(Sets, - [](const std::pair<GlobalClassesTy::iterator, unsigned> &S1, - const std::pair<GlobalClassesTy::iterator, unsigned> &S2) { - return S1.second < S2.second; - }); + llvm::sort(Sets, llvm::less_second()); // For each disjoint set we found... for (const auto &S : Sets) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp index 97ef872c5499..b850591b4aa6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp @@ -88,12 +88,11 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/IPO/MergeFunctions.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/Argument.h" -#include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" @@ -113,7 +112,6 @@ #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" -#include "llvm/IR/ValueMap.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" @@ -121,8 +119,8 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" -#include "llvm/Transforms/IPO/MergeFunctions.h" #include "llvm/Transforms/Utils/FunctionComparator.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" #include <algorithm> #include <cassert> #include <iterator> @@ -139,10 +137,10 @@ STATISTIC(NumThunksWritten, "Number of thunks generated"); STATISTIC(NumAliasesWritten, "Number of aliases generated"); STATISTIC(NumDoubleWeak, "Number of new functions created"); -static cl::opt<unsigned> NumFunctionsForSanityCheck( - "mergefunc-sanity", - cl::desc("How many functions in module could be used for " - "MergeFunctions pass sanity check. " +static cl::opt<unsigned> NumFunctionsForVerificationCheck( + "mergefunc-verify", + cl::desc("How many functions in a module could be used for " + "MergeFunctions to pass a basic correctness check. " "'0' disables this check. Works only with '-debug' key."), cl::init(0), cl::Hidden); @@ -228,10 +226,13 @@ private: /// analyzed again. std::vector<WeakTrackingVH> Deferred; + /// Set of values marked as used in llvm.used and llvm.compiler.used. + SmallPtrSet<GlobalValue *, 4> Used; + #ifndef NDEBUG /// Checks the rules of order relation introduced among functions set. - /// Returns true, if sanity check has been passed, and false if failed. - bool doSanityCheck(std::vector<WeakTrackingVH> &Worklist); + /// Returns true, if check has been passed, and false if failed. + bool doFunctionalCheck(std::vector<WeakTrackingVH> &Worklist); #endif /// Insert a ComparableFunction into the FnTree, or merge it away if it's @@ -330,12 +331,12 @@ PreservedAnalyses MergeFunctionsPass::run(Module &M, } #ifndef NDEBUG -bool MergeFunctions::doSanityCheck(std::vector<WeakTrackingVH> &Worklist) { - if (const unsigned Max = NumFunctionsForSanityCheck) { +bool MergeFunctions::doFunctionalCheck(std::vector<WeakTrackingVH> &Worklist) { + if (const unsigned Max = NumFunctionsForVerificationCheck) { unsigned TripleNumber = 0; bool Valid = true; - dbgs() << "MERGEFUNC-SANITY: Started for first " << Max << " functions.\n"; + dbgs() << "MERGEFUNC-VERIFY: Started for first " << Max << " functions.\n"; unsigned i = 0; for (std::vector<WeakTrackingVH>::iterator I = Worklist.begin(), @@ -351,7 +352,7 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakTrackingVH> &Worklist) { // If F1 <= F2, then F2 >= F1, otherwise report failure. if (Res1 != -Res2) { - dbgs() << "MERGEFUNC-SANITY: Non-symmetric; triple: " << TripleNumber + dbgs() << "MERGEFUNC-VERIFY: Non-symmetric; triple: " << TripleNumber << "\n"; dbgs() << *F1 << '\n' << *F2 << '\n'; Valid = false; @@ -384,7 +385,7 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakTrackingVH> &Worklist) { } if (!Transitive) { - dbgs() << "MERGEFUNC-SANITY: Non-transitive; triple: " + dbgs() << "MERGEFUNC-VERIFY: Non-transitive; triple: " << TripleNumber << "\n"; dbgs() << "Res1, Res3, Res4: " << Res1 << ", " << Res3 << ", " << Res4 << "\n"; @@ -395,7 +396,7 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakTrackingVH> &Worklist) { } } - dbgs() << "MERGEFUNC-SANITY: " << (Valid ? "Passed." : "Failed.") << "\n"; + dbgs() << "MERGEFUNC-VERIFY: " << (Valid ? "Passed." : "Failed.") << "\n"; return Valid; } return true; @@ -410,6 +411,11 @@ static bool isEligibleForMerging(Function &F) { bool MergeFunctions::runOnModule(Module &M) { bool Changed = false; + SmallVector<GlobalValue *, 4> UsedV; + collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/false); + collectUsedGlobalVariables(M, UsedV, /*CompilerUsed=*/true); + Used.insert(UsedV.begin(), UsedV.end()); + // All functions in the module, ordered by hash. Functions with a unique // hash value are easily eliminated. std::vector<std::pair<FunctionComparator::FunctionHash, Function *>> @@ -436,7 +442,7 @@ bool MergeFunctions::runOnModule(Module &M) { std::vector<WeakTrackingVH> Worklist; Deferred.swap(Worklist); - LLVM_DEBUG(doSanityCheck(Worklist)); + LLVM_DEBUG(doFunctionalCheck(Worklist)); LLVM_DEBUG(dbgs() << "size of module: " << M.size() << '\n'); LLVM_DEBUG(dbgs() << "size of worklist: " << Worklist.size() << '\n'); @@ -456,6 +462,7 @@ bool MergeFunctions::runOnModule(Module &M) { FnTree.clear(); FNodesInTree.clear(); GlobalNumbers.clear(); + Used.clear(); return Changed; } @@ -484,7 +491,7 @@ static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) { if (SrcTy->isStructTy()) { assert(DestTy->isStructTy()); assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements()); - Value *Result = UndefValue::get(DestTy); + Value *Result = PoisonValue::get(DestTy); for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) { Value *Element = createCast( Builder, Builder.CreateExtractValue(V, makeArrayRef(I)), @@ -828,7 +835,10 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { // For better debugability, under MergeFunctionsPDI, we do not modify G's // call sites to point to F even when within the same translation unit. if (!G->isInterposable() && !MergeFunctionsPDI) { - if (G->hasGlobalUnnamedAddr()) { + // Functions referred to by llvm.used/llvm.compiler.used are special: + // there are uses of the symbol name that are not visible to LLVM, + // usually from inline asm. + if (G->hasGlobalUnnamedAddr() && !Used.contains(G)) { // G might have been a key in our GlobalNumberState, and it's illegal // to replace a key in ValueMap<GlobalValue *> with a non-global. GlobalNumbers.erase(G); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ModuleInliner.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ModuleInliner.cpp index d515303e4911..143715006512 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ModuleInliner.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ModuleInliner.cpp @@ -14,43 +14,33 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/ModuleInliner.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BlockFrequencyInfo.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InlineAdvisor.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/InlineOrder.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/ReplayInlineAdvisor.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/IR/DebugLoc.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" -#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" -#include "llvm/IR/User.h" -#include "llvm/IR/Value.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/CallPromotionUtils.h" #include "llvm/Transforms/Utils/Cloning.h" -#include "llvm/Transforms/Utils/Local.h" -#include "llvm/Transforms/Utils/ModuleUtils.h" #include <cassert> -#include <functional> using namespace llvm; @@ -94,7 +84,9 @@ InlineAdvisor &ModuleInlinerPass::getAdvisor(const ModuleAnalysisManager &MAM, // inliner pass, and thus the lifetime of the owned advisor. The one we // would get from the MAM can be invalidated as a result of the inliner's // activity. - OwnedAdvisor = std::make_unique<DefaultInlineAdvisor>(M, FAM, Params); + OwnedAdvisor = std::make_unique<DefaultInlineAdvisor>( + M, FAM, Params, + InlineContext{LTOPhase, InlinePass::ModuleInliner}); return *OwnedAdvisor; } @@ -119,7 +111,9 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M, LLVM_DEBUG(dbgs() << "---- Module Inliner is Running ---- \n"); auto &IAA = MAM.getResult<InlineAdvisorAnalysis>(M); - if (!IAA.tryCreate(Params, Mode, {})) { + if (!IAA.tryCreate( + Params, Mode, {}, + InlineContext{LTOPhase, InlinePass::ModuleInliner})) { M.getContext().emitError( "Could not setup Inlining Advisor for the requested " "mode and/or options"); @@ -153,7 +147,8 @@ PreservedAnalyses ModuleInlinerPass::run(Module &M, // the SCC inliner, which need some refactoring. std::unique_ptr<InlineOrder<std::pair<CallBase *, int>>> Calls; if (InlineEnablePriorityOrder) - Calls = std::make_unique<PriorityInlineOrder<InlineSizePriority>>(); + Calls = std::make_unique<PriorityInlineOrder>( + std::make_unique<SizePriority>()); else Calls = std::make_unique<DefaultInlineOrder<std::pair<CallBase *, int>>>(); assert(Calls != nullptr && "Expected an initialized InlineOrder"); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp index 7205ae178d21..227ad8501f25 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -49,7 +49,6 @@ #include "llvm/Transforms/IPO/Attributor.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/CallGraphUpdater.h" -#include "llvm/Transforms/Utils/CodeExtractor.h" #include <algorithm> @@ -59,17 +58,16 @@ using namespace omp; #define DEBUG_TYPE "openmp-opt" static cl::opt<bool> DisableOpenMPOptimizations( - "openmp-opt-disable", cl::ZeroOrMore, - cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, - cl::init(false)); + "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."), + cl::Hidden, cl::init(false)); static cl::opt<bool> EnableParallelRegionMerging( - "openmp-opt-enable-merging", cl::ZeroOrMore, + "openmp-opt-enable-merging", cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, cl::init(false)); static cl::opt<bool> - DisableInternalization("openmp-opt-disable-internalization", cl::ZeroOrMore, + DisableInternalization("openmp-opt-disable-internalization", cl::desc("Disable function internalization."), cl::Hidden, cl::init(false)); @@ -85,42 +83,47 @@ static cl::opt<bool> HideMemoryTransferLatency( cl::Hidden, cl::init(false)); static cl::opt<bool> DisableOpenMPOptDeglobalization( - "openmp-opt-disable-deglobalization", cl::ZeroOrMore, + "openmp-opt-disable-deglobalization", cl::desc("Disable OpenMP optimizations involving deglobalization."), cl::Hidden, cl::init(false)); static cl::opt<bool> DisableOpenMPOptSPMDization( - "openmp-opt-disable-spmdization", cl::ZeroOrMore, + "openmp-opt-disable-spmdization", cl::desc("Disable OpenMP optimizations involving SPMD-ization."), cl::Hidden, cl::init(false)); static cl::opt<bool> DisableOpenMPOptFolding( - "openmp-opt-disable-folding", cl::ZeroOrMore, + "openmp-opt-disable-folding", cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden, cl::init(false)); static cl::opt<bool> DisableOpenMPOptStateMachineRewrite( - "openmp-opt-disable-state-machine-rewrite", cl::ZeroOrMore, + "openmp-opt-disable-state-machine-rewrite", cl::desc("Disable OpenMP optimizations that replace the state machine."), cl::Hidden, cl::init(false)); static cl::opt<bool> DisableOpenMPOptBarrierElimination( - "openmp-opt-disable-barrier-elimination", cl::ZeroOrMore, + "openmp-opt-disable-barrier-elimination", cl::desc("Disable OpenMP optimizations that eliminate barriers."), cl::Hidden, cl::init(false)); static cl::opt<bool> PrintModuleAfterOptimizations( - "openmp-opt-print-module", cl::ZeroOrMore, + "openmp-opt-print-module-after", cl::desc("Print the current module after OpenMP optimizations."), cl::Hidden, cl::init(false)); +static cl::opt<bool> PrintModuleBeforeOptimizations( + "openmp-opt-print-module-before", + cl::desc("Print the current module before OpenMP optimizations."), + cl::Hidden, cl::init(false)); + static cl::opt<bool> AlwaysInlineDeviceFunctions( - "openmp-opt-inline-device", cl::ZeroOrMore, + "openmp-opt-inline-device", cl::desc("Inline all applicible functions on the device."), cl::Hidden, cl::init(false)); static cl::opt<bool> - EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::ZeroOrMore, + EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::desc("Enables more verbose remarks."), cl::Hidden, cl::init(false)); @@ -129,6 +132,11 @@ static cl::opt<unsigned> cl::desc("Maximal number of attributor iterations."), cl::init(256)); +static cl::opt<unsigned> + SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden, + cl::desc("Maximum amount of shared memory to use."), + cl::init(std::numeric_limits<unsigned>::max())); + STATISTIC(NumOpenMPRuntimeCallsDeduplicated, "Number of OpenMP runtime calls deduplicated"); STATISTIC(NumOpenMPParallelRegionsDeleted, @@ -493,11 +501,14 @@ struct OMPInformationCache : public InformationCache { // Remove the `noinline` attribute from `__kmpc`, `_OMP::` and `omp_` // functions, except if `optnone` is present. - for (Function &F : M) { - for (StringRef Prefix : {"__kmpc", "_ZN4_OMP", "omp_"}) - if (F.getName().startswith(Prefix) && - !F.hasFnAttribute(Attribute::OptimizeNone)) - F.removeFnAttr(Attribute::NoInline); + if (isOpenMPDevice(M)) { + for (Function &F : M) { + for (StringRef Prefix : {"__kmpc", "_ZN4_OMP", "omp_"}) + if (F.hasFnAttribute(Attribute::NoInline) && + F.getName().startswith(Prefix) && + !F.hasFnAttribute(Attribute::OptimizeNone)) + F.removeFnAttr(Attribute::NoInline); + } } // TODO: We should attach the attributes defined in OMPKinds.def. @@ -591,7 +602,7 @@ struct KernelInfoState : AbstractState { /// Abstract State interface ///{ - KernelInfoState() {} + KernelInfoState() = default; KernelInfoState(bool BestState) { if (!BestState) indicatePessimisticFixpoint(); @@ -926,8 +937,7 @@ private: SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap; BasicBlock *StartBB = nullptr, *EndBB = nullptr; - auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, - BasicBlock &ContinuationIP) { + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) { BasicBlock *CGStartBB = CodeGenIP.getBlock(); BasicBlock *CGEndBB = SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI); @@ -966,8 +976,7 @@ private: const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc(); ParentBB->getTerminator()->eraseFromParent(); - auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, - BasicBlock &ContinuationIP) { + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) { BasicBlock *CGStartBB = CodeGenIP.getBlock(); BasicBlock *CGEndBB = SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI); @@ -1107,10 +1116,8 @@ private: // callbacks. SmallVector<Value *, 8> Args; for (auto *CI : MergableCIs) { - Value *Callee = - CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts(); - FunctionType *FT = - cast<FunctionType>(Callee->getType()->getPointerElementType()); + Value *Callee = CI->getArgOperand(CallbackCalleeOperand); + FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask; Args.clear(); Args.push_back(OutlinedFn->getArg(0)); Args.push_back(OutlinedFn->getArg(1)); @@ -2408,8 +2415,7 @@ struct AAICVTrackerFunction : public AAICVTracker { auto CallCheck = [&](Instruction &I) { Optional<Value *> ReplVal = getValueForCall(A, I, ICV); - if (ReplVal.hasValue() && - ValuesMap.insert(std::make_pair(&I, *ReplVal)).second) + if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second) HasChanged = ChangeStatus::CHANGED; return true; @@ -2469,7 +2475,8 @@ struct AAICVTrackerFunction : public AAICVTracker { if (ICVTrackingAA.isAssumedTracked()) { Optional<Value *> URV = ICVTrackingAA.getUniqueReplacementValue(ICV); - if (!URV || (*URV && AA::isValidAtPosition(**URV, I, OMPInfoCache))) + if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I), + OMPInfoCache))) return URV; } @@ -2510,13 +2517,13 @@ struct AAICVTrackerFunction : public AAICVTracker { if (ValuesMap.count(CurrInst)) { Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst); // Unknown value, track new. - if (!ReplVal.hasValue()) { + if (!ReplVal) { ReplVal = NewReplVal; break; } // If we found a new value, we can't know the icv value anymore. - if (NewReplVal.hasValue()) + if (NewReplVal) if (ReplVal != NewReplVal) return nullptr; @@ -2524,11 +2531,11 @@ struct AAICVTrackerFunction : public AAICVTracker { } Optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV); - if (!NewReplVal.hasValue()) + if (!NewReplVal) continue; // Unknown value, track new. - if (!ReplVal.hasValue()) { + if (!ReplVal) { ReplVal = NewReplVal; break; } @@ -2540,7 +2547,7 @@ struct AAICVTrackerFunction : public AAICVTracker { } // If we are in the same BB and we have a value, we are done. - if (CurrBB == I->getParent() && ReplVal.hasValue()) + if (CurrBB == I->getParent() && ReplVal) return ReplVal; // Go through all predecessors and add terminators for analysis. @@ -2598,7 +2605,7 @@ struct AAICVTrackerFunctionReturned : AAICVTracker { ICVTrackingAA.getReplacementValue(ICV, &I, A); // If we found a second ICV value there is no unique returned value. - if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal) + if (UniqueICVValue && UniqueICVValue != NewReplVal) return false; UniqueICVValue = NewReplVal; @@ -2649,10 +2656,10 @@ struct AAICVTrackerCallSite : AAICVTracker { } ChangeStatus manifest(Attributor &A) override { - if (!ReplVal.hasValue() || !ReplVal.getValue()) + if (!ReplVal || !*ReplVal) return ChangeStatus::UNCHANGED; - A.changeValueAfterManifest(*getCtxI(), **ReplVal); + A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal); A.deleteAfterManifest(*getCtxI()); return ChangeStatus::CHANGED; @@ -2790,7 +2797,7 @@ struct AAExecutionDomainFunction : public AAExecutionDomain { SmallSetVector<const BasicBlock *, 16> SingleThreadedBBs; /// Total number of basic blocks in this function. - long unsigned NumBBs; + long unsigned NumBBs = 0; }; ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { @@ -2953,12 +2960,23 @@ struct AAHeapToSharedFunction : public AAHeapToShared { } void initialize(Attributor &A) override { + if (DisableOpenMPOptDeglobalization) { + indicatePessimisticFixpoint(); + return; + } + auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; + Attributor::SimplifictionCallbackTy SCB = + [](const IRPosition &, const AbstractAttribute *, + bool &) -> Optional<Value *> { return nullptr; }; for (User *U : RFI.Declaration->users()) - if (CallBase *CB = dyn_cast<CallBase>(U)) + if (CallBase *CB = dyn_cast<CallBase>(U)) { MallocCalls.insert(CB); + A.registerSimplificationCallback(IRPosition::callsite_returned(*CB), + SCB); + } findPotentialRemovedFreeCalls(A); } @@ -3000,6 +3018,14 @@ struct AAHeapToSharedFunction : public AAHeapToShared { auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0)); + if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) { + LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB + << " with shared memory." + << " Shared memory usage is limited to " + << SharedMemoryLimit << " bytes\n"); + continue; + } + LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB << " with " << AllocSize->getZExtValue() << " bytes of shared memory\n"); @@ -3030,11 +3056,12 @@ struct AAHeapToSharedFunction : public AAHeapToShared { "HeapToShared on allocation without alignment attribute"); SharedMem->setAlignment(MaybeAlign(Alignment)); - A.changeValueAfterManifest(*CB, *NewBuffer); + A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer); A.deleteAfterManifest(*CB); A.deleteAfterManifest(*FreeCalls.front()); - NumBytesMovedToSharedMemory += AllocSize->getZExtValue(); + SharedMemoryUsed += AllocSize->getZExtValue(); + NumBytesMovedToSharedMemory = SharedMemoryUsed; Changed = ChangeStatus::CHANGED; } @@ -3070,6 +3097,8 @@ struct AAHeapToSharedFunction : public AAHeapToShared { SmallSetVector<CallBase *, 4> MallocCalls; /// Collection of potentially removed free calls in a function. SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls; + /// The total amount of shared memory that has been used for HeapToShared. + unsigned SharedMemoryUsed = 0; }; struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> { @@ -3138,12 +3167,6 @@ struct AAKernelInfoFunction : AAKernelInfo { auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache()); Function *Fn = getAnchorScope(); - if (!OMPInfoCache.Kernels.count(Fn)) - return; - - // Add itself to the reaching kernel and set IsKernelEntry. - ReachingKernelEntries.insert(Fn); - IsKernelEntry = true; OMPInformationCache::RuntimeFunctionInfo &InitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; @@ -3177,10 +3200,12 @@ struct AAKernelInfoFunction : AAKernelInfo { Fn); // Ignore kernels without initializers such as global constructors. - if (!KernelInitCB || !KernelDeinitCB) { - indicateOptimisticFixpoint(); + if (!KernelInitCB || !KernelDeinitCB) return; - } + + // Add itself to the reaching kernel and set IsKernelEntry. + ReachingKernelEntries.insert(Fn); + IsKernelEntry = true; // For kernels we might need to initialize/finalize the IsSPMD state and // we need to register a simplification callback so that the Attributor @@ -3346,8 +3371,17 @@ struct AAKernelInfoFunction : AAKernelInfo { return false; } - // Check if the kernel is already in SPMD mode, if so, return success. + // Get the actual kernel, could be the caller of the anchor scope if we have + // a debug wrapper. Function *Kernel = getAnchorScope(); + if (Kernel->hasLocalLinkage()) { + assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper."); + auto *CB = cast<CallBase>(Kernel->user_back()); + Kernel = CB->getCaller(); + } + assert(OMPInfoCache.Kernels.count(Kernel) && "Expected kernel function!"); + + // Check if the kernel is already in SPMD mode, if so, return success. GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable( (Kernel->getName() + "_exec_mode").str()); assert(ExecMode && "Kernel without exec mode?"); @@ -4242,10 +4276,10 @@ struct AAKernelInfoCallSite : AAKernelInfo { unsigned ScheduleTypeVal = ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0; switch (OMPScheduleType(ScheduleTypeVal)) { - case OMPScheduleType::Static: - case OMPScheduleType::StaticChunked: - case OMPScheduleType::Distribute: - case OMPScheduleType::DistributeChunked: + case OMPScheduleType::UnorderedStatic: + case OMPScheduleType::UnorderedStaticChunked: + case OMPScheduleType::OrderedDistribute: + case OMPScheduleType::OrderedDistributeChunked: break; default: SPMDCompatibilityTracker.indicatePessimisticFixpoint(); @@ -4391,7 +4425,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall { std::string Str("simplified value: "); - if (!SimplifiedValue.hasValue()) + if (!SimplifiedValue) return Str + std::string("none"); if (!SimplifiedValue.getValue()) @@ -4421,8 +4455,8 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall { IRPosition::callsite_returned(CB), [&](const IRPosition &IRP, const AbstractAttribute *AA, bool &UsedAssumedInformation) -> Optional<Value *> { - assert((isValidState() || (SimplifiedValue.hasValue() && - SimplifiedValue.getValue() == nullptr)) && + assert((isValidState() || + (SimplifiedValue && SimplifiedValue.getValue() == nullptr)) && "Unexpected invalid state!"); if (!isAtFixpoint()) { @@ -4462,9 +4496,9 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall { ChangeStatus manifest(Attributor &A) override { ChangeStatus Changed = ChangeStatus::UNCHANGED; - if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) { + if (SimplifiedValue && *SimplifiedValue) { Instruction &I = *getCtxI(); - A.changeValueAfterManifest(I, **SimplifiedValue); + A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue); A.deleteAfterManifest(I); CallBase *CB = dyn_cast<CallBase>(&I); @@ -4550,7 +4584,7 @@ private: // We have empty reaching kernels, therefore we cannot tell if the // associated call site can be folded. At this moment, SimplifiedValue // must be none. - assert(!SimplifiedValue.hasValue() && "SimplifiedValue should be none"); + assert(!SimplifiedValue && "SimplifiedValue should be none"); } return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED @@ -4593,7 +4627,7 @@ private: return indicatePessimisticFixpoint(); if (CallerKernelInfoAA.ReachingKernelEntries.empty()) { - assert(!SimplifiedValue.hasValue() && + assert(!SimplifiedValue && "SimplifiedValue should keep none at this point"); return ChangeStatus::UNCHANGED; } @@ -4701,18 +4735,23 @@ void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) { void OpenMPOpt::registerAAs(bool IsModulePass) { if (SCC.empty()) - return; + if (IsModulePass) { // Ensure we create the AAKernelInfo AAs first and without triggering an // update. This will make sure we register all value simplification // callbacks before any other AA has the chance to create an AAValueSimplify // or similar. - for (Function *Kernel : OMPInfoCache.Kernels) + auto CreateKernelInfoCB = [&](Use &, Function &Kernel) { A.getOrCreateAAFor<AAKernelInfo>( - IRPosition::function(*Kernel), /* QueryingAA */ nullptr, + IRPosition::function(Kernel), /* QueryingAA */ nullptr, DepClassTy::NONE, /* ForceUpdate */ false, /* UpdateAfterInit */ false); + return false; + }; + OMPInformationCache::RuntimeFunctionInfo &InitRFI = + OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; + InitRFI.foreachUse(SCC, CreateKernelInfoCB); registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id); registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode); @@ -4900,6 +4939,9 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); KernelSet Kernels = getDeviceKernels(M); + if (PrintModuleBeforeOptimizations) + LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M); + auto IsCalled = [&](Function &F) { if (Kernels.contains(&F)) return true; @@ -4959,8 +5001,15 @@ PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? SetFixpointIterations : 32; - Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false, - MaxFixpointIterations, OREGetter, DEBUG_TYPE); + + AttributorConfig AC(CGUpdater); + AC.DefaultInitializeLiveInternals = false; + AC.RewriteSignatures = false; + AC.MaxFixpointIterations = MaxFixpointIterations; + AC.OREGetter = OREGetter; + AC.PassName = DEBUG_TYPE; + + Attributor A(Functions, InfoCache, AC); OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); bool Changed = OMPOpt.run(true); @@ -5002,6 +5051,9 @@ PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, Module &M = *C.begin()->getFunction().getParent(); + if (PrintModuleBeforeOptimizations) + LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M); + KernelSet Kernels = getDeviceKernels(M); FunctionAnalysisManager &FAM = @@ -5023,8 +5075,16 @@ PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C, unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? SetFixpointIterations : 32; - Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true, - MaxFixpointIterations, OREGetter, DEBUG_TYPE); + + AttributorConfig AC(CGUpdater); + AC.DefaultInitializeLiveInternals = false; + AC.IsModulePass = false; + AC.RewriteSignatures = false; + AC.MaxFixpointIterations = MaxFixpointIterations; + AC.OREGetter = OREGetter; + AC.PassName = DEBUG_TYPE; + + Attributor A(Functions, InfoCache, AC); OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); bool Changed = OMPOpt.run(false); @@ -5094,8 +5154,16 @@ struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass { unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? SetFixpointIterations : 32; - Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true, - MaxFixpointIterations, OREGetter, DEBUG_TYPE); + + AttributorConfig AC(CGUpdater); + AC.DefaultInitializeLiveInternals = false; + AC.IsModulePass = false; + AC.RewriteSignatures = false; + AC.MaxFixpointIterations = MaxFixpointIterations; + AC.OREGetter = OREGetter; + AC.PassName = DEBUG_TYPE; + + Attributor A(Functions, InfoCache, AC); OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); bool Result = OMPOpt.run(false); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp index 5f2223e4047e..54c72bdbb203 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp @@ -14,7 +14,6 @@ #include "llvm/Transforms/IPO/PartialInlining.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -40,6 +39,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/User.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" @@ -55,8 +55,6 @@ #include <algorithm> #include <cassert> #include <cstdint> -#include <functional> -#include <iterator> #include <memory> #include <tuple> #include <vector> @@ -99,7 +97,7 @@ static cl::opt<bool> // This is an option used by testing: static cl::opt<bool> SkipCostAnalysis("skip-partial-inlining-cost-analysis", - cl::init(false), cl::ZeroOrMore, + cl::ReallyHidden, cl::desc("Skip Cost Analysis")); // Used to determine if a cold region is worth outlining based on @@ -129,7 +127,7 @@ static cl::opt<unsigned> MaxNumInlineBlocks( // Command line option to set the maximum number of partial inlining allowed // for the module. The default value of -1 means no limit. static cl::opt<int> MaxNumPartialInlining( - "max-partial-inlining", cl::init(-1), cl::Hidden, cl::ZeroOrMore, + "max-partial-inlining", cl::init(-1), cl::Hidden, cl::desc("Max number of partial inlining. The default is unlimited")); // Used only when PGO or user annotated branch data is absent. It is @@ -137,7 +135,7 @@ static cl::opt<int> MaxNumPartialInlining( // produces larger value, the BFI value will be used. static cl::opt<int> OutlineRegionFreqPercent("outline-region-freq-percent", cl::init(75), - cl::Hidden, cl::ZeroOrMore, + cl::Hidden, cl::desc("Relative frequency of outline region to " "the entry block")); @@ -169,7 +167,7 @@ struct FunctionOutliningInfo { }; struct FunctionOutliningMultiRegionInfo { - FunctionOutliningMultiRegionInfo() {} + FunctionOutliningMultiRegionInfo() = default; // Container for outline regions struct OutlineRegionInfo { @@ -440,7 +438,7 @@ PartialInlinerImpl::computeOutliningColdRegionsInfo( }; auto BBProfileCount = [BFI](BasicBlock *BB) { - return BFI->getBlockProfileCount(BB).getValueOr(0); + return BFI->getBlockProfileCount(BB).value_or(0); }; // Use the same computeBBInlineCost function to compute the cost savings of @@ -741,7 +739,7 @@ BranchProbability PartialInlinerImpl::getOutliningCallBBRelativeFreq( auto OutlineRegionRelFreq = BranchProbability::getBranchProbability( OutliningCallFreq.getFrequency(), EntryFreq.getFrequency()); - if (hasProfileData(*Cloner.OrigFunc, *Cloner.ClonedOI.get())) + if (hasProfileData(*Cloner.OrigFunc, *Cloner.ClonedOI)) return OutlineRegionRelFreq; // When profile data is not available, we need to be conservative in diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp index 6e5aeb9c41f6..ae787be40c55 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -15,19 +15,13 @@ #include "llvm-c/Transforms/PassManagerBuilder.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/CFLAndersAliasAnalysis.h" #include "llvm/Analysis/CFLSteensAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/InlineCost.h" -#include "llvm/Analysis/Passes.h" #include "llvm/Analysis/ScopedNoAliasAA.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TypeBasedAliasAnalysis.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/LegacyPassManager.h" -#include "llvm/IR/Verifier.h" -#include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Target/CGPassBuilderOption.h" @@ -41,22 +35,16 @@ #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/GVN.h" -#include "llvm/Transforms/Scalar/InstSimplifyPass.h" #include "llvm/Transforms/Scalar/LICM.h" #include "llvm/Transforms/Scalar/LoopUnrollPass.h" -#include "llvm/Transforms/Scalar/SCCP.h" #include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h" #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Vectorize.h" -#include "llvm/Transforms/Vectorize/LoopVectorize.h" -#include "llvm/Transforms/Vectorize/SLPVectorizer.h" -#include "llvm/Transforms/Vectorize/VectorCombine.h" using namespace llvm; namespace llvm { -cl::opt<bool> RunPartialInlining("enable-partial-inlining", cl::init(false), - cl::Hidden, cl::ZeroOrMore, +cl::opt<bool> RunPartialInlining("enable-partial-inlining", cl::Hidden, cl::desc("Run Partial inlinining pass")); static cl::opt<bool> @@ -111,8 +99,8 @@ static cl::opt<bool> EnablePerformThinLTO("perform-thinlto", cl::init(false), cl::Hidden, cl::desc("Enable performing ThinLTO.")); -cl::opt<bool> EnableHotColdSplit("hot-cold-split", cl::init(false), - cl::ZeroOrMore, cl::desc("Enable hot-cold splitting pass")); +cl::opt<bool> EnableHotColdSplit("hot-cold-split", + cl::desc("Enable hot-cold splitting pass")); cl::opt<bool> EnableIROutliner("ir-outliner", cl::init(false), cl::Hidden, cl::desc("Enable ir outliner pass")); @@ -126,12 +114,12 @@ cl::opt<bool> cl::desc("Disable pre-instrumentation inliner")); cl::opt<int> PreInlineThreshold( - "preinline-threshold", cl::Hidden, cl::init(75), cl::ZeroOrMore, + "preinline-threshold", cl::Hidden, cl::init(75), cl::desc("Control the amount of inlining in pre-instrumentation inliner " "(default = 75)")); cl::opt<bool> - EnableGVNHoist("enable-gvn-hoist", cl::init(false), cl::ZeroOrMore, + EnableGVNHoist("enable-gvn-hoist", cl::desc("Enable the GVN hoisting pass (default = off)")); static cl::opt<bool> @@ -139,13 +127,8 @@ static cl::opt<bool> cl::Hidden, cl::desc("Disable shrink-wrap library calls")); -static cl::opt<bool> EnableSimpleLoopUnswitch( - "enable-simple-loop-unswitch", cl::init(false), cl::Hidden, - cl::desc("Enable the simple loop unswitch pass. Also enables independent " - "cleanup passes integrated into the loop pass manager pipeline.")); - cl::opt<bool> - EnableGVNSink("enable-gvn-sink", cl::init(false), cl::ZeroOrMore, + EnableGVNSink("enable-gvn-sink", cl::desc("Enable the GVN sinking pass (default = off)")); // This option is used in simplifying testing SampleFDO optimizations for @@ -336,61 +319,6 @@ void PassManagerBuilder::populateFunctionPassManager( FPM.add(createEarlyCSEPass()); } -// Do PGO instrumentation generation or use pass as the option specified. -void PassManagerBuilder::addPGOInstrPasses(legacy::PassManagerBase &MPM, - bool IsCS = false) { - if (IsCS) { - if (!EnablePGOCSInstrGen && !EnablePGOCSInstrUse) - return; - } else if (!EnablePGOInstrGen && PGOInstrUse.empty() && PGOSampleUse.empty()) - return; - - // Perform the preinline and cleanup passes for O1 and above. - // We will not do this inline for context sensitive PGO (when IsCS is true). - if (OptLevel > 0 && !DisablePreInliner && PGOSampleUse.empty() && !IsCS) { - // Create preinline pass. We construct an InlineParams object and specify - // the threshold here to avoid the command line options of the regular - // inliner to influence pre-inlining. The only fields of InlineParams we - // care about are DefaultThreshold and HintThreshold. - InlineParams IP; - IP.DefaultThreshold = PreInlineThreshold; - // FIXME: The hint threshold has the same value used by the regular inliner - // when not optimzing for size. This should probably be lowered after - // performance testing. - // Use PreInlineThreshold for both -Os and -Oz. Not running preinliner makes - // the instrumented binary unusably large. Even if PreInlineThreshold is not - // correct thresold for -Oz, it is better than not running preinliner. - IP.HintThreshold = SizeLevel > 0 ? PreInlineThreshold : 325; - - MPM.add(createFunctionInliningPass(IP)); - MPM.add(createSROAPass()); - MPM.add(createEarlyCSEPass()); // Catch trivial redundancies - MPM.add(createCFGSimplificationPass( - SimplifyCFGOptions().convertSwitchRangeToICmp( - true))); // Merge & remove BBs - MPM.add(createInstructionCombiningPass()); // Combine silly seq's - addExtensionsToPM(EP_Peephole, MPM); - } - if ((EnablePGOInstrGen && !IsCS) || (EnablePGOCSInstrGen && IsCS)) { - MPM.add(createPGOInstrumentationGenLegacyPass(IsCS)); - // Add the profile lowering pass. - InstrProfOptions Options; - if (!PGOInstrGen.empty()) - Options.InstrProfileOutput = PGOInstrGen; - Options.DoCounterPromotion = true; - Options.UseBFIInPromotion = IsCS; - MPM.add(createLoopRotatePass()); - MPM.add(createInstrProfilingLegacyPass(Options, IsCS)); - } - if (!PGOInstrUse.empty()) - MPM.add(createPGOInstrumentationUseLegacyPass(PGOInstrUse, IsCS)); - // Indirect call promotion that promotes intra-module targets only. - // For ThinLTO this is done earlier due to interactions with globalopt - // for imported functions. We don't run this at -O0. - if (OptLevel > 0 && !IsCS) - MPM.add( - createPGOIndirectCallPromotionLegacyPass(false, !PGOSampleUse.empty())); -} void PassManagerBuilder::addFunctionSimplificationPasses( legacy::PassManagerBase &MPM) { // Start of function pass. @@ -432,10 +360,6 @@ void PassManagerBuilder::addFunctionSimplificationPasses( MPM.add(createLibCallsShrinkWrapPass()); addExtensionsToPM(EP_Peephole, MPM); - // Optimize memory intrinsic calls based on the profiled size information. - if (SizeLevel == 0) - MPM.add(createPGOMemOPSizeOptLegacyPass()); - // TODO: Investigate the cost/benefit of tail call elimination on debugging. if (OptLevel > 1) MPM.add(createTailCallEliminationPass()); // Eliminate tail calls @@ -450,13 +374,13 @@ void PassManagerBuilder::addFunctionSimplificationPasses( MPM.add(createVectorCombinePass()); // Begin the loop pass pipeline. - if (EnableSimpleLoopUnswitch) { - // The simple loop unswitch pass relies on separate cleanup passes. Schedule - // them first so when we re-process a loop they run before other loop - // passes. - MPM.add(createLoopInstSimplifyPass()); - MPM.add(createLoopSimplifyCFGPass()); - } + + // The simple loop unswitch pass relies on separate cleanup passes. Schedule + // them first so when we re-process a loop they run before other loop + // passes. + MPM.add(createLoopInstSimplifyPass()); + MPM.add(createLoopSimplifyCFGPass()); + // Try to remove as much code from the loop header as possible, // to reduce amount of IR that will have to be duplicated. However, // do not perform speculative hoisting the first time as LICM @@ -470,10 +394,7 @@ void PassManagerBuilder::addFunctionSimplificationPasses( // TODO: Investigate promotion cap for O1. MPM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap, /*AllowSpeculation=*/true)); - if (EnableSimpleLoopUnswitch) - MPM.add(createSimpleLoopUnswitchLegacyPass()); - else - MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3, DivergentTarget)); + MPM.add(createSimpleLoopUnswitchLegacyPass(OptLevel == 3)); // FIXME: We break the loop pass pipeline here in order to do full // simplifycfg. Eventually loop-simplifycfg should be enhanced to replace the // need for this. @@ -596,7 +517,7 @@ void PassManagerBuilder::addVectorPasses(legacy::PassManagerBase &PM, PM.add(createInstructionCombiningPass()); PM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap, /*AllowSpeculation=*/true)); - PM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3, DivergentTarget)); + PM.add(createSimpleLoopUnswitchLegacyPass()); PM.add(createCFGSimplificationPass( SimplifyCFGOptions().convertSwitchRangeToICmp(true))); PM.add(createInstructionCombiningPass()); @@ -675,10 +596,6 @@ void PassManagerBuilder::addVectorPasses(legacy::PassManagerBase &PM, void PassManagerBuilder::populateModulePassManager( legacy::PassManagerBase &MPM) { - // Whether this is a default or *LTO pre-link pipeline. The FullLTO post-link - // is handled separately, so just check this is not the ThinLTO post-link. - bool DefaultOrPreLinkPipeline = !PerformThinLTO; - MPM.add(createAnnotation2MetadataLegacyPass()); if (!PGOSampleUse.empty()) { @@ -696,7 +613,6 @@ void PassManagerBuilder::populateModulePassManager( // If all optimizations are disabled, just run the always-inline pass and, // if enabled, the function merging pass. if (OptLevel == 0) { - addPGOInstrPasses(MPM); if (Inliner) { MPM.add(Inliner); Inliner = nullptr; @@ -750,8 +666,6 @@ void PassManagerBuilder::populateModulePassManager( // earlier in the pass pipeline, here before globalopt. Otherwise imported // available_externally functions look unreferenced and are removed. if (PerformThinLTO) { - MPM.add(createPGOIndirectCallPromotionLegacyPass(/*InLTO = */ true, - !PGOSampleUse.empty())); MPM.add(createLowerTypeTestsPass(nullptr, nullptr, true)); } @@ -794,19 +708,6 @@ void PassManagerBuilder::populateModulePassManager( createCFGSimplificationPass(SimplifyCFGOptions().convertSwitchRangeToICmp( true))); // Clean up after IPCP & DAE - // For SamplePGO in ThinLTO compile phase, we do not want to do indirect - // call promotion as it will change the CFG too much to make the 2nd - // profile annotation in backend more difficult. - // PGO instrumentation is added during the compile phase for ThinLTO, do - // not run it a second time - if (DefaultOrPreLinkPipeline && !PrepareForThinLTOUsingPGOSampleProfile) - addPGOInstrPasses(MPM); - - // Create profile COMDAT variables. Lld linker wants to see all variables - // before the LTO/ThinLTO link since it needs to resolve symbols/comdats. - if (!PerformThinLTO && EnablePGOCSInstrGen) - MPM.add(createPGOInstrumentationGenCreateVarLegacyPass(PGOInstrGen)); - // We add a module alias analysis pass here. In part due to bugs in the // analysis infrastructure this "works" in that the analysis stays alive // for the entire SCC pass run below. @@ -831,8 +732,6 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createOpenMPOptCGSCCLegacyPass()); MPM.add(createPostOrderFunctionAttrsLegacyPass()); - if (OptLevel > 2) - MPM.add(createArgumentPromotionPass()); // Scalarize uninlined fn args addExtensionsToPM(EP_CGSCCOptimizerLate, MPM); addFunctionSimplificationPasses(MPM); @@ -857,14 +756,6 @@ void PassManagerBuilder::populateModulePassManager( // and saves running remaining passes on the eliminated functions. MPM.add(createEliminateAvailableExternallyPass()); - // CSFDO instrumentation and use pass. Don't invoke this for Prepare pass - // for LTO and ThinLTO -- The actual pass will be called after all inlines - // are performed. - // Need to do this after COMDAT variables have been eliminated, - // (i.e. after EliminateAvailableExternallyPass). - if (!(PrepareForLTO || PrepareForThinLTO)) - addPGOInstrPasses(MPM, /* IsCS */ true); - if (EnableOrderFileInstrumentation) MPM.add(createInstrOrderFilePass()); @@ -1031,13 +922,6 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { // Split call-site with more constrained arguments. PM.add(createCallSiteSplittingPass()); - // Indirect call promotion. This should promote all the targets that are - // left by the earlier promotion pass that promotes intra-module targets. - // This two-step promotion is to save the compile time. For LTO, it should - // produce the same result as if we only do promotion here. - PM.add( - createPGOIndirectCallPromotionLegacyPass(true, !PGOSampleUse.empty())); - // Propage constant function arguments by specializing the functions. if (EnableFunctionSpecialization && OptLevel > 2) PM.add(createFunctionSpecializationPass()); @@ -1103,9 +987,6 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { PM.add(createPruneEHPass()); // Remove dead EH info. - // CSFDO instrumentation and use pass. - addPGOInstrPasses(PM, /* IsCS */ true); - // Infer attributes on declarations, call sites, arguments, etc. for an SCC. if (AttributorRun & AttributorRunOption::CGSCC) PM.add(createAttributorCGSCCLegacyPass()); @@ -1120,14 +1001,10 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { PM.add(createGlobalOptimizerPass()); PM.add(createGlobalDCEPass()); // Remove dead functions. - // If we didn't decide to inline a function, check to see if we can - // transform it to pass arguments by value instead of by reference. - PM.add(createArgumentPromotionPass()); - // The IPO passes may leave cruft around. Clean up after them. PM.add(createInstructionCombiningPass()); addExtensionsToPM(EP_Peephole, PM); - PM.add(createJumpThreadingPass(/*FreezeSelectCond*/ true)); + PM.add(createJumpThreadingPass()); // Break up allocas PM.add(createSROAPass()); @@ -1172,7 +1049,7 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { addExtensionsToPM(EP_Peephole, PM); - PM.add(createJumpThreadingPass(/*FreezeSelectCond*/ true)); + PM.add(createJumpThreadingPass()); } void PassManagerBuilder::addLateLTOOptimizationPasses( @@ -1198,80 +1075,6 @@ void PassManagerBuilder::addLateLTOOptimizationPasses( PM.add(createMergeFunctionsPass()); } -void PassManagerBuilder::populateThinLTOPassManager( - legacy::PassManagerBase &PM) { - PerformThinLTO = true; - if (LibraryInfo) - PM.add(new TargetLibraryInfoWrapperPass(*LibraryInfo)); - - if (VerifyInput) - PM.add(createVerifierPass()); - - if (ImportSummary) { - // This pass imports type identifier resolutions for whole-program - // devirtualization and CFI. It must run early because other passes may - // disturb the specific instruction patterns that these passes look for, - // creating dependencies on resolutions that may not appear in the summary. - // - // For example, GVN may transform the pattern assume(type.test) appearing in - // two basic blocks into assume(phi(type.test, type.test)), which would - // transform a dependency on a WPD resolution into a dependency on a type - // identifier resolution for CFI. - // - // Also, WPD has access to more precise information than ICP and can - // devirtualize more effectively, so it should operate on the IR first. - PM.add(createWholeProgramDevirtPass(nullptr, ImportSummary)); - PM.add(createLowerTypeTestsPass(nullptr, ImportSummary)); - } - - populateModulePassManager(PM); - - if (VerifyOutput) - PM.add(createVerifierPass()); - PerformThinLTO = false; -} - -void PassManagerBuilder::populateLTOPassManager(legacy::PassManagerBase &PM) { - if (LibraryInfo) - PM.add(new TargetLibraryInfoWrapperPass(*LibraryInfo)); - - if (VerifyInput) - PM.add(createVerifierPass()); - - addExtensionsToPM(EP_FullLinkTimeOptimizationEarly, PM); - - if (OptLevel != 0) - addLTOOptimizationPasses(PM); - else { - // The whole-program-devirt pass needs to run at -O0 because only it knows - // about the llvm.type.checked.load intrinsic: it needs to both lower the - // intrinsic itself and handle it in the summary. - PM.add(createWholeProgramDevirtPass(ExportSummary, nullptr)); - } - - // Create a function that performs CFI checks for cross-DSO calls with targets - // in the current module. - PM.add(createCrossDSOCFIPass()); - - // Lower type metadata and the type.test intrinsic. This pass supports Clang's - // control flow integrity mechanisms (-fsanitize=cfi*) and needs to run at - // link time if CFI is enabled. The pass does nothing if CFI is disabled. - PM.add(createLowerTypeTestsPass(ExportSummary, nullptr)); - // Run a second time to clean up any type tests left behind by WPD for use - // in ICP (which is performed earlier than this in the regular LTO pipeline). - PM.add(createLowerTypeTestsPass(nullptr, nullptr, true)); - - if (OptLevel != 0) - addLateLTOOptimizationPasses(PM); - - addExtensionsToPM(EP_FullLinkTimeOptimizationLast, PM); - - PM.add(createAnnotationRemarksLegacyPass()); - - if (VerifyOutput) - PM.add(createVerifierPass()); -} - LLVMPassManagerBuilderRef LLVMPassManagerBuilderCreate() { PassManagerBuilder *PMB = new PassManagerBuilder(); return wrap(PMB); @@ -1337,18 +1140,3 @@ LLVMPassManagerBuilderPopulateModulePassManager(LLVMPassManagerBuilderRef PMB, legacy::PassManagerBase *MPM = unwrap(PM); Builder->populateModulePassManager(*MPM); } - -void LLVMPassManagerBuilderPopulateLTOPassManager(LLVMPassManagerBuilderRef PMB, - LLVMPassManagerRef PM, - LLVMBool Internalize, - LLVMBool RunInliner) { - PassManagerBuilder *Builder = unwrap(PMB); - legacy::PassManagerBase *LPM = unwrap(PM); - - // A small backwards compatibility hack. populateLTOPassManager used to take - // an RunInliner option. - if (RunInliner && !Builder->Inliner) - Builder->Inliner = createFunctionInliningPass(); - - Builder->populateLTOPassManager(*LPM); -} diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/PruneEH.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/PruneEH.cpp index 39de19ca9e9d..e0836a9fd699 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/PruneEH.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/PruneEH.cpp @@ -14,7 +14,6 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" @@ -24,9 +23,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/InitializePasses.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/CallGraphUpdater.h" #include "llvm/Transforms/Utils/Local.h" @@ -246,7 +243,7 @@ static void DeleteBasicBlock(BasicBlock *BB, CallGraphUpdater &CGU) { } if (!I->use_empty()) - I->replaceAllUsesWith(UndefValue::get(I->getType())); + I->replaceAllUsesWith(PoisonValue::get(I->getType())); } if (TokenInst) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SCCP.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SCCP.cpp index 5779553ee732..26fb7d676429 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SCCP.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SCCP.cpp @@ -18,6 +18,7 @@ #include "llvm/InitializePasses.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Scalar/SCCP.h" +#include "llvm/Transforms/Utils/SCCPSolver.h" using namespace llvm; diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp index 7334bf695b67..6859953de962 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp @@ -14,7 +14,8 @@ #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/DebugInfoMetadata.h" -#include "llvm/IR/Instructions.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/ProfileData/SampleProf.h" #include <map> #include <queue> @@ -62,23 +63,24 @@ ContextTrieNode::getHottestChildContext(const LineLocation &CallSite) { return ChildNodeRet; } -ContextTrieNode &ContextTrieNode::moveToChildContext( - const LineLocation &CallSite, ContextTrieNode &&NodeToMove, - uint32_t ContextFramesToRemove, bool DeleteNode) { +ContextTrieNode & +SampleContextTracker::moveContextSamples(ContextTrieNode &ToNodeParent, + const LineLocation &CallSite, + ContextTrieNode &&NodeToMove) { uint64_t Hash = FunctionSamples::getCallSiteHash(NodeToMove.getFuncName(), CallSite); + std::map<uint64_t, ContextTrieNode> &AllChildContext = + ToNodeParent.getAllChildContext(); assert(!AllChildContext.count(Hash) && "Node to remove must exist"); - LineLocation OldCallSite = NodeToMove.CallSiteLoc; - ContextTrieNode &OldParentContext = *NodeToMove.getParentContext(); AllChildContext[Hash] = NodeToMove; ContextTrieNode &NewNode = AllChildContext[Hash]; - NewNode.CallSiteLoc = CallSite; + NewNode.setCallSiteLoc(CallSite); // Walk through nodes in the moved the subtree, and update // FunctionSamples' context as for the context promotion. // We also need to set new parant link for all children. std::queue<ContextTrieNode *> NodeToUpdate; - NewNode.setParentContext(this); + NewNode.setParentContext(&ToNodeParent); NodeToUpdate.push(&NewNode); while (!NodeToUpdate.empty()) { @@ -87,10 +89,8 @@ ContextTrieNode &ContextTrieNode::moveToChildContext( FunctionSamples *FSamples = Node->getFunctionSamples(); if (FSamples) { - FSamples->getContext().promoteOnPath(ContextFramesToRemove); + setContextNode(FSamples, Node); FSamples->getContext().setState(SyntheticContext); - LLVM_DEBUG(dbgs() << " Context promoted to: " - << FSamples->getContext().toString() << "\n"); } for (auto &It : Node->getAllChildContext()) { @@ -100,10 +100,6 @@ ContextTrieNode &ContextTrieNode::moveToChildContext( } } - // Original context no longer needed, destroy if requested. - if (DeleteNode) - OldParentContext.removeChildContext(OldCallSite, NewNode.getFuncName()); - return NewNode; } @@ -131,7 +127,7 @@ void ContextTrieNode::setFunctionSamples(FunctionSamples *FSamples) { Optional<uint32_t> ContextTrieNode::getFunctionSize() const { return FuncSize; } void ContextTrieNode::addFunctionSize(uint32_t FSize) { - if (!FuncSize.hasValue()) + if (!FuncSize) FuncSize = 0; FuncSize = FuncSize.getValue() + FSize; @@ -147,6 +143,10 @@ void ContextTrieNode::setParentContext(ContextTrieNode *Parent) { ParentContext = Parent; } +void ContextTrieNode::setCallSiteLoc(const LineLocation &Loc) { + CallSiteLoc = Loc; +} + void ContextTrieNode::dumpNode() { dbgs() << "Node: " << FuncName << "\n" << " Callsite: " << CallSiteLoc << "\n" @@ -202,13 +202,23 @@ SampleContextTracker::SampleContextTracker( SampleContext Context = FuncSample.first; LLVM_DEBUG(dbgs() << "Tracking Context for function: " << Context.toString() << "\n"); - if (!Context.isBaseContext()) - FuncToCtxtProfiles[Context.getName()].insert(FSamples); ContextTrieNode *NewNode = getOrCreateContextPath(Context, true); assert(!NewNode->getFunctionSamples() && "New node can't have sample profile"); NewNode->setFunctionSamples(FSamples); } + populateFuncToCtxtMap(); +} + +void SampleContextTracker::populateFuncToCtxtMap() { + for (auto *Node : *this) { + FunctionSamples *FSamples = Node->getFunctionSamples(); + if (FSamples) { + FSamples->getContext().setState(RawContext); + setContextNode(FSamples, Node); + FuncToCtxtProfiles[Node->getFuncName()].push_back(FSamples); + } + } } FunctionSamples * @@ -231,7 +241,7 @@ SampleContextTracker::getCalleeContextSamplesFor(const CallBase &Inst, if (CalleeContext) { FunctionSamples *FSamples = CalleeContext->getFunctionSamples(); LLVM_DEBUG(if (FSamples) { - dbgs() << " Callee context found: " << FSamples->getContext().toString() + dbgs() << " Callee context found: " << getContextString(CalleeContext) << "\n"; }); return FSamples; @@ -333,7 +343,7 @@ FunctionSamples *SampleContextTracker::getBaseSamplesFor(StringRef Name, if (Context.hasState(InlinedContext) || Context.hasState(MergedContext)) continue; - ContextTrieNode *FromNode = getContextFor(Context); + ContextTrieNode *FromNode = getContextNodeForProfile(CSamples); if (FromNode == Node) continue; @@ -354,7 +364,7 @@ void SampleContextTracker::markContextSamplesInlined( const FunctionSamples *InlinedSamples) { assert(InlinedSamples && "Expect non-null inlined samples"); LLVM_DEBUG(dbgs() << "Marking context profile as inlined: " - << InlinedSamples->getContext().toString() << "\n"); + << getContextString(*InlinedSamples) << "\n"); InlinedSamples->getContext().setState(InlinedContext); } @@ -405,17 +415,43 @@ ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree( // the context profile in the base (context-less) profile. FunctionSamples *FromSamples = NodeToPromo.getFunctionSamples(); assert(FromSamples && "Shouldn't promote a context without profile"); + (void)FromSamples; // Unused in release build. + LLVM_DEBUG(dbgs() << " Found context tree root to promote: " - << FromSamples->getContext().toString() << "\n"); + << getContextString(&NodeToPromo) << "\n"); assert(!FromSamples->getContext().hasState(InlinedContext) && "Shouldn't promote inlined context profile"); - uint32_t ContextFramesToRemove = - FromSamples->getContext().getContextFrames().size() - 1; - return promoteMergeContextSamplesTree(NodeToPromo, RootContext, - ContextFramesToRemove); + return promoteMergeContextSamplesTree(NodeToPromo, RootContext); +} + +#ifndef NDEBUG +std::string +SampleContextTracker::getContextString(const FunctionSamples &FSamples) const { + return getContextString(getContextNodeForProfile(&FSamples)); } +std::string +SampleContextTracker::getContextString(ContextTrieNode *Node) const { + SampleContextFrameVector Res; + if (Node == &RootContext) + return std::string(); + Res.emplace_back(Node->getFuncName(), LineLocation(0, 0)); + + ContextTrieNode *PreNode = Node; + Node = Node->getParentContext(); + while (Node && Node != &RootContext) { + Res.emplace_back(Node->getFuncName(), PreNode->getCallSiteLoc()); + PreNode = Node; + Node = Node->getParentContext(); + } + + std::reverse(Res.begin(), Res.end()); + + return SampleContext::getContextString(Res); +} +#endif + void SampleContextTracker::dump() { RootContext.dumpTree(); } StringRef SampleContextTracker::getFuncNameFor(ContextTrieNode *Node) const { @@ -526,8 +562,7 @@ ContextTrieNode &SampleContextTracker::addTopLevelContextNode(StringRef FName) { } void SampleContextTracker::mergeContextNode(ContextTrieNode &FromNode, - ContextTrieNode &ToNode, - uint32_t ContextFramesToRemove) { + ContextTrieNode &ToNode) { FunctionSamples *FromSamples = FromNode.getFunctionSamples(); FunctionSamples *ToSamples = ToNode.getFunctionSamples(); if (FromSamples && ToSamples) { @@ -540,16 +575,13 @@ void SampleContextTracker::mergeContextNode(ContextTrieNode &FromNode, } else if (FromSamples) { // Transfer FromSamples from FromNode to ToNode ToNode.setFunctionSamples(FromSamples); + setContextNode(FromSamples, &ToNode); FromSamples->getContext().setState(SyntheticContext); - FromSamples->getContext().promoteOnPath(ContextFramesToRemove); - FromNode.setFunctionSamples(nullptr); } } ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree( - ContextTrieNode &FromNode, ContextTrieNode &ToNodeParent, - uint32_t ContextFramesToRemove) { - assert(ContextFramesToRemove && "Context to remove can't be empty"); + ContextTrieNode &FromNode, ContextTrieNode &ToNodeParent) { // Ignore call site location if destination is top level under root LineLocation NewCallSiteLoc = LineLocation(0, 0); @@ -566,22 +598,25 @@ ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree( if (!ToNode) { // Do not delete node to move from its parent here because // caller is iterating over children of that parent node. - ToNode = &ToNodeParent.moveToChildContext( - NewCallSiteLoc, std::move(FromNode), ContextFramesToRemove, false); + ToNode = + &moveContextSamples(ToNodeParent, NewCallSiteLoc, std::move(FromNode)); + LLVM_DEBUG({ + dbgs() << " Context promoted and merged to: " << getContextString(ToNode) + << "\n"; + }); } else { // Destination node exists, merge samples for the context tree - mergeContextNode(FromNode, *ToNode, ContextFramesToRemove); + mergeContextNode(FromNode, *ToNode); LLVM_DEBUG({ if (ToNode->getFunctionSamples()) dbgs() << " Context promoted and merged to: " - << ToNode->getFunctionSamples()->getContext().toString() << "\n"; + << getContextString(ToNode) << "\n"; }); // Recursively promote and merge children for (auto &It : FromNode.getAllChildContext()) { ContextTrieNode &FromChildNode = It.second; - promoteMergeContextSamplesTree(FromChildNode, *ToNode, - ContextFramesToRemove); + promoteMergeContextSamplesTree(FromChildNode, *ToNode); } // Remove children once they're all merged @@ -594,4 +629,14 @@ ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree( return *ToNode; } + +void SampleContextTracker::createContextLessProfileMap( + SampleProfileMap &ContextLessProfiles) { + for (auto *Node : *this) { + FunctionSamples *FProfile = Node->getFunctionSamples(); + // Profile's context can be empty, use ContextNode's func name. + if (FProfile) + ContextLessProfiles[Node->getFuncName()].merge(*FProfile); + } +} } // namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp index bc6051de90c4..40de69bbf2cf 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp @@ -25,11 +25,8 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/PriorityQueue.h" #include "llvm/ADT/SCCIterator.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringMap.h" @@ -38,22 +35,16 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BlockFrequencyInfoImpl.h" #include "llvm/Analysis/CallGraph.h" -#include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/Analysis/InlineAdvisor.h" #include "llvm/Analysis/InlineCost.h" -#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" -#include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ReplayInlineAdvisor.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CFG.h" -#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/DiagnosticInfo.h" -#include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/InstrTypes.h" @@ -64,6 +55,7 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/PseudoProbe.h" #include "llvm/IR/ValueSymbolTable.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" @@ -73,9 +65,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/ErrorOr.h" -#include "llvm/Support/GenericDomTree.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/ProfiledCallGraph.h" @@ -84,7 +74,6 @@ #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/CallPromotionUtils.h" #include "llvm/Transforms/Utils/Cloning.h" -#include "llvm/Transforms/Utils/SampleProfileInference.h" #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h" #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h" #include <algorithm> @@ -151,8 +140,7 @@ static cl::opt<bool> ProfileSampleBlockAccurate( "them conservatively as unknown. ")); static cl::opt<bool> ProfileAccurateForSymsInList( - "profile-accurate-for-symsinlist", cl::Hidden, cl::ZeroOrMore, - cl::init(true), + "profile-accurate-for-symsinlist", cl::Hidden, cl::init(true), cl::desc("For symbols in profile symbol list, regard their profiles to " "be accurate. It may be overriden by profile-sample-accurate. ")); @@ -183,6 +171,15 @@ static cl::opt<bool> ProfileSizeInline( cl::desc("Inline cold call sites in profile loader if it's beneficial " "for code size.")); +// Since profiles are consumed by many passes, turning on this option has +// side effects. For instance, pre-link SCC inliner would see merged profiles +// and inline the hot functions (that are skipped in this pass). +static cl::opt<bool> DisableSampleLoaderInlining( + "disable-sample-loader-inlining", cl::Hidden, cl::init(false), + cl::desc("If true, artifically skip inline transformation in sample-loader " + "pass, and merge (or scale) profiles (as configured by " + "--sample-profile-merge-inlinee).")); + cl::opt<int> ProfileInlineGrowthLimit( "sample-profile-inline-growth-limit", cl::Hidden, cl::init(12), cl::desc("The size growth ratio limit for proirity-based sample profile " @@ -219,19 +216,19 @@ static cl::opt<unsigned> ProfileICPRelativeHotnessSkip( "Skip relative hotness check for ICP up to given number of targets.")); static cl::opt<bool> CallsitePrioritizedInline( - "sample-profile-prioritized-inline", cl::Hidden, cl::ZeroOrMore, - cl::init(false), + "sample-profile-prioritized-inline", cl::Hidden, + cl::desc("Use call site prioritized inlining for sample profile loader." "Currently only CSSPGO is supported.")); static cl::opt<bool> UsePreInlinerDecision( - "sample-profile-use-preinliner", cl::Hidden, cl::ZeroOrMore, - cl::init(false), + "sample-profile-use-preinliner", cl::Hidden, + cl::desc("Use the preinliner decisions stored in profile context.")); static cl::opt<bool> AllowRecursiveInline( - "sample-profile-recursive-inline", cl::Hidden, cl::ZeroOrMore, - cl::init(false), + "sample-profile-recursive-inline", cl::Hidden, + cl::desc("Allow sample loader inliner to inline recursive calls.")); static cl::opt<std::string> ProfileInlineReplayFile( @@ -287,7 +284,6 @@ static cl::opt<CallSiteFormat::Format> ProfileInlineReplayFormat( static cl::opt<unsigned> MaxNumPromotions("sample-profile-icp-max-prom", cl::init(3), cl::Hidden, - cl::ZeroOrMore, cl::desc("Max number of promotions for a single indirect " "call callsite in sample profile loader")); @@ -295,6 +291,13 @@ static cl::opt<bool> OverwriteExistingWeights( "overwrite-existing-weights", cl::Hidden, cl::init(false), cl::desc("Ignore existing branch weights on IR and always overwrite.")); +static cl::opt<bool> AnnotateSampleProfileInlinePhase( + "annotate-sample-profile-inline-phase", cl::Hidden, cl::init(false), + cl::desc("Annotate LTO phase (prelink / postlink), or main (no LTO) for " + "sample-profile inline pass name.")); + +extern cl::opt<bool> EnableExtTspBlockPlacement; + namespace { using BlockWeightMap = DenseMap<const BasicBlock *, uint64_t>; @@ -425,7 +428,11 @@ public: : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName)), GetAC(std::move(GetAssumptionCache)), GetTTI(std::move(GetTargetTransformInfo)), GetTLI(std::move(GetTLI)), - LTOPhase(LTOPhase) {} + LTOPhase(LTOPhase), + AnnotatedPassName(AnnotateSampleProfileInlinePhase + ? llvm::AnnotateInlinePassName(InlineContext{ + LTOPhase, InlinePass::SampleProfileInliner}) + : CSINLINE_DEBUG) {} bool doInitialization(Module &M, FunctionAnalysisManager *FAM = nullptr); bool runOnModule(Module &M, ModuleAnalysisManager *AM, @@ -487,15 +494,13 @@ protected: /// Profile tracker for different context. std::unique_ptr<SampleContextTracker> ContextTracker; - /// Flag indicating whether input profile is context-sensitive - bool ProfileIsCSFlat = false; - /// Flag indicating which LTO/ThinLTO phase the pass is invoked in. /// /// We need to know the LTO phase because for example in ThinLTOPrelink /// phase, in annotation, we should not promote indirect calls. Instead, /// we will mark GUIDs that needs to be annotated to the function. - ThinOrFullLTOPhase LTOPhase; + const ThinOrFullLTOPhase LTOPhase; + const std::string AnnotatedPassName; /// Profle Symbol list tells whether a function name appears in the binary /// used to generate the current profile. @@ -535,6 +540,11 @@ protected: // A pseudo probe helper to correlate the imported sample counts. std::unique_ptr<PseudoProbeManager> ProbeManager; + +private: + const char *getAnnotatedRemarkPassName() const { + return AnnotatedPassName.c_str(); + } }; class SampleProfileLoaderLegacyPass : public ModulePass { @@ -605,7 +615,7 @@ ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) { // call instruction should have 0 count. // For CS profile, the callsite count of previously inlined callees is // populated with the entry count of the callees. - if (!ProfileIsCSFlat) + if (!FunctionSamples::ProfileIsCS) if (const auto *CB = dyn_cast<CallBase>(&Inst)) if (!CB->isIndirectCall() && findCalleeFunctionSamples(*CB)) return 0; @@ -644,7 +654,7 @@ ErrorOr<uint64_t> SampleProfileLoader::getProbeWeight(const Instruction &Inst) { // call instruction should have 0 count. // For CS profile, the callsite count of previously inlined callees is // populated with the entry count of the callees. - if (!ProfileIsCSFlat) + if (!FunctionSamples::ProfileIsCS) if (const auto *CB = dyn_cast<CallBase>(&Inst)) if (!CB->isIndirectCall() && findCalleeFunctionSamples(*CB)) return 0; @@ -698,7 +708,7 @@ SampleProfileLoader::findCalleeFunctionSamples(const CallBase &Inst) const { if (Function *Callee = Inst.getCalledFunction()) CalleeName = Callee->getName(); - if (ProfileIsCSFlat) + if (FunctionSamples::ProfileIsCS) return ContextTracker->getCalleeContextSamplesFor(Inst, CalleeName); const FunctionSamples *FS = findFunctionSamples(Inst); @@ -730,7 +740,7 @@ SampleProfileLoader::findIndirectCallFunctionSamples( FunctionSamples::getGUID(R->getName()); }; - if (ProfileIsCSFlat) { + if (FunctionSamples::ProfileIsCS) { auto CalleeSamples = ContextTracker->getIndirectCalleeContextSamplesFor(DIL); if (CalleeSamples.empty()) @@ -783,7 +793,7 @@ SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const { auto it = DILocation2SampleMap.try_emplace(DIL,nullptr); if (it.second) { - if (ProfileIsCSFlat) + if (FunctionSamples::ProfileIsCS) it.first->second = ContextTracker->getContextSamplesFor(DIL); else it.first->second = @@ -839,6 +849,13 @@ static void updateIDTMetaData(Instruction &Inst, const SmallVectorImpl<InstrProfValueData> &CallTargets, uint64_t Sum) { + // Bail out early if MaxNumPromotions is zero. + // This prevents allocating an array of zero length below. + // + // Note `updateIDTMetaData` is called in two places so check + // `MaxNumPromotions` inside it. + if (MaxNumPromotions == 0) + return; uint32_t NumVals = 0; // OldSum is the existing total count in the value profile data. uint64_t OldSum = 0; @@ -922,6 +939,14 @@ updateIDTMetaData(Instruction &Inst, bool SampleProfileLoader::tryPromoteAndInlineCandidate( Function &F, InlineCandidate &Candidate, uint64_t SumOrigin, uint64_t &Sum, SmallVector<CallBase *, 8> *InlinedCallSite) { + // Bail out early if sample-loader inliner is disabled. + if (DisableSampleLoaderInlining) + return false; + + // Bail out early if MaxNumPromotions is zero. + // This prevents allocating an array of zero length in callees below. + if (MaxNumPromotions == 0) + return false; auto CalleeFunctionName = Candidate.CalleeSamples->getFuncName(); auto R = SymbolMap.find(CalleeFunctionName); if (R == SymbolMap.end() || !R->getValue()) @@ -1009,8 +1034,9 @@ void SampleProfileLoader::emitOptimizationRemarksForInlineCandidates( for (auto I : Candidates) { Function *CalledFunction = I->getCalledFunction(); if (CalledFunction) { - ORE->emit(OptimizationRemarkAnalysis(CSINLINE_DEBUG, "InlineAttempt", - I->getDebugLoc(), I->getParent()) + ORE->emit(OptimizationRemarkAnalysis(getAnnotatedRemarkPassName(), + "InlineAttempt", I->getDebugLoc(), + I->getParent()) << "previous inlining reattempted for " << (Hot ? "hotness: '" : "size: '") << ore::NV("Callee", CalledFunction) << "' into '" @@ -1042,13 +1068,12 @@ void SampleProfileLoader::findExternalInlineCandidate( // For AutoFDO profile, retrieve candidate profiles by walking over // the nested inlinee profiles. - if (!ProfileIsCSFlat) { + if (!FunctionSamples::ProfileIsCS) { Samples->findInlinedFunctions(InlinedGUIDs, SymbolMap, Threshold); return; } - ContextTrieNode *Caller = - ContextTracker->getContextFor(Samples->getContext()); + ContextTrieNode *Caller = ContextTracker->getContextNodeForProfile(Samples); std::queue<ContextTrieNode *> CalleeList; CalleeList.push(Caller); while (!CalleeList.empty()) { @@ -1098,11 +1123,20 @@ void SampleProfileLoader::findExternalInlineCandidate( /// Iteratively inline hot callsites of a function. /// -/// Iteratively traverse all callsites of the function \p F, and find if -/// the corresponding inlined instance exists and is hot in profile. If -/// it is hot enough, inline the callsites and adds new callsites of the -/// callee into the caller. If the call is an indirect call, first promote -/// it to direct call. Each indirect call is limited with a single target. +/// Iteratively traverse all callsites of the function \p F, so as to +/// find out callsites with corresponding inline instances. +/// +/// For such callsites, +/// - If it is hot enough, inline the callsites and adds callsites of the callee +/// into the caller. If the call is an indirect call, first promote +/// it to direct call. Each indirect call is limited with a single target. +/// +/// - If a callsite is not inlined, merge the its profile to the outline +/// version (if --sample-profile-merge-inlinee is true), or scale the +/// counters of standalone function based on the profile of inlined +/// instances (if --sample-profile-merge-inlinee is false). +/// +/// Later passes may consume the updated profiles. /// /// \param F function to perform iterative inlining. /// \param InlinedGUIDs a set to be updated to include all GUIDs that are @@ -1137,7 +1171,7 @@ bool SampleProfileLoader::inlineHotFunctions( assert((!FunctionSamples::UseMD5 || FS->GUIDToFuncNameMap) && "GUIDToFuncNameMap has to be populated"); AllCandidates.push_back(CB); - if (FS->getEntrySamples() > 0 || ProfileIsCSFlat) + if (FS->getEntrySamples() > 0 || FunctionSamples::ProfileIsCS) LocalNotInlinedCallSites.try_emplace(CB, FS); if (callsiteIsHot(FS, PSI, ProfAccForSymsInList)) Hot = true; @@ -1200,13 +1234,17 @@ bool SampleProfileLoader::inlineHotFunctions( // For CS profile, profile for not inlined context will be merged when // base profile is being retrieved. - if (!FunctionSamples::ProfileIsCSFlat) + if (!FunctionSamples::ProfileIsCS) promoteMergeNotInlinedContextSamples(LocalNotInlinedCallSites, F); return Changed; } bool SampleProfileLoader::tryInlineCandidate( InlineCandidate &Candidate, SmallVector<CallBase *, 8> *InlinedCallSites) { + // Do not attempt to inline a candidate if + // --disable-sample-loader-inlining is true. + if (DisableSampleLoaderInlining) + return false; CallBase &CB = *Candidate.CallInstr; Function *CalledFunction = CB.getCalledFunction(); @@ -1216,7 +1254,8 @@ bool SampleProfileLoader::tryInlineCandidate( InlineCost Cost = shouldInlineCandidate(Candidate); if (Cost.isNever()) { - ORE->emit(OptimizationRemarkAnalysis(CSINLINE_DEBUG, "InlineFail", DLoc, BB) + ORE->emit(OptimizationRemarkAnalysis(getAnnotatedRemarkPassName(), + "InlineFail", DLoc, BB) << "incompatible inlining"); return false; } @@ -1226,45 +1265,45 @@ bool SampleProfileLoader::tryInlineCandidate( InlineFunctionInfo IFI(nullptr, GetAC); IFI.UpdateProfile = false; - if (InlineFunction(CB, IFI).isSuccess()) { - // Merge the attributes based on the inlining. - AttributeFuncs::mergeAttributesForInlining(*BB->getParent(), - *CalledFunction); - - // The call to InlineFunction erases I, so we can't pass it here. - emitInlinedIntoBasedOnCost(*ORE, DLoc, BB, *CalledFunction, - *BB->getParent(), Cost, true, CSINLINE_DEBUG); - - // Now populate the list of newly exposed call sites. - if (InlinedCallSites) { - InlinedCallSites->clear(); - for (auto &I : IFI.InlinedCallSites) - InlinedCallSites->push_back(I); - } + if (!InlineFunction(CB, IFI).isSuccess()) + return false; - if (ProfileIsCSFlat) - ContextTracker->markContextSamplesInlined(Candidate.CalleeSamples); - ++NumCSInlined; - - // Prorate inlined probes for a duplicated inlining callsite which probably - // has a distribution less than 100%. Samples for an inlinee should be - // distributed among the copies of the original callsite based on each - // callsite's distribution factor for counts accuracy. Note that an inlined - // probe may come with its own distribution factor if it has been duplicated - // in the inlinee body. The two factor are multiplied to reflect the - // aggregation of duplication. - if (Candidate.CallsiteDistribution < 1) { - for (auto &I : IFI.InlinedCallSites) { - if (Optional<PseudoProbe> Probe = extractProbe(*I)) - setProbeDistributionFactor(*I, Probe->Factor * - Candidate.CallsiteDistribution); - } - NumDuplicatedInlinesite++; - } + // Merge the attributes based on the inlining. + AttributeFuncs::mergeAttributesForInlining(*BB->getParent(), + *CalledFunction); - return true; + // The call to InlineFunction erases I, so we can't pass it here. + emitInlinedIntoBasedOnCost(*ORE, DLoc, BB, *CalledFunction, *BB->getParent(), + Cost, true, getAnnotatedRemarkPassName()); + + // Now populate the list of newly exposed call sites. + if (InlinedCallSites) { + InlinedCallSites->clear(); + for (auto &I : IFI.InlinedCallSites) + InlinedCallSites->push_back(I); } - return false; + + if (FunctionSamples::ProfileIsCS) + ContextTracker->markContextSamplesInlined(Candidate.CalleeSamples); + ++NumCSInlined; + + // Prorate inlined probes for a duplicated inlining callsite which probably + // has a distribution less than 100%. Samples for an inlinee should be + // distributed among the copies of the original callsite based on each + // callsite's distribution factor for counts accuracy. Note that an inlined + // probe may come with its own distribution factor if it has been duplicated + // in the inlinee body. The two factor are multiplied to reflect the + // aggregation of duplication. + if (Candidate.CallsiteDistribution < 1) { + for (auto &I : IFI.InlinedCallSites) { + if (Optional<PseudoProbe> Probe = extractProbe(*I)) + setProbeDistributionFactor(*I, Probe->Factor * + Candidate.CallsiteDistribution); + } + NumDuplicatedInlinesite++; + } + + return true; } bool SampleProfileLoader::getInlineCandidate(InlineCandidate *NewCandidate, @@ -1285,14 +1324,8 @@ bool SampleProfileLoader::getInlineCandidate(InlineCandidate *NewCandidate, if (Optional<PseudoProbe> Probe = extractProbe(*CB)) Factor = Probe->Factor; - uint64_t CallsiteCount = 0; - ErrorOr<uint64_t> Weight = getBlockWeight(CB->getParent()); - if (Weight) - CallsiteCount = Weight.get(); - if (CalleeSamples) - CallsiteCount = std::max( - CallsiteCount, uint64_t(CalleeSamples->getEntrySamples() * Factor)); - + uint64_t CallsiteCount = + CalleeSamples ? CalleeSamples->getEntrySamples() * Factor : 0; *NewCandidate = {CB, CalleeSamples, CallsiteCount, Factor}; return true; } @@ -1387,7 +1420,6 @@ SampleProfileLoader::shouldInlineCandidate(InlineCandidate &Candidate) { bool SampleProfileLoader::inlineHotFunctionsWithPriority( Function &F, DenseSet<GlobalValue::GUID> &InlinedGUIDs) { - // ProfAccForSymsInList is used in callsiteIsHot. The assertion makes sure // Profile symbol list is ignored when profile-sample-accurate is on. assert((!ProfAccForSymsInList || @@ -1513,7 +1545,7 @@ bool SampleProfileLoader::inlineHotFunctionsWithPriority( // For CS profile, profile for not inlined context will be merged when // base profile is being retrieved. - if (!FunctionSamples::ProfileIsCSFlat) + if (!FunctionSamples::ProfileIsCS) promoteMergeNotInlinedContextSamples(LocalNotInlinedCallSites, F); return Changed; } @@ -1528,11 +1560,11 @@ void SampleProfileLoader::promoteMergeNotInlinedContextSamples( if (!Callee || Callee->isDeclaration()) continue; - ORE->emit(OptimizationRemarkAnalysis(CSINLINE_DEBUG, "NotInline", - I->getDebugLoc(), I->getParent()) - << "previous inlining not repeated: '" - << ore::NV("Callee", Callee) << "' into '" - << ore::NV("Caller", &F) << "'"); + ORE->emit( + OptimizationRemarkAnalysis(getAnnotatedRemarkPassName(), "NotInline", + I->getDebugLoc(), I->getParent()) + << "previous inlining not repeated: '" << ore::NV("Callee", Callee) + << "' into '" << ore::NV("Caller", &F) << "'"); ++NumCSNotInlined; const FunctionSamples *FS = Pair.getSecond(); @@ -1540,6 +1572,10 @@ void SampleProfileLoader::promoteMergeNotInlinedContextSamples( continue; } + // Do not merge a context that is already duplicated into the base profile. + if (FS->getContext().hasAttribute(sampleprof::ContextDuplicatedIntoBase)) + continue; + if (ProfileMergeInlinee) { // A function call can be replicated by optimizations like callsite // splitting or jump threading and the replicates end up sharing the @@ -1623,7 +1659,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { // With CSSPGO all indirect call targets are counted torwards the // original indirect call site in the profile, including both // inlined and non-inlined targets. - if (!FunctionSamples::ProfileIsCSFlat) { + if (!FunctionSamples::ProfileIsCS) { if (const FunctionSamplesMap *M = FS->findFunctionSamplesMapAt(CallSite)) { for (const auto &NameFS : *M) @@ -1714,6 +1750,11 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { } } + // FIXME: Re-enable for sample profiling after investigating why the sum + // of branch weights can be 0 + // + // misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false); + uint64_t TempWeight; // Only set weights if there is at least one non-zero weight. // In any other case, let the analyzer set weights. @@ -1798,7 +1839,7 @@ INITIALIZE_PASS_END(SampleProfileLoaderLegacyPass, "sample-profile", std::unique_ptr<ProfiledCallGraph> SampleProfileLoader::buildProfiledCallGraph(CallGraph &CG) { std::unique_ptr<ProfiledCallGraph> ProfiledCG; - if (ProfileIsCSFlat) + if (FunctionSamples::ProfileIsCS) ProfiledCG = std::make_unique<ProfiledCallGraph>(*ContextTracker); else ProfiledCG = std::make_unique<ProfiledCallGraph>(Reader->getProfiles()); @@ -1843,8 +1884,8 @@ SampleProfileLoader::buildFunctionOrder(Module &M, CallGraph *CG) { assert(&CG->getModule() == &M); - if (UseProfiledCallGraph || - (ProfileIsCSFlat && !UseProfiledCallGraph.getNumOccurrences())) { + if (UseProfiledCallGraph || (FunctionSamples::ProfileIsCS && + !UseProfiledCallGraph.getNumOccurrences())) { // Use profiled call edges to augment the top-down order. There are cases // that the top-down order computed based on the static call graph doesn't // reflect real execution order. For example @@ -1973,40 +2014,50 @@ bool SampleProfileLoader::doInitialization(Module &M, ProfileInlineReplayScope, ProfileInlineReplayFallback, {ProfileInlineReplayFormat}}, - /*EmitRemarks=*/false); + /*EmitRemarks=*/false, InlineContext{LTOPhase, InlinePass::ReplaySampleProfileInliner}); } - // Apply tweaks if context-sensitive profile is available. - if (Reader->profileIsCSFlat() || Reader->profileIsCSNested()) { - ProfileIsCSFlat = Reader->profileIsCSFlat(); + // Apply tweaks if context-sensitive or probe-based profile is available. + if (Reader->profileIsCS() || Reader->profileIsPreInlined() || + Reader->profileIsProbeBased()) { + if (!UseIterativeBFIInference.getNumOccurrences()) + UseIterativeBFIInference = true; + if (!SampleProfileUseProfi.getNumOccurrences()) + SampleProfileUseProfi = true; + if (!EnableExtTspBlockPlacement.getNumOccurrences()) + EnableExtTspBlockPlacement = true; // Enable priority-base inliner and size inline by default for CSSPGO. if (!ProfileSizeInline.getNumOccurrences()) ProfileSizeInline = true; if (!CallsitePrioritizedInline.getNumOccurrences()) CallsitePrioritizedInline = true; - - // For CSSPGO, use preinliner decision by default when available. - if (!UsePreInlinerDecision.getNumOccurrences()) - UsePreInlinerDecision = true; - // For CSSPGO, we also allow recursive inline to best use context profile. if (!AllowRecursiveInline.getNumOccurrences()) AllowRecursiveInline = true; - // Enable iterative-BFI by default for CSSPGO. - if (!UseIterativeBFIInference.getNumOccurrences()) - UseIterativeBFIInference = true; - // Enable Profi by default for CSSPGO. - if (!SampleProfileUseProfi.getNumOccurrences()) - SampleProfileUseProfi = true; + if (Reader->profileIsPreInlined()) { + if (!UsePreInlinerDecision.getNumOccurrences()) + UsePreInlinerDecision = true; + } - if (FunctionSamples::ProfileIsCSFlat) { - // Tracker for profiles under different context - ContextTracker = std::make_unique<SampleContextTracker>( - Reader->getProfiles(), &GUIDToFuncNameMap); + if (!Reader->profileIsCS()) { + // Non-CS profile should be fine without a function size budget for the + // inliner since the contexts in the profile are either all from inlining + // in the prevoius build or pre-computed by the preinliner with a size + // cap, thus they are bounded. + if (!ProfileInlineLimitMin.getNumOccurrences()) + ProfileInlineLimitMin = std::numeric_limits<unsigned>::max(); + if (!ProfileInlineLimitMax.getNumOccurrences()) + ProfileInlineLimitMax = std::numeric_limits<unsigned>::max(); } } + if (Reader->profileIsCS()) { + // Tracker for profiles under different context + ContextTracker = std::make_unique<SampleContextTracker>( + Reader->getProfiles(), &GUIDToFuncNameMap); + } + // Load pseudo probe descriptors for probe-based function samples. if (Reader->profileIsProbeBased()) { ProbeManager = std::make_unique<PseudoProbeManager>(M); @@ -2082,7 +2133,7 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, } // Account for cold calls not inlined.... - if (!ProfileIsCSFlat) + if (!FunctionSamples::ProfileIsCS) for (const std::pair<Function *, NotInlinedProfileInfo> &pair : notInlinedCallInfo) updateProfileCallee(pair.first, pair.second.entryCount); @@ -2145,7 +2196,7 @@ bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) // Initialize entry count when the function has no existing entry // count value. - if (!F.getEntryCount().hasValue()) + if (!F.getEntryCount()) F.setEntryCount(ProfileCount(initialEntryCount, Function::PCT_Real)); std::unique_ptr<OptimizationRemarkEmitter> OwnedORE; if (AM) { @@ -2158,7 +2209,7 @@ bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) ORE = OwnedORE.get(); } - if (ProfileIsCSFlat) + if (FunctionSamples::ProfileIsCS) Samples = ContextTracker->getBaseSamplesFor(F); else Samples = Reader->getSamplesFor(F); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp index e104ae00e916..d1ab2649ee2e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp @@ -13,21 +13,19 @@ #include "llvm/Transforms/IPO/SampleProfileProbe.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/BlockFrequencyInfo.h" -#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CFG.h" -#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfoMetadata.h" -#include "llvm/IR/GlobalValue.h" -#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/PseudoProbe.h" #include "llvm/ProfileData/SampleProf.h" #include "llvm/Support/CRC.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include <unordered_set> @@ -416,7 +414,7 @@ void PseudoProbeUpdatePass::runOnFunction(Function &F, FunctionAnalysisManager &FAM) { BlockFrequencyInfo &BFI = FAM.getResult<BlockFrequencyAnalysis>(F); auto BBProfileCount = [&BFI](BasicBlock *BB) { - return BFI.getBlockProfileCount(BB).getValueOr(0); + return BFI.getBlockProfileCount(BB).value_or(0); }; // Collect the sum of execution weight for each probe. diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp index 95393d9476e0..c7d54b8cdeb0 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SyntheticCountsPropagation.cpp @@ -25,18 +25,13 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/SyntheticCountsPropagation.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CallGraph.h" -#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/SyntheticCountsUtils.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" using namespace llvm; using Scaled64 = ScaledNumber<uint64_t>; @@ -47,18 +42,17 @@ using ProfileCount = Function::ProfileCount; namespace llvm { cl::opt<int> InitialSyntheticCount("initial-synthetic-count", cl::Hidden, cl::init(10), - cl::ZeroOrMore, cl::desc("Initial value of synthetic entry count")); } // namespace llvm /// Initial synthetic count assigned to inline functions. static cl::opt<int> InlineSyntheticCount( - "inline-synthetic-count", cl::Hidden, cl::init(15), cl::ZeroOrMore, + "inline-synthetic-count", cl::Hidden, cl::init(15), cl::desc("Initial synthetic entry count for inline functions.")); /// Initial synthetic count assigned to cold functions. static cl::opt<int> ColdSyntheticCount( - "cold-synthetic-count", cl::Hidden, cl::init(5), cl::ZeroOrMore, + "cold-synthetic-count", cl::Hidden, cl::init(5), cl::desc("Initial synthetic entry count for cold functions.")); // Assign initial synthetic entry counts to functions. diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp index 52708ff2f226..a360a768a2bc 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp @@ -21,7 +21,6 @@ #include "llvm/InitializePasses.h" #include "llvm/Object/ModuleSymbolTable.h" #include "llvm/Pass.h" -#include "llvm/Support/ScopedPrinter.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/FunctionAttrs.h" @@ -311,7 +310,8 @@ void splitAndWriteThinLTOBitcode( return; } if (!F->isDeclaration() && - computeFunctionBodyMemoryAccess(*F, AARGetter(*F)) == MAK_ReadNone) + computeFunctionBodyMemoryAccess(*F, AARGetter(*F)) == + FMRB_DoesNotAccessMemory) EligibleVirtualFns.insert(F); }); } @@ -542,11 +542,11 @@ class WriteThinLTOBitcode : public ModulePass { raw_ostream &OS; // raw_ostream to print on // The output stream on which to emit a minimized module for use // just in the thin link, if requested. - raw_ostream *ThinLinkOS; + raw_ostream *ThinLinkOS = nullptr; public: static char ID; // Pass identification, replacement for typeid - WriteThinLTOBitcode() : ModulePass(ID), OS(dbgs()), ThinLinkOS(nullptr) { + WriteThinLTOBitcode() : ModulePass(ID), OS(dbgs()) { initializeWriteThinLTOBitcodePass(*PassRegistry::getPassRegistry()); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index 8b30f0e989a1..898a213d0849 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -57,6 +57,7 @@ #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" #include "llvm/ADT/Triple.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AssumptionCache.h" @@ -79,6 +80,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/ModuleSummaryIndexYAML.h" @@ -95,6 +97,7 @@ #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/FunctionAttrs.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/CallPromotionUtils.h" #include "llvm/Transforms/Utils/Evaluator.h" #include <algorithm> #include <cstddef> @@ -107,6 +110,15 @@ using namespace wholeprogramdevirt; #define DEBUG_TYPE "wholeprogramdevirt" +STATISTIC(NumDevirtTargets, "Number of whole program devirtualization targets"); +STATISTIC(NumSingleImpl, "Number of single implementation devirtualizations"); +STATISTIC(NumBranchFunnel, "Number of branch funnels"); +STATISTIC(NumUniformRetVal, "Number of uniform return value optimizations"); +STATISTIC(NumUniqueRetVal, "Number of unique return value optimizations"); +STATISTIC(NumVirtConstProp1Bit, + "Number of 1 bit virtual constant propagations"); +STATISTIC(NumVirtConstProp, "Number of virtual constant propagations"); + static cl::opt<PassSummaryAction> ClSummaryAction( "wholeprogramdevirt-summary-action", cl::desc("What to do with the summary when running this pass"), @@ -132,13 +144,12 @@ static cl::opt<std::string> ClWriteSummary( static cl::opt<unsigned> ClThreshold("wholeprogramdevirt-branch-funnel-threshold", cl::Hidden, - cl::init(10), cl::ZeroOrMore, + cl::init(10), cl::desc("Maximum number of call targets per " "call site to enable branch funnels")); static cl::opt<bool> PrintSummaryDevirt("wholeprogramdevirt-print-index-based", cl::Hidden, - cl::init(false), cl::ZeroOrMore, cl::desc("Print index-based devirtualization messages")); /// Provide a way to force enable whole program visibility in tests. @@ -146,30 +157,34 @@ static cl::opt<bool> /// !vcall_visibility metadata (the mere presense of type tests /// previously implied hidden visibility). static cl::opt<bool> - WholeProgramVisibility("whole-program-visibility", cl::init(false), - cl::Hidden, cl::ZeroOrMore, + WholeProgramVisibility("whole-program-visibility", cl::Hidden, cl::desc("Enable whole program visibility")); /// Provide a way to force disable whole program for debugging or workarounds, /// when enabled via the linker. static cl::opt<bool> DisableWholeProgramVisibility( - "disable-whole-program-visibility", cl::init(false), cl::Hidden, - cl::ZeroOrMore, + "disable-whole-program-visibility", cl::Hidden, cl::desc("Disable whole program visibility (overrides enabling options)")); /// Provide way to prevent certain function from being devirtualized static cl::list<std::string> SkipFunctionNames("wholeprogramdevirt-skip", cl::desc("Prevent function(s) from being devirtualized"), - cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated); - -/// Mechanism to add runtime checking of devirtualization decisions, trapping on -/// any that are not correct. Useful for debugging undefined behavior leading to -/// failures with WPD. -static cl::opt<bool> - CheckDevirt("wholeprogramdevirt-check", cl::init(false), cl::Hidden, - cl::ZeroOrMore, - cl::desc("Add code to trap on incorrect devirtualizations")); + cl::Hidden, cl::CommaSeparated); + +/// Mechanism to add runtime checking of devirtualization decisions, optionally +/// trapping or falling back to indirect call on any that are not correct. +/// Trapping mode is useful for debugging undefined behavior leading to failures +/// with WPD. Fallback mode is useful for ensuring safety when whole program +/// visibility may be compromised. +enum WPDCheckMode { None, Trap, Fallback }; +static cl::opt<WPDCheckMode> DevirtCheckMode( + "wholeprogramdevirt-check", cl::Hidden, + cl::desc("Type of checking for incorrect devirtualizations"), + cl::values(clEnumValN(WPDCheckMode::None, "none", "No checking"), + clEnumValN(WPDCheckMode::Trap, "trap", "Trap when incorrect"), + clEnumValN(WPDCheckMode::Fallback, "fallback", + "Fallback to indirect when incorrect"))); namespace { struct PatternList { @@ -866,13 +881,14 @@ void updateVCallVisibilityInIndex( if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) return; for (auto &P : Index) { + // Don't upgrade the visibility for symbols exported to the dynamic + // linker, as we have no information on their eventual use. + if (DynamicExportSymbols.count(P.first)) + continue; for (auto &S : P.second.SummaryList) { auto *GVar = dyn_cast<GlobalVarSummary>(S.get()); if (!GVar || - GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic || - // Don't upgrade the visibility for symbols exported to the dynamic - // linker, as we have no information on their eventual use. - DynamicExportSymbols.count(P.first)) + GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic) continue; GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit); } @@ -1133,16 +1149,17 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, if (RemarksEnabled) VCallSite.emitRemark("single-impl", TheFn->stripPointerCasts()->getName(), OREGetter); + NumSingleImpl++; auto &CB = VCallSite.CB; assert(!CB.getCalledFunction() && "devirtualizing direct call?"); IRBuilder<> Builder(&CB); Value *Callee = Builder.CreateBitCast(TheFn, CB.getCalledOperand()->getType()); - // If checking is enabled, add support to compare the virtual function - // pointer to the devirtualized target. In case of a mismatch, perform a - // debug trap. - if (CheckDevirt) { + // If trap checking is enabled, add support to compare the virtual + // function pointer to the devirtualized target. In case of a mismatch, + // perform a debug trap. + if (DevirtCheckMode == WPDCheckMode::Trap) { auto *Cond = Builder.CreateICmpNE(CB.getCalledOperand(), Callee); Instruction *ThenTerm = SplitBlockAndInsertIfThen(Cond, &CB, /*Unreachable=*/false); @@ -1152,8 +1169,38 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, CallTrap->setDebugLoc(CB.getDebugLoc()); } - // Devirtualize. - CB.setCalledOperand(Callee); + // If fallback checking is enabled, add support to compare the virtual + // function pointer to the devirtualized target. In case of a mismatch, + // fall back to indirect call. + if (DevirtCheckMode == WPDCheckMode::Fallback) { + MDNode *Weights = + MDBuilder(M.getContext()).createBranchWeights((1U << 20) - 1, 1); + // Version the indirect call site. If the called value is equal to the + // given callee, 'NewInst' will be executed, otherwise the original call + // site will be executed. + CallBase &NewInst = versionCallSite(CB, Callee, Weights); + NewInst.setCalledOperand(Callee); + // Since the new call site is direct, we must clear metadata that + // is only appropriate for indirect calls. This includes !prof and + // !callees metadata. + NewInst.setMetadata(LLVMContext::MD_prof, nullptr); + NewInst.setMetadata(LLVMContext::MD_callees, nullptr); + // Additionally, we should remove them from the fallback indirect call, + // so that we don't attempt to perform indirect call promotion later. + CB.setMetadata(LLVMContext::MD_prof, nullptr); + CB.setMetadata(LLVMContext::MD_callees, nullptr); + } + + // In either trapping or non-checking mode, devirtualize original call. + else { + // Devirtualize unconditionally. + CB.setCalledOperand(Callee); + // Since the call site is now direct, we must clear metadata that + // is only appropriate for indirect calls. This includes !prof and + // !callees metadata. + CB.setMetadata(LLVMContext::MD_prof, nullptr); + CB.setMetadata(LLVMContext::MD_callees, nullptr); + } // This use is no longer unsafe. if (VCallSite.NumUnsafeUses) @@ -1208,7 +1255,7 @@ bool DevirtModule::trySingleImplDevirt( return false; // If so, update each call site to call that implementation directly. - if (RemarksEnabled) + if (RemarksEnabled || AreStatisticsEnabled()) TargetsForSlot[0].WasDevirt = true; bool IsExported = false; @@ -1279,7 +1326,7 @@ bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot, return false; // Collect functions devirtualized at least for one call site for stats. - if (PrintSummaryDevirt) + if (PrintSummaryDevirt || AreStatisticsEnabled()) DevirtTargets.insert(TheFn); auto &S = TheFn.getSummaryList()[0]; @@ -1385,6 +1432,7 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, !FSAttr.getValueAsString().contains("+retpoline")) continue; + NumBranchFunnel++; if (RemarksEnabled) VCallSite.emitRemark("branch-funnel", JT->stripPointerCasts()->getName(), OREGetter); @@ -1476,6 +1524,7 @@ void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, for (auto Call : CSInfo.CallSites) { if (!OptimizedCalls.insert(&Call.CB).second) continue; + NumUniformRetVal++; Call.replaceAndErase( "uniform-ret-val", FnName, RemarksEnabled, OREGetter, ConstantInt::get(cast<IntegerType>(Call.CB.getType()), TheRetVal)); @@ -1499,7 +1548,7 @@ bool DevirtModule::tryUniformRetValOpt( } applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal); - if (RemarksEnabled) + if (RemarksEnabled || AreStatisticsEnabled()) for (auto &&Target : TargetsForSlot) Target.WasDevirt = true; return true; @@ -1592,6 +1641,7 @@ void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, Call.VTable, B.CreateBitCast(UniqueMemberAddr, Call.VTable->getType())); Cmp = B.CreateZExt(Cmp, Call.CB.getType()); + NumUniqueRetVal++; Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, OREGetter, Cmp); } @@ -1636,7 +1686,7 @@ bool DevirtModule::tryUniqueRetValOpt( UniqueMemberAddr); // Update devirtualization statistics for targets. - if (RemarksEnabled) + if (RemarksEnabled || AreStatisticsEnabled()) for (auto &&Target : TargetsForSlot) Target.WasDevirt = true; @@ -1665,11 +1715,13 @@ void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, Value *Bits = B.CreateLoad(Int8Ty, Addr); Value *BitsAndBit = B.CreateAnd(Bits, Bit); auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); + NumVirtConstProp1Bit++; Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled, OREGetter, IsBitSet); } else { Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); Value *Val = B.CreateLoad(RetType, ValAddr); + NumVirtConstProp++; Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, OREGetter, Val); } @@ -1701,7 +1753,7 @@ bool DevirtModule::tryVirtualConstProp( for (VirtualCallTarget &Target : TargetsForSlot) { if (Target.Fn->isDeclaration() || computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) != - MAK_ReadNone || + FMRB_DoesNotAccessMemory || Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() || Target.Fn->getReturnType() != RetType) return false; @@ -1755,7 +1807,7 @@ bool DevirtModule::tryVirtualConstProp( setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte, OffsetBit); - if (RemarksEnabled) + if (RemarksEnabled || AreStatisticsEnabled()) for (auto &&Target : TargetsForSlot) Target.WasDevirt = true; @@ -1963,7 +2015,7 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { // (although this is unlikely). In that case, explicitly build a pair and // RAUW it. if (!CI->use_empty()) { - Value *Pair = UndefValue::get(CI->getType()); + Value *Pair = PoisonValue::get(CI->getType()); IRBuilder<> B(CI); Pair = B.CreateInsertValue(Pair, LoadedValue, {0}); Pair = B.CreateInsertValue(Pair, TypeTestCall, {1}); @@ -2151,9 +2203,9 @@ bool DevirtModule::run() { removeRedundantTypeTests(); - // We have lowered or deleted the type instrinsics, so we will no - // longer have enough information to reason about the liveness of virtual - // function pointers in GlobalDCE. + // We have lowered or deleted the type intrinsics, so we will no longer have + // enough information to reason about the liveness of virtual function + // pointers in GlobalDCE. for (GlobalVariable &GV : M.globals()) GV.eraseMetadata(LLVMContext::MD_vcall_visibility); @@ -2243,7 +2295,7 @@ bool DevirtModule::run() { } // Collect functions devirtualized at least for one call site for stats. - if (RemarksEnabled) + if (RemarksEnabled || AreStatisticsEnabled()) for (const auto &T : TargetsForSlot) if (T.WasDevirt) DevirtTargets[std::string(T.Fn->getName())] = T.Fn; @@ -2276,6 +2328,8 @@ bool DevirtModule::run() { } } + NumDevirtTargets += DevirtTargets.size(); + removeRedundantTypeTests(); // Rebuild each global we touched as part of virtual constant propagation to @@ -2284,9 +2338,9 @@ bool DevirtModule::run() { for (VTableBits &B : Bits) rebuildGlobal(B); - // We have lowered or deleted the type instrinsics, so we will no - // longer have enough information to reason about the liveness of virtual - // function pointers in GlobalDCE. + // We have lowered or deleted the type intrinsics, so we will no longer have + // enough information to reason about the liveness of virtual function + // pointers in GlobalDCE. for (GlobalVariable &GV : M.globals()) GV.eraseMetadata(LLVMContext::MD_vcall_visibility); @@ -2367,4 +2421,6 @@ void DevirtIndex::run() { if (PrintSummaryDevirt) for (const auto &DT : DevirtTargets) errs() << "Devirtualized call to " << DT << "\n"; + + NumDevirtTargets += DevirtTargets.size(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 0598f751febe..f4d8b79a5311 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -693,9 +693,6 @@ unsigned FAddCombine::calcInstrNumber(const AddendVect &Opnds) { unsigned OpndNum = Opnds.size(); unsigned InstrNeeded = OpndNum - 1; - // The number of addends in the form of "(-1)*x". - unsigned NegOpndNum = 0; - // Adjust the number of instructions needed to emit the N-ary add. for (const FAddend *Opnd : Opnds) { if (Opnd->isConstant()) @@ -707,9 +704,6 @@ unsigned FAddCombine::calcInstrNumber(const AddendVect &Opnds) { continue; const FAddendCoef &CE = Opnd->getCoef(); - if (CE.isMinusOne() || CE.isMinusTwo()) - NegOpndNum++; - // Let the addend be "c * x". If "c == +/-1", the value of the addend // is immediately available; otherwise, it needs exactly one instruction // to evaluate the value. @@ -1277,7 +1271,7 @@ static Instruction *factorizeMathWithShlOps(BinaryOperator &I, } Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { - if (Value *V = SimplifyAddInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifyAddInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1375,6 +1369,13 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { } } + // (A & 2^C1) + A => A & (2^C1 - 1) iff bit C1 in A is a sign bit + if (match(&I, m_c_Add(m_And(m_Value(A), m_APInt(C1)), m_Deferred(A))) && + C1->isPowerOf2() && (ComputeNumSignBits(A) > C1->countLeadingZeros())) { + Constant *NewMask = ConstantInt::get(RHS->getType(), *C1 - 1); + return BinaryOperator::CreateAnd(A, NewMask); + } + // A+B --> A|B iff A and B have no bits set in common. if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT)) return BinaryOperator::CreateOr(LHS, RHS); @@ -1528,7 +1529,7 @@ static Instruction *factorizeFAddFSub(BinaryOperator &I, } Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) { - if (Value *V = SimplifyFAddInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifyFAddInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1687,7 +1688,8 @@ Value *InstCombinerImpl::OptimizePointerDifference(Value *LHS, Value *RHS, // Require at least one GEP with a common base pointer on both sides. if (auto *LHSGEP = dyn_cast<GEPOperator>(LHS)) { // (gep X, ...) - X - if (LHSGEP->getOperand(0) == RHS) { + if (LHSGEP->getOperand(0)->stripPointerCasts() == + RHS->stripPointerCasts()) { GEP1 = LHSGEP; } else if (auto *RHSGEP = dyn_cast<GEPOperator>(RHS)) { // (gep X, ...) - (gep X, ...) @@ -1749,7 +1751,7 @@ Value *InstCombinerImpl::OptimizePointerDifference(Value *LHS, Value *RHS, } Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { - if (Value *V = SimplifySubInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifySubInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -2014,6 +2016,37 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { } } + if (auto *II = dyn_cast<MinMaxIntrinsic>(Op1)) { + { + // sub(add(X,Y), s/umin(X,Y)) --> s/umax(X,Y) + // sub(add(X,Y), s/umax(X,Y)) --> s/umin(X,Y) + Value *X = II->getLHS(); + Value *Y = II->getRHS(); + if (match(Op0, m_c_Add(m_Specific(X), m_Specific(Y))) && + (Op0->hasOneUse() || Op1->hasOneUse())) { + Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); + Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, Y); + return replaceInstUsesWith(I, InvMaxMin); + } + } + + { + // sub(add(X,Y),umin(Y,Z)) --> add(X,usub.sat(Y,Z)) + // sub(add(X,Z),umin(Y,Z)) --> add(X,usub.sat(Z,Y)) + Value *X, *Y, *Z; + if (match(Op1, m_OneUse(m_UMin(m_Value(Y), m_Value(Z))))) { + if (match(Op0, m_OneUse(m_c_Add(m_Specific(Y), m_Value(X))))) + return BinaryOperator::CreateAdd( + X, Builder.CreateIntrinsic(Intrinsic::usub_sat, I.getType(), + {Y, Z})); + if (match(Op0, m_OneUse(m_c_Add(m_Specific(Z), m_Value(X))))) + return BinaryOperator::CreateAdd( + X, Builder.CreateIntrinsic(Intrinsic::usub_sat, I.getType(), + {Z, Y})); + } + } + } + { // If we have a subtraction between some value and a select between // said value and something else, sink subtraction into select hands, i.e.: @@ -2089,36 +2122,6 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { return BinaryOperator::CreateSub(X, Not); } - // TODO: This is the same logic as above but handles the cmp-select idioms - // for min/max, so the use checks are increased to account for the - // extra instructions. If we canonicalize to intrinsics, this block - // can likely be removed. - { - Value *LHS, *RHS, *A; - Value *NotA = Op0, *MinMax = Op1; - SelectPatternFlavor SPF = matchSelectPattern(MinMax, LHS, RHS).Flavor; - if (!SelectPatternResult::isMinOrMax(SPF)) { - NotA = Op1; - MinMax = Op0; - SPF = matchSelectPattern(MinMax, LHS, RHS).Flavor; - } - if (SelectPatternResult::isMinOrMax(SPF) && - match(NotA, m_Not(m_Value(A))) && (NotA == LHS || NotA == RHS)) { - if (NotA == LHS) - std::swap(LHS, RHS); - // LHS is now Y above and expected to have at least 2 uses (the min/max) - // NotA is expected to have 2 uses from the min/max and 1 from the sub. - if (isFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && - !NotA->hasNUsesOrMore(4)) { - Value *Not = Builder.CreateNot(MinMax); - if (NotA == Op0) - return BinaryOperator::CreateSub(Not, A); - else - return BinaryOperator::CreateSub(A, Not); - } - } - } - // Optimize pointer differences into the same array into a size. Consider: // &A[10] - &A[0]: we should compile this to "10". Value *LHSOp, *RHSOp; @@ -2149,11 +2152,11 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { // B = ashr i32 A, 31 ; smear the sign bit // sub (xor A, B), B ; flip bits if negative and subtract -1 (add 1) // --> (A < 0) ? -A : A - Value *Cmp = Builder.CreateICmpSLT(A, ConstantInt::getNullValue(Ty)); + Value *IsNeg = Builder.CreateIsNeg(A); // Copy the nuw/nsw flags from the sub to the negate. - Value *Neg = Builder.CreateNeg(A, "", I.hasNoUnsignedWrap(), - I.hasNoSignedWrap()); - return SelectInst::Create(Cmp, Neg, A); + Value *NegA = Builder.CreateNeg(A, "", I.hasNoUnsignedWrap(), + I.hasNoSignedWrap()); + return SelectInst::Create(IsNeg, NegA, A); } // If we are subtracting a low-bit masked subset of some value from an add @@ -2187,12 +2190,23 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { return replaceInstUsesWith( I, Builder.CreateIntrinsic(Intrinsic::usub_sat, {Ty}, {X, Op1})); + // Op0 - umin(X, Op0) --> usub.sat(Op0, X) + if (match(Op1, m_OneUse(m_c_UMin(m_Value(X), m_Specific(Op0))))) + return replaceInstUsesWith( + I, Builder.CreateIntrinsic(Intrinsic::usub_sat, {Ty}, {Op0, X})); + // Op0 - umax(X, Op0) --> 0 - usub.sat(X, Op0) if (match(Op1, m_OneUse(m_c_UMax(m_Value(X), m_Specific(Op0))))) { Value *USub = Builder.CreateIntrinsic(Intrinsic::usub_sat, {Ty}, {X, Op0}); return BinaryOperator::CreateNeg(USub); } + // umin(X, Op1) - Op1 --> 0 - usub.sat(Op1, X) + if (match(Op0, m_OneUse(m_c_UMin(m_Value(X), m_Specific(Op1))))) { + Value *USub = Builder.CreateIntrinsic(Intrinsic::usub_sat, {Ty}, {Op1, X}); + return BinaryOperator::CreateNeg(USub); + } + // C - ctpop(X) => ctpop(~X) if C is bitwidth if (match(Op0, m_SpecificInt(Ty->getScalarSizeInBits())) && match(Op1, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(X))))) @@ -2264,7 +2278,7 @@ static Instruction *hoistFNegAboveFMulFDiv(Instruction &I, Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { Value *Op = I.getOperand(0); - if (Value *V = SimplifyFNegInst(Op, I.getFastMathFlags(), + if (Value *V = simplifyFNegInst(Op, I.getFastMathFlags(), getSimplifyQuery().getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -2287,10 +2301,11 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { // Unlike most transforms, this one is not safe to propagate nsz unless // it is present on the original select. (We are conservatively intersecting // the nsz flags from the select and root fneg instruction.) - auto propagateSelectFMF = [&](SelectInst *S) { + auto propagateSelectFMF = [&](SelectInst *S, bool CommonOperand) { S->copyFastMathFlags(&I); if (auto *OldSel = dyn_cast<SelectInst>(Op)) - if (!OldSel->hasNoSignedZeros()) + if (!OldSel->hasNoSignedZeros() && !CommonOperand && + !isGuaranteedNotToBeUndefOrPoison(OldSel->getCondition())) S->setHasNoSignedZeros(false); }; // -(Cond ? -P : Y) --> Cond ? P : -Y @@ -2298,14 +2313,14 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { if (match(X, m_FNeg(m_Value(P)))) { Value *NegY = Builder.CreateFNegFMF(Y, &I, Y->getName() + ".neg"); SelectInst *NewSel = SelectInst::Create(Cond, P, NegY); - propagateSelectFMF(NewSel); + propagateSelectFMF(NewSel, P == Y); return NewSel; } // -(Cond ? X : -P) --> Cond ? -X : P if (match(Y, m_FNeg(m_Value(P)))) { Value *NegX = Builder.CreateFNegFMF(X, &I, X->getName() + ".neg"); SelectInst *NewSel = SelectInst::Create(Cond, NegX, P); - propagateSelectFMF(NewSel); + propagateSelectFMF(NewSel, P == X); return NewSel; } } @@ -2314,7 +2329,7 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { } Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { - if (Value *V = SimplifyFSubInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifyFSubInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), getSimplifyQuery().getWithInstruction(&I))) return replaceInstUsesWith(I, V); diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 7eaa28bd1320..ae8865651ece 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -24,32 +24,6 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" -/// Similar to getICmpCode but for FCmpInst. This encodes a fcmp predicate into -/// a four bit mask. -static unsigned getFCmpCode(FCmpInst::Predicate CC) { - assert(FCmpInst::FCMP_FALSE <= CC && CC <= FCmpInst::FCMP_TRUE && - "Unexpected FCmp predicate!"); - // Take advantage of the bit pattern of FCmpInst::Predicate here. - // U L G E - static_assert(FCmpInst::FCMP_FALSE == 0, ""); // 0 0 0 0 - static_assert(FCmpInst::FCMP_OEQ == 1, ""); // 0 0 0 1 - static_assert(FCmpInst::FCMP_OGT == 2, ""); // 0 0 1 0 - static_assert(FCmpInst::FCMP_OGE == 3, ""); // 0 0 1 1 - static_assert(FCmpInst::FCMP_OLT == 4, ""); // 0 1 0 0 - static_assert(FCmpInst::FCMP_OLE == 5, ""); // 0 1 0 1 - static_assert(FCmpInst::FCMP_ONE == 6, ""); // 0 1 1 0 - static_assert(FCmpInst::FCMP_ORD == 7, ""); // 0 1 1 1 - static_assert(FCmpInst::FCMP_UNO == 8, ""); // 1 0 0 0 - static_assert(FCmpInst::FCMP_UEQ == 9, ""); // 1 0 0 1 - static_assert(FCmpInst::FCMP_UGT == 10, ""); // 1 0 1 0 - static_assert(FCmpInst::FCMP_UGE == 11, ""); // 1 0 1 1 - static_assert(FCmpInst::FCMP_ULT == 12, ""); // 1 1 0 0 - static_assert(FCmpInst::FCMP_ULE == 13, ""); // 1 1 0 1 - static_assert(FCmpInst::FCMP_UNE == 14, ""); // 1 1 1 0 - static_assert(FCmpInst::FCMP_TRUE == 15, ""); // 1 1 1 1 - return CC; -} - /// This is the complement of getICmpCode, which turns an opcode and two /// operands into either a constant true or false, or a brand new ICmp /// instruction. The sign is passed in to determine which kind of predicate to @@ -66,14 +40,10 @@ static Value *getNewICmpValue(unsigned Code, bool Sign, Value *LHS, Value *RHS, /// operands into either a FCmp instruction, or a true/false constant. static Value *getFCmpValue(unsigned Code, Value *LHS, Value *RHS, InstCombiner::BuilderTy &Builder) { - const auto Pred = static_cast<FCmpInst::Predicate>(Code); - assert(FCmpInst::FCMP_FALSE <= Pred && Pred <= FCmpInst::FCMP_TRUE && - "Unexpected FCmp predicate!"); - if (Pred == FCmpInst::FCMP_FALSE) - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); - if (Pred == FCmpInst::FCMP_TRUE) - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 1); - return Builder.CreateFCmp(Pred, LHS, RHS); + FCmpInst::Predicate NewPred; + if (Constant *TorF = getPredForFCmpCode(Code, LHS->getType(), NewPred)) + return TorF; + return Builder.CreateFCmp(NewPred, LHS, RHS); } /// Transform BITWISE_OP(BSWAP(A),BSWAP(B)) or @@ -395,6 +365,7 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, /// (icmp(A & X) ==/!= Y), where the left-hand side is of type Mask_NotAllZeros /// and the right hand side is of type BMask_Mixed. For example, /// (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8). +/// Also used for logical and/or, must be poison safe. static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C, Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, @@ -409,9 +380,9 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( // // We currently handle the case of B, C, D, E are constant. // - ConstantInt *BCst, *CCst, *DCst, *ECst; - if (!match(B, m_ConstantInt(BCst)) || !match(C, m_ConstantInt(CCst)) || - !match(D, m_ConstantInt(DCst)) || !match(E, m_ConstantInt(ECst))) + const APInt *BCst, *CCst, *DCst, *OrigECst; + if (!match(B, m_APInt(BCst)) || !match(C, m_APInt(CCst)) || + !match(D, m_APInt(DCst)) || !match(E, m_APInt(OrigECst))) return nullptr; ICmpInst::Predicate NewCC = IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; @@ -420,19 +391,20 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( // canonicalized as, // (icmp ne (A & D), 0) -> (icmp eq (A & D), D) or // (icmp ne (A & D), D) -> (icmp eq (A & D), 0). + APInt ECst = *OrigECst; if (PredR != NewCC) - ECst = cast<ConstantInt>(ConstantExpr::getXor(DCst, ECst)); + ECst ^= *DCst; // If B or D is zero, skip because if LHS or RHS can be trivially folded by // other folding rules and this pattern won't apply any more. - if (BCst->getValue() == 0 || DCst->getValue() == 0) + if (*BCst == 0 || *DCst == 0) return nullptr; // If B and D don't intersect, ie. (B & D) == 0, no folding because we can't // deduce anything from it. // For example, // (icmp ne (A & 12), 0) & (icmp eq (A & 3), 1) -> no folding. - if ((BCst->getValue() & DCst->getValue()) == 0) + if ((*BCst & *DCst) == 0) return nullptr; // If the following two conditions are met: @@ -451,22 +423,21 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( // For example, // (icmp ne (A & 12), 0) & (icmp eq (A & 7), 1) -> (icmp eq (A & 15), 9) // (icmp ne (A & 15), 0) & (icmp eq (A & 7), 0) -> (icmp eq (A & 15), 8) - if ((((BCst->getValue() & DCst->getValue()) & ECst->getValue()) == 0) && - (BCst->getValue() & (BCst->getValue() ^ DCst->getValue())).isPowerOf2()) { - APInt BorD = BCst->getValue() | DCst->getValue(); - APInt BandBxorDorE = (BCst->getValue() & (BCst->getValue() ^ DCst->getValue())) | - ECst->getValue(); - Value *NewMask = ConstantInt::get(BCst->getType(), BorD); - Value *NewMaskedValue = ConstantInt::get(BCst->getType(), BandBxorDorE); + if ((((*BCst & *DCst) & ECst) == 0) && + (*BCst & (*BCst ^ *DCst)).isPowerOf2()) { + APInt BorD = *BCst | *DCst; + APInt BandBxorDorE = (*BCst & (*BCst ^ *DCst)) | ECst; + Value *NewMask = ConstantInt::get(A->getType(), BorD); + Value *NewMaskedValue = ConstantInt::get(A->getType(), BandBxorDorE); Value *NewAnd = Builder.CreateAnd(A, NewMask); return Builder.CreateICmp(NewCC, NewAnd, NewMaskedValue); } - auto IsSubSetOrEqual = [](ConstantInt *C1, ConstantInt *C2) { - return (C1->getValue() & C2->getValue()) == C1->getValue(); + auto IsSubSetOrEqual = [](const APInt *C1, const APInt *C2) { + return (*C1 & *C2) == *C1; }; - auto IsSuperSetOrEqual = [](ConstantInt *C1, ConstantInt *C2) { - return (C1->getValue() & C2->getValue()) == C2->getValue(); + auto IsSuperSetOrEqual = [](const APInt *C1, const APInt *C2) { + return (*C1 & *C2) == *C2; }; // In the following, we consider only the cases where B is a superset of D, B @@ -486,7 +457,7 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( // For example, // (icmp ne (A & 3), 0) & (icmp eq (A & 7), 0) -> false. // (icmp ne (A & 15), 0) & (icmp eq (A & 3), 0) -> no folding. - if (ECst->isZero()) { + if (ECst.isZero()) { if (IsSubSetOrEqual(BCst, DCst)) return ConstantInt::get(LHS->getType(), !IsAnd); return nullptr; @@ -504,7 +475,7 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( // ie. (B & E) != 0, then LHS is subsumed by RHS. For example. // (icmp ne (A & 12), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8). assert(IsSubSetOrEqual(BCst, DCst) && "Precondition due to above code"); - if ((BCst->getValue() & ECst->getValue()) != 0) + if ((*BCst & ECst) != 0) return RHS; // Otherwise, LHS and RHS contradict and the whole expression becomes false // (or true if negated.) For example, @@ -516,6 +487,7 @@ static Value *foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed( /// Try to fold (icmp(A & B) ==/!= 0) &/| (icmp(A & D) ==/!= E) into a single /// (icmp(A & X) ==/!= Y), where the left-hand side and the right hand side /// aren't of the common mask pattern type. +/// Also used for logical and/or, must be poison safe. static Value *foldLogOpOfMaskedICmpsAsymmetric( ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *A, Value *B, Value *C, Value *D, Value *E, ICmpInst::Predicate PredL, ICmpInst::Predicate PredR, @@ -550,6 +522,7 @@ static Value *foldLogOpOfMaskedICmpsAsymmetric( /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) /// into a single (icmp(A & X) ==/!= Y). static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, + bool IsLogical, InstCombiner::BuilderTy &Builder) { Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); @@ -594,6 +567,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, if (Mask & Mask_AllZeros) { // (icmp eq (A & B), 0) & (icmp eq (A & D), 0) // -> (icmp eq (A & (B|D)), 0) + if (IsLogical && !isGuaranteedNotToBeUndefOrPoison(D)) + return nullptr; // TODO: Use freeze? Value *NewOr = Builder.CreateOr(B, D); Value *NewAnd = Builder.CreateAnd(A, NewOr); // We can't use C as zero because we might actually handle @@ -605,6 +580,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, if (Mask & BMask_AllOnes) { // (icmp eq (A & B), B) & (icmp eq (A & D), D) // -> (icmp eq (A & (B|D)), (B|D)) + if (IsLogical && !isGuaranteedNotToBeUndefOrPoison(D)) + return nullptr; // TODO: Use freeze? Value *NewOr = Builder.CreateOr(B, D); Value *NewAnd = Builder.CreateAnd(A, NewOr); return Builder.CreateICmp(NewCC, NewAnd, NewOr); @@ -612,6 +589,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, if (Mask & AMask_AllOnes) { // (icmp eq (A & B), A) & (icmp eq (A & D), A) // -> (icmp eq (A & (B&D)), A) + if (IsLogical && !isGuaranteedNotToBeUndefOrPoison(D)) + return nullptr; // TODO: Use freeze? Value *NewAnd1 = Builder.CreateAnd(B, D); Value *NewAnd2 = Builder.CreateAnd(A, NewAnd1); return Builder.CreateICmp(NewCC, NewAnd2, A); @@ -736,47 +715,6 @@ Value *InstCombinerImpl::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, return Builder.CreateICmp(NewPred, Input, RangeEnd); } -static Value * -foldAndOrOfEqualityCmpsWithConstants(ICmpInst *LHS, ICmpInst *RHS, - bool JoinedByAnd, - InstCombiner::BuilderTy &Builder) { - Value *X = LHS->getOperand(0); - if (X != RHS->getOperand(0)) - return nullptr; - - const APInt *C1, *C2; - if (!match(LHS->getOperand(1), m_APInt(C1)) || - !match(RHS->getOperand(1), m_APInt(C2))) - return nullptr; - - // We only handle (X != C1 && X != C2) and (X == C1 || X == C2). - ICmpInst::Predicate Pred = LHS->getPredicate(); - if (Pred != RHS->getPredicate()) - return nullptr; - if (JoinedByAnd && Pred != ICmpInst::ICMP_NE) - return nullptr; - if (!JoinedByAnd && Pred != ICmpInst::ICMP_EQ) - return nullptr; - - // The larger unsigned constant goes on the right. - if (C1->ugt(*C2)) - std::swap(C1, C2); - - APInt Xor = *C1 ^ *C2; - if (Xor.isPowerOf2()) { - // If LHSC and RHSC differ by only one bit, then set that bit in X and - // compare against the larger constant: - // (X == C1 || X == C2) --> (X | (C1 ^ C2)) == C2 - // (X != C1 && X != C2) --> (X | (C1 ^ C2)) != C2 - // We choose an 'or' with a Pow2 constant rather than the inverse mask with - // 'and' because that may lead to smaller codegen from a smaller constant. - Value *Or = Builder.CreateOr(X, ConstantInt::get(X->getType(), Xor)); - return Builder.CreateICmp(Pred, Or, ConstantInt::get(X->getType(), *C2)); - } - - return nullptr; -} - // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) Value *InstCombinerImpl::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, @@ -941,7 +879,29 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1, CxtI.getName() + ".simplified"); } +/// Fold (icmp eq ctpop(X) 1) | (icmp eq X 0) into (icmp ult ctpop(X) 2) and +/// fold (icmp ne ctpop(X) 1) & (icmp ne X 0) into (icmp ugt ctpop(X) 1). +/// Also used for logical and/or, must be poison safe. +static Value *foldIsPowerOf2OrZero(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd, + InstCombiner::BuilderTy &Builder) { + CmpInst::Predicate Pred0, Pred1; + Value *X; + if (!match(Cmp0, m_ICmp(Pred0, m_Intrinsic<Intrinsic::ctpop>(m_Value(X)), + m_SpecificInt(1))) || + !match(Cmp1, m_ICmp(Pred1, m_Specific(X), m_ZeroInt()))) + return nullptr; + + Value *CtPop = Cmp0->getOperand(0); + if (IsAnd && Pred0 == ICmpInst::ICMP_NE && Pred1 == ICmpInst::ICMP_NE) + return Builder.CreateICmpUGT(CtPop, ConstantInt::get(CtPop->getType(), 1)); + if (!IsAnd && Pred0 == ICmpInst::ICMP_EQ && Pred1 == ICmpInst::ICMP_EQ) + return Builder.CreateICmpULT(CtPop, ConstantInt::get(CtPop->getType(), 2)); + + return nullptr; +} + /// Reduce a pair of compares that check if a value has exactly 1 bit set. +/// Also used for logical and/or, must be poison safe. static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, bool JoinedByAnd, InstCombiner::BuilderTy &Builder) { // Handle 'and' / 'or' commutation: make the equality check the first operand. @@ -1001,22 +961,13 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp, }; // Given ZeroCmpOp = (A + B) - // ZeroCmpOp <= A && ZeroCmpOp != 0 --> (0-B) < A - // ZeroCmpOp > A || ZeroCmpOp == 0 --> (0-B) >= A - // // ZeroCmpOp < A && ZeroCmpOp != 0 --> (0-X) < Y iff // ZeroCmpOp >= A || ZeroCmpOp == 0 --> (0-X) >= Y iff // with X being the value (A/B) that is known to be non-zero, // and Y being remaining value. - if (UnsignedPred == ICmpInst::ICMP_ULE && EqPred == ICmpInst::ICMP_NE && - IsAnd) - return Builder.CreateICmpULT(Builder.CreateNeg(B), A); if (UnsignedPred == ICmpInst::ICMP_ULT && EqPred == ICmpInst::ICMP_NE && IsAnd && GetKnownNonZeroAndOther(B, A)) return Builder.CreateICmpULT(Builder.CreateNeg(B), A); - if (UnsignedPred == ICmpInst::ICMP_UGT && EqPred == ICmpInst::ICMP_EQ && - !IsAnd) - return Builder.CreateICmpUGE(Builder.CreateNeg(B), A); if (UnsignedPred == ICmpInst::ICMP_UGE && EqPred == ICmpInst::ICMP_EQ && !IsAnd && GetKnownNonZeroAndOther(B, A)) return Builder.CreateICmpUGE(Builder.CreateNeg(B), A); @@ -1143,12 +1094,9 @@ Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, /// common operand with the constant. Callers are expected to call this with /// Cmp0/Cmp1 switched to handle logic op commutativity. static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, - BinaryOperator &Logic, + bool IsAnd, InstCombiner::BuilderTy &Builder, const SimplifyQuery &Q) { - bool IsAnd = Logic.getOpcode() == Instruction::And; - assert((IsAnd || Logic.getOpcode() == Instruction::Or) && "Wrong logic op"); - // Match an equality compare with a non-poison constant as Cmp0. // Also, give up if the compare can be constant-folded to avoid looping. ICmpInst::Predicate Pred0; @@ -1174,7 +1122,7 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, // (X != C) || (Y Pred1 X) --> (X != C) || (Y Pred1 C) // Can think of the 'or' substitution with the 'and' bool equivalent: // A || B --> A || (!A && B) - Value *SubstituteCmp = SimplifyICmpInst(Pred1, Y, C, Q); + Value *SubstituteCmp = simplifyICmpInst(Pred1, Y, C, Q); if (!SubstituteCmp) { // If we need to create a new instruction, require that the old compare can // be removed. @@ -1182,16 +1130,24 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, return nullptr; SubstituteCmp = Builder.CreateICmp(Pred1, Y, C); } - return Builder.CreateBinOp(Logic.getOpcode(), Cmp0, SubstituteCmp); + return Builder.CreateBinOp(IsAnd ? Instruction::And : Instruction::Or, Cmp0, + SubstituteCmp); } /// Fold (icmp Pred1 V1, C1) & (icmp Pred2 V2, C2) /// or (icmp Pred1 V1, C1) | (icmp Pred2 V2, C2) /// into a single comparison using range-based reasoning. -static Value *foldAndOrOfICmpsUsingRanges( - ICmpInst::Predicate Pred1, Value *V1, const APInt &C1, - ICmpInst::Predicate Pred2, Value *V2, const APInt &C2, - IRBuilderBase &Builder, bool IsAnd) { +/// NOTE: This is also used for logical and/or, must be poison-safe! +Value *InstCombinerImpl::foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1, + ICmpInst *ICmp2, + bool IsAnd) { + ICmpInst::Predicate Pred1, Pred2; + Value *V1, *V2; + const APInt *C1, *C2; + if (!match(ICmp1, m_ICmp(Pred1, m_Value(V1), m_APInt(C1))) || + !match(ICmp2, m_ICmp(Pred2, m_Value(V2), m_APInt(C2)))) + return nullptr; + // Look through add of a constant offset on V1, V2, or both operands. This // allows us to interpret the V + C' < C'' range idiom into a proper range. const APInt *Offset1 = nullptr, *Offset2 = nullptr; @@ -1206,152 +1162,51 @@ static Value *foldAndOrOfICmpsUsingRanges( if (V1 != V2) return nullptr; - ConstantRange CR1 = ConstantRange::makeExactICmpRegion(Pred1, C1); + ConstantRange CR1 = ConstantRange::makeExactICmpRegion( + IsAnd ? ICmpInst::getInversePredicate(Pred1) : Pred1, *C1); if (Offset1) CR1 = CR1.subtract(*Offset1); - ConstantRange CR2 = ConstantRange::makeExactICmpRegion(Pred2, C2); + ConstantRange CR2 = ConstantRange::makeExactICmpRegion( + IsAnd ? ICmpInst::getInversePredicate(Pred2) : Pred2, *C2); if (Offset2) CR2 = CR2.subtract(*Offset2); - Optional<ConstantRange> CR = - IsAnd ? CR1.exactIntersectWith(CR2) : CR1.exactUnionWith(CR2); - if (!CR) - return nullptr; - - CmpInst::Predicate NewPred; - APInt NewC, Offset; - CR->getEquivalentICmp(NewPred, NewC, Offset); - Type *Ty = V1->getType(); Value *NewV = V1; - if (Offset != 0) - NewV = Builder.CreateAdd(NewV, ConstantInt::get(Ty, Offset)); - return Builder.CreateICmp(NewPred, NewV, ConstantInt::get(Ty, NewC)); -} - -/// Fold (icmp)&(icmp) if possible. -Value *InstCombinerImpl::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, - BinaryOperator &And) { - const SimplifyQuery Q = SQ.getWithInstruction(&And); - - // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) - // if K1 and K2 are a one-bit mask. - if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, &And, - /* IsAnd */ true)) - return V; - - ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); - - // (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B) - if (predicatesFoldable(PredL, PredR)) { - if (LHS->getOperand(0) == RHS->getOperand(1) && - LHS->getOperand(1) == RHS->getOperand(0)) - LHS->swapOperands(); - if (LHS->getOperand(0) == RHS->getOperand(0) && - LHS->getOperand(1) == RHS->getOperand(1)) { - Value *Op0 = LHS->getOperand(0), *Op1 = LHS->getOperand(1); - unsigned Code = getICmpCode(LHS) & getICmpCode(RHS); - bool IsSigned = LHS->isSigned() || RHS->isSigned(); - return getNewICmpValue(Code, IsSigned, Op0, Op1, Builder); - } - } - - // handle (roughly): (icmp eq (A & B), C) & (icmp eq (A & D), E) - if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, true, Builder)) - return V; - - if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, And, Builder, Q)) - return V; - if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, And, Builder, Q)) - return V; - - // E.g. (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n - if (Value *V = simplifyRangeCheck(LHS, RHS, /*Inverted=*/false)) - return V; - - // E.g. (icmp slt x, n) & (icmp sge x, 0) --> icmp ult x, n - if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/false)) - return V; - - if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, true, Builder)) - return V; - - if (Value *V = foldSignedTruncationCheck(LHS, RHS, And, Builder)) - return V; - - if (Value *V = foldIsPowerOf2(LHS, RHS, true /* JoinedByAnd */, Builder)) - return V; - - if (Value *X = - foldUnsignedUnderflowCheck(LHS, RHS, /*IsAnd=*/true, Q, Builder)) - return X; - if (Value *X = - foldUnsignedUnderflowCheck(RHS, LHS, /*IsAnd=*/true, Q, Builder)) - return X; - - if (Value *X = foldEqOfParts(LHS, RHS, /*IsAnd=*/true)) - return X; + Optional<ConstantRange> CR = CR1.exactUnionWith(CR2); + if (!CR) { + if (!(ICmp1->hasOneUse() && ICmp2->hasOneUse()) || CR1.isWrappedSet() || + CR2.isWrappedSet()) + return nullptr; - // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). - Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); + // Check whether we have equal-size ranges that only differ by one bit. + // In that case we can apply a mask to map one range onto the other. + APInt LowerDiff = CR1.getLower() ^ CR2.getLower(); + APInt UpperDiff = (CR1.getUpper() - 1) ^ (CR2.getUpper() - 1); + APInt CR1Size = CR1.getUpper() - CR1.getLower(); + if (!LowerDiff.isPowerOf2() || LowerDiff != UpperDiff || + CR1Size != CR2.getUpper() - CR2.getLower()) + return nullptr; - // (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0) - // TODO: Remove this when foldLogOpOfMaskedICmps can handle undefs. - if (PredL == ICmpInst::ICMP_EQ && match(LHS->getOperand(1), m_ZeroInt()) && - PredR == ICmpInst::ICMP_EQ && match(RHS->getOperand(1), m_ZeroInt()) && - LHS0->getType() == RHS0->getType()) { - Value *NewOr = Builder.CreateOr(LHS0, RHS0); - return Builder.CreateICmp(PredL, NewOr, - Constant::getNullValue(NewOr->getType())); + CR = CR1.getLower().ult(CR2.getLower()) ? CR1 : CR2; + NewV = Builder.CreateAnd(NewV, ConstantInt::get(Ty, ~LowerDiff)); } - const APInt *LHSC, *RHSC; - if (!match(LHS->getOperand(1), m_APInt(LHSC)) || - !match(RHS->getOperand(1), m_APInt(RHSC))) - return nullptr; - - // (trunc x) == C1 & (and x, CA) == C2 -> (and x, CA|CMAX) == C1|C2 - // where CMAX is the all ones value for the truncated type, - // iff the lower bits of C2 and CA are zero. - if (PredL == ICmpInst::ICMP_EQ && PredL == PredR && LHS->hasOneUse() && - RHS->hasOneUse()) { - Value *V; - const APInt *AndC, *SmallC = nullptr, *BigC = nullptr; - - // (trunc x) == C1 & (and x, CA) == C2 - // (and x, CA) == C2 & (trunc x) == C1 - if (match(RHS0, m_Trunc(m_Value(V))) && - match(LHS0, m_And(m_Specific(V), m_APInt(AndC)))) { - SmallC = RHSC; - BigC = LHSC; - } else if (match(LHS0, m_Trunc(m_Value(V))) && - match(RHS0, m_And(m_Specific(V), m_APInt(AndC)))) { - SmallC = LHSC; - BigC = RHSC; - } - - if (SmallC && BigC) { - unsigned BigBitSize = BigC->getBitWidth(); - unsigned SmallBitSize = SmallC->getBitWidth(); + if (IsAnd) + CR = CR->inverse(); - // Check that the low bits are zero. - APInt Low = APInt::getLowBitsSet(BigBitSize, SmallBitSize); - if ((Low & *AndC).isZero() && (Low & *BigC).isZero()) { - Value *NewAnd = Builder.CreateAnd(V, Low | *AndC); - APInt N = SmallC->zext(BigBitSize) | *BigC; - Value *NewVal = ConstantInt::get(NewAnd->getType(), N); - return Builder.CreateICmp(PredL, NewAnd, NewVal); - } - } - } + CmpInst::Predicate NewPred; + APInt NewC, Offset; + CR->getEquivalentICmp(NewPred, NewC, Offset); - return foldAndOrOfICmpsUsingRanges(PredL, LHS0, *LHSC, PredR, RHS0, *RHSC, - Builder, /* IsAnd */ true); + if (Offset != 0) + NewV = Builder.CreateAdd(NewV, ConstantInt::get(Ty, Offset)); + return Builder.CreateICmp(NewPred, NewV, ConstantInt::get(Ty, NewC)); } Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, - bool IsAnd) { + bool IsAnd, bool IsLogicalSelect) { Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); FCmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); @@ -1380,11 +1235,22 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, unsigned FCmpCodeL = getFCmpCode(PredL); unsigned FCmpCodeR = getFCmpCode(PredR); unsigned NewPred = IsAnd ? FCmpCodeL & FCmpCodeR : FCmpCodeL | FCmpCodeR; + + // Intersect the fast math flags. + // TODO: We can union the fast math flags unless this is a logical select. + IRBuilder<>::FastMathFlagGuard FMFG(Builder); + FastMathFlags FMF = LHS->getFastMathFlags(); + FMF &= RHS->getFastMathFlags(); + Builder.setFastMathFlags(FMF); + return getFCmpValue(NewPred, LHS0, LHS1, Builder); } - if ((PredL == FCmpInst::FCMP_ORD && PredR == FCmpInst::FCMP_ORD && IsAnd) || - (PredL == FCmpInst::FCMP_UNO && PredR == FCmpInst::FCMP_UNO && !IsAnd)) { + // This transform is not valid for a logical select. + if (!IsLogicalSelect && + ((PredL == FCmpInst::FCMP_ORD && PredR == FCmpInst::FCMP_ORD && IsAnd) || + (PredL == FCmpInst::FCMP_UNO && PredR == FCmpInst::FCMP_UNO && + !IsAnd))) { if (LHS0->getType() != RHS0->getType()) return nullptr; @@ -1574,9 +1440,10 @@ Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) { Value *Cast1Src = Cast1->getOperand(0); // fold logic(cast(A), cast(B)) -> cast(logic(A, B)) - if (shouldOptimizeCast(Cast0) && shouldOptimizeCast(Cast1)) { + if ((Cast0->hasOneUse() || Cast1->hasOneUse()) && + shouldOptimizeCast(Cast0) && shouldOptimizeCast(Cast1)) { Value *NewOp = Builder.CreateBinOp(LogicOpc, Cast0Src, Cast1Src, - I.getName()); + I.getName()); return CastInst::Create(CastOpcode, NewOp, DestTy); } @@ -1589,9 +1456,8 @@ Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) { ICmpInst *ICmp0 = dyn_cast<ICmpInst>(Cast0Src); ICmpInst *ICmp1 = dyn_cast<ICmpInst>(Cast1Src); if (ICmp0 && ICmp1) { - Value *Res = LogicOpc == Instruction::And ? foldAndOfICmps(ICmp0, ICmp1, I) - : foldOrOfICmps(ICmp0, ICmp1, I); - if (Res) + if (Value *Res = + foldAndOrOfICmps(ICmp0, ICmp1, I, LogicOpc == Instruction::And)) return CastInst::Create(CastOpcode, Res, DestTy); return nullptr; } @@ -1862,7 +1728,7 @@ static Instruction *foldComplexAndOrPatterns(BinaryOperator &I, Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { Type *Ty = I.getType(); - if (Value *V = SimplifyAndInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifyAndInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1930,25 +1796,6 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { return BinaryOperator::CreateOr(And, ConstantInt::get(Ty, Together)); } - // If the mask is only needed on one incoming arm, push the 'and' op up. - if (match(Op0, m_OneUse(m_Xor(m_Value(X), m_Value(Y)))) || - match(Op0, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) { - APInt NotAndMask(~(*C)); - BinaryOperator::BinaryOps BinOp = cast<BinaryOperator>(Op0)->getOpcode(); - if (MaskedValueIsZero(X, NotAndMask, 0, &I)) { - // Not masking anything out for the LHS, move mask to RHS. - // and ({x}or X, Y), C --> {x}or X, (and Y, C) - Value *NewRHS = Builder.CreateAnd(Y, Op1, Y->getName() + ".masked"); - return BinaryOperator::Create(BinOp, X, NewRHS); - } - if (!isa<Constant>(Y) && MaskedValueIsZero(Y, NotAndMask, 0, &I)) { - // Not masking anything out for the RHS, move mask to LHS. - // and ({x}or X, Y), C --> {x}or (and X, C), Y - Value *NewLHS = Builder.CreateAnd(X, Op1, X->getName() + ".masked"); - return BinaryOperator::Create(BinOp, NewLHS, Y); - } - } - unsigned Width = Ty->getScalarSizeInBits(); const APInt *ShiftC; if (match(Op0, m_OneUse(m_SExt(m_AShr(m_Value(X), m_APInt(ShiftC)))))) { @@ -1989,7 +1836,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { // ((C1 OP zext(X)) & C2) -> zext((C1 OP X) & C2) if C2 fits in the // bitwidth of X and OP behaves well when given trunc(C1) and X. - auto isSuitableBinOpcode = [](BinaryOperator *B) { + auto isNarrowableBinOpcode = [](BinaryOperator *B) { switch (B->getOpcode()) { case Instruction::Xor: case Instruction::Or: @@ -2002,22 +1849,125 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { } }; BinaryOperator *BO; - if (match(Op0, m_OneUse(m_BinOp(BO))) && isSuitableBinOpcode(BO)) { + if (match(Op0, m_OneUse(m_BinOp(BO))) && isNarrowableBinOpcode(BO)) { + Instruction::BinaryOps BOpcode = BO->getOpcode(); Value *X; const APInt *C1; // TODO: The one-use restrictions could be relaxed a little if the AND // is going to be removed. + // Try to narrow the 'and' and a binop with constant operand: + // and (bo (zext X), C1), C --> zext (and (bo X, TruncC1), TruncC) if (match(BO, m_c_BinOp(m_OneUse(m_ZExt(m_Value(X))), m_APInt(C1))) && C->isIntN(X->getType()->getScalarSizeInBits())) { unsigned XWidth = X->getType()->getScalarSizeInBits(); Constant *TruncC1 = ConstantInt::get(X->getType(), C1->trunc(XWidth)); Value *BinOp = isa<ZExtInst>(BO->getOperand(0)) - ? Builder.CreateBinOp(BO->getOpcode(), X, TruncC1) - : Builder.CreateBinOp(BO->getOpcode(), TruncC1, X); + ? Builder.CreateBinOp(BOpcode, X, TruncC1) + : Builder.CreateBinOp(BOpcode, TruncC1, X); Constant *TruncC = ConstantInt::get(X->getType(), C->trunc(XWidth)); Value *And = Builder.CreateAnd(BinOp, TruncC); return new ZExtInst(And, Ty); } + + // Similar to above: if the mask matches the zext input width, then the + // 'and' can be eliminated, so we can truncate the other variable op: + // and (bo (zext X), Y), C --> zext (bo X, (trunc Y)) + if (isa<Instruction>(BO->getOperand(0)) && + match(BO->getOperand(0), m_OneUse(m_ZExt(m_Value(X)))) && + C->isMask(X->getType()->getScalarSizeInBits())) { + Y = BO->getOperand(1); + Value *TrY = Builder.CreateTrunc(Y, X->getType(), Y->getName() + ".tr"); + Value *NewBO = + Builder.CreateBinOp(BOpcode, X, TrY, BO->getName() + ".narrow"); + return new ZExtInst(NewBO, Ty); + } + // and (bo Y, (zext X)), C --> zext (bo (trunc Y), X) + if (isa<Instruction>(BO->getOperand(1)) && + match(BO->getOperand(1), m_OneUse(m_ZExt(m_Value(X)))) && + C->isMask(X->getType()->getScalarSizeInBits())) { + Y = BO->getOperand(0); + Value *TrY = Builder.CreateTrunc(Y, X->getType(), Y->getName() + ".tr"); + Value *NewBO = + Builder.CreateBinOp(BOpcode, TrY, X, BO->getName() + ".narrow"); + return new ZExtInst(NewBO, Ty); + } + } + + // This is intentionally placed after the narrowing transforms for + // efficiency (transform directly to the narrow logic op if possible). + // If the mask is only needed on one incoming arm, push the 'and' op up. + if (match(Op0, m_OneUse(m_Xor(m_Value(X), m_Value(Y)))) || + match(Op0, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) { + APInt NotAndMask(~(*C)); + BinaryOperator::BinaryOps BinOp = cast<BinaryOperator>(Op0)->getOpcode(); + if (MaskedValueIsZero(X, NotAndMask, 0, &I)) { + // Not masking anything out for the LHS, move mask to RHS. + // and ({x}or X, Y), C --> {x}or X, (and Y, C) + Value *NewRHS = Builder.CreateAnd(Y, Op1, Y->getName() + ".masked"); + return BinaryOperator::Create(BinOp, X, NewRHS); + } + if (!isa<Constant>(Y) && MaskedValueIsZero(Y, NotAndMask, 0, &I)) { + // Not masking anything out for the RHS, move mask to LHS. + // and ({x}or X, Y), C --> {x}or (and X, C), Y + Value *NewLHS = Builder.CreateAnd(X, Op1, X->getName() + ".masked"); + return BinaryOperator::Create(BinOp, NewLHS, Y); + } + } + + // When the mask is a power-of-2 constant and op0 is a shifted-power-of-2 + // constant, test if the shift amount equals the offset bit index: + // (ShiftC << X) & C --> X == (log2(C) - log2(ShiftC)) ? C : 0 + // (ShiftC >> X) & C --> X == (log2(ShiftC) - log2(C)) ? C : 0 + if (C->isPowerOf2() && + match(Op0, m_OneUse(m_LogicalShift(m_Power2(ShiftC), m_Value(X))))) { + int Log2ShiftC = ShiftC->exactLogBase2(); + int Log2C = C->exactLogBase2(); + bool IsShiftLeft = + cast<BinaryOperator>(Op0)->getOpcode() == Instruction::Shl; + int BitNum = IsShiftLeft ? Log2C - Log2ShiftC : Log2ShiftC - Log2C; + assert(BitNum >= 0 && "Expected demanded bits to handle impossible mask"); + Value *Cmp = Builder.CreateICmpEQ(X, ConstantInt::get(Ty, BitNum)); + return SelectInst::Create(Cmp, ConstantInt::get(Ty, *C), + ConstantInt::getNullValue(Ty)); + } + + Constant *C1, *C2; + const APInt *C3 = C; + Value *X; + if (C3->isPowerOf2()) { + Constant *Log2C3 = ConstantInt::get(Ty, C3->countTrailingZeros()); + if (match(Op0, m_OneUse(m_LShr(m_Shl(m_ImmConstant(C1), m_Value(X)), + m_ImmConstant(C2)))) && + match(C1, m_Power2())) { + Constant *Log2C1 = ConstantExpr::getExactLogBase2(C1); + Constant *LshrC = ConstantExpr::getAdd(C2, Log2C3); + KnownBits KnownLShrc = computeKnownBits(LshrC, 0, nullptr); + if (KnownLShrc.getMaxValue().ult(Width)) { + // iff C1,C3 is pow2 and C2 + cttz(C3) < BitWidth: + // ((C1 << X) >> C2) & C3 -> X == (cttz(C3)+C2-cttz(C1)) ? C3 : 0 + Constant *CmpC = ConstantExpr::getSub(LshrC, Log2C1); + Value *Cmp = Builder.CreateICmpEQ(X, CmpC); + return SelectInst::Create(Cmp, ConstantInt::get(Ty, *C3), + ConstantInt::getNullValue(Ty)); + } + } + + if (match(Op0, m_OneUse(m_Shl(m_LShr(m_ImmConstant(C1), m_Value(X)), + m_ImmConstant(C2)))) && + match(C1, m_Power2())) { + Constant *Log2C1 = ConstantExpr::getExactLogBase2(C1); + Constant *Cmp = + ConstantExpr::getCompare(ICmpInst::ICMP_ULT, Log2C3, C2); + if (Cmp->isZeroValue()) { + // iff C1,C3 is pow2 and Log2(C3) >= C2: + // ((C1 >> X) << C2) & C3 -> X == (cttz(C1)+C2-cttz(C3)) ? C3 : 0 + Constant *ShlC = ConstantExpr::getAdd(C2, Log2C1); + Constant *CmpC = ConstantExpr::getSub(ShlC, Log2C3); + Value *Cmp = Builder.CreateICmpEQ(X, CmpC); + return SelectInst::Create(Cmp, ConstantInt::get(Ty, *C3), + ConstantInt::getNullValue(Ty)); + } + } } } @@ -2127,32 +2077,50 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { ICmpInst *LHS = dyn_cast<ICmpInst>(Op0); ICmpInst *RHS = dyn_cast<ICmpInst>(Op1); if (LHS && RHS) - if (Value *Res = foldAndOfICmps(LHS, RHS, I)) + if (Value *Res = foldAndOrOfICmps(LHS, RHS, I, /* IsAnd */ true)) return replaceInstUsesWith(I, Res); // TODO: Make this recursive; it's a little tricky because an arbitrary // number of 'and' instructions might have to be created. - if (LHS && match(Op1, m_OneUse(m_And(m_Value(X), m_Value(Y))))) { + if (LHS && match(Op1, m_OneUse(m_LogicalAnd(m_Value(X), m_Value(Y))))) { + bool IsLogical = isa<SelectInst>(Op1); + // LHS & (X && Y) --> (LHS && X) && Y if (auto *Cmp = dyn_cast<ICmpInst>(X)) - if (Value *Res = foldAndOfICmps(LHS, Cmp, I)) - return replaceInstUsesWith(I, Builder.CreateAnd(Res, Y)); + if (Value *Res = + foldAndOrOfICmps(LHS, Cmp, I, /* IsAnd */ true, IsLogical)) + return replaceInstUsesWith(I, IsLogical + ? Builder.CreateLogicalAnd(Res, Y) + : Builder.CreateAnd(Res, Y)); + // LHS & (X && Y) --> X && (LHS & Y) if (auto *Cmp = dyn_cast<ICmpInst>(Y)) - if (Value *Res = foldAndOfICmps(LHS, Cmp, I)) - return replaceInstUsesWith(I, Builder.CreateAnd(Res, X)); - } - if (RHS && match(Op0, m_OneUse(m_And(m_Value(X), m_Value(Y))))) { + if (Value *Res = foldAndOrOfICmps(LHS, Cmp, I, /* IsAnd */ true, + /* IsLogical */ false)) + return replaceInstUsesWith(I, IsLogical + ? Builder.CreateLogicalAnd(X, Res) + : Builder.CreateAnd(X, Res)); + } + if (RHS && match(Op0, m_OneUse(m_LogicalAnd(m_Value(X), m_Value(Y))))) { + bool IsLogical = isa<SelectInst>(Op0); + // (X && Y) & RHS --> (X && RHS) && Y if (auto *Cmp = dyn_cast<ICmpInst>(X)) - if (Value *Res = foldAndOfICmps(Cmp, RHS, I)) - return replaceInstUsesWith(I, Builder.CreateAnd(Res, Y)); + if (Value *Res = + foldAndOrOfICmps(Cmp, RHS, I, /* IsAnd */ true, IsLogical)) + return replaceInstUsesWith(I, IsLogical + ? Builder.CreateLogicalAnd(Res, Y) + : Builder.CreateAnd(Res, Y)); + // (X && Y) & RHS --> X && (Y & RHS) if (auto *Cmp = dyn_cast<ICmpInst>(Y)) - if (Value *Res = foldAndOfICmps(Cmp, RHS, I)) - return replaceInstUsesWith(I, Builder.CreateAnd(Res, X)); + if (Value *Res = foldAndOrOfICmps(Cmp, RHS, I, /* IsAnd */ true, + /* IsLogical */ false)) + return replaceInstUsesWith(I, IsLogical + ? Builder.CreateLogicalAnd(X, Res) + : Builder.CreateAnd(X, Res)); } } if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) - if (Value *Res = foldLogicOfFCmps(LHS, RHS, true)) + if (Value *Res = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ true)) return replaceInstUsesWith(I, Res); if (Instruction *FoldedFCmps = reassociateFCmps(I, Builder)) @@ -2181,18 +2149,16 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { unsigned FullShift = Ty->getScalarSizeInBits() - 1; if (match(&I, m_c_And(m_OneUse(m_AShr(m_Value(X), m_SpecificInt(FullShift))), m_Value(Y)))) { - Constant *Zero = ConstantInt::getNullValue(Ty); - Value *Cmp = Builder.CreateICmpSLT(X, Zero, "isneg"); - return SelectInst::Create(Cmp, Y, Zero); + Value *IsNeg = Builder.CreateIsNeg(X, "isneg"); + return SelectInst::Create(IsNeg, Y, ConstantInt::getNullValue(Ty)); } // If there's a 'not' of the shifted value, swap the select operands: // ~(iN X s>> (N-1)) & Y --> (X s< 0) ? 0 : Y if (match(&I, m_c_And(m_OneUse(m_Not( m_AShr(m_Value(X), m_SpecificInt(FullShift)))), m_Value(Y)))) { - Constant *Zero = ConstantInt::getNullValue(Ty); - Value *Cmp = Builder.CreateICmpSLT(X, Zero, "isneg"); - return SelectInst::Create(Cmp, Zero, Y); + Value *IsNeg = Builder.CreateIsNeg(X, "isneg"); + return SelectInst::Create(IsNeg, ConstantInt::getNullValue(Ty), Y); } // (~x) & y --> ~(x | (~y)) iff that gets rid of inversions @@ -2505,15 +2471,46 @@ Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B, return nullptr; } -/// Fold (icmp)|(icmp) if possible. -Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, - BinaryOperator &Or) { - const SimplifyQuery Q = SQ.getWithInstruction(&Or); +// (icmp eq X, 0) | (icmp ult Other, X) -> (icmp ule Other, X-1) +// (icmp ne X, 0) & (icmp uge Other, X) -> (icmp ugt Other, X-1) +Value *foldAndOrOfICmpEqZeroAndICmp(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, + IRBuilderBase &Builder) { + ICmpInst::Predicate LPred = + IsAnd ? LHS->getInversePredicate() : LHS->getPredicate(); + ICmpInst::Predicate RPred = + IsAnd ? RHS->getInversePredicate() : RHS->getPredicate(); + Value *LHS0 = LHS->getOperand(0); + if (LPred != ICmpInst::ICMP_EQ || !match(LHS->getOperand(1), m_Zero()) || + !LHS0->getType()->isIntOrIntVectorTy() || + !(LHS->hasOneUse() || RHS->hasOneUse())) + return nullptr; + + Value *Other; + if (RPred == ICmpInst::ICMP_ULT && RHS->getOperand(1) == LHS0) + Other = RHS->getOperand(0); + else if (RPred == ICmpInst::ICMP_UGT && RHS->getOperand(0) == LHS0) + Other = RHS->getOperand(1); + else + return nullptr; + + return Builder.CreateICmp( + IsAnd ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE, + Builder.CreateAdd(LHS0, Constant::getAllOnesValue(LHS0->getType())), + Other); +} + +/// Fold (icmp)&(icmp) or (icmp)|(icmp) if possible. +/// If IsLogical is true, then the and/or is in select form and the transform +/// must be poison-safe. +Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, + Instruction &I, bool IsAnd, + bool IsLogical) { + const SimplifyQuery Q = SQ.getWithInstruction(&I); // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) + // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) // if K1 and K2 are a one-bit mask. - if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, &Or, - /* IsAnd */ false)) + if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, &I, IsAnd, IsLogical)) return V; ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); @@ -2523,64 +2520,16 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, match(LHS1, m_APInt(LHSC)); match(RHS1, m_APInt(RHSC)); - // Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3) - // --> (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3) - // The original condition actually refers to the following two ranges: - // [MAX_UINT-C1+1, MAX_UINT-C1+1+C3] and [MAX_UINT-C2+1, MAX_UINT-C2+1+C3] - // We can fold these two ranges if: - // 1) C1 and C2 is unsigned greater than C3. - // 2) The two ranges are separated. - // 3) C1 ^ C2 is one-bit mask. - // 4) LowRange1 ^ LowRange2 and HighRange1 ^ HighRange2 are one-bit mask. - // This implies all values in the two ranges differ by exactly one bit. - if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) && - PredL == PredR && LHSC && RHSC && LHS->hasOneUse() && RHS->hasOneUse() && - LHSC->getBitWidth() == RHSC->getBitWidth() && *LHSC == *RHSC) { - - Value *AddOpnd; - const APInt *LAddC, *RAddC; - if (match(LHS0, m_Add(m_Value(AddOpnd), m_APInt(LAddC))) && - match(RHS0, m_Add(m_Specific(AddOpnd), m_APInt(RAddC))) && - LAddC->ugt(*LHSC) && RAddC->ugt(*LHSC)) { - - APInt DiffC = *LAddC ^ *RAddC; - if (DiffC.isPowerOf2()) { - const APInt *MaxAddC = nullptr; - if (LAddC->ult(*RAddC)) - MaxAddC = RAddC; - else - MaxAddC = LAddC; - - APInt RRangeLow = -*RAddC; - APInt RRangeHigh = RRangeLow + *LHSC; - APInt LRangeLow = -*LAddC; - APInt LRangeHigh = LRangeLow + *LHSC; - APInt LowRangeDiff = RRangeLow ^ LRangeLow; - APInt HighRangeDiff = RRangeHigh ^ LRangeHigh; - APInt RangeDiff = LRangeLow.sgt(RRangeLow) ? LRangeLow - RRangeLow - : RRangeLow - LRangeLow; - - if (LowRangeDiff.isPowerOf2() && LowRangeDiff == HighRangeDiff && - RangeDiff.ugt(*LHSC)) { - Type *Ty = AddOpnd->getType(); - Value *MaskC = ConstantInt::get(Ty, ~DiffC); - - Value *NewAnd = Builder.CreateAnd(AddOpnd, MaskC); - Value *NewAdd = Builder.CreateAdd(NewAnd, - ConstantInt::get(Ty, *MaxAddC)); - return Builder.CreateICmp(LHS->getPredicate(), NewAdd, - ConstantInt::get(Ty, *LHSC)); - } - } - } - } - // (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B) + // (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B) if (predicatesFoldable(PredL, PredR)) { - if (LHS0 == RHS1 && LHS1 == RHS0) - LHS->swapOperands(); + if (LHS0 == RHS1 && LHS1 == RHS0) { + PredL = ICmpInst::getSwappedPredicate(PredL); + std::swap(LHS0, LHS1); + } if (LHS0 == RHS0 && LHS1 == RHS1) { - unsigned Code = getICmpCode(LHS) | getICmpCode(RHS); + unsigned Code = IsAnd ? getICmpCode(PredL) & getICmpCode(PredR) + : getICmpCode(PredL) | getICmpCode(PredR); bool IsSigned = LHS->isSigned() || RHS->isSigned(); return getNewICmpValue(Code, IsSigned, LHS0, LHS1, Builder); } @@ -2588,68 +2537,70 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // handle (roughly): // (icmp ne (A & B), C) | (icmp ne (A & D), E) - if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, false, Builder)) + // (icmp eq (A & B), C) & (icmp eq (A & D), E) + if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder)) return V; - if (LHS->hasOneUse() || RHS->hasOneUse()) { - // (icmp eq B, 0) | (icmp ult A, B) -> (icmp ule A, B-1) - // (icmp eq B, 0) | (icmp ugt B, A) -> (icmp ule A, B-1) - Value *A = nullptr, *B = nullptr; - if (PredL == ICmpInst::ICMP_EQ && match(LHS1, m_Zero())) { - B = LHS0; - if (PredR == ICmpInst::ICMP_ULT && LHS0 == RHS1) - A = RHS0; - else if (PredR == ICmpInst::ICMP_UGT && LHS0 == RHS0) - A = RHS1; - } - // (icmp ult A, B) | (icmp eq B, 0) -> (icmp ule A, B-1) - // (icmp ugt B, A) | (icmp eq B, 0) -> (icmp ule A, B-1) - else if (PredR == ICmpInst::ICMP_EQ && match(RHS1, m_Zero())) { - B = RHS0; - if (PredL == ICmpInst::ICMP_ULT && RHS0 == LHS1) - A = LHS0; - else if (PredL == ICmpInst::ICMP_UGT && RHS0 == LHS0) - A = LHS1; - } - if (A && B && B->getType()->isIntOrIntVectorTy()) - return Builder.CreateICmp( - ICmpInst::ICMP_UGE, - Builder.CreateAdd(B, Constant::getAllOnesValue(B->getType())), A); - } - - if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, Or, Builder, Q)) + // TODO: One of these directions is fine with logical and/or, the other could + // be supported by inserting freeze. + if (!IsLogical) { + if (Value *V = foldAndOrOfICmpEqZeroAndICmp(LHS, RHS, IsAnd, Builder)) + return V; + if (Value *V = foldAndOrOfICmpEqZeroAndICmp(RHS, LHS, IsAnd, Builder)) + return V; + } + + // TODO: Verify whether this is safe for logical and/or. + if (!IsLogical) { + if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, IsAnd, Builder, Q)) + return V; + if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, IsAnd, Builder, Q)) + return V; + } + + if (Value *V = foldIsPowerOf2OrZero(LHS, RHS, IsAnd, Builder)) return V; - if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, Or, Builder, Q)) + if (Value *V = foldIsPowerOf2OrZero(RHS, LHS, IsAnd, Builder)) return V; - // E.g. (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n - if (Value *V = simplifyRangeCheck(LHS, RHS, /*Inverted=*/true)) - return V; + // TODO: One of these directions is fine with logical and/or, the other could + // be supported by inserting freeze. + if (!IsLogical) { + // E.g. (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n + // E.g. (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n + if (Value *V = simplifyRangeCheck(LHS, RHS, /*Inverted=*/!IsAnd)) + return V; - // E.g. (icmp sgt x, n) | (icmp slt x, 0) --> icmp ugt x, n - if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/true)) - return V; + // E.g. (icmp sgt x, n) | (icmp slt x, 0) --> icmp ugt x, n + // E.g. (icmp slt x, n) & (icmp sge x, 0) --> icmp ult x, n + if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/!IsAnd)) + return V; + } - if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, false, Builder)) - return V; + // TODO: Add conjugated or fold, check whether it is safe for logical and/or. + if (IsAnd && !IsLogical) + if (Value *V = foldSignedTruncationCheck(LHS, RHS, I, Builder)) + return V; - if (Value *V = foldIsPowerOf2(LHS, RHS, false /* JoinedByAnd */, Builder)) + if (Value *V = foldIsPowerOf2(LHS, RHS, IsAnd, Builder)) return V; - if (Value *X = - foldUnsignedUnderflowCheck(LHS, RHS, /*IsAnd=*/false, Q, Builder)) - return X; - if (Value *X = - foldUnsignedUnderflowCheck(RHS, LHS, /*IsAnd=*/false, Q, Builder)) - return X; + // TODO: Verify whether this is safe for logical and/or. + if (!IsLogical) { + if (Value *X = foldUnsignedUnderflowCheck(LHS, RHS, IsAnd, Q, Builder)) + return X; + if (Value *X = foldUnsignedUnderflowCheck(RHS, LHS, IsAnd, Q, Builder)) + return X; + } - if (Value *X = foldEqOfParts(LHS, RHS, /*IsAnd=*/false)) + if (Value *X = foldEqOfParts(LHS, RHS, IsAnd)) return X; // (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0) + // (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0) // TODO: Remove this when foldLogOpOfMaskedICmps can handle undefs. - if (PredL == ICmpInst::ICMP_NE && match(LHS1, m_ZeroInt()) && - PredR == ICmpInst::ICMP_NE && match(RHS1, m_ZeroInt()) && + if (!IsLogical && PredL == (IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE) && + PredL == PredR && match(LHS1, m_ZeroInt()) && match(RHS1, m_ZeroInt()) && LHS0->getType() == RHS0->getType()) { Value *NewOr = Builder.CreateOr(LHS0, RHS0); return Builder.CreateICmp(PredL, NewOr, @@ -2660,15 +2611,83 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (!LHSC || !RHSC) return nullptr; - return foldAndOrOfICmpsUsingRanges(PredL, LHS0, *LHSC, PredR, RHS0, *RHSC, - Builder, /* IsAnd */ false); + // (trunc x) == C1 & (and x, CA) == C2 -> (and x, CA|CMAX) == C1|C2 + // (trunc x) != C1 | (and x, CA) != C2 -> (and x, CA|CMAX) != C1|C2 + // where CMAX is the all ones value for the truncated type, + // iff the lower bits of C2 and CA are zero. + if (PredL == (IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE) && + PredL == PredR && LHS->hasOneUse() && RHS->hasOneUse()) { + Value *V; + const APInt *AndC, *SmallC = nullptr, *BigC = nullptr; + + // (trunc x) == C1 & (and x, CA) == C2 + // (and x, CA) == C2 & (trunc x) == C1 + if (match(RHS0, m_Trunc(m_Value(V))) && + match(LHS0, m_And(m_Specific(V), m_APInt(AndC)))) { + SmallC = RHSC; + BigC = LHSC; + } else if (match(LHS0, m_Trunc(m_Value(V))) && + match(RHS0, m_And(m_Specific(V), m_APInt(AndC)))) { + SmallC = LHSC; + BigC = RHSC; + } + + if (SmallC && BigC) { + unsigned BigBitSize = BigC->getBitWidth(); + unsigned SmallBitSize = SmallC->getBitWidth(); + + // Check that the low bits are zero. + APInt Low = APInt::getLowBitsSet(BigBitSize, SmallBitSize); + if ((Low & *AndC).isZero() && (Low & *BigC).isZero()) { + Value *NewAnd = Builder.CreateAnd(V, Low | *AndC); + APInt N = SmallC->zext(BigBitSize) | *BigC; + Value *NewVal = ConstantInt::get(NewAnd->getType(), N); + return Builder.CreateICmp(PredL, NewAnd, NewVal); + } + } + } + + // Match naive pattern (and its inverted form) for checking if two values + // share same sign. An example of the pattern: + // (icmp slt (X & Y), 0) | (icmp sgt (X | Y), -1) -> (icmp sgt (X ^ Y), -1) + // Inverted form (example): + // (icmp slt (X | Y), 0) & (icmp sgt (X & Y), -1) -> (icmp slt (X ^ Y), 0) + bool TrueIfSignedL, TrueIfSignedR; + if (InstCombiner::isSignBitCheck(PredL, *LHSC, TrueIfSignedL) && + InstCombiner::isSignBitCheck(PredR, *RHSC, TrueIfSignedR) && + (RHS->hasOneUse() || LHS->hasOneUse())) { + Value *X, *Y; + if (IsAnd) { + if ((TrueIfSignedL && !TrueIfSignedR && + match(LHS0, m_Or(m_Value(X), m_Value(Y))) && + match(RHS0, m_c_And(m_Specific(X), m_Specific(Y)))) || + (!TrueIfSignedL && TrueIfSignedR && + match(LHS0, m_And(m_Value(X), m_Value(Y))) && + match(RHS0, m_c_Or(m_Specific(X), m_Specific(Y))))) { + Value *NewXor = Builder.CreateXor(X, Y); + return Builder.CreateIsNeg(NewXor); + } + } else { + if ((TrueIfSignedL && !TrueIfSignedR && + match(LHS0, m_And(m_Value(X), m_Value(Y))) && + match(RHS0, m_c_Or(m_Specific(X), m_Specific(Y)))) || + (!TrueIfSignedL && TrueIfSignedR && + match(LHS0, m_Or(m_Value(X), m_Value(Y))) && + match(RHS0, m_c_And(m_Specific(X), m_Specific(Y))))) { + Value *NewXor = Builder.CreateXor(X, Y); + return Builder.CreateIsNotNeg(NewXor); + } + } + } + + return foldAndOrOfICmpsUsingRanges(LHS, RHS, IsAnd); } // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { - if (Value *V = SimplifyOrInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifyOrInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -2834,6 +2853,14 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (match(Op1, m_Xor(m_Specific(B), m_Specific(A)))) return BinaryOperator::CreateOr(Op1, C); + // ((A & B) ^ C) | B -> C | B + if (match(Op0, m_c_Xor(m_c_And(m_Value(A), m_Specific(Op1)), m_Value(C)))) + return BinaryOperator::CreateOr(C, Op1); + + // B | ((A & B) ^ C) -> B | C + if (match(Op1, m_c_Xor(m_c_And(m_Value(A), m_Specific(Op0)), m_Value(C)))) + return BinaryOperator::CreateOr(Op0, C); + // ((B | C) & A) | B -> B | (A & C) if (match(Op0, m_And(m_Or(m_Specific(Op1), m_Value(C)), m_Value(A)))) return BinaryOperator::CreateOr(Op1, Builder.CreateAnd(A, C)); @@ -2895,33 +2922,51 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { ICmpInst *LHS = dyn_cast<ICmpInst>(Op0); ICmpInst *RHS = dyn_cast<ICmpInst>(Op1); if (LHS && RHS) - if (Value *Res = foldOrOfICmps(LHS, RHS, I)) + if (Value *Res = foldAndOrOfICmps(LHS, RHS, I, /* IsAnd */ false)) return replaceInstUsesWith(I, Res); // TODO: Make this recursive; it's a little tricky because an arbitrary // number of 'or' instructions might have to be created. Value *X, *Y; - if (LHS && match(Op1, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) { + if (LHS && match(Op1, m_OneUse(m_LogicalOr(m_Value(X), m_Value(Y))))) { + bool IsLogical = isa<SelectInst>(Op1); + // LHS | (X || Y) --> (LHS || X) || Y if (auto *Cmp = dyn_cast<ICmpInst>(X)) - if (Value *Res = foldOrOfICmps(LHS, Cmp, I)) - return replaceInstUsesWith(I, Builder.CreateOr(Res, Y)); + if (Value *Res = + foldAndOrOfICmps(LHS, Cmp, I, /* IsAnd */ false, IsLogical)) + return replaceInstUsesWith(I, IsLogical + ? Builder.CreateLogicalOr(Res, Y) + : Builder.CreateOr(Res, Y)); + // LHS | (X || Y) --> X || (LHS | Y) if (auto *Cmp = dyn_cast<ICmpInst>(Y)) - if (Value *Res = foldOrOfICmps(LHS, Cmp, I)) - return replaceInstUsesWith(I, Builder.CreateOr(Res, X)); - } - if (RHS && match(Op0, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) { + if (Value *Res = foldAndOrOfICmps(LHS, Cmp, I, /* IsAnd */ false, + /* IsLogical */ false)) + return replaceInstUsesWith(I, IsLogical + ? Builder.CreateLogicalOr(X, Res) + : Builder.CreateOr(X, Res)); + } + if (RHS && match(Op0, m_OneUse(m_LogicalOr(m_Value(X), m_Value(Y))))) { + bool IsLogical = isa<SelectInst>(Op0); + // (X || Y) | RHS --> (X || RHS) || Y if (auto *Cmp = dyn_cast<ICmpInst>(X)) - if (Value *Res = foldOrOfICmps(Cmp, RHS, I)) - return replaceInstUsesWith(I, Builder.CreateOr(Res, Y)); + if (Value *Res = + foldAndOrOfICmps(Cmp, RHS, I, /* IsAnd */ false, IsLogical)) + return replaceInstUsesWith(I, IsLogical + ? Builder.CreateLogicalOr(Res, Y) + : Builder.CreateOr(Res, Y)); + // (X || Y) | RHS --> X || (Y | RHS) if (auto *Cmp = dyn_cast<ICmpInst>(Y)) - if (Value *Res = foldOrOfICmps(Cmp, RHS, I)) - return replaceInstUsesWith(I, Builder.CreateOr(Res, X)); + if (Value *Res = foldAndOrOfICmps(Cmp, RHS, I, /* IsAnd */ false, + /* IsLogical */ false)) + return replaceInstUsesWith(I, IsLogical + ? Builder.CreateLogicalOr(X, Res) + : Builder.CreateOr(X, Res)); } } if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) - if (Value *Res = foldLogicOfFCmps(LHS, RHS, false)) + if (Value *Res = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ false)) return replaceInstUsesWith(I, Res); if (Instruction *FoldedFCmps = reassociateFCmps(I, Builder)) @@ -3035,6 +3080,36 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (matchSimpleRecurrence(&I, PN, Start, Step) && DT.dominates(Step, PN)) return replaceInstUsesWith(I, Builder.CreateOr(Start, Step)); + // (A & B) | (C | D) or (C | D) | (A & B) + // Can be combined if C or D is of type (A/B & X) + if (match(&I, m_c_Or(m_OneUse(m_And(m_Value(A), m_Value(B))), + m_OneUse(m_Or(m_Value(C), m_Value(D)))))) { + // (A & B) | (C | ?) -> C | (? | (A & B)) + // (A & B) | (C | ?) -> C | (? | (A & B)) + // (A & B) | (C | ?) -> C | (? | (A & B)) + // (A & B) | (C | ?) -> C | (? | (A & B)) + // (C | ?) | (A & B) -> C | (? | (A & B)) + // (C | ?) | (A & B) -> C | (? | (A & B)) + // (C | ?) | (A & B) -> C | (? | (A & B)) + // (C | ?) | (A & B) -> C | (? | (A & B)) + if (match(D, m_OneUse(m_c_And(m_Specific(A), m_Value()))) || + match(D, m_OneUse(m_c_And(m_Specific(B), m_Value())))) + return BinaryOperator::CreateOr( + C, Builder.CreateOr(D, Builder.CreateAnd(A, B))); + // (A & B) | (? | D) -> (? | (A & B)) | D + // (A & B) | (? | D) -> (? | (A & B)) | D + // (A & B) | (? | D) -> (? | (A & B)) | D + // (A & B) | (? | D) -> (? | (A & B)) | D + // (? | D) | (A & B) -> (? | (A & B)) | D + // (? | D) | (A & B) -> (? | (A & B)) | D + // (? | D) | (A & B) -> (? | (A & B)) | D + // (? | D) | (A & B) -> (? | (A & B)) | D + if (match(C, m_OneUse(m_c_And(m_Specific(A), m_Value()))) || + match(C, m_OneUse(m_c_And(m_Specific(B), m_Value())))) + return BinaryOperator::CreateOr( + Builder.CreateOr(C, Builder.CreateAnd(A, B)), D); + } + return nullptr; } @@ -3096,26 +3171,26 @@ Value *InstCombinerImpl::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, assert(I.getOpcode() == Instruction::Xor && I.getOperand(0) == LHS && I.getOperand(1) == RHS && "Should be 'xor' with these operands"); - if (predicatesFoldable(LHS->getPredicate(), RHS->getPredicate())) { - if (LHS->getOperand(0) == RHS->getOperand(1) && - LHS->getOperand(1) == RHS->getOperand(0)) - LHS->swapOperands(); - if (LHS->getOperand(0) == RHS->getOperand(0) && - LHS->getOperand(1) == RHS->getOperand(1)) { + ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); + Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); + Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); + + if (predicatesFoldable(PredL, PredR)) { + if (LHS0 == RHS1 && LHS1 == RHS0) { + std::swap(LHS0, LHS1); + PredL = ICmpInst::getSwappedPredicate(PredL); + } + if (LHS0 == RHS0 && LHS1 == RHS1) { // (icmp1 A, B) ^ (icmp2 A, B) --> (icmp3 A, B) - Value *Op0 = LHS->getOperand(0), *Op1 = LHS->getOperand(1); - unsigned Code = getICmpCode(LHS) ^ getICmpCode(RHS); + unsigned Code = getICmpCode(PredL) ^ getICmpCode(PredR); bool IsSigned = LHS->isSigned() || RHS->isSigned(); - return getNewICmpValue(Code, IsSigned, Op0, Op1, Builder); + return getNewICmpValue(Code, IsSigned, LHS0, LHS1, Builder); } } // TODO: This can be generalized to compares of non-signbits using // decomposeBitTestICmp(). It could be enhanced more by using (something like) // foldLogOpOfMaskedICmps(). - ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); - Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); - Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); if ((LHS->hasOneUse() || RHS->hasOneUse()) && LHS0->getType() == RHS0->getType() && LHS0->getType()->isIntOrIntVectorTy()) { @@ -3124,19 +3199,17 @@ Value *InstCombinerImpl::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, if ((PredL == CmpInst::ICMP_SGT && match(LHS1, m_AllOnes()) && PredR == CmpInst::ICMP_SGT && match(RHS1, m_AllOnes())) || (PredL == CmpInst::ICMP_SLT && match(LHS1, m_Zero()) && - PredR == CmpInst::ICMP_SLT && match(RHS1, m_Zero()))) { - Value *Zero = ConstantInt::getNullValue(LHS0->getType()); - return Builder.CreateICmpSLT(Builder.CreateXor(LHS0, RHS0), Zero); - } + PredR == CmpInst::ICMP_SLT && match(RHS1, m_Zero()))) + return Builder.CreateIsNeg(Builder.CreateXor(LHS0, RHS0)); + // (X > -1) ^ (Y < 0) --> (X ^ Y) > -1 // (X < 0) ^ (Y > -1) --> (X ^ Y) > -1 if ((PredL == CmpInst::ICMP_SGT && match(LHS1, m_AllOnes()) && PredR == CmpInst::ICMP_SLT && match(RHS1, m_Zero())) || (PredL == CmpInst::ICMP_SLT && match(LHS1, m_Zero()) && - PredR == CmpInst::ICMP_SGT && match(RHS1, m_AllOnes()))) { - Value *MinusOne = ConstantInt::getAllOnesValue(LHS0->getType()); - return Builder.CreateICmpSGT(Builder.CreateXor(LHS0, RHS0), MinusOne); - } + PredR == CmpInst::ICMP_SGT && match(RHS1, m_AllOnes()))) + return Builder.CreateIsNotNeg(Builder.CreateXor(LHS0, RHS0)); + } // Instead of trying to imitate the folds for and/or, decompose this 'xor' @@ -3145,10 +3218,10 @@ Value *InstCombinerImpl::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, // // This is based on a truth table definition of xor: // X ^ Y --> (X | Y) & !(X & Y) - if (Value *OrICmp = SimplifyBinOp(Instruction::Or, LHS, RHS, SQ)) { + if (Value *OrICmp = simplifyBinOp(Instruction::Or, LHS, RHS, SQ)) { // TODO: If OrICmp is true, then the definition of xor simplifies to !(X&Y). // TODO: If OrICmp is false, the whole thing is false (InstSimplify?). - if (Value *AndICmp = SimplifyBinOp(Instruction::And, LHS, RHS, SQ)) { + if (Value *AndICmp = simplifyBinOp(Instruction::And, LHS, RHS, SQ)) { // TODO: Independently handle cases where the 'and' side is a constant. ICmpInst *X = nullptr, *Y = nullptr; if (OrICmp == LHS && AndICmp == RHS) { @@ -3284,12 +3357,12 @@ static Instruction *canonicalizeAbs(BinaryOperator &Xor, // Op1 = ashr i32 A, 31 ; smear the sign bit // xor (add A, Op1), Op1 ; add -1 and flip bits if negative // --> (A < 0) ? -A : A - Value *Cmp = Builder.CreateICmpSLT(A, ConstantInt::getNullValue(Ty)); + Value *IsNeg = Builder.CreateIsNeg(A); // Copy the nuw/nsw flags from the add to the negate. auto *Add = cast<BinaryOperator>(Op0); - Value *Neg = Builder.CreateNeg(A, "", Add->hasNoUnsignedWrap(), + Value *NegA = Builder.CreateNeg(A, "", Add->hasNoUnsignedWrap(), Add->hasNoSignedWrap()); - return SelectInst::Create(Cmp, Neg, A); + return SelectInst::Create(IsNeg, NegA, A); } return nullptr; } @@ -3475,51 +3548,7 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { } } - // TODO: Remove folds if we canonicalize to intrinsics (see above). - // Eliminate a bitwise 'not' op of 'not' min/max by inverting the min/max: - // - // %notx = xor i32 %x, -1 - // %cmp1 = icmp sgt i32 %notx, %y - // %smax = select i1 %cmp1, i32 %notx, i32 %y - // %res = xor i32 %smax, -1 - // => - // %noty = xor i32 %y, -1 - // %cmp2 = icmp slt %x, %noty - // %res = select i1 %cmp2, i32 %x, i32 %noty - // - // Same is applicable for smin/umax/umin. if (NotOp->hasOneUse()) { - Value *LHS, *RHS; - SelectPatternFlavor SPF = matchSelectPattern(NotOp, LHS, RHS).Flavor; - if (SelectPatternResult::isMinOrMax(SPF)) { - // It's possible we get here before the not has been simplified, so make - // sure the input to the not isn't freely invertible. - if (match(LHS, m_Not(m_Value(X))) && !isFreeToInvert(X, X->hasOneUse())) { - Value *NotY = Builder.CreateNot(RHS); - return SelectInst::Create( - Builder.CreateICmp(getInverseMinMaxPred(SPF), X, NotY), X, NotY); - } - - // It's possible we get here before the not has been simplified, so make - // sure the input to the not isn't freely invertible. - if (match(RHS, m_Not(m_Value(Y))) && !isFreeToInvert(Y, Y->hasOneUse())) { - Value *NotX = Builder.CreateNot(LHS); - return SelectInst::Create( - Builder.CreateICmp(getInverseMinMaxPred(SPF), NotX, Y), NotX, Y); - } - - // If both sides are freely invertible, then we can get rid of the xor - // completely. - if (isFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && - isFreeToInvert(RHS, !RHS->hasNUsesOrMore(3))) { - Value *NotLHS = Builder.CreateNot(LHS); - Value *NotRHS = Builder.CreateNot(RHS); - return SelectInst::Create( - Builder.CreateICmp(getInverseMinMaxPred(SPF), NotLHS, NotRHS), - NotLHS, NotRHS); - } - } - // Pull 'not' into operands of select if both operands are one-use compares // or one is one-use compare and the other one is a constant. // Inverting the predicates eliminates the 'not' operation. @@ -3559,7 +3588,7 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { - if (Value *V = SimplifyXorInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifyXorInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -3606,8 +3635,20 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { Value *X, *Y; Constant *C1; if (match(Op1, m_Constant(C1))) { - // Use DeMorgan and reassociation to eliminate a 'not' op. Constant *C2; + + if (match(Op0, m_OneUse(m_Or(m_Value(X), m_ImmConstant(C2)))) && + match(C1, m_ImmConstant())) { + // (X | C2) ^ C1 --> (X & ~C2) ^ (C1^C2) + C2 = Constant::replaceUndefsWith( + C2, Constant::getAllOnesValue(C2->getType()->getScalarType())); + Value *And = Builder.CreateAnd( + X, Constant::mergeUndefsWith(ConstantExpr::getNot(C2), C1)); + return BinaryOperator::CreateXor( + And, Constant::mergeUndefsWith(ConstantExpr::getXor(C1, C2), C1)); + } + + // Use DeMorgan and reassociation to eliminate a 'not' op. if (match(Op0, m_OneUse(m_Or(m_Not(m_Value(X)), m_Constant(C2))))) { // (~X | C2) ^ C1 --> ((X & ~C2) ^ -1) ^ C1 --> (X & ~C2) ^ ~C1 Value *And = Builder.CreateAnd(X, ConstantExpr::getNot(C2)); @@ -3629,9 +3670,8 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { *CA == X->getType()->getScalarSizeInBits() - 1 && !match(C1, m_AllOnes())) { assert(!C1->isZeroValue() && "Unexpected xor with 0"); - Value *ICmp = - Builder.CreateICmpSGT(X, Constant::getAllOnesValue(X->getType())); - return SelectInst::Create(ICmp, Op1, Builder.CreateNot(Op1)); + Value *IsNotNeg = Builder.CreateIsNotNeg(X); + return SelectInst::Create(IsNotNeg, Op1, Builder.CreateNot(Op1)); } } @@ -3687,9 +3727,8 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { APInt FoldConst = C1->getValue().lshr(C2->getValue()); FoldConst ^= C3->getValue(); // Prepare the two operands. - auto *Opnd0 = cast<Instruction>(Builder.CreateLShr(X, C2)); - Opnd0->takeName(cast<Instruction>(Op0)); - Opnd0->setDebugLoc(I.getDebugLoc()); + auto *Opnd0 = Builder.CreateLShr(X, C2); + Opnd0->takeName(Op0); return BinaryOperator::CreateXor(Opnd0, ConstantInt::get(Ty, FoldConst)); } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp index 495493aab4b5..2540e545ae4d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp @@ -12,7 +12,6 @@ #include "InstCombineInternal.h" #include "llvm/IR/Instructions.h" -#include "llvm/Transforms/InstCombine/InstCombiner.h" using namespace llvm; diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 05b28328afbf..67ef2e895b6c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -15,21 +15,18 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/FloatingPointMode.h" #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" -#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/Twine.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumeBundleQueries.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/MemoryBuiltins.h" -#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/Attributes.h" @@ -74,7 +71,6 @@ #include <algorithm> #include <cassert> #include <cstdint> -#include <cstring> #include <utility> #include <vector> @@ -108,6 +104,19 @@ static Type *getPromotedType(Type *Ty) { return Ty; } +/// Recognize a memcpy/memmove from a trivially otherwise unused alloca. +/// TODO: This should probably be integrated with visitAllocSites, but that +/// requires a deeper change to allow either unread or unwritten objects. +static bool hasUndefSource(AnyMemTransferInst *MI) { + auto *Src = MI->getRawSource(); + while (isa<GetElementPtrInst>(Src) || isa<BitCastInst>(Src)) { + if (!Src->hasOneUse()) + return false; + Src = cast<Instruction>(Src)->getOperand(0); + } + return isa<AllocaInst>(Src) && Src->hasOneUse(); +} + Instruction *InstCombinerImpl::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { Align DstAlign = getKnownAlignment(MI->getRawDest(), DL, MI, &AC, &DT); MaybeAlign CopyDstAlign = MI->getDestAlign(); @@ -132,6 +141,14 @@ Instruction *InstCombinerImpl::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { return MI; } + // If the source is provably undef, the memcpy/memmove doesn't do anything + // (unless the transfer is volatile). + if (hasUndefSource(MI) && !MI->isVolatile()) { + // Set the size of the copy to 0, it will be deleted on the next iteration. + MI->setLength(Constant::getNullValue(MI->getLength()->getType())); + return MI; + } + // If MemCpyInst length is 1/2/4/8 bytes then replace memcpy with // load/store. ConstantInt *MemOpLength = dyn_cast<ConstantInt>(MI->getLength()); @@ -241,6 +258,15 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) { return MI; } + // Remove memset with an undef value. + // FIXME: This is technically incorrect because it might overwrite a poison + // value. Change to PoisonValue once #52930 is resolved. + if (isa<UndefValue>(MI->getValue())) { + // Set the size of the copy to 0, it will be deleted on the next iteration. + MI->setLength(Constant::getNullValue(MI->getLength()->getType())); + return MI; + } + // Extract the length and alignment and fill if they are constant. ConstantInt *LenC = dyn_cast<ConstantInt>(MI->getLength()); ConstantInt *FillC = dyn_cast<ConstantInt>(MI->getValue()); @@ -248,7 +274,7 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) { return nullptr; const uint64_t Len = LenC->getLimitedValue(); assert(Len && "0-sized memory setting should be removed already."); - const Align Alignment = assumeAligned(MI->getDestAlignment()); + const Align Alignment = MI->getDestAlign().valueOrOne(); // If it is an atomic and alignment is less than the size then we will // introduce the unaligned memory access which will be later transformed @@ -769,7 +795,7 @@ static CallInst *canonicalizeConstantArg0ToArg1(CallInst &Call) { /// \p Result and a constant \p Overflow value. static Instruction *createOverflowTuple(IntrinsicInst *II, Value *Result, Constant *Overflow) { - Constant *V[] = {UndefValue::get(Result->getType()), Overflow}; + Constant *V[] = {PoisonValue::get(Result->getType()), Overflow}; StructType *ST = cast<StructType>(II->getType()); Constant *Struct = ConstantStruct::get(ST, V); return InsertValueInst::Create(Struct, Result, 0); @@ -795,6 +821,10 @@ static Optional<bool> getKnownSign(Value *Op, Instruction *CxtI, if (Known.isNegative()) return true; + Value *X, *Y; + if (match(Op, m_NSWSub(m_Value(X), m_Value(Y)))) + return isImpliedByDomCondition(ICmpInst::ICMP_SLT, X, Y, CxtI, DL); + return isImpliedByDomCondition( ICmpInst::ICMP_SLT, Op, Constant::getNullValue(Op->getType()), CxtI, DL); } @@ -837,6 +867,67 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II, return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1)) : BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1)); } +/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value. +Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) { + Type *Ty = MinMax1.getType(); + + // We are looking for a tree of: + // max(INT_MIN, min(INT_MAX, add(sext(A), sext(B)))) + // Where the min and max could be reversed + Instruction *MinMax2; + BinaryOperator *AddSub; + const APInt *MinValue, *MaxValue; + if (match(&MinMax1, m_SMin(m_Instruction(MinMax2), m_APInt(MaxValue)))) { + if (!match(MinMax2, m_SMax(m_BinOp(AddSub), m_APInt(MinValue)))) + return nullptr; + } else if (match(&MinMax1, + m_SMax(m_Instruction(MinMax2), m_APInt(MinValue)))) { + if (!match(MinMax2, m_SMin(m_BinOp(AddSub), m_APInt(MaxValue)))) + return nullptr; + } else + return nullptr; + + // Check that the constants clamp a saturate, and that the new type would be + // sensible to convert to. + if (!(*MaxValue + 1).isPowerOf2() || -*MinValue != *MaxValue + 1) + return nullptr; + // In what bitwidth can this be treated as saturating arithmetics? + unsigned NewBitWidth = (*MaxValue + 1).logBase2() + 1; + // FIXME: This isn't quite right for vectors, but using the scalar type is a + // good first approximation for what should be done there. + if (!shouldChangeType(Ty->getScalarType()->getIntegerBitWidth(), NewBitWidth)) + return nullptr; + + // Also make sure that the inner min/max and the add/sub have one use. + if (!MinMax2->hasOneUse() || !AddSub->hasOneUse()) + return nullptr; + + // Create the new type (which can be a vector type) + Type *NewTy = Ty->getWithNewBitWidth(NewBitWidth); + + Intrinsic::ID IntrinsicID; + if (AddSub->getOpcode() == Instruction::Add) + IntrinsicID = Intrinsic::sadd_sat; + else if (AddSub->getOpcode() == Instruction::Sub) + IntrinsicID = Intrinsic::ssub_sat; + else + return nullptr; + + // The two operands of the add/sub must be nsw-truncatable to the NewTy. This + // is usually achieved via a sext from a smaller type. + if (ComputeMaxSignificantBits(AddSub->getOperand(0), 0, AddSub) > + NewBitWidth || + ComputeMaxSignificantBits(AddSub->getOperand(1), 0, AddSub) > NewBitWidth) + return nullptr; + + // Finally create and return the sat intrinsic, truncated to the new type + Function *F = Intrinsic::getDeclaration(MinMax1.getModule(), IntrinsicID, NewTy); + Value *AT = Builder.CreateTrunc(AddSub->getOperand(0), NewTy); + Value *BT = Builder.CreateTrunc(AddSub->getOperand(1), NewTy); + Value *Sat = Builder.CreateCall(F, {AT, BT}); + return CastInst::Create(Instruction::SExt, Sat, Ty); +} + /// If we have a clamp pattern like max (min X, 42), 41 -- where the output /// can only be one of two possible constant values -- turn that into a select @@ -879,6 +970,59 @@ static Instruction *foldClampRangeOfTwo(IntrinsicInst *II, return SelectInst::Create(Cmp, ConstantInt::get(II->getType(), *C0), I1); } +/// If this min/max has a constant operand and an operand that is a matching +/// min/max with a constant operand, constant-fold the 2 constant operands. +static Instruction *reassociateMinMaxWithConstants(IntrinsicInst *II) { + Intrinsic::ID MinMaxID = II->getIntrinsicID(); + auto *LHS = dyn_cast<IntrinsicInst>(II->getArgOperand(0)); + if (!LHS || LHS->getIntrinsicID() != MinMaxID) + return nullptr; + + Constant *C0, *C1; + if (!match(LHS->getArgOperand(1), m_ImmConstant(C0)) || + !match(II->getArgOperand(1), m_ImmConstant(C1))) + return nullptr; + + // max (max X, C0), C1 --> max X, (max C0, C1) --> max X, NewC + ICmpInst::Predicate Pred = MinMaxIntrinsic::getPredicate(MinMaxID); + Constant *CondC = ConstantExpr::getICmp(Pred, C0, C1); + Constant *NewC = ConstantExpr::getSelect(CondC, C0, C1); + + Module *Mod = II->getModule(); + Function *MinMax = Intrinsic::getDeclaration(Mod, MinMaxID, II->getType()); + return CallInst::Create(MinMax, {LHS->getArgOperand(0), NewC}); +} + +/// If this min/max has a matching min/max operand with a constant, try to push +/// the constant operand into this instruction. This can enable more folds. +static Instruction * +reassociateMinMaxWithConstantInOperand(IntrinsicInst *II, + InstCombiner::BuilderTy &Builder) { + // Match and capture a min/max operand candidate. + Value *X, *Y; + Constant *C; + Instruction *Inner; + if (!match(II, m_c_MaxOrMin(m_OneUse(m_CombineAnd( + m_Instruction(Inner), + m_MaxOrMin(m_Value(X), m_ImmConstant(C)))), + m_Value(Y)))) + return nullptr; + + // The inner op must match. Check for constants to avoid infinite loops. + Intrinsic::ID MinMaxID = II->getIntrinsicID(); + auto *InnerMM = dyn_cast<IntrinsicInst>(Inner); + if (!InnerMM || InnerMM->getIntrinsicID() != MinMaxID || + match(X, m_ImmConstant()) || match(Y, m_ImmConstant())) + return nullptr; + + // max (max X, C), Y --> max (max X, Y), C + Function *MinMax = + Intrinsic::getDeclaration(II->getModule(), MinMaxID, II->getType()); + Value *NewInner = Builder.CreateBinaryIntrinsic(MinMaxID, X, Y); + NewInner->takeName(Inner); + return CallInst::Create(MinMax, {NewInner, C}); +} + /// Reduce a sequence of min/max intrinsics with a common operand. static Instruction *factorizeMinMaxTree(IntrinsicInst *II) { // Match 3 of the same min/max ops. Example: umin(umin(), umin()). @@ -936,6 +1080,56 @@ static Instruction *factorizeMinMaxTree(IntrinsicInst *II) { return CallInst::Create(MinMax, { MinMaxOp, ThirdOp }); } +/// If all arguments of the intrinsic are unary shuffles with the same mask, +/// try to shuffle after the intrinsic. +static Instruction * +foldShuffledIntrinsicOperands(IntrinsicInst *II, + InstCombiner::BuilderTy &Builder) { + // TODO: This should be extended to handle other intrinsics like fshl, ctpop, + // etc. Use llvm::isTriviallyVectorizable() and related to determine + // which intrinsics are safe to shuffle? + switch (II->getIntrinsicID()) { + case Intrinsic::smax: + case Intrinsic::smin: + case Intrinsic::umax: + case Intrinsic::umin: + case Intrinsic::fma: + case Intrinsic::fshl: + case Intrinsic::fshr: + break; + default: + return nullptr; + } + + Value *X; + ArrayRef<int> Mask; + if (!match(II->getArgOperand(0), + m_Shuffle(m_Value(X), m_Undef(), m_Mask(Mask)))) + return nullptr; + + // At least 1 operand must have 1 use because we are creating 2 instructions. + if (none_of(II->args(), [](Value *V) { return V->hasOneUse(); })) + return nullptr; + + // See if all arguments are shuffled with the same mask. + SmallVector<Value *, 4> NewArgs(II->arg_size()); + NewArgs[0] = X; + Type *SrcTy = X->getType(); + for (unsigned i = 1, e = II->arg_size(); i != e; ++i) { + if (!match(II->getArgOperand(i), + m_Shuffle(m_Value(X), m_Undef(), m_SpecificMask(Mask))) || + X->getType() != SrcTy) + return nullptr; + NewArgs[i] = X; + } + + // intrinsic (shuf X, M), (shuf Y, M), ... --> shuf (intrinsic X, Y, ...), M + Instruction *FPI = isa<FPMathOperator>(II) ? II : nullptr; + Value *NewIntrinsic = + Builder.CreateIntrinsic(II->getIntrinsicID(), SrcTy, NewArgs, FPI); + return new ShuffleVectorInst(NewIntrinsic, Mask); +} + /// CallInst simplification. This mostly only handles folding of intrinsic /// instructions. For normal calls, it allows visitCallBase to do the heavy /// lifting. @@ -943,14 +1137,14 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // Don't try to simplify calls without uses. It will not do anything useful, // but will result in the following folds being skipped. if (!CI.use_empty()) - if (Value *V = SimplifyCall(&CI, SQ.getWithInstruction(&CI))) + if (Value *V = simplifyCall(&CI, SQ.getWithInstruction(&CI))) return replaceInstUsesWith(CI, V); if (isFreeCall(&CI, &TLI)) return visitFree(CI); - // If the caller function is nounwind, mark the call as nounwind, even if the - // callee isn't. + // If the caller function (i.e. us, the function that contains this CallInst) + // is nounwind, mark the call as nounwind, even if the callee isn't. if (CI.getFunction()->doesNotThrow() && !CI.doesNotThrow()) { CI.setDoesNotThrow(); return &CI; @@ -980,13 +1174,6 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Constant *NumBytes = dyn_cast<Constant>(MI->getLength())) { if (NumBytes->isNullValue()) return eraseInstFromFunction(CI); - - if (ConstantInt *CI = dyn_cast<ConstantInt>(NumBytes)) - if (CI->getZExtValue() == 1) { - // Replace the instruction with just byte operations. We would - // transform other cases to loads/stores, but we don't know if - // alignment is sufficient. - } } // No other transformations apply to volatile transfers. @@ -1050,10 +1237,19 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { return NewCall; } + // Unused constrained FP intrinsic calls may have declared side effect, which + // prevents it from being removed. In some cases however the side effect is + // actually absent. To detect this case, call SimplifyConstrainedFPCall. If it + // returns a replacement, the call may be removed. + if (CI.use_empty() && isa<ConstrainedFPIntrinsic>(CI)) { + if (simplifyConstrainedFPCall(&CI, SQ.getWithInstruction(&CI))) + return eraseInstFromFunction(CI); + } + Intrinsic::ID IID = II->getIntrinsicID(); switch (IID) { case Intrinsic::objectsize: - if (Value *V = lowerObjectSizeCall(II, DL, &TLI, /*MustSucceed=*/false)) + if (Value *V = lowerObjectSizeCall(II, DL, &TLI, AA, /*MustSucceed=*/false)) return replaceInstUsesWith(CI, V); return nullptr; case Intrinsic::abs: { @@ -1224,6 +1420,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Instruction *R = FoldOpIntoSelect(*II, Sel)) return R; + if (Instruction *NewMinMax = reassociateMinMaxWithConstants(II)) + return NewMinMax; + + if (Instruction *R = reassociateMinMaxWithConstantInOperand(II, Builder)) + return R; + if (Instruction *NewMinMax = factorizeMinMaxTree(II)) return NewMinMax; @@ -1231,14 +1433,35 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } case Intrinsic::bswap: { Value *IIOperand = II->getArgOperand(0); - Value *X = nullptr; + + // Try to canonicalize bswap-of-logical-shift-by-8-bit-multiple as + // inverse-shift-of-bswap: + // bswap (shl X, Y) --> lshr (bswap X), Y + // bswap (lshr X, Y) --> shl (bswap X), Y + Value *X, *Y; + if (match(IIOperand, m_OneUse(m_LogicalShift(m_Value(X), m_Value(Y))))) { + // The transform allows undef vector elements, so try a constant match + // first. If knownbits can handle that case, that clause could be removed. + unsigned BitWidth = IIOperand->getType()->getScalarSizeInBits(); + const APInt *C; + if ((match(Y, m_APIntAllowUndef(C)) && (*C & 7) == 0) || + MaskedValueIsZero(Y, APInt::getLowBitsSet(BitWidth, 3))) { + Value *NewSwap = Builder.CreateUnaryIntrinsic(Intrinsic::bswap, X); + BinaryOperator::BinaryOps InverseShift = + cast<BinaryOperator>(IIOperand)->getOpcode() == Instruction::Shl + ? Instruction::LShr + : Instruction::Shl; + return BinaryOperator::Create(InverseShift, NewSwap, Y); + } + } KnownBits Known = computeKnownBits(IIOperand, 0, II); uint64_t LZ = alignDown(Known.countMinLeadingZeros(), 8); uint64_t TZ = alignDown(Known.countMinTrailingZeros(), 8); + unsigned BW = Known.getBitWidth(); // bswap(x) -> shift(x) if x has exactly one "active byte" - if (Known.getBitWidth() - LZ - TZ == 8) { + if (BW - LZ - TZ == 8) { assert(LZ != TZ && "active byte cannot be in the middle"); if (LZ > TZ) // -> shl(x) if the "active byte" is in the low part of x return BinaryOperator::CreateNUWShl( @@ -1250,8 +1473,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // bswap(trunc(bswap(x))) -> trunc(lshr(x, c)) if (match(IIOperand, m_Trunc(m_BSwap(m_Value(X))))) { - unsigned C = X->getType()->getScalarSizeInBits() - - IIOperand->getType()->getScalarSizeInBits(); + unsigned C = X->getType()->getScalarSizeInBits() - BW; Value *CV = ConstantInt::get(X->getType(), C); Value *V = Builder.CreateLShr(X, CV); return new TruncInst(V, IIOperand->getType()); @@ -1618,7 +1840,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } // Try to simplify the underlying FMul. - if (Value *V = SimplifyFMulInst(II->getArgOperand(0), II->getArgOperand(1), + if (Value *V = simplifyFMulInst(II->getArgOperand(0), II->getArgOperand(1), II->getFastMathFlags(), SQ.getWithInstruction(II))) { auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2)); @@ -1649,7 +1871,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // Try to simplify the underlying FMul. We can only apply simplifications // that do not require rounding. - if (Value *V = SimplifyFMAFMul(II->getArgOperand(0), II->getArgOperand(1), + if (Value *V = simplifyFMAFMul(II->getArgOperand(0), II->getArgOperand(1), II->getFastMathFlags(), SQ.getWithInstruction(II))) { auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2)); @@ -2135,7 +2357,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } break; } - case Intrinsic::experimental_vector_insert: { + case Intrinsic::vector_insert: { Value *Vec = II->getArgOperand(0); Value *SubVec = II->getArgOperand(1); Value *Idx = II->getArgOperand(2); @@ -2181,7 +2403,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } break; } - case Intrinsic::experimental_vector_extract: { + case Intrinsic::vector_extract: { Value *Vec = II->getArgOperand(0); Value *Idx = II->getArgOperand(1); @@ -2456,11 +2678,15 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { default: { // Handle target specific intrinsics Optional<Instruction *> V = targetInstCombineIntrinsic(*II); - if (V.hasValue()) + if (V) return V.getValue(); break; } } + + if (Instruction *Shuf = foldShuffledIntrinsicOperands(II, Builder)) + return Shuf; + // Some intrinsics (like experimental_gc_statepoint) can be used in invoke // context, so it is handled in visitCallBase and we should trigger it. return visitCallBase(*II); @@ -2648,47 +2874,56 @@ static IntrinsicInst *findInitTrampoline(Value *Callee) { return nullptr; } -void InstCombinerImpl::annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI) { +bool InstCombinerImpl::annotateAnyAllocSite(CallBase &Call, + const TargetLibraryInfo *TLI) { // Note: We only handle cases which can't be driven from generic attributes // here. So, for example, nonnull and noalias (which are common properties // of some allocation functions) are expected to be handled via annotation // of the respective allocator declaration with generic attributes. + bool Changed = false; - uint64_t Size; - ObjectSizeOpts Opts; - if (getObjectSize(&Call, Size, DL, TLI, Opts) && Size > 0) { - // TODO: We really should just emit deref_or_null here and then - // let the generic inference code combine that with nonnull. - if (Call.hasRetAttr(Attribute::NonNull)) - Call.addRetAttr(Attribute::getWithDereferenceableBytes( - Call.getContext(), Size)); - else - Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), Size)); + if (isAllocationFn(&Call, TLI)) { + uint64_t Size; + ObjectSizeOpts Opts; + if (getObjectSize(&Call, Size, DL, TLI, Opts) && Size > 0) { + // TODO: We really should just emit deref_or_null here and then + // let the generic inference code combine that with nonnull. + if (Call.hasRetAttr(Attribute::NonNull)) { + Changed = !Call.hasRetAttr(Attribute::Dereferenceable); + Call.addRetAttr( + Attribute::getWithDereferenceableBytes(Call.getContext(), Size)); + } else { + Changed = !Call.hasRetAttr(Attribute::DereferenceableOrNull); + Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Size)); + } + } } // Add alignment attribute if alignment is a power of two constant. Value *Alignment = getAllocAlignment(&Call, TLI); if (!Alignment) - return; + return Changed; ConstantInt *AlignOpC = dyn_cast<ConstantInt>(Alignment); if (AlignOpC && AlignOpC->getValue().ult(llvm::Value::MaximumAlignment)) { uint64_t AlignmentVal = AlignOpC->getZExtValue(); if (llvm::isPowerOf2_64(AlignmentVal)) { - Call.removeRetAttr(Attribute::Alignment); - Call.addRetAttr(Attribute::getWithAlignment(Call.getContext(), - Align(AlignmentVal))); + Align ExistingAlign = Call.getRetAlign().valueOrOne(); + Align NewAlign = Align(AlignmentVal); + if (NewAlign > ExistingAlign) { + Call.addRetAttr( + Attribute::getWithAlignment(Call.getContext(), NewAlign)); + Changed = true; + } } } + return Changed; } /// Improvements for call, callbr and invoke instructions. Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { - if (isAllocationFn(&Call, &TLI)) - annotateAnyAllocSite(Call, &TLI); - - bool Changed = false; + bool Changed = annotateAnyAllocSite(Call, &TLI); // Mark any parameters that are known to be non-null with the nonnull // attribute. This is helpful for inlining calls to functions with null @@ -2718,10 +2953,12 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { // If the callee is a pointer to a function, attempt to move any casts to the // arguments of the call/callbr/invoke. Value *Callee = Call.getCalledOperand(); - if (!isa<Function>(Callee) && transformConstExprCastCall(Call)) + Function *CalleeF = dyn_cast<Function>(Callee); + if ((!CalleeF || CalleeF->getFunctionType() != Call.getFunctionType()) && + transformConstExprCastCall(Call)) return nullptr; - if (Function *CalleeF = dyn_cast<Function>(Callee)) { + if (CalleeF) { // Remove the convergent attr on calls when the callee is not convergent. if (Call.isConvergent() && !CalleeF->isConvergent() && !CalleeF->isIntrinsic()) { @@ -2905,7 +3142,7 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { Optional<OperandBundleUse> Bundle = GCSP.getOperandBundle(LLVMContext::OB_gc_live); unsigned NumOfGCLives = LiveGcValues.size(); - if (!Bundle.hasValue() || NumOfGCLives == Bundle->Inputs.size()) + if (!Bundle || NumOfGCLives == Bundle->Inputs.size()) break; // We can reduce the size of gc live bundle. DenseMap<Value *, unsigned> Val2Idx; @@ -3026,8 +3263,7 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { // // Similarly, avoid folding away bitcasts of byval calls. if (Callee->getAttributes().hasAttrSomewhere(Attribute::InAlloca) || - Callee->getAttributes().hasAttrSomewhere(Attribute::Preallocated) || - Callee->getAttributes().hasAttrSomewhere(Attribute::ByVal)) + Callee->getAttributes().hasAttrSomewhere(Attribute::Preallocated)) return false; auto AI = Call.arg_begin(); @@ -3038,12 +3274,15 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { if (!CastInst::isBitOrNoopPointerCastable(ActTy, ParamTy, DL)) return false; // Cannot transform this parameter value. + // Check if there are any incompatible attributes we cannot drop safely. if (AttrBuilder(FT->getContext(), CallerPAL.getParamAttrs(i)) - .overlaps(AttributeFuncs::typeIncompatible(ParamTy))) + .overlaps(AttributeFuncs::typeIncompatible( + ParamTy, AttributeFuncs::ASK_UNSAFE_TO_DROP))) return false; // Attribute not compatible with transformed value. - if (Call.isInAllocaArgument(i)) - return false; // Cannot transform to and from inalloca. + if (Call.isInAllocaArgument(i) || + CallerPAL.hasParamAttr(i, Attribute::Preallocated)) + return false; // Cannot transform to and from inalloca/preallocated. if (CallerPAL.hasParamAttr(i, Attribute::SwiftError)) return false; @@ -3052,13 +3291,18 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { // sized type and the sized type has to have the same size as the old type. if (ParamTy != ActTy && CallerPAL.hasParamAttr(i, Attribute::ByVal)) { PointerType *ParamPTy = dyn_cast<PointerType>(ParamTy); - if (!ParamPTy || !ParamPTy->getPointerElementType()->isSized()) + if (!ParamPTy) return false; - Type *CurElTy = Call.getParamByValType(i); - if (DL.getTypeAllocSize(CurElTy) != - DL.getTypeAllocSize(ParamPTy->getPointerElementType())) - return false; + if (!ParamPTy->isOpaque()) { + Type *ParamElTy = ParamPTy->getNonOpaquePointerElementType(); + if (!ParamElTy->isSized()) + return false; + + Type *CurElTy = Call.getParamByValType(i); + if (DL.getTypeAllocSize(CurElTy) != DL.getTypeAllocSize(ParamElTy)) + return false; + } } } @@ -3116,13 +3360,20 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { NewArg = Builder.CreateBitOrPointerCast(*AI, ParamTy); Args.push_back(NewArg); - // Add any parameter attributes. - if (CallerPAL.hasParamAttr(i, Attribute::ByVal)) { - AttrBuilder AB(FT->getContext(), CallerPAL.getParamAttrs(i)); - AB.addByValAttr(NewArg->getType()->getPointerElementType()); + // Add any parameter attributes except the ones incompatible with the new + // type. Note that we made sure all incompatible ones are safe to drop. + AttributeMask IncompatibleAttrs = AttributeFuncs::typeIncompatible( + ParamTy, AttributeFuncs::ASK_SAFE_TO_DROP); + if (CallerPAL.hasParamAttr(i, Attribute::ByVal) && + !ParamTy->isOpaquePointerTy()) { + AttrBuilder AB(Ctx, CallerPAL.getParamAttrs(i).removeAttributes( + Ctx, IncompatibleAttrs)); + AB.addByValAttr(ParamTy->getNonOpaquePointerElementType()); ArgAttrs.push_back(AttributeSet::get(Ctx, AB)); - } else - ArgAttrs.push_back(CallerPAL.getParamAttrs(i)); + } else { + ArgAttrs.push_back( + CallerPAL.getParamAttrs(i).removeAttributes(Ctx, IncompatibleAttrs)); + } } // If the function takes more arguments than the call was taking, add them diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index f11ba8772f3c..e9e779b8619b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -13,13 +13,10 @@ #include "InstCombineInternal.h" #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/ConstantFolding.h" -#include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/IR/DIBuilder.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" -#include <numeric> using namespace llvm; using namespace PatternMatch; @@ -39,8 +36,10 @@ static Value *decomposeSimpleLinearExpr(Value *Val, unsigned &Scale, if (BinaryOperator *I = dyn_cast<BinaryOperator>(Val)) { // Cannot look past anything that might overflow. + // We specifically require nuw because we store the Scale in an unsigned + // and perform an unsigned divide on it. OverflowingBinaryOperator *OBI = dyn_cast<OverflowingBinaryOperator>(Val); - if (OBI && !OBI->hasNoUnsignedWrap() && !OBI->hasNoSignedWrap()) { + if (OBI && !OBI->hasNoUnsignedWrap()) { Scale = 1; Offset = 0; return Val; @@ -639,10 +638,12 @@ Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) { /// Try to narrow the width of math or bitwise logic instructions by pulling a /// truncate ahead of binary operators. -/// TODO: Transforms for truncated shifts should be moved into here. Instruction *InstCombinerImpl::narrowBinOp(TruncInst &Trunc) { Type *SrcTy = Trunc.getSrcTy(); Type *DestTy = Trunc.getType(); + unsigned SrcWidth = SrcTy->getScalarSizeInBits(); + unsigned DestWidth = DestTy->getScalarSizeInBits(); + if (!isa<VectorType>(SrcTy) && !shouldChangeType(SrcTy, DestTy)) return nullptr; @@ -685,7 +686,30 @@ Instruction *InstCombinerImpl::narrowBinOp(TruncInst &Trunc) { } break; } - + case Instruction::LShr: + case Instruction::AShr: { + // trunc (*shr (trunc A), C) --> trunc(*shr A, C) + Value *A; + Constant *C; + if (match(BinOp0, m_Trunc(m_Value(A))) && match(BinOp1, m_Constant(C))) { + unsigned MaxShiftAmt = SrcWidth - DestWidth; + // If the shift is small enough, all zero/sign bits created by the shift + // are removed by the trunc. + if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, + APInt(SrcWidth, MaxShiftAmt)))) { + auto *OldShift = cast<Instruction>(Trunc.getOperand(0)); + bool IsExact = OldShift->isExact(); + auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true); + ShAmt = Constant::mergeUndefsWith(ShAmt, C); + Value *Shift = + OldShift->getOpcode() == Instruction::AShr + ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact) + : Builder.CreateLShr(A, ShAmt, OldShift->getName(), IsExact); + return CastInst::CreateTruncOrBitCast(Shift, DestTy); + } + } + break; + } default: break; } @@ -873,26 +897,6 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { // TODO: Mask high bits with 'and'. } - // trunc (*shr (trunc A), C) --> trunc(*shr A, C) - if (match(Src, m_OneUse(m_Shr(m_Trunc(m_Value(A)), m_Constant(C))))) { - unsigned MaxShiftAmt = SrcWidth - DestWidth; - - // If the shift is small enough, all zero/sign bits created by the shift are - // removed by the trunc. - if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, - APInt(SrcWidth, MaxShiftAmt)))) { - auto *OldShift = cast<Instruction>(Src); - bool IsExact = OldShift->isExact(); - auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true); - ShAmt = Constant::mergeUndefsWith(ShAmt, C); - Value *Shift = - OldShift->getOpcode() == Instruction::AShr - ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact) - : Builder.CreateLShr(A, ShAmt, OldShift->getName(), IsExact); - return CastInst::CreateTruncOrBitCast(Shift, DestTy); - } - } - if (Instruction *I = narrowBinOp(Trunc)) return I; @@ -971,7 +975,7 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { Attribute Attr = Trunc.getFunction()->getFnAttribute(Attribute::VScaleRange); if (Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { - if (Log2_32(MaxVScale.getValue()) < DestWidth) { + if (Log2_32(*MaxVScale) < DestWidth) { Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); return replaceInstUsesWith(Trunc, VScale); } @@ -986,13 +990,18 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext) // If we are just checking for a icmp eq of a single bit and zext'ing it // to an integer, then shift the bit to the appropriate place and then // cast to integer to avoid the comparison. + + // FIXME: This set of transforms does not check for extra uses and/or creates + // an extra instruction (an optional final cast is not included + // in the transform comments). We may also want to favor icmp over + // shifts in cases of equal instructions because icmp has better + // analysis in general (invert the transform). + const APInt *Op1CV; if (match(Cmp->getOperand(1), m_APInt(Op1CV))) { // zext (x <s 0) to i32 --> x>>u31 true if signbit set. - // zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear. - if ((Cmp->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isZero()) || - (Cmp->getPredicate() == ICmpInst::ICMP_SGT && Op1CV->isAllOnes())) { + if (Cmp->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isZero()) { Value *In = Cmp->getOperand(0); Value *Sh = ConstantInt::get(In->getType(), In->getType()->getScalarSizeInBits() - 1); @@ -1000,11 +1009,6 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext) if (In->getType() != Zext.getType()) In = Builder.CreateIntCast(In, Zext.getType(), false /*ZExt*/); - if (Cmp->getPredicate() == ICmpInst::ICMP_SGT) { - Constant *One = ConstantInt::get(In->getType(), 1); - In = Builder.CreateXor(In, One, In->getName() + ".not"); - } - return replaceInstUsesWith(Zext, In); } @@ -1080,7 +1084,7 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext) KnownBits KnownLHS = computeKnownBits(LHS, 0, &Zext); KnownBits KnownRHS = computeKnownBits(RHS, 0, &Zext); - if (KnownLHS.Zero == KnownRHS.Zero && KnownLHS.One == KnownRHS.One) { + if (KnownLHS == KnownRHS) { APInt KnownBits = KnownLHS.Zero | KnownLHS.One; APInt UnknownBit = ~KnownBits; if (UnknownBit.countPopulation() == 1) { @@ -1343,7 +1347,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { Attribute Attr = CI.getFunction()->getFnAttribute(Attribute::VScaleRange); if (Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { unsigned TypeWidth = Src->getType()->getScalarSizeInBits(); - if (Log2_32(MaxVScale.getValue()) < TypeWidth) { + if (Log2_32(*MaxVScale) < TypeWidth) { Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); return replaceInstUsesWith(CI, VScale); } @@ -1506,10 +1510,8 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { unsigned SrcBitSize = SrcTy->getScalarSizeInBits(); unsigned DestBitSize = DestTy->getScalarSizeInBits(); - // If we know that the value being extended is positive, we can use a zext - // instead. - KnownBits Known = computeKnownBits(Src, 0, &CI); - if (Known.isNonNegative()) + // If the value being extended is zero or positive, use a zext instead. + if (isKnownNonNegative(Src, DL, 0, &AC, &CI, &DT)) return CastInst::Create(Instruction::ZExt, Src, DestTy); // Try to extend the entire expression tree to the wide destination type. @@ -1597,14 +1599,20 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { // Splatting a bit of constant-index across a value: // sext (ashr (trunc iN X to iM), M-1) to iN --> ashr (shl X, N-M), N-1 - // TODO: If the dest type is different, use a cast (adjust use check). + // If the dest type is different, use a cast (adjust use check). if (match(Src, m_OneUse(m_AShr(m_Trunc(m_Value(X)), - m_SpecificInt(SrcBitSize - 1)))) && - X->getType() == DestTy) { - Constant *ShlAmtC = ConstantInt::get(DestTy, DestBitSize - SrcBitSize); - Constant *AshrAmtC = ConstantInt::get(DestTy, DestBitSize - 1); - Value *Shl = Builder.CreateShl(X, ShlAmtC); - return BinaryOperator::CreateAShr(Shl, AshrAmtC); + m_SpecificInt(SrcBitSize - 1))))) { + Type *XTy = X->getType(); + unsigned XBitSize = XTy->getScalarSizeInBits(); + Constant *ShlAmtC = ConstantInt::get(XTy, XBitSize - SrcBitSize); + Constant *AshrAmtC = ConstantInt::get(XTy, XBitSize - 1); + if (XTy == DestTy) + return BinaryOperator::CreateAShr(Builder.CreateShl(X, ShlAmtC), + AshrAmtC); + if (cast<BinaryOperator>(Src)->getOperand(0)->hasOneUse()) { + Value *Ashr = Builder.CreateAShr(Builder.CreateShl(X, ShlAmtC), AshrAmtC); + return CastInst::CreateIntegerCast(Ashr, DestTy, /* isSigned */ true); + } } if (match(Src, m_VScale(DL))) { @@ -1612,7 +1620,7 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { Attribute Attr = CI.getFunction()->getFnAttribute(Attribute::VScaleRange); if (Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { - if (Log2_32(MaxVScale.getValue()) < (SrcBitSize - 1)) { + if (Log2_32(*MaxVScale) < (SrcBitSize - 1)) { Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); return replaceInstUsesWith(CI, VScale); } @@ -1712,7 +1720,7 @@ static Type *getMinimumFPType(Value *V) { /// Return true if the cast from integer to FP can be proven to be exact for all /// possible inputs (the conversion does not lose any precision). -static bool isKnownExactCastIntToFP(CastInst &I) { +static bool isKnownExactCastIntToFP(CastInst &I, InstCombinerImpl &IC) { CastInst::CastOps Opcode = I.getOpcode(); assert((Opcode == CastInst::SIToFP || Opcode == CastInst::UIToFP) && "Unexpected cast"); @@ -1749,6 +1757,12 @@ static bool isKnownExactCastIntToFP(CastInst &I) { // TODO: // Try harder to find if the source integer type has less significant bits. // For example, compute number of sign bits or compute low bit mask. + KnownBits SrcKnown = IC.computeKnownBits(Src, 0, &I); + int LowBits = + (int)SrcTy->getScalarSizeInBits() - SrcKnown.countMinLeadingZeros(); + if (LowBits <= DestNumSigBits) + return true; + return false; } @@ -1929,7 +1943,7 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) { Value *Src = FPT.getOperand(0); if (isa<SIToFPInst>(Src) || isa<UIToFPInst>(Src)) { auto *FPCast = cast<CastInst>(Src); - if (isKnownExactCastIntToFP(*FPCast)) + if (isKnownExactCastIntToFP(*FPCast, *this)) return CastInst::Create(FPCast->getOpcode(), FPCast->getOperand(0), Ty); } @@ -1943,7 +1957,7 @@ Instruction *InstCombinerImpl::visitFPExt(CastInst &FPExt) { Value *Src = FPExt.getOperand(0); if (isa<SIToFPInst>(Src) || isa<UIToFPInst>(Src)) { auto *FPCast = cast<CastInst>(Src); - if (isKnownExactCastIntToFP(*FPCast)) + if (isKnownExactCastIntToFP(*FPCast, *this)) return CastInst::Create(FPCast->getOpcode(), FPCast->getOperand(0), Ty); } @@ -1970,13 +1984,13 @@ Instruction *InstCombinerImpl::foldItoFPtoI(CastInst &FI) { // This means this is also safe for a signed input and unsigned output, since // a negative input would lead to undefined behavior. - if (!isKnownExactCastIntToFP(*OpI)) { + if (!isKnownExactCastIntToFP(*OpI, *this)) { // The first cast may not round exactly based on the source integer width // and FP width, but the overflow UB rules can still allow this to fold. // If the destination type is narrow, that means the intermediate FP value // must be large enough to hold the source value exactly. // For example, (uint8_t)((float)(uint32_t 16777217) is undefined behavior. - int OutputSize = (int)DestType->getScalarSizeInBits() - IsOutputSigned; + int OutputSize = (int)DestType->getScalarSizeInBits(); if (OutputSize > OpI->getType()->getFPMantissaWidth()) return nullptr; } @@ -2150,14 +2164,10 @@ optimizeVectorResizeWithIntegerBitCasts(Value *InVal, VectorType *DestTy, // Now that the element types match, get the shuffle mask and RHS of the // shuffle to use, which depends on whether we're increasing or decreasing the // size of the input. - SmallVector<int, 16> ShuffleMaskStorage; + auto ShuffleMaskStorage = llvm::to_vector<16>(llvm::seq<int>(0, SrcElts)); ArrayRef<int> ShuffleMask; Value *V2; - // Produce an identify shuffle mask for the src vector. - ShuffleMaskStorage.resize(SrcElts); - std::iota(ShuffleMaskStorage.begin(), ShuffleMaskStorage.end(), 0); - if (SrcElts > DestElts) { // If we're shrinking the number of elements (rewriting an integer // truncate), just shuffle in the elements corresponding to the least @@ -2278,6 +2288,8 @@ static bool collectInsertionElements(Value *V, unsigned Shift, switch (I->getOpcode()) { default: return false; // Unhandled case. case Instruction::BitCast: + if (I->getOperand(0)->getType()->isVectorTy()) + return false; return collectInsertionElements(I->getOperand(0), Shift, Elements, VecEltTy, isBigEndian); case Instruction::ZExt: @@ -2351,21 +2363,28 @@ static Value *optimizeIntegerToVectorInsertions(BitCastInst &CI, /// usually not type-specific like scalar integer or scalar floating-point. static Instruction *canonicalizeBitCastExtElt(BitCastInst &BitCast, InstCombinerImpl &IC) { - // TODO: Create and use a pattern matcher for ExtractElementInst. - auto *ExtElt = dyn_cast<ExtractElementInst>(BitCast.getOperand(0)); - if (!ExtElt || !ExtElt->hasOneUse()) + Value *VecOp, *Index; + if (!match(BitCast.getOperand(0), + m_OneUse(m_ExtractElt(m_Value(VecOp), m_Value(Index))))) return nullptr; // The bitcast must be to a vectorizable type, otherwise we can't make a new // type to extract from. Type *DestType = BitCast.getType(); - if (!VectorType::isValidElementType(DestType)) - return nullptr; + VectorType *VecType = cast<VectorType>(VecOp->getType()); + if (VectorType::isValidElementType(DestType)) { + auto *NewVecType = VectorType::get(DestType, VecType); + auto *NewBC = IC.Builder.CreateBitCast(VecOp, NewVecType, "bc"); + return ExtractElementInst::Create(NewBC, Index); + } + + // Only solve DestType is vector to avoid inverse transform in visitBitCast. + // bitcast (extractelement <1 x elt>, dest) -> bitcast(<1 x elt>, dest) + auto *FixedVType = dyn_cast<FixedVectorType>(VecType); + if (DestType->isVectorTy() && FixedVType && FixedVType->getNumElements() == 1) + return CastInst::Create(Instruction::BitCast, VecOp, DestType); - auto *NewVecType = VectorType::get(DestType, ExtElt->getVectorOperandType()); - auto *NewBC = IC.Builder.CreateBitCast(ExtElt->getVectorOperand(), - NewVecType, "bc"); - return ExtractElementInst::Create(NewBC, ExtElt->getIndexOperand()); + return nullptr; } /// Change the type of a bitwise logic operation if we can eliminate a bitcast. @@ -2373,8 +2392,8 @@ static Instruction *foldBitCastBitwiseLogic(BitCastInst &BitCast, InstCombiner::BuilderTy &Builder) { Type *DestTy = BitCast.getType(); BinaryOperator *BO; - if (!DestTy->isIntOrIntVectorTy() || - !match(BitCast.getOperand(0), m_OneUse(m_BinOp(BO))) || + + if (!match(BitCast.getOperand(0), m_OneUse(m_BinOp(BO))) || !BO->isBitwiseLogicOp()) return nullptr; @@ -2384,6 +2403,32 @@ static Instruction *foldBitCastBitwiseLogic(BitCastInst &BitCast, if (!DestTy->isVectorTy() || !BO->getType()->isVectorTy()) return nullptr; + if (DestTy->isFPOrFPVectorTy()) { + Value *X, *Y; + // bitcast(logic(bitcast(X), bitcast(Y))) -> bitcast'(logic(bitcast'(X), Y)) + if (match(BO->getOperand(0), m_OneUse(m_BitCast(m_Value(X)))) && + match(BO->getOperand(1), m_OneUse(m_BitCast(m_Value(Y))))) { + if (X->getType()->isFPOrFPVectorTy() && + Y->getType()->isIntOrIntVectorTy()) { + Value *CastedOp = + Builder.CreateBitCast(BO->getOperand(0), Y->getType()); + Value *NewBO = Builder.CreateBinOp(BO->getOpcode(), CastedOp, Y); + return CastInst::CreateBitOrPointerCast(NewBO, DestTy); + } + if (X->getType()->isIntOrIntVectorTy() && + Y->getType()->isFPOrFPVectorTy()) { + Value *CastedOp = + Builder.CreateBitCast(BO->getOperand(1), X->getType()); + Value *NewBO = Builder.CreateBinOp(BO->getOpcode(), CastedOp, X); + return CastInst::CreateBitOrPointerCast(NewBO, DestTy); + } + } + return nullptr; + } + + if (!DestTy->isIntOrIntVectorTy()) + return nullptr; + Value *X; if (match(BO->getOperand(0), m_OneUse(m_BitCast(m_Value(X)))) && X->getType() == DestTy && !isa<Constant>(X)) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index e45be5745fcc..d1f89973caa1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -17,13 +17,11 @@ #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" @@ -105,10 +103,14 @@ static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) { /// /// If AndCst is non-null, then the loaded value is masked with that constant /// before doing the comparison. This handles cases like "A[i]&4 == 0". -Instruction * -InstCombinerImpl::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, - GlobalVariable *GV, CmpInst &ICI, - ConstantInt *AndCst) { +Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( + LoadInst *LI, GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI, + ConstantInt *AndCst) { + if (LI->isVolatile() || LI->getType() != GEP->getResultElementType() || + GV->getValueType() != GEP->getSourceElementType() || + !GV->isConstant() || !GV->hasDefinitiveInitializer()) + return nullptr; + Constant *Init = GV->getInitializer(); if (!isa<ConstantArray>(Init) && !isa<ConstantDataArray>(Init)) return nullptr; @@ -188,8 +190,11 @@ InstCombinerImpl::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, if (!Elt) return nullptr; // If this is indexing an array of structures, get the structure element. - if (!LaterIndices.empty()) - Elt = ConstantExpr::getExtractValue(Elt, LaterIndices); + if (!LaterIndices.empty()) { + Elt = ConstantFoldExtractValueInstruction(Elt, LaterIndices); + if (!Elt) + return nullptr; + } // If the element is masked, handle it. if (AndCst) Elt = ConstantExpr::getAnd(Elt, AndCst); @@ -757,7 +762,7 @@ getAsConstantIndexedAddress(Type *ElemTy, Value *V, const DataLayout &DL) { V = GEP->getOperand(0); Constant *GEPIndex = static_cast<Constant *>(GEP->getOperand(1)); Index = ConstantExpr::getAdd( - Index, ConstantExpr::getSExtOrBitCast(GEPIndex, IndexType)); + Index, ConstantExpr::getSExtOrTrunc(GEPIndex, IndexType)); continue; } break; @@ -887,7 +892,8 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, if (PtrBase != GEPRHS->getOperand(0)) { bool IndicesTheSame = GEPLHS->getNumOperands() == GEPRHS->getNumOperands() && - GEPLHS->getType() == GEPRHS->getType() && + GEPLHS->getPointerOperand()->getType() == + GEPRHS->getPointerOperand()->getType() && GEPLHS->getSourceElementType() == GEPRHS->getSourceElementType(); if (IndicesTheSame) for (unsigned i = 1, e = GEPLHS->getNumOperands(); i != e; ++i) @@ -950,7 +956,8 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, return foldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds(); - if (GEPLHS->getNumOperands() == GEPRHS->getNumOperands()) { + if (GEPLHS->getNumOperands() == GEPRHS->getNumOperands() && + GEPLHS->getSourceElementType() == GEPRHS->getSourceElementType()) { // If the GEPs only differ by one index, compare it. unsigned NumDifferences = 0; // Keep track of # differences. unsigned DiffOperand = 0; // The operand that differs. @@ -1001,8 +1008,7 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, } Instruction *InstCombinerImpl::foldAllocaCmp(ICmpInst &ICI, - const AllocaInst *Alloca, - const Value *Other) { + const AllocaInst *Alloca) { assert(ICI.isEquality() && "Cannot fold non-equality comparison."); // It would be tempting to fold away comparisons between allocas and any @@ -1071,10 +1077,9 @@ Instruction *InstCombinerImpl::foldAllocaCmp(ICmpInst &ICI, } } - Type *CmpTy = CmpInst::makeCmpResultType(Other->getType()); - return replaceInstUsesWith( - ICI, - ConstantInt::get(CmpTy, !CmpInst::isTrueWhenEqual(ICI.getPredicate()))); + auto *Res = ConstantInt::get(ICI.getType(), + !CmpInst::isTrueWhenEqual(ICI.getPredicate())); + return replaceInstUsesWith(ICI, Res); } /// Fold "icmp pred (X+C), X". @@ -1376,8 +1381,7 @@ Instruction *InstCombinerImpl::foldICmpWithZero(ICmpInst &Cmp) { // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) if (Pred == ICmpInst::ICMP_SGT) { Value *A, *B; - SelectPatternResult SPR = matchSelectPattern(Cmp.getOperand(0), A, B); - if (SPR.Flavor == SPF_SMIN) { + if (match(Cmp.getOperand(0), m_SMin(m_Value(A), m_Value(B)))) { if (isKnownPositive(A, DL, 0, &AC, &Cmp, &DT)) return new ICmpInst(Pred, B, Cmp.getOperand(1)); if (isKnownPositive(B, DL, 0, &AC, &Cmp, &DT)) @@ -1530,7 +1534,7 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) { return nullptr; } -/// Fold icmp (trunc X, Y), C. +/// Fold icmp (trunc X), C. Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, TruncInst *Trunc, const APInt &C) { @@ -1547,6 +1551,16 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, unsigned DstBits = Trunc->getType()->getScalarSizeInBits(), SrcBits = X->getType()->getScalarSizeInBits(); if (Cmp.isEquality() && Trunc->hasOneUse()) { + // Canonicalize to a mask and wider compare if the wide type is suitable: + // (trunc X to i8) == C --> (X & 0xff) == (zext C) + if (!X->getType()->isVectorTy() && shouldChangeType(DstBits, SrcBits)) { + Constant *Mask = ConstantInt::get(X->getType(), + APInt::getLowBitsSet(SrcBits, DstBits)); + Value *And = Builder.CreateAnd(X, Mask); + Constant *WideC = ConstantInt::get(X->getType(), C.zext(SrcBits)); + return new ICmpInst(Pred, And, WideC); + } + // Simplify icmp eq (trunc x to i8), 42 -> icmp eq x, 42|highbits if all // of the high bits truncated out of x are known. KnownBits Known = computeKnownBits(X, 0, &Cmp); @@ -1865,15 +1879,13 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, // Try to optimize things like "A[i] & 42 == 0" to index computations. Value *X = And->getOperand(0); Value *Y = And->getOperand(1); - if (auto *LI = dyn_cast<LoadInst>(X)) - if (auto *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0))) - if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) - if (GV->isConstant() && GV->hasDefinitiveInitializer() && - !LI->isVolatile() && isa<ConstantInt>(Y)) { - ConstantInt *C2 = cast<ConstantInt>(Y); - if (Instruction *Res = foldCmpLoadFromIndexedGlobal(GEP, GV, Cmp, C2)) + if (auto *C2 = dyn_cast<ConstantInt>(Y)) + if (auto *LI = dyn_cast<LoadInst>(X)) + if (auto *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0))) + if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) + if (Instruction *Res = + foldCmpLoadFromIndexedGlobal(LI, GEP, GV, Cmp, C2)) return Res; - } if (!Cmp.isEquality()) return nullptr; @@ -2216,22 +2228,41 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, if (Cmp.isEquality() && Shr->isExact() && C.isZero()) return new ICmpInst(Pred, X, Cmp.getOperand(1)); - const APInt *ShiftVal; - if (Cmp.isEquality() && match(Shr->getOperand(0), m_APInt(ShiftVal))) - return foldICmpShrConstConst(Cmp, Shr->getOperand(1), C, *ShiftVal); - - const APInt *ShiftAmt; - if (!match(Shr->getOperand(1), m_APInt(ShiftAmt))) + bool IsAShr = Shr->getOpcode() == Instruction::AShr; + const APInt *ShiftValC; + if (match(Shr->getOperand(0), m_APInt(ShiftValC))) { + if (Cmp.isEquality()) + return foldICmpShrConstConst(Cmp, Shr->getOperand(1), C, *ShiftValC); + + // If the shifted constant is a power-of-2, test the shift amount directly: + // (ShiftValC >> X) >u C --> X <u (LZ(C) - LZ(ShiftValC)) + // (ShiftValC >> X) <u C --> X >=u (LZ(C-1) - LZ(ShiftValC)) + if (!IsAShr && ShiftValC->isPowerOf2() && + (Pred == CmpInst::ICMP_UGT || Pred == CmpInst::ICMP_ULT)) { + bool IsUGT = Pred == CmpInst::ICMP_UGT; + assert(ShiftValC->uge(C) && "Expected simplify of compare"); + assert((IsUGT || !C.isZero()) && "Expected X u< 0 to simplify"); + + unsigned CmpLZ = + IsUGT ? C.countLeadingZeros() : (C - 1).countLeadingZeros(); + unsigned ShiftLZ = ShiftValC->countLeadingZeros(); + Constant *NewC = ConstantInt::get(Shr->getType(), CmpLZ - ShiftLZ); + auto NewPred = IsUGT ? CmpInst::ICMP_ULT : CmpInst::ICMP_UGE; + return new ICmpInst(NewPred, Shr->getOperand(1), NewC); + } + } + + const APInt *ShiftAmtC; + if (!match(Shr->getOperand(1), m_APInt(ShiftAmtC))) return nullptr; // Check that the shift amount is in range. If not, don't perform undefined // shifts. When the shift is visited it will be simplified. unsigned TypeBits = C.getBitWidth(); - unsigned ShAmtVal = ShiftAmt->getLimitedValue(TypeBits); + unsigned ShAmtVal = ShiftAmtC->getLimitedValue(TypeBits); if (ShAmtVal >= TypeBits || ShAmtVal == 0) return nullptr; - bool IsAShr = Shr->getOpcode() == Instruction::AShr; bool IsExact = Shr->isExact(); Type *ShrTy = Shr->getType(); // TODO: If we could guarantee that InstSimplify would handle all of the @@ -2256,8 +2287,11 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, } if (Pred == CmpInst::ICMP_UGT) { // icmp ugt (ashr X, ShAmtC), C --> icmp ugt X, ((C + 1) << ShAmtC) - 1 + // 'C + 1 << ShAmtC' can overflow as a signed number, so the 2nd + // clause accounts for that pattern. APInt ShiftedC = (C + 1).shl(ShAmtVal) - 1; - if ((ShiftedC + 1).ashr(ShAmtVal) == (C + 1)) + if ((ShiftedC + 1).ashr(ShAmtVal) == (C + 1) || + (C + 1).shl(ShAmtVal).isMinSignedValue()) return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); } @@ -2337,7 +2371,8 @@ Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp, // constant power-of-2 value: // (X % pow2C) sgt/slt 0 const ICmpInst::Predicate Pred = Cmp.getPredicate(); - if (Pred != ICmpInst::ICMP_SGT && Pred != ICmpInst::ICMP_SLT) + if (Pred != ICmpInst::ICMP_SGT && Pred != ICmpInst::ICMP_SLT && + Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE) return nullptr; // TODO: The one-use check is standard because we do not typically want to @@ -2347,7 +2382,15 @@ Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp, return nullptr; const APInt *DivisorC; - if (!C.isZero() || !match(SRem->getOperand(1), m_Power2(DivisorC))) + if (!match(SRem->getOperand(1), m_Power2(DivisorC))) + return nullptr; + + // For cmp_sgt/cmp_slt only zero valued C is handled. + // For cmp_eq/cmp_ne only positive valued C is handled. + if (((Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT) && + !C.isZero()) || + ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && + !C.isStrictlyPositive())) return nullptr; // Mask off the sign bit and the modulo bits (low-bits). @@ -2356,6 +2399,9 @@ Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp, Constant *MaskC = ConstantInt::get(Ty, SignMask | (*DivisorC - 1)); Value *And = Builder.CreateAnd(SRem->getOperand(0), MaskC); + if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) + return new ICmpInst(Pred, And, ConstantInt::get(Ty, C)); + // For 'is positive?' check that the sign-bit is clear and at least 1 masked // bit is set. Example: // (i8 X % 32) s> 0 --> (X & 159) s> 0 @@ -2372,26 +2418,30 @@ Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp, Instruction *InstCombinerImpl::foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv, const APInt &C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = UDiv->getOperand(0); + Value *Y = UDiv->getOperand(1); + Type *Ty = UDiv->getType(); + const APInt *C2; - if (!match(UDiv->getOperand(0), m_APInt(C2))) + if (!match(X, m_APInt(C2))) return nullptr; assert(*C2 != 0 && "udiv 0, X should have been simplified already."); // (icmp ugt (udiv C2, Y), C) -> (icmp ule Y, C2/(C+1)) - Value *Y = UDiv->getOperand(1); - if (Cmp.getPredicate() == ICmpInst::ICMP_UGT) { + if (Pred == ICmpInst::ICMP_UGT) { assert(!C.isMaxValue() && "icmp ugt X, UINT_MAX should have been simplified already."); return new ICmpInst(ICmpInst::ICMP_ULE, Y, - ConstantInt::get(Y->getType(), C2->udiv(C + 1))); + ConstantInt::get(Ty, C2->udiv(C + 1))); } // (icmp ult (udiv C2, Y), C) -> (icmp ugt Y, C2/C) - if (Cmp.getPredicate() == ICmpInst::ICMP_ULT) { + if (Pred == ICmpInst::ICMP_ULT) { assert(C != 0 && "icmp ult X, 0 should have been simplified already."); return new ICmpInst(ICmpInst::ICMP_UGT, Y, - ConstantInt::get(Y->getType(), C2->udiv(C))); + ConstantInt::get(Ty, C2->udiv(C))); } return nullptr; @@ -2401,6 +2451,28 @@ Instruction *InstCombinerImpl::foldICmpUDivConstant(ICmpInst &Cmp, Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div, const APInt &C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Div->getOperand(0); + Value *Y = Div->getOperand(1); + Type *Ty = Div->getType(); + bool DivIsSigned = Div->getOpcode() == Instruction::SDiv; + + // If unsigned division and the compare constant is bigger than + // UMAX/2 (negative), there's only one pair of values that satisfies an + // equality check, so eliminate the division: + // (X u/ Y) == C --> (X == C) && (Y == 1) + // (X u/ Y) != C --> (X != C) || (Y != 1) + // Similarly, if signed division and the compare constant is exactly SMIN: + // (X s/ Y) == SMIN --> (X == SMIN) && (Y == 1) + // (X s/ Y) != SMIN --> (X != SMIN) || (Y != 1) + if (Cmp.isEquality() && Div->hasOneUse() && C.isSignBitSet() && + (!DivIsSigned || C.isMinSignedValue())) { + Value *XBig = Builder.CreateICmp(Pred, X, ConstantInt::get(Ty, C)); + Value *YOne = Builder.CreateICmp(Pred, Y, ConstantInt::get(Ty, 1)); + auto Logic = Pred == ICmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; + return BinaryOperator::Create(Logic, XBig, YOne); + } + // Fold: icmp pred ([us]div X, C2), C -> range test // Fold this div into the comparison, producing a range check. // Determine, based on the divide type, what the range is being @@ -2408,7 +2480,7 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, // it, otherwise compute the range [low, hi) bounding the new value. // See: InsertRangeTest above for the kinds of replacements possible. const APInt *C2; - if (!match(Div->getOperand(1), m_APInt(C2))) + if (!match(Y, m_APInt(C2))) return nullptr; // FIXME: If the operand types don't match the type of the divide @@ -2419,7 +2491,6 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, // (x /u C2) <u C. Simply casting the operands and result won't // work. :( The if statement below tests that condition and bails // if it finds it. - bool DivIsSigned = Div->getOpcode() == Instruction::SDiv; if (!Cmp.isEquality() && DivIsSigned != Cmp.isSigned()) return nullptr; @@ -2441,8 +2512,6 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, // instruction that we're folding. bool ProdOV = (DivIsSigned ? Prod.sdiv(*C2) : Prod.udiv(*C2)) != C; - ICmpInst::Predicate Pred = Cmp.getPredicate(); - // If the division is known to be exact, then there is no remainder from the // divide, so the covered range size is unit, otherwise it is the divisor. APInt RangeSize = Div->isExact() ? APInt(C2->getBitWidth(), 1) : *C2; @@ -2457,7 +2526,7 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, int LoOverflow = 0, HiOverflow = 0; APInt LoBound, HiBound; - if (!DivIsSigned) { // udiv + if (!DivIsSigned) { // udiv // e.g. X/5 op 3 --> [15, 20) LoBound = Prod; HiOverflow = LoOverflow = ProdOV; @@ -2472,7 +2541,7 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, LoBound = -(RangeSize - 1); HiBound = RangeSize; } else if (C.isStrictlyPositive()) { // (X / pos) op pos - LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) + LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) HiOverflow = LoOverflow = ProdOV; if (!HiOverflow) HiOverflow = addWithOverflow(HiBound, Prod, RangeSize, true); @@ -2492,18 +2561,19 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, // e.g. X/-5 op 0 --> [-4, 5) LoBound = RangeSize + 1; HiBound = -RangeSize; - if (HiBound == *C2) { // -INTMIN = INTMIN - HiOverflow = 1; // [INTMIN+1, overflow) - HiBound = APInt(); // e.g. X/INTMIN = 0 --> X > INTMIN + if (HiBound == *C2) { // -INTMIN = INTMIN + HiOverflow = 1; // [INTMIN+1, overflow) + HiBound = APInt(); // e.g. X/INTMIN = 0 --> X > INTMIN } } else if (C.isStrictlyPositive()) { // (X / neg) op pos // e.g. X/-5 op 3 --> [-19, -14) HiBound = Prod + 1; HiOverflow = LoOverflow = ProdOV ? -1 : 0; if (!LoOverflow) - LoOverflow = addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0; - } else { // (X / neg) op neg - LoBound = Prod; // e.g. X/-5 op -3 --> [15, 20) + LoOverflow = + addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1 : 0; + } else { // (X / neg) op neg + LoBound = Prod; // e.g. X/-5 op -3 --> [15, 20) LoOverflow = HiOverflow = ProdOV; if (!HiOverflow) HiOverflow = subWithOverflow(HiBound, Prod, RangeSize, true); @@ -2513,54 +2583,47 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, Pred = ICmpInst::getSwappedPredicate(Pred); } - Value *X = Div->getOperand(0); switch (Pred) { - default: llvm_unreachable("Unhandled icmp opcode!"); - case ICmpInst::ICMP_EQ: - if (LoOverflow && HiOverflow) - return replaceInstUsesWith(Cmp, Builder.getFalse()); - if (HiOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : - ICmpInst::ICMP_UGE, X, - ConstantInt::get(Div->getType(), LoBound)); - if (LoOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : - ICmpInst::ICMP_ULT, X, - ConstantInt::get(Div->getType(), HiBound)); - return replaceInstUsesWith( - Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, true)); - case ICmpInst::ICMP_NE: - if (LoOverflow && HiOverflow) - return replaceInstUsesWith(Cmp, Builder.getTrue()); - if (HiOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : - ICmpInst::ICMP_ULT, X, - ConstantInt::get(Div->getType(), LoBound)); - if (LoOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : - ICmpInst::ICMP_UGE, X, - ConstantInt::get(Div->getType(), HiBound)); - return replaceInstUsesWith(Cmp, - insertRangeTest(X, LoBound, HiBound, - DivIsSigned, false)); - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_SLT: - if (LoOverflow == +1) // Low bound is greater than input range. - return replaceInstUsesWith(Cmp, Builder.getTrue()); - if (LoOverflow == -1) // Low bound is less than input range. - return replaceInstUsesWith(Cmp, Builder.getFalse()); - return new ICmpInst(Pred, X, ConstantInt::get(Div->getType(), LoBound)); - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_SGT: - if (HiOverflow == +1) // High bound greater than input range. - return replaceInstUsesWith(Cmp, Builder.getFalse()); - if (HiOverflow == -1) // High bound less than input range. - return replaceInstUsesWith(Cmp, Builder.getTrue()); - if (Pred == ICmpInst::ICMP_UGT) - return new ICmpInst(ICmpInst::ICMP_UGE, X, - ConstantInt::get(Div->getType(), HiBound)); - return new ICmpInst(ICmpInst::ICMP_SGE, X, - ConstantInt::get(Div->getType(), HiBound)); + default: + llvm_unreachable("Unhandled icmp predicate!"); + case ICmpInst::ICMP_EQ: + if (LoOverflow && HiOverflow) + return replaceInstUsesWith(Cmp, Builder.getFalse()); + if (HiOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, + X, ConstantInt::get(Ty, LoBound)); + if (LoOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, + X, ConstantInt::get(Ty, HiBound)); + return replaceInstUsesWith( + Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, true)); + case ICmpInst::ICMP_NE: + if (LoOverflow && HiOverflow) + return replaceInstUsesWith(Cmp, Builder.getTrue()); + if (HiOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, + X, ConstantInt::get(Ty, LoBound)); + if (LoOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, + X, ConstantInt::get(Ty, HiBound)); + return replaceInstUsesWith( + Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, false)); + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLT: + if (LoOverflow == +1) // Low bound is greater than input range. + return replaceInstUsesWith(Cmp, Builder.getTrue()); + if (LoOverflow == -1) // Low bound is less than input range. + return replaceInstUsesWith(Cmp, Builder.getFalse()); + return new ICmpInst(Pred, X, ConstantInt::get(Ty, LoBound)); + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGT: + if (HiOverflow == +1) // High bound greater than input range. + return replaceInstUsesWith(Cmp, Builder.getFalse()); + if (HiOverflow == -1) // High bound less than input range. + return replaceInstUsesWith(Cmp, Builder.getTrue()); + if (Pred == ICmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_UGE, X, ConstantInt::get(Ty, HiBound)); + return new ICmpInst(ICmpInst::ICMP_SGE, X, ConstantInt::get(Ty, HiBound)); } return nullptr; @@ -2593,18 +2656,24 @@ Instruction *InstCombinerImpl::foldICmpSubConstant(ICmpInst &Cmp, !subWithOverflow(SubResult, *C2, C, Cmp.isSigned())) return new ICmpInst(SwappedPred, Y, ConstantInt::get(Ty, SubResult)); + // X - Y == 0 --> X == Y. + // X - Y != 0 --> X != Y. + // TODO: We allow this with multiple uses as long as the other uses are not + // in phis. The phi use check is guarding against a codegen regression + // for a loop test. If the backend could undo this (and possibly + // subsequent transforms), we would not need this hack. + if (Cmp.isEquality() && C.isZero() && + none_of((Sub->users()), [](const User *U) { return isa<PHINode>(U); })) + return new ICmpInst(Pred, X, Y); + // The following transforms are only worth it if the only user of the subtract // is the icmp. // TODO: This is an artificial restriction for all of the transforms below - // that only need a single replacement icmp. + // that only need a single replacement icmp. Can these use the phi test + // like the transform above here? if (!Sub->hasOneUse()) return nullptr; - // X - Y == 0 --> X == Y. - // X - Y != 0 --> X != Y. - if (Cmp.isEquality() && C.isZero()) - return new ICmpInst(Pred, X, Y); - if (Sub->hasNoSignedWrap()) { // (icmp sgt (sub nsw X, Y), -1) -> (icmp sge X, Y) if (Pred == ICmpInst::ICMP_SGT && C.isAllOnes()) @@ -2855,10 +2924,13 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { ICmpInst::Predicate Pred = Cmp.getPredicate(); Value *Op1 = Cmp.getOperand(1); Value *BCSrcOp = Bitcast->getOperand(0); + Type *SrcType = Bitcast->getSrcTy(); + Type *DstType = Bitcast->getType(); - // Make sure the bitcast doesn't change the number of vector elements. - if (Bitcast->getSrcTy()->getScalarSizeInBits() == - Bitcast->getDestTy()->getScalarSizeInBits()) { + // Make sure the bitcast doesn't change between scalar and vector and + // doesn't change the number of vector elements. + if (SrcType->isVectorTy() == DstType->isVectorTy() && + SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits()) { // Zero-equality and sign-bit checks are preserved through sitofp + bitcast. Value *X; if (match(BCSrcOp, m_SIToFP(m_Value(X)))) { @@ -2903,8 +2975,7 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { Type *XType = X->getType(); // We can't currently handle Power style floating point operations here. - if (!(XType->isPPC_FP128Ty() || BCSrcOp->getType()->isPPC_FP128Ty())) { - + if (!(XType->isPPC_FP128Ty() || SrcType->isPPC_FP128Ty())) { Type *NewType = Builder.getIntNTy(XType->getScalarSizeInBits()); if (auto *XVTy = dyn_cast<VectorType>(XType)) NewType = VectorType::get(NewType, XVTy->getElementCount()); @@ -2922,21 +2993,19 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { // Test to see if the operands of the icmp are casted versions of other // values. If the ptr->ptr cast can be stripped off both arguments, do so. - if (Bitcast->getType()->isPointerTy() && - (isa<Constant>(Op1) || isa<BitCastInst>(Op1))) { + if (DstType->isPointerTy() && (isa<Constant>(Op1) || isa<BitCastInst>(Op1))) { // If operand #1 is a bitcast instruction, it must also be a ptr->ptr cast // so eliminate it as well. if (auto *BC2 = dyn_cast<BitCastInst>(Op1)) Op1 = BC2->getOperand(0); - Op1 = Builder.CreateBitCast(Op1, BCSrcOp->getType()); + Op1 = Builder.CreateBitCast(Op1, SrcType); return new ICmpInst(Pred, BCSrcOp, Op1); } const APInt *C; - if (!match(Cmp.getOperand(1), m_APInt(C)) || - !Bitcast->getType()->isIntegerTy() || - !Bitcast->getSrcTy()->isIntOrIntVectorTy()) + if (!match(Cmp.getOperand(1), m_APInt(C)) || !DstType->isIntegerTy() || + !SrcType->isIntOrIntVectorTy()) return nullptr; // If this is checking if all elements of a vector compare are set or not, @@ -2948,9 +3017,8 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { // TODO: Try harder to reduce compare of 2 freely invertible operands? if (Cmp.isEquality() && C->isAllOnes() && Bitcast->hasOneUse() && isFreeToInvert(BCSrcOp, BCSrcOp->hasOneUse())) { - Type *ScalarTy = Bitcast->getType(); - Value *Cast = Builder.CreateBitCast(Builder.CreateNot(BCSrcOp), ScalarTy); - return new ICmpInst(Pred, Cast, ConstantInt::getNullValue(ScalarTy)); + Value *Cast = Builder.CreateBitCast(Builder.CreateNot(BCSrcOp), DstType); + return new ICmpInst(Pred, Cast, ConstantInt::getNullValue(DstType)); } // If this is checking if all elements of an extended vector are clear or not, @@ -2978,7 +3046,7 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { if (match(BCSrcOp, m_Shuffle(m_Value(Vec), m_Undef(), m_Mask(Mask)))) { // Check whether every element of Mask is the same constant if (is_splat(Mask)) { - auto *VecTy = cast<VectorType>(BCSrcOp->getType()); + auto *VecTy = cast<VectorType>(SrcType); auto *EltTy = cast<IntegerType>(VecTy->getElementType()); if (C->isSplat(EltTy->getBitWidth())) { // Fold the icmp based on the value of C @@ -3000,83 +3068,31 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { /// where X is some kind of instruction. Instruction *InstCombinerImpl::foldICmpInstWithConstant(ICmpInst &Cmp) { const APInt *C; - if (!match(Cmp.getOperand(1), m_APInt(C))) - return nullptr; - if (auto *BO = dyn_cast<BinaryOperator>(Cmp.getOperand(0))) { - switch (BO->getOpcode()) { - case Instruction::Xor: - if (Instruction *I = foldICmpXorConstant(Cmp, BO, *C)) - return I; - break; - case Instruction::And: - if (Instruction *I = foldICmpAndConstant(Cmp, BO, *C)) - return I; - break; - case Instruction::Or: - if (Instruction *I = foldICmpOrConstant(Cmp, BO, *C)) - return I; - break; - case Instruction::Mul: - if (Instruction *I = foldICmpMulConstant(Cmp, BO, *C)) - return I; - break; - case Instruction::Shl: - if (Instruction *I = foldICmpShlConstant(Cmp, BO, *C)) - return I; - break; - case Instruction::LShr: - case Instruction::AShr: - if (Instruction *I = foldICmpShrConstant(Cmp, BO, *C)) - return I; - break; - case Instruction::SRem: - if (Instruction *I = foldICmpSRemConstant(Cmp, BO, *C)) - return I; - break; - case Instruction::UDiv: - if (Instruction *I = foldICmpUDivConstant(Cmp, BO, *C)) - return I; - LLVM_FALLTHROUGH; - case Instruction::SDiv: - if (Instruction *I = foldICmpDivConstant(Cmp, BO, *C)) + if (match(Cmp.getOperand(1), m_APInt(C))) { + if (auto *BO = dyn_cast<BinaryOperator>(Cmp.getOperand(0))) + if (Instruction *I = foldICmpBinOpWithConstant(Cmp, BO, *C)) return I; - break; - case Instruction::Sub: - if (Instruction *I = foldICmpSubConstant(Cmp, BO, *C)) - return I; - break; - case Instruction::Add: - if (Instruction *I = foldICmpAddConstant(Cmp, BO, *C)) - return I; - break; - default: - break; - } - // TODO: These folds could be refactored to be part of the above calls. - if (Instruction *I = foldICmpBinOpEqualityWithConstant(Cmp, BO, *C)) - return I; - } - // Match against CmpInst LHS being instructions other than binary operators. + if (auto *SI = dyn_cast<SelectInst>(Cmp.getOperand(0))) + // For now, we only support constant integers while folding the + // ICMP(SELECT)) pattern. We can extend this to support vector of integers + // similar to the cases handled by binary ops above. + if (auto *ConstRHS = dyn_cast<ConstantInt>(Cmp.getOperand(1))) + if (Instruction *I = foldICmpSelectConstant(Cmp, SI, ConstRHS)) + return I; - if (auto *SI = dyn_cast<SelectInst>(Cmp.getOperand(0))) { - // For now, we only support constant integers while folding the - // ICMP(SELECT)) pattern. We can extend this to support vector of integers - // similar to the cases handled by binary ops above. - if (ConstantInt *ConstRHS = dyn_cast<ConstantInt>(Cmp.getOperand(1))) - if (Instruction *I = foldICmpSelectConstant(Cmp, SI, ConstRHS)) + if (auto *TI = dyn_cast<TruncInst>(Cmp.getOperand(0))) + if (Instruction *I = foldICmpTruncConstant(Cmp, TI, *C)) return I; - } - if (auto *TI = dyn_cast<TruncInst>(Cmp.getOperand(0))) { - if (Instruction *I = foldICmpTruncConstant(Cmp, TI, *C)) - return I; + if (auto *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0))) + if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, II, *C)) + return I; } - if (auto *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0))) - if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, II, *C)) - return I; + if (match(Cmp.getOperand(1), m_APIntAllowUndef(C))) + return foldICmpInstWithConstantAllowUndef(Cmp, *C); return nullptr; } @@ -3233,12 +3249,6 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( case Intrinsic::fshl: case Intrinsic::fshr: if (II->getArgOperand(0) == II->getArgOperand(1)) { - // (rot X, ?) == 0/-1 --> X == 0/-1 - // TODO: This transform is safe to re-use undef elts in a vector, but - // the constant value passed in by the caller doesn't allow that. - if (C.isZero() || C.isAllOnes()) - return new ICmpInst(Pred, II->getArgOperand(0), Cmp.getOperand(1)); - const APInt *RotAmtC; // ror(X, RotAmtC) == C --> X == rol(C, RotAmtC) // rol(X, RotAmtC) == C --> X == ror(C, RotAmtC) @@ -3311,6 +3321,89 @@ static Instruction *foldICmpIntrinsicWithIntrinsic(ICmpInst &Cmp) { return nullptr; } +/// Try to fold integer comparisons with a constant operand: icmp Pred X, C +/// where X is some kind of instruction and C is AllowUndef. +/// TODO: Move more folds which allow undef to this function. +Instruction * +InstCombinerImpl::foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp, + const APInt &C) { + const ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (auto *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0))) { + switch (II->getIntrinsicID()) { + default: + break; + case Intrinsic::fshl: + case Intrinsic::fshr: + if (Cmp.isEquality() && II->getArgOperand(0) == II->getArgOperand(1)) { + // (rot X, ?) == 0/-1 --> X == 0/-1 + if (C.isZero() || C.isAllOnes()) + return new ICmpInst(Pred, II->getArgOperand(0), Cmp.getOperand(1)); + } + break; + } + } + + return nullptr; +} + +/// Fold an icmp with BinaryOp and constant operand: icmp Pred BO, C. +Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp, + BinaryOperator *BO, + const APInt &C) { + switch (BO->getOpcode()) { + case Instruction::Xor: + if (Instruction *I = foldICmpXorConstant(Cmp, BO, C)) + return I; + break; + case Instruction::And: + if (Instruction *I = foldICmpAndConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Or: + if (Instruction *I = foldICmpOrConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Mul: + if (Instruction *I = foldICmpMulConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Shl: + if (Instruction *I = foldICmpShlConstant(Cmp, BO, C)) + return I; + break; + case Instruction::LShr: + case Instruction::AShr: + if (Instruction *I = foldICmpShrConstant(Cmp, BO, C)) + return I; + break; + case Instruction::SRem: + if (Instruction *I = foldICmpSRemConstant(Cmp, BO, C)) + return I; + break; + case Instruction::UDiv: + if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C)) + return I; + LLVM_FALLTHROUGH; + case Instruction::SDiv: + if (Instruction *I = foldICmpDivConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Sub: + if (Instruction *I = foldICmpSubConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Add: + if (Instruction *I = foldICmpAddConstant(Cmp, BO, C)) + return I; + break; + default: + break; + } + + // TODO: These folds could be refactored to be part of the above calls. + return foldICmpBinOpEqualityWithConstant(Cmp, BO, C); +} + /// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, IntrinsicInst *II, @@ -3406,64 +3499,6 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) { if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) return NV; break; - case Instruction::Select: { - // If either operand of the select is a constant, we can fold the - // comparison into the select arms, which will cause one to be - // constant folded and the select turned into a bitwise or. - Value *Op1 = nullptr, *Op2 = nullptr; - ConstantInt *CI = nullptr; - - auto SimplifyOp = [&](Value *V) { - Value *Op = nullptr; - if (Constant *C = dyn_cast<Constant>(V)) { - Op = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); - } else if (RHSC->isNullValue()) { - // If null is being compared, check if it can be further simplified. - Op = SimplifyICmpInst(I.getPredicate(), V, RHSC, SQ); - } - return Op; - }; - Op1 = SimplifyOp(LHSI->getOperand(1)); - if (Op1) - CI = dyn_cast<ConstantInt>(Op1); - - Op2 = SimplifyOp(LHSI->getOperand(2)); - if (Op2) - CI = dyn_cast<ConstantInt>(Op2); - - // We only want to perform this transformation if it will not lead to - // additional code. This is true if either both sides of the select - // fold to a constant (in which case the icmp is replaced with a select - // which will usually simplify) or this is the only user of the - // select (in which case we are trading a select+icmp for a simpler - // select+icmp) or all uses of the select can be replaced based on - // dominance information ("Global cases"). - bool Transform = false; - if (Op1 && Op2) - Transform = true; - else if (Op1 || Op2) { - // Local case - if (LHSI->hasOneUse()) - Transform = true; - // Global cases - else if (CI && !CI->isZero()) - // When Op1 is constant try replacing select with second operand. - // Otherwise Op2 is constant and try replacing select with first - // operand. - Transform = - replacedSelectWithOperand(cast<SelectInst>(LHSI), &I, Op1 ? 2 : 1); - } - if (Transform) { - if (!Op1) - Op1 = Builder.CreateICmp(I.getPredicate(), LHSI->getOperand(1), RHSC, - I.getName()); - if (!Op2) - Op2 = Builder.CreateICmp(I.getPredicate(), LHSI->getOperand(2), RHSC, - I.getName()); - return SelectInst::Create(LHSI->getOperand(0), Op1, Op2); - } - break; - } case Instruction::IntToPtr: // icmp pred inttoptr(X), null -> icmp pred X, 0 if (RHSC->isNullValue() && @@ -3476,19 +3511,72 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) { case Instruction::Load: // Try to optimize things like "A[i] > 4" to index computations. if (GetElementPtrInst *GEP = - dyn_cast<GetElementPtrInst>(LHSI->getOperand(0))) { + dyn_cast<GetElementPtrInst>(LHSI->getOperand(0))) if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) - if (GV->isConstant() && GV->hasDefinitiveInitializer() && - !cast<LoadInst>(LHSI)->isVolatile()) - if (Instruction *Res = foldCmpLoadFromIndexedGlobal(GEP, GV, I)) - return Res; - } + if (Instruction *Res = + foldCmpLoadFromIndexedGlobal(cast<LoadInst>(LHSI), GEP, GV, I)) + return Res; break; } return nullptr; } +Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred, + SelectInst *SI, Value *RHS, + const ICmpInst &I) { + // Try to fold the comparison into the select arms, which will cause the + // select to be converted into a logical and/or. + auto SimplifyOp = [&](Value *Op, bool SelectCondIsTrue) -> Value * { + if (Value *Res = simplifyICmpInst(Pred, Op, RHS, SQ)) + return Res; + if (Optional<bool> Impl = isImpliedCondition(SI->getCondition(), Pred, Op, + RHS, DL, SelectCondIsTrue)) + return ConstantInt::get(I.getType(), *Impl); + return nullptr; + }; + + ConstantInt *CI = nullptr; + Value *Op1 = SimplifyOp(SI->getOperand(1), true); + if (Op1) + CI = dyn_cast<ConstantInt>(Op1); + + Value *Op2 = SimplifyOp(SI->getOperand(2), false); + if (Op2) + CI = dyn_cast<ConstantInt>(Op2); + + // We only want to perform this transformation if it will not lead to + // additional code. This is true if either both sides of the select + // fold to a constant (in which case the icmp is replaced with a select + // which will usually simplify) or this is the only user of the + // select (in which case we are trading a select+icmp for a simpler + // select+icmp) or all uses of the select can be replaced based on + // dominance information ("Global cases"). + bool Transform = false; + if (Op1 && Op2) + Transform = true; + else if (Op1 || Op2) { + // Local case + if (SI->hasOneUse()) + Transform = true; + // Global cases + else if (CI && !CI->isZero()) + // When Op1 is constant try replacing select with second operand. + // Otherwise Op2 is constant and try replacing select with first + // operand. + Transform = replacedSelectWithOperand(SI, &I, Op1 ? 2 : 1); + } + if (Transform) { + if (!Op1) + Op1 = Builder.CreateICmp(Pred, SI->getOperand(1), RHS, I.getName()); + if (!Op2) + Op2 = Builder.CreateICmp(Pred, SI->getOperand(2), RHS, I.getName()); + return SelectInst::Create(SI->getOperand(0), Op1, Op2); + } + + return nullptr; +} + /// Some comparisons can be simplified. /// In this case, we are looking for comparisons that look like /// a check for a lossy truncation. @@ -3756,7 +3844,7 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, // Can we fold (XShAmt+YShAmt) ? auto *NewShAmt = dyn_cast_or_null<Constant>( - SimplifyAddInst(XShAmt, YShAmt, /*isNSW=*/false, + simplifyAddInst(XShAmt, YShAmt, /*isNSW=*/false, /*isNUW=*/false, SQ.getWithInstruction(&I))); if (!NewShAmt) return nullptr; @@ -3957,6 +4045,24 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, return new ICmpInst(Pred, X, Builder.CreateNot(Op0)); { + // (Op1 + X) + C u</u>= Op1 --> ~C - X u</u>= Op1 + Constant *C; + if (match(Op0, m_OneUse(m_Add(m_c_Add(m_Specific(Op1), m_Value(X)), + m_ImmConstant(C)))) && + (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) { + Constant *C2 = ConstantExpr::getNot(C); + return new ICmpInst(Pred, Builder.CreateSub(C2, X), Op1); + } + // Op0 u>/u<= (Op0 + X) + C --> Op0 u>/u<= ~C - X + if (match(Op1, m_OneUse(m_Add(m_c_Add(m_Specific(Op0), m_Value(X)), + m_ImmConstant(C)))) && + (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE)) { + Constant *C2 = ConstantExpr::getNot(C); + return new ICmpInst(Pred, Op0, Builder.CreateSub(C2, X)); + } + } + + { // Similar to above: an unsigned overflow comparison may use offset + mask: // ((Op1 + C) & C) u< Op1 --> Op1 != 0 // ((Op1 + C) & C) u>= Op1 --> Op1 == 0 @@ -4114,29 +4220,38 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, // icmp (A + C1), (C + C2) -> icmp A, (C + C3) // s.t. C3 = C2 - C1 if (A && C && NoOp0WrapProblem && NoOp1WrapProblem && - (BO0->hasOneUse() || BO1->hasOneUse()) && !I.isUnsigned()) - if (ConstantInt *C1 = dyn_cast<ConstantInt>(B)) - if (ConstantInt *C2 = dyn_cast<ConstantInt>(D)) { - const APInt &AP1 = C1->getValue(); - const APInt &AP2 = C2->getValue(); - if (AP1.isNegative() == AP2.isNegative()) { - APInt AP1Abs = C1->getValue().abs(); - APInt AP2Abs = C2->getValue().abs(); - if (AP1Abs.uge(AP2Abs)) { - ConstantInt *C3 = Builder.getInt(AP1 - AP2); - bool HasNUW = BO0->hasNoUnsignedWrap() && C3->getValue().ule(AP1); - bool HasNSW = BO0->hasNoSignedWrap(); - Value *NewAdd = Builder.CreateAdd(A, C3, "", HasNUW, HasNSW); - return new ICmpInst(Pred, NewAdd, C); - } else { - ConstantInt *C3 = Builder.getInt(AP2 - AP1); - bool HasNUW = BO1->hasNoUnsignedWrap() && C3->getValue().ule(AP2); - bool HasNSW = BO1->hasNoSignedWrap(); - Value *NewAdd = Builder.CreateAdd(C, C3, "", HasNUW, HasNSW); - return new ICmpInst(Pred, A, NewAdd); - } - } + (BO0->hasOneUse() || BO1->hasOneUse()) && !I.isUnsigned()) { + const APInt *AP1, *AP2; + // TODO: Support non-uniform vectors. + // TODO: Allow undef passthrough if B AND D's element is undef. + if (match(B, m_APIntAllowUndef(AP1)) && match(D, m_APIntAllowUndef(AP2)) && + AP1->isNegative() == AP2->isNegative()) { + APInt AP1Abs = AP1->abs(); + APInt AP2Abs = AP2->abs(); + if (AP1Abs.uge(AP2Abs)) { + APInt Diff = *AP1 - *AP2; + bool HasNUW = BO0->hasNoUnsignedWrap() && Diff.ule(*AP1); + bool HasNSW = BO0->hasNoSignedWrap(); + Constant *C3 = Constant::getIntegerValue(BO0->getType(), Diff); + Value *NewAdd = Builder.CreateAdd(A, C3, "", HasNUW, HasNSW); + return new ICmpInst(Pred, NewAdd, C); + } else { + APInt Diff = *AP2 - *AP1; + bool HasNUW = BO1->hasNoUnsignedWrap() && Diff.ule(*AP2); + bool HasNSW = BO1->hasNoSignedWrap(); + Constant *C3 = Constant::getIntegerValue(BO0->getType(), Diff); + Value *NewAdd = Builder.CreateAdd(C, C3, "", HasNUW, HasNSW); + return new ICmpInst(Pred, A, NewAdd); } + } + Constant *Cst1, *Cst2; + if (match(B, m_ImmConstant(Cst1)) && match(D, m_ImmConstant(Cst2)) && + ICmpInst::isEquality(Pred)) { + Constant *Diff = ConstantExpr::getSub(Cst2, Cst1); + Value *NewAdd = Builder.CreateAdd(C, Diff); + return new ICmpInst(Pred, A, NewAdd); + } + } // Analyze the case when either Op0 or Op1 is a sub instruction. // Op0 = A - B (or A and B are null); Op1 = C - D (or C and D are null). @@ -4524,18 +4639,21 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { // (A >> C) == (B >> C) --> (A^B) u< (1 << C) // For lshr and ashr pairs. - if ((match(Op0, m_OneUse(m_LShr(m_Value(A), m_ConstantInt(Cst1)))) && - match(Op1, m_OneUse(m_LShr(m_Value(B), m_Specific(Cst1))))) || - (match(Op0, m_OneUse(m_AShr(m_Value(A), m_ConstantInt(Cst1)))) && - match(Op1, m_OneUse(m_AShr(m_Value(B), m_Specific(Cst1)))))) { - unsigned TypeBits = Cst1->getBitWidth(); - unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); + const APInt *AP1, *AP2; + if ((match(Op0, m_OneUse(m_LShr(m_Value(A), m_APIntAllowUndef(AP1)))) && + match(Op1, m_OneUse(m_LShr(m_Value(B), m_APIntAllowUndef(AP2))))) || + (match(Op0, m_OneUse(m_AShr(m_Value(A), m_APIntAllowUndef(AP1)))) && + match(Op1, m_OneUse(m_AShr(m_Value(B), m_APIntAllowUndef(AP2)))))) { + if (AP1 != AP2) + return nullptr; + unsigned TypeBits = AP1->getBitWidth(); + unsigned ShAmt = AP1->getLimitedValue(TypeBits); if (ShAmt < TypeBits && ShAmt != 0) { ICmpInst::Predicate NewPred = Pred == ICmpInst::ICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; Value *Xor = Builder.CreateXor(A, B, I.getName() + ".unshifted"); APInt CmpVal = APInt::getOneBitSet(TypeBits, ShAmt); - return new ICmpInst(NewPred, Xor, Builder.getInt(CmpVal)); + return new ICmpInst(NewPred, Xor, ConstantInt::get(A->getType(), CmpVal)); } } @@ -4665,8 +4783,7 @@ static Instruction *foldICmpWithTrunc(ICmpInst &ICmp, return nullptr; } -static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp, - InstCombiner::BuilderTy &Builder) { +Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) { assert(isa<CastInst>(ICmp.getOperand(0)) && "Expected cast for operand 0"); auto *CastOp0 = cast<CastInst>(ICmp.getOperand(0)); Value *X; @@ -4675,25 +4792,37 @@ static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp, bool IsSignedExt = CastOp0->getOpcode() == Instruction::SExt; bool IsSignedCmp = ICmp.isSigned(); - if (auto *CastOp1 = dyn_cast<CastInst>(ICmp.getOperand(1))) { - // If the signedness of the two casts doesn't agree (i.e. one is a sext - // and the other is a zext), then we can't handle this. - // TODO: This is too strict. We can handle some predicates (equality?). - if (CastOp0->getOpcode() != CastOp1->getOpcode()) - return nullptr; + + // icmp Pred (ext X), (ext Y) + Value *Y; + if (match(ICmp.getOperand(1), m_ZExtOrSExt(m_Value(Y)))) { + bool IsZext0 = isa<ZExtOperator>(ICmp.getOperand(0)); + bool IsZext1 = isa<ZExtOperator>(ICmp.getOperand(1)); + + // If we have mismatched casts, treat the zext of a non-negative source as + // a sext to simulate matching casts. Otherwise, we are done. + // TODO: Can we handle some predicates (equality) without non-negative? + if (IsZext0 != IsZext1) { + if ((IsZext0 && isKnownNonNegative(X, DL, 0, &AC, &ICmp, &DT)) || + (IsZext1 && isKnownNonNegative(Y, DL, 0, &AC, &ICmp, &DT))) + IsSignedExt = true; + else + return nullptr; + } // Not an extension from the same type? - Value *Y = CastOp1->getOperand(0); Type *XTy = X->getType(), *YTy = Y->getType(); if (XTy != YTy) { // One of the casts must have one use because we are creating a new cast. - if (!CastOp0->hasOneUse() && !CastOp1->hasOneUse()) + if (!ICmp.getOperand(0)->hasOneUse() && !ICmp.getOperand(1)->hasOneUse()) return nullptr; // Extend the narrower operand to the type of the wider operand. + CastInst::CastOps CastOpcode = + IsSignedExt ? Instruction::SExt : Instruction::ZExt; if (XTy->getScalarSizeInBits() < YTy->getScalarSizeInBits()) - X = Builder.CreateCast(CastOp0->getOpcode(), X, YTy); + X = Builder.CreateCast(CastOpcode, X, YTy); else if (YTy->getScalarSizeInBits() < XTy->getScalarSizeInBits()) - Y = Builder.CreateCast(CastOp0->getOpcode(), Y, XTy); + Y = Builder.CreateCast(CastOpcode, Y, XTy); else return nullptr; } @@ -4742,7 +4871,7 @@ static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp, // or could not be determined to be equal (in the case of a constant // expression), so the constant cannot be represented in the shorter type. // All the cases that fold to true or false will have already been handled - // by SimplifyICmpInst, so only deal with the tricky case. + // by simplifyICmpInst, so only deal with the tricky case. if (IsSignedCmp || !IsSignedExt || !isa<ConstantInt>(C)) return nullptr; @@ -4811,7 +4940,7 @@ Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) { if (Instruction *R = foldICmpWithTrunc(ICmp, Builder)) return R; - return foldICmpWithZextOrSext(ICmp, Builder); + return foldICmpWithZextOrSext(ICmp); } static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) { @@ -5449,35 +5578,23 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { LHS = Op0; Value *X; - if (match(LHS, m_Shl(m_One(), m_Value(X)))) { - APInt ValToCheck = Op0KnownZeroInverted; + const APInt *C1; + if (match(LHS, m_Shl(m_Power2(C1), m_Value(X)))) { Type *XTy = X->getType(); - if (ValToCheck.isPowerOf2()) { - // ((1 << X) & 8) == 0 -> X != 3 - // ((1 << X) & 8) != 0 -> X == 3 - auto *CmpC = ConstantInt::get(XTy, ValToCheck.countTrailingZeros()); - auto NewPred = ICmpInst::getInversePredicate(Pred); - return new ICmpInst(NewPred, X, CmpC); - } else if ((++ValToCheck).isPowerOf2()) { - // ((1 << X) & 7) == 0 -> X >= 3 - // ((1 << X) & 7) != 0 -> X < 3 - auto *CmpC = ConstantInt::get(XTy, ValToCheck.countTrailingZeros()); + unsigned Log2C1 = C1->countTrailingZeros(); + APInt C2 = Op0KnownZeroInverted; + APInt C2Pow2 = (C2 & ~(*C1 - 1)) + *C1; + if (C2Pow2.isPowerOf2()) { + // iff (C1 is pow2) & ((C2 & ~(C1-1)) + C1) is pow2): + // ((C1 << X) & C2) == 0 -> X >= (Log2(C2+C1) - Log2(C1)) + // ((C1 << X) & C2) != 0 -> X < (Log2(C2+C1) - Log2(C1)) + unsigned Log2C2 = C2Pow2.countTrailingZeros(); + auto *CmpC = ConstantInt::get(XTy, Log2C2 - Log2C1); auto NewPred = Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGE : CmpInst::ICMP_ULT; return new ICmpInst(NewPred, X, CmpC); } } - - // Check if the LHS is 8 >>u x and the result is a power of 2 like 1. - const APInt *CI; - if (Op0KnownZeroInverted.isOne() && - match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) { - // ((8 >>u X) & 1) == 0 -> X != 3 - // ((8 >>u X) & 1) != 0 -> X == 3 - unsigned CmpVal = CI->countTrailingZeros(); - auto NewPred = ICmpInst::getInversePredicate(Pred); - return new ICmpInst(NewPred, X, ConstantInt::get(X->getType(), CmpVal)); - } } break; } @@ -5557,6 +5674,28 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { return nullptr; } +/// If one operand of an icmp is effectively a bool (value range of {0,1}), +/// then try to reduce patterns based on that limit. +static Instruction *foldICmpUsingBoolRange(ICmpInst &I, + InstCombiner::BuilderTy &Builder) { + Value *X, *Y; + ICmpInst::Predicate Pred; + + // X must be 0 and bool must be true for "ULT": + // X <u (zext i1 Y) --> (X == 0) & Y + if (match(&I, m_c_ICmp(Pred, m_Value(X), m_OneUse(m_ZExt(m_Value(Y))))) && + Y->getType()->isIntOrIntVectorTy(1) && Pred == ICmpInst::ICMP_ULT) + return BinaryOperator::CreateAnd(Builder.CreateIsNull(X), Y); + + // X must be 0 or bool must be true for "ULE": + // X <=u (sext i1 Y) --> (X == 0) | Y + if (match(&I, m_c_ICmp(Pred, m_Value(X), m_OneUse(m_SExt(m_Value(Y))))) && + Y->getType()->isIntOrIntVectorTy(1) && Pred == ICmpInst::ICMP_ULE) + return BinaryOperator::CreateOr(Builder.CreateIsNull(X), Y); + + return nullptr; +} + llvm::Optional<std::pair<CmpInst::Predicate, Constant *>> InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, Constant *C) { @@ -5948,7 +6087,7 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { Changed = true; } - if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, Q)) + if (Value *V = simplifyICmpInst(I.getPredicate(), Op0, Op1, Q)) return replaceInstUsesWith(I, V); // Comparing -val or val with non-zero is the same as just comparing val @@ -5984,6 +6123,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpWithDominatingICmp(I)) return Res; + if (Instruction *Res = foldICmpUsingBoolRange(I, Builder)) + return Res; + if (Instruction *Res = foldICmpUsingKnownBits(I)) return Res; @@ -6057,14 +6199,21 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *NI = foldGEPICmp(GEP, Op0, I.getSwappedPredicate(), I)) return NI; + if (auto *SI = dyn_cast<SelectInst>(Op0)) + if (Instruction *NI = foldSelectICmp(I.getPredicate(), SI, Op1, I)) + return NI; + if (auto *SI = dyn_cast<SelectInst>(Op1)) + if (Instruction *NI = foldSelectICmp(I.getSwappedPredicate(), SI, Op0, I)) + return NI; + // Try to optimize equality comparisons against alloca-based pointers. if (Op0->getType()->isPointerTy() && I.isEquality()) { assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?"); if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op0))) - if (Instruction *New = foldAllocaCmp(I, Alloca, Op1)) + if (Instruction *New = foldAllocaCmp(I, Alloca)) return New; if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op1))) - if (Instruction *New = foldAllocaCmp(I, Alloca, Op0)) + if (Instruction *New = foldAllocaCmp(I, Alloca)) return New; } @@ -6529,6 +6678,25 @@ static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) { } } +static Instruction *foldFCmpFNegCommonOp(FCmpInst &I) { + CmpInst::Predicate Pred = I.getPredicate(); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // Canonicalize fneg as Op1. + if (match(Op0, m_FNeg(m_Value())) && !match(Op1, m_FNeg(m_Value()))) { + std::swap(Op0, Op1); + Pred = I.getSwappedPredicate(); + } + + if (!match(Op1, m_FNeg(m_Specific(Op0)))) + return nullptr; + + // Replace the negated operand with 0.0: + // fcmp Pred Op0, -Op0 --> fcmp Pred Op0, 0.0 + Constant *Zero = ConstantFP::getNullValue(Op0->getType()); + return new FCmpInst(Pred, Op0, Zero, "", &I); +} + Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { bool Changed = false; @@ -6542,7 +6710,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { const CmpInst::Predicate Pred = I.getPredicate(); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (Value *V = SimplifyFCmpInst(Pred, Op0, Op1, I.getFastMathFlags(), + if (Value *V = simplifyFCmpInst(Pred, Op0, Op1, I.getFastMathFlags(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -6587,6 +6755,9 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) return new FCmpInst(I.getSwappedPredicate(), X, Y, "", &I); + if (Instruction *R = foldFCmpFNegCommonOp(I)) + return R; + // Test if the FCmpInst instruction is used exclusively by a select as // part of a minimum or maximum operation. If so, refrain from doing // any other folding. This helps out other analyses which understand @@ -6632,10 +6803,9 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { case Instruction::Load: if (auto *GEP = dyn_cast<GetElementPtrInst>(LHSI->getOperand(0))) if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) - if (GV->isConstant() && GV->hasDefinitiveInitializer() && - !cast<LoadInst>(LHSI)->isVolatile()) - if (Instruction *Res = foldCmpLoadFromIndexedGlobal(GEP, GV, I)) - return Res; + if (Instruction *Res = foldCmpLoadFromIndexedGlobal( + cast<LoadInst>(LHSI), GEP, GV, I)) + return Res; break; } } @@ -6657,7 +6827,6 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { if (match(Op1, m_FPExt(m_Value(Y))) && X->getType() == Y->getType()) return new FCmpInst(Pred, X, Y, "", &I); - // fcmp (fpext X), C -> fcmp X, (fptrunc C) if fptrunc is lossless const APFloat *C; if (match(Op1, m_APFloat(C))) { const fltSemantics &FPSem = @@ -6666,6 +6835,31 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { APFloat TruncC = *C; TruncC.convert(FPSem, APFloat::rmNearestTiesToEven, &Lossy); + if (Lossy) { + // X can't possibly equal the higher-precision constant, so reduce any + // equality comparison. + // TODO: Other predicates can be handled via getFCmpCode(). + switch (Pred) { + case FCmpInst::FCMP_OEQ: + // X is ordered and equal to an impossible constant --> false + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + case FCmpInst::FCMP_ONE: + // X is ordered and not equal to an impossible constant --> ordered + return new FCmpInst(FCmpInst::FCMP_ORD, X, + ConstantFP::getNullValue(X->getType())); + case FCmpInst::FCMP_UEQ: + // X is unordered or equal to an impossible constant --> unordered + return new FCmpInst(FCmpInst::FCMP_UNO, X, + ConstantFP::getNullValue(X->getType())); + case FCmpInst::FCMP_UNE: + // X is unordered or not equal to an impossible constant --> true + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + default: + break; + } + } + + // fcmp (fpext X), C -> fcmp X, (fptrunc C) if fptrunc is lossless // Avoid lossy conversions and denormals. // Zero is a special case that's OK to convert. APFloat Fabs = TruncC; diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 7743b4c41555..271154bb3f5a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -71,7 +71,7 @@ public: : InstCombiner(Worklist, Builder, MinimizeSize, AA, AC, TLI, TTI, DT, ORE, BFI, PSI, DL, LI) {} - virtual ~InstCombinerImpl() {} + virtual ~InstCombinerImpl() = default; /// Run the combiner over the entire worklist until it is empty. /// @@ -172,7 +172,8 @@ public: Instruction *visitLandingPadInst(LandingPadInst &LI); Instruction *visitVAEndInst(VAEndInst &I); Value *pushFreezeToPreventPoisonFromPropagating(FreezeInst &FI); - bool freezeDominatedUses(FreezeInst &FI); + bool freezeOtherUses(FreezeInst &FI); + Instruction *foldFreezeIntoRecurrence(FreezeInst &I, PHINode *PN); Instruction *visitFreeze(FreezeInst &I); /// Specify what to return for unhandled instructions. @@ -192,7 +193,7 @@ public: const Twine &Suffix = ""); private: - void annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI); + bool annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI); bool isDesirableIntType(unsigned BitWidth) const; bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const; bool shouldChangeType(Type *From, Type *To) const; @@ -325,7 +326,7 @@ private: Instruction *narrowMathIfNoOverflow(BinaryOperator &I); Instruction *narrowFunnelShift(TruncInst &Trunc); Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); - Instruction *matchSAddSubSat(Instruction &MinMax1); + Instruction *matchSAddSubSat(IntrinsicInst &MinMax1); Instruction *foldNot(BinaryOperator &I); void freelyInvertAllUsersOf(Value *V); @@ -344,16 +345,20 @@ private: const CastInst *CI2); Value *simplifyIntToPtrRoundTripCast(Value *Val); - Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &And); - Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &Or); + Value *foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &I, + bool IsAnd, bool IsLogical = false); Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &Xor); Value *foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd); + Value *foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1, ICmpInst *ICmp2, + bool IsAnd); + /// Optimize (fcmp)&(fcmp) or (fcmp)|(fcmp). /// NOTE: Unlike most of instcombine, this returns a Value which should /// already be inserted into the function. - Value *foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd); + Value *foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd, + bool IsLogicalSelect = false); Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, Instruction *CxtI, bool IsAnd, @@ -407,7 +412,7 @@ public: // If we are replacing the instruction with itself, this must be in a // segment of unreachable code, so just clobber the instruction. if (&I == V) - V = UndefValue::get(I.getType()); + V = PoisonValue::get(I.getType()); LLVM_DEBUG(dbgs() << "IC: Replacing " << I << "\n" << " with " << *V << '\n'); @@ -435,7 +440,7 @@ public: void CreateNonTerminatorUnreachable(Instruction *InsertAt) { auto &Ctx = InsertAt->getContext(); new StoreInst(ConstantInt::getTrue(Ctx), - UndefValue::get(Type::getInt1PtrTy(Ctx)), + PoisonValue::get(Type::getInt1PtrTy(Ctx)), InsertAt); } @@ -621,7 +626,8 @@ public: /// other operand, try to fold the binary operator into the select arguments. /// This also works for Cast instructions, which obviously do not have a /// second operand. - Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI); + Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI, + bool FoldWithMultiUse = false); /// This is a convenience wrapper function for the above two functions. Instruction *foldBinOpIntoSelectOrPhi(BinaryOperator &I); @@ -650,22 +656,27 @@ public: Instruction *foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, ICmpInst::Predicate Cond, Instruction &I); - Instruction *foldAllocaCmp(ICmpInst &ICI, const AllocaInst *Alloca, - const Value *Other); - Instruction *foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, + Instruction *foldSelectICmp(ICmpInst::Predicate Pred, SelectInst *SI, + Value *RHS, const ICmpInst &I); + Instruction *foldAllocaCmp(ICmpInst &ICI, const AllocaInst *Alloca); + Instruction *foldCmpLoadFromIndexedGlobal(LoadInst *LI, + GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI, ConstantInt *AndCst = nullptr); Instruction *foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, Constant *RHSC); Instruction *foldICmpAddOpConst(Value *X, const APInt &C, ICmpInst::Predicate Pred); - Instruction *foldICmpWithCastOp(ICmpInst &ICI); + Instruction *foldICmpWithCastOp(ICmpInst &ICmp); + Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp); Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp); Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp); Instruction *foldICmpWithConstant(ICmpInst &Cmp); Instruction *foldICmpInstWithConstant(ICmpInst &Cmp); Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp); + Instruction *foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp, + const APInt &C); Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ); Instruction *foldICmpEquality(ICmpInst &Cmp); Instruction *foldIRemByPowerOfTwoToBitTest(ICmpInst &I); @@ -674,6 +685,8 @@ public: Value *foldMultiplicationOverflowCheck(ICmpInst &Cmp); + Instruction *foldICmpBinOpWithConstant(ICmpInst &Cmp, BinaryOperator *BO, + const APInt &C); Instruction *foldICmpSelectConstant(ICmpInst &Cmp, SelectInst *Select, ConstantInt *C); Instruction *foldICmpTruncConstant(ICmpInst &Cmp, TruncInst *Trunc, diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 756792918dba..e03b7026f802 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -16,15 +16,12 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/Loads.h" -#include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" -#include "llvm/IR/MDBuilder.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; using namespace PatternMatch; @@ -775,7 +772,7 @@ static bool isObjectSizeLessThanOrEq(Value *V, uint64_t MaxSize, uint64_t TypeSize = DL.getTypeAllocSize(AI->getAllocatedType()); // Make sure that, even if the multiplication below would wrap as an // uint64_t, we still do the right thing. - if ((CS->getValue().zextOrSelf(128)*APInt(128, TypeSize)).ugt(MaxSize)) + if ((CS->getValue().zext(128) * APInt(128, TypeSize)).ugt(MaxSize)) return false; continue; } @@ -1395,8 +1392,10 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) { if (StoreInst *PrevSI = dyn_cast<StoreInst>(BBI)) { // Prev store isn't volatile, and stores to the same location? - if (PrevSI->isUnordered() && equivalentAddressValues(PrevSI->getOperand(1), - SI.getOperand(1))) { + if (PrevSI->isUnordered() && + equivalentAddressValues(PrevSI->getOperand(1), SI.getOperand(1)) && + PrevSI->getValueOperand()->getType() == + SI.getValueOperand()->getType()) { ++NumDeadStore; // Manually add back the original store to the worklist now, so it will // be processed after the operands of the removed store, as this may @@ -1436,6 +1435,8 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) { } // store undef, Ptr -> noop + // FIXME: This is technically incorrect because it might overwrite a poison + // value. Change to PoisonValue once #52930 is resolved. if (isa<UndefValue>(Val)) return eraseInstFromFunction(SI); diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 1aa10b550fc4..2a34edbf6cb8 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -12,7 +12,6 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" -#include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -30,13 +29,9 @@ #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/KnownBits.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include <cassert> -#include <cstddef> -#include <cstdint> -#include <utility> #define DEBUG_TYPE "instcombine" #include "llvm/Transforms/Utils/InstructionWorklist.h" @@ -145,7 +140,7 @@ static Value *foldMulSelectToNegate(BinaryOperator &I, } Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { - if (Value *V = SimplifyMulInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifyMulInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -297,15 +292,24 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { auto RemOpc = Div->getOpcode() == Instruction::UDiv ? Instruction::URem : Instruction::SRem; - Value *Rem = Builder.CreateBinOp(RemOpc, X, DivOp1); + // X must be frozen because we are increasing its number of uses. + Value *XFreeze = Builder.CreateFreeze(X, X->getName() + ".fr"); + Value *Rem = Builder.CreateBinOp(RemOpc, XFreeze, DivOp1); if (DivOp1 == Y) - return BinaryOperator::CreateSub(X, Rem); - return BinaryOperator::CreateSub(Rem, X); + return BinaryOperator::CreateSub(XFreeze, Rem); + return BinaryOperator::CreateSub(Rem, XFreeze); } } - /// i1 mul -> i1 and. - if (I.getType()->isIntOrIntVectorTy(1)) + // Fold the following two scenarios: + // 1) i1 mul -> i1 and. + // 2) X * Y --> X & Y, iff X, Y can be only {0,1}. + // Note: We could use known bits to generalize this and related patterns with + // shifts/truncs + Type *Ty = I.getType(); + if (Ty->isIntOrIntVectorTy(1) || + (match(Op0, m_And(m_Value(), m_One())) && + match(Op1, m_And(m_Value(), m_One())))) return BinaryOperator::CreateAnd(Op0, Op1); // X*(1 << Y) --> X << Y @@ -338,7 +342,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType() && (Op0->hasOneUse() || Op1->hasOneUse() || X == Y)) { Value *And = Builder.CreateAnd(X, Y, "mulbool"); - return CastInst::Create(Instruction::ZExt, And, I.getType()); + return CastInst::Create(Instruction::ZExt, And, Ty); } // (sext bool X) * (zext bool Y) --> sext (and X, Y) // (zext bool X) * (sext bool Y) --> sext (and X, Y) @@ -348,42 +352,56 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType() && (Op0->hasOneUse() || Op1->hasOneUse())) { Value *And = Builder.CreateAnd(X, Y, "mulbool"); - return CastInst::Create(Instruction::SExt, And, I.getType()); + return CastInst::Create(Instruction::SExt, And, Ty); } // (zext bool X) * Y --> X ? Y : 0 // Y * (zext bool X) --> X ? Y : 0 if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) - return SelectInst::Create(X, Op1, ConstantInt::get(I.getType(), 0)); + return SelectInst::Create(X, Op1, ConstantInt::getNullValue(Ty)); if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) - return SelectInst::Create(X, Op0, ConstantInt::get(I.getType(), 0)); + return SelectInst::Create(X, Op0, ConstantInt::getNullValue(Ty)); - // (sext bool X) * C --> X ? -C : 0 Constant *ImmC; - if (match(Op0, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1) && - match(Op1, m_ImmConstant(ImmC))) { - Constant *NegC = ConstantExpr::getNeg(ImmC); - return SelectInst::Create(X, NegC, ConstantInt::getNullValue(I.getType())); + if (match(Op1, m_ImmConstant(ImmC))) { + // (sext bool X) * C --> X ? -C : 0 + if (match(Op0, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + Constant *NegC = ConstantExpr::getNeg(ImmC); + return SelectInst::Create(X, NegC, ConstantInt::getNullValue(Ty)); + } + + // (ashr i32 X, 31) * C --> (X < 0) ? -C : 0 + const APInt *C; + if (match(Op0, m_OneUse(m_AShr(m_Value(X), m_APInt(C)))) && + *C == C->getBitWidth() - 1) { + Constant *NegC = ConstantExpr::getNeg(ImmC); + Value *IsNeg = Builder.CreateIsNeg(X, "isneg"); + return SelectInst::Create(IsNeg, NegC, ConstantInt::getNullValue(Ty)); + } } - // (lshr X, 31) * Y --> (ashr X, 31) & Y - // Y * (lshr X, 31) --> (ashr X, 31) & Y + // (lshr X, 31) * Y --> (X < 0) ? Y : 0 // TODO: We are not checking one-use because the elimination of the multiply // is better for analysis? - // TODO: Should we canonicalize to '(X < 0) ? Y : 0' instead? That would be - // more similar to what we're doing above. const APInt *C; - if (match(Op0, m_LShr(m_Value(X), m_APInt(C))) && *C == C->getBitWidth() - 1) - return BinaryOperator::CreateAnd(Builder.CreateAShr(X, *C), Op1); - if (match(Op1, m_LShr(m_Value(X), m_APInt(C))) && *C == C->getBitWidth() - 1) - return BinaryOperator::CreateAnd(Builder.CreateAShr(X, *C), Op0); + if (match(&I, m_c_BinOp(m_LShr(m_Value(X), m_APInt(C)), m_Value(Y))) && + *C == C->getBitWidth() - 1) { + Value *IsNeg = Builder.CreateIsNeg(X, "isneg"); + return SelectInst::Create(IsNeg, Y, ConstantInt::getNullValue(Ty)); + } + + // (and X, 1) * Y --> (trunc X) ? Y : 0 + if (match(&I, m_c_BinOp(m_OneUse(m_And(m_Value(X), m_One())), m_Value(Y)))) { + Value *Tr = Builder.CreateTrunc(X, CmpInst::makeCmpResultType(Ty)); + return SelectInst::Create(Tr, Y, ConstantInt::getNullValue(Ty)); + } // ((ashr X, 31) | 1) * X --> abs(X) // X * ((ashr X, 31) | 1) --> abs(X) if (match(&I, m_c_BinOp(m_Or(m_AShr(m_Value(X), - m_SpecificIntAllowUndef(BitWidth - 1)), - m_One()), - m_Deferred(X)))) { + m_SpecificIntAllowUndef(BitWidth - 1)), + m_One()), + m_Deferred(X)))) { Value *Abs = Builder.CreateBinaryIntrinsic( Intrinsic::abs, X, ConstantInt::getBool(I.getContext(), I.hasNoSignedWrap())); @@ -442,7 +460,7 @@ Instruction *InstCombinerImpl::foldFPSignBitOps(BinaryOperator &I) { } Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { - if (Value *V = SimplifyFMulInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifyFMulInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -532,9 +550,8 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { // sqrt(X) * sqrt(Y) -> sqrt(X * Y) // nnan disallows the possibility of returning a number if both operands are // negative (in that case, we should return NaN). - if (I.hasNoNaNs() && - match(Op0, m_OneUse(m_Intrinsic<Intrinsic::sqrt>(m_Value(X)))) && - match(Op1, m_OneUse(m_Intrinsic<Intrinsic::sqrt>(m_Value(Y))))) { + if (I.hasNoNaNs() && match(Op0, m_OneUse(m_Sqrt(m_Value(X)))) && + match(Op1, m_OneUse(m_Sqrt(m_Value(Y))))) { Value *XY = Builder.CreateFMulFMF(X, Y, &I); Value *Sqrt = Builder.CreateUnaryIntrinsic(Intrinsic::sqrt, XY, &I); return replaceInstUsesWith(I, Sqrt); @@ -548,11 +565,11 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { // has the necessary (reassoc) fast-math-flags. if (I.hasNoSignedZeros() && match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) && - match(Y, m_Intrinsic<Intrinsic::sqrt>(m_Value(X))) && Op1 == X) + match(Y, m_Sqrt(m_Value(X))) && Op1 == X) return BinaryOperator::CreateFDivFMF(X, Y, &I); if (I.hasNoSignedZeros() && match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) && - match(Y, m_Intrinsic<Intrinsic::sqrt>(m_Value(X))) && Op0 == X) + match(Y, m_Sqrt(m_Value(X))) && Op0 == X) return BinaryOperator::CreateFDivFMF(X, Y, &I); // Like the similar transform in instsimplify, this requires 'nsz' because @@ -561,14 +578,12 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { Op0->hasNUses(2)) { // Peek through fdiv to find squaring of square root: // (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y - if (match(Op0, m_FDiv(m_Value(X), - m_Intrinsic<Intrinsic::sqrt>(m_Value(Y))))) { + if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y))))) { Value *XX = Builder.CreateFMulFMF(X, X, &I); return BinaryOperator::CreateFDivFMF(XX, Y, &I); } // (sqrt(Y) / X) * (sqrt(Y) / X) --> Y / (X * X) - if (match(Op0, m_FDiv(m_Intrinsic<Intrinsic::sqrt>(m_Value(Y)), - m_Value(X)))) { + if (match(Op0, m_FDiv(m_Sqrt(m_Value(Y)), m_Value(X)))) { Value *XX = Builder.CreateFMulFMF(X, X, &I); return BinaryOperator::CreateFDivFMF(Y, XX, &I); } @@ -777,7 +792,8 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { // TODO: Adapt simplifyDivRemOfSelectWithZeroOp to allow this and other folds. if (match(Op0, m_ImmConstant()) && match(Op1, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant()))) { - if (Instruction *R = FoldOpIntoSelect(I, cast<SelectInst>(Op1))) + if (Instruction *R = FoldOpIntoSelect(I, cast<SelectInst>(Op1), + /*FoldWithMultiUse*/ true)) return R; } @@ -853,12 +869,13 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { if (match(Op0, m_One())) { assert(!Ty->isIntOrIntVectorTy(1) && "i1 divide not removed?"); if (IsSigned) { - // If Op1 is 0 then it's undefined behaviour, if Op1 is 1 then the - // result is one, if Op1 is -1 then the result is minus one, otherwise - // it's zero. - Value *Inc = Builder.CreateAdd(Op1, Op0); + // 1 / 0 --> undef ; 1 / 1 --> 1 ; 1 / -1 --> -1 ; 1 / anything else --> 0 + // (Op1 + 1) u< 3 ? Op1 : 0 + // Op1 must be frozen because we are increasing its number of uses. + Value *F1 = Builder.CreateFreeze(Op1, Op1->getName() + ".fr"); + Value *Inc = Builder.CreateAdd(F1, Op0); Value *Cmp = Builder.CreateICmpULT(Inc, ConstantInt::get(Ty, 3)); - return SelectInst::Create(Cmp, Op1, ConstantInt::get(Ty, 0)); + return SelectInst::Create(Cmp, F1, ConstantInt::get(Ty, 0)); } else { // If Op1 is 0 then it's undefined behaviour. If Op1 is 1 then the // result is one, otherwise it's zero. @@ -900,113 +917,69 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { static const unsigned MaxDepth = 6; -namespace { - -using FoldUDivOperandCb = Instruction *(*)(Value *Op0, Value *Op1, - const BinaryOperator &I, - InstCombinerImpl &IC); - -/// Used to maintain state for visitUDivOperand(). -struct UDivFoldAction { - /// Informs visitUDiv() how to fold this operand. This can be zero if this - /// action joins two actions together. - FoldUDivOperandCb FoldAction; - - /// Which operand to fold. - Value *OperandToFold; - - union { - /// The instruction returned when FoldAction is invoked. - Instruction *FoldResult; - - /// Stores the LHS action index if this action joins two actions together. - size_t SelectLHSIdx; +// Take the exact integer log2 of the value. If DoFold is true, create the +// actual instructions, otherwise return a non-null dummy value. Return nullptr +// on failure. +static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, + bool DoFold) { + auto IfFold = [DoFold](function_ref<Value *()> Fn) { + if (!DoFold) + return reinterpret_cast<Value *>(-1); + return Fn(); }; - UDivFoldAction(FoldUDivOperandCb FA, Value *InputOperand) - : FoldAction(FA), OperandToFold(InputOperand), FoldResult(nullptr) {} - UDivFoldAction(FoldUDivOperandCb FA, Value *InputOperand, size_t SLHS) - : FoldAction(FA), OperandToFold(InputOperand), SelectLHSIdx(SLHS) {} -}; - -} // end anonymous namespace - -// X udiv 2^C -> X >> C -static Instruction *foldUDivPow2Cst(Value *Op0, Value *Op1, - const BinaryOperator &I, - InstCombinerImpl &IC) { - Constant *C1 = ConstantExpr::getExactLogBase2(cast<Constant>(Op1)); - if (!C1) - llvm_unreachable("Failed to constant fold udiv -> logbase2"); - BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, C1); - if (I.isExact()) - LShr->setIsExact(); - return LShr; -} - -// X udiv (C1 << N), where C1 is "1<<C2" --> X >> (N+C2) -// X udiv (zext (C1 << N)), where C1 is "1<<C2" --> X >> (N+C2) -static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, - InstCombinerImpl &IC) { - Value *ShiftLeft; - if (!match(Op1, m_ZExt(m_Value(ShiftLeft)))) - ShiftLeft = Op1; - - Constant *CI; - Value *N; - if (!match(ShiftLeft, m_Shl(m_Constant(CI), m_Value(N)))) - llvm_unreachable("match should never fail here!"); - Constant *Log2Base = ConstantExpr::getExactLogBase2(CI); - if (!Log2Base) - llvm_unreachable("getLogBase2 should never fail here!"); - N = IC.Builder.CreateAdd(N, Log2Base); - if (Op1 != ShiftLeft) - N = IC.Builder.CreateZExt(N, Op1->getType()); - BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, N); - if (I.isExact()) - LShr->setIsExact(); - return LShr; -} - -// Recursively visits the possible right hand operands of a udiv -// instruction, seeing through select instructions, to determine if we can -// replace the udiv with something simpler. If we find that an operand is not -// able to simplify the udiv, we abort the entire transformation. -static size_t visitUDivOperand(Value *Op0, Value *Op1, const BinaryOperator &I, - SmallVectorImpl<UDivFoldAction> &Actions, - unsigned Depth = 0) { // FIXME: assert that Op1 isn't/doesn't contain undef. - // Check to see if this is an unsigned division with an exact power of 2, - // if so, convert to a right shift. - if (match(Op1, m_Power2())) { - Actions.push_back(UDivFoldAction(foldUDivPow2Cst, Op1)); - return Actions.size(); - } - - // X udiv (C1 << N), where C1 is "1<<C2" --> X >> (N+C2) - if (match(Op1, m_Shl(m_Power2(), m_Value())) || - match(Op1, m_ZExt(m_Shl(m_Power2(), m_Value())))) { - Actions.push_back(UDivFoldAction(foldUDivShl, Op1)); - return Actions.size(); - } + // log2(2^C) -> C + if (match(Op, m_Power2())) + return IfFold([&]() { + Constant *C = ConstantExpr::getExactLogBase2(cast<Constant>(Op)); + if (!C) + llvm_unreachable("Failed to constant fold udiv -> logbase2"); + return C; + }); // The remaining tests are all recursive, so bail out if we hit the limit. if (Depth++ == MaxDepth) - return 0; - - if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) - // FIXME: missed optimization: if one of the hands of select is/contains - // undef, just directly pick the other one. - // FIXME: can both hands contain undef? - if (size_t LHSIdx = - visitUDivOperand(Op0, SI->getOperand(1), I, Actions, Depth)) - if (visitUDivOperand(Op0, SI->getOperand(2), I, Actions, Depth)) { - Actions.push_back(UDivFoldAction(nullptr, Op1, LHSIdx - 1)); - return Actions.size(); - } + return nullptr; + + // log2(zext X) -> zext log2(X) + // FIXME: Require one use? + Value *X, *Y; + if (match(Op, m_ZExt(m_Value(X)))) + if (Value *LogX = takeLog2(Builder, X, Depth, DoFold)) + return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); }); + + // log2(X << Y) -> log2(X) + Y + // FIXME: Require one use unless X is 1? + if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) + if (Value *LogX = takeLog2(Builder, X, Depth, DoFold)) + return IfFold([&]() { return Builder.CreateAdd(LogX, Y); }); + + // log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y) + // FIXME: missed optimization: if one of the hands of select is/contains + // undef, just directly pick the other one. + // FIXME: can both hands contain undef? + // FIXME: Require one use? + if (SelectInst *SI = dyn_cast<SelectInst>(Op)) + if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, DoFold)) + if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, DoFold)) + return IfFold([&]() { + return Builder.CreateSelect(SI->getOperand(0), LogX, LogY); + }); + + // log2(umin(X, Y)) -> umin(log2(X), log2(Y)) + // log2(umax(X, Y)) -> umax(log2(X), log2(Y)) + auto *MinMax = dyn_cast<MinMaxIntrinsic>(Op); + if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) + if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, DoFold)) + if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, DoFold)) + return IfFold([&]() { + return Builder.CreateBinaryIntrinsic( + MinMax->getIntrinsicID(), LogX, LogY); + }); - return 0; + return nullptr; } /// If we have zero-extended operands of an unsigned div or rem, we may be able @@ -1047,7 +1020,7 @@ static Instruction *narrowUDivURem(BinaryOperator &I, } Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { - if (Value *V = SimplifyUDivInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifyUDivInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1106,42 +1079,18 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { return BinaryOperator::CreateUDiv(A, X); } - // (LHS udiv (select (select (...)))) -> (LHS >> (select (select (...)))) - SmallVector<UDivFoldAction, 6> UDivActions; - if (visitUDivOperand(Op0, Op1, I, UDivActions)) - for (unsigned i = 0, e = UDivActions.size(); i != e; ++i) { - FoldUDivOperandCb Action = UDivActions[i].FoldAction; - Value *ActionOp1 = UDivActions[i].OperandToFold; - Instruction *Inst; - if (Action) - Inst = Action(Op0, ActionOp1, I, *this); - else { - // This action joins two actions together. The RHS of this action is - // simply the last action we processed, we saved the LHS action index in - // the joining action. - size_t SelectRHSIdx = i - 1; - Value *SelectRHS = UDivActions[SelectRHSIdx].FoldResult; - size_t SelectLHSIdx = UDivActions[i].SelectLHSIdx; - Value *SelectLHS = UDivActions[SelectLHSIdx].FoldResult; - Inst = SelectInst::Create(cast<SelectInst>(ActionOp1)->getCondition(), - SelectLHS, SelectRHS); - } - - // If this is the last action to process, return it to the InstCombiner. - // Otherwise, we insert it before the UDiv and record it so that we may - // use it as part of a joining action (i.e., a SelectInst). - if (e - i != 1) { - Inst->insertBefore(&I); - UDivActions[i].FoldResult = Inst; - } else - return Inst; - } + // Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away. + if (takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/false)) { + Value *Res = takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/true); + return replaceInstUsesWith( + I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact())); + } return nullptr; } Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { - if (Value *V = SimplifySDivInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifySDivInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1223,9 +1172,9 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { if (match(&I, m_c_BinOp( m_OneUse(m_Intrinsic<Intrinsic::abs>(m_Value(X), m_One())), m_Deferred(X)))) { - Constant *NegOne = ConstantInt::getAllOnesValue(Ty); - Value *Cond = Builder.CreateICmpSGT(X, NegOne); - return SelectInst::Create(Cond, ConstantInt::get(Ty, 1), NegOne); + Value *Cond = Builder.CreateIsNotNeg(X); + return SelectInst::Create(Cond, ConstantInt::get(Ty, 1), + ConstantInt::getAllOnesValue(Ty)); } // If the sign bits of both operands are zero (i.e. we can prove they are @@ -1242,8 +1191,10 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { if (match(Op1, m_NegatedPower2())) { // X sdiv (-(1 << C)) -> -(X sdiv (1 << C)) -> // -> -(X udiv (1 << C)) -> -(X u>> C) - return BinaryOperator::CreateNeg(Builder.Insert(foldUDivPow2Cst( - Op0, ConstantExpr::getNeg(cast<Constant>(Op1)), I, *this))); + Constant *CNegLog2 = ConstantExpr::getExactLogBase2( + ConstantExpr::getNeg(cast<Constant>(Op1))); + Value *Shr = Builder.CreateLShr(Op0, CNegLog2, I.getName(), I.isExact()); + return BinaryOperator::CreateNeg(Shr); } if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) { @@ -1368,7 +1319,9 @@ static Instruction *foldFDivPowDivisor(BinaryOperator &I, } Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) { - if (Value *V = SimplifyFDivInst(I.getOperand(0), I.getOperand(1), + Module *M = I.getModule(); + + if (Value *V = simplifyFDivInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1433,8 +1386,8 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) { !IsTan && match(Op0, m_Intrinsic<Intrinsic::cos>(m_Value(X))) && match(Op1, m_Intrinsic<Intrinsic::sin>(m_Specific(X))); - if ((IsTan || IsCot) && - hasFloatFn(&TLI, I.getType(), LibFunc_tan, LibFunc_tanf, LibFunc_tanl)) { + if ((IsTan || IsCot) && hasFloatFn(M, &TLI, I.getType(), LibFunc_tan, + LibFunc_tanf, LibFunc_tanl)) { IRBuilder<> B(&I); IRBuilder<>::FastMathFlagGuard FMFGuard(B); B.setFastMathFlags(I.getFastMathFlags()); @@ -1498,7 +1451,8 @@ Instruction *InstCombinerImpl::commonIRemTransforms(BinaryOperator &I) { // TODO: Adapt simplifyDivRemOfSelectWithZeroOp to allow this and other folds. if (match(Op0, m_ImmConstant()) && match(Op1, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant()))) { - if (Instruction *R = FoldOpIntoSelect(I, cast<SelectInst>(Op1))) + if (Instruction *R = FoldOpIntoSelect(I, cast<SelectInst>(Op1), + /*FoldWithMultiUse*/ true)) return R; } @@ -1530,7 +1484,7 @@ Instruction *InstCombinerImpl::commonIRemTransforms(BinaryOperator &I) { } Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) { - if (Value *V = SimplifyURemInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifyURemInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1560,11 +1514,13 @@ Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) { return CastInst::CreateZExtOrBitCast(Cmp, Ty); } - // X urem C -> X < C ? X : X - C, where C >= signbit. + // Op0 urem C -> Op0 < C ? Op0 : Op0 - C, where C >= signbit. + // Op0 must be frozen because we are increasing its number of uses. if (match(Op1, m_Negative())) { - Value *Cmp = Builder.CreateICmpULT(Op0, Op1); - Value *Sub = Builder.CreateSub(Op0, Op1); - return SelectInst::Create(Cmp, Op0, Sub); + Value *F0 = Builder.CreateFreeze(Op0, Op0->getName() + ".fr"); + Value *Cmp = Builder.CreateICmpULT(F0, Op1); + Value *Sub = Builder.CreateSub(F0, Op1); + return SelectInst::Create(Cmp, F0, Sub); } // If the divisor is a sext of a boolean, then the divisor must be max @@ -1581,7 +1537,7 @@ Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) { } Instruction *InstCombinerImpl::visitSRem(BinaryOperator &I) { - if (Value *V = SimplifySRemInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifySRemInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1653,7 +1609,7 @@ Instruction *InstCombinerImpl::visitSRem(BinaryOperator &I) { } Instruction *InstCombinerImpl::visitFRem(BinaryOperator &I) { - if (Value *V = SimplifyFRemInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifyFRemInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp index 42ba4a34a5a9..c573b03f31a6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp @@ -248,6 +248,20 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { return nullptr; switch (I->getOpcode()) { + case Instruction::And: { + Constant *ShAmt; + // sub(y,and(lshr(x,C),1)) --> add(ashr(shl(x,(BW-1)-C),BW-1),y) + if (match(I, m_c_And(m_OneUse(m_TruncOrSelf( + m_LShr(m_Value(X), m_ImmConstant(ShAmt)))), + m_One()))) { + unsigned BW = X->getType()->getScalarSizeInBits(); + Constant *BWMinusOne = ConstantInt::get(X->getType(), BW - 1); + Value *R = Builder.CreateShl(X, Builder.CreateSub(BWMinusOne, ShAmt)); + R = Builder.CreateAShr(R, BWMinusOne); + return Builder.CreateTruncOrBitCast(R, I->getType()); + } + break; + } case Instruction::SDiv: // `sdiv` is negatible if divisor is not undef/INT_MIN/1. // While this is normally not behind a use-check, diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 09694d50468f..90a796a0939e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -511,7 +511,8 @@ Instruction *InstCombinerImpl::foldPHIArgGEPIntoPHI(PHINode &PN) { // Scan to see if all operands are the same opcode, and all have one user. for (Value *V : drop_begin(PN.incoming_values())) { GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V); - if (!GEP || !GEP->hasOneUser() || GEP->getType() != FirstInst->getType() || + if (!GEP || !GEP->hasOneUser() || + GEP->getSourceElementType() != FirstInst->getSourceElementType() || GEP->getNumOperands() != FirstInst->getNumOperands()) return nullptr; @@ -657,6 +658,10 @@ static bool isSafeAndProfitableToSinkLoad(LoadInst *L) { Instruction *InstCombinerImpl::foldPHIArgLoadIntoPHI(PHINode &PN) { LoadInst *FirstLI = cast<LoadInst>(PN.getIncomingValue(0)); + // Can't forward swifterror through a phi. + if (FirstLI->getOperand(0)->isSwiftError()) + return nullptr; + // FIXME: This is overconservative; this transform is allowed in some cases // for atomic operations. if (FirstLI->isAtomic()) @@ -693,6 +698,10 @@ Instruction *InstCombinerImpl::foldPHIArgLoadIntoPHI(PHINode &PN) { LI->getPointerAddressSpace() != LoadAddrSpace) return nullptr; + // Can't forward swifterror through a phi. + if (LI->getOperand(0)->isSwiftError()) + return nullptr; + // We can't sink the load if the loaded value could be modified between // the load and the PHI. if (LI->getParent() != InBB || !isSafeAndProfitableToSinkLoad(LI)) @@ -1112,6 +1121,13 @@ Instruction *InstCombinerImpl::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { return nullptr; } + // If the incoming value is a PHI node before a catchswitch, we cannot + // extract the value within that BB because we cannot insert any non-PHI + // instructions in the BB. + for (auto *Pred : PN->blocks()) + if (Pred->getFirstInsertionPt() == Pred->end()) + return nullptr; + for (User *U : PN->users()) { Instruction *UserI = cast<Instruction>(U); @@ -1260,12 +1276,12 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN, // ... ... // \ / // phi [true] [false] - if (!PN.getType()->isIntegerTy(1)) - return nullptr; - - if (PN.getNumOperands() != 2) - return nullptr; - + // and + // switch (cond) + // case v1: / \ case v2: + // ... ... + // \ / + // phi [v1] [v2] // Make sure all inputs are constants. if (!all_of(PN.operands(), [](Value *V) { return isa<ConstantInt>(V); })) return nullptr; @@ -1275,50 +1291,77 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN, if (!DT.isReachableFromEntry(BB)) return nullptr; - // Same inputs. - if (PN.getOperand(0) == PN.getOperand(1)) - return PN.getOperand(0); + // Determine which value the condition of the idom has for which successor. + LLVMContext &Context = PN.getContext(); + auto *IDom = DT.getNode(BB)->getIDom()->getBlock(); + Value *Cond; + SmallDenseMap<ConstantInt *, BasicBlock *, 8> SuccForValue; + SmallDenseMap<BasicBlock *, unsigned, 8> SuccCount; + auto AddSucc = [&](ConstantInt *C, BasicBlock *Succ) { + SuccForValue[C] = Succ; + ++SuccCount[Succ]; + }; + if (auto *BI = dyn_cast<BranchInst>(IDom->getTerminator())) { + if (BI->isUnconditional()) + return nullptr; - BasicBlock *TruePred = nullptr, *FalsePred = nullptr; - for (auto *Pred : predecessors(BB)) { - auto *Input = cast<ConstantInt>(PN.getIncomingValueForBlock(Pred)); - if (Input->isAllOnesValue()) - TruePred = Pred; - else - FalsePred = Pred; + Cond = BI->getCondition(); + AddSucc(ConstantInt::getTrue(Context), BI->getSuccessor(0)); + AddSucc(ConstantInt::getFalse(Context), BI->getSuccessor(1)); + } else if (auto *SI = dyn_cast<SwitchInst>(IDom->getTerminator())) { + Cond = SI->getCondition(); + ++SuccCount[SI->getDefaultDest()]; + for (auto Case : SI->cases()) + AddSucc(Case.getCaseValue(), Case.getCaseSuccessor()); + } else { + return nullptr; } - assert(TruePred && FalsePred && "Must be!"); - // Check which edge of the dominator dominates the true input. If it is the - // false edge, we should invert the condition. - auto *IDom = DT.getNode(BB)->getIDom()->getBlock(); - auto *BI = dyn_cast<BranchInst>(IDom->getTerminator()); - if (!BI || BI->isUnconditional()) + if (Cond->getType() != PN.getType()) return nullptr; // Check that edges outgoing from the idom's terminators dominate respective // inputs of the Phi. - BasicBlockEdge TrueOutEdge(IDom, BI->getSuccessor(0)); - BasicBlockEdge FalseOutEdge(IDom, BI->getSuccessor(1)); + Optional<bool> Invert; + for (auto Pair : zip(PN.incoming_values(), PN.blocks())) { + auto *Input = cast<ConstantInt>(std::get<0>(Pair)); + BasicBlock *Pred = std::get<1>(Pair); + auto IsCorrectInput = [&](ConstantInt *Input) { + // The input needs to be dominated by the corresponding edge of the idom. + // This edge cannot be a multi-edge, as that would imply that multiple + // different condition values follow the same edge. + auto It = SuccForValue.find(Input); + return It != SuccForValue.end() && SuccCount[It->second] == 1 && + DT.dominates(BasicBlockEdge(IDom, It->second), + BasicBlockEdge(Pred, BB)); + }; + + // Depending on the constant, the condition may need to be inverted. + bool NeedsInvert; + if (IsCorrectInput(Input)) + NeedsInvert = false; + else if (IsCorrectInput(cast<ConstantInt>(ConstantExpr::getNot(Input)))) + NeedsInvert = true; + else + return nullptr; + + // Make sure the inversion requirement is always the same. + if (Invert && *Invert != NeedsInvert) + return nullptr; - BasicBlockEdge TrueIncEdge(TruePred, BB); - BasicBlockEdge FalseIncEdge(FalsePred, BB); + Invert = NeedsInvert; + } - auto *Cond = BI->getCondition(); - if (DT.dominates(TrueOutEdge, TrueIncEdge) && - DT.dominates(FalseOutEdge, FalseIncEdge)) - // This Phi is actually equivalent to branching condition of IDom. + if (!*Invert) return Cond; - if (DT.dominates(TrueOutEdge, FalseIncEdge) && - DT.dominates(FalseOutEdge, TrueIncEdge)) { - // This Phi is actually opposite to branching condition of IDom. We invert - // the condition that will potentially open up some opportunities for - // sinking. - auto InsertPt = BB->getFirstInsertionPt(); - if (InsertPt != BB->end()) { - Self.Builder.SetInsertPoint(&*InsertPt); - return Self.Builder.CreateNot(Cond); - } + + // This Phi is actually opposite to branching condition of IDom. We invert + // the condition that will potentially open up some opportunities for + // sinking. + auto InsertPt = BB->getFirstInsertionPt(); + if (InsertPt != BB->end()) { + Self.Builder.SetInsertPoint(&*InsertPt); + return Self.Builder.CreateNot(Cond); } return nullptr; @@ -1327,7 +1370,7 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN, // PHINode simplification // Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { - if (Value *V = SimplifyInstruction(&PN, SQ.getWithInstruction(&PN))) + if (Value *V = simplifyInstruction(&PN, SQ.getWithInstruction(&PN))) return replaceInstUsesWith(PN, V); if (Instruction *Result = foldPHIArgZextsIntoPHI(PN)) diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 881b00f2a55a..ad96a5f475f1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -22,6 +22,7 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" +#include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" @@ -49,13 +50,6 @@ using namespace llvm; using namespace PatternMatch; -static Value *createMinMax(InstCombiner::BuilderTy &Builder, - SelectPatternFlavor SPF, Value *A, Value *B) { - CmpInst::Predicate Pred = getMinMaxPred(SPF); - assert(CmpInst::isIntPredicate(Pred) && "Expected integer predicate"); - return Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B); -} - /// Replace a select operand based on an equality comparison with the identity /// constant of a binop. static Instruction *foldSelectBinOpIdentity(SelectInst &Sel, @@ -370,6 +364,7 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, // one-use constraint, but that needs be examined carefully since it may not // reduce the total number of instructions. if (TI->getNumOperands() != 2 || FI->getNumOperands() != 2 || + !TI->isSameOperationAs(FI) || (!isa<BinaryOperator>(TI) && !isa<GetElementPtrInst>(TI)) || !TI->hasOneUse() || !FI->hasOneUse()) return nullptr; @@ -444,69 +439,56 @@ Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, Value *FalseVal) { // See the comment above GetSelectFoldableOperands for a description of the // transformation we are doing here. - if (auto *TVI = dyn_cast<BinaryOperator>(TrueVal)) { - if (TVI->hasOneUse() && !isa<Constant>(FalseVal)) { - if (unsigned SFO = getSelectFoldableOperands(TVI)) { - unsigned OpToFold = 0; - if ((SFO & 1) && FalseVal == TVI->getOperand(0)) { - OpToFold = 1; - } else if ((SFO & 2) && FalseVal == TVI->getOperand(1)) { - OpToFold = 2; - } - - if (OpToFold) { - Constant *C = ConstantExpr::getBinOpIdentity(TVI->getOpcode(), - TVI->getType(), true); - Value *OOp = TVI->getOperand(2-OpToFold); - // Avoid creating select between 2 constants unless it's selecting - // between 0, 1 and -1. - const APInt *OOpC; - bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); - if (!isa<Constant>(OOp) || - (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { - Value *NewSel = Builder.CreateSelect(SI.getCondition(), OOp, C); - NewSel->takeName(TVI); - BinaryOperator *BO = BinaryOperator::Create(TVI->getOpcode(), - FalseVal, NewSel); - BO->copyIRFlags(TVI); - return BO; + auto TryFoldSelectIntoOp = [&](SelectInst &SI, Value *TrueVal, + Value *FalseVal, + bool Swapped) -> Instruction * { + if (auto *TVI = dyn_cast<BinaryOperator>(TrueVal)) { + if (TVI->hasOneUse() && !isa<Constant>(FalseVal)) { + if (unsigned SFO = getSelectFoldableOperands(TVI)) { + unsigned OpToFold = 0; + if ((SFO & 1) && FalseVal == TVI->getOperand(0)) + OpToFold = 1; + else if ((SFO & 2) && FalseVal == TVI->getOperand(1)) + OpToFold = 2; + + if (OpToFold) { + FastMathFlags FMF; + // TODO: We probably ought to revisit cases where the select and FP + // instructions have different flags and add tests to ensure the + // behaviour is correct. + if (isa<FPMathOperator>(&SI)) + FMF = SI.getFastMathFlags(); + Constant *C = ConstantExpr::getBinOpIdentity( + TVI->getOpcode(), TVI->getType(), true, FMF.noSignedZeros()); + Value *OOp = TVI->getOperand(2 - OpToFold); + // Avoid creating select between 2 constants unless it's selecting + // between 0, 1 and -1. + const APInt *OOpC; + bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); + if (!isa<Constant>(OOp) || + (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { + Value *NewSel = Builder.CreateSelect( + SI.getCondition(), Swapped ? C : OOp, Swapped ? OOp : C); + if (isa<FPMathOperator>(&SI)) + cast<Instruction>(NewSel)->setFastMathFlags(FMF); + NewSel->takeName(TVI); + BinaryOperator *BO = + BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel); + BO->copyIRFlags(TVI); + return BO; + } } } } } - } + return nullptr; + }; - if (auto *FVI = dyn_cast<BinaryOperator>(FalseVal)) { - if (FVI->hasOneUse() && !isa<Constant>(TrueVal)) { - if (unsigned SFO = getSelectFoldableOperands(FVI)) { - unsigned OpToFold = 0; - if ((SFO & 1) && TrueVal == FVI->getOperand(0)) { - OpToFold = 1; - } else if ((SFO & 2) && TrueVal == FVI->getOperand(1)) { - OpToFold = 2; - } + if (Instruction *R = TryFoldSelectIntoOp(SI, TrueVal, FalseVal, false)) + return R; - if (OpToFold) { - Constant *C = ConstantExpr::getBinOpIdentity(FVI->getOpcode(), - FVI->getType(), true); - Value *OOp = FVI->getOperand(2-OpToFold); - // Avoid creating select between 2 constants unless it's selecting - // between 0, 1 and -1. - const APInt *OOpC; - bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); - if (!isa<Constant>(OOp) || - (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { - Value *NewSel = Builder.CreateSelect(SI.getCondition(), C, OOp); - NewSel->takeName(FVI); - BinaryOperator *BO = BinaryOperator::Create(FVI->getOpcode(), - TrueVal, NewSel); - BO->copyIRFlags(FVI); - return BO; - } - } - } - } - } + if (Instruction *R = TryFoldSelectIntoOp(SI, FalseVal, TrueVal, true)) + return R; return nullptr; } @@ -535,6 +517,16 @@ static Instruction *foldSelectICmpAndAnd(Type *SelType, const ICmpInst *Cmp, // Where %B may be optionally shifted: lshr %X, %Z. Value *X, *Z; const bool HasShift = match(B, m_OneUse(m_LShr(m_Value(X), m_Value(Z)))); + + // The shift must be valid. + // TODO: This restricts the fold to constant shift amounts. Is there a way to + // handle variable shifts safely? PR47012 + if (HasShift && + !match(Z, m_SpecificInt_ICMP(CmpInst::ICMP_ULT, + APInt(SelType->getScalarSizeInBits(), + SelType->getScalarSizeInBits())))) + return nullptr; + if (!HasShift) X = B; @@ -1096,74 +1088,55 @@ static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) { return true; } -/// If this is an integer min/max (icmp + select) with a constant operand, -/// create the canonical icmp for the min/max operation and canonicalize the -/// constant to the 'false' operand of the select: -/// select (icmp Pred X, C1), C2, X --> select (icmp Pred' X, C2), X, C2 -/// Note: if C1 != C2, this will change the icmp constant to the existing -/// constant operand of the select. -static Instruction *canonicalizeMinMaxWithConstant(SelectInst &Sel, - ICmpInst &Cmp, - InstCombinerImpl &IC) { - if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1))) - return nullptr; - - // Canonicalize the compare predicate based on whether we have min or max. +static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp, + InstCombinerImpl &IC) { Value *LHS, *RHS; - SelectPatternResult SPR = matchSelectPattern(&Sel, LHS, RHS); - if (!SelectPatternResult::isMinOrMax(SPR.Flavor)) - return nullptr; - - // Is this already canonical? - ICmpInst::Predicate CanonicalPred = getMinMaxPred(SPR.Flavor); - if (Cmp.getOperand(0) == LHS && Cmp.getOperand(1) == RHS && - Cmp.getPredicate() == CanonicalPred) - return nullptr; - - // Bail out on unsimplified X-0 operand (due to some worklist management bug), - // as this may cause an infinite combine loop. Let the sub be folded first. - if (match(LHS, m_Sub(m_Value(), m_Zero())) || - match(RHS, m_Sub(m_Value(), m_Zero()))) - return nullptr; - - // Create the canonical compare and plug it into the select. - IC.replaceOperand(Sel, 0, IC.Builder.CreateICmp(CanonicalPred, LHS, RHS)); - - // If the select operands did not change, we're done. - if (Sel.getTrueValue() == LHS && Sel.getFalseValue() == RHS) - return &Sel; - - // If we are swapping the select operands, swap the metadata too. - assert(Sel.getTrueValue() == RHS && Sel.getFalseValue() == LHS && - "Unexpected results from matchSelectPattern"); - Sel.swapValues(); - Sel.swapProfMetadata(); - return &Sel; -} - -static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp, - InstCombinerImpl &IC) { - if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1))) + // TODO: What to do with pointer min/max patterns? + if (!Sel.getType()->isIntOrIntVectorTy()) return nullptr; - Value *LHS, *RHS; SelectPatternFlavor SPF = matchSelectPattern(&Sel, LHS, RHS).Flavor; - if (SPF != SelectPatternFlavor::SPF_ABS && - SPF != SelectPatternFlavor::SPF_NABS) - return nullptr; - - // Note that NSW flag can only be propagated for normal, non-negated abs! - bool IntMinIsPoison = SPF == SelectPatternFlavor::SPF_ABS && - match(RHS, m_NSWNeg(m_Specific(LHS))); - Constant *IntMinIsPoisonC = - ConstantInt::get(Type::getInt1Ty(Sel.getContext()), IntMinIsPoison); - Instruction *Abs = - IC.Builder.CreateBinaryIntrinsic(Intrinsic::abs, LHS, IntMinIsPoisonC); - - if (SPF == SelectPatternFlavor::SPF_NABS) - return BinaryOperator::CreateNeg(Abs); // Always without NSW flag! + if (SPF == SelectPatternFlavor::SPF_ABS || + SPF == SelectPatternFlavor::SPF_NABS) { + if (!Cmp.hasOneUse() && !RHS->hasOneUse()) + return nullptr; // TODO: Relax this restriction. + + // Note that NSW flag can only be propagated for normal, non-negated abs! + bool IntMinIsPoison = SPF == SelectPatternFlavor::SPF_ABS && + match(RHS, m_NSWNeg(m_Specific(LHS))); + Constant *IntMinIsPoisonC = + ConstantInt::get(Type::getInt1Ty(Sel.getContext()), IntMinIsPoison); + Instruction *Abs = + IC.Builder.CreateBinaryIntrinsic(Intrinsic::abs, LHS, IntMinIsPoisonC); + + if (SPF == SelectPatternFlavor::SPF_NABS) + return BinaryOperator::CreateNeg(Abs); // Always without NSW flag! + return IC.replaceInstUsesWith(Sel, Abs); + } + + if (SelectPatternResult::isMinOrMax(SPF)) { + Intrinsic::ID IntrinsicID; + switch (SPF) { + case SelectPatternFlavor::SPF_UMIN: + IntrinsicID = Intrinsic::umin; + break; + case SelectPatternFlavor::SPF_UMAX: + IntrinsicID = Intrinsic::umax; + break; + case SelectPatternFlavor::SPF_SMIN: + IntrinsicID = Intrinsic::smin; + break; + case SelectPatternFlavor::SPF_SMAX: + IntrinsicID = Intrinsic::smax; + break; + default: + llvm_unreachable("Unexpected SPF"); + } + return IC.replaceInstUsesWith( + Sel, IC.Builder.CreateBinaryIntrinsic(IntrinsicID, LHS, RHS)); + } - return IC.replaceInstUsesWith(Sel, Abs); + return nullptr; } /// If we have a select with an equality comparison, then we know the value in @@ -1336,6 +1309,7 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, ICmpInst::Predicate::ICMP_NE, APInt::getAllOnes(C0->getType()->getScalarSizeInBits())))) return nullptr; // Can't do, have all-ones element[s]. + Pred0 = ICmpInst::getFlippedStrictnessPredicate(Pred0); C0 = InstCombiner::AddOne(C0); break; default: @@ -1401,15 +1375,22 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, case ICmpInst::Predicate::ICMP_SGE: // Also non-canonical, but here we don't need to change C2, // so we don't have any restrictions on C2, so we can just handle it. + Pred1 = ICmpInst::Predicate::ICMP_SLT; std::swap(ReplacementLow, ReplacementHigh); break; default: return nullptr; // Unknown predicate. } + assert(Pred1 == ICmpInst::Predicate::ICMP_SLT && + "Unexpected predicate type."); // The thresholds of this clamp-like pattern. auto *ThresholdLowIncl = ConstantExpr::getNeg(C1); auto *ThresholdHighExcl = ConstantExpr::getSub(C0, C1); + + assert((Pred0 == ICmpInst::Predicate::ICMP_ULT || + Pred0 == ICmpInst::Predicate::ICMP_UGE) && + "Unexpected predicate type."); if (Pred0 == ICmpInst::Predicate::ICMP_UGE) std::swap(ThresholdLowIncl, ThresholdHighExcl); @@ -1530,17 +1511,71 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, return &Sel; } +static Instruction *foldSelectZeroOrOnes(ICmpInst *Cmp, Value *TVal, + Value *FVal, + InstCombiner::BuilderTy &Builder) { + if (!Cmp->hasOneUse()) + return nullptr; + + const APInt *CmpC; + if (!match(Cmp->getOperand(1), m_APIntAllowUndef(CmpC))) + return nullptr; + + // (X u< 2) ? -X : -1 --> sext (X != 0) + Value *X = Cmp->getOperand(0); + if (Cmp->getPredicate() == ICmpInst::ICMP_ULT && *CmpC == 2 && + match(TVal, m_Neg(m_Specific(X))) && match(FVal, m_AllOnes())) + return new SExtInst(Builder.CreateIsNotNull(X), TVal->getType()); + + // (X u> 1) ? -1 : -X --> sext (X != 0) + if (Cmp->getPredicate() == ICmpInst::ICMP_UGT && *CmpC == 1 && + match(FVal, m_Neg(m_Specific(X))) && match(TVal, m_AllOnes())) + return new SExtInst(Builder.CreateIsNotNull(X), TVal->getType()); + + return nullptr; +} + +static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI) { + const APInt *CmpC; + Value *V; + CmpInst::Predicate Pred; + if (!match(ICI, m_ICmp(Pred, m_Value(V), m_APInt(CmpC)))) + return nullptr; + + BinaryOperator *BO; + const APInt *C; + CmpInst::Predicate CPred; + if (match(&SI, m_Select(m_Specific(ICI), m_APInt(C), m_BinOp(BO)))) + CPred = ICI->getPredicate(); + else if (match(&SI, m_Select(m_Specific(ICI), m_BinOp(BO), m_APInt(C)))) + CPred = ICI->getInversePredicate(); + else + return nullptr; + + const APInt *BinOpC; + if (!match(BO, m_BinOp(m_Specific(V), m_APInt(BinOpC)))) + return nullptr; + + ConstantRange R = ConstantRange::makeExactICmpRegion(CPred, *CmpC) + .binaryOp(BO->getOpcode(), *BinOpC); + if (R == *C) { + BO->dropPoisonGeneratingFlags(); + return BO; + } + return nullptr; +} + /// Visit a SelectInst that has an ICmpInst as its first operand. Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI) { if (Instruction *NewSel = foldSelectValueEquivalence(SI, *ICI)) return NewSel; - if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, *this)) - return NewSel; + if (Instruction *NewSPF = canonicalizeSPF(SI, *ICI, *this)) + return NewSPF; - if (Instruction *NewAbs = canonicalizeAbsNabs(SI, *ICI, *this)) - return NewAbs; + if (Value *V = foldSelectInstWithICmpConst(SI, ICI)) + return replaceInstUsesWith(SI, V); if (Value *V = canonicalizeClampLike(SI, *ICI, Builder)) return replaceInstUsesWith(SI, V); @@ -1581,8 +1616,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, ICI->hasOneUse()) { InstCombiner::BuilderTy::InsertPointGuard Guard(Builder); Builder.SetInsertPoint(&SI); - Value *IsNeg = Builder.CreateICmpSLT( - CmpLHS, ConstantInt::getNullValue(CmpLHS->getType()), ICI->getName()); + Value *IsNeg = Builder.CreateIsNeg(CmpLHS, ICI->getName()); replaceOperand(SI, 0, IsNeg); SI.swapValues(); SI.swapProfMetadata(); @@ -1646,6 +1680,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Instruction *V = foldSelectCtlzToCttz(ICI, TrueVal, FalseVal, Builder)) return V; + if (Instruction *V = foldSelectZeroOrOnes(ICI, TrueVal, FalseVal, Builder)) + return V; + if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); @@ -1715,114 +1752,6 @@ Instruction *InstCombinerImpl::foldSPFofSPF(Instruction *Inner, // TODO: This could be done in instsimplify. if (SPF1 == SPF2 && SelectPatternResult::isMinOrMax(SPF1)) return replaceInstUsesWith(Outer, Inner); - - // MAX(MIN(a, b), a) -> a - // MIN(MAX(a, b), a) -> a - // TODO: This could be done in instsimplify. - if ((SPF1 == SPF_SMIN && SPF2 == SPF_SMAX) || - (SPF1 == SPF_SMAX && SPF2 == SPF_SMIN) || - (SPF1 == SPF_UMIN && SPF2 == SPF_UMAX) || - (SPF1 == SPF_UMAX && SPF2 == SPF_UMIN)) - return replaceInstUsesWith(Outer, C); - } - - if (SPF1 == SPF2) { - const APInt *CB, *CC; - if (match(B, m_APInt(CB)) && match(C, m_APInt(CC))) { - // MIN(MIN(A, 23), 97) -> MIN(A, 23) - // MAX(MAX(A, 97), 23) -> MAX(A, 97) - // TODO: This could be done in instsimplify. - if ((SPF1 == SPF_UMIN && CB->ule(*CC)) || - (SPF1 == SPF_SMIN && CB->sle(*CC)) || - (SPF1 == SPF_UMAX && CB->uge(*CC)) || - (SPF1 == SPF_SMAX && CB->sge(*CC))) - return replaceInstUsesWith(Outer, Inner); - - // MIN(MIN(A, 97), 23) -> MIN(A, 23) - // MAX(MAX(A, 23), 97) -> MAX(A, 97) - if ((SPF1 == SPF_UMIN && CB->ugt(*CC)) || - (SPF1 == SPF_SMIN && CB->sgt(*CC)) || - (SPF1 == SPF_UMAX && CB->ult(*CC)) || - (SPF1 == SPF_SMAX && CB->slt(*CC))) { - Outer.replaceUsesOfWith(Inner, A); - return &Outer; - } - } - } - - // max(max(A, B), min(A, B)) --> max(A, B) - // min(min(A, B), max(A, B)) --> min(A, B) - // TODO: This could be done in instsimplify. - if (SPF1 == SPF2 && - ((SPF1 == SPF_UMIN && match(C, m_c_UMax(m_Specific(A), m_Specific(B)))) || - (SPF1 == SPF_SMIN && match(C, m_c_SMax(m_Specific(A), m_Specific(B)))) || - (SPF1 == SPF_UMAX && match(C, m_c_UMin(m_Specific(A), m_Specific(B)))) || - (SPF1 == SPF_SMAX && match(C, m_c_SMin(m_Specific(A), m_Specific(B)))))) - return replaceInstUsesWith(Outer, Inner); - - // ABS(ABS(X)) -> ABS(X) - // NABS(NABS(X)) -> NABS(X) - // TODO: This could be done in instsimplify. - if (SPF1 == SPF2 && (SPF1 == SPF_ABS || SPF1 == SPF_NABS)) { - return replaceInstUsesWith(Outer, Inner); - } - - // ABS(NABS(X)) -> ABS(X) - // NABS(ABS(X)) -> NABS(X) - if ((SPF1 == SPF_ABS && SPF2 == SPF_NABS) || - (SPF1 == SPF_NABS && SPF2 == SPF_ABS)) { - SelectInst *SI = cast<SelectInst>(Inner); - Value *NewSI = - Builder.CreateSelect(SI->getCondition(), SI->getFalseValue(), - SI->getTrueValue(), SI->getName(), SI); - return replaceInstUsesWith(Outer, NewSI); - } - - auto IsFreeOrProfitableToInvert = - [&](Value *V, Value *&NotV, bool &ElidesXor) { - if (match(V, m_Not(m_Value(NotV)))) { - // If V has at most 2 uses then we can get rid of the xor operation - // entirely. - ElidesXor |= !V->hasNUsesOrMore(3); - return true; - } - - if (isFreeToInvert(V, !V->hasNUsesOrMore(3))) { - NotV = nullptr; - return true; - } - - return false; - }; - - Value *NotA, *NotB, *NotC; - bool ElidesXor = false; - - // MIN(MIN(~A, ~B), ~C) == ~MAX(MAX(A, B), C) - // MIN(MAX(~A, ~B), ~C) == ~MAX(MIN(A, B), C) - // MAX(MIN(~A, ~B), ~C) == ~MIN(MAX(A, B), C) - // MAX(MAX(~A, ~B), ~C) == ~MIN(MIN(A, B), C) - // - // This transform is performance neutral if we can elide at least one xor from - // the set of three operands, since we'll be tacking on an xor at the very - // end. - if (SelectPatternResult::isMinOrMax(SPF1) && - SelectPatternResult::isMinOrMax(SPF2) && - IsFreeOrProfitableToInvert(A, NotA, ElidesXor) && - IsFreeOrProfitableToInvert(B, NotB, ElidesXor) && - IsFreeOrProfitableToInvert(C, NotC, ElidesXor) && ElidesXor) { - if (!NotA) - NotA = Builder.CreateNot(A); - if (!NotB) - NotB = Builder.CreateNot(B); - if (!NotC) - NotC = Builder.CreateNot(C); - - Value *NewInner = createMinMax(Builder, getInverseMinMaxFlavor(SPF1), NotA, - NotB); - Value *NewOuter = Builder.CreateNot( - createMinMax(Builder, getInverseMinMaxFlavor(SPF2), NewInner, NotC)); - return replaceInstUsesWith(Outer, NewOuter); } return nullptr; @@ -2255,163 +2184,6 @@ static Value *foldSelectCmpXchg(SelectInst &SI) { return nullptr; } -static Instruction *moveAddAfterMinMax(SelectPatternFlavor SPF, Value *X, - Value *Y, - InstCombiner::BuilderTy &Builder) { - assert(SelectPatternResult::isMinOrMax(SPF) && "Expected min/max pattern"); - bool IsUnsigned = SPF == SelectPatternFlavor::SPF_UMIN || - SPF == SelectPatternFlavor::SPF_UMAX; - // TODO: If InstSimplify could fold all cases where C2 <= C1, we could change - // the constant value check to an assert. - Value *A; - const APInt *C1, *C2; - if (IsUnsigned && match(X, m_NUWAdd(m_Value(A), m_APInt(C1))) && - match(Y, m_APInt(C2)) && C2->uge(*C1) && X->hasNUses(2)) { - // umin (add nuw A, C1), C2 --> add nuw (umin A, C2 - C1), C1 - // umax (add nuw A, C1), C2 --> add nuw (umax A, C2 - C1), C1 - Value *NewMinMax = createMinMax(Builder, SPF, A, - ConstantInt::get(X->getType(), *C2 - *C1)); - return BinaryOperator::CreateNUW(BinaryOperator::Add, NewMinMax, - ConstantInt::get(X->getType(), *C1)); - } - - if (!IsUnsigned && match(X, m_NSWAdd(m_Value(A), m_APInt(C1))) && - match(Y, m_APInt(C2)) && X->hasNUses(2)) { - bool Overflow; - APInt Diff = C2->ssub_ov(*C1, Overflow); - if (!Overflow) { - // smin (add nsw A, C1), C2 --> add nsw (smin A, C2 - C1), C1 - // smax (add nsw A, C1), C2 --> add nsw (smax A, C2 - C1), C1 - Value *NewMinMax = createMinMax(Builder, SPF, A, - ConstantInt::get(X->getType(), Diff)); - return BinaryOperator::CreateNSW(BinaryOperator::Add, NewMinMax, - ConstantInt::get(X->getType(), *C1)); - } - } - - return nullptr; -} - -/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value. -Instruction *InstCombinerImpl::matchSAddSubSat(Instruction &MinMax1) { - Type *Ty = MinMax1.getType(); - - // We are looking for a tree of: - // max(INT_MIN, min(INT_MAX, add(sext(A), sext(B)))) - // Where the min and max could be reversed - Instruction *MinMax2; - BinaryOperator *AddSub; - const APInt *MinValue, *MaxValue; - if (match(&MinMax1, m_SMin(m_Instruction(MinMax2), m_APInt(MaxValue)))) { - if (!match(MinMax2, m_SMax(m_BinOp(AddSub), m_APInt(MinValue)))) - return nullptr; - } else if (match(&MinMax1, - m_SMax(m_Instruction(MinMax2), m_APInt(MinValue)))) { - if (!match(MinMax2, m_SMin(m_BinOp(AddSub), m_APInt(MaxValue)))) - return nullptr; - } else - return nullptr; - - // Check that the constants clamp a saturate, and that the new type would be - // sensible to convert to. - if (!(*MaxValue + 1).isPowerOf2() || -*MinValue != *MaxValue + 1) - return nullptr; - // In what bitwidth can this be treated as saturating arithmetics? - unsigned NewBitWidth = (*MaxValue + 1).logBase2() + 1; - // FIXME: This isn't quite right for vectors, but using the scalar type is a - // good first approximation for what should be done there. - if (!shouldChangeType(Ty->getScalarType()->getIntegerBitWidth(), NewBitWidth)) - return nullptr; - - // Also make sure that the number of uses is as expected. The 3 is for the - // the two items of the compare and the select, or 2 from a min/max. - unsigned ExpUses = isa<IntrinsicInst>(MinMax1) ? 2 : 3; - if (MinMax2->hasNUsesOrMore(ExpUses) || AddSub->hasNUsesOrMore(ExpUses)) - return nullptr; - - // Create the new type (which can be a vector type) - Type *NewTy = Ty->getWithNewBitWidth(NewBitWidth); - - Intrinsic::ID IntrinsicID; - if (AddSub->getOpcode() == Instruction::Add) - IntrinsicID = Intrinsic::sadd_sat; - else if (AddSub->getOpcode() == Instruction::Sub) - IntrinsicID = Intrinsic::ssub_sat; - else - return nullptr; - - // The two operands of the add/sub must be nsw-truncatable to the NewTy. This - // is usually achieved via a sext from a smaller type. - if (ComputeMaxSignificantBits(AddSub->getOperand(0), 0, AddSub) > - NewBitWidth || - ComputeMaxSignificantBits(AddSub->getOperand(1), 0, AddSub) > NewBitWidth) - return nullptr; - - // Finally create and return the sat intrinsic, truncated to the new type - Function *F = Intrinsic::getDeclaration(MinMax1.getModule(), IntrinsicID, NewTy); - Value *AT = Builder.CreateTrunc(AddSub->getOperand(0), NewTy); - Value *BT = Builder.CreateTrunc(AddSub->getOperand(1), NewTy); - Value *Sat = Builder.CreateCall(F, {AT, BT}); - return CastInst::Create(Instruction::SExt, Sat, Ty); -} - -/// Reduce a sequence of min/max with a common operand. -static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS, - Value *RHS, - InstCombiner::BuilderTy &Builder) { - assert(SelectPatternResult::isMinOrMax(SPF) && "Expected a min/max"); - // TODO: Allow FP min/max with nnan/nsz. - if (!LHS->getType()->isIntOrIntVectorTy()) - return nullptr; - - // Match 3 of the same min/max ops. Example: umin(umin(), umin()). - Value *A, *B, *C, *D; - SelectPatternResult L = matchSelectPattern(LHS, A, B); - SelectPatternResult R = matchSelectPattern(RHS, C, D); - if (SPF != L.Flavor || L.Flavor != R.Flavor) - return nullptr; - - // Look for a common operand. The use checks are different than usual because - // a min/max pattern typically has 2 uses of each op: 1 by the cmp and 1 by - // the select. - Value *MinMaxOp = nullptr; - Value *ThirdOp = nullptr; - if (!LHS->hasNUsesOrMore(3) && RHS->hasNUsesOrMore(3)) { - // If the LHS is only used in this chain and the RHS is used outside of it, - // reuse the RHS min/max because that will eliminate the LHS. - if (D == A || C == A) { - // min(min(a, b), min(c, a)) --> min(min(c, a), b) - // min(min(a, b), min(a, d)) --> min(min(a, d), b) - MinMaxOp = RHS; - ThirdOp = B; - } else if (D == B || C == B) { - // min(min(a, b), min(c, b)) --> min(min(c, b), a) - // min(min(a, b), min(b, d)) --> min(min(b, d), a) - MinMaxOp = RHS; - ThirdOp = A; - } - } else if (!RHS->hasNUsesOrMore(3)) { - // Reuse the LHS. This will eliminate the RHS. - if (D == A || D == B) { - // min(min(a, b), min(c, a)) --> min(min(a, b), c) - // min(min(a, b), min(c, b)) --> min(min(a, b), c) - MinMaxOp = LHS; - ThirdOp = C; - } else if (C == A || C == B) { - // min(min(a, b), min(b, d)) --> min(min(a, b), d) - // min(min(a, b), min(c, b)) --> min(min(a, b), d) - MinMaxOp = LHS; - ThirdOp = D; - } - } - if (!MinMaxOp || !ThirdOp) - return nullptr; - - CmpInst::Predicate P = getMinMaxPred(SPF); - Value *CmpABC = Builder.CreateICmp(P, MinMaxOp, ThirdOp); - return SelectInst::Create(CmpABC, MinMaxOp, ThirdOp); -} - /// Try to reduce a funnel/rotate pattern that includes a compare and select /// into a funnel shift intrinsic. Example: /// rotl32(a, b) --> (b == 0 ? a : ((a >> (32 - b)) | (a << b))) @@ -2501,7 +2273,8 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel, // Match select ?, TC, FC where the constants are equal but negated. // TODO: Generalize to handle a negated variable operand? const APFloat *TC, *FC; - if (!match(TVal, m_APFloat(TC)) || !match(FVal, m_APFloat(FC)) || + if (!match(TVal, m_APFloatAllowUndef(TC)) || + !match(FVal, m_APFloatAllowUndef(FC)) || !abs(*TC).bitwiseIsEqual(abs(*FC))) return nullptr; @@ -2521,17 +2294,16 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel, // (bitcast X) < 0 ? TC : -TC --> copysign(TC, -X) // (bitcast X) >= 0 ? -TC : TC --> copysign(TC, -X) // (bitcast X) >= 0 ? TC : -TC --> copysign(TC, X) + // Note: FMF from the select can not be propagated to the new instructions. if (IsTrueIfSignSet ^ TC->isNegative()) - X = Builder.CreateFNegFMF(X, &Sel); + X = Builder.CreateFNeg(X); // Canonicalize the magnitude argument as the positive constant since we do // not care about its sign. - Value *MagArg = TC->isNegative() ? FVal : TVal; + Value *MagArg = ConstantFP::get(SelType, abs(*TC)); Function *F = Intrinsic::getDeclaration(Sel.getModule(), Intrinsic::copysign, Sel.getType()); - Instruction *CopySign = CallInst::Create(F, { MagArg, X }); - CopySign->setFastMathFlags(Sel.getFastMathFlags()); - return CopySign; + return CallInst::Create(F, { MagArg, X }); } Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { @@ -2732,29 +2504,144 @@ Instruction *InstCombinerImpl::foldAndOrOfSelectUsingImpliedCond(Value *Op, } } -Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { +// Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need +// fast-math-flags (nsz) or fsub with +0.0 (not fneg) for this to work. +static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI, + InstCombinerImpl &IC) { Value *CondVal = SI.getCondition(); - Value *TrueVal = SI.getTrueValue(); - Value *FalseVal = SI.getFalseValue(); - Type *SelType = SI.getType(); - // FIXME: Remove this workaround when freeze related patches are done. - // For select with undef operand which feeds into an equality comparison, - // don't simplify it so loop unswitch can know the equality comparison - // may have an undef operand. This is a workaround for PR31652 caused by - // descrepancy about branch on undef between LoopUnswitch and GVN. - if (match(TrueVal, m_Undef()) || match(FalseVal, m_Undef())) { - if (llvm::any_of(SI.users(), [&](User *U) { - ICmpInst *CI = dyn_cast<ICmpInst>(U); - if (CI && CI->isEquality()) - return true; - return false; - })) { + for (bool Swap : {false, true}) { + Value *TrueVal = SI.getTrueValue(); + Value *X = SI.getFalseValue(); + CmpInst::Predicate Pred; + + if (Swap) + std::swap(TrueVal, X); + + if (!match(CondVal, m_FCmp(Pred, m_Specific(X), m_AnyZeroFP()))) + continue; + + // fold (X <= +/-0.0) ? (0.0 - X) : X to fabs(X), when 'Swap' is false + // fold (X > +/-0.0) ? X : (0.0 - X) to fabs(X), when 'Swap' is true + if (match(TrueVal, m_FSub(m_PosZeroFP(), m_Specific(X)))) { + if (!Swap && (Pred == FCmpInst::FCMP_OLE || Pred == FCmpInst::FCMP_ULE)) { + Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI); + return IC.replaceInstUsesWith(SI, Fabs); + } + if (Swap && (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_UGT)) { + Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI); + return IC.replaceInstUsesWith(SI, Fabs); + } + } + + // With nsz, when 'Swap' is false: + // fold (X < +/-0.0) ? -X : X or (X <= +/-0.0) ? -X : X to fabs(X) + // fold (X > +/-0.0) ? -X : X or (X >= +/-0.0) ? -X : X to -fabs(x) + // when 'Swap' is true: + // fold (X > +/-0.0) ? X : -X or (X >= +/-0.0) ? X : -X to fabs(X) + // fold (X < +/-0.0) ? X : -X or (X <= +/-0.0) ? X : -X to -fabs(X) + if (!match(TrueVal, m_FNeg(m_Specific(X))) || !SI.hasNoSignedZeros()) return nullptr; + + if (Swap) + Pred = FCmpInst::getSwappedPredicate(Pred); + + bool IsLTOrLE = Pred == FCmpInst::FCMP_OLT || Pred == FCmpInst::FCMP_OLE || + Pred == FCmpInst::FCMP_ULT || Pred == FCmpInst::FCMP_ULE; + bool IsGTOrGE = Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_OGE || + Pred == FCmpInst::FCMP_UGT || Pred == FCmpInst::FCMP_UGE; + + if (IsLTOrLE) { + Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI); + return IC.replaceInstUsesWith(SI, Fabs); + } + if (IsGTOrGE) { + Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI); + Instruction *NewFNeg = UnaryOperator::CreateFNeg(Fabs); + NewFNeg->setFastMathFlags(SI.getFastMathFlags()); + return NewFNeg; } } - if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal, + return nullptr; +} + +// Match the following IR pattern: +// %x.lowbits = and i8 %x, %lowbitmask +// %x.lowbits.are.zero = icmp eq i8 %x.lowbits, 0 +// %x.biased = add i8 %x, %bias +// %x.biased.highbits = and i8 %x.biased, %highbitmask +// %x.roundedup = select i1 %x.lowbits.are.zero, i8 %x, i8 %x.biased.highbits +// Define: +// %alignment = add i8 %lowbitmask, 1 +// Iff 1. an %alignment is a power-of-two (aka, %lowbitmask is a low bit mask) +// and 2. %bias is equal to either %lowbitmask or %alignment, +// and 3. %highbitmask is equal to ~%lowbitmask (aka, to -%alignment) +// then this pattern can be transformed into: +// %x.offset = add i8 %x, %lowbitmask +// %x.roundedup = and i8 %x.offset, %highbitmask +static Value * +foldRoundUpIntegerWithPow2Alignment(SelectInst &SI, + InstCombiner::BuilderTy &Builder) { + Value *Cond = SI.getCondition(); + Value *X = SI.getTrueValue(); + Value *XBiasedHighBits = SI.getFalseValue(); + + ICmpInst::Predicate Pred; + Value *XLowBits; + if (!match(Cond, m_ICmp(Pred, m_Value(XLowBits), m_ZeroInt())) || + !ICmpInst::isEquality(Pred)) + return nullptr; + + if (Pred == ICmpInst::Predicate::ICMP_NE) + std::swap(X, XBiasedHighBits); + + // FIXME: we could support non non-splats here. + + const APInt *LowBitMaskCst; + if (!match(XLowBits, m_And(m_Specific(X), m_APIntAllowUndef(LowBitMaskCst)))) + return nullptr; + + const APInt *BiasCst, *HighBitMaskCst; + if (!match(XBiasedHighBits, + m_And(m_Add(m_Specific(X), m_APIntAllowUndef(BiasCst)), + m_APIntAllowUndef(HighBitMaskCst)))) + return nullptr; + + if (!LowBitMaskCst->isMask()) + return nullptr; + + APInt InvertedLowBitMaskCst = ~*LowBitMaskCst; + if (InvertedLowBitMaskCst != *HighBitMaskCst) + return nullptr; + + APInt AlignmentCst = *LowBitMaskCst + 1; + + if (*BiasCst != AlignmentCst && *BiasCst != *LowBitMaskCst) + return nullptr; + + if (!XBiasedHighBits->hasOneUse()) { + if (*BiasCst == *LowBitMaskCst) + return XBiasedHighBits; + return nullptr; + } + + // FIXME: could we preserve undef's here? + Type *Ty = X->getType(); + Value *XOffset = Builder.CreateAdd(X, ConstantInt::get(Ty, *LowBitMaskCst), + X->getName() + ".biased"); + Value *R = Builder.CreateAnd(XOffset, ConstantInt::get(Ty, *HighBitMaskCst)); + R->takeName(&SI); + return R; +} + +Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + Type *SelType = SI.getType(); + + if (Value *V = simplifySelectInst(CondVal, TrueVal, FalseVal, SQ.getWithInstruction(&SI))) return replaceInstUsesWith(SI, V); @@ -2764,8 +2651,6 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Instruction *I = canonicalizeScalarSelectOfVecs(SI, *this)) return I; - CmpInst::Predicate Pred; - // Avoid potential infinite loops by checking for non-constant condition. // TODO: Can we assert instead by improving canonicalizeSelectToShuffle()? // Scalar select must have simplified? @@ -2774,13 +2659,29 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { // Folding select to and/or i1 isn't poison safe in general. impliesPoison // checks whether folding it does not convert a well-defined value into // poison. - if (match(TrueVal, m_One()) && impliesPoison(FalseVal, CondVal)) { - // Change: A = select B, true, C --> A = or B, C - return BinaryOperator::CreateOr(CondVal, FalseVal); + if (match(TrueVal, m_One())) { + if (impliesPoison(FalseVal, CondVal)) { + // Change: A = select B, true, C --> A = or B, C + return BinaryOperator::CreateOr(CondVal, FalseVal); + } + + if (auto *LHS = dyn_cast<FCmpInst>(CondVal)) + if (auto *RHS = dyn_cast<FCmpInst>(FalseVal)) + if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ false, + /*IsSelectLogical*/ true)) + return replaceInstUsesWith(SI, V); } - if (match(FalseVal, m_Zero()) && impliesPoison(TrueVal, CondVal)) { - // Change: A = select B, C, false --> A = and B, C - return BinaryOperator::CreateAnd(CondVal, TrueVal); + if (match(FalseVal, m_Zero())) { + if (impliesPoison(TrueVal, CondVal)) { + // Change: A = select B, C, false --> A = and B, C + return BinaryOperator::CreateAnd(CondVal, TrueVal); + } + + if (auto *LHS = dyn_cast<FCmpInst>(CondVal)) + if (auto *RHS = dyn_cast<FCmpInst>(TrueVal)) + if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ true, + /*IsSelectLogical*/ true)) + return replaceInstUsesWith(SI, V); } auto *One = ConstantInt::getTrue(SelType); @@ -2838,6 +2739,20 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero())) return replaceOperand(SI, 0, A); + Value *C; + // select (~a | c), a, b -> and a, (or c, freeze(b)) + if (match(CondVal, m_c_Or(m_Not(m_Specific(TrueVal)), m_Value(C))) && + CondVal->hasOneUse()) { + FalseVal = Builder.CreateFreeze(FalseVal); + return BinaryOperator::CreateAnd(TrueVal, Builder.CreateOr(C, FalseVal)); + } + // select (~c & b), a, b -> and b, (or freeze(a), c) + if (match(CondVal, m_c_And(m_Not(m_Value(C)), m_Specific(FalseVal))) && + CondVal->hasOneUse()) { + TrueVal = Builder.CreateFreeze(TrueVal); + return BinaryOperator::CreateAnd(FalseVal, Builder.CreateOr(C, TrueVal)); + } + if (!SelType->isVectorTy()) { if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal, One, SQ, /* AllowRefinement */ true)) @@ -2863,16 +2778,11 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { /* IsAnd */ IsAnd)) return I; - if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal)) { - if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1)) { - if (auto *V = foldAndOrOfICmpsOfAndWithPow2(ICmp0, ICmp1, &SI, IsAnd, - /* IsLogical */ true)) - return replaceInstUsesWith(SI, V); - - if (auto *V = foldEqOfParts(ICmp0, ICmp1, IsAnd)) + if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal)) + if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1)) + if (auto *V = foldAndOrOfICmps(ICmp0, ICmp1, SI, IsAnd, + /* IsLogical */ true)) return replaceInstUsesWith(SI, V); - } - } } // select (select a, true, b), c, false -> select a, c, false @@ -2976,42 +2886,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } } - // Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need - // fast-math-flags (nsz) or fsub with +0.0 (not fneg) for this to work. - // (X <= +/-0.0) ? (0.0 - X) : X --> fabs(X) - if (match(CondVal, m_FCmp(Pred, m_Specific(FalseVal), m_AnyZeroFP())) && - match(TrueVal, m_FSub(m_PosZeroFP(), m_Specific(FalseVal))) && - (Pred == FCmpInst::FCMP_OLE || Pred == FCmpInst::FCMP_ULE)) { - Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, FalseVal, &SI); - return replaceInstUsesWith(SI, Fabs); - } - // (X > +/-0.0) ? X : (0.0 - X) --> fabs(X) - if (match(CondVal, m_FCmp(Pred, m_Specific(TrueVal), m_AnyZeroFP())) && - match(FalseVal, m_FSub(m_PosZeroFP(), m_Specific(TrueVal))) && - (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_UGT)) { - Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, TrueVal, &SI); - return replaceInstUsesWith(SI, Fabs); - } - // With nnan and nsz: - // (X < +/-0.0) ? -X : X --> fabs(X) - // (X <= +/-0.0) ? -X : X --> fabs(X) - if (match(CondVal, m_FCmp(Pred, m_Specific(FalseVal), m_AnyZeroFP())) && - match(TrueVal, m_FNeg(m_Specific(FalseVal))) && SI.hasNoSignedZeros() && - (Pred == FCmpInst::FCMP_OLT || Pred == FCmpInst::FCMP_OLE || - Pred == FCmpInst::FCMP_ULT || Pred == FCmpInst::FCMP_ULE)) { - Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, FalseVal, &SI); - return replaceInstUsesWith(SI, Fabs); - } - // With nnan and nsz: - // (X > +/-0.0) ? X : -X --> fabs(X) - // (X >= +/-0.0) ? X : -X --> fabs(X) - if (match(CondVal, m_FCmp(Pred, m_Specific(TrueVal), m_AnyZeroFP())) && - match(FalseVal, m_FNeg(m_Specific(TrueVal))) && SI.hasNoSignedZeros() && - (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_OGE || - Pred == FCmpInst::FCMP_UGT || Pred == FCmpInst::FCMP_UGE)) { - Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, TrueVal, &SI); - return replaceInstUsesWith(SI, Fabs); - } + // Fold selecting to fabs. + if (Instruction *Fabs = foldSelectWithFCmpToFabs(SI, *this)) + return Fabs; // See if we are selecting two values based on a comparison of the two values. if (ICmpInst *ICI = dyn_cast<ICmpInst>(CondVal)) @@ -3083,8 +2960,6 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Instruction *R = foldSPFofSPF(cast<Instruction>(RHS), SPF2, LHS2, RHS2, SI, SPF, LHS)) return R; - // TODO. - // ABS(-X) -> ABS(X) } if (SelectPatternResult::isMinOrMax(SPF)) { @@ -3119,46 +2994,6 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *NewCast = Builder.CreateCast(CastOp, NewSI, SelType); return replaceInstUsesWith(SI, NewCast); } - - // MAX(~a, ~b) -> ~MIN(a, b) - // MAX(~a, C) -> ~MIN(a, ~C) - // MIN(~a, ~b) -> ~MAX(a, b) - // MIN(~a, C) -> ~MAX(a, ~C) - auto moveNotAfterMinMax = [&](Value *X, Value *Y) -> Instruction * { - Value *A; - if (match(X, m_Not(m_Value(A))) && !X->hasNUsesOrMore(3) && - !isFreeToInvert(A, A->hasOneUse()) && - // Passing false to only consider m_Not and constants. - isFreeToInvert(Y, false)) { - Value *B = Builder.CreateNot(Y); - Value *NewMinMax = createMinMax(Builder, getInverseMinMaxFlavor(SPF), - A, B); - // Copy the profile metadata. - if (MDNode *MD = SI.getMetadata(LLVMContext::MD_prof)) { - cast<SelectInst>(NewMinMax)->setMetadata(LLVMContext::MD_prof, MD); - // Swap the metadata if the operands are swapped. - if (X == SI.getFalseValue() && Y == SI.getTrueValue()) - cast<SelectInst>(NewMinMax)->swapProfMetadata(); - } - - return BinaryOperator::CreateNot(NewMinMax); - } - - return nullptr; - }; - - if (Instruction *I = moveNotAfterMinMax(LHS, RHS)) - return I; - if (Instruction *I = moveNotAfterMinMax(RHS, LHS)) - return I; - - if (Instruction *I = moveAddAfterMinMax(SPF, LHS, RHS, Builder)) - return I; - - if (Instruction *I = factorizeMinMaxTree(SPF, LHS, RHS, Builder)) - return I; - if (Instruction *I = matchSAddSubSat(SI)) - return I; } } @@ -3324,35 +3159,42 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Value *Fr = foldSelectWithFrozenICmp(SI, Builder)) return replaceInstUsesWith(SI, Fr); + if (Value *V = foldRoundUpIntegerWithPow2Alignment(SI, Builder)) + return replaceInstUsesWith(SI, V); + // select(mask, mload(,,mask,0), 0) -> mload(,,mask,0) // Load inst is intentionally not checked for hasOneUse() if (match(FalseVal, m_Zero()) && - match(TrueVal, m_MaskedLoad(m_Value(), m_Value(), m_Specific(CondVal), - m_CombineOr(m_Undef(), m_Zero())))) { - auto *MaskedLoad = cast<IntrinsicInst>(TrueVal); - if (isa<UndefValue>(MaskedLoad->getArgOperand(3))) - MaskedLoad->setArgOperand(3, FalseVal /* Zero */); - return replaceInstUsesWith(SI, MaskedLoad); + (match(TrueVal, m_MaskedLoad(m_Value(), m_Value(), m_Specific(CondVal), + m_CombineOr(m_Undef(), m_Zero()))) || + match(TrueVal, m_MaskedGather(m_Value(), m_Value(), m_Specific(CondVal), + m_CombineOr(m_Undef(), m_Zero()))))) { + auto *MaskedInst = cast<IntrinsicInst>(TrueVal); + if (isa<UndefValue>(MaskedInst->getArgOperand(3))) + MaskedInst->setArgOperand(3, FalseVal /* Zero */); + return replaceInstUsesWith(SI, MaskedInst); } Value *Mask; if (match(TrueVal, m_Zero()) && - match(FalseVal, m_MaskedLoad(m_Value(), m_Value(), m_Value(Mask), - m_CombineOr(m_Undef(), m_Zero()))) && + (match(FalseVal, m_MaskedLoad(m_Value(), m_Value(), m_Value(Mask), + m_CombineOr(m_Undef(), m_Zero()))) || + match(FalseVal, m_MaskedGather(m_Value(), m_Value(), m_Value(Mask), + m_CombineOr(m_Undef(), m_Zero())))) && (CondVal->getType() == Mask->getType())) { // We can remove the select by ensuring the load zeros all lanes the // select would have. We determine this by proving there is no overlap // between the load and select masks. // (i.e (load_mask & select_mask) == 0 == no overlap) bool CanMergeSelectIntoLoad = false; - if (Value *V = SimplifyAndInst(CondVal, Mask, SQ.getWithInstruction(&SI))) + if (Value *V = simplifyAndInst(CondVal, Mask, SQ.getWithInstruction(&SI))) CanMergeSelectIntoLoad = match(V, m_Zero()); if (CanMergeSelectIntoLoad) { - auto *MaskedLoad = cast<IntrinsicInst>(FalseVal); - if (isa<UndefValue>(MaskedLoad->getArgOperand(3))) - MaskedLoad->setArgOperand(3, TrueVal /* Zero */); - return replaceInstUsesWith(SI, MaskedLoad); + auto *MaskedInst = cast<IntrinsicInst>(FalseVal); + if (isa<UndefValue>(MaskedInst->getArgOperand(3))) + MaskedInst->setArgOperand(3, TrueVal /* Zero */); + return replaceInstUsesWith(SI, MaskedInst); } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 17f0c5c4cff0..f4e2d1239f0f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -11,7 +11,6 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" -#include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" @@ -108,7 +107,7 @@ Value *InstCombinerImpl::reassociateShiftAmtsOfTwoSameDirectionShifts( // Can we fold (ShAmt0+ShAmt1) ? auto *NewShAmt = dyn_cast_or_null<Constant>( - SimplifyAddInst(ShAmt0, ShAmt1, /*isNSW=*/false, /*isNUW=*/false, + simplifyAddInst(ShAmt0, ShAmt1, /*isNSW=*/false, /*isNUW=*/false, SQ.getWithInstruction(Sh0))); if (!NewShAmt) return nullptr; // Did not simplify. @@ -232,7 +231,7 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, return nullptr; // Can we simplify (MaskShAmt+ShiftShAmt) ? - auto *SumOfShAmts = dyn_cast_or_null<Constant>(SimplifyAddInst( + auto *SumOfShAmts = dyn_cast_or_null<Constant>(simplifyAddInst( MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q)); if (!SumOfShAmts) return nullptr; // Did not simplify. @@ -264,7 +263,7 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, return nullptr; // Can we simplify (ShiftShAmt-MaskShAmt) ? - auto *ShAmtsDiff = dyn_cast_or_null<Constant>(SimplifySubInst( + auto *ShAmtsDiff = dyn_cast_or_null<Constant>(simplifySubInst( ShiftShAmt, MaskShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q)); if (!ShAmtsDiff) return nullptr; // Did not simplify. @@ -374,11 +373,12 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); assert(Op0->getType() == Op1->getType()); + Type *Ty = I.getType(); // If the shift amount is a one-use `sext`, we can demote it to `zext`. Value *Y; if (match(Op1, m_OneUse(m_SExt(m_Value(Y))))) { - Value *NewExt = Builder.CreateZExt(Y, I.getType(), Op1->getName()); + Value *NewExt = Builder.CreateZExt(Y, Ty, Op1->getName()); return BinaryOperator::Create(I.getOpcode(), Op0, NewExt); } @@ -400,15 +400,56 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) { reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ))) return NewShift; - // (C1 shift (A add C2)) -> (C1 shift C2) shift A) - // iff A and C2 are both positive. + // Pre-shift a constant shifted by a variable amount with constant offset: + // C shift (A add nuw C1) --> (C shift C1) shift A Value *A; - Constant *C; - if (match(Op0, m_Constant()) && match(Op1, m_Add(m_Value(A), m_Constant(C)))) - if (isKnownNonNegative(A, DL, 0, &AC, &I, &DT) && - isKnownNonNegative(C, DL, 0, &AC, &I, &DT)) - return BinaryOperator::Create( - I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), Op0, C), A); + Constant *C, *C1; + if (match(Op0, m_Constant(C)) && + match(Op1, m_NUWAdd(m_Value(A), m_Constant(C1)))) { + Value *NewC = Builder.CreateBinOp(I.getOpcode(), C, C1); + return BinaryOperator::Create(I.getOpcode(), NewC, A); + } + + unsigned BitWidth = Ty->getScalarSizeInBits(); + + const APInt *AC, *AddC; + // Try to pre-shift a constant shifted by a variable amount added with a + // negative number: + // C << (X - AddC) --> (C >> AddC) << X + // and + // C >> (X - AddC) --> (C << AddC) >> X + if (match(Op0, m_APInt(AC)) && match(Op1, m_Add(m_Value(A), m_APInt(AddC))) && + AddC->isNegative() && (-*AddC).ult(BitWidth)) { + assert(!AC->isZero() && "Expected simplify of shifted zero"); + unsigned PosOffset = (-*AddC).getZExtValue(); + + auto isSuitableForPreShift = [PosOffset, &I, AC]() { + switch (I.getOpcode()) { + default: + return false; + case Instruction::Shl: + return (I.hasNoSignedWrap() || I.hasNoUnsignedWrap()) && + AC->eq(AC->lshr(PosOffset).shl(PosOffset)); + case Instruction::LShr: + return I.isExact() && AC->eq(AC->shl(PosOffset).lshr(PosOffset)); + case Instruction::AShr: + return I.isExact() && AC->eq(AC->shl(PosOffset).ashr(PosOffset)); + } + }; + if (isSuitableForPreShift()) { + Constant *NewC = ConstantInt::get(Ty, I.getOpcode() == Instruction::Shl + ? AC->lshr(PosOffset) + : AC->shl(PosOffset)); + BinaryOperator *NewShiftOp = + BinaryOperator::Create(I.getOpcode(), NewC, A); + if (I.getOpcode() == Instruction::Shl) { + NewShiftOp->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + } else { + NewShiftOp->setIsExact(); + } + return NewShiftOp; + } + } // X shift (A srem C) -> X shift (A and (C - 1)) iff C is a power of 2. // Because shifts by negative values (which could occur if A were negative) @@ -417,7 +458,7 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) { match(C, m_Power2())) { // FIXME: Should this get moved into SimplifyDemandedBits by saying we don't // demand the sign bit (and many others) here?? - Constant *Mask = ConstantExpr::getSub(C, ConstantInt::get(I.getType(), 1)); + Constant *Mask = ConstantExpr::getSub(C, ConstantInt::get(Ty, 1)); Value *Rem = Builder.CreateAnd(A, Mask, Op1->getName()); return replaceOperand(I, 1, Rem); } @@ -661,10 +702,18 @@ static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift, } } -Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1, +Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1, BinaryOperator &I) { + // (C2 << X) << C1 --> (C2 << C1) << X + // (C2 >> X) >> C1 --> (C2 >> C1) >> X + Constant *C2; + Value *X; + if (match(Op0, m_BinOp(I.getOpcode(), m_Constant(C2), m_Value(X)))) + return BinaryOperator::Create( + I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), C2, C1), X); + const APInt *Op1C; - if (!match(Op1, m_APInt(Op1C))) + if (!match(C1, m_APInt(Op1C))) return nullptr; // See if we can propagate this shift into the input, this covers the trivial @@ -701,11 +750,11 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1, const APInt *Op0C; if (match(Op0BO->getOperand(1), m_APInt(Op0C))) { if (canShiftBinOpWithConstantRHS(I, Op0BO)) { - Constant *NewRHS = ConstantExpr::get( - I.getOpcode(), cast<Constant>(Op0BO->getOperand(1)), Op1); + Value *NewRHS = + Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(1), C1); Value *NewShift = - Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), Op1); + Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), C1); NewShift->takeName(Op0BO); return BinaryOperator::Create(Op0BO->getOpcode(), NewShift, NewRHS); @@ -730,10 +779,10 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1, if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal && match(TBO->getOperand(1), m_APInt(C)) && canShiftBinOpWithConstantRHS(I, TBO)) { - Constant *NewRHS = ConstantExpr::get( - I.getOpcode(), cast<Constant>(TBO->getOperand(1)), Op1); + Value *NewRHS = + Builder.CreateBinOp(I.getOpcode(), TBO->getOperand(1), C1); - Value *NewShift = Builder.CreateBinOp(I.getOpcode(), FalseVal, Op1); + Value *NewShift = Builder.CreateBinOp(I.getOpcode(), FalseVal, C1); Value *NewOp = Builder.CreateBinOp(TBO->getOpcode(), NewShift, NewRHS); return SelectInst::Create(Cond, NewOp, NewShift); } @@ -747,10 +796,10 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1, if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal && match(FBO->getOperand(1), m_APInt(C)) && canShiftBinOpWithConstantRHS(I, FBO)) { - Constant *NewRHS = ConstantExpr::get( - I.getOpcode(), cast<Constant>(FBO->getOperand(1)), Op1); + Value *NewRHS = + Builder.CreateBinOp(I.getOpcode(), FBO->getOperand(1), C1); - Value *NewShift = Builder.CreateBinOp(I.getOpcode(), TrueVal, Op1); + Value *NewShift = Builder.CreateBinOp(I.getOpcode(), TrueVal, C1); Value *NewOp = Builder.CreateBinOp(FBO->getOpcode(), NewShift, NewRHS); return SelectInst::Create(Cond, NewShift, NewOp); } @@ -762,7 +811,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1, Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { const SimplifyQuery Q = SQ.getWithInstruction(&I); - if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifyShlInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), Q)) return replaceInstUsesWith(I, V); @@ -968,10 +1017,6 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { if (match(Op1, m_Constant(C1))) { Constant *C2; Value *X; - // (C2 << X) << C1 --> (C2 << C1) << X - if (match(Op0, m_OneUse(m_Shl(m_Constant(C2), m_Value(X))))) - return BinaryOperator::CreateShl(ConstantExpr::getShl(C2, C1), X); - // (X * C2) << C1 --> X * (C2 << C1) if (match(Op0, m_Mul(m_Value(X), m_Constant(C2)))) return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1)); @@ -993,7 +1038,7 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { } Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { - if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), + if (Value *V = simplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1164,15 +1209,54 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { } } - // Look for a "splat" mul pattern - it replicates bits across each half of - // a value, so a right shift is just a mask of the low bits: - // lshr i32 (mul nuw X, Pow2+1), 16 --> and X, Pow2-1 - // TODO: Generalize to allow more than just half-width shifts? const APInt *MulC; - if (match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC))) && - ShAmtC * 2 == BitWidth && (*MulC - 1).isPowerOf2() && - MulC->logBase2() == ShAmtC) - return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2)); + if (match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC)))) { + // Look for a "splat" mul pattern - it replicates bits across each half of + // a value, so a right shift is just a mask of the low bits: + // lshr i[2N] (mul nuw X, (2^N)+1), N --> and iN X, (2^N)-1 + // TODO: Generalize to allow more than just half-width shifts? + if (BitWidth > 2 && ShAmtC * 2 == BitWidth && (*MulC - 1).isPowerOf2() && + MulC->logBase2() == ShAmtC) + return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2)); + + // The one-use check is not strictly necessary, but codegen may not be + // able to invert the transform and perf may suffer with an extra mul + // instruction. + if (Op0->hasOneUse()) { + APInt NewMulC = MulC->lshr(ShAmtC); + // if c is divisible by (1 << ShAmtC): + // lshr (mul nuw x, MulC), ShAmtC -> mul nuw x, (MulC >> ShAmtC) + if (MulC->eq(NewMulC.shl(ShAmtC))) { + auto *NewMul = + BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC)); + BinaryOperator *OrigMul = cast<BinaryOperator>(Op0); + NewMul->setHasNoSignedWrap(OrigMul->hasNoSignedWrap()); + return NewMul; + } + } + } + + // Try to narrow bswap. + // In the case where the shift amount equals the bitwidth difference, the + // shift is eliminated. + if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::bswap>( + m_OneUse(m_ZExt(m_Value(X))))))) { + unsigned SrcWidth = X->getType()->getScalarSizeInBits(); + unsigned WidthDiff = BitWidth - SrcWidth; + if (SrcWidth % 16 == 0) { + Value *NarrowSwap = Builder.CreateUnaryIntrinsic(Intrinsic::bswap, X); + if (ShAmtC >= WidthDiff) { + // (bswap (zext X)) >> C --> zext (bswap X >> C') + Value *NewShift = Builder.CreateLShr(NarrowSwap, ShAmtC - WidthDiff); + return new ZExtInst(NewShift, Ty); + } else { + // (bswap (zext X)) >> C --> (zext (bswap X)) << C' + Value *NewZExt = Builder.CreateZExt(NarrowSwap, Ty); + Constant *ShiftDiff = ConstantInt::get(Ty, WidthDiff - ShAmtC); + return BinaryOperator::CreateShl(NewZExt, ShiftDiff); + } + } + } // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && @@ -1263,7 +1347,7 @@ InstCombinerImpl::foldVariableSignZeroExtensionOfVariableHighBitExtract( } Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { - if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), + if (Value *V = simplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 3f064cfda712..9d4c01ac03e2 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -12,8 +12,8 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" -#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/KnownBits.h" @@ -154,6 +154,29 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (Depth == 0 && !V->hasOneUse()) DemandedMask.setAllBits(); + // If the high-bits of an ADD/SUB/MUL are not demanded, then we do not care + // about the high bits of the operands. + auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) { + unsigned NLZ = DemandedMask.countLeadingZeros(); + // Right fill the mask of bits for the operands to demand the most + // significant bit and all those below it. + DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); + if (ShrinkDemandedConstant(I, 0, DemandedFromOps) || + SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) || + ShrinkDemandedConstant(I, 1, DemandedFromOps) || + SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) { + if (NLZ > 0) { + // Disable the nsw and nuw flags here: We can no longer guarantee that + // we won't wrap after simplification. Removing the nsw/nuw flags is + // legal here because the top bit is not demanded. + I->setHasNoSignedWrap(false); + I->setHasNoUnsignedWrap(false); + } + return true; + } + return false; + }; + switch (I->getOpcode()) { default: computeKnownBits(I, Known, Depth, CxtI); @@ -297,13 +320,11 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, (LHSKnown.One & RHSKnown.One & DemandedMask) != 0) { APInt NewMask = ~(LHSKnown.One & RHSKnown.One & DemandedMask); - Constant *AndC = - ConstantInt::get(I->getType(), NewMask & AndRHS->getValue()); + Constant *AndC = ConstantInt::get(VTy, NewMask & AndRHS->getValue()); Instruction *NewAnd = BinaryOperator::CreateAnd(I->getOperand(0), AndC); InsertNewInstWith(NewAnd, *I); - Constant *XorC = - ConstantInt::get(I->getType(), NewMask & XorRHS->getValue()); + Constant *XorC = ConstantInt::get(VTy, NewMask & XorRHS->getValue()); Instruction *NewXor = BinaryOperator::CreateXor(NewAnd, XorC); return InsertNewInstWith(NewXor, *I); } @@ -311,33 +332,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, break; } case Instruction::Select: { - Value *LHS, *RHS; - SelectPatternFlavor SPF = matchSelectPattern(I, LHS, RHS).Flavor; - if (SPF == SPF_UMAX) { - // UMax(A, C) == A if ... - // The lowest non-zero bit of DemandMask is higher than the highest - // non-zero bit of C. - const APInt *C; - unsigned CTZ = DemandedMask.countTrailingZeros(); - if (match(RHS, m_APInt(C)) && CTZ >= C->getActiveBits()) - return LHS; - } else if (SPF == SPF_UMIN) { - // UMin(A, C) == A if ... - // The lowest non-zero bit of DemandMask is higher than the highest - // non-one bit of C. - // This comes from using DeMorgans on the above umax example. - const APInt *C; - unsigned CTZ = DemandedMask.countTrailingZeros(); - if (match(RHS, m_APInt(C)) && - CTZ >= C->getBitWidth() - C->countLeadingOnes()) - return LHS; - } - - // If this is a select as part of any other min/max pattern, don't simplify - // any further in case we break the structure. - if (SPF != SPF_UNKNOWN) - return nullptr; - if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnown, Depth + 1) || SimplifyDemandedBits(I, 1, DemandedMask, LHSKnown, Depth + 1)) return I; @@ -393,12 +387,12 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (match(I->getOperand(0), m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) { // The shift amount must be valid (not poison) in the narrow type, and // it must not be greater than the high bits demanded of the result. - if (C->ult(I->getType()->getScalarSizeInBits()) && + if (C->ult(VTy->getScalarSizeInBits()) && C->ule(DemandedMask.countLeadingZeros())) { // trunc (lshr X, C) --> lshr (trunc X), C IRBuilderBase::InsertPointGuard Guard(Builder); Builder.SetInsertPoint(I); - Value *Trunc = Builder.CreateTrunc(X, I->getType()); + Value *Trunc = Builder.CreateTrunc(X, VTy); return Builder.CreateLShr(Trunc, C->getZExtValue()); } } @@ -420,9 +414,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (!I->getOperand(0)->getType()->isIntOrIntVectorTy()) return nullptr; // vector->int or fp->int? - if (VectorType *DstVTy = dyn_cast<VectorType>(I->getType())) { - if (VectorType *SrcVTy = - dyn_cast<VectorType>(I->getOperand(0)->getType())) { + if (auto *DstVTy = dyn_cast<VectorType>(VTy)) { + if (auto *SrcVTy = dyn_cast<VectorType>(I->getOperand(0)->getType())) { if (cast<FixedVectorType>(DstVTy)->getNumElements() != cast<FixedVectorType>(SrcVTy)->getNumElements()) // Don't touch a bitcast between vectors of different element counts. @@ -507,26 +500,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } LLVM_FALLTHROUGH; case Instruction::Sub: { - /// If the high-bits of an ADD/SUB are not demanded, then we do not care - /// about the high bits of the operands. - unsigned NLZ = DemandedMask.countLeadingZeros(); - // Right fill the mask of bits for this ADD/SUB to demand the most - // significant bit and all those below it. - APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ)); - if (ShrinkDemandedConstant(I, 0, DemandedFromOps) || - SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) || - ShrinkDemandedConstant(I, 1, DemandedFromOps) || - SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) { - if (NLZ > 0) { - // Disable the nsw and nuw flags here: We can no longer guarantee that - // we won't wrap after simplification. Removing the nsw/nuw flags is - // legal here because the top bit is not demanded. - BinaryOperator &BinOP = *cast<BinaryOperator>(I); - BinOP.setHasNoSignedWrap(false); - BinOP.setHasNoUnsignedWrap(false); - } + APInt DemandedFromOps; + if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps)) return I; - } // If we are known to be adding/subtracting zeros to every bit below // the highest demanded bit, we just return the other side. @@ -544,6 +520,36 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, NSW, LHSKnown, RHSKnown); break; } + case Instruction::Mul: { + APInt DemandedFromOps; + if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps)) + return I; + + if (DemandedMask.isPowerOf2()) { + // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1. + // If we demand exactly one bit N and we have "X * (C' << N)" where C' is + // odd (has LSB set), then the left-shifted low bit of X is the answer. + unsigned CTZ = DemandedMask.countTrailingZeros(); + const APInt *C; + if (match(I->getOperand(1), m_APInt(C)) && + C->countTrailingZeros() == CTZ) { + Constant *ShiftC = ConstantInt::get(VTy, CTZ); + Instruction *Shl = BinaryOperator::CreateShl(I->getOperand(0), ShiftC); + return InsertNewInstWith(Shl, *I); + } + } + // For a squared value "X * X", the bottom 2 bits are 0 and X[0] because: + // X * X is odd iff X is odd. + // 'Quadratic Reciprocity': X * X -> 0 for bit[1] + if (I->getOperand(0) == I->getOperand(1) && DemandedMask.ult(4)) { + Constant *One = ConstantInt::get(VTy, 1); + Instruction *And1 = BinaryOperator::CreateAnd(I->getOperand(0), One); + return InsertNewInstWith(And1, *I); + } + + computeKnownBits(I, Known, Depth, CxtI); + break; + } case Instruction::Shl: { const APInt *SA; if (match(I->getOperand(1), m_APInt(SA))) { @@ -554,7 +560,26 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, DemandedMask, Known)) return R; + // TODO: If we only want bits that already match the signbit then we don't + // need to shift. + + // If we can pre-shift a right-shifted constant to the left without + // losing any high bits amd we don't demand the low bits, then eliminate + // the left-shift: + // (C >> X) << LeftShiftAmtC --> (C << RightShiftAmtC) >> X uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); + Value *X; + Constant *C; + if (DemandedMask.countTrailingZeros() >= ShiftAmt && + match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) { + Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt); + Constant *NewC = ConstantExpr::getShl(C, LeftShiftAmtC); + if (ConstantExpr::getLShr(NewC, LeftShiftAmtC) == C) { + Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X); + return InsertNewInstWith(Lshr, *I); + } + } + APInt DemandedMaskIn(DemandedMask.lshr(ShiftAmt)); // If the shift is NUW/NSW, then it does demand the high bits. @@ -584,7 +609,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, else if (SignBitOne) Known.One.setSignBit(); if (Known.hasConflict()) - return UndefValue::get(I->getType()); + return UndefValue::get(VTy); } } else { // This is a variable shift, so we can't shift the demand mask by a known @@ -607,6 +632,34 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (match(I->getOperand(1), m_APInt(SA))) { uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); + // If we are just demanding the shifted sign bit and below, then this can + // be treated as an ASHR in disguise. + if (DemandedMask.countLeadingZeros() >= ShiftAmt) { + // If we only want bits that already match the signbit then we don't + // need to shift. + unsigned NumHiDemandedBits = + BitWidth - DemandedMask.countTrailingZeros(); + unsigned SignBits = + ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI); + if (SignBits >= NumHiDemandedBits) + return I->getOperand(0); + + // If we can pre-shift a left-shifted constant to the right without + // losing any low bits (we already know we don't demand the high bits), + // then eliminate the right-shift: + // (C << X) >> RightShiftAmtC --> (C >> RightShiftAmtC) << X + Value *X; + Constant *C; + if (match(I->getOperand(0), m_Shl(m_ImmConstant(C), m_Value(X)))) { + Constant *RightShiftAmtC = ConstantInt::get(VTy, ShiftAmt); + Constant *NewC = ConstantExpr::getLShr(C, RightShiftAmtC); + if (ConstantExpr::getShl(NewC, RightShiftAmtC) == C) { + Instruction *Shl = BinaryOperator::CreateShl(NewC, X); + return InsertNewInstWith(Shl, *I); + } + } + } + // Unsigned shift right. APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); @@ -628,6 +681,14 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, break; } case Instruction::AShr: { + unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI); + + // If we only want bits that already match the signbit then we don't need + // to shift. + unsigned NumHiDemandedBits = BitWidth - DemandedMask.countTrailingZeros(); + if (SignBits >= NumHiDemandedBits) + return I->getOperand(0); + // If this is an arithmetic shift right and only the low-bit is set, we can // always convert this into a logical shr, even if the shift amount is // variable. The low bit of the shift cannot be an input sign bit unless @@ -639,11 +700,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return InsertNewInstWith(NewVal, *I); } - // If the sign bit is the only bit demanded by this ashr, then there is no - // need to do it, the shift doesn't change the high bit. - if (DemandedMask.isSignMask()) - return I->getOperand(0); - const APInt *SA; if (match(I->getOperand(1), m_APInt(SA))) { uint32_t ShiftAmt = SA->getLimitedValue(BitWidth-1); @@ -663,8 +719,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) return I; - unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI); - assert(!Known.hasConflict() && "Bits known to be one AND zero?"); // Compute the new bits that are at the top now plus sign bits. APInt HighBits(APInt::getHighBitsSet( @@ -713,13 +767,13 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, break; } case Instruction::SRem: { - ConstantInt *Rem; - if (match(I->getOperand(1), m_ConstantInt(Rem))) { + const APInt *Rem; + if (match(I->getOperand(1), m_APInt(Rem))) { // X % -1 demands all the bits because we don't want to introduce // INT_MIN % -1 (== undef) by accident. - if (Rem->isMinusOne()) + if (Rem->isAllOnes()) break; - APInt RA = Rem->getValue().abs(); + APInt RA = Rem->abs(); if (RA.isPowerOf2()) { if (DemandedMask.ult(RA)) // srem won't affect demanded bits return I->getOperand(0); @@ -786,7 +840,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (DemandedMask == 1 && VTy->getScalarSizeInBits() % 2 == 0 && match(II->getArgOperand(0), m_Not(m_Value(X)))) { Function *Ctpop = Intrinsic::getDeclaration( - II->getModule(), Intrinsic::ctpop, II->getType()); + II->getModule(), Intrinsic::ctpop, VTy); return InsertNewInstWith(CallInst::Create(Ctpop, {X}), *I); } break; @@ -809,12 +863,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Instruction *NewVal; if (NLZ > NTZ) NewVal = BinaryOperator::CreateLShr( - II->getArgOperand(0), - ConstantInt::get(I->getType(), NLZ - NTZ)); + II->getArgOperand(0), ConstantInt::get(VTy, NLZ - NTZ)); else NewVal = BinaryOperator::CreateShl( - II->getArgOperand(0), - ConstantInt::get(I->getType(), NTZ - NLZ)); + II->getArgOperand(0), ConstantInt::get(VTy, NTZ - NLZ)); NewVal->takeName(I); return InsertNewInstWith(NewVal, *I); } @@ -872,7 +924,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // Handle target specific intrinsics Optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic( *II, DemandedMask, Known, KnownBitsComputed); - if (V.hasValue()) + if (V) return V.getValue(); break; } @@ -1583,7 +1635,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, Optional<Value *> V = targetSimplifyDemandedVectorEltsIntrinsic( *II, DemandedElts, UndefElts, UndefElts2, UndefElts3, simplifyAndSetOp); - if (V.hasValue()) + if (V) return V.getValue(); break; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 736cf9c825d5..22659a8e4951 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -42,7 +42,6 @@ #include <utility> #define DEBUG_TYPE "instcombine" -#include "llvm/Transforms/Utils/InstructionWorklist.h" using namespace llvm; using namespace PatternMatch; @@ -378,7 +377,7 @@ ConstantInt *getPreferredVectorIndex(ConstantInt *IndexC) { Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { Value *SrcVec = EI.getVectorOperand(); Value *Index = EI.getIndexOperand(); - if (Value *V = SimplifyExtractElementInst(SrcVec, Index, + if (Value *V = simplifyExtractElementInst(SrcVec, Index, SQ.getWithInstruction(&EI))) return replaceInstUsesWith(EI, V); @@ -879,7 +878,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( // of an aggregate. If we did, that means the CurrIVI will later be // overwritten with the already-recorded value. But if not, let's record it! Optional<Instruction *> &Elt = AggElts[Indices.front()]; - Elt = Elt.getValueOr(InsertedValue); + Elt = Elt.value_or(InsertedValue); // FIXME: should we handle chain-terminating undef base operand? } @@ -1489,7 +1488,7 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { Value *ScalarOp = IE.getOperand(1); Value *IdxOp = IE.getOperand(2); - if (auto *V = SimplifyInsertElementInst( + if (auto *V = simplifyInsertElementInst( VecOp, ScalarOp, IdxOp, SQ.getWithInstruction(&IE))) return replaceInstUsesWith(IE, V); @@ -1919,24 +1918,29 @@ static BinopElts getAlternateBinop(BinaryOperator *BO, const DataLayout &DL) { Value *BO0 = BO->getOperand(0), *BO1 = BO->getOperand(1); Type *Ty = BO->getType(); switch (BO->getOpcode()) { - case Instruction::Shl: { - // shl X, C --> mul X, (1 << C) - Constant *C; - if (match(BO1, m_Constant(C))) { - Constant *ShlOne = ConstantExpr::getShl(ConstantInt::get(Ty, 1), C); - return { Instruction::Mul, BO0, ShlOne }; - } - break; - } - case Instruction::Or: { - // or X, C --> add X, C (when X and C have no common bits set) - const APInt *C; - if (match(BO1, m_APInt(C)) && MaskedValueIsZero(BO0, *C, DL)) - return { Instruction::Add, BO0, BO1 }; - break; + case Instruction::Shl: { + // shl X, C --> mul X, (1 << C) + Constant *C; + if (match(BO1, m_Constant(C))) { + Constant *ShlOne = ConstantExpr::getShl(ConstantInt::get(Ty, 1), C); + return {Instruction::Mul, BO0, ShlOne}; } - default: - break; + break; + } + case Instruction::Or: { + // or X, C --> add X, C (when X and C have no common bits set) + const APInt *C; + if (match(BO1, m_APInt(C)) && MaskedValueIsZero(BO0, *C, DL)) + return {Instruction::Add, BO0, BO1}; + break; + } + case Instruction::Sub: + // sub 0, X --> mul X, -1 + if (match(BO0, m_ZeroInt())) + return {Instruction::Mul, BO1, ConstantInt::getAllOnesValue(Ty)}; + break; + default: + break; } return {}; } @@ -2053,15 +2057,20 @@ Instruction *InstCombinerImpl::foldSelectShuffle(ShuffleVectorInst &Shuf) { !match(Shuf.getOperand(1), m_BinOp(B1))) return nullptr; + // If one operand is "0 - X", allow that to be viewed as "X * -1" + // (ConstantsAreOp1) by getAlternateBinop below. If the neg is not paired + // with a multiply, we will exit because C0/C1 will not be set. Value *X, *Y; - Constant *C0, *C1; + Constant *C0 = nullptr, *C1 = nullptr; bool ConstantsAreOp1; - if (match(B0, m_BinOp(m_Value(X), m_Constant(C0))) && - match(B1, m_BinOp(m_Value(Y), m_Constant(C1)))) - ConstantsAreOp1 = true; - else if (match(B0, m_BinOp(m_Constant(C0), m_Value(X))) && - match(B1, m_BinOp(m_Constant(C1), m_Value(Y)))) + if (match(B0, m_BinOp(m_Constant(C0), m_Value(X))) && + match(B1, m_BinOp(m_Constant(C1), m_Value(Y)))) ConstantsAreOp1 = false; + else if (match(B0, m_CombineOr(m_BinOp(m_Value(X), m_Constant(C0)), + m_Neg(m_Value(X)))) && + match(B1, m_CombineOr(m_BinOp(m_Value(Y), m_Constant(C1)), + m_Neg(m_Value(Y))))) + ConstantsAreOp1 = true; else return nullptr; @@ -2086,7 +2095,7 @@ Instruction *InstCombinerImpl::foldSelectShuffle(ShuffleVectorInst &Shuf) { } } - if (Opc0 != Opc1) + if (Opc0 != Opc1 || !C0 || !C1) return nullptr; // The opcodes must be the same. Use a new name to make that clear. @@ -2233,6 +2242,88 @@ static Instruction *narrowVectorSelect(ShuffleVectorInst &Shuf, return SelectInst::Create(NarrowCond, NarrowX, NarrowY); } +/// Canonicalize FP negate after shuffle. +static Instruction *foldFNegShuffle(ShuffleVectorInst &Shuf, + InstCombiner::BuilderTy &Builder) { + Instruction *FNeg0; + Value *X; + if (!match(Shuf.getOperand(0), m_CombineAnd(m_Instruction(FNeg0), + m_FNeg(m_Value(X))))) + return nullptr; + + // shuffle (fneg X), Mask --> fneg (shuffle X, Mask) + if (FNeg0->hasOneUse() && match(Shuf.getOperand(1), m_Undef())) { + Value *NewShuf = Builder.CreateShuffleVector(X, Shuf.getShuffleMask()); + return UnaryOperator::CreateFNegFMF(NewShuf, FNeg0); + } + + Instruction *FNeg1; + Value *Y; + if (!match(Shuf.getOperand(1), m_CombineAnd(m_Instruction(FNeg1), + m_FNeg(m_Value(Y))))) + return nullptr; + + // shuffle (fneg X), (fneg Y), Mask --> fneg (shuffle X, Y, Mask) + if (FNeg0->hasOneUse() || FNeg1->hasOneUse()) { + Value *NewShuf = Builder.CreateShuffleVector(X, Y, Shuf.getShuffleMask()); + Instruction *NewFNeg = UnaryOperator::CreateFNeg(NewShuf); + NewFNeg->copyIRFlags(FNeg0); + NewFNeg->andIRFlags(FNeg1); + return NewFNeg; + } + + return nullptr; +} + +/// Canonicalize casts after shuffle. +static Instruction *foldCastShuffle(ShuffleVectorInst &Shuf, + InstCombiner::BuilderTy &Builder) { + // Do we have 2 matching cast operands? + auto *Cast0 = dyn_cast<CastInst>(Shuf.getOperand(0)); + auto *Cast1 = dyn_cast<CastInst>(Shuf.getOperand(1)); + if (!Cast0 || !Cast1 || Cast0->getOpcode() != Cast1->getOpcode() || + Cast0->getSrcTy() != Cast1->getSrcTy()) + return nullptr; + + // TODO: Allow other opcodes? That would require easing the type restrictions + // below here. + CastInst::CastOps CastOpcode = Cast0->getOpcode(); + switch (CastOpcode) { + case Instruction::FPToSI: + case Instruction::FPToUI: + case Instruction::SIToFP: + case Instruction::UIToFP: + break; + default: + return nullptr; + } + + VectorType *ShufTy = Shuf.getType(); + VectorType *ShufOpTy = cast<VectorType>(Shuf.getOperand(0)->getType()); + VectorType *CastSrcTy = cast<VectorType>(Cast0->getSrcTy()); + + // TODO: Allow length-increasing shuffles? + if (ShufTy->getElementCount().getKnownMinValue() > + ShufOpTy->getElementCount().getKnownMinValue()) + return nullptr; + + // TODO: Allow element-size-decreasing casts (ex: fptosi float to i8)? + assert(isa<FixedVectorType>(CastSrcTy) && isa<FixedVectorType>(ShufOpTy) && + "Expected fixed vector operands for casts and binary shuffle"); + if (CastSrcTy->getPrimitiveSizeInBits() > ShufOpTy->getPrimitiveSizeInBits()) + return nullptr; + + // At least one of the operands must have only one use (the shuffle). + if (!Cast0->hasOneUse() && !Cast1->hasOneUse()) + return nullptr; + + // shuffle (cast X), (cast Y), Mask --> cast (shuffle X, Y, Mask) + Value *X = Cast0->getOperand(0); + Value *Y = Cast1->getOperand(0); + Value *NewShuf = Builder.CreateShuffleVector(X, Y, Shuf.getShuffleMask()); + return CastInst::Create(CastOpcode, NewShuf, ShufTy); +} + /// Try to fold an extract subvector operation. static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) { Value *Op0 = Shuf.getOperand(0), *Op1 = Shuf.getOperand(1); @@ -2442,7 +2533,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { Value *LHS = SVI.getOperand(0); Value *RHS = SVI.getOperand(1); SimplifyQuery ShufQuery = SQ.getWithInstruction(&SVI); - if (auto *V = SimplifyShuffleVectorInst(LHS, RHS, SVI.getShuffleMask(), + if (auto *V = simplifyShuffleVectorInst(LHS, RHS, SVI.getShuffleMask(), SVI.getType(), ShufQuery)) return replaceInstUsesWith(SVI, V); @@ -2497,7 +2588,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (!ScaledMask.empty()) { // If the shuffled source vector simplifies, cast that value to this // shuffle's type. - if (auto *V = SimplifyShuffleVectorInst(X, UndefValue::get(XType), + if (auto *V = simplifyShuffleVectorInst(X, UndefValue::get(XType), ScaledMask, XType, ShufQuery)) return BitCastInst::Create(Instruction::BitCast, V, SVI.getType()); } @@ -2528,6 +2619,12 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (Instruction *I = narrowVectorSelect(SVI, Builder)) return I; + if (Instruction *I = foldFNegShuffle(SVI, Builder)) + return I; + + if (Instruction *I = foldCastShuffle(SVI, Builder)) + return I; + APInt UndefElts(VWidth, 0); APInt AllOnesEltMask(APInt::getAllOnes(VWidth)); if (Value *V = SimplifyDemandedVectorElts(&SVI, AllOnesEltMask, UndefElts)) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 3091905ca534..0816a4a575d9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -42,7 +42,6 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/TinyPtrVector.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" @@ -60,6 +59,7 @@ #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/BasicBlock.h" @@ -90,8 +90,6 @@ #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" #include "llvm/InitializePasses.h" -#include "llvm/Pass.h" -#include "llvm/Support/CBindingWrapping.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" @@ -140,6 +138,10 @@ static cl::opt<bool> EnableCodeSinking("instcombine-code-sinking", cl::desc("Enable code sinking"), cl::init(true)); +static cl::opt<unsigned> MaxSinkNumUsers( + "instcombine-max-sink-users", cl::init(32), + cl::desc("Maximum number of undroppable users for instruction sinking")); + static cl::opt<unsigned> LimitMaxIterations( "instcombine-max-iterations", cl::desc("Limit the maximum number of instruction combining iterations"), @@ -424,7 +426,7 @@ bool InstCombinerImpl::SimplifyAssociativeOrCommutative(BinaryOperator &I) { Value *C = I.getOperand(1); // Does "B op C" simplify? - if (Value *V = SimplifyBinOp(Opcode, B, C, SQ.getWithInstruction(&I))) { + if (Value *V = simplifyBinOp(Opcode, B, C, SQ.getWithInstruction(&I))) { // It simplifies to V. Form "A op V". replaceOperand(I, 0, A); replaceOperand(I, 1, V); @@ -457,7 +459,7 @@ bool InstCombinerImpl::SimplifyAssociativeOrCommutative(BinaryOperator &I) { Value *C = Op1->getOperand(1); // Does "A op B" simplify? - if (Value *V = SimplifyBinOp(Opcode, A, B, SQ.getWithInstruction(&I))) { + if (Value *V = simplifyBinOp(Opcode, A, B, SQ.getWithInstruction(&I))) { // It simplifies to V. Form "V op C". replaceOperand(I, 0, V); replaceOperand(I, 1, C); @@ -485,7 +487,7 @@ bool InstCombinerImpl::SimplifyAssociativeOrCommutative(BinaryOperator &I) { Value *C = I.getOperand(1); // Does "C op A" simplify? - if (Value *V = SimplifyBinOp(Opcode, C, A, SQ.getWithInstruction(&I))) { + if (Value *V = simplifyBinOp(Opcode, C, A, SQ.getWithInstruction(&I))) { // It simplifies to V. Form "V op B". replaceOperand(I, 0, V); replaceOperand(I, 1, B); @@ -505,7 +507,7 @@ bool InstCombinerImpl::SimplifyAssociativeOrCommutative(BinaryOperator &I) { Value *C = Op1->getOperand(1); // Does "C op A" simplify? - if (Value *V = SimplifyBinOp(Opcode, C, A, SQ.getWithInstruction(&I))) { + if (Value *V = simplifyBinOp(Opcode, C, A, SQ.getWithInstruction(&I))) { // It simplifies to V. Form "B op V". replaceOperand(I, 0, B); replaceOperand(I, 1, V); @@ -652,7 +654,7 @@ Value *InstCombinerImpl::tryFactorization(BinaryOperator &I, std::swap(C, D); // Consider forming "A op' (B op D)". // If "B op D" simplifies then it can be formed with no cost. - V = SimplifyBinOp(TopLevelOpcode, B, D, SQ.getWithInstruction(&I)); + V = simplifyBinOp(TopLevelOpcode, B, D, SQ.getWithInstruction(&I)); // If "B op D" doesn't simplify then only go on if both of the existing // operations "A op' B" and "C op' D" will be zapped as no longer used. if (!V && LHS->hasOneUse() && RHS->hasOneUse()) @@ -671,7 +673,7 @@ Value *InstCombinerImpl::tryFactorization(BinaryOperator &I, std::swap(C, D); // Consider forming "(A op C) op' B". // If "A op C" simplifies then it can be formed with no cost. - V = SimplifyBinOp(TopLevelOpcode, A, C, SQ.getWithInstruction(&I)); + V = simplifyBinOp(TopLevelOpcode, A, C, SQ.getWithInstruction(&I)); // If "A op C" doesn't simplify then only go on if both of the existing // operations "A op' B" and "C op' D" will be zapped as no longer used. @@ -780,8 +782,8 @@ Value *InstCombinerImpl::SimplifyUsingDistributiveLaws(BinaryOperator &I) { // Disable the use of undef because it's not safe to distribute undef. auto SQDistributive = SQ.getWithInstruction(&I).getWithoutUndef(); - Value *L = SimplifyBinOp(TopLevelOpcode, A, C, SQDistributive); - Value *R = SimplifyBinOp(TopLevelOpcode, B, C, SQDistributive); + Value *L = simplifyBinOp(TopLevelOpcode, A, C, SQDistributive); + Value *R = simplifyBinOp(TopLevelOpcode, B, C, SQDistributive); // Do "A op C" and "B op C" both simplify? if (L && R) { @@ -819,8 +821,8 @@ Value *InstCombinerImpl::SimplifyUsingDistributiveLaws(BinaryOperator &I) { // Disable the use of undef because it's not safe to distribute undef. auto SQDistributive = SQ.getWithInstruction(&I).getWithoutUndef(); - Value *L = SimplifyBinOp(TopLevelOpcode, A, B, SQDistributive); - Value *R = SimplifyBinOp(TopLevelOpcode, A, C, SQDistributive); + Value *L = simplifyBinOp(TopLevelOpcode, A, B, SQDistributive); + Value *R = simplifyBinOp(TopLevelOpcode, A, C, SQDistributive); // Do "A op B" and "A op C" both simplify? if (L && R) { @@ -876,8 +878,8 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, if (LHSIsSelect && RHSIsSelect && A == D) { // (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F) Cond = A; - True = SimplifyBinOp(Opcode, B, E, FMF, Q); - False = SimplifyBinOp(Opcode, C, F, FMF, Q); + True = simplifyBinOp(Opcode, B, E, FMF, Q); + False = simplifyBinOp(Opcode, C, F, FMF, Q); if (LHS->hasOneUse() && RHS->hasOneUse()) { if (False && !True) @@ -888,13 +890,13 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, } else if (LHSIsSelect && LHS->hasOneUse()) { // (A ? B : C) op Y -> A ? (B op Y) : (C op Y) Cond = A; - True = SimplifyBinOp(Opcode, B, RHS, FMF, Q); - False = SimplifyBinOp(Opcode, C, RHS, FMF, Q); + True = simplifyBinOp(Opcode, B, RHS, FMF, Q); + False = simplifyBinOp(Opcode, C, RHS, FMF, Q); } else if (RHSIsSelect && RHS->hasOneUse()) { // X op (D ? E : F) -> D ? (X op E) : (X op F) Cond = D; - True = SimplifyBinOp(Opcode, LHS, E, FMF, Q); - False = SimplifyBinOp(Opcode, LHS, F, FMF, Q); + True = simplifyBinOp(Opcode, LHS, E, FMF, Q); + False = simplifyBinOp(Opcode, LHS, F, FMF, Q); } if (!True || !False) @@ -986,8 +988,8 @@ Instruction *InstCombinerImpl::foldBinopOfSextBoolToSelect(BinaryOperator &BO) { // bo (sext i1 X), C --> select X, (bo -1, C), (bo 0, C) Constant *Ones = ConstantInt::getAllOnesValue(BO.getType()); Constant *Zero = ConstantInt::getNullValue(BO.getType()); - Constant *TVal = ConstantExpr::get(BO.getOpcode(), Ones, C); - Constant *FVal = ConstantExpr::get(BO.getOpcode(), Zero, C); + Value *TVal = Builder.CreateBinOp(BO.getOpcode(), Ones, C); + Value *FVal = Builder.CreateBinOp(BO.getOpcode(), Zero, C); return SelectInst::Create(X, TVal, FVal); } @@ -1018,12 +1020,6 @@ static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO, bool ConstIsRHS = isa<Constant>(I.getOperand(1)); Constant *ConstOperand = cast<Constant>(I.getOperand(ConstIsRHS)); - if (auto *SOC = dyn_cast<Constant>(SO)) { - if (ConstIsRHS) - return ConstantExpr::get(I.getOpcode(), SOC, ConstOperand); - return ConstantExpr::get(I.getOpcode(), ConstOperand, SOC); - } - Value *Op0 = SO, *Op1 = ConstOperand; if (!ConstIsRHS) std::swap(Op0, Op1); @@ -1035,10 +1031,10 @@ static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO, return NewBO; } -Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, - SelectInst *SI) { - // Don't modify shared select instructions. - if (!SI->hasOneUse()) +Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, + bool FoldWithMultiUse) { + // Don't modify shared select instructions unless set FoldWithMultiUse + if (!SI->hasOneUse() && !FoldWithMultiUse) return nullptr; Value *TV = SI->getTrueValue(); @@ -1114,12 +1110,6 @@ static Value *foldOperationIntoPhiValue(BinaryOperator *I, Value *InV, bool ConstIsRHS = isa<Constant>(I->getOperand(1)); Constant *C = cast<Constant>(I->getOperand(ConstIsRHS)); - if (auto *InC = dyn_cast<Constant>(InV)) { - if (ConstIsRHS) - return ConstantExpr::get(I->getOpcode(), InC, C); - return ConstantExpr::get(I->getOpcode(), C, InC); - } - Value *Op0 = InV, *Op1 = C; if (!ConstIsRHS) std::swap(Op0, Op1); @@ -1175,10 +1165,11 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { if (cast<Instruction>(InVal)->getParent() == NonConstBB) return nullptr; - // If the incoming non-constant value is in I's block, we will remove one - // instruction, but insert another equivalent one, leading to infinite - // instcombine. - if (isPotentiallyReachable(I.getParent(), NonConstBB, nullptr, &DT, LI)) + // If the incoming non-constant value is reachable from the phis block, + // we'll push the operation across a loop backedge. This could result in + // an infinite combine loop, and is generally non-profitable (especially + // if the operation was originally outside the loop). + if (isPotentiallyReachable(PN->getParent(), NonConstBB, nullptr, &DT, LI)) return nullptr; } @@ -1941,10 +1932,8 @@ static Instruction *foldSelectGEP(GetElementPtrInst &GEP, SmallVector<Value *, 4> IndexC(GEP.indices()); bool IsInBounds = GEP.isInBounds(); Type *Ty = GEP.getSourceElementType(); - Value *NewTrueC = IsInBounds ? Builder.CreateInBoundsGEP(Ty, TrueC, IndexC) - : Builder.CreateGEP(Ty, TrueC, IndexC); - Value *NewFalseC = IsInBounds ? Builder.CreateInBoundsGEP(Ty, FalseC, IndexC) - : Builder.CreateGEP(Ty, FalseC, IndexC); + Value *NewTrueC = Builder.CreateGEP(Ty, TrueC, IndexC, "", IsInBounds); + Value *NewFalseC = Builder.CreateGEP(Ty, FalseC, IndexC, "", IsInBounds); return SelectInst::Create(Cond, NewTrueC, NewFalseC, "", nullptr, Sel); } @@ -1953,13 +1942,11 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, // Combine Indices - If the source pointer to this getelementptr instruction // is a getelementptr instruction with matching element type, combine the // indices of the two getelementptr instructions into a single instruction. - if (Src->getResultElementType() != GEP.getSourceElementType()) - return nullptr; - if (!shouldMergeGEPs(*cast<GEPOperator>(&GEP), *Src)) return nullptr; - if (Src->getNumOperands() == 2 && GEP.getNumOperands() == 2 && + if (Src->getResultElementType() == GEP.getSourceElementType() && + Src->getNumOperands() == 2 && GEP.getNumOperands() == 2 && Src->hasOneUse()) { Value *GO1 = GEP.getOperand(1); Value *SO1 = Src->getOperand(1); @@ -1971,45 +1958,21 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, // invariant: this breaks the dependence between GEPs and allows LICM // to hoist the invariant part out of the loop. if (L->isLoopInvariant(GO1) && !L->isLoopInvariant(SO1)) { - // We have to be careful here. - // We have something like: - // %src = getelementptr <ty>, <ty>* %base, <ty> %idx - // %gep = getelementptr <ty>, <ty>* %src, <ty> %idx2 - // If we just swap idx & idx2 then we could inadvertantly - // change %src from a vector to a scalar, or vice versa. - // Cases: - // 1) %base a scalar & idx a scalar & idx2 a vector - // => Swapping idx & idx2 turns %src into a vector type. - // 2) %base a scalar & idx a vector & idx2 a scalar - // => Swapping idx & idx2 turns %src in a scalar type - // 3) %base, %idx, and %idx2 are scalars - // => %src & %gep are scalars - // => swapping idx & idx2 is safe - // 4) %base a vector - // => %src is a vector - // => swapping idx & idx2 is safe. - auto *SO0 = Src->getOperand(0); - auto *SO0Ty = SO0->getType(); - if (!isa<VectorType>(GEP.getType()) || // case 3 - isa<VectorType>(SO0Ty)) { // case 4 - Src->setOperand(1, GO1); - GEP.setOperand(1, SO1); - return &GEP; - } else { - // Case 1 or 2 - // -- have to recreate %src & %gep - // put NewSrc at same location as %src - Builder.SetInsertPoint(cast<Instruction>(Src)); - Value *NewSrc = Builder.CreateGEP( - GEP.getSourceElementType(), SO0, GO1, Src->getName()); - // Propagate 'inbounds' if the new source was not constant-folded. - if (auto *NewSrcGEPI = dyn_cast<GetElementPtrInst>(NewSrc)) - NewSrcGEPI->setIsInBounds(Src->isInBounds()); - GetElementPtrInst *NewGEP = GetElementPtrInst::Create( - GEP.getSourceElementType(), NewSrc, {SO1}); - NewGEP->setIsInBounds(GEP.isInBounds()); - return NewGEP; - } + // The swapped GEPs are inbounds if both original GEPs are inbounds + // and the sign of the offsets is the same. For simplicity, only + // handle both offsets being non-negative. + bool IsInBounds = Src->isInBounds() && GEP.isInBounds() && + isKnownNonNegative(SO1, DL, 0, &AC, &GEP, &DT) && + isKnownNonNegative(GO1, DL, 0, &AC, &GEP, &DT); + // Put NewSrc at same location as %src. + Builder.SetInsertPoint(cast<Instruction>(Src)); + Value *NewSrc = Builder.CreateGEP(GEP.getSourceElementType(), + Src->getPointerOperand(), GO1, + Src->getName(), IsInBounds); + GetElementPtrInst *NewGEP = GetElementPtrInst::Create( + GEP.getSourceElementType(), NewSrc, {SO1}); + NewGEP->setIsInBounds(IsInBounds); + return NewGEP; } } } @@ -2022,6 +1985,87 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, if (SrcGEP->getNumOperands() == 2 && shouldMergeGEPs(*Src, *SrcGEP)) return nullptr; // Wait until our source is folded to completion. + // For constant GEPs, use a more general offset-based folding approach. + // Only do this for opaque pointers, as the result element type may change. + Type *PtrTy = Src->getType()->getScalarType(); + if (PtrTy->isOpaquePointerTy() && GEP.hasAllConstantIndices() && + (Src->hasOneUse() || Src->hasAllConstantIndices())) { + // Split Src into a variable part and a constant suffix. + gep_type_iterator GTI = gep_type_begin(*Src); + Type *BaseType = GTI.getIndexedType(); + bool IsFirstType = true; + unsigned NumVarIndices = 0; + for (auto Pair : enumerate(Src->indices())) { + if (!isa<ConstantInt>(Pair.value())) { + BaseType = GTI.getIndexedType(); + IsFirstType = false; + NumVarIndices = Pair.index() + 1; + } + ++GTI; + } + + // Determine the offset for the constant suffix of Src. + APInt Offset(DL.getIndexTypeSizeInBits(PtrTy), 0); + if (NumVarIndices != Src->getNumIndices()) { + // FIXME: getIndexedOffsetInType() does not handled scalable vectors. + if (isa<ScalableVectorType>(BaseType)) + return nullptr; + + SmallVector<Value *> ConstantIndices; + if (!IsFirstType) + ConstantIndices.push_back( + Constant::getNullValue(Type::getInt32Ty(GEP.getContext()))); + append_range(ConstantIndices, drop_begin(Src->indices(), NumVarIndices)); + Offset += DL.getIndexedOffsetInType(BaseType, ConstantIndices); + } + + // Add the offset for GEP (which is fully constant). + if (!GEP.accumulateConstantOffset(DL, Offset)) + return nullptr; + + APInt OffsetOld = Offset; + // Convert the total offset back into indices. + SmallVector<APInt> ConstIndices = + DL.getGEPIndicesForOffset(BaseType, Offset); + if (!Offset.isZero() || (!IsFirstType && !ConstIndices[0].isZero())) { + // If both GEP are constant-indexed, and cannot be merged in either way, + // convert them to a GEP of i8. + if (Src->hasAllConstantIndices()) + return isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP)) + ? GetElementPtrInst::CreateInBounds( + Builder.getInt8Ty(), Src->getOperand(0), + Builder.getInt(OffsetOld), GEP.getName()) + : GetElementPtrInst::Create( + Builder.getInt8Ty(), Src->getOperand(0), + Builder.getInt(OffsetOld), GEP.getName()); + return nullptr; + } + + bool IsInBounds = isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP)); + SmallVector<Value *> Indices; + append_range(Indices, drop_end(Src->indices(), + Src->getNumIndices() - NumVarIndices)); + for (const APInt &Idx : drop_begin(ConstIndices, !IsFirstType)) { + Indices.push_back(ConstantInt::get(GEP.getContext(), Idx)); + // Even if the total offset is inbounds, we may end up representing it + // by first performing a larger negative offset, and then a smaller + // positive one. The large negative offset might go out of bounds. Only + // preserve inbounds if all signs are the same. + IsInBounds &= Idx.isNonNegative() == ConstIndices[0].isNonNegative(); + } + + return IsInBounds + ? GetElementPtrInst::CreateInBounds(Src->getSourceElementType(), + Src->getOperand(0), Indices, + GEP.getName()) + : GetElementPtrInst::Create(Src->getSourceElementType(), + Src->getOperand(0), Indices, + GEP.getName()); + } + + if (Src->getResultElementType() != GEP.getSourceElementType()) + return nullptr; + SmallVector<Value*, 8> Indices; // Find out whether the last index in the source GEP is a sequential idx. @@ -2045,7 +2089,7 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, return nullptr; Value *Sum = - SimplifyAddInst(GO1, SO1, false, false, SQ.getWithInstruction(&GEP)); + simplifyAddInst(GO1, SO1, false, false, SQ.getWithInstruction(&GEP)); // Only do the combine when we are sure the cost after the // merge is never more than that before the merge. if (Sum == nullptr) @@ -2116,9 +2160,8 @@ Instruction *InstCombinerImpl::visitGEPOfBitcast(BitCastInst *BCI, // existing GEP Value. Causing issues if this Value is accessed when // constructing an AddrSpaceCastInst SmallVector<Value *, 8> Indices(GEP.indices()); - Value *NGEP = GEP.isInBounds() - ? Builder.CreateInBoundsGEP(SrcEltType, SrcOp, Indices) - : Builder.CreateGEP(SrcEltType, SrcOp, Indices); + Value *NGEP = + Builder.CreateGEP(SrcEltType, SrcOp, Indices, "", GEP.isInBounds()); NGEP->takeName(&GEP); // Preserve GEP address space to satisfy users @@ -2169,12 +2212,10 @@ Instruction *InstCombinerImpl::visitGEPOfBitcast(BitCastInst *BCI, // Otherwise, if the offset is non-zero, we need to find out if there is a // field at Offset in 'A's type. If so, we can pull the cast through the // GEP. - SmallVector<Value*, 8> NewIndices; + SmallVector<Value *, 8> NewIndices; if (findElementAtOffset(SrcType, Offset.getSExtValue(), NewIndices, DL)) { - Value *NGEP = - GEP.isInBounds() - ? Builder.CreateInBoundsGEP(SrcEltType, SrcOp, NewIndices) - : Builder.CreateGEP(SrcEltType, SrcOp, NewIndices); + Value *NGEP = Builder.CreateGEP(SrcEltType, SrcOp, NewIndices, "", + GEP.isInBounds()); if (NGEP->getType() == GEP.getType()) return replaceInstUsesWith(GEP, NGEP); @@ -2195,7 +2236,7 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { Type *GEPType = GEP.getType(); Type *GEPEltType = GEP.getSourceElementType(); bool IsGEPSrcEleScalable = isa<ScalableVectorType>(GEPEltType); - if (Value *V = SimplifyGEPInst(GEPEltType, PtrOp, Indices, GEP.isInBounds(), + if (Value *V = simplifyGEPInst(GEPEltType, PtrOp, Indices, GEP.isInBounds(), SQ.getWithInstruction(&GEP))) return replaceInstUsesWith(GEP, V); @@ -2280,7 +2321,8 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { for (auto I = PN->op_begin()+1, E = PN->op_end(); I !=E; ++I) { auto *Op2 = dyn_cast<GetElementPtrInst>(*I); - if (!Op2 || Op1->getNumOperands() != Op2->getNumOperands()) + if (!Op2 || Op1->getNumOperands() != Op2->getNumOperands() || + Op1->getSourceElementType() != Op2->getSourceElementType()) return nullptr; // As for Op1 above, don't try to fold a GEP into itself. @@ -2476,11 +2518,8 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { // addrspacecast i8 addrspace(1)* %0 to i8* SmallVector<Value *, 8> Idx(GEP.indices()); Value *NewGEP = - GEP.isInBounds() - ? Builder.CreateInBoundsGEP(StrippedPtrEltTy, StrippedPtr, - Idx, GEP.getName()) - : Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Idx, - GEP.getName()); + Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Idx, + GEP.getName(), GEP.isInBounds()); return new AddrSpaceCastInst(NewGEP, GEPType); } } @@ -2495,13 +2534,9 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { DL.getTypeAllocSize(StrippedPtrEltTy->getArrayElementType()) == DL.getTypeAllocSize(GEPEltType)) { Type *IdxType = DL.getIndexType(GEPType); - Value *Idx[2] = { Constant::getNullValue(IdxType), GEP.getOperand(1) }; - Value *NewGEP = - GEP.isInBounds() - ? Builder.CreateInBoundsGEP(StrippedPtrEltTy, StrippedPtr, Idx, - GEP.getName()) - : Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Idx, - GEP.getName()); + Value *Idx[2] = {Constant::getNullValue(IdxType), GEP.getOperand(1)}; + Value *NewGEP = Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Idx, + GEP.getName(), GEP.isInBounds()); // V and GEP are both pointer types --> BitCast return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, GEPType); @@ -2533,11 +2568,8 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { // If the multiplication NewIdx * Scale may overflow then the new // GEP may not be "inbounds". Value *NewGEP = - GEP.isInBounds() && NSW - ? Builder.CreateInBoundsGEP(StrippedPtrEltTy, StrippedPtr, - NewIdx, GEP.getName()) - : Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, NewIdx, - GEP.getName()); + Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, NewIdx, + GEP.getName(), GEP.isInBounds() && NSW); // The NewGEP must be pointer typed, so must the old one -> BitCast return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, @@ -2578,11 +2610,8 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { Value *Off[2] = {Constant::getNullValue(IndTy), NewIdx}; Value *NewGEP = - GEP.isInBounds() && NSW - ? Builder.CreateInBoundsGEP(StrippedPtrEltTy, StrippedPtr, - Off, GEP.getName()) - : Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Off, - GEP.getName()); + Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Off, + GEP.getName(), GEP.isInBounds() && NSW); // The NewGEP must be pointer typed, so must the old one -> BitCast return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, GEPType); @@ -2672,6 +2701,7 @@ static bool isAllocSiteRemovable(Instruction *AI, SmallVectorImpl<WeakTrackingVH> &Users, const TargetLibraryInfo &TLI) { SmallVector<Instruction*, 4> Worklist; + const Optional<StringRef> Family = getAllocationFamily(AI, &TLI); Worklist.push_back(AI); do { @@ -2740,12 +2770,15 @@ static bool isAllocSiteRemovable(Instruction *AI, continue; } - if (isFreeCall(I, &TLI)) { + if (isFreeCall(I, &TLI) && getAllocationFamily(I, &TLI) == Family) { + assert(Family); Users.emplace_back(I); continue; } - if (isReallocLikeFn(I, &TLI)) { + if (isReallocLikeFn(I, &TLI) && + getAllocationFamily(I, &TLI) == Family) { + assert(Family); Users.emplace_back(I); Worklist.push_back(I); continue; @@ -2803,7 +2836,7 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) { if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { if (II->getIntrinsicID() == Intrinsic::objectsize) { Value *Result = - lowerObjectSizeCall(II, DL, &TLI, /*MustSucceed=*/true); + lowerObjectSizeCall(II, DL, &TLI, AA, /*MustSucceed=*/true); replaceInstUsesWith(*I, Result); eraseInstFromFunction(*I); Users[i] = nullptr; // Skip examining in the next loop. @@ -3192,7 +3225,7 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { if (!EV.hasIndices()) return replaceInstUsesWith(EV, Agg); - if (Value *V = SimplifyExtractValueInst(Agg, EV.getIndices(), + if (Value *V = simplifyExtractValueInst(Agg, EV.getIndices(), SQ.getWithInstruction(&EV))) return replaceInstUsesWith(EV, V); @@ -3248,6 +3281,15 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { makeArrayRef(exti, exte)); } if (WithOverflowInst *WO = dyn_cast<WithOverflowInst>(Agg)) { + // extractvalue (any_mul_with_overflow X, -1), 0 --> -X + Intrinsic::ID OvID = WO->getIntrinsicID(); + if (*EV.idx_begin() == 0 && + (OvID == Intrinsic::smul_with_overflow || + OvID == Intrinsic::umul_with_overflow) && + match(WO->getArgOperand(1), m_AllOnes())) { + return BinaryOperator::CreateNeg(WO->getArgOperand(0)); + } + // We're extracting from an overflow intrinsic, see if we're the only user, // which allows us to simplify multiple result intrinsics to simpler // things that just get one value. @@ -3723,21 +3765,116 @@ InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating(FreezeInst &OrigFI) { if (!MaybePoisonOperand) return OrigOp; - auto *FrozenMaybePoisonOperand = new FreezeInst( + Builder.SetInsertPoint(OrigOpInst); + auto *FrozenMaybePoisonOperand = Builder.CreateFreeze( MaybePoisonOperand->get(), MaybePoisonOperand->get()->getName() + ".fr"); replaceUse(*MaybePoisonOperand, FrozenMaybePoisonOperand); - FrozenMaybePoisonOperand->insertBefore(OrigOpInst); return OrigOp; } -bool InstCombinerImpl::freezeDominatedUses(FreezeInst &FI) { +Instruction *InstCombinerImpl::foldFreezeIntoRecurrence(FreezeInst &FI, + PHINode *PN) { + // Detect whether this is a recurrence with a start value and some number of + // backedge values. We'll check whether we can push the freeze through the + // backedge values (possibly dropping poison flags along the way) until we + // reach the phi again. In that case, we can move the freeze to the start + // value. + Use *StartU = nullptr; + SmallVector<Value *> Worklist; + for (Use &U : PN->incoming_values()) { + if (DT.dominates(PN->getParent(), PN->getIncomingBlock(U))) { + // Add backedge value to worklist. + Worklist.push_back(U.get()); + continue; + } + + // Don't bother handling multiple start values. + if (StartU) + return nullptr; + StartU = &U; + } + + if (!StartU || Worklist.empty()) + return nullptr; // Not a recurrence. + + Value *StartV = StartU->get(); + BasicBlock *StartBB = PN->getIncomingBlock(*StartU); + bool StartNeedsFreeze = !isGuaranteedNotToBeUndefOrPoison(StartV); + // We can't insert freeze if the the start value is the result of the + // terminator (e.g. an invoke). + if (StartNeedsFreeze && StartBB->getTerminator() == StartV) + return nullptr; + + SmallPtrSet<Value *, 32> Visited; + SmallVector<Instruction *> DropFlags; + while (!Worklist.empty()) { + Value *V = Worklist.pop_back_val(); + if (!Visited.insert(V).second) + continue; + + if (Visited.size() > 32) + return nullptr; // Limit the total number of values we inspect. + + // Assume that PN is non-poison, because it will be after the transform. + if (V == PN || isGuaranteedNotToBeUndefOrPoison(V)) + continue; + + Instruction *I = dyn_cast<Instruction>(V); + if (!I || canCreateUndefOrPoison(cast<Operator>(I), + /*ConsiderFlags*/ false)) + return nullptr; + + DropFlags.push_back(I); + append_range(Worklist, I->operands()); + } + + for (Instruction *I : DropFlags) + I->dropPoisonGeneratingFlags(); + + if (StartNeedsFreeze) { + Builder.SetInsertPoint(StartBB->getTerminator()); + Value *FrozenStartV = Builder.CreateFreeze(StartV, + StartV->getName() + ".fr"); + replaceUse(*StartU, FrozenStartV); + } + return replaceInstUsesWith(FI, PN); +} + +bool InstCombinerImpl::freezeOtherUses(FreezeInst &FI) { Value *Op = FI.getOperand(0); - if (isa<Constant>(Op)) + if (isa<Constant>(Op) || Op->hasOneUse()) return false; + // Move the freeze directly after the definition of its operand, so that + // it dominates the maximum number of uses. Note that it may not dominate + // *all* uses if the operand is an invoke/callbr and the use is in a phi on + // the normal/default destination. This is why the domination check in the + // replacement below is still necessary. + Instruction *MoveBefore = nullptr; + if (isa<Argument>(Op)) { + MoveBefore = &FI.getFunction()->getEntryBlock().front(); + while (isa<AllocaInst>(MoveBefore)) + MoveBefore = MoveBefore->getNextNode(); + } else if (auto *PN = dyn_cast<PHINode>(Op)) { + MoveBefore = PN->getParent()->getFirstNonPHI(); + } else if (auto *II = dyn_cast<InvokeInst>(Op)) { + MoveBefore = II->getNormalDest()->getFirstNonPHI(); + } else if (auto *CB = dyn_cast<CallBrInst>(Op)) { + MoveBefore = CB->getDefaultDest()->getFirstNonPHI(); + } else { + auto *I = cast<Instruction>(Op); + assert(!I->isTerminator() && "Cannot be a terminator"); + MoveBefore = I->getNextNode(); + } + bool Changed = false; + if (&FI != MoveBefore) { + FI.moveBefore(MoveBefore); + Changed = true; + } + Op->replaceUsesWithIf(&FI, [&](Use &U) -> bool { bool Dominates = DT.dominates(&FI, U); Changed |= Dominates; @@ -3750,48 +3887,63 @@ bool InstCombinerImpl::freezeDominatedUses(FreezeInst &FI) { Instruction *InstCombinerImpl::visitFreeze(FreezeInst &I) { Value *Op0 = I.getOperand(0); - if (Value *V = SimplifyFreezeInst(Op0, SQ.getWithInstruction(&I))) + if (Value *V = simplifyFreezeInst(Op0, SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); // freeze (phi const, x) --> phi const, (freeze x) if (auto *PN = dyn_cast<PHINode>(Op0)) { if (Instruction *NV = foldOpIntoPhi(I, PN)) return NV; + if (Instruction *NV = foldFreezeIntoRecurrence(I, PN)) + return NV; } if (Value *NI = pushFreezeToPreventPoisonFromPropagating(I)) return replaceInstUsesWith(I, NI); - if (match(Op0, m_Undef())) { - // If I is freeze(undef), see its uses and fold it to the best constant. - // - or: pick -1 - // - select's condition: pick the value that leads to choosing a constant - // - other ops: pick 0 + // If I is freeze(undef), check its uses and fold it to a fixed constant. + // - or: pick -1 + // - select's condition: if the true value is constant, choose it by making + // the condition true. + // - default: pick 0 + // + // Note that this transform is intentionally done here rather than + // via an analysis in InstSimplify or at individual user sites. That is + // because we must produce the same value for all uses of the freeze - + // it's the reason "freeze" exists! + // + // TODO: This could use getBinopAbsorber() / getBinopIdentity() to avoid + // duplicating logic for binops at least. + auto getUndefReplacement = [&I](Type *Ty) { Constant *BestValue = nullptr; - Constant *NullValue = Constant::getNullValue(I.getType()); + Constant *NullValue = Constant::getNullValue(Ty); for (const auto *U : I.users()) { Constant *C = NullValue; - if (match(U, m_Or(m_Value(), m_Value()))) - C = Constant::getAllOnesValue(I.getType()); - else if (const auto *SI = dyn_cast<SelectInst>(U)) { - if (SI->getCondition() == &I) { - APInt CondVal(1, isa<Constant>(SI->getFalseValue()) ? 0 : 1); - C = Constant::getIntegerValue(I.getType(), CondVal); - } - } + C = ConstantInt::getAllOnesValue(Ty); + else if (match(U, m_Select(m_Specific(&I), m_Constant(), m_Value()))) + C = ConstantInt::getTrue(Ty); if (!BestValue) BestValue = C; else if (BestValue != C) BestValue = NullValue; } + assert(BestValue && "Must have at least one use"); + return BestValue; + }; - return replaceInstUsesWith(I, BestValue); + if (match(Op0, m_Undef())) + return replaceInstUsesWith(I, getUndefReplacement(I.getType())); + + Constant *C; + if (match(Op0, m_Constant(C)) && C->containsUndefOrPoisonElement()) { + Constant *ReplaceC = getUndefReplacement(I.getType()->getScalarType()); + return replaceInstUsesWith(I, Constant::replaceUndefsWith(C, ReplaceC)); } - // Replace all dominated uses of Op to freeze(Op). - if (freezeDominatedUses(I)) + // Replace uses of Op with freeze(Op). + if (freezeOtherUses(I)) return &I; return nullptr; @@ -3847,7 +3999,6 @@ static bool SoleWriteToDeadLocal(Instruction *I, TargetLibraryInfo &TLI) { /// block. static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock, TargetLibraryInfo &TLI) { - assert(I->getUniqueUndroppableUser() && "Invariants didn't hold!"); BasicBlock *SrcBlock = I->getParent(); // Cannot move control-flow-involving, volatile loads, vaarg, etc. @@ -4014,48 +4165,68 @@ bool InstCombinerImpl::run() { [this](Instruction *I) -> Optional<BasicBlock *> { if (!EnableCodeSinking) return None; - auto *UserInst = cast_or_null<Instruction>(I->getUniqueUndroppableUser()); - if (!UserInst) - return None; BasicBlock *BB = I->getParent(); BasicBlock *UserParent = nullptr; + unsigned NumUsers = 0; - // Special handling for Phi nodes - get the block the use occurs in. - if (PHINode *PN = dyn_cast<PHINode>(UserInst)) { - for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) { - if (PN->getIncomingValue(i) == I) { - // Bail out if we have uses in different blocks. We don't do any - // sophisticated analysis (i.e finding NearestCommonDominator of these - // use blocks). - if (UserParent && UserParent != PN->getIncomingBlock(i)) - return None; - UserParent = PN->getIncomingBlock(i); + for (auto *U : I->users()) { + if (U->isDroppable()) + continue; + if (NumUsers > MaxSinkNumUsers) + return None; + + Instruction *UserInst = cast<Instruction>(U); + // Special handling for Phi nodes - get the block the use occurs in. + if (PHINode *PN = dyn_cast<PHINode>(UserInst)) { + for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) { + if (PN->getIncomingValue(i) == I) { + // Bail out if we have uses in different blocks. We don't do any + // sophisticated analysis (i.e finding NearestCommonDominator of + // these use blocks). + if (UserParent && UserParent != PN->getIncomingBlock(i)) + return None; + UserParent = PN->getIncomingBlock(i); + } } + assert(UserParent && "expected to find user block!"); + } else { + if (UserParent && UserParent != UserInst->getParent()) + return None; + UserParent = UserInst->getParent(); } - assert(UserParent && "expected to find user block!"); - } else - UserParent = UserInst->getParent(); - // Try sinking to another block. If that block is unreachable, then do - // not bother. SimplifyCFG should handle it. - if (UserParent == BB || !DT.isReachableFromEntry(UserParent)) - return None; + // Make sure these checks are done only once, naturally we do the checks + // the first time we get the userparent, this will save compile time. + if (NumUsers == 0) { + // Try sinking to another block. If that block is unreachable, then do + // not bother. SimplifyCFG should handle it. + if (UserParent == BB || !DT.isReachableFromEntry(UserParent)) + return None; + + auto *Term = UserParent->getTerminator(); + // See if the user is one of our successors that has only one + // predecessor, so that we don't have to split the critical edge. + // Another option where we can sink is a block that ends with a + // terminator that does not pass control to other block (such as + // return or unreachable or resume). In this case: + // - I dominates the User (by SSA form); + // - the User will be executed at most once. + // So sinking I down to User is always profitable or neutral. + if (UserParent->getUniquePredecessor() != BB && !succ_empty(Term)) + return None; + + assert(DT.dominates(BB, UserParent) && "Dominance relation broken?"); + } - auto *Term = UserParent->getTerminator(); - // See if the user is one of our successors that has only one - // predecessor, so that we don't have to split the critical edge. - // Another option where we can sink is a block that ends with a - // terminator that does not pass control to other block (such as - // return or unreachable or resume). In this case: - // - I dominates the User (by SSA form); - // - the User will be executed at most once. - // So sinking I down to User is always profitable or neutral. - if (UserParent->getUniquePredecessor() == BB || succ_empty(Term)) { - assert(DT.dominates(BB, UserParent) && "Dominance relation broken?"); - return UserParent; + NumUsers++; } - return None; + + // No user or only has droppable users. + if (!UserParent) + return None; + + return UserParent; }; auto OptBB = getOptionalSinkBlockForInst(I); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp index 8f94172a6402..7a5a74aa4fff 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -31,6 +31,7 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/BinaryFormat/MachO.h" +#include "llvm/Demangle/Demangle.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -42,14 +43,12 @@ #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" -#include "llvm/IR/InstIterator.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" @@ -63,15 +62,12 @@ #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" #include "llvm/MC/MCSectionMachO.h" -#include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" -#include "llvm/Support/ScopedPrinter.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Instrumentation/AddressSanitizerCommon.h" @@ -87,7 +83,6 @@ #include <cstdint> #include <iomanip> #include <limits> -#include <memory> #include <sstream> #include <string> #include <tuple> @@ -116,7 +111,7 @@ static const uint64_t kFreeBSDKasan_ShadowOffset64 = 0xdffff7c000000000; static const uint64_t kNetBSD_ShadowOffset32 = 1ULL << 30; static const uint64_t kNetBSD_ShadowOffset64 = 1ULL << 46; static const uint64_t kNetBSDKasan_ShadowOffset64 = 0xdfff900000000000; -static const uint64_t kPS4CPU_ShadowOffset64 = 1ULL << 40; +static const uint64_t kPS_ShadowOffset64 = 1ULL << 40; static const uint64_t kWindowsShadowOffset32 = 3ULL << 28; static const uint64_t kEmscriptenShadowOffset = 0; @@ -335,6 +330,11 @@ static cl::opt<std::string> ClMemoryAccessCallbackPrefix( cl::desc("Prefix for memory access callbacks"), cl::Hidden, cl::init("__asan_")); +static cl::opt<bool> ClKasanMemIntrinCallbackPrefix( + "asan-kernel-mem-intrinsic-prefix", + cl::desc("Use prefix for memory intrinsics in KASAN mode"), cl::Hidden, + cl::init(false)); + static cl::opt<bool> ClInstrumentDynamicAllocas("asan-instrument-dynamic-allocas", cl::desc("instrument dynamic allocas"), @@ -465,11 +465,12 @@ struct ShadowMapping { static ShadowMapping getShadowMapping(const Triple &TargetTriple, int LongSize, bool IsKasan) { bool IsAndroid = TargetTriple.isAndroid(); - bool IsIOS = TargetTriple.isiOS() || TargetTriple.isWatchOS(); + bool IsIOS = TargetTriple.isiOS() || TargetTriple.isWatchOS() || + TargetTriple.isDriverKit(); bool IsMacOS = TargetTriple.isMacOSX(); bool IsFreeBSD = TargetTriple.isOSFreeBSD(); bool IsNetBSD = TargetTriple.isOSNetBSD(); - bool IsPS4CPU = TargetTriple.isPS4CPU(); + bool IsPS = TargetTriple.isPS(); bool IsLinux = TargetTriple.isOSLinux(); bool IsPPC64 = TargetTriple.getArch() == Triple::ppc64 || TargetTriple.getArch() == Triple::ppc64le; @@ -528,8 +529,8 @@ static ShadowMapping getShadowMapping(const Triple &TargetTriple, int LongSize, Mapping.Offset = kNetBSDKasan_ShadowOffset64; else Mapping.Offset = kNetBSD_ShadowOffset64; - } else if (IsPS4CPU) - Mapping.Offset = kPS4CPU_ShadowOffset64; + } else if (IsPS) + Mapping.Offset = kPS_ShadowOffset64; else if (IsLinux && IsX86_64) { if (IsKasan) Mapping.Offset = kLinuxKasan_ShadowOffset64; @@ -568,7 +569,7 @@ static ShadowMapping getShadowMapping(const Triple &TargetTriple, int LongSize, // offset is not necessary 1/8-th of the address space. On SystemZ, // we could OR the constant in a single instruction, but it's more // efficient to load it once and use indexed addressing. - Mapping.OrShadowOffset = !IsAArch64 && !IsPPC64 && !IsSystemZ && !IsPS4CPU && + Mapping.OrShadowOffset = !IsAArch64 && !IsPPC64 && !IsSystemZ && !IsPS && !IsRISCV64 && !(Mapping.Offset & (Mapping.Offset - 1)) && Mapping.Offset != kDynamicShadowSentinel; @@ -621,41 +622,9 @@ static uint64_t GetCtorAndDtorPriority(Triple &TargetTriple) { namespace { -/// Module analysis for getting various metadata about the module. -class ASanGlobalsMetadataWrapperPass : public ModulePass { -public: - static char ID; - - ASanGlobalsMetadataWrapperPass() : ModulePass(ID) { - initializeASanGlobalsMetadataWrapperPassPass( - *PassRegistry::getPassRegistry()); - } - - bool runOnModule(Module &M) override { - GlobalsMD = GlobalsMetadata(M); - return false; - } - - StringRef getPassName() const override { - return "ASanGlobalsMetadataWrapperPass"; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - } - - GlobalsMetadata &getGlobalsMD() { return GlobalsMD; } - -private: - GlobalsMetadata GlobalsMD; -}; - -char ASanGlobalsMetadataWrapperPass::ID = 0; - /// AddressSanitizer: instrument the code in module to find memory bugs. struct AddressSanitizer { - AddressSanitizer(Module &M, const GlobalsMetadata *GlobalsMD, - const StackSafetyGlobalInfo *SSGI, + AddressSanitizer(Module &M, const StackSafetyGlobalInfo *SSGI, bool CompileKernel = false, bool Recover = false, bool UseAfterScope = false, AsanDetectStackUseAfterReturnMode UseAfterReturn = @@ -666,7 +635,7 @@ struct AddressSanitizer { UseAfterScope(UseAfterScope || ClUseAfterScope), UseAfterReturn(ClUseAfterReturn.getNumOccurrences() ? ClUseAfterReturn : UseAfterReturn), - GlobalsMD(*GlobalsMD), SSGI(SSGI) { + SSGI(SSGI) { C = &(M.getContext()); LongSize = M.getDataLayout().getPointerSizeInBits(); IntptrTy = Type::getIntNTy(*C, LongSize); @@ -779,7 +748,6 @@ private: FunctionCallee AsanMemmove, AsanMemcpy, AsanMemset; Value *LocalDynamicShadow = nullptr; - const GlobalsMetadata &GlobalsMD; const StackSafetyGlobalInfo *SSGI; DenseMap<const AllocaInst *, bool> ProcessedAllocas; @@ -787,60 +755,13 @@ private: FunctionCallee AMDGPUAddressPrivate; }; -class AddressSanitizerLegacyPass : public FunctionPass { -public: - static char ID; - - explicit AddressSanitizerLegacyPass( - bool CompileKernel = false, bool Recover = false, - bool UseAfterScope = false, - AsanDetectStackUseAfterReturnMode UseAfterReturn = - AsanDetectStackUseAfterReturnMode::Runtime) - : FunctionPass(ID), CompileKernel(CompileKernel), Recover(Recover), - UseAfterScope(UseAfterScope), UseAfterReturn(UseAfterReturn) { - initializeAddressSanitizerLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - StringRef getPassName() const override { - return "AddressSanitizerFunctionPass"; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<ASanGlobalsMetadataWrapperPass>(); - if (ClUseStackSafety) - AU.addRequired<StackSafetyGlobalInfoWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - } - - bool runOnFunction(Function &F) override { - GlobalsMetadata &GlobalsMD = - getAnalysis<ASanGlobalsMetadataWrapperPass>().getGlobalsMD(); - const StackSafetyGlobalInfo *const SSGI = - ClUseStackSafety - ? &getAnalysis<StackSafetyGlobalInfoWrapperPass>().getResult() - : nullptr; - const TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - AddressSanitizer ASan(*F.getParent(), &GlobalsMD, SSGI, CompileKernel, - Recover, UseAfterScope, UseAfterReturn); - return ASan.instrumentFunction(F, TLI); - } - -private: - bool CompileKernel; - bool Recover; - bool UseAfterScope; - AsanDetectStackUseAfterReturnMode UseAfterReturn; -}; - class ModuleAddressSanitizer { public: - ModuleAddressSanitizer(Module &M, const GlobalsMetadata *GlobalsMD, - bool CompileKernel = false, bool Recover = false, - bool UseGlobalsGC = true, bool UseOdrIndicator = false, + ModuleAddressSanitizer(Module &M, bool CompileKernel = false, + bool Recover = false, bool UseGlobalsGC = true, + bool UseOdrIndicator = false, AsanDtorKind DestructorKind = AsanDtorKind::Global) - : GlobalsMD(*GlobalsMD), - CompileKernel(ClEnableKasan.getNumOccurrences() > 0 ? ClEnableKasan + : CompileKernel(ClEnableKasan.getNumOccurrences() > 0 ? ClEnableKasan : CompileKernel), Recover(ClRecover.getNumOccurrences() > 0 ? ClRecover : Recover), UseGlobalsGC(UseGlobalsGC && ClUseGlobalsGC && !this->CompileKernel), @@ -906,7 +827,6 @@ private: uint64_t getRedzoneSizeForGlobal(uint64_t SizeInBytes) const; int GetAsanVersion(const Module &M) const; - const GlobalsMetadata &GlobalsMD; bool CompileKernel; bool Recover; bool UseGlobalsGC; @@ -931,44 +851,6 @@ private: Function *AsanDtorFunction = nullptr; }; -class ModuleAddressSanitizerLegacyPass : public ModulePass { -public: - static char ID; - - explicit ModuleAddressSanitizerLegacyPass( - bool CompileKernel = false, bool Recover = false, bool UseGlobalGC = true, - bool UseOdrIndicator = false, - AsanDtorKind DestructorKind = AsanDtorKind::Global) - : ModulePass(ID), CompileKernel(CompileKernel), Recover(Recover), - UseGlobalGC(UseGlobalGC), UseOdrIndicator(UseOdrIndicator), - DestructorKind(DestructorKind) { - initializeModuleAddressSanitizerLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - StringRef getPassName() const override { return "ModuleAddressSanitizer"; } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<ASanGlobalsMetadataWrapperPass>(); - } - - bool runOnModule(Module &M) override { - GlobalsMetadata &GlobalsMD = - getAnalysis<ASanGlobalsMetadataWrapperPass>().getGlobalsMD(); - ModuleAddressSanitizer ASanModule(M, &GlobalsMD, CompileKernel, Recover, - UseGlobalGC, UseOdrIndicator, - DestructorKind); - return ASanModule.instrumentModule(M); - } - -private: - bool CompileKernel; - bool Recover; - bool UseGlobalGC; - bool UseOdrIndicator; - AsanDtorKind DestructorKind; -}; - // Stack poisoning does not play well with exception handling. // When an exception is thrown, we essentially bypass the code // that unpoisones the stack. This is why the run-time library has @@ -1221,85 +1103,6 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { } // end anonymous namespace -void LocationMetadata::parse(MDNode *MDN) { - assert(MDN->getNumOperands() == 3); - MDString *DIFilename = cast<MDString>(MDN->getOperand(0)); - Filename = DIFilename->getString(); - LineNo = mdconst::extract<ConstantInt>(MDN->getOperand(1))->getLimitedValue(); - ColumnNo = - mdconst::extract<ConstantInt>(MDN->getOperand(2))->getLimitedValue(); -} - -// FIXME: It would be cleaner to instead attach relevant metadata to the globals -// we want to sanitize instead and reading this metadata on each pass over a -// function instead of reading module level metadata at first. -GlobalsMetadata::GlobalsMetadata(Module &M) { - NamedMDNode *Globals = M.getNamedMetadata("llvm.asan.globals"); - if (!Globals) - return; - for (auto MDN : Globals->operands()) { - // Metadata node contains the global and the fields of "Entry". - assert(MDN->getNumOperands() == 5); - auto *V = mdconst::extract_or_null<Constant>(MDN->getOperand(0)); - // The optimizer may optimize away a global entirely. - if (!V) - continue; - auto *StrippedV = V->stripPointerCasts(); - auto *GV = dyn_cast<GlobalVariable>(StrippedV); - if (!GV) - continue; - // We can already have an entry for GV if it was merged with another - // global. - Entry &E = Entries[GV]; - if (auto *Loc = cast_or_null<MDNode>(MDN->getOperand(1))) - E.SourceLoc.parse(Loc); - if (auto *Name = cast_or_null<MDString>(MDN->getOperand(2))) - E.Name = Name->getString(); - ConstantInt *IsDynInit = mdconst::extract<ConstantInt>(MDN->getOperand(3)); - E.IsDynInit |= IsDynInit->isOne(); - ConstantInt *IsExcluded = - mdconst::extract<ConstantInt>(MDN->getOperand(4)); - E.IsExcluded |= IsExcluded->isOne(); - } -} - -AnalysisKey ASanGlobalsMetadataAnalysis::Key; - -GlobalsMetadata ASanGlobalsMetadataAnalysis::run(Module &M, - ModuleAnalysisManager &AM) { - return GlobalsMetadata(M); -} - -PreservedAnalyses AddressSanitizerPass::run(Function &F, - AnalysisManager<Function> &AM) { - auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F); - Module &M = *F.getParent(); - if (auto *R = MAMProxy.getCachedResult<ASanGlobalsMetadataAnalysis>(M)) { - const TargetLibraryInfo *TLI = &AM.getResult<TargetLibraryAnalysis>(F); - AddressSanitizer Sanitizer(M, R, nullptr, Options.CompileKernel, - Options.Recover, Options.UseAfterScope, - Options.UseAfterReturn); - if (Sanitizer.instrumentFunction(F, TLI)) - return PreservedAnalyses::none(); - return PreservedAnalyses::all(); - } - - report_fatal_error( - "The ASanGlobalsMetadataAnalysis is required to run before " - "AddressSanitizer can run"); - return PreservedAnalyses::all(); -} - -void AddressSanitizerPass::printPipeline( - raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { - static_cast<PassInfoMixin<AddressSanitizerPass> *>(this)->printPipeline( - OS, MapClassName2PassName); - OS << "<"; - if (Options.CompileKernel) - OS << "kernel"; - OS << ">"; -} - void ModuleAddressSanitizerPass::printPipeline( raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { static_cast<PassInfoMixin<ModuleAddressSanitizerPass> *>(this)->printPipeline( @@ -1318,8 +1121,7 @@ ModuleAddressSanitizerPass::ModuleAddressSanitizerPass( PreservedAnalyses ModuleAddressSanitizerPass::run(Module &M, ModuleAnalysisManager &MAM) { - GlobalsMetadata &GlobalsMD = MAM.getResult<ASanGlobalsMetadataAnalysis>(M); - ModuleAddressSanitizer ModuleSanitizer(M, &GlobalsMD, Options.CompileKernel, + ModuleAddressSanitizer ModuleSanitizer(M, Options.CompileKernel, Options.Recover, UseGlobalGC, UseOdrIndicator, DestructorKind); bool Modified = false; @@ -1327,9 +1129,9 @@ PreservedAnalyses ModuleAddressSanitizerPass::run(Module &M, const StackSafetyGlobalInfo *const SSGI = ClUseStackSafety ? &MAM.getResult<StackSafetyGlobalAnalysis>(M) : nullptr; for (Function &F : M) { - AddressSanitizer FunctionSanitizer( - M, &GlobalsMD, SSGI, Options.CompileKernel, Options.Recover, - Options.UseAfterScope, Options.UseAfterReturn); + AddressSanitizer FunctionSanitizer(M, SSGI, Options.CompileKernel, + Options.Recover, Options.UseAfterScope, + Options.UseAfterReturn); const TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(F); Modified |= FunctionSanitizer.instrumentFunction(F, &TLI); } @@ -1337,75 +1139,20 @@ PreservedAnalyses ModuleAddressSanitizerPass::run(Module &M, return Modified ? PreservedAnalyses::none() : PreservedAnalyses::all(); } -INITIALIZE_PASS(ASanGlobalsMetadataWrapperPass, "asan-globals-md", - "Read metadata to mark which globals should be instrumented " - "when running ASan.", - false, true) - -char AddressSanitizerLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN( - AddressSanitizerLegacyPass, "asan", - "AddressSanitizer: detects use-after-free and out-of-bounds bugs.", false, - false) -INITIALIZE_PASS_DEPENDENCY(ASanGlobalsMetadataWrapperPass) -INITIALIZE_PASS_DEPENDENCY(StackSafetyGlobalInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END( - AddressSanitizerLegacyPass, "asan", - "AddressSanitizer: detects use-after-free and out-of-bounds bugs.", false, - false) - -FunctionPass *llvm::createAddressSanitizerFunctionPass( - bool CompileKernel, bool Recover, bool UseAfterScope, - AsanDetectStackUseAfterReturnMode UseAfterReturn) { - assert(!CompileKernel || Recover); - return new AddressSanitizerLegacyPass(CompileKernel, Recover, UseAfterScope, - UseAfterReturn); -} - -char ModuleAddressSanitizerLegacyPass::ID = 0; - -INITIALIZE_PASS( - ModuleAddressSanitizerLegacyPass, "asan-module", - "AddressSanitizer: detects use-after-free and out-of-bounds bugs." - "ModulePass", - false, false) - -ModulePass *llvm::createModuleAddressSanitizerLegacyPassPass( - bool CompileKernel, bool Recover, bool UseGlobalsGC, bool UseOdrIndicator, - AsanDtorKind Destructor) { - assert(!CompileKernel || Recover); - return new ModuleAddressSanitizerLegacyPass( - CompileKernel, Recover, UseGlobalsGC, UseOdrIndicator, Destructor); -} - static size_t TypeSizeToSizeIndex(uint32_t TypeSize) { size_t Res = countTrailingZeros(TypeSize / 8); assert(Res < kNumberOfAccessSizes); return Res; } -/// Create a global describing a source location. -static GlobalVariable *createPrivateGlobalForSourceLoc(Module &M, - LocationMetadata MD) { - Constant *LocData[] = { - createPrivateGlobalForString(M, MD.Filename, true, kAsanGenPrefix), - ConstantInt::get(Type::getInt32Ty(M.getContext()), MD.LineNo), - ConstantInt::get(Type::getInt32Ty(M.getContext()), MD.ColumnNo), - }; - auto LocStruct = ConstantStruct::getAnon(LocData); - auto GV = new GlobalVariable(M, LocStruct->getType(), true, - GlobalValue::PrivateLinkage, LocStruct, - kAsanGenPrefix); - GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); - return GV; -} - /// Check if \p G has been created by a trusted compiler pass. static bool GlobalWasGeneratedByCompiler(GlobalVariable *G) { // Do not instrument @llvm.global_ctors, @llvm.used, etc. - if (G->getName().startswith("llvm.")) + if (G->getName().startswith("llvm.") || + // Do not instrument gcov counter arrays. + G->getName().startswith("__llvm_gcov_ctr") || + // Do not instrument rtti proxy symbols for function sanitizer. + G->getName().startswith("__llvm_rtti_proxy")) return true; // Do not instrument asan globals. @@ -1414,10 +1161,6 @@ static bool GlobalWasGeneratedByCompiler(GlobalVariable *G) { G->getName().startswith(kODRGenPrefix)) return true; - // Do not instrument gcov counter arrays. - if (G->getName() == "__llvm_gcov_ctr") - return true; - return false; } @@ -1518,10 +1261,6 @@ bool AddressSanitizer::ignoreAccess(Instruction *Inst, Value *Ptr) { void AddressSanitizer::getInterestingMemoryOperands( Instruction *I, SmallVectorImpl<InterestingMemoryOperand> &Interesting) { - // Skip memory accesses inserted by another instrumentation. - if (I->hasMetadata("nosanitize")) - return; - // Do not instrument the load fetching the dynamic shadow address. if (LocalDynamicShadow == I) return; @@ -1613,10 +1352,13 @@ bool AddressSanitizer::GlobalIsLinkerInitialized(GlobalVariable *G) { // If a global variable does not have dynamic initialization we don't // have to instrument it. However, if a global does not have initializer // at all, we assume it has dynamic initializer (in other TU). - // - // FIXME: Metadata should be attched directly to the global directly instead - // of being added to llvm.asan.globals. - return G->hasInitializer() && !GlobalsMD.get(G).IsDynInit; + if (!G->hasInitializer()) + return false; + + if (G->hasSanitizerMetadata() && G->getSanitizerMetadata().IsDynInit) + return false; + + return true; } void AddressSanitizer::instrumentPointerComparisonOrSubtraction( @@ -1977,9 +1719,8 @@ bool ModuleAddressSanitizer::shouldInstrumentGlobal(GlobalVariable *G) const { Type *Ty = G->getValueType(); LLVM_DEBUG(dbgs() << "GLOBAL: " << *G << "\n"); - // FIXME: Metadata should be attched directly to the global directly instead - // of being added to llvm.asan.globals. - if (GlobalsMD.get(G).IsExcluded) return false; + if (G->hasSanitizerMetadata() && G->getSanitizerMetadata().NoAddress) + return false; if (!Ty->isSized()) return false; if (!G->hasInitializer()) return false; // Globals in address space 1 and 4 are supported for AMDGPU. @@ -2125,6 +1866,8 @@ bool ModuleAddressSanitizer::ShouldUseMachOGlobalsSection() const { return true; if (TargetTriple.isWatchOS() && !TargetTriple.isOSVersionLT(2)) return true; + if (TargetTriple.isDriverKit()) + return true; return false; } @@ -2136,7 +1879,9 @@ StringRef ModuleAddressSanitizer::getGlobalMetadataSection() const { case Triple::MachO: return "__DATA,__asan_globals,regular"; case Triple::Wasm: case Triple::GOFF: + case Triple::SPIRV: case Triple::XCOFF: + case Triple::DXContainer: report_fatal_error( "ModuleAddressSanitizer not implemented for object file format"); case Triple::UnknownObjectFormat: @@ -2470,7 +2215,7 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M, // const char *name; // const char *module_name; // size_t has_dynamic_init; - // void *source_location; + // size_t padding_for_windows_msvc_incremental_link; // size_t odr_indicator; // We initialize an array of such structures and pass it to a run-time call. StructType *GlobalStructTy = @@ -2489,15 +2234,16 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M, for (size_t i = 0; i < n; i++) { GlobalVariable *G = GlobalsToChange[i]; - // FIXME: Metadata should be attched directly to the global directly instead - // of being added to llvm.asan.globals. - auto MD = GlobalsMD.get(G); - StringRef NameForGlobal = G->getName(); - // Create string holding the global name (use global name from metadata - // if it's available, otherwise just write the name of global variable). - GlobalVariable *Name = createPrivateGlobalForString( - M, MD.Name.empty() ? NameForGlobal : MD.Name, - /*AllowMerging*/ true, kAsanGenPrefix); + GlobalValue::SanitizerMetadata MD; + if (G->hasSanitizerMetadata()) + MD = G->getSanitizerMetadata(); + + // TODO: Symbol names in the descriptor can be demangled by the runtime + // library. This could save ~0.4% of VM size for a private large binary. + std::string NameForGlobal = llvm::demangle(G->getName().str()); + GlobalVariable *Name = + createPrivateGlobalForString(M, NameForGlobal, + /*AllowMerging*/ true, kAsanGenPrefix); Type *Ty = G->getValueType(); const uint64_t SizeInBytes = DL.getTypeAllocSize(Ty); @@ -2545,14 +2291,6 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M, G->eraseFromParent(); NewGlobals[i] = NewGlobal; - Constant *SourceLoc; - if (!MD.SourceLoc.empty()) { - auto SourceLocGlobal = createPrivateGlobalForSourceLoc(M, MD.SourceLoc); - SourceLoc = ConstantExpr::getPointerCast(SourceLocGlobal, IntptrTy); - } else { - SourceLoc = ConstantInt::get(IntptrTy, 0); - } - Constant *ODRIndicator = ConstantExpr::getNullValue(IRB.getInt8PtrTy()); GlobalValue *InstrumentedGlobal = NewGlobal; @@ -2593,10 +2331,12 @@ bool ModuleAddressSanitizer::InstrumentGlobals(IRBuilder<> &IRB, Module &M, ConstantInt::get(IntptrTy, SizeInBytes + RightRedzoneSize), ConstantExpr::getPointerCast(Name, IntptrTy), ConstantExpr::getPointerCast(ModuleName, IntptrTy), - ConstantInt::get(IntptrTy, MD.IsDynInit), SourceLoc, + ConstantInt::get(IntptrTy, MD.IsDynInit), + Constant::getNullValue(IntptrTy), ConstantExpr::getPointerCast(ODRIndicator, IntptrTy)); - if (ClInitializers && MD.IsDynInit) HasDynamicallyInitializedGlobals = true; + if (ClInitializers && MD.IsDynInit) + HasDynamicallyInitializedGlobals = true; LLVM_DEBUG(dbgs() << "NEW GLOBAL: " << *NewGlobal << "\n"); @@ -2759,7 +2499,9 @@ void AddressSanitizer::initializeCallbacks(Module &M) { } const std::string MemIntrinCallbackPrefix = - CompileKernel ? std::string("") : ClMemoryAccessCallbackPrefix; + (CompileKernel && !ClKasanMemIntrinCallbackPrefix) + ? std::string("") + : ClMemoryAccessCallbackPrefix; AsanMemmove = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memmove", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy); @@ -2888,6 +2630,9 @@ bool AddressSanitizer::instrumentFunction(Function &F, // Leave if the function doesn't need instrumentation. if (!F.hasFnAttribute(Attribute::SanitizeAddress)) return FunctionModified; + if (F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation)) + return FunctionModified; + LLVM_DEBUG(dbgs() << "ASAN instrumenting:\n" << F << "\n"); initializeCallbacks(*F.getParent()); @@ -2908,7 +2653,6 @@ bool AddressSanitizer::instrumentFunction(Function &F, SmallVector<Instruction *, 8> NoReturnCalls; SmallVector<BasicBlock *, 16> AllBlocks; SmallVector<Instruction *, 16> PointerComparisonsOrSubtracts; - int NumAllocas = 0; // Fill the set of memory operations to instrument. for (auto &BB : F) { @@ -2917,6 +2661,9 @@ bool AddressSanitizer::instrumentFunction(Function &F, int NumInsnsPerBB = 0; for (auto &Inst : BB) { if (LooksLikeCodeInBug11395(&Inst)) return false; + // Skip instructions inserted by another instrumentation. + if (Inst.hasMetadata(LLVMContext::MD_nosanitize)) + continue; SmallVector<InterestingMemoryOperand, 1> InterestingOperands; getInterestingMemoryOperands(&Inst, InterestingOperands); @@ -2948,11 +2695,10 @@ bool AddressSanitizer::instrumentFunction(Function &F, IntrinToInstrument.push_back(MI); NumInsnsPerBB++; } else { - if (isa<AllocaInst>(Inst)) NumAllocas++; if (auto *CB = dyn_cast<CallBase>(&Inst)) { // A call inside BB. TempsToInstrument.clear(); - if (CB->doesNotReturn() && !CB->hasMetadata("nosanitize")) + if (CB->doesNotReturn()) NoReturnCalls.push_back(CB); } if (CallInst *CI = dyn_cast<CallInst>(&Inst)) @@ -3347,7 +3093,7 @@ void FunctionStackPoisoner::processStaticAllocas() { ASanStackVariableDescription D = {AI->getName().data(), ASan.getAllocaSizeInBytes(*AI), 0, - AI->getAlignment(), + AI->getAlign().value(), AI, 0, 0}; @@ -3611,7 +3357,7 @@ void FunctionStackPoisoner::poisonAlloca(Value *V, uint64_t Size, void FunctionStackPoisoner::handleDynamicAllocaCall(AllocaInst *AI) { IRBuilder<> IRB(AI); - const uint64_t Alignment = std::max(kAllocaRzSize, AI->getAlignment()); + const Align Alignment = std::max(Align(kAllocaRzSize), AI->getAlign()); const uint64_t AllocaRedzoneMask = kAllocaRzSize - 1; Value *Zero = Constant::getNullValue(IntptrTy); @@ -3642,17 +3388,19 @@ void FunctionStackPoisoner::handleDynamicAllocaCall(AllocaInst *AI) { // Alignment is added to locate left redzone, PartialPadding for possible // partial redzone and kAllocaRzSize for right redzone respectively. Value *AdditionalChunkSize = IRB.CreateAdd( - ConstantInt::get(IntptrTy, Alignment + kAllocaRzSize), PartialPadding); + ConstantInt::get(IntptrTy, Alignment.value() + kAllocaRzSize), + PartialPadding); Value *NewSize = IRB.CreateAdd(OldSize, AdditionalChunkSize); // Insert new alloca with new NewSize and Alignment params. AllocaInst *NewAlloca = IRB.CreateAlloca(IRB.getInt8Ty(), NewSize); - NewAlloca->setAlignment(Align(Alignment)); + NewAlloca->setAlignment(Alignment); // NewAddress = Address + Alignment - Value *NewAddress = IRB.CreateAdd(IRB.CreatePtrToInt(NewAlloca, IntptrTy), - ConstantInt::get(IntptrTy, Alignment)); + Value *NewAddress = + IRB.CreateAdd(IRB.CreatePtrToInt(NewAlloca, IntptrTy), + ConstantInt::get(IntptrTy, Alignment.value())); // Insert __asan_alloca_poison call for new created alloca. IRB.CreateCall(AsanAllocaPoisonFunc, {NewAddress, OldSize}); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp index 4ad07cab001a..1eadafb4e4b4 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp @@ -19,7 +19,6 @@ #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" -#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" @@ -29,7 +28,6 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include <cstdint> #include <utility> @@ -142,6 +140,9 @@ static void insertBoundsCheck(Value *Or, BuilderTy &IRB, GetTrapBBT GetTrapBB) { static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI, ScalarEvolution &SE) { + if (F.hasFnAttribute(Attribute::NoSanitizeBounds)) + return false; + const DataLayout &DL = F.getParent()->getDataLayout(); ObjectSizeOpts EvalOpts; EvalOpts.RoundToAlign = true; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/CGProfile.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/CGProfile.cpp index 1a7f7a365ce4..b11b84d65d23 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/CGProfile.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/CGProfile.cpp @@ -13,15 +13,12 @@ #include "llvm/Analysis/LazyBlockFrequencyInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Constants.h" -#include "llvm/IR/Instructions.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PassManager.h" #include "llvm/InitializePasses.h" #include "llvm/ProfileData/InstrProf.h" #include "llvm/Transforms/Instrumentation.h" -#include <array> - using namespace llvm; static bool diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp index 497aac30c3f6..e5c0705b916e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp @@ -26,6 +26,7 @@ #include "llvm/IR/CFG.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PassManager.h" #include "llvm/InitializePasses.h" @@ -145,27 +146,27 @@ FunctionPass *llvm::createControlHeightReductionLegacyPass() { namespace { struct CHRStats { - CHRStats() : NumBranches(0), NumBranchesDelta(0), - WeightedNumBranchesDelta(0) {} + CHRStats() = default; void print(raw_ostream &OS) const { OS << "CHRStats: NumBranches " << NumBranches << " NumBranchesDelta " << NumBranchesDelta << " WeightedNumBranchesDelta " << WeightedNumBranchesDelta; } - uint64_t NumBranches; // The original number of conditional branches / - // selects - uint64_t NumBranchesDelta; // The decrease of the number of conditional - // branches / selects in the hot paths due to CHR. - uint64_t WeightedNumBranchesDelta; // NumBranchesDelta weighted by the profile - // count at the scope entry. + // The original number of conditional branches / selects + uint64_t NumBranches = 0; + // The decrease of the number of conditional branches / selects in the hot + // paths due to CHR. + uint64_t NumBranchesDelta = 0; + // NumBranchesDelta weighted by the profile count at the scope entry. + uint64_t WeightedNumBranchesDelta = 0; }; // RegInfo - some properties of a Region. struct RegInfo { - RegInfo() : R(nullptr), HasBranch(false) {} - RegInfo(Region *RegionIn) : R(RegionIn), HasBranch(false) {} - Region *R; - bool HasBranch; + RegInfo() = default; + RegInfo(Region *RegionIn) : R(RegionIn) {} + Region *R = nullptr; + bool HasBranch = false; SmallVector<SelectInst *, 8> Selects; }; @@ -769,9 +770,21 @@ CHRScope * CHR::findScope(Region *R) { return nullptr; // If any of the basic blocks have address taken, we must skip this region // because we cannot clone basic blocks that have address taken. - for (BasicBlock *BB : R->blocks()) + for (BasicBlock *BB : R->blocks()) { if (BB->hasAddressTaken()) return nullptr; + // If we encounter llvm.coro.id, skip this region because if the basic block + // is cloned, we end up inserting a token type PHI node to the block with + // llvm.coro.begin. + // FIXME: This could lead to less optimal codegen, because the region is + // excluded, it can prevent CHR from merging adjacent regions into bigger + // scope and hoisting more branches. + for (Instruction &I : *BB) + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + if (II->getIntrinsicID() == Intrinsic::coro_id) + return nullptr; + } + if (Exit) { // Try to find an if-then block (check if R is an if-then). // if (cond) { @@ -1752,7 +1765,7 @@ void CHR::transformScopes(CHRScope *Scope, DenseSet<PHINode *> &TrivialPHIs) { // Create the combined branch condition and constant-fold the branches/selects // in the hot path. fixupBranchesAndSelects(Scope, PreEntryBlock, MergedBr, - ProfileCount.getValueOr(0)); + ProfileCount.value_or(0)); } // A helper for transformScopes. Clone the blocks in the scope (excluding the @@ -1949,28 +1962,27 @@ void CHR::fixupSelect(SelectInst *SI, CHRScope *Scope, // A helper for fixupBranch/fixupSelect. Add a branch condition to the merged // condition. void CHR::addToMergedCondition(bool IsTrueBiased, Value *Cond, - Instruction *BranchOrSelect, - CHRScope *Scope, - IRBuilder<> &IRB, - Value *&MergedCondition) { - if (IsTrueBiased) { - MergedCondition = IRB.CreateAnd(MergedCondition, Cond); - } else { + Instruction *BranchOrSelect, CHRScope *Scope, + IRBuilder<> &IRB, Value *&MergedCondition) { + if (!IsTrueBiased) { // If Cond is an icmp and all users of V except for BranchOrSelect is a // branch, negate the icmp predicate and swap the branch targets and avoid // inserting an Xor to negate Cond. - bool Done = false; - if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) - if (negateICmpIfUsedByBranchOrSelectOnly(ICmp, BranchOrSelect, Scope)) { - MergedCondition = IRB.CreateAnd(MergedCondition, Cond); - Done = true; - } - if (!Done) { - Value *Negate = IRB.CreateXor( - ConstantInt::getTrue(F.getContext()), Cond); - MergedCondition = IRB.CreateAnd(MergedCondition, Negate); - } + auto *ICmp = dyn_cast<ICmpInst>(Cond); + if (!ICmp || + !negateICmpIfUsedByBranchOrSelectOnly(ICmp, BranchOrSelect, Scope)) + Cond = IRB.CreateXor(ConstantInt::getTrue(F.getContext()), Cond); } + + // Select conditions can be poison, while branching on poison is immediate + // undefined behavior. As such, we need to freeze potentially poisonous + // conditions derived from selects. + if (isa<SelectInst>(BranchOrSelect) && + !isGuaranteedNotToBeUndefOrPoison(Cond)) + Cond = IRB.CreateFreeze(Cond); + + // Use logical and to avoid propagating poison from later conditions. + MergedCondition = IRB.CreateLogicalAnd(MergedCondition, Cond); } void CHR::transformScopes(SmallVectorImpl<CHRScope *> &CHRScopes) { @@ -2080,7 +2092,7 @@ bool ControlHeightReductionLegacyPass::runOnFunction(Function &F) { RegionInfo &RI = getAnalysis<RegionInfoPass>().getRegionInfo(); std::unique_ptr<OptimizationRemarkEmitter> OwnedORE = std::make_unique<OptimizationRemarkEmitter>(&F); - return CHR(F, BFI, DT, PSI, RI, *OwnedORE.get()).run(); + return CHR(F, BFI, DT, PSI, RI, *OwnedORE).run(); } namespace llvm { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp index ff3aa14a2a83..6815688827d2 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp @@ -66,8 +66,8 @@ #include "llvm/ADT/None.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" #include "llvm/ADT/Triple.h" #include "llvm/ADT/iterator.h" #include "llvm/Analysis/ValueTracking.h" @@ -84,13 +84,11 @@ #include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InlineAsm.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" @@ -112,7 +110,6 @@ #include <cassert> #include <cstddef> #include <cstdint> -#include <iterator> #include <memory> #include <set> #include <string> @@ -187,6 +184,15 @@ static cl::opt<bool> ClCombineOffsetLabelsOnGEP( "doing pointer arithmetic."), cl::Hidden, cl::init(true)); +static cl::list<std::string> ClCombineTaintLookupTables( + "dfsan-combine-taint-lookup-table", + cl::desc( + "When dfsan-combine-offset-labels-on-gep and/or " + "dfsan-combine-pointer-labels-on-load are false, this flag can " + "be used to re-enable combining offset and/or pointer taint when " + "loading specific constant global variables (i.e. lookup tables)."), + cl::Hidden); + static cl::opt<bool> ClDebugNonzeroLabels( "dfsan-debug-nonzero-labels", cl::desc("Insert calls to __dfsan_nonzero_label on observing a parameter, " @@ -433,6 +439,7 @@ class DataFlowSanitizer { FunctionType *DFSanUnionLoadFnTy; FunctionType *DFSanLoadLabelAndOriginFnTy; FunctionType *DFSanUnimplementedFnTy; + FunctionType *DFSanWrapperExternWeakNullFnTy; FunctionType *DFSanSetLabelFnTy; FunctionType *DFSanNonzeroLabelFnTy; FunctionType *DFSanVarargWrapperFnTy; @@ -448,6 +455,7 @@ class DataFlowSanitizer { FunctionCallee DFSanUnionLoadFn; FunctionCallee DFSanLoadLabelAndOriginFn; FunctionCallee DFSanUnimplementedFn; + FunctionCallee DFSanWrapperExternWeakNullFn; FunctionCallee DFSanSetLabelFn; FunctionCallee DFSanNonzeroLabelFn; FunctionCallee DFSanVarargWrapperFn; @@ -467,6 +475,7 @@ class DataFlowSanitizer { DFSanABIList ABIList; DenseMap<Value *, Function *> UnwrappedFnMap; AttributeMask ReadOnlyNoneAttrs; + StringSet<> CombineTaintLookupTableNames; /// Memory map parameters used in calculation mapping application addresses /// to shadow addresses and origin addresses. @@ -480,14 +489,13 @@ class DataFlowSanitizer { bool isInstrumented(const Function *F); bool isInstrumented(const GlobalAlias *GA); bool isForceZeroLabels(const Function *F); - FunctionType *getTrampolineFunctionType(FunctionType *T); TransformedFunction getCustomFunctionType(FunctionType *T); WrapperKind getWrapperKind(Function *F); void addGlobalNameSuffix(GlobalValue *GV); + void buildExternWeakCheckIfNeeded(IRBuilder<> &IRB, Function *F); Function *buildWrapperFunction(Function *F, StringRef NewFName, GlobalValue::LinkageTypes NewFLink, FunctionType *NewFT); - Constant *getOrBuildTrampolineFunction(FunctionType *FT, StringRef FName); void initializeCallbackFunctions(Module &M); void initializeRuntimeFunctions(Module &M); void injectMetadataGlobals(Module &M); @@ -658,6 +666,8 @@ struct DFSanFunction { // branch instruction using the given conditional expression. void addConditionalCallbacksIfEnabled(Instruction &I, Value *Condition); + bool isLookupTableConstant(Value *P); + private: /// Collapses the shadow with aggregate type into a single primitive shadow /// value. @@ -792,25 +802,9 @@ DataFlowSanitizer::DataFlowSanitizer( // FIXME: should we propagate vfs::FileSystem to this constructor? ABIList.set( SpecialCaseList::createOrDie(AllABIListFiles, *vfs::getRealFileSystem())); -} -FunctionType *DataFlowSanitizer::getTrampolineFunctionType(FunctionType *T) { - assert(!T->isVarArg()); - SmallVector<Type *, 4> ArgTypes; - ArgTypes.push_back(T->getPointerTo()); - ArgTypes.append(T->param_begin(), T->param_end()); - ArgTypes.append(T->getNumParams(), PrimitiveShadowTy); - Type *RetType = T->getReturnType(); - if (!RetType->isVoidTy()) - ArgTypes.push_back(PrimitiveShadowPtrTy); - - if (shouldTrackOrigins()) { - ArgTypes.append(T->getNumParams(), OriginTy); - if (!RetType->isVoidTy()) - ArgTypes.push_back(OriginPtrTy); - } - - return FunctionType::get(T->getReturnType(), ArgTypes, false); + for (StringRef v : ClCombineTaintLookupTables) + CombineTaintLookupTableNames.insert(v); } TransformedFunction DataFlowSanitizer::getCustomFunctionType(FunctionType *T) { @@ -823,16 +817,8 @@ TransformedFunction DataFlowSanitizer::getCustomFunctionType(FunctionType *T) { std::vector<unsigned> ArgumentIndexMapping; for (unsigned I = 0, E = T->getNumParams(); I != E; ++I) { Type *ParamType = T->getParamType(I); - FunctionType *FT; - if (isa<PointerType>(ParamType) && - (FT = dyn_cast<FunctionType>(ParamType->getPointerElementType()))) { - ArgumentIndexMapping.push_back(ArgTypes.size()); - ArgTypes.push_back(getTrampolineFunctionType(FT)->getPointerTo()); - ArgTypes.push_back(Type::getInt8PtrTy(*Ctx)); - } else { - ArgumentIndexMapping.push_back(ArgTypes.size()); - ArgTypes.push_back(ParamType); - } + ArgumentIndexMapping.push_back(ArgTypes.size()); + ArgTypes.push_back(ParamType); } for (unsigned I = 0, E = T->getNumParams(); I != E; ++I) ArgTypes.push_back(PrimitiveShadowTy); @@ -1058,6 +1044,10 @@ bool DataFlowSanitizer::initializeModule(Module &M) { /*isVarArg=*/false); DFSanUnimplementedFnTy = FunctionType::get( Type::getVoidTy(*Ctx), Type::getInt8PtrTy(*Ctx), /*isVarArg=*/false); + Type *DFSanWrapperExternWeakNullArgs[2] = {Int8Ptr, Int8Ptr}; + DFSanWrapperExternWeakNullFnTy = + FunctionType::get(Type::getVoidTy(*Ctx), DFSanWrapperExternWeakNullArgs, + /*isVarArg=*/false); Type *DFSanSetLabelArgs[4] = {PrimitiveShadowTy, OriginTy, Type::getInt8PtrTy(*Ctx), IntptrTy}; DFSanSetLabelFnTy = FunctionType::get(Type::getVoidTy(*Ctx), @@ -1149,6 +1139,23 @@ void DataFlowSanitizer::addGlobalNameSuffix(GlobalValue *GV) { } } +void DataFlowSanitizer::buildExternWeakCheckIfNeeded(IRBuilder<> &IRB, + Function *F) { + // If the function we are wrapping was ExternWeak, it may be null. + // The original code before calling this wrapper may have checked for null, + // but replacing with a known-to-not-be-null wrapper can break this check. + // When replacing uses of the extern weak function with the wrapper we try + // to avoid replacing uses in conditionals, but this is not perfect. + // In the case where we fail, and accidentially optimize out a null check + // for a extern weak function, add a check here to help identify the issue. + if (GlobalValue::isExternalWeakLinkage(F->getLinkage())) { + std::vector<Value *> Args; + Args.push_back(IRB.CreatePointerCast(F, IRB.getInt8PtrTy())); + Args.push_back(IRB.CreateGlobalStringPtr(F->getName())); + IRB.CreateCall(DFSanWrapperExternWeakNullFn, Args); + } +} + Function * DataFlowSanitizer::buildWrapperFunction(Function *F, StringRef NewFName, GlobalValue::LinkageTypes NewFLink, @@ -1181,61 +1188,6 @@ DataFlowSanitizer::buildWrapperFunction(Function *F, StringRef NewFName, return NewF; } -Constant *DataFlowSanitizer::getOrBuildTrampolineFunction(FunctionType *FT, - StringRef FName) { - FunctionType *FTT = getTrampolineFunctionType(FT); - FunctionCallee C = Mod->getOrInsertFunction(FName, FTT); - Function *F = dyn_cast<Function>(C.getCallee()); - if (F && F->isDeclaration()) { - F->setLinkage(GlobalValue::LinkOnceODRLinkage); - BasicBlock *BB = BasicBlock::Create(*Ctx, "entry", F); - std::vector<Value *> Args; - Function::arg_iterator AI = F->arg_begin() + 1; - for (unsigned N = FT->getNumParams(); N != 0; ++AI, --N) - Args.push_back(&*AI); - CallInst *CI = CallInst::Create(FT, &*F->arg_begin(), Args, "", BB); - Type *RetType = FT->getReturnType(); - ReturnInst *RI = RetType->isVoidTy() ? ReturnInst::Create(*Ctx, BB) - : ReturnInst::Create(*Ctx, CI, BB); - - // F is called by a wrapped custom function with primitive shadows. So - // its arguments and return value need conversion. - DFSanFunction DFSF(*this, F, /*IsNativeABI=*/true, - /*IsForceZeroLabels=*/false); - Function::arg_iterator ValAI = F->arg_begin(), ShadowAI = AI; - ++ValAI; - for (unsigned N = FT->getNumParams(); N != 0; ++ValAI, ++ShadowAI, --N) { - Value *Shadow = - DFSF.expandFromPrimitiveShadow(ValAI->getType(), &*ShadowAI, CI); - DFSF.ValShadowMap[&*ValAI] = Shadow; - } - Function::arg_iterator RetShadowAI = ShadowAI; - const bool ShouldTrackOrigins = shouldTrackOrigins(); - if (ShouldTrackOrigins) { - ValAI = F->arg_begin(); - ++ValAI; - Function::arg_iterator OriginAI = ShadowAI; - if (!RetType->isVoidTy()) - ++OriginAI; - for (unsigned N = FT->getNumParams(); N != 0; ++ValAI, ++OriginAI, --N) { - DFSF.ValOriginMap[&*ValAI] = &*OriginAI; - } - } - DFSanVisitor(DFSF).visitCallInst(*CI); - if (!RetType->isVoidTy()) { - Value *PrimitiveShadow = DFSF.collapseToPrimitiveShadow( - DFSF.getShadow(RI->getReturnValue()), RI); - new StoreInst(PrimitiveShadow, &*RetShadowAI, RI); - if (ShouldTrackOrigins) { - Value *Origin = DFSF.getOrigin(RI->getReturnValue()); - new StoreInst(Origin, &*std::prev(F->arg_end()), RI); - } - } - } - - return cast<Constant>(C.getCallee()); -} - // Initialize DataFlowSanitizer runtime functions and declare them in the module void DataFlowSanitizer::initializeRuntimeFunctions(Module &M) { { @@ -1256,6 +1208,8 @@ void DataFlowSanitizer::initializeRuntimeFunctions(Module &M) { } DFSanUnimplementedFn = Mod->getOrInsertFunction("__dfsan_unimplemented", DFSanUnimplementedFnTy); + DFSanWrapperExternWeakNullFn = Mod->getOrInsertFunction( + "__dfsan_wrapper_extern_weak_null", DFSanWrapperExternWeakNullFnTy); { AttributeList AL; AL = AL.addParamAttribute(M.getContext(), 0, Attribute::ZExt); @@ -1300,6 +1254,8 @@ void DataFlowSanitizer::initializeRuntimeFunctions(Module &M) { DFSanRuntimeFunctions.insert( DFSanUnimplementedFn.getCallee()->stripPointerCasts()); DFSanRuntimeFunctions.insert( + DFSanWrapperExternWeakNullFn.getCallee()->stripPointerCasts()); + DFSanRuntimeFunctions.insert( DFSanSetLabelFn.getCallee()->stripPointerCasts()); DFSanRuntimeFunctions.insert( DFSanNonzeroLabelFn.getCallee()->stripPointerCasts()); @@ -1500,7 +1456,40 @@ bool DataFlowSanitizer::runImpl(Module &M) { Value *WrappedFnCst = ConstantExpr::getBitCast(NewF, PointerType::getUnqual(FT)); - F.replaceAllUsesWith(WrappedFnCst); + + // Extern weak functions can sometimes be null at execution time. + // Code will sometimes check if an extern weak function is null. + // This could look something like: + // declare extern_weak i8 @my_func(i8) + // br i1 icmp ne (i8 (i8)* @my_func, i8 (i8)* null), label %use_my_func, + // label %avoid_my_func + // The @"dfsw$my_func" wrapper is never null, so if we replace this use + // in the comparision, the icmp will simplify to false and we have + // accidentially optimized away a null check that is necessary. + // This can lead to a crash when the null extern_weak my_func is called. + // + // To prevent (the most common pattern of) this problem, + // do not replace uses in comparisons with the wrapper. + // We definitely want to replace uses in call instructions. + // Other uses (e.g. store the function address somewhere) might be + // called or compared or both - this case may not be handled correctly. + // We will default to replacing with wrapper in cases we are unsure. + auto IsNotCmpUse = [](Use &U) -> bool { + User *Usr = U.getUser(); + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Usr)) { + // This is the most common case for icmp ne null + if (CE->getOpcode() == Instruction::ICmp) { + return false; + } + } + if (Instruction *I = dyn_cast<Instruction>(Usr)) { + if (I->getOpcode() == Instruction::ICmp) { + return false; + } + } + return true; + }; + F.replaceUsesWithIf(WrappedFnCst, IsNotCmpUse); UnwrappedFnMap[WrappedFnCst] = &F; *FI = NewF; @@ -1919,6 +1908,14 @@ Align DFSanFunction::getOriginAlign(Align InstAlignment) { return Align(std::max(MinOriginAlignment, Alignment)); } +bool DFSanFunction::isLookupTableConstant(Value *P) { + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(P->stripPointerCasts())) + if (GV->isConstant() && GV->hasName()) + return DFS.CombineTaintLookupTableNames.count(GV->getName()); + + return false; +} + bool DFSanFunction::useCallbackLoadLabelAndOrigin(uint64_t Size, Align InstAlignment) { // When enabling tracking load instructions, we always use @@ -2172,6 +2169,29 @@ static AtomicOrdering addAcquireOrdering(AtomicOrdering AO) { llvm_unreachable("Unknown ordering"); } +Value *StripPointerGEPsAndCasts(Value *V) { + if (!V->getType()->isPointerTy()) + return V; + + // DFSan pass should be running on valid IR, but we'll + // keep a seen set to ensure there are no issues. + SmallPtrSet<const Value *, 4> Visited; + Visited.insert(V); + do { + if (auto *GEP = dyn_cast<GEPOperator>(V)) { + V = GEP->getPointerOperand(); + } else if (Operator::getOpcode(V) == Instruction::BitCast) { + V = cast<Operator>(V)->getOperand(0); + if (!V->getType()->isPointerTy()) + return V; + } else if (isa<GlobalAlias>(V)) { + V = cast<GlobalAlias>(V)->getAliasee(); + } + } while (Visited.insert(V).second); + + return V; +} + void DFSanVisitor::visitLoadInst(LoadInst &LI) { auto &DL = LI.getModule()->getDataLayout(); uint64_t Size = DL.getTypeStoreSize(LI.getType()); @@ -2200,7 +2220,9 @@ void DFSanVisitor::visitLoadInst(LoadInst &LI) { Shadows.push_back(PrimitiveShadow); Origins.push_back(Origin); } - if (ClCombinePointerLabelsOnLoad) { + if (ClCombinePointerLabelsOnLoad || + DFSF.isLookupTableConstant( + StripPointerGEPsAndCasts(LI.getPointerOperand()))) { Value *PtrShadow = DFSF.getShadow(LI.getPointerOperand()); PrimitiveShadow = DFSF.combineShadows(PrimitiveShadow, PtrShadow, Pos); if (ShouldTrackOrigins) { @@ -2562,7 +2584,9 @@ void DFSanVisitor::visitLandingPadInst(LandingPadInst &LPI) { } void DFSanVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { - if (ClCombineOffsetLabelsOnGEP) { + if (ClCombineOffsetLabelsOnGEP || + DFSF.isLookupTableConstant( + StripPointerGEPsAndCasts(GEPI.getPointerOperand()))) { visitInstOperands(GEPI); return; } @@ -2722,13 +2746,8 @@ void DFSanVisitor::visitMemTransferInst(MemTransferInst &I) { auto *MTI = cast<MemTransferInst>( IRB.CreateCall(I.getFunctionType(), I.getCalledOperand(), {DestShadow, SrcShadow, LenShadow, I.getVolatileCst()})); - if (ClPreserveAlignment) { - MTI->setDestAlignment(I.getDestAlign() * DFSF.DFS.ShadowWidthBytes); - MTI->setSourceAlignment(I.getSourceAlign() * DFSF.DFS.ShadowWidthBytes); - } else { - MTI->setDestAlignment(Align(DFSF.DFS.ShadowWidthBytes)); - MTI->setSourceAlignment(Align(DFSF.DFS.ShadowWidthBytes)); - } + MTI->setDestAlignment(DFSF.getShadowAlign(I.getDestAlign().valueOrOne())); + MTI->setSourceAlignment(DFSF.getShadowAlign(I.getSourceAlign().valueOrOne())); if (ClEventCallbacks) { IRB.CreateCall(DFSF.DFS.DFSanMemTransferCallbackFn, {RawDestShadow, @@ -2864,16 +2883,19 @@ bool DFSanVisitor::visitWrappedCallBase(Function &F, CallBase &CB) { CB.setCalledFunction(&F); IRB.CreateCall(DFSF.DFS.DFSanUnimplementedFn, IRB.CreateGlobalStringPtr(F.getName())); + DFSF.DFS.buildExternWeakCheckIfNeeded(IRB, &F); DFSF.setShadow(&CB, DFSF.DFS.getZeroShadow(&CB)); DFSF.setOrigin(&CB, DFSF.DFS.ZeroOrigin); return true; case DataFlowSanitizer::WK_Discard: CB.setCalledFunction(&F); + DFSF.DFS.buildExternWeakCheckIfNeeded(IRB, &F); DFSF.setShadow(&CB, DFSF.DFS.getZeroShadow(&CB)); DFSF.setOrigin(&CB, DFSF.DFS.ZeroOrigin); return true; case DataFlowSanitizer::WK_Functional: CB.setCalledFunction(&F); + DFSF.DFS.buildExternWeakCheckIfNeeded(IRB, &F); visitInstOperands(CB); return true; case DataFlowSanitizer::WK_Custom: @@ -2905,22 +2927,7 @@ bool DFSanVisitor::visitWrappedCallBase(Function &F, CallBase &CB) { // Adds non-variable arguments. auto *I = CB.arg_begin(); for (unsigned N = FT->getNumParams(); N != 0; ++I, --N) { - Type *T = (*I)->getType(); - FunctionType *ParamFT; - if (isa<PointerType>(T) && - (ParamFT = dyn_cast<FunctionType>(T->getPointerElementType()))) { - std::string TName = "dfst"; - TName += utostr(FT->getNumParams() - N); - TName += "$"; - TName += F.getName(); - Constant *Trampoline = - DFSF.DFS.getOrBuildTrampolineFunction(ParamFT, TName); - Args.push_back(Trampoline); - Args.push_back( - IRB.CreateBitCast(*I, Type::getInt8PtrTy(*DFSF.DFS.Ctx))); - } else { - Args.push_back(*I); - } + Args.push_back(*I); } // Adds shadow arguments. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp index 325089fc4402..ac4a1fd6bb7e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp @@ -14,19 +14,15 @@ //===----------------------------------------------------------------------===// #include "CFGMST.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/IR/CFG.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/IRBuilder.h" @@ -34,8 +30,6 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/CRC.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -87,7 +81,7 @@ GCOVOptions GCOVOptions::getDefault() { if (DefaultGCOVVersion.size() != 4) { llvm::report_fatal_error(Twine("Invalid -default-gcov-version: ") + - DefaultGCOVVersion); + DefaultGCOVVersion, /*GenCrashDiag=*/false); } memcpy(Options.Version, DefaultGCOVVersion.c_str(), 4); return Options; @@ -169,39 +163,6 @@ private: StringMap<bool> InstrumentedFiles; }; -class GCOVProfilerLegacyPass : public ModulePass { -public: - static char ID; - GCOVProfilerLegacyPass() - : GCOVProfilerLegacyPass(GCOVOptions::getDefault()) {} - GCOVProfilerLegacyPass(const GCOVOptions &Opts) - : ModulePass(ID), Profiler(Opts) { - initializeGCOVProfilerLegacyPassPass(*PassRegistry::getPassRegistry()); - } - StringRef getPassName() const override { return "GCOV Profiler"; } - - bool runOnModule(Module &M) override { - auto GetBFI = [this](Function &F) { - return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI(); - }; - auto GetBPI = [this](Function &F) { - return &this->getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI(); - }; - auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & { - return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - }; - return Profiler.runOnModule(M, GetBFI, GetBPI, GetTLI); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<BlockFrequencyInfoWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - } - -private: - GCOVProfiler Profiler; -}; - struct BBInfo { BBInfo *Group; uint32_t Index; @@ -237,21 +198,6 @@ struct Edge { }; } -char GCOVProfilerLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN( - GCOVProfilerLegacyPass, "insert-gcov-profiling", - "Insert instrumentation for GCOV profiling", false, false) -INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END( - GCOVProfilerLegacyPass, "insert-gcov-profiling", - "Insert instrumentation for GCOV profiling", false, false) - -ModulePass *llvm::createGCOVProfilerPass(const GCOVOptions &Options) { - return new GCOVProfilerLegacyPass(Options); -} - static StringRef getFunctionName(const DISubprogram *SP) { if (!SP->getLinkageName().empty()) return SP->getLinkageName(); @@ -862,7 +808,8 @@ bool GCOVProfiler::emitProfileNotes( // Split indirectbr critical edges here before computing the MST rather // than later in getInstrBB() to avoid invalidating it. - SplitIndirectBrCriticalEdges(F, BPI, BFI); + SplitIndirectBrCriticalEdges(F, /*IgnoreBlocksWithoutPHI=*/false, BPI, + BFI); CFGMST<Edge, BBInfo> MST(F, /*InstrumentFuncEntry_=*/false, BPI, BFI); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp index 7b3741d19a1b..218b4bbfb6c0 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -13,14 +13,15 @@ #include "llvm/Transforms/Instrumentation/HWAddressSanitizer.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" -#include "llvm/Analysis/CFG.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/StackSafetyAnalysis.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/BinaryFormat/Dwarf.h" #include "llvm/BinaryFormat/ELF.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -33,7 +34,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" -#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" @@ -43,19 +44,15 @@ #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" -#include "llvm/PassRegistry.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Instrumentation/AddressSanitizerCommon.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/MemoryTaggingSupport.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" -#include <sstream> using namespace llvm; @@ -83,6 +80,11 @@ static cl::opt<std::string> cl::desc("Prefix for memory access callbacks"), cl::Hidden, cl::init("__hwasan_")); +static cl::opt<bool> ClKasanMemIntrinCallbackPrefix( + "hwasan-kernel-mem-intrinsic-prefix", + cl::desc("Use prefix for memory intrinsics in KASAN mode"), cl::Hidden, + cl::init(false)); + static cl::opt<bool> ClInstrumentWithCalls( "hwasan-instrument-with-calls", cl::desc("instrument reads and writes with callbacks"), cl::Hidden, @@ -145,7 +147,7 @@ static cl::opt<bool> ClGenerateTagsWithCalls( cl::init(false)); static cl::opt<bool> ClGlobals("hwasan-globals", cl::desc("Instrument globals"), - cl::Hidden, cl::init(false), cl::ZeroOrMore); + cl::Hidden, cl::init(false)); static cl::opt<int> ClMatchAllTag( "hwasan-match-all-tag", @@ -191,17 +193,16 @@ static cl::opt<bool> static cl::opt<bool> ClInstrumentLandingPads("hwasan-instrument-landing-pads", cl::desc("instrument landing pads"), cl::Hidden, - cl::init(false), cl::ZeroOrMore); + cl::init(false)); static cl::opt<bool> ClUseShortGranules( "hwasan-use-short-granules", cl::desc("use short granules in allocas and outlined checks"), cl::Hidden, - cl::init(false), cl::ZeroOrMore); + cl::init(false)); static cl::opt<bool> ClInstrumentPersonalityFunctions( "hwasan-instrument-personality-functions", - cl::desc("instrument personality functions"), cl::Hidden, cl::init(false), - cl::ZeroOrMore); + cl::desc("instrument personality functions"), cl::Hidden); static cl::opt<bool> ClInlineAllChecks("hwasan-inline-all-checks", cl::desc("inline all checks"), @@ -244,13 +245,6 @@ bool shouldDetectUseAfterScope(const Triple &TargetTriple) { /// An instrumentation pass implementing detection of addressability bugs /// using tagged pointers. class HWAddressSanitizer { -private: - struct AllocaInfo { - AllocaInst *AI; - SmallVector<IntrinsicInst *, 2> LifetimeStart; - SmallVector<IntrinsicInst *, 2> LifetimeEnd; - }; - public: HWAddressSanitizer(Module &M, bool CompileKernel, bool Recover, const StackSafetyGlobalInfo *SSI) @@ -265,11 +259,7 @@ public: void setSSI(const StackSafetyGlobalInfo *S) { SSI = S; } - DenseMap<AllocaInst *, AllocaInst *> padInterestingAllocas( - const MapVector<AllocaInst *, AllocaInfo> &AllocasToInstrument); - bool sanitizeFunction(Function &F, - llvm::function_ref<const DominatorTree &()> GetDT, - llvm::function_ref<const PostDominatorTree &()> GetPDT); + bool sanitizeFunction(Function &F, FunctionAnalysisManager &FAM); void initializeModule(); void createHwasanCtorComdat(); @@ -301,16 +291,9 @@ public: void tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, size_t Size); Value *tagPointer(IRBuilder<> &IRB, Type *Ty, Value *PtrLong, Value *Tag); Value *untagPointer(IRBuilder<> &IRB, Value *PtrLong); - static bool isStandardLifetime(const AllocaInfo &AllocaInfo, - const DominatorTree &DT); - bool instrumentStack( - bool ShouldDetectUseAfterScope, - MapVector<AllocaInst *, AllocaInfo> &AllocasToInstrument, - SmallVector<Instruction *, 4> &UnrecognizedLifetimes, - DenseMap<AllocaInst *, std::vector<DbgVariableIntrinsic *>> &AllocaDbgMap, - SmallVectorImpl<Instruction *> &RetVec, Value *StackTag, - llvm::function_ref<const DominatorTree &()> GetDT, - llvm::function_ref<const PostDominatorTree &()> GetPDT); + bool instrumentStack(memtag::StackInfo &Info, Value *StackTag, + const DominatorTree &DT, const PostDominatorTree &PDT, + const LoopInfo &LI); Value *readRegister(IRBuilder<> &IRB, StringRef Name); bool instrumentLandingPads(SmallVectorImpl<Instruction *> &RetVec); Value *getNextTagWithCall(IRBuilder<> &IRB); @@ -328,6 +311,9 @@ public: void instrumentGlobal(GlobalVariable *GV, uint8_t Tag); void instrumentGlobals(); + Value *getPC(IRBuilder<> &IRB); + Value *getSP(IRBuilder<> &IRB); + void instrumentPersonalityFunctions(); private: @@ -397,96 +383,12 @@ private: Value *ShadowBase = nullptr; Value *StackBaseTag = nullptr; + Value *CachedSP = nullptr; GlobalValue *ThreadPtrGlobal = nullptr; }; -class HWAddressSanitizerLegacyPass : public FunctionPass { -public: - // Pass identification, replacement for typeid. - static char ID; - - explicit HWAddressSanitizerLegacyPass(bool CompileKernel = false, - bool Recover = false, - bool DisableOptimization = false) - : FunctionPass(ID), CompileKernel(CompileKernel), Recover(Recover), - DisableOptimization(DisableOptimization) { - initializeHWAddressSanitizerLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - StringRef getPassName() const override { return "HWAddressSanitizer"; } - - bool doInitialization(Module &M) override { - HWASan = std::make_unique<HWAddressSanitizer>(M, CompileKernel, Recover, - /*SSI=*/nullptr); - return true; - } - - bool runOnFunction(Function &F) override { - auto TargetTriple = Triple(F.getParent()->getTargetTriple()); - if (shouldUseStackSafetyAnalysis(TargetTriple, DisableOptimization)) { - // We cannot call getAnalysis in doInitialization, that would cause a - // crash as the required analyses are not initialized yet. - HWASan->setSSI( - &getAnalysis<StackSafetyGlobalInfoWrapperPass>().getResult()); - } - return HWASan->sanitizeFunction( - F, - [&]() -> const DominatorTree & { - return getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - }, - [&]() -> const PostDominatorTree & { - return getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); - }); - } - - bool doFinalization(Module &M) override { - HWASan.reset(); - return false; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - // This is an over-estimation of, in case we are building for an - // architecture that doesn't allow stack tagging we will still load the - // analysis. - // This is so we don't need to plumb TargetTriple all the way to here. - if (mightUseStackSafetyAnalysis(DisableOptimization)) - AU.addRequired<StackSafetyGlobalInfoWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<PostDominatorTreeWrapperPass>(); - } - -private: - std::unique_ptr<HWAddressSanitizer> HWASan; - bool CompileKernel; - bool Recover; - bool DisableOptimization; -}; - } // end anonymous namespace -char HWAddressSanitizerLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN( - HWAddressSanitizerLegacyPass, "hwasan", - "HWAddressSanitizer: detect memory bugs using tagged addressing.", false, - false) -INITIALIZE_PASS_DEPENDENCY(StackSafetyGlobalInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) -INITIALIZE_PASS_END( - HWAddressSanitizerLegacyPass, "hwasan", - "HWAddressSanitizer: detect memory bugs using tagged addressing.", false, - false) - -FunctionPass * -llvm::createHWAddressSanitizerLegacyPassPass(bool CompileKernel, bool Recover, - bool DisableOptimization) { - assert(!CompileKernel || Recover); - return new HWAddressSanitizerLegacyPass(CompileKernel, Recover, - DisableOptimization); -} - PreservedAnalyses HWAddressSanitizerPass::run(Module &M, ModuleAnalysisManager &MAM) { const StackSafetyGlobalInfo *SSI = nullptr; @@ -497,16 +399,8 @@ PreservedAnalyses HWAddressSanitizerPass::run(Module &M, HWAddressSanitizer HWASan(M, Options.CompileKernel, Options.Recover, SSI); bool Modified = false; auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); - for (Function &F : M) { - Modified |= HWASan.sanitizeFunction( - F, - [&]() -> const DominatorTree & { - return FAM.getResult<DominatorTreeAnalysis>(F); - }, - [&]() -> const PostDominatorTree & { - return FAM.getResult<PostDominatorTreeAnalysis>(F); - }); - } + for (Function &F : M) + Modified |= HWASan.sanitizeFunction(F, FAM); if (Modified) return PreservedAnalyses::none(); return PreservedAnalyses::all(); @@ -739,7 +633,9 @@ void HWAddressSanitizer::initializeCallbacks(Module &M) { ArrayType::get(IRB.getInt8Ty(), 0)); const std::string MemIntrinCallbackPrefix = - CompileKernel ? std::string("") : ClMemoryAccessCallbackPrefix; + (CompileKernel && !ClKasanMemIntrinCallbackPrefix) + ? std::string("") + : ClMemoryAccessCallbackPrefix; HWAsanMemmove = M.getOrInsertFunction(MemIntrinCallbackPrefix + "memmove", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy); @@ -812,7 +708,7 @@ bool HWAddressSanitizer::ignoreAccess(Instruction *Inst, Value *Ptr) { void HWAddressSanitizer::getInterestingMemoryOperands( Instruction *I, SmallVectorImpl<InterestingMemoryOperand> &Interesting) { // Skip memory accesses inserted by another instrumentation. - if (I->hasMetadata("nosanitize")) + if (I->hasMetadata(LLVMContext::MD_nosanitize)) return; // Do not instrument the load fetching the dynamic shadow address. @@ -1056,18 +952,6 @@ bool HWAddressSanitizer::instrumentMemAccess(InterestingMemoryOperand &O) { return true; } -static uint64_t getAllocaSizeInBytes(const AllocaInst &AI) { - uint64_t ArraySize = 1; - if (AI.isArrayAllocation()) { - const ConstantInt *CI = dyn_cast<ConstantInt>(AI.getArraySize()); - assert(CI && "non-constant array size"); - ArraySize = CI->getZExtValue(); - } - Type *Ty = AI.getAllocatedType(); - uint64_t SizeInBytes = AI.getModule()->getDataLayout().getTypeAllocSize(Ty); - return SizeInBytes * ArraySize; -} - void HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, size_t Size) { size_t AlignedSize = alignTo(Size, Mapping.getObjectAlignment()); @@ -1141,19 +1025,10 @@ Value *HWAddressSanitizer::getStackBaseTag(IRBuilder<> &IRB) { return getNextTagWithCall(IRB); if (StackBaseTag) return StackBaseTag; - // FIXME: use addressofreturnaddress (but implement it in aarch64 backend - // first). - Module *M = IRB.GetInsertBlock()->getParent()->getParent(); - auto GetStackPointerFn = Intrinsic::getDeclaration( - M, Intrinsic::frameaddress, - IRB.getInt8PtrTy(M->getDataLayout().getAllocaAddrSpace())); - Value *StackPointer = IRB.CreateCall( - GetStackPointerFn, {Constant::getNullValue(IRB.getInt32Ty())}); - // Extract some entropy from the stack pointer for the tags. // Take bits 20..28 (ASLR entropy) and xor with bits 0..8 (these differ // between functions). - Value *StackPointerLong = IRB.CreatePointerCast(StackPointer, IntptrTy); + Value *StackPointerLong = getSP(IRB); Value *StackTag = applyTagMask(IRB, IRB.CreateXor(StackPointerLong, IRB.CreateLShr(StackPointerLong, 20))); @@ -1233,6 +1108,30 @@ Value *HWAddressSanitizer::getHwasanThreadSlotPtr(IRBuilder<> &IRB, Type *Ty) { return nullptr; } +Value *HWAddressSanitizer::getPC(IRBuilder<> &IRB) { + if (TargetTriple.getArch() == Triple::aarch64) + return readRegister(IRB, "pc"); + else + return IRB.CreatePtrToInt(IRB.GetInsertBlock()->getParent(), IntptrTy); +} + +Value *HWAddressSanitizer::getSP(IRBuilder<> &IRB) { + if (!CachedSP) { + // FIXME: use addressofreturnaddress (but implement it in aarch64 backend + // first). + Function *F = IRB.GetInsertBlock()->getParent(); + Module *M = F->getParent(); + auto GetStackPointerFn = Intrinsic::getDeclaration( + M, Intrinsic::frameaddress, + IRB.getInt8PtrTy(M->getDataLayout().getAllocaAddrSpace())); + CachedSP = IRB.CreatePtrToInt( + IRB.CreateCall(GetStackPointerFn, + {Constant::getNullValue(IRB.getInt32Ty())}), + IntptrTy); + } + return CachedSP; +} + void HWAddressSanitizer::emitPrologue(IRBuilder<> &IRB, bool WithFrameRecord) { if (!Mapping.InTls) ShadowBase = getShadowNonTls(IRB); @@ -1251,23 +1150,12 @@ void HWAddressSanitizer::emitPrologue(IRBuilder<> &IRB, bool WithFrameRecord) { TargetTriple.isAArch64() ? ThreadLong : untagPointer(IRB, ThreadLong); if (WithFrameRecord) { - Function *F = IRB.GetInsertBlock()->getParent(); StackBaseTag = IRB.CreateAShr(ThreadLong, 3); // Prepare ring buffer data. - Value *PC; - if (TargetTriple.getArch() == Triple::aarch64) - PC = readRegister(IRB, "pc"); - else - PC = IRB.CreatePtrToInt(F, IntptrTy); - Module *M = F->getParent(); - auto GetStackPointerFn = Intrinsic::getDeclaration( - M, Intrinsic::frameaddress, - IRB.getInt8PtrTy(M->getDataLayout().getAllocaAddrSpace())); - Value *SP = IRB.CreatePtrToInt( - IRB.CreateCall(GetStackPointerFn, - {Constant::getNullValue(IRB.getInt32Ty())}), - IntptrTy); + Value *PC = getPC(IRB); + Value *SP = getSP(IRB); + // Mix SP and PC. // Assumptions: // PC is 0x0000PPPPPPPPPPPP (48 bits are meaningful, others are zero) @@ -1330,43 +1218,16 @@ bool HWAddressSanitizer::instrumentLandingPads( return true; } -static bool -maybeReachableFromEachOther(const SmallVectorImpl<IntrinsicInst *> &Insts, - const DominatorTree &DT) { - // If we have too many lifetime ends, give up, as the algorithm below is N^2. - if (Insts.size() > ClMaxLifetimes) - return true; - for (size_t I = 0; I < Insts.size(); ++I) { - for (size_t J = 0; J < Insts.size(); ++J) { - if (I == J) - continue; - if (isPotentiallyReachable(Insts[I], Insts[J], nullptr, &DT)) - return true; - } - } - return false; -} - -// static -bool HWAddressSanitizer::isStandardLifetime(const AllocaInfo &AllocaInfo, - const DominatorTree &DT) { - // An alloca that has exactly one start and end in every possible execution. - // If it has multiple ends, they have to be unreachable from each other, so - // at most one of them is actually used for each execution of the function. - return AllocaInfo.LifetimeStart.size() == 1 && - (AllocaInfo.LifetimeEnd.size() == 1 || - (AllocaInfo.LifetimeEnd.size() > 0 && - !maybeReachableFromEachOther(AllocaInfo.LifetimeEnd, DT))); +static bool isLifetimeIntrinsic(Value *V) { + auto *II = dyn_cast<IntrinsicInst>(V); + return II && II->isLifetimeStartOrEnd(); } -bool HWAddressSanitizer::instrumentStack( - bool ShouldDetectUseAfterScope, - MapVector<AllocaInst *, AllocaInfo> &AllocasToInstrument, - SmallVector<Instruction *, 4> &UnrecognizedLifetimes, - DenseMap<AllocaInst *, std::vector<DbgVariableIntrinsic *>> &AllocaDbgMap, - SmallVectorImpl<Instruction *> &RetVec, Value *StackTag, - llvm::function_ref<const DominatorTree &()> GetDT, - llvm::function_ref<const PostDominatorTree &()> GetPDT) { +bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo, + Value *StackTag, + const DominatorTree &DT, + const PostDominatorTree &PDT, + const LoopInfo &LI) { // Ideally, we want to calculate tagged stack base pointer, and rewrite all // alloca addresses using that. Unfortunately, offsets are not known yet // (unless we use ASan-style mega-alloca). Instead we keep the base tag in a @@ -1374,10 +1235,10 @@ bool HWAddressSanitizer::instrumentStack( // This generates one extra instruction per alloca use. unsigned int I = 0; - for (auto &KV : AllocasToInstrument) { + for (auto &KV : SInfo.AllocasToInstrument) { auto N = I++; auto *AI = KV.first; - AllocaInfo &Info = KV.second; + memtag::AllocaInfo &Info = KV.second; IRBuilder<> IRB(AI->getNextNode()); // Replace uses of the alloca with tagged address. @@ -1388,10 +1249,34 @@ bool HWAddressSanitizer::instrumentStack( AI->hasName() ? AI->getName().str() : "alloca." + itostr(N); Replacement->setName(Name + ".hwasan"); - AI->replaceUsesWithIf(Replacement, - [AILong](Use &U) { return U.getUser() != AILong; }); + size_t Size = memtag::getAllocaSizeInBytes(*AI); + size_t AlignedSize = alignTo(Size, Mapping.getObjectAlignment()); + + Value *AICast = IRB.CreatePointerCast(AI, Int8PtrTy); + + auto HandleLifetime = [&](IntrinsicInst *II) { + // Set the lifetime intrinsic to cover the whole alloca. This reduces the + // set of assumptions we need to make about the lifetime. Without this we + // would need to ensure that we can track the lifetime pointer to a + // constant offset from the alloca, and would still need to change the + // size to include the extra alignment we use for the untagging to make + // the size consistent. + // + // The check for standard lifetime below makes sure that we have exactly + // one set of start / end in any execution (i.e. the ends are not + // reachable from each other), so this will not cause any problems. + II->setArgOperand(0, ConstantInt::get(Int64Ty, AlignedSize)); + II->setArgOperand(1, AICast); + }; + llvm::for_each(Info.LifetimeStart, HandleLifetime); + llvm::for_each(Info.LifetimeEnd, HandleLifetime); - for (auto *DDI : AllocaDbgMap.lookup(AI)) { + AI->replaceUsesWithIf(Replacement, [AICast, AILong](Use &U) { + auto *User = U.getUser(); + return User != AILong && User != AICast && !isLifetimeIntrinsic(User); + }); + + for (auto *DDI : Info.DbgVariableIntrinsics) { // Prepend "tag_offset, N" to the dwarf expression. // Tag offset logically applies to the alloca pointer, and it makes sense // to put it at the beginning of the expression. @@ -1403,37 +1288,47 @@ bool HWAddressSanitizer::instrumentStack( NewOps, LocNo)); } - size_t Size = getAllocaSizeInBytes(*AI); - size_t AlignedSize = alignTo(Size, Mapping.getObjectAlignment()); auto TagEnd = [&](Instruction *Node) { IRB.SetInsertPoint(Node); Value *UARTag = getUARTag(IRB, StackTag); + // When untagging, use the `AlignedSize` because we need to set the tags + // for the entire alloca to zero. If we used `Size` here, we would + // keep the last granule tagged, and store zero in the last byte of the + // last granule, due to how short granules are implemented. tagAlloca(IRB, AI, UARTag, AlignedSize); }; + // Calls to functions that may return twice (e.g. setjmp) confuse the + // postdominator analysis, and will leave us to keep memory tagged after + // function return. Work around this by always untagging at every return + // statement if return_twice functions are called. bool StandardLifetime = - UnrecognizedLifetimes.empty() && isStandardLifetime(Info, GetDT()); - if (ShouldDetectUseAfterScope && StandardLifetime) { + SInfo.UnrecognizedLifetimes.empty() && + memtag::isStandardLifetime(Info.LifetimeStart, Info.LifetimeEnd, &DT, + &LI, ClMaxLifetimes) && + !SInfo.CallsReturnTwice; + if (DetectUseAfterScope && StandardLifetime) { IntrinsicInst *Start = Info.LifetimeStart[0]; IRB.SetInsertPoint(Start->getNextNode()); tagAlloca(IRB, AI, Tag, Size); - if (!forAllReachableExits(GetDT(), GetPDT(), Start, Info.LifetimeEnd, - RetVec, TagEnd)) { + if (!memtag::forAllReachableExits(DT, PDT, LI, Start, Info.LifetimeEnd, + SInfo.RetVec, TagEnd)) { for (auto *End : Info.LifetimeEnd) End->eraseFromParent(); } } else { tagAlloca(IRB, AI, Tag, Size); - for (auto *RI : RetVec) + for (auto *RI : SInfo.RetVec) TagEnd(RI); - if (!StandardLifetime) { - for (auto &II : Info.LifetimeStart) - II->eraseFromParent(); - for (auto &II : Info.LifetimeEnd) - II->eraseFromParent(); - } + // We inserted tagging outside of the lifetimes, so we have to remove + // them. + for (auto &II : Info.LifetimeStart) + II->eraseFromParent(); + for (auto &II : Info.LifetimeEnd) + II->eraseFromParent(); } + memtag::alignAndPadAlloca(Info, Align(Mapping.getObjectAlignment())); } - for (auto &I : UnrecognizedLifetimes) + for (auto &I : SInfo.UnrecognizedLifetimes) I->eraseFromParent(); return true; } @@ -1443,7 +1338,7 @@ bool HWAddressSanitizer::isInterestingAlloca(const AllocaInst &AI) { // FIXME: instrument dynamic allocas, too AI.isStaticAlloca() && // alloca() may be called with 0 size, ignore it. - getAllocaSizeInBytes(AI) > 0 && + memtag::getAllocaSizeInBytes(AI) > 0 && // We are only interested in allocas not promotable to registers. // Promotable allocas are common under -O0. !isAllocaPromotable(&AI) && @@ -1456,42 +1351,8 @@ bool HWAddressSanitizer::isInterestingAlloca(const AllocaInst &AI) { !(SSI && SSI->isSafe(AI)); } -DenseMap<AllocaInst *, AllocaInst *> HWAddressSanitizer::padInterestingAllocas( - const MapVector<AllocaInst *, AllocaInfo> &AllocasToInstrument) { - DenseMap<AllocaInst *, AllocaInst *> AllocaToPaddedAllocaMap; - for (auto &KV : AllocasToInstrument) { - AllocaInst *AI = KV.first; - uint64_t Size = getAllocaSizeInBytes(*AI); - uint64_t AlignedSize = alignTo(Size, Mapping.getObjectAlignment()); - AI->setAlignment( - Align(std::max(AI->getAlignment(), Mapping.getObjectAlignment()))); - if (Size != AlignedSize) { - Type *AllocatedType = AI->getAllocatedType(); - if (AI->isArrayAllocation()) { - uint64_t ArraySize = - cast<ConstantInt>(AI->getArraySize())->getZExtValue(); - AllocatedType = ArrayType::get(AllocatedType, ArraySize); - } - Type *TypeWithPadding = StructType::get( - AllocatedType, ArrayType::get(Int8Ty, AlignedSize - Size)); - auto *NewAI = new AllocaInst( - TypeWithPadding, AI->getType()->getAddressSpace(), nullptr, "", AI); - NewAI->takeName(AI); - NewAI->setAlignment(AI->getAlign()); - NewAI->setUsedWithInAlloca(AI->isUsedWithInAlloca()); - NewAI->setSwiftError(AI->isSwiftError()); - NewAI->copyMetadata(*AI); - auto *Bitcast = new BitCastInst(NewAI, AI->getType(), "", AI); - AI->replaceAllUsesWith(Bitcast); - AllocaToPaddedAllocaMap[AI] = NewAI; - } - } - return AllocaToPaddedAllocaMap; -} - -bool HWAddressSanitizer::sanitizeFunction( - Function &F, llvm::function_ref<const DominatorTree &()> GetDT, - llvm::function_ref<const PostDominatorTree &()> GetPDT) { +bool HWAddressSanitizer::sanitizeFunction(Function &F, + FunctionAnalysisManager &FAM) { if (&F == HwasanCtorFunction) return false; @@ -1502,72 +1363,27 @@ bool HWAddressSanitizer::sanitizeFunction( SmallVector<InterestingMemoryOperand, 16> OperandsToInstrument; SmallVector<MemIntrinsic *, 16> IntrinToInstrument; - MapVector<AllocaInst *, AllocaInfo> AllocasToInstrument; - SmallVector<Instruction *, 8> RetVec; SmallVector<Instruction *, 8> LandingPadVec; - SmallVector<Instruction *, 4> UnrecognizedLifetimes; - DenseMap<AllocaInst *, std::vector<DbgVariableIntrinsic *>> AllocaDbgMap; - bool CallsReturnTwice = false; - for (auto &BB : F) { - for (auto &Inst : BB) { - if (CallInst *CI = dyn_cast<CallInst>(&Inst)) { - if (CI->canReturnTwice()) { - CallsReturnTwice = true; - } - } - if (InstrumentStack) { - if (AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) { - if (isInterestingAlloca(*AI)) - AllocasToInstrument.insert({AI, {}}); - continue; - } - auto *II = dyn_cast<IntrinsicInst>(&Inst); - if (II && (II->getIntrinsicID() == Intrinsic::lifetime_start || - II->getIntrinsicID() == Intrinsic::lifetime_end)) { - AllocaInst *AI = findAllocaForValue(II->getArgOperand(1)); - if (!AI) { - UnrecognizedLifetimes.push_back(&Inst); - continue; - } - if (!isInterestingAlloca(*AI)) - continue; - if (II->getIntrinsicID() == Intrinsic::lifetime_start) - AllocasToInstrument[AI].LifetimeStart.push_back(II); - else - AllocasToInstrument[AI].LifetimeEnd.push_back(II); - continue; - } - } - if (isa<ReturnInst>(Inst)) { - if (CallInst *CI = Inst.getParent()->getTerminatingMustTailCall()) - RetVec.push_back(CI); - else - RetVec.push_back(&Inst); - } else if (isa<ResumeInst, CleanupReturnInst>(Inst)) { - RetVec.push_back(&Inst); - } - - if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&Inst)) { - for (Value *V : DVI->location_ops()) { - if (auto *Alloca = dyn_cast_or_null<AllocaInst>(V)) - if (!AllocaDbgMap.count(Alloca) || - AllocaDbgMap[Alloca].back() != DVI) - AllocaDbgMap[Alloca].push_back(DVI); - } - } + memtag::StackInfoBuilder SIB( + [this](const AllocaInst &AI) { return isInterestingAlloca(AI); }); + for (auto &Inst : instructions(F)) { + if (InstrumentStack) { + SIB.visit(Inst); + } - if (InstrumentLandingPads && isa<LandingPadInst>(Inst)) - LandingPadVec.push_back(&Inst); + if (InstrumentLandingPads && isa<LandingPadInst>(Inst)) + LandingPadVec.push_back(&Inst); - getInterestingMemoryOperands(&Inst, OperandsToInstrument); + getInterestingMemoryOperands(&Inst, OperandsToInstrument); - if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(&Inst)) - if (!ignoreMemIntrinsic(MI)) - IntrinToInstrument.push_back(MI); - } + if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(&Inst)) + if (!ignoreMemIntrinsic(MI)) + IntrinToInstrument.push_back(MI); } + memtag::StackInfo &SInfo = SIB.get(); + initializeCallbacks(*F.getParent()); bool Changed = false; @@ -1575,7 +1391,7 @@ bool HWAddressSanitizer::sanitizeFunction( if (!LandingPadVec.empty()) Changed |= instrumentLandingPads(LandingPadVec); - if (AllocasToInstrument.empty() && F.hasPersonalityFn() && + if (SInfo.AllocasToInstrument.empty() && F.hasPersonalityFn() && F.getPersonalityFn()->getName() == kHwasanPersonalityThunkName) { // __hwasan_personality_thunk is a no-op for functions without an // instrumented stack, so we can drop it. @@ -1583,7 +1399,7 @@ bool HWAddressSanitizer::sanitizeFunction( Changed = true; } - if (AllocasToInstrument.empty() && OperandsToInstrument.empty() && + if (SInfo.AllocasToInstrument.empty() && OperandsToInstrument.empty() && IntrinToInstrument.empty()) return Changed; @@ -1593,42 +1409,16 @@ bool HWAddressSanitizer::sanitizeFunction( IRBuilder<> EntryIRB(InsertPt); emitPrologue(EntryIRB, /*WithFrameRecord*/ ClRecordStackHistory && - Mapping.WithFrameRecord && !AllocasToInstrument.empty()); + Mapping.WithFrameRecord && + !SInfo.AllocasToInstrument.empty()); - if (!AllocasToInstrument.empty()) { + if (!SInfo.AllocasToInstrument.empty()) { + const DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); + const PostDominatorTree &PDT = FAM.getResult<PostDominatorTreeAnalysis>(F); + const LoopInfo &LI = FAM.getResult<LoopAnalysis>(F); Value *StackTag = ClGenerateTagsWithCalls ? nullptr : getStackBaseTag(EntryIRB); - // Calls to functions that may return twice (e.g. setjmp) confuse the - // postdominator analysis, and will leave us to keep memory tagged after - // function return. Work around this by always untagging at every return - // statement if return_twice functions are called. - instrumentStack(DetectUseAfterScope && !CallsReturnTwice, - AllocasToInstrument, UnrecognizedLifetimes, AllocaDbgMap, - RetVec, StackTag, GetDT, GetPDT); - } - // Pad and align each of the allocas that we instrumented to stop small - // uninteresting allocas from hiding in instrumented alloca's padding and so - // that we have enough space to store real tags for short granules. - DenseMap<AllocaInst *, AllocaInst *> AllocaToPaddedAllocaMap = - padInterestingAllocas(AllocasToInstrument); - - if (!AllocaToPaddedAllocaMap.empty()) { - for (auto &BB : F) { - for (auto &Inst : BB) { - if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&Inst)) { - SmallDenseSet<Value *> LocationOps(DVI->location_ops().begin(), - DVI->location_ops().end()); - for (Value *V : LocationOps) { - if (auto *AI = dyn_cast_or_null<AllocaInst>(V)) { - if (auto *NewAI = AllocaToPaddedAllocaMap.lookup(AI)) - DVI->replaceVariableLocationOp(V, NewAI); - } - } - } - } - } - for (auto &P : AllocaToPaddedAllocaMap) - P.first->eraseFromParent(); + instrumentStack(SInfo, StackTag, DT, PDT, LI); } // If we split the entry block, move any allocas that were originally in the @@ -1654,6 +1444,7 @@ bool HWAddressSanitizer::sanitizeFunction( ShadowBase = nullptr; StackBaseTag = nullptr; + CachedSP = nullptr; return true; } @@ -1735,34 +1526,10 @@ void HWAddressSanitizer::instrumentGlobal(GlobalVariable *GV, uint8_t Tag) { GV->eraseFromParent(); } -static DenseSet<GlobalVariable *> getExcludedGlobals(Module &M) { - NamedMDNode *Globals = M.getNamedMetadata("llvm.asan.globals"); - if (!Globals) - return DenseSet<GlobalVariable *>(); - DenseSet<GlobalVariable *> Excluded(Globals->getNumOperands()); - for (auto MDN : Globals->operands()) { - // Metadata node contains the global and the fields of "Entry". - assert(MDN->getNumOperands() == 5); - auto *V = mdconst::extract_or_null<Constant>(MDN->getOperand(0)); - // The optimizer may optimize away a global entirely. - if (!V) - continue; - auto *StrippedV = V->stripPointerCasts(); - auto *GV = dyn_cast<GlobalVariable>(StrippedV); - if (!GV) - continue; - ConstantInt *IsExcluded = mdconst::extract<ConstantInt>(MDN->getOperand(4)); - if (IsExcluded->isOne()) - Excluded.insert(GV); - } - return Excluded; -} - void HWAddressSanitizer::instrumentGlobals() { std::vector<GlobalVariable *> Globals; - auto ExcludedGlobals = getExcludedGlobals(M); for (GlobalVariable &GV : M.globals()) { - if (ExcludedGlobals.count(&GV)) + if (GV.hasSanitizerMetadata() && GV.getSanitizerMetadata().NoHWAddress) continue; if (GV.isDeclarationForLinker() || GV.getName().startswith("llvm.") || diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp index 9a3afa9cc924..3ef06907dfee 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp @@ -13,30 +13,20 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/IndirectCallPromotionAnalysis.h" #include "llvm/Analysis/IndirectCallVisitor.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" -#include "llvm/IR/Attributes.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PassManager.h" -#include "llvm/IR/Type.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/ProfileData/InstrProf.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" @@ -45,7 +35,6 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/CallPromotionUtils.h" #include <cassert> #include <cstdint> @@ -71,13 +60,13 @@ static cl::opt<bool> DisableICP("disable-icp", cl::init(false), cl::Hidden, // value. // For debug use only. static cl::opt<unsigned> - ICPCutOff("icp-cutoff", cl::init(0), cl::Hidden, cl::ZeroOrMore, + ICPCutOff("icp-cutoff", cl::init(0), cl::Hidden, cl::desc("Max number of promotions for this compilation")); // If ICPCSSkip is non zero, the first ICPCSSkip callsites will be skipped. // For debug use only. static cl::opt<unsigned> - ICPCSSkip("icp-csskip", cl::init(0), cl::Hidden, cl::ZeroOrMore, + ICPCSSkip("icp-csskip", cl::init(0), cl::Hidden, cl::desc("Skip Callsite up to this number for this compilation")); // Set if the pass is called in LTO optimization. The difference for LTO mode @@ -115,55 +104,6 @@ static cl::opt<bool> namespace { -class PGOIndirectCallPromotionLegacyPass : public ModulePass { -public: - static char ID; - - PGOIndirectCallPromotionLegacyPass(bool InLTO = false, bool SamplePGO = false) - : ModulePass(ID), InLTO(InLTO), SamplePGO(SamplePGO) { - initializePGOIndirectCallPromotionLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<ProfileSummaryInfoWrapperPass>(); - } - - StringRef getPassName() const override { return "PGOIndirectCallPromotion"; } - -private: - bool runOnModule(Module &M) override; - - // If this pass is called in LTO. We need to special handling the PGOFuncName - // for the static variables due to LTO's internalization. - bool InLTO; - - // If this pass is called in SamplePGO. We need to add the prof metadata to - // the promoted direct call. - bool SamplePGO; -}; - -} // end anonymous namespace - -char PGOIndirectCallPromotionLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(PGOIndirectCallPromotionLegacyPass, "pgo-icall-prom", - "Use PGO instrumentation profile to promote indirect " - "calls to direct calls.", - false, false) -INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) -INITIALIZE_PASS_END(PGOIndirectCallPromotionLegacyPass, "pgo-icall-prom", - "Use PGO instrumentation profile to promote indirect " - "calls to direct calls.", - false, false) - -ModulePass *llvm::createPGOIndirectCallPromotionLegacyPass(bool InLTO, - bool SamplePGO) { - return new PGOIndirectCallPromotionLegacyPass(InLTO, SamplePGO); -} - -namespace { - // The class for main data structure to promote indirect calls to conditional // direct calls. class ICallPromotionFunc { @@ -428,15 +368,6 @@ static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI, return Changed; } -bool PGOIndirectCallPromotionLegacyPass::runOnModule(Module &M) { - ProfileSummaryInfo *PSI = - &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); - - // Command-line option has the priority for InLTO. - return promoteIndirectCalls(M, PSI, InLTO | ICPLTOMode, - SamplePGO | ICPSamplePGOMode); -} - PreservedAnalyses PGOIndirectCallPromotion::run(Module &M, ModuleAnalysisManager &AM) { ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrOrderFile.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrOrderFile.cpp index 3ea314329079..2091881c29fe 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrOrderFile.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrOrderFile.cpp @@ -9,29 +9,22 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Instrumentation/InstrOrderFile.h" -#include "llvm/ADT/Statistic.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/PassRegistry.h" #include "llvm/ProfileData/InstrProf.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" #include "llvm/Support/FileSystem.h" -#include "llvm/Support/Path.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" #include <fstream> -#include <map> #include <mutex> -#include <set> #include <sstream> using namespace llvm; @@ -61,7 +54,7 @@ private: ArrayType *MapTy; public: - InstrOrderFile() {} + InstrOrderFile() = default; void createOrderFileData(Module &M) { LLVMContext &Ctx = M.getContext(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp index 6868408ef5f5..7843b1522830 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -47,12 +47,10 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Error.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include <algorithm> #include <cassert> -#include <cstddef> #include <cstdint> #include <string> @@ -62,7 +60,7 @@ using namespace llvm; namespace llvm { cl::opt<bool> - DebugInfoCorrelate("debug-info-correlate", cl::ZeroOrMore, + DebugInfoCorrelate("debug-info-correlate", cl::desc("Use debug info to correlate profiles."), cl::init(false)); } // namespace llvm @@ -95,18 +93,18 @@ cl::opt<double> NumCountersPerValueSite( cl::init(1.0)); cl::opt<bool> AtomicCounterUpdateAll( - "instrprof-atomic-counter-update-all", cl::ZeroOrMore, + "instrprof-atomic-counter-update-all", cl::desc("Make all profile counter updates atomic (for testing only)"), cl::init(false)); cl::opt<bool> AtomicCounterUpdatePromoted( - "atomic-counter-update-promoted", cl::ZeroOrMore, + "atomic-counter-update-promoted", cl::desc("Do counter update using atomic fetch add " " for promoted counters only"), cl::init(false)); cl::opt<bool> AtomicFirstCounter( - "atomic-first-counter", cl::ZeroOrMore, + "atomic-first-counter", cl::desc("Use atomic fetch add for first counter in a function (usually " "the entry counter)"), cl::init(false)); @@ -116,37 +114,37 @@ cl::opt<bool> AtomicFirstCounter( // pipeline is setup, i.e., the default value of true of this option // does not mean the promotion will be done by default. Explicitly // setting this option can override the default behavior. -cl::opt<bool> DoCounterPromotion("do-counter-promotion", cl::ZeroOrMore, +cl::opt<bool> DoCounterPromotion("do-counter-promotion", cl::desc("Do counter register promotion"), cl::init(false)); cl::opt<unsigned> MaxNumOfPromotionsPerLoop( - cl::ZeroOrMore, "max-counter-promotions-per-loop", cl::init(20), + "max-counter-promotions-per-loop", cl::init(20), cl::desc("Max number counter promotions per loop to avoid" " increasing register pressure too much")); // A debug option cl::opt<int> - MaxNumOfPromotions(cl::ZeroOrMore, "max-counter-promotions", cl::init(-1), + MaxNumOfPromotions("max-counter-promotions", cl::init(-1), cl::desc("Max number of allowed counter promotions")); cl::opt<unsigned> SpeculativeCounterPromotionMaxExiting( - cl::ZeroOrMore, "speculative-counter-promotion-max-exiting", cl::init(3), + "speculative-counter-promotion-max-exiting", cl::init(3), cl::desc("The max number of exiting blocks of a loop to allow " " speculative counter promotion")); cl::opt<bool> SpeculativeCounterPromotionToLoop( - cl::ZeroOrMore, "speculative-counter-promotion-to-loop", cl::init(false), + "speculative-counter-promotion-to-loop", cl::desc("When the option is false, if the target block is in a loop, " "the promotion will be disallowed unless the promoted counter " " update can be further/iteratively promoted into an acyclic " " region.")); cl::opt<bool> IterativeCounterPromotion( - cl::ZeroOrMore, "iterative-counter-promotion", cl::init(true), + "iterative-counter-promotion", cl::init(true), cl::desc("Allow counter promotion across the whole loop nest.")); cl::opt<bool> SkipRetExitBlock( - cl::ZeroOrMore, "skip-ret-exit-block", cl::init(true), + "skip-ret-exit-block", cl::init(true), cl::desc("Suppress counter promotion if exit blocks contain ret.")); class InstrProfilingLegacyPass : public ModulePass { @@ -211,6 +209,18 @@ public: Value *Addr = cast<StoreInst>(Store)->getPointerOperand(); Type *Ty = LiveInValue->getType(); IRBuilder<> Builder(InsertPos); + if (auto *AddrInst = dyn_cast_or_null<IntToPtrInst>(Addr)) { + // If isRuntimeCounterRelocationEnabled() is true then the address of + // the store instruction is computed with two instructions in + // InstrProfiling::getCounterAddress(). We need to copy those + // instructions to this block to compute Addr correctly. + // %BiasAdd = add i64 ptrtoint <__profc_>, <__llvm_profile_counter_bias> + // %Addr = inttoptr i64 %BiasAdd to i64* + auto *OrigBiasInst = dyn_cast<BinaryOperator>(AddrInst->getOperand(0)); + assert(OrigBiasInst->getOpcode() == Instruction::BinaryOps::Add); + Value *BiasInst = Builder.Insert(OrigBiasInst->clone()); + Addr = Builder.CreateIntToPtr(BiasInst, Ty->getPointerTo()); + } if (AtomicCounterUpdatePromoted) // automic update currently can only be promoted across the current // loop, not the whole loop nest. @@ -303,8 +313,7 @@ public: auto PreheaderCount = BFI->getBlockProfileCount(L.getLoopPreheader()); // If the average loop trip count is not greater than 1.5, we skip // promotion. - if (PreheaderCount && - (PreheaderCount.getValue() * 3) >= (InstrCount.getValue() * 2)) + if (PreheaderCount && (*PreheaderCount * 3) >= (*InstrCount * 2)) continue; } @@ -705,10 +714,9 @@ Value *InstrProfiling::getCounterAddress(InstrProfInstBase *I) { Type *Int64Ty = Type::getInt64Ty(M->getContext()); Function *Fn = I->getParent()->getParent(); - Instruction &EntryI = Fn->getEntryBlock().front(); - LoadInst *LI = dyn_cast<LoadInst>(&EntryI); - if (!LI) { - IRBuilder<> EntryBuilder(&EntryI); + LoadInst *&BiasLI = FunctionToProfileBiasMap[Fn]; + if (!BiasLI) { + IRBuilder<> EntryBuilder(&Fn->getEntryBlock().front()); auto *Bias = M->getGlobalVariable(getInstrProfCounterBiasVarName()); if (!Bias) { // Compiler must define this variable when runtime counter relocation @@ -725,9 +733,9 @@ Value *InstrProfiling::getCounterAddress(InstrProfInstBase *I) { if (TT.supportsCOMDAT()) Bias->setComdat(M->getOrInsertComdat(Bias->getName())); } - LI = EntryBuilder.CreateLoad(Int64Ty, Bias); + BiasLI = EntryBuilder.CreateLoad(Int64Ty, Bias); } - auto *Add = Builder.CreateAdd(Builder.CreatePtrToInt(Addr, Int64Ty), LI); + auto *Add = Builder.CreateAdd(Builder.CreatePtrToInt(Addr, Int64Ty), BiasLI); return Builder.CreateIntToPtr(Add, Addr->getType()); } @@ -769,7 +777,8 @@ void InstrProfiling::lowerCoverageData(GlobalVariable *CoverageNamesVar) { Name->setLinkage(GlobalValue::PrivateLinkage); ReferencedNames.push_back(Name); - NC->dropAllReferences(); + if (isa<ConstantExpr>(NC)) + NC->dropAllReferences(); } CoverageNamesVar->eraseFromParent(); } @@ -856,8 +865,8 @@ static bool needsRuntimeRegistrationOfSectionRange(const Triple &TT) { if (TT.isOSDarwin()) return false; // Use linker script magic to get data/cnts/name start/end. - if (TT.isOSLinux() || TT.isOSFreeBSD() || TT.isOSNetBSD() || - TT.isOSSolaris() || TT.isOSFuchsia() || TT.isPS4CPU() || TT.isOSWindows()) + if (TT.isOSAIX() || TT.isOSLinux() || TT.isOSFreeBSD() || TT.isOSNetBSD() || + TT.isOSSolaris() || TT.isOSFuchsia() || TT.isPS() || TT.isOSWindows()) return false; return true; @@ -1236,7 +1245,7 @@ bool InstrProfiling::emitRuntimeHook() { new GlobalVariable(*M, Int32Ty, false, GlobalValue::ExternalLinkage, nullptr, getInstrProfRuntimeHookVarName()); - if (TT.isOSBinFormatELF()) { + if (TT.isOSBinFormatELF() && !TT.isPS()) { // Mark the user variable as used so that it isn't stripped out. CompilerUsedVars.push_back(Var); } else { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp index dda242492391..9ff0e632bd7f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/Instrumentation.cpp @@ -91,23 +91,13 @@ Comdat *llvm::getOrCreateFunctionComdat(Function &F, Triple &T) { /// initializeInstrumentation - Initialize all passes in the TransformUtils /// library. void llvm::initializeInstrumentation(PassRegistry &Registry) { - initializeAddressSanitizerLegacyPassPass(Registry); - initializeModuleAddressSanitizerLegacyPassPass(Registry); initializeMemProfilerLegacyPassPass(Registry); initializeModuleMemProfilerLegacyPassPass(Registry); initializeBoundsCheckingLegacyPassPass(Registry); initializeControlHeightReductionLegacyPassPass(Registry); - initializeGCOVProfilerLegacyPassPass(Registry); - initializePGOInstrumentationGenLegacyPassPass(Registry); - initializePGOInstrumentationUseLegacyPassPass(Registry); - initializePGOIndirectCallPromotionLegacyPassPass(Registry); - initializePGOMemOPSizeOptLegacyPassPass(Registry); initializeCGProfileLegacyPassPass(Registry); initializeInstrOrderFileLegacyPassPass(Registry); initializeInstrProfilingLegacyPassPass(Registry); - initializeMemorySanitizerLegacyPassPass(Registry); - initializeHWAddressSanitizerLegacyPassPass(Registry); - initializeThreadSanitizerLegacyPassPass(Registry); initializeModuleSanitizerCoverageLegacyPassPass(Registry); initializeDataFlowSanitizerLegacyPassPass(Registry); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MaximumSpanningTree.h b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MaximumSpanningTree.h deleted file mode 100644 index 892a6a26da91..000000000000 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MaximumSpanningTree.h +++ /dev/null @@ -1,109 +0,0 @@ -//===- llvm/Analysis/MaximumSpanningTree.h - Interface ----------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This module provides means for calculating a maximum spanning tree for a -// given set of weighted edges. The type parameter T is the type of a node. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_LIB_TRANSFORMS_INSTRUMENTATION_MAXIMUMSPANNINGTREE_H -#define LLVM_LIB_TRANSFORMS_INSTRUMENTATION_MAXIMUMSPANNINGTREE_H - -#include "llvm/ADT/EquivalenceClasses.h" -#include "llvm/IR/BasicBlock.h" -#include <algorithm> -#include <vector> - -namespace llvm { - - /// MaximumSpanningTree - A MST implementation. - /// The type parameter T determines the type of the nodes of the graph. - template <typename T> - class MaximumSpanningTree { - public: - typedef std::pair<const T*, const T*> Edge; - typedef std::pair<Edge, double> EdgeWeight; - typedef std::vector<EdgeWeight> EdgeWeights; - protected: - typedef std::vector<Edge> MaxSpanTree; - - MaxSpanTree MST; - - private: - // A comparing class for comparing weighted edges. - struct EdgeWeightCompare { - static bool getBlockSize(const T *X) { - const BasicBlock *BB = dyn_cast_or_null<BasicBlock>(X); - return BB ? BB->size() : 0; - } - - bool operator()(EdgeWeight X, EdgeWeight Y) const { - if (X.second > Y.second) return true; - if (X.second < Y.second) return false; - - // Equal edge weights: break ties by comparing block sizes. - size_t XSizeA = getBlockSize(X.first.first); - size_t YSizeA = getBlockSize(Y.first.first); - if (XSizeA > YSizeA) return true; - if (XSizeA < YSizeA) return false; - - size_t XSizeB = getBlockSize(X.first.second); - size_t YSizeB = getBlockSize(Y.first.second); - if (XSizeB > YSizeB) return true; - if (XSizeB < YSizeB) return false; - - return false; - } - }; - - public: - static char ID; // Class identification, replacement for typeinfo - - /// MaximumSpanningTree() - Takes a vector of weighted edges and returns a - /// spanning tree. - MaximumSpanningTree(EdgeWeights &EdgeVector) { - llvm::stable_sort(EdgeVector, EdgeWeightCompare()); - - // Create spanning tree, Forest contains a special data structure - // that makes checking if two nodes are already in a common (sub-)tree - // fast and cheap. - EquivalenceClasses<const T*> Forest; - for (typename EdgeWeights::iterator EWi = EdgeVector.begin(), - EWe = EdgeVector.end(); EWi != EWe; ++EWi) { - Edge e = (*EWi).first; - - Forest.insert(e.first); - Forest.insert(e.second); - } - - // Iterate over the sorted edges, biggest first. - for (typename EdgeWeights::iterator EWi = EdgeVector.begin(), - EWe = EdgeVector.end(); EWi != EWe; ++EWi) { - Edge e = (*EWi).first; - - if (Forest.findLeader(e.first) != Forest.findLeader(e.second)) { - Forest.unionSets(e.first, e.second); - // So we know now that the edge is not already in a subtree, so we push - // the edge to the MST. - MST.push_back(e); - } - } - } - - typename MaxSpanTree::iterator begin() { - return MST.begin(); - } - - typename MaxSpanTree::iterator end() { - return MST.end(); - } - }; - -} // End llvm namespace - -#endif // LLVM_LIB_TRANSFORMS_INSTRUMENTATION_MAXIMUMSPANNINGTREE_H diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp index 5e078f2c4212..01e3b2c20218 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp @@ -27,15 +27,14 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" +#include "llvm/ProfileData/InstrProf.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -156,7 +155,6 @@ static uint64_t getCtorAndDtorPriority(Triple &TargetTriple) { struct InterestingMemoryAccess { Value *Addr = nullptr; bool IsWrite; - unsigned Alignment; Type *AccessTy; uint64_t TypeSize; Value *MaybeMask = nullptr; @@ -182,8 +180,7 @@ public: void instrumentAddress(Instruction *OrigIns, Instruction *InsertBefore, Value *Addr, uint32_t TypeSize, bool IsWrite); void instrumentMaskedLoadOrStore(const DataLayout &DL, Value *Mask, - Instruction *I, Value *Addr, - unsigned Alignment, Type *AccessTy, + Instruction *I, Value *Addr, Type *AccessTy, bool IsWrite); void instrumentMemIntrinsic(MemIntrinsic *MI); Value *memToShadow(Value *Shadow, IRBuilder<> &IRB); @@ -255,7 +252,7 @@ public: } // end anonymous namespace -MemProfilerPass::MemProfilerPass() {} +MemProfilerPass::MemProfilerPass() = default; PreservedAnalyses MemProfilerPass::run(Function &F, AnalysisManager<Function> &AM) { @@ -266,7 +263,7 @@ PreservedAnalyses MemProfilerPass::run(Function &F, return PreservedAnalyses::all(); } -ModuleMemProfilerPass::ModuleMemProfilerPass() {} +ModuleMemProfilerPass::ModuleMemProfilerPass() = default; PreservedAnalyses ModuleMemProfilerPass::run(Module &M, AnalysisManager<Module> &AM) { @@ -341,28 +338,24 @@ MemProfiler::isInterestingMemoryAccess(Instruction *I) const { return None; Access.IsWrite = false; Access.AccessTy = LI->getType(); - Access.Alignment = LI->getAlignment(); Access.Addr = LI->getPointerOperand(); } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) { if (!ClInstrumentWrites) return None; Access.IsWrite = true; Access.AccessTy = SI->getValueOperand()->getType(); - Access.Alignment = SI->getAlignment(); Access.Addr = SI->getPointerOperand(); } else if (AtomicRMWInst *RMW = dyn_cast<AtomicRMWInst>(I)) { if (!ClInstrumentAtomics) return None; Access.IsWrite = true; Access.AccessTy = RMW->getValOperand()->getType(); - Access.Alignment = 0; Access.Addr = RMW->getPointerOperand(); } else if (AtomicCmpXchgInst *XCHG = dyn_cast<AtomicCmpXchgInst>(I)) { if (!ClInstrumentAtomics) return None; Access.IsWrite = true; Access.AccessTy = XCHG->getCompareOperand()->getType(); - Access.Alignment = 0; Access.Addr = XCHG->getPointerOperand(); } else if (auto *CI = dyn_cast<CallInst>(I)) { auto *F = CI->getCalledFunction(); @@ -384,11 +377,6 @@ MemProfiler::isInterestingMemoryAccess(Instruction *I) const { } auto *BasePtr = CI->getOperand(0 + OpOffset); - if (auto *AlignmentConstant = - dyn_cast<ConstantInt>(CI->getOperand(1 + OpOffset))) - Access.Alignment = (unsigned)AlignmentConstant->getZExtValue(); - else - Access.Alignment = 1; // No alignment guarantees. We probably got Undef Access.MaybeMask = CI->getOperand(2 + OpOffset); Access.Addr = BasePtr; } @@ -410,6 +398,25 @@ MemProfiler::isInterestingMemoryAccess(Instruction *I) const { if (Access.Addr->isSwiftError()) return None; + // Peel off GEPs and BitCasts. + auto *Addr = Access.Addr->stripInBoundsOffsets(); + + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) { + // Do not instrument PGO counter updates. + if (GV->hasSection()) { + StringRef SectionName = GV->getSection(); + // Check if the global is in the PGO counters section. + auto OF = Triple(I->getModule()->getTargetTriple()).getObjectFormat(); + if (SectionName.endswith( + getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false))) + return None; + } + + // Do not instrument accesses to LLVM internal variables. + if (GV->getName().startswith("__llvm")) + return None; + } + const DataLayout &DL = I->getModule()->getDataLayout(); Access.TypeSize = DL.getTypeStoreSizeInBits(Access.AccessTy); return Access; @@ -417,7 +424,6 @@ MemProfiler::isInterestingMemoryAccess(Instruction *I) const { void MemProfiler::instrumentMaskedLoadOrStore(const DataLayout &DL, Value *Mask, Instruction *I, Value *Addr, - unsigned Alignment, Type *AccessTy, bool IsWrite) { auto *VTy = cast<FixedVectorType>(AccessTy); uint64_t ElemTypeSize = DL.getTypeStoreSizeInBits(VTy->getScalarType()); @@ -468,8 +474,7 @@ void MemProfiler::instrumentMop(Instruction *I, const DataLayout &DL, if (Access.MaybeMask) { instrumentMaskedLoadOrStore(DL, Access.MaybeMask, I, Access.Addr, - Access.Alignment, Access.AccessTy, - Access.IsWrite); + Access.AccessTy, Access.IsWrite); } else { // Since the access counts will be accumulated across the entire allocation, // we only update the shadow access count for the first location and thus @@ -615,8 +620,6 @@ bool MemProfiler::instrumentFunction(Function &F) { initializeCallbacks(*F.getParent()); - FunctionModified |= insertDynamicShadowAtFunctionEntry(F); - SmallVector<Instruction *, 16> ToInstrument; // Fill the set of memory operations to instrument. @@ -627,6 +630,15 @@ bool MemProfiler::instrumentFunction(Function &F) { } } + if (ToInstrument.empty()) { + LLVM_DEBUG(dbgs() << "MEMPROF done instrumenting: " << FunctionModified + << " " << F << "\n"); + + return FunctionModified; + } + + FunctionModified |= insertDynamicShadowAtFunctionEntry(F); + int NumInstrumented = 0; for (auto *Inst : ToInstrument) { if (ClDebugMin < 0 || ClDebugMax < 0 || diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index c51acdf52f14..4d72f6c3d1a9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -174,24 +174,19 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsX86.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueMap.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/Support/Alignment.h" #include "llvm/Support/AtomicOrdering.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -634,33 +629,6 @@ void insertModuleCtor(Module &M) { }); } -/// A legacy function pass for msan instrumentation. -/// -/// Instruments functions to detect uninitialized reads. -struct MemorySanitizerLegacyPass : public FunctionPass { - // Pass identification, replacement for typeid. - static char ID; - - MemorySanitizerLegacyPass(MemorySanitizerOptions Options = {}) - : FunctionPass(ID), Options(Options) { - initializeMemorySanitizerLegacyPassPass(*PassRegistry::getPassRegistry()); - } - StringRef getPassName() const override { return "MemorySanitizerLegacyPass"; } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetLibraryInfoWrapperPass>(); - } - - bool runOnFunction(Function &F) override { - return MSan->sanitizeFunction( - F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F)); - } - bool doInitialization(Module &M) override; - - Optional<MemorySanitizer> MSan; - MemorySanitizerOptions Options; -}; - template <class T> T getOptOrDefault(const cl::opt<T> &Opt, T Default) { return (Opt.getNumOccurrences() > 0) ? Opt : Default; } @@ -705,21 +673,6 @@ void MemorySanitizerPass::printPipeline( OS << ">"; } -char MemorySanitizerLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(MemorySanitizerLegacyPass, "msan", - "MemorySanitizer: detects uninitialized reads.", false, - false) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(MemorySanitizerLegacyPass, "msan", - "MemorySanitizer: detects uninitialized reads.", false, - false) - -FunctionPass * -llvm::createMemorySanitizerLegacyPassPass(MemorySanitizerOptions Options) { - return new MemorySanitizerLegacyPass(Options); -} - /// Create a non-const global initialized with the given string. /// /// Creates a writable global for Str so that we can pass it to the @@ -1017,13 +970,6 @@ void MemorySanitizer::initializeModule(Module &M) { } } -bool MemorySanitizerLegacyPass::doInitialization(Module &M) { - if (!Options.Kernel) - insertModuleCtor(M); - MSan.emplace(M, Options); - return true; -} - namespace { /// A helper class that handles instrumentation of VarArg @@ -1674,7 +1620,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// or extracts if from ParamTLS (for function arguments). Value *getShadow(Value *V) { if (Instruction *I = dyn_cast<Instruction>(V)) { - if (!PropagateShadow || I->getMetadata("nosanitize")) + if (!PropagateShadow || I->getMetadata(LLVMContext::MD_nosanitize)) return getCleanShadow(V); // For instructions the shadow is already stored in the map. Value *Shadow = ShadowMap[V]; @@ -1694,9 +1640,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } if (Argument *A = dyn_cast<Argument>(V)) { // For arguments we compute the shadow on demand and store it in the map. - Value **ShadowPtr = &ShadowMap[V]; - if (*ShadowPtr) - return *ShadowPtr; + Value *&ShadowPtr = ShadowMap[V]; + if (ShadowPtr) + return ShadowPtr; Function *F = A->getParent(); IRBuilder<> EntryIRB(FnPrologueEnd); unsigned ArgOffset = 0; @@ -1753,12 +1699,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (!PropagateShadow || Overflow || FArg.hasByValAttr() || (MS.EagerChecks && FArg.hasAttribute(Attribute::NoUndef))) { - *ShadowPtr = getCleanShadow(V); + ShadowPtr = getCleanShadow(V); setOrigin(A, getCleanOrigin()); } else { // Shadow over TLS Value *Base = getShadowPtrForArgument(&FArg, EntryIRB, ArgOffset); - *ShadowPtr = EntryIRB.CreateAlignedLoad(getShadowTy(&FArg), Base, + ShadowPtr = EntryIRB.CreateAlignedLoad(getShadowTy(&FArg), Base, kShadowTLSAlignment); if (MS.TrackOrigins) { Value *OriginPtr = @@ -1767,14 +1713,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } } LLVM_DEBUG(dbgs() - << " ARG: " << FArg << " ==> " << **ShadowPtr << "\n"); + << " ARG: " << FArg << " ==> " << *ShadowPtr << "\n"); break; } ArgOffset += alignTo(Size, kShadowTLSAlignment); } - assert(*ShadowPtr && "Could not find shadow for an argument"); - return *ShadowPtr; + assert(ShadowPtr && "Could not find shadow for an argument"); + return ShadowPtr; } // For everything else the shadow is zero. return getCleanShadow(V); @@ -1793,7 +1739,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { assert((isa<Instruction>(V) || isa<Argument>(V)) && "Unexpected value type in getOrigin()"); if (Instruction *I = dyn_cast<Instruction>(V)) { - if (I->getMetadata("nosanitize")) + if (I->getMetadata(LLVMContext::MD_nosanitize)) return getCleanOrigin(); } Value *Origin = OriginMap[V]; @@ -1916,7 +1862,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // ------------------- Visitors. using InstVisitor<MemorySanitizerVisitor>::visit; void visit(Instruction &I) { - if (I.getMetadata("nosanitize")) + if (I.getMetadata(LLVMContext::MD_nosanitize)) return; // Don't want to visit if we're in the prologue if (isInPrologue(I)) @@ -1930,12 +1876,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// Optionally, checks that the load address is fully defined. void visitLoadInst(LoadInst &I) { assert(I.getType()->isSized() && "Load type must have size"); - assert(!I.getMetadata("nosanitize")); + assert(!I.getMetadata(LLVMContext::MD_nosanitize)); IRBuilder<> IRB(I.getNextNode()); Type *ShadowTy = getShadowTy(&I); Value *Addr = I.getPointerOperand(); Value *ShadowPtr = nullptr, *OriginPtr = nullptr; - const Align Alignment = assumeAligned(I.getAlignment()); + const Align Alignment = I.getAlign(); if (PropagateShadow) { std::tie(ShadowPtr, OriginPtr) = getShadowOriginPtr(Addr, IRB, ShadowTy, Alignment, /*isStore*/ false); @@ -2573,6 +2519,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// /// Similar situation exists for memcpy and memset. void visitMemMoveInst(MemMoveInst &I) { + getShadow(I.getArgOperand(1)); // Ensure shadow initialized IRBuilder<> IRB(&I); IRB.CreateCall( MS.MemmoveFn, @@ -2587,6 +2534,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // FIXME: consider doing manual inline for small constant sizes and proper // alignment. void visitMemCpyInst(MemCpyInst &I) { + getShadow(I.getArgOperand(1)); // Ensure shadow initialized IRBuilder<> IRB(&I); IRB.CreateCall( MS.MemcpyFn, @@ -3252,27 +3200,37 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { SOC.Done(&I); } - // Instrument _mm_*_sd intrinsics - void handleUnarySdIntrinsic(IntrinsicInst &I) { + // Instrument _mm_*_sd|ss intrinsics + void handleUnarySdSsIntrinsic(IntrinsicInst &I) { IRBuilder<> IRB(&I); + unsigned Width = + cast<FixedVectorType>(I.getArgOperand(0)->getType())->getNumElements(); Value *First = getShadow(&I, 0); Value *Second = getShadow(&I, 1); - // High word of first operand, low word of second - Value *Shadow = - IRB.CreateShuffleVector(First, Second, llvm::makeArrayRef<int>({2, 1})); + // First element of second operand, remaining elements of first operand + SmallVector<int, 16> Mask; + Mask.push_back(Width); + for (unsigned i = 1; i < Width; i++) + Mask.push_back(i); + Value *Shadow = IRB.CreateShuffleVector(First, Second, Mask); setShadow(&I, Shadow); setOriginForNaryOp(I); } - void handleBinarySdIntrinsic(IntrinsicInst &I) { + void handleBinarySdSsIntrinsic(IntrinsicInst &I) { IRBuilder<> IRB(&I); + unsigned Width = + cast<FixedVectorType>(I.getArgOperand(0)->getType())->getNumElements(); Value *First = getShadow(&I, 0); Value *Second = getShadow(&I, 1); Value *OrShadow = IRB.CreateOr(First, Second); - // High word of first operand, low word of both OR'd together - Value *Shadow = IRB.CreateShuffleVector(First, OrShadow, - llvm::makeArrayRef<int>({2, 1})); + // First element of both OR'd together, remaining elements of first operand + SmallVector<int, 16> Mask; + Mask.push_back(Width); + for (unsigned i = 1; i < Width; i++) + Mask.push_back(i); + Value *Shadow = IRB.CreateShuffleVector(First, OrShadow, Mask); setShadow(&I, Shadow); setOriginForNaryOp(I); @@ -3547,11 +3505,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { break; case Intrinsic::x86_sse41_round_sd: - handleUnarySdIntrinsic(I); + case Intrinsic::x86_sse41_round_ss: + handleUnarySdSsIntrinsic(I); break; case Intrinsic::x86_sse2_max_sd: + case Intrinsic::x86_sse_max_ss: case Intrinsic::x86_sse2_min_sd: - handleBinarySdIntrinsic(I); + case Intrinsic::x86_sse_min_ss: + handleBinarySdSsIntrinsic(I); break; case Intrinsic::fshl: @@ -3630,7 +3591,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } void visitCallBase(CallBase &CB) { - assert(!CB.getMetadata("nosanitize")); + assert(!CB.getMetadata(LLVMContext::MD_nosanitize)); if (CB.isInlineAsm()) { // For inline asm (either a call to asm function, or callbr instruction), // do the usual thing: check argument shadow and mark all outputs as @@ -4083,8 +4044,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // Nothing to do here. } - void instrumentAsmArgument(Value *Operand, Instruction &I, IRBuilder<> &IRB, - const DataLayout &DL, bool isOutput) { + void instrumentAsmArgument(Value *Operand, Type *ElemTy, Instruction &I, + IRBuilder<> &IRB, const DataLayout &DL, + bool isOutput) { // For each assembly argument, we check its value for being initialized. // If the argument is a pointer, we assume it points to a single element // of the corresponding type (or to a 8-byte word, if the type is unsized). @@ -4096,10 +4058,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { assert(!isOutput); return; } - Type *ElType = OpType->getPointerElementType(); - if (!ElType->isSized()) + if (!ElemTy->isSized()) return; - int Size = DL.getTypeStoreSize(ElType); + int Size = DL.getTypeStoreSize(ElemTy); Value *Ptr = IRB.CreatePointerCast(Operand, IRB.getInt8PtrTy()); Value *SizeVal = ConstantInt::get(MS.IntptrTy, Size); IRB.CreateCall(MS.MsanInstrumentAsmStoreFn, {Ptr, SizeVal}); @@ -4159,14 +4120,16 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // that we won't overwrite uninit values before checking them. for (int i = OutputArgs; i < NumOperands; i++) { Value *Operand = CB->getOperand(i); - instrumentAsmArgument(Operand, I, IRB, DL, /*isOutput*/ false); + instrumentAsmArgument(Operand, CB->getParamElementType(i), I, IRB, DL, + /*isOutput*/ false); } // Unpoison output arguments. This must happen before the actual InlineAsm // call, so that the shadow for memory published in the asm() statement // remains valid. for (int i = 0; i < OutputArgs; i++) { Value *Operand = CB->getOperand(i); - instrumentAsmArgument(Operand, I, IRB, DL, /*isOutput*/ true); + instrumentAsmArgument(Operand, CB->getParamElementType(i), I, IRB, DL, + /*isOutput*/ true); } setShadow(&I, getCleanShadow(&I)); @@ -4885,8 +4848,8 @@ struct VarArgPowerPC64Helper : public VarArgHelper { assert(A->getType()->isPointerTy()); Type *RealTy = CB.getParamByValType(ArgNo); uint64_t ArgSize = DL.getTypeAllocSize(RealTy); - MaybeAlign ArgAlign = CB.getParamAlign(ArgNo); - if (!ArgAlign || *ArgAlign < Align(8)) + Align ArgAlign = CB.getParamAlign(ArgNo).value_or(Align(8)); + if (ArgAlign < 8) ArgAlign = Align(8); VAArgOffset = alignTo(VAArgOffset, ArgAlign); if (!IsFixed) { @@ -4902,27 +4865,27 @@ struct VarArgPowerPC64Helper : public VarArgHelper { kShadowTLSAlignment, ArgSize); } } - VAArgOffset += alignTo(ArgSize, 8); + VAArgOffset += alignTo(ArgSize, Align(8)); } else { Value *Base; uint64_t ArgSize = DL.getTypeAllocSize(A->getType()); - uint64_t ArgAlign = 8; + Align ArgAlign = Align(8); if (A->getType()->isArrayTy()) { // Arrays are aligned to element size, except for long double // arrays, which are aligned to 8 bytes. Type *ElementTy = A->getType()->getArrayElementType(); if (!ElementTy->isPPC_FP128Ty()) - ArgAlign = DL.getTypeAllocSize(ElementTy); + ArgAlign = Align(DL.getTypeAllocSize(ElementTy)); } else if (A->getType()->isVectorTy()) { // Vectors are naturally aligned. - ArgAlign = DL.getTypeAllocSize(A->getType()); + ArgAlign = Align(ArgSize); } if (ArgAlign < 8) - ArgAlign = 8; + ArgAlign = Align(8); VAArgOffset = alignTo(VAArgOffset, ArgAlign); if (DL.isBigEndian()) { - // Adjusting the shadow for argument with size < 8 to match the placement - // of bits in big endian system + // Adjusting the shadow for argument with size < 8 to match the + // placement of bits in big endian system if (ArgSize < 8) VAArgOffset += (8 - ArgSize); } @@ -4933,7 +4896,7 @@ struct VarArgPowerPC64Helper : public VarArgHelper { IRB.CreateAlignedStore(MSV.getShadow(A), Base, kShadowTLSAlignment); } VAArgOffset += ArgSize; - VAArgOffset = alignTo(VAArgOffset, 8); + VAArgOffset = alignTo(VAArgOffset, Align(8)); } if (IsFixed) VAArgBase = VAArgOffset; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index 0902a94452e3..3a29cd70e42e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -52,7 +52,6 @@ #include "ValueProfileCollector.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -68,6 +67,7 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -94,8 +94,6 @@ #include "llvm/IR/ProfileSummary.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" #include "llvm/ProfileData/InstrProf.h" #include "llvm/ProfileData/InstrProfReader.h" #include "llvm/Support/BranchProbability.h" @@ -110,6 +108,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/MisExpect.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include <algorithm> #include <cassert> @@ -173,14 +172,14 @@ static cl::opt<bool> DisableValueProfiling("disable-vp", cl::init(false), // Command line option to set the maximum number of VP annotations to write to // the metadata for a single indirect call callsite. static cl::opt<unsigned> MaxNumAnnotations( - "icp-max-annotations", cl::init(3), cl::Hidden, cl::ZeroOrMore, + "icp-max-annotations", cl::init(3), cl::Hidden, cl::desc("Max number of annotations for a single indirect " "call callsite")); // Command line option to set the maximum number of value annotations // to write to the metadata for a single memop intrinsic. static cl::opt<unsigned> MaxNumMemOPAnnotations( - "memop-max-annotations", cl::init(4), cl::Hidden, cl::ZeroOrMore, + "memop-max-annotations", cl::init(4), cl::Hidden, cl::desc("Max number of preicise value annotations for a single memop" "intrinsic")); @@ -256,7 +255,7 @@ static cl::opt<bool> PGOInstrumentEntry( cl::desc("Force to instrument function entry basicblock.")); static cl::opt<bool> PGOFunctionEntryCoverage( - "pgo-function-entry-coverage", cl::init(false), cl::Hidden, cl::ZeroOrMore, + "pgo-function-entry-coverage", cl::Hidden, cl::desc( "Use this option to enable function entry coverage instrumentation.")); @@ -431,125 +430,8 @@ struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> { unsigned getNumOfSelectInsts() const { return NSIs; } }; - -class PGOInstrumentationGenLegacyPass : public ModulePass { -public: - static char ID; - - PGOInstrumentationGenLegacyPass(bool IsCS = false) - : ModulePass(ID), IsCS(IsCS) { - initializePGOInstrumentationGenLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - StringRef getPassName() const override { return "PGOInstrumentationGenPass"; } - -private: - // Is this is context-sensitive instrumentation. - bool IsCS; - bool runOnModule(Module &M) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<BlockFrequencyInfoWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - } -}; - -class PGOInstrumentationUseLegacyPass : public ModulePass { -public: - static char ID; - - // Provide the profile filename as the parameter. - PGOInstrumentationUseLegacyPass(std::string Filename = "", bool IsCS = false) - : ModulePass(ID), ProfileFileName(std::move(Filename)), IsCS(IsCS) { - if (!PGOTestProfileFile.empty()) - ProfileFileName = PGOTestProfileFile; - initializePGOInstrumentationUseLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - StringRef getPassName() const override { return "PGOInstrumentationUsePass"; } - -private: - std::string ProfileFileName; - // Is this is context-sensitive instrumentation use. - bool IsCS; - - bool runOnModule(Module &M) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<ProfileSummaryInfoWrapperPass>(); - AU.addRequired<BlockFrequencyInfoWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - } -}; - -class PGOInstrumentationGenCreateVarLegacyPass : public ModulePass { -public: - static char ID; - StringRef getPassName() const override { - return "PGOInstrumentationGenCreateVarPass"; - } - PGOInstrumentationGenCreateVarLegacyPass(std::string CSInstrName = "") - : ModulePass(ID), InstrProfileOutput(CSInstrName) { - initializePGOInstrumentationGenCreateVarLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - -private: - bool runOnModule(Module &M) override { - createProfileFileNameVar(M, InstrProfileOutput); - // The variable in a comdat may be discarded by LTO. Ensure the - // declaration will be retained. - appendToCompilerUsed(M, createIRLevelProfileFlagVar(M, /*IsCS=*/true)); - return false; - } - std::string InstrProfileOutput; -}; - } // end anonymous namespace -char PGOInstrumentationGenLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(PGOInstrumentationGenLegacyPass, "pgo-instr-gen", - "PGO instrumentation.", false, false) -INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(PGOInstrumentationGenLegacyPass, "pgo-instr-gen", - "PGO instrumentation.", false, false) - -ModulePass *llvm::createPGOInstrumentationGenLegacyPass(bool IsCS) { - return new PGOInstrumentationGenLegacyPass(IsCS); -} - -char PGOInstrumentationUseLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(PGOInstrumentationUseLegacyPass, "pgo-instr-use", - "Read PGO instrumentation profile.", false, false) -INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) -INITIALIZE_PASS_END(PGOInstrumentationUseLegacyPass, "pgo-instr-use", - "Read PGO instrumentation profile.", false, false) - -ModulePass *llvm::createPGOInstrumentationUseLegacyPass(StringRef Filename, - bool IsCS) { - return new PGOInstrumentationUseLegacyPass(Filename.str(), IsCS); -} - -char PGOInstrumentationGenCreateVarLegacyPass::ID = 0; - -INITIALIZE_PASS(PGOInstrumentationGenCreateVarLegacyPass, - "pgo-instr-gen-create-var", - "Create PGO instrumentation version variable for CSPGO.", false, - false) - -ModulePass * -llvm::createPGOInstrumentationGenCreateVarLegacyPass(StringRef CSInstrName) { - return new PGOInstrumentationGenCreateVarLegacyPass(std::string(CSInstrName)); -} - namespace { /// An MST based instrumentation for PGO @@ -940,7 +822,7 @@ static void instrumentOneFunc( bool IsCS) { // Split indirectbr critical edges here before computing the MST rather than // later in getInstrBB() to avoid invalidating it. - SplitIndirectBrCriticalEdges(F, BPI, BFI); + SplitIndirectBrCriticalEdges(F, /*IgnoreBlocksWithoutPHI=*/false, BPI, BFI); FuncPGOInstrumentation<PGOEdge, BBInfo> FuncInfo( F, TLI, ComdatMembers, true, BPI, BFI, IsCS, PGOInstrumentEntry); @@ -1457,6 +1339,7 @@ void PGOUseFunc::populateCounters() { } LLVM_DEBUG(dbgs() << "Populate counts in " << NumPasses << " passes.\n"); + (void) NumPasses; #ifndef NDEBUG // Assert every BB has a valid counter. for (auto &BB : F) { @@ -1697,22 +1580,6 @@ PGOInstrumentationGenCreateVar::run(Module &M, ModuleAnalysisManager &AM) { return PreservedAnalyses::all(); } -bool PGOInstrumentationGenLegacyPass::runOnModule(Module &M) { - if (skipModule(M)) - return false; - - auto LookupTLI = [this](Function &F) -> TargetLibraryInfo & { - return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - }; - auto LookupBPI = [this](Function &F) { - return &this->getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI(); - }; - auto LookupBFI = [this](Function &F) { - return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI(); - }; - return InstrumentAllFunctions(M, LookupTLI, LookupBPI, LookupBFI, IsCS); -} - PreservedAnalyses PGOInstrumentationGen::run(Module &M, ModuleAnalysisManager &AM) { auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); @@ -1740,7 +1607,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI, BlockFrequencyInfo NBFI(F, NBPI, LI); #ifndef NDEBUG auto BFIEntryCount = F.getEntryCount(); - assert(BFIEntryCount.hasValue() && (BFIEntryCount->getCount() > 0) && + assert(BFIEntryCount && (BFIEntryCount->getCount() > 0) && "Invalid BFI Entrycount"); #endif auto SumCount = APFloat::getZero(APFloat::IEEEdouble()); @@ -1752,7 +1619,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI, continue; auto BFICount = NBFI.getBlockProfileCount(&BBI); CountValue = Func.getBBInfo(&BBI).CountValue; - BFICountValue = BFICount.getValue(); + BFICountValue = *BFICount; SumCount.add(APFloat(CountValue * 1.0), APFloat::rmNearestTiesToEven); SumBFICount.add(APFloat(BFICountValue * 1.0), APFloat::rmNearestTiesToEven); } @@ -1805,7 +1672,7 @@ static void verifyFuncBFI(PGOUseFunc &Func, LoopInfo &LI, NonZeroBBNum++; auto BFICount = NBFI.getBlockProfileCount(&BBI); if (BFICount) - BFICountValue = BFICount.getValue(); + BFICountValue = *BFICount; if (HotBBOnly) { bool rawIsHot = CountValue >= HotCountThreshold; @@ -1929,7 +1796,7 @@ static bool annotateAllFunctions( auto *BFI = LookupBFI(F); // Split indirectbr critical edges here before computing the MST rather than // later in getInstrBB() to avoid invalidating it. - SplitIndirectBrCriticalEdges(F, BPI, BFI); + SplitIndirectBrCriticalEdges(F, /*IgnoreBlocksWithoutPHI=*/false, BPI, BFI); PGOUseFunc Func(F, &M, TLI, ComdatMembers, BPI, BFI, PSI, IsCS, InstrumentFuncEntry); // When AllMinusOnes is true, it means the profile for the function @@ -2073,25 +1940,6 @@ PreservedAnalyses PGOInstrumentationUse::run(Module &M, return PreservedAnalyses::none(); } -bool PGOInstrumentationUseLegacyPass::runOnModule(Module &M) { - if (skipModule(M)) - return false; - - auto LookupTLI = [this](Function &F) -> TargetLibraryInfo & { - return this->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - }; - auto LookupBPI = [this](Function &F) { - return &this->getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI(); - }; - auto LookupBFI = [this](Function &F) { - return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI(); - }; - - auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); - return annotateAllFunctions(M, ProfileFileName, "", LookupTLI, LookupBPI, - LookupBFI, PSI, IsCS); -} - static std::string getSimpleNodeName(const BasicBlock *Node) { if (!Node->getName().empty()) return std::string(Node->getName()); @@ -2117,6 +1965,8 @@ void llvm::setProfMetadata(Module *M, Instruction *TI, dbgs() << W << " "; } dbgs() << "\n";); + misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false); + TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); if (EmitBranchProbability) { std::string BrCondStr = getBranchCondString(TI); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp index d4b78f2c14b0..b11f16894669 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp @@ -20,7 +20,6 @@ #include "llvm/ADT/Twine.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/DomTreeUpdater.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/BasicBlock.h" @@ -29,15 +28,11 @@ #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" -#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" -#include "llvm/PassRegistry.h" #include "llvm/ProfileData/InstrProf.h" #define INSTR_PROF_VALUE_PROF_MEMOP_API #include "llvm/ProfileData/InstrProfData.inc" @@ -46,8 +41,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" -#include "llvm/Support/WithColor.h" -#include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include <cassert> @@ -63,8 +56,7 @@ STATISTIC(NumOfPGOMemOPAnnotate, "Number of memop intrinsics annotated."); // The minimum call count to optimize memory intrinsic calls. static cl::opt<unsigned> - MemOPCountThreshold("pgo-memop-count-threshold", cl::Hidden, cl::ZeroOrMore, - cl::init(1000), + MemOPCountThreshold("pgo-memop-count-threshold", cl::Hidden, cl::init(1000), cl::desc("The minimum count to optimize memory " "intrinsic calls")); @@ -76,14 +68,13 @@ static cl::opt<bool> DisableMemOPOPT("disable-memop-opt", cl::init(false), // The percent threshold to optimize memory intrinsic calls. static cl::opt<unsigned> MemOPPercentThreshold("pgo-memop-percent-threshold", cl::init(40), - cl::Hidden, cl::ZeroOrMore, + cl::Hidden, cl::desc("The percentage threshold for the " "memory intrinsic calls optimization")); // Maximum number of versions for optimizing memory intrinsic call. static cl::opt<unsigned> MemOPMaxVersion("pgo-memop-max-version", cl::init(3), cl::Hidden, - cl::ZeroOrMore, cl::desc("The max version for the optimized memory " " intrinsic calls")); @@ -103,43 +94,6 @@ static cl::opt<unsigned> cl::desc("Optimize the memop size <= this value")); namespace { -class PGOMemOPSizeOptLegacyPass : public FunctionPass { -public: - static char ID; - - PGOMemOPSizeOptLegacyPass() : FunctionPass(ID) { - initializePGOMemOPSizeOptLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - StringRef getPassName() const override { return "PGOMemOPSize"; } - -private: - bool runOnFunction(Function &F) override; - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<BlockFrequencyInfoWrapperPass>(); - AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); - } -}; -} // end anonymous namespace - -char PGOMemOPSizeOptLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(PGOMemOPSizeOptLegacyPass, "pgo-memop-opt", - "Optimize memory intrinsic using its size value profile", - false, false) -INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(PGOMemOPSizeOptLegacyPass, "pgo-memop-opt", - "Optimize memory intrinsic using its size value profile", - false, false) - -FunctionPass *llvm::createPGOMemOPSizeOptLegacyPass() { - return new PGOMemOPSizeOptLegacyPass(); -} - -namespace { static const char *getMIName(const MemIntrinsic *MI) { switch (MI->getIntrinsicID()) { @@ -517,20 +471,6 @@ static bool PGOMemOPSizeOptImpl(Function &F, BlockFrequencyInfo &BFI, return MemOPSizeOpt.isChanged(); } -bool PGOMemOPSizeOptLegacyPass::runOnFunction(Function &F) { - BlockFrequencyInfo &BFI = - getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI(); - auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); - DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr; - TargetLibraryInfo &TLI = - getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - return PGOMemOPSizeOptImpl(F, BFI, ORE, DT, TLI); -} - -namespace llvm { -char &PGOMemOPSizeOptID = PGOMemOPSizeOptLegacyPass::ID; - PreservedAnalyses PGOMemOPSizeOpt::run(Function &F, FunctionAnalysisManager &FAM) { auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F); @@ -544,4 +484,3 @@ PreservedAnalyses PGOMemOPSizeOpt::run(Function &F, PA.preserve<DominatorTreeAnalysis>(); return PA; } -} // namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PoisonChecking.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PoisonChecking.cpp index fc5267261851..0e39fe266369 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PoisonChecking.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PoisonChecking.cpp @@ -60,15 +60,9 @@ #include "llvm/Transforms/Instrumentation/PoisonChecking.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InstVisitor.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/PatternMatch.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" using namespace llvm; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp index d3b60c7add34..d9d11cc90d3d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp @@ -13,30 +13,24 @@ #include "llvm/Transforms/Instrumentation/SanitizerCoverage.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Triple.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/PostDominators.h" -#include "llvm/IR/CFG.h" #include "llvm/IR/Constant.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/DebugInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InlineAsm.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" -#include "llvm/IR/MDBuilder.h" -#include "llvm/IR/Mangler.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" #include "llvm/Support/SpecialCaseList.h" #include "llvm/Support/VirtualFileSystem.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -247,8 +241,7 @@ private: Type *Ty); void SetNoSanitizeMetadata(Instruction *I) { - I->setMetadata(I->getModule()->getMDKindID("nosanitize"), - MDNode::get(*C, None)); + I->setMetadata(LLVMContext::MD_nosanitize, MDNode::get(*C, None)); } std::string getSectionName(const std::string &Section) const; @@ -694,7 +687,7 @@ void ModuleSanitizerCoverage::instrumentFunction( for (auto &Inst : BB) { if (Options.IndirectCalls) { CallBase *CB = dyn_cast<CallBase>(&Inst); - if (CB && !CB->getCalledFunction()) + if (CB && CB->isIndirectCall()) IndirCalls.push_back(&Inst); } if (Options.TraceCmp) { @@ -996,15 +989,11 @@ void ModuleSanitizerCoverage::InjectCoverageAtBlock(Function &F, BasicBlock &BB, // if we aren't splitting the block, it's nice for allocas to be before // calls. IP = PrepareToSplitEntryBlock(BB, IP); - } else { - EntryLoc = IP->getDebugLoc(); - if (!EntryLoc) - if (auto *SP = F.getSubprogram()) - EntryLoc = DILocation::get(SP->getContext(), 0, 0, SP); } - IRBuilder<> IRB(&*IP); - IRB.SetCurrentDebugLocation(EntryLoc); + InstrumentationIRBuilder IRB(&*IP); + if (EntryLoc) + IRB.SetCurrentDebugLocation(EntryLoc); if (Options.TracePC) { IRB.CreateCall(SanCovTracePC) ->setCannotMerge(); // gets the PC using GET_CALLER_PC. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp index 180012198c42..c33b1b3b1a5c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp @@ -38,7 +38,6 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" -#include "llvm/InitializePasses.h" #include "llvm/ProfileData/InstrProf.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -174,19 +173,6 @@ private: FunctionCallee MemmoveFn, MemcpyFn, MemsetFn; }; -struct ThreadSanitizerLegacyPass : FunctionPass { - ThreadSanitizerLegacyPass() : FunctionPass(ID) { - initializeThreadSanitizerLegacyPassPass(*PassRegistry::getPassRegistry()); - } - StringRef getPassName() const override; - void getAnalysisUsage(AnalysisUsage &AU) const override; - bool runOnFunction(Function &F) override; - bool doInitialization(Module &M) override; - static char ID; // Pass identification, replacement for typeid. -private: - Optional<ThreadSanitizer> TSan; -}; - void insertModuleCtor(Module &M) { getOrCreateSanitizerCtorAndInitFunctions( M, kTsanModuleCtorName, kTsanInitName, /*InitArgTypes=*/{}, @@ -195,7 +181,6 @@ void insertModuleCtor(Module &M) { // time. Hook them into the global ctors list in that case: [&](Function *Ctor, FunctionCallee) { appendToGlobalCtors(M, Ctor, 0); }); } - } // namespace PreservedAnalyses ThreadSanitizerPass::run(Function &F, @@ -211,38 +196,6 @@ PreservedAnalyses ModuleThreadSanitizerPass::run(Module &M, insertModuleCtor(M); return PreservedAnalyses::none(); } - -char ThreadSanitizerLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(ThreadSanitizerLegacyPass, "tsan", - "ThreadSanitizer: detects data races.", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(ThreadSanitizerLegacyPass, "tsan", - "ThreadSanitizer: detects data races.", false, false) - -StringRef ThreadSanitizerLegacyPass::getPassName() const { - return "ThreadSanitizerLegacyPass"; -} - -void ThreadSanitizerLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addRequired<TargetLibraryInfoWrapperPass>(); -} - -bool ThreadSanitizerLegacyPass::doInitialization(Module &M) { - insertModuleCtor(M); - TSan.emplace(); - return true; -} - -bool ThreadSanitizerLegacyPass::runOnFunction(Function &F) { - auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - TSan->sanitizeFunction(F, TLI); - return true; -} - -FunctionPass *llvm::createThreadSanitizerLegacyPassPass() { - return new ThreadSanitizerLegacyPass(); -} - void ThreadSanitizer::initialize(Module &M) { const DataLayout &DL = M.getDataLayout(); IntptrTy = DL.getIntPtrType(M.getContext()); @@ -527,26 +480,22 @@ void ThreadSanitizer::chooseInstructionsToInstrument( Local.clear(); } -static bool isAtomic(Instruction *I) { +static bool isTsanAtomic(const Instruction *I) { // TODO: Ask TTI whether synchronization scope is between threads. - if (LoadInst *LI = dyn_cast<LoadInst>(I)) - return LI->isAtomic() && LI->getSyncScopeID() != SyncScope::SingleThread; - if (StoreInst *SI = dyn_cast<StoreInst>(I)) - return SI->isAtomic() && SI->getSyncScopeID() != SyncScope::SingleThread; - if (isa<AtomicRMWInst>(I)) - return true; - if (isa<AtomicCmpXchgInst>(I)) - return true; - if (isa<FenceInst>(I)) - return true; - return false; + auto SSID = getAtomicSyncScopeID(I); + if (!SSID) + return false; + if (isa<LoadInst>(I) || isa<StoreInst>(I)) + return SSID.getValue() != SyncScope::SingleThread; + return true; } void ThreadSanitizer::InsertRuntimeIgnores(Function &F) { - IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI()); + InstrumentationIRBuilder IRB(F.getEntryBlock().getFirstNonPHI()); IRB.CreateCall(TsanIgnoreBegin); EscapeEnumerator EE(F, "tsan_ignore_cleanup", ClHandleCxxExceptions); while (IRBuilder<> *AtExit = EE.Next()) { + InstrumentationIRBuilder::ensureDebugInfo(*AtExit, F); AtExit->CreateCall(TsanIgnoreEnd); } } @@ -581,7 +530,7 @@ bool ThreadSanitizer::sanitizeFunction(Function &F, // Traverse all instructions, collect loads/stores/returns, check for calls. for (auto &BB : F) { for (auto &Inst : BB) { - if (isAtomic(&Inst)) + if (isTsanAtomic(&Inst)) AtomicAccesses.push_back(&Inst); else if (isa<LoadInst>(Inst) || isa<StoreInst>(Inst)) LocalLoadsAndStores.push_back(&Inst); @@ -629,7 +578,7 @@ bool ThreadSanitizer::sanitizeFunction(Function &F, // Instrument function entry/exit points if there were instrumented accesses. if ((Res || HasCalls) && ClInstrumentFuncEntryExit) { - IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI()); + InstrumentationIRBuilder IRB(F.getEntryBlock().getFirstNonPHI()); Value *ReturnAddress = IRB.CreateCall( Intrinsic::getDeclaration(F.getParent(), Intrinsic::returnaddress), IRB.getInt32(0)); @@ -637,6 +586,7 @@ bool ThreadSanitizer::sanitizeFunction(Function &F, EscapeEnumerator EE(F, "tsan_cleanup", ClHandleCxxExceptions); while (IRBuilder<> *AtExit = EE.Next()) { + InstrumentationIRBuilder::ensureDebugInfo(*AtExit, F); AtExit->CreateCall(TsanFuncExit, {}); } Res = true; @@ -646,7 +596,7 @@ bool ThreadSanitizer::sanitizeFunction(Function &F, bool ThreadSanitizer::instrumentLoadOrStore(const InstructionInfo &II, const DataLayout &DL) { - IRBuilder<> IRB(II.Inst); + InstrumentationIRBuilder IRB(II.Inst); const bool IsWrite = isa<StoreInst>(*II.Inst); Value *Addr = IsWrite ? cast<StoreInst>(II.Inst)->getPointerOperand() : cast<LoadInst>(II.Inst)->getPointerOperand(); @@ -686,8 +636,8 @@ bool ThreadSanitizer::instrumentLoadOrStore(const InstructionInfo &II, return true; } - const unsigned Alignment = IsWrite ? cast<StoreInst>(II.Inst)->getAlignment() - : cast<LoadInst>(II.Inst)->getAlignment(); + const Align Alignment = IsWrite ? cast<StoreInst>(II.Inst)->getAlign() + : cast<LoadInst>(II.Inst)->getAlign(); const bool IsCompoundRW = ClCompoundReadBeforeWrite && (II.Flags & InstructionInfo::kCompoundRW); const bool IsVolatile = ClDistinguishVolatile && @@ -697,7 +647,7 @@ bool ThreadSanitizer::instrumentLoadOrStore(const InstructionInfo &II, const uint32_t TypeSize = DL.getTypeStoreSizeInBits(OrigTy); FunctionCallee OnAccessFunc = nullptr; - if (Alignment == 0 || Alignment >= 8 || (Alignment % (TypeSize / 8)) == 0) { + if (Alignment >= Align(8) || (Alignment.value() % (TypeSize / 8)) == 0) { if (IsCompoundRW) OnAccessFunc = TsanCompoundRW[Idx]; else if (IsVolatile) @@ -775,7 +725,7 @@ bool ThreadSanitizer::instrumentMemIntrinsic(Instruction *I) { // http://www.hpl.hp.com/personal/Hans_Boehm/c++mm/ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { - IRBuilder<> IRB(I); + InstrumentationIRBuilder IRB(I); if (LoadInst *LI = dyn_cast<LoadInst>(I)) { Value *Addr = LI->getPointerOperand(); Type *OrigTy = LI->getType(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ValueProfileCollector.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ValueProfileCollector.cpp index fb6216bb2177..32633bbc941b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ValueProfileCollector.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ValueProfileCollector.cpp @@ -10,12 +10,9 @@ // //===----------------------------------------------------------------------===// +#include "ValueProfileCollector.h" #include "ValueProfilePlugins.inc" -#include "llvm/IR/Function.h" -#include "llvm/IR/InstIterator.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/InitializePasses.h" -#include <cassert> +#include "llvm/ProfileData/InstrProf.h" using namespace llvm; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ValueProfileCollector.h b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ValueProfileCollector.h index 584a60ab451e..10e5e4d128b1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ValueProfileCollector.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ValueProfileCollector.h @@ -16,7 +16,6 @@ #ifndef LLVM_ANALYSIS_PROFILE_GEN_ANALYSIS_H #define LLVM_ANALYSIS_PROFILE_GEN_ANALYSIS_H -#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/ProfileData/InstrProf.h" #include <memory> #include <vector> @@ -25,6 +24,7 @@ namespace llvm { class Function; class Instruction; +class TargetLibraryInfo; class Value; /// Utility analysis that determines what values are worth profiling. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ValueProfilePlugins.inc b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ValueProfilePlugins.inc index 6a2c473a596a..3a129de1acd0 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ValueProfilePlugins.inc +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ValueProfilePlugins.inc @@ -15,6 +15,7 @@ #include "ValueProfileCollector.h" #include "llvm/Analysis/IndirectCallVisitor.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/InstVisitor.h" using namespace llvm; diff --git a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.cpp b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.cpp index 126845bb3308..70f150c9461a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.cpp @@ -16,7 +16,6 @@ #include "llvm-c/Initialization.h" #include "llvm/Analysis/ObjCARCUtil.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InlineAsm.h" #include "llvm/IR/Instructions.h" #include "llvm/InitializePasses.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.h b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.h index 62f88a8cc02b..2bc0c8f87d77 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.h +++ b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.h @@ -22,7 +22,6 @@ #ifndef LLVM_LIB_TRANSFORMS_OBJCARC_OBJCARC_H #define LLVM_LIB_TRANSFORMS_OBJCARC_OBJCARC_H -#include "ARCRuntimeEntryPoints.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/ObjCARCAnalysisUtils.h" #include "llvm/Analysis/ObjCARCUtil.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp index 210ec60f2f87..03e5fb18d5ac 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCAPElim.cpp @@ -23,11 +23,14 @@ /// //===----------------------------------------------------------------------===// -#include "ObjCARC.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/ObjCARCAnalysisUtils.h" +#include "llvm/Analysis/ObjCARCInstKind.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/PassManager.h" #include "llvm/InitializePasses.h" +#include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/ObjCARC.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp index 2985ae004d3c..f64c26ef2bed 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp @@ -102,11 +102,8 @@ public: }; class ObjCARCContractLegacyPass : public FunctionPass { - ObjCARCContract OCARCC; - public: void getAnalysisUsage(AnalysisUsage &AU) const override; - bool doInitialization(Module &M) override; bool runOnFunction(Function &F) override; static char ID; @@ -737,11 +734,9 @@ Pass *llvm::createObjCARCContractPass() { return new ObjCARCContractLegacyPass(); } -bool ObjCARCContractLegacyPass::doInitialization(Module &M) { - return OCARCC.init(M); -} - bool ObjCARCContractLegacyPass::runOnFunction(Function &F) { + ObjCARCContract OCARCC; + OCARCC.init(*F.getParent()); auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); return OCARCC.run(F, AA, DT); diff --git a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCExpand.cpp b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCExpand.cpp index 6b074ac5adab..efcdc51ef5e3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCExpand.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCExpand.cpp @@ -22,7 +22,7 @@ /// //===----------------------------------------------------------------------===// -#include "ObjCARC.h" +#include "llvm/Analysis/ObjCARCAnalysisUtils.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ADCE.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ADCE.cpp index 1cda206a7e14..cdf9de8d78d5 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ADCE.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ADCE.cpp @@ -35,7 +35,6 @@ #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" -#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp index e4ec5f266eb8..9571e99dfb19 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp @@ -15,8 +15,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm/IR/Instructions.h" -#include "llvm/InitializePasses.h" #include "llvm/Transforms/Scalar/AlignmentFromAssumptions.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" @@ -26,12 +24,11 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/Constant.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" -#include "llvm/IR/Module.h" +#include "llvm/InitializePasses.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp index a5e65ffc45fe..155f47b49357 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp @@ -16,11 +16,8 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" -#include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/MemoryOpRemark.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp index 95de59fa8262..cc12033fb677 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -57,6 +57,7 @@ #include "llvm/Transforms/Scalar/CallSiteSplitting.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/IntrinsicInst.h" @@ -65,7 +66,6 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" @@ -123,8 +123,8 @@ static bool isCondRelevantToAnyCallArgument(ICmpInst *Cmp, CallBase &CB) { return false; } -typedef std::pair<ICmpInst *, unsigned> ConditionTy; -typedef SmallVector<ConditionTy, 2> ConditionsTy; +using ConditionTy = std::pair<ICmpInst *, unsigned>; +using ConditionsTy = SmallVector<ConditionTy, 2>; /// If From has a conditional jump to To, add the condition to Conditions, /// if it is relevant to any argument at CB. @@ -301,10 +301,9 @@ static void copyMustTailReturn(BasicBlock *SplitBB, Instruction *CI, /// Note that in case any arguments at the call-site are constrained by its /// predecessors, new call-sites with more constrained arguments will be /// created in createCallSitesOnPredicatedArgument(). -static void splitCallSite( - CallBase &CB, - const SmallVectorImpl<std::pair<BasicBlock *, ConditionsTy>> &Preds, - DomTreeUpdater &DTU) { +static void splitCallSite(CallBase &CB, + ArrayRef<std::pair<BasicBlock *, ConditionsTy>> Preds, + DomTreeUpdater &DTU) { BasicBlock *TailBB = CB.getParent(); bool IsMustTailCall = CB.isMustTailCall(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp index 25e8c3ef3b48..8a1761505d59 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -52,6 +52,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp index 13963657d183..6dfa2440023f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp @@ -19,15 +19,16 @@ #include "llvm/Analysis/ConstraintSystem.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/DebugCounter.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Scalar.h" #include <string> @@ -42,48 +43,129 @@ DEBUG_COUNTER(EliminatedCounter, "conds-eliminated", "Controls which conditions are eliminated"); static int64_t MaxConstraintValue = std::numeric_limits<int64_t>::max(); +static int64_t MinSignedConstraintValue = std::numeric_limits<int64_t>::min(); namespace { -struct ConstraintTy { - SmallVector<int64_t, 8> Coefficients; - ConstraintTy(SmallVector<int64_t, 8> Coefficients) - : Coefficients(Coefficients) {} +class ConstraintInfo; - unsigned size() const { return Coefficients.size(); } +struct StackEntry { + unsigned NumIn; + unsigned NumOut; + bool IsNot; + bool IsSigned = false; + /// Variables that can be removed from the system once the stack entry gets + /// removed. + SmallVector<Value *, 2> ValuesToRelease; + + StackEntry(unsigned NumIn, unsigned NumOut, bool IsNot, bool IsSigned, + SmallVector<Value *, 2> ValuesToRelease) + : NumIn(NumIn), NumOut(NumOut), IsNot(IsNot), IsSigned(IsSigned), + ValuesToRelease(ValuesToRelease) {} }; -/// Struct to manage a list of constraints. -struct ConstraintListTy { - SmallVector<ConstraintTy, 4> Constraints; +/// Struct to express a pre-condition of the form %Op0 Pred %Op1. +struct PreconditionTy { + CmpInst::Predicate Pred; + Value *Op0; + Value *Op1; - ConstraintListTy() {} + PreconditionTy(CmpInst::Predicate Pred, Value *Op0, Value *Op1) + : Pred(Pred), Op0(Op0), Op1(Op1) {} +}; - ConstraintListTy(const SmallVector<ConstraintTy, 4> &Constraints) - : Constraints(Constraints) {} +struct ConstraintTy { + SmallVector<int64_t, 8> Coefficients; + SmallVector<PreconditionTy, 2> Preconditions; - void mergeIn(const ConstraintListTy &Other) { - append_range(Constraints, Other.Constraints); - } + bool IsSigned = false; + bool IsEq = false; + + ConstraintTy() = default; - unsigned size() const { return Constraints.size(); } + ConstraintTy(SmallVector<int64_t, 8> Coefficients, bool IsSigned) + : Coefficients(Coefficients), IsSigned(IsSigned) {} + + unsigned size() const { return Coefficients.size(); } - unsigned empty() const { return Constraints.empty(); } + unsigned empty() const { return Coefficients.empty(); } /// Returns true if any constraint has a non-zero coefficient for any of the /// newly added indices. Zero coefficients for new indices are removed. If it /// returns true, no new variable need to be added to the system. bool needsNewIndices(const DenseMap<Value *, unsigned> &NewIndices) { - assert(size() == 1); for (unsigned I = 0; I < NewIndices.size(); ++I) { - int64_t Last = get(0).Coefficients.pop_back_val(); + int64_t Last = Coefficients.pop_back_val(); if (Last != 0) return true; } return false; } - ConstraintTy &get(unsigned I) { return Constraints[I]; } + /// Returns true if all preconditions for this list of constraints are + /// satisfied given \p CS and the corresponding \p Value2Index mapping. + bool isValid(const ConstraintInfo &Info) const; +}; + +/// Wrapper encapsulating separate constraint systems and corresponding value +/// mappings for both unsigned and signed information. Facts are added to and +/// conditions are checked against the corresponding system depending on the +/// signed-ness of their predicates. While the information is kept separate +/// based on signed-ness, certain conditions can be transferred between the two +/// systems. +class ConstraintInfo { + DenseMap<Value *, unsigned> UnsignedValue2Index; + DenseMap<Value *, unsigned> SignedValue2Index; + + ConstraintSystem UnsignedCS; + ConstraintSystem SignedCS; + +public: + DenseMap<Value *, unsigned> &getValue2Index(bool Signed) { + return Signed ? SignedValue2Index : UnsignedValue2Index; + } + const DenseMap<Value *, unsigned> &getValue2Index(bool Signed) const { + return Signed ? SignedValue2Index : UnsignedValue2Index; + } + + ConstraintSystem &getCS(bool Signed) { + return Signed ? SignedCS : UnsignedCS; + } + const ConstraintSystem &getCS(bool Signed) const { + return Signed ? SignedCS : UnsignedCS; + } + + void popLastConstraint(bool Signed) { getCS(Signed).popLastConstraint(); } + void popLastNVariables(bool Signed, unsigned N) { + getCS(Signed).popLastNVariables(N); + } + + bool doesHold(CmpInst::Predicate Pred, Value *A, Value *B) const; + + void addFact(CmpInst::Predicate Pred, Value *A, Value *B, bool IsNegated, + unsigned NumIn, unsigned NumOut, + SmallVectorImpl<StackEntry> &DFSInStack); + + /// Turn a comparison of the form \p Op0 \p Pred \p Op1 into a vector of + /// constraints, using indices from the corresponding constraint system. + /// Additional indices for newly discovered values are added to \p NewIndices. + ConstraintTy getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1, + DenseMap<Value *, unsigned> &NewIndices) const; + + /// Turn a condition \p CmpI into a vector of constraints, using indices from + /// the corresponding constraint system. Additional indices for newly + /// discovered values are added to \p NewIndices. + ConstraintTy getConstraint(CmpInst *Cmp, + DenseMap<Value *, unsigned> &NewIndices) const { + return getConstraint(Cmp->getPredicate(), Cmp->getOperand(0), + Cmp->getOperand(1), NewIndices); + } + + /// Try to add information from \p A \p Pred \p B to the unsigned/signed + /// system if \p Pred is signed/unsigned. + void transferToOtherSystem(CmpInst::Predicate Pred, Value *A, Value *B, + bool IsNegated, unsigned NumIn, unsigned NumOut, + SmallVectorImpl<StackEntry> &DFSInStack); }; } // namespace @@ -92,11 +174,28 @@ struct ConstraintListTy { // sum of the pairs equals \p V. The first pair is the constant-factor and X // must be nullptr. If the expression cannot be decomposed, returns an empty // vector. -static SmallVector<std::pair<int64_t, Value *>, 4> decompose(Value *V) { +static SmallVector<std::pair<int64_t, Value *>, 4> +decompose(Value *V, SmallVector<PreconditionTy, 4> &Preconditions, + bool IsSigned) { + + auto CanUseSExt = [](ConstantInt *CI) { + const APInt &Val = CI->getValue(); + return Val.sgt(MinSignedConstraintValue) && Val.slt(MaxConstraintValue); + }; + // Decompose \p V used with a signed predicate. + if (IsSigned) { + if (auto *CI = dyn_cast<ConstantInt>(V)) { + if (CanUseSExt(CI)) + return {{CI->getSExtValue(), nullptr}}; + } + + return {{0, nullptr}, {1, V}}; + } + if (auto *CI = dyn_cast<ConstantInt>(V)) { - if (CI->isNegative() || CI->uge(MaxConstraintValue)) + if (CI->uge(MaxConstraintValue)) return {}; - return {{CI->getSExtValue(), nullptr}}; + return {{CI->getZExtValue(), nullptr}}; } auto *GEP = dyn_cast<GetElementPtrInst>(V); if (GEP && GEP->getNumOperands() == 2 && GEP->isInBounds()) { @@ -106,11 +205,13 @@ static SmallVector<std::pair<int64_t, Value *>, 4> decompose(Value *V) { // If the index is zero-extended, it is guaranteed to be positive. if (match(GEP->getOperand(GEP->getNumOperands() - 1), m_ZExt(m_Value(Op0)))) { - if (match(Op0, m_NUWShl(m_Value(Op1), m_ConstantInt(CI)))) + if (match(Op0, m_NUWShl(m_Value(Op1), m_ConstantInt(CI))) && + CanUseSExt(CI)) return {{0, nullptr}, {1, GEP->getPointerOperand()}, {std::pow(int64_t(2), CI->getSExtValue()), Op1}}; - if (match(Op0, m_NSWAdd(m_Value(Op1), m_ConstantInt(CI)))) + if (match(Op0, m_NSWAdd(m_Value(Op1), m_ConstantInt(CI))) && + CanUseSExt(CI)) return {{CI->getSExtValue(), nullptr}, {1, GEP->getPointerOperand()}, {1, Op1}}; @@ -118,17 +219,19 @@ static SmallVector<std::pair<int64_t, Value *>, 4> decompose(Value *V) { } if (match(GEP->getOperand(GEP->getNumOperands() - 1), m_ConstantInt(CI)) && - !CI->isNegative()) + !CI->isNegative() && CanUseSExt(CI)) return {{CI->getSExtValue(), nullptr}, {1, GEP->getPointerOperand()}}; SmallVector<std::pair<int64_t, Value *>, 4> Result; if (match(GEP->getOperand(GEP->getNumOperands() - 1), - m_NUWShl(m_Value(Op0), m_ConstantInt(CI)))) + m_NUWShl(m_Value(Op0), m_ConstantInt(CI))) && + CanUseSExt(CI)) Result = {{0, nullptr}, {1, GEP->getPointerOperand()}, {std::pow(int64_t(2), CI->getSExtValue()), Op0}}; else if (match(GEP->getOperand(GEP->getNumOperands() - 1), - m_NSWAdd(m_Value(Op0), m_ConstantInt(CI)))) + m_NSWAdd(m_Value(Op0), m_ConstantInt(CI))) && + CanUseSExt(CI)) Result = {{CI->getSExtValue(), nullptr}, {1, GEP->getPointerOperand()}, {1, Op0}}; @@ -136,6 +239,10 @@ static SmallVector<std::pair<int64_t, Value *>, 4> decompose(Value *V) { Op0 = GEP->getOperand(GEP->getNumOperands() - 1); Result = {{0, nullptr}, {1, GEP->getPointerOperand()}, {1, Op0}}; } + // If Op0 is signed non-negative, the GEP is increasing monotonically and + // can be de-composed. + Preconditions.emplace_back(CmpInst::ICMP_SGE, Op0, + ConstantInt::get(Op0->getType(), 0)); return Result; } @@ -145,12 +252,20 @@ static SmallVector<std::pair<int64_t, Value *>, 4> decompose(Value *V) { Value *Op1; ConstantInt *CI; - if (match(V, m_NUWAdd(m_Value(Op0), m_ConstantInt(CI)))) + if (match(V, m_NUWAdd(m_Value(Op0), m_ConstantInt(CI))) && + !CI->uge(MaxConstraintValue)) + return {{CI->getZExtValue(), nullptr}, {1, Op0}}; + if (match(V, m_Add(m_Value(Op0), m_ConstantInt(CI))) && CI->isNegative() && + CanUseSExt(CI)) { + Preconditions.emplace_back( + CmpInst::ICMP_UGE, Op0, + ConstantInt::get(Op0->getType(), CI->getSExtValue() * -1)); return {{CI->getSExtValue(), nullptr}, {1, Op0}}; + } if (match(V, m_NUWAdd(m_Value(Op0), m_Value(Op1)))) return {{0, nullptr}, {1, Op0}, {1, Op1}}; - if (match(V, m_NUWSub(m_Value(Op0), m_ConstantInt(CI)))) + if (match(V, m_NUWSub(m_Value(Op0), m_ConstantInt(CI))) && CanUseSExt(CI)) return {{-1 * CI->getSExtValue(), nullptr}, {1, Op0}}; if (match(V, m_NUWSub(m_Value(Op0), m_Value(Op1)))) return {{0, nullptr}, {1, Op0}, {-1, Op1}}; @@ -158,73 +273,73 @@ static SmallVector<std::pair<int64_t, Value *>, 4> decompose(Value *V) { return {{0, nullptr}, {1, V}}; } -/// Turn a condition \p CmpI into a vector of constraints, using indices from \p -/// Value2Index. Additional indices for newly discovered values are added to \p -/// NewIndices. -static ConstraintListTy -getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1, - const DenseMap<Value *, unsigned> &Value2Index, - DenseMap<Value *, unsigned> &NewIndices) { - int64_t Offset1 = 0; - int64_t Offset2 = 0; - - // First try to look up \p V in Value2Index and NewIndices. Otherwise add a - // new entry to NewIndices. - auto GetOrAddIndex = [&Value2Index, &NewIndices](Value *V) -> unsigned { - auto V2I = Value2Index.find(V); - if (V2I != Value2Index.end()) - return V2I->second; - auto NewI = NewIndices.find(V); - if (NewI != NewIndices.end()) - return NewI->second; - auto Insert = - NewIndices.insert({V, Value2Index.size() + NewIndices.size() + 1}); - return Insert.first->second; - }; - - if (Pred == CmpInst::ICMP_UGT || Pred == CmpInst::ICMP_UGE) - return getConstraint(CmpInst::getSwappedPredicate(Pred), Op1, Op0, - Value2Index, NewIndices); - - if (Pred == CmpInst::ICMP_EQ) { - if (match(Op1, m_Zero())) - return getConstraint(CmpInst::ICMP_ULE, Op0, Op1, Value2Index, - NewIndices); - - auto A = - getConstraint(CmpInst::ICMP_UGE, Op0, Op1, Value2Index, NewIndices); - auto B = - getConstraint(CmpInst::ICMP_ULE, Op0, Op1, Value2Index, NewIndices); - A.mergeIn(B); - return A; +ConstraintTy +ConstraintInfo::getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1, + DenseMap<Value *, unsigned> &NewIndices) const { + bool IsEq = false; + // Try to convert Pred to one of ULE/SLT/SLE/SLT. + switch (Pred) { + case CmpInst::ICMP_UGT: + case CmpInst::ICMP_UGE: + case CmpInst::ICMP_SGT: + case CmpInst::ICMP_SGE: { + Pred = CmpInst::getSwappedPredicate(Pred); + std::swap(Op0, Op1); + break; } - - if (Pred == CmpInst::ICMP_NE && match(Op1, m_Zero())) { - return getConstraint(CmpInst::ICMP_UGT, Op0, Op1, Value2Index, NewIndices); + case CmpInst::ICMP_EQ: + if (match(Op1, m_Zero())) { + Pred = CmpInst::ICMP_ULE; + } else { + IsEq = true; + Pred = CmpInst::ICMP_ULE; + } + break; + case CmpInst::ICMP_NE: + if (!match(Op1, m_Zero())) + return {}; + Pred = CmpInst::getSwappedPredicate(CmpInst::ICMP_UGT); + std::swap(Op0, Op1); + break; + default: + break; } // Only ULE and ULT predicates are supported at the moment. - if (Pred != CmpInst::ICMP_ULE && Pred != CmpInst::ICMP_ULT) + if (Pred != CmpInst::ICMP_ULE && Pred != CmpInst::ICMP_ULT && + Pred != CmpInst::ICMP_SLE && Pred != CmpInst::ICMP_SLT) return {}; - auto ADec = decompose(Op0->stripPointerCastsSameRepresentation()); - auto BDec = decompose(Op1->stripPointerCastsSameRepresentation()); + SmallVector<PreconditionTy, 4> Preconditions; + bool IsSigned = CmpInst::isSigned(Pred); + auto &Value2Index = getValue2Index(IsSigned); + auto ADec = decompose(Op0->stripPointerCastsSameRepresentation(), + Preconditions, IsSigned); + auto BDec = decompose(Op1->stripPointerCastsSameRepresentation(), + Preconditions, IsSigned); // Skip if decomposing either of the values failed. if (ADec.empty() || BDec.empty()) return {}; - // Skip trivial constraints without any variables. - if (ADec.size() == 1 && BDec.size() == 1) - return {}; - - Offset1 = ADec[0].first; - Offset2 = BDec[0].first; + int64_t Offset1 = ADec[0].first; + int64_t Offset2 = BDec[0].first; Offset1 *= -1; // Create iterator ranges that skip the constant-factor. auto VariablesA = llvm::drop_begin(ADec); auto VariablesB = llvm::drop_begin(BDec); + // First try to look up \p V in Value2Index and NewIndices. Otherwise add a + // new entry to NewIndices. + auto GetOrAddIndex = [&Value2Index, &NewIndices](Value *V) -> unsigned { + auto V2I = Value2Index.find(V); + if (V2I != Value2Index.end()) + return V2I->second; + auto Insert = + NewIndices.insert({V, Value2Index.size() + NewIndices.size() + 1}); + return Insert.first->second; + }; + // Make sure all variables have entries in Value2Index or NewIndices. for (const auto &KV : concat<std::pair<int64_t, Value *>>(VariablesA, VariablesB)) @@ -232,22 +347,85 @@ getConstraint(CmpInst::Predicate Pred, Value *Op0, Value *Op1, // Build result constraint, by first adding all coefficients from A and then // subtracting all coefficients from B. - SmallVector<int64_t, 8> R(Value2Index.size() + NewIndices.size() + 1, 0); + ConstraintTy Res( + SmallVector<int64_t, 8>(Value2Index.size() + NewIndices.size() + 1, 0), + IsSigned); + Res.IsEq = IsEq; + auto &R = Res.Coefficients; for (const auto &KV : VariablesA) R[GetOrAddIndex(KV.second)] += KV.first; for (const auto &KV : VariablesB) R[GetOrAddIndex(KV.second)] -= KV.first; - R[0] = Offset1 + Offset2 + (Pred == CmpInst::ICMP_ULT ? -1 : 0); - return {{R}}; + int64_t OffsetSum; + if (AddOverflow(Offset1, Offset2, OffsetSum)) + return {}; + if (Pred == (IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT)) + if (AddOverflow(OffsetSum, int64_t(-1), OffsetSum)) + return {}; + R[0] = OffsetSum; + Res.Preconditions = std::move(Preconditions); + return Res; +} + +bool ConstraintTy::isValid(const ConstraintInfo &Info) const { + return Coefficients.size() > 0 && + all_of(Preconditions, [&Info](const PreconditionTy &C) { + return Info.doesHold(C.Pred, C.Op0, C.Op1); + }); +} + +bool ConstraintInfo::doesHold(CmpInst::Predicate Pred, Value *A, + Value *B) const { + DenseMap<Value *, unsigned> NewIndices; + auto R = getConstraint(Pred, A, B, NewIndices); + + if (!NewIndices.empty()) + return false; + + // TODO: properly check NewIndices. + return NewIndices.empty() && R.Preconditions.empty() && !R.IsEq && + !R.empty() && + getCS(CmpInst::isSigned(Pred)).isConditionImplied(R.Coefficients); } -static ConstraintListTy -getConstraint(CmpInst *Cmp, const DenseMap<Value *, unsigned> &Value2Index, - DenseMap<Value *, unsigned> &NewIndices) { - return getConstraint(Cmp->getPredicate(), Cmp->getOperand(0), - Cmp->getOperand(1), Value2Index, NewIndices); +void ConstraintInfo::transferToOtherSystem( + CmpInst::Predicate Pred, Value *A, Value *B, bool IsNegated, unsigned NumIn, + unsigned NumOut, SmallVectorImpl<StackEntry> &DFSInStack) { + // Check if we can combine facts from the signed and unsigned systems to + // derive additional facts. + if (!A->getType()->isIntegerTy()) + return; + // FIXME: This currently depends on the order we add facts. Ideally we + // would first add all known facts and only then try to add additional + // facts. + switch (Pred) { + default: + break; + case CmpInst::ICMP_ULT: + // If B is a signed positive constant, A >=s 0 and A <s B. + if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), 0))) { + addFact(CmpInst::ICMP_SGE, A, ConstantInt::get(B->getType(), 0), + IsNegated, NumIn, NumOut, DFSInStack); + addFact(CmpInst::ICMP_SLT, A, B, IsNegated, NumIn, NumOut, DFSInStack); + } + break; + case CmpInst::ICMP_SLT: + if (doesHold(CmpInst::ICMP_SGE, A, ConstantInt::get(B->getType(), 0))) + addFact(CmpInst::ICMP_ULT, A, B, IsNegated, NumIn, NumOut, DFSInStack); + break; + case CmpInst::ICMP_SGT: + if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), -1))) + addFact(CmpInst::ICMP_UGE, A, ConstantInt::get(B->getType(), 0), + IsNegated, NumIn, NumOut, DFSInStack); + break; + case CmpInst::ICMP_SGE: + if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), 0))) { + addFact(CmpInst::ICMP_UGE, A, B, IsNegated, NumIn, NumOut, DFSInStack); + } + break; + } } namespace { @@ -271,134 +449,253 @@ struct ConstraintOrBlock { Not(Not), Condition(Condition) {} }; -struct StackEntry { - unsigned NumIn; - unsigned NumOut; - CmpInst *Condition; - bool IsNot; +/// Keep state required to build worklist. +struct State { + DominatorTree &DT; + SmallVector<ConstraintOrBlock, 64> WorkList; - StackEntry(unsigned NumIn, unsigned NumOut, CmpInst *Condition, bool IsNot) - : NumIn(NumIn), NumOut(NumOut), Condition(Condition), IsNot(IsNot) {} + State(DominatorTree &DT) : DT(DT) {} + + /// Process block \p BB and add known facts to work-list. + void addInfoFor(BasicBlock &BB); + + /// Returns true if we can add a known condition from BB to its successor + /// block Succ. Each predecessor of Succ can either be BB or be dominated + /// by Succ (e.g. the case when adding a condition from a pre-header to a + /// loop header). + bool canAddSuccessor(BasicBlock &BB, BasicBlock *Succ) const { + if (BB.getSingleSuccessor()) { + assert(BB.getSingleSuccessor() == Succ); + return DT.properlyDominates(&BB, Succ); + } + return any_of(successors(&BB), + [Succ](const BasicBlock *S) { return S != Succ; }) && + all_of(predecessors(Succ), [&BB, Succ, this](BasicBlock *Pred) { + return Pred == &BB || DT.dominates(Succ, Pred); + }); + } }; + } // namespace #ifndef NDEBUG -static void dumpWithNames(ConstraintTy &C, +static void dumpWithNames(const ConstraintSystem &CS, DenseMap<Value *, unsigned> &Value2Index) { SmallVector<std::string> Names(Value2Index.size(), ""); for (auto &KV : Value2Index) { Names[KV.second - 1] = std::string("%") + KV.first->getName().str(); } - ConstraintSystem CS; - CS.addVariableRowFill(C.Coefficients); CS.dump(Names); } -#endif -static bool eliminateConstraints(Function &F, DominatorTree &DT) { - bool Changed = false; - DT.updateDFSNumbers(); +static void dumpWithNames(ArrayRef<int64_t> C, + DenseMap<Value *, unsigned> &Value2Index) { ConstraintSystem CS; + CS.addVariableRowFill(C); + dumpWithNames(CS, Value2Index); +} +#endif - SmallVector<ConstraintOrBlock, 64> WorkList; - - // First, collect conditions implied by branches and blocks with their - // Dominator DFS in and out numbers. - for (BasicBlock &BB : F) { - if (!DT.getNode(&BB)) - continue; - WorkList.emplace_back(DT.getNode(&BB)); - - // True as long as long as the current instruction is guaranteed to execute. - bool GuaranteedToExecute = true; - // Scan BB for assume calls. - // TODO: also use this scan to queue conditions to simplify, so we can - // interleave facts from assumes and conditions to simplify in a single - // basic block. And to skip another traversal of each basic block when - // simplifying. - for (Instruction &I : BB) { - Value *Cond; - // For now, just handle assumes with a single compare as condition. - if (match(&I, m_Intrinsic<Intrinsic::assume>(m_Value(Cond))) && - isa<CmpInst>(Cond)) { - if (GuaranteedToExecute) { - // The assume is guaranteed to execute when BB is entered, hence Cond - // holds on entry to BB. - WorkList.emplace_back(DT.getNode(&BB), cast<CmpInst>(Cond), false); - } else { - // Otherwise the condition only holds in the successors. - for (BasicBlock *Succ : successors(&BB)) - WorkList.emplace_back(DT.getNode(Succ), cast<CmpInst>(Cond), false); +void State::addInfoFor(BasicBlock &BB) { + WorkList.emplace_back(DT.getNode(&BB)); + + // True as long as long as the current instruction is guaranteed to execute. + bool GuaranteedToExecute = true; + // Scan BB for assume calls. + // TODO: also use this scan to queue conditions to simplify, so we can + // interleave facts from assumes and conditions to simplify in a single + // basic block. And to skip another traversal of each basic block when + // simplifying. + for (Instruction &I : BB) { + Value *Cond; + // For now, just handle assumes with a single compare as condition. + if (match(&I, m_Intrinsic<Intrinsic::assume>(m_Value(Cond))) && + isa<ICmpInst>(Cond)) { + if (GuaranteedToExecute) { + // The assume is guaranteed to execute when BB is entered, hence Cond + // holds on entry to BB. + WorkList.emplace_back(DT.getNode(&BB), cast<ICmpInst>(Cond), false); + } else { + // Otherwise the condition only holds in the successors. + for (BasicBlock *Succ : successors(&BB)) { + if (!canAddSuccessor(BB, Succ)) + continue; + WorkList.emplace_back(DT.getNode(Succ), cast<ICmpInst>(Cond), false); } } - GuaranteedToExecute &= isGuaranteedToTransferExecutionToSuccessor(&I); } + GuaranteedToExecute &= isGuaranteedToTransferExecutionToSuccessor(&I); + } - auto *Br = dyn_cast<BranchInst>(BB.getTerminator()); - if (!Br || !Br->isConditional()) - continue; + auto *Br = dyn_cast<BranchInst>(BB.getTerminator()); + if (!Br || !Br->isConditional()) + return; + + // If the condition is an OR of 2 compares and the false successor only has + // the current block as predecessor, queue both negated conditions for the + // false successor. + Value *Op0, *Op1; + if (match(Br->getCondition(), m_LogicalOr(m_Value(Op0), m_Value(Op1))) && + isa<ICmpInst>(Op0) && isa<ICmpInst>(Op1)) { + BasicBlock *FalseSuccessor = Br->getSuccessor(1); + if (canAddSuccessor(BB, FalseSuccessor)) { + WorkList.emplace_back(DT.getNode(FalseSuccessor), cast<ICmpInst>(Op0), + true); + WorkList.emplace_back(DT.getNode(FalseSuccessor), cast<ICmpInst>(Op1), + true); + } + return; + } - // Returns true if we can add a known condition from BB to its successor - // block Succ. Each predecessor of Succ can either be BB or be dominated by - // Succ (e.g. the case when adding a condition from a pre-header to a loop - // header). - auto CanAdd = [&BB, &DT](BasicBlock *Succ) { - return all_of(predecessors(Succ), [&BB, &DT, Succ](BasicBlock *Pred) { - return Pred == &BB || DT.dominates(Succ, Pred); - }); - }; - // If the condition is an OR of 2 compares and the false successor only has - // the current block as predecessor, queue both negated conditions for the - // false successor. - Value *Op0, *Op1; - if (match(Br->getCondition(), m_LogicalOr(m_Value(Op0), m_Value(Op1))) && - match(Op0, m_Cmp()) && match(Op1, m_Cmp())) { - BasicBlock *FalseSuccessor = Br->getSuccessor(1); - if (CanAdd(FalseSuccessor)) { - WorkList.emplace_back(DT.getNode(FalseSuccessor), cast<CmpInst>(Op0), - true); - WorkList.emplace_back(DT.getNode(FalseSuccessor), cast<CmpInst>(Op1), - true); - } - continue; + // If the condition is an AND of 2 compares and the true successor only has + // the current block as predecessor, queue both conditions for the true + // successor. + if (match(Br->getCondition(), m_LogicalAnd(m_Value(Op0), m_Value(Op1))) && + isa<ICmpInst>(Op0) && isa<ICmpInst>(Op1)) { + BasicBlock *TrueSuccessor = Br->getSuccessor(0); + if (canAddSuccessor(BB, TrueSuccessor)) { + WorkList.emplace_back(DT.getNode(TrueSuccessor), cast<ICmpInst>(Op0), + false); + WorkList.emplace_back(DT.getNode(TrueSuccessor), cast<ICmpInst>(Op1), + false); } + return; + } - // If the condition is an AND of 2 compares and the true successor only has - // the current block as predecessor, queue both conditions for the true - // successor. - if (match(Br->getCondition(), m_LogicalAnd(m_Value(Op0), m_Value(Op1))) && - match(Op0, m_Cmp()) && match(Op1, m_Cmp())) { - BasicBlock *TrueSuccessor = Br->getSuccessor(0); - if (CanAdd(TrueSuccessor)) { - WorkList.emplace_back(DT.getNode(TrueSuccessor), cast<CmpInst>(Op0), - false); - WorkList.emplace_back(DT.getNode(TrueSuccessor), cast<CmpInst>(Op1), - false); + auto *CmpI = dyn_cast<ICmpInst>(Br->getCondition()); + if (!CmpI) + return; + if (canAddSuccessor(BB, Br->getSuccessor(0))) + WorkList.emplace_back(DT.getNode(Br->getSuccessor(0)), CmpI, false); + if (canAddSuccessor(BB, Br->getSuccessor(1))) + WorkList.emplace_back(DT.getNode(Br->getSuccessor(1)), CmpI, true); +} + +void ConstraintInfo::addFact(CmpInst::Predicate Pred, Value *A, Value *B, + bool IsNegated, unsigned NumIn, unsigned NumOut, + SmallVectorImpl<StackEntry> &DFSInStack) { + // If the constraint has a pre-condition, skip the constraint if it does not + // hold. + DenseMap<Value *, unsigned> NewIndices; + auto R = getConstraint(Pred, A, B, NewIndices); + if (!R.isValid(*this)) + return; + + //LLVM_DEBUG(dbgs() << "Adding " << *Condition << " " << IsNegated << "\n"); + bool Added = false; + assert(CmpInst::isSigned(Pred) == R.IsSigned && + "condition and constraint signs must match"); + auto &CSToUse = getCS(R.IsSigned); + if (R.Coefficients.empty()) + return; + + Added |= CSToUse.addVariableRowFill(R.Coefficients); + + // If R has been added to the system, queue it for removal once it goes + // out-of-scope. + if (Added) { + SmallVector<Value *, 2> ValuesToRelease; + for (auto &KV : NewIndices) { + getValue2Index(R.IsSigned).insert(KV); + ValuesToRelease.push_back(KV.first); + } + + LLVM_DEBUG({ + dbgs() << " constraint: "; + dumpWithNames(R.Coefficients, getValue2Index(R.IsSigned)); + }); + + DFSInStack.emplace_back(NumIn, NumOut, IsNegated, R.IsSigned, + ValuesToRelease); + + if (R.IsEq) { + // Also add the inverted constraint for equality constraints. + for (auto &Coeff : R.Coefficients) + Coeff *= -1; + CSToUse.addVariableRowFill(R.Coefficients); + + DFSInStack.emplace_back(NumIn, NumOut, IsNegated, R.IsSigned, + SmallVector<Value *, 2>()); + } + } +} + +static void +tryToSimplifyOverflowMath(IntrinsicInst *II, ConstraintInfo &Info, + SmallVectorImpl<Instruction *> &ToRemove) { + auto DoesConditionHold = [](CmpInst::Predicate Pred, Value *A, Value *B, + ConstraintInfo &Info) { + DenseMap<Value *, unsigned> NewIndices; + auto R = Info.getConstraint(Pred, A, B, NewIndices); + if (R.size() < 2 || R.needsNewIndices(NewIndices) || !R.isValid(Info)) + return false; + + auto &CSToUse = Info.getCS(CmpInst::isSigned(Pred)); + return CSToUse.isConditionImplied(R.Coefficients); + }; + + if (II->getIntrinsicID() == Intrinsic::ssub_with_overflow) { + // If A s>= B && B s>= 0, ssub.with.overflow(a, b) should not overflow and + // can be simplified to a regular sub. + Value *A = II->getArgOperand(0); + Value *B = II->getArgOperand(1); + if (!DoesConditionHold(CmpInst::ICMP_SGE, A, B, Info) || + !DoesConditionHold(CmpInst::ICMP_SGE, B, + ConstantInt::get(A->getType(), 0), Info)) + return; + + IRBuilder<> Builder(II->getParent(), II->getIterator()); + Value *Sub = nullptr; + for (User *U : make_early_inc_range(II->users())) { + if (match(U, m_ExtractValue<0>(m_Value()))) { + if (!Sub) + Sub = Builder.CreateSub(A, B); + U->replaceAllUsesWith(Sub); + } else if (match(U, m_ExtractValue<1>(m_Value()))) + U->replaceAllUsesWith(Builder.getFalse()); + else + continue; + + if (U->use_empty()) { + auto *I = cast<Instruction>(U); + ToRemove.push_back(I); + I->setOperand(0, PoisonValue::get(II->getType())); } - continue; } - auto *CmpI = dyn_cast<CmpInst>(Br->getCondition()); - if (!CmpI) + if (II->use_empty()) + II->eraseFromParent(); + } +} + +static bool eliminateConstraints(Function &F, DominatorTree &DT) { + bool Changed = false; + DT.updateDFSNumbers(); + + ConstraintInfo Info; + State S(DT); + + // First, collect conditions implied by branches and blocks with their + // Dominator DFS in and out numbers. + for (BasicBlock &BB : F) { + if (!DT.getNode(&BB)) continue; - if (CanAdd(Br->getSuccessor(0))) - WorkList.emplace_back(DT.getNode(Br->getSuccessor(0)), CmpI, false); - if (CanAdd(Br->getSuccessor(1))) - WorkList.emplace_back(DT.getNode(Br->getSuccessor(1)), CmpI, true); + S.addInfoFor(BB); } // Next, sort worklist by dominance, so that dominating blocks and conditions // come before blocks and conditions dominated by them. If a block and a // condition have the same numbers, the condition comes before the block, as // it holds on entry to the block. - sort(WorkList, [](const ConstraintOrBlock &A, const ConstraintOrBlock &B) { + stable_sort(S.WorkList, [](const ConstraintOrBlock &A, const ConstraintOrBlock &B) { return std::tie(A.NumIn, A.IsBlock) < std::tie(B.NumIn, B.IsBlock); }); + SmallVector<Instruction *> ToRemove; + // Finally, process ordered worklist and eliminate implied conditions. SmallVector<StackEntry, 16> DFSInStack; - DenseMap<Value *, unsigned> Value2Index; - for (ConstraintOrBlock &CB : WorkList) { + for (ConstraintOrBlock &CB : S.WorkList) { // First, pop entries from the stack that are out-of-scope for CB. Remove // the corresponding entry from the constraint system. while (!DFSInStack.empty()) { @@ -409,10 +706,20 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT) { assert(E.NumIn <= CB.NumIn); if (CB.NumOut <= E.NumOut) break; - LLVM_DEBUG(dbgs() << "Removing " << *E.Condition << " " << E.IsNot - << "\n"); + LLVM_DEBUG({ + dbgs() << "Removing "; + dumpWithNames(Info.getCS(E.IsSigned).getLastConstraint(), + Info.getValue2Index(E.IsSigned)); + dbgs() << "\n"; + }); + + Info.popLastConstraint(E.IsSigned); + // Remove variables in the system that went out of scope. + auto &Mapping = Info.getValue2Index(E.IsSigned); + for (Value *V : E.ValuesToRelease) + Mapping.erase(V); + Info.popLastNVariables(E.IsSigned, E.ValuesToRelease.size()); DFSInStack.pop_back(); - CS.popLastConstraint(); } LLVM_DEBUG({ @@ -427,28 +734,30 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT) { // For a block, check if any CmpInsts become known based on the current set // of constraints. if (CB.IsBlock) { - for (Instruction &I : *CB.BB) { - auto *Cmp = dyn_cast<CmpInst>(&I); + for (Instruction &I : make_early_inc_range(*CB.BB)) { + if (auto *II = dyn_cast<WithOverflowInst>(&I)) { + tryToSimplifyOverflowMath(II, Info, ToRemove); + continue; + } + auto *Cmp = dyn_cast<ICmpInst>(&I); if (!Cmp) continue; DenseMap<Value *, unsigned> NewIndices; - auto R = getConstraint(Cmp, Value2Index, NewIndices); - if (R.size() != 1) - continue; - - if (R.needsNewIndices(NewIndices)) + auto R = Info.getConstraint(Cmp, NewIndices); + if (R.IsEq || R.empty() || R.needsNewIndices(NewIndices) || + !R.isValid(Info)) continue; - if (CS.isConditionImplied(R.get(0).Coefficients)) { + auto &CSToUse = Info.getCS(R.IsSigned); + if (CSToUse.isConditionImplied(R.Coefficients)) { if (!DebugCounter::shouldExecute(EliminatedCounter)) continue; - LLVM_DEBUG(dbgs() << "Condition " << *Cmp - << " implied by dominating constraints\n"); LLVM_DEBUG({ - for (auto &E : reverse(DFSInStack)) - dbgs() << " C " << *E.Condition << " " << E.IsNot << "\n"; + dbgs() << "Condition " << *Cmp + << " implied by dominating constraints\n"; + dumpWithNames(CSToUse, Info.getValue2Index(R.IsSigned)); }); Cmp->replaceUsesWithIf( ConstantInt::getTrue(F.getParent()->getContext()), [](Use &U) { @@ -460,16 +769,15 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT) { NumCondsRemoved++; Changed = true; } - if (CS.isConditionImplied( - ConstraintSystem::negate(R.get(0).Coefficients))) { + if (CSToUse.isConditionImplied( + ConstraintSystem::negate(R.Coefficients))) { if (!DebugCounter::shouldExecute(EliminatedCounter)) continue; - LLVM_DEBUG(dbgs() << "Condition !" << *Cmp - << " implied by dominating constraints\n"); LLVM_DEBUG({ - for (auto &E : reverse(DFSInStack)) - dbgs() << " C " << *E.Condition << " " << E.IsNot << "\n"; + dbgs() << "Condition !" << *Cmp + << " implied by dominating constraints\n"; + dumpWithNames(CSToUse, Info.getValue2Index(R.IsSigned)); }); Cmp->replaceAllUsesWith( ConstantInt::getFalse(F.getParent()->getContext())); @@ -482,7 +790,7 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT) { // Set up a function to restore the predicate at the end of the scope if it // has been negated. Negate the predicate in-place, if required. - auto *CI = dyn_cast<CmpInst>(CB.Condition); + auto *CI = dyn_cast<ICmpInst>(CB.Condition); auto PredicateRestorer = make_scope_exit([CI, &CB]() { if (CB.Not && CI) CI->setPredicate(CI->getInversePredicate()); @@ -496,34 +804,28 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT) { } } - // Otherwise, add the condition to the system and stack, if we can transform - // it into a constraint. - DenseMap<Value *, unsigned> NewIndices; - auto R = getConstraint(CB.Condition, Value2Index, NewIndices); - if (R.empty()) - continue; - - for (auto &KV : NewIndices) - Value2Index.insert(KV); - - LLVM_DEBUG(dbgs() << "Adding " << *CB.Condition << " " << CB.Not << "\n"); - bool Added = false; - for (auto &C : R.Constraints) { - auto Coeffs = C.Coefficients; - LLVM_DEBUG({ - dbgs() << " constraint: "; - dumpWithNames(C, Value2Index); - }); - Added |= CS.addVariableRowFill(Coeffs); - // If R has been added to the system, queue it for removal once it goes - // out-of-scope. - if (Added) - DFSInStack.emplace_back(CB.NumIn, CB.NumOut, CB.Condition, CB.Not); + ICmpInst::Predicate Pred; + Value *A, *B; + if (match(CB.Condition, m_ICmp(Pred, m_Value(A), m_Value(B)))) { + // Otherwise, add the condition to the system and stack, if we can + // transform it into a constraint. + Info.addFact(Pred, A, B, CB.Not, CB.NumIn, CB.NumOut, DFSInStack); + Info.transferToOtherSystem(Pred, A, B, CB.Not, CB.NumIn, CB.NumOut, + DFSInStack); } } - assert(CS.size() == DFSInStack.size() && +#ifndef NDEBUG + unsigned SignedEntries = + count_if(DFSInStack, [](const StackEntry &E) { return E.IsSigned; }); + assert(Info.getCS(false).size() == DFSInStack.size() - SignedEntries && + "updates to CS and DFSInStack are out of sync"); + assert(Info.getCS(true).size() == SignedEntries && "updates to CS and DFSInStack are out of sync"); +#endif + + for (Instruction *I : ToRemove) + I->eraseFromParent(); return Changed; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index a3fd97079b1d..64bd4241f37c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -41,8 +41,6 @@ #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include <cassert> @@ -215,6 +213,53 @@ static bool simplifyCommonValuePhi(PHINode *P, LazyValueInfo *LVI, return true; } +static Value *getValueOnEdge(LazyValueInfo *LVI, Value *Incoming, + BasicBlock *From, BasicBlock *To, + Instruction *CxtI) { + if (Constant *C = LVI->getConstantOnEdge(Incoming, From, To, CxtI)) + return C; + + // Look if the incoming value is a select with a scalar condition for which + // LVI can tells us the value. In that case replace the incoming value with + // the appropriate value of the select. This often allows us to remove the + // select later. + auto *SI = dyn_cast<SelectInst>(Incoming); + if (!SI) + return nullptr; + + // Once LVI learns to handle vector types, we could also add support + // for vector type constants that are not all zeroes or all ones. + Value *Condition = SI->getCondition(); + if (!Condition->getType()->isVectorTy()) { + if (Constant *C = LVI->getConstantOnEdge(Condition, From, To, CxtI)) { + if (C->isOneValue()) + return SI->getTrueValue(); + if (C->isZeroValue()) + return SI->getFalseValue(); + } + } + + // Look if the select has a constant but LVI tells us that the incoming + // value can never be that constant. In that case replace the incoming + // value with the other value of the select. This often allows us to + // remove the select later. + + // The "false" case + if (auto *C = dyn_cast<Constant>(SI->getFalseValue())) + if (LVI->getPredicateOnEdge(ICmpInst::ICMP_EQ, SI, C, From, To, CxtI) == + LazyValueInfo::False) + return SI->getTrueValue(); + + // The "true" case, + // similar to the select "false" case, but try the select "true" value + if (auto *C = dyn_cast<Constant>(SI->getTrueValue())) + if (LVI->getPredicateOnEdge(ICmpInst::ICMP_EQ, SI, C, From, To, CxtI) == + LazyValueInfo::False) + return SI->getFalseValue(); + + return nullptr; +} + static bool processPHI(PHINode *P, LazyValueInfo *LVI, DominatorTree *DT, const SimplifyQuery &SQ) { bool Changed = false; @@ -224,53 +269,14 @@ static bool processPHI(PHINode *P, LazyValueInfo *LVI, DominatorTree *DT, Value *Incoming = P->getIncomingValue(i); if (isa<Constant>(Incoming)) continue; - Value *V = LVI->getConstantOnEdge(Incoming, P->getIncomingBlock(i), BB, P); - - // Look if the incoming value is a select with a scalar condition for which - // LVI can tells us the value. In that case replace the incoming value with - // the appropriate value of the select. This often allows us to remove the - // select later. - if (!V) { - SelectInst *SI = dyn_cast<SelectInst>(Incoming); - if (!SI) continue; - - Value *Condition = SI->getCondition(); - if (!Condition->getType()->isVectorTy()) { - if (Constant *C = LVI->getConstantOnEdge( - Condition, P->getIncomingBlock(i), BB, P)) { - if (C->isOneValue()) { - V = SI->getTrueValue(); - } else if (C->isZeroValue()) { - V = SI->getFalseValue(); - } - // Once LVI learns to handle vector types, we could also add support - // for vector type constants that are not all zeroes or all ones. - } - } - - // Look if the select has a constant but LVI tells us that the incoming - // value can never be that constant. In that case replace the incoming - // value with the other value of the select. This often allows us to - // remove the select later. - if (!V) { - Constant *C = dyn_cast<Constant>(SI->getFalseValue()); - if (!C) continue; - - if (LVI->getPredicateOnEdge(ICmpInst::ICMP_EQ, SI, C, - P->getIncomingBlock(i), BB, P) != - LazyValueInfo::False) - continue; - V = SI->getTrueValue(); - } - - LLVM_DEBUG(dbgs() << "CVP: Threading PHI over " << *SI << '\n'); + Value *V = getValueOnEdge(LVI, Incoming, P->getIncomingBlock(i), BB, P); + if (V) { + P->setIncomingValue(i, V); + Changed = true; } - - P->setIncomingValue(i, V); - Changed = true; } - if (Value *V = SimplifyInstruction(P, SQ)) { + if (Value *V = simplifyInstruction(P, SQ)) { P->replaceAllUsesWith(V); P->eraseFromParent(); Changed = true; @@ -575,7 +581,7 @@ static bool processOverflowIntrinsic(WithOverflowInst *WO, LazyValueInfo *LVI) { StructType *ST = cast<StructType>(WO->getType()); Constant *Struct = ConstantStruct::get(ST, - { UndefValue::get(ST->getElementType(0)), + { PoisonValue::get(ST->getElementType(0)), ConstantInt::getFalse(ST->getElementType(1)) }); Value *NewI = B.CreateInsertValue(Struct, NewOp, 0); WO->replaceAllUsesWith(NewI); @@ -735,8 +741,7 @@ static bool narrowSDivOrSRem(BinaryOperator *Instr, LazyValueInfo *LVI) { // sdiv/srem is UB if divisor is -1 and divident is INT_MIN, so unless we can // prove that such a combination is impossible, we need to bump the bitwidth. if (CRs[1]->contains(APInt::getAllOnes(OrigWidth)) && - CRs[0]->contains( - APInt::getSignedMinValue(MinSignedBits).sextOrSelf(OrigWidth))) + CRs[0]->contains(APInt::getSignedMinValue(MinSignedBits).sext(OrigWidth))) ++MinSignedBits; // Don't shrink below 8 bits wide. @@ -955,7 +960,8 @@ static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) { ++NumAShrsConverted; auto *BO = BinaryOperator::CreateLShr(SDI->getOperand(0), SDI->getOperand(1), - SDI->getName(), SDI); + "", SDI); + BO->takeName(SDI); BO->setDebugLoc(SDI->getDebugLoc()); BO->setIsExact(SDI->isExact()); SDI->replaceAllUsesWith(BO); @@ -974,8 +980,8 @@ static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) { return false; ++NumSExt; - auto *ZExt = - CastInst::CreateZExtOrBitCast(Base, SDI->getType(), SDI->getName(), SDI); + auto *ZExt = CastInst::CreateZExtOrBitCast(Base, SDI->getType(), "", SDI); + ZExt->takeName(SDI); ZExt->setDebugLoc(SDI->getDebugLoc()); SDI->replaceAllUsesWith(ZExt); SDI->eraseFromParent(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp index 143a78f604fc..5667eefabad5 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp @@ -60,30 +60,31 @@ #include "llvm/Transforms/Scalar/DFAJumpThreading.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" -#include "llvm/Analysis/LoopIterator.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Verifier.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/SSAUpdaterBulk.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> #include <deque> +#ifdef EXPENSIVE_CHECKS +#include "llvm/IR/Verifier.h" +#endif + using namespace llvm; #define DEBUG_TYPE "dfa-jump-threading" @@ -102,6 +103,11 @@ static cl::opt<unsigned> MaxPathLength( cl::desc("Max number of blocks searched to find a threading path"), cl::Hidden, cl::init(20)); +static cl::opt<unsigned> MaxNumPaths( + "dfa-max-num-paths", + cl::desc("Max number of paths enumerated around a switch"), + cl::Hidden, cl::init(200)); + static cl::opt<unsigned> CostThreshold("dfa-cost-threshold", cl::desc("Maximum cost accepted for the transformation"), @@ -414,7 +420,7 @@ inline raw_ostream &operator<<(raw_ostream &OS, const ThreadingPath &TPath) { struct MainSwitch { MainSwitch(SwitchInst *SI, OptimizationRemarkEmitter *ORE) { - if (isPredictable(SI)) { + if (isCandidate(SI)) { Instr = SI; } else { ORE->emit([&]() { @@ -432,83 +438,60 @@ struct MainSwitch { } private: - /// Do a use-def chain traversal. Make sure the value of the switch variable - /// is always a known constant. This means that all conditional jumps based on - /// switch variable can be converted to unconditional jumps. - bool isPredictable(const SwitchInst *SI) { - std::deque<Instruction *> Q; + /// Do a use-def chain traversal starting from the switch condition to see if + /// \p SI is a potential condidate. + /// + /// Also, collect select instructions to unfold. + bool isCandidate(const SwitchInst *SI) { + std::deque<Value *> Q; SmallSet<Value *, 16> SeenValues; SelectInsts.clear(); - Value *FirstDef = SI->getOperand(0); - auto *Inst = dyn_cast<Instruction>(FirstDef); - - // If this is a function argument or another non-instruction, then give up. - // We are interested in loop local variables. - if (!Inst) - return false; - - // Require the first definition to be a PHINode - if (!isa<PHINode>(Inst)) + Value *SICond = SI->getCondition(); + LLVM_DEBUG(dbgs() << "\tSICond: " << *SICond << "\n"); + if (!isa<PHINode>(SICond)) return false; - LLVM_DEBUG(dbgs() << "\tisPredictable() FirstDef: " << *Inst << "\n"); - - Q.push_back(Inst); - SeenValues.insert(FirstDef); + addToQueue(SICond, Q, SeenValues); while (!Q.empty()) { - Instruction *Current = Q.front(); + Value *Current = Q.front(); Q.pop_front(); if (auto *Phi = dyn_cast<PHINode>(Current)) { for (Value *Incoming : Phi->incoming_values()) { - if (!isPredictableValue(Incoming, SeenValues)) - return false; - addInstToQueue(Incoming, Q, SeenValues); + addToQueue(Incoming, Q, SeenValues); } - LLVM_DEBUG(dbgs() << "\tisPredictable() phi: " << *Phi << "\n"); + LLVM_DEBUG(dbgs() << "\tphi: " << *Phi << "\n"); } else if (SelectInst *SelI = dyn_cast<SelectInst>(Current)) { if (!isValidSelectInst(SelI)) return false; - if (!isPredictableValue(SelI->getTrueValue(), SeenValues) || - !isPredictableValue(SelI->getFalseValue(), SeenValues)) { - return false; - } - addInstToQueue(SelI->getTrueValue(), Q, SeenValues); - addInstToQueue(SelI->getFalseValue(), Q, SeenValues); - LLVM_DEBUG(dbgs() << "\tisPredictable() select: " << *SelI << "\n"); + addToQueue(SelI->getTrueValue(), Q, SeenValues); + addToQueue(SelI->getFalseValue(), Q, SeenValues); + LLVM_DEBUG(dbgs() << "\tselect: " << *SelI << "\n"); if (auto *SelIUse = dyn_cast<PHINode>(SelI->user_back())) SelectInsts.push_back(SelectInstToUnfold(SelI, SelIUse)); + } else if (isa<Constant>(Current)) { + LLVM_DEBUG(dbgs() << "\tconst: " << *Current << "\n"); + continue; } else { - // If it is neither a phi nor a select, then we give up. - return false; + LLVM_DEBUG(dbgs() << "\tother: " << *Current << "\n"); + // Allow unpredictable values. The hope is that those will be the + // initial switch values that can be ignored (they will hit the + // unthreaded switch) but this assumption will get checked later after + // paths have been enumerated (in function getStateDefMap). + continue; } } return true; } - bool isPredictableValue(Value *InpVal, SmallSet<Value *, 16> &SeenValues) { - if (SeenValues.contains(InpVal)) - return true; - - if (isa<ConstantInt>(InpVal)) - return true; - - // If this is a function argument or another non-instruction, then give up. - if (!isa<Instruction>(InpVal)) - return false; - - return true; - } - - void addInstToQueue(Value *Val, std::deque<Instruction *> &Q, - SmallSet<Value *, 16> &SeenValues) { + void addToQueue(Value *Val, std::deque<Value *> &Q, + SmallSet<Value *, 16> &SeenValues) { if (SeenValues.contains(Val)) return; - if (Instruction *I = dyn_cast<Instruction>(Val)) - Q.push_back(I); + Q.push_back(Val); SeenValues.insert(Val); } @@ -562,7 +545,16 @@ struct AllSwitchPaths { void run() { VisitedBlocks Visited; PathsType LoopPaths = paths(SwitchBlock, Visited, /* PathDepth = */ 1); - StateDefMap StateDef = getStateDefMap(); + StateDefMap StateDef = getStateDefMap(LoopPaths); + + if (StateDef.empty()) { + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "SwitchNotPredictable", + Switch) + << "Switch instruction is not predictable."; + }); + return; + } for (PathType Path : LoopPaths) { ThreadingPath TPath; @@ -637,6 +629,9 @@ private: PathType NewPath(Path); NewPath.push_front(BB); Res.push_back(NewPath); + if (Res.size() >= MaxNumPaths) { + return Res; + } } } // This block could now be visited again from a different predecessor. Note @@ -647,14 +642,22 @@ private: } /// Walk the use-def chain and collect all the state-defining instructions. - StateDefMap getStateDefMap() const { + /// + /// Return an empty map if unpredictable values encountered inside the basic + /// blocks of \p LoopPaths. + StateDefMap getStateDefMap(const PathsType &LoopPaths) const { StateDefMap Res; + // Basic blocks belonging to any of the loops around the switch statement. + SmallPtrSet<BasicBlock *, 16> LoopBBs; + for (const PathType &Path : LoopPaths) { + for (BasicBlock *BB : Path) + LoopBBs.insert(BB); + } + Value *FirstDef = Switch->getOperand(0); - assert(isa<PHINode>(FirstDef) && "After select unfolding, all state " - "definitions are expected to be phi " - "nodes."); + assert(isa<PHINode>(FirstDef) && "The first definition must be a phi."); SmallVector<PHINode *, 8> Stack; Stack.push_back(dyn_cast<PHINode>(FirstDef)); @@ -666,15 +669,17 @@ private: Res[CurPhi->getParent()] = CurPhi; SeenValues.insert(CurPhi); - for (Value *Incoming : CurPhi->incoming_values()) { + for (BasicBlock *IncomingBB : CurPhi->blocks()) { + Value *Incoming = CurPhi->getIncomingValueForBlock(IncomingBB); + bool IsOutsideLoops = LoopBBs.count(IncomingBB) == 0; if (Incoming == FirstDef || isa<ConstantInt>(Incoming) || - SeenValues.contains(Incoming)) { + SeenValues.contains(Incoming) || IsOutsideLoops) { continue; } - assert(isa<PHINode>(Incoming) && "After select unfolding, all state " - "definitions are expected to be phi " - "nodes."); + // Any unpredictable value inside the loops means we must bail out. + if (!isa<PHINode>(Incoming)) + return StateDefMap(); Stack.push_back(cast<PHINode>(Incoming)); } @@ -823,6 +828,16 @@ private: }); return false; } + + if (!Metrics.NumInsts.isValid()) { + LLVM_DEBUG(dbgs() << "DFA Jump Threading: Not jump threading, contains " + << "instructions with invalid cost.\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "ConvergentInst", Switch) + << "Contains instructions with invalid cost."; + }); + return false; + } } unsigned DuplicationCost = 0; @@ -836,7 +851,7 @@ private: // using binary search, hence the LogBase2(). unsigned CondBranches = APInt(32, Switch->getNumSuccessors()).ceilLogBase2(); - DuplicationCost = Metrics.NumInsts / CondBranches; + DuplicationCost = *Metrics.NumInsts.getValue() / CondBranches; } else { // Compared with jump tables, the DFA optimizer removes an indirect branch // on each loop iteration, thus making branch prediction more precise. The @@ -844,7 +859,7 @@ private: // predictor to make a mistake, and the more benefit there is in the DFA // optimizer. Thus, the more branch targets there are, the lower is the // cost of the DFA opt. - DuplicationCost = Metrics.NumInsts / JumpTableSize; + DuplicationCost = *Metrics.NumInsts.getValue() / JumpTableSize; } LLVM_DEBUG(dbgs() << "\nDFA Jump Threading: Cost to jump thread block " @@ -1197,7 +1212,7 @@ private: PhiToRemove.push_back(Phi); } for (PHINode *PN : PhiToRemove) { - PN->replaceAllUsesWith(UndefValue::get(PN->getType())); + PN->replaceAllUsesWith(PoisonValue::get(PN->getType())); PN->eraseFromParent(); } return; @@ -1246,7 +1261,7 @@ private: /// Returns true if IncomingBB is a predecessor of BB. bool isPredecessor(BasicBlock *BB, BasicBlock *IncomingBB) { - return llvm::find(predecessors(BB), IncomingBB) != pred_end(BB); + return llvm::is_contained(predecessors(BB), IncomingBB); } AllSwitchPaths *SwitchPaths; @@ -1278,7 +1293,7 @@ bool DFAJumpThreading::run(Function &F) { continue; LLVM_DEBUG(dbgs() << "\nCheck if SwitchInst in BB " << BB.getName() - << " is predictable\n"); + << " is a candidate\n"); MainSwitch Switch(SI, ORE); if (!Switch.getInstr()) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp index c5c8e880eb3d..4c42869dbd58 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -38,7 +38,9 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" @@ -62,8 +64,6 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" @@ -75,7 +75,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/DebugCounter.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" @@ -83,7 +82,6 @@ #include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> -#include <cstddef> #include <cstdint> #include <iterator> #include <map> @@ -766,6 +764,9 @@ struct DSEState { // Post-order numbers for each basic block. Used to figure out if memory // accesses are executed before another access. DenseMap<BasicBlock *, unsigned> PostOrderNumbers; + // Values that are only used with assumes. Used to refine pointer escape + // analysis. + SmallPtrSet<const Value *, 32> EphValues; /// Keep track of instructions (partly) overlapping with killing MemoryDefs per /// basic block. @@ -780,10 +781,10 @@ struct DSEState { DSEState &operator=(const DSEState &) = delete; DSEState(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, DominatorTree &DT, - PostDominatorTree &PDT, const TargetLibraryInfo &TLI, - const LoopInfo &LI) - : F(F), AA(AA), EI(DT, LI), BatchAA(AA, &EI), MSSA(MSSA), DT(DT), - PDT(PDT), TLI(TLI), DL(F.getParent()->getDataLayout()), LI(LI) { + PostDominatorTree &PDT, AssumptionCache &AC, + const TargetLibraryInfo &TLI, const LoopInfo &LI) + : F(F), AA(AA), EI(DT, LI, EphValues), BatchAA(AA, &EI), MSSA(MSSA), + DT(DT), PDT(PDT), TLI(TLI), DL(F.getParent()->getDataLayout()), LI(LI) { // Collect blocks with throwing instructions not modeled in MemorySSA and // alloc-like objects. unsigned PO = 0; @@ -813,6 +814,8 @@ struct DSEState { AnyUnreachableExit = any_of(PDT.roots(), [](const BasicBlock *E) { return isa<UnreachableInst>(E->getTerminator()); }); + + CodeMetrics::collectEphemeralValues(&F, &AC, EphValues); } /// Return 'OW_Complete' if a store to the 'KillingLoc' location (by \p @@ -959,7 +962,7 @@ struct DSEState { if (!isInvisibleToCallerOnUnwind(V)) { I.first->second = false; } else if (isNoAliasCall(V)) { - I.first->second = !PointerMayBeCaptured(V, true, false); + I.first->second = !PointerMayBeCaptured(V, true, false, EphValues); } } return I.first->second; @@ -978,7 +981,7 @@ struct DSEState { // with the killing MemoryDef. But we refrain from doing so for now to // limit compile-time and this does not cause any changes to the number // of stores removed on a large test set in practice. - I.first->second = PointerMayBeCaptured(V, false, true); + I.first->second = PointerMayBeCaptured(V, false, true, EphValues); return !I.first->second; } @@ -1011,7 +1014,8 @@ struct DSEState { if (CB->isLifetimeStartOrEnd()) return false; - return CB->use_empty() && CB->willReturn() && CB->doesNotThrow(); + return CB->use_empty() && CB->willReturn() && CB->doesNotThrow() && + !CB->isTerminator(); } return false; @@ -1241,6 +1245,9 @@ struct DSEState { // Reached TOP. if (MSSA.isLiveOnEntryDef(Current)) { LLVM_DEBUG(dbgs() << " ... found LiveOnEntryDef\n"); + if (CanOptimize && Current != KillingDef->getDefiningAccess()) + // The first clobbering def is... none. + KillingDef->setOptimized(Current); return None; } @@ -1317,7 +1324,6 @@ struct DSEState { // memory location and not located in different loops. if (!isGuaranteedLoopIndependent(CurrentI, KillingI, *CurrentLoc)) { LLVM_DEBUG(dbgs() << " ... not guaranteed loop independent\n"); - WalkerStepLimit -= 1; CanOptimize = false; continue; } @@ -1790,10 +1796,9 @@ struct DSEState { if (!isRemovable(DefI)) return false; - if (StoredConstant && isAllocationFn(DefUO, &TLI)) { - auto *CB = cast<CallBase>(DefUO); - auto *InitC = getInitialValueOfAllocation(CB, &TLI, - StoredConstant->getType()); + if (StoredConstant) { + Constant *InitC = + getInitialValueOfAllocation(DefUO, &TLI, StoredConstant->getType()); // If the clobbering access is LiveOnEntry, no instructions between them // can modify the memory location. if (InitC && InitC == StoredConstant) @@ -1931,11 +1936,13 @@ struct DSEState { static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, DominatorTree &DT, PostDominatorTree &PDT, + AssumptionCache &AC, const TargetLibraryInfo &TLI, const LoopInfo &LI) { bool MadeChange = false; - DSEState State(F, AA, MSSA, DT, PDT, TLI, LI); + MSSA.ensureOptimizedUses(); + DSEState State(F, AA, MSSA, DT, PDT, AC, TLI, LI); // For each store: for (unsigned I = 0; I < State.MemDefs.size(); I++) { MemoryDef *KillingDef = State.MemDefs[I]; @@ -2115,9 +2122,10 @@ PreservedAnalyses DSEPass::run(Function &F, FunctionAnalysisManager &AM) { DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); MemorySSA &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); PostDominatorTree &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); + AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F); LoopInfo &LI = AM.getResult<LoopAnalysis>(F); - bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, TLI, LI); + bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, AC, TLI, LI); #ifdef LLVM_ENABLE_STATS if (AreStatisticsEnabled()) @@ -2157,9 +2165,11 @@ public: MemorySSA &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); PostDominatorTree &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); + AssumptionCache &AC = + getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, TLI, LI); + bool Changed = eliminateDeadStores(F, AA, MSSA, DT, PDT, AC, TLI, LI); #ifdef LLVM_ENABLE_STATS if (AreStatisticsEnabled()) @@ -2183,6 +2193,7 @@ public: AU.addPreserved<MemorySSAWrapperPass>(); AU.addRequired<LoopInfoWrapperPass>(); AU.addPreserved<LoopInfoWrapperPass>(); + AU.addRequired<AssumptionCacheTracker>(); } }; @@ -2200,6 +2211,7 @@ INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_END(DSELegacyPass, "dse", "Dead Store Elimination", false, false) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/EarlyCSE.cpp index 59b934c16c8a..cf2824954122 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -16,7 +16,6 @@ #include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopedHashTable.h" -#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" @@ -30,19 +29,16 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" -#include "llvm/IR/Use.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" @@ -55,7 +51,6 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" -#include "llvm/Transforms/Utils/GuardUtils.h" #include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <deque> @@ -781,6 +776,21 @@ private: return getLoadStorePointerOperand(Inst); } + Type *getValueType() const { + // TODO: handle target-specific intrinsics. + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { + switch (II->getIntrinsicID()) { + case Intrinsic::masked_load: + return II->getType(); + case Intrinsic::masked_store: + return II->getArgOperand(0)->getType(); + default: + return nullptr; + } + } + return getLoadStoreType(Inst); + } + bool mayReadFromMemory() const { if (IntrID != 0) return Info.ReadMem; @@ -1162,6 +1172,9 @@ bool EarlyCSE::overridingStores(const ParseMemoryInst &Earlier, "Violated invariant"); if (Earlier.getPointerOperand() != Later.getPointerOperand()) return false; + if (!Earlier.getValueType() || !Later.getValueType() || + Earlier.getValueType() != Later.getValueType()) + return false; if (Earlier.getMatchingId() != Later.getMatchingId()) return false; // At the moment, we don't remove ordered stores, but do remove @@ -1334,7 +1347,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // If the instruction can be simplified (e.g. X+0 = X) then replace it with // its simpler value. - if (Value *V = SimplifyInstruction(&Inst, SQ)) { + if (Value *V = simplifyInstruction(&Inst, SQ)) { LLVM_DEBUG(dbgs() << "EarlyCSE Simplify: " << Inst << " to: " << *V << '\n'); if (!DebugCounter::shouldExecute(CSECounter)) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/FlattenCFGPass.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/FlattenCFGPass.cpp index 44017b555769..ad2041cd4253 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/FlattenCFGPass.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/FlattenCFGPass.cpp @@ -11,8 +11,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/IR/CFG.h" -#include "llvm/IR/InstrTypes.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/ValueHandle.h" #include "llvm/InitializePasses.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Float2Int.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Float2Int.cpp index a98bb8358aef..56f2a3b3004d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Float2Int.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Float2Int.cpp @@ -11,24 +11,22 @@ // //===----------------------------------------------------------------------===// -#include "llvm/InitializePasses.h" -#include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Scalar/Float2Int.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InstIterator.h" -#include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "llvm/InitializePasses.h" #include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include <deque> -#include <functional> // For std::function #define DEBUG_TYPE "float2int" @@ -236,116 +234,111 @@ void Float2IntPass::walkBackwards() { } } -// Walk forwards down the list of seen instructions, so we visit defs before -// uses. -void Float2IntPass::walkForwards() { - for (auto &It : reverse(SeenInsts)) { - if (It.second != unknownRange()) - continue; +// Calculate result range from operand ranges. +// Return None if the range cannot be calculated yet. +Optional<ConstantRange> Float2IntPass::calcRange(Instruction *I) { + SmallVector<ConstantRange, 4> OpRanges; + for (Value *O : I->operands()) { + if (Instruction *OI = dyn_cast<Instruction>(O)) { + auto OpIt = SeenInsts.find(OI); + assert(OpIt != SeenInsts.end() && "def not seen before use!"); + if (OpIt->second == unknownRange()) + return None; // Wait until operand range has been calculated. + OpRanges.push_back(OpIt->second); + } else if (ConstantFP *CF = dyn_cast<ConstantFP>(O)) { + // Work out if the floating point number can be losslessly represented + // as an integer. + // APFloat::convertToInteger(&Exact) purports to do what we want, but + // the exactness can be too precise. For example, negative zero can + // never be exactly converted to an integer. + // + // Instead, we ask APFloat to round itself to an integral value - this + // preserves sign-of-zero - then compare the result with the original. + // + const APFloat &F = CF->getValueAPF(); + + // First, weed out obviously incorrect values. Non-finite numbers + // can't be represented and neither can negative zero, unless + // we're in fast math mode. + if (!F.isFinite() || + (F.isZero() && F.isNegative() && isa<FPMathOperator>(I) && + !I->hasNoSignedZeros())) + return badRange(); + + APFloat NewF = F; + auto Res = NewF.roundToIntegral(APFloat::rmNearestTiesToEven); + if (Res != APFloat::opOK || NewF != F) + return badRange(); + + // OK, it's representable. Now get it. + APSInt Int(MaxIntegerBW+1, false); + bool Exact; + CF->getValueAPF().convertToInteger(Int, + APFloat::rmNearestTiesToEven, + &Exact); + OpRanges.push_back(ConstantRange(Int)); + } else { + llvm_unreachable("Should have already marked this as badRange!"); + } + } - Instruction *I = It.first; - std::function<ConstantRange(ArrayRef<ConstantRange>)> Op; - switch (I->getOpcode()) { - // FIXME: Handle select and phi nodes. - default: - case Instruction::UIToFP: - case Instruction::SIToFP: - llvm_unreachable("Should have been handled in walkForwards!"); + switch (I->getOpcode()) { + // FIXME: Handle select and phi nodes. + default: + case Instruction::UIToFP: + case Instruction::SIToFP: + llvm_unreachable("Should have been handled in walkForwards!"); - case Instruction::FNeg: - Op = [](ArrayRef<ConstantRange> Ops) { - assert(Ops.size() == 1 && "FNeg is a unary operator!"); - unsigned Size = Ops[0].getBitWidth(); - auto Zero = ConstantRange(APInt::getZero(Size)); - return Zero.sub(Ops[0]); - }; - break; + case Instruction::FNeg: { + assert(OpRanges.size() == 1 && "FNeg is a unary operator!"); + unsigned Size = OpRanges[0].getBitWidth(); + auto Zero = ConstantRange(APInt::getZero(Size)); + return Zero.sub(OpRanges[0]); + } - case Instruction::FAdd: - case Instruction::FSub: - case Instruction::FMul: - Op = [I](ArrayRef<ConstantRange> Ops) { - assert(Ops.size() == 2 && "its a binary operator!"); - auto BinOp = (Instruction::BinaryOps) I->getOpcode(); - return Ops[0].binaryOp(BinOp, Ops[1]); - }; - break; + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: { + assert(OpRanges.size() == 2 && "its a binary operator!"); + auto BinOp = (Instruction::BinaryOps) I->getOpcode(); + return OpRanges[0].binaryOp(BinOp, OpRanges[1]); + } - // - // Root-only instructions - we'll only see these if they're the - // first node in a walk. - // - case Instruction::FPToUI: - case Instruction::FPToSI: - Op = [I](ArrayRef<ConstantRange> Ops) { - assert(Ops.size() == 1 && "FPTo[US]I is a unary operator!"); - // Note: We're ignoring the casts output size here as that's what the - // caller expects. - auto CastOp = (Instruction::CastOps)I->getOpcode(); - return Ops[0].castOp(CastOp, MaxIntegerBW+1); - }; - break; + // + // Root-only instructions - we'll only see these if they're the + // first node in a walk. + // + case Instruction::FPToUI: + case Instruction::FPToSI: { + assert(OpRanges.size() == 1 && "FPTo[US]I is a unary operator!"); + // Note: We're ignoring the casts output size here as that's what the + // caller expects. + auto CastOp = (Instruction::CastOps)I->getOpcode(); + return OpRanges[0].castOp(CastOp, MaxIntegerBW+1); + } - case Instruction::FCmp: - Op = [](ArrayRef<ConstantRange> Ops) { - assert(Ops.size() == 2 && "FCmp is a binary operator!"); - return Ops[0].unionWith(Ops[1]); - }; - break; - } + case Instruction::FCmp: + assert(OpRanges.size() == 2 && "FCmp is a binary operator!"); + return OpRanges[0].unionWith(OpRanges[1]); + } +} - bool Abort = false; - SmallVector<ConstantRange,4> OpRanges; - for (Value *O : I->operands()) { - if (Instruction *OI = dyn_cast<Instruction>(O)) { - assert(SeenInsts.find(OI) != SeenInsts.end() && - "def not seen before use!"); - OpRanges.push_back(SeenInsts.find(OI)->second); - } else if (ConstantFP *CF = dyn_cast<ConstantFP>(O)) { - // Work out if the floating point number can be losslessly represented - // as an integer. - // APFloat::convertToInteger(&Exact) purports to do what we want, but - // the exactness can be too precise. For example, negative zero can - // never be exactly converted to an integer. - // - // Instead, we ask APFloat to round itself to an integral value - this - // preserves sign-of-zero - then compare the result with the original. - // - const APFloat &F = CF->getValueAPF(); - - // First, weed out obviously incorrect values. Non-finite numbers - // can't be represented and neither can negative zero, unless - // we're in fast math mode. - if (!F.isFinite() || - (F.isZero() && F.isNegative() && isa<FPMathOperator>(I) && - !I->hasNoSignedZeros())) { - seen(I, badRange()); - Abort = true; - break; - } +// Walk forwards down the list of seen instructions, so we visit defs before +// uses. +void Float2IntPass::walkForwards() { + std::deque<Instruction *> Worklist; + for (const auto &Pair : SeenInsts) + if (Pair.second == unknownRange()) + Worklist.push_back(Pair.first); - APFloat NewF = F; - auto Res = NewF.roundToIntegral(APFloat::rmNearestTiesToEven); - if (Res != APFloat::opOK || NewF != F) { - seen(I, badRange()); - Abort = true; - break; - } - // OK, it's representable. Now get it. - APSInt Int(MaxIntegerBW+1, false); - bool Exact; - CF->getValueAPF().convertToInteger(Int, - APFloat::rmNearestTiesToEven, - &Exact); - OpRanges.push_back(ConstantRange(Int)); - } else { - llvm_unreachable("Should have already marked this as badRange!"); - } - } + while (!Worklist.empty()) { + Instruction *I = Worklist.back(); + Worklist.pop_back(); - // Reduce the operands' ranges to a single range and return. - if (!Abort) - seen(I, Op(OpRanges)); + if (Optional<ConstantRange> Range = calcRange(I)) + seen(I, *Range); + else + Worklist.push_front(I); // Reprocess later. } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVN.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVN.cpp index 398c93e8758c..783301fe589e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVN.cpp @@ -19,7 +19,6 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/MapVector.h" -#include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" @@ -32,6 +31,7 @@ #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InstructionPrecedenceTracking.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" @@ -42,12 +42,10 @@ #include "llvm/Analysis/PHITransAddr.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/Config/llvm-config.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" @@ -55,11 +53,9 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" -#include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" @@ -72,7 +68,6 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" @@ -112,16 +107,16 @@ static cl::opt<bool> GVNEnableLoadInLoopPRE("enable-load-in-loop-pre", cl::init(true)); static cl::opt<bool> GVNEnableSplitBackedgeInLoadPRE("enable-split-backedge-in-load-pre", - cl::init(true)); + cl::init(false)); static cl::opt<bool> GVNEnableMemDep("enable-gvn-memdep", cl::init(true)); static cl::opt<uint32_t> MaxNumDeps( - "gvn-max-num-deps", cl::Hidden, cl::init(100), cl::ZeroOrMore, + "gvn-max-num-deps", cl::Hidden, cl::init(100), cl::desc("Max number of dependences to attempt Load PRE (default = 100)")); // This is based on IsValueFullyAvailableInBlockNumSpeculationsMax stat. static cl::opt<uint32_t> MaxBBSpeculations( - "gvn-max-block-speculations", cl::Hidden, cl::init(600), cl::ZeroOrMore, + "gvn-max-block-speculations", cl::Hidden, cl::init(600), cl::desc("Max number of blocks we're willing to speculate on (and recurse " "into) when deducing if a value is fully available or not in GVN " "(default = 600)")); @@ -129,6 +124,8 @@ static cl::opt<uint32_t> MaxBBSpeculations( struct llvm::GVNPass::Expression { uint32_t opcode; bool commutative = false; + // The type is not necessarily the result type of the expression, it may be + // any additional type needed to disambiguate the expression. Type *type = nullptr; SmallVector<uint32_t, 4> varargs; @@ -178,70 +175,88 @@ template <> struct DenseMapInfo<GVNPass::Expression> { /// implicitly associated with a rematerialization point which is the /// location of the instruction from which it was formed. struct llvm::gvn::AvailableValue { - enum ValType { + enum class ValType { SimpleVal, // A simple offsetted value that is accessed. LoadVal, // A value produced by a load. MemIntrin, // A memory intrinsic which is loaded from. - UndefVal // A UndefValue representing a value from dead block (which + UndefVal, // A UndefValue representing a value from dead block (which // is not yet physically removed from the CFG). + SelectVal, // A pointer select which is loaded from and for which the load + // can be replace by a value select. }; - /// V - The value that is live out of the block. - PointerIntPair<Value *, 2, ValType> Val; + /// Val - The value that is live out of the block. + Value *Val; + /// Kind of the live-out value. + ValType Kind; /// Offset - The byte offset in Val that is interesting for the load query. unsigned Offset = 0; static AvailableValue get(Value *V, unsigned Offset = 0) { AvailableValue Res; - Res.Val.setPointer(V); - Res.Val.setInt(SimpleVal); + Res.Val = V; + Res.Kind = ValType::SimpleVal; Res.Offset = Offset; return Res; } static AvailableValue getMI(MemIntrinsic *MI, unsigned Offset = 0) { AvailableValue Res; - Res.Val.setPointer(MI); - Res.Val.setInt(MemIntrin); + Res.Val = MI; + Res.Kind = ValType::MemIntrin; Res.Offset = Offset; return Res; } static AvailableValue getLoad(LoadInst *Load, unsigned Offset = 0) { AvailableValue Res; - Res.Val.setPointer(Load); - Res.Val.setInt(LoadVal); + Res.Val = Load; + Res.Kind = ValType::LoadVal; Res.Offset = Offset; return Res; } static AvailableValue getUndef() { AvailableValue Res; - Res.Val.setPointer(nullptr); - Res.Val.setInt(UndefVal); + Res.Val = nullptr; + Res.Kind = ValType::UndefVal; Res.Offset = 0; return Res; } - bool isSimpleValue() const { return Val.getInt() == SimpleVal; } - bool isCoercedLoadValue() const { return Val.getInt() == LoadVal; } - bool isMemIntrinValue() const { return Val.getInt() == MemIntrin; } - bool isUndefValue() const { return Val.getInt() == UndefVal; } + static AvailableValue getSelect(SelectInst *Sel) { + AvailableValue Res; + Res.Val = Sel; + Res.Kind = ValType::SelectVal; + Res.Offset = 0; + return Res; + } + + bool isSimpleValue() const { return Kind == ValType::SimpleVal; } + bool isCoercedLoadValue() const { return Kind == ValType::LoadVal; } + bool isMemIntrinValue() const { return Kind == ValType::MemIntrin; } + bool isUndefValue() const { return Kind == ValType::UndefVal; } + bool isSelectValue() const { return Kind == ValType::SelectVal; } Value *getSimpleValue() const { assert(isSimpleValue() && "Wrong accessor"); - return Val.getPointer(); + return Val; } LoadInst *getCoercedLoadValue() const { assert(isCoercedLoadValue() && "Wrong accessor"); - return cast<LoadInst>(Val.getPointer()); + return cast<LoadInst>(Val); } MemIntrinsic *getMemIntrinValue() const { assert(isMemIntrinValue() && "Wrong accessor"); - return cast<MemIntrinsic>(Val.getPointer()); + return cast<MemIntrinsic>(Val); + } + + SelectInst *getSelectValue() const { + assert(isSelectValue() && "Wrong accessor"); + return cast<SelectInst>(Val); } /// Emit code at the specified insertion point to adjust the value defined @@ -275,6 +290,10 @@ struct llvm::gvn::AvailableValueInBlock { return get(BB, AvailableValue::getUndef()); } + static AvailableValueInBlock getSelect(BasicBlock *BB, SelectInst *Sel) { + return get(BB, AvailableValue::getSelect(Sel)); + } + /// Emit code at the end of this block to adjust the value defined here to /// the specified type. This handles various coercion cases. Value *MaterializeAdjustedValue(LoadInst *Load, GVNPass &gvn) const { @@ -379,6 +398,39 @@ GVNPass::ValueTable::createExtractvalueExpr(ExtractValueInst *EI) { return e; } +GVNPass::Expression GVNPass::ValueTable::createGEPExpr(GetElementPtrInst *GEP) { + Expression E; + Type *PtrTy = GEP->getType()->getScalarType(); + const DataLayout &DL = GEP->getModule()->getDataLayout(); + unsigned BitWidth = DL.getIndexTypeSizeInBits(PtrTy); + MapVector<Value *, APInt> VariableOffsets; + APInt ConstantOffset(BitWidth, 0); + if (PtrTy->isOpaquePointerTy() && + GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset)) { + // For opaque pointers, convert into offset representation, to recognize + // equivalent address calculations that use different type encoding. + LLVMContext &Context = GEP->getContext(); + E.opcode = GEP->getOpcode(); + E.type = nullptr; + E.varargs.push_back(lookupOrAdd(GEP->getPointerOperand())); + for (const auto &Pair : VariableOffsets) { + E.varargs.push_back(lookupOrAdd(Pair.first)); + E.varargs.push_back(lookupOrAdd(ConstantInt::get(Context, Pair.second))); + } + if (!ConstantOffset.isZero()) + E.varargs.push_back( + lookupOrAdd(ConstantInt::get(Context, ConstantOffset))); + } else { + // If converting to offset representation fails (for typed pointers and + // scalable vectors), fall back to type-based implementation: + E.opcode = GEP->getOpcode(); + E.type = GEP->getSourceElementType(); + for (Use &Op : GEP->operands()) + E.varargs.push_back(lookupOrAdd(Op)); + } + return E; +} + //===----------------------------------------------------------------------===// // ValueTable External Functions //===----------------------------------------------------------------------===// @@ -562,9 +614,11 @@ uint32_t GVNPass::ValueTable::lookupOrAdd(Value *V) { case Instruction::InsertElement: case Instruction::ShuffleVector: case Instruction::InsertValue: - case Instruction::GetElementPtr: exp = createExpr(I); break; + case Instruction::GetElementPtr: + exp = createGEPExpr(cast<GetElementPtrInst>(I)); + break; case Instruction::ExtractValue: exp = createExtractvalueExpr(cast<ExtractValueInst>(I)); break; @@ -639,24 +693,24 @@ void GVNPass::ValueTable::verifyRemoved(const Value *V) const { //===----------------------------------------------------------------------===// bool GVNPass::isPREEnabled() const { - return Options.AllowPRE.getValueOr(GVNEnablePRE); + return Options.AllowPRE.value_or(GVNEnablePRE); } bool GVNPass::isLoadPREEnabled() const { - return Options.AllowLoadPRE.getValueOr(GVNEnableLoadPRE); + return Options.AllowLoadPRE.value_or(GVNEnableLoadPRE); } bool GVNPass::isLoadInLoopPREEnabled() const { - return Options.AllowLoadInLoopPRE.getValueOr(GVNEnableLoadInLoopPRE); + return Options.AllowLoadInLoopPRE.value_or(GVNEnableLoadInLoopPRE); } bool GVNPass::isLoadPRESplitBackedgeEnabled() const { - return Options.AllowLoadPRESplitBackedge.getValueOr( + return Options.AllowLoadPRESplitBackedge.value_or( GVNEnableSplitBackedgeInLoadPRE); } bool GVNPass::isMemDepEnabled() const { - return Options.AllowMemDep.getValueOr(GVNEnableMemDep); + return Options.AllowMemDep.value_or(GVNEnableMemDep); } PreservedAnalyses GVNPass::run(Function &F, FunctionAnalysisManager &AM) { @@ -897,6 +951,17 @@ ConstructSSAForLoadSet(LoadInst *Load, return SSAUpdate.GetValueInMiddleOfBlock(Load->getParent()); } +static LoadInst *findDominatingLoad(Value *Ptr, Type *LoadTy, SelectInst *Sel, + DominatorTree &DT) { + for (Value *U : Ptr->users()) { + auto *LI = dyn_cast<LoadInst>(U); + if (LI && LI->getType() == LoadTy && LI->getParent() == Sel->getParent() && + DT.dominates(LI, Sel)) + return LI; + } + return nullptr; +} + Value *AvailableValue::MaterializeAdjustedValue(LoadInst *Load, Instruction *InsertPt, GVNPass &gvn) const { @@ -937,6 +1002,17 @@ Value *AvailableValue::MaterializeAdjustedValue(LoadInst *Load, << " " << *getMemIntrinValue() << '\n' << *Res << '\n' << "\n\n\n"); + } else if (isSelectValue()) { + // Introduce a new value select for a load from an eligible pointer select. + SelectInst *Sel = getSelectValue(); + LoadInst *L1 = findDominatingLoad(Sel->getOperand(1), LoadTy, Sel, + gvn.getDominatorTree()); + LoadInst *L2 = findDominatingLoad(Sel->getOperand(2), LoadTy, Sel, + gvn.getDominatorTree()); + assert(L1 && L2 && + "must be able to obtain dominating loads for both value operands of " + "the select"); + Res = SelectInst::Create(Sel->getCondition(), L1, L2, "", Sel); } else { llvm_unreachable("Should not materialize value from dead block"); } @@ -1023,8 +1099,54 @@ static void reportMayClobberedLoad(LoadInst *Load, MemDepResult DepInfo, ORE->emit(R); } +/// Check if a load from pointer-select \p Address in \p DepBB can be converted +/// to a value select. The following conditions need to be satisfied: +/// 1. The pointer select (\p Address) must be defined in \p DepBB. +/// 2. Both value operands of the pointer select must be loaded in the same +/// basic block, before the pointer select. +/// 3. There must be no instructions between the found loads and \p End that may +/// clobber the loads. +static Optional<AvailableValue> +tryToConvertLoadOfPtrSelect(BasicBlock *DepBB, BasicBlock::iterator End, + Value *Address, Type *LoadTy, DominatorTree &DT, + AAResults *AA) { + + auto *Sel = dyn_cast_or_null<SelectInst>(Address); + if (!Sel || DepBB != Sel->getParent()) + return None; + + LoadInst *L1 = findDominatingLoad(Sel->getOperand(1), LoadTy, Sel, DT); + LoadInst *L2 = findDominatingLoad(Sel->getOperand(2), LoadTy, Sel, DT); + if (!L1 || !L2) + return None; + + // Ensure there are no accesses that may modify the locations referenced by + // either L1 or L2 between L1, L2 and the specified End iterator. + Instruction *EarlierLoad = L1->comesBefore(L2) ? L1 : L2; + MemoryLocation L1Loc = MemoryLocation::get(L1); + MemoryLocation L2Loc = MemoryLocation::get(L2); + if (any_of(make_range(EarlierLoad->getIterator(), End), [&](Instruction &I) { + return isModSet(AA->getModRefInfo(&I, L1Loc)) || + isModSet(AA->getModRefInfo(&I, L2Loc)); + })) + return None; + + return AvailableValue::getSelect(Sel); +} + bool GVNPass::AnalyzeLoadAvailability(LoadInst *Load, MemDepResult DepInfo, Value *Address, AvailableValue &Res) { + if (!DepInfo.isDef() && !DepInfo.isClobber()) { + assert(isa<SelectInst>(Address)); + if (auto R = tryToConvertLoadOfPtrSelect( + Load->getParent(), Load->getIterator(), Address, Load->getType(), + getDominatorTree(), getAliasAnalysis())) { + Res = *R; + return true; + } + return false; + } + assert((DepInfo.isDef() || DepInfo.isClobber()) && "expected a local dependence"); assert(Load->isUnordered() && "rules below are incorrect for ordered access"); @@ -1066,9 +1188,7 @@ bool GVNPass::AnalyzeLoadAvailability(LoadInst *Load, MemDepResult DepInfo, canCoerceMustAliasedValueToLoad(DepLoad, LoadType, DL)) { const auto ClobberOff = MD->getClobberOffset(DepLoad); // GVN has no deal with a negative offset. - Offset = (ClobberOff == None || ClobberOff.getValue() < 0) - ? -1 - : ClobberOff.getValue(); + Offset = (ClobberOff == None || *ClobberOff < 0) ? -1 : *ClobberOff; } if (Offset == -1) Offset = @@ -1092,6 +1212,7 @@ bool GVNPass::AnalyzeLoadAvailability(LoadInst *Load, MemDepResult DepInfo, } } } + // Nothing known about this clobber, have to be conservative LLVM_DEBUG( // fast print dep, using operator<< on instruction is too slow. @@ -1111,12 +1232,11 @@ bool GVNPass::AnalyzeLoadAvailability(LoadInst *Load, MemDepResult DepInfo, return true; } - if (isAllocationFn(DepInst, TLI)) - if (auto *InitVal = getInitialValueOfAllocation(cast<CallBase>(DepInst), - TLI, Load->getType())) { - Res = AvailableValue::get(InitVal); - return true; - } + if (Constant *InitVal = + getInitialValueOfAllocation(DepInst, TLI, Load->getType())) { + Res = AvailableValue::get(InitVal); + return true; + } if (StoreInst *S = dyn_cast<StoreInst>(DepInst)) { // Reject loads and stores that are to the same address but are of @@ -1176,16 +1296,23 @@ void GVNPass::AnalyzeLoadAvailability(LoadInst *Load, LoadDepVect &Deps, continue; } - if (!DepInfo.isDef() && !DepInfo.isClobber()) { - UnavailableBlocks.push_back(DepBB); - continue; - } - // The address being loaded in this non-local block may not be the same as // the pointer operand of the load if PHI translation occurs. Make sure // to consider the right address. Value *Address = Deps[i].getAddress(); + if (!DepInfo.isDef() && !DepInfo.isClobber()) { + if (auto R = tryToConvertLoadOfPtrSelect( + DepBB, DepBB->end(), Address, Load->getType(), getDominatorTree(), + getAliasAnalysis())) { + ValuesPerBlock.push_back( + AvailableValueInBlock::get(DepBB, std::move(*R))); + continue; + } + UnavailableBlocks.push_back(DepBB); + continue; + } + AvailableValue AV; if (AnalyzeLoadAvailability(Load, DepInfo, Address, AV)) { // subtlety: because we know this was a non-local dependency, we know @@ -1923,8 +2050,9 @@ bool GVNPass::processLoad(LoadInst *L) { if (Dep.isNonLocal()) return processNonLocalLoad(L); + Value *Address = L->getPointerOperand(); // Only handle the local case below - if (!Dep.isDef() && !Dep.isClobber()) { + if (!Dep.isDef() && !Dep.isClobber() && !isa<SelectInst>(Address)) { // This might be a NonFuncLocal or an Unknown LLVM_DEBUG( // fast print dep, using operator<< on instruction is too slow. @@ -1934,7 +2062,7 @@ bool GVNPass::processLoad(LoadInst *L) { } AvailableValue AV; - if (AnalyzeLoadAvailability(L, Dep, L->getPointerOperand(), AV)) { + if (AnalyzeLoadAvailability(L, Dep, Address, AV)) { Value *AvailableValue = AV.MaterializeAdjustedValue(L, L, *this); // Replace the load! @@ -2324,7 +2452,7 @@ bool GVNPass::processInstruction(Instruction *I) { // example if it determines that %y is equal to %x then the instruction // "%z = and i32 %x, %y" becomes "%z = and i32 %x, %x" which we now simplify. const DataLayout &DL = I->getModule()->getDataLayout(); - if (Value *V = SimplifyInstruction(I, {DL, TLI, DT, AC})) { + if (Value *V = simplifyInstruction(I, {DL, TLI, DT, AC})) { bool Changed = false; if (!I->use_empty()) { // Simplification can cause a special instruction to become not special. @@ -2491,6 +2619,7 @@ bool GVNPass::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, unsigned Iteration = 0; while (ShouldContinue) { LLVM_DEBUG(dbgs() << "GVN iteration: " << Iteration << "\n"); + (void) Iteration; ShouldContinue = iterateOnFunction(F); Changed |= ShouldContinue; ++Iteration; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNHoist.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNHoist.cpp index fdc3afd9348a..6cdc671ddb64 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNHoist.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNHoist.cpp @@ -54,11 +54,9 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" -#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Use.h" @@ -126,7 +124,7 @@ using HoistingPointInfo = std::pair<BasicBlock *, SmallVecInsn>; using HoistingPointList = SmallVector<HoistingPointInfo, 4>; // A map from a pair of VNs to all the instructions with those VNs. -using VNType = std::pair<unsigned, unsigned>; +using VNType = std::pair<unsigned, uintptr_t>; using VNtoInsns = DenseMap<VNType, SmallVector<Instruction *, 4>>; @@ -161,7 +159,7 @@ using InValuesType = // An invalid value number Used when inserting a single value number into // VNtoInsns. -enum : unsigned { InvalidVN = ~2U }; +enum : uintptr_t { InvalidVN = ~(uintptr_t)2 }; // Records all scalar instructions candidate for code hoisting. class InsnInfo { @@ -187,7 +185,9 @@ public: void insert(LoadInst *Load, GVNPass::ValueTable &VN) { if (Load->isSimple()) { unsigned V = VN.lookupOrAdd(Load->getPointerOperand()); - VNtoLoads[{V, InvalidVN}].push_back(Load); + // With opaque pointers we may have loads from the same pointer with + // different result types, which should be disambiguated. + VNtoLoads[{V, (uintptr_t)Load->getType()}].push_back(Load); } } @@ -261,7 +261,9 @@ public: GVNHoist(DominatorTree *DT, PostDominatorTree *PDT, AliasAnalysis *AA, MemoryDependenceResults *MD, MemorySSA *MSSA) : DT(DT), PDT(PDT), AA(AA), MD(MD), MSSA(MSSA), - MSSAUpdater(std::make_unique<MemorySSAUpdater>(MSSA)) {} + MSSAUpdater(std::make_unique<MemorySSAUpdater>(MSSA)) { + MSSA->ensureOptimizedUses(); + } bool run(Function &F); @@ -1147,6 +1149,8 @@ std::pair<unsigned, unsigned> GVNHoist::hoist(HoistingPointList &HPL) { DFSNumber[Repl] = DFSNumber[Last]++; } + // Drop debug location as per debug info update guide. + Repl->dropLocation(); NR += removeAndReplace(InstructionsToHoist, Repl, DestBB, MoveAccess); if (isa<LoadInst>(Repl)) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNSink.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNSink.cpp index e612a82fc89a..720b8e71fd56 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNSink.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNSink.cpp @@ -35,7 +35,6 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/None.h" @@ -45,7 +44,6 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -383,6 +381,8 @@ public: } }; +using BasicBlocksSet = SmallPtrSet<const BasicBlock *, 32>; + class ValueTable { DenseMap<Value *, uint32_t> ValueNumbering; DenseMap<GVNExpression::Expression *, uint32_t> ExpressionNumbering; @@ -390,6 +390,7 @@ class ValueTable { BumpPtrAllocator Allocator; ArrayRecycler<Value *> Recycler; uint32_t nextValueNumber = 1; + BasicBlocksSet ReachableBBs; /// Create an expression for I based on its opcode and its uses. If I /// touches or reads memory, the expression is also based upon its memory @@ -421,6 +422,11 @@ class ValueTable { public: ValueTable() = default; + /// Set basic blocks reachable from entry block. + void setReachableBBs(const BasicBlocksSet &ReachableBBs) { + this->ReachableBBs = ReachableBBs; + } + /// Returns the value number for the specified value, assigning /// it a new number if it did not have one before. uint32_t lookupOrAdd(Value *V) { @@ -434,6 +440,9 @@ public: } Instruction *I = cast<Instruction>(V); + if (!ReachableBBs.contains(I->getParent())) + return ~0U; + InstructionUseExpr *exp = nullptr; switch (I->getOpcode()) { case Instruction::Load: @@ -570,6 +579,7 @@ public: unsigned NumSunk = 0; ReversePostOrderTraversal<Function*> RPOT(&F); + VN.setReachableBBs(BasicBlocksSet(RPOT.begin(), RPOT.end())); for (auto *N : RPOT) NumSunk += sinkBB(N); @@ -648,12 +658,7 @@ Optional<SinkingInstructionCandidate> GVNSink::analyzeInstructionForSinking( VNums[N]++; } unsigned VNumToSink = - std::max_element(VNums.begin(), VNums.end(), - [](const std::pair<uint32_t, unsigned> &I, - const std::pair<uint32_t, unsigned> &J) { - return I.second < J.second; - }) - ->first; + std::max_element(VNums.begin(), VNums.end(), llvm::less_second())->first; if (VNums[VNumToSink] == 1) // Can't sink anything! @@ -776,12 +781,9 @@ unsigned GVNSink::sinkBB(BasicBlock *BBEnd) { unsigned NumOrigPreds = Preds.size(); // We can only sink instructions through unconditional branches. - for (auto I = Preds.begin(); I != Preds.end();) { - if ((*I)->getTerminator()->getNumSuccessors() != 1) - I = Preds.erase(I); - else - ++I; - } + llvm::erase_if(Preds, [](BasicBlock *BB) { + return BB->getTerminator()->getNumSuccessors() != 1; + }); LockstepReverseIterator LRI(Preds); SmallVector<SinkingInstructionCandidate, 4> Candidates; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GuardWidening.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GuardWidening.cpp index 82b81003ef21..af6062d142f0 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GuardWidening.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GuardWidening.cpp @@ -42,7 +42,6 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" @@ -496,6 +495,8 @@ void GuardWideningImpl::makeAvailableAt(Value *V, Instruction *Loc) const { makeAvailableAt(Op, Loc); Inst->moveBefore(Loc); + // If we moved instruction before guard we must clean poison generating flags. + Inst->dropPoisonGeneratingFlags(); } bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/IVUsersPrinter.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/IVUsersPrinter.cpp index e2022aba97c4..26f2db183fbf 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/IVUsersPrinter.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/IVUsersPrinter.cpp @@ -8,7 +8,6 @@ #include "llvm/Transforms/Scalar/IVUsersPrinter.h" #include "llvm/Analysis/IVUsers.h" -#include "llvm/Support/Debug.h" using namespace llvm; #define DEBUG_TYPE "iv-users" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp index ceb03eb17f6d..e977dd18be9f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -25,10 +25,7 @@ #include "llvm/Transforms/Scalar/IndVarSimplify.h" #include "llvm/ADT/APFloat.h" -#include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" @@ -74,11 +71,9 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -387,7 +382,7 @@ bool IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { RecursivelyDeleteTriviallyDeadInstructions(Compare, TLI, MSSAU.get()); // Delete the old floating point increment. - Incr->replaceAllUsesWith(UndefValue::get(Incr->getType())); + Incr->replaceAllUsesWith(PoisonValue::get(Incr->getType())); RecursivelyDeleteTriviallyDeadInstructions(Incr, TLI, MSSAU.get()); // If the FP induction variable still has uses, this is because something else @@ -605,10 +600,10 @@ bool IndVarSimplify::simplifyAndExtend(Loop *L, Intrinsic::getName(Intrinsic::experimental_guard)); bool HasGuards = GuardDecl && !GuardDecl->use_empty(); - SmallVector<PHINode*, 8> LoopPhis; - for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) { - LoopPhis.push_back(cast<PHINode>(I)); - } + SmallVector<PHINode *, 8> LoopPhis; + for (PHINode &PN : L->getHeader()->phis()) + LoopPhis.push_back(&PN); + // Each round of simplification iterates through the SimplifyIVUsers worklist // for all current phis, then determines whether any IVs can be // widened. Widening adds new phis to LoopPhis, inducing another round of diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index 0e5653eeb7d5..799669a19796 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -56,8 +56,6 @@ #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/IR/BasicBlock.h" @@ -1411,12 +1409,12 @@ bool LoopConstrainer::run() { bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate; Optional<SubRanges> MaybeSR = calculateSubRanges(IsSignedPredicate); - if (!MaybeSR.hasValue()) { + if (!MaybeSR) { LLVM_DEBUG(dbgs() << "irce: could not compute subranges\n"); return false; } - SubRanges SR = MaybeSR.getValue(); + SubRanges SR = *MaybeSR; bool Increasing = MainLoopStructure.IndVarIncreasing; IntegerType *IVTy = cast<IntegerType>(Range.getBegin()->getType()); @@ -1429,9 +1427,9 @@ bool LoopConstrainer::run() { // constructor. ClonedLoop PreLoop, PostLoop; bool NeedsPreLoop = - Increasing ? SR.LowLimit.hasValue() : SR.HighLimit.hasValue(); + Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value(); bool NeedsPostLoop = - Increasing ? SR.HighLimit.hasValue() : SR.LowLimit.hasValue(); + Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value(); Value *ExitPreLoopAt = nullptr; Value *ExitMainLoopAt = nullptr; @@ -1710,7 +1708,7 @@ IntersectSignedRange(ScalarEvolution &SE, const InductiveRangeCheck::Range &R2) { if (R2.isEmpty(SE, /* IsSigned */ true)) return None; - if (!R1.hasValue()) + if (!R1) return R2; auto &R1Value = R1.getValue(); // We never return empty ranges from this function, and R1 is supposed to be @@ -1739,7 +1737,7 @@ IntersectUnsignedRange(ScalarEvolution &SE, const InductiveRangeCheck::Range &R2) { if (R2.isEmpty(SE, /* IsSigned */ false)) return None; - if (!R1.hasValue()) + if (!R1) return R2; auto &R1Value = R1.getValue(); // We never return empty ranges from this function, and R1 is supposed to be @@ -1763,10 +1761,14 @@ IntersectUnsignedRange(ScalarEvolution &SE, } PreservedAnalyses IRCEPass::run(Function &F, FunctionAnalysisManager &AM) { - auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); - auto &BPI = AM.getResult<BranchProbabilityAnalysis>(F); LoopInfo &LI = AM.getResult<LoopAnalysis>(F); + // There are no loops in the function. Return before computing other expensive + // analyses. + if (LI.empty()) + return PreservedAnalyses::all(); + auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); + auto &BPI = AM.getResult<BranchProbabilityAnalysis>(F); // Get BFI analysis result on demand. Please note that modification of // CFG invalidates this analysis and we should handle it. @@ -1854,7 +1856,7 @@ InductiveRangeCheckElimination::isProfitableToTransform(const Loop &L, LoopStructure &LS) { if (SkipProfitabilityChecks) return true; - if (GetBFI.hasValue()) { + if (GetBFI) { BlockFrequencyInfo &BFI = (*GetBFI)(); uint64_t hFreq = BFI.getBlockFreq(LS.Header).getFrequency(); uint64_t phFreq = BFI.getBlockFreq(L.getLoopPreheader()).getFrequency(); @@ -1920,12 +1922,12 @@ bool InductiveRangeCheckElimination::run( const char *FailureReason = nullptr; Optional<LoopStructure> MaybeLoopStructure = LoopStructure::parseLoopStructure(SE, *L, FailureReason); - if (!MaybeLoopStructure.hasValue()) { + if (!MaybeLoopStructure) { LLVM_DEBUG(dbgs() << "irce: could not parse loop structure: " << FailureReason << "\n";); return false; } - LoopStructure LS = MaybeLoopStructure.getValue(); + LoopStructure LS = *MaybeLoopStructure; if (!isProfitableToTransform(*L, LS)) return false; const SCEVAddRecExpr *IndVar = @@ -1946,10 +1948,10 @@ bool InductiveRangeCheckElimination::run( for (InductiveRangeCheck &IRC : RangeChecks) { auto Result = IRC.computeSafeIterationSpace(SE, IndVar, LS.IsSignedPredicate); - if (Result.hasValue()) { + if (Result) { auto MaybeSafeIterRange = IntersectRange(SE, SafeIterRange, Result.getValue()); - if (MaybeSafeIterRange.hasValue()) { + if (MaybeSafeIterRange) { assert( !MaybeSafeIterRange.getValue().isEmpty(SE, LS.IsSignedPredicate) && "We should never return empty ranges!"); @@ -1959,7 +1961,7 @@ bool InductiveRangeCheckElimination::run( } } - if (!SafeIterRange.hasValue()) + if (!SafeIterRange) return false; LoopConstrainer LC(*L, LI, LPMAddNewLoop, LS, SE, DT, diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp index ddc747a2ca29..5eefde2e37a1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -92,8 +92,6 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AssumptionCache.h" @@ -182,7 +180,7 @@ public: class InferAddressSpacesImpl { AssumptionCache &AC; - DominatorTree *DT = nullptr; + const DominatorTree *DT = nullptr; const TargetTransformInfo *TTI = nullptr; const DataLayout *DL = nullptr; @@ -213,10 +211,11 @@ class InferAddressSpacesImpl { // Changes the flat address expressions in function F to point to specific // address spaces if InferredAddrSpace says so. Postorder is the postorder of // all flat expressions in the use-def graph of function F. - bool rewriteWithNewAddressSpaces( - const TargetTransformInfo &TTI, ArrayRef<WeakTrackingVH> Postorder, - const ValueToAddrSpaceMapTy &InferredAddrSpace, - const PredicatedAddrSpaceMapTy &PredicatedAS, Function *F) const; + bool + rewriteWithNewAddressSpaces(ArrayRef<WeakTrackingVH> Postorder, + const ValueToAddrSpaceMapTy &InferredAddrSpace, + const PredicatedAddrSpaceMapTy &PredicatedAS, + Function *F) const; void appendsFlatAddressExpressionToPostorderStack( Value *V, PostorderStackTy &PostorderStack, @@ -240,7 +239,7 @@ class InferAddressSpacesImpl { unsigned getPredicatedAddrSpace(const Value &V, Value *Opnd) const; public: - InferAddressSpacesImpl(AssumptionCache &AC, DominatorTree *DT, + InferAddressSpacesImpl(AssumptionCache &AC, const DominatorTree *DT, const TargetTransformInfo *TTI, unsigned FlatAddrSpace) : AC(AC), DT(DT), TTI(TTI), FlatAddrSpace(FlatAddrSpace) {} bool run(Function &F); @@ -280,15 +279,15 @@ static bool isNoopPtrIntCastPair(const Operator *I2P, const DataLayout &DL, // arithmetic may also be undefined after invalid pointer reinterpret cast. // However, as we confirm through the target hooks that it's a no-op // addrspacecast, it doesn't matter since the bits should be the same. + unsigned P2IOp0AS = P2I->getOperand(0)->getType()->getPointerAddressSpace(); + unsigned I2PAS = I2P->getType()->getPointerAddressSpace(); return CastInst::isNoopCast(Instruction::CastOps(I2P->getOpcode()), I2P->getOperand(0)->getType(), I2P->getType(), DL) && CastInst::isNoopCast(Instruction::CastOps(P2I->getOpcode()), P2I->getOperand(0)->getType(), P2I->getType(), DL) && - TTI->isNoopAddrSpaceCast( - P2I->getOperand(0)->getType()->getPointerAddressSpace(), - I2P->getType()->getPointerAddressSpace()); + (P2IOp0AS == I2PAS || TTI->isNoopAddrSpaceCast(P2IOp0AS, I2PAS)); } // Returns true if V is an address expression. @@ -332,8 +331,7 @@ getPointerOperands(const Value &V, const DataLayout &DL, switch (Op.getOpcode()) { case Instruction::PHI: { auto IncomingValues = cast<PHINode>(Op).incoming_values(); - return SmallVector<Value *, 2>(IncomingValues.begin(), - IncomingValues.end()); + return {IncomingValues.begin(), IncomingValues.end()}; } case Instruction::BitCast: case Instruction::AddrSpaceCast: @@ -729,7 +727,7 @@ static Value *cloneConstantExprWithNewAddressSpace( NewOperands.push_back(cast<Constant>(NewOperand)); continue; } - if (auto CExpr = dyn_cast<ConstantExpr>(Operand)) + if (auto *CExpr = dyn_cast<ConstantExpr>(Operand)) if (Value *NewOperand = cloneConstantExprWithNewAddressSpace( CExpr, NewAddrSpace, ValueWithNewAddrSpace, DL, TTI)) { IsNew = true; @@ -741,7 +739,7 @@ static Value *cloneConstantExprWithNewAddressSpace( } // If !IsNew, we will replace the Value with itself. However, replaced values - // are assumed to wrapped in a addrspace cast later so drop it now. + // are assumed to wrapped in an addrspacecast cast later so drop it now. if (!IsNew) return nullptr; @@ -824,8 +822,8 @@ bool InferAddressSpacesImpl::run(Function &F) { // Changes the address spaces of the flat address expressions who are inferred // to point to a specific address space. - return rewriteWithNewAddressSpaces(*TTI, Postorder, InferredAddrSpace, - PredicatedAS, &F); + return rewriteWithNewAddressSpaces(Postorder, InferredAddrSpace, PredicatedAS, + &F); } // Constants need to be tracked through RAUW to handle cases with nested @@ -1013,7 +1011,7 @@ static bool isSimplePointerUseValidToReplace(const TargetTransformInfo &TTI, } /// Update memory intrinsic uses that require more complex processing than -/// simple memory instructions. Thse require re-mangling and may have multiple +/// simple memory instructions. These require re-mangling and may have multiple /// pointer operands. static bool handleMemIntrinsicPtrUse(MemIntrinsic *MI, Value *OldV, Value *NewV) { @@ -1023,8 +1021,7 @@ static bool handleMemIntrinsicPtrUse(MemIntrinsic *MI, Value *OldV, MDNode *NoAliasMD = MI->getMetadata(LLVMContext::MD_noalias); if (auto *MSI = dyn_cast<MemSetInst>(MI)) { - B.CreateMemSet(NewV, MSI->getValue(), MSI->getLength(), - MaybeAlign(MSI->getDestAlignment()), + B.CreateMemSet(NewV, MSI->getValue(), MSI->getLength(), MSI->getDestAlign(), false, // isVolatile TBAA, ScopeMD, NoAliasMD); } else if (auto *MTI = dyn_cast<MemTransferInst>(MI)) { @@ -1107,7 +1104,7 @@ static Value::use_iterator skipToNextUser(Value::use_iterator I, } bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( - const TargetTransformInfo &TTI, ArrayRef<WeakTrackingVH> Postorder, + ArrayRef<WeakTrackingVH> Postorder, const ValueToAddrSpaceMapTy &InferredAddrSpace, const PredicatedAddrSpaceMapTy &PredicatedAS, Function *F) const { // For each address expression to be modified, creates a clone of it with its @@ -1181,7 +1178,7 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( I = skipToNextUser(I, E); if (isSimplePointerUseValidToReplace( - TTI, U, V->getType()->getPointerAddressSpace())) { + *TTI, U, V->getType()->getPointerAddressSpace())) { // If V is used as the pointer operand of a compatible memory operation, // sets the pointer operand to NewV. This replacement does not change // the element type, so the resultant load/store is still valid. @@ -1242,8 +1239,16 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( if (!cast<PointerType>(ASC->getType()) ->hasSameElementTypeAs( cast<PointerType>(NewV->getType()))) { + BasicBlock::iterator InsertPos; + if (Instruction *NewVInst = dyn_cast<Instruction>(NewV)) + InsertPos = std::next(NewVInst->getIterator()); + else if (Instruction *VInst = dyn_cast<Instruction>(V)) + InsertPos = std::next(VInst->getIterator()); + else + InsertPos = ASC->getIterator(); + NewV = CastInst::Create(Instruction::BitCast, NewV, - ASC->getType(), "", ASC); + ASC->getType(), "", &*InsertPos); } ASC->replaceAllUsesWith(NewV); DeadInstructions.push_back(ASC); @@ -1252,12 +1257,18 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( } // Otherwise, replaces the use with flat(NewV). - if (Instruction *Inst = dyn_cast<Instruction>(V)) { + if (Instruction *VInst = dyn_cast<Instruction>(V)) { // Don't create a copy of the original addrspacecast. if (U == V && isa<AddrSpaceCastInst>(V)) continue; - BasicBlock::iterator InsertPos = std::next(Inst->getIterator()); + // Insert the addrspacecast after NewV. + BasicBlock::iterator InsertPos; + if (Instruction *NewVInst = dyn_cast<Instruction>(NewV)) + InsertPos = std::next(NewVInst->getIterator()); + else + InsertPos = std::next(VInst->getIterator()); + while (isa<PHINode>(InsertPos)) ++InsertPos; U.set(new AddrSpaceCastInst(NewV, V->getType(), "", &*InsertPos)); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp index c11d2e4c1d6b..4644905adba3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp @@ -7,21 +7,17 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/InstSimplifyPass.h" -#include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" -#include "llvm/IR/Type.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -55,7 +51,7 @@ static bool runImpl(Function &F, const SimplifyQuery &SQ, DeadInstsInBB.push_back(&I); Changed = true; } else if (!I.use_empty()) { - if (Value *V = SimplifyInstruction(&I, SQ, ORE)) { + if (Value *V = simplifyInstruction(&I, SQ, ORE)) { // Mark all uses for resimplification next time round the loop. for (User *U : I.users()) Next->insert(cast<Instruction>(U)); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpThreading.cpp index a3efad104ca6..5caefc422921 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -56,7 +56,6 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" -#include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" @@ -74,7 +73,6 @@ #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> #include <cassert> -#include <cstddef> #include <cstdint> #include <iterator> #include <memory> @@ -106,11 +104,6 @@ static cl::opt<bool> PrintLVIAfterJumpThreading( cl::desc("Print the LazyValueInfo cache after JumpThreading"), cl::init(false), cl::Hidden); -static cl::opt<bool> JumpThreadingFreezeSelectCond( - "jump-threading-freeze-select-cond", - cl::desc("Freeze the condition when unfolding select"), cl::init(false), - cl::Hidden); - static cl::opt<bool> ThreadAcrossLoopHeaders( "jump-threading-across-loop-headers", cl::desc("Allow JumpThreading to thread across loop headers, for testing"), @@ -140,8 +133,7 @@ namespace { public: static char ID; // Pass identification - JumpThreading(bool InsertFreezeWhenUnfoldingSelect = false, int T = -1) - : FunctionPass(ID), Impl(InsertFreezeWhenUnfoldingSelect, T) { + JumpThreading(int T = -1) : FunctionPass(ID), Impl(T) { initializeJumpThreadingPass(*PassRegistry::getPassRegistry()); } @@ -175,12 +167,11 @@ INITIALIZE_PASS_END(JumpThreading, "jump-threading", "Jump Threading", false, false) // Public interface to the Jump Threading pass -FunctionPass *llvm::createJumpThreadingPass(bool InsertFr, int Threshold) { - return new JumpThreading(InsertFr, Threshold); +FunctionPass *llvm::createJumpThreadingPass(int Threshold) { + return new JumpThreading(Threshold); } -JumpThreadingPass::JumpThreadingPass(bool InsertFr, int T) { - InsertFreezeWhenUnfoldingSelect = JumpThreadingFreezeSelectCond | InsertFr; +JumpThreadingPass::JumpThreadingPass(int T) { DefaultBBDupThreshold = (T == -1) ? BBDuplicateThreshold : unsigned(T); } @@ -326,7 +317,7 @@ bool JumpThreading::runOnFunction(Function &F) { std::unique_ptr<BlockFrequencyInfo> BFI; std::unique_ptr<BranchProbabilityInfo> BPI; if (F.hasProfileData()) { - LoopInfo LI{DominatorTree(F)}; + LoopInfo LI{*DT}; BPI.reset(new BranchProbabilityInfo(F, LI, TLI)); BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); } @@ -491,14 +482,16 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, // at the end of block. RAUW unconditionally replaces all uses // including the guards/assumes themselves and the uses before the // guard/assume. -static void replaceFoldableUses(Instruction *Cond, Value *ToVal) { +static bool replaceFoldableUses(Instruction *Cond, Value *ToVal, + BasicBlock *KnownAtEndOfBB) { + bool Changed = false; assert(Cond->getType() == ToVal->getType()); - auto *BB = Cond->getParent(); // We can unconditionally replace all uses in non-local blocks (i.e. uses // strictly dominated by BB), since LVI information is true from the // terminator of BB. - replaceNonLocalUsesWith(Cond, ToVal); - for (Instruction &I : reverse(*BB)) { + if (Cond->getParent() == KnownAtEndOfBB) + Changed |= replaceNonLocalUsesWith(Cond, ToVal); + for (Instruction &I : reverse(*KnownAtEndOfBB)) { // Reached the Cond whose uses we are trying to replace, so there are no // more uses. if (&I == Cond) @@ -507,10 +500,13 @@ static void replaceFoldableUses(Instruction *Cond, Value *ToVal) { // of BB, where we know Cond is ToVal. if (!isGuaranteedToTransferExecutionToSuccessor(&I)) break; - I.replaceUsesOfWith(Cond, ToVal); + Changed |= I.replaceUsesOfWith(Cond, ToVal); } - if (Cond->use_empty() && !Cond->mayHaveSideEffects()) + if (Cond->use_empty() && !Cond->mayHaveSideEffects()) { Cond->eraseFromParent(); + Changed = true; + } + return Changed; } /// Return the cost of duplicating a piece of this block from first non-phi @@ -792,6 +788,7 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl( if (Preference != WantInteger) return false; if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) { + const DataLayout &DL = BO->getModule()->getDataLayout(); PredValueInfoTy LHSVals; computeValueKnownInPredecessorsImpl(BO->getOperand(0), BB, LHSVals, WantInteger, RecursionSet, CxtI); @@ -799,7 +796,8 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl( // Try to use constant folding to simplify the binary operator. for (const auto &LHSVal : LHSVals) { Constant *V = LHSVal.first; - Constant *Folded = ConstantExpr::get(BO->getOpcode(), V, CI); + Constant *Folded = + ConstantFoldBinaryOpOperands(BO->getOpcode(), V, CI, DL); if (Constant *KC = getKnownConstant(Folded, WantInteger)) Result.emplace_back(KC, LHSVal.second); @@ -835,7 +833,7 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl( LHS = CmpLHS->DoPHITranslation(BB, PredBB); RHS = PN->getIncomingValue(i); } - Value *Res = SimplifyCmpInst(Pred, LHS, RHS, {DL}); + Value *Res = simplifyCmpInst(Pred, LHS, RHS, {DL}); if (!Res) { if (!isa<Constant>(RHS)) continue; @@ -1135,34 +1133,21 @@ bool JumpThreadingPass::processBlock(BasicBlock *BB) { return ConstantFolded; } - if (CmpInst *CondCmp = dyn_cast<CmpInst>(CondInst)) { + // Some of the following optimization can safely work on the unfrozen cond. + Value *CondWithoutFreeze = CondInst; + if (auto *FI = dyn_cast<FreezeInst>(CondInst)) + CondWithoutFreeze = FI->getOperand(0); + + if (CmpInst *CondCmp = dyn_cast<CmpInst>(CondWithoutFreeze)) { // If we're branching on a conditional, LVI might be able to determine // it's value at the branch instruction. We only handle comparisons // against a constant at this time. - // TODO: This should be extended to handle switches as well. - BranchInst *CondBr = dyn_cast<BranchInst>(BB->getTerminator()); - Constant *CondConst = dyn_cast<Constant>(CondCmp->getOperand(1)); - if (CondBr && CondConst) { - // We should have returned as soon as we turn a conditional branch to - // unconditional. Because its no longer interesting as far as jump - // threading is concerned. - assert(CondBr->isConditional() && "Threading on unconditional terminator"); - + if (Constant *CondConst = dyn_cast<Constant>(CondCmp->getOperand(1))) { LazyValueInfo::Tristate Ret = LVI->getPredicateAt(CondCmp->getPredicate(), CondCmp->getOperand(0), - CondConst, CondBr, /*UseBlockValue=*/false); + CondConst, BB->getTerminator(), + /*UseBlockValue=*/false); if (Ret != LazyValueInfo::Unknown) { - unsigned ToRemove = Ret == LazyValueInfo::True ? 1 : 0; - unsigned ToKeep = Ret == LazyValueInfo::True ? 0 : 1; - BasicBlock *ToRemoveSucc = CondBr->getSuccessor(ToRemove); - ToRemoveSucc->removePredecessor(BB, true); - BranchInst *UncondBr = - BranchInst::Create(CondBr->getSuccessor(ToKeep), CondBr); - UncondBr->setDebugLoc(CondBr->getDebugLoc()); - ++NumFolds; - CondBr->eraseFromParent(); - if (CondCmp->use_empty()) - CondCmp->eraseFromParent(); // We can safely replace *some* uses of the CondInst if it has // exactly one value as returned by LVI. RAUW is incorrect in the // presence of guards and assumes, that have the `Cond` as the use. This @@ -1170,17 +1155,11 @@ bool JumpThreadingPass::processBlock(BasicBlock *BB) { // at the end of block, but RAUW unconditionally replaces all uses // including the guards/assumes themselves and the uses before the // guard/assume. - else if (CondCmp->getParent() == BB) { - auto *CI = Ret == LazyValueInfo::True ? - ConstantInt::getTrue(CondCmp->getType()) : - ConstantInt::getFalse(CondCmp->getType()); - replaceFoldableUses(CondCmp, CI); - } - DTU->applyUpdatesPermissive( - {{DominatorTree::Delete, BB, ToRemoveSucc}}); - if (HasProfileData) - BPI->eraseBlock(BB); - return true; + auto *CI = Ret == LazyValueInfo::True ? + ConstantInt::getTrue(CondCmp->getType()) : + ConstantInt::getFalse(CondCmp->getType()); + if (replaceFoldableUses(CondCmp, CI, BB)) + return true; } // We did not manage to simplify this branch, try to see whether @@ -1198,11 +1177,7 @@ bool JumpThreadingPass::processBlock(BasicBlock *BB) { // for loads that are used by a switch or by the condition for the branch. If // we see one, check to see if it's partially redundant. If so, insert a PHI // which can then be used to thread the values. - Value *SimplifyValue = CondInst; - - if (auto *FI = dyn_cast<FreezeInst>(SimplifyValue)) - // Look into freeze's operand - SimplifyValue = FI->getOperand(0); + Value *SimplifyValue = CondWithoutFreeze; if (CmpInst *CondCmp = dyn_cast<CmpInst>(SimplifyValue)) if (isa<Constant>(CondCmp->getOperand(1))) @@ -1227,10 +1202,7 @@ bool JumpThreadingPass::processBlock(BasicBlock *BB) { // If this is an otherwise-unfoldable branch on a phi node or freeze(phi) in // the current block, see if we can simplify. - PHINode *PN = dyn_cast<PHINode>( - isa<FreezeInst>(CondInst) ? cast<FreezeInst>(CondInst)->getOperand(0) - : CondInst); - + PHINode *PN = dyn_cast<PHINode>(CondWithoutFreeze); if (PN && PN->getParent() == BB && isa<BranchInst>(BB->getTerminator())) return processBranchOnPHI(PN); @@ -1253,6 +1225,17 @@ bool JumpThreadingPass::processImpliedCondition(BasicBlock *BB) { return false; Value *Cond = BI->getCondition(); + // Assuming that predecessor's branch was taken, if pred's branch condition + // (V) implies Cond, Cond can be either true, undef, or poison. In this case, + // freeze(Cond) is either true or a nondeterministic value. + // If freeze(Cond) has only one use, we can freely fold freeze(Cond) to true + // without affecting other instructions. + auto *FICond = dyn_cast<FreezeInst>(Cond); + if (FICond && FICond->hasOneUse()) + Cond = FICond->getOperand(0); + else + FICond = nullptr; + BasicBlock *CurrentBB = BB; BasicBlock *CurrentPred = BB->getSinglePredecessor(); unsigned Iter = 0; @@ -1269,6 +1252,15 @@ bool JumpThreadingPass::processImpliedCondition(BasicBlock *BB) { bool CondIsTrue = PBI->getSuccessor(0) == CurrentBB; Optional<bool> Implication = isImpliedCondition(PBI->getCondition(), Cond, DL, CondIsTrue); + + // If the branch condition of BB (which is Cond) and CurrentPred are + // exactly the same freeze instruction, Cond can be folded into CondIsTrue. + if (!Implication && FICond && isa<FreezeInst>(PBI->getCondition())) { + if (cast<FreezeInst>(PBI->getCondition())->getOperand(0) == + FICond->getOperand(0)) + Implication = CondIsTrue; + } + if (Implication) { BasicBlock *KeepSucc = BI->getSuccessor(*Implication ? 0 : 1); BasicBlock *RemoveSucc = BI->getSuccessor(*Implication ? 1 : 0); @@ -1277,6 +1269,9 @@ bool JumpThreadingPass::processImpliedCondition(BasicBlock *BB) { UncondBI->setDebugLoc(BI->getDebugLoc()); ++NumFolds; BI->eraseFromParent(); + if (FICond) + FICond->eraseFromParent(); + DTU->applyUpdatesPermissive({{DominatorTree::Delete, BB, RemoveSucc}}); if (HasProfileData) BPI->eraseBlock(BB); @@ -1338,10 +1333,10 @@ bool JumpThreadingPass::simplifyPartiallyRedundantLoad(LoadInst *LoadI) { combineMetadataForCSE(NLoadI, LoadI, false); }; - // If the returned value is the load itself, replace with an undef. This can + // If the returned value is the load itself, replace with poison. This can // only happen in dead loops. if (AvailableVal == LoadI) - AvailableVal = UndefValue::get(LoadI->getType()); + AvailableVal = PoisonValue::get(LoadI->getType()); if (AvailableVal->getType() != LoadI->getType()) AvailableVal = CastInst::CreateBitOrPointerCast( AvailableVal, LoadI->getType(), "", LoadI); @@ -1566,10 +1561,8 @@ findMostPopularDest(BasicBlock *BB, DestPopularity[PredToDest.second]++; // Find the most popular dest. - using VT = decltype(DestPopularity)::value_type; auto MostPopular = std::max_element( - DestPopularity.begin(), DestPopularity.end(), - [](const VT &L, const VT &R) { return L.second < R.second; }); + DestPopularity.begin(), DestPopularity.end(), llvm::less_second()); // Okay, we have finally picked the most popular destination. return MostPopular->first; @@ -1742,9 +1735,8 @@ bool JumpThreadingPass::processThreadableEdges(Value *Cond, BasicBlock *BB, // at the end of block, but RAUW unconditionally replaces all uses // including the guards/assumes themselves and the uses before the // guard/assume. - else if (OnlyVal && OnlyVal != MultipleVal && - CondInst->getParent() == BB) - replaceFoldableUses(CondInst, OnlyVal); + else if (OnlyVal && OnlyVal != MultipleVal) + replaceFoldableUses(CondInst, OnlyVal, BB); } return true; } @@ -2672,7 +2664,7 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred( // If this instruction can be simplified after the operands are updated, // just use the simplified value instead. This frequently happens due to // phi translation. - if (Value *IV = SimplifyInstruction( + if (Value *IV = simplifyInstruction( New, {BB->getModule()->getDataLayout(), TLI, nullptr, nullptr, New})) { ValueMapping[&*BI] = IV; @@ -2912,9 +2904,7 @@ bool JumpThreadingPass::tryToUnfoldSelectInCurrBB(BasicBlock *BB) { continue; // Expand the select. Value *Cond = SI->getCondition(); - if (InsertFreezeWhenUnfoldingSelect && - !isGuaranteedNotToBeUndefOrPoison(Cond, nullptr, SI, - &DTU->getDomTree())) + if (!isGuaranteedNotToBeUndefOrPoison(Cond, nullptr, SI)) Cond = new FreezeInst(Cond, "cond.fr", SI); Instruction *Term = SplitBlockAndInsertIfThen(Cond, SI, false); BasicBlock *SplitBB = SI->getParent(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LICM.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LICM.cpp index 6372ce19f8ee..492f4e40395a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LICM.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LICM.cpp @@ -37,29 +37,27 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LICM.h" +#include "llvm/ADT/PriorityWorklist.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AliasSetTracker.h" -#include "llvm/Analysis/BasicAliasAnalysis.h" -#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/ConstantFolding.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/LazyBlockFrequencyInfo.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" +#include "llvm/Analysis/LoopNestAnalysis.h" #include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/MustExecute.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" @@ -78,7 +76,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" @@ -88,6 +85,11 @@ #include <utility> using namespace llvm; +namespace llvm { +class BlockFrequencyInfo; +class LPMUpdater; +} // namespace llvm + #define DEBUG_TYPE "licm" STATISTIC(NumCreatedBlocks, "Number of blocks created"); @@ -114,8 +116,7 @@ static cl::opt<uint32_t> MaxNumUsesTraversed( // Experimental option to allow imprecision in LICM in pathological cases, in // exchange for faster compile. This is to be removed if MemorySSA starts to -// address the same issue. This flag applies only when LICM uses MemorySSA -// instead on AliasSetTracker. LICM calls MemorySSAWalker's +// address the same issue. LICM calls MemorySSAWalker's // getClobberingMemoryAccess, up to the value of the Cap, getting perfect // accuracy. Afterwards, LICM will call into MemorySSA's getDefiningAccess, // which may not be precise, since optimizeUses is capped. The result is @@ -143,35 +144,32 @@ static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop, bool LoopNestMode); static void hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, BasicBlock *Dest, ICFLoopSafetyInfo *SafetyInfo, - MemorySSAUpdater *MSSAU, ScalarEvolution *SE, + MemorySSAUpdater &MSSAU, ScalarEvolution *SE, OptimizationRemarkEmitter *ORE); static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, BlockFrequencyInfo *BFI, const Loop *CurLoop, - ICFLoopSafetyInfo *SafetyInfo, MemorySSAUpdater *MSSAU, + ICFLoopSafetyInfo *SafetyInfo, MemorySSAUpdater &MSSAU, OptimizationRemarkEmitter *ORE); static bool isSafeToExecuteUnconditionally( Instruction &Inst, const DominatorTree *DT, const TargetLibraryInfo *TLI, const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE, const Instruction *CtxI, bool AllowSpeculation); -static bool pointerInvalidatedByLoop(MemoryLocation MemLoc, - AliasSetTracker *CurAST, Loop *CurLoop, - AAResults *AA); -static bool pointerInvalidatedByLoopWithMSSA(MemorySSA *MSSA, MemoryUse *MU, - Loop *CurLoop, Instruction &I, - SinkAndHoistLICMFlags &Flags); -static bool pointerInvalidatedByBlockWithMSSA(BasicBlock &BB, MemorySSA &MSSA, - MemoryUse &MU); +static bool pointerInvalidatedByLoop(MemorySSA *MSSA, MemoryUse *MU, + Loop *CurLoop, Instruction &I, + SinkAndHoistLICMFlags &Flags); +static bool pointerInvalidatedByBlock(BasicBlock &BB, MemorySSA &MSSA, + MemoryUse &MU); static Instruction *cloneInstructionInExitBlock( Instruction &I, BasicBlock &ExitBlock, PHINode &PN, const LoopInfo *LI, - const LoopSafetyInfo *SafetyInfo, MemorySSAUpdater *MSSAU); + const LoopSafetyInfo *SafetyInfo, MemorySSAUpdater &MSSAU); static void eraseInstruction(Instruction &I, ICFLoopSafetyInfo &SafetyInfo, - MemorySSAUpdater *MSSAU); + MemorySSAUpdater &MSSAU); static void moveInstructionBefore(Instruction &I, Instruction &Dest, ICFLoopSafetyInfo &SafetyInfo, - MemorySSAUpdater *MSSAU, ScalarEvolution *SE); + MemorySSAUpdater &MSSAU, ScalarEvolution *SE); static void foreachMemoryAccess(MemorySSA *MSSA, Loop *L, function_ref<void(Instruction *)> Fn); @@ -268,8 +266,8 @@ PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM, // but ORE cannot be preserved (see comment before the pass definition). OptimizationRemarkEmitter ORE(L.getHeader()->getParent()); - LoopInvariantCodeMotion LICM(LicmMssaOptCap, LicmMssaNoAccForPromotionCap, - LicmAllowSpeculation); + LoopInvariantCodeMotion LICM(Opts.MssaOptCap, Opts.MssaNoAccForPromotionCap, + Opts.AllowSpeculation); if (!LICM.runOnLoop(&L, &AR.AA, &AR.LI, &AR.DT, AR.BFI, &AR.TLI, &AR.TTI, &AR.SE, AR.MSSA, &ORE)) return PreservedAnalyses::all(); @@ -283,6 +281,16 @@ PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM, return PA; } +void LICMPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<LICMPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + + OS << "<"; + OS << (Opts.AllowSpeculation ? "" : "no-") << "allowspeculation"; + OS << ">"; +} + PreservedAnalyses LNICMPass::run(LoopNest &LN, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { @@ -294,8 +302,8 @@ PreservedAnalyses LNICMPass::run(LoopNest &LN, LoopAnalysisManager &AM, // but ORE cannot be preserved (see comment before the pass definition). OptimizationRemarkEmitter ORE(LN.getParent()); - LoopInvariantCodeMotion LICM(LicmMssaOptCap, LicmMssaNoAccForPromotionCap, - LicmAllowSpeculation); + LoopInvariantCodeMotion LICM(Opts.MssaOptCap, Opts.MssaNoAccForPromotionCap, + Opts.AllowSpeculation); Loop &OutermostLoop = LN.getOutermostLoop(); bool Changed = LICM.runOnLoop(&OutermostLoop, &AR.AA, &AR.LI, &AR.DT, AR.BFI, @@ -313,6 +321,16 @@ PreservedAnalyses LNICMPass::run(LoopNest &LN, LoopAnalysisManager &AM, return PA; } +void LNICMPass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<LNICMPass> *>(this)->printPipeline( + OS, MapClassName2PassName); + + OS << "<"; + OS << (Opts.AllowSpeculation ? "" : "no-") << "allowspeculation"; + OS << ">"; +} + char LegacyLICMPass::ID = 0; INITIALIZE_PASS_BEGIN(LegacyLICMPass, "licm", "Loop Invariant Code Motion", false, false) @@ -372,6 +390,7 @@ bool LoopInvariantCodeMotion::runOnLoop( bool Changed = false; assert(L->isLCSSAForm(*DT) && "Loop is not in LCSSA form."); + MSSA->ensureOptimizedUses(); // If this loop has metadata indicating that LICM is not to be performed then // just exit. @@ -418,14 +437,14 @@ bool LoopInvariantCodeMotion::runOnLoop( if (L->hasDedicatedExits()) Changed |= LoopNestMode ? sinkRegionForLoopNest(DT->getNode(L->getHeader()), AA, LI, - DT, BFI, TLI, TTI, L, &MSSAU, + DT, BFI, TLI, TTI, L, MSSAU, &SafetyInfo, Flags, ORE) : sinkRegion(DT->getNode(L->getHeader()), AA, LI, DT, BFI, - TLI, TTI, L, &MSSAU, &SafetyInfo, Flags, ORE); + TLI, TTI, L, MSSAU, &SafetyInfo, Flags, ORE); Flags.setIsSink(false); if (Preheader) Changed |= hoistRegion(DT->getNode(L->getHeader()), AA, LI, DT, BFI, TLI, L, - &MSSAU, SE, &SafetyInfo, Flags, ORE, LoopNestMode, + MSSAU, SE, &SafetyInfo, Flags, ORE, LoopNestMode, LicmAllowSpeculation); // Now that all loop invariants have been removed from the loop, promote any @@ -459,8 +478,7 @@ bool LoopInvariantCodeMotion::runOnLoop( PredIteratorCache PIC; // Promoting one set of accesses may make the pointers for another set - // loop invariant, so run this in a loop (with the MaybePromotable set - // decreasing in size over time). + // loop invariant, so run this in a loop. bool Promoted = false; bool LocalPromoted; do { @@ -469,7 +487,7 @@ bool LoopInvariantCodeMotion::runOnLoop( collectPromotionCandidates(MSSA, AA, L)) { LocalPromoted |= promoteLoopAccessesToScalars( PointerMustAliases, ExitBlocks, InsertPts, MSSAInsertPts, PIC, LI, - DT, TLI, L, &MSSAU, &SafetyInfo, ORE, LicmAllowSpeculation); + DT, TLI, L, MSSAU, &SafetyInfo, ORE, LicmAllowSpeculation); } Promoted |= LocalPromoted; } while (LocalPromoted); @@ -510,17 +528,17 @@ bool LoopInvariantCodeMotion::runOnLoop( bool llvm::sinkRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, DominatorTree *DT, BlockFrequencyInfo *BFI, TargetLibraryInfo *TLI, TargetTransformInfo *TTI, - Loop *CurLoop, MemorySSAUpdater *MSSAU, + Loop *CurLoop, MemorySSAUpdater &MSSAU, ICFLoopSafetyInfo *SafetyInfo, SinkAndHoistLICMFlags &Flags, OptimizationRemarkEmitter *ORE, Loop *OutermostLoop) { // Verify inputs. assert(N != nullptr && AA != nullptr && LI != nullptr && DT != nullptr && - CurLoop != nullptr && MSSAU != nullptr && SafetyInfo != nullptr && + CurLoop != nullptr && SafetyInfo != nullptr && "Unexpected input to sinkRegion."); - // We want to visit children before parents. We will enque all the parents + // We want to visit children before parents. We will enqueue all the parents // before their children in the worklist and process the worklist in reverse // order. SmallVector<DomTreeNode *, 16> Worklist = collectChildrenInLoop(N, CurLoop); @@ -558,8 +576,7 @@ bool llvm::sinkRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, if (!I.mayHaveSideEffects() && isNotUsedOrFreeInLoop(I, LoopNestMode ? OutermostLoop : CurLoop, SafetyInfo, TTI, FreeInLoop, LoopNestMode) && - canSinkOrHoistInst(I, AA, DT, CurLoop, /*CurAST*/nullptr, MSSAU, true, - &Flags, ORE)) { + canSinkOrHoistInst(I, AA, DT, CurLoop, MSSAU, true, Flags, ORE)) { if (sink(I, LI, DT, BFI, CurLoop, SafetyInfo, MSSAU, ORE)) { if (!FreeInLoop) { ++II; @@ -572,14 +589,14 @@ bool llvm::sinkRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, } } if (VerifyMemorySSA) - MSSAU->getMemorySSA()->verifyMemorySSA(); + MSSAU.getMemorySSA()->verifyMemorySSA(); return Changed; } bool llvm::sinkRegionForLoopNest( DomTreeNode *N, AAResults *AA, LoopInfo *LI, DominatorTree *DT, BlockFrequencyInfo *BFI, TargetLibraryInfo *TLI, TargetTransformInfo *TTI, - Loop *CurLoop, MemorySSAUpdater *MSSAU, ICFLoopSafetyInfo *SafetyInfo, + Loop *CurLoop, MemorySSAUpdater &MSSAU, ICFLoopSafetyInfo *SafetyInfo, SinkAndHoistLICMFlags &Flags, OptimizationRemarkEmitter *ORE) { bool Changed = false; @@ -608,7 +625,7 @@ private: LoopInfo *LI; DominatorTree *DT; Loop *CurLoop; - MemorySSAUpdater *MSSAU; + MemorySSAUpdater &MSSAU; // A map of blocks in the loop to the block their instructions will be hoisted // to. @@ -620,7 +637,7 @@ private: public: ControlFlowHoister(LoopInfo *LI, DominatorTree *DT, Loop *CurLoop, - MemorySSAUpdater *MSSAU) + MemorySSAUpdater &MSSAU) : LI(LI), DT(DT), CurLoop(CurLoop), MSSAU(MSSAU) {} void registerPossiblyHoistableBranch(BranchInst *BI) { @@ -796,7 +813,7 @@ public: if (HoistTarget == InitialPreheader) { // Phis in the loop header now need to use the new preheader. InitialPreheader->replaceSuccessorsPhiUsesWith(HoistCommonSucc); - MSSAU->wireOldPredecessorsToNewImmediatePredecessor( + MSSAU.wireOldPredecessorsToNewImmediatePredecessor( HoistTarget->getSingleSuccessor(), HoistCommonSucc, {HoistTarget}); // The new preheader dominates the loop header. DomTreeNode *PreheaderNode = DT->getNode(HoistCommonSucc); @@ -830,14 +847,14 @@ public: bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, DominatorTree *DT, BlockFrequencyInfo *BFI, TargetLibraryInfo *TLI, Loop *CurLoop, - MemorySSAUpdater *MSSAU, ScalarEvolution *SE, + MemorySSAUpdater &MSSAU, ScalarEvolution *SE, ICFLoopSafetyInfo *SafetyInfo, SinkAndHoistLICMFlags &Flags, OptimizationRemarkEmitter *ORE, bool LoopNestMode, bool AllowSpeculation) { // Verify inputs. assert(N != nullptr && AA != nullptr && LI != nullptr && DT != nullptr && - CurLoop != nullptr && MSSAU != nullptr && SafetyInfo != nullptr && + CurLoop != nullptr && SafetyInfo != nullptr && "Unexpected input to hoistRegion."); ControlFlowHoister CFH(LI, DT, CurLoop, MSSAU); @@ -882,8 +899,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, // and we have accurately duplicated the control flow from the loop header // to that block. if (CurLoop->hasLoopInvariantOperands(&I) && - canSinkOrHoistInst(I, AA, DT, CurLoop, /*CurAST*/ nullptr, MSSAU, - true, &Flags, ORE) && + canSinkOrHoistInst(I, AA, DT, CurLoop, MSSAU, true, Flags, ORE) && isSafeToExecuteUnconditionally( I, DT, TLI, CurLoop, SafetyInfo, ORE, CurLoop->getLoopPreheader()->getTerminator(), AllowSpeculation)) { @@ -991,7 +1007,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, } } if (VerifyMemorySSA) - MSSAU->getMemorySSA()->verifyMemorySSA(); + MSSAU.getMemorySSA()->verifyMemorySSA(); // Now that we've finished hoisting make sure that LI and DT are still // valid. @@ -1092,30 +1108,19 @@ bool isHoistableAndSinkableInst(Instruction &I) { isa<ShuffleVectorInst>(I) || isa<ExtractValueInst>(I) || isa<InsertValueInst>(I) || isa<FreezeInst>(I)); } -/// Return true if all of the alias sets within this AST are known not to -/// contain a Mod, or if MSSA knows there are no MemoryDefs in the loop. -bool isReadOnly(AliasSetTracker *CurAST, const MemorySSAUpdater *MSSAU, - const Loop *L) { - if (CurAST) { - for (AliasSet &AS : *CurAST) { - if (!AS.isForwardingAliasSet() && AS.isMod()) { - return false; - } - } - return true; - } else { /*MSSAU*/ - for (auto *BB : L->getBlocks()) - if (MSSAU->getMemorySSA()->getBlockDefs(BB)) - return false; - return true; - } +/// Return true if MSSA knows there are no MemoryDefs in the loop. +bool isReadOnly(const MemorySSAUpdater &MSSAU, const Loop *L) { + for (auto *BB : L->getBlocks()) + if (MSSAU.getMemorySSA()->getBlockDefs(BB)) + return false; + return true; } /// Return true if I is the only Instruction with a MemoryAccess in L. bool isOnlyMemoryAccess(const Instruction *I, const Loop *L, - const MemorySSAUpdater *MSSAU) { + const MemorySSAUpdater &MSSAU) { for (auto *BB : L->getBlocks()) - if (auto *Accs = MSSAU->getMemorySSA()->getBlockAccesses(BB)) { + if (auto *Accs = MSSAU.getMemorySSA()->getBlockAccesses(BB)) { int NotAPhi = 0; for (const auto &Acc : *Accs) { if (isa<MemoryPhi>(&Acc)) @@ -1130,22 +1135,15 @@ bool isOnlyMemoryAccess(const Instruction *I, const Loop *L, } bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, - Loop *CurLoop, AliasSetTracker *CurAST, - MemorySSAUpdater *MSSAU, + Loop *CurLoop, MemorySSAUpdater &MSSAU, bool TargetExecutesOncePerLoop, - SinkAndHoistLICMFlags *Flags, + SinkAndHoistLICMFlags &Flags, OptimizationRemarkEmitter *ORE) { - assert(((CurAST != nullptr) ^ (MSSAU != nullptr)) && - "Either AliasSetTracker or MemorySSA should be initialized."); - // If we don't understand the instruction, bail early. if (!isHoistableAndSinkableInst(I)) return false; - MemorySSA *MSSA = MSSAU ? MSSAU->getMemorySSA() : nullptr; - if (MSSA) - assert(Flags != nullptr && "Flags cannot be null."); - + MemorySSA *MSSA = MSSAU.getMemorySSA(); // Loads have extra constraints we have to verify before we can hoist them. if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { if (!LI->isUnordered()) @@ -1165,13 +1163,8 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, if (isLoadInvariantInLoop(LI, DT, CurLoop)) return true; - bool Invalidated; - if (CurAST) - Invalidated = pointerInvalidatedByLoop(MemoryLocation::get(LI), CurAST, - CurLoop, AA); - else - Invalidated = pointerInvalidatedByLoopWithMSSA( - MSSA, cast<MemoryUse>(MSSA->getMemoryAccess(LI)), CurLoop, I, *Flags); + bool Invalidated = pointerInvalidatedByLoop( + MSSA, cast<MemoryUse>(MSSA->getMemoryAccess(LI)), CurLoop, I, Flags); // Check loop-invariant address because this may also be a sinkable load // whose address is not necessarily loop-invariant. if (ORE && Invalidated && CurLoop->isLoopInvariant(LI->getPointerOperand())) @@ -1219,24 +1212,17 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, if (AAResults::onlyAccessesArgPointees(Behavior)) { // TODO: expand to writeable arguments for (Value *Op : CI->args()) - if (Op->getType()->isPointerTy()) { - bool Invalidated; - if (CurAST) - Invalidated = pointerInvalidatedByLoop( - MemoryLocation::getBeforeOrAfter(Op), CurAST, CurLoop, AA); - else - Invalidated = pointerInvalidatedByLoopWithMSSA( + if (Op->getType()->isPointerTy() && + pointerInvalidatedByLoop( MSSA, cast<MemoryUse>(MSSA->getMemoryAccess(CI)), CurLoop, I, - *Flags); - if (Invalidated) - return false; - } + Flags)) + return false; return true; } // If this call only reads from memory and there are no writes to memory // in the loop, we can hoist or sink the call as appropriate. - if (isReadOnly(CurAST, MSSAU, CurLoop)) + if (isReadOnly(MSSAU, CurLoop)) return true; } @@ -1247,21 +1233,7 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, } else if (auto *FI = dyn_cast<FenceInst>(&I)) { // Fences alias (most) everything to provide ordering. For the moment, // just give up if there are any other memory operations in the loop. - if (CurAST) { - auto Begin = CurAST->begin(); - assert(Begin != CurAST->end() && "must contain FI"); - if (std::next(Begin) != CurAST->end()) - // constant memory for instance, TODO: handle better - return false; - auto *UniqueI = Begin->getUniqueInstruction(); - if (!UniqueI) - // other memory op, give up - return false; - (void)FI; // suppress unused variable warning - assert(UniqueI == FI && "AS must contain FI"); - return true; - } else // MSSAU - return isOnlyMemoryAccess(FI, CurLoop, MSSAU); + return isOnlyMemoryAccess(FI, CurLoop, MSSAU); } else if (auto *SI = dyn_cast<StoreInst>(&I)) { if (!SI->isUnordered()) return false; // Don't sink/hoist volatile or ordered atomic store! @@ -1271,68 +1243,54 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, // load store promotion instead. TODO: We can extend this to cases where // there is exactly one write to the location and that write dominates an // arbitrary number of reads in the loop. - if (CurAST) { - auto &AS = CurAST->getAliasSetFor(MemoryLocation::get(SI)); - - if (AS.isRef() || !AS.isMustAlias()) - // Quick exit test, handled by the full path below as well. - return false; - auto *UniqueI = AS.getUniqueInstruction(); - if (!UniqueI) - // other memory op, give up - return false; - assert(UniqueI == SI && "AS must contain SI"); + if (isOnlyMemoryAccess(SI, CurLoop, MSSAU)) return true; - } else { // MSSAU - if (isOnlyMemoryAccess(SI, CurLoop, MSSAU)) - return true; - // If there are more accesses than the Promotion cap or no "quota" to - // check clobber, then give up as we're not walking a list that long. - if (Flags->tooManyMemoryAccesses() || Flags->tooManyClobberingCalls()) - return false; - // If there are interfering Uses (i.e. their defining access is in the - // loop), or ordered loads (stored as Defs!), don't move this store. - // Could do better here, but this is conservatively correct. - // TODO: Cache set of Uses on the first walk in runOnLoop, update when - // moving accesses. Can also extend to dominating uses. - auto *SIMD = MSSA->getMemoryAccess(SI); - for (auto *BB : CurLoop->getBlocks()) - if (auto *Accesses = MSSA->getBlockAccesses(BB)) { - for (const auto &MA : *Accesses) - if (const auto *MU = dyn_cast<MemoryUse>(&MA)) { - auto *MD = MU->getDefiningAccess(); - if (!MSSA->isLiveOnEntryDef(MD) && - CurLoop->contains(MD->getBlock())) - return false; - // Disable hoisting past potentially interfering loads. Optimized - // Uses may point to an access outside the loop, as getClobbering - // checks the previous iteration when walking the backedge. - // FIXME: More precise: no Uses that alias SI. - if (!Flags->getIsSink() && !MSSA->dominates(SIMD, MU)) - return false; - } else if (const auto *MD = dyn_cast<MemoryDef>(&MA)) { - if (auto *LI = dyn_cast<LoadInst>(MD->getMemoryInst())) { - (void)LI; // Silence warning. - assert(!LI->isUnordered() && "Expected unordered load"); + // If there are more accesses than the Promotion cap or no "quota" to + // check clobber, then give up as we're not walking a list that long. + if (Flags.tooManyMemoryAccesses() || Flags.tooManyClobberingCalls()) + return false; + // If there are interfering Uses (i.e. their defining access is in the + // loop), or ordered loads (stored as Defs!), don't move this store. + // Could do better here, but this is conservatively correct. + // TODO: Cache set of Uses on the first walk in runOnLoop, update when + // moving accesses. Can also extend to dominating uses. + auto *SIMD = MSSA->getMemoryAccess(SI); + for (auto *BB : CurLoop->getBlocks()) + if (auto *Accesses = MSSA->getBlockAccesses(BB)) { + for (const auto &MA : *Accesses) + if (const auto *MU = dyn_cast<MemoryUse>(&MA)) { + auto *MD = MU->getDefiningAccess(); + if (!MSSA->isLiveOnEntryDef(MD) && + CurLoop->contains(MD->getBlock())) + return false; + // Disable hoisting past potentially interfering loads. Optimized + // Uses may point to an access outside the loop, as getClobbering + // checks the previous iteration when walking the backedge. + // FIXME: More precise: no Uses that alias SI. + if (!Flags.getIsSink() && !MSSA->dominates(SIMD, MU)) + return false; + } else if (const auto *MD = dyn_cast<MemoryDef>(&MA)) { + if (auto *LI = dyn_cast<LoadInst>(MD->getMemoryInst())) { + (void)LI; // Silence warning. + assert(!LI->isUnordered() && "Expected unordered load"); + return false; + } + // Any call, while it may not be clobbering SI, it may be a use. + if (auto *CI = dyn_cast<CallInst>(MD->getMemoryInst())) { + // Check if the call may read from the memory location written + // to by SI. Check CI's attributes and arguments; the number of + // such checks performed is limited above by NoOfMemAccTooLarge. + ModRefInfo MRI = AA->getModRefInfo(CI, MemoryLocation::get(SI)); + if (isModOrRefSet(MRI)) return false; - } - // Any call, while it may not be clobbering SI, it may be a use. - if (auto *CI = dyn_cast<CallInst>(MD->getMemoryInst())) { - // Check if the call may read from the memory location written - // to by SI. Check CI's attributes and arguments; the number of - // such checks performed is limited above by NoOfMemAccTooLarge. - ModRefInfo MRI = AA->getModRefInfo(CI, MemoryLocation::get(SI)); - if (isModOrRefSet(MRI)) - return false; - } } - } - auto *Source = MSSA->getSkipSelfWalker()->getClobberingMemoryAccess(SI); - Flags->incrementClobberingCalls(); - // If there are no clobbering Defs in the loop, store is safe to hoist. - return MSSA->isLiveOnEntryDef(Source) || - !CurLoop->contains(Source->getBlock()); - } + } + } + auto *Source = MSSA->getSkipSelfWalker()->getClobberingMemoryAccess(SI); + Flags.incrementClobberingCalls(); + // If there are no clobbering Defs in the loop, store is safe to hoist. + return MSSA->isLiveOnEntryDef(Source) || + !CurLoop->contains(Source->getBlock()); } assert(!I.mayReadOrWriteMemory() && "unhandled aliasing"); @@ -1430,7 +1388,7 @@ static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop, static Instruction *cloneInstructionInExitBlock( Instruction &I, BasicBlock &ExitBlock, PHINode &PN, const LoopInfo *LI, - const LoopSafetyInfo *SafetyInfo, MemorySSAUpdater *MSSAU) { + const LoopSafetyInfo *SafetyInfo, MemorySSAUpdater &MSSAU) { Instruction *New; if (auto *CI = dyn_cast<CallInst>(&I)) { const auto &BlockColors = SafetyInfo->getBlockColors(); @@ -1466,16 +1424,16 @@ static Instruction *cloneInstructionInExitBlock( if (!I.getName().empty()) New->setName(I.getName() + ".le"); - if (MSSAU && MSSAU->getMemorySSA()->getMemoryAccess(&I)) { + if (MSSAU.getMemorySSA()->getMemoryAccess(&I)) { // Create a new MemoryAccess and let MemorySSA set its defining access. - MemoryAccess *NewMemAcc = MSSAU->createMemoryAccessInBB( + MemoryAccess *NewMemAcc = MSSAU.createMemoryAccessInBB( New, nullptr, New->getParent(), MemorySSA::Beginning); if (NewMemAcc) { if (auto *MemDef = dyn_cast<MemoryDef>(NewMemAcc)) - MSSAU->insertDef(MemDef, /*RenameUses=*/true); + MSSAU.insertDef(MemDef, /*RenameUses=*/true); else { auto *MemUse = cast<MemoryUse>(NewMemAcc); - MSSAU->insertUse(MemUse, /*RenameUses=*/true); + MSSAU.insertUse(MemUse, /*RenameUses=*/true); } } } @@ -1501,25 +1459,22 @@ static Instruction *cloneInstructionInExitBlock( } static void eraseInstruction(Instruction &I, ICFLoopSafetyInfo &SafetyInfo, - MemorySSAUpdater *MSSAU) { - if (MSSAU) - MSSAU->removeMemoryAccess(&I); + MemorySSAUpdater &MSSAU) { + MSSAU.removeMemoryAccess(&I); SafetyInfo.removeInstruction(&I); I.eraseFromParent(); } static void moveInstructionBefore(Instruction &I, Instruction &Dest, ICFLoopSafetyInfo &SafetyInfo, - MemorySSAUpdater *MSSAU, + MemorySSAUpdater &MSSAU, ScalarEvolution *SE) { SafetyInfo.removeInstruction(&I); SafetyInfo.insertInstructionTo(&I, Dest.getParent()); I.moveBefore(&Dest); - if (MSSAU) - if (MemoryUseOrDef *OldMemAcc = cast_or_null<MemoryUseOrDef>( - MSSAU->getMemorySSA()->getMemoryAccess(&I))) - MSSAU->moveToPlace(OldMemAcc, Dest.getParent(), - MemorySSA::BeforeTerminator); + if (MemoryUseOrDef *OldMemAcc = cast_or_null<MemoryUseOrDef>( + MSSAU.getMemorySSA()->getMemoryAccess(&I))) + MSSAU.moveToPlace(OldMemAcc, Dest.getParent(), MemorySSA::BeforeTerminator); if (SE) SE->forgetValue(&I); } @@ -1528,7 +1483,7 @@ static Instruction *sinkThroughTriviallyReplaceablePHI( PHINode *TPN, Instruction *I, LoopInfo *LI, SmallDenseMap<BasicBlock *, Instruction *, 32> &SunkCopies, const LoopSafetyInfo *SafetyInfo, const Loop *CurLoop, - MemorySSAUpdater *MSSAU) { + MemorySSAUpdater &MSSAU) { assert(isTriviallyReplaceablePHI(*TPN, *I) && "Expect only trivially replaceable PHI"); BasicBlock *ExitBlock = TPN->getParent(); @@ -1634,7 +1589,7 @@ static void splitPredecessorsOfLoopExit(PHINode *PN, DominatorTree *DT, /// static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, BlockFrequencyInfo *BFI, const Loop *CurLoop, - ICFLoopSafetyInfo *SafetyInfo, MemorySSAUpdater *MSSAU, + ICFLoopSafetyInfo *SafetyInfo, MemorySSAUpdater &MSSAU, OptimizationRemarkEmitter *ORE) { bool Changed = false; LLVM_DEBUG(dbgs() << "LICM sinking instruction: " << I << "\n"); @@ -1651,7 +1606,7 @@ static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, continue; if (!DT->isReachableFromEntry(User->getParent())) { - U = UndefValue::get(I.getType()); + U = PoisonValue::get(I.getType()); Changed = true; continue; } @@ -1664,7 +1619,7 @@ static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, // unreachable. BasicBlock *BB = PN->getIncomingBlock(U); if (!DT->isReachableFromEntry(BB)) { - U = UndefValue::get(I.getType()); + U = PoisonValue::get(I.getType()); Changed = true; continue; } @@ -1678,7 +1633,7 @@ static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, // Split predecessors of the PHI so that we can make users trivially // replaceable. - splitPredecessorsOfLoopExit(PN, DT, LI, CurLoop, SafetyInfo, MSSAU); + splitPredecessorsOfLoopExit(PN, DT, LI, CurLoop, SafetyInfo, &MSSAU); // Should rebuild the iterators, as they may be invalidated by // splitPredecessorsOfLoopExit(). @@ -1729,7 +1684,7 @@ static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, Instruction *New = sinkThroughTriviallyReplaceablePHI( PN, &I, LI, SunkCopies, SafetyInfo, CurLoop, MSSAU); PN->replaceAllUsesWith(New); - eraseInstruction(*PN, *SafetyInfo, nullptr); + eraseInstruction(*PN, *SafetyInfo, MSSAU); Changed = true; } return Changed; @@ -1740,7 +1695,7 @@ static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, /// static void hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, BasicBlock *Dest, ICFLoopSafetyInfo *SafetyInfo, - MemorySSAUpdater *MSSAU, ScalarEvolution *SE, + MemorySSAUpdater &MSSAU, ScalarEvolution *SE, OptimizationRemarkEmitter *ORE) { LLVM_DEBUG(dbgs() << "LICM hoisting to " << Dest->getNameOrAsOperand() << ": " << I << "\n"); @@ -1816,7 +1771,7 @@ class LoopPromoter : public LoadAndStorePromoter { SmallVectorImpl<Instruction *> &LoopInsertPts; SmallVectorImpl<MemoryAccess *> &MSSAInsertPts; PredIteratorCache &PredCache; - MemorySSAUpdater *MSSAU; + MemorySSAUpdater &MSSAU; LoopInfo &LI; DebugLoc DL; Align Alignment; @@ -1848,7 +1803,7 @@ public: SmallVectorImpl<BasicBlock *> &LEB, SmallVectorImpl<Instruction *> &LIP, SmallVectorImpl<MemoryAccess *> &MSSAIP, PredIteratorCache &PIC, - MemorySSAUpdater *MSSAU, LoopInfo &li, DebugLoc dl, + MemorySSAUpdater &MSSAU, LoopInfo &li, DebugLoc dl, Align Alignment, bool UnorderedAtomic, const AAMDNodes &AATags, ICFLoopSafetyInfo &SafetyInfo, bool CanInsertStoresInExitBlocks) : LoadAndStorePromoter(Insts, S), SomePtr(SP), PointerMustAliases(PMA), @@ -1890,14 +1845,14 @@ public: MemoryAccess *MSSAInsertPoint = MSSAInsertPts[i]; MemoryAccess *NewMemAcc; if (!MSSAInsertPoint) { - NewMemAcc = MSSAU->createMemoryAccessInBB( + NewMemAcc = MSSAU.createMemoryAccessInBB( NewSI, nullptr, NewSI->getParent(), MemorySSA::Beginning); } else { NewMemAcc = - MSSAU->createMemoryAccessAfter(NewSI, nullptr, MSSAInsertPoint); + MSSAU.createMemoryAccessAfter(NewSI, nullptr, MSSAInsertPoint); } MSSAInsertPts[i] = NewMemAcc; - MSSAU->insertDef(cast<MemoryDef>(NewMemAcc), true); + MSSAU.insertDef(cast<MemoryDef>(NewMemAcc), true); // FIXME: true for safety, false may still be correct. } } @@ -1909,7 +1864,7 @@ public: void instructionDeleted(Instruction *I) const override { SafetyInfo.removeInstruction(I); - MSSAU->removeMemoryAccess(I); + MSSAU.removeMemoryAccess(I); } bool shouldDelete(Instruction *I) const override { @@ -1955,7 +1910,7 @@ bool llvm::promoteLoopAccessesToScalars( SmallVectorImpl<Instruction *> &InsertPts, SmallVectorImpl<MemoryAccess *> &MSSAInsertPts, PredIteratorCache &PIC, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, - Loop *CurLoop, MemorySSAUpdater *MSSAU, ICFLoopSafetyInfo *SafetyInfo, + Loop *CurLoop, MemorySSAUpdater &MSSAU, ICFLoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE, bool AllowSpeculation) { // Verify inputs. assert(LI != nullptr && DT != nullptr && CurLoop != nullptr && @@ -2004,6 +1959,7 @@ bool llvm::promoteLoopAccessesToScalars( bool DereferenceableInPH = false; bool SafeToInsertStore = false; + bool StoreIsGuanteedToExecute = false; bool FoundLoadToPromote = false; SmallVector<Instruction *, 64> LoopUses; @@ -2038,9 +1994,9 @@ bool llvm::promoteLoopAccessesToScalars( // different sizes. While we are at it, collect alignment and AA info. Type *AccessTy = nullptr; for (Value *ASIV : PointerMustAliases) { - for (User *U : ASIV->users()) { + for (Use &U : ASIV->uses()) { // Ignore instructions that are outside the loop. - Instruction *UI = dyn_cast<Instruction>(U); + Instruction *UI = dyn_cast<Instruction>(U.getUser()); if (!UI || !CurLoop->contains(UI)) continue; @@ -2070,7 +2026,7 @@ bool llvm::promoteLoopAccessesToScalars( } else if (const StoreInst *Store = dyn_cast<StoreInst>(UI)) { // Stores *of* the pointer are not interesting, only stores *to* the // pointer. - if (UI->getOperand(1) != ASIV) + if (U.getOperandNo() != StoreInst::getPointerOperandIndex()) continue; if (!Store->isUnordered()) return false; @@ -2084,10 +2040,12 @@ bool llvm::promoteLoopAccessesToScalars( // alignment than any other guaranteed stores, in which case we can // raise the alignment on the promoted store. Align InstAlignment = Store->getAlign(); - + bool GuaranteedToExecute = + SafetyInfo->isGuaranteedToExecute(*UI, DT, CurLoop); + StoreIsGuanteedToExecute |= GuaranteedToExecute; if (!DereferenceableInPH || !SafeToInsertStore || (InstAlignment > Alignment)) { - if (SafetyInfo->isGuaranteedToExecute(*UI, DT, CurLoop)) { + if (GuaranteedToExecute) { DereferenceableInPH = true; SafeToInsertStore = true; Alignment = std::max(Alignment, InstAlignment); @@ -2201,32 +2159,37 @@ bool llvm::promoteLoopAccessesToScalars( // Set up the preheader to have a definition of the value. It is the live-out // value from the preheader that uses in the loop will use. - LoadInst *PreheaderLoad = new LoadInst( - AccessTy, SomePtr, SomePtr->getName() + ".promoted", - Preheader->getTerminator()); - if (SawUnorderedAtomic) - PreheaderLoad->setOrdering(AtomicOrdering::Unordered); - PreheaderLoad->setAlignment(Alignment); - PreheaderLoad->setDebugLoc(DebugLoc()); - if (AATags) - PreheaderLoad->setAAMetadata(AATags); - SSA.AddAvailableValue(Preheader, PreheaderLoad); - - MemoryAccess *PreheaderLoadMemoryAccess = MSSAU->createMemoryAccessInBB( - PreheaderLoad, nullptr, PreheaderLoad->getParent(), MemorySSA::End); - MemoryUse *NewMemUse = cast<MemoryUse>(PreheaderLoadMemoryAccess); - MSSAU->insertUse(NewMemUse, /*RenameUses=*/true); + LoadInst *PreheaderLoad = nullptr; + if (FoundLoadToPromote || !StoreIsGuanteedToExecute) { + PreheaderLoad = + new LoadInst(AccessTy, SomePtr, SomePtr->getName() + ".promoted", + Preheader->getTerminator()); + if (SawUnorderedAtomic) + PreheaderLoad->setOrdering(AtomicOrdering::Unordered); + PreheaderLoad->setAlignment(Alignment); + PreheaderLoad->setDebugLoc(DebugLoc()); + if (AATags) + PreheaderLoad->setAAMetadata(AATags); + + MemoryAccess *PreheaderLoadMemoryAccess = MSSAU.createMemoryAccessInBB( + PreheaderLoad, nullptr, PreheaderLoad->getParent(), MemorySSA::End); + MemoryUse *NewMemUse = cast<MemoryUse>(PreheaderLoadMemoryAccess); + MSSAU.insertUse(NewMemUse, /*RenameUses=*/true); + SSA.AddAvailableValue(Preheader, PreheaderLoad); + } else { + SSA.AddAvailableValue(Preheader, PoisonValue::get(AccessTy)); + } if (VerifyMemorySSA) - MSSAU->getMemorySSA()->verifyMemorySSA(); + MSSAU.getMemorySSA()->verifyMemorySSA(); // Rewrite all the loads in the loop and remember all the definitions from // stores in the loop. Promoter.run(LoopUses); if (VerifyMemorySSA) - MSSAU->getMemorySSA()->verifyMemorySSA(); + MSSAU.getMemorySSA()->verifyMemorySSA(); // If the SSAUpdater didn't use the load in the preheader, just zap it now. - if (PreheaderLoad->use_empty()) + if (PreheaderLoad && PreheaderLoad->use_empty()) eraseInstruction(*PreheaderLoad, *SafetyInfo, MSSAU); return true; @@ -2253,8 +2216,7 @@ collectPromotionCandidates(MemorySSA *MSSA, AliasAnalysis *AA, Loop *L) { return false; }; - // Populate AST with potentially promotable accesses and remove them from - // MaybePromotable, so they will not be checked again on the next iteration. + // Populate AST with potentially promotable accesses. SmallPtrSet<Value *, 16> AttemptingPromotion; foreachMemoryAccess(MSSA, L, [&](Instruction *I) { if (IsPotentiallyPromotable(I)) { @@ -2293,15 +2255,9 @@ collectPromotionCandidates(MemorySSA *MSSA, AliasAnalysis *AA, Loop *L) { return Result; } -static bool pointerInvalidatedByLoop(MemoryLocation MemLoc, - AliasSetTracker *CurAST, Loop *CurLoop, - AAResults *AA) { - return CurAST->getAliasSetFor(MemLoc).isMod(); -} - -bool pointerInvalidatedByLoopWithMSSA(MemorySSA *MSSA, MemoryUse *MU, - Loop *CurLoop, Instruction &I, - SinkAndHoistLICMFlags &Flags) { +static bool pointerInvalidatedByLoop(MemorySSA *MSSA, MemoryUse *MU, + Loop *CurLoop, Instruction &I, + SinkAndHoistLICMFlags &Flags) { // For hoisting, use the walker to determine safety if (!Flags.getIsSink()) { MemoryAccess *Source; @@ -2336,17 +2292,16 @@ bool pointerInvalidatedByLoopWithMSSA(MemorySSA *MSSA, MemoryUse *MU, if (Flags.tooManyMemoryAccesses()) return true; for (auto *BB : CurLoop->getBlocks()) - if (pointerInvalidatedByBlockWithMSSA(*BB, *MSSA, *MU)) + if (pointerInvalidatedByBlock(*BB, *MSSA, *MU)) return true; // When sinking, the source block may not be part of the loop so check it. if (!CurLoop->contains(&I)) - return pointerInvalidatedByBlockWithMSSA(*I.getParent(), *MSSA, *MU); + return pointerInvalidatedByBlock(*I.getParent(), *MSSA, *MU); return false; } -bool pointerInvalidatedByBlockWithMSSA(BasicBlock &BB, MemorySSA &MSSA, - MemoryUse &MU) { +bool pointerInvalidatedByBlock(BasicBlock &BB, MemorySSA &MSSA, MemoryUse &MU) { if (const auto *Accesses = MSSA.getBlockDefs(&BB)) for (const auto &MA : *Accesses) if (const auto *MD = dyn_cast<MemoryDef>(&MA)) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp index 1c3ff1a61b7e..c063c0d3c88a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp @@ -8,6 +8,7 @@ #include "llvm/Transforms/Scalar/LoopAccessAnalysisPrinter.h" #include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" using namespace llvm; #define DEBUG_TYPE "loop-accesses" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp index d438d56e38ca..2b9800f11912 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp @@ -8,20 +8,15 @@ #include "llvm/Transforms/Scalar/LoopBoundSplit.h" #include "llvm/ADT/Sequence.h" -#include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/LoopIterator.h" -#include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/MemorySSA.h" -#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/LoopSimplify.h" -#include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #define DEBUG_TYPE "loop-bound-split" @@ -33,26 +28,23 @@ using namespace PatternMatch; namespace { struct ConditionInfo { /// Branch instruction with this condition - BranchInst *BI; + BranchInst *BI = nullptr; /// ICmp instruction with this condition - ICmpInst *ICmp; + ICmpInst *ICmp = nullptr; /// Preciate info - ICmpInst::Predicate Pred; + ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; /// AddRec llvm value - Value *AddRecValue; + Value *AddRecValue = nullptr; /// Non PHI AddRec llvm value Value *NonPHIAddRecValue; /// Bound llvm value - Value *BoundValue; + Value *BoundValue = nullptr; /// AddRec SCEV - const SCEVAddRecExpr *AddRecSCEV; + const SCEVAddRecExpr *AddRecSCEV = nullptr; /// Bound SCEV - const SCEV *BoundSCEV; + const SCEV *BoundSCEV = nullptr; - ConditionInfo() - : BI(nullptr), ICmp(nullptr), Pred(ICmpInst::BAD_ICMP_PREDICATE), - AddRecValue(nullptr), BoundValue(nullptr), AddRecSCEV(nullptr), - BoundSCEV(nullptr) {} + ConditionInfo() = default; }; } // namespace diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp index 57e36e5b9b90..9590fbbb1994 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp @@ -22,7 +22,6 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/IR/CFG.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/Module.h" @@ -30,9 +29,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" -#include "llvm/Transforms/Utils/ValueMapper.h" #define DEBUG_TYPE "loop-data-prefetch" @@ -236,15 +233,14 @@ struct Prefetch { /// The address formula for this prefetch as returned by ScalarEvolution. const SCEVAddRecExpr *LSCEVAddRec; /// The point of insertion for the prefetch instruction. - Instruction *InsertPt; + Instruction *InsertPt = nullptr; /// True if targeting a write memory access. - bool Writes; + bool Writes = false; /// The (first seen) prefetched instruction. - Instruction *MemI; + Instruction *MemI = nullptr; /// Constructor to create a new Prefetch for \p I. - Prefetch(const SCEVAddRecExpr *L, Instruction *I) - : LSCEVAddRec(L), InsertPt(nullptr), Writes(false), MemI(nullptr) { + Prefetch(const SCEVAddRecExpr *L, Instruction *I) : LSCEVAddRec(L) { addInstruction(I); }; @@ -303,7 +299,11 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { } Metrics.analyzeBasicBlock(BB, *TTI, EphValues); } - unsigned LoopSize = Metrics.NumInsts; + + if (!Metrics.NumInsts.isValid()) + return MadeChange; + + unsigned LoopSize = *Metrics.NumInsts.getValue(); if (!LoopSize) LoopSize = 1; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDeletion.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDeletion.cpp index 361d6c0d9381..93f3cd704196 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDeletion.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDeletion.cpp @@ -17,12 +17,12 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/CFG.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/PatternMatch.h" @@ -192,13 +192,13 @@ getValueOnFirstIteration(Value *V, DenseMap<Value *, Value *> &FirstIterValue, getValueOnFirstIteration(BO->getOperand(0), FirstIterValue, SQ); Value *RHS = getValueOnFirstIteration(BO->getOperand(1), FirstIterValue, SQ); - FirstIterV = SimplifyBinOp(BO->getOpcode(), LHS, RHS, SQ); + FirstIterV = simplifyBinOp(BO->getOpcode(), LHS, RHS, SQ); } else if (auto *Cmp = dyn_cast<ICmpInst>(V)) { Value *LHS = getValueOnFirstIteration(Cmp->getOperand(0), FirstIterValue, SQ); Value *RHS = getValueOnFirstIteration(Cmp->getOperand(1), FirstIterValue, SQ); - FirstIterV = SimplifyICmpInst(Cmp->getPredicate(), LHS, RHS, SQ); + FirstIterV = simplifyICmpInst(Cmp->getPredicate(), LHS, RHS, SQ); } else if (auto *Select = dyn_cast<SelectInst>(V)) { Value *Cond = getValueOnFirstIteration(Select->getCondition(), FirstIterValue, SQ); @@ -458,13 +458,13 @@ static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT, if (ExitBlock && isLoopNeverExecuted(L)) { LLVM_DEBUG(dbgs() << "Loop is proven to never execute, delete it!"); // We need to forget the loop before setting the incoming values of the exit - // phis to undef, so we properly invalidate the SCEV expressions for those + // phis to poison, so we properly invalidate the SCEV expressions for those // phis. SE.forgetLoop(L); - // Set incoming value to undef for phi nodes in the exit block. + // Set incoming value to poison for phi nodes in the exit block. for (PHINode &P : ExitBlock->phis()) { std::fill(P.incoming_values().begin(), P.incoming_values().end(), - UndefValue::get(P.getType())); + PoisonValue::get(P.getType())); } ORE.emit([&]() { return OptimizationRemark(DEBUG_TYPE, "NeverExecutes", L->getStartLoc(), diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDistribute.cpp index 0f4c767c1e4c..03a10cb36bb6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -47,7 +47,6 @@ #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" -#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" @@ -231,7 +230,7 @@ public: // having to update as many def-use and use-def chains. for (auto *Inst : reverse(Unused)) { if (!Inst->use_empty()) - Inst->replaceAllUsesWith(UndefValue::get(Inst->getType())); + Inst->replaceAllUsesWith(PoisonValue::get(Inst->getType())); Inst->eraseFromParent(); } } @@ -601,7 +600,7 @@ private: {LLVMLoopDistributeFollowupAll, Part->hasDepCycle() ? LLVMLoopDistributeFollowupSequential : LLVMLoopDistributeFollowupCoincident}); - if (PartitionID.hasValue()) { + if (PartitionID) { Loop *NewLoop = Part->getDistributedLoop(); NewLoop->setLoopID(PartitionID.getValue()); } @@ -770,19 +769,19 @@ public: // Don't distribute the loop if we need too many SCEV run-time checks, or // any if it's illegal. - const SCEVUnionPredicate &Pred = LAI->getPSE().getUnionPredicate(); + const SCEVPredicate &Pred = LAI->getPSE().getPredicate(); if (LAI->hasConvergentOp() && !Pred.isAlwaysTrue()) { return fail("RuntimeCheckWithConvergent", "may not insert runtime check with convergent operation"); } - if (Pred.getComplexity() > (IsForced.getValueOr(false) + if (Pred.getComplexity() > (IsForced.value_or(false) ? PragmaDistributeSCEVCheckThreshold : DistributeSCEVCheckThreshold)) return fail("TooManySCEVRuntimeChecks", "too many SCEV run-time checks needed.\n"); - if (!IsForced.getValueOr(false) && hasDisableAllTransformsHint(L)) + if (!IsForced.value_or(false) && hasDisableAllTransformsHint(L)) return fail("HeuristicDisabled", "distribution heuristic disabled"); LLVM_DEBUG(dbgs() << "\nDistributing loop: " << *L << "\n"); @@ -859,7 +858,7 @@ public: /// Provide diagnostics then \return with false. bool fail(StringRef RemarkName, StringRef Message) { LLVMContext &Ctx = F->getContext(); - bool Forced = isForced().getValueOr(false); + bool Forced = isForced().value_or(false); LLVM_DEBUG(dbgs() << "Skipping; " << Message << "\n"); @@ -991,7 +990,7 @@ static bool runImpl(Function &F, LoopInfo *LI, DominatorTree *DT, // If distribution was forced for the specific loop to be // enabled/disabled, follow that. Otherwise use the global flag. - if (LDL.isForced().getValueOr(EnableLoopDistribute)) + if (LDL.isForced().value_or(EnableLoopDistribute)) Changed |= LDL.processLoop(GetLAA); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFlatten.cpp index c46db4e63bfe..f36193fc468e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -54,6 +54,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopNestAnalysis.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -64,12 +65,12 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/IR/Verifier.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" @@ -210,8 +211,9 @@ struct FlattenInfo { if (!MatchedItCount) return false; - // Look through extends if the IV has been widened. - if (Widened && + // Look through extends if the IV has been widened. Don't look through + // extends if we already looked through a trunc. + if (Widened && IsAdd && (isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) { assert(MatchedItCount->getType() == InnerInductionPHI->getType() && "Unexpected type mismatch in types after widening"); @@ -410,7 +412,7 @@ static bool findLoopComponents( // pre-header and one from the latch. The incoming latch value is the // increment variable. Increment = - dyn_cast<BinaryOperator>(InductionPHI->getIncomingValueForBlock(Latch)); + cast<BinaryOperator>(InductionPHI->getIncomingValueForBlock(Latch)); if (Increment->hasNUsesOrMore(3)) { LLVM_DEBUG(dbgs() << "Could not find valid increment\n"); return false; @@ -921,7 +923,7 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, // this pass will simplify all loops that contain inner loops, // regardless of whether anything ends up being flattened. Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U, - MSSAU.hasValue() ? MSSAU.getPointer() : nullptr); + MSSAU ? MSSAU.getPointer() : nullptr); if (!Changed) return PreservedAnalyses::all(); @@ -987,7 +989,7 @@ bool LoopFlattenLegacyPass::runOnFunction(Function &F) { for (Loop *L : *LI) { auto LN = LoopNest::getLoopNest(*L, *SE); Changed |= Flatten(*LN, DT, LI, SE, AC, TTI, nullptr, - MSSAU.hasValue() ? MSSAU.getPointer() : nullptr); + MSSAU ? MSSAU.getPointer() : nullptr); } return Changed; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFuse.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFuse.cpp index bf4d275e04ba..d94b767c7b63 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFuse.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFuse.cpp @@ -117,7 +117,7 @@ static cl::opt<FusionDependenceAnalysisChoice> FusionDependenceAnalysis( "Use the dependence analysis interface"), clEnumValN(FUSION_DEPENDENCE_ANALYSIS_ALL, "all", "Use all available analyses")), - cl::Hidden, cl::init(FUSION_DEPENDENCE_ANALYSIS_ALL), cl::ZeroOrMore); + cl::Hidden, cl::init(FUSION_DEPENDENCE_ANALYSIS_ALL)); static cl::opt<unsigned> FusionPeelMaxCount( "loop-fusion-peel-max-count", cl::init(0), cl::Hidden, @@ -128,7 +128,7 @@ static cl::opt<unsigned> FusionPeelMaxCount( static cl::opt<bool> VerboseFusionDebugging("loop-fusion-verbose-debug", cl::desc("Enable verbose debugging for Loop Fusion"), - cl::Hidden, cl::init(false), cl::ZeroOrMore); + cl::Hidden, cl::init(false)); #endif namespace { @@ -178,12 +178,12 @@ struct FusionCandidate { /// FusionCandidateCompare function, required by FusionCandidateSet to /// determine where the FusionCandidate should be inserted into the set. These /// are used to establish ordering of the FusionCandidates based on dominance. - const DominatorTree *DT; + DominatorTree &DT; const PostDominatorTree *PDT; OptimizationRemarkEmitter &ORE; - FusionCandidate(Loop *L, const DominatorTree *DT, + FusionCandidate(Loop *L, DominatorTree &DT, const PostDominatorTree *PDT, OptimizationRemarkEmitter &ORE, TTI::PeelingPreferences PP) : Preheader(L->getLoopPreheader()), Header(L->getHeader()), @@ -192,7 +192,6 @@ struct FusionCandidate { GuardBranch(L->getLoopGuardBranch()), PP(PP), AbleToPeel(canPeel(L)), Peeled(false), DT(DT), PDT(PDT), ORE(ORE) { - assert(DT && "Expected non-null DT!"); // Walk over all blocks in the loop and check for conditions that may // prevent fusion. For each block, walk over all instructions and collect // the memory reads and writes If any instructions that prevent fusion are @@ -391,7 +390,7 @@ struct FusionCandidateCompare { /// IF RHS dominates LHS and LHS post-dominates RHS, return false; bool operator()(const FusionCandidate &LHS, const FusionCandidate &RHS) const { - const DominatorTree *DT = LHS.DT; + const DominatorTree *DT = &(LHS.DT); BasicBlock *LHSEntryBlock = LHS.getEntryBlock(); BasicBlock *RHSEntryBlock = RHS.getEntryBlock(); @@ -646,7 +645,7 @@ private: for (Loop *L : LV) { TTI::PeelingPreferences PP = gatherPeelingPreferences(L, SE, TTI, None, None); - FusionCandidate CurrCand(L, &DT, &PDT, ORE, PP); + FusionCandidate CurrCand(L, DT, &PDT, ORE, PP); if (!CurrCand.isEligibleForFusion(SE)) continue; @@ -991,7 +990,7 @@ private: FuseCounter); FusionCandidate FusedCand( - performFusion((Peel ? FC0Copy : *FC0), *FC1), &DT, &PDT, ORE, + performFusion((Peel ? FC0Copy : *FC0), *FC1), DT, &PDT, ORE, FC0Copy.PP); FusedCand.verify(); assert(FusedCand.isEligibleForFusion(SE) && diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 2635d0a213ff..88d6a7aff3c9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -61,7 +61,6 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" @@ -346,7 +345,7 @@ INITIALIZE_PASS_END(LoopIdiomRecognizeLegacyPass, "loop-idiom", Pass *llvm::createLoopIdiomPass() { return new LoopIdiomRecognizeLegacyPass(); } static void deleteDeadInstruction(Instruction *I) { - I->replaceAllUsesWith(UndefValue::get(I->getType())); + I->replaceAllUsesWith(PoisonValue::get(I->getType())); I->eraseFromParent(); } @@ -798,7 +797,7 @@ bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL, } /// processLoopMemIntrinsic - Template function for calling different processor -/// functions based on mem instrinsic type. +/// functions based on mem intrinsic type. template <typename MemInst> bool LoopIdiomRecognize::processLoopMemIntrinsic( BasicBlock *BB, @@ -995,9 +994,8 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, SmallPtrSet<Instruction *, 1> MSIs; MSIs.insert(MSI); return processLoopStridedStore(Pointer, SE->getSCEV(MSI->getLength()), - MaybeAlign(MSI->getDestAlignment()), - SplatValue, MSI, MSIs, Ev, BECount, - IsNegStride, /*IsLoopMemset=*/true); + MSI->getDestAlign(), SplatValue, MSI, MSIs, Ev, + BECount, IsNegStride, /*IsLoopMemset=*/true); } /// mayLoopAccessLocation - Return true if the specified loop might access the @@ -1101,6 +1099,7 @@ bool LoopIdiomRecognize::processLoopStridedStore( Value *StoredVal, Instruction *TheStore, SmallPtrSetImpl<Instruction *> &Stores, const SCEVAddRecExpr *Ev, const SCEV *BECount, bool IsNegStride, bool IsLoopMemset) { + Module *M = TheStore->getModule(); Value *SplatValue = isBytewiseValue(StoredVal, *DL); Constant *PatternValue = nullptr; @@ -1183,15 +1182,14 @@ bool LoopIdiomRecognize::processLoopStridedStore( NewCall = Builder.CreateMemSet( BasePtr, SplatValue, NumBytes, MaybeAlign(StoreAlignment), /*isVolatile=*/false, AATags.TBAA, AATags.Scope, AATags.NoAlias); - } else { + } else if (isLibFuncEmittable(M, TLI, LibFunc_memset_pattern16)) { // Everything is emitted in default address space Type *Int8PtrTy = DestInt8PtrTy; - Module *M = TheStore->getModule(); StringRef FuncName = "memset_pattern16"; - FunctionCallee MSP = M->getOrInsertFunction(FuncName, Builder.getVoidTy(), - Int8PtrTy, Int8PtrTy, IntIdxTy); - inferLibFuncAttributes(M, FuncName, *TLI); + FunctionCallee MSP = getOrInsertLibFunc(M, *TLI, LibFunc_memset_pattern16, + Builder.getVoidTy(), Int8PtrTy, Int8PtrTy, IntIdxTy); + inferNonMandatoryLibFuncAttrs(M, FuncName, *TLI); // Otherwise we should form a memset_pattern16. PatternValue is known to be // an constant array of 16-bytes. Plop the value into a mergable global. @@ -1202,7 +1200,9 @@ bool LoopIdiomRecognize::processLoopStridedStore( GV->setAlignment(Align(16)); Value *PatternPtr = ConstantExpr::getBitCast(GV, Int8PtrTy); NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, NumBytes}); - } + } else + return Changed; + NewCall->setDebugLoc(TheStore->getDebugLoc()); if (MSSAU) { @@ -1277,9 +1277,8 @@ class MemmoveVerifier { public: explicit MemmoveVerifier(const Value &LoadBasePtr, const Value &StoreBasePtr, const DataLayout &DL) - : DL(DL), LoadOff(0), StoreOff(0), - BP1(llvm::GetPointerBaseWithConstantOffset( - LoadBasePtr.stripPointerCasts(), LoadOff, DL)), + : DL(DL), BP1(llvm::GetPointerBaseWithConstantOffset( + LoadBasePtr.stripPointerCasts(), LoadOff, DL)), BP2(llvm::GetPointerBaseWithConstantOffset( StoreBasePtr.stripPointerCasts(), StoreOff, DL)), IsSameObject(BP1 == BP2) {} @@ -1309,8 +1308,8 @@ public: private: const DataLayout &DL; - int64_t LoadOff; - int64_t StoreOff; + int64_t LoadOff = 0; + int64_t StoreOff = 0; const Value *BP1; const Value *BP2; @@ -1482,7 +1481,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad( return Changed; // We cannot allow unaligned ops for unordered load/store, so reject // anything where the alignment isn't at least the element size. - assert((StoreAlign.hasValue() && LoadAlign.hasValue()) && + assert((StoreAlign && LoadAlign) && "Expect unordered load/store to have align."); if (StoreAlign.getValue() < StoreSize || LoadAlign.getValue() < StoreSize) return Changed; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp index b9e63a4bc06f..4249512ea0f8 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -11,7 +11,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LoopInstSimplify.h" -#include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" @@ -25,21 +24,17 @@ #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CFG.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" -#include "llvm/IR/User.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" -#include <algorithm> #include <utility> using namespace llvm; @@ -101,7 +96,7 @@ static bool simplifyLoopInst(Loop &L, DominatorTree &DT, LoopInfo &LI, if (!IsFirstIteration && !ToSimplify->count(&I)) continue; - Value *V = SimplifyInstruction(&I, SQ.getWithInstruction(&I)); + Value *V = simplifyInstruction(&I, SQ.getWithInstruction(&I)); if (!V || !LI.replacementPreservesLCSSAForm(&I, V)) continue; @@ -109,6 +104,10 @@ static bool simplifyLoopInst(Loop &L, DominatorTree &DT, LoopInfo &LI, auto *UserI = cast<Instruction>(U.getUser()); U.set(V); + // Do not bother dealing with unreachable code. + if (!DT.isReachableFromEntry(UserI->getParent())) + continue; + // If the instruction is used by a PHI node we have already processed // we'll need to iterate on the loop body to converge, so add it to // the next set. @@ -222,7 +221,7 @@ PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, AR.MSSA->verifyMemorySSA(); } if (!simplifyLoopInst(L, AR.DT, AR.LI, AR.AC, AR.TLI, - MSSAU.hasValue() ? MSSAU.getPointer() : nullptr)) + MSSAU ? MSSAU.getPointer() : nullptr)) return PreservedAnalyses::all(); auto PA = getLoopPassPreservedAnalyses(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInterchange.cpp index c2b065c4eb31..1d3023d04463 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/DependenceAnalysis.h" +#include "llvm/Analysis/LoopCacheAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopNestAnalysis.h" #include "llvm/Analysis/LoopPass.h" @@ -33,7 +34,6 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/Type.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" @@ -44,7 +44,6 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include <cassert> @@ -120,8 +119,6 @@ static bool populateDependencyMatrix(CharMatrix &DepMatrix, unsigned Level, std::vector<char> Dep; Instruction *Src = cast<Instruction>(*I); Instruction *Dst = cast<Instruction>(*J); - if (Src == Dst) - continue; // Ignore Input dependencies. if (isa<LoadInst>(Src) && isa<LoadInst>(Dst)) continue; @@ -270,26 +267,28 @@ static bool isLegalToInterChangeLoops(CharMatrix &DepMatrix, return true; } -static LoopVector populateWorklist(Loop &L) { +static void populateWorklist(Loop &L, LoopVector &LoopList) { LLVM_DEBUG(dbgs() << "Calling populateWorklist on Func: " << L.getHeader()->getParent()->getName() << " Loop: %" << L.getHeader()->getName() << '\n'); - LoopVector LoopList; + assert(LoopList.empty() && "LoopList should initially be empty!"); Loop *CurrentLoop = &L; const std::vector<Loop *> *Vec = &CurrentLoop->getSubLoops(); while (!Vec->empty()) { // The current loop has multiple subloops in it hence it is not tightly // nested. // Discard all loops above it added into Worklist. - if (Vec->size() != 1) - return {}; + if (Vec->size() != 1) { + LoopList = {}; + return; + } LoopList.push_back(CurrentLoop); CurrentLoop = Vec->front(); Vec = &CurrentLoop->getSubLoops(); } LoopList.push_back(CurrentLoop); - return LoopList; + return; } namespace { @@ -360,8 +359,10 @@ public: : OuterLoop(Outer), InnerLoop(Inner), SE(SE), ORE(ORE) {} /// Check if the loop interchange is profitable. - bool isProfitable(unsigned InnerLoopId, unsigned OuterLoopId, - CharMatrix &DepMatrix); + bool isProfitable(const Loop *InnerLoop, const Loop *OuterLoop, + unsigned InnerLoopId, unsigned OuterLoopId, + CharMatrix &DepMatrix, + const DenseMap<const Loop *, unsigned> &CostMap); private: int getInstrOrderCost(); @@ -412,23 +413,26 @@ struct LoopInterchange { LoopInfo *LI = nullptr; DependenceInfo *DI = nullptr; DominatorTree *DT = nullptr; + std::unique_ptr<CacheCost> CC = nullptr; /// Interface to emit optimization remarks. OptimizationRemarkEmitter *ORE; LoopInterchange(ScalarEvolution *SE, LoopInfo *LI, DependenceInfo *DI, - DominatorTree *DT, OptimizationRemarkEmitter *ORE) - : SE(SE), LI(LI), DI(DI), DT(DT), ORE(ORE) {} + DominatorTree *DT, std::unique_ptr<CacheCost> &CC, + OptimizationRemarkEmitter *ORE) + : SE(SE), LI(LI), DI(DI), DT(DT), CC(std::move(CC)), ORE(ORE) {} bool run(Loop *L) { if (L->getParentLoop()) return false; - - return processLoopList(populateWorklist(*L)); + SmallVector<Loop *, 8> LoopList; + populateWorklist(*L, LoopList); + return processLoopList(LoopList); } bool run(LoopNest &LN) { - const auto &LoopList = LN.getLoops(); + SmallVector<Loop *, 8> LoopList(LN.getLoops().begin(), LN.getLoops().end()); for (unsigned I = 1; I < LoopList.size(); ++I) if (LoopList[I]->getParentLoop() != LoopList[I - 1]) return false; @@ -460,7 +464,7 @@ struct LoopInterchange { return LoopList.size() - 1; } - bool processLoopList(ArrayRef<Loop *> LoopList) { + bool processLoopList(SmallVectorImpl<Loop *> &LoopList) { bool Changed = false; unsigned LoopNestDepth = LoopList.size(); if (LoopNestDepth < 2) { @@ -500,27 +504,55 @@ struct LoopInterchange { } unsigned SelecLoopId = selectLoopForInterchange(LoopList); - // Move the selected loop outwards to the best possible position. - Loop *LoopToBeInterchanged = LoopList[SelecLoopId]; - for (unsigned i = SelecLoopId; i > 0; i--) { - bool Interchanged = processLoop(LoopToBeInterchanged, LoopList[i - 1], i, - i - 1, DependencyMatrix); - if (!Interchanged) - return Changed; - // Update the DependencyMatrix - interChangeDependencies(DependencyMatrix, i, i - 1); + // Obtain the loop vector returned from loop cache analysis beforehand, + // and put each <Loop, index> pair into a map for constant time query + // later. Indices in loop vector reprsent the optimal order of the + // corresponding loop, e.g., given a loopnest with depth N, index 0 + // indicates the loop should be placed as the outermost loop and index N + // indicates the loop should be placed as the innermost loop. + // + // For the old pass manager CacheCost would be null. + DenseMap<const Loop *, unsigned> CostMap; + if (CC != nullptr) { + const auto &LoopCosts = CC->getLoopCosts(); + for (unsigned i = 0; i < LoopCosts.size(); i++) { + CostMap[LoopCosts[i].first] = i; + } + } + // We try to achieve the globally optimal memory access for the loopnest, + // and do interchange based on a bubble-sort fasion. We start from + // the innermost loop, move it outwards to the best possible position + // and repeat this process. + for (unsigned j = SelecLoopId; j > 0; j--) { + bool ChangedPerIter = false; + for (unsigned i = SelecLoopId; i > SelecLoopId - j; i--) { + bool Interchanged = processLoop(LoopList[i], LoopList[i - 1], i, i - 1, + DependencyMatrix, CostMap); + if (!Interchanged) + continue; + // Loops interchanged, update LoopList accordingly. + std::swap(LoopList[i - 1], LoopList[i]); + // Update the DependencyMatrix + interChangeDependencies(DependencyMatrix, i, i - 1); #ifdef DUMP_DEP_MATRICIES - LLVM_DEBUG(dbgs() << "Dependence after interchange\n"); - printDepMatrix(DependencyMatrix); + LLVM_DEBUG(dbgs() << "Dependence after interchange\n"); + printDepMatrix(DependencyMatrix); #endif - Changed |= Interchanged; + ChangedPerIter |= Interchanged; + Changed |= Interchanged; + } + // Early abort if there was no interchange during an entire round of + // moving loops outwards. + if (!ChangedPerIter) + break; } return Changed; } bool processLoop(Loop *InnerLoop, Loop *OuterLoop, unsigned InnerLoopId, unsigned OuterLoopId, - std::vector<std::vector<char>> &DependencyMatrix) { + std::vector<std::vector<char>> &DependencyMatrix, + const DenseMap<const Loop *, unsigned> &CostMap) { LLVM_DEBUG(dbgs() << "Processing InnerLoopId = " << InnerLoopId << " and OuterLoopId = " << OuterLoopId << "\n"); LoopInterchangeLegality LIL(OuterLoop, InnerLoop, SE, ORE); @@ -530,7 +562,8 @@ struct LoopInterchange { } LLVM_DEBUG(dbgs() << "Loops are legal to interchange\n"); LoopInterchangeProfitability LIP(OuterLoop, InnerLoop, SE, ORE); - if (!LIP.isProfitable(InnerLoopId, OuterLoopId, DependencyMatrix)) { + if (!LIP.isProfitable(InnerLoop, OuterLoop, InnerLoopId, OuterLoopId, + DependencyMatrix, CostMap)) { LLVM_DEBUG(dbgs() << "Interchanging loops not profitable.\n"); return false; } @@ -733,8 +766,12 @@ static PHINode *findInnerReductionPhi(Loop *L, Value *V) { if (PHI->getNumIncomingValues() == 1) continue; RecurrenceDescriptor RD; - if (RecurrenceDescriptor::isReductionPHI(PHI, L, RD)) + if (RecurrenceDescriptor::isReductionPHI(PHI, L, RD)) { + // Detect floating point reduction only when it can be reordered. + if (RD.getExactFPMathInst() != nullptr) + return nullptr; return PHI; + } return nullptr; } } @@ -893,28 +930,23 @@ areInnerLoopExitPHIsSupported(Loop *InnerL, Loop *OuterL, static bool areOuterLoopExitPHIsSupported(Loop *OuterLoop, Loop *InnerLoop) { BasicBlock *LoopNestExit = OuterLoop->getUniqueExitBlock(); for (PHINode &PHI : LoopNestExit->phis()) { - // FIXME: We currently are not able to detect floating point reductions - // and have to use floating point PHIs as a proxy to prevent - // interchanging in the presence of floating point reductions. - if (PHI.getType()->isFloatingPointTy()) - return false; for (unsigned i = 0; i < PHI.getNumIncomingValues(); i++) { - Instruction *IncomingI = dyn_cast<Instruction>(PHI.getIncomingValue(i)); - if (!IncomingI || IncomingI->getParent() != OuterLoop->getLoopLatch()) - continue; - - // The incoming value is defined in the outer loop latch. Currently we - // only support that in case the outer loop latch has a single predecessor. - // This guarantees that the outer loop latch is executed if and only if - // the inner loop is executed (because tightlyNested() guarantees that the - // outer loop header only branches to the inner loop or the outer loop - // latch). - // FIXME: We could weaken this logic and allow multiple predecessors, - // if the values are produced outside the loop latch. We would need - // additional logic to update the PHI nodes in the exit block as - // well. - if (OuterLoop->getLoopLatch()->getUniquePredecessor() == nullptr) - return false; + Instruction *IncomingI = dyn_cast<Instruction>(PHI.getIncomingValue(i)); + if (!IncomingI || IncomingI->getParent() != OuterLoop->getLoopLatch()) + continue; + + // The incoming value is defined in the outer loop latch. Currently we + // only support that in case the outer loop latch has a single predecessor. + // This guarantees that the outer loop latch is executed if and only if + // the inner loop is executed (because tightlyNested() guarantees that the + // outer loop header only branches to the inner loop or the outer loop + // latch). + // FIXME: We could weaken this logic and allow multiple predecessors, + // if the values are produced outside the loop latch. We would need + // additional logic to update the PHI nodes in the exit block as + // well. + if (OuterLoop->getLoopLatch()->getUniquePredecessor() == nullptr) + return false; } } return true; @@ -1125,21 +1157,33 @@ static bool isProfitableForVectorization(unsigned InnerLoopId, return !DepMatrix.empty(); } -bool LoopInterchangeProfitability::isProfitable(unsigned InnerLoopId, - unsigned OuterLoopId, - CharMatrix &DepMatrix) { - // TODO: Add better profitability checks. - // e.g - // 1) Construct dependency matrix and move the one with no loop carried dep - // inside to enable vectorization. - - // This is rough cost estimation algorithm. It counts the good and bad order - // of induction variables in the instruction and allows reordering if number - // of bad orders is more than good. - int Cost = getInstrOrderCost(); - LLVM_DEBUG(dbgs() << "Cost = " << Cost << "\n"); - if (Cost < -LoopInterchangeCostThreshold) - return true; +bool LoopInterchangeProfitability::isProfitable( + const Loop *InnerLoop, const Loop *OuterLoop, unsigned InnerLoopId, + unsigned OuterLoopId, CharMatrix &DepMatrix, + const DenseMap<const Loop *, unsigned> &CostMap) { + // TODO: Remove the legacy cost model. + + // This is the new cost model returned from loop cache analysis. + // A smaller index means the loop should be placed an outer loop, and vice + // versa. + if (CostMap.find(InnerLoop) != CostMap.end() && + CostMap.find(OuterLoop) != CostMap.end()) { + unsigned InnerIndex = 0, OuterIndex = 0; + InnerIndex = CostMap.find(InnerLoop)->second; + OuterIndex = CostMap.find(OuterLoop)->second; + LLVM_DEBUG(dbgs() << "InnerIndex = " << InnerIndex + << ", OuterIndex = " << OuterIndex << "\n"); + if (InnerIndex < OuterIndex) + return true; + } else { + // Legacy cost model: this is rough cost estimation algorithm. It counts the + // good and bad order of induction variables in the instruction and allows + // reordering if number of bad orders is more than good. + int Cost = getInstrOrderCost(); + LLVM_DEBUG(dbgs() << "Cost = " << Cost << "\n"); + if (Cost < -LoopInterchangeCostThreshold) + return true; + } // It is not profitable as per current cache profitability model. But check if // we can move this loop outside to improve parallelism. @@ -1150,10 +1194,8 @@ bool LoopInterchangeProfitability::isProfitable(unsigned InnerLoopId, return OptimizationRemarkMissed(DEBUG_TYPE, "InterchangeNotProfitable", InnerLoop->getStartLoc(), InnerLoop->getHeader()) - << "Interchanging loops is too costly (cost=" - << ore::NV("Cost", Cost) << ", threshold=" - << ore::NV("Threshold", LoopInterchangeCostThreshold) - << ") and it does not improve parallelism."; + << "Interchanging loops is too costly and it does not improve " + "parallelism."; }); return false; } @@ -1424,9 +1466,13 @@ static void moveLCSSAPhis(BasicBlock *InnerExit, BasicBlock *InnerHeader, // Incoming values are guaranteed be instructions currently. auto IncI = cast<Instruction>(P.getIncomingValueForBlock(InnerLatch)); + // In case of multi-level nested loops, follow LCSSA to find the incoming + // value defined from the innermost loop. + auto IncIInnerMost = cast<Instruction>(followLCSSA(IncI)); // Skip phis with incoming values from the inner loop body, excluding the // header and latch. - if (IncI->getParent() != InnerLatch && IncI->getParent() != InnerHeader) + if (IncIInnerMost->getParent() != InnerLatch && + IncIInnerMost->getParent() != InnerHeader) continue; assert(all_of(P.users(), @@ -1695,8 +1741,8 @@ struct LoopInterchangeLegacyPass : public LoopPass { auto *DI = &getAnalysis<DependenceAnalysisWrapperPass>().getDI(); auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto *ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - - return LoopInterchange(SE, LI, DI, DT, ORE).run(L); + std::unique_ptr<CacheCost> CC = nullptr; + return LoopInterchange(SE, LI, DI, DT, CC, ORE).run(L); } }; } // namespace @@ -1723,8 +1769,10 @@ PreservedAnalyses LoopInterchangePass::run(LoopNest &LN, Function &F = *LN.getParent(); DependenceInfo DI(&F, &AR.AA, &AR.SE, &AR.LI); + std::unique_ptr<CacheCost> CC = + CacheCost::getCacheCost(LN.getOutermostLoop(), AR, DI); OptimizationRemarkEmitter ORE(&F); - if (!LoopInterchange(&AR.SE, &AR.LI, &DI, &AR.DT, &ORE).run(LN)) + if (!LoopInterchange(&AR.SE, &AR.LI, &DI, &AR.DT, CC, &ORE).run(LN)) return PreservedAnalyses::all(); return getLoopPassPreservedAnalyses(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp index 21d59936616b..1877ac1dfd08 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -61,7 +61,6 @@ #include <algorithm> #include <cassert> #include <forward_list> -#include <set> #include <tuple> #include <utility> @@ -213,7 +212,8 @@ public: continue; // Only progagate the value if they are of the same type. - if (Store->getPointerOperandType() != Load->getPointerOperandType()) + if (Store->getPointerOperandType() != Load->getPointerOperandType() || + getLoadStoreType(Store) != getLoadStoreType(Load)) continue; Candidates.emplace_front(Load, Store); @@ -528,7 +528,7 @@ public: return false; } - if (LAI.getPSE().getUnionPredicate().getComplexity() > + if (LAI.getPSE().getPredicate().getComplexity() > LoadElimSCEVCheckThreshold) { LLVM_DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n"); return false; @@ -539,7 +539,7 @@ public: return false; } - if (!Checks.empty() || !LAI.getPSE().getUnionPredicate().isAlwaysTrue()) { + if (!Checks.empty() || !LAI.getPSE().getPredicate().isAlwaysTrue()) { if (LAI.hasConvergentOp()) { LLVM_DEBUG(dbgs() << "Versioning is needed but not allowed with " "convergent calls\n"); @@ -706,8 +706,12 @@ FunctionPass *llvm::createLoopLoadEliminationPass() { PreservedAnalyses LoopLoadEliminationPass::run(Function &F, FunctionAnalysisManager &AM) { - auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); auto &LI = AM.getResult<LoopAnalysis>(F); + // There are no loops in the function. Return before computing other expensive + // analyses. + if (LI.empty()) + return PreservedAnalyses::all(); + auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); auto &TTI = AM.getResult<TargetIRAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopPassManager.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopPassManager.cpp index 6c783848432b..d20d275ea60c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopPassManager.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopPassManager.cpp @@ -8,14 +8,12 @@ #include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BranchProbabilityInfo.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/MemorySSA.h" -#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Support/Debug.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Support/TimeProfiler.h" using namespace llvm; @@ -311,12 +309,12 @@ PreservedAnalyses FunctionToLoopPassAdaptor::run(Function &F, #ifndef NDEBUG // LoopAnalysisResults should always be valid. - // Note that we don't LAR.SE.verify() because that can change observed SE - // queries. See PR44815. if (VerifyDomInfo) LAR.DT.verify(); if (VerifyLoopInfo) LAR.LI.verify(LAR.DT); + if (VerifySCEV) + LAR.SE.verify(); if (LAR.MSSA && VerifyMemorySSA) LAR.MSSA->verifyMemorySSA(); #endif diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopPredication.cpp index aa7e79a589f2..d0ee5b47a8ca 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopPredication.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopPredication.cpp @@ -188,7 +188,6 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/IR/Function.h" -#include "llvm/IR/GlobalValue.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" @@ -244,7 +243,7 @@ struct LoopICmp { LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV, const SCEV *Limit) : Pred(Pred), IV(IV), Limit(Limit) {} - LoopICmp() {} + LoopICmp() = default; void dump() { dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV << ", Limit = " << *Limit << "\n"; @@ -778,7 +777,7 @@ unsigned LoopPredication::collectChecks(SmallVectorImpl<Value *> &Checks, if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) { if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Guard)) { - Checks.push_back(NewRangeCheck.getValue()); + Checks.push_back(*NewRangeCheck); NumWidened++; continue; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp index 9d22eceb987f..f4ef22562341 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp @@ -29,15 +29,11 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" -#include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" @@ -59,7 +55,6 @@ #include <cassert> #include <cstddef> #include <cstdint> -#include <cstdlib> #include <iterator> #include <map> #include <utility> @@ -559,12 +554,12 @@ bool LoopReroll::isLoopControlIV(Loop *L, Instruction *IV) { } // Must be a CMP or an ext (of a value with nsw) then CMP else { - Instruction *UUser = dyn_cast<Instruction>(UU); + auto *UUser = cast<Instruction>(UU); // Skip SExt if we are extending an nsw value // TODO: Allow ZExt too - if (BO->hasNoSignedWrap() && UUser && UUser->hasOneUse() && + if (BO->hasNoSignedWrap() && UUser->hasOneUse() && isa<SExtInst>(UUser)) - UUser = dyn_cast<Instruction>(*(UUser->user_begin())); + UUser = cast<Instruction>(*(UUser->user_begin())); if (!isCompareUsedByBranch(UUser)) return false; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRotation.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRotation.cpp index 5ba137b1c85f..d9c33b5f335a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRotation.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRotation.cpp @@ -11,10 +11,10 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LoopRotation.h" -#include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LazyBlockFrequencyInfo.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" @@ -22,9 +22,7 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/LoopRotationUtils.h" #include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; @@ -62,8 +60,8 @@ PreservedAnalyses LoopRotatePass::run(Loop &L, LoopAnalysisManager &AM, MSSAU = MemorySSAUpdater(AR.MSSA); bool Changed = LoopRotation(&L, &AR.LI, &AR.TTI, &AR.AC, &AR.DT, &AR.SE, - MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, SQ, false, - Threshold, false, PrepareForLTO || PrepareForLTOOption); + MSSAU ? MSSAU.getPointer() : nullptr, SQ, false, Threshold, + false, PrepareForLTO || PrepareForLTOOption); if (!Changed) return PreservedAnalyses::all(); @@ -133,9 +131,8 @@ public: : MaxHeaderSize; return LoopRotation(L, LI, TTI, AC, &DT, &SE, - MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, SQ, - false, Threshold, false, - PrepareForLTO || PrepareForLTOOption); + MSSAU ? MSSAU.getPointer() : nullptr, SQ, false, + Threshold, false, PrepareForLTO || PrepareForLTOOption); } }; } // end namespace diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp index d3fcba10c275..b7e0e32780b4 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp @@ -16,28 +16,21 @@ #include "llvm/Transforms/Scalar/LoopSimplifyCFG.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/DependenceAnalysis.h" #include "llvm/Analysis/DomTreeUpdater.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" -#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" -#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; @@ -261,13 +254,17 @@ private: assert(L.getNumBlocks() == LiveLoopBlocks.size() + DeadLoopBlocks.size() && "Malformed block sets?"); - // Now, all exit blocks that are not marked as live are dead. + // Now, all exit blocks that are not marked as live are dead, if all their + // predecessors are in the loop. This may not be the case, as the input loop + // may not by in loop-simplify/canonical form. SmallVector<BasicBlock *, 8> ExitBlocks; L.getExitBlocks(ExitBlocks); SmallPtrSet<BasicBlock *, 8> UniqueDeadExits; for (auto *ExitBlock : ExitBlocks) if (!LiveExitBlocks.count(ExitBlock) && - UniqueDeadExits.insert(ExitBlock).second) + UniqueDeadExits.insert(ExitBlock).second && + all_of(predecessors(ExitBlock), + [this](BasicBlock *Pred) { return L.contains(Pred); })) DeadExitBlocks.push_back(ExitBlock); // Whether or not the edge From->To will still be present in graph after the @@ -374,7 +371,7 @@ private: DeadInstructions.emplace_back(LandingPad); for (Instruction *I : DeadInstructions) { - I->replaceAllUsesWith(UndefValue::get(I->getType())); + I->replaceAllUsesWith(PoisonValue::get(I->getType())); I->eraseFromParent(); } @@ -704,8 +701,7 @@ PreservedAnalyses LoopSimplifyCFGPass::run(Loop &L, LoopAnalysisManager &AM, MSSAU = MemorySSAUpdater(AR.MSSA); bool DeleteCurrentLoop = false; if (!simplifyLoopCFG(L, AR.DT, AR.LI, AR.SE, - MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, - DeleteCurrentLoop)) + MSSAU ? MSSAU.getPointer() : nullptr, DeleteCurrentLoop)) return PreservedAnalyses::all(); if (DeleteCurrentLoop) @@ -739,9 +735,9 @@ public: if (MSSAA && VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); bool DeleteCurrentLoop = false; - bool Changed = simplifyLoopCFG( - *L, DT, LI, SE, MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, - DeleteCurrentLoop); + bool Changed = + simplifyLoopCFG(*L, DT, LI, SE, MSSAU ? MSSAU.getPointer() : nullptr, + DeleteCurrentLoop); if (DeleteCurrentLoop) LPM.markLoopAsDeleted(*L); return Changed; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSink.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSink.cpp index c9c9e60d0921..dce1af475fb1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSink.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSink.cpp @@ -34,24 +34,18 @@ #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/AliasSetTracker.h" -#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/BlockFrequencyInfo.h" -#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Metadata.h" #include "llvm/InitializePasses.h" +#include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; @@ -70,14 +64,6 @@ static cl::opt<unsigned> MaxNumberOfUseBBsForSinking( "max-uses-for-sinking", cl::Hidden, cl::init(30), cl::desc("Do not sink instructions that have too many uses.")); -static cl::opt<bool> EnableMSSAInLoopSink( - "enable-mssa-in-loop-sink", cl::Hidden, cl::init(true), - cl::desc("Enable MemorySSA for LoopSink in new pass manager")); - -static cl::opt<bool> EnableMSSAInLegacyLoopSink( - "enable-mssa-in-legacy-loop-sink", cl::Hidden, cl::init(false), - cl::desc("Enable MemorySSA for LoopSink in legacy pass manager")); - /// Return adjusted total frequency of \p BBs. /// /// * If there is only one BB, sinking instruction will not introduce code @@ -279,9 +265,8 @@ static bool sinkInstruction( static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI, DominatorTree &DT, BlockFrequencyInfo &BFI, - ScalarEvolution *SE, - AliasSetTracker *CurAST, - MemorySSA *MSSA) { + MemorySSA &MSSA, + ScalarEvolution *SE) { BasicBlock *Preheader = L.getLoopPreheader(); assert(Preheader && "Expected loop to have preheader"); @@ -297,13 +282,8 @@ static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI, })) return false; - std::unique_ptr<MemorySSAUpdater> MSSAU; - std::unique_ptr<SinkAndHoistLICMFlags> LICMFlags; - if (MSSA) { - MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); - LICMFlags = - std::make_unique<SinkAndHoistLICMFlags>(/*IsSink=*/true, &L, MSSA); - } + MemorySSAUpdater MSSAU(&MSSA); + SinkAndHoistLICMFlags LICMFlags(/*IsSink=*/true, &L, &MSSA); bool Changed = false; @@ -324,14 +304,15 @@ static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI, // on B (A appears after B), A needs to be sinked first before B can be // sinked. for (Instruction &I : llvm::make_early_inc_range(llvm::reverse(*Preheader))) { + if (isa<PHINode>(&I)) + continue; // No need to check for instruction's operands are loop invariant. assert(L.hasLoopInvariantOperands(&I) && "Insts in a loop's preheader should have loop invariant operands!"); - if (!canSinkOrHoistInst(I, &AA, &DT, &L, CurAST, MSSAU.get(), false, - LICMFlags.get())) + if (!canSinkOrHoistInst(I, &AA, &DT, &L, MSSAU, false, LICMFlags)) continue; if (sinkInstruction(L, I, ColdLoopBBs, LoopBlockNumber, LI, DT, BFI, - MSSAU.get())) + &MSSAU)) Changed = true; } @@ -340,13 +321,6 @@ static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI, return Changed; } -static void computeAliasSet(Loop &L, BasicBlock &Preheader, - AliasSetTracker &CurAST) { - for (BasicBlock *BB : L.blocks()) - CurAST.add(*BB); - CurAST.add(Preheader); -} - PreservedAnalyses LoopSinkPass::run(Function &F, FunctionAnalysisManager &FAM) { LoopInfo &LI = FAM.getResult<LoopAnalysis>(F); // Nothing to do if there are no loops. @@ -356,10 +330,7 @@ PreservedAnalyses LoopSinkPass::run(Function &F, FunctionAnalysisManager &FAM) { AAResults &AA = FAM.getResult<AAManager>(F); DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); BlockFrequencyInfo &BFI = FAM.getResult<BlockFrequencyAnalysis>(F); - - MemorySSA *MSSA = EnableMSSAInLoopSink - ? &FAM.getResult<MemorySSAAnalysis>(F).getMSSA() - : nullptr; + MemorySSA &MSSA = FAM.getResult<MemorySSAAnalysis>(F).getMSSA(); // We want to do a postorder walk over the loops. Since loops are a tree this // is equivalent to a reversed preorder walk and preorder is easy to compute @@ -381,18 +352,11 @@ PreservedAnalyses LoopSinkPass::run(Function &F, FunctionAnalysisManager &FAM) { if (!Preheader->getParent()->hasProfileData()) continue; - std::unique_ptr<AliasSetTracker> CurAST; - if (!EnableMSSAInLoopSink) { - CurAST = std::make_unique<AliasSetTracker>(AA); - computeAliasSet(L, *Preheader, *CurAST.get()); - } - // Note that we don't pass SCEV here because it is only used to invalidate // loops in SCEV and we don't preserve (or request) SCEV at all making that // unnecessary. - Changed |= sinkLoopInvariantInstructions(L, AA, LI, DT, BFI, - /*ScalarEvolution*/ nullptr, - CurAST.get(), MSSA); + Changed |= sinkLoopInvariantInstructions(L, AA, LI, DT, BFI, MSSA, + /*ScalarEvolution*/ nullptr); } while (!PreorderLoops.empty()); if (!Changed) @@ -400,13 +364,10 @@ PreservedAnalyses LoopSinkPass::run(Function &F, FunctionAnalysisManager &FAM) { PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); + PA.preserve<MemorySSAAnalysis>(); - if (MSSA) { - PA.preserve<MemorySSAAnalysis>(); - - if (VerifyMemorySSA) - MSSA->verifyMemorySSA(); - } + if (VerifyMemorySSA) + MSSA.verifyMemorySSA(); return PA; } @@ -432,24 +393,16 @@ struct LegacyLoopSinkPass : public LoopPass { return false; AAResults &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + MemorySSA &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); auto *SE = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); - std::unique_ptr<AliasSetTracker> CurAST; - MemorySSA *MSSA = nullptr; - if (EnableMSSAInLegacyLoopSink) - MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); - else { - CurAST = std::make_unique<AliasSetTracker>(AA); - computeAliasSet(*L, *Preheader, *CurAST.get()); - } - bool Changed = sinkLoopInvariantInstructions( *L, AA, getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), getAnalysis<DominatorTreeWrapperPass>().getDomTree(), getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI(), - SE ? &SE->getSE() : nullptr, CurAST.get(), MSSA); + MSSA, SE ? &SE->getSE() : nullptr); - if (MSSA && VerifyMemorySSA) - MSSA->verifyMemorySSA(); + if (VerifyMemorySSA) + MSSA.verifyMemorySSA(); return Changed; } @@ -458,10 +411,8 @@ struct LegacyLoopSinkPass : public LoopPass { AU.setPreservesCFG(); AU.addRequired<BlockFrequencyInfoWrapperPass>(); getLoopAnalysisUsage(AU); - if (EnableMSSAInLegacyLoopSink) { - AU.addRequired<MemorySSAWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - } + AU.addRequired<MemorySSAWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); } }; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 654f0d2a03a8..9959e408e2e2 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -78,6 +78,7 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/BinaryFormat/Dwarf.h" #include "llvm/Config/llvm-config.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -91,9 +92,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" -#include "llvm/IR/OperandTraits.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" @@ -114,12 +113,12 @@ #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include <algorithm> #include <cassert> #include <cstddef> #include <cstdint> -#include <cstdlib> #include <iterator> #include <limits> #include <map> @@ -142,10 +141,7 @@ static const unsigned MaxIVUsers = 200; /// the salvaging is not too expensive for the compiler. static const unsigned MaxSCEVSalvageExpressionSize = 64; -// Temporary flag to cleanup congruent phis after LSR phi expansion. -// It's currently disabled until we can determine whether it's truly useful or -// not. The flag should be removed after the v3.0 release. -// This is now needed for ivchains. +// Cleanup congruent phis after LSR phi expansion. static cl::opt<bool> EnablePhiElim( "enable-lsr-phielim", cl::Hidden, cl::init(true), cl::desc("Enable LSR phi elimination")); @@ -481,6 +477,12 @@ void Formula::initialMatch(const SCEV *S, Loop *L, ScalarEvolution &SE) { canonicalize(*L); } +static bool containsAddRecDependentOnLoop(const SCEV *S, const Loop &L) { + return SCEVExprContains(S, [&L](const SCEV *S) { + return isa<SCEVAddRecExpr>(S) && (cast<SCEVAddRecExpr>(S)->getLoop() == &L); + }); +} + /// Check whether or not this formula satisfies the canonical /// representation. /// \see Formula::BaseRegs. @@ -494,18 +496,15 @@ bool Formula::isCanonical(const Loop &L) const { if (Scale == 1 && BaseRegs.empty()) return false; - const SCEVAddRecExpr *SAR = dyn_cast<const SCEVAddRecExpr>(ScaledReg); - if (SAR && SAR->getLoop() == &L) + if (containsAddRecDependentOnLoop(ScaledReg, L)) return true; // If ScaledReg is not a recurrent expr, or it is but its loop is not current // loop, meanwhile BaseRegs contains a recurrent expr reg related with current // loop, we want to swap the reg in BaseRegs with ScaledReg. - auto I = find_if(BaseRegs, [&](const SCEV *S) { - return isa<const SCEVAddRecExpr>(S) && - (cast<SCEVAddRecExpr>(S)->getLoop() == &L); + return none_of(BaseRegs, [&L](const SCEV *S) { + return containsAddRecDependentOnLoop(S, L); }); - return I == BaseRegs.end(); } /// Helper method to morph a formula into its canonical representation. @@ -537,11 +536,9 @@ void Formula::canonicalize(const Loop &L) { // If ScaledReg is an invariant with respect to L, find the reg from // BaseRegs containing the recurrent expr related with Loop L. Swap the // reg with ScaledReg. - const SCEVAddRecExpr *SAR = dyn_cast<const SCEVAddRecExpr>(ScaledReg); - if (!SAR || SAR->getLoop() != &L) { - auto I = find_if(BaseRegs, [&](const SCEV *S) { - return isa<const SCEVAddRecExpr>(S) && - (cast<SCEVAddRecExpr>(S)->getLoop() == &L); + if (!containsAddRecDependentOnLoop(ScaledReg, L)) { + auto I = find_if(BaseRegs, [&L](const SCEV *S) { + return containsAddRecDependentOnLoop(S, L); }); if (I != BaseRegs.end()) std::swap(ScaledReg, *I); @@ -1070,7 +1067,7 @@ public: C.ScaleCost = 0; } - bool isLess(Cost &Other); + bool isLess(const Cost &Other); void Lose(); @@ -1358,6 +1355,8 @@ void Cost::RateFormula(const Formula &F, const DenseSet<const SCEV *> &VisitedRegs, const LSRUse &LU, SmallPtrSetImpl<const SCEV *> *LoserRegs) { + if (isLoser()) + return; assert(F.isCanonical(*L) && "Cost is accurate only for canonical formula"); // Tally up the registers. unsigned PrevAddRecCost = C.AddRecCost; @@ -1467,7 +1466,7 @@ void Cost::Lose() { } /// Choose the lower cost. -bool Cost::isLess(Cost &Other) { +bool Cost::isLess(const Cost &Other) { if (InsnsCost.getNumOccurrences() > 0 && InsnsCost && C.Insns != Other.C.Insns) return C.Insns < Other.C.Insns; @@ -4081,23 +4080,24 @@ void LSRInstance::GenerateScales(LSRUse &LU, unsigned LUIdx, Formula Base) { continue; // Divide out the factor, ignoring high bits, since we'll be // scaling the value back up in the end. - if (const SCEV *Quotient = getExactSDiv(AR, FactorS, SE, true)) { - // TODO: This could be optimized to avoid all the copying. - Formula F = Base; - F.ScaledReg = Quotient; - F.deleteBaseReg(F.BaseRegs[i]); - // The canonical representation of 1*reg is reg, which is already in - // Base. In that case, do not try to insert the formula, it will be - // rejected anyway. - if (F.Scale == 1 && (F.BaseRegs.empty() || - (AR->getLoop() != L && LU.AllFixupsOutsideLoop))) - continue; - // If AllFixupsOutsideLoop is true and F.Scale is 1, we may generate - // non canonical Formula with ScaledReg's loop not being L. - if (F.Scale == 1 && LU.AllFixupsOutsideLoop) - F.canonicalize(*L); - (void)InsertFormula(LU, LUIdx, F); - } + if (const SCEV *Quotient = getExactSDiv(AR, FactorS, SE, true)) + if (!Quotient->isZero()) { + // TODO: This could be optimized to avoid all the copying. + Formula F = Base; + F.ScaledReg = Quotient; + F.deleteBaseReg(F.BaseRegs[i]); + // The canonical representation of 1*reg is reg, which is already in + // Base. In that case, do not try to insert the formula, it will be + // rejected anyway. + if (F.Scale == 1 && (F.BaseRegs.empty() || + (AR->getLoop() != L && LU.AllFixupsOutsideLoop))) + continue; + // If AllFixupsOutsideLoop is true and F.Scale is 1, we may generate + // non canonical Formula with ScaledReg's loop not being L. + if (F.Scale == 1 && LU.AllFixupsOutsideLoop) + F.canonicalize(*L); + (void)InsertFormula(LU, LUIdx, F); + } } } } @@ -5601,6 +5601,27 @@ void LSRInstance::Rewrite(const LSRUse &LU, const LSRFixup &LF, DeadInsts.emplace_back(OperandIsInstr); } +// Check if there are any loop exit values which are only used once within the +// loop which may potentially be optimized with a call to rewriteLoopExitValue. +static bool LoopExitValHasSingleUse(Loop *L) { + BasicBlock *ExitBB = L->getExitBlock(); + if (!ExitBB) + return false; + + for (PHINode &ExitPhi : ExitBB->phis()) { + if (ExitPhi.getNumIncomingValues() != 1) + break; + + BasicBlock *Pred = ExitPhi.getIncomingBlock(0); + Value *IVNext = ExitPhi.getIncomingValueForBlock(Pred); + // One use would be the exit phi node, and there should be only one other + // use for this to be considered. + if (IVNext->getNumUses() == 2) + return true; + } + return false; +} + /// Rewrite all the fixup locations with new values, following the chosen /// solution. void LSRInstance::ImplementSolution( @@ -5894,40 +5915,57 @@ void LoopStrengthReduce::getAnalysisUsage(AnalysisUsage &AU) const { } namespace { + +/// Enables more convenient iteration over a DWARF expression vector. +static iterator_range<llvm::DIExpression::expr_op_iterator> +ToDwarfOpIter(SmallVectorImpl<uint64_t> &Expr) { + llvm::DIExpression::expr_op_iterator Begin = + llvm::DIExpression::expr_op_iterator(Expr.begin()); + llvm::DIExpression::expr_op_iterator End = + llvm::DIExpression::expr_op_iterator(Expr.end()); + return {Begin, End}; +} + struct SCEVDbgValueBuilder { SCEVDbgValueBuilder() = default; - SCEVDbgValueBuilder(const SCEVDbgValueBuilder &Base) { - Values = Base.Values; + SCEVDbgValueBuilder(const SCEVDbgValueBuilder &Base) { clone(Base); } + + void clone(const SCEVDbgValueBuilder &Base) { + LocationOps = Base.LocationOps; Expr = Base.Expr; } + void clear() { + LocationOps.clear(); + Expr.clear(); + } + /// The DIExpression as we translate the SCEV. SmallVector<uint64_t, 6> Expr; /// The location ops of the DIExpression. - SmallVector<llvm::ValueAsMetadata *, 2> Values; + SmallVector<Value *, 2> LocationOps; void pushOperator(uint64_t Op) { Expr.push_back(Op); } void pushUInt(uint64_t Operand) { Expr.push_back(Operand); } /// Add a DW_OP_LLVM_arg to the expression, followed by the index of the value /// in the set of values referenced by the expression. - void pushValue(llvm::Value *V) { + void pushLocation(llvm::Value *V) { Expr.push_back(llvm::dwarf::DW_OP_LLVM_arg); - auto *It = - std::find(Values.begin(), Values.end(), llvm::ValueAsMetadata::get(V)); + auto *It = std::find(LocationOps.begin(), LocationOps.end(), V); unsigned ArgIndex = 0; - if (It != Values.end()) { - ArgIndex = std::distance(Values.begin(), It); + if (It != LocationOps.end()) { + ArgIndex = std::distance(LocationOps.begin(), It); } else { - ArgIndex = Values.size(); - Values.push_back(llvm::ValueAsMetadata::get(V)); + ArgIndex = LocationOps.size(); + LocationOps.push_back(V); } Expr.push_back(ArgIndex); } void pushValue(const SCEVUnknown *U) { llvm::Value *V = cast<SCEVUnknown>(U)->getValue(); - pushValue(V); + pushLocation(V); } bool pushConst(const SCEVConstant *C) { @@ -5938,6 +5976,12 @@ struct SCEVDbgValueBuilder { return true; } + // Iterating the expression as DWARF ops is convenient when updating + // DWARF_OP_LLVM_args. + iterator_range<llvm::DIExpression::expr_op_iterator> expr_ops() { + return ToDwarfOpIter(Expr); + } + /// Several SCEV types are sequences of the same arithmetic operator applied /// to constants and values that may be extended or truncated. bool pushArithmeticExpr(const llvm::SCEVCommutativeExpr *CommExpr, @@ -5979,7 +6023,7 @@ struct SCEVDbgValueBuilder { } else if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { if (!U->getValue()) return false; - pushValue(U->getValue()); + pushLocation(U->getValue()); } else if (const SCEVMulExpr *MulRec = dyn_cast<SCEVMulExpr>(S)) { Success &= pushArithmeticExpr(MulRec, llvm::dwarf::DW_OP_mul); @@ -6010,52 +6054,6 @@ struct SCEVDbgValueBuilder { return Success; } - void setFinalExpression(llvm::DbgValueInst &DI, const DIExpression *OldExpr) { - // Re-state assumption that this dbg.value is not variadic. Any remaining - // opcodes in its expression operate on a single value already on the - // expression stack. Prepend our operations, which will re-compute and - // place that value on the expression stack. - assert(!DI.hasArgList()); - auto *NewExpr = - DIExpression::prependOpcodes(OldExpr, Expr, /*StackValue*/ true); - DI.setExpression(NewExpr); - - auto ValArrayRef = llvm::ArrayRef<llvm::ValueAsMetadata *>(Values); - DI.setRawLocation(llvm::DIArgList::get(DI.getContext(), ValArrayRef)); - } - - /// If a DVI can be emitted without a DIArgList, omit DW_OP_llvm_arg and the - /// location op index 0. - void setShortFinalExpression(llvm::DbgValueInst &DI, - const DIExpression *OldExpr) { - assert((Expr[0] == llvm::dwarf::DW_OP_LLVM_arg && Expr[1] == 0) && - "Expected DW_OP_llvm_arg and 0."); - DI.replaceVariableLocationOp( - 0u, llvm::MetadataAsValue::get(DI.getContext(), Values[0])); - - // See setFinalExpression: prepend our opcodes on the start of any old - // expression opcodes. - assert(!DI.hasArgList()); - llvm::SmallVector<uint64_t, 6> FinalExpr(llvm::drop_begin(Expr, 2)); - auto *NewExpr = - DIExpression::prependOpcodes(OldExpr, FinalExpr, /*StackValue*/ true); - DI.setExpression(NewExpr); - } - - /// Once the IV and variable SCEV translation is complete, write it to the - /// source DVI. - void applyExprToDbgValue(llvm::DbgValueInst &DI, - const DIExpression *OldExpr) { - assert(!Expr.empty() && "Unexpected empty expression."); - // Emit a simpler form if only a single location is referenced. - if (Values.size() == 1 && Expr[0] == llvm::dwarf::DW_OP_LLVM_arg && - Expr[1] == 0) { - setShortFinalExpression(DI, OldExpr); - } else { - setFinalExpression(DI, OldExpr); - } - } - /// Return true if the combination of arithmetic operator and underlying /// SCEV constant value is an identity function. bool isIdentityFunction(uint64_t Op, const SCEV *S) { @@ -6104,6 +6102,48 @@ struct SCEVDbgValueBuilder { return true; } + /// Create an expression that is an offset from a value (usually the IV). + void createOffsetExpr(int64_t Offset, Value *OffsetValue) { + pushLocation(OffsetValue); + DIExpression::appendOffset(Expr, Offset); + LLVM_DEBUG( + dbgs() << "scev-salvage: Generated IV offset expression. Offset: " + << std::to_string(Offset) << "\n"); + } + + /// Combine a translation of the SCEV and the IV to create an expression that + /// recovers a location's value. + /// returns true if an expression was created. + bool createIterCountExpr(const SCEV *S, + const SCEVDbgValueBuilder &IterationCount, + ScalarEvolution &SE) { + // SCEVs for SSA values are most frquently of the form + // {start,+,stride}, but sometimes they are ({start,+,stride} + %a + ..). + // This is because %a is a PHI node that is not the IV. However, these + // SCEVs have not been observed to result in debuginfo-lossy optimisations, + // so its not expected this point will be reached. + if (!isa<SCEVAddRecExpr>(S)) + return false; + + LLVM_DEBUG(dbgs() << "scev-salvage: Location to salvage SCEV: " << *S + << '\n'); + + const auto *Rec = cast<SCEVAddRecExpr>(S); + if (!Rec->isAffine()) + return false; + + if (S->getExpressionSize() > MaxSCEVSalvageExpressionSize) + return false; + + // Initialise a new builder with the iteration count expression. In + // combination with the value's SCEV this enables recovery. + clone(IterationCount); + if (!SCEVToValueExpr(*Rec, SE)) + return false; + + return true; + } + /// Convert a SCEV of a value to a DIExpression that is pushed onto the /// builder's expression stack. The stack should already contain an /// expression for the iteration count, so that it can be multiplied by @@ -6133,74 +6173,294 @@ struct SCEVDbgValueBuilder { } return true; } + + // Append the current expression and locations to a location list and an + // expression list. Modify the DW_OP_LLVM_arg indexes to account for + // the locations already present in the destination list. + void appendToVectors(SmallVectorImpl<uint64_t> &DestExpr, + SmallVectorImpl<Value *> &DestLocations) { + assert(!DestLocations.empty() && + "Expected the locations vector to contain the IV"); + // The DWARF_OP_LLVM_arg arguments of the expression being appended must be + // modified to account for the locations already in the destination vector. + // All builders contain the IV as the first location op. + assert(!LocationOps.empty() && + "Expected the location ops to contain the IV."); + // DestIndexMap[n] contains the index in DestLocations for the nth + // location in this SCEVDbgValueBuilder. + SmallVector<uint64_t, 2> DestIndexMap; + for (const auto &Op : LocationOps) { + auto It = find(DestLocations, Op); + if (It != DestLocations.end()) { + // Location already exists in DestLocations, reuse existing ArgIndex. + DestIndexMap.push_back(std::distance(DestLocations.begin(), It)); + continue; + } + // Location is not in DestLocations, add it. + DestIndexMap.push_back(DestLocations.size()); + DestLocations.push_back(Op); + } + + for (const auto &Op : expr_ops()) { + if (Op.getOp() != dwarf::DW_OP_LLVM_arg) { + Op.appendToVector(DestExpr); + continue; + } + + DestExpr.push_back(dwarf::DW_OP_LLVM_arg); + // `DW_OP_LLVM_arg n` represents the nth LocationOp in this SCEV, + // DestIndexMap[n] contains its new index in DestLocations. + uint64_t NewIndex = DestIndexMap[Op.getArg(0)]; + DestExpr.push_back(NewIndex); + } + } }; +/// Holds all the required data to salvage a dbg.value using the pre-LSR SCEVs +/// and DIExpression. struct DVIRecoveryRec { + DVIRecoveryRec(DbgValueInst *DbgValue) + : DVI(DbgValue), Expr(DbgValue->getExpression()), + HadLocationArgList(false) {} + DbgValueInst *DVI; DIExpression *Expr; - Metadata *LocationOp; - const llvm::SCEV *SCEV; + bool HadLocationArgList; + SmallVector<WeakVH, 2> LocationOps; + SmallVector<const llvm::SCEV *, 2> SCEVs; + SmallVector<std::unique_ptr<SCEVDbgValueBuilder>, 2> RecoveryExprs; + + void clear() { + for (auto &RE : RecoveryExprs) + RE.reset(); + RecoveryExprs.clear(); + } + + ~DVIRecoveryRec() { clear(); } }; } // namespace -static void RewriteDVIUsingIterCount(DVIRecoveryRec CachedDVI, - const SCEVDbgValueBuilder &IterationCount, - ScalarEvolution &SE) { - // LSR may add locations to previously single location-op DVIs which - // are currently not supported. - if (CachedDVI.DVI->getNumVariableLocationOps() != 1) - return; +/// Returns the total number of DW_OP_llvm_arg operands in the expression. +/// This helps in determining if a DIArglist is necessary or can be omitted from +/// the dbg.value. +static unsigned numLLVMArgOps(SmallVectorImpl<uint64_t> &Expr) { + auto expr_ops = ToDwarfOpIter(Expr); + unsigned Count = 0; + for (auto Op : expr_ops) + if (Op.getOp() == dwarf::DW_OP_LLVM_arg) + Count++; + return Count; +} + +/// Overwrites DVI with the location and Ops as the DIExpression. This will +/// create an invalid expression if Ops has any dwarf::DW_OP_llvm_arg operands, +/// because a DIArglist is not created for the first argument of the dbg.value. +static void updateDVIWithLocation(DbgValueInst &DVI, Value *Location, + SmallVectorImpl<uint64_t> &Ops) { + assert( + numLLVMArgOps(Ops) == 0 && + "Expected expression that does not contain any DW_OP_llvm_arg operands."); + DVI.setRawLocation(ValueAsMetadata::get(Location)); + DVI.setExpression(DIExpression::get(DVI.getContext(), Ops)); +} + +/// Overwrite DVI with locations placed into a DIArglist. +static void updateDVIWithLocations(DbgValueInst &DVI, + SmallVectorImpl<Value *> &Locations, + SmallVectorImpl<uint64_t> &Ops) { + assert(numLLVMArgOps(Ops) != 0 && + "Expected expression that references DIArglist locations using " + "DW_OP_llvm_arg operands."); + SmallVector<ValueAsMetadata *, 3> MetadataLocs; + for (Value *V : Locations) + MetadataLocs.push_back(ValueAsMetadata::get(V)); + auto ValArrayRef = llvm::ArrayRef<llvm::ValueAsMetadata *>(MetadataLocs); + DVI.setRawLocation(llvm::DIArgList::get(DVI.getContext(), ValArrayRef)); + DVI.setExpression(DIExpression::get(DVI.getContext(), Ops)); +} + +/// Write the new expression and new location ops for the dbg.value. If possible +/// reduce the szie of the dbg.value intrinsic by omitting DIArglist. This +/// can be omitted if: +/// 1. There is only a single location, refenced by a single DW_OP_llvm_arg. +/// 2. The DW_OP_LLVM_arg is the first operand in the expression. +static void UpdateDbgValueInst(DVIRecoveryRec &DVIRec, + SmallVectorImpl<Value *> &NewLocationOps, + SmallVectorImpl<uint64_t> &NewExpr) { + unsigned NumLLVMArgs = numLLVMArgOps(NewExpr); + if (NumLLVMArgs == 0) { + // Location assumed to be on the stack. + updateDVIWithLocation(*DVIRec.DVI, NewLocationOps[0], NewExpr); + } else if (NumLLVMArgs == 1 && NewExpr[0] == dwarf::DW_OP_LLVM_arg) { + // There is only a single DW_OP_llvm_arg at the start of the expression, + // so it can be omitted along with DIArglist. + assert(NewExpr[1] == 0 && + "Lone LLVM_arg in a DIExpression should refer to location-op 0."); + llvm::SmallVector<uint64_t, 6> ShortenedOps(llvm::drop_begin(NewExpr, 2)); + updateDVIWithLocation(*DVIRec.DVI, NewLocationOps[0], ShortenedOps); + } else { + // Multiple DW_OP_llvm_arg, so DIArgList is strictly necessary. + updateDVIWithLocations(*DVIRec.DVI, NewLocationOps, NewExpr); + } + + // If the DIExpression was previously empty then add the stack terminator. + // Non-empty expressions have only had elements inserted into them and so the + // terminator should already be present e.g. stack_value or fragment. + DIExpression *SalvageExpr = DVIRec.DVI->getExpression(); + if (!DVIRec.Expr->isComplex() && SalvageExpr->isComplex()) { + SalvageExpr = DIExpression::append(SalvageExpr, {dwarf::DW_OP_stack_value}); + DVIRec.DVI->setExpression(SalvageExpr); + } +} + +/// Cached location ops may be erased during LSR, in which case an undef is +/// required when restoring from the cache. The type of that location is no +/// longer available, so just use int8. The undef will be replaced by one or +/// more locations later when a SCEVDbgValueBuilder selects alternative +/// locations to use for the salvage. +static Value *getValueOrUndef(WeakVH &VH, LLVMContext &C) { + return (VH) ? VH : UndefValue::get(llvm::Type::getInt8Ty(C)); +} + +/// Restore the DVI's pre-LSR arguments. Substitute undef for any erased values. +static void restorePreTransformState(DVIRecoveryRec &DVIRec) { + LLVM_DEBUG(dbgs() << "scev-salvage: restore dbg.value to pre-LSR state\n" + << "scev-salvage: post-LSR: " << *DVIRec.DVI << '\n'); + assert(DVIRec.Expr && "Expected an expression"); + DVIRec.DVI->setExpression(DVIRec.Expr); + + // Even a single location-op may be inside a DIArgList and referenced with + // DW_OP_LLVM_arg, which is valid only with a DIArgList. + if (!DVIRec.HadLocationArgList) { + assert(DVIRec.LocationOps.size() == 1 && + "Unexpected number of location ops."); + // LSR's unsuccessful salvage attempt may have added DIArgList, which in + // this case was not present before, so force the location back to a single + // uncontained Value. + Value *CachedValue = + getValueOrUndef(DVIRec.LocationOps[0], DVIRec.DVI->getContext()); + DVIRec.DVI->setRawLocation(ValueAsMetadata::get(CachedValue)); + } else { + SmallVector<ValueAsMetadata *, 3> MetadataLocs; + for (WeakVH VH : DVIRec.LocationOps) { + Value *CachedValue = getValueOrUndef(VH, DVIRec.DVI->getContext()); + MetadataLocs.push_back(ValueAsMetadata::get(CachedValue)); + } + auto ValArrayRef = llvm::ArrayRef<llvm::ValueAsMetadata *>(MetadataLocs); + DVIRec.DVI->setRawLocation( + llvm::DIArgList::get(DVIRec.DVI->getContext(), ValArrayRef)); + } + LLVM_DEBUG(dbgs() << "scev-salvage: pre-LSR: " << *DVIRec.DVI << '\n'); +} - // SCEVs for SSA values are most frquently of the form - // {start,+,stride}, but sometimes they are ({start,+,stride} + %a + ..). - // This is because %a is a PHI node that is not the IV. However, these - // SCEVs have not been observed to result in debuginfo-lossy optimisations, - // so its not expected this point will be reached. - if (!isa<SCEVAddRecExpr>(CachedDVI.SCEV)) - return; +static bool SalvageDVI(llvm::Loop *L, ScalarEvolution &SE, + llvm::PHINode *LSRInductionVar, DVIRecoveryRec &DVIRec, + const SCEV *SCEVInductionVar, + SCEVDbgValueBuilder IterCountExpr) { + if (!DVIRec.DVI->isUndef()) + return false; - LLVM_DEBUG(dbgs() << "scev-salvage: Value to salvage SCEV: " - << *CachedDVI.SCEV << '\n'); + // LSR may have caused several changes to the dbg.value in the failed salvage + // attempt. So restore the DIExpression, the location ops and also the + // location ops format, which is always DIArglist for multiple ops, but only + // sometimes for a single op. + restorePreTransformState(DVIRec); + + // LocationOpIndexMap[i] will store the post-LSR location index of + // the non-optimised out location at pre-LSR index i. + SmallVector<int64_t, 2> LocationOpIndexMap; + LocationOpIndexMap.assign(DVIRec.LocationOps.size(), -1); + SmallVector<Value *, 2> NewLocationOps; + NewLocationOps.push_back(LSRInductionVar); + + for (unsigned i = 0; i < DVIRec.LocationOps.size(); i++) { + WeakVH VH = DVIRec.LocationOps[i]; + // Place the locations not optimised out in the list first, avoiding + // inserts later. The map is used to update the DIExpression's + // DW_OP_LLVM_arg arguments as the expression is updated. + if (VH && !isa<UndefValue>(VH)) { + NewLocationOps.push_back(VH); + LocationOpIndexMap[i] = NewLocationOps.size() - 1; + LLVM_DEBUG(dbgs() << "scev-salvage: Location index " << i + << " now at index " << LocationOpIndexMap[i] << "\n"); + continue; + } - const auto *Rec = cast<SCEVAddRecExpr>(CachedDVI.SCEV); - if (!Rec->isAffine()) - return; + // It's possible that a value referred to in the SCEV may have been + // optimised out by LSR. + if (SE.containsErasedValue(DVIRec.SCEVs[i]) || + SE.containsUndefs(DVIRec.SCEVs[i])) { + LLVM_DEBUG(dbgs() << "scev-salvage: SCEV for location at index: " << i + << " refers to a location that is now undef or erased. " + "Salvage abandoned.\n"); + return false; + } - if (CachedDVI.SCEV->getExpressionSize() > MaxSCEVSalvageExpressionSize) - return; + LLVM_DEBUG(dbgs() << "scev-salvage: salvaging location at index " << i + << " with SCEV: " << *DVIRec.SCEVs[i] << "\n"); + + DVIRec.RecoveryExprs[i] = std::make_unique<SCEVDbgValueBuilder>(); + SCEVDbgValueBuilder *SalvageExpr = DVIRec.RecoveryExprs[i].get(); + + // Create an offset-based salvage expression if possible, as it requires + // less DWARF ops than an iteration count-based expression. + if (Optional<APInt> Offset = + SE.computeConstantDifference(DVIRec.SCEVs[i], SCEVInductionVar)) { + if (Offset.getValue().getMinSignedBits() <= 64) + SalvageExpr->createOffsetExpr(Offset.getValue().getSExtValue(), + LSRInductionVar); + } else if (!SalvageExpr->createIterCountExpr(DVIRec.SCEVs[i], IterCountExpr, + SE)) + return false; + } - // Initialise a new builder with the iteration count expression. In - // combination with the value's SCEV this enables recovery. - SCEVDbgValueBuilder RecoverValue(IterationCount); - if (!RecoverValue.SCEVToValueExpr(*Rec, SE)) - return; + // Merge the DbgValueBuilder generated expressions and the original + // DIExpression, place the result into an new vector. + SmallVector<uint64_t, 3> NewExpr; + if (DVIRec.Expr->getNumElements() == 0) { + assert(DVIRec.RecoveryExprs.size() == 1 && + "Expected only a single recovery expression for an empty " + "DIExpression."); + assert(DVIRec.RecoveryExprs[0] && + "Expected a SCEVDbgSalvageBuilder for location 0"); + SCEVDbgValueBuilder *B = DVIRec.RecoveryExprs[0].get(); + B->appendToVectors(NewExpr, NewLocationOps); + } + for (const auto &Op : DVIRec.Expr->expr_ops()) { + // Most Ops needn't be updated. + if (Op.getOp() != dwarf::DW_OP_LLVM_arg) { + Op.appendToVector(NewExpr); + continue; + } - LLVM_DEBUG(dbgs() << "scev-salvage: Updating: " << *CachedDVI.DVI << '\n'); - RecoverValue.applyExprToDbgValue(*CachedDVI.DVI, CachedDVI.Expr); - LLVM_DEBUG(dbgs() << "scev-salvage: to: " << *CachedDVI.DVI << '\n'); -} + uint64_t LocationArgIndex = Op.getArg(0); + SCEVDbgValueBuilder *DbgBuilder = + DVIRec.RecoveryExprs[LocationArgIndex].get(); + // The location doesn't have s SCEVDbgValueBuilder, so LSR did not + // optimise it away. So just translate the argument to the updated + // location index. + if (!DbgBuilder) { + NewExpr.push_back(dwarf::DW_OP_LLVM_arg); + assert(LocationOpIndexMap[Op.getArg(0)] != -1 && + "Expected a positive index for the location-op position."); + NewExpr.push_back(LocationOpIndexMap[Op.getArg(0)]); + continue; + } + // The location has a recovery expression. + DbgBuilder->appendToVectors(NewExpr, NewLocationOps); + } -static void RewriteDVIUsingOffset(DVIRecoveryRec &DVIRec, llvm::PHINode &IV, - int64_t Offset) { - assert(!DVIRec.DVI->hasArgList() && "Expected single location-op dbg.value."); - DbgValueInst *DVI = DVIRec.DVI; - SmallVector<uint64_t, 8> Ops; - DIExpression::appendOffset(Ops, Offset); - DIExpression *Expr = DIExpression::prependOpcodes(DVIRec.Expr, Ops, true); - LLVM_DEBUG(dbgs() << "scev-salvage: Updating: " << *DVIRec.DVI << '\n'); - DVI->setExpression(Expr); - llvm::Value *ValIV = dyn_cast<llvm::Value>(&IV); - DVI->replaceVariableLocationOp( - 0u, llvm::MetadataAsValue::get(DVI->getContext(), - llvm::ValueAsMetadata::get(ValIV))); - LLVM_DEBUG(dbgs() << "scev-salvage: updated with offset to IV: " - << *DVIRec.DVI << '\n'); + UpdateDbgValueInst(DVIRec, NewLocationOps, NewExpr); + LLVM_DEBUG(dbgs() << "scev-salvage: Updated DVI: " << *DVIRec.DVI << "\n"); + return true; } +/// Obtain an expression for the iteration count, then attempt to salvage the +/// dbg.value intrinsics. static void DbgRewriteSalvageableDVIs(llvm::Loop *L, ScalarEvolution &SE, llvm::PHINode *LSRInductionVar, - SmallVector<DVIRecoveryRec, 2> &DVIToUpdate) { + SmallVector<std::unique_ptr<DVIRecoveryRec>, 2> &DVIToUpdate) { if (DVIToUpdate.empty()) return; @@ -6213,49 +6473,22 @@ DbgRewriteSalvageableDVIs(llvm::Loop *L, ScalarEvolution &SE, if (!IVAddRec->isAffine()) return; + // Prevent translation using excessive resources. if (IVAddRec->getExpressionSize() > MaxSCEVSalvageExpressionSize) return; // The iteration count is required to recover location values. SCEVDbgValueBuilder IterCountExpr; - IterCountExpr.pushValue(LSRInductionVar); + IterCountExpr.pushLocation(LSRInductionVar); if (!IterCountExpr.SCEVToIterCountExpr(*IVAddRec, SE)) return; LLVM_DEBUG(dbgs() << "scev-salvage: IV SCEV: " << *SCEVInductionVar << '\n'); - // Needn't salvage if the location op hasn't been undef'd by LSR. for (auto &DVIRec : DVIToUpdate) { - if (!DVIRec.DVI->isUndef()) - continue; - - // Some DVIs that were single location-op when cached are now multi-op, - // due to LSR optimisations. However, multi-op salvaging is not yet - // supported by SCEV salvaging. But, we can attempt a salvage by restoring - // the pre-LSR single-op expression. - if (DVIRec.DVI->hasArgList()) { - if (!DVIRec.DVI->getVariableLocationOp(0)) - continue; - llvm::Type *Ty = DVIRec.DVI->getVariableLocationOp(0)->getType(); - DVIRec.DVI->setRawLocation( - llvm::ValueAsMetadata::get(UndefValue::get(Ty))); - DVIRec.DVI->setExpression(DVIRec.Expr); - } - - LLVM_DEBUG(dbgs() << "scev-salvage: value to recover SCEV: " - << *DVIRec.SCEV << '\n'); - - // Create a simple expression if the IV and value to salvage SCEVs - // start values differ by only a constant value. - if (Optional<APInt> Offset = - SE.computeConstantDifference(DVIRec.SCEV, SCEVInductionVar)) { - if (Offset.getValue().getMinSignedBits() <= 64) - RewriteDVIUsingOffset(DVIRec, *LSRInductionVar, - Offset.getValue().getSExtValue()); - } else { - RewriteDVIUsingIterCount(DVIRec, IterCountExpr, SE); - } + SalvageDVI(L, SE, LSRInductionVar, *DVIRec, SCEVInductionVar, + IterCountExpr); } } } @@ -6263,39 +6496,53 @@ DbgRewriteSalvageableDVIs(llvm::Loop *L, ScalarEvolution &SE, /// Identify and cache salvageable DVI locations and expressions along with the /// corresponding SCEV(s). Also ensure that the DVI is not deleted between /// cacheing and salvaging. -static void -DbgGatherSalvagableDVI(Loop *L, ScalarEvolution &SE, - SmallVector<DVIRecoveryRec, 2> &SalvageableDVISCEVs, - SmallSet<AssertingVH<DbgValueInst>, 2> &DVIHandles) { +static void DbgGatherSalvagableDVI( + Loop *L, ScalarEvolution &SE, + SmallVector<std::unique_ptr<DVIRecoveryRec>, 2> &SalvageableDVISCEVs, + SmallSet<AssertingVH<DbgValueInst>, 2> &DVIHandles) { for (auto &B : L->getBlocks()) { for (auto &I : *B) { auto DVI = dyn_cast<DbgValueInst>(&I); if (!DVI) continue; - + // Ensure that if any location op is undef that the dbg.vlue is not + // cached. if (DVI->isUndef()) continue; - if (DVI->hasArgList()) - continue; + // Check that the location op SCEVs are suitable for translation to + // DIExpression. + const auto &HasTranslatableLocationOps = + [&](const DbgValueInst *DVI) -> bool { + for (const auto LocOp : DVI->location_ops()) { + if (!LocOp) + return false; - if (!DVI->getVariableLocationOp(0) || - !SE.isSCEVable(DVI->getVariableLocationOp(0)->getType())) - continue; + if (!SE.isSCEVable(LocOp->getType())) + return false; - // SCEVUnknown wraps an llvm::Value, it does not have a start and stride. - // Therefore no translation to DIExpression is performed. - const SCEV *S = SE.getSCEV(DVI->getVariableLocationOp(0)); - if (isa<SCEVUnknown>(S)) - continue; + const SCEV *S = SE.getSCEV(LocOp); + if (SE.containsUndefs(S)) + return false; + } + return true; + }; - // Avoid wasting resources generating an expression containing undef. - if (SE.containsUndefs(S)) + if (!HasTranslatableLocationOps(DVI)) continue; - SalvageableDVISCEVs.push_back( - {DVI, DVI->getExpression(), DVI->getRawLocation(), - SE.getSCEV(DVI->getVariableLocationOp(0))}); + std::unique_ptr<DVIRecoveryRec> NewRec = + std::make_unique<DVIRecoveryRec>(DVI); + // Each location Op may need a SCEVDbgValueBuilder in order to recover it. + // Pre-allocating a vector will enable quick lookups of the builder later + // during the salvage. + NewRec->RecoveryExprs.resize(DVI->getNumVariableLocationOps()); + for (const auto LocOp : DVI->location_ops()) { + NewRec->SCEVs.push_back(SE.getSCEV(LocOp)); + NewRec->LocationOps.push_back(LocOp); + NewRec->HadLocationArgList = DVI->hasArgList(); + } + SalvageableDVISCEVs.push_back(std::move(NewRec)); DVIHandles.insert(DVI); } } @@ -6344,9 +6591,9 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, // Debug preservation - before we start removing anything identify which DVI // meet the salvageable criteria and store their DIExpression and SCEVs. - SmallVector<DVIRecoveryRec, 2> SalvageableDVI; + SmallVector<std::unique_ptr<DVIRecoveryRec>, 2> SalvageableDVIRecords; SmallSet<AssertingVH<DbgValueInst>, 2> DVIHandles; - DbgGatherSalvagableDVI(L, SE, SalvageableDVI, DVIHandles); + DbgGatherSalvagableDVI(L, SE, SalvageableDVIRecords, DVIHandles); bool Changed = false; std::unique_ptr<MemorySSAUpdater> MSSAU; @@ -6375,8 +6622,26 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); } } + // LSR may at times remove all uses of an induction variable from a loop. + // The only remaining use is the PHI in the exit block. + // When this is the case, if the exit value of the IV can be calculated using + // SCEV, we can replace the exit block PHI with the final value of the IV and + // skip the updates in each loop iteration. + if (L->isRecursivelyLCSSAForm(DT, LI) && LoopExitValHasSingleUse(L)) { + SmallVector<WeakTrackingVH, 16> DeadInsts; + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + SCEVExpander Rewriter(SE, DL, "lsr", false); + int Rewrites = rewriteLoopExitValues(L, &LI, &TLI, &SE, &TTI, Rewriter, &DT, + OnlyCheapRepl, DeadInsts); + if (Rewrites) { + Changed = true; + RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts, &TLI, + MSSAU.get()); + DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); + } + } - if (SalvageableDVI.empty()) + if (SalvageableDVIRecords.empty()) return Changed; // Obtain relevant IVs and attempt to rewrite the salvageable DVIs with @@ -6384,13 +6649,16 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, // TODO: Allow for multiple IV references for nested AddRecSCEVs for (auto &L : LI) { if (llvm::PHINode *IV = GetInductionVariable(*L, SE, Reducer)) - DbgRewriteSalvageableDVIs(L, SE, IV, SalvageableDVI); + DbgRewriteSalvageableDVIs(L, SE, IV, SalvageableDVIRecords); else { LLVM_DEBUG(dbgs() << "scev-salvage: SCEV salvaging not possible. An IV " "could not be identified.\n"); } } + for (auto &Rec : SalvageableDVIRecords) + Rec->clear(); + SalvageableDVIRecords.clear(); DVIHandles.clear(); return Changed; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp index 1ecbb86724e1..8c2868563227 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp @@ -22,6 +22,7 @@ #include "llvm/Analysis/DependenceAnalysis.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopNestAnalysis.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -42,10 +43,8 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils.h" -#include "llvm/Transforms/Utils/LCSSA.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/LoopPeel.h" -#include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/UnrollLoop.h" #include <cassert> @@ -331,14 +330,23 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, SmallPtrSet<const Value *, 32> EphValues; CodeMetrics::collectEphemeralValues(L, &AC, EphValues); Loop *SubLoop = L->getSubLoops()[0]; - unsigned InnerLoopSize = + InstructionCost InnerLoopSizeIC = ApproximateLoopSize(SubLoop, NumInlineCandidates, NotDuplicatable, Convergent, TTI, EphValues, UP.BEInsns); - unsigned OuterLoopSize = + InstructionCost OuterLoopSizeIC = ApproximateLoopSize(L, NumInlineCandidates, NotDuplicatable, Convergent, TTI, EphValues, UP.BEInsns); - LLVM_DEBUG(dbgs() << " Outer Loop Size: " << OuterLoopSize << "\n"); - LLVM_DEBUG(dbgs() << " Inner Loop Size: " << InnerLoopSize << "\n"); + LLVM_DEBUG(dbgs() << " Outer Loop Size: " << OuterLoopSizeIC << "\n"); + LLVM_DEBUG(dbgs() << " Inner Loop Size: " << InnerLoopSizeIC << "\n"); + + if (!InnerLoopSizeIC.isValid() || !OuterLoopSizeIC.isValid()) { + LLVM_DEBUG(dbgs() << " Not unrolling loop which contains instructions" + << " with invalid cost.\n"); + return LoopUnrollResult::Unmodified; + } + unsigned InnerLoopSize = *InnerLoopSizeIC.getValue(); + unsigned OuterLoopSize = *OuterLoopSizeIC.getValue(); + if (NotDuplicatable) { LLVM_DEBUG(dbgs() << " Not unrolling loop which contains non-duplicatable " "instructions.\n"); @@ -364,7 +372,7 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, Optional<MDNode *> NewInnerEpilogueLoopID = makeFollowupLoopID( OrigOuterLoopID, {LLVMLoopUnrollAndJamFollowupAll, LLVMLoopUnrollAndJamFollowupRemainderInner}); - if (NewInnerEpilogueLoopID.hasValue()) + if (NewInnerEpilogueLoopID) SubLoop->setLoopID(NewInnerEpilogueLoopID.getValue()); // Find trip count and trip multiple @@ -394,14 +402,14 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, Optional<MDNode *> NewOuterEpilogueLoopID = makeFollowupLoopID( OrigOuterLoopID, {LLVMLoopUnrollAndJamFollowupAll, LLVMLoopUnrollAndJamFollowupRemainderOuter}); - if (NewOuterEpilogueLoopID.hasValue()) + if (NewOuterEpilogueLoopID) EpilogueOuterLoop->setLoopID(NewOuterEpilogueLoopID.getValue()); } Optional<MDNode *> NewInnerLoopID = makeFollowupLoopID(OrigOuterLoopID, {LLVMLoopUnrollAndJamFollowupAll, LLVMLoopUnrollAndJamFollowupInner}); - if (NewInnerLoopID.hasValue()) + if (NewInnerLoopID) SubLoop->setLoopID(NewInnerLoopID.getValue()); else SubLoop->setLoopID(OrigSubLoopID); @@ -410,7 +418,7 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, Optional<MDNode *> NewOuterLoopID = makeFollowupLoopID( OrigOuterLoopID, {LLVMLoopUnrollAndJamFollowupAll, LLVMLoopUnrollAndJamFollowupOuter}); - if (NewOuterLoopID.hasValue()) { + if (NewOuterLoopID) { L->setLoopID(NewOuterLoopID.getValue()); // Do not setLoopAlreadyUnrolled if a followup was given. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp index 9beb2281cf0f..fda86afe5f9d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -25,7 +25,6 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CodeMetrics.h" -#include "llvm/Analysis/LazyBlockFrequencyInfo.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" @@ -133,7 +132,7 @@ static cl::opt<bool> UnrollAllowRemainder( "when unrolling a loop.")); static cl::opt<bool> - UnrollRuntime("unroll-runtime", cl::ZeroOrMore, cl::Hidden, + UnrollRuntime("unroll-runtime", cl::Hidden, cl::desc("Unroll loops with run-time trip counts")); static cl::opt<unsigned> UnrollMaxUpperBound( @@ -254,19 +253,19 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences( UP.MaxIterationsCountToAnalyze = UnrollMaxIterationsCountToAnalyze; // Apply user values provided by argument - if (UserThreshold.hasValue()) { + if (UserThreshold) { UP.Threshold = *UserThreshold; UP.PartialThreshold = *UserThreshold; } - if (UserCount.hasValue()) + if (UserCount) UP.Count = *UserCount; - if (UserAllowPartial.hasValue()) + if (UserAllowPartial) UP.Partial = *UserAllowPartial; - if (UserRuntime.hasValue()) + if (UserRuntime) UP.Runtime = *UserRuntime; - if (UserUpperBound.hasValue()) + if (UserUpperBound) UP.UpperBound = *UserUpperBound; - if (UserFullUnrollMaxCount.hasValue()) + if (UserFullUnrollMaxCount) UP.FullUnrollMaxCount = *UserFullUnrollMaxCount; return UP; @@ -664,7 +663,7 @@ static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost( } /// ApproximateLoopSize - Approximate the size of the loop. -unsigned llvm::ApproximateLoopSize( +InstructionCost llvm::ApproximateLoopSize( const Loop *L, unsigned &NumCalls, bool &NotDuplicatable, bool &Convergent, const TargetTransformInfo &TTI, const SmallPtrSetImpl<const Value *> &EphValues, unsigned BEInsns) { @@ -675,7 +674,7 @@ unsigned llvm::ApproximateLoopSize( NotDuplicatable = Metrics.notDuplicatable; Convergent = Metrics.convergent; - unsigned LoopSize = Metrics.NumInsts; + InstructionCost LoopSize = Metrics.NumInsts; // Don't allow an estimate of size zero. This would allows unrolling of loops // with huge iteration counts, which is a compile time problem even if it's @@ -683,7 +682,9 @@ unsigned llvm::ApproximateLoopSize( // that each loop has at least three instructions (likely a conditional // branch, a comparison feeding that branch, and some kind of loop increment // feeding that comparison instruction). - LoopSize = std::max(LoopSize, BEInsns + 1); + if (LoopSize.isValid() && *LoopSize.getValue() < BEInsns + 1) + // This is an open coded max() on InstructionCost + LoopSize = BEInsns + 1; return LoopSize; } @@ -788,15 +789,13 @@ shouldPragmaUnroll(Loop *L, const PragmaInfo &PInfo, // 2nd priority is unroll count set by pragma. if (PInfo.PragmaCount > 0) { - if ((UP.AllowRemainder || (TripMultiple % PInfo.PragmaCount == 0)) && - UCE.getUnrolledLoopSize(UP, PInfo.PragmaCount) < PragmaUnrollThreshold) + if ((UP.AllowRemainder || (TripMultiple % PInfo.PragmaCount == 0))) return PInfo.PragmaCount; } - if (PInfo.PragmaFullUnroll && TripCount != 0) { - if (UCE.getUnrolledLoopSize(UP, TripCount) < PragmaUnrollThreshold) - return TripCount; - } + if (PInfo.PragmaFullUnroll && TripCount != 0) + return TripCount; + // if didn't return until here, should continue to other priorties return None; } @@ -912,7 +911,7 @@ bool llvm::computeUnrollCount( if (PP.PeelCount) { if (UnrollCount.getNumOccurrences() > 0) { report_fatal_error("Cannot specify both explicit peel count and " - "explicit unroll count"); + "explicit unroll count", /*GenCrashDiag=*/false); } UP.Count = 1; UP.Runtime = false; @@ -1192,10 +1191,18 @@ static LoopUnrollResult tryToUnrollLoop( SmallPtrSet<const Value *, 32> EphValues; CodeMetrics::collectEphemeralValues(L, &AC, EphValues); - unsigned LoopSize = + InstructionCost LoopSizeIC = ApproximateLoopSize(L, NumInlineCandidates, NotDuplicatable, Convergent, TTI, EphValues, UP.BEInsns); - LLVM_DEBUG(dbgs() << " Loop Size = " << LoopSize << "\n"); + LLVM_DEBUG(dbgs() << " Loop Size = " << LoopSizeIC << "\n"); + + if (!LoopSizeIC.isValid()) { + LLVM_DEBUG(dbgs() << " Not unrolling loop which contains instructions" + << " with invalid cost.\n"); + return LoopUnrollResult::Unmodified; + } + unsigned LoopSize = *LoopSizeIC.getValue(); + if (NotDuplicatable) { LLVM_DEBUG(dbgs() << " Not unrolling loop which contains non-duplicatable" << " instructions.\n"); @@ -1316,7 +1323,7 @@ static LoopUnrollResult tryToUnrollLoop( Optional<MDNode *> RemainderLoopID = makeFollowupLoopID(OrigLoopID, {LLVMLoopUnrollFollowupAll, LLVMLoopUnrollFollowupRemainder}); - if (RemainderLoopID.hasValue()) + if (RemainderLoopID) RemainderLoop->setLoopID(RemainderLoopID.getValue()); } @@ -1324,7 +1331,7 @@ static LoopUnrollResult tryToUnrollLoop( Optional<MDNode *> NewLoopID = makeFollowupLoopID(OrigLoopID, {LLVMLoopUnrollFollowupAll, LLVMLoopUnrollFollowupUnrolled}); - if (NewLoopID.hasValue()) { + if (NewLoopID) { L->setLoopID(NewLoopID.getValue()); // Do not setLoopAlreadyUnrolled if loop attributes have been specified @@ -1548,8 +1555,12 @@ PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM, PreservedAnalyses LoopUnrollPass::run(Function &F, FunctionAnalysisManager &AM) { - auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); auto &LI = AM.getResult<LoopAnalysis>(F); + // There are no loops in the function. Return before computing other expensive + // analyses. + if (LI.empty()) + return PreservedAnalyses::all(); + auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); auto &TTI = AM.getResult<TargetIRAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp deleted file mode 100644 index 76bb5497c2c2..000000000000 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp +++ /dev/null @@ -1,1774 +0,0 @@ -//===- LoopUnswitch.cpp - Hoist loop-invariant conditionals in loop -------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This pass transforms loops that contain branches on loop-invariant conditions -// to multiple loops. For example, it turns the left into the right code: -// -// for (...) if (lic) -// A for (...) -// if (lic) A; B; C -// B else -// C for (...) -// A; C -// -// This can increase the size of the code exponentially (doubling it every time -// a loop is unswitched) so we only unswitch if the resultant code will be -// smaller than a threshold. -// -// This pass expects LICM to be run before it to hoist invariant conditions out -// of the loop, to make the unswitching opportunity obvious. -// -//===----------------------------------------------------------------------===// - -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/CodeMetrics.h" -#include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/LazyBlockFrequencyInfo.h" -#include "llvm/Analysis/LegacyDivergenceAnalysis.h" -#include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/LoopIterator.h" -#include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/MemorySSA.h" -#include "llvm/Analysis/MemorySSAUpdater.h" -#include "llvm/Analysis/MustExecute.h" -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/IR/Attributes.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Constant.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/Dominators.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instruction.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/User.h" -#include "llvm/IR/Value.h" -#include "llvm/IR/ValueHandle.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Scalar/LoopPassManager.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Cloning.h" -#include "llvm/Transforms/Utils/Local.h" -#include "llvm/Transforms/Utils/LoopUtils.h" -#include "llvm/Transforms/Utils/ValueMapper.h" -#include <algorithm> -#include <cassert> -#include <map> -#include <set> -#include <tuple> -#include <utility> -#include <vector> - -using namespace llvm; - -#define DEBUG_TYPE "loop-unswitch" - -STATISTIC(NumBranches, "Number of branches unswitched"); -STATISTIC(NumSwitches, "Number of switches unswitched"); -STATISTIC(NumGuards, "Number of guards unswitched"); -STATISTIC(NumSelects , "Number of selects unswitched"); -STATISTIC(NumTrivial , "Number of unswitches that are trivial"); -STATISTIC(NumSimplify, "Number of simplifications of unswitched code"); -STATISTIC(TotalInsts, "Total number of instructions analyzed"); - -// The specific value of 100 here was chosen based only on intuition and a -// few specific examples. -static cl::opt<unsigned> -Threshold("loop-unswitch-threshold", cl::desc("Max loop size to unswitch"), - cl::init(100), cl::Hidden); - -static cl::opt<unsigned> - MSSAThreshold("loop-unswitch-memoryssa-threshold", - cl::desc("Max number of memory uses to explore during " - "partial unswitching analysis"), - cl::init(100), cl::Hidden); - -namespace { - - class LUAnalysisCache { - using UnswitchedValsMap = - DenseMap<const SwitchInst *, SmallPtrSet<const Value *, 8>>; - using UnswitchedValsIt = UnswitchedValsMap::iterator; - - struct LoopProperties { - unsigned CanBeUnswitchedCount; - unsigned WasUnswitchedCount; - unsigned SizeEstimation; - UnswitchedValsMap UnswitchedVals; - }; - - // Here we use std::map instead of DenseMap, since we need to keep valid - // LoopProperties pointer for current loop for better performance. - using LoopPropsMap = std::map<const Loop *, LoopProperties>; - using LoopPropsMapIt = LoopPropsMap::iterator; - - LoopPropsMap LoopsProperties; - UnswitchedValsMap *CurLoopInstructions = nullptr; - LoopProperties *CurrentLoopProperties = nullptr; - - // A loop unswitching with an estimated cost above this threshold - // is not performed. MaxSize is turned into unswitching quota for - // the current loop, and reduced correspondingly, though note that - // the quota is returned by releaseMemory() when the loop has been - // processed, so that MaxSize will return to its previous - // value. So in most cases MaxSize will equal the Threshold flag - // when a new loop is processed. An exception to that is that - // MaxSize will have a smaller value while processing nested loops - // that were introduced due to loop unswitching of an outer loop. - // - // FIXME: The way that MaxSize works is subtle and depends on the - // pass manager processing loops and calling releaseMemory() in a - // specific order. It would be good to find a more straightforward - // way of doing what MaxSize does. - unsigned MaxSize; - - public: - LUAnalysisCache() : MaxSize(Threshold) {} - - // Analyze loop. Check its size, calculate is it possible to unswitch - // it. Returns true if we can unswitch this loop. - bool countLoop(const Loop *L, const TargetTransformInfo &TTI, - AssumptionCache *AC); - - // Clean all data related to given loop. - void forgetLoop(const Loop *L); - - // Mark case value as unswitched. - // Since SI instruction can be partly unswitched, in order to avoid - // extra unswitching in cloned loops keep track all unswitched values. - void setUnswitched(const SwitchInst *SI, const Value *V); - - // Check was this case value unswitched before or not. - bool isUnswitched(const SwitchInst *SI, const Value *V); - - // Returns true if another unswitching could be done within the cost - // threshold. - bool costAllowsUnswitching(); - - // Clone all loop-unswitch related loop properties. - // Redistribute unswitching quotas. - // Note, that new loop data is stored inside the VMap. - void cloneData(const Loop *NewLoop, const Loop *OldLoop, - const ValueToValueMapTy &VMap); - }; - - class LoopUnswitch : public LoopPass { - LoopInfo *LI; // Loop information - LPPassManager *LPM; - AssumptionCache *AC; - - // Used to check if second loop needs processing after - // rewriteLoopBodyWithConditionConstant rewrites first loop. - std::vector<Loop*> LoopProcessWorklist; - - LUAnalysisCache BranchesInfo; - - bool OptimizeForSize; - bool RedoLoop = false; - - Loop *CurrentLoop = nullptr; - DominatorTree *DT = nullptr; - MemorySSA *MSSA = nullptr; - AAResults *AA = nullptr; - std::unique_ptr<MemorySSAUpdater> MSSAU; - BasicBlock *LoopHeader = nullptr; - BasicBlock *LoopPreheader = nullptr; - - bool SanitizeMemory; - SimpleLoopSafetyInfo SafetyInfo; - - // LoopBlocks contains all of the basic blocks of the loop, including the - // preheader of the loop, the body of the loop, and the exit blocks of the - // loop, in that order. - std::vector<BasicBlock*> LoopBlocks; - // NewBlocks contained cloned copy of basic blocks from LoopBlocks. - std::vector<BasicBlock*> NewBlocks; - - bool HasBranchDivergence; - - public: - static char ID; // Pass ID, replacement for typeid - - explicit LoopUnswitch(bool Os = false, bool HasBranchDivergence = false) - : LoopPass(ID), OptimizeForSize(Os), - HasBranchDivergence(HasBranchDivergence) { - initializeLoopUnswitchPass(*PassRegistry::getPassRegistry()); - } - - bool runOnLoop(Loop *L, LPPassManager &LPM) override; - bool processCurrentLoop(); - bool isUnreachableDueToPreviousUnswitching(BasicBlock *); - - /// This transformation requires natural loop information & requires that - /// loop preheaders be inserted into the CFG. - /// - void getAnalysisUsage(AnalysisUsage &AU) const override { - // Lazy BFI and BPI are marked as preserved here so Loop Unswitching - // can remain part of the same loop pass as LICM - AU.addPreserved<LazyBlockFrequencyInfoPass>(); - AU.addPreserved<LazyBranchProbabilityInfoPass>(); - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addRequired<MemorySSAWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - if (HasBranchDivergence) - AU.addRequired<LegacyDivergenceAnalysis>(); - getLoopAnalysisUsage(AU); - } - - private: - void releaseMemory() override { BranchesInfo.forgetLoop(CurrentLoop); } - - void initLoopData() { - LoopHeader = CurrentLoop->getHeader(); - LoopPreheader = CurrentLoop->getLoopPreheader(); - } - - /// Split all of the edges from inside the loop to their exit blocks. - /// Update the appropriate Phi nodes as we do so. - void splitExitEdges(Loop *L, - const SmallVectorImpl<BasicBlock *> &ExitBlocks); - - bool tryTrivialLoopUnswitch(bool &Changed); - - bool unswitchIfProfitable(Value *LoopCond, Constant *Val, - Instruction *TI = nullptr, - ArrayRef<Instruction *> ToDuplicate = {}); - void unswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, - BasicBlock *ExitBlock, Instruction *TI); - void unswitchNontrivialCondition(Value *LIC, Constant *OnVal, Loop *L, - Instruction *TI, - ArrayRef<Instruction *> ToDuplicate = {}); - - void rewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, - Constant *Val, bool IsEqual); - - void - emitPreheaderBranchOnCondition(Value *LIC, Constant *Val, - BasicBlock *TrueDest, BasicBlock *FalseDest, - BranchInst *OldBranch, Instruction *TI, - ArrayRef<Instruction *> ToDuplicate = {}); - - void simplifyCode(std::vector<Instruction *> &Worklist, Loop *L); - - /// Given that the Invariant is not equal to Val. Simplify instructions - /// in the loop. - Value *simplifyInstructionWithNotEqual(Instruction *Inst, Value *Invariant, - Constant *Val); - }; - -} // end anonymous namespace - -// Analyze loop. Check its size, calculate is it possible to unswitch -// it. Returns true if we can unswitch this loop. -bool LUAnalysisCache::countLoop(const Loop *L, const TargetTransformInfo &TTI, - AssumptionCache *AC) { - LoopPropsMapIt PropsIt; - bool Inserted; - std::tie(PropsIt, Inserted) = - LoopsProperties.insert(std::make_pair(L, LoopProperties())); - - LoopProperties &Props = PropsIt->second; - - if (Inserted) { - // New loop. - - // Limit the number of instructions to avoid causing significant code - // expansion, and the number of basic blocks, to avoid loops with - // large numbers of branches which cause loop unswitching to go crazy. - // This is a very ad-hoc heuristic. - - SmallPtrSet<const Value *, 32> EphValues; - CodeMetrics::collectEphemeralValues(L, AC, EphValues); - - // FIXME: This is overly conservative because it does not take into - // consideration code simplification opportunities and code that can - // be shared by the resultant unswitched loops. - CodeMetrics Metrics; - for (BasicBlock *BB : L->blocks()) - Metrics.analyzeBasicBlock(BB, TTI, EphValues); - - Props.SizeEstimation = Metrics.NumInsts; - Props.CanBeUnswitchedCount = MaxSize / (Props.SizeEstimation); - Props.WasUnswitchedCount = 0; - MaxSize -= Props.SizeEstimation * Props.CanBeUnswitchedCount; - - if (Metrics.notDuplicatable) { - LLVM_DEBUG(dbgs() << "NOT unswitching loop %" << L->getHeader()->getName() - << ", contents cannot be " - << "duplicated!\n"); - return false; - } - } - - // Be careful. This links are good only before new loop addition. - CurrentLoopProperties = &Props; - CurLoopInstructions = &Props.UnswitchedVals; - - return true; -} - -// Clean all data related to given loop. -void LUAnalysisCache::forgetLoop(const Loop *L) { - LoopPropsMapIt LIt = LoopsProperties.find(L); - - if (LIt != LoopsProperties.end()) { - LoopProperties &Props = LIt->second; - MaxSize += (Props.CanBeUnswitchedCount + Props.WasUnswitchedCount) * - Props.SizeEstimation; - LoopsProperties.erase(LIt); - } - - CurrentLoopProperties = nullptr; - CurLoopInstructions = nullptr; -} - -// Mark case value as unswitched. -// Since SI instruction can be partly unswitched, in order to avoid -// extra unswitching in cloned loops keep track all unswitched values. -void LUAnalysisCache::setUnswitched(const SwitchInst *SI, const Value *V) { - (*CurLoopInstructions)[SI].insert(V); -} - -// Check was this case value unswitched before or not. -bool LUAnalysisCache::isUnswitched(const SwitchInst *SI, const Value *V) { - return (*CurLoopInstructions)[SI].count(V); -} - -bool LUAnalysisCache::costAllowsUnswitching() { - return CurrentLoopProperties->CanBeUnswitchedCount > 0; -} - -// Clone all loop-unswitch related loop properties. -// Redistribute unswitching quotas. -// Note, that new loop data is stored inside the VMap. -void LUAnalysisCache::cloneData(const Loop *NewLoop, const Loop *OldLoop, - const ValueToValueMapTy &VMap) { - LoopProperties &NewLoopProps = LoopsProperties[NewLoop]; - LoopProperties &OldLoopProps = *CurrentLoopProperties; - UnswitchedValsMap &Insts = OldLoopProps.UnswitchedVals; - - // Reallocate "can-be-unswitched quota" - - --OldLoopProps.CanBeUnswitchedCount; - ++OldLoopProps.WasUnswitchedCount; - NewLoopProps.WasUnswitchedCount = 0; - unsigned Quota = OldLoopProps.CanBeUnswitchedCount; - NewLoopProps.CanBeUnswitchedCount = Quota / 2; - OldLoopProps.CanBeUnswitchedCount = Quota - Quota / 2; - - NewLoopProps.SizeEstimation = OldLoopProps.SizeEstimation; - - // Clone unswitched values info: - // for new loop switches we clone info about values that was - // already unswitched and has redundant successors. - for (const auto &I : Insts) { - const SwitchInst *OldInst = I.first; - Value *NewI = VMap.lookup(OldInst); - const SwitchInst *NewInst = cast_or_null<SwitchInst>(NewI); - assert(NewInst && "All instructions that are in SrcBB must be in VMap."); - - NewLoopProps.UnswitchedVals[NewInst] = OldLoopProps.UnswitchedVals[OldInst]; - } -} - -char LoopUnswitch::ID = 0; - -INITIALIZE_PASS_BEGIN(LoopUnswitch, "loop-unswitch", "Unswitch loops", - false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(LoopPass) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LegacyDivergenceAnalysis) -INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) -INITIALIZE_PASS_END(LoopUnswitch, "loop-unswitch", "Unswitch loops", - false, false) - -Pass *llvm::createLoopUnswitchPass(bool Os, bool HasBranchDivergence) { - return new LoopUnswitch(Os, HasBranchDivergence); -} - -/// Operator chain lattice. -enum OperatorChain { - OC_OpChainNone, ///< There is no operator. - OC_OpChainOr, ///< There are only ORs. - OC_OpChainAnd, ///< There are only ANDs. - OC_OpChainMixed ///< There are ANDs and ORs. -}; - -/// Cond is a condition that occurs in L. If it is invariant in the loop, or has -/// an invariant piece, return the invariant. Otherwise, return null. -// -/// NOTE: findLIVLoopCondition will not return a partial LIV by walking up a -/// mixed operator chain, as we can not reliably find a value which will -/// simplify the operator chain. If the chain is AND-only or OR-only, we can use -/// 0 or ~0 to simplify the chain. -/// -/// NOTE: In case a partial LIV and a mixed operator chain, we may be able to -/// simplify the condition itself to a loop variant condition, but at the -/// cost of creating an entirely new loop. -static Value *findLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, - OperatorChain &ParentChain, - DenseMap<Value *, Value *> &Cache, - MemorySSAUpdater *MSSAU) { - auto CacheIt = Cache.find(Cond); - if (CacheIt != Cache.end()) - return CacheIt->second; - - // We started analyze new instruction, increment scanned instructions counter. - ++TotalInsts; - - // We can never unswitch on vector conditions. - if (Cond->getType()->isVectorTy()) - return nullptr; - - // Constants should be folded, not unswitched on! - if (isa<Constant>(Cond)) return nullptr; - - // TODO: Handle: br (VARIANT|INVARIANT). - - // Hoist simple values out. - if (L->makeLoopInvariant(Cond, Changed, nullptr, MSSAU)) { - Cache[Cond] = Cond; - return Cond; - } - - // Walk up the operator chain to find partial invariant conditions. - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Cond)) - if (BO->getOpcode() == Instruction::And || - BO->getOpcode() == Instruction::Or) { - // Given the previous operator, compute the current operator chain status. - OperatorChain NewChain; - switch (ParentChain) { - case OC_OpChainNone: - NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd : - OC_OpChainOr; - break; - case OC_OpChainOr: - NewChain = BO->getOpcode() == Instruction::Or ? OC_OpChainOr : - OC_OpChainMixed; - break; - case OC_OpChainAnd: - NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd : - OC_OpChainMixed; - break; - case OC_OpChainMixed: - NewChain = OC_OpChainMixed; - break; - } - - // If we reach a Mixed state, we do not want to keep walking up as we can not - // reliably find a value that will simplify the chain. With this check, we - // will return null on the first sight of mixed chain and the caller will - // either backtrack to find partial LIV in other operand or return null. - if (NewChain != OC_OpChainMixed) { - // Update the current operator chain type before we search up the chain. - ParentChain = NewChain; - // If either the left or right side is invariant, we can unswitch on this, - // which will cause the branch to go away in one loop and the condition to - // simplify in the other one. - if (Value *LHS = findLIVLoopCondition(BO->getOperand(0), L, Changed, - ParentChain, Cache, MSSAU)) { - Cache[Cond] = LHS; - return LHS; - } - // We did not manage to find a partial LIV in operand(0). Backtrack and try - // operand(1). - ParentChain = NewChain; - if (Value *RHS = findLIVLoopCondition(BO->getOperand(1), L, Changed, - ParentChain, Cache, MSSAU)) { - Cache[Cond] = RHS; - return RHS; - } - } - } - - Cache[Cond] = nullptr; - return nullptr; -} - -/// Cond is a condition that occurs in L. If it is invariant in the loop, or has -/// an invariant piece, return the invariant along with the operator chain type. -/// Otherwise, return null. -static std::pair<Value *, OperatorChain> -findLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, - MemorySSAUpdater *MSSAU) { - DenseMap<Value *, Value *> Cache; - OperatorChain OpChain = OC_OpChainNone; - Value *FCond = findLIVLoopCondition(Cond, L, Changed, OpChain, Cache, MSSAU); - - // In case we do find a LIV, it can not be obtained by walking up a mixed - // operator chain. - assert((!FCond || OpChain != OC_OpChainMixed) && - "Do not expect a partial LIV with mixed operator chain"); - return {FCond, OpChain}; -} - -bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPMRef) { - if (skipLoop(L)) - return false; - - AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache( - *L->getHeader()->getParent()); - LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - LPM = &LPMRef; - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); - MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); - CurrentLoop = L; - Function *F = CurrentLoop->getHeader()->getParent(); - - SanitizeMemory = F->hasFnAttribute(Attribute::SanitizeMemory); - if (SanitizeMemory) - SafetyInfo.computeLoopSafetyInfo(L); - - if (VerifyMemorySSA) - MSSA->verifyMemorySSA(); - - bool Changed = false; - do { - assert(CurrentLoop->isLCSSAForm(*DT)); - if (VerifyMemorySSA) - MSSA->verifyMemorySSA(); - RedoLoop = false; - Changed |= processCurrentLoop(); - } while (RedoLoop); - - if (VerifyMemorySSA) - MSSA->verifyMemorySSA(); - - return Changed; -} - -// Return true if the BasicBlock BB is unreachable from the loop header. -// Return false, otherwise. -bool LoopUnswitch::isUnreachableDueToPreviousUnswitching(BasicBlock *BB) { - auto *Node = DT->getNode(BB)->getIDom(); - BasicBlock *DomBB = Node->getBlock(); - while (CurrentLoop->contains(DomBB)) { - BranchInst *BInst = dyn_cast<BranchInst>(DomBB->getTerminator()); - - Node = DT->getNode(DomBB)->getIDom(); - DomBB = Node->getBlock(); - - if (!BInst || !BInst->isConditional()) - continue; - - Value *Cond = BInst->getCondition(); - if (!isa<ConstantInt>(Cond)) - continue; - - BasicBlock *UnreachableSucc = - Cond == ConstantInt::getTrue(Cond->getContext()) - ? BInst->getSuccessor(1) - : BInst->getSuccessor(0); - - if (DT->dominates(UnreachableSucc, BB)) - return true; - } - return false; -} - -/// FIXME: Remove this workaround when freeze related patches are done. -/// LoopUnswitch and Equality propagation in GVN have discrepancy about -/// whether branch on undef/poison has undefine behavior. Here it is to -/// rule out some common cases that we found such discrepancy already -/// causing problems. Detail could be found in PR31652. Note if the -/// func returns true, it is unsafe. But if it is false, it doesn't mean -/// it is necessarily safe. -static bool equalityPropUnSafe(Value &LoopCond) { - ICmpInst *CI = dyn_cast<ICmpInst>(&LoopCond); - if (!CI || !CI->isEquality()) - return false; - - Value *LHS = CI->getOperand(0); - Value *RHS = CI->getOperand(1); - if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) - return true; - - auto HasUndefInPHI = [](PHINode &PN) { - for (Value *Opd : PN.incoming_values()) { - if (isa<UndefValue>(Opd)) - return true; - } - return false; - }; - PHINode *LPHI = dyn_cast<PHINode>(LHS); - PHINode *RPHI = dyn_cast<PHINode>(RHS); - if ((LPHI && HasUndefInPHI(*LPHI)) || (RPHI && HasUndefInPHI(*RPHI))) - return true; - - auto HasUndefInSelect = [](SelectInst &SI) { - if (isa<UndefValue>(SI.getTrueValue()) || - isa<UndefValue>(SI.getFalseValue())) - return true; - return false; - }; - SelectInst *LSI = dyn_cast<SelectInst>(LHS); - SelectInst *RSI = dyn_cast<SelectInst>(RHS); - if ((LSI && HasUndefInSelect(*LSI)) || (RSI && HasUndefInSelect(*RSI))) - return true; - return false; -} - -/// Do actual work and unswitch loop if possible and profitable. -bool LoopUnswitch::processCurrentLoop() { - bool Changed = false; - - initLoopData(); - - // If LoopSimplify was unable to form a preheader, don't do any unswitching. - if (!LoopPreheader) - return false; - - // Loops with indirectbr cannot be cloned. - if (!CurrentLoop->isSafeToClone()) - return false; - - // Without dedicated exits, splitting the exit edge may fail. - if (!CurrentLoop->hasDedicatedExits()) - return false; - - LLVMContext &Context = LoopHeader->getContext(); - - // Analyze loop cost, and stop unswitching if loop content can not be duplicated. - if (!BranchesInfo.countLoop( - CurrentLoop, - getAnalysis<TargetTransformInfoWrapperPass>().getTTI( - *CurrentLoop->getHeader()->getParent()), - AC)) - return false; - - // Try trivial unswitch first before loop over other basic blocks in the loop. - if (tryTrivialLoopUnswitch(Changed)) { - return true; - } - - // Do not do non-trivial unswitch while optimizing for size. - // FIXME: Use Function::hasOptSize(). - if (OptimizeForSize || - LoopHeader->getParent()->hasFnAttribute(Attribute::OptimizeForSize)) - return Changed; - - // Run through the instructions in the loop, keeping track of three things: - // - // - That we do not unswitch loops containing convergent operations, as we - // might be making them control dependent on the unswitch value when they - // were not before. - // FIXME: This could be refined to only bail if the convergent operation is - // not already control-dependent on the unswitch value. - // - // - That basic blocks in the loop contain invokes whose predecessor edges we - // cannot split. - // - // - The set of guard intrinsics encountered (these are non terminator - // instructions that are also profitable to be unswitched). - - SmallVector<IntrinsicInst *, 4> Guards; - - for (const auto BB : CurrentLoop->blocks()) { - for (auto &I : *BB) { - auto *CB = dyn_cast<CallBase>(&I); - if (!CB) - continue; - if (CB->isConvergent()) - return Changed; - if (auto *II = dyn_cast<InvokeInst>(&I)) - if (!II->getUnwindDest()->canSplitPredecessors()) - return Changed; - if (auto *II = dyn_cast<IntrinsicInst>(&I)) - if (II->getIntrinsicID() == Intrinsic::experimental_guard) - Guards.push_back(II); - } - } - - for (IntrinsicInst *Guard : Guards) { - Value *LoopCond = findLIVLoopCondition(Guard->getOperand(0), CurrentLoop, - Changed, MSSAU.get()) - .first; - if (LoopCond && - unswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { - // NB! Unswitching (if successful) could have erased some of the - // instructions in Guards leaving dangling pointers there. This is fine - // because we're returning now, and won't look at Guards again. - ++NumGuards; - return true; - } - } - - // Loop over all of the basic blocks in the loop. If we find an interior - // block that is branching on a loop-invariant condition, we can unswitch this - // loop. - for (Loop::block_iterator I = CurrentLoop->block_begin(), - E = CurrentLoop->block_end(); - I != E; ++I) { - Instruction *TI = (*I)->getTerminator(); - - // Unswitching on a potentially uninitialized predicate is not - // MSan-friendly. Limit this to the cases when the original predicate is - // guaranteed to execute, to avoid creating a use-of-uninitialized-value - // in the code that did not have one. - // This is a workaround for the discrepancy between LLVM IR and MSan - // semantics. See PR28054 for more details. - if (SanitizeMemory && - !SafetyInfo.isGuaranteedToExecute(*TI, DT, CurrentLoop)) - continue; - - if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { - // Some branches may be rendered unreachable because of previous - // unswitching. - // Unswitch only those branches that are reachable. - if (isUnreachableDueToPreviousUnswitching(*I)) - continue; - - // If this isn't branching on an invariant condition, we can't unswitch - // it. - if (BI->isConditional()) { - // See if this, or some part of it, is loop invariant. If so, we can - // unswitch on it if we desire. - Value *LoopCond = findLIVLoopCondition(BI->getCondition(), CurrentLoop, - Changed, MSSAU.get()) - .first; - if (LoopCond && !equalityPropUnSafe(*LoopCond) && - unswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) { - ++NumBranches; - return true; - } - } - } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { - Value *SC = SI->getCondition(); - Value *LoopCond; - OperatorChain OpChain; - std::tie(LoopCond, OpChain) = - findLIVLoopCondition(SC, CurrentLoop, Changed, MSSAU.get()); - - unsigned NumCases = SI->getNumCases(); - if (LoopCond && NumCases) { - // Find a value to unswitch on: - // FIXME: this should chose the most expensive case! - // FIXME: scan for a case with a non-critical edge? - Constant *UnswitchVal = nullptr; - // Find a case value such that at least one case value is unswitched - // out. - if (OpChain == OC_OpChainAnd) { - // If the chain only has ANDs and the switch has a case value of 0. - // Dropping in a 0 to the chain will unswitch out the 0-casevalue. - auto *AllZero = cast<ConstantInt>(Constant::getNullValue(SC->getType())); - if (BranchesInfo.isUnswitched(SI, AllZero)) - continue; - // We are unswitching 0 out. - UnswitchVal = AllZero; - } else if (OpChain == OC_OpChainOr) { - // If the chain only has ORs and the switch has a case value of ~0. - // Dropping in a ~0 to the chain will unswitch out the ~0-casevalue. - auto *AllOne = cast<ConstantInt>(Constant::getAllOnesValue(SC->getType())); - if (BranchesInfo.isUnswitched(SI, AllOne)) - continue; - // We are unswitching ~0 out. - UnswitchVal = AllOne; - } else { - assert(OpChain == OC_OpChainNone && - "Expect to unswitch on trivial chain"); - // Do not process same value again and again. - // At this point we have some cases already unswitched and - // some not yet unswitched. Let's find the first not yet unswitched one. - for (auto Case : SI->cases()) { - Constant *UnswitchValCandidate = Case.getCaseValue(); - if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) { - UnswitchVal = UnswitchValCandidate; - break; - } - } - } - - if (!UnswitchVal) - continue; - - if (unswitchIfProfitable(LoopCond, UnswitchVal)) { - ++NumSwitches; - // In case of a full LIV, UnswitchVal is the value we unswitched out. - // In case of a partial LIV, we only unswitch when its an AND-chain - // or OR-chain. In both cases switch input value simplifies to - // UnswitchVal. - BranchesInfo.setUnswitched(SI, UnswitchVal); - return true; - } - } - } - - // Scan the instructions to check for unswitchable values. - for (BasicBlock::iterator BBI = (*I)->begin(), E = (*I)->end(); - BBI != E; ++BBI) - if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) { - Value *LoopCond = findLIVLoopCondition(SI->getCondition(), CurrentLoop, - Changed, MSSAU.get()) - .first; - if (LoopCond && - unswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { - ++NumSelects; - return true; - } - } - } - - // Check if there is a header condition that is invariant along the patch from - // either the true or false successors to the header. This allows unswitching - // conditions depending on memory accesses, if there's a path not clobbering - // the memory locations. Check if this transform has been disabled using - // metadata, to avoid unswitching the same loop multiple times. - if (MSSA && - !findOptionMDForLoop(CurrentLoop, "llvm.loop.unswitch.partial.disable")) { - if (auto Info = - hasPartialIVCondition(*CurrentLoop, MSSAThreshold, *MSSA, *AA)) { - assert(!Info->InstToDuplicate.empty() && - "need at least a partially invariant condition"); - LLVM_DEBUG(dbgs() << "loop-unswitch: Found partially invariant condition " - << *Info->InstToDuplicate[0] << "\n"); - - Instruction *TI = CurrentLoop->getHeader()->getTerminator(); - Value *LoopCond = Info->InstToDuplicate[0]; - - // If the partially unswitched path is a no-op and has a single exit - // block, we do not need to do full unswitching. Instead, we can directly - // branch to the exit. - // TODO: Instead of duplicating the checks, we could also just directly - // branch to the exit from the conditional branch in the loop. - if (Info->PathIsNoop) { - if (HasBranchDivergence && - getAnalysis<LegacyDivergenceAnalysis>().isDivergent(LoopCond)) { - LLVM_DEBUG(dbgs() << "NOT unswitching loop %" - << CurrentLoop->getHeader()->getName() - << " at non-trivial condition '" - << *Info->KnownValue << "' == " << *LoopCond << "\n" - << ". Condition is divergent.\n"); - return false; - } - - ++NumBranches; - - BasicBlock *TrueDest = LoopHeader; - BasicBlock *FalseDest = Info->ExitForPath; - if (Info->KnownValue->isOneValue()) - std::swap(TrueDest, FalseDest); - - auto *OldBr = - cast<BranchInst>(CurrentLoop->getLoopPreheader()->getTerminator()); - emitPreheaderBranchOnCondition(LoopCond, Info->KnownValue, TrueDest, - FalseDest, OldBr, TI, - Info->InstToDuplicate); - delete OldBr; - RedoLoop = false; - return true; - } - - // Otherwise, the path is not a no-op. Run regular unswitching. - if (unswitchIfProfitable(LoopCond, Info->KnownValue, - CurrentLoop->getHeader()->getTerminator(), - Info->InstToDuplicate)) { - ++NumBranches; - RedoLoop = false; - return true; - } - } - } - - return Changed; -} - -/// Check to see if all paths from BB exit the loop with no side effects -/// (including infinite loops). -/// -/// If true, we return true and set ExitBB to the block we -/// exit through. -/// -static bool isTrivialLoopExitBlockHelper(Loop *L, BasicBlock *BB, - BasicBlock *&ExitBB, - std::set<BasicBlock*> &Visited) { - if (!Visited.insert(BB).second) { - // Already visited. Without more analysis, this could indicate an infinite - // loop. - return false; - } - if (!L->contains(BB)) { - // Otherwise, this is a loop exit, this is fine so long as this is the - // first exit. - if (ExitBB) return false; - ExitBB = BB; - return true; - } - - // Otherwise, this is an unvisited intra-loop node. Check all successors. - for (BasicBlock *Succ : successors(BB)) { - // Check to see if the successor is a trivial loop exit. - if (!isTrivialLoopExitBlockHelper(L, Succ, ExitBB, Visited)) - return false; - } - - // Okay, everything after this looks good, check to make sure that this block - // doesn't include any side effects. - for (Instruction &I : *BB) - if (I.mayHaveSideEffects()) - return false; - - return true; -} - -/// Return true if the specified block unconditionally leads to an exit from -/// the specified loop, and has no side-effects in the process. If so, return -/// the block that is exited to, otherwise return null. -static BasicBlock *isTrivialLoopExitBlock(Loop *L, BasicBlock *BB) { - std::set<BasicBlock*> Visited; - Visited.insert(L->getHeader()); // Branches to header make infinite loops. - BasicBlock *ExitBB = nullptr; - if (isTrivialLoopExitBlockHelper(L, BB, ExitBB, Visited)) - return ExitBB; - return nullptr; -} - -/// We have found that we can unswitch CurrentLoop when LoopCond == Val to -/// simplify the loop. If we decide that this is profitable, -/// unswitch the loop, reprocess the pieces, then return true. -bool LoopUnswitch::unswitchIfProfitable(Value *LoopCond, Constant *Val, - Instruction *TI, - ArrayRef<Instruction *> ToDuplicate) { - // Check to see if it would be profitable to unswitch current loop. - if (!BranchesInfo.costAllowsUnswitching()) { - LLVM_DEBUG(dbgs() << "NOT unswitching loop %" - << CurrentLoop->getHeader()->getName() - << " at non-trivial condition '" << *Val - << "' == " << *LoopCond << "\n" - << ". Cost too high.\n"); - return false; - } - if (HasBranchDivergence && - getAnalysis<LegacyDivergenceAnalysis>().isDivergent(LoopCond)) { - LLVM_DEBUG(dbgs() << "NOT unswitching loop %" - << CurrentLoop->getHeader()->getName() - << " at non-trivial condition '" << *Val - << "' == " << *LoopCond << "\n" - << ". Condition is divergent.\n"); - return false; - } - - unswitchNontrivialCondition(LoopCond, Val, CurrentLoop, TI, ToDuplicate); - return true; -} - -/// Emit a conditional branch on two values if LIC == Val, branch to TrueDst, -/// otherwise branch to FalseDest. Insert the code immediately before OldBranch -/// and remove (but not erase!) it from the function. -void LoopUnswitch::emitPreheaderBranchOnCondition( - Value *LIC, Constant *Val, BasicBlock *TrueDest, BasicBlock *FalseDest, - BranchInst *OldBranch, Instruction *TI, - ArrayRef<Instruction *> ToDuplicate) { - assert(OldBranch->isUnconditional() && "Preheader is not split correctly"); - assert(TrueDest != FalseDest && "Branch targets should be different"); - - // Insert a conditional branch on LIC to the two preheaders. The original - // code is the true version and the new code is the false version. - Value *BranchVal = LIC; - bool Swapped = false; - - if (!ToDuplicate.empty()) { - ValueToValueMapTy Old2New; - for (Instruction *I : reverse(ToDuplicate)) { - auto *New = I->clone(); - New->insertBefore(OldBranch); - RemapInstruction(New, Old2New, - RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); - Old2New[I] = New; - - if (MSSAU) { - MemorySSA *MSSA = MSSAU->getMemorySSA(); - auto *MemA = dyn_cast_or_null<MemoryUse>(MSSA->getMemoryAccess(I)); - if (!MemA) - continue; - - Loop *L = LI->getLoopFor(I->getParent()); - auto *DefiningAccess = MemA->getDefiningAccess(); - // Get the first defining access before the loop. - while (L->contains(DefiningAccess->getBlock())) { - // If the defining access is a MemoryPhi, get the incoming - // value for the pre-header as defining access. - if (auto *MemPhi = dyn_cast<MemoryPhi>(DefiningAccess)) { - DefiningAccess = - MemPhi->getIncomingValueForBlock(L->getLoopPreheader()); - } else { - DefiningAccess = - cast<MemoryDef>(DefiningAccess)->getDefiningAccess(); - } - } - MSSAU->createMemoryAccessInBB(New, DefiningAccess, New->getParent(), - MemorySSA::BeforeTerminator); - } - } - BranchVal = Old2New[ToDuplicate[0]]; - } else { - - if (!isa<ConstantInt>(Val) || - Val->getType() != Type::getInt1Ty(LIC->getContext())) - BranchVal = new ICmpInst(OldBranch, ICmpInst::ICMP_EQ, LIC, Val); - else if (Val != ConstantInt::getTrue(Val->getContext())) { - // We want to enter the new loop when the condition is true. - std::swap(TrueDest, FalseDest); - Swapped = true; - } - } - - // Old branch will be removed, so save its parent and successor to update the - // DomTree. - auto *OldBranchSucc = OldBranch->getSuccessor(0); - auto *OldBranchParent = OldBranch->getParent(); - - // Insert the new branch. - BranchInst *BI = - IRBuilder<>(OldBranch).CreateCondBr(BranchVal, TrueDest, FalseDest, TI); - if (Swapped) - BI->swapProfMetadata(); - - // Remove the old branch so there is only one branch at the end. This is - // needed to perform DomTree's internal DFS walk on the function's CFG. - OldBranch->removeFromParent(); - - // Inform the DT about the new branch. - if (DT) { - // First, add both successors. - SmallVector<DominatorTree::UpdateType, 3> Updates; - if (TrueDest != OldBranchSucc) - Updates.push_back({DominatorTree::Insert, OldBranchParent, TrueDest}); - if (FalseDest != OldBranchSucc) - Updates.push_back({DominatorTree::Insert, OldBranchParent, FalseDest}); - // If both of the new successors are different from the old one, inform the - // DT that the edge was deleted. - if (OldBranchSucc != TrueDest && OldBranchSucc != FalseDest) { - Updates.push_back({DominatorTree::Delete, OldBranchParent, OldBranchSucc}); - } - - if (MSSAU) - MSSAU->applyUpdates(Updates, *DT, /*UpdateDT=*/true); - else - DT->applyUpdates(Updates); - } - - // If either edge is critical, split it. This helps preserve LoopSimplify - // form for enclosing loops. - auto Options = - CriticalEdgeSplittingOptions(DT, LI, MSSAU.get()).setPreserveLCSSA(); - SplitCriticalEdge(BI, 0, Options); - SplitCriticalEdge(BI, 1, Options); -} - -/// Given a loop that has a trivial unswitchable condition in it (a cond branch -/// from its header block to its latch block, where the path through the loop -/// that doesn't execute its body has no side-effects), unswitch it. This -/// doesn't involve any code duplication, just moving the conditional branch -/// outside of the loop and updating loop info. -void LoopUnswitch::unswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, - BasicBlock *ExitBlock, - Instruction *TI) { - LLVM_DEBUG(dbgs() << "loop-unswitch: Trivial-Unswitch loop %" - << LoopHeader->getName() << " [" << L->getBlocks().size() - << " blocks] in Function " - << L->getHeader()->getParent()->getName() - << " on cond: " << *Val << " == " << *Cond << "\n"); - // We are going to make essential changes to CFG. This may invalidate cached - // information for L or one of its parent loops in SCEV. - if (auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>()) - SEWP->getSE().forgetTopmostLoop(L); - - // First step, split the preheader, so that we know that there is a safe place - // to insert the conditional branch. We will change LoopPreheader to have a - // conditional branch on Cond. - BasicBlock *NewPH = SplitEdge(LoopPreheader, LoopHeader, DT, LI, MSSAU.get()); - - // Now that we have a place to insert the conditional branch, create a place - // to branch to: this is the exit block out of the loop that we should - // short-circuit to. - - // Split this block now, so that the loop maintains its exit block, and so - // that the jump from the preheader can execute the contents of the exit block - // without actually branching to it (the exit block should be dominated by the - // loop header, not the preheader). - assert(!L->contains(ExitBlock) && "Exit block is in the loop?"); - BasicBlock *NewExit = - SplitBlock(ExitBlock, &ExitBlock->front(), DT, LI, MSSAU.get()); - - // Okay, now we have a position to branch from and a position to branch to, - // insert the new conditional branch. - auto *OldBranch = dyn_cast<BranchInst>(LoopPreheader->getTerminator()); - assert(OldBranch && "Failed to split the preheader"); - emitPreheaderBranchOnCondition(Cond, Val, NewExit, NewPH, OldBranch, TI); - - // emitPreheaderBranchOnCondition removed the OldBranch from the function. - // Delete it, as it is no longer needed. - delete OldBranch; - - // We need to reprocess this loop, it could be unswitched again. - RedoLoop = true; - - // Now that we know that the loop is never entered when this condition is a - // particular value, rewrite the loop with this info. We know that this will - // at least eliminate the old branch. - rewriteLoopBodyWithConditionConstant(L, Cond, Val, /*IsEqual=*/false); - - ++NumTrivial; -} - -/// Check if the first non-constant condition starting from the loop header is -/// a trivial unswitch condition: that is, a condition controls whether or not -/// the loop does anything at all. If it is a trivial condition, unswitching -/// produces no code duplications (equivalently, it produces a simpler loop and -/// a new empty loop, which gets deleted). Therefore always unswitch trivial -/// condition. -bool LoopUnswitch::tryTrivialLoopUnswitch(bool &Changed) { - BasicBlock *CurrentBB = CurrentLoop->getHeader(); - Instruction *CurrentTerm = CurrentBB->getTerminator(); - LLVMContext &Context = CurrentBB->getContext(); - - // If loop header has only one reachable successor (currently via an - // unconditional branch or constant foldable conditional branch, but - // should also consider adding constant foldable switch instruction in - // future), we should keep looking for trivial condition candidates in - // the successor as well. An alternative is to constant fold conditions - // and merge successors into loop header (then we only need to check header's - // terminator). The reason for not doing this in LoopUnswitch pass is that - // it could potentially break LoopPassManager's invariants. Folding dead - // branches could either eliminate the current loop or make other loops - // unreachable. LCSSA form might also not be preserved after deleting - // branches. The following code keeps traversing loop header's successors - // until it finds the trivial condition candidate (condition that is not a - // constant). Since unswitching generates branches with constant conditions, - // this scenario could be very common in practice. - SmallPtrSet<BasicBlock*, 8> Visited; - - while (true) { - // If we exit loop or reach a previous visited block, then - // we can not reach any trivial condition candidates (unfoldable - // branch instructions or switch instructions) and no unswitch - // can happen. Exit and return false. - if (!CurrentLoop->contains(CurrentBB) || !Visited.insert(CurrentBB).second) - return false; - - // Check if this loop will execute any side-effecting instructions (e.g. - // stores, calls, volatile loads) in the part of the loop that the code - // *would* execute. Check the header first. - for (Instruction &I : *CurrentBB) - if (I.mayHaveSideEffects()) - return false; - - if (BranchInst *BI = dyn_cast<BranchInst>(CurrentTerm)) { - if (BI->isUnconditional()) { - CurrentBB = BI->getSuccessor(0); - } else if (BI->getCondition() == ConstantInt::getTrue(Context)) { - CurrentBB = BI->getSuccessor(0); - } else if (BI->getCondition() == ConstantInt::getFalse(Context)) { - CurrentBB = BI->getSuccessor(1); - } else { - // Found a trivial condition candidate: non-foldable conditional branch. - break; - } - } else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) { - // At this point, any constant-foldable instructions should have probably - // been folded. - ConstantInt *Cond = dyn_cast<ConstantInt>(SI->getCondition()); - if (!Cond) - break; - // Find the target block we are definitely going to. - CurrentBB = SI->findCaseValue(Cond)->getCaseSuccessor(); - } else { - // We do not understand these terminator instructions. - break; - } - - CurrentTerm = CurrentBB->getTerminator(); - } - - // CondVal is the condition that controls the trivial condition. - // LoopExitBB is the BasicBlock that loop exits when meets trivial condition. - Constant *CondVal = nullptr; - BasicBlock *LoopExitBB = nullptr; - - if (BranchInst *BI = dyn_cast<BranchInst>(CurrentTerm)) { - // If this isn't branching on an invariant condition, we can't unswitch it. - if (!BI->isConditional()) - return false; - - Value *LoopCond = findLIVLoopCondition(BI->getCondition(), CurrentLoop, - Changed, MSSAU.get()) - .first; - - // Unswitch only if the trivial condition itself is an LIV (not - // partial LIV which could occur in and/or) - if (!LoopCond || LoopCond != BI->getCondition()) - return false; - - // Check to see if a successor of the branch is guaranteed to - // exit through a unique exit block without having any - // side-effects. If so, determine the value of Cond that causes - // it to do this. - if ((LoopExitBB = - isTrivialLoopExitBlock(CurrentLoop, BI->getSuccessor(0)))) { - CondVal = ConstantInt::getTrue(Context); - } else if ((LoopExitBB = - isTrivialLoopExitBlock(CurrentLoop, BI->getSuccessor(1)))) { - CondVal = ConstantInt::getFalse(Context); - } - - // If we didn't find a single unique LoopExit block, or if the loop exit - // block contains phi nodes, this isn't trivial. - if (!LoopExitBB || isa<PHINode>(LoopExitBB->begin())) - return false; // Can't handle this. - - if (equalityPropUnSafe(*LoopCond)) - return false; - - unswitchTrivialCondition(CurrentLoop, LoopCond, CondVal, LoopExitBB, - CurrentTerm); - ++NumBranches; - return true; - } else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) { - // If this isn't switching on an invariant condition, we can't unswitch it. - Value *LoopCond = findLIVLoopCondition(SI->getCondition(), CurrentLoop, - Changed, MSSAU.get()) - .first; - - // Unswitch only if the trivial condition itself is an LIV (not - // partial LIV which could occur in and/or) - if (!LoopCond || LoopCond != SI->getCondition()) - return false; - - // Check to see if a successor of the switch is guaranteed to go to the - // latch block or exit through a one exit block without having any - // side-effects. If so, determine the value of Cond that causes it to do - // this. - // Note that we can't trivially unswitch on the default case or - // on already unswitched cases. - for (auto Case : SI->cases()) { - BasicBlock *LoopExitCandidate; - if ((LoopExitCandidate = - isTrivialLoopExitBlock(CurrentLoop, Case.getCaseSuccessor()))) { - // Okay, we found a trivial case, remember the value that is trivial. - ConstantInt *CaseVal = Case.getCaseValue(); - - // Check that it was not unswitched before, since already unswitched - // trivial vals are looks trivial too. - if (BranchesInfo.isUnswitched(SI, CaseVal)) - continue; - LoopExitBB = LoopExitCandidate; - CondVal = CaseVal; - break; - } - } - - // If we didn't find a single unique LoopExit block, or if the loop exit - // block contains phi nodes, this isn't trivial. - if (!LoopExitBB || isa<PHINode>(LoopExitBB->begin())) - return false; // Can't handle this. - - unswitchTrivialCondition(CurrentLoop, LoopCond, CondVal, LoopExitBB, - nullptr); - - // We are only unswitching full LIV. - BranchesInfo.setUnswitched(SI, CondVal); - ++NumSwitches; - return true; - } - return false; -} - -/// Split all of the edges from inside the loop to their exit blocks. -/// Update the appropriate Phi nodes as we do so. -void LoopUnswitch::splitExitEdges( - Loop *L, const SmallVectorImpl<BasicBlock *> &ExitBlocks) { - - for (unsigned I = 0, E = ExitBlocks.size(); I != E; ++I) { - BasicBlock *ExitBlock = ExitBlocks[I]; - SmallVector<BasicBlock *, 4> Preds(predecessors(ExitBlock)); - - // Although SplitBlockPredecessors doesn't preserve loop-simplify in - // general, if we call it on all predecessors of all exits then it does. - SplitBlockPredecessors(ExitBlock, Preds, ".us-lcssa", DT, LI, MSSAU.get(), - /*PreserveLCSSA*/ true); - } -} - -/// We determined that the loop is profitable to unswitch when LIC equal Val. -/// Split it into loop versions and test the condition outside of either loop. -/// Return the loops created as Out1/Out2. -void LoopUnswitch::unswitchNontrivialCondition( - Value *LIC, Constant *Val, Loop *L, Instruction *TI, - ArrayRef<Instruction *> ToDuplicate) { - Function *F = LoopHeader->getParent(); - LLVM_DEBUG(dbgs() << "loop-unswitch: Unswitching loop %" - << LoopHeader->getName() << " [" << L->getBlocks().size() - << " blocks] in Function " << F->getName() << " when '" - << *Val << "' == " << *LIC << "\n"); - - // We are going to make essential changes to CFG. This may invalidate cached - // information for L or one of its parent loops in SCEV. - if (auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>()) - SEWP->getSE().forgetTopmostLoop(L); - - LoopBlocks.clear(); - NewBlocks.clear(); - - if (MSSAU && VerifyMemorySSA) - MSSA->verifyMemorySSA(); - - // First step, split the preheader and exit blocks, and add these blocks to - // the LoopBlocks list. - BasicBlock *NewPreheader = - SplitEdge(LoopPreheader, LoopHeader, DT, LI, MSSAU.get()); - LoopBlocks.push_back(NewPreheader); - - // We want the loop to come after the preheader, but before the exit blocks. - llvm::append_range(LoopBlocks, L->blocks()); - - SmallVector<BasicBlock*, 8> ExitBlocks; - L->getUniqueExitBlocks(ExitBlocks); - - // Split all of the edges from inside the loop to their exit blocks. Update - // the appropriate Phi nodes as we do so. - splitExitEdges(L, ExitBlocks); - - // The exit blocks may have been changed due to edge splitting, recompute. - ExitBlocks.clear(); - L->getUniqueExitBlocks(ExitBlocks); - - // Add exit blocks to the loop blocks. - llvm::append_range(LoopBlocks, ExitBlocks); - - // Next step, clone all of the basic blocks that make up the loop (including - // the loop preheader and exit blocks), keeping track of the mapping between - // the instructions and blocks. - NewBlocks.reserve(LoopBlocks.size()); - ValueToValueMapTy VMap; - for (unsigned I = 0, E = LoopBlocks.size(); I != E; ++I) { - BasicBlock *NewBB = CloneBasicBlock(LoopBlocks[I], VMap, ".us", F); - - NewBlocks.push_back(NewBB); - VMap[LoopBlocks[I]] = NewBB; // Keep the BB mapping. - } - - // Splice the newly inserted blocks into the function right before the - // original preheader. - F->getBasicBlockList().splice(NewPreheader->getIterator(), - F->getBasicBlockList(), - NewBlocks[0]->getIterator(), F->end()); - - // Now we create the new Loop object for the versioned loop. - Loop *NewLoop = cloneLoop(L, L->getParentLoop(), VMap, LI, LPM); - - // Recalculate unswitching quota, inherit simplified switches info for NewBB, - // Probably clone more loop-unswitch related loop properties. - BranchesInfo.cloneData(NewLoop, L, VMap); - - Loop *ParentLoop = L->getParentLoop(); - if (ParentLoop) { - // Make sure to add the cloned preheader and exit blocks to the parent loop - // as well. - ParentLoop->addBasicBlockToLoop(NewBlocks[0], *LI); - } - - for (unsigned EBI = 0, EBE = ExitBlocks.size(); EBI != EBE; ++EBI) { - BasicBlock *NewExit = cast<BasicBlock>(VMap[ExitBlocks[EBI]]); - // The new exit block should be in the same loop as the old one. - if (Loop *ExitBBLoop = LI->getLoopFor(ExitBlocks[EBI])) - ExitBBLoop->addBasicBlockToLoop(NewExit, *LI); - - assert(NewExit->getTerminator()->getNumSuccessors() == 1 && - "Exit block should have been split to have one successor!"); - BasicBlock *ExitSucc = NewExit->getTerminator()->getSuccessor(0); - - // If the successor of the exit block had PHI nodes, add an entry for - // NewExit. - for (PHINode &PN : ExitSucc->phis()) { - Value *V = PN.getIncomingValueForBlock(ExitBlocks[EBI]); - ValueToValueMapTy::iterator It = VMap.find(V); - if (It != VMap.end()) V = It->second; - PN.addIncoming(V, NewExit); - } - - if (LandingPadInst *LPad = NewExit->getLandingPadInst()) { - PHINode *PN = PHINode::Create(LPad->getType(), 0, "", - &*ExitSucc->getFirstInsertionPt()); - - for (BasicBlock *BB : predecessors(ExitSucc)) { - LandingPadInst *LPI = BB->getLandingPadInst(); - LPI->replaceAllUsesWith(PN); - PN->addIncoming(LPI, BB); - } - } - } - - // Rewrite the code to refer to itself. - for (unsigned NBI = 0, NBE = NewBlocks.size(); NBI != NBE; ++NBI) { - for (Instruction &I : *NewBlocks[NBI]) { - RemapInstruction(&I, VMap, - RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); - if (auto *II = dyn_cast<AssumeInst>(&I)) - AC->registerAssumption(II); - } - } - - // Rewrite the original preheader to select between versions of the loop. - BranchInst *OldBR = cast<BranchInst>(LoopPreheader->getTerminator()); - assert(OldBR->isUnconditional() && OldBR->getSuccessor(0) == LoopBlocks[0] && - "Preheader splitting did not work correctly!"); - - if (MSSAU) { - // Update MemorySSA after cloning, and before splitting to unreachables, - // since that invalidates the 1:1 mapping of clones in VMap. - LoopBlocksRPO LBRPO(L); - LBRPO.perform(LI); - MSSAU->updateForClonedLoop(LBRPO, ExitBlocks, VMap); - } - - // Emit the new branch that selects between the two versions of this loop. - emitPreheaderBranchOnCondition(LIC, Val, NewBlocks[0], LoopBlocks[0], OldBR, - TI, ToDuplicate); - if (MSSAU) { - // Update MemoryPhis in Exit blocks. - MSSAU->updateExitBlocksForClonedLoop(ExitBlocks, VMap, *DT); - if (VerifyMemorySSA) - MSSA->verifyMemorySSA(); - } - - // The OldBr was replaced by a new one and removed (but not erased) by - // emitPreheaderBranchOnCondition. It is no longer needed, so delete it. - delete OldBR; - - LoopProcessWorklist.push_back(NewLoop); - RedoLoop = true; - - // Keep a WeakTrackingVH holding onto LIC. If the first call to - // RewriteLoopBody - // deletes the instruction (for example by simplifying a PHI that feeds into - // the condition that we're unswitching on), we don't rewrite the second - // iteration. - WeakTrackingVH LICHandle(LIC); - - if (ToDuplicate.empty()) { - // Now we rewrite the original code to know that the condition is true and - // the new code to know that the condition is false. - rewriteLoopBodyWithConditionConstant(L, LIC, Val, /*IsEqual=*/false); - - // It's possible that simplifying one loop could cause the other to be - // changed to another value or a constant. If its a constant, don't - // simplify it. - if (!LoopProcessWorklist.empty() && LoopProcessWorklist.back() == NewLoop && - LICHandle && !isa<Constant>(LICHandle)) - rewriteLoopBodyWithConditionConstant(NewLoop, LICHandle, Val, - /*IsEqual=*/true); - } else { - // Partial unswitching. Update the condition in the right loop with the - // constant. - auto *CC = cast<ConstantInt>(Val); - if (CC->isOneValue()) { - rewriteLoopBodyWithConditionConstant(NewLoop, VMap[LIC], Val, - /*IsEqual=*/true); - } else - rewriteLoopBodyWithConditionConstant(L, LIC, Val, /*IsEqual=*/true); - - // Mark the new loop as partially unswitched, to avoid unswitching on the - // same condition again. - auto &Context = NewLoop->getHeader()->getContext(); - MDNode *DisableUnswitchMD = MDNode::get( - Context, MDString::get(Context, "llvm.loop.unswitch.partial.disable")); - MDNode *NewLoopID = makePostTransformationMetadata( - Context, L->getLoopID(), {"llvm.loop.unswitch.partial"}, - {DisableUnswitchMD}); - NewLoop->setLoopID(NewLoopID); - } - - if (MSSA && VerifyMemorySSA) - MSSA->verifyMemorySSA(); -} - -/// Remove all instances of I from the worklist vector specified. -static void removeFromWorklist(Instruction *I, - std::vector<Instruction *> &Worklist) { - llvm::erase_value(Worklist, I); -} - -/// When we find that I really equals V, remove I from the -/// program, replacing all uses with V and update the worklist. -static void replaceUsesOfWith(Instruction *I, Value *V, - std::vector<Instruction *> &Worklist, Loop *L, - LPPassManager *LPM, MemorySSAUpdater *MSSAU) { - LLVM_DEBUG(dbgs() << "Replace with '" << *V << "': " << *I << "\n"); - - // Add uses to the worklist, which may be dead now. - for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) - if (Instruction *Use = dyn_cast<Instruction>(I->getOperand(i))) - Worklist.push_back(Use); - - // Add users to the worklist which may be simplified now. - for (User *U : I->users()) - Worklist.push_back(cast<Instruction>(U)); - removeFromWorklist(I, Worklist); - I->replaceAllUsesWith(V); - if (!I->mayHaveSideEffects()) { - if (MSSAU) - MSSAU->removeMemoryAccess(I); - I->eraseFromParent(); - } - ++NumSimplify; -} - -/// We know either that the value LIC has the value specified by Val in the -/// specified loop, or we know it does NOT have that value. -/// Rewrite any uses of LIC or of properties correlated to it. -void LoopUnswitch::rewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, - Constant *Val, - bool IsEqual) { - assert(!isa<Constant>(LIC) && "Why are we unswitching on a constant?"); - - // FIXME: Support correlated properties, like: - // for (...) - // if (li1 < li2) - // ... - // if (li1 > li2) - // ... - - // FOLD boolean conditions (X|LIC), (X&LIC). Fold conditional branches, - // selects, switches. - std::vector<Instruction*> Worklist; - LLVMContext &Context = Val->getContext(); - - // If we know that LIC == Val, or that LIC == NotVal, just replace uses of LIC - // in the loop with the appropriate one directly. - if (IsEqual || (isa<ConstantInt>(Val) && - Val->getType()->isIntegerTy(1))) { - Value *Replacement; - if (IsEqual) - Replacement = Val; - else - Replacement = ConstantInt::get(Type::getInt1Ty(Val->getContext()), - !cast<ConstantInt>(Val)->getZExtValue()); - - for (User *U : LIC->users()) { - Instruction *UI = dyn_cast<Instruction>(U); - if (!UI || !L->contains(UI)) - continue; - Worklist.push_back(UI); - } - - for (Instruction *UI : Worklist) - UI->replaceUsesOfWith(LIC, Replacement); - - simplifyCode(Worklist, L); - return; - } - - // Otherwise, we don't know the precise value of LIC, but we do know that it - // is certainly NOT "Val". As such, simplify any uses in the loop that we - // can. This case occurs when we unswitch switch statements. - for (User *U : LIC->users()) { - Instruction *UI = dyn_cast<Instruction>(U); - if (!UI || !L->contains(UI)) - continue; - - // At this point, we know LIC is definitely not Val. Try to use some simple - // logic to simplify the user w.r.t. to the context. - if (Value *Replacement = simplifyInstructionWithNotEqual(UI, LIC, Val)) { - if (LI->replacementPreservesLCSSAForm(UI, Replacement)) { - // This in-loop instruction has been simplified w.r.t. its context, - // i.e. LIC != Val, make sure we propagate its replacement value to - // all its users. - // - // We can not yet delete UI, the LIC user, yet, because that would invalidate - // the LIC->users() iterator !. However, we can make this instruction - // dead by replacing all its users and push it onto the worklist so that - // it can be properly deleted and its operands simplified. - UI->replaceAllUsesWith(Replacement); - } - } - - // This is a LIC user, push it into the worklist so that simplifyCode can - // attempt to simplify it. - Worklist.push_back(UI); - - // If we know that LIC is not Val, use this info to simplify code. - SwitchInst *SI = dyn_cast<SwitchInst>(UI); - if (!SI || !isa<ConstantInt>(Val)) continue; - - // NOTE: if a case value for the switch is unswitched out, we record it - // after the unswitch finishes. We can not record it here as the switch - // is not a direct user of the partial LIV. - SwitchInst::CaseHandle DeadCase = - *SI->findCaseValue(cast<ConstantInt>(Val)); - // Default case is live for multiple values. - if (DeadCase == *SI->case_default()) - continue; - - // Found a dead case value. Don't remove PHI nodes in the - // successor if they become single-entry, those PHI nodes may - // be in the Users list. - - BasicBlock *Switch = SI->getParent(); - BasicBlock *SISucc = DeadCase.getCaseSuccessor(); - BasicBlock *Latch = L->getLoopLatch(); - - if (!SI->findCaseDest(SISucc)) continue; // Edge is critical. - // If the DeadCase successor dominates the loop latch, then the - // transformation isn't safe since it will delete the sole predecessor edge - // to the latch. - if (Latch && DT->dominates(SISucc, Latch)) - continue; - - // FIXME: This is a hack. We need to keep the successor around - // and hooked up so as to preserve the loop structure, because - // trying to update it is complicated. So instead we preserve the - // loop structure and put the block on a dead code path. - SplitEdge(Switch, SISucc, DT, LI, MSSAU.get()); - // Compute the successors instead of relying on the return value - // of SplitEdge, since it may have split the switch successor - // after PHI nodes. - BasicBlock *NewSISucc = DeadCase.getCaseSuccessor(); - BasicBlock *OldSISucc = *succ_begin(NewSISucc); - // Create an "unreachable" destination. - BasicBlock *Abort = BasicBlock::Create(Context, "us-unreachable", - Switch->getParent(), - OldSISucc); - new UnreachableInst(Context, Abort); - // Force the new case destination to branch to the "unreachable" - // block while maintaining a (dead) CFG edge to the old block. - NewSISucc->getTerminator()->eraseFromParent(); - BranchInst::Create(Abort, OldSISucc, - ConstantInt::getTrue(Context), NewSISucc); - // Release the PHI operands for this edge. - for (PHINode &PN : NewSISucc->phis()) - PN.setIncomingValueForBlock(Switch, UndefValue::get(PN.getType())); - // Tell the domtree about the new block. We don't fully update the - // domtree here -- instead we force it to do a full recomputation - // after the pass is complete -- but we do need to inform it of - // new blocks. - DT->addNewBlock(Abort, NewSISucc); - } - - simplifyCode(Worklist, L); -} - -/// Now that we have simplified some instructions in the loop, walk over it and -/// constant prop, dce, and fold control flow where possible. Note that this is -/// effectively a very simple loop-structure-aware optimizer. During processing -/// of this loop, L could very well be deleted, so it must not be used. -/// -/// FIXME: When the loop optimizer is more mature, separate this out to a new -/// pass. -/// -void LoopUnswitch::simplifyCode(std::vector<Instruction *> &Worklist, Loop *L) { - const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); - while (!Worklist.empty()) { - Instruction *I = Worklist.back(); - Worklist.pop_back(); - - // Simple DCE. - if (isInstructionTriviallyDead(I)) { - LLVM_DEBUG(dbgs() << "Remove dead instruction '" << *I << "\n"); - - // Add uses to the worklist, which may be dead now. - for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) - if (Instruction *Use = dyn_cast<Instruction>(I->getOperand(i))) - Worklist.push_back(Use); - removeFromWorklist(I, Worklist); - if (MSSAU) - MSSAU->removeMemoryAccess(I); - I->eraseFromParent(); - ++NumSimplify; - continue; - } - - // See if instruction simplification can hack this up. This is common for - // things like "select false, X, Y" after unswitching made the condition be - // 'false'. TODO: update the domtree properly so we can pass it here. - if (Value *V = SimplifyInstruction(I, DL)) - if (LI->replacementPreservesLCSSAForm(I, V)) { - replaceUsesOfWith(I, V, Worklist, L, LPM, MSSAU.get()); - continue; - } - - // Special case hacks that appear commonly in unswitched code. - if (BranchInst *BI = dyn_cast<BranchInst>(I)) { - if (BI->isUnconditional()) { - // If BI's parent is the only pred of the successor, fold the two blocks - // together. - BasicBlock *Pred = BI->getParent(); - (void)Pred; - BasicBlock *Succ = BI->getSuccessor(0); - BasicBlock *SinglePred = Succ->getSinglePredecessor(); - if (!SinglePred) continue; // Nothing to do. - assert(SinglePred == Pred && "CFG broken"); - - // Make the LPM and Worklist updates specific to LoopUnswitch. - removeFromWorklist(BI, Worklist); - auto SuccIt = Succ->begin(); - while (PHINode *PN = dyn_cast<PHINode>(SuccIt++)) { - for (unsigned It = 0, E = PN->getNumOperands(); It != E; ++It) - if (Instruction *Use = dyn_cast<Instruction>(PN->getOperand(It))) - Worklist.push_back(Use); - for (User *U : PN->users()) - Worklist.push_back(cast<Instruction>(U)); - removeFromWorklist(PN, Worklist); - ++NumSimplify; - } - // Merge the block and make the remaining analyses updates. - DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); - MergeBlockIntoPredecessor(Succ, &DTU, LI, MSSAU.get()); - ++NumSimplify; - continue; - } - - continue; - } - } -} - -/// Simple simplifications we can do given the information that Cond is -/// definitely not equal to Val. -Value *LoopUnswitch::simplifyInstructionWithNotEqual(Instruction *Inst, - Value *Invariant, - Constant *Val) { - // icmp eq cond, val -> false - ICmpInst *CI = dyn_cast<ICmpInst>(Inst); - if (CI && CI->isEquality()) { - Value *Op0 = CI->getOperand(0); - Value *Op1 = CI->getOperand(1); - if ((Op0 == Invariant && Op1 == Val) || (Op0 == Val && Op1 == Invariant)) { - LLVMContext &Ctx = Inst->getContext(); - if (CI->getPredicate() == CmpInst::ICMP_EQ) - return ConstantInt::getFalse(Ctx); - else - return ConstantInt::getTrue(Ctx); - } - } - - // FIXME: there may be other opportunities, e.g. comparison with floating - // point, or Invariant - Val != 0, etc. - return nullptr; -} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp index 2ff1e8480749..c733aa4701ed 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -70,14 +70,12 @@ #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" -#include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerAtomic.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerAtomicPass.cpp index 4063e4fe0472..6aba913005d0 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerAtomic.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerAtomicPass.cpp @@ -11,95 +11,17 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar/LowerAtomic.h" +#include "llvm/Transforms/Scalar/LowerAtomicPass.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/LowerAtomic.h" using namespace llvm; #define DEBUG_TYPE "loweratomic" -static bool LowerAtomicCmpXchgInst(AtomicCmpXchgInst *CXI) { - IRBuilder<> Builder(CXI); - Value *Ptr = CXI->getPointerOperand(); - Value *Cmp = CXI->getCompareOperand(); - Value *Val = CXI->getNewValOperand(); - - LoadInst *Orig = Builder.CreateLoad(Val->getType(), Ptr); - Value *Equal = Builder.CreateICmpEQ(Orig, Cmp); - Value *Res = Builder.CreateSelect(Equal, Val, Orig); - Builder.CreateStore(Res, Ptr); - - Res = Builder.CreateInsertValue(UndefValue::get(CXI->getType()), Orig, 0); - Res = Builder.CreateInsertValue(Res, Equal, 1); - - CXI->replaceAllUsesWith(Res); - CXI->eraseFromParent(); - return true; -} - -bool llvm::lowerAtomicRMWInst(AtomicRMWInst *RMWI) { - IRBuilder<> Builder(RMWI); - Value *Ptr = RMWI->getPointerOperand(); - Value *Val = RMWI->getValOperand(); - - LoadInst *Orig = Builder.CreateLoad(Val->getType(), Ptr); - Value *Res = nullptr; - - switch (RMWI->getOperation()) { - default: llvm_unreachable("Unexpected RMW operation"); - case AtomicRMWInst::Xchg: - Res = Val; - break; - case AtomicRMWInst::Add: - Res = Builder.CreateAdd(Orig, Val); - break; - case AtomicRMWInst::Sub: - Res = Builder.CreateSub(Orig, Val); - break; - case AtomicRMWInst::And: - Res = Builder.CreateAnd(Orig, Val); - break; - case AtomicRMWInst::Nand: - Res = Builder.CreateNot(Builder.CreateAnd(Orig, Val)); - break; - case AtomicRMWInst::Or: - Res = Builder.CreateOr(Orig, Val); - break; - case AtomicRMWInst::Xor: - Res = Builder.CreateXor(Orig, Val); - break; - case AtomicRMWInst::Max: - Res = Builder.CreateSelect(Builder.CreateICmpSLT(Orig, Val), - Val, Orig); - break; - case AtomicRMWInst::Min: - Res = Builder.CreateSelect(Builder.CreateICmpSLT(Orig, Val), - Orig, Val); - break; - case AtomicRMWInst::UMax: - Res = Builder.CreateSelect(Builder.CreateICmpULT(Orig, Val), - Val, Orig); - break; - case AtomicRMWInst::UMin: - Res = Builder.CreateSelect(Builder.CreateICmpULT(Orig, Val), - Orig, Val); - break; - case AtomicRMWInst::FAdd: - Res = Builder.CreateFAdd(Orig, Val); - break; - case AtomicRMWInst::FSub: - Res = Builder.CreateFSub(Orig, Val); - break; - } - Builder.CreateStore(Res, Ptr); - RMWI->replaceAllUsesWith(Orig); - RMWI->eraseFromParent(); - return true; -} - static bool LowerFenceInst(FenceInst *FI) { FI->eraseFromParent(); return true; @@ -121,7 +43,7 @@ static bool runOnBasicBlock(BasicBlock &BB) { if (FenceInst *FI = dyn_cast<FenceInst>(&Inst)) Changed |= LowerFenceInst(FI); else if (AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(&Inst)) - Changed |= LowerAtomicCmpXchgInst(CXI); + Changed |= lowerAtomicCmpXchgInst(CXI); else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(&Inst)) Changed |= lowerAtomicRMWInst(RMWI); else if (LoadInst *LI = dyn_cast<LoadInst>(&Inst)) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp index 186065db327e..47493b54a527 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp @@ -26,11 +26,9 @@ #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" -#include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" @@ -96,7 +94,7 @@ static bool replaceConditionalBranchesOnConstant(Instruction *II, return HasDeadBlocks; } -static bool lowerConstantIntrinsics(Function &F, const TargetLibraryInfo *TLI, +static bool lowerConstantIntrinsics(Function &F, const TargetLibraryInfo &TLI, DominatorTree *DT) { Optional<DomTreeUpdater> DTU; if (DT) @@ -140,21 +138,21 @@ static bool lowerConstantIntrinsics(Function &F, const TargetLibraryInfo *TLI, IsConstantIntrinsicsHandled++; break; case Intrinsic::objectsize: - NewValue = lowerObjectSizeCall(II, DL, TLI, true); + NewValue = lowerObjectSizeCall(II, DL, &TLI, true); ObjectSizeIntrinsicsHandled++; break; } HasDeadBlocks |= replaceConditionalBranchesOnConstant( - II, NewValue, DTU.hasValue() ? DTU.getPointer() : nullptr); + II, NewValue, DTU ? DTU.getPointer() : nullptr); } if (HasDeadBlocks) - removeUnreachableBlocks(F, DTU.hasValue() ? DTU.getPointer() : nullptr); + removeUnreachableBlocks(F, DTU ? DTU.getPointer() : nullptr); return !Worklist.empty(); } PreservedAnalyses LowerConstantIntrinsicsPass::run(Function &F, FunctionAnalysisManager &AM) { - if (lowerConstantIntrinsics(F, AM.getCachedResult<TargetLibraryAnalysis>(F), + if (lowerConstantIntrinsics(F, AM.getResult<TargetLibraryAnalysis>(F), AM.getCachedResult<DominatorTreeAnalysis>(F))) { PreservedAnalyses PA; PA.preserve<DominatorTreeAnalysis>(); @@ -178,8 +176,8 @@ public: } bool runOnFunction(Function &F) override { - auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); - const TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI(F) : nullptr; + const TargetLibraryInfo &TLI = + getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); DominatorTree *DT = nullptr; if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>()) DT = &DTWP->getDomTree(); @@ -187,6 +185,7 @@ public: } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); } @@ -196,6 +195,7 @@ public: char LowerConstantIntrinsics::ID = 0; INITIALIZE_PASS_BEGIN(LowerConstantIntrinsics, "lower-constant-intrinsics", "Lower constant intrinsics", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_END(LowerConstantIntrinsics, "lower-constant-intrinsics", "Lower constant intrinsics", false, false) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp index a7eb60b5e032..88fad9896c59 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp @@ -21,12 +21,11 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" -#include "llvm/IR/Metadata.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/MisExpect.h" using namespace llvm; @@ -101,6 +100,8 @@ static bool handleSwitchExpect(SwitchInst &SI) { uint64_t Index = (Case == *SI.case_default()) ? 0 : Case.getCaseIndex() + 1; Weights[Index] = LikelyBranchWeightVal; + misexpect::checkExpectAnnotations(SI, Weights, /*IsFrontend=*/true); + SI.setCondition(ArgValue); SI.setMetadata(LLVMContext::MD_prof, @@ -315,13 +316,16 @@ template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) { std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) = getBranchWeight(Fn->getIntrinsicID(), CI, 2); + SmallVector<uint32_t, 4> ExpectedWeights; if ((ExpectedValue->getZExtValue() == ValueComparedTo) == (Predicate == CmpInst::ICMP_EQ)) { Node = MDB.createBranchWeights(LikelyBranchWeightVal, UnlikelyBranchWeightVal); + ExpectedWeights = {LikelyBranchWeightVal, UnlikelyBranchWeightVal}; } else { Node = MDB.createBranchWeights(UnlikelyBranchWeightVal, LikelyBranchWeightVal); + ExpectedWeights = {UnlikelyBranchWeightVal, LikelyBranchWeightVal}; } if (CmpI) @@ -329,6 +333,8 @@ template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) { else BSI.setCondition(ArgValue); + misexpect::checkFrontendInstrumentation(BSI, ExpectedWeights); + BSI.setMetadata(LLVMContext::MD_prof, Node); return true; @@ -409,7 +415,7 @@ public: bool runOnFunction(Function &F) override { return lowerExpectIntrinsic(F); } }; -} +} // namespace char LowerExpectIntrinsic::ID = 0; INITIALIZE_PASS(LowerExpectIntrinsic, "lower-expect", diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp index 45f5929e3b90..8dc037b10cc8 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp @@ -15,7 +15,6 @@ #include "llvm/Transforms/Scalar/LowerGuardIntrinsic.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/GuardUtils.h" -#include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" @@ -49,9 +48,13 @@ static bool lowerGuardIntrinsic(Function &F) { return false; SmallVector<CallInst *, 8> ToLower; - for (auto &I : instructions(F)) - if (isGuard(&I)) - ToLower.push_back(cast<CallInst>(&I)); + // Traverse through the users of GuardDecl. + // This is presumably cheaper than traversing all instructions in the + // function. + for (auto *U : GuardDecl->users()) + if (auto *CI = dyn_cast<CallInst>(U)) + if (CI->getFunction() == &F) + ToLower.push_back(CI); if (ToLower.empty()) return false; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 296becb31e8f..c05906649f16 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -18,11 +18,11 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" -#include "llvm/ADT/GraphTraits.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -704,10 +704,10 @@ public: // We may remove II. By default continue on the next/prev instruction. ++II; // If we were to erase II, move again. - auto EraseFromParent = [&II](Value *V) { + auto EraseFromParent = [&II, &BB](Value *V) { auto *Inst = cast<Instruction>(V); if (Inst->use_empty()) { - if (Inst == &*II) { + if (II != BB.rend() && Inst == &*II) { ++II; } Inst->eraseFromParent(); @@ -718,7 +718,7 @@ public: Instruction *NewInst = nullptr; IRBuilder<> IB(&I); - MatrixBuilder<IRBuilder<>> Builder(IB); + MatrixBuilder Builder(IB); Value *TA, *TAMA, *TAMB; ConstantInt *R, *K, *C; @@ -766,28 +766,25 @@ public: // If we have a TT matmul, lift the transpose. We may be able to fold into // consuming multiply. for (BasicBlock &BB : Func) { - for (BasicBlock::iterator II = BB.begin(); II != BB.end();) { - Instruction *I = &*II; - // We may remove I. - ++II; + for (Instruction &I : llvm::make_early_inc_range(BB)) { Value *A, *B, *AT, *BT; ConstantInt *R, *K, *C; // A^t * B ^t -> (B * A)^t - if (match(&*I, m_Intrinsic<Intrinsic::matrix_multiply>( - m_Value(A), m_Value(B), m_ConstantInt(R), - m_ConstantInt(K), m_ConstantInt(C))) && + if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>( + m_Value(A), m_Value(B), m_ConstantInt(R), + m_ConstantInt(K), m_ConstantInt(C))) && match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) && match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) { - IRBuilder<> IB(&*I); - MatrixBuilder<IRBuilder<>> Builder(IB); + IRBuilder<> IB(&I); + MatrixBuilder Builder(IB); Value *M = Builder.CreateMatrixMultiply( BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue()); setShapeInfo(M, {C, R}); Instruction *NewInst = Builder.CreateMatrixTranspose( M, C->getZExtValue(), R->getZExtValue()); - ReplaceAllUsesWith(*I, NewInst); - if (I->use_empty()) - I->eraseFromParent(); + ReplaceAllUsesWith(I, NewInst); + if (I.use_empty()) + I.eraseFromParent(); if (A->use_empty()) cast<Instruction>(A)->eraseFromParent(); if (A != B && B->use_empty()) @@ -891,27 +888,27 @@ public: // having to update as many def-use and use-def chains. // // Because we add to ToRemove during fusion we can't guarantee that defs - // are before uses. Change uses to undef temporarily as these should get + // are before uses. Change uses to poison temporarily as these should get // removed as well. // - // For verification, we keep track of where we changed uses to undefs in - // UndefedInsts and then check that we in fact remove them. - SmallSet<Instruction *, 16> UndefedInsts; + // For verification, we keep track of where we changed uses to poison in + // PoisonedInsts and then check that we in fact remove them. + SmallSet<Instruction *, 16> PoisonedInsts; for (auto *Inst : reverse(ToRemove)) { for (Use &U : llvm::make_early_inc_range(Inst->uses())) { - if (auto *Undefed = dyn_cast<Instruction>(U.getUser())) - UndefedInsts.insert(Undefed); - U.set(UndefValue::get(Inst->getType())); + if (auto *Poisoned = dyn_cast<Instruction>(U.getUser())) + PoisonedInsts.insert(Poisoned); + U.set(PoisonValue::get(Inst->getType())); } Inst->eraseFromParent(); - UndefedInsts.erase(Inst); + PoisonedInsts.erase(Inst); } - if (!UndefedInsts.empty()) { - // If we didn't remove all undefed instructions, it's a hard error. - dbgs() << "Undefed but present instructions:\n"; - for (auto *I : UndefedInsts) + if (!PoisonedInsts.empty()) { + // If we didn't remove all poisoned instructions, it's a hard error. + dbgs() << "Poisoned but present instructions:\n"; + for (auto *I : PoisonedInsts) dbgs() << *I << "\n"; - llvm_unreachable("Undefed but instruction not removed"); + llvm_unreachable("Poisoned but instruction not removed"); } return Changed; @@ -1670,7 +1667,7 @@ public: for (unsigned I = 0; I < NewNumVecs; ++I) { // Build a single result vector. First initialize it. - Value *ResultVector = UndefValue::get( + Value *ResultVector = PoisonValue::get( FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); // Go through the old elements and insert it into the resulting vector. for (auto J : enumerate(InputMatrix.vectors())) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerWidenableCondition.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerWidenableCondition.cpp index 73b2cd06fa23..e2de322933bc 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerWidenableCondition.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerWidenableCondition.cpp @@ -13,8 +13,6 @@ #include "llvm/Transforms/Scalar/LowerWidenableCondition.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Analysis/GuardUtils.h" -#include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" @@ -24,7 +22,6 @@ #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/GuardUtils.h" using namespace llvm; @@ -50,9 +47,13 @@ static bool lowerWidenableCondition(Function &F) { using namespace llvm::PatternMatch; SmallVector<CallInst *, 8> ToLower; - for (auto &I : instructions(F)) - if (match(&I, m_Intrinsic<Intrinsic::experimental_widenable_condition>())) - ToLower.push_back(cast<CallInst>(&I)); + // Traverse through the users of WCDecl. + // This is presumably cheaper than traversing all instructions in the + // function. + for (auto *U : WCDecl->users()) + if (auto *CI = dyn_cast<CallInst>(U)) + if (CI->getFunction() == &F) + ToLower.push_back(CI); if (ToLower.empty()) return false; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MakeGuardsExplicit.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MakeGuardsExplicit.cpp index 5ffae128f5f0..a3f09a5a33c3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MakeGuardsExplicit.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MakeGuardsExplicit.cpp @@ -33,13 +33,11 @@ #include "llvm/Transforms/Scalar/MakeGuardsExplicit.h" #include "llvm/Analysis/GuardUtils.h" -#include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" -#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/GuardUtils.h" using namespace llvm; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp index 6698db26626b..1f5bc69acecd 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -28,14 +28,12 @@ #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" -#include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" @@ -45,7 +43,6 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" -#include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" #include "llvm/IR/User.h" @@ -61,15 +58,13 @@ #include <algorithm> #include <cassert> #include <cstdint> -#include <utility> using namespace llvm; #define DEBUG_TYPE "memcpyopt" static cl::opt<bool> EnableMemCpyOptWithoutLibcalls( - "enable-memcpyopt-without-libcalls", cl::init(false), cl::Hidden, - cl::ZeroOrMore, + "enable-memcpyopt-without-libcalls", cl::Hidden, cl::desc("Enable memcpyopt even when libcalls are disabled")); STATISTIC(NumMemCpyInstr, "Number of memcpy instructions deleted"); @@ -100,7 +95,7 @@ struct MemsetRange { Value *StartPtr; /// Alignment - The known alignment of the first store. - unsigned Alignment; + MaybeAlign Alignment; /// TheStores - The actual stores that make up this range. SmallVector<Instruction*, 16> TheStores; @@ -182,16 +177,16 @@ public: TypeSize StoreSize = DL.getTypeStoreSize(SI->getOperand(0)->getType()); assert(!StoreSize.isScalable() && "Can't track scalable-typed stores"); addRange(OffsetFromFirst, StoreSize.getFixedSize(), SI->getPointerOperand(), - SI->getAlign().value(), SI); + SI->getAlign(), SI); } void addMemSet(int64_t OffsetFromFirst, MemSetInst *MSI) { int64_t Size = cast<ConstantInt>(MSI->getLength())->getZExtValue(); - addRange(OffsetFromFirst, Size, MSI->getDest(), MSI->getDestAlignment(), MSI); + addRange(OffsetFromFirst, Size, MSI->getDest(), MSI->getDestAlign(), MSI); } - void addRange(int64_t Start, int64_t Size, Value *Ptr, - unsigned Alignment, Instruction *Inst); + void addRange(int64_t Start, int64_t Size, Value *Ptr, MaybeAlign Alignment, + Instruction *Inst); }; } // end anonymous namespace @@ -200,7 +195,7 @@ public: /// new range for the specified store at the specified offset, merging into /// existing ranges as appropriate. void MemsetRanges::addRange(int64_t Start, int64_t Size, Value *Ptr, - unsigned Alignment, Instruction *Inst) { + MaybeAlign Alignment, Instruction *Inst) { int64_t End = Start+Size; range_iterator I = partition_point( @@ -352,9 +347,25 @@ static bool accessedBetween(AliasAnalysis &AA, MemoryLocation Loc, // Check for mod of Loc between Start and End, excluding both boundaries. // Start and End can be in different blocks. -static bool writtenBetween(MemorySSA *MSSA, MemoryLocation Loc, - const MemoryUseOrDef *Start, +static bool writtenBetween(MemorySSA *MSSA, AliasAnalysis &AA, + MemoryLocation Loc, const MemoryUseOrDef *Start, const MemoryUseOrDef *End) { + if (isa<MemoryUse>(End)) { + // For MemoryUses, getClobberingMemoryAccess may skip non-clobbering writes. + // Manually check read accesses between Start and End, if they are in the + // same block, for clobbers. Otherwise assume Loc is clobbered. + return Start->getBlock() != End->getBlock() || + any_of( + make_range(std::next(Start->getIterator()), End->getIterator()), + [&AA, Loc](const MemoryAccess &Acc) { + if (isa<MemoryUse>(&Acc)) + return false; + Instruction *AccInst = + cast<MemoryUseOrDef>(&Acc)->getMemoryInst(); + return isModSet(AA.getModRefInfo(AccInst, Loc)); + }); + } + // TODO: Only walk until we hit Start. MemoryAccess *Clobber = MSSA->getWalker()->getClobberingMemoryAccess( End->getDefiningAccess(), Loc); @@ -492,7 +503,7 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, StartPtr = Range.StartPtr; AMemSet = Builder.CreateMemSet(StartPtr, ByteVal, Range.End - Range.Start, - MaybeAlign(Range.Alignment)); + Range.Alignment); LLVM_DEBUG(dbgs() << "Replace stores:\n"; for (Instruction *SI : Range.TheStores) dbgs() << *SI << '\n'; @@ -749,36 +760,25 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { // Detect cases where we're performing call slot forwarding, but // happen to be using a load-store pair to implement it, rather than // a memcpy. - CallInst *C = nullptr; - if (auto *LoadClobber = dyn_cast<MemoryUseOrDef>( - MSSA->getWalker()->getClobberingMemoryAccess(LI))) { - // The load most post-dom the call. Limit to the same block for now. - // TODO: Support non-local call-slot optimization? - if (LoadClobber->getBlock() == SI->getParent()) - C = dyn_cast_or_null<CallInst>(LoadClobber->getMemoryInst()); - } - - if (C) { - // Check that nothing touches the dest of the "copy" between - // the call and the store. - MemoryLocation StoreLoc = MemoryLocation::get(SI); - if (accessedBetween(*AA, StoreLoc, MSSA->getMemoryAccess(C), - MSSA->getMemoryAccess(SI))) - C = nullptr; - } - - if (C) { - bool changed = performCallSlotOptzn( - LI, SI, SI->getPointerOperand()->stripPointerCasts(), - LI->getPointerOperand()->stripPointerCasts(), - DL.getTypeStoreSize(SI->getOperand(0)->getType()), - commonAlignment(SI->getAlign(), LI->getAlign()), C); - if (changed) { - eraseInstruction(SI); - eraseInstruction(LI); - ++NumMemCpyInstr; - return true; - } + auto GetCall = [&]() -> CallInst * { + // We defer this expensive clobber walk until the cheap checks + // have been done on the source inside performCallSlotOptzn. + if (auto *LoadClobber = dyn_cast<MemoryUseOrDef>( + MSSA->getWalker()->getClobberingMemoryAccess(LI))) + return dyn_cast_or_null<CallInst>(LoadClobber->getMemoryInst()); + return nullptr; + }; + + bool changed = performCallSlotOptzn( + LI, SI, SI->getPointerOperand()->stripPointerCasts(), + LI->getPointerOperand()->stripPointerCasts(), + DL.getTypeStoreSize(SI->getOperand(0)->getType()), + std::min(SI->getAlign(), LI->getAlign()), GetCall); + if (changed) { + eraseInstruction(SI); + eraseInstruction(LI); + ++NumMemCpyInstr; + return true; } } } @@ -853,7 +853,8 @@ bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) { bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, Instruction *cpyStore, Value *cpyDest, Value *cpySrc, TypeSize cpySize, - Align cpyAlign, CallInst *C) { + Align cpyAlign, + std::function<CallInst *()> GetC) { // The general transformation to keep in mind is // // call @func(..., src, ...) @@ -872,11 +873,6 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, if (cpySize.isScalable()) return false; - // Lifetime marks shouldn't be operated on. - if (Function *F = C->getCalledFunction()) - if (F->isIntrinsic() && F->getIntrinsicID() == Intrinsic::lifetime_start) - return false; - // Require that src be an alloca. This simplifies the reasoning considerably. auto *srcAlloca = dyn_cast<AllocaInst>(cpySrc); if (!srcAlloca) @@ -893,6 +889,33 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, if (cpySize < srcSize) return false; + CallInst *C = GetC(); + if (!C) + return false; + + // Lifetime marks shouldn't be operated on. + if (Function *F = C->getCalledFunction()) + if (F->isIntrinsic() && F->getIntrinsicID() == Intrinsic::lifetime_start) + return false; + + + if (C->getParent() != cpyStore->getParent()) { + LLVM_DEBUG(dbgs() << "Call Slot: block local restriction\n"); + return false; + } + + MemoryLocation DestLoc = isa<StoreInst>(cpyStore) ? + MemoryLocation::get(cpyStore) : + MemoryLocation::getForDest(cast<MemCpyInst>(cpyStore)); + + // Check that nothing touches the dest of the copy between + // the call and the store/memcpy. + if (accessedBetween(*AA, DestLoc, MSSA->getMemoryAccess(C), + MSSA->getMemoryAccess(cpyStore))) { + LLVM_DEBUG(dbgs() << "Call Slot: Dest pointer modified after call\n"); + return false; + } + // Check that accessing the first srcSize bytes of dest will not cause a // trap. Otherwise the transform is invalid since it might cause a trap // to occur earlier than it otherwise would. @@ -902,6 +925,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, return false; } + // Make sure that nothing can observe cpyDest being written early. There are // a number of cases to consider: // 1. cpyDest cannot be accessed between C and cpyStore as a precondition of @@ -1118,7 +1142,7 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M, // then we could still perform the xform by moving M up to the first memcpy. // TODO: It would be sufficient to check the MDep source up to the memcpy // size of M, rather than MDep. - if (writtenBetween(MSSA, MemoryLocation::getForSource(MDep), + if (writtenBetween(MSSA, *AA, MemoryLocation::getForSource(MDep), MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(M))) return false; @@ -1215,14 +1239,14 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, } // By default, create an unaligned memset. - unsigned Align = 1; + Align Alignment = Align(1); // If Dest is aligned, and SrcSize is constant, use the minimum alignment // of the sum. - const unsigned DestAlign = - std::max(MemSet->getDestAlignment(), MemCpy->getDestAlignment()); + const Align DestAlign = std::max(MemSet->getDestAlign().valueOrOne(), + MemCpy->getDestAlign().valueOrOne()); if (DestAlign > 1) if (auto *SrcSizeC = dyn_cast<ConstantInt>(SrcSize)) - Align = MinAlign(SrcSizeC->getZExtValue(), DestAlign); + Alignment = commonAlignment(DestAlign, SrcSizeC->getZExtValue()); IRBuilder<> Builder(MemCpy); @@ -1241,11 +1265,11 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, Ule, ConstantInt::getNullValue(DestSize->getType()), SizeDiff); unsigned DestAS = Dest->getType()->getPointerAddressSpace(); Instruction *NewMemSet = Builder.CreateMemSet( - Builder.CreateGEP(Builder.getInt8Ty(), - Builder.CreatePointerCast(Dest, - Builder.getInt8PtrTy(DestAS)), - SrcSize), - MemSet->getOperand(1), MemsetLen, MaybeAlign(Align)); + Builder.CreateGEP( + Builder.getInt8Ty(), + Builder.CreatePointerCast(Dest, Builder.getInt8PtrTy(DestAS)), + SrcSize), + MemSet->getOperand(1), MemsetLen, Alignment); assert(isa<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)) && "MemCpy must be a MemoryDef"); @@ -1402,7 +1426,8 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { } MemoryUseOrDef *MA = MSSA->getMemoryAccess(M); - MemoryAccess *AnyClobber = MSSA->getWalker()->getClobberingMemoryAccess(MA); + // FIXME: Not using getClobberingMemoryAccess() here due to PR54682. + MemoryAccess *AnyClobber = MA->getDefiningAccess(); MemoryLocation DestLoc = MemoryLocation::getForDest(M); const MemoryAccess *DestClobber = MSSA->getWalker()->getClobberingMemoryAccess(AnyClobber, DestLoc); @@ -1431,28 +1456,20 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { if (Instruction *MI = MD->getMemoryInst()) { if (auto *CopySize = dyn_cast<ConstantInt>(M->getLength())) { if (auto *C = dyn_cast<CallInst>(MI)) { - // The memcpy must post-dom the call. Limit to the same block for - // now. Additionally, we need to ensure that there are no accesses - // to dest between the call and the memcpy. Accesses to src will be - // checked by performCallSlotOptzn(). - // TODO: Support non-local call-slot optimization? - if (C->getParent() == M->getParent() && - !accessedBetween(*AA, DestLoc, MD, MA)) { - // FIXME: Can we pass in either of dest/src alignment here instead - // of conservatively taking the minimum? - Align Alignment = std::min(M->getDestAlign().valueOrOne(), - M->getSourceAlign().valueOrOne()); - if (performCallSlotOptzn( - M, M, M->getDest(), M->getSource(), - TypeSize::getFixed(CopySize->getZExtValue()), Alignment, - C)) { - LLVM_DEBUG(dbgs() << "Performed call slot optimization:\n" - << " call: " << *C << "\n" - << " memcpy: " << *M << "\n"); - eraseInstruction(M); - ++NumMemCpyInstr; - return true; - } + // FIXME: Can we pass in either of dest/src alignment here instead + // of conservatively taking the minimum? + Align Alignment = std::min(M->getDestAlign().valueOrOne(), + M->getSourceAlign().valueOrOne()); + if (performCallSlotOptzn( + M, M, M->getDest(), M->getSource(), + TypeSize::getFixed(CopySize->getZExtValue()), Alignment, + [C]() -> CallInst * { return C; })) { + LLVM_DEBUG(dbgs() << "Performed call slot optimization:\n" + << " call: " << *C << "\n" + << " memcpy: " << *M << "\n"); + eraseInstruction(M); + ++NumMemCpyInstr; + return true; } } } @@ -1557,7 +1574,7 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) { // *b = 42; // foo(*a) // It would be invalid to transform the second memcpy into foo(*b). - if (writtenBetween(MSSA, MemoryLocation::getForSource(MDep), + if (writtenBetween(MSSA, *AA, MemoryLocation::getForSource(MDep), MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(&CB))) return false; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergeICmps.cpp index aac0deea5be3..ce01ae5b2692 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergeICmps.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergeICmps.cpp @@ -144,31 +144,33 @@ BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) { LLVM_DEBUG(dbgs() << "volatile or atomic\n"); return {}; } - Value *const Addr = LoadI->getOperand(0); + Value *Addr = LoadI->getOperand(0); if (Addr->getType()->getPointerAddressSpace() != 0) { LLVM_DEBUG(dbgs() << "from non-zero AddressSpace\n"); return {}; } - auto *const GEP = dyn_cast<GetElementPtrInst>(Addr); - if (!GEP) - return {}; - LLVM_DEBUG(dbgs() << "GEP\n"); - if (GEP->isUsedOutsideOfBlock(LoadI->getParent())) { - LLVM_DEBUG(dbgs() << "used outside of block\n"); - return {}; - } - const auto &DL = GEP->getModule()->getDataLayout(); - if (!isDereferenceablePointer(GEP, LoadI->getType(), DL)) { + const auto &DL = LoadI->getModule()->getDataLayout(); + if (!isDereferenceablePointer(Addr, LoadI->getType(), DL)) { LLVM_DEBUG(dbgs() << "not dereferenceable\n"); // We need to make sure that we can do comparison in any order, so we // require memory to be unconditionnally dereferencable. return {}; } - APInt Offset = APInt(DL.getPointerTypeSizeInBits(GEP->getType()), 0); - if (!GEP->accumulateConstantOffset(DL, Offset)) - return {}; - return BCEAtom(GEP, LoadI, BaseId.getBaseId(GEP->getPointerOperand()), - Offset); + + APInt Offset = APInt(DL.getPointerTypeSizeInBits(Addr->getType()), 0); + Value *Base = Addr; + auto *GEP = dyn_cast<GetElementPtrInst>(Addr); + if (GEP) { + LLVM_DEBUG(dbgs() << "GEP\n"); + if (GEP->isUsedOutsideOfBlock(LoadI->getParent())) { + LLVM_DEBUG(dbgs() << "used outside of block\n"); + return {}; + } + if (!GEP->accumulateConstantOffset(DL, Offset)) + return {}; + Base = GEP->getPointerOperand(); + } + return BCEAtom(GEP, LoadI, BaseId.getBaseId(Base), Offset); } // A comparison between two BCE atoms, e.g. `a == o.a` in the example at the @@ -244,7 +246,7 @@ bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst, auto MayClobber = [&](LoadInst *LI) { // If a potentially clobbering instruction comes before the load, // we can still safely sink the load. - return !Inst->comesBefore(LI) && + return (Inst->getParent() != LI->getParent() || !Inst->comesBefore(LI)) && isModSet(AA.getModRefInfo(Inst, MemoryLocation::get(LI))); }; if (MayClobber(Cmp.Lhs.LoadI) || MayClobber(Cmp.Rhs.LoadI)) @@ -270,9 +272,8 @@ void BCECmpBlock::split(BasicBlock *NewParent, AliasAnalysis &AA) const { } // Do the actual spliting. - for (Instruction *Inst : reverse(OtherInsts)) { - Inst->moveBefore(&*NewParent->begin()); - } + for (Instruction *Inst : reverse(OtherInsts)) + Inst->moveBefore(*NewParent, NewParent->begin()); } bool BCECmpBlock::canSplit(AliasAnalysis &AA) const { @@ -368,8 +369,11 @@ Optional<BCECmpBlock> visitCmpBlock(Value *const Val, BasicBlock *const Block, return None; BCECmpBlock::InstructionSet BlockInsts( - {Result->Lhs.GEP, Result->Rhs.GEP, Result->Lhs.LoadI, Result->Rhs.LoadI, - Result->CmpI, BranchI}); + {Result->Lhs.LoadI, Result->Rhs.LoadI, Result->CmpI, BranchI}); + if (Result->Lhs.GEP) + BlockInsts.insert(Result->Lhs.GEP); + if (Result->Rhs.GEP) + BlockInsts.insert(Result->Rhs.GEP); return BCECmpBlock(std::move(*Result), Block, BlockInsts); } @@ -604,8 +608,15 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, NextCmpBlock->getParent(), InsertBefore); IRBuilder<> Builder(BB); // Add the GEPs from the first BCECmpBlock. - Value *const Lhs = Builder.Insert(FirstCmp.Lhs().GEP->clone()); - Value *const Rhs = Builder.Insert(FirstCmp.Rhs().GEP->clone()); + Value *Lhs, *Rhs; + if (FirstCmp.Lhs().GEP) + Lhs = Builder.Insert(FirstCmp.Lhs().GEP->clone()); + else + Lhs = FirstCmp.Lhs().LoadI->getPointerOperand(); + if (FirstCmp.Rhs().GEP) + Rhs = Builder.Insert(FirstCmp.Rhs().GEP->clone()); + else + Rhs = FirstCmp.Rhs().LoadI->getPointerOperand(); Value *IsEqual = nullptr; LLVM_DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons -> " diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp index 734532a6670c..6383d6ea838b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -76,13 +76,9 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/MergedLoadStoreMotion.h" -#include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/CFG.h" #include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/Loads.h" -#include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/Metadata.h" +#include "llvm/IR/Instructions.h" #include "llvm/InitializePasses.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/NewGVN.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/NewGVN.cpp index f35c9212a6f9..876ef3c427a6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/NewGVN.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/NewGVN.cpp @@ -88,8 +88,6 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" @@ -1076,6 +1074,9 @@ const Expression *NewGVN::createBinaryExpression(unsigned Opcode, Type *T, Value *Arg1, Value *Arg2, Instruction *I) const { auto *E = new (ExpressionAllocator) BasicExpression(2); + // TODO: we need to remove context instruction after Value Tracking + // can run without context instruction + const SimplifyQuery Q = SQ.getWithInstruction(I); E->setType(T); E->setOpcode(Opcode); @@ -1091,7 +1092,7 @@ const Expression *NewGVN::createBinaryExpression(unsigned Opcode, Type *T, E->op_push_back(lookupOperandLeader(Arg1)); E->op_push_back(lookupOperandLeader(Arg2)); - Value *V = SimplifyBinOp(Opcode, E->getOperand(0), E->getOperand(1), SQ); + Value *V = simplifyBinOp(Opcode, E->getOperand(0), E->getOperand(1), Q); if (auto Simplified = checkExprResults(E, I, V)) { addAdditionalUsers(Simplified, I); return Simplified.Expr; @@ -1147,6 +1148,9 @@ NewGVN::ExprResult NewGVN::checkExprResults(Expression *E, Instruction *I, NewGVN::ExprResult NewGVN::createExpression(Instruction *I) const { auto *E = new (ExpressionAllocator) BasicExpression(I->getNumOperands()); + // TODO: we need to remove context instruction after Value Tracking + // can run without context instruction + const SimplifyQuery Q = SQ.getWithInstruction(I); bool AllConstant = setBasicExpressionInfo(I, E); @@ -1169,13 +1173,13 @@ NewGVN::ExprResult NewGVN::createExpression(Instruction *I) const { Predicate = CmpInst::getSwappedPredicate(Predicate); } E->setOpcode((CI->getOpcode() << 8) | Predicate); - // TODO: 25% of our time is spent in SimplifyCmpInst with pointer operands + // TODO: 25% of our time is spent in simplifyCmpInst with pointer operands assert(I->getOperand(0)->getType() == I->getOperand(1)->getType() && "Wrong types on cmp instruction"); assert((E->getOperand(0)->getType() == I->getOperand(0)->getType() && E->getOperand(1)->getType() == I->getOperand(1)->getType())); Value *V = - SimplifyCmpInst(Predicate, E->getOperand(0), E->getOperand(1), SQ); + simplifyCmpInst(Predicate, E->getOperand(0), E->getOperand(1), Q); if (auto Simplified = checkExprResults(E, I, V)) return Simplified; } else if (isa<SelectInst>(I)) { @@ -1183,26 +1187,26 @@ NewGVN::ExprResult NewGVN::createExpression(Instruction *I) const { E->getOperand(1) == E->getOperand(2)) { assert(E->getOperand(1)->getType() == I->getOperand(1)->getType() && E->getOperand(2)->getType() == I->getOperand(2)->getType()); - Value *V = SimplifySelectInst(E->getOperand(0), E->getOperand(1), - E->getOperand(2), SQ); + Value *V = simplifySelectInst(E->getOperand(0), E->getOperand(1), + E->getOperand(2), Q); if (auto Simplified = checkExprResults(E, I, V)) return Simplified; } } else if (I->isBinaryOp()) { Value *V = - SimplifyBinOp(E->getOpcode(), E->getOperand(0), E->getOperand(1), SQ); + simplifyBinOp(E->getOpcode(), E->getOperand(0), E->getOperand(1), Q); if (auto Simplified = checkExprResults(E, I, V)) return Simplified; } else if (auto *CI = dyn_cast<CastInst>(I)) { Value *V = - SimplifyCastInst(CI->getOpcode(), E->getOperand(0), CI->getType(), SQ); + simplifyCastInst(CI->getOpcode(), E->getOperand(0), CI->getType(), Q); if (auto Simplified = checkExprResults(E, I, V)) return Simplified; } else if (auto *GEPI = dyn_cast<GetElementPtrInst>(I)) { Value *V = - SimplifyGEPInst(GEPI->getSourceElementType(), *E->op_begin(), + simplifyGEPInst(GEPI->getSourceElementType(), *E->op_begin(), makeArrayRef(std::next(E->op_begin()), E->op_end()), - GEPI->isInBounds(), SQ); + GEPI->isInBounds(), Q); if (auto Simplified = checkExprResults(E, I, V)) return Simplified; } else if (AllConstant) { @@ -1453,10 +1457,12 @@ NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr, if (Offset >= 0) { if (auto *C = dyn_cast<Constant>( lookupOperandLeader(DepSI->getValueOperand()))) { - LLVM_DEBUG(dbgs() << "Coercing load from store " << *DepSI - << " to constant " << *C << "\n"); - return createConstantExpression( - getConstantStoreValueForLoad(C, Offset, LoadType, DL)); + if (Constant *Res = + getConstantStoreValueForLoad(C, Offset, LoadType, DL)) { + LLVM_DEBUG(dbgs() << "Coercing load from store " << *DepSI + << " to constant " << *Res << "\n"); + return createConstantExpression(Res); + } } } } else if (auto *DepLI = dyn_cast<LoadInst>(DepInst)) { @@ -1503,9 +1509,8 @@ NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr, else if (auto *II = dyn_cast<IntrinsicInst>(DepInst)) { if (II->getIntrinsicID() == Intrinsic::lifetime_start) return createConstantExpression(UndefValue::get(LoadType)); - } else if (isAllocationFn(DepInst, TLI)) - if (auto *InitVal = getInitialValueOfAllocation(cast<CallBase>(DepInst), - TLI, LoadType)) + } else if (auto *InitVal = + getInitialValueOfAllocation(DepInst, TLI, LoadType)) return createConstantExpression(InitVal); return nullptr; @@ -3142,9 +3147,8 @@ bool NewGVN::singleReachablePHIPath( // connected component finding in this routine, and it's probably not worth // the complexity for the time being. So, we just keep a set of visited // MemoryAccess and return true when we hit a cycle. - if (Visited.count(First)) + if (!Visited.insert(First).second) return true; - Visited.insert(First); const auto *EndDef = First; for (auto *ChainDef : optimized_def_chain(First)) { @@ -3353,7 +3357,7 @@ void NewGVN::verifyStoreExpressions() const { // instruction set, propagating value numbers, marking things touched, etc, // until the set of touched instructions is completely empty. void NewGVN::iterateTouchedInstructions() { - unsigned int Iterations = 0; + uint64_t Iterations = 0; // Figure out where touchedinstructions starts int FirstInstr = TouchedInstructions.find_first(); // Nothing set, nothing to iterate, just return. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp index e0d0301c1ef6..689a2a286cb9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp @@ -125,6 +125,9 @@ static bool runPartiallyInlineLibCalls(Function &F, TargetLibraryInfo *TLI, if (Call->isNoBuiltin() || Call->isStrictFP()) continue; + if (Call->isMustTailCall()) + continue; + // Skip if function either has local linkage or is not a known library // function. LibFunc LF; @@ -137,7 +140,7 @@ static bool runPartiallyInlineLibCalls(Function &F, TargetLibraryInfo *TLI, case LibFunc_sqrt: if (TTI->haveFastSqrt(Call->getType()) && optimizeSQRT(Call, CalledFunc, *CurrBB, BB, TTI, - DTU.hasValue() ? DTU.getPointer() : nullptr)) + DTU ? DTU.getPointer() : nullptr)) break; continue; default: diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp index a110f7d5c241..e1cc3fc71c3e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp @@ -53,9 +53,9 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LegacyPassManager.h" @@ -65,6 +65,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" #define DEBUG_TYPE "safepoint-placement" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reassociate.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reassociate.cpp index c354fa177a60..da1737979305 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -24,7 +24,6 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/PostOrderIterator.h" -#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" @@ -42,7 +41,6 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" @@ -54,7 +52,6 @@ #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" @@ -183,7 +180,7 @@ void ReassociatePass::BuildRankMap(Function &F, // we cannot move. This ensures that the ranks for these instructions are // all different in the block. for (Instruction &I : *BB) - if (mayBeMemoryDependent(I)) + if (mayHaveNonDefUseDependency(I)) ValueRankMap[&I] = ++BBRank; } } @@ -1076,7 +1073,7 @@ static BinaryOperator *ConvertShiftToMul(Instruction *Shl) { BinaryOperator *Mul = BinaryOperator::CreateMul(Shl->getOperand(0), MulCst, "", Shl); - Shl->setOperand(0, UndefValue::get(Shl->getType())); // Drop use of op. + Shl->setOperand(0, PoisonValue::get(Shl->getType())); // Drop use of op. Mul->takeName(Shl); // Everyone now refers to the mul instruction. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reg2Mem.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reg2Mem.cpp index a49b9ad3f62b..9dc64493a9ee 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reg2Mem.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reg2Mem.cpp @@ -24,8 +24,6 @@ #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index b795ad3899bc..51e4a5773f3e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -258,6 +258,7 @@ struct GCPtrLivenessData { // base relation will remain. Internally, we add a mixture of the two // types, then update all the second type to the first type using DefiningValueMapTy = MapVector<Value *, Value *>; +using IsKnownBaseMapTy = MapVector<Value *, bool>; using PointerToBaseTy = MapVector<Value *, Value *>; using StatepointLiveSetTy = SetVector<Value *>; using RematerializedValueMapTy = @@ -281,19 +282,29 @@ struct PartiallyConstructedSafepointRecord { RematerializedValueMapTy RematerializedValues; }; +struct RematerizlizationCandidateRecord { + // Chain from derived pointer to base. + SmallVector<Instruction *, 3> ChainToBase; + // Original base. + Value *RootOfChain; + // Cost of chain. + InstructionCost Cost; +}; +using RematCandTy = MapVector<Value *, RematerizlizationCandidateRecord>; + } // end anonymous namespace static ArrayRef<Use> GetDeoptBundleOperands(const CallBase *Call) { Optional<OperandBundleUse> DeoptBundle = Call->getOperandBundle(LLVMContext::OB_deopt); - if (!DeoptBundle.hasValue()) { + if (!DeoptBundle) { assert(AllowStatepointWithNoDeoptInfo && "Found non-leaf call without deopt info!"); return None; } - return DeoptBundle.getValue().Inputs; + return DeoptBundle->Inputs; } /// Compute the live-in set for every basic block in the function @@ -385,45 +396,16 @@ static void analyzeParsePointLiveness( Result.LiveSet = LiveSet; } -// Returns true is V is a knownBaseResult. -static bool isKnownBaseResult(Value *V); - -// Returns true if V is a BaseResult that already exists in the IR, i.e. it is -// not created by the findBasePointers algorithm. -static bool isOriginalBaseResult(Value *V); - -namespace { - -/// A single base defining value - An immediate base defining value for an -/// instruction 'Def' is an input to 'Def' whose base is also a base of 'Def'. -/// For instructions which have multiple pointer [vector] inputs or that -/// transition between vector and scalar types, there is no immediate base -/// defining value. The 'base defining value' for 'Def' is the transitive -/// closure of this relation stopping at the first instruction which has no -/// immediate base defining value. The b.d.v. might itself be a base pointer, -/// but it can also be an arbitrary derived pointer. -struct BaseDefiningValueResult { - /// Contains the value which is the base defining value. - Value * const BDV; - - /// True if the base defining value is also known to be an actual base - /// pointer. - const bool IsKnownBase; - - BaseDefiningValueResult(Value *BDV, bool IsKnownBase) - : BDV(BDV), IsKnownBase(IsKnownBase) { -#ifndef NDEBUG - // Check consistency between new and old means of checking whether a BDV is - // a base. - bool MustBeBase = isKnownBaseResult(BDV); - assert(!MustBeBase || MustBeBase == IsKnownBase); -#endif - } -}; +/// Returns true if V is a known base. +static bool isKnownBase(Value *V, const IsKnownBaseMapTy &KnownBases); -} // end anonymous namespace +/// Caches the IsKnownBase flag for a value and asserts that it wasn't present +/// in the cache before. +static void setKnownBase(Value *V, bool IsKnownBase, + IsKnownBaseMapTy &KnownBases); -static BaseDefiningValueResult findBaseDefiningValue(Value *I); +static Value *findBaseDefiningValue(Value *I, DefiningValueMapTy &Cache, + IsKnownBaseMapTy &KnownBases); /// Return a base defining value for the 'Index' element of the given vector /// instruction 'I'. If Index is null, returns a BDV for the entire vector @@ -434,76 +416,122 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I); /// vector returned is a BDV (and possibly a base) of the entire vector 'I'. /// If the later, the return pointer is a BDV (or possibly a base) for the /// particular element in 'I'. -static BaseDefiningValueResult -findBaseDefiningValueOfVector(Value *I) { +static Value *findBaseDefiningValueOfVector(Value *I, DefiningValueMapTy &Cache, + IsKnownBaseMapTy &KnownBases) { // Each case parallels findBaseDefiningValue below, see that code for // detailed motivation. - if (isa<Argument>(I)) + auto Cached = Cache.find(I); + if (Cached != Cache.end()) + return Cached->second; + + if (isa<Argument>(I)) { // An incoming argument to the function is a base pointer - return BaseDefiningValueResult(I, true); + Cache[I] = I; + setKnownBase(I, /* IsKnownBase */true, KnownBases); + return I; + } - if (isa<Constant>(I)) + if (isa<Constant>(I)) { // Base of constant vector consists only of constant null pointers. // For reasoning see similar case inside 'findBaseDefiningValue' function. - return BaseDefiningValueResult(ConstantAggregateZero::get(I->getType()), - true); + auto *CAZ = ConstantAggregateZero::get(I->getType()); + Cache[I] = CAZ; + setKnownBase(CAZ, /* IsKnownBase */true, KnownBases); + return CAZ; + } - if (isa<LoadInst>(I)) - return BaseDefiningValueResult(I, true); + if (isa<LoadInst>(I)) { + Cache[I] = I; + setKnownBase(I, /* IsKnownBase */true, KnownBases); + return I; + } - if (isa<InsertElementInst>(I)) + if (isa<InsertElementInst>(I)) { // We don't know whether this vector contains entirely base pointers or // not. To be conservatively correct, we treat it as a BDV and will // duplicate code as needed to construct a parallel vector of bases. - return BaseDefiningValueResult(I, false); + Cache[I] = I; + setKnownBase(I, /* IsKnownBase */false, KnownBases); + return I; + } - if (isa<ShuffleVectorInst>(I)) + if (isa<ShuffleVectorInst>(I)) { // We don't know whether this vector contains entirely base pointers or // not. To be conservatively correct, we treat it as a BDV and will // duplicate code as needed to construct a parallel vector of bases. // TODO: There a number of local optimizations which could be applied here // for particular sufflevector patterns. - return BaseDefiningValueResult(I, false); + Cache[I] = I; + setKnownBase(I, /* IsKnownBase */false, KnownBases); + return I; + } // The behavior of getelementptr instructions is the same for vector and // non-vector data types. - if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) - return findBaseDefiningValue(GEP->getPointerOperand()); + if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { + auto *BDV = + findBaseDefiningValue(GEP->getPointerOperand(), Cache, KnownBases); + Cache[GEP] = BDV; + return BDV; + } + + // The behavior of freeze instructions is the same for vector and + // non-vector data types. + if (auto *Freeze = dyn_cast<FreezeInst>(I)) { + auto *BDV = findBaseDefiningValue(Freeze->getOperand(0), Cache, KnownBases); + Cache[Freeze] = BDV; + return BDV; + } // If the pointer comes through a bitcast of a vector of pointers to // a vector of another type of pointer, then look through the bitcast - if (auto *BC = dyn_cast<BitCastInst>(I)) - return findBaseDefiningValue(BC->getOperand(0)); + if (auto *BC = dyn_cast<BitCastInst>(I)) { + auto *BDV = findBaseDefiningValue(BC->getOperand(0), Cache, KnownBases); + Cache[BC] = BDV; + return BDV; + } // We assume that functions in the source language only return base // pointers. This should probably be generalized via attributes to support // both source language and internal functions. - if (isa<CallInst>(I) || isa<InvokeInst>(I)) - return BaseDefiningValueResult(I, true); + if (isa<CallInst>(I) || isa<InvokeInst>(I)) { + Cache[I] = I; + setKnownBase(I, /* IsKnownBase */true, KnownBases); + return I; + } // A PHI or Select is a base defining value. The outer findBasePointer // algorithm is responsible for constructing a base value for this BDV. assert((isa<SelectInst>(I) || isa<PHINode>(I)) && "unknown vector instruction - no base found for vector element"); - return BaseDefiningValueResult(I, false); + Cache[I] = I; + setKnownBase(I, /* IsKnownBase */false, KnownBases); + return I; } /// Helper function for findBasePointer - Will return a value which either a) /// defines the base pointer for the input, b) blocks the simple search /// (i.e. a PHI or Select of two derived pointers), or c) involves a change /// from pointer to vector type or back. -static BaseDefiningValueResult findBaseDefiningValue(Value *I) { +static Value *findBaseDefiningValue(Value *I, DefiningValueMapTy &Cache, + IsKnownBaseMapTy &KnownBases) { assert(I->getType()->isPtrOrPtrVectorTy() && "Illegal to ask for the base pointer of a non-pointer type"); + auto Cached = Cache.find(I); + if (Cached != Cache.end()) + return Cached->second; if (I->getType()->isVectorTy()) - return findBaseDefiningValueOfVector(I); + return findBaseDefiningValueOfVector(I, Cache, KnownBases); - if (isa<Argument>(I)) + if (isa<Argument>(I)) { // An incoming argument to the function is a base pointer // We should have never reached here if this argument isn't an gc value - return BaseDefiningValueResult(I, true); + Cache[I] = I; + setKnownBase(I, /* IsKnownBase */true, KnownBases); + return I; + } if (isa<Constant>(I)) { // We assume that objects with a constant base (e.g. a global) can't move @@ -516,8 +544,10 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I) { // "phi (const1, const2)" or "phi (const, regular gc ptr)". // See constant.ll file for relevant test cases. - return BaseDefiningValueResult( - ConstantPointerNull::get(cast<PointerType>(I->getType())), true); + auto *CPN = ConstantPointerNull::get(cast<PointerType>(I->getType())); + Cache[I] = CPN; + setKnownBase(CPN, /* IsKnownBase */true, KnownBases); + return CPN; } // inttoptrs in an integral address space are currently ill-defined. We @@ -525,8 +555,11 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I) { // constant rule above and because we don't really have a better semantic // to give them. Note that the optimizer is always free to insert undefined // behavior on dynamically dead paths as well. - if (isa<IntToPtrInst>(I)) - return BaseDefiningValueResult(I, true); + if (isa<IntToPtrInst>(I)) { + Cache[I] = I; + setKnownBase(I, /* IsKnownBase */true, KnownBases); + return I; + } if (CastInst *CI = dyn_cast<CastInst>(I)) { Value *Def = CI->stripPointerCasts(); @@ -539,16 +572,31 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I) { // not simply a pointer cast (i.e. an inttoptr). We don't know how to // handle int->ptr conversion. assert(!isa<CastInst>(Def) && "shouldn't find another cast here"); - return findBaseDefiningValue(Def); + auto *BDV = findBaseDefiningValue(Def, Cache, KnownBases); + Cache[CI] = BDV; + return BDV; } - if (isa<LoadInst>(I)) + if (isa<LoadInst>(I)) { // The value loaded is an gc base itself - return BaseDefiningValueResult(I, true); + Cache[I] = I; + setKnownBase(I, /* IsKnownBase */true, KnownBases); + return I; + } - if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I)) + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I)) { // The base of this GEP is the base - return findBaseDefiningValue(GEP->getPointerOperand()); + auto *BDV = + findBaseDefiningValue(GEP->getPointerOperand(), Cache, KnownBases); + Cache[GEP] = BDV; + return BDV; + } + + if (auto *Freeze = dyn_cast<FreezeInst>(I)) { + auto *BDV = findBaseDefiningValue(Freeze->getOperand(0), Cache, KnownBases); + Cache[Freeze] = BDV; + return BDV; + } if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { switch (II->getIntrinsicID()) { @@ -569,24 +617,32 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I) { llvm_unreachable( "interaction with the gcroot mechanism is not supported"); case Intrinsic::experimental_gc_get_pointer_base: - return findBaseDefiningValue(II->getOperand(0)); + auto *BDV = findBaseDefiningValue(II->getOperand(0), Cache, KnownBases); + Cache[II] = BDV; + return BDV; } } // We assume that functions in the source language only return base // pointers. This should probably be generalized via attributes to support // both source language and internal functions. - if (isa<CallInst>(I) || isa<InvokeInst>(I)) - return BaseDefiningValueResult(I, true); + if (isa<CallInst>(I) || isa<InvokeInst>(I)) { + Cache[I] = I; + setKnownBase(I, /* IsKnownBase */true, KnownBases); + return I; + } // TODO: I have absolutely no idea how to implement this part yet. It's not // necessarily hard, I just haven't really looked at it yet. assert(!isa<LandingPadInst>(I) && "Landing Pad is unimplemented"); - if (isa<AtomicCmpXchgInst>(I)) + if (isa<AtomicCmpXchgInst>(I)) { // A CAS is effectively a atomic store and load combined under a // predicate. From the perspective of base pointers, we just treat it // like a load. - return BaseDefiningValueResult(I, true); + Cache[I] = I; + setKnownBase(I, /* IsKnownBase */true, KnownBases); + return I; + } assert(!isa<AtomicRMWInst>(I) && "Xchg handled above, all others are " "binary ops which don't apply to pointers"); @@ -594,8 +650,11 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I) { // The aggregate ops. Aggregates can either be in the heap or on the // stack, but in either case, this is simply a field load. As a result, // this is a defining definition of the base just like a load is. - if (isa<ExtractValueInst>(I)) - return BaseDefiningValueResult(I, true); + if (isa<ExtractValueInst>(I)) { + Cache[I] = I; + setKnownBase(I, /* IsKnownBase */true, KnownBases); + return I; + } // We should never see an insert vector since that would require we be // tracing back a struct value not a pointer value. @@ -606,6 +665,8 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I) { // substituting gc.get.pointer.base() intrinsic. bool IsKnownBase = isa<Instruction>(I) && cast<Instruction>(I)->getMetadata("is_base_value"); + setKnownBase(I, /* IsKnownBase */IsKnownBase, KnownBases); + Cache[I] = I; // An extractelement produces a base result exactly when it's input does. // We may need to insert a parallel instruction to extract the appropriate @@ -615,33 +676,38 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I) { // Note: There a lot of obvious peephole cases here. This are deliberately // handled after the main base pointer inference algorithm to make writing // test cases to exercise that code easier. - return BaseDefiningValueResult(I, IsKnownBase); + return I; // The last two cases here don't return a base pointer. Instead, they // return a value which dynamically selects from among several base // derived pointers (each with it's own base potentially). It's the job of // the caller to resolve these. assert((isa<SelectInst>(I) || isa<PHINode>(I)) && - "missing instruction case in findBaseDefiningValing"); - return BaseDefiningValueResult(I, IsKnownBase); + "missing instruction case in findBaseDefiningValue"); + return I; } /// Returns the base defining value for this value. -static Value *findBaseDefiningValueCached(Value *I, DefiningValueMapTy &Cache) { - Value *&Cached = Cache[I]; - if (!Cached) { - Cached = findBaseDefiningValue(I).BDV; +static Value *findBaseDefiningValueCached(Value *I, DefiningValueMapTy &Cache, + IsKnownBaseMapTy &KnownBases) { + if (Cache.find(I) == Cache.end()) { + auto *BDV = findBaseDefiningValue(I, Cache, KnownBases); + Cache[I] = BDV; LLVM_DEBUG(dbgs() << "fBDV-cached: " << I->getName() << " -> " - << Cached->getName() << "\n"); + << Cache[I]->getName() << ", is known base = " + << KnownBases[I] << "\n"); } assert(Cache[I] != nullptr); - return Cached; + assert(KnownBases.find(Cache[I]) != KnownBases.end() && + "Cached value must be present in known bases map"); + return Cache[I]; } /// Return a base pointer for this value if known. Otherwise, return it's /// base defining value. -static Value *findBaseOrBDV(Value *I, DefiningValueMapTy &Cache) { - Value *Def = findBaseDefiningValueCached(I, Cache); +static Value *findBaseOrBDV(Value *I, DefiningValueMapTy &Cache, + IsKnownBaseMapTy &KnownBases) { + Value *Def = findBaseDefiningValueCached(I, Cache, KnownBases); auto Found = Cache.find(Def); if (Found != Cache.end()) { // Either a base-of relation, or a self reference. Caller must check. @@ -651,6 +717,7 @@ static Value *findBaseOrBDV(Value *I, DefiningValueMapTy &Cache) { return Def; } +#ifndef NDEBUG /// This value is a base pointer that is not generated by RS4GC, i.e. it already /// exists in the code. static bool isOriginalBaseResult(Value *V) { @@ -659,21 +726,22 @@ static bool isOriginalBaseResult(Value *V) { !isa<ExtractElementInst>(V) && !isa<InsertElementInst>(V) && !isa<ShuffleVectorInst>(V); } +#endif -/// Given the result of a call to findBaseDefiningValue, or findBaseOrBDV, -/// is it known to be a base pointer? Or do we need to continue searching. -static bool isKnownBaseResult(Value *V) { - if (isOriginalBaseResult(V)) - return true; - if (isa<Instruction>(V) && - cast<Instruction>(V)->getMetadata("is_base_value")) { - // This is a previously inserted base phi or select. We know - // that this is a base value. - return true; - } +static bool isKnownBase(Value *V, const IsKnownBaseMapTy &KnownBases) { + auto It = KnownBases.find(V); + assert(It != KnownBases.end() && "Value not present in the map"); + return It->second; +} - // We need to keep searching - return false; +static void setKnownBase(Value *V, bool IsKnownBase, + IsKnownBaseMapTy &KnownBases) { +#ifndef NDEBUG + auto It = KnownBases.find(V); + if (It != KnownBases.end()) + assert(It->second == IsKnownBase && "Changing already present value"); +#endif + KnownBases[V] = IsKnownBase; } // Returns true if First and Second values are both scalar or both vector. @@ -801,10 +869,11 @@ static raw_ostream &operator<<(raw_ostream &OS, const BDVState &State) { /// For gc objects, this is simply itself. On success, returns a value which is /// the base pointer. (This is reliable and can be used for relocation.) On /// failure, returns nullptr. -static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { - Value *Def = findBaseOrBDV(I, Cache); +static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache, + IsKnownBaseMapTy &KnownBases) { + Value *Def = findBaseOrBDV(I, Cache, KnownBases); - if (isKnownBaseResult(Def) && areBothVectorOrScalar(Def, I)) + if (isKnownBase(Def, KnownBases) && areBothVectorOrScalar(Def, I)) return Def; // Here's the rough algorithm: @@ -887,8 +956,8 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { assert(!isOriginalBaseResult(Current) && "why did it get added?"); auto visitIncomingValue = [&](Value *InVal) { - Value *Base = findBaseOrBDV(InVal, Cache); - if (isKnownBaseResult(Base) && areBothVectorOrScalar(Base, InVal)) + Value *Base = findBaseOrBDV(InVal, Cache, KnownBases); + if (isKnownBase(Base, KnownBases) && areBothVectorOrScalar(Base, InVal)) // Known bases won't need new instructions introduced and can be // ignored safely. However, this can only be done when InVal and Base // are both scalar or both vector. Otherwise, we need to find a @@ -924,12 +993,16 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { for (auto Pair : States) { Value *BDV = Pair.first; auto canPruneInput = [&](Value *V) { - Value *BDV = findBaseOrBDV(V, Cache); - if (V->stripPointerCasts() != BDV) + // If the input of the BDV is the BDV itself we can prune it. This is + // only possible if the BDV is a PHI node. + if (V->stripPointerCasts() == BDV) + return true; + Value *VBDV = findBaseOrBDV(V, Cache, KnownBases); + if (V->stripPointerCasts() != VBDV) return false; // The assumption is that anything not in the state list is // propagates a base pointer. - return States.count(BDV) == 0; + return States.count(VBDV) == 0; }; bool CanPrune = true; @@ -975,13 +1048,13 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { // Only values that do not have known bases or those that have differing // type (scalar versus vector) from a possible known base should be in the // lattice. - assert((!isKnownBaseResult(BDV) || + assert((!isKnownBase(BDV, KnownBases) || !areBothVectorOrScalar(BDV, Pair.second.getBaseValue())) && "why did it get added?"); BDVState NewState(BDV); visitBDVOperands(BDV, [&](Value *Op) { - Value *BDV = findBaseOrBDV(Op, Cache); + Value *BDV = findBaseOrBDV(Op, Cache, KnownBases); auto OpState = GetStateForBDV(BDV, Op); NewState.meet(OpState); }); @@ -1014,8 +1087,9 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { // Only values that do not have known bases or those that have differing // type (scalar versus vector) from a possible known base should be in the // lattice. - assert((!isKnownBaseResult(I) || !areBothVectorOrScalar(I, BaseValue)) && - "why did it get added?"); + assert( + (!isKnownBase(I, KnownBases) || !areBothVectorOrScalar(I, BaseValue)) && + "why did it get added?"); assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); if (!State.isBase() || !isa<VectorType>(BaseValue->getType())) @@ -1033,6 +1107,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { State.getBaseValue(), EE->getIndexOperand(), "base_ee", EE); BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {})); States[I] = BDVState(I, BDVState::Base, BaseInst); + setKnownBase(BaseInst, /* IsKnownBase */true, KnownBases); } else if (!isa<VectorType>(I->getType())) { // We need to handle cases that have a vector base but the instruction is // a scalar type (these could be phis or selects or any instruction that @@ -1055,7 +1130,8 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { // Only values that do not have known bases or those that have differing // type (scalar versus vector) from a possible known base should be in the // lattice. - assert((!isKnownBaseResult(I) || !areBothVectorOrScalar(I, State.getBaseValue())) && + assert((!isKnownBase(I, KnownBases) || + !areBothVectorOrScalar(I, State.getBaseValue())) && "why did it get added?"); assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); @@ -1087,6 +1163,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { // Add metadata marking this as a base value BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {})); States[I] = BDVState(I, BDVState::Conflict, BaseInst); + setKnownBase(BaseInst, /* IsKnownBase */true, KnownBases); } #ifndef NDEBUG @@ -1102,7 +1179,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { // assured to be able to determine an instruction which produces it's base // pointer. auto getBaseForInput = [&](Value *Input, Instruction *InsertPt) { - Value *BDV = findBaseOrBDV(Input, Cache); + Value *BDV = findBaseOrBDV(Input, Cache, KnownBases); Value *Base = nullptr; if (!States.count(BDV)) { assert(areBothVectorOrScalar(BDV, Input)); @@ -1129,7 +1206,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { // Only values that do not have known bases or those that have differing // type (scalar versus vector) from a possible known base should be in the // lattice. - assert((!isKnownBaseResult(BDV) || + assert((!isKnownBase(BDV, KnownBases) || !areBothVectorOrScalar(BDV, State.getBaseValue())) && "why did it get added?"); assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); @@ -1154,13 +1231,21 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { #ifndef NDEBUG Value *OldBase = BlockToValue[InBB]; Value *Base = getBaseForInput(InVal, nullptr); + + // We can't use `stripPointerCasts` instead of this function because + // `stripPointerCasts` doesn't handle vectors of pointers. + auto StripBitCasts = [](Value *V) -> Value * { + while (auto *BC = dyn_cast<BitCastInst>(V)) + V = BC->getOperand(0); + return V; + }; // In essence this assert states: the only way two values // incoming from the same basic block may be different is by // being different bitcasts of the same value. A cleanup // that remains TODO is changing findBaseOrBDV to return an // llvm::Value of the correct type (and still remain pure). // This will remove the need to add bitcasts. - assert(Base->stripPointerCasts() == OldBase->stripPointerCasts() && + assert(StripBitCasts(Base) == StripBitCasts(OldBase) && "findBaseOrBDV should be pure!"); #endif } @@ -1223,8 +1308,9 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { // Only values that do not have known bases or those that have differing // type (scalar versus vector) from a possible known base should be in the // lattice. - assert((!isKnownBaseResult(BDV) || !areBothVectorOrScalar(BDV, Base)) && - "why did it get added?"); + assert( + (!isKnownBase(BDV, KnownBases) || !areBothVectorOrScalar(BDV, Base)) && + "why did it get added?"); LLVM_DEBUG( dbgs() << "Updating base value cache" @@ -1255,9 +1341,10 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { // pointer was a base pointer. static void findBasePointers(const StatepointLiveSetTy &live, PointerToBaseTy &PointerToBase, DominatorTree *DT, - DefiningValueMapTy &DVCache) { + DefiningValueMapTy &DVCache, + IsKnownBaseMapTy &KnownBases) { for (Value *ptr : live) { - Value *base = findBasePointer(ptr, DVCache); + Value *base = findBasePointer(ptr, DVCache, KnownBases); assert(base && "failed to find base pointer"); PointerToBase[ptr] = base; assert((!isa<Instruction>(base) || !isa<Instruction>(ptr) || @@ -1272,7 +1359,8 @@ static void findBasePointers(const StatepointLiveSetTy &live, static void findBasePointers(DominatorTree &DT, DefiningValueMapTy &DVCache, CallBase *Call, PartiallyConstructedSafepointRecord &result, - PointerToBaseTy &PointerToBase) { + PointerToBaseTy &PointerToBase, + IsKnownBaseMapTy &KnownBases) { StatepointLiveSetTy PotentiallyDerivedPointers = result.LiveSet; // We assume that all pointers passed to deopt are base pointers; as an // optimization, we can use this to avoid seperately materializing the base @@ -1286,7 +1374,8 @@ static void findBasePointers(DominatorTree &DT, DefiningValueMapTy &DVCache, PotentiallyDerivedPointers.remove(V); PointerToBase[V] = V; } - findBasePointers(PotentiallyDerivedPointers, PointerToBase, &DT, DVCache); + findBasePointers(PotentiallyDerivedPointers, PointerToBase, &DT, DVCache, + KnownBases); } /// Given an updated version of the dataflow liveness results, update the @@ -1349,23 +1438,23 @@ static constexpr Attribute::AttrKind FnAttrsToStrip[] = // Create new attribute set containing only attributes which can be transferred // from original call to the safepoint. static AttributeList legalizeCallAttributes(LLVMContext &Ctx, - AttributeList AL) { - if (AL.isEmpty()) - return AL; + AttributeList OrigAL, + AttributeList StatepointAL) { + if (OrigAL.isEmpty()) + return StatepointAL; // Remove the readonly, readnone, and statepoint function attributes. - AttrBuilder FnAttrs(Ctx, AL.getFnAttrs()); + AttrBuilder FnAttrs(Ctx, OrigAL.getFnAttrs()); for (auto Attr : FnAttrsToStrip) FnAttrs.removeAttribute(Attr); - for (Attribute A : AL.getFnAttrs()) { + for (Attribute A : OrigAL.getFnAttrs()) { if (isStatepointDirectiveAttr(A)) FnAttrs.removeAttribute(A); } // Just skip parameter and return attributes for now - return AttributeList::get(Ctx, AttributeList::FunctionIndex, - AttributeSet::get(Ctx, FnAttrs)); + return StatepointAL.addFnAttributes(Ctx, FnAttrs); } /// Helper function to place all gc relocates necessary for the given @@ -1570,8 +1659,8 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ assert(DeoptLowering.equals("live-through") && "Unsupported value!"); } - Value *CallTarget = Call->getCalledOperand(); - if (Function *F = dyn_cast<Function>(CallTarget)) { + FunctionCallee CallTarget(Call->getFunctionType(), Call->getCalledOperand()); + if (Function *F = dyn_cast<Function>(CallTarget.getCallee())) { auto IID = F->getIntrinsicID(); if (IID == Intrinsic::experimental_deoptimize) { // Calls to llvm.experimental.deoptimize are lowered to calls to the @@ -1589,8 +1678,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ // the same module. This is fine -- we assume the frontend knew what it // was doing when generating this kind of IR. CallTarget = F->getParent() - ->getOrInsertFunction("__llvm_deoptimize", FTy) - .getCallee(); + ->getOrInsertFunction("__llvm_deoptimize", FTy); IsDeoptimize = true; } else if (IID == Intrinsic::memcpy_element_unordered_atomic || @@ -1686,8 +1774,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ CallTarget = F->getParent() - ->getOrInsertFunction(GetFunctionName(IID, ElementSizeCI), FTy) - .getCallee(); + ->getOrInsertFunction(GetFunctionName(IID, ElementSizeCI), FTy); } } @@ -1705,8 +1792,8 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ // function attributes. In case if we can handle this set of attributes - // set up function attrs directly on statepoint and return attrs later for // gc_result intrinsic. - SPCall->setAttributes( - legalizeCallAttributes(CI->getContext(), CI->getAttributes())); + SPCall->setAttributes(legalizeCallAttributes( + CI->getContext(), CI->getAttributes(), SPCall->getAttributes())); Token = cast<GCStatepointInst>(SPCall); @@ -1732,8 +1819,8 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ // function attributes. In case if we can handle this set of attributes - // set up function attrs directly on statepoint and return attrs later for // gc_result intrinsic. - SPInvoke->setAttributes( - legalizeCallAttributes(II->getContext(), II->getAttributes())); + SPInvoke->setAttributes(legalizeCallAttributes( + II->getContext(), II->getAttributes(), SPInvoke->getAttributes())); Token = cast<GCStatepointInst>(SPInvoke); @@ -2071,6 +2158,7 @@ static void relocationViaAlloca( assert(PromotableAllocas.size() == Live.size() + NumRematerializedValues && "we must have the same allocas with lives"); + (void) NumRematerializedValues; if (!PromotableAllocas.empty()) { // Apply mem2reg to promote alloca to SSA PromoteMemToReg(PromotableAllocas, DT); @@ -2221,27 +2309,25 @@ static bool AreEquivalentPhiNodes(PHINode &OrigRootPhi, PHINode &AlternateRootPh return true; } -// From the statepoint live set pick values that are cheaper to recompute then -// to relocate. Remove this values from the live set, rematerialize them after -// statepoint and record them in "Info" structure. Note that similar to -// relocated values we don't do any user adjustments here. -static void rematerializeLiveValues(CallBase *Call, - PartiallyConstructedSafepointRecord &Info, - PointerToBaseTy &PointerToBase, - TargetTransformInfo &TTI) { +// Find derived pointers that can be recomputed cheap enough and fill +// RematerizationCandidates with such candidates. +static void +findRematerializationCandidates(PointerToBaseTy PointerToBase, + RematCandTy &RematerizationCandidates, + TargetTransformInfo &TTI) { const unsigned int ChainLengthThreshold = 10; - // Record values we are going to delete from this statepoint live set. - // We can not di this in following loop due to iterator invalidation. - SmallVector<Value *, 32> LiveValuesToBeDeleted; + for (auto P2B : PointerToBase) { + auto *Derived = P2B.first; + auto *Base = P2B.second; + // Consider only derived pointers. + if (Derived == Base) + continue; - for (Value *LiveValue: Info.LiveSet) { - // For each live pointer find its defining chain + // For each live pointer find its defining chain. SmallVector<Instruction *, 3> ChainToBase; - assert(PointerToBase.count(LiveValue)); Value *RootOfChain = - findRematerializableChainToBasePointer(ChainToBase, - LiveValue); + findRematerializableChainToBasePointer(ChainToBase, Derived); // Nothing to do, or chain is too long if ( ChainToBase.size() == 0 || @@ -2250,9 +2336,9 @@ static void rematerializeLiveValues(CallBase *Call, // Handle the scenario where the RootOfChain is not equal to the // Base Value, but they are essentially the same phi values. - if (RootOfChain != PointerToBase[LiveValue]) { + if (RootOfChain != PointerToBase[Derived]) { PHINode *OrigRootPhi = dyn_cast<PHINode>(RootOfChain); - PHINode *AlternateRootPhi = dyn_cast<PHINode>(PointerToBase[LiveValue]); + PHINode *AlternateRootPhi = dyn_cast<PHINode>(PointerToBase[Derived]); if (!OrigRootPhi || !AlternateRootPhi) continue; // PHI nodes that have the same incoming values, and belonging to the same @@ -2266,33 +2352,61 @@ static void rematerializeLiveValues(CallBase *Call, // deficiency in the findBasePointer algorithm. if (!AreEquivalentPhiNodes(*OrigRootPhi, *AlternateRootPhi)) continue; - // Now that the phi nodes are proved to be the same, assert that - // findBasePointer's newly generated AlternateRootPhi is present in the - // liveset of the call. - assert(Info.LiveSet.count(AlternateRootPhi)); } - // Compute cost of this chain + // Compute cost of this chain. InstructionCost Cost = chainToBasePointerCost(ChainToBase, TTI); // TODO: We can also account for cases when we will be able to remove some // of the rematerialized values by later optimization passes. I.e if // we rematerialized several intersecting chains. Or if original values // don't have any uses besides this statepoint. + // Ok, there is a candidate. + RematerizlizationCandidateRecord Record; + Record.ChainToBase = ChainToBase; + Record.RootOfChain = RootOfChain; + Record.Cost = Cost; + RematerizationCandidates.insert({ Derived, Record }); + } +} + +// From the statepoint live set pick values that are cheaper to recompute then +// to relocate. Remove this values from the live set, rematerialize them after +// statepoint and record them in "Info" structure. Note that similar to +// relocated values we don't do any user adjustments here. +static void rematerializeLiveValues(CallBase *Call, + PartiallyConstructedSafepointRecord &Info, + PointerToBaseTy &PointerToBase, + RematCandTy &RematerizationCandidates, + TargetTransformInfo &TTI) { + // Record values we are going to delete from this statepoint live set. + // We can not di this in following loop due to iterator invalidation. + SmallVector<Value *, 32> LiveValuesToBeDeleted; + + for (Value *LiveValue : Info.LiveSet) { + auto It = RematerizationCandidates.find(LiveValue); + if (It == RematerizationCandidates.end()) + continue; + + RematerizlizationCandidateRecord &Record = It->second; + + InstructionCost Cost = Record.Cost; // For invokes we need to rematerialize each chain twice - for normal and // for unwind basic blocks. Model this by multiplying cost by two. - if (isa<InvokeInst>(Call)) { + if (isa<InvokeInst>(Call)) Cost *= 2; - } - // If it's too expensive - skip it + + // If it's too expensive - skip it. if (Cost >= RematerializationThreshold) continue; // Remove value from the live set LiveValuesToBeDeleted.push_back(LiveValue); - // Clone instructions and record them inside "Info" structure + // Clone instructions and record them inside "Info" structure. - // Walk backwards to visit top-most instructions first + // For each live pointer find get its defining chain. + SmallVector<Instruction *, 3> ChainToBase = Record.ChainToBase; + // Walk backwards to visit top-most instructions first. std::reverse(ChainToBase.begin(), ChainToBase.end()); // Utility function which clones all instructions from "ChainToBase" @@ -2352,7 +2466,7 @@ static void rematerializeLiveValues(CallBase *Call, Instruction *InsertBefore = Call->getNextNode(); assert(InsertBefore); Instruction *RematerializedValue = rematerializeChain( - InsertBefore, RootOfChain, PointerToBase[LiveValue]); + InsertBefore, Record.RootOfChain, PointerToBase[LiveValue]); Info.RematerializedValues[RematerializedValue] = LiveValue; } else { auto *Invoke = cast<InvokeInst>(Call); @@ -2363,9 +2477,9 @@ static void rematerializeLiveValues(CallBase *Call, &*Invoke->getUnwindDest()->getFirstInsertionPt(); Instruction *NormalRematerializedValue = rematerializeChain( - NormalInsertBefore, RootOfChain, PointerToBase[LiveValue]); + NormalInsertBefore, Record.RootOfChain, PointerToBase[LiveValue]); Instruction *UnwindRematerializedValue = rematerializeChain( - UnwindInsertBefore, RootOfChain, PointerToBase[LiveValue]); + UnwindInsertBefore, Record.RootOfChain, PointerToBase[LiveValue]); Info.RematerializedValues[NormalRematerializedValue] = LiveValue; Info.RematerializedValues[UnwindRematerializedValue] = LiveValue; @@ -2380,7 +2494,8 @@ static void rematerializeLiveValues(CallBase *Call, static bool inlineGetBaseAndOffset(Function &F, SmallVectorImpl<CallInst *> &Intrinsics, - DefiningValueMapTy &DVCache) { + DefiningValueMapTy &DVCache, + IsKnownBaseMapTy &KnownBases) { auto &Context = F.getContext(); auto &DL = F.getParent()->getDataLayout(); bool Changed = false; @@ -2389,7 +2504,8 @@ static bool inlineGetBaseAndOffset(Function &F, switch (Callsite->getIntrinsicID()) { case Intrinsic::experimental_gc_get_pointer_base: { Changed = true; - Value *Base = findBasePointer(Callsite->getOperand(0), DVCache); + Value *Base = + findBasePointer(Callsite->getOperand(0), DVCache, KnownBases); assert(!DVCache.count(Callsite)); auto *BaseBC = IRBuilder<>(Callsite).CreateBitCast( Base, Callsite->getType(), suffixed_name_or(Base, ".cast", "")); @@ -2404,7 +2520,7 @@ static bool inlineGetBaseAndOffset(Function &F, case Intrinsic::experimental_gc_get_pointer_offset: { Changed = true; Value *Derived = Callsite->getOperand(0); - Value *Base = findBasePointer(Derived, DVCache); + Value *Base = findBasePointer(Derived, DVCache, KnownBases); assert(!DVCache.count(Callsite)); unsigned AddressSpace = Derived->getType()->getPointerAddressSpace(); unsigned IntPtrSize = DL.getPointerSizeInBits(AddressSpace); @@ -2431,7 +2547,8 @@ static bool inlineGetBaseAndOffset(Function &F, static bool insertParsePoints(Function &F, DominatorTree &DT, TargetTransformInfo &TTI, SmallVectorImpl<CallBase *> &ToUpdate, - DefiningValueMapTy &DVCache) { + DefiningValueMapTy &DVCache, + IsKnownBaseMapTy &KnownBases) { #ifndef NDEBUG // Validate the input std::set<CallBase *> Uniqued; @@ -2487,7 +2604,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, // B) Find the base pointers for each live pointer for (size_t i = 0; i < Records.size(); i++) { PartiallyConstructedSafepointRecord &info = Records[i]; - findBasePointers(DT, DVCache, ToUpdate[i], info, PointerToBase); + findBasePointers(DT, DVCache, ToUpdate[i], info, PointerToBase, KnownBases); } if (PrintBasePointers) { errs() << "Base Pairs (w/o Relocation):\n"; @@ -2563,11 +2680,16 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, Holders.clear(); + // Compute the cost of possible re-materialization of derived pointers. + RematCandTy RematerizationCandidates; + findRematerializationCandidates(PointerToBase, RematerizationCandidates, TTI); + // In order to reduce live set of statepoint we might choose to rematerialize // some values instead of relocating them. This is purely an optimization and // does not influence correctness. for (size_t i = 0; i < Records.size(); i++) - rematerializeLiveValues(ToUpdate[i], Records[i], PointerToBase, TTI); + rematerializeLiveValues(ToUpdate[i], Records[i], PointerToBase, + RematerizationCandidates, TTI); // We need this to safely RAUW and delete call or invoke return values that // may themselves be live over a statepoint. For details, please see usage in @@ -2930,13 +3052,18 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT, // inlineGetBaseAndOffset() and insertParsePoints(). DefiningValueMapTy DVCache; + // Mapping between a base values and a flag indicating whether it's a known + // base or not. + IsKnownBaseMapTy KnownBases; + if (!Intrinsics.empty()) // Inline @gc.get.pointer.base() and @gc.get.pointer.offset() before finding // live references. - MadeChange |= inlineGetBaseAndOffset(F, Intrinsics, DVCache); + MadeChange |= inlineGetBaseAndOffset(F, Intrinsics, DVCache, KnownBases); if (!ParsePointNeeded.empty()) - MadeChange |= insertParsePoints(F, DT, TTI, ParsePointNeeded, DVCache); + MadeChange |= + insertParsePoints(F, DT, TTI, ParsePointNeeded, DVCache, KnownBases); return MadeChange; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SCCP.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SCCP.cpp index fa1cfc84e4fd..2282ef636076 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SCCP.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SCCP.cpp @@ -17,20 +17,15 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/SCCP.h" -#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" -#include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueLattice.h" #include "llvm/Analysis/ValueLatticeUtils.h" @@ -38,14 +33,13 @@ #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" -#include "llvm/IR/InstVisitor.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" @@ -59,7 +53,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" -#include "llvm/Transforms/Utils/PredicateInfo.h" +#include "llvm/Transforms/Utils/SCCPSolver.h" #include <cassert> #include <utility> #include <vector> @@ -97,6 +91,18 @@ static bool isOverdefined(const ValueLatticeElement &LV) { return !LV.isUnknownOrUndef() && !isConstant(LV); } +static bool canRemoveInstruction(Instruction *I) { + if (wouldInstructionBeTriviallyDead(I)) + return true; + + // Some instructions can be handled but are rejected above. Catch + // those cases by falling through to here. + // TODO: Mark globals as being constant earlier, so + // TODO: wouldInstructionBeTriviallyDead() knows that atomic loads + // TODO: are safe to remove. + return isa<LoadInst>(I); +} + static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { Constant *Const = nullptr; if (V->getType()->isStructTy()) { @@ -127,7 +133,8 @@ static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { // Calls with "clang.arc.attachedcall" implicitly use the return value and // those uses cannot be updated with a constant. CallBase *CB = dyn_cast<CallBase>(V); - if (CB && ((CB->isMustTailCall() && !CB->isSafeToRemove()) || + if (CB && ((CB->isMustTailCall() && + !canRemoveInstruction(CB)) || CB->getOperandBundle(LLVMContext::OB_clang_arc_attachedcall))) { Function *F = CB->getCalledFunction(); @@ -156,7 +163,7 @@ static bool simplifyInstsInBlock(SCCPSolver &Solver, BasicBlock &BB, if (Inst.getType()->isVoidTy()) continue; if (tryToReplaceWithConstant(Solver, &Inst)) { - if (Inst.isSafeToRemove()) + if (canRemoveInstruction(&Inst)) Inst.eraseFromParent(); MadeChanges = true; @@ -170,6 +177,7 @@ static bool simplifyInstsInBlock(SCCPSolver &Solver, BasicBlock &BB, continue; if (IV.getConstantRange().isAllNonNegative()) { auto *ZExt = new ZExtInst(ExtOp, Inst.getType(), "", &Inst); + ZExt->takeName(&Inst); InsertedValues.insert(ZExt); Inst.replaceAllUsesWith(ZExt); Solver.removeLatticeValueFor(&Inst); @@ -182,10 +190,14 @@ static bool simplifyInstsInBlock(SCCPSolver &Solver, BasicBlock &BB, return MadeChanges; } +static bool removeNonFeasibleEdges(const SCCPSolver &Solver, BasicBlock *BB, + DomTreeUpdater &DTU, + BasicBlock *&NewUnreachableBB); + // runSCCP() - Run the Sparse Conditional Constant Propagation algorithm, // and return true if the function was modified. static bool runSCCP(Function &F, const DataLayout &DL, - const TargetLibraryInfo *TLI) { + const TargetLibraryInfo *TLI, DomTreeUpdater &DTU) { LLVM_DEBUG(dbgs() << "SCCP on function '" << F.getName() << "'\n"); SCCPSolver Solver( DL, [TLI](Function &F) -> const TargetLibraryInfo & { return *TLI; }, @@ -213,13 +225,12 @@ static bool runSCCP(Function &F, const DataLayout &DL, // as we cannot modify the CFG of the function. SmallPtrSet<Value *, 32> InsertedValues; + SmallVector<BasicBlock *, 8> BlocksToErase; for (BasicBlock &BB : F) { if (!Solver.isBlockExecutable(&BB)) { LLVM_DEBUG(dbgs() << " BasicBlock Dead:" << BB); - ++NumDeadBlocks; - NumInstRemoved += removeAllNonTerminatorAndEHPadInstructions(&BB).first; - + BlocksToErase.push_back(&BB); MadeChanges = true; continue; } @@ -228,17 +239,32 @@ static bool runSCCP(Function &F, const DataLayout &DL, NumInstRemoved, NumInstReplaced); } + // Remove unreachable blocks and non-feasible edges. + for (BasicBlock *DeadBB : BlocksToErase) + NumInstRemoved += changeToUnreachable(DeadBB->getFirstNonPHI(), + /*PreserveLCSSA=*/false, &DTU); + + BasicBlock *NewUnreachableBB = nullptr; + for (BasicBlock &BB : F) + MadeChanges |= removeNonFeasibleEdges(Solver, &BB, DTU, NewUnreachableBB); + + for (BasicBlock *DeadBB : BlocksToErase) + if (!DeadBB->hasAddressTaken()) + DTU.deleteBB(DeadBB); + return MadeChanges; } PreservedAnalyses SCCPPass::run(Function &F, FunctionAnalysisManager &AM) { const DataLayout &DL = F.getParent()->getDataLayout(); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); - if (!runSCCP(F, DL, &TLI)) + auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); + if (!runSCCP(F, DL, &TLI, DTU)) return PreservedAnalyses::all(); auto PA = PreservedAnalyses(); - PA.preserveSet<CFGAnalyses>(); + PA.preserve<DominatorTreeAnalysis>(); return PA; } @@ -261,7 +287,7 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); - AU.setPreservesCFG(); + AU.addPreserved<DominatorTreeWrapperPass>(); } // runOnFunction - Run the Sparse Conditional Constant Propagation @@ -272,7 +298,10 @@ public: const DataLayout &DL = F.getParent()->getDataLayout(); const TargetLibraryInfo *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); - return runSCCP(F, DL, TLI); + auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + DomTreeUpdater DTU(DTWP ? &DTWP->getDomTree() : nullptr, + DomTreeUpdater::UpdateStrategy::Lazy); + return runSCCP(F, DL, TLI, DTU); } }; @@ -363,7 +392,19 @@ static bool removeNonFeasibleEdges(const SCCPSolver &Solver, BasicBlock *BB, isa<IndirectBrInst>(TI)) && "Terminator must be a br, switch or indirectbr"); - if (FeasibleSuccessors.size() == 1) { + if (FeasibleSuccessors.size() == 0) { + // Branch on undef/poison, replace with unreachable. + SmallPtrSet<BasicBlock *, 8> SeenSuccs; + SmallVector<DominatorTree::UpdateType, 8> Updates; + for (BasicBlock *Succ : successors(BB)) { + Succ->removePredecessor(BB); + if (SeenSuccs.insert(Succ).second) + Updates.push_back({DominatorTree::Delete, BB, Succ}); + } + TI->eraseFromParent(); + new UnreachableInst(BB->getContext(), BB); + DTU.applyUpdatesPermissive(Updates); + } else if (FeasibleSuccessors.size() == 1) { // Replace with an unconditional branch to the only feasible successor. BasicBlock *OnlyFeasibleSuccessor = *FeasibleSuccessors.begin(); SmallVector<DominatorTree::UpdateType, 8> Updates; @@ -555,7 +596,8 @@ bool llvm::runIPSCCP( MadeChanges |= removeNonFeasibleEdges(Solver, &BB, DTU, NewUnreachableBB); for (BasicBlock *DeadBB : BlocksToErase) - DTU.deleteBB(DeadBB); + if (!DeadBB->hasAddressTaken()) + DTU.deleteBB(DeadBB); for (BasicBlock &BB : F) { for (Instruction &Inst : llvm::make_early_inc_range(BB)) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SROA.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SROA.cpp index 8be8946702be..143a035749c7 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SROA.cpp @@ -57,11 +57,9 @@ #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" -#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" @@ -78,14 +76,12 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" #include <algorithm> #include <cassert> -#include <chrono> #include <cstddef> #include <cstdint> #include <cstring> @@ -1016,7 +1012,7 @@ private: I.getParent()->getFirstInsertionPt() == I.getParent()->end()) return PI.setAborted(&I); - // TODO: We could use SimplifyInstruction here to fold PHINodes and + // TODO: We could use simplifyInstruction here to fold PHINodes and // SelectInsts. However, doing so requires to change the current // dead-operand-tracking mechanism. For instance, suppose neither loading // from %U nor %other traps. Then "load (select undef, %U, %other)" does not @@ -1987,13 +1983,22 @@ static bool isIntegerWideningViableForSlice(const Slice &S, uint64_t RelBegin = S.beginOffset() - AllocBeginOffset; uint64_t RelEnd = S.endOffset() - AllocBeginOffset; + Use *U = S.getUse(); + + // Lifetime intrinsics operate over the whole alloca whose sizes are usually + // larger than other load/store slices (RelEnd > Size). But lifetime are + // always promotable and should not impact other slices' promotability of the + // partition. + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U->getUser())) { + if (II->isLifetimeStartOrEnd() || II->isDroppable()) + return true; + } + // We can't reasonably handle cases where the load or store extends past // the end of the alloca's type and into its padding. if (RelEnd > Size) return false; - Use *U = S.getUse(); - if (LoadInst *LI = dyn_cast<LoadInst>(U->getUser())) { if (LI->isVolatile()) return false; @@ -2048,9 +2053,6 @@ static bool isIntegerWideningViableForSlice(const Slice &S, return false; if (!S.isSplittable()) return false; // Skip any unsplittable intrinsics. - } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U->getUser())) { - if (!II->isLifetimeStartOrEnd() && !II->isDroppable()) - return false; } else { return false; } @@ -2179,10 +2181,7 @@ static Value *extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, return V; } - SmallVector<int, 8> Mask; - Mask.reserve(NumElements); - for (unsigned i = BeginIndex; i != EndIndex; ++i) - Mask.push_back(i); + auto Mask = llvm::to_vector<8>(llvm::seq<int>(BeginIndex, EndIndex)); V = IRB.CreateShuffleVector(V, Mask, Name + ".extract"); LLVM_DEBUG(dbgs() << " shuffle: " << *V << "\n"); return V; @@ -2734,10 +2733,9 @@ private: Type *SplatIntTy = Type::getIntNTy(VTy->getContext(), Size * 8); V = IRB.CreateMul( IRB.CreateZExt(V, SplatIntTy, "zext"), - ConstantExpr::getUDiv( - Constant::getAllOnesValue(SplatIntTy), - ConstantExpr::getZExt(Constant::getAllOnesValue(V->getType()), - SplatIntTy)), + IRB.CreateUDiv(Constant::getAllOnesValue(SplatIntTy), + IRB.CreateZExt(Constant::getAllOnesValue(V->getType()), + SplatIntTy)), "isplat"); return V; } @@ -2887,7 +2885,7 @@ private: assert((IsDest && II.getRawDest() == OldPtr) || (!IsDest && II.getRawSource() == OldPtr)); - MaybeAlign SliceAlign = getSliceAlign(); + Align SliceAlign = getSliceAlign(); // For unsplit intrinsics, we simply modify the source and destination // pointers in place. This isn't just an optimization, it is a matter of @@ -3481,19 +3479,13 @@ private: Type *Ty = GEPI.getSourceElementType(); Value *True = Sel->getTrueValue(); - Value *NTrue = - IsInBounds - ? IRB.CreateInBoundsGEP(Ty, True, Index, - True->getName() + ".sroa.gep") - : IRB.CreateGEP(Ty, True, Index, True->getName() + ".sroa.gep"); + Value *NTrue = IRB.CreateGEP(Ty, True, Index, True->getName() + ".sroa.gep", + IsInBounds); Value *False = Sel->getFalseValue(); - Value *NFalse = - IsInBounds - ? IRB.CreateInBoundsGEP(Ty, False, Index, - False->getName() + ".sroa.gep") - : IRB.CreateGEP(Ty, False, Index, False->getName() + ".sroa.gep"); + Value *NFalse = IRB.CreateGEP(Ty, False, Index, + False->getName() + ".sroa.gep", IsInBounds); Value *NSel = IRB.CreateSelect(Sel->getCondition(), NTrue, NFalse, Sel->getName() + ".sroa.sel"); @@ -3547,10 +3539,8 @@ private: IRB.SetInsertPoint(In->getParent(), std::next(In->getIterator())); Type *Ty = GEPI.getSourceElementType(); - NewVal = IsInBounds ? IRB.CreateInBoundsGEP(Ty, In, Index, - In->getName() + ".sroa.gep") - : IRB.CreateGEP(Ty, In, Index, - In->getName() + ".sroa.gep"); + NewVal = IRB.CreateGEP(Ty, In, Index, In->getName() + ".sroa.gep", + IsInBounds); } NewPN->addIncoming(NewVal, B); } @@ -3972,16 +3962,15 @@ bool SROAPass::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { for (LoadInst *LI : Loads) { SplitLoads.clear(); - IntegerType *Ty = cast<IntegerType>(LI->getType()); - assert(Ty->getBitWidth() % 8 == 0); - uint64_t LoadSize = Ty->getBitWidth() / 8; - assert(LoadSize > 0 && "Cannot have a zero-sized integer load!"); - auto &Offsets = SplitOffsetsMap[LI]; - assert(LoadSize == Offsets.S->endOffset() - Offsets.S->beginOffset() && - "Slice size should always match load size exactly!"); + unsigned SliceSize = Offsets.S->endOffset() - Offsets.S->beginOffset(); + assert(LI->getType()->getIntegerBitWidth() % 8 == 0 && + "Load must have type size equal to store size"); + assert(LI->getType()->getIntegerBitWidth() / 8 >= SliceSize && + "Load must be >= slice size"); + uint64_t BaseOffset = Offsets.S->beginOffset(); - assert(BaseOffset + LoadSize > BaseOffset && + assert(BaseOffset + SliceSize > BaseOffset && "Cannot represent alloca access size using 64-bit integers!"); Instruction *BasePtr = cast<Instruction>(LI->getPointerOperand()); @@ -3992,7 +3981,7 @@ bool SROAPass::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { uint64_t PartOffset = 0, PartSize = Offsets.Splits.front(); int Idx = 0, Size = Offsets.Splits.size(); for (;;) { - auto *PartTy = Type::getIntNTy(Ty->getContext(), PartSize * 8); + auto *PartTy = Type::getIntNTy(LI->getContext(), PartSize * 8); auto AS = LI->getPointerAddressSpace(); auto *PartPtrTy = PartTy->getPointerTo(AS); LoadInst *PLoad = IRB.CreateAlignedLoad( @@ -4025,7 +4014,7 @@ bool SROAPass::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { // Setup the next partition. PartOffset = Offsets.Splits[Idx]; ++Idx; - PartSize = (Idx < Size ? Offsets.Splits[Idx] : LoadSize) - PartOffset; + PartSize = (Idx < Size ? Offsets.Splits[Idx] : SliceSize) - PartOffset; } // Now that we have the split loads, do the slow walk over all uses of the diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalar.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalar.cpp index f9650efc051f..008ddfc72740 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalar.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalar.cpp @@ -16,16 +16,13 @@ #include "llvm-c/Initialization.h" #include "llvm-c/Transforms/Scalar.h" #include "llvm/Analysis/BasicAliasAnalysis.h" -#include "llvm/Analysis/Passes.h" #include "llvm/Analysis/ScopedNoAliasAA.h" #include "llvm/Analysis/TypeBasedAliasAnalysis.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Verifier.h" #include "llvm/InitializePasses.h" #include "llvm/Transforms/Scalar/GVN.h" #include "llvm/Transforms/Scalar/Scalarizer.h" -#include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h" #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" using namespace llvm; @@ -76,7 +73,6 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeLoopRerollLegacyPassPass(Registry); initializeLoopUnrollPass(Registry); initializeLoopUnrollAndJamPass(Registry); - initializeLoopUnswitchPass(Registry); initializeWarnMissedTransformationsLegacyPass(Registry); initializeLoopVersioningLICMLegacyPassPass(Registry); initializeLoopIdiomRecognizeLegacyPassPass(Registry); @@ -104,6 +100,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeSimpleLoopUnswitchLegacyPassPass(Registry); initializeSinkingLegacyPassPass(Registry); initializeTailCallElimPass(Registry); + initializeTLSVariableHoistLegacyPassPass(Registry); initializeSeparateConstOffsetFromGEPLegacyPassPass(Registry); initializeSpeculativeExecutionLegacyPassPass(Registry); initializeStraightLineStrengthReduceLegacyPassPass(Registry); @@ -214,10 +211,6 @@ void LLVMAddLoopUnrollAndJamPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createLoopUnrollAndJamPass()); } -void LLVMAddLoopUnswitchPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLoopUnswitchPass()); -} - void LLVMAddLowerAtomicPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createLowerAtomicPass()); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp index 29cea42e4a00..e2976ace3a4a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp @@ -1,5 +1,5 @@ //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===// -// instrinsics +// intrinsics // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -24,11 +24,9 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" @@ -36,7 +34,6 @@ #include "llvm/Support/Casting.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include <algorithm> #include <cassert> using namespace llvm; @@ -876,7 +873,7 @@ static bool runImpl(Function &F, const TargetTransformInfo &TTI, for (BasicBlock &BB : llvm::make_early_inc_range(F)) { bool ModifiedDTOnIteration = false; MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL, - DTU.hasValue() ? DTU.getPointer() : nullptr); + DTU ? DTU.getPointer() : nullptr); // Restart BB iteration if the dominator tree of the Function was changed if (ModifiedDTOnIteration) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalarizer.cpp index 3606c8a4b073..08f4b2173da2 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalarizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalarizer.cpp @@ -39,8 +39,6 @@ #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/MathExtras.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstdint> @@ -52,7 +50,7 @@ using namespace llvm; #define DEBUG_TYPE "scalarizer" -static cl::opt<bool> ScalarizeVariableInsertExtract( +static cl::opt<bool> ClScalarizeVariableInsertExtract( "scalarize-variable-insert-extract", cl::init(true), cl::Hidden, cl::desc("Allow the scalarizer pass to scalarize " "insertelement/extractelement with variable index")); @@ -60,9 +58,9 @@ static cl::opt<bool> ScalarizeVariableInsertExtract( // This is disabled by default because having separate loads and stores // makes it more likely that the -combiner-alias-analysis limits will be // reached. -static cl::opt<bool> - ScalarizeLoadStore("scalarize-load-store", cl::init(false), cl::Hidden, - cl::desc("Allow the scalarizer pass to scalarize loads and store")); +static cl::opt<bool> ClScalarizeLoadStore( + "scalarize-load-store", cl::init(false), cl::Hidden, + cl::desc("Allow the scalarizer pass to scalarize loads and store")); namespace { @@ -96,7 +94,7 @@ public: // Scatter V into Size components. If new instructions are needed, // insert them before BBI in BB. If Cache is nonnull, use it to cache // the results. - Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, + Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, Type *PtrElemTy, ValueVector *cachePtr = nullptr); // Return component I, creating a new Value for it if necessary. @@ -109,8 +107,8 @@ private: BasicBlock *BB; BasicBlock::iterator BBI; Value *V; + Type *PtrElemTy; ValueVector *CachePtr; - PointerType *PtrTy; ValueVector Tmp; unsigned Size; }; @@ -188,10 +186,23 @@ struct VectorLayout { uint64_t ElemSize = 0; }; +template <typename T> +T getWithDefaultOverride(const cl::opt<T> &ClOption, + const llvm::Optional<T> &DefaultOverride) { + return ClOption.getNumOccurrences() ? ClOption + : DefaultOverride.value_or(ClOption); +} + class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> { public: - ScalarizerVisitor(unsigned ParallelLoopAccessMDKind, DominatorTree *DT) - : ParallelLoopAccessMDKind(ParallelLoopAccessMDKind), DT(DT) { + ScalarizerVisitor(unsigned ParallelLoopAccessMDKind, DominatorTree *DT, + ScalarizerPassOptions Options) + : ParallelLoopAccessMDKind(ParallelLoopAccessMDKind), DT(DT), + ScalarizeVariableInsertExtract( + getWithDefaultOverride(ClScalarizeVariableInsertExtract, + Options.ScalarizeVariableInsertExtract)), + ScalarizeLoadStore(getWithDefaultOverride(ClScalarizeLoadStore, + Options.ScalarizeLoadStore)) { } bool visit(Function &F); @@ -216,8 +227,9 @@ public: bool visitCallInst(CallInst &ICI); private: - Scatterer scatter(Instruction *Point, Value *V); + Scatterer scatter(Instruction *Point, Value *V, Type *PtrElemTy = nullptr); void gather(Instruction *Op, const ValueVector &CV); + void replaceUses(Instruction *Op, Value *CV); bool canTransferMetadata(unsigned Kind); void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV); Optional<VectorLayout> getVectorLayout(Type *Ty, Align Alignment, @@ -231,12 +243,16 @@ private: ScatterMap Scattered; GatherList Gathered; + bool Scalarized; SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs; unsigned ParallelLoopAccessMDKind; DominatorTree *DT; + + const bool ScalarizeVariableInsertExtract; + const bool ScalarizeLoadStore; }; class ScalarizerLegacyPass : public FunctionPass { @@ -265,12 +281,14 @@ INITIALIZE_PASS_END(ScalarizerLegacyPass, "scalarizer", "Scalarize vector operations", false, false) Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, - ValueVector *cachePtr) - : BB(bb), BBI(bbi), V(v), CachePtr(cachePtr) { + Type *PtrElemTy, ValueVector *cachePtr) + : BB(bb), BBI(bbi), V(v), PtrElemTy(PtrElemTy), CachePtr(cachePtr) { Type *Ty = V->getType(); - PtrTy = dyn_cast<PointerType>(Ty); - if (PtrTy) - Ty = PtrTy->getPointerElementType(); + if (Ty->isPointerTy()) { + assert(cast<PointerType>(Ty)->isOpaqueOrPointeeTypeMatches(PtrElemTy) && + "Pointer element type mismatch"); + Ty = PtrElemTy; + } Size = cast<FixedVectorType>(Ty)->getNumElements(); if (!CachePtr) Tmp.resize(Size, nullptr); @@ -287,15 +305,15 @@ Value *Scatterer::operator[](unsigned I) { if (CV[I]) return CV[I]; IRBuilder<> Builder(BB, BBI); - if (PtrTy) { - Type *ElTy = - cast<VectorType>(PtrTy->getPointerElementType())->getElementType(); + if (PtrElemTy) { + Type *VectorElemTy = cast<VectorType>(PtrElemTy)->getElementType(); if (!CV[0]) { - Type *NewPtrTy = PointerType::get(ElTy, PtrTy->getAddressSpace()); + Type *NewPtrTy = PointerType::get( + VectorElemTy, V->getType()->getPointerAddressSpace()); CV[0] = Builder.CreateBitCast(V, NewPtrTy, V->getName() + ".i0"); } if (I != 0) - CV[I] = Builder.CreateConstGEP1_32(ElTy, CV[0], I, + CV[I] = Builder.CreateConstGEP1_32(VectorElemTy, CV[0], I, V->getName() + ".i" + Twine(I)); } else { // Search through a chain of InsertElementInsts looking for element I. @@ -334,7 +352,7 @@ bool ScalarizerLegacyPass::runOnFunction(Function &F) { unsigned ParallelLoopAccessMDKind = M.getContext().getMDKindID("llvm.mem.parallel_loop_access"); DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT); + ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT, ScalarizerPassOptions()); return Impl.visit(F); } @@ -345,6 +363,8 @@ FunctionPass *llvm::createScalarizerPass() { bool ScalarizerVisitor::visit(Function &F) { assert(Gathered.empty() && Scattered.empty()); + Scalarized = false; + // To ensure we replace gathered components correctly we need to do an ordered // traversal of the basic blocks in the function. ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock()); @@ -362,13 +382,14 @@ bool ScalarizerVisitor::visit(Function &F) { // Return a scattered form of V that can be accessed by Point. V must be a // vector or a pointer to a vector. -Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V) { +Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V, + Type *PtrElemTy) { if (Argument *VArg = dyn_cast<Argument>(V)) { // Put the scattered form of arguments in the entry block, // so that it can be used everywhere. Function *F = VArg->getParent(); BasicBlock *BB = &F->getEntryBlock(); - return Scatterer(BB, BB->begin(), V, &Scattered[V]); + return Scatterer(BB, BB->begin(), V, PtrElemTy, &Scattered[V]); } if (Instruction *VOp = dyn_cast<Instruction>(V)) { // When scalarizing PHI nodes we might try to examine/rewrite InsertElement @@ -379,17 +400,17 @@ Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V) { // need to analyse them further. if (!DT->isReachableFromEntry(VOp->getParent())) return Scatterer(Point->getParent(), Point->getIterator(), - UndefValue::get(V->getType())); + PoisonValue::get(V->getType()), PtrElemTy); // Put the scattered form of an instruction directly after the // instruction, skipping over PHI nodes and debug intrinsics. BasicBlock *BB = VOp->getParent(); return Scatterer( BB, skipPastPhiNodesAndDbg(std::next(BasicBlock::iterator(VOp))), V, - &Scattered[V]); + PtrElemTy, &Scattered[V]); } // In the fallback case, just put the scattered before Point and // keep the result local to Point. - return Scatterer(Point->getParent(), Point->getIterator(), V); + return Scatterer(Point->getParent(), Point->getIterator(), V, PtrElemTy); } // Replace Op with the gathered form of the components in CV. Defer the @@ -419,6 +440,15 @@ void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV) { Gathered.push_back(GatherList::value_type(Op, &SV)); } +// Replace Op with CV and collect Op has a potentially dead instruction. +void ScalarizerVisitor::replaceUses(Instruction *Op, Value *CV) { + if (CV != Op) { + Op->replaceAllUsesWith(CV); + PotentiallyDeadInstrs.emplace_back(Op); + Scalarized = true; + } +} + // Return true if it is safe to transfer the given metadata tag from // vector to scalar instructions. bool ScalarizerVisitor::canTransferMetadata(unsigned Tag) { @@ -558,9 +588,11 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) { if (OpI->getType()->isVectorTy()) { Scattered[I] = scatter(&CI, OpI); assert(Scattered[I].size() == NumElems && "mismatched call operands"); + if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) + Tys.push_back(OpI->getType()->getScalarType()); } else { ScalarOperands[I] = OpI; - if (hasVectorInstrinsicOverloadedScalarOpd(ID, I)) + if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) Tys.push_back(OpI->getType()); } } @@ -576,7 +608,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) { ScalarCallOps.clear(); for (unsigned J = 0; J != NumArgs; ++J) { - if (hasVectorInstrinsicScalarOpd(ID, J)) + if (isVectorIntrinsicWithScalarOpAtArg(ID, J)) ScalarCallOps.push_back(ScalarOperands[J]); else ScalarCallOps.push_back(Scattered[J][Elem]); @@ -809,7 +841,7 @@ bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) { if (auto *CI = dyn_cast<ConstantInt>(ExtIdx)) { Value *Res = Op0[CI->getValue().getZExtValue()]; - gather(&EEI, {Res}); + replaceUses(&EEI, Res); return true; } @@ -825,7 +857,7 @@ bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) { Res = Builder.CreateSelect(ShouldExtract, Elt, Res, EEI.getName() + ".upto" + Twine(I)); } - gather(&EEI, {Res}); + replaceUses(&EEI, Res); return true; } @@ -891,7 +923,7 @@ bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) { unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements(); IRBuilder<> Builder(&LI); - Scatterer Ptr = scatter(&LI, LI.getPointerOperand()); + Scatterer Ptr = scatter(&LI, LI.getPointerOperand(), LI.getType()); ValueVector Res; Res.resize(NumElems); @@ -917,7 +949,7 @@ bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) { unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements(); IRBuilder<> Builder(&SI); - Scatterer VPtr = scatter(&SI, SI.getPointerOperand()); + Scatterer VPtr = scatter(&SI, SI.getPointerOperand(), FullValue->getType()); Scatterer VVal = scatter(&SI, FullValue); ValueVector Stores; @@ -940,7 +972,7 @@ bool ScalarizerVisitor::visitCallInst(CallInst &CI) { bool ScalarizerVisitor::finish() { // The presence of data in Gathered or Scattered indicates changes // made to the Function. - if (Gathered.empty() && Scattered.empty()) + if (Gathered.empty() && Scattered.empty() && !Scalarized) return false; for (const auto &GMI : Gathered) { Instruction *Op = GMI.first; @@ -971,6 +1003,7 @@ bool ScalarizerVisitor::finish() { } Gathered.clear(); Scattered.clear(); + Scalarized = false; RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs); @@ -982,7 +1015,7 @@ PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM) unsigned ParallelLoopAccessMDKind = M.getContext().getMDKindID("llvm.mem.parallel_loop_access"); DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F); - ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT); + ScalarizerVisitor Impl(ParallelLoopAccessMDKind, DT, Options); bool Changed = Impl.visit(F); PreservedAnalyses PA; PA.preserve<DominatorTreeAnalysis>(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index d23925042b0a..7da5a78772ad 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -189,7 +189,6 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include <cassert> diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index a27da047bfd3..0535608244cc 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -19,7 +19,6 @@ #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/GuardUtils.h" -#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" @@ -28,6 +27,7 @@ #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/MustExecute.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -49,7 +49,9 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/GenericDomTree.h" +#include "llvm/Support/InstructionCost.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" @@ -81,7 +83,6 @@ static cl::opt<bool> EnableNonTrivialUnswitch( static cl::opt<int> UnswitchThreshold("unswitch-threshold", cl::init(50), cl::Hidden, - cl::ZeroOrMore, cl::desc("The cost threshold for unswitching a loop.")); static cl::opt<bool> EnableUnswitchCostMultiplier( @@ -110,17 +111,27 @@ static cl::opt<unsigned> "partial unswitching analysis"), cl::init(100), cl::Hidden); static cl::opt<bool> FreezeLoopUnswitchCond( - "freeze-loop-unswitch-cond", cl::init(false), cl::Hidden, + "freeze-loop-unswitch-cond", cl::init(true), cl::Hidden, cl::desc("If enabled, the freeze instruction will be added to condition " "of loop unswitch to prevent miscompilation.")); +// Helper to skip (select x, true, false), which matches both a logical AND and +// OR and can confuse code that tries to determine if \p Cond is either a +// logical AND or OR but not both. +static Value *skipTrivialSelect(Value *Cond) { + Value *CondNext; + while (match(Cond, m_Select(m_Value(CondNext), m_One(), m_Zero()))) + Cond = CondNext; + return Cond; +} + /// Collect all of the loop invariant input values transitively used by the /// homogeneous instruction graph from a given root. /// /// This essentially walks from a root recursively through loop variant operands -/// which have the exact same opcode and finds all inputs which are loop -/// invariant. For some operations these can be re-associated and unswitched out -/// of the loop entirely. +/// which have perform the same logical operation (AND or OR) and finds all +/// inputs which are loop invariant. For some operations these can be +/// re-associated and unswitched out of the loop entirely. static TinyPtrVector<Value *> collectHomogenousInstGraphLoopInvariants(Loop &L, Instruction &Root, LoopInfo &LI) { @@ -150,7 +161,7 @@ collectHomogenousInstGraphLoopInvariants(Loop &L, Instruction &Root, } // If not an instruction with the same opcode, nothing we can do. - Instruction *OpI = dyn_cast<Instruction>(OpV); + Instruction *OpI = dyn_cast<Instruction>(skipTrivialSelect(OpV)); if (OpI && ((IsRootAnd && match(OpI, m_LogicalAnd())) || (IsRootOr && match(OpI, m_LogicalOr())))) { @@ -202,13 +213,19 @@ static bool areLoopExitPHIsLoopInvariant(Loop &L, BasicBlock &ExitingBB, /// branch on a single value. static void buildPartialUnswitchConditionalBranch( BasicBlock &BB, ArrayRef<Value *> Invariants, bool Direction, - BasicBlock &UnswitchedSucc, BasicBlock &NormalSucc, bool InsertFreeze) { + BasicBlock &UnswitchedSucc, BasicBlock &NormalSucc, bool InsertFreeze, + Instruction *I, AssumptionCache *AC, DominatorTree &DT) { IRBuilder<> IRB(&BB); - Value *Cond = Direction ? IRB.CreateOr(Invariants) : - IRB.CreateAnd(Invariants); - if (InsertFreeze) - Cond = IRB.CreateFreeze(Cond, Cond->getName() + ".fr"); + SmallVector<Value *> FrozenInvariants; + for (Value *Inv : Invariants) { + if (InsertFreeze && !isGuaranteedNotToBeUndefOrPoison(Inv, AC, I, &DT)) + Inv = IRB.CreateFreeze(Inv, Inv->getName() + ".fr"); + FrozenInvariants.push_back(Inv); + } + + Value *Cond = Direction ? IRB.CreateOr(FrozenInvariants) + : IRB.CreateAnd(FrozenInvariants); IRB.CreateCondBr(Cond, Direction ? &UnswitchedSucc : &NormalSucc, Direction ? &NormalSucc : &UnswitchedSucc); } @@ -442,11 +459,12 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, // some input conditions to the branch. bool FullUnswitch = false; - if (L.isLoopInvariant(BI.getCondition())) { - Invariants.push_back(BI.getCondition()); + Value *Cond = skipTrivialSelect(BI.getCondition()); + if (L.isLoopInvariant(Cond)) { + Invariants.push_back(Cond); FullUnswitch = true; } else { - if (auto *CondInst = dyn_cast<Instruction>(BI.getCondition())) + if (auto *CondInst = dyn_cast<Instruction>(Cond)) Invariants = collectHomogenousInstGraphLoopInvariants(L, *CondInst, LI); if (Invariants.empty()) { LLVM_DEBUG(dbgs() << " Couldn't find invariant inputs!\n"); @@ -480,8 +498,8 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, // is a graph of `or` operations, or the exit block is along the false edge // and the condition is a graph of `and` operations. if (!FullUnswitch) { - if (ExitDirection ? !match(BI.getCondition(), m_LogicalOr()) - : !match(BI.getCondition(), m_LogicalAnd())) { + if (ExitDirection ? !match(Cond, m_LogicalOr()) + : !match(Cond, m_LogicalAnd())) { LLVM_DEBUG(dbgs() << " Branch condition is in improper form for " "non-full unswitch!\n"); return false; @@ -546,6 +564,7 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, // its successors. OldPH->getInstList().splice(OldPH->end(), BI.getParent()->getInstList(), BI); + BI.setCondition(Cond); if (MSSAU) { // Temporarily clone the terminator, to make MSSA update cheaper by // separating "insert edge" updates from "remove edge" ones. @@ -561,15 +580,16 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, // Only unswitching a subset of inputs to the condition, so we will need to // build a new branch that merges the invariant inputs. if (ExitDirection) - assert(match(BI.getCondition(), m_LogicalOr()) && + assert(match(skipTrivialSelect(BI.getCondition()), m_LogicalOr()) && "Must have an `or` of `i1`s or `select i1 X, true, Y`s for the " "condition!"); else - assert(match(BI.getCondition(), m_LogicalAnd()) && + assert(match(skipTrivialSelect(BI.getCondition()), m_LogicalAnd()) && "Must have an `and` of `i1`s or `select i1 X, Y, false`s for the" " condition!"); - buildPartialUnswitchConditionalBranch(*OldPH, Invariants, ExitDirection, - *UnswitchedBB, *NewPH, false); + buildPartialUnswitchConditionalBranch( + *OldPH, Invariants, ExitDirection, *UnswitchedBB, *NewPH, + FreezeLoopUnswitchCond, OldPH->getTerminator(), nullptr, DT); } // Update the dominator tree with the added edge. @@ -1019,7 +1039,8 @@ static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, // Don't bother trying to unswitch past an unconditional branch or a branch // with a constant value. These should be removed by simplifycfg prior to // running this pass. - if (!BI->isConditional() || isa<Constant>(BI->getCondition())) + if (!BI->isConditional() || + isa<Constant>(skipTrivialSelect(BI->getCondition()))) return Changed; // Found a trivial condition candidate: non-foldable conditional branch. If @@ -1663,7 +1684,7 @@ deleteDeadBlocksFromLoop(Loop &L, // uses in other blocks. for (auto &I : *BB) if (!I.use_empty()) - I.replaceAllUsesWith(UndefValue::get(I.getType())); + I.replaceAllUsesWith(PoisonValue::get(I.getType())); BB->dropAllReferences(); } @@ -2042,12 +2063,13 @@ static void unswitchNontrivialInvariants( "Can only unswitch switches and conditional branch!"); bool PartiallyInvariant = !PartialIVInfo.InstToDuplicate.empty(); bool FullUnswitch = - SI || (BI->getCondition() == Invariants[0] && !PartiallyInvariant); + SI || (skipTrivialSelect(BI->getCondition()) == Invariants[0] && + !PartiallyInvariant); if (FullUnswitch) assert(Invariants.size() == 1 && "Cannot have other invariants with full unswitching!"); else - assert(isa<Instruction>(BI->getCondition()) && + assert(isa<Instruction>(skipTrivialSelect(BI->getCondition())) && "Partial unswitching requires an instruction as the condition!"); if (MSSAU && VerifyMemorySSA) @@ -2062,14 +2084,14 @@ static void unswitchNontrivialInvariants( bool Direction = true; int ClonedSucc = 0; if (!FullUnswitch) { - Value *Cond = BI->getCondition(); + Value *Cond = skipTrivialSelect(BI->getCondition()); (void)Cond; assert(((match(Cond, m_LogicalAnd()) ^ match(Cond, m_LogicalOr())) || PartiallyInvariant) && "Only `or`, `and`, an `select`, partially invariant instructions " "can combine invariants being unswitched."); - if (!match(BI->getCondition(), m_LogicalOr())) { - if (match(BI->getCondition(), m_LogicalAnd()) || + if (!match(Cond, m_LogicalOr())) { + if (match(Cond, m_LogicalAnd()) || (PartiallyInvariant && !PartialIVInfo.KnownValue->isOneValue())) { Direction = false; ClonedSucc = 1; @@ -2209,11 +2231,12 @@ static void unswitchNontrivialInvariants( BasicBlock *ClonedPH = ClonedPHs.begin()->second; BI->setSuccessor(ClonedSucc, ClonedPH); BI->setSuccessor(1 - ClonedSucc, LoopPH); + Value *Cond = skipTrivialSelect(BI->getCondition()); if (InsertFreeze) { - auto Cond = BI->getCondition(); if (!isGuaranteedNotToBeUndefOrPoison(Cond, &AC, BI, &DT)) - BI->setCondition(new FreezeInst(Cond, Cond->getName() + ".fr", BI)); + Cond = new FreezeInst(Cond, Cond->getName() + ".fr", BI); } + BI->setCondition(Cond); DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH}); } else { assert(SI && "Must either be a branch or switch!"); @@ -2311,9 +2334,11 @@ static void unswitchNontrivialInvariants( if (PartiallyInvariant) buildPartialInvariantUnswitchConditionalBranch( *SplitBB, Invariants, Direction, *ClonedPH, *LoopPH, L, MSSAU); - else - buildPartialUnswitchConditionalBranch(*SplitBB, Invariants, Direction, - *ClonedPH, *LoopPH, InsertFreeze); + else { + buildPartialUnswitchConditionalBranch( + *SplitBB, Invariants, Direction, *ClonedPH, *LoopPH, + FreezeLoopUnswitchCond, BI, &AC, DT); + } DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH}); if (MSSAU) { @@ -2745,22 +2770,16 @@ static bool unswitchBestCondition( BI->getSuccessor(0) == BI->getSuccessor(1)) continue; - // If BI's condition is 'select _, true, false', simplify it to confuse - // matchers - Value *Cond = BI->getCondition(), *CondNext; - while (match(Cond, m_Select(m_Value(CondNext), m_One(), m_Zero()))) - Cond = CondNext; - BI->setCondition(Cond); - + Value *Cond = skipTrivialSelect(BI->getCondition()); if (isa<Constant>(Cond)) continue; - if (L.isLoopInvariant(BI->getCondition())) { - UnswitchCandidates.push_back({BI, {BI->getCondition()}}); + if (L.isLoopInvariant(Cond)) { + UnswitchCandidates.push_back({BI, {Cond}}); continue; } - Instruction &CondI = *cast<Instruction>(BI->getCondition()); + Instruction &CondI = *cast<Instruction>(Cond); if (match(&CondI, m_CombineOr(m_LogicalAnd(), m_LogicalOr()))) { TinyPtrVector<Value *> Invariants = collectHomogenousInstGraphLoopInvariants(L, CondI, LI); @@ -2785,8 +2804,7 @@ static bool unswitchBestCondition( PartialIVInfo = *Info; PartialIVCondBranch = L.getHeader()->getTerminator(); TinyPtrVector<Value *> ValsToDuplicate; - for (auto *Inst : Info->InstToDuplicate) - ValsToDuplicate.push_back(Inst); + llvm::append_range(ValsToDuplicate, Info->InstToDuplicate); UnswitchCandidates.push_back( {L.getHeader()->getTerminator(), std::move(ValsToDuplicate)}); } @@ -2902,10 +2920,11 @@ static bool unswitchBestCondition( // its cost. if (!FullUnswitch) { auto &BI = cast<BranchInst>(TI); - if (match(BI.getCondition(), m_LogicalAnd())) { + Value *Cond = skipTrivialSelect(BI.getCondition()); + if (match(Cond, m_LogicalAnd())) { if (SuccBB == BI.getSuccessor(1)) continue; - } else if (match(BI.getCondition(), m_LogicalOr())) { + } else if (match(Cond, m_LogicalOr())) { if (SuccBB == BI.getSuccessor(0)) continue; } else if ((PartialIVInfo.KnownValue->isOneValue() && @@ -2947,8 +2966,9 @@ static bool unswitchBestCondition( ArrayRef<Value *> Invariants = TerminatorAndInvariants.second; BranchInst *BI = dyn_cast<BranchInst>(&TI); InstructionCost CandidateCost = ComputeUnswitchedCost( - TI, /*FullUnswitch*/ !BI || (Invariants.size() == 1 && - Invariants[0] == BI->getCondition())); + TI, /*FullUnswitch*/ !BI || + (Invariants.size() == 1 && + Invariants[0] == skipTrivialSelect(BI->getCondition()))); // Calculate cost multiplier which is a tool to limit potentially // exponential behavior of loop-unswitch. if (EnableUnswitchCostMultiplier) { @@ -3131,8 +3151,7 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, AR.MSSA->verifyMemorySSA(); } if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.AA, AR.TTI, Trivial, NonTrivial, - UnswitchCB, &AR.SE, - MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, + UnswitchCB, &AR.SE, MSSAU ? MSSAU.getPointer() : nullptr, DestroyLoopCB)) return PreservedAnalyses::all(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp index b8972751066d..fb2d812a186d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -31,19 +31,16 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/CFG.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Module.h" #include "llvm/IR/ValueHandle.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/SimplifyCFG.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SimplifyCFGOptions.h" #include <utility> diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Sink.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Sink.cpp index 8600aacdb056..e8fde53005f0 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Sink.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Sink.cpp @@ -15,12 +15,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/CFG.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -48,7 +43,7 @@ static bool isSafeToMove(Instruction *Inst, AliasAnalysis &AA, } if (Inst->isTerminator() || isa<PHINode>(Inst) || Inst->isEHPad() || - Inst->mayThrow()) + Inst->mayThrow() || !Inst->willReturn()) return false; if (auto *Call = dyn_cast<CallBase>(Inst)) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp index 06169a7834f6..9ac4608134c2 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp @@ -63,10 +63,10 @@ #include "llvm/Transforms/Scalar/SpeculativeExecution.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" @@ -275,7 +275,7 @@ bool SpeculativeExecutionPass::considerHoistingFromTo( }); } - // Usially debug label instrinsic corresponds to label in LLVM IR. In these + // Usially debug label intrinsic corresponds to label in LLVM IR. In these // cases we should not move it here. // TODO: Possible special processing needed to detect it is related to a // hoisted instruction. @@ -301,7 +301,7 @@ bool SpeculativeExecutionPass::considerHoistingFromTo( if (TotalSpeculationCost > SpecExecMaxSpeculationCost) return false; // too much to hoist } else { - // Debug info instrinsics should not be counted for threshold. + // Debug info intrinsics should not be counted for threshold. if (!isa<DbgInfoIntrinsic>(I)) NotHoistedInstCount++; if (NotHoistedInstCount > SpecExecMaxNotHoisted) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp index b47378808216..70df0cec0dca 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -68,7 +68,6 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -683,24 +682,16 @@ void StraightLineStrengthReduce::rewriteCandidateWithBasis( unsigned AS = Basis.Ins->getType()->getPointerAddressSpace(); Type *CharTy = Type::getInt8PtrTy(Basis.Ins->getContext(), AS); Reduced = Builder.CreateBitCast(Basis.Ins, CharTy); - if (InBounds) - Reduced = - Builder.CreateInBoundsGEP(Builder.getInt8Ty(), Reduced, Bump); - else - Reduced = Builder.CreateGEP(Builder.getInt8Ty(), Reduced, Bump); + Reduced = + Builder.CreateGEP(Builder.getInt8Ty(), Reduced, Bump, "", InBounds); Reduced = Builder.CreateBitCast(Reduced, C.Ins->getType()); } else { // C = gep Basis, Bump // Canonicalize bump to pointer size. Bump = Builder.CreateSExtOrTrunc(Bump, IntPtrTy); - if (InBounds) - Reduced = Builder.CreateInBoundsGEP( - cast<GetElementPtrInst>(Basis.Ins)->getResultElementType(), - Basis.Ins, Bump); - else - Reduced = Builder.CreateGEP( - cast<GetElementPtrInst>(Basis.Ins)->getResultElementType(), - Basis.Ins, Bump); + Reduced = Builder.CreateGEP( + cast<GetElementPtrInst>(Basis.Ins)->getResultElementType(), + Basis.Ins, Bump, "", InBounds); } break; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp index b3a445368537..f6525ad7de9b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -18,10 +18,8 @@ #include "llvm/Analysis/RegionInfo.h" #include "llvm/Analysis/RegionIterator.h" #include "llvm/Analysis/RegionPass.h" -#include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" -#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" @@ -33,7 +31,6 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" -#include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" #include "llvm/InitializePasses.h" @@ -41,7 +38,6 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils.h" @@ -72,6 +68,11 @@ static cl::opt<bool> cl::desc("Allow relaxed uniform region checks"), cl::init(true)); +static cl::opt<unsigned> + ReorderNodeSize("structurizecfg-node-reorder-size", + cl::desc("Limit region size for reordering nodes"), + cl::init(100), cl::Hidden); + // Definition of the complex types used in this pass. using BBValuePair = std::pair<BasicBlock *, Value *>; @@ -266,6 +267,8 @@ class StructurizeCFG { void orderNodes(); + void reorderNodes(); + void analyzeLoops(RegionNode *N); Value *buildCondition(BranchInst *Term, unsigned Idx, bool Invert); @@ -424,6 +427,57 @@ void StructurizeCFG::orderNodes() { } } +/// Change the node ordering to decrease the range of live values, especially +/// the values that capture the control flow path for branches. We do this +/// by moving blocks with a single predecessor and successor to appear after +/// predecessor. The motivation is to move some loop exit blocks into a loop. +/// In cases where a loop has a large number of exit blocks, this reduces the +/// amount of values needed across the loop boundary. +void StructurizeCFG::reorderNodes() { + SmallVector<RegionNode *, 8> NewOrder; + DenseMap<BasicBlock *, unsigned> MoveTo; + BitVector Moved(Order.size()); + + // The benefits of reordering nodes occurs for large regions. + if (Order.size() <= ReorderNodeSize) + return; + + // The algorithm works with two passes over Order. The first pass identifies + // the blocks to move and the position to move them to. The second pass + // creates the new order based upon this information. We move blocks with + // a single predecessor and successor. If there are multiple candidates then + // maintain the original order. + BBSet Seen; + for (int I = Order.size() - 1; I >= 0; --I) { + auto *BB = Order[I]->getEntry(); + Seen.insert(BB); + auto *Pred = BB->getSinglePredecessor(); + auto *Succ = BB->getSingleSuccessor(); + // Consider only those basic blocks that have a predecessor in Order and a + // successor that exits the region. The region may contain subregions that + // have been structurized and are not included in Order. + if (Pred && Succ && Seen.count(Pred) && Succ == ParentRegion->getExit() && + !MoveTo.count(Pred)) { + MoveTo[Pred] = I; + Moved.set(I); + } + } + + // If no blocks have been moved then the original order is good. + if (!Moved.count()) + return; + + for (size_t I = 0, E = Order.size(); I < E; ++I) { + auto *BB = Order[I]->getEntry(); + if (MoveTo.count(BB)) + NewOrder.push_back(Order[MoveTo[BB]]); + if (!Moved[I]) + NewOrder.push_back(Order[I]); + } + + Order.assign(NewOrder); +} + /// Determine the end of the loops void StructurizeCFG::analyzeLoops(RegionNode *N) { if (N->isSubRegion()) { @@ -685,7 +739,7 @@ void StructurizeCFG::simplifyAffectedPhis() { Q.DT = DT; for (WeakVH VH : AffectedPhis) { if (auto Phi = dyn_cast_or_null<PHINode>(VH)) { - if (auto NewValue = SimplifyInstruction(Phi, Q)) { + if (auto NewValue = simplifyInstruction(Phi, Q)) { Phi->replaceAllUsesWith(NewValue); Phi->eraseFromParent(); Changed = true; @@ -1085,12 +1139,13 @@ bool StructurizeCFG::run(Region *R, DominatorTree *DT) { ParentRegion = R; orderNodes(); + reorderNodes(); collectInfos(); createFlow(); insertConditions(false); insertConditions(true); - simplifyConditions(); setPhiValues(); + simplifyConditions(); simplifyAffectedPhis(); rebuildSSA(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/TLSVariableHoist.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/TLSVariableHoist.cpp new file mode 100644 index 000000000000..16b3483f9687 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/TLSVariableHoist.cpp @@ -0,0 +1,306 @@ +//===- TLSVariableHoist.cpp -------- Remove Redundant TLS Loads ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass identifies/eliminate Redundant TLS Loads if related option is set. +// The example: Please refer to the comment at the head of TLSVariableHoist.h. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Value.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/TLSVariableHoist.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <iterator> +#include <tuple> +#include <utility> + +using namespace llvm; +using namespace tlshoist; + +#define DEBUG_TYPE "tlshoist" + +static cl::opt<bool> TLSLoadHoist( + "tls-load-hoist", cl::init(false), cl::Hidden, + cl::desc("hoist the TLS loads in PIC model to eliminate redundant " + "TLS address calculation.")); + +namespace { + +/// The TLS Variable hoist pass. +class TLSVariableHoistLegacyPass : public FunctionPass { +public: + static char ID; // Pass identification, replacement for typeid + + TLSVariableHoistLegacyPass() : FunctionPass(ID) { + initializeTLSVariableHoistLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &Fn) override; + + StringRef getPassName() const override { return "TLS Variable Hoist"; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + } + +private: + TLSVariableHoistPass Impl; +}; + +} // end anonymous namespace + +char TLSVariableHoistLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(TLSVariableHoistLegacyPass, "tlshoist", + "TLS Variable Hoist", false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(TLSVariableHoistLegacyPass, "tlshoist", + "TLS Variable Hoist", false, false) + +FunctionPass *llvm::createTLSVariableHoistPass() { + return new TLSVariableHoistLegacyPass(); +} + +/// Perform the TLS Variable Hoist optimization for the given function. +bool TLSVariableHoistLegacyPass::runOnFunction(Function &Fn) { + if (skipFunction(Fn)) + return false; + + LLVM_DEBUG(dbgs() << "********** Begin TLS Variable Hoist **********\n"); + LLVM_DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n'); + + bool MadeChange = + Impl.runImpl(Fn, getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + getAnalysis<LoopInfoWrapperPass>().getLoopInfo()); + + if (MadeChange) { + LLVM_DEBUG(dbgs() << "********** Function after TLS Variable Hoist: " + << Fn.getName() << '\n'); + LLVM_DEBUG(dbgs() << Fn); + } + LLVM_DEBUG(dbgs() << "********** End TLS Variable Hoist **********\n"); + + return MadeChange; +} + +void TLSVariableHoistPass::collectTLSCandidate(Instruction *Inst) { + // Skip all cast instructions. They are visited indirectly later on. + if (Inst->isCast()) + return; + + // Scan all operands. + for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) { + auto *GV = dyn_cast<GlobalVariable>(Inst->getOperand(Idx)); + if (!GV || !GV->isThreadLocal()) + continue; + + // Add Candidate to TLSCandMap (GV --> Candidate). + TLSCandMap[GV].addUser(Inst, Idx); + } +} + +void TLSVariableHoistPass::collectTLSCandidates(Function &Fn) { + // First, quickly check if there is TLS Variable. + Module *M = Fn.getParent(); + + bool HasTLS = llvm::any_of( + M->globals(), [](GlobalVariable &GV) { return GV.isThreadLocal(); }); + + // If non, directly return. + if (!HasTLS) + return; + + TLSCandMap.clear(); + + // Then, collect TLS Variable info. + for (BasicBlock &BB : Fn) { + // Ignore unreachable basic blocks. + if (!DT->isReachableFromEntry(&BB)) + continue; + + for (Instruction &Inst : BB) + collectTLSCandidate(&Inst); + } +} + +static bool oneUseOutsideLoop(tlshoist::TLSCandidate &Cand, LoopInfo *LI) { + if (Cand.Users.size() != 1) + return false; + + BasicBlock *BB = Cand.Users[0].Inst->getParent(); + if (LI->getLoopFor(BB)) + return false; + + return true; +} + +Instruction *TLSVariableHoistPass::getNearestLoopDomInst(BasicBlock *BB, + Loop *L) { + assert(L && "Unexcepted Loop status!"); + + // Get the outermost loop. + while (Loop *Parent = L->getParentLoop()) + L = Parent; + + BasicBlock *PreHeader = L->getLoopPreheader(); + + // There is unique predecessor outside the loop. + if (PreHeader) + return PreHeader->getTerminator(); + + BasicBlock *Header = L->getHeader(); + BasicBlock *Dom = Header; + for (BasicBlock *PredBB : predecessors(Header)) + Dom = DT->findNearestCommonDominator(Dom, PredBB); + + assert(Dom && "Not find dominator BB!"); + Instruction *Term = Dom->getTerminator(); + + return Term; +} + +Instruction *TLSVariableHoistPass::getDomInst(Instruction *I1, + Instruction *I2) { + if (!I1) + return I2; + if (DT->dominates(I1, I2)) + return I1; + if (DT->dominates(I2, I1)) + return I2; + + // If there is no dominance relation, use common dominator. + BasicBlock *DomBB = + DT->findNearestCommonDominator(I1->getParent(), I2->getParent()); + + Instruction *Dom = DomBB->getTerminator(); + assert(Dom && "Common dominator not found!"); + + return Dom; +} + +BasicBlock::iterator TLSVariableHoistPass::findInsertPos(Function &Fn, + GlobalVariable *GV, + BasicBlock *&PosBB) { + tlshoist::TLSCandidate &Cand = TLSCandMap[GV]; + + // We should hoist the TLS use out of loop, so choose its nearest instruction + // which dominate the loop and the outside loops (if exist). + Instruction *LastPos = nullptr; + for (auto &User : Cand.Users) { + BasicBlock *BB = User.Inst->getParent(); + Instruction *Pos = User.Inst; + if (Loop *L = LI->getLoopFor(BB)) { + Pos = getNearestLoopDomInst(BB, L); + assert(Pos && "Not find insert position out of loop!"); + } + Pos = getDomInst(LastPos, Pos); + LastPos = Pos; + } + + assert(LastPos && "Unexpected insert position!"); + BasicBlock *Parent = LastPos->getParent(); + PosBB = Parent; + return LastPos->getIterator(); +} + +// Generate a bitcast (no type change) to replace the uses of TLS Candidate. +Instruction *TLSVariableHoistPass::genBitCastInst(Function &Fn, + GlobalVariable *GV) { + BasicBlock *PosBB = &Fn.getEntryBlock(); + BasicBlock::iterator Iter = findInsertPos(Fn, GV, PosBB); + Type *Ty = GV->getType(); + auto *CastInst = new BitCastInst(GV, Ty, "tls_bitcast"); + PosBB->getInstList().insert(Iter, CastInst); + return CastInst; +} + +bool TLSVariableHoistPass::tryReplaceTLSCandidate(Function &Fn, + GlobalVariable *GV) { + + tlshoist::TLSCandidate &Cand = TLSCandMap[GV]; + + // If only used 1 time and not in loops, we no need to replace it. + if (oneUseOutsideLoop(Cand, LI)) + return false; + + // Generate a bitcast (no type change) + auto *CastInst = genBitCastInst(Fn, GV); + + // to replace the uses of TLS Candidate + for (auto &User : Cand.Users) + User.Inst->setOperand(User.OpndIdx, CastInst); + + return true; +} + +bool TLSVariableHoistPass::tryReplaceTLSCandidates(Function &Fn) { + if (TLSCandMap.empty()) + return false; + + bool Replaced = false; + for (auto &GV2Cand : TLSCandMap) { + GlobalVariable *GV = GV2Cand.first; + Replaced |= tryReplaceTLSCandidate(Fn, GV); + } + + return Replaced; +} + +/// Optimize expensive TLS variables in the given function. +bool TLSVariableHoistPass::runImpl(Function &Fn, DominatorTree &DT, + LoopInfo &LI) { + if (Fn.hasOptNone()) + return false; + + if (!TLSLoadHoist && !Fn.getAttributes().hasFnAttr("tls-load-hoist")) + return false; + + this->LI = &LI; + this->DT = &DT; + assert(this->LI && this->DT && "Unexcepted requirement!"); + + // Collect all TLS variable candidates. + collectTLSCandidates(Fn); + + bool MadeChange = tryReplaceTLSCandidates(Fn); + + return MadeChange; +} + +PreservedAnalyses TLSVariableHoistPass::run(Function &F, + FunctionAnalysisManager &AM) { + + auto &LI = AM.getResult<LoopAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + + if (!runImpl(F, DT, LI)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp index 3bcf92e28a21..27c04177e894 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -53,11 +53,8 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/CFG.h" -#include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" @@ -76,14 +73,12 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" -#include "llvm/IR/ValueHandle.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" using namespace llvm; #define DEBUG_TYPE "tailcallelim" @@ -248,10 +243,10 @@ static bool markTails(Function &F, OptimizationRemarkEmitter *ORE) { isa<PseudoProbeInst>(&I)) continue; - // Special-case operand bundle "clang.arc.attachedcall". + // Special-case operand bundles "clang.arc.attachedcall" and "ptrauth". bool IsNoTail = CI->isNoTailCall() || CI->hasOperandBundlesOtherThan( - LLVMContext::OB_clang_arc_attachedcall); + {LLVMContext::OB_clang_arc_attachedcall, LLVMContext::OB_ptrauth}); if (!IsNoTail && CI->doesNotAccessMemory()) { // A call to a readnone function whose arguments are all things computed @@ -531,7 +526,7 @@ void TailRecursionEliminator::createTailRecurseLoopHeader(CallInst *CI) { } // If the function doen't return void, create the RetPN and RetKnownPN PHI - // nodes to track our return value. We initialize RetPN with undef and + // nodes to track our return value. We initialize RetPN with poison and // RetKnownPN with false since we can't know our return value at function // entry. Type *RetType = F.getReturnType(); @@ -540,7 +535,7 @@ void TailRecursionEliminator::createTailRecurseLoopHeader(CallInst *CI) { RetPN = PHINode::Create(RetType, 2, "ret.tr", InsertPos); RetKnownPN = PHINode::Create(BoolType, 2, "ret.known.tr", InsertPos); - RetPN->addIncoming(UndefValue::get(RetType), NewEntry); + RetPN->addIncoming(PoisonValue::get(RetType), NewEntry); RetKnownPN->addIncoming(ConstantInt::getFalse(BoolType), NewEntry); } @@ -734,7 +729,7 @@ void TailRecursionEliminator::cleanupAndFinalize() { // call. for (PHINode *PN : ArgumentPHIs) { // If the PHI Node is a dynamic constant, replace it with the value it is. - if (Value *PNV = SimplifyInstruction(PN, F.getParent()->getDataLayout())) { + if (Value *PNV = simplifyInstruction(PN, F.getParent()->getDataLayout())) { PN->replaceAllUsesWith(PNV); PN->eraseFromParent(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp index 80a7d3a43ad6..8367e61c1a47 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp @@ -61,7 +61,7 @@ static void warnAboutLeftoverTransformations(Loop *L, << "loop not vectorized: the optimizer was unable to perform the " "requested transformation; the transformation might be disabled " "or specified as part of an unsupported transformation ordering"); - else if (InterleaveCount.getValueOr(0) != 1) + else if (InterleaveCount.value_or(0) != 1) ORE->emit( DiagnosticInfoOptimizationFailure(DEBUG_TYPE, "FailedRequestedInterleaving", diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp index c734611836eb..24972db404be 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp @@ -50,9 +50,6 @@ static Value *callPrintfBegin(IRBuilder<> &Builder, Value *Version) { auto Int64Ty = Builder.getInt64Ty(); auto M = Builder.GetInsertBlock()->getModule(); auto Fn = M->getOrInsertFunction("__ockl_printf_begin", Int64Ty, Int64Ty); - if (!M->getModuleFlag("amdgpu_hostcall")) { - M->addModuleFlag(llvm::Module::Override, "amdgpu_hostcall", 1); - } return Builder.CreateCall(Fn, Version); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/ASanStackFrameLayout.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/ASanStackFrameLayout.cpp index cbc508bb863a..0318429a76a7 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/ASanStackFrameLayout.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/ASanStackFrameLayout.cpp @@ -11,7 +11,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/ASanStackFrameLayout.h" #include "llvm/ADT/SmallString.h" -#include "llvm/IR/DebugInfo.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/ScopedPrinter.h" #include "llvm/Support/raw_ostream.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/AddDiscriminators.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/AddDiscriminators.cpp index e789194eb3ab..e6372fc5ab86 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/AddDiscriminators.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/AddDiscriminators.cpp @@ -222,7 +222,7 @@ static bool addDiscriminators(Function &F) { << DIL->getColumn() << ":" << Discriminator << " " << I << "\n"); } else { - I.setDebugLoc(NewDIL.getValue()); + I.setDebugLoc(*NewDIL); LLVM_DEBUG(dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":" << DIL->getColumn() << ":" << Discriminator << " " << I << "\n"); @@ -260,7 +260,7 @@ static bool addDiscriminators(Function &F) { << CurrentDIL->getLine() << ":" << CurrentDIL->getColumn() << ":" << Discriminator << " " << I << "\n"); } else { - I.setDebugLoc(NewDIL.getValue()); + I.setDebugLoc(*NewDIL); Changed = true; } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp index f910f7c3c31f..02ea17825c2f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp @@ -18,6 +18,7 @@ #include "llvm/IR/InstIterator.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/DebugCounter.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp index 15c4a64eb794..e9983ff82176 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -21,7 +21,6 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/MemorySSAUpdater.h" -#include "llvm/Analysis/PostDominators.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" @@ -33,7 +32,6 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" -#include "llvm/IR/PseudoProbe.h" #include "llvm/IR/Type.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" @@ -1164,7 +1162,11 @@ SplitBlockPredecessorsImpl(BasicBlock *BB, ArrayRef<BasicBlock *> Preds, if (NewLatch != OldLatch) { MDNode *MD = OldLatch->getTerminator()->getMetadata("llvm.loop"); NewLatch->getTerminator()->setMetadata("llvm.loop", MD); - OldLatch->getTerminator()->setMetadata("llvm.loop", nullptr); + // It's still possible that OldLatch is the latch of another inner loop, + // in which case we do not remove the metadata. + Loop *IL = LI->getLoopFor(OldLatch); + if (IL && IL->getLoopLatch() != OldLatch) + OldLatch->getTerminator()->setMetadata("llvm.loop", nullptr); } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp index 1bb80be8ef99..0b36e8708a03 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp @@ -27,9 +27,7 @@ #include "llvm/IR/CFG.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/Type.h" #include "llvm/InitializePasses.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -317,18 +315,11 @@ llvm::SplitKnownCriticalEdge(Instruction *TI, unsigned SuccNum, // predecessors of BB. static BasicBlock * findIBRPredecessor(BasicBlock *BB, SmallVectorImpl<BasicBlock *> &OtherPreds) { - // If the block doesn't have any PHIs, we don't care about it, since there's - // no point in splitting it. - PHINode *PN = dyn_cast<PHINode>(BB->begin()); - if (!PN) - return nullptr; - // Verify we have exactly one IBR predecessor. // Conservatively bail out if one of the other predecessors is not a "regular" // terminator (that is, not a switch or a br). BasicBlock *IBB = nullptr; - for (unsigned Pred = 0, E = PN->getNumIncomingValues(); Pred != E; ++Pred) { - BasicBlock *PredBB = PN->getIncomingBlock(Pred); + for (BasicBlock *PredBB : predecessors(BB)) { Instruction *PredTerm = PredBB->getTerminator(); switch (PredTerm->getOpcode()) { case Instruction::IndirectBr: @@ -349,6 +340,7 @@ findIBRPredecessor(BasicBlock *BB, SmallVectorImpl<BasicBlock *> &OtherPreds) { } bool llvm::SplitIndirectBrCriticalEdges(Function &F, + bool IgnoreBlocksWithoutPHI, BranchProbabilityInfo *BPI, BlockFrequencyInfo *BFI) { // Check whether the function has any indirectbrs, and collect which blocks @@ -370,6 +362,9 @@ bool llvm::SplitIndirectBrCriticalEdges(Function &F, bool ShouldUpdateAnalysis = BPI && BFI; bool Changed = false; for (BasicBlock *Target : Targets) { + if (IgnoreBlocksWithoutPHI && Target->phis().empty()) + continue; + SmallVector<BasicBlock *, 16> OtherPreds; BasicBlock *IBRPred = findIBRPredecessor(Target, OtherPreds); // If we did not found an indirectbr, or the indirectbr is the only diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/BuildLibCalls.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/BuildLibCalls.cpp index 97f11ca71726..c4a58f36c171 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/BuildLibCalls.cpp @@ -13,16 +13,17 @@ #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/CallingConv.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Intrinsics.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" -#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Support/TypeSize.h" using namespace llvm; @@ -41,7 +42,6 @@ STATISTIC(NumInaccessibleMemOrArgMemOnly, STATISTIC(NumNoUnwind, "Number of functions inferred as nounwind"); STATISTIC(NumNoCapture, "Number of arguments inferred as nocapture"); STATISTIC(NumWriteOnlyArg, "Number of arguments inferred as writeonly"); -STATISTIC(NumSExtArg, "Number of arguments inferred as signext"); STATISTIC(NumReadOnlyArg, "Number of arguments inferred as readonly"); STATISTIC(NumNoAlias, "Number of function returns inferred as noalias"); STATISTIC(NumNoUndef, "Number of function returns inferred as noundef returns"); @@ -149,14 +149,6 @@ static bool setOnlyWritesMemory(Function &F, unsigned ArgNo) { return true; } -static bool setSignExtendedArg(Function &F, unsigned ArgNo) { - if (F.hasParamAttribute(ArgNo, Attribute::SExt)) - return false; - F.addParamAttr(ArgNo, Attribute::SExt); - ++NumSExtArg; - return true; -} - static bool setRetNoUndef(Function &F) { if (!F.getReturnType()->isVoidTy() && !F.hasRetAttribute(Attribute::NoUndef)) { @@ -224,15 +216,54 @@ static bool setWillReturn(Function &F) { return true; } -bool llvm::inferLibFuncAttributes(Module *M, StringRef Name, - const TargetLibraryInfo &TLI) { +static bool setAlignedAllocParam(Function &F, unsigned ArgNo) { + if (F.hasParamAttribute(ArgNo, Attribute::AllocAlign)) + return false; + F.addParamAttr(ArgNo, Attribute::AllocAlign); + return true; +} + +static bool setAllocatedPointerParam(Function &F, unsigned ArgNo) { + if (F.hasParamAttribute(ArgNo, Attribute::AllocatedPointer)) + return false; + F.addParamAttr(ArgNo, Attribute::AllocatedPointer); + return true; +} + +static bool setAllocSize(Function &F, unsigned ElemSizeArg, + Optional<unsigned> NumElemsArg) { + if (F.hasFnAttribute(Attribute::AllocSize)) + return false; + F.addFnAttr(Attribute::getWithAllocSizeArgs(F.getContext(), ElemSizeArg, + NumElemsArg)); + return true; +} + +static bool setAllocFamily(Function &F, StringRef Family) { + if (F.hasFnAttribute("alloc-family")) + return false; + F.addFnAttr("alloc-family", Family); + return true; +} + +static bool setAllocKind(Function &F, AllocFnKind K) { + if (F.hasFnAttribute(Attribute::AllocKind)) + return false; + F.addFnAttr( + Attribute::get(F.getContext(), Attribute::AllocKind, uint64_t(K))); + return true; +} + +bool llvm::inferNonMandatoryLibFuncAttrs(Module *M, StringRef Name, + const TargetLibraryInfo &TLI) { Function *F = M->getFunction(Name); if (!F) return false; - return inferLibFuncAttributes(*F, TLI); + return inferNonMandatoryLibFuncAttrs(*F, TLI); } -bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { +bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, + const TargetLibraryInfo &TLI) { LibFunc TheLibFunc; if (!(TLI.getLibFunc(F, TheLibFunc) && TLI.has(TheLibFunc))) return false; @@ -360,6 +391,7 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { Changed |= setArgNoUndef(F, 1); LLVM_FALLTHROUGH; case LibFunc_strdup: + Changed |= setAllocFamily(F, "malloc"); Changed |= setOnlyAccessesInaccessibleMemOrArgMem(F); Changed |= setDoesNotThrow(F); Changed |= setRetDoesNotAlias(F); @@ -416,9 +448,17 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { Changed |= setOnlyReadsMemory(F, 0); return Changed; case LibFunc_aligned_alloc: + Changed |= setAlignedAllocParam(F, 0); + Changed |= setAllocSize(F, 1, None); + Changed |= setAllocKind(F, AllocFnKind::Alloc | AllocFnKind::Uninitialized | AllocFnKind::Aligned); + LLVM_FALLTHROUGH; case LibFunc_valloc: case LibFunc_malloc: case LibFunc_vec_malloc: + Changed |= setAllocFamily(F, TheLibFunc == LibFunc_vec_malloc ? "vec_malloc" + : "malloc"); + Changed |= setAllocKind(F, AllocFnKind::Alloc | AllocFnKind::Uninitialized); + Changed |= setAllocSize(F, 0, None); Changed |= setOnlyAccessesInaccessibleMemory(F); Changed |= setRetAndArgsNoUndef(F); Changed |= setDoesNotThrow(F); @@ -481,6 +521,11 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { Changed |= setOnlyReadsMemory(F, 1); return Changed; case LibFunc_memalign: + Changed |= setAllocFamily(F, "malloc"); + Changed |= setAllocKind(F, AllocFnKind::Alloc | AllocFnKind::Aligned | + AllocFnKind::Uninitialized); + Changed |= setAllocSize(F, 1, None); + Changed |= setAlignedAllocParam(F, 0); Changed |= setOnlyAccessesInaccessibleMemory(F); Changed |= setRetNoUndef(F); Changed |= setDoesNotThrow(F); @@ -500,8 +545,13 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { Changed |= setDoesNotCapture(F, 0); return Changed; case LibFunc_realloc: - case LibFunc_vec_realloc: case LibFunc_reallocf: + case LibFunc_vec_realloc: + Changed |= setAllocFamily( + F, TheLibFunc == LibFunc_vec_realloc ? "vec_malloc" : "malloc"); + Changed |= setAllocKind(F, AllocFnKind::Realloc); + Changed |= setAllocatedPointerParam(F, 0); + Changed |= setAllocSize(F, 1, None); Changed |= setOnlyAccessesInaccessibleMemOrArgMem(F); Changed |= setRetNoUndef(F); Changed |= setDoesNotThrow(F); @@ -575,6 +625,10 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { return Changed; case LibFunc_calloc: case LibFunc_vec_calloc: + Changed |= setAllocFamily(F, TheLibFunc == LibFunc_vec_calloc ? "vec_malloc" + : "malloc"); + Changed |= setAllocKind(F, AllocFnKind::Alloc | AllocFnKind::Zeroed); + Changed |= setAllocSize(F, 0, 1); Changed |= setOnlyAccessesInaccessibleMemory(F); Changed |= setRetAndArgsNoUndef(F); Changed |= setDoesNotThrow(F); @@ -633,6 +687,10 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { return Changed; case LibFunc_free: case LibFunc_vec_free: + Changed |= setAllocFamily(F, TheLibFunc == LibFunc_vec_free ? "vec_malloc" + : "malloc"); + Changed |= setAllocKind(F, AllocFnKind::Free); + Changed |= setAllocatedPointerParam(F, 0); Changed |= setOnlyAccessesInaccessibleMemOrArgMem(F); Changed |= setArgsNoUndef(F); Changed |= setDoesNotThrow(F); @@ -1041,7 +1099,6 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { case LibFunc_ldexp: case LibFunc_ldexpf: case LibFunc_ldexpl: - Changed |= setSignExtendedArg(F, 1); Changed |= setWillReturn(F); return Changed; case LibFunc_abs: @@ -1178,34 +1235,179 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { } } -bool llvm::hasFloatFn(const TargetLibraryInfo *TLI, Type *Ty, +static void setArgExtAttr(Function &F, unsigned ArgNo, + const TargetLibraryInfo &TLI, bool Signed = true) { + Attribute::AttrKind ExtAttr = TLI.getExtAttrForI32Param(Signed); + if (ExtAttr != Attribute::None && !F.hasParamAttribute(ArgNo, ExtAttr)) + F.addParamAttr(ArgNo, ExtAttr); +} + +// Modeled after X86TargetLowering::markLibCallAttributes. +static void markRegisterParameterAttributes(Function *F) { + if (!F->arg_size() || F->isVarArg()) + return; + + const CallingConv::ID CC = F->getCallingConv(); + if (CC != CallingConv::C && CC != CallingConv::X86_StdCall) + return; + + const Module *M = F->getParent(); + unsigned N = M->getNumberRegisterParameters(); + if (!N) + return; + + const DataLayout &DL = M->getDataLayout(); + + for (Argument &A : F->args()) { + Type *T = A.getType(); + if (!T->isIntOrPtrTy()) + continue; + + const TypeSize &TS = DL.getTypeAllocSize(T); + if (TS > 8) + continue; + + assert(TS <= 4 && "Need to account for parameters larger than word size"); + const unsigned NumRegs = TS > 4 ? 2 : 1; + if (N < NumRegs) + return; + + N -= NumRegs; + F->addParamAttr(A.getArgNo(), Attribute::InReg); + } +} + +FunctionCallee llvm::getOrInsertLibFunc(Module *M, const TargetLibraryInfo &TLI, + LibFunc TheLibFunc, FunctionType *T, + AttributeList AttributeList) { + assert(TLI.has(TheLibFunc) && + "Creating call to non-existing library function."); + StringRef Name = TLI.getName(TheLibFunc); + FunctionCallee C = M->getOrInsertFunction(Name, T, AttributeList); + + // Make sure any mandatory argument attributes are added. + + // Any outgoing i32 argument should be handled with setArgExtAttr() which + // will add an extension attribute if the target ABI requires it. Adding + // argument extensions is typically done by the front end but when an + // optimizer is building a library call on its own it has to take care of + // this. Each such generated function must be handled here with sign or + // zero extensions as needed. F is retreived with cast<> because we demand + // of the caller to have called isLibFuncEmittable() first. + Function *F = cast<Function>(C.getCallee()); + assert(F->getFunctionType() == T && "Function type does not match."); + switch (TheLibFunc) { + case LibFunc_fputc: + case LibFunc_putchar: + setArgExtAttr(*F, 0, TLI); + break; + case LibFunc_ldexp: + case LibFunc_ldexpf: + case LibFunc_ldexpl: + case LibFunc_memchr: + case LibFunc_memrchr: + case LibFunc_strchr: + setArgExtAttr(*F, 1, TLI); + break; + case LibFunc_memccpy: + setArgExtAttr(*F, 2, TLI); + break; + + // These are functions that are known to not need any argument extension + // on any target: A size_t argument (which may be an i32 on some targets) + // should not trigger the assert below. + case LibFunc_bcmp: + case LibFunc_calloc: + case LibFunc_fwrite: + case LibFunc_malloc: + case LibFunc_memcmp: + case LibFunc_memcpy_chk: + case LibFunc_mempcpy: + case LibFunc_memset_pattern16: + case LibFunc_snprintf: + case LibFunc_stpncpy: + case LibFunc_strlcat: + case LibFunc_strlcpy: + case LibFunc_strncat: + case LibFunc_strncmp: + case LibFunc_strncpy: + case LibFunc_vsnprintf: + break; + + default: +#ifndef NDEBUG + for (unsigned i = 0; i < T->getNumParams(); i++) + assert(!isa<IntegerType>(T->getParamType(i)) && + "Unhandled integer argument."); +#endif + break; + } + + markRegisterParameterAttributes(F); + + return C; +} + +FunctionCallee llvm::getOrInsertLibFunc(Module *M, const TargetLibraryInfo &TLI, + LibFunc TheLibFunc, FunctionType *T) { + return getOrInsertLibFunc(M, TLI, TheLibFunc, T, AttributeList()); +} + +bool llvm::isLibFuncEmittable(const Module *M, const TargetLibraryInfo *TLI, + LibFunc TheLibFunc) { + StringRef FuncName = TLI->getName(TheLibFunc); + if (!TLI->has(TheLibFunc)) + return false; + + // Check if the Module already has a GlobalValue with the same name, in + // which case it must be a Function with the expected type. + if (GlobalValue *GV = M->getNamedValue(FuncName)) { + if (auto *F = dyn_cast<Function>(GV)) + return TLI->isValidProtoForLibFunc(*F->getFunctionType(), TheLibFunc, *M); + return false; + } + + return true; +} + +bool llvm::isLibFuncEmittable(const Module *M, const TargetLibraryInfo *TLI, + StringRef Name) { + LibFunc TheLibFunc; + return TLI->getLibFunc(Name, TheLibFunc) && + isLibFuncEmittable(M, TLI, TheLibFunc); +} + +bool llvm::hasFloatFn(const Module *M, const TargetLibraryInfo *TLI, Type *Ty, LibFunc DoubleFn, LibFunc FloatFn, LibFunc LongDoubleFn) { switch (Ty->getTypeID()) { case Type::HalfTyID: return false; case Type::FloatTyID: - return TLI->has(FloatFn); + return isLibFuncEmittable(M, TLI, FloatFn); case Type::DoubleTyID: - return TLI->has(DoubleFn); + return isLibFuncEmittable(M, TLI, DoubleFn); default: - return TLI->has(LongDoubleFn); + return isLibFuncEmittable(M, TLI, LongDoubleFn); } } -StringRef llvm::getFloatFnName(const TargetLibraryInfo *TLI, Type *Ty, - LibFunc DoubleFn, LibFunc FloatFn, - LibFunc LongDoubleFn) { - assert(hasFloatFn(TLI, Ty, DoubleFn, FloatFn, LongDoubleFn) && +StringRef llvm::getFloatFn(const Module *M, const TargetLibraryInfo *TLI, + Type *Ty, LibFunc DoubleFn, LibFunc FloatFn, + LibFunc LongDoubleFn, LibFunc &TheLibFunc) { + assert(hasFloatFn(M, TLI, Ty, DoubleFn, FloatFn, LongDoubleFn) && "Cannot get name for unavailable function!"); switch (Ty->getTypeID()) { case Type::HalfTyID: llvm_unreachable("No name for HalfTy!"); case Type::FloatTyID: + TheLibFunc = FloatFn; return TLI->getName(FloatFn); case Type::DoubleTyID: + TheLibFunc = DoubleFn; return TLI->getName(DoubleFn); default: + TheLibFunc = LongDoubleFn; return TLI->getName(LongDoubleFn); } } @@ -1222,14 +1424,14 @@ static Value *emitLibCall(LibFunc TheLibFunc, Type *ReturnType, ArrayRef<Value *> Operands, IRBuilderBase &B, const TargetLibraryInfo *TLI, bool IsVaArgs = false) { - if (!TLI->has(TheLibFunc)) + Module *M = B.GetInsertBlock()->getModule(); + if (!isLibFuncEmittable(M, TLI, TheLibFunc)) return nullptr; - Module *M = B.GetInsertBlock()->getModule(); StringRef FuncName = TLI->getName(TheLibFunc); FunctionType *FuncType = FunctionType::get(ReturnType, ParamTypes, IsVaArgs); - FunctionCallee Callee = M->getOrInsertFunction(FuncName, FuncType); - inferLibFuncAttributes(M, FuncName, *TLI); + FunctionCallee Callee = getOrInsertLibFunc(M, *TLI, TheLibFunc, FuncType); + inferNonMandatoryLibFuncAttrs(M, FuncName, *TLI); CallInst *CI = B.CreateCall(Callee, Operands, FuncName); if (const Function *F = dyn_cast<Function>(Callee.getCallee()->stripPointerCasts())) @@ -1298,16 +1500,16 @@ Value *llvm::emitStpNCpy(Value *Dst, Value *Src, Value *Len, IRBuilderBase &B, Value *llvm::emitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc_memcpy_chk)) + Module *M = B.GetInsertBlock()->getModule(); + if (!isLibFuncEmittable(M, TLI, LibFunc_memcpy_chk)) return nullptr; - Module *M = B.GetInsertBlock()->getModule(); AttributeList AS; AS = AttributeList::get(M->getContext(), AttributeList::FunctionIndex, Attribute::NoUnwind); LLVMContext &Context = B.GetInsertBlock()->getContext(); - FunctionCallee MemCpy = M->getOrInsertFunction( - "__memcpy_chk", AttributeList::get(M->getContext(), AS), B.getInt8PtrTy(), + FunctionCallee MemCpy = getOrInsertLibFunc(M, *TLI, LibFunc_memcpy_chk, + AttributeList::get(M->getContext(), AS), B.getInt8PtrTy(), B.getInt8PtrTy(), B.getInt8PtrTy(), DL.getIntPtrType(Context), DL.getIntPtrType(Context)); Dst = castToCStr(Dst, B); @@ -1337,6 +1539,15 @@ Value *llvm::emitMemChr(Value *Ptr, Value *Val, Value *Len, IRBuilderBase &B, {castToCStr(Ptr, B), Val, Len}, B, TLI); } +Value *llvm::emitMemRChr(Value *Ptr, Value *Val, Value *Len, IRBuilderBase &B, + const DataLayout &DL, const TargetLibraryInfo *TLI) { + LLVMContext &Context = B.GetInsertBlock()->getContext(); + return emitLibCall( + LibFunc_memrchr, B.getInt8PtrTy(), + {B.getInt8PtrTy(), B.getInt32Ty(), DL.getIntPtrType(Context)}, + {castToCStr(Ptr, B), Val, Len}, B, TLI); +} + Value *llvm::emitMemCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { LLVMContext &Context = B.GetInsertBlock()->getContext(); @@ -1441,14 +1652,15 @@ static void appendTypeSuffix(Value *Op, StringRef &Name, } } -static Value *emitUnaryFloatFnCallHelper(Value *Op, StringRef Name, - IRBuilderBase &B, - const AttributeList &Attrs) { +static Value *emitUnaryFloatFnCallHelper(Value *Op, LibFunc TheLibFunc, + StringRef Name, IRBuilderBase &B, + const AttributeList &Attrs, + const TargetLibraryInfo *TLI) { assert((Name != "") && "Must specify Name to emitUnaryFloatFnCall"); Module *M = B.GetInsertBlock()->getModule(); - FunctionCallee Callee = - M->getOrInsertFunction(Name, Op->getType(), Op->getType()); + FunctionCallee Callee = getOrInsertLibFunc(M, *TLI, TheLibFunc, Op->getType(), + Op->getType()); CallInst *CI = B.CreateCall(Callee, Op, Name); // The incoming attribute set may have come from a speculatable intrinsic, but @@ -1463,12 +1675,16 @@ static Value *emitUnaryFloatFnCallHelper(Value *Op, StringRef Name, return CI; } -Value *llvm::emitUnaryFloatFnCall(Value *Op, StringRef Name, IRBuilderBase &B, +Value *llvm::emitUnaryFloatFnCall(Value *Op, const TargetLibraryInfo *TLI, + StringRef Name, IRBuilderBase &B, const AttributeList &Attrs) { SmallString<20> NameBuffer; appendTypeSuffix(Op, Name, NameBuffer); - return emitUnaryFloatFnCallHelper(Op, Name, B, Attrs); + LibFunc TheLibFunc; + TLI->getLibFunc(Name, TheLibFunc); + + return emitUnaryFloatFnCallHelper(Op, TheLibFunc, Name, B, Attrs, TLI); } Value *llvm::emitUnaryFloatFnCall(Value *Op, const TargetLibraryInfo *TLI, @@ -1476,23 +1692,25 @@ Value *llvm::emitUnaryFloatFnCall(Value *Op, const TargetLibraryInfo *TLI, LibFunc LongDoubleFn, IRBuilderBase &B, const AttributeList &Attrs) { // Get the name of the function according to TLI. - StringRef Name = getFloatFnName(TLI, Op->getType(), - DoubleFn, FloatFn, LongDoubleFn); + Module *M = B.GetInsertBlock()->getModule(); + LibFunc TheLibFunc; + StringRef Name = getFloatFn(M, TLI, Op->getType(), DoubleFn, FloatFn, + LongDoubleFn, TheLibFunc); - return emitUnaryFloatFnCallHelper(Op, Name, B, Attrs); + return emitUnaryFloatFnCallHelper(Op, TheLibFunc, Name, B, Attrs, TLI); } static Value *emitBinaryFloatFnCallHelper(Value *Op1, Value *Op2, + LibFunc TheLibFunc, StringRef Name, IRBuilderBase &B, const AttributeList &Attrs, - const TargetLibraryInfo *TLI = nullptr) { + const TargetLibraryInfo *TLI) { assert((Name != "") && "Must specify Name to emitBinaryFloatFnCall"); Module *M = B.GetInsertBlock()->getModule(); - FunctionCallee Callee = M->getOrInsertFunction(Name, Op1->getType(), - Op1->getType(), Op2->getType()); - if (TLI != nullptr) - inferLibFuncAttributes(M, Name, *TLI); + FunctionCallee Callee = getOrInsertLibFunc(M, *TLI, TheLibFunc, Op1->getType(), + Op1->getType(), Op2->getType()); + inferNonMandatoryLibFuncAttrs(M, Name, *TLI); CallInst *CI = B.CreateCall(Callee, { Op1, Op2 }, Name); // The incoming attribute set may have come from a speculatable intrinsic, but @@ -1507,15 +1725,19 @@ static Value *emitBinaryFloatFnCallHelper(Value *Op1, Value *Op2, return CI; } -Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name, - IRBuilderBase &B, +Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, + const TargetLibraryInfo *TLI, + StringRef Name, IRBuilderBase &B, const AttributeList &Attrs) { assert((Name != "") && "Must specify Name to emitBinaryFloatFnCall"); SmallString<20> NameBuffer; appendTypeSuffix(Op1, Name, NameBuffer); - return emitBinaryFloatFnCallHelper(Op1, Op2, Name, B, Attrs); + LibFunc TheLibFunc; + TLI->getLibFunc(Name, TheLibFunc); + + return emitBinaryFloatFnCallHelper(Op1, Op2, TheLibFunc, Name, B, Attrs, TLI); } Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, @@ -1524,22 +1746,24 @@ Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, LibFunc LongDoubleFn, IRBuilderBase &B, const AttributeList &Attrs) { // Get the name of the function according to TLI. - StringRef Name = getFloatFnName(TLI, Op1->getType(), - DoubleFn, FloatFn, LongDoubleFn); + Module *M = B.GetInsertBlock()->getModule(); + LibFunc TheLibFunc; + StringRef Name = getFloatFn(M, TLI, Op1->getType(), DoubleFn, FloatFn, + LongDoubleFn, TheLibFunc); - return emitBinaryFloatFnCallHelper(Op1, Op2, Name, B, Attrs, TLI); + return emitBinaryFloatFnCallHelper(Op1, Op2, TheLibFunc, Name, B, Attrs, TLI); } Value *llvm::emitPutChar(Value *Char, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc_putchar)) + Module *M = B.GetInsertBlock()->getModule(); + if (!isLibFuncEmittable(M, TLI, LibFunc_putchar)) return nullptr; - Module *M = B.GetInsertBlock()->getModule(); StringRef PutCharName = TLI->getName(LibFunc_putchar); - FunctionCallee PutChar = - M->getOrInsertFunction(PutCharName, B.getInt32Ty(), B.getInt32Ty()); - inferLibFuncAttributes(M, PutCharName, *TLI); + FunctionCallee PutChar = getOrInsertLibFunc(M, *TLI, LibFunc_putchar, + B.getInt32Ty(), B.getInt32Ty()); + inferNonMandatoryLibFuncAttrs(M, PutCharName, *TLI); CallInst *CI = B.CreateCall(PutChar, B.CreateIntCast(Char, B.getInt32Ty(), @@ -1555,14 +1779,14 @@ Value *llvm::emitPutChar(Value *Char, IRBuilderBase &B, Value *llvm::emitPutS(Value *Str, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc_puts)) + Module *M = B.GetInsertBlock()->getModule(); + if (!isLibFuncEmittable(M, TLI, LibFunc_puts)) return nullptr; - Module *M = B.GetInsertBlock()->getModule(); StringRef PutsName = TLI->getName(LibFunc_puts); - FunctionCallee PutS = - M->getOrInsertFunction(PutsName, B.getInt32Ty(), B.getInt8PtrTy()); - inferLibFuncAttributes(M, PutsName, *TLI); + FunctionCallee PutS = getOrInsertLibFunc(M, *TLI, LibFunc_puts, B.getInt32Ty(), + B.getInt8PtrTy()); + inferNonMandatoryLibFuncAttrs(M, PutsName, *TLI); CallInst *CI = B.CreateCall(PutS, castToCStr(Str, B), PutsName); if (const Function *F = dyn_cast<Function>(PutS.getCallee()->stripPointerCasts())) @@ -1572,15 +1796,15 @@ Value *llvm::emitPutS(Value *Str, IRBuilderBase &B, Value *llvm::emitFPutC(Value *Char, Value *File, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc_fputc)) + Module *M = B.GetInsertBlock()->getModule(); + if (!isLibFuncEmittable(M, TLI, LibFunc_fputc)) return nullptr; - Module *M = B.GetInsertBlock()->getModule(); StringRef FPutcName = TLI->getName(LibFunc_fputc); - FunctionCallee F = M->getOrInsertFunction(FPutcName, B.getInt32Ty(), - B.getInt32Ty(), File->getType()); + FunctionCallee F = getOrInsertLibFunc(M, *TLI, LibFunc_fputc, B.getInt32Ty(), + B.getInt32Ty(), File->getType()); if (File->getType()->isPointerTy()) - inferLibFuncAttributes(M, FPutcName, *TLI); + inferNonMandatoryLibFuncAttrs(M, FPutcName, *TLI); Char = B.CreateIntCast(Char, B.getInt32Ty(), /*isSigned*/true, "chari"); CallInst *CI = B.CreateCall(F, {Char, File}, FPutcName); @@ -1593,15 +1817,15 @@ Value *llvm::emitFPutC(Value *Char, Value *File, IRBuilderBase &B, Value *llvm::emitFPutS(Value *Str, Value *File, IRBuilderBase &B, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc_fputs)) + Module *M = B.GetInsertBlock()->getModule(); + if (!isLibFuncEmittable(M, TLI, LibFunc_fputs)) return nullptr; - Module *M = B.GetInsertBlock()->getModule(); StringRef FPutsName = TLI->getName(LibFunc_fputs); - FunctionCallee F = M->getOrInsertFunction(FPutsName, B.getInt32Ty(), - B.getInt8PtrTy(), File->getType()); + FunctionCallee F = getOrInsertLibFunc(M, *TLI, LibFunc_fputs, B.getInt32Ty(), + B.getInt8PtrTy(), File->getType()); if (File->getType()->isPointerTy()) - inferLibFuncAttributes(M, FPutsName, *TLI); + inferNonMandatoryLibFuncAttrs(M, FPutsName, *TLI); CallInst *CI = B.CreateCall(F, {castToCStr(Str, B), File}, FPutsName); if (const Function *Fn = @@ -1612,18 +1836,18 @@ Value *llvm::emitFPutS(Value *Str, Value *File, IRBuilderBase &B, Value *llvm::emitFWrite(Value *Ptr, Value *Size, Value *File, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc_fwrite)) + Module *M = B.GetInsertBlock()->getModule(); + if (!isLibFuncEmittable(M, TLI, LibFunc_fwrite)) return nullptr; - Module *M = B.GetInsertBlock()->getModule(); LLVMContext &Context = B.GetInsertBlock()->getContext(); StringRef FWriteName = TLI->getName(LibFunc_fwrite); - FunctionCallee F = M->getOrInsertFunction( - FWriteName, DL.getIntPtrType(Context), B.getInt8PtrTy(), - DL.getIntPtrType(Context), DL.getIntPtrType(Context), File->getType()); + FunctionCallee F = getOrInsertLibFunc(M, *TLI, LibFunc_fwrite, + DL.getIntPtrType(Context), B.getInt8PtrTy(), DL.getIntPtrType(Context), + DL.getIntPtrType(Context), File->getType()); if (File->getType()->isPointerTy()) - inferLibFuncAttributes(M, FWriteName, *TLI); + inferNonMandatoryLibFuncAttrs(M, FWriteName, *TLI); CallInst *CI = B.CreateCall(F, {castToCStr(Ptr, B), Size, ConstantInt::get(DL.getIntPtrType(Context), 1), File}); @@ -1636,15 +1860,15 @@ Value *llvm::emitFWrite(Value *Ptr, Value *Size, Value *File, IRBuilderBase &B, Value *llvm::emitMalloc(Value *Num, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI) { - if (!TLI->has(LibFunc_malloc)) + Module *M = B.GetInsertBlock()->getModule(); + if (!isLibFuncEmittable(M, TLI, LibFunc_malloc)) return nullptr; - Module *M = B.GetInsertBlock()->getModule(); StringRef MallocName = TLI->getName(LibFunc_malloc); LLVMContext &Context = B.GetInsertBlock()->getContext(); - FunctionCallee Malloc = M->getOrInsertFunction(MallocName, B.getInt8PtrTy(), - DL.getIntPtrType(Context)); - inferLibFuncAttributes(M, MallocName, *TLI); + FunctionCallee Malloc = getOrInsertLibFunc(M, *TLI, LibFunc_malloc, + B.getInt8PtrTy(), DL.getIntPtrType(Context)); + inferNonMandatoryLibFuncAttrs(M, MallocName, *TLI); CallInst *CI = B.CreateCall(Malloc, Num, MallocName); if (const Function *F = @@ -1656,16 +1880,16 @@ Value *llvm::emitMalloc(Value *Num, IRBuilderBase &B, const DataLayout &DL, Value *llvm::emitCalloc(Value *Num, Value *Size, IRBuilderBase &B, const TargetLibraryInfo &TLI) { - if (!TLI.has(LibFunc_calloc)) + Module *M = B.GetInsertBlock()->getModule(); + if (!isLibFuncEmittable(M, &TLI, LibFunc_calloc)) return nullptr; - Module *M = B.GetInsertBlock()->getModule(); StringRef CallocName = TLI.getName(LibFunc_calloc); const DataLayout &DL = M->getDataLayout(); IntegerType *PtrType = DL.getIntPtrType((B.GetInsertBlock()->getContext())); - FunctionCallee Calloc = - M->getOrInsertFunction(CallocName, B.getInt8PtrTy(), PtrType, PtrType); - inferLibFuncAttributes(M, CallocName, TLI); + FunctionCallee Calloc = getOrInsertLibFunc(M, TLI, LibFunc_calloc, + B.getInt8PtrTy(), PtrType, PtrType); + inferNonMandatoryLibFuncAttrs(M, CallocName, TLI); CallInst *CI = B.CreateCall(Calloc, {Num, Size}, CallocName); if (const auto *F = diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp index ac3839f2a4ab..1840f26add2d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp @@ -14,6 +14,9 @@ #include "llvm/Transforms/Utils/CallGraphUpdater.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/Analysis/CallGraphSCCPass.h" +#include "llvm/IR/Constants.h" #include "llvm/Transforms/Utils/ModuleUtils.h" using namespace llvm; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp index 56b6e4bc46a5..e530afc277db 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -279,8 +279,8 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) { /// ; The original call instruction stays in its original block. /// %t0 = musttail call i32 %ptr() /// ret %t0 -static CallBase &versionCallSite(CallBase &CB, Value *Callee, - MDNode *BranchWeights) { +CallBase &llvm::versionCallSite(CallBase &CB, Value *Callee, + MDNode *BranchWeights) { IRBuilder<> Builder(&CB); CallBase *OrigInst = &CB; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp index 6b01c0c71d00..f229d4bf14e9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp @@ -30,8 +30,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/CanonicalizeAliases.h" -#include "llvm/IR/Operator.h" -#include "llvm/IR/ValueHandle.h" +#include "llvm/IR/Constants.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp index 049c7d113521..a1ee3df907ec 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp @@ -29,7 +29,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/CanonicalizeFreezeInLoops.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/IVDescriptors.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneFunction.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneFunction.cpp index 86413df664a0..8f053cd56e0e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneFunction.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneFunction.cpp @@ -14,7 +14,6 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" @@ -23,7 +22,6 @@ #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" -#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" @@ -324,6 +322,9 @@ struct PruningFunctionCloner { bool ModuleLevelChanges; const char *NameSuffix; ClonedCodeInfo *CodeInfo; + bool HostFuncIsStrictFP; + + Instruction *cloneInstruction(BasicBlock::const_iterator II); public: PruningFunctionCloner(Function *newFunc, const Function *oldFunc, @@ -331,7 +332,10 @@ public: const char *nameSuffix, ClonedCodeInfo *codeInfo) : NewFunc(newFunc), OldFunc(oldFunc), VMap(valueMap), ModuleLevelChanges(moduleLevelChanges), NameSuffix(nameSuffix), - CodeInfo(codeInfo) {} + CodeInfo(codeInfo) { + HostFuncIsStrictFP = + newFunc->getAttributes().hasFnAttr(Attribute::StrictFP); + } /// The specified block is found to be reachable, clone it and /// anything that it can reach. @@ -340,6 +344,89 @@ public: }; } // namespace +static bool hasRoundingModeOperand(Intrinsic::ID CIID) { + switch (CIID) { +#define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC) \ + case Intrinsic::INTRINSIC: \ + return ROUND_MODE == 1; +#define FUNCTION INSTRUCTION +#include "llvm/IR/ConstrainedOps.def" + default: + llvm_unreachable("Unexpected constrained intrinsic id"); + } +} + +Instruction * +PruningFunctionCloner::cloneInstruction(BasicBlock::const_iterator II) { + const Instruction &OldInst = *II; + Instruction *NewInst = nullptr; + if (HostFuncIsStrictFP) { + Intrinsic::ID CIID = getConstrainedIntrinsicID(OldInst); + if (CIID != Intrinsic::not_intrinsic) { + // Instead of cloning the instruction, a call to constrained intrinsic + // should be created. + // Assume the first arguments of constrained intrinsics are the same as + // the operands of original instruction. + + // Determine overloaded types of the intrinsic. + SmallVector<Type *, 2> TParams; + SmallVector<Intrinsic::IITDescriptor, 8> Descriptor; + getIntrinsicInfoTableEntries(CIID, Descriptor); + for (unsigned I = 0, E = Descriptor.size(); I != E; ++I) { + Intrinsic::IITDescriptor Operand = Descriptor[I]; + switch (Operand.Kind) { + case Intrinsic::IITDescriptor::Argument: + if (Operand.getArgumentKind() != + Intrinsic::IITDescriptor::AK_MatchType) { + if (I == 0) + TParams.push_back(OldInst.getType()); + else + TParams.push_back(OldInst.getOperand(I - 1)->getType()); + } + break; + case Intrinsic::IITDescriptor::SameVecWidthArgument: + ++I; + break; + default: + break; + } + } + + // Create intrinsic call. + LLVMContext &Ctx = NewFunc->getContext(); + Function *IFn = + Intrinsic::getDeclaration(NewFunc->getParent(), CIID, TParams); + SmallVector<Value *, 4> Args; + unsigned NumOperands = OldInst.getNumOperands(); + if (isa<CallInst>(OldInst)) + --NumOperands; + for (unsigned I = 0; I < NumOperands; ++I) { + Value *Op = OldInst.getOperand(I); + Args.push_back(Op); + } + if (const auto *CmpI = dyn_cast<FCmpInst>(&OldInst)) { + FCmpInst::Predicate Pred = CmpI->getPredicate(); + StringRef PredName = FCmpInst::getPredicateName(Pred); + Args.push_back(MetadataAsValue::get(Ctx, MDString::get(Ctx, PredName))); + } + + // The last arguments of a constrained intrinsic are metadata that + // represent rounding mode (absents in some intrinsics) and exception + // behavior. The inlined function uses default settings. + if (hasRoundingModeOperand(CIID)) + Args.push_back( + MetadataAsValue::get(Ctx, MDString::get(Ctx, "round.tonearest"))); + Args.push_back( + MetadataAsValue::get(Ctx, MDString::get(Ctx, "fpexcept.ignore"))); + + NewInst = CallInst::Create(IFn, Args, OldInst.getName() + ".strict"); + } + } + if (!NewInst) + NewInst = II->clone(); + return NewInst; +} + /// The specified block is found to be reachable, clone it and /// anything that it can reach. void PruningFunctionCloner::CloneBlock( @@ -379,7 +466,14 @@ void PruningFunctionCloner::CloneBlock( for (BasicBlock::const_iterator II = StartingInst, IE = --BB->end(); II != IE; ++II) { - Instruction *NewInst = II->clone(); + Instruction *NewInst = cloneInstruction(II); + + if (HostFuncIsStrictFP) { + // All function calls in the inlined function must get 'strictfp' + // attribute to prevent undesirable optimizations. + if (auto *Call = dyn_cast<CallInst>(NewInst)) + Call->addFnAttr(Attribute::StrictFP); + } // Eagerly remap operands to the newly cloned instruction, except for PHI // nodes for which we defer processing until we update the CFG. @@ -391,7 +485,7 @@ void PruningFunctionCloner::CloneBlock( // a mapping to that value rather than inserting a new instruction into // the basic block. if (Value *V = - SimplifyInstruction(NewInst, BB->getModule()->getDataLayout())) { + simplifyInstruction(NewInst, BB->getModule()->getDataLayout())) { // On the off-chance that this simplifies to an instruction in the old // function, map it back into the new function. if (NewFunc != OldFunc) @@ -674,7 +768,7 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, continue; // See if this instruction simplifies. - Value *SimpleV = SimplifyInstruction(I, DL); + Value *SimpleV = simplifyInstruction(I, DL); if (!SimpleV) continue; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneModule.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneModule.cpp index 57c273a0e3c5..55cda0f11e47 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneModule.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneModule.cpp @@ -11,13 +11,16 @@ // //===----------------------------------------------------------------------===// -#include "llvm/IR/Constant.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/ValueMapper.h" using namespace llvm; +namespace llvm { +class Constant; +} + static void copyComdat(GlobalObject *Dst, const GlobalObject *Src) { const Comdat *SC = Src->getComdat(); if (!SC) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp index cec159f6a448..f94d854f7ee8 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -53,7 +53,6 @@ #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/Verifier.h" -#include "llvm/Pass.h" #include "llvm/Support/BlockFrequency.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/Casting.h" @@ -62,12 +61,10 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstdint> #include <iterator> #include <map> -#include <set> #include <utility> #include <vector> @@ -249,9 +246,10 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, bool AggregateArgs, BlockFrequencyInfo *BFI, BranchProbabilityInfo *BPI, AssumptionCache *AC, bool AllowVarArgs, bool AllowAlloca, - std::string Suffix) + BasicBlock *AllocationBlock, std::string Suffix) : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), AC(AC), AllowVarArgs(AllowVarArgs), + BPI(BPI), AC(AC), AllocationBlock(AllocationBlock), + AllowVarArgs(AllowVarArgs), Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)), Suffix(Suffix) {} @@ -260,7 +258,7 @@ CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, BranchProbabilityInfo *BPI, AssumptionCache *AC, std::string Suffix) : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), AC(AC), AllowVarArgs(false), + BPI(BPI), AC(AC), AllocationBlock(nullptr), AllowVarArgs(false), Blocks(buildExtractionBlockSet(L.getBlocks(), &DT, /* AllowVarArgs */ false, /* AllowAlloca */ false)), @@ -922,6 +920,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::StackAlignment: case Attribute::WillReturn: case Attribute::WriteOnly: + case Attribute::AllocKind: + case Attribute::PresplitCoroutine: continue; // Those attributes should be safe to propagate to the extracted function. case Attribute::AlwaysInline: @@ -939,6 +939,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::NonLazyBind: case Attribute::NoRedZone: case Attribute::NoUnwind: + case Attribute::NoSanitizeBounds: case Attribute::NoSanitizeCoverage: case Attribute::NullPointerIsValid: case Attribute::OptForFuzzing: @@ -964,6 +965,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, break; // These attributes cannot be applied to functions. case Attribute::Alignment: + case Attribute::AllocatedPointer: + case Attribute::AllocAlign: case Attribute::ByVal: case Attribute::Dereferenceable: case Attribute::DereferenceableOrNull: @@ -1190,9 +1193,10 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, // Allocate a struct at the beginning of this function StructArgTy = StructType::get(newFunction->getContext(), ArgTypes); - Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr, - "structArg", - &codeReplacer->getParent()->front().front()); + Struct = new AllocaInst( + StructArgTy, DL.getAllocaAddrSpace(), nullptr, "structArg", + AllocationBlock ? &*AllocationBlock->getFirstInsertionPt() + : &codeReplacer->getParent()->front().front()); params.push_back(Struct); // Store aggregated inputs in the struct. @@ -1771,7 +1775,7 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, // Update the entry count of the function. if (BFI) { auto Count = BFI->getProfileCountFromFreq(EntryFreq.getFrequency()); - if (Count.hasValue()) + if (Count) newFunction->setEntryCount( ProfileCount(Count.getValue(), Function::PCT_Real)); // FIXME BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency()); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeLayout.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeLayout.cpp index dfb9f608eab2..1ff0f148b3a9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeLayout.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeLayout.cpp @@ -40,11 +40,20 @@ #include "llvm/Transforms/Utils/CodeLayout.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" using namespace llvm; #define DEBUG_TYPE "code-layout" +cl::opt<bool> EnableExtTspBlockPlacement( + "enable-ext-tsp-block-placement", cl::Hidden, cl::init(false), + cl::desc("Enable machine block placement based on the ext-tsp model, " + "optimizing I-cache utilization.")); + +cl::opt<bool> ApplyExtTspWithoutProfile( + "ext-tsp-apply-without-profile", + cl::desc("Whether to apply ext-tsp placement for instances w/o profile"), + cl::init(true), cl::Hidden); + // Algorithm-specific constants. The values are tuned for the best performance // of large-scale front-end bound binaries. static cl::opt<double> @@ -63,6 +72,12 @@ static cl::opt<unsigned> BackwardDistance( "ext-tsp-backward-distance", cl::Hidden, cl::init(640), cl::desc("The maximum distance (in bytes) of a backward jump for ExtTSP")); +// The maximum size of a chain created by the algorithm. The size is bounded +// so that the algorithm can efficiently process extremely large instance. +static cl::opt<unsigned> + MaxChainSize("ext-tsp-max-chain-size", cl::Hidden, cl::init(4096), + cl::desc("The maximum size of a chain to create.")); + // The maximum size of a chain for splitting. Larger values of the threshold // may yield better quality at the cost of worsen run-time. static cl::opt<unsigned> ChainSplitThreshold( @@ -115,7 +130,7 @@ enum class MergeTypeTy : int { X_Y, X1_Y_X2, Y_X2_X1, X2_X1_Y }; /// together with the corresponfiding merge 'type' and 'offset'. class MergeGainTy { public: - explicit MergeGainTy() {} + explicit MergeGainTy() = default; explicit MergeGainTy(double Score, size_t MergeOffset, MergeTypeTy MergeType) : Score(Score), MergeOffset(MergeOffset), MergeType(MergeType) {} @@ -142,7 +157,6 @@ private: MergeTypeTy MergeType{MergeTypeTy::X_Y}; }; -class Block; class Jump; class Chain; class ChainEdge; @@ -223,6 +237,8 @@ public: const std::vector<Block *> &blocks() const { return Blocks; } + size_t numBlocks() const { return Blocks.size(); } + const std::vector<std::pair<Chain *, ChainEdge *>> &edges() const { return Edges; } @@ -499,7 +515,7 @@ private: AllEdges.reserve(AllJumps.size()); for (auto &Block : AllBlocks) { for (auto &Jump : Block.OutJumps) { - const auto SuccBlock = Jump->Target; + auto SuccBlock = Jump->Target; auto CurEdge = Block.CurChain->getEdge(SuccBlock->CurChain); // this edge is already present in the graph if (CurEdge != nullptr) { @@ -589,6 +605,10 @@ private: if (ChainPred == ChainSucc) continue; + // Stop early if the combined chain violates the maximum allowed size + if (ChainPred->numBlocks() + ChainSucc->numBlocks() >= MaxChainSize) + continue; + // Compute the gain of merging the two chains auto CurGain = getBestMergeGain(ChainPred, ChainSucc, ChainEdge); if (CurGain.score() <= EPS) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CtorUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CtorUtils.cpp index 069a86f6ab33..c997f39508e3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CtorUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CtorUtils.cpp @@ -18,6 +18,7 @@ #include "llvm/IR/Module.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include <numeric> #define DEBUG_TYPE "ctor_utils" @@ -62,21 +63,20 @@ static void removeGlobalCtors(GlobalVariable *GCL, const BitVector &CtorsToRemov /// Given a llvm.global_ctors list that we can understand, /// return a list of the functions and null terminator as a vector. -static std::vector<Function *> parseGlobalCtors(GlobalVariable *GV) { - if (GV->getInitializer()->isNullValue()) - return std::vector<Function *>(); +static std::vector<std::pair<uint32_t, Function *>> +parseGlobalCtors(GlobalVariable *GV) { ConstantArray *CA = cast<ConstantArray>(GV->getInitializer()); - std::vector<Function *> Result; + std::vector<std::pair<uint32_t, Function *>> Result; Result.reserve(CA->getNumOperands()); for (auto &V : CA->operands()) { ConstantStruct *CS = cast<ConstantStruct>(V); - Result.push_back(dyn_cast<Function>(CS->getOperand(1))); + Result.emplace_back(cast<ConstantInt>(CS->getOperand(0))->getZExtValue(), + dyn_cast<Function>(CS->getOperand(1))); } return Result; } -/// Find the llvm.global_ctors list, verifying that all initializers have an -/// init priority of 65535. +/// Find the llvm.global_ctors list. static GlobalVariable *findGlobalCtors(Module &M) { GlobalVariable *GV = M.getGlobalVariable("llvm.global_ctors"); if (!GV) @@ -87,9 +87,11 @@ static GlobalVariable *findGlobalCtors(Module &M) { if (!GV->hasUniqueInitializer()) return nullptr; - if (isa<ConstantAggregateZero>(GV->getInitializer())) - return GV; - ConstantArray *CA = cast<ConstantArray>(GV->getInitializer()); + // If there are no ctors, then the initializer might be null/undef/poison. + // Ignore anything but an array. + ConstantArray *CA = dyn_cast<ConstantArray>(GV->getInitializer()); + if (!CA) + return nullptr; for (auto &V : CA->operands()) { if (isa<ConstantAggregateZero>(V)) @@ -98,54 +100,47 @@ static GlobalVariable *findGlobalCtors(Module &M) { if (isa<ConstantPointerNull>(CS->getOperand(1))) continue; - // Must have a function or null ptr. - if (!isa<Function>(CS->getOperand(1))) - return nullptr; - - // Init priority must be standard. - ConstantInt *CI = cast<ConstantInt>(CS->getOperand(0)); - if (CI->getZExtValue() != 65535) + // Can only handle global constructors with no arguments. + Function *F = dyn_cast<Function>(CS->getOperand(1)); + if (!F || F->arg_size() != 0) return nullptr; } - return GV; } /// Call "ShouldRemove" for every entry in M's global_ctor list and remove the /// entries for which it returns true. Return true if anything changed. bool llvm::optimizeGlobalCtorsList( - Module &M, function_ref<bool(Function *)> ShouldRemove) { + Module &M, function_ref<bool(uint32_t, Function *)> ShouldRemove) { GlobalVariable *GlobalCtors = findGlobalCtors(M); if (!GlobalCtors) return false; - std::vector<Function *> Ctors = parseGlobalCtors(GlobalCtors); + std::vector<std::pair<uint32_t, Function *>> Ctors = + parseGlobalCtors(GlobalCtors); if (Ctors.empty()) return false; bool MadeChange = false; - // Loop over global ctors, optimizing them when we can. - unsigned NumCtors = Ctors.size(); - BitVector CtorsToRemove(NumCtors); - for (unsigned i = 0; i != Ctors.size() && NumCtors > 0; ++i) { - Function *F = Ctors[i]; - // Found a null terminator in the middle of the list, prune off the rest of - // the list. + BitVector CtorsToRemove(Ctors.size()); + std::vector<size_t> CtorsByPriority(Ctors.size()); + std::iota(CtorsByPriority.begin(), CtorsByPriority.end(), 0); + stable_sort(CtorsByPriority, [&](size_t LHS, size_t RHS) { + return Ctors[LHS].first < Ctors[RHS].first; + }); + for (unsigned CtorIndex : CtorsByPriority) { + const uint32_t Priority = Ctors[CtorIndex].first; + Function *F = Ctors[CtorIndex].second; if (!F) continue; LLVM_DEBUG(dbgs() << "Optimizing Global Constructor: " << *F << "\n"); - // We cannot simplify external ctor functions. - if (F->empty()) - continue; - // If we can evaluate the ctor at compile time, do. - if (ShouldRemove(F)) { - Ctors[i] = nullptr; - CtorsToRemove.set(i); - NumCtors--; + if (ShouldRemove(Priority, F)) { + Ctors[CtorIndex].second = nullptr; + CtorsToRemove.set(CtorIndex); MadeChange = true; continue; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/Debugify.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/Debugify.cpp index 589622d69578..205f7a7d9ed2 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/Debugify.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/Debugify.cpp @@ -37,12 +37,16 @@ namespace { cl::opt<bool> Quiet("debugify-quiet", cl::desc("Suppress verbose debugify output")); +cl::opt<uint64_t> DebugifyFunctionsLimit( + "debugify-func-limit", + cl::desc("Set max number of processed functions per pass."), + cl::init(UINT_MAX)); + enum class Level { Locations, LocationsAndVariables }; -// Used for the synthetic mode only. cl::opt<Level> DebugifyLevel( "debugify-level", cl::desc("Kind of debug info to add"), cl::values(clEnumValN(Level::Locations, "locations", "Locations only"), @@ -210,15 +214,15 @@ bool llvm::applyDebugifyMetadata( static bool applyDebugify(Function &F, enum DebugifyMode Mode = DebugifyMode::SyntheticDebugInfo, - DebugInfoPerPassMap *DIPreservationMap = nullptr, + DebugInfoPerPass *DebugInfoBeforePass = nullptr, StringRef NameOfWrappedPass = "") { Module &M = *F.getParent(); auto FuncIt = F.getIterator(); if (Mode == DebugifyMode::SyntheticDebugInfo) return applyDebugifyMetadata(M, make_range(FuncIt, std::next(FuncIt)), "FunctionDebugify: ", /*ApplyToMF*/ nullptr); - assert(DIPreservationMap); - return collectDebugInfoMetadata(M, M.functions(), *DIPreservationMap, + assert(DebugInfoBeforePass); + return collectDebugInfoMetadata(M, M.functions(), *DebugInfoBeforePass, "FunctionDebugify (original debuginfo)", NameOfWrappedPass); } @@ -226,12 +230,12 @@ applyDebugify(Function &F, static bool applyDebugify(Module &M, enum DebugifyMode Mode = DebugifyMode::SyntheticDebugInfo, - DebugInfoPerPassMap *DIPreservationMap = nullptr, + DebugInfoPerPass *DebugInfoBeforePass = nullptr, StringRef NameOfWrappedPass = "") { if (Mode == DebugifyMode::SyntheticDebugInfo) return applyDebugifyMetadata(M, M.functions(), "ModuleDebugify: ", /*ApplyToMF*/ nullptr); - return collectDebugInfoMetadata(M, M.functions(), *DIPreservationMap, + return collectDebugInfoMetadata(M, M.functions(), *DebugInfoBeforePass, "ModuleDebugify (original debuginfo)", NameOfWrappedPass); } @@ -267,7 +271,7 @@ bool llvm::stripDebugifyMetadata(Module &M) { SmallVector<MDNode *, 4> Flags(NMD->operands()); NMD->clearOperands(); for (MDNode *Flag : Flags) { - MDString *Key = dyn_cast_or_null<MDString>(Flag->getOperand(1)); + auto *Key = cast<MDString>(Flag->getOperand(1)); if (Key->getString() == "Debug Info Version") { Changed = true; continue; @@ -283,32 +287,37 @@ bool llvm::stripDebugifyMetadata(Module &M) { bool llvm::collectDebugInfoMetadata(Module &M, iterator_range<Module::iterator> Functions, - DebugInfoPerPassMap &DIPreservationMap, + DebugInfoPerPass &DebugInfoBeforePass, StringRef Banner, StringRef NameOfWrappedPass) { LLVM_DEBUG(dbgs() << Banner << ": (before) " << NameOfWrappedPass << '\n'); - // Clear the map with the debug info before every single pass. - DIPreservationMap.clear(); - if (!M.getNamedMetadata("llvm.dbg.cu")) { dbg() << Banner << ": Skipping module without debug info\n"; return false; } + uint64_t FunctionsCnt = DebugInfoBeforePass.DIFunctions.size(); // Visit each instruction. for (Function &F : Functions) { + // Use DI collected after previous Pass (when -debugify-each is used). + if (DebugInfoBeforePass.DIFunctions.count(&F)) + continue; + if (isFunctionSkipped(F)) continue; + // Stop collecting DI if the Functions number reached the limit. + if (++FunctionsCnt >= DebugifyFunctionsLimit) + break; // Collect the DISubprogram. auto *SP = F.getSubprogram(); - DIPreservationMap[NameOfWrappedPass].DIFunctions.insert({F.getName(), SP}); + DebugInfoBeforePass.DIFunctions.insert({&F, SP}); if (SP) { LLVM_DEBUG(dbgs() << " Collecting subprogram: " << *SP << '\n'); for (const DINode *DN : SP->getRetainedNodes()) { if (const auto *DV = dyn_cast<DILocalVariable>(DN)) { - DIPreservationMap[NameOfWrappedPass].DIVariables[DV] = 0; + DebugInfoBeforePass.DIVariables[DV] = 0; } } } @@ -320,20 +329,22 @@ bool llvm::collectDebugInfoMetadata(Module &M, if (isa<PHINode>(I)) continue; - // Collect dbg.values and dbg.declares. - if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I)) { - if (!SP) - continue; - // Skip inlined variables. - if (I.getDebugLoc().getInlinedAt()) + // Cllect dbg.values and dbg.declare. + if (DebugifyLevel > Level::Locations) { + if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I)) { + if (!SP) + continue; + // Skip inlined variables. + if (I.getDebugLoc().getInlinedAt()) + continue; + // Skip undef values. + if (DVI->isUndef()) + continue; + + auto *Var = DVI->getVariable(); + DebugInfoBeforePass.DIVariables[Var]++; continue; - // Skip undef values. - if (DVI->isUndef()) - continue; - - auto *Var = DVI->getVariable(); - DIPreservationMap[NameOfWrappedPass].DIVariables[Var]++; - continue; + } } // Skip debug instructions other than dbg.value and dbg.declare. @@ -341,11 +352,11 @@ bool llvm::collectDebugInfoMetadata(Module &M, continue; LLVM_DEBUG(dbgs() << " Collecting info for inst: " << I << '\n'); - DIPreservationMap[NameOfWrappedPass].InstToDelete.insert({&I, &I}); + DebugInfoBeforePass.InstToDelete.insert({&I, &I}); const DILocation *Loc = I.getDebugLoc().get(); bool HasLoc = Loc != nullptr; - DIPreservationMap[NameOfWrappedPass].DILocations.insert({&I, HasLoc}); + DebugInfoBeforePass.DILocations.insert({&I, HasLoc}); } } } @@ -367,12 +378,12 @@ static bool checkFunctions(const DebugFnMap &DIFunctionsBefore, if (SPIt == DIFunctionsBefore.end()) { if (ShouldWriteIntoJSON) Bugs.push_back(llvm::json::Object({{"metadata", "DISubprogram"}, - {"name", F.first}, + {"name", F.first->getName()}, {"action", "not-generate"}})); else dbg() << "ERROR: " << NameOfWrappedPass - << " did not generate DISubprogram for " << F.first << " from " - << FileNameFromCU << '\n'; + << " did not generate DISubprogram for " << F.first->getName() + << " from " << FileNameFromCU << '\n'; Preserved = false; } else { auto SP = SPIt->second; @@ -382,11 +393,11 @@ static bool checkFunctions(const DebugFnMap &DIFunctionsBefore, // a debug info bug. if (ShouldWriteIntoJSON) Bugs.push_back(llvm::json::Object({{"metadata", "DISubprogram"}, - {"name", F.first}, + {"name", F.first->getName()}, {"action", "drop"}})); else dbg() << "ERROR: " << NameOfWrappedPass << " dropped DISubprogram of " - << F.first << " from " << FileNameFromCU << '\n'; + << F.first->getName() << " from " << FileNameFromCU << '\n'; Preserved = false; } } @@ -515,7 +526,7 @@ static void writeJSON(StringRef OrigDIVerifyBugsReportFilePath, bool llvm::checkDebugInfoMetadata(Module &M, iterator_range<Module::iterator> Functions, - DebugInfoPerPassMap &DIPreservationMap, + DebugInfoPerPass &DebugInfoBeforePass, StringRef Banner, StringRef NameOfWrappedPass, StringRef OrigDIVerifyBugsReportFilePath) { LLVM_DEBUG(dbgs() << Banner << ": (after) " << NameOfWrappedPass << '\n'); @@ -526,24 +537,26 @@ bool llvm::checkDebugInfoMetadata(Module &M, } // Map the debug info holding DIs after a pass. - DebugInfoPerPassMap DIPreservationAfter; + DebugInfoPerPass DebugInfoAfterPass; // Visit each instruction. for (Function &F : Functions) { if (isFunctionSkipped(F)) continue; + // Don't process functions without DI collected before the Pass. + if (!DebugInfoBeforePass.DIFunctions.count(&F)) + continue; // TODO: Collect metadata other than DISubprograms. // Collect the DISubprogram. auto *SP = F.getSubprogram(); - DIPreservationAfter[NameOfWrappedPass].DIFunctions.insert( - {F.getName(), SP}); + DebugInfoAfterPass.DIFunctions.insert({&F, SP}); if (SP) { LLVM_DEBUG(dbgs() << " Collecting subprogram: " << *SP << '\n'); for (const DINode *DN : SP->getRetainedNodes()) { if (const auto *DV = dyn_cast<DILocalVariable>(DN)) { - DIPreservationAfter[NameOfWrappedPass].DIVariables[DV] = 0; + DebugInfoAfterPass.DIVariables[DV] = 0; } } } @@ -556,19 +569,21 @@ bool llvm::checkDebugInfoMetadata(Module &M, continue; // Collect dbg.values and dbg.declares. - if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I)) { - if (!SP) - continue; - // Skip inlined variables. - if (I.getDebugLoc().getInlinedAt()) - continue; - // Skip undef values. - if (DVI->isUndef()) + if (DebugifyLevel > Level::Locations) { + if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I)) { + if (!SP) + continue; + // Skip inlined variables. + if (I.getDebugLoc().getInlinedAt()) + continue; + // Skip undef values. + if (DVI->isUndef()) + continue; + + auto *Var = DVI->getVariable(); + DebugInfoAfterPass.DIVariables[Var]++; continue; - - auto *Var = DVI->getVariable(); - DIPreservationAfter[NameOfWrappedPass].DIVariables[Var]++; - continue; + } } // Skip debug instructions other than dbg.value and dbg.declare. @@ -580,7 +595,7 @@ bool llvm::checkDebugInfoMetadata(Module &M, const DILocation *Loc = I.getDebugLoc().get(); bool HasLoc = Loc != nullptr; - DIPreservationAfter[NameOfWrappedPass].DILocations.insert({&I, HasLoc}); + DebugInfoAfterPass.DILocations.insert({&I, HasLoc}); } } } @@ -590,16 +605,16 @@ bool llvm::checkDebugInfoMetadata(Module &M, (cast<DICompileUnit>(M.getNamedMetadata("llvm.dbg.cu")->getOperand(0))) ->getFilename(); - auto DIFunctionsBefore = DIPreservationMap[NameOfWrappedPass].DIFunctions; - auto DIFunctionsAfter = DIPreservationAfter[NameOfWrappedPass].DIFunctions; + auto DIFunctionsBefore = DebugInfoBeforePass.DIFunctions; + auto DIFunctionsAfter = DebugInfoAfterPass.DIFunctions; - auto DILocsBefore = DIPreservationMap[NameOfWrappedPass].DILocations; - auto DILocsAfter = DIPreservationAfter[NameOfWrappedPass].DILocations; + auto DILocsBefore = DebugInfoBeforePass.DILocations; + auto DILocsAfter = DebugInfoAfterPass.DILocations; - auto InstToDelete = DIPreservationMap[NameOfWrappedPass].InstToDelete; + auto InstToDelete = DebugInfoBeforePass.InstToDelete; - auto DIVarsBefore = DIPreservationMap[NameOfWrappedPass].DIVariables; - auto DIVarsAfter = DIPreservationAfter[NameOfWrappedPass].DIVariables; + auto DIVarsBefore = DebugInfoBeforePass.DIVariables; + auto DIVarsAfter = DebugInfoAfterPass.DIVariables; bool ShouldWriteIntoJSON = !OrigDIVerifyBugsReportFilePath.empty(); llvm::json::Array Bugs; @@ -626,6 +641,11 @@ bool llvm::checkDebugInfoMetadata(Module &M, else dbg() << ResultBanner << ": FAIL\n"; + // In the case of the `debugify-each`, no need to go over all the instructions + // again in the collectDebugInfoMetadata(), since as an input we can use + // the debugging information from the previous pass. + DebugInfoBeforePass = DebugInfoAfterPass; + LLVM_DEBUG(dbgs() << "\n\n"); return Result; } @@ -770,14 +790,14 @@ bool checkDebugifyMetadata(Module &M, /// legacy module pass manager. struct DebugifyModulePass : public ModulePass { bool runOnModule(Module &M) override { - return applyDebugify(M, Mode, DIPreservationMap, NameOfWrappedPass); + return applyDebugify(M, Mode, DebugInfoBeforePass, NameOfWrappedPass); } DebugifyModulePass(enum DebugifyMode Mode = DebugifyMode::SyntheticDebugInfo, StringRef NameOfWrappedPass = "", - DebugInfoPerPassMap *DIPreservationMap = nullptr) + DebugInfoPerPass *DebugInfoBeforePass = nullptr) : ModulePass(ID), NameOfWrappedPass(NameOfWrappedPass), - DIPreservationMap(DIPreservationMap), Mode(Mode) {} + DebugInfoBeforePass(DebugInfoBeforePass), Mode(Mode) {} void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesAll(); @@ -787,7 +807,7 @@ struct DebugifyModulePass : public ModulePass { private: StringRef NameOfWrappedPass; - DebugInfoPerPassMap *DIPreservationMap; + DebugInfoPerPass *DebugInfoBeforePass; enum DebugifyMode Mode; }; @@ -795,15 +815,15 @@ private: /// single function, used with the legacy module pass manager. struct DebugifyFunctionPass : public FunctionPass { bool runOnFunction(Function &F) override { - return applyDebugify(F, Mode, DIPreservationMap, NameOfWrappedPass); + return applyDebugify(F, Mode, DebugInfoBeforePass, NameOfWrappedPass); } DebugifyFunctionPass( enum DebugifyMode Mode = DebugifyMode::SyntheticDebugInfo, StringRef NameOfWrappedPass = "", - DebugInfoPerPassMap *DIPreservationMap = nullptr) + DebugInfoPerPass *DebugInfoBeforePass = nullptr) : FunctionPass(ID), NameOfWrappedPass(NameOfWrappedPass), - DIPreservationMap(DIPreservationMap), Mode(Mode) {} + DebugInfoBeforePass(DebugInfoBeforePass), Mode(Mode) {} void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesAll(); @@ -813,7 +833,7 @@ struct DebugifyFunctionPass : public FunctionPass { private: StringRef NameOfWrappedPass; - DebugInfoPerPassMap *DIPreservationMap; + DebugInfoPerPass *DebugInfoBeforePass; enum DebugifyMode Mode; }; @@ -825,7 +845,7 @@ struct CheckDebugifyModulePass : public ModulePass { return checkDebugifyMetadata(M, M.functions(), NameOfWrappedPass, "CheckModuleDebugify", Strip, StatsMap); return checkDebugInfoMetadata( - M, M.functions(), *DIPreservationMap, + M, M.functions(), *DebugInfoBeforePass, "CheckModuleDebugify (original debuginfo)", NameOfWrappedPass, OrigDIVerifyBugsReportFilePath); } @@ -834,11 +854,11 @@ struct CheckDebugifyModulePass : public ModulePass { bool Strip = false, StringRef NameOfWrappedPass = "", DebugifyStatsMap *StatsMap = nullptr, enum DebugifyMode Mode = DebugifyMode::SyntheticDebugInfo, - DebugInfoPerPassMap *DIPreservationMap = nullptr, + DebugInfoPerPass *DebugInfoBeforePass = nullptr, StringRef OrigDIVerifyBugsReportFilePath = "") : ModulePass(ID), NameOfWrappedPass(NameOfWrappedPass), OrigDIVerifyBugsReportFilePath(OrigDIVerifyBugsReportFilePath), - StatsMap(StatsMap), DIPreservationMap(DIPreservationMap), Mode(Mode), + StatsMap(StatsMap), DebugInfoBeforePass(DebugInfoBeforePass), Mode(Mode), Strip(Strip) {} void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -851,7 +871,7 @@ private: StringRef NameOfWrappedPass; StringRef OrigDIVerifyBugsReportFilePath; DebugifyStatsMap *StatsMap; - DebugInfoPerPassMap *DIPreservationMap; + DebugInfoPerPass *DebugInfoBeforePass; enum DebugifyMode Mode; bool Strip; }; @@ -867,7 +887,7 @@ struct CheckDebugifyFunctionPass : public FunctionPass { NameOfWrappedPass, "CheckFunctionDebugify", Strip, StatsMap); return checkDebugInfoMetadata( - M, make_range(FuncIt, std::next(FuncIt)), *DIPreservationMap, + M, make_range(FuncIt, std::next(FuncIt)), *DebugInfoBeforePass, "CheckFunctionDebugify (original debuginfo)", NameOfWrappedPass, OrigDIVerifyBugsReportFilePath); } @@ -876,11 +896,11 @@ struct CheckDebugifyFunctionPass : public FunctionPass { bool Strip = false, StringRef NameOfWrappedPass = "", DebugifyStatsMap *StatsMap = nullptr, enum DebugifyMode Mode = DebugifyMode::SyntheticDebugInfo, - DebugInfoPerPassMap *DIPreservationMap = nullptr, + DebugInfoPerPass *DebugInfoBeforePass = nullptr, StringRef OrigDIVerifyBugsReportFilePath = "") : FunctionPass(ID), NameOfWrappedPass(NameOfWrappedPass), OrigDIVerifyBugsReportFilePath(OrigDIVerifyBugsReportFilePath), - StatsMap(StatsMap), DIPreservationMap(DIPreservationMap), Mode(Mode), + StatsMap(StatsMap), DebugInfoBeforePass(DebugInfoBeforePass), Mode(Mode), Strip(Strip) {} void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -893,7 +913,7 @@ private: StringRef NameOfWrappedPass; StringRef OrigDIVerifyBugsReportFilePath; DebugifyStatsMap *StatsMap; - DebugInfoPerPassMap *DIPreservationMap; + DebugInfoPerPass *DebugInfoBeforePass; enum DebugifyMode Mode; bool Strip; }; @@ -923,21 +943,21 @@ void llvm::exportDebugifyStats(StringRef Path, const DebugifyStatsMap &Map) { ModulePass *createDebugifyModulePass(enum DebugifyMode Mode, llvm::StringRef NameOfWrappedPass, - DebugInfoPerPassMap *DIPreservationMap) { + DebugInfoPerPass *DebugInfoBeforePass) { if (Mode == DebugifyMode::SyntheticDebugInfo) return new DebugifyModulePass(); assert(Mode == DebugifyMode::OriginalDebugInfo && "Must be original mode"); - return new DebugifyModulePass(Mode, NameOfWrappedPass, DIPreservationMap); + return new DebugifyModulePass(Mode, NameOfWrappedPass, DebugInfoBeforePass); } FunctionPass * createDebugifyFunctionPass(enum DebugifyMode Mode, llvm::StringRef NameOfWrappedPass, - DebugInfoPerPassMap *DIPreservationMap) { + DebugInfoPerPass *DebugInfoBeforePass) { if (Mode == DebugifyMode::SyntheticDebugInfo) return new DebugifyFunctionPass(); assert(Mode == DebugifyMode::OriginalDebugInfo && "Must be original mode"); - return new DebugifyFunctionPass(Mode, NameOfWrappedPass, DIPreservationMap); + return new DebugifyFunctionPass(Mode, NameOfWrappedPass, DebugInfoBeforePass); } PreservedAnalyses NewPMDebugifyPass::run(Module &M, ModuleAnalysisManager &) { @@ -948,25 +968,25 @@ PreservedAnalyses NewPMDebugifyPass::run(Module &M, ModuleAnalysisManager &) { ModulePass *createCheckDebugifyModulePass( bool Strip, StringRef NameOfWrappedPass, DebugifyStatsMap *StatsMap, - enum DebugifyMode Mode, DebugInfoPerPassMap *DIPreservationMap, + enum DebugifyMode Mode, DebugInfoPerPass *DebugInfoBeforePass, StringRef OrigDIVerifyBugsReportFilePath) { if (Mode == DebugifyMode::SyntheticDebugInfo) return new CheckDebugifyModulePass(Strip, NameOfWrappedPass, StatsMap); assert(Mode == DebugifyMode::OriginalDebugInfo && "Must be original mode"); return new CheckDebugifyModulePass(false, NameOfWrappedPass, nullptr, Mode, - DIPreservationMap, + DebugInfoBeforePass, OrigDIVerifyBugsReportFilePath); } FunctionPass *createCheckDebugifyFunctionPass( bool Strip, StringRef NameOfWrappedPass, DebugifyStatsMap *StatsMap, - enum DebugifyMode Mode, DebugInfoPerPassMap *DIPreservationMap, + enum DebugifyMode Mode, DebugInfoPerPass *DebugInfoBeforePass, StringRef OrigDIVerifyBugsReportFilePath) { if (Mode == DebugifyMode::SyntheticDebugInfo) return new CheckDebugifyFunctionPass(Strip, NameOfWrappedPass, StatsMap); assert(Mode == DebugifyMode::OriginalDebugInfo && "Must be original mode"); return new CheckDebugifyFunctionPass(false, NameOfWrappedPass, nullptr, Mode, - DIPreservationMap, + DebugInfoBeforePass, OrigDIVerifyBugsReportFilePath); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp index 5f53d794fe8a..f6f80540ad95 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp @@ -8,11 +8,10 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/Analysis/CFG.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/Type.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" using namespace llvm; /// DemoteRegToStack - This function takes a virtual register computed by an diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/Evaluator.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/Evaluator.cpp index e73287c060ae..7b8d8553bac2 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/Evaluator.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/Evaluator.cpp @@ -29,7 +29,6 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/Operator.h" #include "llvm/IR/Type.h" #include "llvm/IR/User.h" @@ -37,7 +36,6 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include <iterator> #define DEBUG_TYPE "evaluator" @@ -219,10 +217,13 @@ Constant *Evaluator::ComputeLoadResult(Constant *P, Type *Ty) { P = cast<Constant>(P->stripAndAccumulateConstantOffsets( DL, Offset, /* AllowNonInbounds */ true)); Offset = Offset.sextOrTrunc(DL.getIndexTypeSizeInBits(P->getType())); - auto *GV = dyn_cast<GlobalVariable>(P); - if (!GV) - return nullptr; + if (auto *GV = dyn_cast<GlobalVariable>(P)) + return ComputeLoadResult(GV, Ty, Offset); + return nullptr; +} +Constant *Evaluator::ComputeLoadResult(GlobalVariable *GV, Type *Ty, + const APInt &Offset) { auto It = MutatedMemory.find(GV); if (It != MutatedMemory.end()) return It->second.read(Ty, Offset, DL); @@ -335,50 +336,6 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, BasicBlock *&NextBB, auto Res = MutatedMemory.try_emplace(GV, GV->getInitializer()); if (!Res.first->second.write(Val, Offset, DL)) return false; - } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(CurInst)) { - InstResult = ConstantExpr::get(BO->getOpcode(), - getVal(BO->getOperand(0)), - getVal(BO->getOperand(1))); - LLVM_DEBUG(dbgs() << "Found a BinaryOperator! Simplifying: " - << *InstResult << "\n"); - } else if (CmpInst *CI = dyn_cast<CmpInst>(CurInst)) { - InstResult = ConstantExpr::getCompare(CI->getPredicate(), - getVal(CI->getOperand(0)), - getVal(CI->getOperand(1))); - LLVM_DEBUG(dbgs() << "Found a CmpInst! Simplifying: " << *InstResult - << "\n"); - } else if (CastInst *CI = dyn_cast<CastInst>(CurInst)) { - InstResult = ConstantExpr::getCast(CI->getOpcode(), - getVal(CI->getOperand(0)), - CI->getType()); - LLVM_DEBUG(dbgs() << "Found a Cast! Simplifying: " << *InstResult - << "\n"); - } else if (SelectInst *SI = dyn_cast<SelectInst>(CurInst)) { - InstResult = ConstantExpr::getSelect(getVal(SI->getOperand(0)), - getVal(SI->getOperand(1)), - getVal(SI->getOperand(2))); - LLVM_DEBUG(dbgs() << "Found a Select! Simplifying: " << *InstResult - << "\n"); - } else if (auto *EVI = dyn_cast<ExtractValueInst>(CurInst)) { - InstResult = ConstantExpr::getExtractValue( - getVal(EVI->getAggregateOperand()), EVI->getIndices()); - LLVM_DEBUG(dbgs() << "Found an ExtractValueInst! Simplifying: " - << *InstResult << "\n"); - } else if (auto *IVI = dyn_cast<InsertValueInst>(CurInst)) { - InstResult = ConstantExpr::getInsertValue( - getVal(IVI->getAggregateOperand()), - getVal(IVI->getInsertedValueOperand()), IVI->getIndices()); - LLVM_DEBUG(dbgs() << "Found an InsertValueInst! Simplifying: " - << *InstResult << "\n"); - } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(CurInst)) { - Constant *P = getVal(GEP->getOperand(0)); - SmallVector<Constant*, 8> GEPOps; - for (Use &Op : llvm::drop_begin(GEP->operands())) - GEPOps.push_back(getVal(Op)); - InstResult = - ConstantExpr::getGetElementPtr(GEP->getSourceElementType(), P, GEPOps, - cast<GEPOperator>(GEP)->isInBounds()); - LLVM_DEBUG(dbgs() << "Found a GEP! Simplifying: " << *InstResult << "\n"); } else if (LoadInst *LI = dyn_cast<LoadInst>(CurInst)) { if (!LI->isSimple()) { LLVM_DEBUG( @@ -438,16 +395,39 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, BasicBlock *&NextBB, << "intrinsic.\n"); return false; } + + auto *LenC = dyn_cast<ConstantInt>(getVal(MSI->getLength())); + if (!LenC) { + LLVM_DEBUG(dbgs() << "Memset with unknown length.\n"); + return false; + } + Constant *Ptr = getVal(MSI->getDest()); + APInt Offset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0); + Ptr = cast<Constant>(Ptr->stripAndAccumulateConstantOffsets( + DL, Offset, /* AllowNonInbounds */ true)); + auto *GV = dyn_cast<GlobalVariable>(Ptr); + if (!GV) { + LLVM_DEBUG(dbgs() << "Memset with unknown base.\n"); + return false; + } + Constant *Val = getVal(MSI->getValue()); - Constant *DestVal = - ComputeLoadResult(getVal(Ptr), MSI->getValue()->getType()); - if (Val->isNullValue() && DestVal && DestVal->isNullValue()) { - // This memset is a no-op. - LLVM_DEBUG(dbgs() << "Ignoring no-op memset.\n"); - ++CurInst; - continue; + APInt Len = LenC->getValue(); + while (Len != 0) { + Constant *DestVal = ComputeLoadResult(GV, Val->getType(), Offset); + if (DestVal != Val) { + LLVM_DEBUG(dbgs() << "Memset is not a no-op at offset " + << Offset << " of " << *GV << ".\n"); + return false; + } + ++Offset; + --Len; } + + LLVM_DEBUG(dbgs() << "Ignoring no-op memset.\n"); + ++CurInst; + continue; } if (II->isLifetimeStartOrEnd()) { @@ -602,11 +582,16 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, BasicBlock *&NextBB, LLVM_DEBUG(dbgs() << "Successfully evaluated block.\n"); return true; } else { - // Did not know how to evaluate this! - LLVM_DEBUG( - dbgs() << "Failed to evaluate block due to unhandled instruction." - "\n"); - return false; + SmallVector<Constant *> Ops; + for (Value *Op : CurInst->operands()) + Ops.push_back(getVal(Op)); + InstResult = ConstantFoldInstOperands(&*CurInst, Ops, DL, TLI); + if (!InstResult) { + LLVM_DEBUG(dbgs() << "Cannot fold instruction: " << *CurInst << "\n"); + return false; + } + LLVM_DEBUG(dbgs() << "Folded instruction " << *CurInst << " to " + << *InstResult << "\n"); } if (!CurInst->use_empty()) { @@ -631,6 +616,8 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, BasicBlock *&NextBB, /// function. bool Evaluator::EvaluateFunction(Function *F, Constant *&RetVal, const SmallVectorImpl<Constant*> &ActualArgs) { + assert(ActualArgs.size() == F->arg_size() && "wrong number of arguments"); + // Check to see if this function is already executing (recursion). If so, // bail out. TODO: we might want to accept limited recursion. if (is_contained(CallStack, F)) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/FixIrreducible.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/FixIrreducible.cpp index 8de3ce876bab..24539bd231c6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/FixIrreducible.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/FixIrreducible.cpp @@ -68,6 +68,7 @@ #include "llvm/Transforms/Utils/FixIrreducible.h" #include "llvm/ADT/SCCIterator.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" @@ -137,10 +138,18 @@ static void reconnectChildLoops(LoopInfo &LI, Loop *ParentLoop, Loop *NewLoop, // not be necessary if we can retain such backedges. if (Headers.count(Child->getHeader())) { for (auto BB : Child->blocks()) { + if (LI.getLoopFor(BB) != Child) + continue; LI.changeLoopFor(BB, NewLoop); LLVM_DEBUG(dbgs() << "moved block from child: " << BB->getName() << "\n"); } + std::vector<Loop *> GrandChildLoops; + std::swap(GrandChildLoops, Child->getSubLoopsVector()); + for (auto GrandChildLoop : GrandChildLoops) { + GrandChildLoop->setParentLoop(nullptr); + NewLoop->addChildLoop(GrandChildLoop); + } LI.destroy(Child); LLVM_DEBUG(dbgs() << "subsumed child loop (common header)\n"); continue; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp index 2946c0018c31..193806d9cc87 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp @@ -12,8 +12,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/FunctionImportUtils.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/InstIterator.h" using namespace llvm; /// Checks if we should import SGV as a definition, otherwise import as a diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/GlobalStatus.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/GlobalStatus.cpp index c1c5f5cc879f..c5aded3c45f4 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/GlobalStatus.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/GlobalStatus.cpp @@ -38,22 +38,26 @@ static AtomicOrdering strongerOrdering(AtomicOrdering X, AtomicOrdering Y) { } /// It is safe to destroy a constant iff it is only used by constants itself. -/// Note that constants cannot be cyclic, so this test is pretty easy to -/// implement recursively. -/// +/// Note that while constants cannot be cyclic, they can be tree-like, so we +/// should keep a visited set to avoid exponential runtime. bool llvm::isSafeToDestroyConstant(const Constant *C) { - if (isa<GlobalValue>(C)) - return false; - - if (isa<ConstantData>(C)) - return false; + SmallVector<const Constant *, 8> Worklist; + SmallPtrSet<const Constant *, 8> Visited; + Worklist.push_back(C); + while (!Worklist.empty()) { + const Constant *C = Worklist.pop_back_val(); + if (!Visited.insert(C).second) + continue; + if (isa<GlobalValue>(C) || isa<ConstantData>(C)) + return false; - for (const User *U : C->users()) - if (const Constant *CU = dyn_cast<Constant>(U)) { - if (!isSafeToDestroyConstant(CU)) + for (const User *U : C->users()) { + if (const Constant *CU = dyn_cast<Constant>(U)) + Worklist.push_back(CU); + else return false; - } else - return false; + } + } return true; } @@ -100,6 +104,8 @@ static bool analyzeGlobalAux(const Value *V, GlobalStatus &GS, if (SI->isVolatile()) return true; + ++GS.NumStores; + GS.Ordering = strongerOrdering(GS.Ordering, SI->getOrdering()); // If this is a direct store to the global (i.e., the global is a scalar diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp index 047bf5569ded..55bcb6f3b121 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/InjectTLIMappings.cpp @@ -19,7 +19,6 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/InstIterator.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/ModuleUtils.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/InlineFunction.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/InlineFunction.cpp index 923bcc781e47..2fb00f95b749 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -37,7 +37,6 @@ #include "llvm/IR/CFG.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" -#include "llvm/IR/DIBuilder.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DebugInfoMetadata.h" @@ -85,7 +84,7 @@ EnableNoAliasConversion("enable-noalias-to-md-conversion", cl::init(true), static cl::opt<bool> UseNoAliasIntrinsic("use-noalias-intrinsic-during-inlining", cl::Hidden, - cl::ZeroOrMore, cl::init(true), + cl::init(true), cl::desc("Use the llvm.experimental.noalias.scope.decl " "intrinsic during inlining.")); @@ -1044,12 +1043,10 @@ static void AddAliasScopeMetadata(CallBase &CB, ValueToValueMapTy &VMap, } for (Value *Arg : Call->args()) { - // We need to check the underlying objects of all arguments, not just - // the pointer arguments, because we might be passing pointers as - // integers, etc. - // However, if we know that the call only accesses pointer arguments, - // then we only need to check the pointer arguments. - if (IsArgMemOnlyCall && !Arg->getType()->isPointerTy()) + // Only care about pointer arguments. If a noalias argument is + // accessed through a non-pointer argument, it must be captured + // first (e.g. via ptrtoint), and we protect against captures below. + if (!Arg->getType()->isPointerTy()) continue; PtrArgs.push_back(Arg); @@ -1080,7 +1077,8 @@ static void AddAliasScopeMetadata(CallBase &CB, ValueToValueMapTy &VMap, // Figure out if we're derived from anything that is not a noalias // argument. - bool CanDeriveViaCapture = false, UsesAliasingPtr = false; + bool RequiresNoCaptureBefore = false, UsesAliasingPtr = false, + UsesUnknownObject = false; for (const Value *V : ObjSet) { // Is this value a constant that cannot be derived from any pointer // value (we need to exclude constant expressions, for example, that @@ -1101,19 +1099,28 @@ static void AddAliasScopeMetadata(CallBase &CB, ValueToValueMapTy &VMap, UsesAliasingPtr = true; } - // If this is not some identified function-local object (which cannot - // directly alias a noalias argument), or some other argument (which, - // by definition, also cannot alias a noalias argument), then we could - // alias a noalias argument that has been captured). - if (!isa<Argument>(V) && - !isIdentifiedFunctionLocal(const_cast<Value*>(V))) - CanDeriveViaCapture = true; + if (isEscapeSource(V)) { + // An escape source can only alias with a noalias argument if it has + // been captured beforehand. + RequiresNoCaptureBefore = true; + } else if (!isa<Argument>(V) && !isIdentifiedObject(V)) { + // If this is neither an escape source, nor some identified object + // (which cannot directly alias a noalias argument), nor some other + // argument (which, by definition, also cannot alias a noalias + // argument), conservatively do not make any assumptions. + UsesUnknownObject = true; + } } + // Nothing we can do if the used underlying object cannot be reliably + // determined. + if (UsesUnknownObject) + continue; + // A function call can always get captured noalias pointers (via other // parameters, globals, etc.). if (IsFuncCall && !IsArgMemOnlyCall) - CanDeriveViaCapture = true; + RequiresNoCaptureBefore = true; // First, we want to figure out all of the sets with which we definitely // don't alias. Iterate over all noalias set, and add those for which: @@ -1124,16 +1131,16 @@ static void AddAliasScopeMetadata(CallBase &CB, ValueToValueMapTy &VMap, // noalias arguments via other noalias arguments or globals, and so we // must always check for prior capture. for (const Argument *A : NoAliasArgs) { - if (!ObjSet.count(A) && (!CanDeriveViaCapture || - // It might be tempting to skip the - // PointerMayBeCapturedBefore check if - // A->hasNoCaptureAttr() is true, but this is - // incorrect because nocapture only guarantees - // that no copies outlive the function, not - // that the value cannot be locally captured. - !PointerMayBeCapturedBefore(A, - /* ReturnCaptures */ false, - /* StoreCaptures */ false, I, &DT))) + if (ObjSet.contains(A)) + continue; // May be based on a noalias argument. + + // It might be tempting to skip the PointerMayBeCapturedBefore check if + // A->hasNoCaptureAttr() is true, but this is incorrect because + // nocapture only guarantees that no copies outlive the function, not + // that the value cannot be locally captured. + if (!RequiresNoCaptureBefore || + !PointerMayBeCapturedBefore(A, /* ReturnCaptures */ false, + /* StoreCaptures */ false, I, &DT)) NoAliases.push_back(NewScopes[A]); } @@ -1422,7 +1429,8 @@ static Value *HandleByValArgument(Type *ByValType, Value *Arg, // If the byval had an alignment specified, we *must* use at least that // alignment, as it is required by the byval argument (and uses of the // pointer inside the callee). - Alignment = max(Alignment, MaybeAlign(ByValAlignment)); + if (ByValAlignment > 0) + Alignment = std::max(Alignment, Align(ByValAlignment)); Value *NewAlloca = new AllocaInst(ByValType, DL.getAllocaAddrSpace(), nullptr, Alignment, @@ -1601,7 +1609,7 @@ static void updateCallProfile(Function *Callee, const ValueToValueMapTy &VMap, return; auto CallSiteCount = PSI ? PSI->getProfileCount(TheCall, CallerBFI) : None; int64_t CallCount = - std::min(CallSiteCount.getValueOr(0), CalleeEntryCount.getCount()); + std::min(CallSiteCount.value_or(0), CalleeEntryCount.getCount()); updateProfileCallee(Callee, -CallCount, &VMap); } @@ -1609,7 +1617,7 @@ void llvm::updateProfileCallee( Function *Callee, int64_t EntryDelta, const ValueMap<const Value *, WeakTrackingVH> *VMap) { auto CalleeCount = Callee->getEntryCount(); - if (!CalleeCount.hasValue()) + if (!CalleeCount) return; const uint64_t PriorEntryCount = CalleeCount->getCount(); @@ -1789,6 +1797,13 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, BasicBlock *OrigBB = CB.getParent(); Function *Caller = OrigBB->getParent(); + // Do not inline strictfp function into non-strictfp one. It would require + // conversion of all FP operations in host function to constrained intrinsics. + if (CalledFunc->getAttributes().hasFnAttr(Attribute::StrictFP) && + !Caller->getAttributes().hasFnAttr(Attribute::StrictFP)) { + return InlineResult::failure("incompatible strictfp attributes"); + } + // GC poses two hazards to inlining, which only occur when the callee has GC: // 1. If the caller has no GC, then the callee's GC must be propagated to the // caller. @@ -2644,7 +2659,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, AssumptionCache *AC = IFI.GetAssumptionCache ? &IFI.GetAssumptionCache(*Caller) : nullptr; auto &DL = Caller->getParent()->getDataLayout(); - if (Value *V = SimplifyInstruction(PHI, {DL, nullptr, nullptr, AC})) { + if (Value *V = simplifyInstruction(PHI, {DL, nullptr, nullptr, AC})) { PHI->replaceAllUsesWith(V); PHI->eraseFromParent(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/IntegerDivision.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/IntegerDivision.cpp index 9082049c82da..47ab30f03d14 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/IntegerDivision.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/IntegerDivision.cpp @@ -18,7 +18,6 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" -#include <utility> using namespace llvm; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LCSSA.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LCSSA.cpp index 72b864dc3e48..84d377d835f3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LCSSA.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LCSSA.cpp @@ -33,14 +33,13 @@ #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" -#include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/Dominators.h" -#include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp index 6958a89f5be6..6e87da9fb168 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp @@ -30,14 +30,12 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/Local.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/Local.cpp index 1c350a2585d0..b203259db1c6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/Local.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/Local.cpp @@ -29,7 +29,6 @@ #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/LazyValueInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -63,9 +62,7 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" -#include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/IR/PseudoProbe.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" @@ -80,7 +77,6 @@ #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> #include <cassert> -#include <climits> #include <cstdint> #include <iterator> #include <map> @@ -489,7 +485,7 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, if (auto *FPI = dyn_cast<ConstrainedFPIntrinsic>(I)) { Optional<fp::ExceptionBehavior> ExBehavior = FPI->getExceptionBehavior(); - return ExBehavior.getValue() != fp::ebStrict; + return *ExBehavior != fp::ebStrict; } } @@ -504,15 +500,12 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, if (isMathLibCallNoop(Call, TLI)) return true; - // To express possible interaction with floating point environment constrained - // intrinsics are described as if they access memory. So they look like having - // side effect but actually do not have it unless they raise floating point - // exception. If FP exceptions are ignored, the intrinsic may be deleted. - if (auto *CI = dyn_cast<ConstrainedFPIntrinsic>(I)) { - Optional<fp::ExceptionBehavior> EB = CI->getExceptionBehavior(); - if (!EB || *EB == fp::ExceptionBehavior::ebIgnore) - return true; - } + // Non-volatile atomic loads from constants can be removed. + if (auto *LI = dyn_cast<LoadInst>(I)) + if (auto *GV = dyn_cast<GlobalVariable>( + LI->getPointerOperand()->stripPointerCasts())) + if (!LI->isVolatile() && GV->isConstant()) + return true; return false; } @@ -682,7 +675,7 @@ simplifyAndDCEInstruction(Instruction *I, return true; } - if (Value *SimpleV = SimplifyInstruction(I, DL)) { + if (Value *SimpleV = simplifyInstruction(I, DL)) { // Add the users to the worklist. CAREFUL: an instruction can use itself, // in the case of a phi node. for (User *U : I->users()) { @@ -1133,7 +1126,7 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, // If there is more than one pred of succ, and there are PHI nodes in // the successor, then we need to add incoming edges for the PHI nodes // - const PredBlockVector BBPreds(pred_begin(BB), pred_end(BB)); + const PredBlockVector BBPreds(predecessors(BB)); // Loop over all of the PHI nodes in the successor of BB. for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) { @@ -1393,7 +1386,7 @@ Align llvm::getOrEnforceKnownAlignment(Value *V, MaybeAlign PrefAlign, static bool PhiHasDebugValue(DILocalVariable *DIVar, DIExpression *DIExpr, PHINode *APN) { - // Since we can't guarantee that the original dbg.declare instrinsic + // Since we can't guarantee that the original dbg.declare intrinsic // is removed by LowerDbgDeclare(), we need to make sure that we are // not inserting the same dbg.value intrinsic over and over. SmallVector<DbgValueInst *, 1> DbgValues; @@ -1472,7 +1465,7 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to dbg.value: " << *DII << '\n'); // For now, when there is a store to parts of the variable (but we do not - // know which part) we insert an dbg.value instrinsic to indicate that we + // know which part) we insert an dbg.value intrinsic to indicate that we // know nothing about the variable's content. DV = UndefValue::get(DV->getType()); Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc, SI); @@ -2240,6 +2233,7 @@ BasicBlock *llvm::changeToInvokeAndSplitBasicBlock(CallInst *CI, II->setDebugLoc(CI->getDebugLoc()); II->setCallingConv(CI->getCallingConv()); II->setAttributes(CI->getAttributes()); + II->setMetadata(LLVMContext::MD_prof, CI->getMetadata(LLVMContext::MD_prof)); if (DTU) DTU->applyUpdates({{DominatorTree::Insert, BB, UnwindEdge}}); @@ -2349,19 +2343,42 @@ static bool markAliveBlocks(Function &F, isa<UndefValue>(Callee)) { changeToUnreachable(II, false, DTU); Changed = true; - } else if (II->doesNotThrow() && canSimplifyInvokeNoUnwind(&F)) { - if (II->use_empty() && !II->mayHaveSideEffects()) { - // jump to the normal destination branch. - BasicBlock *NormalDestBB = II->getNormalDest(); - BasicBlock *UnwindDestBB = II->getUnwindDest(); - BranchInst::Create(NormalDestBB, II); - UnwindDestBB->removePredecessor(II->getParent()); - II->eraseFromParent(); + } else { + if (II->doesNotReturn() && + !isa<UnreachableInst>(II->getNormalDest()->front())) { + // If we found an invoke of a no-return function, + // create a new empty basic block with an `unreachable` terminator, + // and set it as the normal destination for the invoke, + // unless that is already the case. + // Note that the original normal destination could have other uses. + BasicBlock *OrigNormalDest = II->getNormalDest(); + OrigNormalDest->removePredecessor(II->getParent()); + LLVMContext &Ctx = II->getContext(); + BasicBlock *UnreachableNormalDest = BasicBlock::Create( + Ctx, OrigNormalDest->getName() + ".unreachable", + II->getFunction(), OrigNormalDest); + new UnreachableInst(Ctx, UnreachableNormalDest); + II->setNormalDest(UnreachableNormalDest); if (DTU) - DTU->applyUpdates({{DominatorTree::Delete, BB, UnwindDestBB}}); - } else - changeToCall(II, DTU); - Changed = true; + DTU->applyUpdates( + {{DominatorTree::Delete, BB, OrigNormalDest}, + {DominatorTree::Insert, BB, UnreachableNormalDest}}); + Changed = true; + } + if (II->doesNotThrow() && canSimplifyInvokeNoUnwind(&F)) { + if (II->use_empty() && !II->mayHaveSideEffects()) { + // jump to the normal destination branch. + BasicBlock *NormalDestBB = II->getNormalDest(); + BasicBlock *UnwindDestBB = II->getUnwindDest(); + BranchInst::Create(NormalDestBB, II); + UnwindDestBB->removePredecessor(II->getParent()); + II->eraseFromParent(); + if (DTU) + DTU->applyUpdates({{DominatorTree::Delete, BB, UnwindDestBB}}); + } else + changeToCall(II, DTU); + Changed = true; + } } } else if (auto *CatchSwitch = dyn_cast<CatchSwitchInst>(Terminator)) { // Remove catchpads which cannot be reached. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopPeel.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopPeel.cpp index 5b66da1e7082..f093fea19c4d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopPeel.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopPeel.cpp @@ -28,7 +28,6 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" -#include "llvm/IR/Metadata.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" @@ -38,12 +37,10 @@ #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" -#include "llvm/Transforms/Utils/UnrollLoop.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> #include <cassert> #include <cstdint> -#include <limits> using namespace llvm; using namespace llvm::PatternMatch; @@ -389,6 +386,10 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, if (!PP.AllowPeeling) return; + // Check that we can peel at least one iteration. + if (2 * LoopSize > Threshold) + return; + unsigned AlreadyPeeled = 0; if (auto Peeled = getOptionalIntLoopAttribute(L, PeeledCountMetaData)) AlreadyPeeled = *Peeled; @@ -401,47 +402,45 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, // which every Phi is guaranteed to become an invariant, and try to peel the // maximum number of iterations among these values, thus turning all those // Phis into invariants. - // First, check that we can peel at least one iteration. - if (2 * LoopSize <= Threshold && UnrollPeelMaxCount > 0) { - // Store the pre-calculated values here. - SmallDenseMap<PHINode *, Optional<unsigned> > IterationsToInvariance; - // Now go through all Phis to calculate their the number of iterations they - // need to become invariants. - // Start the max computation with the PP.PeelCount value set by the target - // in TTI.getPeelingPreferences or by the flag -unroll-peel-count. - unsigned DesiredPeelCount = TargetPeelCount; - BasicBlock *BackEdge = L->getLoopLatch(); - assert(BackEdge && "Loop is not in simplified form?"); - for (auto BI = L->getHeader()->begin(); isa<PHINode>(&*BI); ++BI) { - PHINode *Phi = cast<PHINode>(&*BI); - auto ToInvariance = calculateIterationsToInvariance( - Phi, L, BackEdge, IterationsToInvariance); - if (ToInvariance) - DesiredPeelCount = std::max(DesiredPeelCount, *ToInvariance); - } - // Pay respect to limitations implied by loop size and the max peel count. - unsigned MaxPeelCount = UnrollPeelMaxCount; - MaxPeelCount = std::min(MaxPeelCount, Threshold / LoopSize - 1); - - DesiredPeelCount = std::max(DesiredPeelCount, - countToEliminateCompares(*L, MaxPeelCount, SE)); - - if (DesiredPeelCount == 0) - DesiredPeelCount = peelToTurnInvariantLoadsDerefencebale(*L, DT); - - if (DesiredPeelCount > 0) { - DesiredPeelCount = std::min(DesiredPeelCount, MaxPeelCount); - // Consider max peel count limitation. - assert(DesiredPeelCount > 0 && "Wrong loop size estimation?"); - if (DesiredPeelCount + AlreadyPeeled <= UnrollPeelMaxCount) { - LLVM_DEBUG(dbgs() << "Peel " << DesiredPeelCount - << " iteration(s) to turn" - << " some Phis into invariants.\n"); - PP.PeelCount = DesiredPeelCount; - PP.PeelProfiledIterations = false; - return; - } + // Store the pre-calculated values here. + SmallDenseMap<PHINode *, Optional<unsigned>> IterationsToInvariance; + // Now go through all Phis to calculate their the number of iterations they + // need to become invariants. + // Start the max computation with the PP.PeelCount value set by the target + // in TTI.getPeelingPreferences or by the flag -unroll-peel-count. + unsigned DesiredPeelCount = TargetPeelCount; + BasicBlock *BackEdge = L->getLoopLatch(); + assert(BackEdge && "Loop is not in simplified form?"); + for (auto BI = L->getHeader()->begin(); isa<PHINode>(&*BI); ++BI) { + PHINode *Phi = cast<PHINode>(&*BI); + auto ToInvariance = calculateIterationsToInvariance(Phi, L, BackEdge, + IterationsToInvariance); + if (ToInvariance) + DesiredPeelCount = std::max(DesiredPeelCount, *ToInvariance); + } + + // Pay respect to limitations implied by loop size and the max peel count. + unsigned MaxPeelCount = UnrollPeelMaxCount; + MaxPeelCount = std::min(MaxPeelCount, Threshold / LoopSize - 1); + + DesiredPeelCount = std::max(DesiredPeelCount, + countToEliminateCompares(*L, MaxPeelCount, SE)); + + if (DesiredPeelCount == 0) + DesiredPeelCount = peelToTurnInvariantLoadsDerefencebale(*L, DT); + + if (DesiredPeelCount > 0) { + DesiredPeelCount = std::min(DesiredPeelCount, MaxPeelCount); + // Consider max peel count limitation. + assert(DesiredPeelCount > 0 && "Wrong loop size estimation?"); + if (DesiredPeelCount + AlreadyPeeled <= UnrollPeelMaxCount) { + LLVM_DEBUG(dbgs() << "Peel " << DesiredPeelCount + << " iteration(s) to turn" + << " some Phis into invariants.\n"); + PP.PeelCount = DesiredPeelCount; + PP.PeelProfiledIterations = false; + return; } } @@ -461,27 +460,26 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, if (L->getHeader()->getParent()->hasProfileData()) { if (violatesLegacyMultiExitLoopCheck(L)) return; - Optional<unsigned> PeelCount = getLoopEstimatedTripCount(L); - if (!PeelCount) + Optional<unsigned> EstimatedTripCount = getLoopEstimatedTripCount(L); + if (!EstimatedTripCount) return; - LLVM_DEBUG(dbgs() << "Profile-based estimated trip count is " << *PeelCount - << "\n"); + LLVM_DEBUG(dbgs() << "Profile-based estimated trip count is " + << *EstimatedTripCount << "\n"); - if (*PeelCount) { - if ((*PeelCount + AlreadyPeeled <= UnrollPeelMaxCount) && - (LoopSize * (*PeelCount + 1) <= Threshold)) { - LLVM_DEBUG(dbgs() << "Peeling first " << *PeelCount - << " iterations.\n"); - PP.PeelCount = *PeelCount; + if (*EstimatedTripCount) { + if (*EstimatedTripCount + AlreadyPeeled <= MaxPeelCount) { + unsigned PeelCount = *EstimatedTripCount; + LLVM_DEBUG(dbgs() << "Peeling first " << PeelCount << " iterations.\n"); + PP.PeelCount = PeelCount; return; } - LLVM_DEBUG(dbgs() << "Requested peel count: " << *PeelCount << "\n"); LLVM_DEBUG(dbgs() << "Already peel count: " << AlreadyPeeled << "\n"); LLVM_DEBUG(dbgs() << "Max peel count: " << UnrollPeelMaxCount << "\n"); - LLVM_DEBUG(dbgs() << "Peel cost: " << LoopSize * (*PeelCount + 1) - << "\n"); + LLVM_DEBUG(dbgs() << "Loop cost: " << LoopSize << "\n"); LLVM_DEBUG(dbgs() << "Max peel cost: " << Threshold << "\n"); + LLVM_DEBUG(dbgs() << "Max peel count by cost: " + << (Threshold / LoopSize - 1) << "\n"); } } } @@ -579,7 +577,8 @@ static void cloneLoopBlocks( SmallVectorImpl<std::pair<BasicBlock *, BasicBlock *>> &ExitEdges, SmallVectorImpl<BasicBlock *> &NewBlocks, LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap, ValueToValueMapTy &LVMap, DominatorTree *DT, - LoopInfo *LI, ArrayRef<MDNode *> LoopLocalNoAliasDeclScopes) { + LoopInfo *LI, ArrayRef<MDNode *> LoopLocalNoAliasDeclScopes, + ScalarEvolution &SE) { BasicBlock *Header = L->getHeader(); BasicBlock *Latch = L->getLoopLatch(); BasicBlock *PreHeader = L->getLoopPreheader(); @@ -685,6 +684,7 @@ static void cloneLoopBlocks( if (LatchInst && L->contains(LatchInst)) LatchVal = VMap[LatchVal]; PHI.addIncoming(LatchVal, cast<BasicBlock>(VMap[Edge.first])); + SE.forgetValue(&PHI); } // LastValueMap is updated with the values for the current loop @@ -719,9 +719,9 @@ TargetTransformInfo::PeelingPreferences llvm::gatherPeelingPreferences( } // User specifed values provided by argument. - if (UserAllowPeeling.hasValue()) + if (UserAllowPeeling) PP.AllowPeeling = *UserAllowPeeling; - if (UserAllowProfileBasedPeeling.hasValue()) + if (UserAllowProfileBasedPeeling) PP.PeelProfiledIterations = *UserAllowProfileBasedPeeling; return PP; @@ -851,7 +851,7 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, cloneLoopBlocks(L, Iter, InsertTop, InsertBot, ExitEdges, NewBlocks, LoopBlocks, VMap, LVMap, &DT, LI, - LoopLocalNoAliasDeclScopes); + LoopLocalNoAliasDeclScopes, *SE); // Remap to use values from the current iteration instead of the // previous one. @@ -907,8 +907,10 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, // We modified the loop, update SE. SE->forgetTopmostLoop(L); +#ifdef EXPENSIVE_CHECKS // Finally DomtTree must be correct. assert(DT.verify(DominatorTree::VerificationLevel::Fast)); +#endif // FIXME: Incrementally update loop-simplify simplifyLoop(L, &DT, LI, SE, AC, nullptr, PreserveLCSSA); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp index c66fd7bb0588..0f33559c7e70 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp @@ -13,31 +13,24 @@ #include "llvm/Transforms/Utils/LoopRotationUtils.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/DomTreeUpdater.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" -#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CFG.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/Dominators.h" -#include "llvm/IR/Function.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" -#include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include "llvm/Transforms/Utils/ValueMapper.h" using namespace llvm; @@ -317,7 +310,13 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { L->dump()); return Rotated; } - if (Metrics.NumInsts > MaxHeaderSize) { + if (!Metrics.NumInsts.isValid()) { + LLVM_DEBUG(dbgs() << "LoopRotation: NOT rotating - contains instructions" + " with invalid cost: "; + L->dump()); + return Rotated; + } + if (*Metrics.NumInsts.getValue() > MaxHeaderSize) { LLVM_DEBUG(dbgs() << "LoopRotation: NOT rotating - contains " << Metrics.NumInsts << " instructions, which is more than the threshold (" @@ -446,7 +445,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // With the operands remapped, see if the instruction constant folds or is // otherwise simplifyable. This commonly occurs because the entry from PHI // nodes allows icmps and other instructions to fold. - Value *V = SimplifyInstruction(C, SQ); + Value *V = simplifyInstruction(C, SQ); if (V && LI->replacementPreservesLCSSAForm(C, V)) { // If so, then delete the temporary instruction and stick the folded value // in the map. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopSimplify.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopSimplify.cpp index 67311ab4cd02..55d5c733733b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopSimplify.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopSimplify.cpp @@ -40,8 +40,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/LoopSimplify.h" -#include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -59,14 +57,11 @@ #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" #include "llvm/InitializePasses.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -181,7 +176,7 @@ static PHINode *findPHIToPartitionLoops(Loop *L, DominatorTree *DT, for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ) { PHINode *PN = cast<PHINode>(I); ++I; - if (Value *V = SimplifyInstruction(PN, {DL, nullptr, DT, AC})) { + if (Value *V = simplifyInstruction(PN, {DL, nullptr, DT, AC})) { // This is a degenerate PHI already, don't modify it! PN->replaceAllUsesWith(V); PN->eraseFromParent(); @@ -602,7 +597,7 @@ ReprocessLoop: PHINode *PN; for (BasicBlock::iterator I = L->getHeader()->begin(); (PN = dyn_cast<PHINode>(I++)); ) - if (Value *V = SimplifyInstruction(PN, {DL, nullptr, DT, AC})) { + if (Value *V = simplifyInstruction(PN, {DL, nullptr, DT, AC})) { if (SE) SE->forgetValue(PN); if (!PreserveLCSSA || LI->replacementPreservesLCSSAForm(PN, V)) { PN->replaceAllUsesWith(V); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnroll.cpp index 9ca1f4f44b97..1be1082002fc 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnroll.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnroll.cpp @@ -236,7 +236,7 @@ void llvm::simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI, SmallVector<WeakTrackingVH, 16> DeadInsts; for (BasicBlock *BB : L->getBlocks()) { for (Instruction &Inst : llvm::make_early_inc_range(*BB)) { - if (Value *V = SimplifyInstruction(&Inst, {DL, nullptr, DT, AC})) + if (Value *V = simplifyInstruction(&Inst, {DL, nullptr, DT, AC})) if (LI->replacementPreservesLCSSAForm(&Inst, V)) Inst.replaceAllUsesWith(V); if (isInstructionTriviallyDead(&Inst)) @@ -513,7 +513,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, if (const DILocation *DIL = I.getDebugLoc()) { auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(ULO.Count); if (NewDIL) - I.setDebugLoc(NewDIL.getValue()); + I.setDebugLoc(*NewDIL); else LLVM_DEBUG(dbgs() << "Failed to create new discriminator: " diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp index 6efaa012aeca..96485d15c75b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp @@ -15,7 +15,6 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -39,7 +38,6 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Use.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" @@ -358,7 +356,7 @@ llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount, if (const DILocation *DIL = I.getDebugLoc()) { auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(Count); if (NewDIL) - I.setDebugLoc(NewDIL.getValue()); + I.setDebugLoc(*NewDIL); else LLVM_DEBUG(dbgs() << "Failed to create new discriminator: " diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp index bb719a499a4c..cd3b6c1a095a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -20,20 +20,19 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/MDBuilder.h" -#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" @@ -74,7 +73,8 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, BasicBlock *OriginalLoopLatchExit, BasicBlock *PreHeader, BasicBlock *NewPreHeader, ValueToValueMapTy &VMap, DominatorTree *DT, - LoopInfo *LI, bool PreserveLCSSA) { + LoopInfo *LI, bool PreserveLCSSA, + ScalarEvolution &SE) { // Loop structure should be the following: // Preheader // PrologHeader @@ -134,6 +134,7 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, PN.setIncomingValueForBlock(NewPreHeader, NewPN); else PN.addIncoming(NewPN, PrologExit); + SE.forgetValue(&PN); } } @@ -192,7 +193,8 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, BasicBlock *Exit, BasicBlock *PreHeader, BasicBlock *EpilogPreHeader, BasicBlock *NewPreHeader, ValueToValueMapTy &VMap, DominatorTree *DT, - LoopInfo *LI, bool PreserveLCSSA) { + LoopInfo *LI, bool PreserveLCSSA, + ScalarEvolution &SE) { BasicBlock *Latch = L->getLoopLatch(); assert(Latch && "Loop must have a latch"); BasicBlock *EpilogLatch = cast<BasicBlock>(VMap[Latch]); @@ -233,6 +235,7 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, // Add incoming PreHeader from branch around the Loop PN.addIncoming(UndefValue::get(PN.getType()), PreHeader); + SE.forgetValue(&PN); Value *V = PN.getIncomingValueForBlock(Latch); Instruction *I = dyn_cast<Instruction>(V); @@ -398,7 +401,7 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder, Optional<MDNode *> NewLoopID = makeFollowupLoopID( LoopID, {LLVMLoopUnrollFollowupAll, LLVMLoopUnrollFollowupRemainder}); - if (NewLoopID.hasValue()) { + if (NewLoopID) { NewLoop->setLoopID(NewLoopID.getValue()); // Do not setLoopAlreadyUnrolled if loop attributes have been defined @@ -739,11 +742,28 @@ bool llvm::UnrollRuntimeLoopRemainder( // Compute the number of extra iterations required, which is: // extra iterations = run-time trip count % loop unroll factor PreHeaderBR = cast<BranchInst>(PreHeader->getTerminator()); + IRBuilder<> B(PreHeaderBR); Value *TripCount = Expander.expandCodeFor(TripCountSC, TripCountSC->getType(), PreHeaderBR); - Value *BECount = Expander.expandCodeFor(BECountSC, BECountSC->getType(), - PreHeaderBR); - IRBuilder<> B(PreHeaderBR); + Value *BECount; + // If there are other exits before the latch, that may cause the latch exit + // branch to never be executed, and the latch exit count may be poison. + // In this case, freeze the TripCount and base BECount on the frozen + // TripCount. We will introduce two branches using these values, and it's + // important that they see a consistent value (which would not be guaranteed + // if were frozen independently.) + if ((!OtherExits.empty() || !SE->loopHasNoAbnormalExits(L)) && + !isGuaranteedNotToBeUndefOrPoison(TripCount, AC, PreHeaderBR, DT)) { + TripCount = B.CreateFreeze(TripCount); + BECount = + B.CreateAdd(TripCount, ConstantInt::get(TripCount->getType(), -1)); + } else { + // If we don't need to freeze, use SCEVExpander for BECount as well, to + // allow slightly better value reuse. + BECount = + Expander.expandCodeFor(BECountSC, BECountSC->getType(), PreHeaderBR); + } + Value * const ModVal = CreateTripRemainder(B, BECount, TripCount, Count); Value *BranchVal = @@ -884,9 +904,8 @@ bool llvm::UnrollRuntimeLoopRemainder( if (UseEpilogRemainder) { // Connect the epilog code to the original loop and update the // PHI functions. - ConnectEpilog(L, ModVal, NewExit, LatchExit, PreHeader, - EpilogPreHeader, NewPreHeader, VMap, DT, LI, - PreserveLCSSA); + ConnectEpilog(L, ModVal, NewExit, LatchExit, PreHeader, EpilogPreHeader, + NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE); // Update counter in loop for unrolling. // Use an incrementing IV. Pre-incr/post-incr is backedge/trip count. @@ -910,7 +929,7 @@ bool llvm::UnrollRuntimeLoopRemainder( // Connect the prolog code to the original loop and update the // PHI functions. ConnectProlog(L, BECount, Count, PrologExit, LatchExit, PreHeader, - NewPreHeader, VMap, DT, LI, PreserveLCSSA); + NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE); } // If this loop is nested, then the loop unroller changes the code in the any @@ -941,7 +960,7 @@ bool llvm::UnrollRuntimeLoopRemainder( SmallVector<WeakTrackingVH, 16> DeadInsts; for (BasicBlock *BB : RemainderBlocks) { for (Instruction &Inst : llvm::make_early_inc_range(*BB)) { - if (Value *V = SimplifyInstruction(&Inst, {DL, nullptr, DT, AC})) + if (Value *V = simplifyInstruction(&Inst, {DL, nullptr, DT, AC})) if (LI->replacementPreservesLCSSAForm(&Inst, V)) Inst.replaceAllUsesWith(V); if (isInstructionTriviallyDead(&Inst)) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp index 95db2fe8d310..ec898c463574 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -23,31 +23,25 @@ #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstSimplifyFolder.h" -#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" -#include "llvm/Analysis/MustExecute.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" -#include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DIBuilder.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" -#include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/ValueHandle.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/KnownBits.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" @@ -260,10 +254,10 @@ llvm::getOptionalElementCountLoopAttribute(const Loop *TheLoop) { Optional<int> Width = getOptionalIntLoopAttribute(TheLoop, "llvm.loop.vectorize.width"); - if (Width.hasValue()) { + if (Width) { Optional<int> IsScalable = getOptionalIntLoopAttribute( TheLoop, "llvm.loop.vectorize.scalable.enable"); - return ElementCount::get(*Width, IsScalable.getValueOr(false)); + return ElementCount::get(*Width, IsScalable.value_or(false)); } return None; @@ -364,7 +358,7 @@ TransformationMode llvm::hasUnrollTransformation(const Loop *L) { Optional<int> Count = getOptionalIntLoopAttribute(L, "llvm.loop.unroll.count"); - if (Count.hasValue()) + if (Count) return Count.getValue() == 1 ? TM_SuppressedByUser : TM_ForcedByUser; if (getBooleanLoopAttribute(L, "llvm.loop.unroll.enable")) @@ -385,7 +379,7 @@ TransformationMode llvm::hasUnrollAndJamTransformation(const Loop *L) { Optional<int> Count = getOptionalIntLoopAttribute(L, "llvm.loop.unroll_and_jam.count"); - if (Count.hasValue()) + if (Count) return Count.getValue() == 1 ? TM_SuppressedByUser : TM_ForcedByUser; if (getBooleanLoopAttribute(L, "llvm.loop.unroll_and_jam.enable")) @@ -497,9 +491,11 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, if (SE) SE->forgetLoop(L); - auto *OldBr = dyn_cast<BranchInst>(Preheader->getTerminator()); - assert(OldBr && "Preheader must end with a branch"); - assert(OldBr->isUnconditional() && "Preheader must have a single successor"); + Instruction *OldTerm = Preheader->getTerminator(); + assert(!OldTerm->mayHaveSideEffects() && + "Preheader must end with a side-effect-free terminator"); + assert(OldTerm->getNumSuccessors() == 1 && + "Preheader must have a single successor"); // Connect the preheader to the exit block. Keep the old edge to the header // around to perform the dominator tree update in two separate steps // -- #1 insertion of the edge preheader -> exit and #2 deletion of the edge @@ -525,7 +521,7 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, // coming to this inner loop, this will break the outer loop structure (by // deleting the backedge of the outer loop). If the outer loop is indeed a // non-loop, it will be deleted in a future iteration of loop deletion pass. - IRBuilder<> Builder(OldBr); + IRBuilder<> Builder(OldTerm); auto *ExitBlock = L->getUniqueExitBlock(); DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); @@ -535,7 +531,7 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, Builder.CreateCondBr(Builder.getFalse(), L->getHeader(), ExitBlock); // Remove the old branch. The conditional branch becomes a new terminator. - OldBr->eraseFromParent(); + OldTerm->eraseFromParent(); // Rewrite phis in the exit block to get their inputs from the Preheader // instead of the exiting block. @@ -579,7 +575,7 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, assert(L->hasNoExitBlocks() && "Loop should have either zero or one exit blocks."); - Builder.SetInsertPoint(OldBr); + Builder.SetInsertPoint(OldTerm); Builder.CreateUnreachable(); Preheader->getTerminator()->eraseFromParent(); } @@ -692,18 +688,12 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, } } -static Loop *getOutermostLoop(Loop *L) { - while (Loop *Parent = L->getParentLoop()) - L = Parent; - return L; -} - void llvm::breakLoopBackedge(Loop *L, DominatorTree &DT, ScalarEvolution &SE, LoopInfo &LI, MemorySSA *MSSA) { auto *Latch = L->getLoopLatch(); assert(Latch && "multiple latches not yet supported"); auto *Header = L->getHeader(); - Loop *OutermostLoop = getOutermostLoop(L); + Loop *OutermostLoop = L->getOutermostLoop(); SE.forgetLoop(L); @@ -1103,7 +1093,8 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B, return B.CreateFAddReduce(Start, Src); } -void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue) { +void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue, + bool IncludeWrapFlags) { auto *VecOp = dyn_cast<Instruction>(I); if (!VecOp) return; @@ -1112,7 +1103,7 @@ void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue) { if (!Intersection) return; const unsigned Opcode = Intersection->getOpcode(); - VecOp->copyIRFlags(Intersection); + VecOp->copyIRFlags(Intersection, IncludeWrapFlags); for (auto *V : VL) { auto *Instr = dyn_cast<Instruction>(V); if (!Instr) @@ -1536,6 +1527,11 @@ static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG, LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n"); Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc); End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc); + if (CG->NeedsFreeze) { + IRBuilder<> Builder(Loc); + Start = Builder.CreateFreeze(Start, Start->getName() + ".fr"); + End = Builder.CreateFreeze(End, End->getName() + ".fr"); + } LLVM_DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High << "\n"); return {Start, End}; } @@ -1614,6 +1610,45 @@ Value *llvm::addRuntimeChecks( return MemoryRuntimeCheck; } +Value *llvm::addDiffRuntimeChecks( + Instruction *Loc, Loop *TheLoop, ArrayRef<PointerDiffInfo> Checks, + SCEVExpander &Expander, + function_ref<Value *(IRBuilderBase &, unsigned)> GetVF, unsigned IC) { + + LLVMContext &Ctx = Loc->getContext(); + IRBuilder<InstSimplifyFolder> ChkBuilder(Ctx, + Loc->getModule()->getDataLayout()); + ChkBuilder.SetInsertPoint(Loc); + // Our instructions might fold to a constant. + Value *MemoryRuntimeCheck = nullptr; + + for (auto &C : Checks) { + Type *Ty = C.SinkStart->getType(); + // Compute VF * IC * AccessSize. + auto *VFTimesUFTimesSize = + ChkBuilder.CreateMul(GetVF(ChkBuilder, Ty->getScalarSizeInBits()), + ConstantInt::get(Ty, IC * C.AccessSize)); + Value *Sink = Expander.expandCodeFor(C.SinkStart, Ty, Loc); + Value *Src = Expander.expandCodeFor(C.SrcStart, Ty, Loc); + if (C.NeedsFreeze) { + IRBuilder<> Builder(Loc); + Sink = Builder.CreateFreeze(Sink, Sink->getName() + ".fr"); + Src = Builder.CreateFreeze(Src, Src->getName() + ".fr"); + } + Value *Diff = ChkBuilder.CreateSub(Sink, Src); + Value *IsConflict = + ChkBuilder.CreateICmpULT(Diff, VFTimesUFTimesSize, "diff.check"); + + if (MemoryRuntimeCheck) { + IsConflict = + ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx"); + } + MemoryRuntimeCheck = IsConflict; + } + + return MemoryRuntimeCheck; +} + Optional<IVConditionInfo> llvm::hasPartialIVCondition(Loop &L, unsigned MSSAThreshold, MemorySSA &MSSA, diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopVersioning.cpp index f0bf625fa18e..97f29527bb95 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -41,9 +41,8 @@ LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, ArrayRef<RuntimePointerCheck> Checks, Loop *L, LoopInfo *LI, DominatorTree *DT, ScalarEvolution *SE) - : VersionedLoop(L), NonVersionedLoop(nullptr), - AliasChecks(Checks.begin(), Checks.end()), - Preds(LAI.getPSE().getUnionPredicate()), LAI(LAI), LI(LI), DT(DT), + : VersionedLoop(L), AliasChecks(Checks.begin(), Checks.end()), + Preds(LAI.getPSE().getPredicate()), LAI(LAI), LI(LI), DT(DT), SE(SE) { } @@ -277,7 +276,7 @@ bool runImpl(LoopInfo *LI, function_ref<const LoopAccessInfo &(Loop &)> GetLAA, const LoopAccessInfo &LAI = GetLAA(*L); if (!LAI.hasConvergentOp() && (LAI.getNumRuntimePointerChecks() || - !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) { + !LAI.getPSE().getPredicate().isAlwaysTrue())) { LoopVersioning LVer(LAI, LAI.getRuntimePointerChecking()->getChecks(), L, LI, DT, SE); LVer.versionLoop(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerAtomic.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerAtomic.cpp new file mode 100644 index 000000000000..8641581c8039 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerAtomic.cpp @@ -0,0 +1,93 @@ +//===- LowerAtomic.cpp - Lower atomic intrinsics --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass lowers atomic intrinsics to non-atomic form for use in a known +// non-preemptible environment. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/LowerAtomic.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +using namespace llvm; + +#define DEBUG_TYPE "loweratomic" + +bool llvm::lowerAtomicCmpXchgInst(AtomicCmpXchgInst *CXI) { + IRBuilder<> Builder(CXI); + Value *Ptr = CXI->getPointerOperand(); + Value *Cmp = CXI->getCompareOperand(); + Value *Val = CXI->getNewValOperand(); + + LoadInst *Orig = Builder.CreateLoad(Val->getType(), Ptr); + Value *Equal = Builder.CreateICmpEQ(Orig, Cmp); + Value *Res = Builder.CreateSelect(Equal, Val, Orig); + Builder.CreateStore(Res, Ptr); + + Res = Builder.CreateInsertValue(UndefValue::get(CXI->getType()), Orig, 0); + Res = Builder.CreateInsertValue(Res, Equal, 1); + + CXI->replaceAllUsesWith(Res); + CXI->eraseFromParent(); + return true; +} + +Value *llvm::buildAtomicRMWValue(AtomicRMWInst::BinOp Op, + IRBuilderBase &Builder, Value *Loaded, + Value *Inc) { + Value *NewVal; + switch (Op) { + case AtomicRMWInst::Xchg: + return Inc; + case AtomicRMWInst::Add: + return Builder.CreateAdd(Loaded, Inc, "new"); + case AtomicRMWInst::Sub: + return Builder.CreateSub(Loaded, Inc, "new"); + case AtomicRMWInst::And: + return Builder.CreateAnd(Loaded, Inc, "new"); + case AtomicRMWInst::Nand: + return Builder.CreateNot(Builder.CreateAnd(Loaded, Inc), "new"); + case AtomicRMWInst::Or: + return Builder.CreateOr(Loaded, Inc, "new"); + case AtomicRMWInst::Xor: + return Builder.CreateXor(Loaded, Inc, "new"); + case AtomicRMWInst::Max: + NewVal = Builder.CreateICmpSGT(Loaded, Inc); + return Builder.CreateSelect(NewVal, Loaded, Inc, "new"); + case AtomicRMWInst::Min: + NewVal = Builder.CreateICmpSLE(Loaded, Inc); + return Builder.CreateSelect(NewVal, Loaded, Inc, "new"); + case AtomicRMWInst::UMax: + NewVal = Builder.CreateICmpUGT(Loaded, Inc); + return Builder.CreateSelect(NewVal, Loaded, Inc, "new"); + case AtomicRMWInst::UMin: + NewVal = Builder.CreateICmpULE(Loaded, Inc); + return Builder.CreateSelect(NewVal, Loaded, Inc, "new"); + case AtomicRMWInst::FAdd: + return Builder.CreateFAdd(Loaded, Inc, "new"); + case AtomicRMWInst::FSub: + return Builder.CreateFSub(Loaded, Inc, "new"); + default: + llvm_unreachable("Unknown atomic op"); + } +} + +bool llvm::lowerAtomicRMWInst(AtomicRMWInst *RMWI) { + IRBuilder<> Builder(RMWI); + Value *Ptr = RMWI->getPointerOperand(); + Value *Val = RMWI->getValOperand(); + + LoadInst *Orig = Builder.CreateLoad(Val->getType(), Ptr); + Value *Res = buildAtomicRMWValue(RMWI->getOperation(), Builder, Orig, Val); + Builder.CreateStore(Res, Ptr); + RMWI->replaceAllUsesWith(Orig); + RMWI->eraseFromParent(); + return true; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp new file mode 100644 index 000000000000..010deb77a883 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp @@ -0,0 +1,221 @@ +//===-- LowerGlobalDtors.cpp - Lower @llvm.global_dtors -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Lower @llvm.global_dtors. +/// +/// Implement @llvm.global_dtors by creating wrapper functions that are +/// registered in @llvm.global_ctors and which contain a call to +/// `__cxa_atexit` to register their destructor functions. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/LowerGlobalDtors.h" + +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" +#include <map> + +using namespace llvm; + +#define DEBUG_TYPE "lower-global-dtors" + +namespace { +class LowerGlobalDtorsLegacyPass final : public ModulePass { + StringRef getPassName() const override { + return "Lower @llvm.global_dtors via `__cxa_atexit`"; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + ModulePass::getAnalysisUsage(AU); + } + + bool runOnModule(Module &M) override; + +public: + static char ID; + LowerGlobalDtorsLegacyPass() : ModulePass(ID) { + initializeLowerGlobalDtorsLegacyPassPass(*PassRegistry::getPassRegistry()); + } +}; +} // End anonymous namespace + +char LowerGlobalDtorsLegacyPass::ID = 0; +INITIALIZE_PASS(LowerGlobalDtorsLegacyPass, DEBUG_TYPE, + "Lower @llvm.global_dtors via `__cxa_atexit`", false, false) + +ModulePass *llvm::createLowerGlobalDtorsLegacyPass() { + return new LowerGlobalDtorsLegacyPass(); +} + +static bool runImpl(Module &M); +bool LowerGlobalDtorsLegacyPass::runOnModule(Module &M) { return runImpl(M); } + +PreservedAnalyses LowerGlobalDtorsPass::run(Module &M, + ModuleAnalysisManager &AM) { + bool Changed = runImpl(M); + if (!Changed) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; +} + +static bool runImpl(Module &M) { + GlobalVariable *GV = M.getGlobalVariable("llvm.global_dtors"); + if (!GV || !GV->hasInitializer()) + return false; + + const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer()); + if (!InitList) + return false; + + // Validate @llvm.global_dtor's type. + auto *ETy = dyn_cast<StructType>(InitList->getType()->getElementType()); + if (!ETy || ETy->getNumElements() != 3 || + !ETy->getTypeAtIndex(0U)->isIntegerTy() || + !ETy->getTypeAtIndex(1U)->isPointerTy() || + !ETy->getTypeAtIndex(2U)->isPointerTy()) + return false; // Not (int, ptr, ptr). + + // Collect the contents of @llvm.global_dtors, ordered by priority. Within a + // priority, sequences of destructors with the same associated object are + // recorded so that we can register them as a group. + std::map< + uint16_t, + std::vector<std::pair<Constant *, std::vector<Constant *>>> + > DtorFuncs; + for (Value *O : InitList->operands()) { + auto *CS = dyn_cast<ConstantStruct>(O); + if (!CS) + continue; // Malformed. + + auto *Priority = dyn_cast<ConstantInt>(CS->getOperand(0)); + if (!Priority) + continue; // Malformed. + uint16_t PriorityValue = Priority->getLimitedValue(UINT16_MAX); + + Constant *DtorFunc = CS->getOperand(1); + if (DtorFunc->isNullValue()) + break; // Found a null terminator, skip the rest. + + Constant *Associated = CS->getOperand(2); + Associated = cast<Constant>(Associated->stripPointerCasts()); + + auto &AtThisPriority = DtorFuncs[PriorityValue]; + if (AtThisPriority.empty() || AtThisPriority.back().first != Associated) { + std::vector<Constant *> NewList; + NewList.push_back(DtorFunc); + AtThisPriority.push_back(std::make_pair(Associated, NewList)); + } else { + AtThisPriority.back().second.push_back(DtorFunc); + } + } + if (DtorFuncs.empty()) + return false; + + // extern "C" int __cxa_atexit(void (*f)(void *), void *p, void *d); + LLVMContext &C = M.getContext(); + PointerType *VoidStar = Type::getInt8PtrTy(C); + Type *AtExitFuncArgs[] = {VoidStar}; + FunctionType *AtExitFuncTy = + FunctionType::get(Type::getVoidTy(C), AtExitFuncArgs, + /*isVarArg=*/false); + + FunctionCallee AtExit = M.getOrInsertFunction( + "__cxa_atexit", + FunctionType::get(Type::getInt32Ty(C), + {PointerType::get(AtExitFuncTy, 0), VoidStar, VoidStar}, + /*isVarArg=*/false)); + + // Declare __dso_local. + Type *DsoHandleTy = Type::getInt8Ty(C); + Constant *DsoHandle = M.getOrInsertGlobal("__dso_handle", DsoHandleTy, [&] { + auto *GV = new GlobalVariable(M, DsoHandleTy, /*isConstant=*/true, + GlobalVariable::ExternalWeakLinkage, nullptr, + "__dso_handle"); + GV->setVisibility(GlobalVariable::HiddenVisibility); + return GV; + }); + + // For each unique priority level and associated symbol, generate a function + // to call all the destructors at that level, and a function to register the + // first function with __cxa_atexit. + for (auto &PriorityAndMore : DtorFuncs) { + uint16_t Priority = PriorityAndMore.first; + uint64_t Id = 0; + auto &AtThisPriority = PriorityAndMore.second; + for (auto &AssociatedAndMore : AtThisPriority) { + Constant *Associated = AssociatedAndMore.first; + auto ThisId = Id++; + + Function *CallDtors = Function::Create( + AtExitFuncTy, Function::PrivateLinkage, + "call_dtors" + + (Priority != UINT16_MAX ? (Twine(".") + Twine(Priority)) + : Twine()) + + (AtThisPriority.size() > 1 ? Twine("$") + Twine(ThisId) + : Twine()) + + (!Associated->isNullValue() ? (Twine(".") + Associated->getName()) + : Twine()), + &M); + BasicBlock *BB = BasicBlock::Create(C, "body", CallDtors); + FunctionType *VoidVoid = FunctionType::get(Type::getVoidTy(C), + /*isVarArg=*/false); + + for (auto Dtor : reverse(AssociatedAndMore.second)) + CallInst::Create(VoidVoid, Dtor, "", BB); + ReturnInst::Create(C, BB); + + Function *RegisterCallDtors = Function::Create( + VoidVoid, Function::PrivateLinkage, + "register_call_dtors" + + (Priority != UINT16_MAX ? (Twine(".") + Twine(Priority)) + : Twine()) + + (AtThisPriority.size() > 1 ? Twine("$") + Twine(ThisId) + : Twine()) + + (!Associated->isNullValue() ? (Twine(".") + Associated->getName()) + : Twine()), + &M); + BasicBlock *EntryBB = BasicBlock::Create(C, "entry", RegisterCallDtors); + BasicBlock *FailBB = BasicBlock::Create(C, "fail", RegisterCallDtors); + BasicBlock *RetBB = BasicBlock::Create(C, "return", RegisterCallDtors); + + Value *Null = ConstantPointerNull::get(VoidStar); + Value *Args[] = {CallDtors, Null, DsoHandle}; + Value *Res = CallInst::Create(AtExit, Args, "call", EntryBB); + Value *Cmp = new ICmpInst(*EntryBB, ICmpInst::ICMP_NE, Res, + Constant::getNullValue(Res->getType())); + BranchInst::Create(FailBB, RetBB, Cmp, EntryBB); + + // If `__cxa_atexit` hits out-of-memory, trap, so that we don't misbehave. + // This should be very rare, because if the process is running out of + // memory before main has even started, something is wrong. + CallInst::Create(Intrinsic::getDeclaration(&M, Intrinsic::trap), "", + FailBB); + new UnreachableInst(C, FailBB); + + ReturnInst::Create(C, RetBB); + + // Now register the registration function with @llvm.global_ctors. + appendToGlobalCtors(M, RegisterCallDtors, Priority, Associated); + } + } + + // Now that we've lowered everything, remove @llvm.global_dtors. + GV->eraseFromParent(); + + return true; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerInvoke.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerInvoke.cpp index fe0ff5899d8f..59cfa41fb7fd 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerInvoke.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerInvoke.cpp @@ -17,8 +17,6 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Transforms/Utils.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp index 3d75dd57456d..b4acb1b2ae90 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp @@ -7,9 +7,11 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; @@ -18,7 +20,9 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr, Value *DstAddr, ConstantInt *CopyLen, Align SrcAlign, Align DstAlign, bool SrcIsVolatile, bool DstIsVolatile, - const TargetTransformInfo &TTI) { + bool CanOverlap, + const TargetTransformInfo &TTI, + Optional<uint32_t> AtomicElementSize) { // No need to expand zero length copies. if (CopyLen->isZero()) return; @@ -28,15 +32,25 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr, Function *ParentFunc = PreLoopBB->getParent(); LLVMContext &Ctx = PreLoopBB->getContext(); const DataLayout &DL = ParentFunc->getParent()->getDataLayout(); + MDBuilder MDB(Ctx); + MDNode *NewDomain = MDB.createAnonymousAliasScopeDomain("MemCopyDomain"); + StringRef Name = "MemCopyAliasScope"; + MDNode *NewScope = MDB.createAnonymousAliasScope(NewDomain, Name); unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace(); unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace(); Type *TypeOfCopyLen = CopyLen->getType(); Type *LoopOpType = TTI.getMemcpyLoopLoweringType( - Ctx, CopyLen, SrcAS, DstAS, SrcAlign.value(), DstAlign.value()); + Ctx, CopyLen, SrcAS, DstAS, SrcAlign.value(), DstAlign.value(), + AtomicElementSize); + assert((!AtomicElementSize || !LoopOpType->isVectorTy()) && + "Atomic memcpy lowering is not supported for vector operand type"); unsigned LoopOpSize = DL.getTypeStoreSize(LoopOpType); + assert((!AtomicElementSize || LoopOpSize % *AtomicElementSize == 0) && + "Atomic memcpy lowering is not supported for selected operand size"); + uint64_t LoopEndCount = CopyLen->getZExtValue() / LoopOpSize; if (LoopEndCount != 0) { @@ -68,12 +82,25 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr, // Loop Body Value *SrcGEP = LoopBuilder.CreateInBoundsGEP(LoopOpType, SrcAddr, LoopIndex); - Value *Load = LoopBuilder.CreateAlignedLoad(LoopOpType, SrcGEP, - PartSrcAlign, SrcIsVolatile); + LoadInst *Load = LoopBuilder.CreateAlignedLoad(LoopOpType, SrcGEP, + PartSrcAlign, SrcIsVolatile); + if (!CanOverlap) { + // Set alias scope for loads. + Load->setMetadata(LLVMContext::MD_alias_scope, + MDNode::get(Ctx, NewScope)); + } Value *DstGEP = LoopBuilder.CreateInBoundsGEP(LoopOpType, DstAddr, LoopIndex); - LoopBuilder.CreateAlignedStore(Load, DstGEP, PartDstAlign, DstIsVolatile); - + StoreInst *Store = LoopBuilder.CreateAlignedStore( + Load, DstGEP, PartDstAlign, DstIsVolatile); + if (!CanOverlap) { + // Indicate that stores don't overlap loads. + Store->setMetadata(LLVMContext::MD_noalias, MDNode::get(Ctx, NewScope)); + } + if (AtomicElementSize) { + Load->setAtomic(AtomicOrdering::Unordered); + Store->setAtomic(AtomicOrdering::Unordered); + } Value *NewIndex = LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(TypeOfCopyLen, 1U)); LoopIndex->addIncoming(NewIndex, LoopBB); @@ -93,7 +120,7 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr, SmallVector<Type *, 5> RemainingOps; TTI.getMemcpyLoopResidualLoweringType(RemainingOps, Ctx, RemainingBytes, SrcAS, DstAS, SrcAlign.value(), - DstAlign.value()); + DstAlign.value(), AtomicElementSize); for (auto OpTy : RemainingOps) { Align PartSrcAlign(commonAlignment(SrcAlign, BytesCopied)); @@ -101,6 +128,10 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr, // Calaculate the new index unsigned OperandSize = DL.getTypeStoreSize(OpTy); + assert( + (!AtomicElementSize || OperandSize % *AtomicElementSize == 0) && + "Atomic memcpy lowering is not supported for selected operand size"); + uint64_t GepIndex = BytesCopied / OperandSize; assert(GepIndex * OperandSize == BytesCopied && "Division should have no Remainder!"); @@ -111,9 +142,13 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr, : RBuilder.CreateBitCast(SrcAddr, SrcPtrType); Value *SrcGEP = RBuilder.CreateInBoundsGEP( OpTy, CastedSrc, ConstantInt::get(TypeOfCopyLen, GepIndex)); - Value *Load = + LoadInst *Load = RBuilder.CreateAlignedLoad(OpTy, SrcGEP, PartSrcAlign, SrcIsVolatile); - + if (!CanOverlap) { + // Set alias scope for loads. + Load->setMetadata(LLVMContext::MD_alias_scope, + MDNode::get(Ctx, NewScope)); + } // Cast destination to operand type and store. PointerType *DstPtrType = PointerType::get(OpTy, DstAS); Value *CastedDst = DstAddr->getType() == DstPtrType @@ -121,8 +156,16 @@ void llvm::createMemCpyLoopKnownSize(Instruction *InsertBefore, Value *SrcAddr, : RBuilder.CreateBitCast(DstAddr, DstPtrType); Value *DstGEP = RBuilder.CreateInBoundsGEP( OpTy, CastedDst, ConstantInt::get(TypeOfCopyLen, GepIndex)); - RBuilder.CreateAlignedStore(Load, DstGEP, PartDstAlign, DstIsVolatile); - + StoreInst *Store = RBuilder.CreateAlignedStore(Load, DstGEP, PartDstAlign, + DstIsVolatile); + if (!CanOverlap) { + // Indicate that stores don't overlap loads. + Store->setMetadata(LLVMContext::MD_noalias, MDNode::get(Ctx, NewScope)); + } + if (AtomicElementSize) { + Load->setAtomic(AtomicOrdering::Unordered); + Store->setAtomic(AtomicOrdering::Unordered); + } BytesCopied += OperandSize; } } @@ -134,8 +177,9 @@ void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore, Value *SrcAddr, Value *DstAddr, Value *CopyLen, Align SrcAlign, Align DstAlign, bool SrcIsVolatile, - bool DstIsVolatile, - const TargetTransformInfo &TTI) { + bool DstIsVolatile, bool CanOverlap, + const TargetTransformInfo &TTI, + Optional<uint32_t> AtomicElementSize) { BasicBlock *PreLoopBB = InsertBefore->getParent(); BasicBlock *PostLoopBB = PreLoopBB->splitBasicBlock(InsertBefore, "post-loop-memcpy-expansion"); @@ -143,12 +187,22 @@ void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore, Function *ParentFunc = PreLoopBB->getParent(); const DataLayout &DL = ParentFunc->getParent()->getDataLayout(); LLVMContext &Ctx = PreLoopBB->getContext(); + MDBuilder MDB(Ctx); + MDNode *NewDomain = MDB.createAnonymousAliasScopeDomain("MemCopyDomain"); + StringRef Name = "MemCopyAliasScope"; + MDNode *NewScope = MDB.createAnonymousAliasScope(NewDomain, Name); + unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace(); unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace(); Type *LoopOpType = TTI.getMemcpyLoopLoweringType( - Ctx, CopyLen, SrcAS, DstAS, SrcAlign.value(), DstAlign.value()); + Ctx, CopyLen, SrcAS, DstAS, SrcAlign.value(), DstAlign.value(), + AtomicElementSize); + assert((!AtomicElementSize || !LoopOpType->isVectorTy()) && + "Atomic memcpy lowering is not supported for vector operand type"); unsigned LoopOpSize = DL.getTypeStoreSize(LoopOpType); + assert((!AtomicElementSize || LoopOpSize % *AtomicElementSize == 0) && + "Atomic memcpy lowering is not supported for selected operand size"); IRBuilder<> PLBuilder(PreLoopBB->getTerminator()); @@ -183,19 +237,40 @@ void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore, LoopIndex->addIncoming(ConstantInt::get(CopyLenType, 0U), PreLoopBB); Value *SrcGEP = LoopBuilder.CreateInBoundsGEP(LoopOpType, SrcAddr, LoopIndex); - Value *Load = LoopBuilder.CreateAlignedLoad(LoopOpType, SrcGEP, PartSrcAlign, - SrcIsVolatile); + LoadInst *Load = LoopBuilder.CreateAlignedLoad(LoopOpType, SrcGEP, + PartSrcAlign, SrcIsVolatile); + if (!CanOverlap) { + // Set alias scope for loads. + Load->setMetadata(LLVMContext::MD_alias_scope, MDNode::get(Ctx, NewScope)); + } Value *DstGEP = LoopBuilder.CreateInBoundsGEP(LoopOpType, DstAddr, LoopIndex); - LoopBuilder.CreateAlignedStore(Load, DstGEP, PartDstAlign, DstIsVolatile); - + StoreInst *Store = + LoopBuilder.CreateAlignedStore(Load, DstGEP, PartDstAlign, DstIsVolatile); + if (!CanOverlap) { + // Indicate that stores don't overlap loads. + Store->setMetadata(LLVMContext::MD_noalias, MDNode::get(Ctx, NewScope)); + } + if (AtomicElementSize) { + Load->setAtomic(AtomicOrdering::Unordered); + Store->setAtomic(AtomicOrdering::Unordered); + } Value *NewIndex = LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(CopyLenType, 1U)); LoopIndex->addIncoming(NewIndex, LoopBB); - if (!LoopOpIsInt8) { - // Add in the - Value *RuntimeResidual = PLBuilder.CreateURem(CopyLen, CILoopOpSize); - Value *RuntimeBytesCopied = PLBuilder.CreateSub(CopyLen, RuntimeResidual); + bool requiresResidual = + !LoopOpIsInt8 && !(AtomicElementSize && LoopOpSize == AtomicElementSize); + if (requiresResidual) { + Type *ResLoopOpType = AtomicElementSize + ? Type::getIntNTy(Ctx, *AtomicElementSize * 8) + : Int8Type; + unsigned ResLoopOpSize = DL.getTypeStoreSize(ResLoopOpType); + assert((ResLoopOpSize == AtomicElementSize ? *AtomicElementSize : 1) && + "Store size is expected to match type size"); + + // Add in the + Value *RuntimeResidual = PLBuilder.CreateURem(CopyLen, CILoopOpSize); + Value *RuntimeBytesCopied = PLBuilder.CreateSub(CopyLen, RuntimeResidual); // Loop body for the residual copy. BasicBlock *ResLoopBB = BasicBlock::Create(Ctx, "loop-memcpy-residual", @@ -230,21 +305,34 @@ void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore, ResBuilder.CreatePHI(CopyLenType, 2, "residual-loop-index"); ResidualIndex->addIncoming(Zero, ResHeaderBB); - Value *SrcAsInt8 = - ResBuilder.CreateBitCast(SrcAddr, PointerType::get(Int8Type, SrcAS)); - Value *DstAsInt8 = - ResBuilder.CreateBitCast(DstAddr, PointerType::get(Int8Type, DstAS)); + Value *SrcAsResLoopOpType = ResBuilder.CreateBitCast( + SrcAddr, PointerType::get(ResLoopOpType, SrcAS)); + Value *DstAsResLoopOpType = ResBuilder.CreateBitCast( + DstAddr, PointerType::get(ResLoopOpType, DstAS)); Value *FullOffset = ResBuilder.CreateAdd(RuntimeBytesCopied, ResidualIndex); - Value *SrcGEP = - ResBuilder.CreateInBoundsGEP(Int8Type, SrcAsInt8, FullOffset); - Value *Load = ResBuilder.CreateAlignedLoad(Int8Type, SrcGEP, PartSrcAlign, - SrcIsVolatile); - Value *DstGEP = - ResBuilder.CreateInBoundsGEP(Int8Type, DstAsInt8, FullOffset); - ResBuilder.CreateAlignedStore(Load, DstGEP, PartDstAlign, DstIsVolatile); - - Value *ResNewIndex = - ResBuilder.CreateAdd(ResidualIndex, ConstantInt::get(CopyLenType, 1U)); + Value *SrcGEP = ResBuilder.CreateInBoundsGEP( + ResLoopOpType, SrcAsResLoopOpType, FullOffset); + LoadInst *Load = ResBuilder.CreateAlignedLoad(ResLoopOpType, SrcGEP, + PartSrcAlign, SrcIsVolatile); + if (!CanOverlap) { + // Set alias scope for loads. + Load->setMetadata(LLVMContext::MD_alias_scope, + MDNode::get(Ctx, NewScope)); + } + Value *DstGEP = ResBuilder.CreateInBoundsGEP( + ResLoopOpType, DstAsResLoopOpType, FullOffset); + StoreInst *Store = ResBuilder.CreateAlignedStore(Load, DstGEP, PartDstAlign, + DstIsVolatile); + if (!CanOverlap) { + // Indicate that stores don't overlap loads. + Store->setMetadata(LLVMContext::MD_noalias, MDNode::get(Ctx, NewScope)); + } + if (AtomicElementSize) { + Load->setAtomic(AtomicOrdering::Unordered); + Store->setAtomic(AtomicOrdering::Unordered); + } + Value *ResNewIndex = ResBuilder.CreateAdd( + ResidualIndex, ConstantInt::get(CopyLenType, ResLoopOpSize)); ResidualIndex->addIncoming(ResNewIndex, ResLoopBB); // Create the loop branch condition. @@ -297,7 +385,13 @@ static void createMemMoveLoop(Instruction *InsertBefore, Value *SrcAddr, Function *F = OrigBB->getParent(); const DataLayout &DL = F->getParent()->getDataLayout(); - Type *EltTy = SrcAddr->getType()->getPointerElementType(); + // TODO: Use different element type if possible? + IRBuilder<> CastBuilder(InsertBefore); + Type *EltTy = CastBuilder.getInt8Ty(); + Type *PtrTy = + CastBuilder.getInt8PtrTy(SrcAddr->getType()->getPointerAddressSpace()); + SrcAddr = CastBuilder.CreateBitCast(SrcAddr, PtrTy); + DstAddr = CastBuilder.CreateBitCast(DstAddr, PtrTy); // Create the a comparison of src and dst, based on which we jump to either // the forward-copy part of the function (if src >= dst) or the backwards-copy @@ -419,8 +513,21 @@ static void createMemSetLoop(Instruction *InsertBefore, Value *DstAddr, NewBB); } +template <typename T> +static bool canOverlap(MemTransferBase<T> *Memcpy, ScalarEvolution *SE) { + if (SE) { + auto *SrcSCEV = SE->getSCEV(Memcpy->getRawSource()); + auto *DestSCEV = SE->getSCEV(Memcpy->getRawDest()); + if (SE->isKnownPredicateAt(CmpInst::ICMP_NE, SrcSCEV, DestSCEV, Memcpy)) + return false; + } + return true; +} + void llvm::expandMemCpyAsLoop(MemCpyInst *Memcpy, - const TargetTransformInfo &TTI) { + const TargetTransformInfo &TTI, + ScalarEvolution *SE) { + bool CanOverlap = canOverlap(Memcpy, SE); if (ConstantInt *CI = dyn_cast<ConstantInt>(Memcpy->getLength())) { createMemCpyLoopKnownSize( /* InsertBefore */ Memcpy, @@ -431,6 +538,7 @@ void llvm::expandMemCpyAsLoop(MemCpyInst *Memcpy, /* DestAlign */ Memcpy->getDestAlign().valueOrOne(), /* SrcIsVolatile */ Memcpy->isVolatile(), /* DstIsVolatile */ Memcpy->isVolatile(), + /* CanOverlap */ CanOverlap, /* TargetTransformInfo */ TTI); } else { createMemCpyLoopUnknownSize( @@ -442,6 +550,7 @@ void llvm::expandMemCpyAsLoop(MemCpyInst *Memcpy, /* DestAlign */ Memcpy->getDestAlign().valueOrOne(), /* SrcIsVolatile */ Memcpy->isVolatile(), /* DstIsVolatile */ Memcpy->isVolatile(), + /* CanOverlap */ CanOverlap, /* TargetTransformInfo */ TTI); } } @@ -465,3 +574,35 @@ void llvm::expandMemSetAsLoop(MemSetInst *Memset) { /* Alignment */ Memset->getDestAlign().valueOrOne(), Memset->isVolatile()); } + +void llvm::expandAtomicMemCpyAsLoop(AtomicMemCpyInst *AtomicMemcpy, + const TargetTransformInfo &TTI, + ScalarEvolution *SE) { + if (ConstantInt *CI = dyn_cast<ConstantInt>(AtomicMemcpy->getLength())) { + createMemCpyLoopKnownSize( + /* InsertBefore */ AtomicMemcpy, + /* SrcAddr */ AtomicMemcpy->getRawSource(), + /* DstAddr */ AtomicMemcpy->getRawDest(), + /* CopyLen */ CI, + /* SrcAlign */ AtomicMemcpy->getSourceAlign().valueOrOne(), + /* DestAlign */ AtomicMemcpy->getDestAlign().valueOrOne(), + /* SrcIsVolatile */ AtomicMemcpy->isVolatile(), + /* DstIsVolatile */ AtomicMemcpy->isVolatile(), + /* CanOverlap */ false, // SrcAddr & DstAddr may not overlap by spec. + /* TargetTransformInfo */ TTI, + /* AtomicCpySize */ AtomicMemcpy->getElementSizeInBytes()); + } else { + createMemCpyLoopUnknownSize( + /* InsertBefore */ AtomicMemcpy, + /* SrcAddr */ AtomicMemcpy->getRawSource(), + /* DstAddr */ AtomicMemcpy->getRawDest(), + /* CopyLen */ AtomicMemcpy->getLength(), + /* SrcAlign */ AtomicMemcpy->getSourceAlign().valueOrOne(), + /* DestAlign */ AtomicMemcpy->getDestAlign().valueOrOne(), + /* SrcIsVolatile */ AtomicMemcpy->isVolatile(), + /* DstIsVolatile */ AtomicMemcpy->isVolatile(), + /* CanOverlap */ false, // SrcAddr & DstAddr may not overlap by spec. + /* TargetTransformInfo */ TTI, + /* AtomicCpySize */ AtomicMemcpy->getElementSizeInBytes()); + } +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp index aff9d1311688..44aeb26fadf9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp @@ -119,25 +119,27 @@ raw_ostream &operator<<(raw_ostream &O, const CaseVector &C) { void FixPhis( BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, const unsigned NumMergedCases = std::numeric_limits<unsigned>::max()) { - for (BasicBlock::iterator I = SuccBB->begin(), - IE = SuccBB->getFirstNonPHI()->getIterator(); - I != IE; ++I) { - PHINode *PN = cast<PHINode>(I); + for (auto &I : SuccBB->phis()) { + PHINode *PN = cast<PHINode>(&I); - // Only update the first occurrence. + // Only update the first occurrence if NewBB exists. unsigned Idx = 0, E = PN->getNumIncomingValues(); unsigned LocalNumMergedCases = NumMergedCases; - for (; Idx != E; ++Idx) { + for (; Idx != E && NewBB; ++Idx) { if (PN->getIncomingBlock(Idx) == OrigBB) { PN->setIncomingBlock(Idx, NewBB); break; } } + // Skip the updated incoming block so that it will not be removed. + if (NewBB) + ++Idx; + // Remove additional occurrences coming from condensed cases and keep the // number of incoming values equal to the number of branches to SuccBB. SmallVector<unsigned, 8> Indices; - for (++Idx; LocalNumMergedCases > 0 && Idx < E; ++Idx) + for (; LocalNumMergedCases > 0 && Idx < E; ++Idx) if (PN->getIncomingBlock(Idx) == OrigBB) { Indices.push_back(Idx); LocalNumMergedCases--; @@ -195,6 +197,13 @@ BasicBlock *NewLeafBlock(CaseRange &Leaf, Value *Val, ConstantInt *LowerBound, BasicBlock *Succ = Leaf.BB; BranchInst::Create(Succ, Default, Comp, NewLeaf); + // Update the PHI incoming value/block for the default. + for (auto &I : Default->phis()) { + PHINode *PN = cast<PHINode>(&I); + auto *V = PN->getIncomingValueForBlock(OrigBlock); + PN->addIncoming(V, NewLeaf); + } + // If there were any PHI nodes in this successor, rewrite one entry // from OrigBlock to come from NewLeaf. for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) { @@ -494,19 +503,17 @@ void ProcessSwitchInst(SwitchInst *SI, Val = SI->getCondition(); } - // Create a new, empty default block so that the new hierarchy of - // if-then statements go to this and the PHI nodes are happy. - BasicBlock *NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); - F->getBasicBlockList().insert(Default->getIterator(), NewDefault); - BranchInst::Create(Default, NewDefault); - BasicBlock *SwitchBlock = SwitchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val, - OrigBlock, OrigBlock, NewDefault, UnreachableRanges); - - // If there are entries in any PHI nodes for the default edge, make sure - // to update them as well. - FixPhis(Default, OrigBlock, NewDefault); + OrigBlock, OrigBlock, Default, UnreachableRanges); + + // We have added incoming values for newly-created predecessors in + // NewLeafBlock(). The only meaningful work we offload to FixPhis() is to + // remove the incoming values from OrigBlock. There might be a special case + // that SwitchBlock is the same as Default, under which the PHIs in Default + // are fixed inside SwitchConvert(). + if (SwitchBlock != Default) + FixPhis(Default, OrigBlock, nullptr); // Branch to our shiny new if-then stuff... BranchInst::Create(SwitchBlock, OrigBlock); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp new file mode 100644 index 000000000000..a1029475cf1d --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp @@ -0,0 +1,195 @@ +//== MemoryTaggingSupport.cpp - helpers for memory tagging implementations ===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares common infrastructure for HWAddressSanitizer and +// Aarch64StackTagging. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/MemoryTaggingSupport.h" + +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/IntrinsicInst.h" + +namespace llvm { +namespace memtag { +namespace { +bool maybeReachableFromEachOther(const SmallVectorImpl<IntrinsicInst *> &Insts, + const DominatorTree *DT, const LoopInfo *LI, + size_t MaxLifetimes) { + // If we have too many lifetime ends, give up, as the algorithm below is N^2. + if (Insts.size() > MaxLifetimes) + return true; + for (size_t I = 0; I < Insts.size(); ++I) { + for (size_t J = 0; J < Insts.size(); ++J) { + if (I == J) + continue; + if (isPotentiallyReachable(Insts[I], Insts[J], nullptr, DT, LI)) + return true; + } + } + return false; +} +} // namespace + +bool forAllReachableExits(const DominatorTree &DT, const PostDominatorTree &PDT, + const LoopInfo &LI, const Instruction *Start, + const SmallVectorImpl<IntrinsicInst *> &Ends, + const SmallVectorImpl<Instruction *> &RetVec, + llvm::function_ref<void(Instruction *)> Callback) { + if (Ends.size() == 1 && PDT.dominates(Ends[0], Start)) { + Callback(Ends[0]); + return true; + } + SmallPtrSet<BasicBlock *, 2> EndBlocks; + for (auto *End : Ends) { + EndBlocks.insert(End->getParent()); + } + SmallVector<Instruction *, 8> ReachableRetVec; + unsigned NumCoveredExits = 0; + for (auto *RI : RetVec) { + if (!isPotentiallyReachable(Start, RI, nullptr, &DT, &LI)) + continue; + ReachableRetVec.push_back(RI); + // If there is an end in the same basic block as the return, we know for + // sure that the return is covered. Otherwise, we can check whether there + // is a way to reach the RI from the start of the lifetime without passing + // through an end. + if (EndBlocks.count(RI->getParent()) > 0 || + !isPotentiallyReachable(Start, RI, &EndBlocks, &DT, &LI)) { + ++NumCoveredExits; + } + } + // If there's a mix of covered and non-covered exits, just put the untag + // on exits, so we avoid the redundancy of untagging twice. + if (NumCoveredExits == ReachableRetVec.size()) { + for (auto *End : Ends) + Callback(End); + } else { + for (auto *RI : ReachableRetVec) + Callback(RI); + // We may have inserted untag outside of the lifetime interval. + // Signal the caller to remove the lifetime end call for this alloca. + return false; + } + return true; +} + +bool isStandardLifetime(const SmallVectorImpl<IntrinsicInst *> &LifetimeStart, + const SmallVectorImpl<IntrinsicInst *> &LifetimeEnd, + const DominatorTree *DT, const LoopInfo *LI, + size_t MaxLifetimes) { + // An alloca that has exactly one start and end in every possible execution. + // If it has multiple ends, they have to be unreachable from each other, so + // at most one of them is actually used for each execution of the function. + return LifetimeStart.size() == 1 && + (LifetimeEnd.size() == 1 || + (LifetimeEnd.size() > 0 && + !maybeReachableFromEachOther(LifetimeEnd, DT, LI, MaxLifetimes))); +} + +Instruction *getUntagLocationIfFunctionExit(Instruction &Inst) { + if (isa<ReturnInst>(Inst)) { + if (CallInst *CI = Inst.getParent()->getTerminatingMustTailCall()) + return CI; + return &Inst; + } + if (isa<ResumeInst, CleanupReturnInst>(Inst)) { + return &Inst; + } + return nullptr; +} + +void StackInfoBuilder::visit(Instruction &Inst) { + if (CallInst *CI = dyn_cast<CallInst>(&Inst)) { + if (CI->canReturnTwice()) { + Info.CallsReturnTwice = true; + } + } + if (AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) { + if (IsInterestingAlloca(*AI)) { + Info.AllocasToInstrument[AI].AI = AI; + } + return; + } + auto *II = dyn_cast<IntrinsicInst>(&Inst); + if (II && (II->getIntrinsicID() == Intrinsic::lifetime_start || + II->getIntrinsicID() == Intrinsic::lifetime_end)) { + AllocaInst *AI = findAllocaForValue(II->getArgOperand(1)); + if (!AI) { + Info.UnrecognizedLifetimes.push_back(&Inst); + return; + } + if (!IsInterestingAlloca(*AI)) + return; + if (II->getIntrinsicID() == Intrinsic::lifetime_start) + Info.AllocasToInstrument[AI].LifetimeStart.push_back(II); + else + Info.AllocasToInstrument[AI].LifetimeEnd.push_back(II); + return; + } + if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&Inst)) { + for (Value *V : DVI->location_ops()) { + if (auto *AI = dyn_cast_or_null<AllocaInst>(V)) { + if (!IsInterestingAlloca(*AI)) + continue; + AllocaInfo &AInfo = Info.AllocasToInstrument[AI]; + auto &DVIVec = AInfo.DbgVariableIntrinsics; + if (DVIVec.empty() || DVIVec.back() != DVI) + DVIVec.push_back(DVI); + } + } + } + Instruction *ExitUntag = getUntagLocationIfFunctionExit(Inst); + if (ExitUntag) + Info.RetVec.push_back(ExitUntag); +} + +uint64_t getAllocaSizeInBytes(const AllocaInst &AI) { + auto DL = AI.getModule()->getDataLayout(); + return *AI.getAllocationSizeInBits(DL) / 8; +} + +void alignAndPadAlloca(memtag::AllocaInfo &Info, llvm::Align Alignment) { + const Align NewAlignment = std::max(Info.AI->getAlign(), Alignment); + Info.AI->setAlignment(NewAlignment); + auto &Ctx = Info.AI->getFunction()->getContext(); + + uint64_t Size = getAllocaSizeInBytes(*Info.AI); + uint64_t AlignedSize = alignTo(Size, Alignment); + if (Size == AlignedSize) + return; + + // Add padding to the alloca. + Type *AllocatedType = + Info.AI->isArrayAllocation() + ? ArrayType::get( + Info.AI->getAllocatedType(), + cast<ConstantInt>(Info.AI->getArraySize())->getZExtValue()) + : Info.AI->getAllocatedType(); + Type *PaddingType = ArrayType::get(Type::getInt8Ty(Ctx), AlignedSize - Size); + Type *TypeWithPadding = StructType::get(AllocatedType, PaddingType); + auto *NewAI = + new AllocaInst(TypeWithPadding, Info.AI->getType()->getAddressSpace(), + nullptr, "", Info.AI); + NewAI->takeName(Info.AI); + NewAI->setAlignment(Info.AI->getAlign()); + NewAI->setUsedWithInAlloca(Info.AI->isUsedWithInAlloca()); + NewAI->setSwiftError(Info.AI->isSwiftError()); + NewAI->copyMetadata(*Info.AI); + + auto *NewPtr = new BitCastInst(NewAI, Info.AI->getType(), "", Info.AI); + Info.AI->replaceAllUsesWith(NewPtr); + Info.AI->eraseFromParent(); + Info.AI = NewAI; +} + +} // namespace memtag +} // namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/MisExpect.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/MisExpect.cpp new file mode 100644 index 000000000000..b73d68ebec7c --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/MisExpect.cpp @@ -0,0 +1,249 @@ +//===--- MisExpect.cpp - Check the use of llvm.expect with PGO data -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This contains code to emit warnings for potentially incorrect usage of the +// llvm.expect intrinsic. This utility extracts the threshold values from +// metadata associated with the instrumented Branch or Switch instruction. The +// threshold values are then used to determine if a warning should be emmited. +// +// MisExpect's implementation relies on two assumptions about how branch weights +// are managed in LLVM. +// +// 1) Frontend profiling weights are always in place before llvm.expect is +// lowered in LowerExpectIntrinsic.cpp. Frontend based instrumentation therefore +// needs to extract the branch weights and then compare them to the weights +// being added by the llvm.expect intrinsic lowering. +// +// 2) Sampling and IR based profiles will *only* have branch weight metadata +// before profiling data is consulted if they are from a lowered llvm.expect +// intrinsic. These profiles thus always extract the expected weights and then +// compare them to the weights collected during profiling to determine if a +// diagnostic message is warranted. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/MisExpect.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include <cstdint> +#include <functional> +#include <numeric> + +#define DEBUG_TYPE "misexpect" + +using namespace llvm; +using namespace misexpect; + +namespace llvm { + +// Command line option to enable/disable the warning when profile data suggests +// a mismatch with the use of the llvm.expect intrinsic +static cl::opt<bool> PGOWarnMisExpect( + "pgo-warn-misexpect", cl::init(false), cl::Hidden, + cl::desc("Use this option to turn on/off " + "warnings about incorrect usage of llvm.expect intrinsics.")); + +static cl::opt<unsigned> MisExpectTolerance( + "misexpect-tolerance", cl::init(0), + cl::desc("Prevents emiting diagnostics when profile counts are " + "within N% of the threshold..")); + +} // namespace llvm + +namespace { + +bool isMisExpectDiagEnabled(LLVMContext &Ctx) { + return PGOWarnMisExpect || Ctx.getMisExpectWarningRequested(); +} + +uint64_t getMisExpectTolerance(LLVMContext &Ctx) { + return std::max(static_cast<uint64_t>(MisExpectTolerance), + Ctx.getDiagnosticsMisExpectTolerance()); +} + +Instruction *getInstCondition(Instruction *I) { + assert(I != nullptr && "MisExpect target Instruction cannot be nullptr"); + Instruction *Ret = nullptr; + if (auto *B = dyn_cast<BranchInst>(I)) { + Ret = dyn_cast<Instruction>(B->getCondition()); + } + // TODO: Find a way to resolve condition location for switches + // Using the condition of the switch seems to often resolve to an earlier + // point in the program, i.e. the calculation of the switch condition, rather + // than the switch's location in the source code. Thus, we should use the + // instruction to get source code locations rather than the condition to + // improve diagnostic output, such as the caret. If the same problem exists + // for branch instructions, then we should remove this function and directly + // use the instruction + // + else if (auto *S = dyn_cast<SwitchInst>(I)) { + Ret = dyn_cast<Instruction>(S->getCondition()); + } + return Ret ? Ret : I; +} + +void emitMisexpectDiagnostic(Instruction *I, LLVMContext &Ctx, + uint64_t ProfCount, uint64_t TotalCount) { + double PercentageCorrect = (double)ProfCount / TotalCount; + auto PerString = + formatv("{0:P} ({1} / {2})", PercentageCorrect, ProfCount, TotalCount); + auto RemStr = formatv( + "Potential performance regression from use of the llvm.expect intrinsic: " + "Annotation was correct on {0} of profiled executions.", + PerString); + Twine Msg(PerString); + Instruction *Cond = getInstCondition(I); + if (isMisExpectDiagEnabled(Ctx)) + Ctx.diagnose(DiagnosticInfoMisExpect(Cond, Msg)); + OptimizationRemarkEmitter ORE(I->getParent()->getParent()); + ORE.emit(OptimizationRemark(DEBUG_TYPE, "misexpect", Cond) << RemStr.str()); +} + +} // namespace + +namespace llvm { +namespace misexpect { + +// Helper function to extract branch weights into a vector +Optional<SmallVector<uint32_t, 4>> extractWeights(Instruction *I, + LLVMContext &Ctx) { + assert(I && "MisExpect::extractWeights given invalid pointer"); + + auto *ProfileData = I->getMetadata(LLVMContext::MD_prof); + if (!ProfileData) + return None; + + unsigned NOps = ProfileData->getNumOperands(); + if (NOps < 3) + return None; + + auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0)); + if (!ProfDataName || !ProfDataName->getString().equals("branch_weights")) + return None; + + SmallVector<uint32_t, 4> Weights(NOps - 1); + for (unsigned Idx = 1; Idx < NOps; Idx++) { + ConstantInt *Value = + mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx)); + uint32_t V = Value->getZExtValue(); + Weights[Idx - 1] = V; + } + + return Weights; +} + +// TODO: when clang allows c++17, use std::clamp instead +uint32_t clamp(uint64_t value, uint32_t low, uint32_t hi) { + if (value > hi) + return hi; + if (value < low) + return low; + return value; +} + +void verifyMisExpect(Instruction &I, ArrayRef<uint32_t> RealWeights, + ArrayRef<uint32_t> ExpectedWeights) { + // To determine if we emit a diagnostic, we need to compare the branch weights + // from the profile to those added by the llvm.expect intrinsic. + // So first, we extract the "likely" and "unlikely" weights from + // ExpectedWeights And determine the correct weight in the profile to compare + // against. + uint64_t LikelyBranchWeight = 0, + UnlikelyBranchWeight = std::numeric_limits<uint32_t>::max(); + size_t MaxIndex = 0; + for (size_t Idx = 0, End = ExpectedWeights.size(); Idx < End; Idx++) { + uint32_t V = ExpectedWeights[Idx]; + if (LikelyBranchWeight < V) { + LikelyBranchWeight = V; + MaxIndex = Idx; + } + if (UnlikelyBranchWeight > V) { + UnlikelyBranchWeight = V; + } + } + + const uint64_t ProfiledWeight = RealWeights[MaxIndex]; + const uint64_t RealWeightsTotal = + std::accumulate(RealWeights.begin(), RealWeights.end(), (uint64_t)0, + std::plus<uint64_t>()); + const uint64_t NumUnlikelyTargets = RealWeights.size() - 1; + + uint64_t TotalBranchWeight = + LikelyBranchWeight + (UnlikelyBranchWeight * NumUnlikelyTargets); + + // FIXME: When we've addressed sample profiling, restore the assertion + // + // We cannot calculate branch probability if either of these invariants aren't + // met. However, MisExpect diagnostics should not prevent code from compiling, + // so we simply forgo emitting diagnostics here, and return early. + if ((TotalBranchWeight == 0) || (TotalBranchWeight <= LikelyBranchWeight)) + return; + + // To determine our threshold value we need to obtain the branch probability + // for the weights added by llvm.expect and use that proportion to calculate + // our threshold based on the collected profile data. + auto LikelyProbablilty = BranchProbability::getBranchProbability( + LikelyBranchWeight, TotalBranchWeight); + + uint64_t ScaledThreshold = LikelyProbablilty.scale(RealWeightsTotal); + + // clamp tolerance range to [0, 100) + auto Tolerance = getMisExpectTolerance(I.getContext()); + Tolerance = clamp(Tolerance, 0, 99); + + // Allow users to relax checking by N% i.e., if they use a 5% tolerance, + // then we check against 0.95*ScaledThreshold + if (Tolerance > 0) + ScaledThreshold *= (1.0 - Tolerance / 100.0); + + // When the profile weight is below the threshold, we emit the diagnostic + if (ProfiledWeight < ScaledThreshold) + emitMisexpectDiagnostic(&I, I.getContext(), ProfiledWeight, + RealWeightsTotal); +} + +void checkBackendInstrumentation(Instruction &I, + const ArrayRef<uint32_t> RealWeights) { + auto ExpectedWeightsOpt = extractWeights(&I, I.getContext()); + if (!ExpectedWeightsOpt) + return; + auto ExpectedWeights = ExpectedWeightsOpt.getValue(); + verifyMisExpect(I, RealWeights, ExpectedWeights); +} + +void checkFrontendInstrumentation(Instruction &I, + const ArrayRef<uint32_t> ExpectedWeights) { + auto RealWeightsOpt = extractWeights(&I, I.getContext()); + if (!RealWeightsOpt) + return; + auto RealWeights = RealWeightsOpt.getValue(); + verifyMisExpect(I, RealWeights, ExpectedWeights); +} + +void checkExpectAnnotations(Instruction &I, + const ArrayRef<uint32_t> ExistingWeights, + bool IsFrontendInstr) { + if (IsFrontendInstr) { + checkFrontendInstrumentation(I, ExistingWeights); + } else { + checkBackendInstrumentation(I, ExistingWeights); + } +} + +} // namespace misexpect +} // namespace llvm +#undef DEBUG_TYPE diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/ModuleUtils.cpp index d6a6be2762c7..5120ade70e16 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/ModuleUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/ModuleUtils.cpp @@ -11,7 +11,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/ModuleUtils.h" -#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" @@ -237,8 +236,8 @@ std::string llvm::getUniqueModuleId(Module *M) { return ("." + Str).str(); } -void VFABI::setVectorVariantNames( - CallInst *CI, const SmallVector<std::string, 8> &VariantMappings) { +void VFABI::setVectorVariantNames(CallInst *CI, + ArrayRef<std::string> VariantMappings) { if (VariantMappings.empty()) return; @@ -255,7 +254,7 @@ void VFABI::setVectorVariantNames( for (const std::string &VariantMapping : VariantMappings) { LLVM_DEBUG(dbgs() << "VFABI: adding mapping '" << VariantMapping << "'\n"); Optional<VFInfo> VI = VFABI::tryDemangleForVFABI(VariantMapping, *M); - assert(VI.hasValue() && "Cannot add an invalid VFABI name."); + assert(VI && "Cannot add an invalid VFABI name."); assert(M->getNamedValue(VI.getValue().VectorName) && "Cannot add variant to attribute: " "vector function declaration is missing."); @@ -266,14 +265,15 @@ void VFABI::setVectorVariantNames( } void llvm::embedBufferInModule(Module &M, MemoryBufferRef Buf, - StringRef SectionName) { - // Embed the buffer into the module. + StringRef SectionName, Align Alignment) { + // Embed the memory buffer into the module. Constant *ModuleConstant = ConstantDataArray::get( M.getContext(), makeArrayRef(Buf.getBufferStart(), Buf.getBufferSize())); GlobalVariable *GV = new GlobalVariable( M, ModuleConstant->getType(), true, GlobalValue::PrivateLinkage, ModuleConstant, "llvm.embedded.object"); GV->setSection(SectionName); + GV->setAlignment(Alignment); appendToCompilerUsed(M, GV); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/PredicateInfo.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/PredicateInfo.cpp index bd2b6fafdf2e..53334bc2a369 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/PredicateInfo.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/PredicateInfo.cpp @@ -15,19 +15,12 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/CFG.h" #include "llvm/IR/AssemblyAnnotationWriter.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" -#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" @@ -35,7 +28,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/DebugCounter.h" #include "llvm/Support/FormattedStream.h" -#include "llvm/Transforms/Utils.h" #include <algorithm> #define DEBUG_TYPE "predicateinfo" using namespace llvm; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp index 01b433b4782a..aff692b36288 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -20,7 +20,6 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/TinyPtrVector.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -32,7 +31,6 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/DIBuilder.h" #include "llvm/IR/DebugInfo.h" -#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" @@ -68,7 +66,7 @@ bool llvm::isAllocaPromotable(const AllocaInst *AI) { if (const LoadInst *LI = dyn_cast<LoadInst>(U)) { // Note that atomic loads can be transformed; atomic semantics do // not have any meaning for a local alloca. - if (LI->isVolatile()) + if (LI->isVolatile() || LI->getType() != AI->getAllocatedType()) return false; } else if (const StoreInst *SI = dyn_cast<StoreInst>(U)) { if (SI->getValueOperand() == AI || @@ -678,7 +676,7 @@ void PromoteMem2Reg::run() { A->eraseFromParent(); } - // Remove alloca's dbg.declare instrinsics from the function. + // Remove alloca's dbg.declare intrinsics from the function. for (auto &DbgUsers : AllocaDbgUsers) { for (auto *DII : DbgUsers) if (DII->isAddressOfVariable() || DII->getExpression()->startsWithDeref()) @@ -704,7 +702,7 @@ void PromoteMem2Reg::run() { PHINode *PN = I->second; // If this PHI node merges one value and/or undefs, get the value. - if (Value *V = SimplifyInstruction(PN, SQ)) { + if (Value *V = simplifyInstruction(PN, SQ)) { PN->replaceAllUsesWith(V); PN->eraseFromParent(); NewPhiNodes.erase(I++); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp index 65207056a3f4..926427450682 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp @@ -18,9 +18,6 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; @@ -38,11 +35,13 @@ static bool shouldConvertToRelLookupTable(Module &M, GlobalVariable &GV) { GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(GV.use_begin()->getUser()); - if (!GEP || !GEP->hasOneUse()) + if (!GEP || !GEP->hasOneUse() || + GV.getValueType() != GEP->getSourceElementType()) return false; LoadInst *Load = dyn_cast<LoadInst>(GEP->use_begin()->getUser()); - if (!Load || !Load->hasOneUse()) + if (!Load || !Load->hasOneUse() || + Load->getType() != GEP->getResultElementType()) return false; // If the original lookup table does not have local linkage and is @@ -144,7 +143,7 @@ static void convertToRelLookupTable(GlobalVariable &LookupTable) { Value *Offset = Builder.CreateShl(Index, ConstantInt::get(IntTy, 2), "reltable.shift"); - // Insert the call to load.relative instrinsic before LOAD. + // Insert the call to load.relative intrinsic before LOAD. // GEP might not be immediately followed by a LOAD, like it can be hoisted // outside the loop or another instruction might be inserted them in between. Builder.SetInsertPoint(Load); @@ -171,13 +170,17 @@ static void convertToRelLookupTable(GlobalVariable &LookupTable) { // Convert lookup tables to relative lookup tables in the module. static bool convertToRelativeLookupTables( Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) { - Module::iterator FI = M.begin(); - if (FI == M.end()) - return false; + for (Function &F : M) { + if (F.isDeclaration()) + continue; - // Check if we have a target that supports relative lookup tables. - if (!GetTTI(*FI).shouldBuildRelLookupTables()) - return false; + // Check if we have a target that supports relative lookup tables. + if (!GetTTI(F).shouldBuildRelLookupTables()) + return false; + + // We assume that the result is independent of the checked function. + break; + } bool Changed = false; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SCCPSolver.cpp index d7e8eaf677c6..eee91e70292e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -15,14 +15,12 @@ #include "llvm/Transforms/Utils/SCCPSolver.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/ValueTracking.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" +#include "llvm/Analysis/ValueLattice.h" +#include "llvm/IR/InstVisitor.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <utility> #include <vector> @@ -452,7 +450,8 @@ public: return TrackingIncomingArguments; } - void markArgInFuncSpecialization(Function *F, Argument *A, Constant *C); + void markArgInFuncSpecialization(Function *F, + const SmallVectorImpl<ArgInfo> &Args); void markFunctionUnreachable(Function *F) { for (auto &BB : *F) @@ -526,29 +525,38 @@ Constant *SCCPInstVisitor::getConstant(const ValueLatticeElement &LV) const { return nullptr; } -void SCCPInstVisitor::markArgInFuncSpecialization(Function *F, Argument *A, - Constant *C) { - assert(F->arg_size() == A->getParent()->arg_size() && +void SCCPInstVisitor::markArgInFuncSpecialization( + Function *F, const SmallVectorImpl<ArgInfo> &Args) { + assert(!Args.empty() && "Specialization without arguments"); + assert(F->arg_size() == Args[0].Formal->getParent()->arg_size() && "Functions should have the same number of arguments"); - // Mark the argument constant in the new function. - markConstant(A, C); - - // For the remaining arguments in the new function, copy the lattice state - // over from the old function. - for (auto I = F->arg_begin(), J = A->getParent()->arg_begin(), - E = F->arg_end(); - I != E; ++I, ++J) - if (J != A && ValueState.count(I)) { + auto Iter = Args.begin(); + Argument *NewArg = F->arg_begin(); + Argument *OldArg = Args[0].Formal->getParent()->arg_begin(); + for (auto End = F->arg_end(); NewArg != End; ++NewArg, ++OldArg) { + + LLVM_DEBUG(dbgs() << "SCCP: Marking argument " + << NewArg->getNameOrAsOperand() << "\n"); + + if (Iter != Args.end() && OldArg == Iter->Formal) { + // Mark the argument constants in the new function. + markConstant(NewArg, Iter->Actual); + ++Iter; + } else if (ValueState.count(OldArg)) { + // For the remaining arguments in the new function, copy the lattice state + // over from the old function. + // // Note: This previously looked like this: - // ValueState[J] = ValueState[I]; + // ValueState[NewArg] = ValueState[OldArg]; // This is incorrect because the DenseMap class may resize the underlying - // memory when inserting `J`, which will invalidate the reference to `I`. - // Instead, we make sure `J` exists, then set it to `I` afterwards. - auto &NewValue = ValueState[J]; - NewValue = ValueState[I]; - pushToWorkList(NewValue, J); + // memory when inserting `NewArg`, which will invalidate the reference to + // `OldArg`. Instead, we make sure `NewArg` exists before setting it. + auto &NewValue = ValueState[NewArg]; + NewValue = ValueState[OldArg]; + pushToWorkList(NewValue, NewArg); } + } } void SCCPInstVisitor::visitInstruction(Instruction &I) { @@ -988,7 +996,7 @@ void SCCPInstVisitor::visitBinaryOperator(Instruction &I) { if ((V1State.isConstant() || V2State.isConstant())) { Value *V1 = isConstant(V1State) ? getConstant(V1State) : I.getOperand(0); Value *V2 = isConstant(V2State) ? getConstant(V2State) : I.getOperand(1); - Value *R = SimplifyBinOp(I.getOpcode(), V1, V2, SimplifyQuery(DL)); + Value *R = simplifyBinOp(I.getOpcode(), V1, V2, SimplifyQuery(DL)); auto *C = dyn_cast_or_null<Constant>(R); if (C) { // X op Y -> undef. @@ -1287,17 +1295,6 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) { return; } - // TODO: Actually filp MayIncludeUndef for the created range to false, - // once most places in the optimizer respect the branches on - // undef/poison are UB rule. The reason why the new range cannot be - // undef is as follows below: - // The new range is based on a branch condition. That guarantees that - // neither of the compare operands can be undef in the branch targets, - // unless we have conditions that are always true/false (e.g. icmp ule - // i32, %a, i32_max). For the latter overdefined/empty range will be - // inferred, but the branch will get folded accordingly anyways. - bool MayIncludeUndef = !isa<PredicateAssume>(PI); - ValueLatticeElement CondVal = getValueState(OtherOp); ValueLatticeElement &IV = ValueState[&CB]; if (CondVal.isConstantRange() || CopyOfVal.isConstantRange()) { @@ -1322,9 +1319,15 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) { if (!CopyOfCR.contains(NewCR) && CopyOfCR.getSingleMissingElement()) NewCR = CopyOfCR; + // The new range is based on a branch condition. That guarantees that + // neither of the compare operands can be undef in the branch targets, + // unless we have conditions that are always true/false (e.g. icmp ule + // i32, %a, i32_max). For the latter overdefined/empty range will be + // inferred, but the branch will get folded accordingly anyways. addAdditionalUser(OtherOp, &CB); - mergeInValue(IV, &CB, - ValueLatticeElement::getRange(NewCR, MayIncludeUndef)); + mergeInValue( + IV, &CB, + ValueLatticeElement::getRange(NewCR, /*MayIncludeUndef*/ false)); return; } else if (Pred == CmpInst::ICMP_EQ && CondVal.isConstant()) { // For non-integer values or integer constant expressions, only @@ -1332,8 +1335,7 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) { addAdditionalUser(OtherOp, &CB); mergeInValue(IV, &CB, CondVal); return; - } else if (Pred == CmpInst::ICMP_NE && CondVal.isConstant() && - !MayIncludeUndef) { + } else if (Pred == CmpInst::ICMP_NE && CondVal.isConstant()) { // Propagate inequalities. addAdditionalUser(OtherOp, &CB); mergeInValue(IV, &CB, @@ -1442,22 +1444,19 @@ void SCCPInstVisitor::solve() { } } -/// resolvedUndefsIn - While solving the dataflow for a function, we assume -/// that branches on undef values cannot reach any of their successors. -/// However, this is not a safe assumption. After we solve dataflow, this -/// method should be use to handle this. If this returns true, the solver -/// should be rerun. +/// While solving the dataflow for a function, we don't compute a result for +/// operations with an undef operand, to allow undef to be lowered to a +/// constant later. For example, constant folding of "zext i8 undef to i16" +/// would result in "i16 0", and if undef is later lowered to "i8 1", then the +/// zext result would become "i16 1" and would result into an overdefined +/// lattice value once merged with the previous result. Not computing the +/// result of the zext (treating undef the same as unknown) allows us to handle +/// a later undef->constant lowering more optimally. /// -/// This method handles this by finding an unresolved branch and marking it one -/// of the edges from the block as being feasible, even though the condition -/// doesn't say it would otherwise be. This allows SCCP to find the rest of the -/// CFG and only slightly pessimizes the analysis results (by marking one, -/// potentially infeasible, edge feasible). This cannot usefully modify the -/// constraints on the condition of the branch, as that would impact other users -/// of the value. -/// -/// This scan also checks for values that use undefs. It conservatively marks -/// them as overdefined. +/// However, if the operand remains undef when the solver returns, we do need +/// to assign some result to the instruction (otherwise we would treat it as +/// unreachable). For simplicity, we mark any instructions that are still +/// unknown as overdefined. bool SCCPInstVisitor::resolvedUndefsIn(Function &F) { bool MadeChange = false; for (BasicBlock &BB : F) { @@ -1486,7 +1485,7 @@ bool SCCPInstVisitor::resolvedUndefsIn(Function &F) { // more precise than this but it isn't worth bothering. for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { ValueLatticeElement &LV = getStructValueState(&I, i); - if (LV.isUnknownOrUndef()) { + if (LV.isUnknown()) { markOverdefined(LV, &I); MadeChange = true; } @@ -1495,7 +1494,7 @@ bool SCCPInstVisitor::resolvedUndefsIn(Function &F) { } ValueLatticeElement &LV = getValueState(&I); - if (!LV.isUnknownOrUndef()) + if (!LV.isUnknown()) continue; // There are two reasons a call can have an undef result @@ -1518,91 +1517,6 @@ bool SCCPInstVisitor::resolvedUndefsIn(Function &F) { markOverdefined(&I); MadeChange = true; } - - // Check to see if we have a branch or switch on an undefined value. If so - // we force the branch to go one way or the other to make the successor - // values live. It doesn't really matter which way we force it. - Instruction *TI = BB.getTerminator(); - if (auto *BI = dyn_cast<BranchInst>(TI)) { - if (!BI->isConditional()) - continue; - if (!getValueState(BI->getCondition()).isUnknownOrUndef()) - continue; - - // If the input to SCCP is actually branch on undef, fix the undef to - // false. - if (isa<UndefValue>(BI->getCondition())) { - BI->setCondition(ConstantInt::getFalse(BI->getContext())); - markEdgeExecutable(&BB, TI->getSuccessor(1)); - MadeChange = true; - continue; - } - - // Otherwise, it is a branch on a symbolic value which is currently - // considered to be undef. Make sure some edge is executable, so a - // branch on "undef" always flows somewhere. - // FIXME: Distinguish between dead code and an LLVM "undef" value. - BasicBlock *DefaultSuccessor = TI->getSuccessor(1); - if (markEdgeExecutable(&BB, DefaultSuccessor)) - MadeChange = true; - - continue; - } - - if (auto *IBR = dyn_cast<IndirectBrInst>(TI)) { - // Indirect branch with no successor ?. Its ok to assume it branches - // to no target. - if (IBR->getNumSuccessors() < 1) - continue; - - if (!getValueState(IBR->getAddress()).isUnknownOrUndef()) - continue; - - // If the input to SCCP is actually branch on undef, fix the undef to - // the first successor of the indirect branch. - if (isa<UndefValue>(IBR->getAddress())) { - IBR->setAddress(BlockAddress::get(IBR->getSuccessor(0))); - markEdgeExecutable(&BB, IBR->getSuccessor(0)); - MadeChange = true; - continue; - } - - // Otherwise, it is a branch on a symbolic value which is currently - // considered to be undef. Make sure some edge is executable, so a - // branch on "undef" always flows somewhere. - // FIXME: IndirectBr on "undef" doesn't actually need to go anywhere: - // we can assume the branch has undefined behavior instead. - BasicBlock *DefaultSuccessor = IBR->getSuccessor(0); - if (markEdgeExecutable(&BB, DefaultSuccessor)) - MadeChange = true; - - continue; - } - - if (auto *SI = dyn_cast<SwitchInst>(TI)) { - if (!SI->getNumCases() || - !getValueState(SI->getCondition()).isUnknownOrUndef()) - continue; - - // If the input to SCCP is actually switch on undef, fix the undef to - // the first constant. - if (isa<UndefValue>(SI->getCondition())) { - SI->setCondition(SI->case_begin()->getCaseValue()); - markEdgeExecutable(&BB, SI->case_begin()->getCaseSuccessor()); - MadeChange = true; - continue; - } - - // Otherwise, it is a branch on a symbolic value which is currently - // considered to be undef. Make sure some edge is executable, so a - // branch on "undef" always flows somewhere. - // FIXME: Distinguish between dead code and an LLVM "undef" value. - BasicBlock *DefaultSuccessor = SI->case_begin()->getCaseSuccessor(); - if (markEdgeExecutable(&BB, DefaultSuccessor)) - MadeChange = true; - - continue; - } } return MadeChange; @@ -1618,7 +1532,7 @@ SCCPSolver::SCCPSolver( LLVMContext &Ctx) : Visitor(new SCCPInstVisitor(DL, std::move(GetTLI), Ctx)) {} -SCCPSolver::~SCCPSolver() {} +SCCPSolver::~SCCPSolver() = default; void SCCPSolver::addAnalysis(Function &F, AnalysisResultsForFn A) { return Visitor->addAnalysis(F, std::move(A)); @@ -1713,9 +1627,9 @@ SmallPtrSetImpl<Function *> &SCCPSolver::getArgumentTrackedFunctions() { return Visitor->getArgumentTrackedFunctions(); } -void SCCPSolver::markArgInFuncSpecialization(Function *F, Argument *A, - Constant *C) { - Visitor->markArgInFuncSpecialization(F, A, C); +void SCCPSolver::markArgInFuncSpecialization( + Function *F, const SmallVectorImpl<ArgInfo> &Args) { + Visitor->markArgInFuncSpecialization(F, Args); } void SCCPSolver::markFunctionUnreachable(Function *F) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SSAUpdater.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SSAUpdater.cpp index 7d9992176658..37019e3bf95b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SSAUpdater.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SSAUpdater.cpp @@ -25,7 +25,6 @@ #include "llvm/IR/Module.h" #include "llvm/IR/Use.h" #include "llvm/IR/Value.h" -#include "llvm/IR/ValueHandle.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -166,7 +165,7 @@ Value *SSAUpdater::GetValueInMiddleOfBlock(BasicBlock *BB) { // See if the PHI node can be merged to a single value. This can happen in // loop cases when we get a PHI of itself and one other value. if (Value *V = - SimplifyInstruction(InsertedPHI, BB->getModule()->getDataLayout())) { + simplifyInstruction(InsertedPHI, BB->getModule()->getDataLayout())) { InsertedPHI->eraseFromParent(); return V; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SampleProfileInference.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SampleProfileInference.cpp index 961adf2570a7..5e92b9852a9f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SampleProfileInference.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SampleProfileInference.cpp @@ -15,15 +15,46 @@ #include "llvm/Transforms/Utils/SampleProfileInference.h" #include "llvm/ADT/BitVector.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include <queue> #include <set> +#include <stack> using namespace llvm; #define DEBUG_TYPE "sample-profile-inference" namespace { +static cl::opt<bool> SampleProfileEvenCountDistribution( + "sample-profile-even-count-distribution", cl::init(true), cl::Hidden, + cl::desc("Try to evenly distribute counts when there are multiple equally " + "likely options.")); + +static cl::opt<unsigned> SampleProfileMaxDfsCalls( + "sample-profile-max-dfs-calls", cl::init(10), cl::Hidden, + cl::desc("Maximum number of dfs iterations for even count distribution.")); + +static cl::opt<unsigned> SampleProfileProfiCostInc( + "sample-profile-profi-cost-inc", cl::init(10), cl::Hidden, + cl::desc("A cost of increasing a block's count by one.")); + +static cl::opt<unsigned> SampleProfileProfiCostDec( + "sample-profile-profi-cost-dec", cl::init(20), cl::Hidden, + cl::desc("A cost of decreasing a block's count by one.")); + +static cl::opt<unsigned> SampleProfileProfiCostIncZero( + "sample-profile-profi-cost-inc-zero", cl::init(11), cl::Hidden, + cl::desc("A cost of increasing a count of zero-weight block by one.")); + +static cl::opt<unsigned> SampleProfileProfiCostIncEntry( + "sample-profile-profi-cost-inc-entry", cl::init(40), cl::Hidden, + cl::desc("A cost of increasing the entry block's count by one.")); + +static cl::opt<unsigned> SampleProfileProfiCostDecEntry( + "sample-profile-profi-cost-dec-entry", cl::init(10), cl::Hidden, + cl::desc("A cost of decreasing the entry block's count by one.")); + /// A value indicating an infinite flow/capacity/weight of a block/edge. /// Not using numeric_limits<int64_t>::max(), as the values can be summed up /// during the execution. @@ -52,16 +83,16 @@ public: Nodes = std::vector<Node>(NodeCount); Edges = std::vector<std::vector<Edge>>(NodeCount, std::vector<Edge>()); + if (SampleProfileEvenCountDistribution) + AugmentingEdges = + std::vector<std::vector<Edge *>>(NodeCount, std::vector<Edge *>()); } // Run the algorithm. int64_t run() { - // Find an augmenting path and update the flow along the path - size_t AugmentationIters = 0; - while (findAugmentingPath()) { - augmentFlowAlongPath(); - AugmentationIters++; - } + // Iteratively find an augmentation path/dag in the network and send the + // flow along its edges + size_t AugmentationIters = applyFlowAugmentation(); // Compute the total flow and its cost int64_t TotalCost = 0; @@ -79,6 +110,7 @@ public: << " iterations with " << TotalFlow << " total flow" << " of " << TotalCost << " cost\n"); (void)TotalFlow; + (void)AugmentationIters; return TotalCost; } @@ -134,20 +166,61 @@ public: return Flow; } - /// A cost of increasing a block's count by one. - static constexpr int64_t AuxCostInc = 10; - /// A cost of decreasing a block's count by one. - static constexpr int64_t AuxCostDec = 20; - /// A cost of increasing a count of zero-weight block by one. - static constexpr int64_t AuxCostIncZero = 11; - /// A cost of increasing the entry block's count by one. - static constexpr int64_t AuxCostIncEntry = 40; - /// A cost of decreasing the entry block's count by one. - static constexpr int64_t AuxCostDecEntry = 10; /// A cost of taking an unlikely jump. static constexpr int64_t AuxCostUnlikely = ((int64_t)1) << 30; + /// Minimum BaseDistance for the jump distance values in island joining. + static constexpr uint64_t MinBaseDistance = 10000; private: + /// Iteratively find an augmentation path/dag in the network and send the + /// flow along its edges. The method returns the number of applied iterations. + size_t applyFlowAugmentation() { + size_t AugmentationIters = 0; + while (findAugmentingPath()) { + uint64_t PathCapacity = computeAugmentingPathCapacity(); + while (PathCapacity > 0) { + bool Progress = false; + if (SampleProfileEvenCountDistribution) { + // Identify node/edge candidates for augmentation + identifyShortestEdges(PathCapacity); + + // Find an augmenting DAG + auto AugmentingOrder = findAugmentingDAG(); + + // Apply the DAG augmentation + Progress = augmentFlowAlongDAG(AugmentingOrder); + PathCapacity = computeAugmentingPathCapacity(); + } + + if (!Progress) { + augmentFlowAlongPath(PathCapacity); + PathCapacity = 0; + } + + AugmentationIters++; + } + } + return AugmentationIters; + } + + /// Compute the capacity of the cannonical augmenting path. If the path is + /// saturated (that is, no flow can be sent along the path), then return 0. + uint64_t computeAugmentingPathCapacity() { + uint64_t PathCapacity = INF; + uint64_t Now = Target; + while (Now != Source) { + uint64_t Pred = Nodes[Now].ParentNode; + auto &Edge = Edges[Pred][Nodes[Now].ParentEdgeIndex]; + + assert(Edge.Capacity >= Edge.Flow && "incorrect edge flow"); + uint64_t EdgeCapacity = uint64_t(Edge.Capacity - Edge.Flow); + PathCapacity = std::min(PathCapacity, EdgeCapacity); + + Now = Pred; + } + return PathCapacity; + } + /// Check for existence of an augmenting path with a positive capacity. bool findAugmentingPath() { // Initialize data structures @@ -180,7 +253,7 @@ private: // from Source to Target; it follows from inequalities // Dist[Source, Target] >= Dist[Source, V] + Dist[V, Target] // >= Dist[Source, V] - if (Nodes[Target].Distance == 0) + if (!SampleProfileEvenCountDistribution && Nodes[Target].Distance == 0) break; if (Nodes[Src].Distance > Nodes[Target].Distance) continue; @@ -210,21 +283,9 @@ private: } /// Update the current flow along the augmenting path. - void augmentFlowAlongPath() { - // Find path capacity - int64_t PathCapacity = INF; - uint64_t Now = Target; - while (Now != Source) { - uint64_t Pred = Nodes[Now].ParentNode; - auto &Edge = Edges[Pred][Nodes[Now].ParentEdgeIndex]; - PathCapacity = std::min(PathCapacity, Edge.Capacity - Edge.Flow); - Now = Pred; - } - + void augmentFlowAlongPath(uint64_t PathCapacity) { assert(PathCapacity > 0 && "found an incorrect augmenting path"); - - // Update the flow along the path - Now = Target; + uint64_t Now = Target; while (Now != Source) { uint64_t Pred = Nodes[Now].ParentNode; auto &Edge = Edges[Pred][Nodes[Now].ParentEdgeIndex]; @@ -237,6 +298,220 @@ private: } } + /// Find an Augmenting DAG order using a modified version of DFS in which we + /// can visit a node multiple times. In the DFS search, when scanning each + /// edge out of a node, continue search at Edge.Dst endpoint if it has not + /// been discovered yet and its NumCalls < MaxDfsCalls. The algorithm + /// runs in O(MaxDfsCalls * |Edges| + |Nodes|) time. + /// It returns an Augmenting Order (Taken nodes in decreasing Finish time) + /// that starts with Source and ends with Target. + std::vector<uint64_t> findAugmentingDAG() { + // We use a stack based implemenation of DFS to avoid recursion. + // Defining DFS data structures: + // A pair (NodeIdx, EdgeIdx) at the top of the Stack denotes that + // - we are currently visiting Nodes[NodeIdx] and + // - the next edge to scan is Edges[NodeIdx][EdgeIdx] + typedef std::pair<uint64_t, uint64_t> StackItemType; + std::stack<StackItemType> Stack; + std::vector<uint64_t> AugmentingOrder; + + // Phase 0: Initialize Node attributes and Time for DFS run + for (auto &Node : Nodes) { + Node.Discovery = 0; + Node.Finish = 0; + Node.NumCalls = 0; + Node.Taken = false; + } + uint64_t Time = 0; + // Mark Target as Taken + // Taken attribute will be propagated backwards from Target towards Source + Nodes[Target].Taken = true; + + // Phase 1: Start DFS traversal from Source + Stack.emplace(Source, 0); + Nodes[Source].Discovery = ++Time; + while (!Stack.empty()) { + auto NodeIdx = Stack.top().first; + auto EdgeIdx = Stack.top().second; + + // If we haven't scanned all edges out of NodeIdx, continue scanning + if (EdgeIdx < Edges[NodeIdx].size()) { + auto &Edge = Edges[NodeIdx][EdgeIdx]; + auto &Dst = Nodes[Edge.Dst]; + Stack.top().second++; + + if (Edge.OnShortestPath) { + // If we haven't seen Edge.Dst so far, continue DFS search there + if (Dst.Discovery == 0 && Dst.NumCalls < SampleProfileMaxDfsCalls) { + Dst.Discovery = ++Time; + Stack.emplace(Edge.Dst, 0); + Dst.NumCalls++; + } else if (Dst.Taken && Dst.Finish != 0) { + // Else, if Edge.Dst already have a path to Target, so that NodeIdx + Nodes[NodeIdx].Taken = true; + } + } + } else { + // If we are done scanning all edge out of NodeIdx + Stack.pop(); + // If we haven't found a path from NodeIdx to Target, forget about it + if (!Nodes[NodeIdx].Taken) { + Nodes[NodeIdx].Discovery = 0; + } else { + // If we have found a path from NodeIdx to Target, then finish NodeIdx + // and propagate Taken flag to DFS parent unless at the Source + Nodes[NodeIdx].Finish = ++Time; + // NodeIdx == Source if and only if the stack is empty + if (NodeIdx != Source) { + assert(!Stack.empty() && "empty stack while running dfs"); + Nodes[Stack.top().first].Taken = true; + } + AugmentingOrder.push_back(NodeIdx); + } + } + } + // Nodes are collected decreasing Finish time, so the order is reversed + std::reverse(AugmentingOrder.begin(), AugmentingOrder.end()); + + // Phase 2: Extract all forward (DAG) edges and fill in AugmentingEdges + for (size_t Src : AugmentingOrder) { + AugmentingEdges[Src].clear(); + for (auto &Edge : Edges[Src]) { + uint64_t Dst = Edge.Dst; + if (Edge.OnShortestPath && Nodes[Src].Taken && Nodes[Dst].Taken && + Nodes[Dst].Finish < Nodes[Src].Finish) { + AugmentingEdges[Src].push_back(&Edge); + } + } + assert((Src == Target || !AugmentingEdges[Src].empty()) && + "incorrectly constructed augmenting edges"); + } + + return AugmentingOrder; + } + + /// Update the current flow along the given (acyclic) subgraph specified by + /// the vertex order, AugmentingOrder. The objective is to send as much flow + /// as possible while evenly distributing flow among successors of each node. + /// After the update at least one edge is saturated. + bool augmentFlowAlongDAG(const std::vector<uint64_t> &AugmentingOrder) { + // Phase 0: Initialization + for (uint64_t Src : AugmentingOrder) { + Nodes[Src].FracFlow = 0; + Nodes[Src].IntFlow = 0; + for (auto &Edge : AugmentingEdges[Src]) { + Edge->AugmentedFlow = 0; + } + } + + // Phase 1: Send a unit of fractional flow along the DAG + uint64_t MaxFlowAmount = INF; + Nodes[Source].FracFlow = 1.0; + for (uint64_t Src : AugmentingOrder) { + assert((Src == Target || Nodes[Src].FracFlow > 0.0) && + "incorrectly computed fractional flow"); + // Distribute flow evenly among successors of Src + uint64_t Degree = AugmentingEdges[Src].size(); + for (auto &Edge : AugmentingEdges[Src]) { + double EdgeFlow = Nodes[Src].FracFlow / Degree; + Nodes[Edge->Dst].FracFlow += EdgeFlow; + if (Edge->Capacity == INF) + continue; + uint64_t MaxIntFlow = double(Edge->Capacity - Edge->Flow) / EdgeFlow; + MaxFlowAmount = std::min(MaxFlowAmount, MaxIntFlow); + } + } + // Stop early if we cannot send any (integral) flow from Source to Target + if (MaxFlowAmount == 0) + return false; + + // Phase 2: Send an integral flow of MaxFlowAmount + Nodes[Source].IntFlow = MaxFlowAmount; + for (uint64_t Src : AugmentingOrder) { + if (Src == Target) + break; + // Distribute flow evenly among successors of Src, rounding up to make + // sure all flow is sent + uint64_t Degree = AugmentingEdges[Src].size(); + // We are guaranteeed that Node[Src].IntFlow <= SuccFlow * Degree + uint64_t SuccFlow = (Nodes[Src].IntFlow + Degree - 1) / Degree; + for (auto &Edge : AugmentingEdges[Src]) { + uint64_t Dst = Edge->Dst; + uint64_t EdgeFlow = std::min(Nodes[Src].IntFlow, SuccFlow); + EdgeFlow = std::min(EdgeFlow, uint64_t(Edge->Capacity - Edge->Flow)); + Nodes[Dst].IntFlow += EdgeFlow; + Nodes[Src].IntFlow -= EdgeFlow; + Edge->AugmentedFlow += EdgeFlow; + } + } + assert(Nodes[Target].IntFlow <= MaxFlowAmount); + Nodes[Target].IntFlow = 0; + + // Phase 3: Send excess flow back traversing the nodes backwards. + // Because of rounding, not all flow can be sent along the edges of Src. + // Hence, sending the remaining flow back to maintain flow conservation + for (size_t Idx = AugmentingOrder.size() - 1; Idx > 0; Idx--) { + uint64_t Src = AugmentingOrder[Idx - 1]; + // Try to send excess flow back along each edge. + // Make sure we only send back flow we just augmented (AugmentedFlow). + for (auto &Edge : AugmentingEdges[Src]) { + uint64_t Dst = Edge->Dst; + if (Nodes[Dst].IntFlow == 0) + continue; + uint64_t EdgeFlow = std::min(Nodes[Dst].IntFlow, Edge->AugmentedFlow); + Nodes[Dst].IntFlow -= EdgeFlow; + Nodes[Src].IntFlow += EdgeFlow; + Edge->AugmentedFlow -= EdgeFlow; + } + } + + // Phase 4: Update flow values along all edges + bool HasSaturatedEdges = false; + for (uint64_t Src : AugmentingOrder) { + // Verify that we have sent all the excess flow from the node + assert(Src == Source || Nodes[Src].IntFlow == 0); + for (auto &Edge : AugmentingEdges[Src]) { + assert(uint64_t(Edge->Capacity - Edge->Flow) >= Edge->AugmentedFlow); + // Update flow values along the edge and its reverse copy + auto &RevEdge = Edges[Edge->Dst][Edge->RevEdgeIndex]; + Edge->Flow += Edge->AugmentedFlow; + RevEdge.Flow -= Edge->AugmentedFlow; + if (Edge->Capacity == Edge->Flow && Edge->AugmentedFlow > 0) + HasSaturatedEdges = true; + } + } + + // The augmentation is successful iff at least one edge becomes saturated + return HasSaturatedEdges; + } + + /// Identify candidate (shortest) edges for augmentation. + void identifyShortestEdges(uint64_t PathCapacity) { + assert(PathCapacity > 0 && "found an incorrect augmenting DAG"); + // To make sure the augmentation DAG contains only edges with large residual + // capacity, we prune all edges whose capacity is below a fraction of + // the capacity of the augmented path. + // (All edges of the path itself are always in the DAG) + uint64_t MinCapacity = std::max(PathCapacity / 2, uint64_t(1)); + + // Decide which edges are on a shortest path from Source to Target + for (size_t Src = 0; Src < Nodes.size(); Src++) { + // An edge cannot be augmenting if the endpoint has large distance + if (Nodes[Src].Distance > Nodes[Target].Distance) + continue; + + for (auto &Edge : Edges[Src]) { + uint64_t Dst = Edge.Dst; + Edge.OnShortestPath = + Src != Target && Dst != Source && + Nodes[Dst].Distance <= Nodes[Target].Distance && + Nodes[Dst].Distance == Nodes[Src].Distance + Edge.Cost && + Edge.Capacity > Edge.Flow && + uint64_t(Edge.Capacity - Edge.Flow) >= MinCapacity; + } + } + } + /// A node in a flow network. struct Node { /// The cost of the cheapest path from the source to the current node. @@ -247,7 +522,20 @@ private: uint64_t ParentEdgeIndex; /// An indicator of whether the current node is in a queue. bool Taken; + + /// Data fields utilized in DAG-augmentation: + /// Fractional flow. + double FracFlow; + /// Integral flow. + uint64_t IntFlow; + /// Discovery time. + uint64_t Discovery; + /// Finish time. + uint64_t Finish; + /// NumCalls. + uint64_t NumCalls; }; + /// An edge in a flow network. struct Edge { /// The cost of the edge. @@ -260,6 +548,12 @@ private: uint64_t Dst; /// The index of the reverse edge between Dst and the current node. uint64_t RevEdgeIndex; + + /// Data fields utilized in DAG-augmentation: + /// Whether the edge is currently on a shortest path from Source to Target. + bool OnShortestPath; + /// Extra flow along the edge. + uint64_t AugmentedFlow; }; /// The set of network nodes. @@ -270,8 +564,13 @@ private: uint64_t Source; /// Target (sink) node of the flow. uint64_t Target; + /// Augmenting edges. + std::vector<std::vector<Edge *>> AugmentingEdges; }; +constexpr int64_t MinCostMaxFlow::AuxCostUnlikely; +constexpr uint64_t MinCostMaxFlow::MinBaseDistance; + /// A post-processing adjustment of control flow. It applies two steps by /// rerouting some flow and making it more realistic: /// @@ -433,19 +732,22 @@ private: /// A distance of a path for a given jump. /// In order to incite the path to use blocks/jumps with large positive flow, /// and avoid changing branch probability of outgoing edges drastically, - /// set the distance as follows: - /// if Jump.Flow > 0, then distance = max(100 - Jump->Flow, 0) - /// if Block.Weight > 0, then distance = 1 - /// otherwise distance >> 1 + /// set the jump distance so as: + /// - to minimize the number of unlikely jumps used and subject to that, + /// - to minimize the number of Flow == 0 jumps used and subject to that, + /// - minimizes total multiplicative Flow increase for the remaining edges. + /// To capture this objective with integer distances, we round off fractional + /// parts to a multiple of 1 / BaseDistance. int64_t jumpDistance(FlowJump *Jump) const { - int64_t BaseDistance = 100; + uint64_t BaseDistance = + std::max(static_cast<uint64_t>(MinCostMaxFlow::MinBaseDistance), + std::min(Func.Blocks[Func.Entry].Flow, + MinCostMaxFlow::AuxCostUnlikely / NumBlocks())); if (Jump->IsUnlikely) return MinCostMaxFlow::AuxCostUnlikely; if (Jump->Flow > 0) - return std::max(BaseDistance - (int64_t)Jump->Flow, (int64_t)0); - if (Func.Blocks[Jump->Target].Weight > 0) - return BaseDistance; - return BaseDistance * (NumBlocks() + 1); + return BaseDistance + BaseDistance / Jump->Flow; + return BaseDistance * NumBlocks(); }; uint64_t NumBlocks() const { return Func.Blocks.size(); } @@ -511,7 +813,7 @@ private: std::vector<FlowBlock *> &KnownDstBlocks, std::vector<FlowBlock *> &UnknownBlocks) { // Run BFS from SrcBlock and make sure all paths are going through unknown - // blocks and end at a non-unknown DstBlock + // blocks and end at a known DstBlock auto Visited = BitVector(NumBlocks(), false); std::queue<uint64_t> Queue; @@ -778,8 +1080,8 @@ void initializeNetwork(MinCostMaxFlow &Network, FlowFunction &Func) { // We assume that decreasing block counts is more expensive than increasing, // and thus, setting separate costs here. In the future we may want to tune // the relative costs so as to maximize the quality of generated profiles. - int64_t AuxCostInc = MinCostMaxFlow::AuxCostInc; - int64_t AuxCostDec = MinCostMaxFlow::AuxCostDec; + int64_t AuxCostInc = SampleProfileProfiCostInc; + int64_t AuxCostDec = SampleProfileProfiCostDec; if (Block.UnknownWeight) { // Do not penalize changing weights of blocks w/o known profile count AuxCostInc = 0; @@ -788,12 +1090,12 @@ void initializeNetwork(MinCostMaxFlow &Network, FlowFunction &Func) { // Increasing the count for "cold" blocks with zero initial count is more // expensive than for "hot" ones if (Block.Weight == 0) { - AuxCostInc = MinCostMaxFlow::AuxCostIncZero; + AuxCostInc = SampleProfileProfiCostIncZero; } // Modifying the count of the entry block is expensive if (Block.isEntry()) { - AuxCostInc = MinCostMaxFlow::AuxCostIncEntry; - AuxCostDec = MinCostMaxFlow::AuxCostDecEntry; + AuxCostInc = SampleProfileProfiCostIncEntry; + AuxCostDec = SampleProfileProfiCostDecEntry; } } // For blocks with self-edges, do not penalize a reduction of the count, diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SampleProfileLoaderBaseUtil.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SampleProfileLoaderBaseUtil.cpp index ea0e8343eb88..a2588b8cec7d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SampleProfileLoaderBaseUtil.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SampleProfileLoaderBaseUtil.cpp @@ -11,6 +11,10 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Module.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" namespace llvm { @@ -35,9 +39,13 @@ cl::opt<bool> NoWarnSampleUnused( "samples but without debug information to use those samples. ")); cl::opt<bool> SampleProfileUseProfi( - "sample-profile-use-profi", cl::init(false), cl::Hidden, cl::ZeroOrMore, + "sample-profile-use-profi", cl::Hidden, cl::desc("Use profi to infer block and edge counts.")); +cl::opt<bool> SampleProfileInferEntryCount( + "sample-profile-infer-entry-count", cl::init(true), cl::Hidden, + cl::desc("Use profi to infer function entry count.")); + namespace sampleprofutil { /// Return true if the given callsite is hot wrt to hot cutoff threshold. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SanitizerStats.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SanitizerStats.cpp index a1313c77ed77..fd21ee4cc408 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SanitizerStats.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SanitizerStats.cpp @@ -11,7 +11,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/SanitizerStats.h" -#include "llvm/ADT/Triple.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalVariable.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp index 5363a851fc27..401f1ee5a55d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -22,11 +22,8 @@ #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -276,7 +273,9 @@ Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode, } // If we haven't found this binop, insert it. - Instruction *BO = cast<Instruction>(Builder.CreateBinOp(Opcode, LHS, RHS)); + // TODO: Use the Builder, which will make CreateBinOp below fold with + // InstSimplifyFolder. + Instruction *BO = Builder.Insert(BinaryOperator::Create(Opcode, LHS, RHS)); BO->setDebugLoc(Loc); if (Flags & SCEV::FlagNUW) BO->setHasNoUnsignedWrap(); @@ -591,7 +590,9 @@ Value *SCEVExpander::expandAddToGEP(const SCEV *const *op_begin, if (isa<DbgInfoIntrinsic>(IP)) ScanLimit++; if (IP->getOpcode() == Instruction::GetElementPtr && - IP->getOperand(0) == V && IP->getOperand(1) == Idx) + IP->getOperand(0) == V && IP->getOperand(1) == Idx && + cast<GEPOperator>(&*IP)->getSourceElementType() == + Type::getInt8Ty(Ty->getContext())) return &*IP; if (IP == BlockBegin) break; } @@ -1633,7 +1634,6 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { NewS = Ext; const SCEV *V = cast<SCEVAddRecExpr>(NewS)->evaluateAtIteration(IH, SE); - //cerr << "Evaluated: " << *this << "\n to: " << *V << "\n"; // Truncate the result down to the original type, if needed. const SCEV *T = SE.getTruncateOrNoop(V, Ty); @@ -1671,154 +1671,49 @@ Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) { return Builder.CreateSExt(V, Ty); } -Value *SCEVExpander::expandSMaxExpr(const SCEVNAryExpr *S) { - Value *LHS = expand(S->getOperand(S->getNumOperands()-1)); - Type *Ty = LHS->getType(); - for (int i = S->getNumOperands()-2; i >= 0; --i) { - // In the case of mixed integer and pointer types, do the - // rest of the comparisons as integer. - Type *OpTy = S->getOperand(i)->getType(); - if (OpTy->isIntegerTy() != Ty->isIntegerTy()) { - Ty = SE.getEffectiveSCEVType(Ty); - LHS = InsertNoopCastOfTo(LHS, Ty); - } - Value *RHS = expandCodeForImpl(S->getOperand(i), Ty, false); - Value *Sel; - if (Ty->isIntegerTy()) - Sel = Builder.CreateIntrinsic(Intrinsic::smax, {Ty}, {LHS, RHS}, - /*FMFSource=*/nullptr, "smax"); - else { - Value *ICmp = Builder.CreateICmpSGT(LHS, RHS); - Sel = Builder.CreateSelect(ICmp, LHS, RHS, "smax"); - } - LHS = Sel; - } - // In the case of mixed integer and pointer types, cast the - // final result back to the pointer type. - if (LHS->getType() != S->getType()) - LHS = InsertNoopCastOfTo(LHS, S->getType()); - return LHS; -} - -Value *SCEVExpander::expandUMaxExpr(const SCEVNAryExpr *S) { - Value *LHS = expand(S->getOperand(S->getNumOperands()-1)); - Type *Ty = LHS->getType(); - for (int i = S->getNumOperands()-2; i >= 0; --i) { - // In the case of mixed integer and pointer types, do the - // rest of the comparisons as integer. - Type *OpTy = S->getOperand(i)->getType(); - if (OpTy->isIntegerTy() != Ty->isIntegerTy()) { - Ty = SE.getEffectiveSCEVType(Ty); - LHS = InsertNoopCastOfTo(LHS, Ty); - } - Value *RHS = expandCodeForImpl(S->getOperand(i), Ty, false); - Value *Sel; - if (Ty->isIntegerTy()) - Sel = Builder.CreateIntrinsic(Intrinsic::umax, {Ty}, {LHS, RHS}, - /*FMFSource=*/nullptr, "umax"); - else { - Value *ICmp = Builder.CreateICmpUGT(LHS, RHS); - Sel = Builder.CreateSelect(ICmp, LHS, RHS, "umax"); - } - LHS = Sel; - } - // In the case of mixed integer and pointer types, cast the - // final result back to the pointer type. - if (LHS->getType() != S->getType()) - LHS = InsertNoopCastOfTo(LHS, S->getType()); - return LHS; -} - -Value *SCEVExpander::expandSMinExpr(const SCEVNAryExpr *S) { - Value *LHS = expand(S->getOperand(S->getNumOperands() - 1)); - Type *Ty = LHS->getType(); - for (int i = S->getNumOperands() - 2; i >= 0; --i) { - // In the case of mixed integer and pointer types, do the - // rest of the comparisons as integer. - Type *OpTy = S->getOperand(i)->getType(); - if (OpTy->isIntegerTy() != Ty->isIntegerTy()) { - Ty = SE.getEffectiveSCEVType(Ty); - LHS = InsertNoopCastOfTo(LHS, Ty); - } - Value *RHS = expandCodeForImpl(S->getOperand(i), Ty, false); - Value *Sel; - if (Ty->isIntegerTy()) - Sel = Builder.CreateIntrinsic(Intrinsic::smin, {Ty}, {LHS, RHS}, - /*FMFSource=*/nullptr, "smin"); - else { - Value *ICmp = Builder.CreateICmpSLT(LHS, RHS); - Sel = Builder.CreateSelect(ICmp, LHS, RHS, "smin"); - } - LHS = Sel; - } - // In the case of mixed integer and pointer types, cast the - // final result back to the pointer type. - if (LHS->getType() != S->getType()) - LHS = InsertNoopCastOfTo(LHS, S->getType()); - return LHS; -} - -Value *SCEVExpander::expandUMinExpr(const SCEVNAryExpr *S) { +Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S, + Intrinsic::ID IntrinID, Twine Name, + bool IsSequential) { Value *LHS = expand(S->getOperand(S->getNumOperands() - 1)); Type *Ty = LHS->getType(); + if (IsSequential) + LHS = Builder.CreateFreeze(LHS); for (int i = S->getNumOperands() - 2; i >= 0; --i) { - // In the case of mixed integer and pointer types, do the - // rest of the comparisons as integer. - Type *OpTy = S->getOperand(i)->getType(); - if (OpTy->isIntegerTy() != Ty->isIntegerTy()) { - Ty = SE.getEffectiveSCEVType(Ty); - LHS = InsertNoopCastOfTo(LHS, Ty); - } Value *RHS = expandCodeForImpl(S->getOperand(i), Ty, false); + if (IsSequential && i != 0) + RHS = Builder.CreateFreeze(RHS); Value *Sel; if (Ty->isIntegerTy()) - Sel = Builder.CreateIntrinsic(Intrinsic::umin, {Ty}, {LHS, RHS}, - /*FMFSource=*/nullptr, "umin"); + Sel = Builder.CreateIntrinsic(IntrinID, {Ty}, {LHS, RHS}, + /*FMFSource=*/nullptr, Name); else { - Value *ICmp = Builder.CreateICmpULT(LHS, RHS); - Sel = Builder.CreateSelect(ICmp, LHS, RHS, "umin"); + Value *ICmp = + Builder.CreateICmp(MinMaxIntrinsic::getPredicate(IntrinID), LHS, RHS); + Sel = Builder.CreateSelect(ICmp, LHS, RHS, Name); } LHS = Sel; } - // In the case of mixed integer and pointer types, cast the - // final result back to the pointer type. - if (LHS->getType() != S->getType()) - LHS = InsertNoopCastOfTo(LHS, S->getType()); return LHS; } Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) { - return expandSMaxExpr(S); + return expandMinMaxExpr(S, Intrinsic::smax, "smax"); } Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) { - return expandUMaxExpr(S); + return expandMinMaxExpr(S, Intrinsic::umax, "umax"); } Value *SCEVExpander::visitSMinExpr(const SCEVSMinExpr *S) { - return expandSMinExpr(S); + return expandMinMaxExpr(S, Intrinsic::smin, "smin"); } Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) { - return expandUMinExpr(S); + return expandMinMaxExpr(S, Intrinsic::umin, "umin"); } Value *SCEVExpander::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S) { - SmallVector<Value *> Ops; - for (const SCEV *Op : S->operands()) - Ops.emplace_back(expand(Op)); - - Value *SaturationPoint = - MinMaxIntrinsic::getSaturationPoint(Intrinsic::umin, S->getType()); - - SmallVector<Value *> OpIsZero; - for (Value *Op : ArrayRef<Value *>(Ops).drop_back()) - OpIsZero.emplace_back(Builder.CreateICmpEQ(Op, SaturationPoint)); - - Value *AnyOpIsZero = Builder.CreateLogicalOr(OpIsZero); - - Value *NaiveUMin = expandUMinExpr(S); - return Builder.CreateSelect(AnyOpIsZero, SaturationPoint, NaiveUMin); + return expandMinMaxExpr(S, Intrinsic::umin, "umin", /*IsSequential*/true); } Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, @@ -1868,35 +1763,33 @@ Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, bool Root) { return V; } -ScalarEvolution::ValueOffsetPair -SCEVExpander::FindValueInExprValueMap(const SCEV *S, - const Instruction *InsertPt) { - auto *Set = SE.getSCEVValues(S); +Value *SCEVExpander::FindValueInExprValueMap(const SCEV *S, + const Instruction *InsertPt) { // If the expansion is not in CanonicalMode, and the SCEV contains any // sub scAddRecExpr type SCEV, it is required to expand the SCEV literally. - if (CanonicalMode || !SE.containsAddRecurrence(S)) { - // If S is scConstant, it may be worse to reuse an existing Value. - if (S->getSCEVType() != scConstant && Set) { - // Choose a Value from the set which dominates the InsertPt. - // InsertPt should be inside the Value's parent loop so as not to break - // the LCSSA form. - for (auto const &VOPair : *Set) { - Value *V = VOPair.first; - ConstantInt *Offset = VOPair.second; - Instruction *EntInst = dyn_cast_or_null<Instruction>(V); - if (!EntInst) - continue; + if (!CanonicalMode && SE.containsAddRecurrence(S)) + return nullptr; - assert(EntInst->getFunction() == InsertPt->getFunction()); - if (S->getType() == V->getType() && - SE.DT.dominates(EntInst, InsertPt) && - (SE.LI.getLoopFor(EntInst->getParent()) == nullptr || - SE.LI.getLoopFor(EntInst->getParent())->contains(InsertPt))) - return {V, Offset}; - } - } + // If S is a constant, it may be worse to reuse an existing Value. + if (isa<SCEVConstant>(S)) + return nullptr; + + // Choose a Value from the set which dominates the InsertPt. + // InsertPt should be inside the Value's parent loop so as not to break + // the LCSSA form. + for (Value *V : SE.getSCEVValues(S)) { + Instruction *EntInst = dyn_cast<Instruction>(V); + if (!EntInst) + continue; + + assert(EntInst->getFunction() == InsertPt->getFunction()); + if (S->getType() == V->getType() && + SE.DT.dominates(EntInst, InsertPt) && + (SE.LI.getLoopFor(EntInst->getParent()) == nullptr || + SE.LI.getLoopFor(EntInst->getParent())->contains(InsertPt))) + return V; } - return {nullptr, nullptr}; + return nullptr; } // The expansion of SCEV will either reuse a previous Value in ExprValueMap, @@ -1965,9 +1858,7 @@ Value *SCEVExpander::expand(const SCEV *S) { Builder.SetInsertPoint(InsertPt); // Expand the expression into instructions. - ScalarEvolution::ValueOffsetPair VO = FindValueInExprValueMap(S, InsertPt); - Value *V = VO.first; - + Value *V = FindValueInExprValueMap(S, InsertPt); if (!V) V = visit(S); else { @@ -1978,21 +1869,6 @@ Value *SCEVExpander::expand(const SCEV *S) { if (auto *I = dyn_cast<Instruction>(V)) if (I->hasPoisonGeneratingFlags() && !programUndefinedIfPoison(I)) I->dropPoisonGeneratingFlags(); - - if (VO.second) { - if (PointerType *Vty = dyn_cast<PointerType>(V->getType())) { - int64_t Offset = VO.second->getSExtValue(); - ConstantInt *Idx = - ConstantInt::getSigned(VO.second->getType(), -Offset); - unsigned AS = Vty->getAddressSpace(); - V = Builder.CreateBitCast(V, Type::getInt8PtrTy(SE.getContext(), AS)); - V = Builder.CreateGEP(Type::getInt8Ty(SE.getContext()), V, Idx, - "uglygep"); - V = Builder.CreateBitCast(V, Vty); - } else { - V = Builder.CreateSub(V, VO.second); - } - } } // Remember the expanded value for this SCEV at this location. // @@ -2058,7 +1934,7 @@ SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, // so narrow phis can reuse them. for (PHINode *Phi : Phis) { auto SimplifyPHINode = [&](PHINode *PN) -> Value * { - if (Value *V = SimplifyInstruction(PN, {DL, &SE.TLI, &SE.DT, &SE.AC})) + if (Value *V = simplifyInstruction(PN, {DL, &SE.TLI, &SE.DT, &SE.AC})) return V; if (!SE.isSCEVable(PN->getType())) return nullptr; @@ -2174,9 +2050,9 @@ SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, return NumElim; } -Optional<ScalarEvolution::ValueOffsetPair> -SCEVExpander::getRelatedExistingExpansion(const SCEV *S, const Instruction *At, - Loop *L) { +Value *SCEVExpander::getRelatedExistingExpansion(const SCEV *S, + const Instruction *At, + Loop *L) { using namespace llvm::PatternMatch; SmallVector<BasicBlock *, 4> ExitingBlocks; @@ -2193,25 +2069,17 @@ SCEVExpander::getRelatedExistingExpansion(const SCEV *S, const Instruction *At, continue; if (SE.getSCEV(LHS) == S && SE.DT.dominates(LHS, At)) - return ScalarEvolution::ValueOffsetPair(LHS, nullptr); + return LHS; if (SE.getSCEV(RHS) == S && SE.DT.dominates(RHS, At)) - return ScalarEvolution::ValueOffsetPair(RHS, nullptr); + return RHS; } // Use expand's logic which is used for reusing a previous Value in // ExprValueMap. Note that we don't currently model the cost of // needing to drop poison generating flags on the instruction if we // want to reuse it. We effectively assume that has zero cost. - ScalarEvolution::ValueOffsetPair VO = FindValueInExprValueMap(S, At); - if (VO.first) - return VO; - - // There is potential to make this significantly smarter, but this simple - // heuristic already gets some interesting cases. - - // Can not find suitable value. - return None; + return FindValueInExprValueMap(S, At); } template<typename T> static InstructionCost costAndCollectOperands( @@ -2469,8 +2337,8 @@ Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred, switch (Pred->getKind()) { case SCEVPredicate::P_Union: return expandUnionPredicate(cast<SCEVUnionPredicate>(Pred), IP); - case SCEVPredicate::P_Equal: - return expandEqualPredicate(cast<SCEVEqualPredicate>(Pred), IP); + case SCEVPredicate::P_Compare: + return expandComparePredicate(cast<SCEVComparePredicate>(Pred), IP); case SCEVPredicate::P_Wrap: { auto *AddRecPred = cast<SCEVWrapPredicate>(Pred); return expandWrapPredicate(AddRecPred, IP); @@ -2479,15 +2347,16 @@ Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred, llvm_unreachable("Unknown SCEV predicate type"); } -Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred, - Instruction *IP) { +Value *SCEVExpander::expandComparePredicate(const SCEVComparePredicate *Pred, + Instruction *IP) { Value *Expr0 = expandCodeForImpl(Pred->getLHS(), Pred->getLHS()->getType(), IP, false); Value *Expr1 = expandCodeForImpl(Pred->getRHS(), Pred->getRHS()->getType(), IP, false); Builder.SetInsertPoint(IP); - auto *I = Builder.CreateICmpNE(Expr0, Expr1, "ident.check"); + auto InvPred = ICmpInst::getInversePredicate(Pred->getPredicate()); + auto *I = Builder.CreateICmp(InvPred, Expr0, Expr1, "ident.check"); return I; } @@ -2496,7 +2365,8 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, assert(AR->isAffine() && "Cannot generate RT check for " "non-affine expression"); - SCEVUnionPredicate Pred; + // FIXME: It is highly suspicious that we're ignoring the predicates here. + SmallVector<const SCEVPredicate *, 4> Pred; const SCEV *ExitCount = SE.getPredicatedBackedgeTakenCount(AR->getLoop(), Pred); @@ -2710,10 +2580,10 @@ namespace { struct SCEVFindUnsafe { ScalarEvolution &SE; bool CanonicalMode; - bool IsUnsafe; + bool IsUnsafe = false; SCEVFindUnsafe(ScalarEvolution &SE, bool CanonicalMode) - : SE(SE), CanonicalMode(CanonicalMode), IsUnsafe(false) {} + : SE(SE), CanonicalMode(CanonicalMode) {} bool follow(const SCEV *S) { if (const SCEVUDivExpr *D = dyn_cast<SCEVUDivExpr>(S)) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 8c4e1b381b4d..567b866f7777 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -27,7 +27,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/ConstantFolding.h" -#include "llvm/Analysis/EHPersonalities.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemorySSA.h" @@ -50,7 +50,6 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" @@ -58,7 +57,6 @@ #include "llvm/IR/NoFolder.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/IR/PseudoProbe.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" @@ -74,7 +72,6 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" -#include "llvm/Transforms/Utils/SSAUpdater.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> #include <cassert> @@ -94,8 +91,8 @@ using namespace PatternMatch; #define DEBUG_TYPE "simplifycfg" cl::opt<bool> llvm::RequireAndPreserveDomTree( - "simplifycfg-require-and-preserve-domtree", cl::Hidden, cl::ZeroOrMore, - cl::init(false), + "simplifycfg-require-and-preserve-domtree", cl::Hidden, + cl::desc("Temorary development switch used to gradually uplift SimplifyCFG " "into preserving DomTree,")); @@ -167,6 +164,14 @@ static cl::opt<unsigned> BranchFoldToCommonDestVectorMultiplier( "to fold branch to common destination when vector operations are " "present")); +static cl::opt<bool> EnableMergeCompatibleInvokes( + "simplifycfg-merge-compatible-invokes", cl::Hidden, cl::init(true), + cl::desc("Allow SimplifyCFG to merge invokes together when appropriate")); + +static cl::opt<unsigned> MaxSwitchCasesPerResult( + "max-switch-cases-per-result", cl::Hidden, cl::init(16), + cl::desc("Limit cases to analyze when converting a switch to select")); + STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps"); STATISTIC(NumLinearMaps, "Number of switch instructions turned into linear mapping"); @@ -192,6 +197,8 @@ STATISTIC(NumSinkCommonInstrs, STATISTIC(NumSpeculations, "Number of speculative executed instructions"); STATISTIC(NumInvokes, "Number of invokes with empty resume blocks simplified into calls"); +STATISTIC(NumInvokesMerged, "Number of invokes that were merged together"); +STATISTIC(NumInvokeSetsFormed, "Number of invoke sets that were formed"); namespace { @@ -291,6 +298,34 @@ public: } // end anonymous namespace +/// Return true if all the PHI nodes in the basic block \p BB +/// receive compatible (identical) incoming values when coming from +/// all of the predecessor blocks that are specified in \p IncomingBlocks. +/// +/// Note that if the values aren't exactly identical, but \p EquivalenceSet +/// is provided, and *both* of the values are present in the set, +/// then they are considered equal. +static bool IncomingValuesAreCompatible( + BasicBlock *BB, ArrayRef<BasicBlock *> IncomingBlocks, + SmallPtrSetImpl<Value *> *EquivalenceSet = nullptr) { + assert(IncomingBlocks.size() == 2 && + "Only for a pair of incoming blocks at the time!"); + + // FIXME: it is okay if one of the incoming values is an `undef` value, + // iff the other incoming value is guaranteed to be a non-poison value. + // FIXME: it is okay if one of the incoming values is a `poison` value. + return all_of(BB->phis(), [IncomingBlocks, EquivalenceSet](PHINode &PN) { + Value *IV0 = PN.getIncomingValueForBlock(IncomingBlocks[0]); + Value *IV1 = PN.getIncomingValueForBlock(IncomingBlocks[1]); + if (IV0 == IV1) + return true; + if (EquivalenceSet && EquivalenceSet->contains(IV0) && + EquivalenceSet->contains(IV1)) + return true; + return false; + }); +} + /// Return true if it is safe to merge these two /// terminator instructions together. static bool @@ -307,17 +342,17 @@ SafeToMergeTerminators(Instruction *SI1, Instruction *SI2, SmallPtrSet<BasicBlock *, 16> SI1Succs(succ_begin(SI1BB), succ_end(SI1BB)); bool Fail = false; - for (BasicBlock *Succ : successors(SI2BB)) - if (SI1Succs.count(Succ)) - for (BasicBlock::iterator BBI = Succ->begin(); isa<PHINode>(BBI); ++BBI) { - PHINode *PN = cast<PHINode>(BBI); - if (PN->getIncomingValueForBlock(SI1BB) != - PN->getIncomingValueForBlock(SI2BB)) { - if (FailBlocks) - FailBlocks->insert(Succ); - Fail = true; - } - } + for (BasicBlock *Succ : successors(SI2BB)) { + if (!SI1Succs.count(Succ)) + continue; + if (IncomingValuesAreCompatible(Succ, {SI1BB, SI2BB})) + continue; + Fail = true; + if (FailBlocks) + FailBlocks->insert(Succ); + else + break; + } return !Fail; } @@ -347,6 +382,13 @@ static InstructionCost computeSpeculationCost(const User *I, return TTI.getUserCost(I, TargetTransformInfo::TCK_SizeAndLatency); } +/// Check whether this is a potentially trapping constant. +static bool canTrap(const Value *V) { + if (auto *C = dyn_cast<Constant>(V)) + return C->canTrap(); + return false; +} + /// If we have a merge point of an "if condition" as accepted above, /// return true if the specified value dominates the block. We /// don't handle the true generality of domination here, just a special case @@ -381,10 +423,7 @@ static bool dominatesMergePoint(Value *V, BasicBlock *BB, if (!I) { // Non-instructions all dominate instructions, but not all constantexprs // can be executed unconditionally. - if (ConstantExpr *C = dyn_cast<ConstantExpr>(V)) - if (C->canTrap()) - return false; - return true; + return !canTrap(V); } BasicBlock *PBB = I->getParent(); @@ -1459,7 +1498,7 @@ bool SimplifyCFGOpt::HoistThenElseCodeToIf(BranchInst *BI, return false; if (!I1NonDbg->isTerminator()) return false; - // Now we know that we only need to hoist debug instrinsics and the + // Now we know that we only need to hoist debug intrinsics and the // terminator. Let the loop below handle those 2 cases. } @@ -2212,6 +2251,320 @@ static bool SinkCommonCodeFromPredecessors(BasicBlock *BB, return Changed; } +namespace { + +struct CompatibleSets { + using SetTy = SmallVector<InvokeInst *, 2>; + + SmallVector<SetTy, 1> Sets; + + static bool shouldBelongToSameSet(ArrayRef<InvokeInst *> Invokes); + + SetTy &getCompatibleSet(InvokeInst *II); + + void insert(InvokeInst *II); +}; + +CompatibleSets::SetTy &CompatibleSets::getCompatibleSet(InvokeInst *II) { + // Perform a linear scan over all the existing sets, see if the new `invoke` + // is compatible with any particular set. Since we know that all the `invokes` + // within a set are compatible, only check the first `invoke` in each set. + // WARNING: at worst, this has quadratic complexity. + for (CompatibleSets::SetTy &Set : Sets) { + if (CompatibleSets::shouldBelongToSameSet({Set.front(), II})) + return Set; + } + + // Otherwise, we either had no sets yet, or this invoke forms a new set. + return Sets.emplace_back(); +} + +void CompatibleSets::insert(InvokeInst *II) { + getCompatibleSet(II).emplace_back(II); +} + +bool CompatibleSets::shouldBelongToSameSet(ArrayRef<InvokeInst *> Invokes) { + assert(Invokes.size() == 2 && "Always called with exactly two candidates."); + + // Can we theoretically merge these `invoke`s? + auto IsIllegalToMerge = [](InvokeInst *II) { + return II->cannotMerge() || II->isInlineAsm(); + }; + if (any_of(Invokes, IsIllegalToMerge)) + return false; + + // Either both `invoke`s must be direct, + // or both `invoke`s must be indirect. + auto IsIndirectCall = [](InvokeInst *II) { return II->isIndirectCall(); }; + bool HaveIndirectCalls = any_of(Invokes, IsIndirectCall); + bool AllCallsAreIndirect = all_of(Invokes, IsIndirectCall); + if (HaveIndirectCalls) { + if (!AllCallsAreIndirect) + return false; + } else { + // All callees must be identical. + Value *Callee = nullptr; + for (InvokeInst *II : Invokes) { + Value *CurrCallee = II->getCalledOperand(); + assert(CurrCallee && "There is always a called operand."); + if (!Callee) + Callee = CurrCallee; + else if (Callee != CurrCallee) + return false; + } + } + + // Either both `invoke`s must not have a normal destination, + // or both `invoke`s must have a normal destination, + auto HasNormalDest = [](InvokeInst *II) { + return !isa<UnreachableInst>(II->getNormalDest()->getFirstNonPHIOrDbg()); + }; + if (any_of(Invokes, HasNormalDest)) { + // Do not merge `invoke` that does not have a normal destination with one + // that does have a normal destination, even though doing so would be legal. + if (!all_of(Invokes, HasNormalDest)) + return false; + + // All normal destinations must be identical. + BasicBlock *NormalBB = nullptr; + for (InvokeInst *II : Invokes) { + BasicBlock *CurrNormalBB = II->getNormalDest(); + assert(CurrNormalBB && "There is always a 'continue to' basic block."); + if (!NormalBB) + NormalBB = CurrNormalBB; + else if (NormalBB != CurrNormalBB) + return false; + } + + // In the normal destination, the incoming values for these two `invoke`s + // must be compatible. + SmallPtrSet<Value *, 16> EquivalenceSet(Invokes.begin(), Invokes.end()); + if (!IncomingValuesAreCompatible( + NormalBB, {Invokes[0]->getParent(), Invokes[1]->getParent()}, + &EquivalenceSet)) + return false; + } + +#ifndef NDEBUG + // All unwind destinations must be identical. + // We know that because we have started from said unwind destination. + BasicBlock *UnwindBB = nullptr; + for (InvokeInst *II : Invokes) { + BasicBlock *CurrUnwindBB = II->getUnwindDest(); + assert(CurrUnwindBB && "There is always an 'unwind to' basic block."); + if (!UnwindBB) + UnwindBB = CurrUnwindBB; + else + assert(UnwindBB == CurrUnwindBB && "Unexpected unwind destination."); + } +#endif + + // In the unwind destination, the incoming values for these two `invoke`s + // must be compatible. + if (!IncomingValuesAreCompatible( + Invokes.front()->getUnwindDest(), + {Invokes[0]->getParent(), Invokes[1]->getParent()})) + return false; + + // Ignoring arguments, these `invoke`s must be identical, + // including operand bundles. + const InvokeInst *II0 = Invokes.front(); + for (auto *II : Invokes.drop_front()) + if (!II->isSameOperationAs(II0)) + return false; + + // Can we theoretically form the data operands for the merged `invoke`? + auto IsIllegalToMergeArguments = [](auto Ops) { + Type *Ty = std::get<0>(Ops)->getType(); + assert(Ty == std::get<1>(Ops)->getType() && "Incompatible types?"); + return Ty->isTokenTy() && std::get<0>(Ops) != std::get<1>(Ops); + }; + assert(Invokes.size() == 2 && "Always called with exactly two candidates."); + if (any_of(zip(Invokes[0]->data_ops(), Invokes[1]->data_ops()), + IsIllegalToMergeArguments)) + return false; + + return true; +} + +} // namespace + +// Merge all invokes in the provided set, all of which are compatible +// as per the `CompatibleSets::shouldBelongToSameSet()`. +static void MergeCompatibleInvokesImpl(ArrayRef<InvokeInst *> Invokes, + DomTreeUpdater *DTU) { + assert(Invokes.size() >= 2 && "Must have at least two invokes to merge."); + + SmallVector<DominatorTree::UpdateType, 8> Updates; + if (DTU) + Updates.reserve(2 + 3 * Invokes.size()); + + bool HasNormalDest = + !isa<UnreachableInst>(Invokes[0]->getNormalDest()->getFirstNonPHIOrDbg()); + + // Clone one of the invokes into a new basic block. + // Since they are all compatible, it doesn't matter which invoke is cloned. + InvokeInst *MergedInvoke = [&Invokes, HasNormalDest]() { + InvokeInst *II0 = Invokes.front(); + BasicBlock *II0BB = II0->getParent(); + BasicBlock *InsertBeforeBlock = + II0->getParent()->getIterator()->getNextNode(); + Function *Func = II0BB->getParent(); + LLVMContext &Ctx = II0->getContext(); + + BasicBlock *MergedInvokeBB = BasicBlock::Create( + Ctx, II0BB->getName() + ".invoke", Func, InsertBeforeBlock); + + auto *MergedInvoke = cast<InvokeInst>(II0->clone()); + // NOTE: all invokes have the same attributes, so no handling needed. + MergedInvokeBB->getInstList().push_back(MergedInvoke); + + if (!HasNormalDest) { + // This set does not have a normal destination, + // so just form a new block with unreachable terminator. + BasicBlock *MergedNormalDest = BasicBlock::Create( + Ctx, II0BB->getName() + ".cont", Func, InsertBeforeBlock); + new UnreachableInst(Ctx, MergedNormalDest); + MergedInvoke->setNormalDest(MergedNormalDest); + } + + // The unwind destination, however, remainds identical for all invokes here. + + return MergedInvoke; + }(); + + if (DTU) { + // Predecessor blocks that contained these invokes will now branch to + // the new block that contains the merged invoke, ... + for (InvokeInst *II : Invokes) + Updates.push_back( + {DominatorTree::Insert, II->getParent(), MergedInvoke->getParent()}); + + // ... which has the new `unreachable` block as normal destination, + // or unwinds to the (same for all `invoke`s in this set) `landingpad`, + for (BasicBlock *SuccBBOfMergedInvoke : successors(MergedInvoke)) + Updates.push_back({DominatorTree::Insert, MergedInvoke->getParent(), + SuccBBOfMergedInvoke}); + + // Since predecessor blocks now unconditionally branch to a new block, + // they no longer branch to their original successors. + for (InvokeInst *II : Invokes) + for (BasicBlock *SuccOfPredBB : successors(II->getParent())) + Updates.push_back( + {DominatorTree::Delete, II->getParent(), SuccOfPredBB}); + } + + bool IsIndirectCall = Invokes[0]->isIndirectCall(); + + // Form the merged operands for the merged invoke. + for (Use &U : MergedInvoke->operands()) { + // Only PHI together the indirect callees and data operands. + if (MergedInvoke->isCallee(&U)) { + if (!IsIndirectCall) + continue; + } else if (!MergedInvoke->isDataOperand(&U)) + continue; + + // Don't create trivial PHI's with all-identical incoming values. + bool NeedPHI = any_of(Invokes, [&U](InvokeInst *II) { + return II->getOperand(U.getOperandNo()) != U.get(); + }); + if (!NeedPHI) + continue; + + // Form a PHI out of all the data ops under this index. + PHINode *PN = PHINode::Create( + U->getType(), /*NumReservedValues=*/Invokes.size(), "", MergedInvoke); + for (InvokeInst *II : Invokes) + PN->addIncoming(II->getOperand(U.getOperandNo()), II->getParent()); + + U.set(PN); + } + + // We've ensured that each PHI node has compatible (identical) incoming values + // when coming from each of the `invoke`s in the current merge set, + // so update the PHI nodes accordingly. + for (BasicBlock *Succ : successors(MergedInvoke)) + AddPredecessorToBlock(Succ, /*NewPred=*/MergedInvoke->getParent(), + /*ExistPred=*/Invokes.front()->getParent()); + + // And finally, replace the original `invoke`s with an unconditional branch + // to the block with the merged `invoke`. Also, give that merged `invoke` + // the merged debugloc of all the original `invoke`s. + const DILocation *MergedDebugLoc = nullptr; + for (InvokeInst *II : Invokes) { + // Compute the debug location common to all the original `invoke`s. + if (!MergedDebugLoc) + MergedDebugLoc = II->getDebugLoc(); + else + MergedDebugLoc = + DILocation::getMergedLocation(MergedDebugLoc, II->getDebugLoc()); + + // And replace the old `invoke` with an unconditionally branch + // to the block with the merged `invoke`. + for (BasicBlock *OrigSuccBB : successors(II->getParent())) + OrigSuccBB->removePredecessor(II->getParent()); + BranchInst::Create(MergedInvoke->getParent(), II->getParent()); + II->replaceAllUsesWith(MergedInvoke); + II->eraseFromParent(); + ++NumInvokesMerged; + } + MergedInvoke->setDebugLoc(MergedDebugLoc); + ++NumInvokeSetsFormed; + + if (DTU) + DTU->applyUpdates(Updates); +} + +/// If this block is a `landingpad` exception handling block, categorize all +/// the predecessor `invoke`s into sets, with all `invoke`s in each set +/// being "mergeable" together, and then merge invokes in each set together. +/// +/// This is a weird mix of hoisting and sinking. Visually, it goes from: +/// [...] [...] +/// | | +/// [invoke0] [invoke1] +/// / \ / \ +/// [cont0] [landingpad] [cont1] +/// to: +/// [...] [...] +/// \ / +/// [invoke] +/// / \ +/// [cont] [landingpad] +/// +/// But of course we can only do that if the invokes share the `landingpad`, +/// edges invoke0->cont0 and invoke1->cont1 are "compatible", +/// and the invoked functions are "compatible". +static bool MergeCompatibleInvokes(BasicBlock *BB, DomTreeUpdater *DTU) { + if (!EnableMergeCompatibleInvokes) + return false; + + bool Changed = false; + + // FIXME: generalize to all exception handling blocks? + if (!BB->isLandingPad()) + return Changed; + + CompatibleSets Grouper; + + // Record all the predecessors of this `landingpad`. As per verifier, + // the only allowed predecessor is the unwind edge of an `invoke`. + // We want to group "compatible" `invokes` into the same set to be merged. + for (BasicBlock *PredBB : predecessors(BB)) + Grouper.insert(cast<InvokeInst>(PredBB->getTerminator())); + + // And now, merge `invoke`s that were grouped togeter. + for (ArrayRef<InvokeInst *> Invokes : Grouper.Sets) { + if (Invokes.size() < 2) + continue; + Changed = true; + MergeCompatibleInvokesImpl(Invokes, DTU); + } + + return Changed; +} + /// Determine if we can hoist sink a sole store instruction out of a /// conditional block. /// @@ -2326,15 +2679,15 @@ static bool validateAndCostRequiredSelects(BasicBlock *BB, BasicBlock *ThenBB, passingValueIsAlwaysUndefined(ThenV, &PN)) return false; + if (canTrap(OrigV) || canTrap(ThenV)) + return false; + HaveRewritablePHIs = true; ConstantExpr *OrigCE = dyn_cast<ConstantExpr>(OrigV); ConstantExpr *ThenCE = dyn_cast<ConstantExpr>(ThenV); if (!OrigCE && !ThenCE) - continue; // Known safe and cheap. + continue; // Known cheap (FIXME: Maybe not true for aggregates). - if ((ThenCE && !isSafeToSpeculativelyExecute(ThenCE)) || - (OrigCE && !isSafeToSpeculativelyExecute(OrigCE))) - return false; InstructionCost OrigCost = OrigCE ? computeSpeculationCost(OrigCE, TTI) : 0; InstructionCost ThenCost = ThenCE ? computeSpeculationCost(ThenCE, TTI) : 0; InstructionCost MaxCost = @@ -2626,40 +2979,85 @@ static bool BlockIsSimpleEnoughToThreadThrough(BasicBlock *BB) { return true; } -/// If we have a conditional branch on a PHI node value that is defined in the -/// same block as the branch and if any PHI entries are constants, thread edges -/// corresponding to that entry to be branches to their ultimate destination. -static Optional<bool> FoldCondBranchOnPHIImpl(BranchInst *BI, - DomTreeUpdater *DTU, - const DataLayout &DL, - AssumptionCache *AC) { +static ConstantInt * +getKnownValueOnEdge(Value *V, BasicBlock *From, BasicBlock *To, + SmallDenseMap<std::pair<BasicBlock *, BasicBlock *>, + ConstantInt *> &Visited) { + // Don't look past the block defining the value, we might get the value from + // a previous loop iteration. + auto *I = dyn_cast<Instruction>(V); + if (I && I->getParent() == To) + return nullptr; + + // We know the value if the From block branches on it. + auto *BI = dyn_cast<BranchInst>(From->getTerminator()); + if (BI && BI->isConditional() && BI->getCondition() == V && + BI->getSuccessor(0) != BI->getSuccessor(1)) + return BI->getSuccessor(0) == To ? ConstantInt::getTrue(BI->getContext()) + : ConstantInt::getFalse(BI->getContext()); + + // Limit the amount of blocks we inspect. + if (Visited.size() >= 8) + return nullptr; + + auto Pair = Visited.try_emplace({From, To}, nullptr); + if (!Pair.second) + return Pair.first->second; + + // Check whether the known value is the same for all predecessors. + ConstantInt *Common = nullptr; + for (BasicBlock *Pred : predecessors(From)) { + ConstantInt *C = getKnownValueOnEdge(V, Pred, From, Visited); + if (!C || (Common && Common != C)) + return nullptr; + Common = C; + } + return Visited[{From, To}] = Common; +} + +/// If we have a conditional branch on something for which we know the constant +/// value in predecessors (e.g. a phi node in the current block), thread edges +/// from the predecessor to their ultimate destination. +static Optional<bool> +FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU, + const DataLayout &DL, + AssumptionCache *AC) { + SmallMapVector<BasicBlock *, ConstantInt *, 8> KnownValues; BasicBlock *BB = BI->getParent(); - PHINode *PN = dyn_cast<PHINode>(BI->getCondition()); - // NOTE: we currently cannot transform this case if the PHI node is used - // outside of the block. - if (!PN || PN->getParent() != BB || !PN->hasOneUse()) - return false; + Value *Cond = BI->getCondition(); + PHINode *PN = dyn_cast<PHINode>(Cond); + if (PN && PN->getParent() == BB) { + // Degenerate case of a single entry PHI. + if (PN->getNumIncomingValues() == 1) { + FoldSingleEntryPHINodes(PN->getParent()); + return true; + } - // Degenerate case of a single entry PHI. - if (PN->getNumIncomingValues() == 1) { - FoldSingleEntryPHINodes(PN->getParent()); - return true; + for (Use &U : PN->incoming_values()) + if (auto *CB = dyn_cast<ConstantInt>(U)) + KnownValues.insert({PN->getIncomingBlock(U), CB}); + } else { + SmallDenseMap<std::pair<BasicBlock *, BasicBlock *>, ConstantInt *> Visited; + for (BasicBlock *Pred : predecessors(BB)) { + if (ConstantInt *CB = getKnownValueOnEdge(Cond, Pred, BB, Visited)) + KnownValues.insert({Pred, CB}); + } } + if (KnownValues.empty()) + return false; + // Now we know that this block has multiple preds and two succs. + // Check that the block is small enough and values defined in the block are + // not used outside of it. if (!BlockIsSimpleEnoughToThreadThrough(BB)) return false; - // Okay, this is a simple enough basic block. See if any phi values are - // constants. - for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { - ConstantInt *CB = dyn_cast<ConstantInt>(PN->getIncomingValue(i)); - if (!CB || !CB->getType()->isIntegerTy(1)) - continue; - + for (const auto &Pair : KnownValues) { // Okay, we now know that all edges from PredBB should be revectored to // branch to RealDest. - BasicBlock *PredBB = PN->getIncomingBlock(i); + ConstantInt *CB = Pair.second; + BasicBlock *PredBB = Pair.first; BasicBlock *RealDest = BI->getSuccessor(!CB->getZExtValue()); if (RealDest == BB) @@ -2690,6 +3088,7 @@ static Optional<bool> FoldCondBranchOnPHIImpl(BranchInst *BI, // cloned instructions outside of EdgeBB. BasicBlock::iterator InsertPt = EdgeBB->begin(); DenseMap<Value *, Value *> TranslateMap; // Track translated values. + TranslateMap[Cond] = Pair.second; for (BasicBlock::iterator BBI = BB->begin(); &*BBI != BI; ++BBI) { if (PHINode *PN = dyn_cast<PHINode>(BBI)) { TranslateMap[PN] = PN->getIncomingValueForBlock(PredBB); @@ -2708,7 +3107,7 @@ static Optional<bool> FoldCondBranchOnPHIImpl(BranchInst *BI, } // Check for trivial simplification. - if (Value *V = SimplifyInstruction(N, {DL, nullptr, nullptr, AC})) { + if (Value *V = simplifyInstruction(N, {DL, nullptr, nullptr, AC})) { if (!BBI->use_empty()) TranslateMap[&*BBI] = V; if (!N->mayHaveSideEffects()) { @@ -2746,6 +3145,12 @@ static Optional<bool> FoldCondBranchOnPHIImpl(BranchInst *BI, DTU->applyUpdates(Updates); } + // For simplicity, we created a separate basic block for the edge. Merge + // it back into the predecessor if possible. This not only avoids + // unnecessary SimplifyCFG iterations, but also makes sure that we don't + // bypass the check for trivial cycles above. + MergeBlockIntoPredecessor(EdgeBB, DTU); + // Signal repeat, simplifying any other constants. return None; } @@ -2753,13 +3158,15 @@ static Optional<bool> FoldCondBranchOnPHIImpl(BranchInst *BI, return false; } -static bool FoldCondBranchOnPHI(BranchInst *BI, DomTreeUpdater *DTU, - const DataLayout &DL, AssumptionCache *AC) { +static bool FoldCondBranchOnValueKnownInPredecessor(BranchInst *BI, + DomTreeUpdater *DTU, + const DataLayout &DL, + AssumptionCache *AC) { Optional<bool> Result; bool EverChanged = false; do { // Note that None means "we changed things, but recurse further." - Result = FoldCondBranchOnPHIImpl(BI, DTU, DL, AC); + Result = FoldCondBranchOnValueKnownInPredecessorImpl(BI, DTU, DL, AC); EverChanged |= Result == None || *Result; } while (Result == None); return EverChanged; @@ -2847,7 +3254,7 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, bool Changed = false; for (BasicBlock::iterator II = BB->begin(); isa<PHINode>(II);) { PHINode *PN = cast<PHINode>(II++); - if (Value *V = SimplifyInstruction(PN, {DL, PN})) { + if (Value *V = simplifyInstruction(PN, {DL, PN})) { PN->replaceAllUsesWith(V); PN->eraseFromParent(); Changed = true; @@ -3186,18 +3593,18 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, DomTreeUpdater *DTU, Instruction *Cond = dyn_cast<Instruction>(BI->getCondition()); - if (!Cond || (!isa<CmpInst>(Cond) && !isa<BinaryOperator>(Cond)) || + if (!Cond || + (!isa<CmpInst>(Cond) && !isa<BinaryOperator>(Cond) && + !isa<SelectInst>(Cond)) || Cond->getParent() != BB || !Cond->hasOneUse()) return false; // Cond is known to be a compare or binary operator. Check to make sure that // neither operand is a potentially-trapping constant expression. - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Cond->getOperand(0))) - if (CE->canTrap()) - return false; - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Cond->getOperand(1))) - if (CE->canTrap()) - return false; + if (canTrap(Cond->getOperand(0))) + return false; + if (canTrap(Cond->getOperand(1))) + return false; // Finally, don't infinitely unroll conditional loops. if (is_contained(successors(BB), BB)) @@ -3384,7 +3791,9 @@ static bool mergeConditionalStoreToAddress( return false; // Now check the stores are compatible. - if (!QStore->isUnordered() || !PStore->isUnordered()) + if (!QStore->isUnordered() || !PStore->isUnordered() || + PStore->getValueOperand()->getType() != + QStore->getValueOperand()->getType()) return false; // Check that sinking the store won't cause program behavior changes. Sinking @@ -3687,7 +4096,8 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, if (PBI->getCondition() == BI->getCondition() && PBI->getSuccessor(0) != PBI->getSuccessor(1)) { // Okay, the outcome of this conditional branch is statically - // knowable. If this block had a single pred, handle specially. + // knowable. If this block had a single pred, handle specially, otherwise + // FoldCondBranchOnValueKnownInPredecessor() will handle it. if (BB->getSinglePredecessor()) { // Turn this into a branch on constant. bool CondIsTrue = PBI->getSuccessor(0) == BB; @@ -3695,35 +4105,6 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, ConstantInt::get(Type::getInt1Ty(BB->getContext()), CondIsTrue)); return true; // Nuke the branch on constant. } - - // Otherwise, if there are multiple predecessors, insert a PHI that merges - // in the constant and simplify the block result. Subsequent passes of - // simplifycfg will thread the block. - if (BlockIsSimpleEnoughToThreadThrough(BB)) { - pred_iterator PB = pred_begin(BB), PE = pred_end(BB); - PHINode *NewPN = PHINode::Create( - Type::getInt1Ty(BB->getContext()), std::distance(PB, PE), - BI->getCondition()->getName() + ".pr", &BB->front()); - // Okay, we're going to insert the PHI node. Since PBI is not the only - // predecessor, compute the PHI'd conditional value for all of the preds. - // Any predecessor where the condition is not computable we keep symbolic. - for (pred_iterator PI = PB; PI != PE; ++PI) { - BasicBlock *P = *PI; - if ((PBI = dyn_cast<BranchInst>(P->getTerminator())) && PBI != BI && - PBI->isConditional() && PBI->getCondition() == BI->getCondition() && - PBI->getSuccessor(0) != PBI->getSuccessor(1)) { - bool CondIsTrue = PBI->getSuccessor(0) == BB; - NewPN->addIncoming( - ConstantInt::get(Type::getInt1Ty(BB->getContext()), CondIsTrue), - P); - } else { - NewPN->addIncoming(BI->getCondition(), P); - } - } - - BI->setCondition(NewPN); - return true; - } } // If the previous block ended with a widenable branch, determine if reusing @@ -3732,9 +4113,8 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, if (tryWidenCondBranchToCondBranch(PBI, BI, DTU)) return true; - if (auto *CE = dyn_cast<ConstantExpr>(BI->getCondition())) - if (CE->canTrap()) - return false; + if (canTrap(BI->getCondition())) + return false; // If both branches are conditional and both contain stores to the same // address, remove the stores from the conditionals and create a conditional @@ -3791,15 +4171,13 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, PHINode *PN = cast<PHINode>(II); Value *BIV = PN->getIncomingValueForBlock(BB); - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(BIV)) - if (CE->canTrap()) - return false; + if (canTrap(BIV)) + return false; unsigned PBBIdx = PN->getBasicBlockIndex(PBI->getParent()); Value *PBIV = PN->getIncomingValue(PBBIdx); - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(PBIV)) - if (CE->canTrap()) - return false; + if (canTrap(PBIV)) + return false; } // Finally, if everything is ok, fold the branches to logical ops. @@ -4116,7 +4494,7 @@ bool SimplifyCFGOpt::tryToSimplifyUncondBranchWithICmpInIt( assert(VVal && "Should have a unique destination value"); ICI->setOperand(0, VVal); - if (Value *V = SimplifyInstruction(ICI, {DL, ICI})) { + if (Value *V = simplifyInstruction(ICI, {DL, ICI})) { ICI->replaceAllUsesWith(V); ICI->eraseFromParent(); } @@ -4812,8 +5190,9 @@ static void createUnreachableSwitchDefault(SwitchInst *Switch, } } -/// Turn a switch with two reachable destinations into an integer range -/// comparison and branch. +/// Turn a switch into an integer range comparison and branch. +/// Switches with more than 2 destinations are ignored. +/// Switches with 1 destination are also ignored. bool SimplifyCFGOpt::TurnSwitchRangeIntoICmp(SwitchInst *SI, IRBuilder<> &Builder) { assert(SI->getNumCases() > 1 && "Degenerate switch?"); @@ -4845,6 +5224,8 @@ bool SimplifyCFGOpt::TurnSwitchRangeIntoICmp(SwitchInst *SI, } return false; // More than two destinations. } + if (!DestB) + return false; // All destinations are the same and the default is unreachable assert(DestA && DestB && "Single-destination switch should have been folded."); @@ -5169,11 +5550,6 @@ ConstantFold(Instruction *I, const DataLayout &DL, return nullptr; } - if (CmpInst *Cmp = dyn_cast<CmpInst>(I)) { - return ConstantFoldCompareInstOperands(Cmp->getPredicate(), COps[0], - COps[1], DL); - } - return ConstantFoldInstOperands(I, COps, DL); } @@ -5182,7 +5558,7 @@ ConstantFold(Instruction *I, const DataLayout &DL, /// destionations CaseDest corresponding to value CaseVal (0 for the default /// case), of a switch instruction SI. static bool -GetCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest, +getCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest, BasicBlock **CommonDest, SmallVectorImpl<std::pair<PHINode *, Constant *>> &Res, const DataLayout &DL, const TargetTransformInfo &TTI) { @@ -5253,9 +5629,9 @@ GetCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest, // Helper function used to add CaseVal to the list of cases that generate // Result. Returns the updated number of cases that generate this result. -static uintptr_t MapCaseToResult(ConstantInt *CaseVal, - SwitchCaseResultVectorTy &UniqueResults, - Constant *Result) { +static size_t mapCaseToResult(ConstantInt *CaseVal, + SwitchCaseResultVectorTy &UniqueResults, + Constant *Result) { for (auto &I : UniqueResults) { if (I.first == Result) { I.second.push_back(CaseVal); @@ -5271,18 +5647,19 @@ static uintptr_t MapCaseToResult(ConstantInt *CaseVal, // results for the PHI node of the common destination block for a switch // instruction. Returns false if multiple PHI nodes have been found or if // there is not a common destination block for the switch. -static bool -InitializeUniqueCases(SwitchInst *SI, PHINode *&PHI, BasicBlock *&CommonDest, - SwitchCaseResultVectorTy &UniqueResults, - Constant *&DefaultResult, const DataLayout &DL, - const TargetTransformInfo &TTI, - uintptr_t MaxUniqueResults, uintptr_t MaxCasesPerResult) { +static bool initializeUniqueCases(SwitchInst *SI, PHINode *&PHI, + BasicBlock *&CommonDest, + SwitchCaseResultVectorTy &UniqueResults, + Constant *&DefaultResult, + const DataLayout &DL, + const TargetTransformInfo &TTI, + uintptr_t MaxUniqueResults) { for (auto &I : SI->cases()) { ConstantInt *CaseVal = I.getCaseValue(); // Resulting value at phi nodes for this case value. SwitchCaseResultsTy Results; - if (!GetCaseResults(SI, CaseVal, I.getCaseSuccessor(), &CommonDest, Results, + if (!getCaseResults(SI, CaseVal, I.getCaseSuccessor(), &CommonDest, Results, DL, TTI)) return false; @@ -5291,11 +5668,11 @@ InitializeUniqueCases(SwitchInst *SI, PHINode *&PHI, BasicBlock *&CommonDest, return false; // Add the case->result mapping to UniqueResults. - const uintptr_t NumCasesForResult = - MapCaseToResult(CaseVal, UniqueResults, Results.begin()->second); + const size_t NumCasesForResult = + mapCaseToResult(CaseVal, UniqueResults, Results.begin()->second); // Early out if there are too many cases for this result. - if (NumCasesForResult > MaxCasesPerResult) + if (NumCasesForResult > MaxSwitchCasesPerResult) return false; // Early out if there are too many unique results. @@ -5311,7 +5688,7 @@ InitializeUniqueCases(SwitchInst *SI, PHINode *&PHI, BasicBlock *&CommonDest, // Find the default result value. SmallVector<std::pair<PHINode *, Constant *>, 1> DefaultResults; BasicBlock *DefaultDest = SI->getDefaultDest(); - GetCaseResults(SI, nullptr, SI->getDefaultDest(), &CommonDest, DefaultResults, + getCaseResults(SI, nullptr, SI->getDefaultDest(), &CommonDest, DefaultResults, DL, TTI); // If the default value is not found abort unless the default destination // is unreachable. @@ -5326,48 +5703,76 @@ InitializeUniqueCases(SwitchInst *SI, PHINode *&PHI, BasicBlock *&CommonDest, // Helper function that checks if it is possible to transform a switch with only // two cases (or two cases + default) that produces a result into a select. -// Example: -// switch (a) { -// case 10: %0 = icmp eq i32 %a, 10 -// return 10; %1 = select i1 %0, i32 10, i32 4 -// case 20: ----> %2 = icmp eq i32 %a, 20 -// return 2; %3 = select i1 %2, i32 2, i32 %1 -// default: -// return 4; -// } -static Value *ConvertTwoCaseSwitch(const SwitchCaseResultVectorTy &ResultVector, - Constant *DefaultResult, Value *Condition, - IRBuilder<> &Builder) { +// TODO: Handle switches with more than 2 cases that map to the same result. +static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector, + Constant *DefaultResult, Value *Condition, + IRBuilder<> &Builder) { // If we are selecting between only two cases transform into a simple // select or a two-way select if default is possible. + // Example: + // switch (a) { %0 = icmp eq i32 %a, 10 + // case 10: return 42; %1 = select i1 %0, i32 42, i32 4 + // case 20: return 2; ----> %2 = icmp eq i32 %a, 20 + // default: return 4; %3 = select i1 %2, i32 2, i32 %1 + // } if (ResultVector.size() == 2 && ResultVector[0].second.size() == 1 && ResultVector[1].second.size() == 1) { - ConstantInt *const FirstCase = ResultVector[0].second[0]; - ConstantInt *const SecondCase = ResultVector[1].second[0]; - - bool DefaultCanTrigger = DefaultResult; + ConstantInt *FirstCase = ResultVector[0].second[0]; + ConstantInt *SecondCase = ResultVector[1].second[0]; Value *SelectValue = ResultVector[1].first; - if (DefaultCanTrigger) { - Value *const ValueCompare = + if (DefaultResult) { + Value *ValueCompare = Builder.CreateICmpEQ(Condition, SecondCase, "switch.selectcmp"); SelectValue = Builder.CreateSelect(ValueCompare, ResultVector[1].first, DefaultResult, "switch.select"); } - Value *const ValueCompare = + Value *ValueCompare = Builder.CreateICmpEQ(Condition, FirstCase, "switch.selectcmp"); return Builder.CreateSelect(ValueCompare, ResultVector[0].first, SelectValue, "switch.select"); } - // Handle the degenerate case where two cases have the same value. - if (ResultVector.size() == 1 && ResultVector[0].second.size() == 2 && - DefaultResult) { - Value *Cmp1 = Builder.CreateICmpEQ( - Condition, ResultVector[0].second[0], "switch.selectcmp.case1"); - Value *Cmp2 = Builder.CreateICmpEQ( - Condition, ResultVector[0].second[1], "switch.selectcmp.case2"); - Value *Cmp = Builder.CreateOr(Cmp1, Cmp2, "switch.selectcmp"); - return Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult); + // Handle the degenerate case where two cases have the same result value. + if (ResultVector.size() == 1 && DefaultResult) { + ArrayRef<ConstantInt *> CaseValues = ResultVector[0].second; + unsigned CaseCount = CaseValues.size(); + // n bits group cases map to the same result: + // case 0,4 -> Cond & 0b1..1011 == 0 ? result : default + // case 0,2,4,6 -> Cond & 0b1..1001 == 0 ? result : default + // case 0,2,8,10 -> Cond & 0b1..0101 == 0 ? result : default + if (isPowerOf2_32(CaseCount)) { + ConstantInt *MinCaseVal = CaseValues[0]; + // Find mininal value. + for (auto Case : CaseValues) + if (Case->getValue().slt(MinCaseVal->getValue())) + MinCaseVal = Case; + + // Mark the bits case number touched. + APInt BitMask = APInt::getZero(MinCaseVal->getBitWidth()); + for (auto Case : CaseValues) + BitMask |= (Case->getValue() - MinCaseVal->getValue()); + + // Check if cases with the same result can cover all number + // in touched bits. + if (BitMask.countPopulation() == Log2_32(CaseCount)) { + if (!MinCaseVal->isNullValue()) + Condition = Builder.CreateSub(Condition, MinCaseVal); + Value *And = Builder.CreateAnd(Condition, ~BitMask, "switch.and"); + Value *Cmp = Builder.CreateICmpEQ( + And, Constant::getNullValue(And->getType()), "switch.selectcmp"); + return Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult); + } + } + + // Handle the degenerate case where two cases have the same value. + if (CaseValues.size() == 2) { + Value *Cmp1 = Builder.CreateICmpEQ(Condition, CaseValues[0], + "switch.selectcmp.case1"); + Value *Cmp2 = Builder.CreateICmpEQ(Condition, CaseValues[1], + "switch.selectcmp.case2"); + Value *Cmp = Builder.CreateOr(Cmp1, Cmp2, "switch.selectcmp"); + return Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult); + } } return nullptr; @@ -5375,10 +5780,10 @@ static Value *ConvertTwoCaseSwitch(const SwitchCaseResultVectorTy &ResultVector, // Helper function to cleanup a switch instruction that has been converted into // a select, fixing up PHI nodes and basic blocks. -static void RemoveSwitchAfterSelectConversion(SwitchInst *SI, PHINode *PHI, - Value *SelectValue, - IRBuilder<> &Builder, - DomTreeUpdater *DTU) { +static void removeSwitchAfterSelectFold(SwitchInst *SI, PHINode *PHI, + Value *SelectValue, + IRBuilder<> &Builder, + DomTreeUpdater *DTU) { std::vector<DominatorTree::UpdateType> Updates; BasicBlock *SelectBB = SI->getParent(); @@ -5409,33 +5814,31 @@ static void RemoveSwitchAfterSelectConversion(SwitchInst *SI, PHINode *PHI, DTU->applyUpdates(Updates); } -/// If the switch is only used to initialize one or more -/// phi nodes in a common successor block with only two different -/// constant values, replace the switch with select. -static bool switchToSelect(SwitchInst *SI, IRBuilder<> &Builder, - DomTreeUpdater *DTU, const DataLayout &DL, - const TargetTransformInfo &TTI) { +/// If a switch is only used to initialize one or more phi nodes in a common +/// successor block with only two different constant values, try to replace the +/// switch with a select. Returns true if the fold was made. +static bool trySwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder, + DomTreeUpdater *DTU, const DataLayout &DL, + const TargetTransformInfo &TTI) { Value *const Cond = SI->getCondition(); PHINode *PHI = nullptr; BasicBlock *CommonDest = nullptr; Constant *DefaultResult; SwitchCaseResultVectorTy UniqueResults; // Collect all the cases that will deliver the same value from the switch. - if (!InitializeUniqueCases(SI, PHI, CommonDest, UniqueResults, DefaultResult, - DL, TTI, /*MaxUniqueResults*/2, - /*MaxCasesPerResult*/2)) + if (!initializeUniqueCases(SI, PHI, CommonDest, UniqueResults, DefaultResult, + DL, TTI, /*MaxUniqueResults*/ 2)) return false; - assert(PHI != nullptr && "PHI for value select not found"); + assert(PHI != nullptr && "PHI for value select not found"); Builder.SetInsertPoint(SI); Value *SelectValue = - ConvertTwoCaseSwitch(UniqueResults, DefaultResult, Cond, Builder); - if (SelectValue) { - RemoveSwitchAfterSelectConversion(SI, PHI, SelectValue, Builder, DTU); - return true; - } - // The switch couldn't be converted into a select. - return false; + foldSwitchToSelect(UniqueResults, DefaultResult, Cond, Builder); + if (!SelectValue) + return false; + + removeSwitchAfterSelectFold(SI, PHI, SelectValue, Builder, DTU); + return true; } namespace { @@ -5655,7 +6058,7 @@ Value *SwitchLookupTable::BuildLookup(Value *Index, IRBuilder<> &Builder) { IntegerType *IT = cast<IntegerType>(Index->getType()); uint64_t TableSize = Array->getInitializer()->getType()->getArrayNumElements(); - if (TableSize > (1ULL << (IT->getBitWidth() - 1))) + if (TableSize > (1ULL << std::min(IT->getBitWidth() - 1, 63u))) Index = Builder.CreateZExt( Index, IntegerType::get(IT->getContext(), IT->getBitWidth() + 1), "switch.tableidx.zext"); @@ -5707,6 +6110,27 @@ static bool isTypeLegalForLookupTable(Type *Ty, const TargetTransformInfo &TTI, DL.fitsInLegalInteger(IT->getBitWidth()); } +static bool isSwitchDense(uint64_t NumCases, uint64_t CaseRange) { + // 40% is the default density for building a jump table in optsize/minsize + // mode. See also TargetLoweringBase::isSuitableForJumpTable(), which this + // function was based on. + const uint64_t MinDensity = 40; + + if (CaseRange >= UINT64_MAX / 100) + return false; // Avoid multiplication overflows below. + + return NumCases * 100 >= CaseRange * MinDensity; +} + +static bool isSwitchDense(ArrayRef<int64_t> Values) { + uint64_t Diff = (uint64_t)Values.back() - (uint64_t)Values.front(); + uint64_t Range = Diff + 1; + if (Range < Diff) + return false; // Overflow. + + return isSwitchDense(Values.size(), Range); +} + /// Determine whether a lookup table should be built for this switch, based on /// the number of cases, size of the table, and the types of the results. // TODO: We could support larger than legal types by limiting based on the @@ -5716,8 +6140,8 @@ static bool ShouldBuildLookupTable(SwitchInst *SI, uint64_t TableSize, const TargetTransformInfo &TTI, const DataLayout &DL, const SmallDenseMap<PHINode *, Type *> &ResultTypes) { - if (SI->getNumCases() > TableSize || TableSize >= UINT64_MAX / 10) - return false; // TableSize overflowed, or mul below might overflow. + if (SI->getNumCases() > TableSize) + return false; // TableSize overflowed. bool AllTablesFitInRegister = true; bool HasIllegalType = false; @@ -5747,10 +6171,7 @@ ShouldBuildLookupTable(SwitchInst *SI, uint64_t TableSize, if (HasIllegalType) return false; - // The table density should be at least 40%. This is the same criterion as for - // jump tables, see SelectionDAGBuilder::handleJTSwitchCase. - // FIXME: Find the best cut-off. - return SI->getNumCases() * 10 >= TableSize * 4; + return isSwitchDense(SI->getNumCases(), TableSize); } /// Try to reuse the switch table index compare. Following pattern: @@ -5888,7 +6309,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // Resulting value at phi nodes for this case value. using ResultsTy = SmallVector<std::pair<PHINode *, Constant *>, 4>; ResultsTy Results; - if (!GetCaseResults(SI, CaseVal, CI->getCaseSuccessor(), &CommonDest, + if (!getCaseResults(SI, CaseVal, CI->getCaseSuccessor(), &CommonDest, Results, DL, TTI)) return false; @@ -5916,7 +6337,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // or a bitmask that fits in a register. SmallVector<std::pair<PHINode *, Constant *>, 4> DefaultResultsList; bool HasDefaultResults = - GetCaseResults(SI, nullptr, SI->getDefaultDest(), &CommonDest, + getCaseResults(SI, nullptr, SI->getDefaultDest(), &CommonDest, DefaultResultsList, DL, TTI); bool NeedMask = (TableHasHoles && !HasDefaultResults); @@ -6086,17 +6507,6 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, return true; } -static bool isSwitchDense(ArrayRef<int64_t> Values) { - // See also SelectionDAGBuilder::isDense(), which this function was based on. - uint64_t Diff = (uint64_t)Values.back() - (uint64_t)Values.front(); - uint64_t Range = Diff + 1; - uint64_t NumCases = Values.size(); - // 40% is the default density for building a jump table in optsize/minsize mode. - uint64_t MinDensity = 40; - - return NumCases * 100 >= Range * MinDensity; -} - /// Try to transform a switch that has "holes" in it to a contiguous sequence /// of cases. /// @@ -6220,7 +6630,7 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { if (eliminateDeadSwitchCases(SI, DTU, Options.AC, DL)) return requestResimplify(); - if (switchToSelect(SI, Builder, DTU, DL, TTI)) + if (trySwitchToSelect(SI, Builder, DTU, DL, TTI)) return requestResimplify(); if (Options.ForwardSwitchCondToPhi && ForwardSwitchConditionToPHI(SI)) @@ -6523,12 +6933,11 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { return requestResimplify(); } - // If this is a branch on a phi node in the current block, thread control - // through this block if any PHI node entries are constants. - if (PHINode *PN = dyn_cast<PHINode>(BI->getCondition())) - if (PN->getParent() == BI->getParent()) - if (FoldCondBranchOnPHI(BI, DTU, DL, Options.AC)) - return requestResimplify(); + // If this is a branch on something for which we know the constant value in + // predecessors (e.g. a phi node in the current block), thread control + // through this block. + if (FoldCondBranchOnValueKnownInPredecessor(BI, DTU, DL, Options.AC)) + return requestResimplify(); // Scan predecessor blocks for conditional branches. for (BasicBlock *Pred : predecessors(BB)) @@ -6727,7 +7136,8 @@ bool SimplifyCFGOpt::simplifyOnce(BasicBlock *BB) { return true; if (SinkCommon && Options.SinkCommonInsts) - if (SinkCommonCodeFromPredecessors(BB, DTU)) { + if (SinkCommonCodeFromPredecessors(BB, DTU) || + MergeCompatibleInvokes(BB, DTU)) { // SinkCommonCodeFromPredecessors() does not automatically CSE PHI's, // so we may now how duplicate PHI's. // Let's rerun EliminateDuplicatePHINodes() first, diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp index 5b7fd4349c6c..dbef1ff2e739 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -13,11 +13,9 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/SimplifyIndVar.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" @@ -58,7 +56,7 @@ namespace { SCEVExpander &Rewriter; SmallVectorImpl<WeakTrackingVH> &DeadInsts; - bool Changed; + bool Changed = false; public: SimplifyIndvar(Loop *Loop, ScalarEvolution *SE, DominatorTree *DT, @@ -66,7 +64,7 @@ namespace { SCEVExpander &Rewriter, SmallVectorImpl<WeakTrackingVH> &Dead) : L(Loop), LI(LI), SE(SE), DT(DT), TTI(TTI), Rewriter(Rewriter), - DeadInsts(Dead), Changed(false) { + DeadInsts(Dead) { assert(LI && "IV simplification requires LoopInfo"); } @@ -161,11 +159,12 @@ Value *SimplifyIndvar::foldIVUser(Instruction *UseInst, Instruction *IVOperand) D = ConstantInt::get(UseInst->getContext(), APInt::getOneBitSet(BitWidth, D->getZExtValue())); } - FoldedExpr = SE->getUDivExpr(SE->getSCEV(IVSrc), SE->getSCEV(D)); + const auto *LHS = SE->getSCEV(IVSrc); + const auto *RHS = SE->getSCEV(D); + FoldedExpr = SE->getUDivExpr(LHS, RHS); // We might have 'exact' flag set at this point which will no longer be // correct after we make the replacement. - if (UseInst->isExact() && - SE->getSCEV(IVSrc) != SE->getMulExpr(FoldedExpr, SE->getSCEV(D))) + if (UseInst->isExact() && LHS != SE->getMulExpr(FoldedExpr, RHS)) MustDropExactFlag = true; } // We have something that might fold it's operand. Compare SCEVs. @@ -872,6 +871,7 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { Instruction *IVOperand = UseOper.second; for (unsigned N = 0; IVOperand; ++N) { assert(N <= Simplified.size() && "runaway iteration"); + (void) N; Value *NewOper = foldIVUser(UseInst, IVOperand); if (!NewOper) @@ -1757,10 +1757,6 @@ Instruction *WidenIV::widenIVUse(WidenIV::NarrowIVDefUse DU, SCEVExpander &Rewri truncateIVUse(DU, DT, LI); return nullptr; } - // Assume block terminators cannot evaluate to a recurrence. We can't to - // insert a Trunc after a terminator if there happens to be a critical edge. - assert(DU.NarrowUse != DU.NarrowUse->getParent()->getTerminator() && - "SCEV is not expected to evaluate a block terminator"); // Reuse the IV increment that SCEVExpander created as long as it dominates // NarrowUse. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index e02d02a05752..f4306bb43dfd 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -14,28 +14,23 @@ #include "llvm/Transforms/Utils/SimplifyLibCalls.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/SmallString.h" -#include "llvm/ADT/StringMap.h" #include "llvm/ADT/Triple.h" -#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" -#include "llvm/Analysis/ProfileSummaryInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/Analysis/CaptureTracking.h" -#include "llvm/Analysis/Loads.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SizeOpts.h" using namespace llvm; @@ -206,6 +201,11 @@ static Value *copyFlags(const CallInst &Old, Value *New) { return New; } +// Helper to avoid truncating the length if size_t is 32-bits. +static StringRef substr(StringRef Str, uint64_t Len) { + return Len >= Str.size() ? Str : Str.substr(0, Len); +} + //===----------------------------------------------------------------------===// // String and Memory Library Call Optimizations //===----------------------------------------------------------------------===// @@ -242,7 +242,7 @@ Value *LibCallSimplifier::emitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len, // Now that we have the destination's length, we must index into the // destination's pointer to get the actual memcpy destination (end of // the string .. we're concatenating). - Value *CpyDst = B.CreateGEP(B.getInt8Ty(), Dst, DstLen, "endptr"); + Value *CpyDst = B.CreateInBoundsGEP(B.getInt8Ty(), Dst, DstLen, "endptr"); // We have enough information to now generate the memcpy call to do the // concatenation for us. Make a memcpy to copy the nul byte with align = 1. @@ -326,7 +326,7 @@ Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilderBase &B) { if (!getConstantStringInfo(SrcStr, Str)) { if (CharC->isZero()) // strchr(p, 0) -> p + strlen(p) if (Value *StrLen = emitStrLen(SrcStr, B, DL, TLI)) - return B.CreateGEP(B.getInt8Ty(), SrcStr, StrLen, "strchr"); + return B.CreateInBoundsGEP(B.getInt8Ty(), SrcStr, StrLen, "strchr"); return nullptr; } @@ -339,35 +339,29 @@ Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilderBase &B) { return Constant::getNullValue(CI->getType()); // strchr(s+n,c) -> gep(s+n+i,c) - return B.CreateGEP(B.getInt8Ty(), SrcStr, B.getInt64(I), "strchr"); + return B.CreateInBoundsGEP(B.getInt8Ty(), SrcStr, B.getInt64(I), "strchr"); } Value *LibCallSimplifier::optimizeStrRChr(CallInst *CI, IRBuilderBase &B) { Value *SrcStr = CI->getArgOperand(0); - ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + Value *CharVal = CI->getArgOperand(1); + ConstantInt *CharC = dyn_cast<ConstantInt>(CharVal); annotateNonNullNoUndefBasedOnAccess(CI, 0); - // Cannot fold anything if we're not looking for a constant. - if (!CharC) - return nullptr; - StringRef Str; if (!getConstantStringInfo(SrcStr, Str)) { // strrchr(s, 0) -> strchr(s, 0) - if (CharC->isZero()) + if (CharC && CharC->isZero()) return copyFlags(*CI, emitStrChr(SrcStr, '\0', B, TLI)); return nullptr; } - // Compute the offset. - size_t I = (0xFF & CharC->getSExtValue()) == 0 - ? Str.size() - : Str.rfind(CharC->getSExtValue()); - if (I == StringRef::npos) // Didn't find the char. Return null. - return Constant::getNullValue(CI->getType()); - - // strrchr(s+n,c) -> gep(s+n+i,c) - return B.CreateGEP(B.getInt8Ty(), SrcStr, B.getInt64(I), "strrchr"); + // Try to expand strrchr to the memrchr nonstandard extension if it's + // available, or simply fail otherwise. + uint64_t NBytes = Str.size() + 1; // Include the terminating nul. + Type *IntPtrType = DL.getIntPtrType(CI->getContext()); + Value *Size = ConstantInt::get(IntPtrType, NBytes); + return copyFlags(*CI, emitMemRChr(SrcStr, CharVal, Size, B, DL, TLI)); } Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilderBase &B) { @@ -428,6 +422,12 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilderBase &B) { return nullptr; } +// Optimize a memcmp or, when StrNCmp is true, strncmp call CI with constant +// arrays LHS and RHS and nonconstant Size. +static Value *optimizeMemCmpVarSize(CallInst *CI, Value *LHS, Value *RHS, + Value *Size, bool StrNCmp, + IRBuilderBase &B, const DataLayout &DL); + Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) { Value *Str1P = CI->getArgOperand(0); Value *Str2P = CI->getArgOperand(1); @@ -442,7 +442,7 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) { if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(Size)) Length = LengthArg->getZExtValue(); else - return nullptr; + return optimizeMemCmpVarSize(CI, Str1P, Str2P, Size, true, B, DL); if (Length == 0) // strncmp(x,y,0) -> 0 return ConstantInt::get(CI->getType(), 0); @@ -456,8 +456,9 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) { // strncmp(x, y) -> cnst (if both x and y are constant strings) if (HasStr1 && HasStr2) { - StringRef SubStr1 = Str1.substr(0, Length); - StringRef SubStr2 = Str2.substr(0, Length); + // Avoid truncating the 64-bit Length to 32 bits in ILP32. + StringRef SubStr1 = substr(Str1, Length); + StringRef SubStr2 = substr(Str2, Length); return ConstantInt::get(CI->getType(), SubStr1.compare(SubStr2)); } @@ -557,8 +558,8 @@ Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilderBase &B) { Type *PT = Callee->getFunctionType()->getParamType(0); Value *LenV = ConstantInt::get(DL.getIntPtrType(PT), Len); - Value *DstEnd = B.CreateGEP(B.getInt8Ty(), Dst, - ConstantInt::get(DL.getIntPtrType(PT), Len - 1)); + Value *DstEnd = B.CreateInBoundsGEP( + B.getInt8Ty(), Dst, ConstantInt::get(DL.getIntPtrType(PT), Len - 1)); // We have enough information to now generate the memcpy call to do the // copy for us. Make a memcpy to copy the nul byte with align = 1. @@ -634,12 +635,51 @@ Value *LibCallSimplifier::optimizeStrNCpy(CallInst *CI, IRBuilderBase &B) { } Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilderBase &B, - unsigned CharSize) { + unsigned CharSize, + Value *Bound) { Value *Src = CI->getArgOperand(0); + Type *CharTy = B.getIntNTy(CharSize); + + if (isOnlyUsedInZeroEqualityComparison(CI) && + (!Bound || isKnownNonZero(Bound, DL))) { + // Fold strlen: + // strlen(x) != 0 --> *x != 0 + // strlen(x) == 0 --> *x == 0 + // and likewise strnlen with constant N > 0: + // strnlen(x, N) != 0 --> *x != 0 + // strnlen(x, N) == 0 --> *x == 0 + return B.CreateZExt(B.CreateLoad(CharTy, Src, "char0"), + CI->getType()); + } + + if (Bound) { + if (ConstantInt *BoundCst = dyn_cast<ConstantInt>(Bound)) { + if (BoundCst->isZero()) + // Fold strnlen(s, 0) -> 0 for any s, constant or otherwise. + return ConstantInt::get(CI->getType(), 0); + + if (BoundCst->isOne()) { + // Fold strnlen(s, 1) -> *s ? 1 : 0 for any s. + Value *CharVal = B.CreateLoad(CharTy, Src, "strnlen.char0"); + Value *ZeroChar = ConstantInt::get(CharTy, 0); + Value *Cmp = B.CreateICmpNE(CharVal, ZeroChar, "strnlen.char0cmp"); + return B.CreateZExt(Cmp, CI->getType()); + } + } + } + + if (uint64_t Len = GetStringLength(Src, CharSize)) { + Value *LenC = ConstantInt::get(CI->getType(), Len - 1); + // Fold strlen("xyz") -> 3 and strnlen("xyz", 2) -> 2 + // and strnlen("xyz", Bound) -> min(3, Bound) for nonconstant Bound. + if (Bound) + return B.CreateBinaryIntrinsic(Intrinsic::umin, LenC, Bound); + return LenC; + } - // Constant folding: strlen("xyz") -> 3 - if (uint64_t Len = GetStringLength(Src, CharSize)) - return ConstantInt::get(CI->getType(), Len - 1); + if (Bound) + // Punt for strnlen for now. + return nullptr; // If s is a constant pointer pointing to a string literal, we can fold // strlen(s + x) to strlen(s) - x, when x is known to be in the range @@ -650,6 +690,7 @@ Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilderBase &B, // very useful because calling strlen for a pointer of other types is // very uncommon. if (GEPOperator *GEP = dyn_cast<GEPOperator>(Src)) { + // TODO: Handle subobjects. if (!isGEPBasedOnPointerToString(GEP, CharSize)) return nullptr; @@ -674,22 +715,15 @@ Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilderBase &B, Value *Offset = GEP->getOperand(2); KnownBits Known = computeKnownBits(Offset, DL, 0, nullptr, CI, nullptr); - Known.Zero.flipAllBits(); uint64_t ArrSize = cast<ArrayType>(GEP->getSourceElementType())->getNumElements(); - // KnownZero's bits are flipped, so zeros in KnownZero now represent - // bits known to be zeros in Offset, and ones in KnowZero represent - // bits unknown in Offset. Therefore, Offset is known to be in range - // [0, NullTermIdx] when the flipped KnownZero is non-negative and - // unsigned-less-than NullTermIdx. - // // If Offset is not provably in the range [0, NullTermIdx], we can still // optimize if we can prove that the program has undefined behavior when // Offset is outside that range. That is the case when GEP->getOperand(0) // is a pointer to an object whose memory extent is NullTermIdx+1. - if ((Known.Zero.isNonNegative() && Known.Zero.ule(NullTermIdx)) || - (GEP->isInBounds() && isa<GlobalVariable>(GEP->getOperand(0)) && + if ((Known.isNonNegative() && Known.getMaxValue().ule(NullTermIdx)) || + (isa<GlobalVariable>(GEP->getOperand(0)) && NullTermIdx == ArrSize - 1)) { Offset = B.CreateSExtOrTrunc(Offset, CI->getType()); return B.CreateSub(ConstantInt::get(CI->getType(), NullTermIdx), @@ -713,12 +747,6 @@ Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilderBase &B, } } - // strlen(x) != 0 --> *x != 0 - // strlen(x) == 0 --> *x == 0 - if (isOnlyUsedInZeroEqualityComparison(CI)) - return B.CreateZExt(B.CreateLoad(B.getIntNTy(CharSize), Src, "strlenfirst"), - CI->getType()); - return nullptr; } @@ -729,6 +757,16 @@ Value *LibCallSimplifier::optimizeStrLen(CallInst *CI, IRBuilderBase &B) { return nullptr; } +Value *LibCallSimplifier::optimizeStrNLen(CallInst *CI, IRBuilderBase &B) { + Value *Bound = CI->getArgOperand(1); + if (Value *V = optimizeStringLength(CI, B, 8, Bound)) + return V; + + if (isKnownNonZero(Bound, DL)) + annotateNonNullNoUndefBasedOnAccess(CI, 0); + return nullptr; +} + Value *LibCallSimplifier::optimizeWcslen(CallInst *CI, IRBuilderBase &B) { Module &M = *CI->getModule(); unsigned WCharSize = TLI->getWCharSize(M) * 8; @@ -755,8 +793,8 @@ Value *LibCallSimplifier::optimizeStrPBrk(CallInst *CI, IRBuilderBase &B) { if (I == StringRef::npos) // No match. return Constant::getNullValue(CI->getType()); - return B.CreateGEP(B.getInt8Ty(), CI->getArgOperand(0), B.getInt64(I), - "strpbrk"); + return B.CreateInBoundsGEP(B.getInt8Ty(), CI->getArgOperand(0), + B.getInt64(I), "strpbrk"); } // strpbrk(s, "a") -> strchr(s, 'a') @@ -880,35 +918,190 @@ Value *LibCallSimplifier::optimizeStrStr(CallInst *CI, IRBuilderBase &B) { } Value *LibCallSimplifier::optimizeMemRChr(CallInst *CI, IRBuilderBase &B) { - if (isKnownNonZero(CI->getOperand(2), DL)) - annotateNonNullNoUndefBasedOnAccess(CI, 0); - return nullptr; + Value *SrcStr = CI->getArgOperand(0); + Value *Size = CI->getArgOperand(2); + annotateNonNullAndDereferenceable(CI, 0, Size, DL); + Value *CharVal = CI->getArgOperand(1); + ConstantInt *LenC = dyn_cast<ConstantInt>(Size); + Value *NullPtr = Constant::getNullValue(CI->getType()); + + if (LenC) { + if (LenC->isZero()) + // Fold memrchr(x, y, 0) --> null. + return NullPtr; + + if (LenC->isOne()) { + // Fold memrchr(x, y, 1) --> *x == y ? x : null for any x and y, + // constant or otherwise. + Value *Val = B.CreateLoad(B.getInt8Ty(), SrcStr, "memrchr.char0"); + // Slice off the character's high end bits. + CharVal = B.CreateTrunc(CharVal, B.getInt8Ty()); + Value *Cmp = B.CreateICmpEQ(Val, CharVal, "memrchr.char0cmp"); + return B.CreateSelect(Cmp, SrcStr, NullPtr, "memrchr.sel"); + } + } + + StringRef Str; + if (!getConstantStringInfo(SrcStr, Str, 0, /*TrimAtNul=*/false)) + return nullptr; + + if (Str.size() == 0) + // If the array is empty fold memrchr(A, C, N) to null for any value + // of C and N on the basis that the only valid value of N is zero + // (otherwise the call is undefined). + return NullPtr; + + uint64_t EndOff = UINT64_MAX; + if (LenC) { + EndOff = LenC->getZExtValue(); + if (Str.size() < EndOff) + // Punt out-of-bounds accesses to sanitizers and/or libc. + return nullptr; + } + + if (ConstantInt *CharC = dyn_cast<ConstantInt>(CharVal)) { + // Fold memrchr(S, C, N) for a constant C. + size_t Pos = Str.rfind(CharC->getZExtValue(), EndOff); + if (Pos == StringRef::npos) + // When the character is not in the source array fold the result + // to null regardless of Size. + return NullPtr; + + if (LenC) + // Fold memrchr(s, c, N) --> s + Pos for constant N > Pos. + return B.CreateInBoundsGEP(B.getInt8Ty(), SrcStr, B.getInt64(Pos)); + + if (Str.find(Str[Pos]) == Pos) { + // When there is just a single occurrence of C in S, i.e., the one + // in Str[Pos], fold + // memrchr(s, c, N) --> N <= Pos ? null : s + Pos + // for nonconstant N. + Value *Cmp = B.CreateICmpULE(Size, ConstantInt::get(Size->getType(), Pos), + "memrchr.cmp"); + Value *SrcPlus = B.CreateInBoundsGEP(B.getInt8Ty(), SrcStr, + B.getInt64(Pos), "memrchr.ptr_plus"); + return B.CreateSelect(Cmp, NullPtr, SrcPlus, "memrchr.sel"); + } + } + + // Truncate the string to search at most EndOff characters. + Str = Str.substr(0, EndOff); + if (Str.find_first_not_of(Str[0]) != StringRef::npos) + return nullptr; + + // If the source array consists of all equal characters, then for any + // C and N (whether in bounds or not), fold memrchr(S, C, N) to + // N != 0 && *S == C ? S + N - 1 : null + Type *SizeTy = Size->getType(); + Type *Int8Ty = B.getInt8Ty(); + Value *NNeZ = B.CreateICmpNE(Size, ConstantInt::get(SizeTy, 0)); + // Slice off the sought character's high end bits. + CharVal = B.CreateTrunc(CharVal, Int8Ty); + Value *CEqS0 = B.CreateICmpEQ(ConstantInt::get(Int8Ty, Str[0]), CharVal); + Value *And = B.CreateLogicalAnd(NNeZ, CEqS0); + Value *SizeM1 = B.CreateSub(Size, ConstantInt::get(SizeTy, 1)); + Value *SrcPlus = + B.CreateInBoundsGEP(Int8Ty, SrcStr, SizeM1, "memrchr.ptr_plus"); + return B.CreateSelect(And, SrcPlus, NullPtr, "memrchr.sel"); } Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) { Value *SrcStr = CI->getArgOperand(0); Value *Size = CI->getArgOperand(2); - annotateNonNullAndDereferenceable(CI, 0, Size, DL); - ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + if (isKnownNonZero(Size, DL)) + annotateNonNullNoUndefBasedOnAccess(CI, 0); + + Value *CharVal = CI->getArgOperand(1); + ConstantInt *CharC = dyn_cast<ConstantInt>(CharVal); ConstantInt *LenC = dyn_cast<ConstantInt>(Size); + Value *NullPtr = Constant::getNullValue(CI->getType()); // memchr(x, y, 0) -> null if (LenC) { if (LenC->isZero()) - return Constant::getNullValue(CI->getType()); - } else { - // From now on we need at least constant length and string. - return nullptr; + return NullPtr; + + if (LenC->isOne()) { + // Fold memchr(x, y, 1) --> *x == y ? x : null for any x and y, + // constant or otherwise. + Value *Val = B.CreateLoad(B.getInt8Ty(), SrcStr, "memchr.char0"); + // Slice off the character's high end bits. + CharVal = B.CreateTrunc(CharVal, B.getInt8Ty()); + Value *Cmp = B.CreateICmpEQ(Val, CharVal, "memchr.char0cmp"); + return B.CreateSelect(Cmp, SrcStr, NullPtr, "memchr.sel"); + } } StringRef Str; if (!getConstantStringInfo(SrcStr, Str, 0, /*TrimAtNul=*/false)) return nullptr; - // Truncate the string to LenC. If Str is smaller than LenC we will still only - // scan the string, as reading past the end of it is undefined and we can just - // return null if we don't find the char. - Str = Str.substr(0, LenC->getZExtValue()); + if (CharC) { + size_t Pos = Str.find(CharC->getZExtValue()); + if (Pos == StringRef::npos) + // When the character is not in the source array fold the result + // to null regardless of Size. + return NullPtr; + + // Fold memchr(s, c, n) -> n <= Pos ? null : s + Pos + // When the constant Size is less than or equal to the character + // position also fold the result to null. + Value *Cmp = B.CreateICmpULE(Size, ConstantInt::get(Size->getType(), Pos), + "memchr.cmp"); + Value *SrcPlus = B.CreateInBoundsGEP(B.getInt8Ty(), SrcStr, B.getInt64(Pos), + "memchr.ptr"); + return B.CreateSelect(Cmp, NullPtr, SrcPlus); + } + + if (Str.size() == 0) + // If the array is empty fold memchr(A, C, N) to null for any value + // of C and N on the basis that the only valid value of N is zero + // (otherwise the call is undefined). + return NullPtr; + + if (LenC) + Str = substr(Str, LenC->getZExtValue()); + + size_t Pos = Str.find_first_not_of(Str[0]); + if (Pos == StringRef::npos + || Str.find_first_not_of(Str[Pos], Pos) == StringRef::npos) { + // If the source array consists of at most two consecutive sequences + // of the same characters, then for any C and N (whether in bounds or + // not), fold memchr(S, C, N) to + // N != 0 && *S == C ? S : null + // or for the two sequences to: + // N != 0 && *S == C ? S : (N > Pos && S[Pos] == C ? S + Pos : null) + // ^Sel2 ^Sel1 are denoted above. + // The latter makes it also possible to fold strchr() calls with strings + // of the same characters. + Type *SizeTy = Size->getType(); + Type *Int8Ty = B.getInt8Ty(); + + // Slice off the sought character's high end bits. + CharVal = B.CreateTrunc(CharVal, Int8Ty); + + Value *Sel1 = NullPtr; + if (Pos != StringRef::npos) { + // Handle two consecutive sequences of the same characters. + Value *PosVal = ConstantInt::get(SizeTy, Pos); + Value *StrPos = ConstantInt::get(Int8Ty, Str[Pos]); + Value *CEqSPos = B.CreateICmpEQ(CharVal, StrPos); + Value *NGtPos = B.CreateICmp(ICmpInst::ICMP_UGT, Size, PosVal); + Value *And = B.CreateAnd(CEqSPos, NGtPos); + Value *SrcPlus = B.CreateInBoundsGEP(B.getInt8Ty(), SrcStr, PosVal); + Sel1 = B.CreateSelect(And, SrcPlus, NullPtr, "memchr.sel1"); + } + + Value *Str0 = ConstantInt::get(Int8Ty, Str[0]); + Value *CEqS0 = B.CreateICmpEQ(Str0, CharVal); + Value *NNeZ = B.CreateICmpNE(Size, ConstantInt::get(SizeTy, 0)); + Value *And = B.CreateAnd(NNeZ, CEqS0); + return B.CreateSelect(And, SrcStr, Sel1, "memchr.sel2"); + } + + if (!LenC) + // From now on we need a constant length and constant array. + return nullptr; // If the char is variable but the input str and length are not we can turn // this memchr call into a simple bit field test. Of course this only works @@ -920,60 +1113,93 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) { // memchr("\r\n", C, 2) != nullptr -> (1 << C & ((1 << '\r') | (1 << '\n'))) // != 0 // after bounds check. - if (!CharC && !Str.empty() && isOnlyUsedInZeroEqualityComparison(CI)) { - unsigned char Max = - *std::max_element(reinterpret_cast<const unsigned char *>(Str.begin()), - reinterpret_cast<const unsigned char *>(Str.end())); - - // Make sure the bit field we're about to create fits in a register on the - // target. - // FIXME: On a 64 bit architecture this prevents us from using the - // interesting range of alpha ascii chars. We could do better by emitting - // two bitfields or shifting the range by 64 if no lower chars are used. - if (!DL.fitsInLegalInteger(Max + 1)) - return nullptr; + if (Str.empty() || !isOnlyUsedInZeroEqualityComparison(CI)) + return nullptr; + + unsigned char Max = + *std::max_element(reinterpret_cast<const unsigned char *>(Str.begin()), + reinterpret_cast<const unsigned char *>(Str.end())); - // For the bit field use a power-of-2 type with at least 8 bits to avoid - // creating unnecessary illegal types. - unsigned char Width = NextPowerOf2(std::max((unsigned char)7, Max)); + // Make sure the bit field we're about to create fits in a register on the + // target. + // FIXME: On a 64 bit architecture this prevents us from using the + // interesting range of alpha ascii chars. We could do better by emitting + // two bitfields or shifting the range by 64 if no lower chars are used. + if (!DL.fitsInLegalInteger(Max + 1)) + return nullptr; - // Now build the bit field. - APInt Bitfield(Width, 0); - for (char C : Str) - Bitfield.setBit((unsigned char)C); - Value *BitfieldC = B.getInt(Bitfield); + // For the bit field use a power-of-2 type with at least 8 bits to avoid + // creating unnecessary illegal types. + unsigned char Width = NextPowerOf2(std::max((unsigned char)7, Max)); - // Adjust width of "C" to the bitfield width, then mask off the high bits. - Value *C = B.CreateZExtOrTrunc(CI->getArgOperand(1), BitfieldC->getType()); - C = B.CreateAnd(C, B.getIntN(Width, 0xFF)); + // Now build the bit field. + APInt Bitfield(Width, 0); + for (char C : Str) + Bitfield.setBit((unsigned char)C); + Value *BitfieldC = B.getInt(Bitfield); - // First check that the bit field access is within bounds. - Value *Bounds = B.CreateICmp(ICmpInst::ICMP_ULT, C, B.getIntN(Width, Width), - "memchr.bounds"); + // Adjust width of "C" to the bitfield width, then mask off the high bits. + Value *C = B.CreateZExtOrTrunc(CharVal, BitfieldC->getType()); + C = B.CreateAnd(C, B.getIntN(Width, 0xFF)); - // Create code that checks if the given bit is set in the field. - Value *Shl = B.CreateShl(B.getIntN(Width, 1ULL), C); - Value *Bits = B.CreateIsNotNull(B.CreateAnd(Shl, BitfieldC), "memchr.bits"); + // First check that the bit field access is within bounds. + Value *Bounds = B.CreateICmp(ICmpInst::ICMP_ULT, C, B.getIntN(Width, Width), + "memchr.bounds"); - // Finally merge both checks and cast to pointer type. The inttoptr - // implicitly zexts the i1 to intptr type. - return B.CreateIntToPtr(B.CreateLogicalAnd(Bounds, Bits, "memchr"), - CI->getType()); - } + // Create code that checks if the given bit is set in the field. + Value *Shl = B.CreateShl(B.getIntN(Width, 1ULL), C); + Value *Bits = B.CreateIsNotNull(B.CreateAnd(Shl, BitfieldC), "memchr.bits"); - // Check if all arguments are constants. If so, we can constant fold. - if (!CharC) - return nullptr; + // Finally merge both checks and cast to pointer type. The inttoptr + // implicitly zexts the i1 to intptr type. + return B.CreateIntToPtr(B.CreateLogicalAnd(Bounds, Bits, "memchr"), + CI->getType()); +} - // Compute the offset. - size_t I = Str.find(CharC->getSExtValue() & 0xFF); - if (I == StringRef::npos) // Didn't find the char. memchr returns null. +// Optimize a memcmp or, when StrNCmp is true, strncmp call CI with constant +// arrays LHS and RHS and nonconstant Size. +static Value *optimizeMemCmpVarSize(CallInst *CI, Value *LHS, Value *RHS, + Value *Size, bool StrNCmp, + IRBuilderBase &B, const DataLayout &DL) { + if (LHS == RHS) // memcmp(s,s,x) -> 0 return Constant::getNullValue(CI->getType()); - // memchr(s+n,c,l) -> gep(s+n+i,c) - return B.CreateGEP(B.getInt8Ty(), SrcStr, B.getInt64(I), "memchr"); + StringRef LStr, RStr; + if (!getConstantStringInfo(LHS, LStr, 0, /*TrimAtNul=*/false) || + !getConstantStringInfo(RHS, RStr, 0, /*TrimAtNul=*/false)) + return nullptr; + + // If the contents of both constant arrays are known, fold a call to + // memcmp(A, B, N) to + // N <= Pos ? 0 : (A < B ? -1 : B < A ? +1 : 0) + // where Pos is the first mismatch between A and B, determined below. + + uint64_t Pos = 0; + Value *Zero = ConstantInt::get(CI->getType(), 0); + for (uint64_t MinSize = std::min(LStr.size(), RStr.size()); ; ++Pos) { + if (Pos == MinSize || + (StrNCmp && (LStr[Pos] == '\0' && RStr[Pos] == '\0'))) { + // One array is a leading part of the other of equal or greater + // size, or for strncmp, the arrays are equal strings. + // Fold the result to zero. Size is assumed to be in bounds, since + // otherwise the call would be undefined. + return Zero; + } + + if (LStr[Pos] != RStr[Pos]) + break; + } + + // Normalize the result. + typedef unsigned char UChar; + int IRes = UChar(LStr[Pos]) < UChar(RStr[Pos]) ? -1 : 1; + Value *MaxSize = ConstantInt::get(Size->getType(), Pos); + Value *Cmp = B.CreateICmp(ICmpInst::ICMP_ULE, Size, MaxSize); + Value *Res = ConstantInt::get(CI->getType(), IRes); + return B.CreateSelect(Cmp, Zero, Res); } +// Optimize a memcmp call CI with constant size Len. static Value *optimizeMemCmpConstantSize(CallInst *CI, Value *LHS, Value *RHS, uint64_t Len, IRBuilderBase &B, const DataLayout &DL) { @@ -1028,25 +1254,6 @@ static Value *optimizeMemCmpConstantSize(CallInst *CI, Value *LHS, Value *RHS, } } - // Constant folding: memcmp(x, y, Len) -> constant (all arguments are const). - // TODO: This is limited to i8 arrays. - StringRef LHSStr, RHSStr; - if (getConstantStringInfo(LHS, LHSStr) && - getConstantStringInfo(RHS, RHSStr)) { - // Make sure we're not reading out-of-bounds memory. - if (Len > LHSStr.size() || Len > RHSStr.size()) - return nullptr; - // Fold the memcmp and normalize the result. This way we get consistent - // results across multiple platforms. - uint64_t Ret = 0; - int Cmp = memcmp(LHSStr.data(), RHSStr.data(), Len); - if (Cmp < 0) - Ret = -1; - else if (Cmp > 0) - Ret = 1; - return ConstantInt::get(CI->getType(), Ret); - } - return nullptr; } @@ -1056,33 +1263,29 @@ Value *LibCallSimplifier::optimizeMemCmpBCmpCommon(CallInst *CI, Value *LHS = CI->getArgOperand(0), *RHS = CI->getArgOperand(1); Value *Size = CI->getArgOperand(2); - if (LHS == RHS) // memcmp(s,s,x) -> 0 - return Constant::getNullValue(CI->getType()); - annotateNonNullAndDereferenceable(CI, {0, 1}, Size, DL); - // Handle constant lengths. + + if (Value *Res = optimizeMemCmpVarSize(CI, LHS, RHS, Size, false, B, DL)) + return Res; + + // Handle constant Size. ConstantInt *LenC = dyn_cast<ConstantInt>(Size); if (!LenC) return nullptr; - // memcmp(d,s,0) -> 0 - if (LenC->getZExtValue() == 0) - return Constant::getNullValue(CI->getType()); - - if (Value *Res = - optimizeMemCmpConstantSize(CI, LHS, RHS, LenC->getZExtValue(), B, DL)) - return Res; - return nullptr; + return optimizeMemCmpConstantSize(CI, LHS, RHS, LenC->getZExtValue(), B, DL); } Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilderBase &B) { + Module *M = CI->getModule(); if (Value *V = optimizeMemCmpBCmpCommon(CI, B)) return V; // memcmp(x, y, Len) == 0 -> bcmp(x, y, Len) == 0 // bcmp can be more efficient than memcmp because it only has to know that // there is a difference, not how different one is to the other. - if (TLI->has(LibFunc_bcmp) && isOnlyUsedInZeroEqualityComparison(CI)) { + if (isLibFuncEmittable(M, TLI, LibFunc_bcmp) && + isOnlyUsedInZeroEqualityComparison(CI)) { Value *LHS = CI->getArgOperand(0); Value *RHS = CI->getArgOperand(1); Value *Size = CI->getArgOperand(2); @@ -1125,6 +1328,7 @@ Value *LibCallSimplifier::optimizeMemCCpy(CallInst *CI, IRBuilderBase &B) { return Constant::getNullValue(CI->getType()); if (!getConstantStringInfo(Src, SrcStr, /*Offset=*/0, /*TrimAtNul=*/false) || + // TODO: Handle zeroinitializer. !StopChar) return nullptr; } else { @@ -1246,7 +1450,8 @@ static Value *valueHasFloatPrecision(Value *Val) { /// Shrink double -> float functions. static Value *optimizeDoubleFP(CallInst *CI, IRBuilderBase &B, - bool isBinary, bool isPrecise = false) { + bool isBinary, const TargetLibraryInfo *TLI, + bool isPrecise = false) { Function *CalleeFn = CI->getCalledFunction(); if (!CI->getType()->isDoubleTy() || !CalleeFn) return nullptr; @@ -1296,22 +1501,25 @@ static Value *optimizeDoubleFP(CallInst *CI, IRBuilderBase &B, R = isBinary ? B.CreateCall(Fn, V) : B.CreateCall(Fn, V[0]); } else { AttributeList CalleeAttrs = CalleeFn->getAttributes(); - R = isBinary ? emitBinaryFloatFnCall(V[0], V[1], CalleeName, B, CalleeAttrs) - : emitUnaryFloatFnCall(V[0], CalleeName, B, CalleeAttrs); + R = isBinary ? emitBinaryFloatFnCall(V[0], V[1], TLI, CalleeName, B, + CalleeAttrs) + : emitUnaryFloatFnCall(V[0], TLI, CalleeName, B, CalleeAttrs); } return B.CreateFPExt(R, B.getDoubleTy()); } /// Shrink double -> float for unary functions. static Value *optimizeUnaryDoubleFP(CallInst *CI, IRBuilderBase &B, + const TargetLibraryInfo *TLI, bool isPrecise = false) { - return optimizeDoubleFP(CI, B, false, isPrecise); + return optimizeDoubleFP(CI, B, false, TLI, isPrecise); } /// Shrink double -> float for binary functions. static Value *optimizeBinaryDoubleFP(CallInst *CI, IRBuilderBase &B, + const TargetLibraryInfo *TLI, bool isPrecise = false) { - return optimizeDoubleFP(CI, B, true, isPrecise); + return optimizeDoubleFP(CI, B, true, TLI, isPrecise); } // cabs(z) -> sqrt((creal(z)*creal(z)) + (cimag(z)*cimag(z))) @@ -1427,6 +1635,7 @@ static Value *getIntToFPVal(Value *I2F, IRBuilderBase &B, unsigned DstWidth) { /// ldexp(1.0, x) for pow(2.0, itofp(x)); exp2(n * x) for pow(2.0 ** n, x); /// exp10(x) for pow(10.0, x); exp2(log2(n) * x) for pow(n, x). Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { + Module *M = Pow->getModule(); Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); AttributeList Attrs; // Attributes are only meaningful on the original call Module *Mod = Pow->getModule(); @@ -1454,7 +1663,8 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { Function *CalleeFn = BaseFn->getCalledFunction(); if (CalleeFn && - TLI->getLibFunc(CalleeFn->getName(), LibFn) && TLI->has(LibFn)) { + TLI->getLibFunc(CalleeFn->getName(), LibFn) && + isLibFuncEmittable(M, TLI, LibFn)) { StringRef ExpName; Intrinsic::ID ID; Value *ExpFn; @@ -1506,7 +1716,7 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { // pow(2.0, itofp(x)) -> ldexp(1.0, x) if (match(Base, m_SpecificFP(2.0)) && (isa<SIToFPInst>(Expo) || isa<UIToFPInst>(Expo)) && - hasFloatFn(TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) { + hasFloatFn(M, TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) { if (Value *ExpoI = getIntToFPVal(Expo, B, TLI->getIntSize())) return copyFlags(*Pow, emitBinaryFloatFnCall(ConstantFP::get(Ty, 1.0), ExpoI, @@ -1515,7 +1725,7 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { } // pow(2.0 ** n, x) -> exp2(n * x) - if (hasFloatFn(TLI, Ty, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l)) { + if (hasFloatFn(M, TLI, Ty, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l)) { APFloat BaseR = APFloat(1.0); BaseR.convert(BaseF->getSemantics(), APFloat::rmTowardZero, &Ignored); BaseR = BaseR / *BaseF; @@ -1542,7 +1752,7 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { // pow(10.0, x) -> exp10(x) // TODO: There is no exp10() intrinsic yet, but some day there shall be one. if (match(Base, m_SpecificFP(10.0)) && - hasFloatFn(TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) + hasFloatFn(M, TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) return copyFlags(*Pow, emitUnaryFloatFnCall(Expo, TLI, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l, B, Attrs)); @@ -1567,7 +1777,8 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { return copyFlags(*Pow, B.CreateCall(Intrinsic::getDeclaration( Mod, Intrinsic::exp2, Ty), FMul, "exp2")); - else if (hasFloatFn(TLI, Ty, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l)) + else if (hasFloatFn(M, TLI, Ty, LibFunc_exp2, LibFunc_exp2f, + LibFunc_exp2l)) return copyFlags(*Pow, emitUnaryFloatFnCall(FMul, TLI, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l, B, Attrs)); @@ -1588,7 +1799,8 @@ static Value *getSqrtCall(Value *V, AttributeList Attrs, bool NoErrno, } // Otherwise, use the libcall for sqrt(). - if (hasFloatFn(TLI, V->getType(), LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) + if (hasFloatFn(M, TLI, V->getType(), LibFunc_sqrt, LibFunc_sqrtf, + LibFunc_sqrtl)) // TODO: We also should check that the target can in fact lower the sqrt() // libcall. We currently have no way to ask this question, so we ask if // the target has a sqrt() libcall, which is not exactly the same. @@ -1778,8 +1990,8 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilderBase &B) { // Shrink pow() to powf() if the arguments are single precision, // unless the result is expected to be double precision. if (UnsafeFPShrink && Name == TLI->getName(LibFunc_pow) && - hasFloatVersion(Name)) { - if (Value *Shrunk = optimizeBinaryDoubleFP(Pow, B, true)) + hasFloatVersion(M, Name)) { + if (Value *Shrunk = optimizeBinaryDoubleFP(Pow, B, TLI, true)) return Shrunk; } @@ -1787,13 +1999,14 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilderBase &B) { } Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilderBase &B) { + Module *M = CI->getModule(); Function *Callee = CI->getCalledFunction(); AttributeList Attrs; // Attributes are only meaningful on the original call StringRef Name = Callee->getName(); Value *Ret = nullptr; if (UnsafeFPShrink && Name == TLI->getName(LibFunc_exp2) && - hasFloatVersion(Name)) - Ret = optimizeUnaryDoubleFP(CI, B, true); + hasFloatVersion(M, Name)) + Ret = optimizeUnaryDoubleFP(CI, B, TLI, true); Type *Ty = CI->getType(); Value *Op = CI->getArgOperand(0); @@ -1801,7 +2014,7 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilderBase &B) { // Turn exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= IntSize // Turn exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < IntSize if ((isa<SIToFPInst>(Op) || isa<UIToFPInst>(Op)) && - hasFloatFn(TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) { + hasFloatFn(M, TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) { if (Value *Exp = getIntToFPVal(Op, B, TLI->getIntSize())) return emitBinaryFloatFnCall(ConstantFP::get(Ty, 1.0), Exp, TLI, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl, @@ -1812,12 +2025,14 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilderBase &B) { } Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilderBase &B) { + Module *M = CI->getModule(); + // If we can shrink the call to a float function rather than a double // function, do that first. Function *Callee = CI->getCalledFunction(); StringRef Name = Callee->getName(); - if ((Name == "fmin" || Name == "fmax") && hasFloatVersion(Name)) - if (Value *Ret = optimizeBinaryDoubleFP(CI, B)) + if ((Name == "fmin" || Name == "fmax") && hasFloatVersion(M, Name)) + if (Value *Ret = optimizeBinaryDoubleFP(CI, B, TLI)) return Ret; // The LLVM intrinsics minnum/maxnum correspond to fmin/fmax. Canonicalize to @@ -1848,8 +2063,8 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { Type *Ty = Log->getType(); Value *Ret = nullptr; - if (UnsafeFPShrink && hasFloatVersion(LogNm)) - Ret = optimizeUnaryDoubleFP(Log, B, true); + if (UnsafeFPShrink && hasFloatVersion(Mod, LogNm)) + Ret = optimizeUnaryDoubleFP(Log, B, TLI, true); // The earlier call must also be 'fast' in order to do these transforms. CallInst *Arg = dyn_cast<CallInst>(Log->getArgOperand(0)); @@ -1957,7 +2172,7 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { Log->doesNotAccessMemory() ? B.CreateCall(Intrinsic::getDeclaration(Mod, LogID, Ty), Arg->getOperand(0), "log") - : emitUnaryFloatFnCall(Arg->getOperand(0), LogNm, B, Attrs); + : emitUnaryFloatFnCall(Arg->getOperand(0), TLI, LogNm, B, Attrs); Value *MulY = B.CreateFMul(Arg->getArgOperand(1), LogX, "mul"); // Since pow() may have side effects, e.g. errno, // dead code elimination may not be trusted to remove it. @@ -1980,7 +2195,7 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { Value *LogE = Log->doesNotAccessMemory() ? B.CreateCall(Intrinsic::getDeclaration(Mod, LogID, Ty), Eul, "log") - : emitUnaryFloatFnCall(Eul, LogNm, B, Attrs); + : emitUnaryFloatFnCall(Eul, TLI, LogNm, B, Attrs); Value *MulY = B.CreateFMul(Arg->getArgOperand(0), LogE, "mul"); // Since exp() may have side effects, e.g. errno, // dead code elimination may not be trusted to remove it. @@ -1992,14 +2207,16 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { } Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) { + Module *M = CI->getModule(); Function *Callee = CI->getCalledFunction(); Value *Ret = nullptr; // TODO: Once we have a way (other than checking for the existince of the // libcall) to tell whether our target can lower @llvm.sqrt, relax the // condition below. - if (TLI->has(LibFunc_sqrtf) && (Callee->getName() == "sqrt" || - Callee->getIntrinsicID() == Intrinsic::sqrt)) - Ret = optimizeUnaryDoubleFP(CI, B, true); + if (isLibFuncEmittable(M, TLI, LibFunc_sqrtf) && + (Callee->getName() == "sqrt" || + Callee->getIntrinsicID() == Intrinsic::sqrt)) + Ret = optimizeUnaryDoubleFP(CI, B, TLI, true); if (!CI->isFast()) return Ret; @@ -2044,7 +2261,6 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) { // If we found a repeated factor, hoist it out of the square root and // replace it with the fabs of that factor. - Module *M = Callee->getParent(); Type *ArgType = I->getType(); Function *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType); Value *FabsCall = B.CreateCall(Fabs, RepeatOp, "fabs"); @@ -2061,11 +2277,12 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) { // TODO: Generalize to handle any trig function and its inverse. Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) { + Module *M = CI->getModule(); Function *Callee = CI->getCalledFunction(); Value *Ret = nullptr; StringRef Name = Callee->getName(); - if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(Name)) - Ret = optimizeUnaryDoubleFP(CI, B, true); + if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(M, Name)) + Ret = optimizeUnaryDoubleFP(CI, B, TLI, true); Value *Op1 = CI->getArgOperand(0); auto *OpC = dyn_cast<CallInst>(Op1); @@ -2081,7 +2298,8 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) { // tanl(atanl(x)) -> x LibFunc Func; Function *F = OpC->getCalledFunction(); - if (F && TLI->getLibFunc(F->getName(), Func) && TLI->has(Func) && + if (F && TLI->getLibFunc(F->getName(), Func) && + isLibFuncEmittable(M, TLI, Func) && ((Func == LibFunc_atan && Callee->getName() == "tan") || (Func == LibFunc_atanf && Callee->getName() == "tanf") || (Func == LibFunc_atanl && Callee->getName() == "tanl"))) @@ -2097,9 +2315,10 @@ static bool isTrigLibCall(CallInst *CI) { CI->hasFnAttr(Attribute::ReadNone); } -static void insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg, +static bool insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg, bool UseFloat, Value *&Sin, Value *&Cos, - Value *&SinCos) { + Value *&SinCos, const TargetLibraryInfo *TLI) { + Module *M = OrigCallee->getParent(); Type *ArgTy = Arg->getType(); Type *ResTy; StringRef Name; @@ -2119,9 +2338,12 @@ static void insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg, ResTy = StructType::get(ArgTy, ArgTy); } - Module *M = OrigCallee->getParent(); - FunctionCallee Callee = - M->getOrInsertFunction(Name, OrigCallee->getAttributes(), ResTy, ArgTy); + if (!isLibFuncEmittable(M, TLI, Name)) + return false; + LibFunc TheLibFunc; + TLI->getLibFunc(Name, TheLibFunc); + FunctionCallee Callee = getOrInsertLibFunc( + M, *TLI, TheLibFunc, OrigCallee->getAttributes(), ResTy, ArgTy); if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) { // If the argument is an instruction, it must dominate all uses so put our @@ -2145,6 +2367,8 @@ static void insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg, Cos = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 1), "cospi"); } + + return true; } Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, IRBuilderBase &B) { @@ -2172,7 +2396,9 @@ Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, IRBuilderBase &B) { return nullptr; Value *Sin, *Cos, *SinCos; - insertSinCosCall(B, CI->getCalledFunction(), Arg, IsFloat, Sin, Cos, SinCos); + if (!insertSinCosCall(B, CI->getCalledFunction(), Arg, IsFloat, Sin, Cos, + SinCos, TLI)) + return nullptr; auto replaceTrigInsts = [this](SmallVectorImpl<CallInst *> &Calls, Value *Res) { @@ -2193,6 +2419,7 @@ void LibCallSimplifier::classifyArgUse( SmallVectorImpl<CallInst *> &CosCalls, SmallVectorImpl<CallInst *> &SinCosCalls) { CallInst *CI = dyn_cast<CallInst>(Val); + Module *M = CI->getModule(); if (!CI || CI->use_empty()) return; @@ -2203,7 +2430,8 @@ void LibCallSimplifier::classifyArgUse( Function *Callee = CI->getCalledFunction(); LibFunc Func; - if (!Callee || !TLI->getLibFunc(*Callee, Func) || !TLI->has(Func) || + if (!Callee || !TLI->getLibFunc(*Callee, Func) || + !isLibFuncEmittable(M, TLI, Func) || !isTrigLibCall(CI)) return; @@ -2258,7 +2486,7 @@ Value *LibCallSimplifier::optimizeAbs(CallInst *CI, IRBuilderBase &B) { // abs(x) -> x <s 0 ? -x : x // The negation has 'nsw' because abs of INT_MIN is undefined. Value *X = CI->getArgOperand(0); - Value *IsNeg = B.CreateICmpSLT(X, Constant::getNullValue(X->getType())); + Value *IsNeg = B.CreateIsNeg(X); Value *NegX = B.CreateNSWNeg(X, "neg"); return B.CreateSelect(IsNeg, NegX, X); } @@ -2418,6 +2646,7 @@ Value *LibCallSimplifier::optimizePrintFString(CallInst *CI, IRBuilderBase &B) { Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilderBase &B) { + Module *M = CI->getModule(); Function *Callee = CI->getCalledFunction(); FunctionType *FT = Callee->getFunctionType(); if (Value *V = optimizePrintFString(CI, B)) { @@ -2426,10 +2655,10 @@ Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilderBase &B) { // printf(format, ...) -> iprintf(format, ...) if no floating point // arguments. - if (TLI->has(LibFunc_iprintf) && !callHasFloatingPointArgument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - FunctionCallee IPrintFFn = - M->getOrInsertFunction("iprintf", FT, Callee->getAttributes()); + if (isLibFuncEmittable(M, TLI, LibFunc_iprintf) && + !callHasFloatingPointArgument(CI)) { + FunctionCallee IPrintFFn = getOrInsertLibFunc(M, *TLI, LibFunc_iprintf, FT, + Callee->getAttributes()); CallInst *New = cast<CallInst>(CI->clone()); New->setCalledFunction(IPrintFFn); B.Insert(New); @@ -2438,11 +2667,10 @@ Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilderBase &B) { // printf(format, ...) -> __small_printf(format, ...) if no 128-bit floating point // arguments. - if (TLI->has(LibFunc_small_printf) && !callHasFP128Argument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - auto SmallPrintFFn = - M->getOrInsertFunction(TLI->getName(LibFunc_small_printf), - FT, Callee->getAttributes()); + if (isLibFuncEmittable(M, TLI, LibFunc_small_printf) && + !callHasFP128Argument(CI)) { + auto SmallPrintFFn = getOrInsertLibFunc(M, *TLI, LibFunc_small_printf, FT, + Callee->getAttributes()); CallInst *New = cast<CallInst>(CI->clone()); New->setCalledFunction(SmallPrintFFn); B.Insert(New); @@ -2489,7 +2717,7 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI, Value *V = B.CreateTrunc(CI->getArgOperand(2), B.getInt8Ty(), "char"); Value *Ptr = castToCStr(Dest, B); B.CreateStore(V, Ptr); - Ptr = B.CreateGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); + Ptr = B.CreateInBoundsGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); B.CreateStore(B.getInt8(0), Ptr); return ConstantInt::get(CI->getType(), 1); @@ -2541,6 +2769,7 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI, } Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilderBase &B) { + Module *M = CI->getModule(); Function *Callee = CI->getCalledFunction(); FunctionType *FT = Callee->getFunctionType(); if (Value *V = optimizeSPrintFString(CI, B)) { @@ -2549,10 +2778,10 @@ Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilderBase &B) { // sprintf(str, format, ...) -> siprintf(str, format, ...) if no floating // point arguments. - if (TLI->has(LibFunc_siprintf) && !callHasFloatingPointArgument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - FunctionCallee SIPrintFFn = - M->getOrInsertFunction("siprintf", FT, Callee->getAttributes()); + if (isLibFuncEmittable(M, TLI, LibFunc_siprintf) && + !callHasFloatingPointArgument(CI)) { + FunctionCallee SIPrintFFn = getOrInsertLibFunc(M, *TLI, LibFunc_siprintf, + FT, Callee->getAttributes()); CallInst *New = cast<CallInst>(CI->clone()); New->setCalledFunction(SIPrintFFn); B.Insert(New); @@ -2561,11 +2790,10 @@ Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilderBase &B) { // sprintf(str, format, ...) -> __small_sprintf(str, format, ...) if no 128-bit // floating point arguments. - if (TLI->has(LibFunc_small_sprintf) && !callHasFP128Argument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - auto SmallSPrintFFn = - M->getOrInsertFunction(TLI->getName(LibFunc_small_sprintf), - FT, Callee->getAttributes()); + if (isLibFuncEmittable(M, TLI, LibFunc_small_sprintf) && + !callHasFP128Argument(CI)) { + auto SmallSPrintFFn = getOrInsertLibFunc(M, *TLI, LibFunc_small_sprintf, FT, + Callee->getAttributes()); CallInst *New = cast<CallInst>(CI->clone()); New->setCalledFunction(SmallSPrintFFn); B.Insert(New); @@ -2629,7 +2857,7 @@ Value *LibCallSimplifier::optimizeSnPrintFString(CallInst *CI, Value *V = B.CreateTrunc(CI->getArgOperand(3), B.getInt8Ty(), "char"); Value *Ptr = castToCStr(CI->getArgOperand(0), B); B.CreateStore(V, Ptr); - Ptr = B.CreateGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); + Ptr = B.CreateInBoundsGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); B.CreateStore(B.getInt8(0), Ptr); return ConstantInt::get(CI->getType(), 1); @@ -2721,6 +2949,7 @@ Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI, } Value *LibCallSimplifier::optimizeFPrintF(CallInst *CI, IRBuilderBase &B) { + Module *M = CI->getModule(); Function *Callee = CI->getCalledFunction(); FunctionType *FT = Callee->getFunctionType(); if (Value *V = optimizeFPrintFString(CI, B)) { @@ -2729,10 +2958,10 @@ Value *LibCallSimplifier::optimizeFPrintF(CallInst *CI, IRBuilderBase &B) { // fprintf(stream, format, ...) -> fiprintf(stream, format, ...) if no // floating point arguments. - if (TLI->has(LibFunc_fiprintf) && !callHasFloatingPointArgument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); - FunctionCallee FIPrintFFn = - M->getOrInsertFunction("fiprintf", FT, Callee->getAttributes()); + if (isLibFuncEmittable(M, TLI, LibFunc_fiprintf) && + !callHasFloatingPointArgument(CI)) { + FunctionCallee FIPrintFFn = getOrInsertLibFunc(M, *TLI, LibFunc_fiprintf, + FT, Callee->getAttributes()); CallInst *New = cast<CallInst>(CI->clone()); New->setCalledFunction(FIPrintFFn); B.Insert(New); @@ -2741,11 +2970,11 @@ Value *LibCallSimplifier::optimizeFPrintF(CallInst *CI, IRBuilderBase &B) { // fprintf(stream, format, ...) -> __small_fprintf(stream, format, ...) if no // 128-bit floating point arguments. - if (TLI->has(LibFunc_small_fprintf) && !callHasFP128Argument(CI)) { - Module *M = B.GetInsertBlock()->getParent()->getParent(); + if (isLibFuncEmittable(M, TLI, LibFunc_small_fprintf) && + !callHasFP128Argument(CI)) { auto SmallFPrintFFn = - M->getOrInsertFunction(TLI->getName(LibFunc_small_fprintf), - FT, Callee->getAttributes()); + getOrInsertLibFunc(M, *TLI, LibFunc_small_fprintf, FT, + Callee->getAttributes()); CallInst *New = cast<CallInst>(CI->clone()); New->setCalledFunction(SmallFPrintFFn); B.Insert(New); @@ -2830,21 +3059,19 @@ Value *LibCallSimplifier::optimizeBCopy(CallInst *CI, IRBuilderBase &B) { CI->getArgOperand(2))); } -bool LibCallSimplifier::hasFloatVersion(StringRef FuncName) { - LibFunc Func; +bool LibCallSimplifier::hasFloatVersion(const Module *M, StringRef FuncName) { SmallString<20> FloatFuncName = FuncName; FloatFuncName += 'f'; - if (TLI->getLibFunc(FloatFuncName, Func)) - return TLI->has(Func); - return false; + return isLibFuncEmittable(M, TLI, FloatFuncName); } Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, IRBuilderBase &Builder) { + Module *M = CI->getModule(); LibFunc Func; Function *Callee = CI->getCalledFunction(); // Check for string/memory library functions. - if (TLI->getLibFunc(*Callee, Func) && TLI->has(Func)) { + if (TLI->getLibFunc(*Callee, Func) && isLibFuncEmittable(M, TLI, Func)) { // Make sure we never change the calling convention. assert( (ignoreCallingConv(Func) || @@ -2871,6 +3098,8 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, return optimizeStrNCpy(CI, Builder); case LibFunc_strlen: return optimizeStrLen(CI, Builder); + case LibFunc_strnlen: + return optimizeStrNLen(CI, Builder); case LibFunc_strpbrk: return optimizeStrPBrk(CI, Builder); case LibFunc_strndup: @@ -2923,6 +3152,8 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, LibFunc Func, IRBuilderBase &Builder) { + const Module *M = CI->getModule(); + // Don't optimize calls that require strict floating point semantics. if (CI->isStrictFP()) return nullptr; @@ -3001,12 +3232,12 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_sin: case LibFunc_sinh: case LibFunc_tanh: - if (UnsafeFPShrink && hasFloatVersion(CI->getCalledFunction()->getName())) - return optimizeUnaryDoubleFP(CI, Builder, true); + if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName())) + return optimizeUnaryDoubleFP(CI, Builder, TLI, true); return nullptr; case LibFunc_copysign: - if (hasFloatVersion(CI->getCalledFunction()->getName())) - return optimizeBinaryDoubleFP(CI, Builder); + if (hasFloatVersion(M, CI->getCalledFunction()->getName())) + return optimizeBinaryDoubleFP(CI, Builder, TLI); return nullptr; case LibFunc_fminf: case LibFunc_fmin: @@ -3025,6 +3256,7 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, } Value *LibCallSimplifier::optimizeCall(CallInst *CI, IRBuilderBase &Builder) { + Module *M = CI->getModule(); assert(!CI->isMustTailCall() && "These transforms aren't musttail safe."); // TODO: Split out the code below that operates on FP calls so that @@ -3103,7 +3335,7 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI, IRBuilderBase &Builder) { } // Then check for known library functions. - if (TLI->getLibFunc(*Callee, Func) && TLI->has(Func)) { + if (TLI->getLibFunc(*Callee, Func) && isLibFuncEmittable(M, TLI, Func)) { // We never change the calling convention. if (!ignoreCallingConv(Func) && !IsCallingConvC) return nullptr; @@ -3170,7 +3402,7 @@ LibCallSimplifier::LibCallSimplifier( function_ref<void(Instruction *, Value *)> Replacer, function_ref<void(Instruction *)> Eraser) : FortifiedSimplifier(TLI), DL(DL), TLI(TLI), ORE(ORE), BFI(BFI), PSI(PSI), - UnsafeFPShrink(false), Replacer(Replacer), Eraser(Eraser) {} + Replacer(Replacer), Eraser(Eraser) {} void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) { // Indirect through the replacer used in this instance. @@ -3361,7 +3593,8 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI, // If the function was an __stpcpy_chk, and we were able to fold it into // a __memcpy_chk, we still need to return the correct end pointer. if (Ret && Func == LibFunc_stpcpy_chk) - return B.CreateGEP(B.getInt8Ty(), Dst, ConstantInt::get(SizeTTy, Len - 1)); + return B.CreateInBoundsGEP(B.getInt8Ty(), Dst, + ConstantInt::get(SizeTTy, Len - 1)); return copyFlags(*CI, cast<CallInst>(Ret)); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SizeOpts.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SizeOpts.cpp index 08a29ea16ba1..1242380f73c1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SizeOpts.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SizeOpts.cpp @@ -48,12 +48,12 @@ cl::opt<bool> llvm::ForcePGSO( cl::desc("Force the (profiled-guided) size optimizations. ")); cl::opt<int> llvm::PgsoCutoffInstrProf( - "pgso-cutoff-instr-prof", cl::Hidden, cl::init(950000), cl::ZeroOrMore, + "pgso-cutoff-instr-prof", cl::Hidden, cl::init(950000), cl::desc("The profile guided size optimization profile summary cutoff " "for instrumentation profile.")); cl::opt<int> llvm::PgsoCutoffSampleProf( - "pgso-cutoff-sample-prof", cl::Hidden, cl::init(990000), cl::ZeroOrMore, + "pgso-cutoff-sample-prof", cl::Hidden, cl::init(990000), cl::desc("The profile guided size optimization profile summary cutoff " "for sample profile.")); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/StripGCRelocates.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/StripGCRelocates.cpp index 1fa574f04c37..0ff88e8b4612 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/StripGCRelocates.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/StripGCRelocates.cpp @@ -9,7 +9,7 @@ // This is a little utility pass that removes the gc.relocates inserted by // RewriteStatepointsForGC. Note that the generated IR is incorrect, // but this is useful as a single pass in itself, for analysis of IR, without -// the GC.relocates. The statepoint and gc.result instrinsics would still be +// the GC.relocates. The statepoint and gc.result intrinsics would still be // present. //===----------------------------------------------------------------------===// @@ -18,10 +18,8 @@ #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Statepoint.h" -#include "llvm/IR/Type.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" -#include "llvm/Support/raw_ostream.h" using namespace llvm; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SymbolRewriter.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SymbolRewriter.cpp index 6a0eb34a7999..4ad16d622e8d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SymbolRewriter.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SymbolRewriter.cpp @@ -57,7 +57,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/SymbolRewriter.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/ilist.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp index 0b718ed6136e..832353741500 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp @@ -18,7 +18,9 @@ #include "llvm/Transforms/Utils/UnifyLoopExits.h" #include "llvm/ADT/MapVector.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/InitializePasses.h" #include "llvm/Transforms/Utils.h" @@ -143,6 +145,8 @@ static bool unifyLoopExits(DominatorTree &DT, LoopInfo &LI, Loop *L) { // locate the exit blocks. SetVector<BasicBlock *> ExitingBlocks; SetVector<BasicBlock *> Exits; + // Record the exit blocks that branch to the same block. + MapVector<BasicBlock *, SetVector<BasicBlock *> > CommonSuccs; // We need SetVectors, but the Loop API takes a vector, so we use a temporary. SmallVector<BasicBlock *, 8> Temp; @@ -156,6 +160,11 @@ static bool unifyLoopExits(DominatorTree &DT, LoopInfo &LI, Loop *L) { if (SL == L || L->contains(SL)) continue; Exits.insert(S); + // The typical case for reducing the number of guard blocks occurs when + // the exit block has a single predecessor and successor. + if (S->getSinglePredecessor()) + if (auto *Succ = S->getSingleSuccessor()) + CommonSuccs[Succ].insert(S); } } @@ -170,13 +179,39 @@ static bool unifyLoopExits(DominatorTree &DT, LoopInfo &LI, Loop *L) { for (auto EB : ExitingBlocks) { dbgs() << " " << EB->getName(); } - dbgs() << "\n";); + dbgs() << "\n"; + + dbgs() << "Exit blocks with a common successor:\n"; + for (auto CS : CommonSuccs) { + dbgs() << " Succ " << CS.first->getName() << ", exits:"; + for (auto Exit : CS.second) + dbgs() << " " << Exit->getName(); + dbgs() << "\n"; + }); if (Exits.size() <= 1) { LLVM_DEBUG(dbgs() << "loop does not have multiple exits; nothing to do\n"); return false; } + // When multiple exit blocks branch to the same block, change the control + // flow hub to after the exit blocks rather than before. This reduces the + // number of guard blocks needed after the loop. + for (auto CS : CommonSuccs) { + auto CB = CS.first; + auto Preds = CS.second; + if (Exits.contains(CB)) + continue; + if (Preds.size() < 2 || Preds.size() == Exits.size()) + continue; + for (auto Exit : Preds) { + Exits.remove(Exit); + ExitingBlocks.remove(Exit->getSinglePredecessor()); + ExitingBlocks.insert(Exit); + } + Exits.insert(CB); + } + SmallVector<BasicBlock *, 8> GuardBlocks; DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); auto LoopExitBlock = CreateControlFlowHub(&DTU, GuardBlocks, ExitingBlocks, @@ -196,6 +231,17 @@ static bool unifyLoopExits(DominatorTree &DT, LoopInfo &LI, Loop *L) { if (auto ParentLoop = L->getParentLoop()) { for (auto G : GuardBlocks) { ParentLoop->addBasicBlockToLoop(G, LI); + // Ensure the guard block predecessors are in a valid loop. After the + // change to the control flow hub for common successors, a guard block + // predecessor may not be in a loop or may be in an outer loop. + for (auto Pred : predecessors(G)) { + auto PredLoop = LI.getLoopFor(Pred); + if (!ParentLoop->contains(PredLoop)) { + if (PredLoop) + LI.removeBlock(Pred); + ParentLoop->addBasicBlockToLoop(Pred, LI); + } + } } ParentLoop->verifyLoop(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/Utils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/Utils.cpp index 43eb5c87acee..f34f2df971b1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/Utils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/Utils.cpp @@ -34,6 +34,7 @@ void llvm::initializeTransformUtils(PassRegistry &Registry) { initializeLCSSAWrapperPassPass(Registry); initializeLibCallsShrinkWrapLegacyPassPass(Registry); initializeLoopSimplifyPass(Registry); + initializeLowerGlobalDtorsLegacyPassPass(Registry); initializeLowerInvokeLegacyPassPass(Registry); initializeLowerSwitchLegacyPassPass(Registry); initializeNameAnonGlobalLegacyPassPass(Registry); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/VNCoercion.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/VNCoercion.cpp index 637181722f63..42be67f3cfc0 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/VNCoercion.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/VNCoercion.cpp @@ -64,10 +64,15 @@ bool canCoerceMustAliasedValueToLoad(Value *StoredVal, Type *LoadTy, return true; } -template <class T, class HelperClass> -static T *coerceAvailableValueToLoadTypeHelper(T *StoredVal, Type *LoadedTy, - HelperClass &Helper, - const DataLayout &DL) { +/// If we saw a store of a value to memory, and +/// then a load from a must-aliased pointer of a different type, try to coerce +/// the stored value. LoadedTy is the type of the load we want to replace. +/// IRB is IRBuilder used to insert new instructions. +/// +/// If we can't do it, return null. +Value *coerceAvailableValueToLoadType(Value *StoredVal, Type *LoadedTy, + IRBuilderBase &Helper, + const DataLayout &DL) { assert(canCoerceMustAliasedValueToLoad(StoredVal, LoadedTy, DL) && "precondition violation - materialization can't fail"); if (auto *C = dyn_cast<Constant>(StoredVal)) @@ -154,18 +159,6 @@ static T *coerceAvailableValueToLoadTypeHelper(T *StoredVal, Type *LoadedTy, return StoredVal; } -/// If we saw a store of a value to memory, and -/// then a load from a must-aliased pointer of a different type, try to coerce -/// the stored value. LoadedTy is the type of the load we want to replace. -/// IRB is IRBuilder used to insert new instructions. -/// -/// If we can't do it, return null. -Value *coerceAvailableValueToLoadType(Value *StoredVal, Type *LoadedTy, - IRBuilderBase &IRB, - const DataLayout &DL) { - return coerceAvailableValueToLoadTypeHelper(StoredVal, LoadedTy, IRB, DL); -} - /// This function is called when we have a memdep query of a load that ends up /// being a clobbering memory write (store, memset, memcpy, memmove). This /// means that the write *may* provide bits used by the load but we can't be @@ -277,7 +270,7 @@ static unsigned getLoadLoadClobberFullWidthSize(const Value *MemLocBase, // looking at an i8 load on x86-32 that is known 1024 byte aligned, we can // widen it up to an i32 load. If it is known 2-byte aligned, we can widen it // to i16. - unsigned LoadAlign = LI->getAlignment(); + unsigned LoadAlign = LI->getAlign().value(); int64_t MemLocEnd = MemLocOffs + MemLocSize; @@ -400,10 +393,9 @@ int analyzeLoadFromClobberingMemInst(Type *LoadTy, Value *LoadPtr, return -1; } -template <class T, class HelperClass> -static T *getStoreValueForLoadHelper(T *SrcVal, unsigned Offset, Type *LoadTy, - HelperClass &Helper, - const DataLayout &DL) { +static Value *getStoreValueForLoadHelper(Value *SrcVal, unsigned Offset, + Type *LoadTy, IRBuilderBase &Builder, + const DataLayout &DL) { LLVMContext &Ctx = SrcVal->getType()->getContext(); // If two pointers are in the same address space, they have the same size, @@ -421,9 +413,11 @@ static T *getStoreValueForLoadHelper(T *SrcVal, unsigned Offset, Type *LoadTy, // Compute which bits of the stored value are being used by the load. Convert // to an integer type to start with. if (SrcVal->getType()->isPtrOrPtrVectorTy()) - SrcVal = Helper.CreatePtrToInt(SrcVal, DL.getIntPtrType(SrcVal->getType())); + SrcVal = + Builder.CreatePtrToInt(SrcVal, DL.getIntPtrType(SrcVal->getType())); if (!SrcVal->getType()->isIntegerTy()) - SrcVal = Helper.CreateBitCast(SrcVal, IntegerType::get(Ctx, StoreSize * 8)); + SrcVal = + Builder.CreateBitCast(SrcVal, IntegerType::get(Ctx, StoreSize * 8)); // Shift the bits to the least significant depending on endianness. unsigned ShiftAmt; @@ -432,12 +426,12 @@ static T *getStoreValueForLoadHelper(T *SrcVal, unsigned Offset, Type *LoadTy, else ShiftAmt = (StoreSize - LoadSize - Offset) * 8; if (ShiftAmt) - SrcVal = Helper.CreateLShr(SrcVal, - ConstantInt::get(SrcVal->getType(), ShiftAmt)); + SrcVal = Builder.CreateLShr(SrcVal, + ConstantInt::get(SrcVal->getType(), ShiftAmt)); if (LoadSize != StoreSize) - SrcVal = Helper.CreateTruncOrBitCast(SrcVal, - IntegerType::get(Ctx, LoadSize * 8)); + SrcVal = Builder.CreateTruncOrBitCast(SrcVal, + IntegerType::get(Ctx, LoadSize * 8)); return SrcVal; } @@ -450,14 +444,12 @@ Value *getStoreValueForLoad(Value *SrcVal, unsigned Offset, Type *LoadTy, IRBuilder<> Builder(InsertPt); SrcVal = getStoreValueForLoadHelper(SrcVal, Offset, LoadTy, Builder, DL); - return coerceAvailableValueToLoadTypeHelper(SrcVal, LoadTy, Builder, DL); + return coerceAvailableValueToLoadType(SrcVal, LoadTy, Builder, DL); } Constant *getConstantStoreValueForLoad(Constant *SrcVal, unsigned Offset, Type *LoadTy, const DataLayout &DL) { - ConstantFolder F; - SrcVal = getStoreValueForLoadHelper(SrcVal, Offset, LoadTy, F, DL); - return coerceAvailableValueToLoadTypeHelper(SrcVal, LoadTy, F, DL); + return ConstantFoldLoadFromConst(SrcVal, LoadTy, APInt(32, Offset), DL); } /// This function is called when we have a memdep query of a load that ends up @@ -522,75 +514,77 @@ Constant *getConstantLoadValueForLoad(Constant *SrcVal, unsigned Offset, return getConstantStoreValueForLoad(SrcVal, Offset, LoadTy, DL); } -template <class T, class HelperClass> -T *getMemInstValueForLoadHelper(MemIntrinsic *SrcInst, unsigned Offset, - Type *LoadTy, HelperClass &Helper, - const DataLayout &DL) { +/// This function is called when we have a +/// memdep query of a load that ends up being a clobbering mem intrinsic. +Value *getMemInstValueForLoad(MemIntrinsic *SrcInst, unsigned Offset, + Type *LoadTy, Instruction *InsertPt, + const DataLayout &DL) { LLVMContext &Ctx = LoadTy->getContext(); uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy).getFixedSize() / 8; + IRBuilder<> Builder(InsertPt); // We know that this method is only called when the mem transfer fully // provides the bits for the load. if (MemSetInst *MSI = dyn_cast<MemSetInst>(SrcInst)) { // memset(P, 'x', 1234) -> splat('x'), even if x is a variable, and // independently of what the offset is. - T *Val = cast<T>(MSI->getValue()); + Value *Val = MSI->getValue(); if (LoadSize != 1) Val = - Helper.CreateZExtOrBitCast(Val, IntegerType::get(Ctx, LoadSize * 8)); - T *OneElt = Val; + Builder.CreateZExtOrBitCast(Val, IntegerType::get(Ctx, LoadSize * 8)); + Value *OneElt = Val; // Splat the value out to the right number of bits. for (unsigned NumBytesSet = 1; NumBytesSet != LoadSize;) { // If we can double the number of bytes set, do it. if (NumBytesSet * 2 <= LoadSize) { - T *ShVal = Helper.CreateShl( + Value *ShVal = Builder.CreateShl( Val, ConstantInt::get(Val->getType(), NumBytesSet * 8)); - Val = Helper.CreateOr(Val, ShVal); + Val = Builder.CreateOr(Val, ShVal); NumBytesSet <<= 1; continue; } // Otherwise insert one byte at a time. - T *ShVal = Helper.CreateShl(Val, ConstantInt::get(Val->getType(), 1 * 8)); - Val = Helper.CreateOr(OneElt, ShVal); + Value *ShVal = + Builder.CreateShl(Val, ConstantInt::get(Val->getType(), 1 * 8)); + Val = Builder.CreateOr(OneElt, ShVal); ++NumBytesSet; } - return coerceAvailableValueToLoadTypeHelper(Val, LoadTy, Helper, DL); + return coerceAvailableValueToLoadType(Val, LoadTy, Builder, DL); } // Otherwise, this is a memcpy/memmove from a constant global. MemTransferInst *MTI = cast<MemTransferInst>(SrcInst); Constant *Src = cast<Constant>(MTI->getSource()); - - // Otherwise, see if we can constant fold a load from the constant with the - // offset applied as appropriate. unsigned IndexSize = DL.getIndexTypeSizeInBits(Src->getType()); - return ConstantFoldLoadFromConstPtr( - Src, LoadTy, APInt(IndexSize, Offset), DL); -} - -/// This function is called when we have a -/// memdep query of a load that ends up being a clobbering mem intrinsic. -Value *getMemInstValueForLoad(MemIntrinsic *SrcInst, unsigned Offset, - Type *LoadTy, Instruction *InsertPt, - const DataLayout &DL) { - IRBuilder<> Builder(InsertPt); - return getMemInstValueForLoadHelper<Value, IRBuilder<>>(SrcInst, Offset, - LoadTy, Builder, DL); + return ConstantFoldLoadFromConstPtr(Src, LoadTy, APInt(IndexSize, Offset), + DL); } Constant *getConstantMemInstValueForLoad(MemIntrinsic *SrcInst, unsigned Offset, Type *LoadTy, const DataLayout &DL) { - // The only case analyzeLoadFromClobberingMemInst cannot be converted to a - // constant is when it's a memset of a non-constant. - if (auto *MSI = dyn_cast<MemSetInst>(SrcInst)) - if (!isa<Constant>(MSI->getValue())) + LLVMContext &Ctx = LoadTy->getContext(); + uint64_t LoadSize = DL.getTypeSizeInBits(LoadTy).getFixedSize() / 8; + + // We know that this method is only called when the mem transfer fully + // provides the bits for the load. + if (MemSetInst *MSI = dyn_cast<MemSetInst>(SrcInst)) { + auto *Val = dyn_cast<ConstantInt>(MSI->getValue()); + if (!Val) return nullptr; - ConstantFolder F; - return getMemInstValueForLoadHelper<Constant, ConstantFolder>(SrcInst, Offset, - LoadTy, F, DL); + + Val = ConstantInt::get(Ctx, APInt::getSplat(LoadSize * 8, Val->getValue())); + return ConstantFoldLoadFromConst(Val, LoadTy, DL); + } + + // Otherwise, this is a memcpy/memmove from a constant global. + MemTransferInst *MTI = cast<MemTransferInst>(SrcInst); + Constant *Src = cast<Constant>(MTI->getSource()); + unsigned IndexSize = DL.getIndexTypeSizeInBits(Src->getType()); + return ConstantFoldLoadFromConstPtr(Src, LoadTy, APInt(IndexSize, Offset), + DL); } } // namespace VNCoercion } // namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index 97c2acb7d4c7..f59fc3a6dd60 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -62,14 +62,13 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" -#include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" @@ -497,7 +496,7 @@ bool Vectorizer::lookThroughComplexAddresses(Value *PtrA, Value *PtrB, if (PtrDelta.urem(Stride) != 0) return false; unsigned IdxBitWidth = OpA->getType()->getScalarSizeInBits(); - APInt IdxDiff = PtrDelta.udiv(Stride).zextOrSelf(IdxBitWidth); + APInt IdxDiff = PtrDelta.udiv(Stride).zext(IdxBitWidth); // Only look through a ZExt/SExt. if (!isa<SExtInst>(OpA) && !isa<ZExtInst>(OpA)) @@ -1298,10 +1297,16 @@ bool Vectorizer::vectorizeLoadChain( CV->replaceAllUsesWith(V); } - // Bitcast might not be an Instruction, if the value being loaded is a - // constant. In that case, no need to reorder anything. - if (Instruction *BitcastInst = dyn_cast<Instruction>(Bitcast)) - reorder(BitcastInst); + // Since we might have opaque pointers we might end up using the pointer + // operand of the first load (wrt. memory loaded) for the vector load. Since + // this first load might not be the first in the block we potentially need to + // reorder the pointer operand (and its operands). If we have a bitcast though + // it might be before the load and should be the reorder start instruction. + // "Might" because for opaque pointers the "bitcast" is just the first loads + // pointer operand, as oppposed to something we inserted at the right position + // ourselves. + Instruction *BCInst = dyn_cast<Instruction>(Bitcast); + reorder((BCInst && BCInst != L0->getPointerOperand()) ? BCInst : LI); eraseInstructions(Chain); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp index 81e5aa223c07..6242d9a93fc1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -17,7 +17,9 @@ #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/IntrinsicInst.h" @@ -31,8 +33,6 @@ using namespace PatternMatch; #define LV_NAME "loop-vectorize" #define DEBUG_TYPE LV_NAME -extern cl::opt<bool> EnableVPlanPredication; - static cl::opt<bool> EnableIfConversion("enable-if-conversion", cl::init(true), cl::Hidden, cl::desc("Enable if-conversion during vectorization.")); @@ -439,6 +439,26 @@ static bool hasOutsideLoopUser(const Loop *TheLoop, Instruction *Inst, return false; } +/// Returns true if A and B have same pointer operands or same SCEVs addresses +static bool storeToSameAddress(ScalarEvolution *SE, StoreInst *A, + StoreInst *B) { + // Compare store + if (A == B) + return true; + + // Otherwise Compare pointers + Value *APtr = A->getPointerOperand(); + Value *BPtr = B->getPointerOperand(); + if (APtr == BPtr) + return true; + + // Otherwise compare address SCEVs + if (SE->getSCEV(APtr) == SE->getSCEV(BPtr)) + return true; + + return false; +} + int LoopVectorizationLegality::isConsecutivePtr(Type *AccessTy, Value *Ptr) const { const ValueToValueMap &Strides = @@ -487,7 +507,7 @@ bool LoopVectorizationLegality::canVectorizeOuterLoop() { // FIXME: We skip these checks when VPlan predication is enabled as we // want to allow divergent branches. This whole check will be removed // once VPlan predication is on by default. - if (!EnableVPlanPredication && Br && Br->isConditional() && + if (Br && Br->isConditional() && !TheLoop->isLoopInvariant(Br->getCondition()) && !LI->isLoopHeader(Br->getSuccessor(0)) && !LI->isLoopHeader(Br->getSuccessor(1))) { @@ -572,7 +592,7 @@ void LoopVectorizationLegality::addInductionPhi( // on predicates that only hold within the loop, since allowing the exit // currently means re-using this SCEV outside the loop (see PR33706 for more // details). - if (PSE.getUnionPredicate().isAlwaysTrue()) { + if (PSE.getPredicate().isAlwaysTrue()) { AllowedExit.insert(Phi); AllowedExit.insert(Phi->getIncomingValueForBlock(TheLoop->getLoopLatch())); } @@ -676,7 +696,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { RecurrenceDescriptor RedDes; if (RecurrenceDescriptor::isReductionPHI(Phi, TheLoop, RedDes, DB, AC, - DT)) { + DT, PSE.getSE())) { Requirements->addExactFPMathInst(RedDes.getExactFPMathInst()); AllowedExit.insert(RedDes.getLoopExitInstr()); Reductions[Phi] = RedDes; @@ -770,7 +790,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { auto *SE = PSE.getSE(); Intrinsic::ID IntrinID = getVectorIntrinsicIDForCall(CI, TLI); for (unsigned i = 0, e = CI->arg_size(); i != e; ++i) - if (hasVectorInstrinsicScalarOpd(IntrinID, i)) { + if (isVectorIntrinsicWithScalarOpAtArg(IntrinID, i)) { if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(i)), TheLoop)) { reportVectorizationFailure("Found unvectorizable intrinsic", "intrinsic instruction cannot be vectorized", @@ -849,7 +869,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // used outside the loop only if the SCEV predicates within the loop is // same as outside the loop. Allowing the exit means reusing the SCEV // outside the loop. - if (PSE.getUnionPredicate().isAlwaysTrue()) { + if (PSE.getPredicate().isAlwaysTrue()) { AllowedExit.insert(&I); continue; } @@ -911,15 +931,70 @@ bool LoopVectorizationLegality::canVectorizeMemory() { if (!LAI->canVectorizeMemory()) return false; - if (LAI->hasDependenceInvolvingLoopInvariantAddress()) { - reportVectorizationFailure("Stores to a uniform address", - "write to a loop invariant address could not be vectorized", - "CantVectorizeStoreToLoopInvariantAddress", ORE, TheLoop); - return false; + // We can vectorize stores to invariant address when final reduction value is + // guaranteed to be stored at the end of the loop. Also, if decision to + // vectorize loop is made, runtime checks are added so as to make sure that + // invariant address won't alias with any other objects. + if (!LAI->getStoresToInvariantAddresses().empty()) { + // For each invariant address, check its last stored value is unconditional. + for (StoreInst *SI : LAI->getStoresToInvariantAddresses()) { + if (isInvariantStoreOfReduction(SI) && + blockNeedsPredication(SI->getParent())) { + reportVectorizationFailure( + "We don't allow storing to uniform addresses", + "write of conditional recurring variant value to a loop " + "invariant address could not be vectorized", + "CantVectorizeStoreToLoopInvariantAddress", ORE, TheLoop); + return false; + } + } + + if (LAI->hasDependenceInvolvingLoopInvariantAddress()) { + // For each invariant address, check its last stored value is the result + // of one of our reductions. + // + // We do not check if dependence with loads exists because they are + // currently rejected earlier in LoopAccessInfo::analyzeLoop. In case this + // behaviour changes we have to modify this code. + ScalarEvolution *SE = PSE.getSE(); + SmallVector<StoreInst *, 4> UnhandledStores; + for (StoreInst *SI : LAI->getStoresToInvariantAddresses()) { + if (isInvariantStoreOfReduction(SI)) { + // Earlier stores to this address are effectively deadcode. + // With opaque pointers it is possible for one pointer to be used with + // different sizes of stored values: + // store i32 0, ptr %x + // store i8 0, ptr %x + // The latest store doesn't complitely overwrite the first one in the + // example. That is why we have to make sure that types of stored + // values are same. + // TODO: Check that bitwidth of unhandled store is smaller then the + // one that overwrites it and add a test. + erase_if(UnhandledStores, [SE, SI](StoreInst *I) { + return storeToSameAddress(SE, SI, I) && + I->getValueOperand()->getType() == + SI->getValueOperand()->getType(); + }); + continue; + } + UnhandledStores.push_back(SI); + } + + bool IsOK = UnhandledStores.empty(); + // TODO: we should also validate against InvariantMemSets. + if (!IsOK) { + reportVectorizationFailure( + "We don't allow storing to uniform addresses", + "write to a loop invariant address could not " + "be vectorized", + "CantVectorizeStoreToLoopInvariantAddress", ORE, TheLoop); + return false; + } + } } Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks()); - PSE.addPredicate(LAI->getPSE().getUnionPredicate()); + PSE.addPredicate(LAI->getPSE().getPredicate()); return true; } @@ -949,6 +1024,26 @@ bool LoopVectorizationLegality::canVectorizeFPMath( })); } +bool LoopVectorizationLegality::isInvariantStoreOfReduction(StoreInst *SI) { + return any_of(getReductionVars(), [&](auto &Reduction) -> bool { + const RecurrenceDescriptor &RdxDesc = Reduction.second; + return RdxDesc.IntermediateStore == SI; + }); +} + +bool LoopVectorizationLegality::isInvariantAddressOfReduction(Value *V) { + return any_of(getReductionVars(), [&](auto &Reduction) -> bool { + const RecurrenceDescriptor &RdxDesc = Reduction.second; + if (!RdxDesc.IntermediateStore) + return false; + + ScalarEvolution *SE = PSE.getSE(); + Value *InvariantAddress = RdxDesc.IntermediateStore->getPointerOperand(); + return V == InvariantAddress || + SE->getSCEV(V) == SE->getSCEV(InvariantAddress); + }); +} + bool LoopVectorizationLegality::isInductionPhi(const Value *V) const { Value *In0 = const_cast<Value *>(V); PHINode *PN = dyn_cast_or_null<PHINode>(In0); @@ -969,6 +1064,16 @@ LoopVectorizationLegality::getIntOrFpInductionDescriptor(PHINode *Phi) const { return nullptr; } +const InductionDescriptor * +LoopVectorizationLegality::getPointerInductionDescriptor(PHINode *Phi) const { + if (!isInductionPhi(Phi)) + return nullptr; + auto &ID = getInductionVars().find(Phi)->second; + if (ID.getKind() == InductionDescriptor::IK_PtrInduction) + return &ID; + return nullptr; +} + bool LoopVectorizationLegality::isCastedInductionVariable( const Value *V) const { auto *Inst = dyn_cast<Instruction>(V); @@ -1266,7 +1371,7 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) { if (Hints->getForce() == LoopVectorizeHints::FK_Enabled) SCEVThreshold = PragmaVectorizeSCEVCheckThreshold; - if (PSE.getUnionPredicate().getComplexity() > SCEVThreshold) { + if (PSE.getPredicate().getComplexity() > SCEVThreshold) { reportVectorizationFailure("Too many SCEV checks needed", "Too many SCEV assumptions need to be made and checked at runtime", "TooManySCEVRunTimeChecks", ORE, TheLoop); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h index 71eb39a18d2f..0cb2032fa45a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -25,6 +25,7 @@ #define LLVM_TRANSFORMS_VECTORIZE_LOOPVECTORIZATIONPLANNER_H #include "VPlan.h" +#include "llvm/Support/InstructionCost.h" namespace llvm { @@ -59,7 +60,7 @@ class VPBuilder { } public: - VPBuilder() {} + VPBuilder() = default; /// Clear the insertion point: created instructions will not be inserted into /// a block. @@ -187,12 +188,16 @@ struct VectorizationFactor { /// Cost of the loop with that width. InstructionCost Cost; - VectorizationFactor(ElementCount Width, InstructionCost Cost) - : Width(Width), Cost(Cost) {} + /// Cost of the scalar loop. + InstructionCost ScalarCost; + + VectorizationFactor(ElementCount Width, InstructionCost Cost, + InstructionCost ScalarCost) + : Width(Width), Cost(Cost), ScalarCost(ScalarCost) {} /// Width 1 means no vectorization, cost 0 means uncomputed cost. static VectorizationFactor Disabled() { - return {ElementCount::getFixed(1), 0}; + return {ElementCount::getFixed(1), 0, 0}; } bool operator==(const VectorizationFactor &rhs) const { @@ -298,8 +303,12 @@ public: /// Generate the IR code for the body of the vectorized loop according to the /// best selected \p VF, \p UF and VPlan \p BestPlan. + /// TODO: \p IsEpilogueVectorization is needed to avoid issues due to epilogue + /// vectorization re-using plans for both the main and epilogue vector loops. + /// It should be removed once the re-use issue has been fixed. void executePlan(ElementCount VF, unsigned UF, VPlan &BestPlan, - InnerLoopVectorizer &LB, DominatorTree *DT); + InnerLoopVectorizer &LB, DominatorTree *DT, + bool IsEpilogueVectorization); #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void printPlans(raw_ostream &O); @@ -319,6 +328,9 @@ public: getDecisionAndClampRange(const std::function<bool(ElementCount)> &Predicate, VFRange &Range); + /// Check if the number of runtime checks exceeds the threshold. + bool requiresTooManyRuntimeChecks() const; + protected: /// Collect the instructions from the original loop that would be trivially /// dead in the vectorized loop if generated. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 51d2c6237af1..b637b2d5ddae 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -58,7 +58,6 @@ #include "VPRecipeBuilder.h" #include "VPlan.h" #include "VPlanHCFGBuilder.h" -#include "VPlanPredicator.h" #include "VPlanTransforms.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" @@ -112,7 +111,6 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" @@ -144,10 +142,10 @@ #include <algorithm> #include <cassert> #include <cstdint> -#include <cstdlib> #include <functional> #include <iterator> #include <limits> +#include <map> #include <memory> #include <string> #include <tuple> @@ -346,13 +344,6 @@ cl::opt<bool> EnableVPlanNativePath( cl::desc("Enable VPlan-native vectorization path with " "support for outer loop vectorization.")); -// FIXME: Remove this switch once we have divergence analysis. Currently we -// assume divergent non-backedge branches when this switch is true. -cl::opt<bool> EnableVPlanPredication( - "enable-vplan-predication", cl::init(false), cl::Hidden, - cl::desc("Enable VPlan-native vectorization path predicator with " - "support for outer loop vectorization.")); - // This flag enables the stress testing of the VPlan H-CFG construction in the // VPlan-native vectorization path. It must be used in conjuction with // -enable-vplan-native-path. -vplan-verify-hcfg can also be used to enable the @@ -481,7 +472,7 @@ public: VPTransformState &State); /// Fix the vectorized code, taking care of header phi's, live-outs, and more. - void fixVectorizedLoop(VPTransformState &State); + void fixVectorizedLoop(VPTransformState &State, VPlan &Plan); // Return true if any runtime check is added. bool areSafetyChecksAdded() { return AddedSafetyChecks; } @@ -491,12 +482,6 @@ public: /// new unrolled loop, where UF is the unroll factor. using VectorParts = SmallVector<Value *, 2>; - /// Vectorize a single first-order recurrence or pointer induction PHINode in - /// a block. This method handles the induction variable canonicalization. It - /// supports both VF = 1 for unrolled loops and arbitrary length vectors. - void widenPHIInstruction(Instruction *PN, VPWidenPHIRecipe *PhiR, - VPTransformState &State); - /// A helper function to scalarize a single Instruction in the innermost loop. /// Generates a sequence of scalar instances for each lane between \p MinLane /// and \p MaxLane, times each part between \p MinPart and \p MaxPart, @@ -506,13 +491,6 @@ public: const VPIteration &Instance, bool IfPredicateInstr, VPTransformState &State); - /// Widen an integer or floating-point induction variable \p IV. If \p Trunc - /// is provided, the integer induction variable will first be truncated to - /// the corresponding type. \p CanonicalIV is the scalar value generated for - /// the canonical induction variable. - void widenIntOrFpInduction(PHINode *IV, VPWidenIntOrFpInductionRecipe *Def, - VPTransformState &State, Value *CanonicalIV); - /// Construct the vector value of a scalarized value \p V one lane at a time. void packScalarIntoVectorValue(VPValue *Def, const VPIteration &Instance, VPTransformState &State); @@ -527,13 +505,8 @@ public: ArrayRef<VPValue *> StoredValues, VPValue *BlockInMask = nullptr); - /// Set the debug location in the builder \p Ptr using the debug location in - /// \p V. If \p Ptr is None then it uses the class member's Builder. - void setDebugLocFromInst(const Value *V, - Optional<IRBuilder<> *> CustomBuilder = None); - - /// Fix the non-induction PHIs in the OrigPHIsToFix vector. - void fixNonInductionPHIs(VPTransformState &State); + /// Fix the non-induction PHIs in \p Plan. + void fixNonInductionPHIs(VPlan &Plan, VPTransformState &State); /// Returns true if the reordering of FP operations is not allowed, but we are /// able to vectorize with strict in-order reductions for the given RdxDesc. @@ -546,17 +519,6 @@ public: /// element. virtual Value *getBroadcastInstrs(Value *V); - /// Add metadata from one instruction to another. - /// - /// This includes both the original MDs from \p From and additional ones (\see - /// addNewMetadata). Use this for *newly created* instructions in the vector - /// loop. - void addMetadata(Instruction *To, Instruction *From); - - /// Similar to the previous function but it adds the metadata to a - /// vector of instructions. - void addMetadata(ArrayRef<Value *> To, Instruction *From); - // Returns the resume value (bc.merge.rdx) for a reduction as // generated by fixReduction. PHINode *getReductionResumeValue(const RecurrenceDescriptor &RdxDesc); @@ -575,13 +537,9 @@ protected: /// Set up the values of the IVs correctly when exiting the vector loop. void fixupIVUsers(PHINode *OrigPhi, const InductionDescriptor &II, - Value *CountRoundDown, Value *EndValue, - BasicBlock *MiddleBlock); - - /// Introduce a conditional branch (on true, condition to be set later) at the - /// end of the header=latch connecting it to itself (across the backedge) and - /// to the exit block of \p L. - void createHeaderBranch(Loop *L); + Value *VectorTripCount, Value *EndValue, + BasicBlock *MiddleBlock, BasicBlock *VectorHeader, + VPlan &Plan); /// Handle all cross-iteration phis in the header. void fixCrossIterationPHIs(VPTransformState &State); @@ -595,16 +553,9 @@ protected: void fixReduction(VPReductionPHIRecipe *Phi, VPTransformState &State); /// Clear NSW/NUW flags from reduction instructions if necessary. - void clearReductionWrapFlags(const RecurrenceDescriptor &RdxDesc, + void clearReductionWrapFlags(VPReductionPHIRecipe *PhiR, VPTransformState &State); - /// Fixup the LCSSA phi nodes in the unique exit block. This simply - /// means we need to add the appropriate incoming value from the middle - /// block as exiting edges from the scalar epilogue loop (if present) are - /// already in place, and we exit the vector loop exclusively to the middle - /// block. - void fixLCSSAPHIs(VPTransformState &State); - /// Iteratively sink the scalarized operands of a predicated instruction into /// the block that was created for it. void sinkScalarOperands(Instruction *PredInst); @@ -613,30 +564,11 @@ protected: /// represented as. void truncateToMinimalBitwidths(VPTransformState &State); - /// Compute scalar induction steps. \p ScalarIV is the scalar induction - /// variable on which to base the steps, \p Step is the size of the step, and - /// \p EntryVal is the value from the original loop that maps to the steps. - /// Note that \p EntryVal doesn't have to be an induction variable - it - /// can also be a truncate instruction. - void buildScalarSteps(Value *ScalarIV, Value *Step, Instruction *EntryVal, - const InductionDescriptor &ID, VPValue *Def, - VPTransformState &State); - - /// Create a vector induction phi node based on an existing scalar one. \p - /// EntryVal is the value from the original loop that maps to the vector phi - /// node, and \p Step is the loop-invariant step. If \p EntryVal is a - /// truncate instruction, instead of widening the original IV, we widen a - /// version of the IV truncated to \p EntryVal's type. - void createVectorIntOrFpInductionPHI(const InductionDescriptor &II, - Value *Step, Value *Start, - Instruction *EntryVal, VPValue *Def, - VPTransformState &State); - /// Returns (and creates if needed) the original loop trip count. - Value *getOrCreateTripCount(Loop *NewLoop); + Value *getOrCreateTripCount(BasicBlock *InsertBlock); /// Returns (and creates if needed) the trip count of the widened loop. - Value *getOrCreateVectorTripCount(Loop *NewLoop); + Value *getOrCreateVectorTripCount(BasicBlock *InsertBlock); /// Returns a bitcasted value to the requested vector type. /// Also handles bitcasts of vector<float> <-> vector<pointer> types. @@ -645,33 +577,21 @@ protected: /// Emit a bypass check to see if the vector trip count is zero, including if /// it overflows. - void emitMinimumIterationCountCheck(Loop *L, BasicBlock *Bypass); + void emitIterationCountCheck(BasicBlock *Bypass); /// Emit a bypass check to see if all of the SCEV assumptions we've /// had to make are correct. Returns the block containing the checks or /// nullptr if no checks have been added. - BasicBlock *emitSCEVChecks(Loop *L, BasicBlock *Bypass); + BasicBlock *emitSCEVChecks(BasicBlock *Bypass); /// Emit bypass checks to check any memory assumptions we may have made. /// Returns the block containing the checks or nullptr if no checks have been /// added. - BasicBlock *emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass); - - /// Compute the transformed value of Index at offset StartValue using step - /// StepValue. - /// For integer induction, returns StartValue + Index * StepValue. - /// For pointer induction, returns StartValue[Index * StepValue]. - /// FIXME: The newly created binary instructions should contain nsw/nuw - /// flags, which can be found from the original scalar operations. - Value *emitTransformedIndex(IRBuilder<> &B, Value *Index, ScalarEvolution *SE, - const DataLayout &DL, - const InductionDescriptor &ID, - BasicBlock *VectorHeader) const; + BasicBlock *emitMemRuntimeChecks(BasicBlock *Bypass); /// Emit basic blocks (prefixed with \p Prefix) for the iteration check, - /// vector loop preheader, middle block and scalar preheader. Also - /// allocate a loop object for the new vector loop and return it. - Loop *createVectorLoopSkeleton(StringRef Prefix); + /// vector loop preheader, middle block and scalar preheader. + void createVectorLoopSkeleton(StringRef Prefix); /// Create new phi nodes for the induction variables to resume iteration count /// in the scalar epilogue, from where the vectorized loop left off. @@ -680,21 +600,12 @@ protected: /// block, the \p AdditionalBypass pair provides information about the bypass /// block and the end value on the edge from bypass to this loop. void createInductionResumeValues( - Loop *L, std::pair<BasicBlock *, Value *> AdditionalBypass = {nullptr, nullptr}); /// Complete the loop skeleton by adding debug MDs, creating appropriate /// conditional branches in the middle block, preparing the builder and - /// running the verifier. Take in the vector loop \p L as argument, and return - /// the preheader of the completed vector loop. - BasicBlock *completeLoopSkeleton(Loop *L, MDNode *OrigLoopID); - - /// Add additional metadata to \p To that was not present on \p Orig. - /// - /// Currently this is used to add the noalias annotations based on the - /// inserted memchecks. Use this for instructions that are *cloned* into the - /// vector loop. - void addNewMetadata(Instruction *To, const Instruction *Orig); + /// running the verifier. Return the preheader of the completed vector loop. + BasicBlock *completeLoopSkeleton(MDNode *OrigLoopID); /// Collect poison-generating recipes that may generate a poison value that is /// used after vectorization, even when their operands are not poison. Those @@ -741,13 +652,6 @@ protected: /// Interface to emit optimization remarks. OptimizationRemarkEmitter *ORE; - /// LoopVersioning. It's only set up (non-null) if memchecks were - /// used. - /// - /// This is currently only used to add no-alias metadata based on the - /// memchecks. The actually versioning is performed manually. - std::unique_ptr<LoopVersioning> LVer; - /// The vectorization SIMD factor to use. Each vector will have this many /// vector elements. ElementCount VF; @@ -774,9 +678,6 @@ protected: /// there can be multiple exiting edges reaching this block. BasicBlock *LoopExitBlock; - /// The vector loop body. - BasicBlock *LoopVectorBody; - /// The scalar loop body. BasicBlock *LoopScalarBody; @@ -805,10 +706,6 @@ protected: // so we can later fix-up the external users of the induction variables. DenseMap<PHINode *, Value *> IVEndValues; - // Vector of original scalar PHIs whose corresponding widened PHIs need to be - // fixed up at the end of vector code generation. - SmallVector<PHINode *, 8> OrigPHIsToFix; - /// BFI and PSI are used to check for profile guided size optimizations. BlockFrequencyInfo *BFI; ProfileSummaryInfo *PSI; @@ -936,8 +833,7 @@ protected: /// Emits an iteration count bypass check once for the main loop (when \p /// ForEpilogue is false) and once for the epilogue loop (when \p /// ForEpilogue is true). - BasicBlock *emitMinimumIterationCountCheck(Loop *L, BasicBlock *Bypass, - bool ForEpilogue); + BasicBlock *emitIterationCountCheck(BasicBlock *Bypass, bool ForEpilogue); void printDebugTracesAtStart() override; void printDebugTracesAtEnd() override; }; @@ -956,7 +852,9 @@ public: BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, GeneratedRTChecks &Checks) : InnerLoopAndEpilogueVectorizer(OrigLoop, PSE, LI, DT, TLI, TTI, AC, ORE, - EPI, LVL, CM, BFI, PSI, Checks) {} + EPI, LVL, CM, BFI, PSI, Checks) { + TripCount = EPI.TripCount; + } /// Implements the interface for creating a vectorized skeleton using the /// *epilogue loop* strategy (ie the second pass of vplan execution). std::pair<BasicBlock *, Value *> @@ -966,7 +864,7 @@ protected: /// Emits an iteration count bypass check after the main vector loop has /// finished to see if there are any iterations left to execute by either /// the vector epilogue or the scalar epilogue. - BasicBlock *emitMinimumVectorEpilogueIterCountCheck(Loop *L, + BasicBlock *emitMinimumVectorEpilogueIterCountCheck( BasicBlock *Bypass, BasicBlock *Insert); void printDebugTracesAtStart() override; @@ -993,31 +891,6 @@ static Instruction *getDebugLocFromInstOrOperands(Instruction *I) { return I; } -void InnerLoopVectorizer::setDebugLocFromInst( - const Value *V, Optional<IRBuilder<> *> CustomBuilder) { - IRBuilder<> *B = (CustomBuilder == None) ? &Builder : *CustomBuilder; - if (const Instruction *Inst = dyn_cast_or_null<Instruction>(V)) { - const DILocation *DIL = Inst->getDebugLoc(); - - // When a FSDiscriminator is enabled, we don't need to add the multiply - // factors to the discriminators. - if (DIL && Inst->getFunction()->isDebugInfoForProfiling() && - !isa<DbgInfoIntrinsic>(Inst) && !EnableFSDiscriminator) { - // FIXME: For scalable vectors, assume vscale=1. - auto NewDIL = - DIL->cloneByMultiplyingDuplicationFactor(UF * VF.getKnownMinValue()); - if (NewDIL) - B->SetCurrentDebugLocation(NewDIL.getValue()); - else - LLVM_DEBUG(dbgs() - << "Failed to create new discriminator: " - << DIL->getFilename() << " Line: " << DIL->getLine()); - } else - B->SetCurrentDebugLocation(DIL); - } else - B->SetCurrentDebugLocation(DebugLoc()); -} - /// Write a \p DebugMsg about vectorization to the debug output stream. If \p I /// is passed, the message relates to that particular instruction. #ifndef NDEBUG @@ -1059,7 +932,7 @@ static OptimizationRemarkAnalysis createLVAnalysis(const char *PassName, namespace llvm { /// Return a value for Step multiplied by VF. -Value *createStepForVF(IRBuilder<> &B, Type *Ty, ElementCount VF, +Value *createStepForVF(IRBuilderBase &B, Type *Ty, ElementCount VF, int64_t Step) { assert(Ty->isIntegerTy() && "Expected an integer step"); Constant *StepVal = ConstantInt::get(Ty, Step * VF.getKnownMinValue()); @@ -1067,12 +940,13 @@ Value *createStepForVF(IRBuilder<> &B, Type *Ty, ElementCount VF, } /// Return the runtime value for VF. -Value *getRuntimeVF(IRBuilder<> &B, Type *Ty, ElementCount VF) { +Value *getRuntimeVF(IRBuilderBase &B, Type *Ty, ElementCount VF) { Constant *EC = ConstantInt::get(Ty, VF.getKnownMinValue()); return VF.isScalable() ? B.CreateVScale(EC) : EC; } -static Value *getRuntimeVFAsFloat(IRBuilder<> &B, Type *FTy, ElementCount VF) { +static Value *getRuntimeVFAsFloat(IRBuilderBase &B, Type *FTy, + ElementCount VF) { assert(FTy->isFloatingPointTy() && "Expected floating point type!"); Type *IntTy = IntegerType::get(FTy->getContext(), FTy->getScalarSizeInBits()); Value *RuntimeVF = getRuntimeVF(B, IntTy, VF); @@ -1119,14 +993,6 @@ static std::string getDebugLocString(const Loop *L) { } #endif -void InnerLoopVectorizer::addNewMetadata(Instruction *To, - const Instruction *Orig) { - // If the loop was versioned with memchecks, add the corresponding no-alias - // metadata. - if (LVer && (isa<LoadInst>(Orig) || isa<StoreInst>(Orig))) - LVer->annotateInstWithNoAlias(To, Orig); -} - void InnerLoopVectorizer::collectPoisonGeneratingRecipes( VPTransformState &State) { @@ -1151,6 +1017,7 @@ void InnerLoopVectorizer::collectPoisonGeneratingRecipes( // handled. if (isa<VPWidenMemoryInstructionRecipe>(CurRec) || isa<VPInterleaveRecipe>(CurRec) || + isa<VPScalarIVStepsRecipe>(CurRec) || isa<VPCanonicalIVPHIRecipe>(CurRec)) continue; @@ -1176,10 +1043,10 @@ void InnerLoopVectorizer::collectPoisonGeneratingRecipes( for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(Iter)) { for (VPRecipeBase &Recipe : *VPBB) { if (auto *WidenRec = dyn_cast<VPWidenMemoryInstructionRecipe>(&Recipe)) { - Instruction *UnderlyingInstr = WidenRec->getUnderlyingInstr(); + Instruction &UnderlyingInstr = WidenRec->getIngredient(); VPDef *AddrDef = WidenRec->getAddr()->getDef(); - if (AddrDef && WidenRec->isConsecutive() && UnderlyingInstr && - Legal->blockNeedsPredication(UnderlyingInstr->getParent())) + if (AddrDef && WidenRec->isConsecutive() && + Legal->blockNeedsPredication(UnderlyingInstr.getParent())) collectPoisonGeneratingInstrsInBackwardSlice( cast<VPRecipeBase>(AddrDef)); } else if (auto *InterleaveRec = dyn_cast<VPInterleaveRecipe>(&Recipe)) { @@ -1206,20 +1073,6 @@ void InnerLoopVectorizer::collectPoisonGeneratingRecipes( } } -void InnerLoopVectorizer::addMetadata(Instruction *To, - Instruction *From) { - propagateMetadata(To, From); - addNewMetadata(To, From); -} - -void InnerLoopVectorizer::addMetadata(ArrayRef<Value *> To, - Instruction *From) { - for (Value *V : To) { - if (Instruction *I = dyn_cast<Instruction>(V)) - addMetadata(I, From); - } -} - PHINode *InnerLoopVectorizer::getReductionResumeValue( const RecurrenceDescriptor &RdxDesc) { auto It = ReductionResumeValues.find(&RdxDesc); @@ -1363,7 +1216,7 @@ public: /// RdxDesc. This is true if the -enable-strict-reductions flag is passed, /// the IsOrdered flag of RdxDesc is set and we do not allow reordering /// of FP operations. - bool useOrderedReductions(const RecurrenceDescriptor &RdxDesc) { + bool useOrderedReductions(const RecurrenceDescriptor &RdxDesc) const { return !Hints->allowReordering() && RdxDesc.isOrdered(); } @@ -1718,15 +1571,10 @@ private: /// \return the maximized element count based on the targets vector /// registers and the loop trip-count, but limited to a maximum safe VF. /// This is a helper function of computeFeasibleMaxVF. - /// FIXME: MaxSafeVF is currently passed by reference to avoid some obscure - /// issue that occurred on one of the buildbots which cannot be reproduced - /// without having access to the properietary compiler (see comments on - /// D98509). The issue is currently under investigation and this workaround - /// will be removed as soon as possible. ElementCount getMaximizedVFForTarget(unsigned ConstTripCount, unsigned SmallestType, unsigned WidestType, - const ElementCount &MaxSafeVF, + ElementCount MaxSafeVF, bool FoldTailByMasking); /// \return the maximum legal scalable VF, based on the safe max number @@ -2017,7 +1865,7 @@ public: /// there is no vector code generation, the check blocks are removed /// completely. void Create(Loop *L, const LoopAccessInfo &LAI, - const SCEVUnionPredicate &UnionPred) { + const SCEVPredicate &UnionPred, ElementCount VF, unsigned IC) { BasicBlock *LoopHeader = L->getHeader(); BasicBlock *Preheader = L->getLoopPreheader(); @@ -2040,9 +1888,19 @@ public: MemCheckBlock = SplitBlock(Pred, Pred->getTerminator(), DT, LI, nullptr, "vector.memcheck"); - MemRuntimeCheckCond = - addRuntimeChecks(MemCheckBlock->getTerminator(), L, - RtPtrChecking.getChecks(), MemCheckExp); + auto DiffChecks = RtPtrChecking.getDiffChecks(); + if (DiffChecks) { + MemRuntimeCheckCond = addDiffRuntimeChecks( + MemCheckBlock->getTerminator(), L, *DiffChecks, MemCheckExp, + [VF](IRBuilderBase &B, unsigned Bits) { + return getRuntimeVF(B, B.getIntNTy(Bits), VF); + }, + IC); + } else { + MemRuntimeCheckCond = + addRuntimeChecks(MemCheckBlock->getTerminator(), L, + RtPtrChecking.getChecks(), MemCheckExp); + } assert(MemRuntimeCheckCond && "no RT checks generated although RtPtrChecking " "claimed checks are required"); @@ -2114,12 +1972,16 @@ public: /// Adds the generated SCEVCheckBlock before \p LoopVectorPreHeader and /// adjusts the branches to branch to the vector preheader or \p Bypass, /// depending on the generated condition. - BasicBlock *emitSCEVChecks(Loop *L, BasicBlock *Bypass, + BasicBlock *emitSCEVChecks(BasicBlock *Bypass, BasicBlock *LoopVectorPreHeader, BasicBlock *LoopExitBlock) { if (!SCEVCheckCond) return nullptr; - if (auto *C = dyn_cast<ConstantInt>(SCEVCheckCond)) + + Value *Cond = SCEVCheckCond; + // Mark the check as used, to prevent it from being removed during cleanup. + SCEVCheckCond = nullptr; + if (auto *C = dyn_cast<ConstantInt>(Cond)) if (C->isZero()) return nullptr; @@ -2138,18 +2000,15 @@ public: DT->addNewBlock(SCEVCheckBlock, Pred); DT->changeImmediateDominator(LoopVectorPreHeader, SCEVCheckBlock); - ReplaceInstWithInst( - SCEVCheckBlock->getTerminator(), - BranchInst::Create(Bypass, LoopVectorPreHeader, SCEVCheckCond)); - // Mark the check as used, to prevent it from being removed during cleanup. - SCEVCheckCond = nullptr; + ReplaceInstWithInst(SCEVCheckBlock->getTerminator(), + BranchInst::Create(Bypass, LoopVectorPreHeader, Cond)); return SCEVCheckBlock; } /// Adds the generated MemCheckBlock before \p LoopVectorPreHeader and adjusts /// the branches to branch to the vector preheader or \p Bypass, depending on /// the generated condition. - BasicBlock *emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass, + BasicBlock *emitMemRuntimeChecks(BasicBlock *Bypass, BasicBlock *LoopVectorPreHeader) { // Check if we generated code that checks in runtime if arrays overlap. if (!MemRuntimeCheckCond) @@ -2346,7 +2205,7 @@ Value *InnerLoopVectorizer::getBroadcastInstrs(Value *V) { /// \p Opcode is relevant for FP induction variable. static Value *getStepVector(Value *Val, Value *StartIdx, Value *Step, Instruction::BinaryOps BinOp, ElementCount VF, - IRBuilder<> &Builder) { + IRBuilderBase &Builder) { assert(VF.isVector() && "only vector VFs are supported"); // Create and check the types. @@ -2362,9 +2221,8 @@ static Value *getStepVector(Value *Val, Value *StartIdx, Value *Step, // Create a vector of consecutive numbers from zero to VF. VectorType *InitVecValVTy = ValVTy; - Type *InitVecValSTy = STy; if (STy->isFloatingPointTy()) { - InitVecValSTy = + Type *InitVecValSTy = IntegerType::get(STy->getContext(), STy->getScalarSizeInBits()); InitVecValVTy = VectorType::get(InitVecValSTy, VLen); } @@ -2394,198 +2252,12 @@ static Value *getStepVector(Value *Val, Value *StartIdx, Value *Step, return Builder.CreateBinOp(BinOp, Val, MulOp, "induction"); } -void InnerLoopVectorizer::createVectorIntOrFpInductionPHI( - const InductionDescriptor &II, Value *Step, Value *Start, - Instruction *EntryVal, VPValue *Def, VPTransformState &State) { - IRBuilder<> &Builder = State.Builder; - assert((isa<PHINode>(EntryVal) || isa<TruncInst>(EntryVal)) && - "Expected either an induction phi-node or a truncate of it!"); - - // Construct the initial value of the vector IV in the vector loop preheader - auto CurrIP = Builder.saveIP(); - Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); - if (isa<TruncInst>(EntryVal)) { - assert(Start->getType()->isIntegerTy() && - "Truncation requires an integer type"); - auto *TruncType = cast<IntegerType>(EntryVal->getType()); - Step = Builder.CreateTrunc(Step, TruncType); - Start = Builder.CreateCast(Instruction::Trunc, Start, TruncType); - } - - Value *Zero = getSignedIntOrFpConstant(Start->getType(), 0); - Value *SplatStart = Builder.CreateVectorSplat(State.VF, Start); - Value *SteppedStart = getStepVector( - SplatStart, Zero, Step, II.getInductionOpcode(), State.VF, State.Builder); - - // We create vector phi nodes for both integer and floating-point induction - // variables. Here, we determine the kind of arithmetic we will perform. - Instruction::BinaryOps AddOp; - Instruction::BinaryOps MulOp; - if (Step->getType()->isIntegerTy()) { - AddOp = Instruction::Add; - MulOp = Instruction::Mul; - } else { - AddOp = II.getInductionOpcode(); - MulOp = Instruction::FMul; - } - - // Multiply the vectorization factor by the step using integer or - // floating-point arithmetic as appropriate. - Type *StepType = Step->getType(); - Value *RuntimeVF; - if (Step->getType()->isFloatingPointTy()) - RuntimeVF = getRuntimeVFAsFloat(Builder, StepType, State.VF); - else - RuntimeVF = getRuntimeVF(Builder, StepType, State.VF); - Value *Mul = Builder.CreateBinOp(MulOp, Step, RuntimeVF); - - // Create a vector splat to use in the induction update. - // - // FIXME: If the step is non-constant, we create the vector splat with - // IRBuilder. IRBuilder can constant-fold the multiply, but it doesn't - // handle a constant vector splat. - Value *SplatVF = isa<Constant>(Mul) - ? ConstantVector::getSplat(State.VF, cast<Constant>(Mul)) - : Builder.CreateVectorSplat(State.VF, Mul); - Builder.restoreIP(CurrIP); - - // We may need to add the step a number of times, depending on the unroll - // factor. The last of those goes into the PHI. - PHINode *VecInd = PHINode::Create(SteppedStart->getType(), 2, "vec.ind", - &*LoopVectorBody->getFirstInsertionPt()); - VecInd->setDebugLoc(EntryVal->getDebugLoc()); - Instruction *LastInduction = VecInd; - for (unsigned Part = 0; Part < UF; ++Part) { - State.set(Def, LastInduction, Part); - - if (isa<TruncInst>(EntryVal)) - addMetadata(LastInduction, EntryVal); - - LastInduction = cast<Instruction>( - Builder.CreateBinOp(AddOp, LastInduction, SplatVF, "step.add")); - LastInduction->setDebugLoc(EntryVal->getDebugLoc()); - } - - // Move the last step to the end of the latch block. This ensures consistent - // placement of all induction updates. - auto *LoopVectorLatch = LI->getLoopFor(LoopVectorBody)->getLoopLatch(); - auto *Br = cast<BranchInst>(LoopVectorLatch->getTerminator()); - LastInduction->moveBefore(Br); - LastInduction->setName("vec.ind.next"); - - VecInd->addIncoming(SteppedStart, LoopVectorPreHeader); - VecInd->addIncoming(LastInduction, LoopVectorLatch); -} - -void InnerLoopVectorizer::widenIntOrFpInduction( - PHINode *IV, VPWidenIntOrFpInductionRecipe *Def, VPTransformState &State, - Value *CanonicalIV) { - Value *Start = Def->getStartValue()->getLiveInIRValue(); - const InductionDescriptor &ID = Def->getInductionDescriptor(); - TruncInst *Trunc = Def->getTruncInst(); - IRBuilder<> &Builder = State.Builder; - assert(IV->getType() == ID.getStartValue()->getType() && "Types must match"); - assert(!State.VF.isZero() && "VF must be non-zero"); - - // The value from the original loop to which we are mapping the new induction - // variable. - Instruction *EntryVal = Trunc ? cast<Instruction>(Trunc) : IV; - - auto &DL = EntryVal->getModule()->getDataLayout(); - - // Generate code for the induction step. Note that induction steps are - // required to be loop-invariant - auto CreateStepValue = [&](const SCEV *Step) -> Value * { - assert(PSE.getSE()->isLoopInvariant(Step, OrigLoop) && - "Induction step should be loop invariant"); - if (PSE.getSE()->isSCEVable(IV->getType())) { - SCEVExpander Exp(*PSE.getSE(), DL, "induction"); - return Exp.expandCodeFor(Step, Step->getType(), - State.CFG.VectorPreHeader->getTerminator()); - } - return cast<SCEVUnknown>(Step)->getValue(); - }; - - // The scalar value to broadcast. This is derived from the canonical - // induction variable. If a truncation type is given, truncate the canonical - // induction variable and step. Otherwise, derive these values from the - // induction descriptor. - auto CreateScalarIV = [&](Value *&Step) -> Value * { - Value *ScalarIV = CanonicalIV; - Type *NeededType = IV->getType(); - if (!Def->isCanonical() || ScalarIV->getType() != NeededType) { - ScalarIV = - NeededType->isIntegerTy() - ? Builder.CreateSExtOrTrunc(ScalarIV, NeededType) - : Builder.CreateCast(Instruction::SIToFP, ScalarIV, NeededType); - ScalarIV = emitTransformedIndex(Builder, ScalarIV, PSE.getSE(), DL, ID, - State.CFG.PrevBB); - ScalarIV->setName("offset.idx"); - } - if (Trunc) { - auto *TruncType = cast<IntegerType>(Trunc->getType()); - assert(Step->getType()->isIntegerTy() && - "Truncation requires an integer step"); - ScalarIV = Builder.CreateTrunc(ScalarIV, TruncType); - Step = Builder.CreateTrunc(Step, TruncType); - } - return ScalarIV; - }; - - // Fast-math-flags propagate from the original induction instruction. - IRBuilder<>::FastMathFlagGuard FMFG(Builder); - if (ID.getInductionBinOp() && isa<FPMathOperator>(ID.getInductionBinOp())) - Builder.setFastMathFlags(ID.getInductionBinOp()->getFastMathFlags()); - - // Now do the actual transformations, and start with creating the step value. - Value *Step = CreateStepValue(ID.getStep()); - if (State.VF.isScalar()) { - Value *ScalarIV = CreateScalarIV(Step); - Type *ScalarTy = IntegerType::get(ScalarIV->getContext(), - Step->getType()->getScalarSizeInBits()); - - for (unsigned Part = 0; Part < UF; ++Part) { - Value *StartIdx = ConstantInt::get(ScalarTy, Part); - Value *EntryPart; - if (Step->getType()->isFloatingPointTy()) { - StartIdx = Builder.CreateUIToFP(StartIdx, Step->getType()); - Value *MulOp = Builder.CreateFMul(StartIdx, Step); - EntryPart = Builder.CreateBinOp(ID.getInductionOpcode(), ScalarIV, - MulOp, "induction"); - } else { - EntryPart = Builder.CreateAdd( - ScalarIV, Builder.CreateMul(StartIdx, Step), "induction"); - } - State.set(Def, EntryPart, Part); - if (Trunc) { - assert(!Step->getType()->isFloatingPointTy() && - "fp inductions shouldn't be truncated"); - addMetadata(EntryPart, Trunc); - } - } - return; - } - - // Create a new independent vector induction variable, if one is needed. - if (Def->needsVectorIV()) - createVectorIntOrFpInductionPHI(ID, Step, Start, EntryVal, Def, State); - - if (Def->needsScalarIV()) { - // Create scalar steps that can be used by instructions we will later - // scalarize. Note that the addition of the scalar steps will not increase - // the number of instructions in the loop in the common case prior to - // InstCombine. We will be trading one vector extract for each scalar step. - Value *ScalarIV = CreateScalarIV(Step); - buildScalarSteps(ScalarIV, Step, EntryVal, ID, Def, State); - } -} - -void InnerLoopVectorizer::buildScalarSteps(Value *ScalarIV, Value *Step, - Instruction *EntryVal, - const InductionDescriptor &ID, - VPValue *Def, - VPTransformState &State) { - IRBuilder<> &Builder = State.Builder; +/// Compute scalar induction steps. \p ScalarIV is the scalar induction +/// variable on which to base the steps, \p Step is the size of the step. +static void buildScalarSteps(Value *ScalarIV, Value *Step, + const InductionDescriptor &ID, VPValue *Def, + VPTransformState &State) { + IRBuilderBase &Builder = State.Builder; // We shouldn't have to build scalar steps if we aren't vectorizing. assert(State.VF.isVector() && "VF should be greater than one"); // Get the value type and ensure it and the step have the same integer type. @@ -2656,6 +2328,103 @@ void InnerLoopVectorizer::buildScalarSteps(Value *ScalarIV, Value *Step, } } +// Generate code for the induction step. Note that induction steps are +// required to be loop-invariant +static Value *CreateStepValue(const SCEV *Step, ScalarEvolution &SE, + Instruction *InsertBefore, + Loop *OrigLoop = nullptr) { + const DataLayout &DL = SE.getDataLayout(); + assert((!OrigLoop || SE.isLoopInvariant(Step, OrigLoop)) && + "Induction step should be loop invariant"); + if (auto *E = dyn_cast<SCEVUnknown>(Step)) + return E->getValue(); + + SCEVExpander Exp(SE, DL, "induction"); + return Exp.expandCodeFor(Step, Step->getType(), InsertBefore); +} + +/// Compute the transformed value of Index at offset StartValue using step +/// StepValue. +/// For integer induction, returns StartValue + Index * StepValue. +/// For pointer induction, returns StartValue[Index * StepValue]. +/// FIXME: The newly created binary instructions should contain nsw/nuw +/// flags, which can be found from the original scalar operations. +static Value *emitTransformedIndex(IRBuilderBase &B, Value *Index, + Value *StartValue, Value *Step, + const InductionDescriptor &ID) { + assert(Index->getType()->getScalarType() == Step->getType() && + "Index scalar type does not match StepValue type"); + + // Note: the IR at this point is broken. We cannot use SE to create any new + // SCEV and then expand it, hoping that SCEV's simplification will give us + // a more optimal code. Unfortunately, attempt of doing so on invalid IR may + // lead to various SCEV crashes. So all we can do is to use builder and rely + // on InstCombine for future simplifications. Here we handle some trivial + // cases only. + auto CreateAdd = [&B](Value *X, Value *Y) { + assert(X->getType() == Y->getType() && "Types don't match!"); + if (auto *CX = dyn_cast<ConstantInt>(X)) + if (CX->isZero()) + return Y; + if (auto *CY = dyn_cast<ConstantInt>(Y)) + if (CY->isZero()) + return X; + return B.CreateAdd(X, Y); + }; + + // We allow X to be a vector type, in which case Y will potentially be + // splatted into a vector with the same element count. + auto CreateMul = [&B](Value *X, Value *Y) { + assert(X->getType()->getScalarType() == Y->getType() && + "Types don't match!"); + if (auto *CX = dyn_cast<ConstantInt>(X)) + if (CX->isOne()) + return Y; + if (auto *CY = dyn_cast<ConstantInt>(Y)) + if (CY->isOne()) + return X; + VectorType *XVTy = dyn_cast<VectorType>(X->getType()); + if (XVTy && !isa<VectorType>(Y->getType())) + Y = B.CreateVectorSplat(XVTy->getElementCount(), Y); + return B.CreateMul(X, Y); + }; + + switch (ID.getKind()) { + case InductionDescriptor::IK_IntInduction: { + assert(!isa<VectorType>(Index->getType()) && + "Vector indices not supported for integer inductions yet"); + assert(Index->getType() == StartValue->getType() && + "Index type does not match StartValue type"); + if (isa<ConstantInt>(Step) && cast<ConstantInt>(Step)->isMinusOne()) + return B.CreateSub(StartValue, Index); + auto *Offset = CreateMul(Index, Step); + return CreateAdd(StartValue, Offset); + } + case InductionDescriptor::IK_PtrInduction: { + assert(isa<Constant>(Step) && + "Expected constant step for pointer induction"); + return B.CreateGEP(ID.getElementType(), StartValue, CreateMul(Index, Step)); + } + case InductionDescriptor::IK_FpInduction: { + assert(!isa<VectorType>(Index->getType()) && + "Vector indices not supported for FP inductions yet"); + assert(Step->getType()->isFloatingPointTy() && "Expected FP Step value"); + auto InductionBinOp = ID.getInductionBinOp(); + assert(InductionBinOp && + (InductionBinOp->getOpcode() == Instruction::FAdd || + InductionBinOp->getOpcode() == Instruction::FSub) && + "Original bin op should be defined for FP induction"); + + Value *MulExp = B.CreateFMul(Step, Index); + return B.CreateBinOp(InductionBinOp->getOpcode(), StartValue, MulExp, + "induction"); + } + case InductionDescriptor::IK_NoInduction: + return nullptr; + } + llvm_unreachable("invalid enum"); +} + void InnerLoopVectorizer::packScalarIntoVectorValue(VPValue *Def, const VPIteration &Instance, VPTransformState &State) { @@ -2738,7 +2507,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup( for (unsigned Part = 0; Part < UF; Part++) { Value *AddrPart = State.get(Addr, VPIteration(Part, 0)); - setDebugLocFromInst(AddrPart); + State.setDebugLocFromInst(AddrPart); // Notice current instruction could be any index. Need to adjust the address // to the member of index 0. @@ -2764,7 +2533,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup( AddrParts.push_back(Builder.CreateBitCast(AddrPart, PtrTy)); } - setDebugLocFromInst(Instr); + State.setDebugLocFromInst(Instr); Value *PoisonVec = PoisonValue::get(VecTy); Value *MaskForGaps = nullptr; @@ -2919,8 +2688,6 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, if (!Instance.isFirstIteration()) return; - setDebugLocFromInst(Instr); - // Does this instruction return a value ? bool IsVoidRetTy = Instr->getType()->isVoidTy(); @@ -2937,21 +2704,23 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, if (State.MayGeneratePoisonRecipes.contains(RepRecipe)) Cloned->dropPoisonGeneratingFlags(); - State.Builder.SetInsertPoint(Builder.GetInsertBlock(), - Builder.GetInsertPoint()); + if (Instr->getDebugLoc()) + State.setDebugLocFromInst(Instr); + // Replace the operands of the cloned instructions with their scalar // equivalents in the new loop. for (auto &I : enumerate(RepRecipe->operands())) { auto InputInstance = Instance; VPValue *Operand = I.value(); - if (State.Plan->isUniformAfterVectorization(Operand)) + VPReplicateRecipe *OperandR = dyn_cast<VPReplicateRecipe>(Operand); + if (OperandR && OperandR->isUniform()) InputInstance.Lane = VPLane::getFirstLane(); Cloned->setOperand(I.index(), State.get(Operand, InputInstance)); } - addNewMetadata(Cloned, Instr); + State.addNewMetadata(Cloned, Instr); // Place the cloned scalar in the new loop. - Builder.Insert(Cloned); + State.Builder.Insert(Cloned); State.set(RepRecipe, Cloned, Instance); @@ -2964,29 +2733,12 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, PredicatedInstructions.push_back(Cloned); } -void InnerLoopVectorizer::createHeaderBranch(Loop *L) { - BasicBlock *Header = L->getHeader(); - assert(!L->getLoopLatch() && "loop should not have a latch at this point"); - - IRBuilder<> B(Header->getTerminator()); - Instruction *OldInst = - getDebugLocFromInstOrOperands(Legal->getPrimaryInduction()); - setDebugLocFromInst(OldInst, &B); - - // Connect the header to the exit and header blocks and replace the old - // terminator. - B.CreateCondBr(B.getTrue(), L->getUniqueExitBlock(), Header); - - // Now we have two terminators. Remove the old one from the block. - Header->getTerminator()->eraseFromParent(); -} - -Value *InnerLoopVectorizer::getOrCreateTripCount(Loop *L) { +Value *InnerLoopVectorizer::getOrCreateTripCount(BasicBlock *InsertBlock) { if (TripCount) return TripCount; - assert(L && "Create Trip Count for null loop."); - IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); + assert(InsertBlock); + IRBuilder<> Builder(InsertBlock->getTerminator()); // Find the loop boundaries. ScalarEvolution *SE = PSE.getSE(); const SCEV *BackedgeTakenCount = PSE.getBackedgeTakenCount(); @@ -3010,7 +2762,7 @@ Value *InnerLoopVectorizer::getOrCreateTripCount(Loop *L) { const SCEV *ExitCount = SE->getAddExpr( BackedgeTakenCount, SE->getOne(BackedgeTakenCount->getType())); - const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + const DataLayout &DL = InsertBlock->getModule()->getDataLayout(); // Expand the trip count and place the new instructions in the preheader. // Notice that the pre-header does not change, only the loop body. @@ -3018,22 +2770,23 @@ Value *InnerLoopVectorizer::getOrCreateTripCount(Loop *L) { // Count holds the overall loop count (N). TripCount = Exp.expandCodeFor(ExitCount, ExitCount->getType(), - L->getLoopPreheader()->getTerminator()); + InsertBlock->getTerminator()); if (TripCount->getType()->isPointerTy()) TripCount = CastInst::CreatePointerCast(TripCount, IdxTy, "exitcount.ptrcnt.to.int", - L->getLoopPreheader()->getTerminator()); + InsertBlock->getTerminator()); return TripCount; } -Value *InnerLoopVectorizer::getOrCreateVectorTripCount(Loop *L) { +Value * +InnerLoopVectorizer::getOrCreateVectorTripCount(BasicBlock *InsertBlock) { if (VectorTripCount) return VectorTripCount; - Value *TC = getOrCreateTripCount(L); - IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); + Value *TC = getOrCreateTripCount(InsertBlock); + IRBuilder<> Builder(InsertBlock->getTerminator()); Type *Ty = TC->getType(); // This is where we can make the step a runtime constant. @@ -3045,6 +2798,8 @@ Value *InnerLoopVectorizer::getOrCreateVectorTripCount(Loop *L) { // overflows: the vector induction variable will eventually wrap to zero given // that it starts at zero and its Step is a power of two; the loop will then // exit, with the last early-exit vector comparison also producing all-true. + // For scalable vectors the VF is not guaranteed to be a power of 2, but this + // is accounted for in emitIterationCountCheck that adds an overflow check. if (Cost->foldTailByMasking()) { assert(isPowerOf2_32(VF.getKnownMinValue() * UF) && "VF*UF must be a power of 2 when folding tail by masking"); @@ -3107,9 +2862,8 @@ Value *InnerLoopVectorizer::createBitOrPointerCast(Value *V, VectorType *DstVTy, return Builder.CreateBitOrPointerCast(CastVal, DstFVTy); } -void InnerLoopVectorizer::emitMinimumIterationCountCheck(Loop *L, - BasicBlock *Bypass) { - Value *Count = getOrCreateTripCount(L); +void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) { + Value *Count = getOrCreateTripCount(LoopVectorPreHeader); // Reuse existing vector loop preheader for TC checks. // Note that new preheader block is generated for vector loop. BasicBlock *const TCCheckBlock = LoopVectorPreHeader; @@ -3124,10 +2878,23 @@ void InnerLoopVectorizer::emitMinimumIterationCountCheck(Loop *L, : ICmpInst::ICMP_ULT; // If tail is to be folded, vector loop takes care of all iterations. + Type *CountTy = Count->getType(); Value *CheckMinIters = Builder.getFalse(); - if (!Cost->foldTailByMasking()) { - Value *Step = createStepForVF(Builder, Count->getType(), VF, UF); + Value *Step = createStepForVF(Builder, CountTy, VF, UF); + if (!Cost->foldTailByMasking()) CheckMinIters = Builder.CreateICmp(P, Count, Step, "min.iters.check"); + else if (VF.isScalable()) { + // vscale is not necessarily a power-of-2, which means we cannot guarantee + // an overflow to zero when updating induction variables and so an + // additional overflow check is required before entering the vector loop. + + // Get the maximum unsigned value for the type. + Value *MaxUIntTripCount = + ConstantInt::get(CountTy, cast<IntegerType>(CountTy)->getMask()); + Value *LHS = Builder.CreateSub(MaxUIntTripCount, Count); + + // Don't execute the vector loop if (UMax - n) < (VF * UF). + CheckMinIters = Builder.CreateICmp(ICmpInst::ICMP_ULT, LHS, Step); } // Create new preheader for vector loop. LoopVectorPreHeader = @@ -3152,10 +2919,10 @@ void InnerLoopVectorizer::emitMinimumIterationCountCheck(Loop *L, LoopBypassBlocks.push_back(TCCheckBlock); } -BasicBlock *InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) { +BasicBlock *InnerLoopVectorizer::emitSCEVChecks(BasicBlock *Bypass) { BasicBlock *const SCEVCheckBlock = - RTChecks.emitSCEVChecks(L, Bypass, LoopVectorPreHeader, LoopExitBlock); + RTChecks.emitSCEVChecks(Bypass, LoopVectorPreHeader, LoopExitBlock); if (!SCEVCheckBlock) return nullptr; @@ -3180,14 +2947,13 @@ BasicBlock *InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) { return SCEVCheckBlock; } -BasicBlock *InnerLoopVectorizer::emitMemRuntimeChecks(Loop *L, - BasicBlock *Bypass) { +BasicBlock *InnerLoopVectorizer::emitMemRuntimeChecks(BasicBlock *Bypass) { // VPlan-native path does not do any analysis for runtime checks currently. if (EnableVPlanNativePath) return nullptr; BasicBlock *const MemCheckBlock = - RTChecks.emitMemRuntimeChecks(L, Bypass, LoopVectorPreHeader); + RTChecks.emitMemRuntimeChecks(Bypass, LoopVectorPreHeader); // Check if we generated code that checks in runtime if arrays overlap. We put // the checks into a separate block to make the more common case of few @@ -3201,7 +2967,8 @@ BasicBlock *InnerLoopVectorizer::emitMemRuntimeChecks(Loop *L, "to vectorize."); ORE->emit([&]() { return OptimizationRemarkAnalysis(DEBUG_TYPE, "VectorizationCodeSize", - L->getStartLoc(), L->getHeader()) + OrigLoop->getStartLoc(), + OrigLoop->getHeader()) << "Code-size may be reduced by not forcing " "vectorization, or by source-code modifications " "eliminating the need for runtime checks " @@ -3213,116 +2980,10 @@ BasicBlock *InnerLoopVectorizer::emitMemRuntimeChecks(Loop *L, AddedSafetyChecks = true; - // We currently don't use LoopVersioning for the actual loop cloning but we - // still use it to add the noalias metadata. - LVer = std::make_unique<LoopVersioning>( - *Legal->getLAI(), - Legal->getLAI()->getRuntimePointerChecking()->getChecks(), OrigLoop, LI, - DT, PSE.getSE()); - LVer->prepareNoAliasMetadata(); return MemCheckBlock; } -Value *InnerLoopVectorizer::emitTransformedIndex( - IRBuilder<> &B, Value *Index, ScalarEvolution *SE, const DataLayout &DL, - const InductionDescriptor &ID, BasicBlock *VectorHeader) const { - - SCEVExpander Exp(*SE, DL, "induction"); - auto Step = ID.getStep(); - auto StartValue = ID.getStartValue(); - assert(Index->getType()->getScalarType() == Step->getType() && - "Index scalar type does not match StepValue type"); - - // Note: the IR at this point is broken. We cannot use SE to create any new - // SCEV and then expand it, hoping that SCEV's simplification will give us - // a more optimal code. Unfortunately, attempt of doing so on invalid IR may - // lead to various SCEV crashes. So all we can do is to use builder and rely - // on InstCombine for future simplifications. Here we handle some trivial - // cases only. - auto CreateAdd = [&B](Value *X, Value *Y) { - assert(X->getType() == Y->getType() && "Types don't match!"); - if (auto *CX = dyn_cast<ConstantInt>(X)) - if (CX->isZero()) - return Y; - if (auto *CY = dyn_cast<ConstantInt>(Y)) - if (CY->isZero()) - return X; - return B.CreateAdd(X, Y); - }; - - // We allow X to be a vector type, in which case Y will potentially be - // splatted into a vector with the same element count. - auto CreateMul = [&B](Value *X, Value *Y) { - assert(X->getType()->getScalarType() == Y->getType() && - "Types don't match!"); - if (auto *CX = dyn_cast<ConstantInt>(X)) - if (CX->isOne()) - return Y; - if (auto *CY = dyn_cast<ConstantInt>(Y)) - if (CY->isOne()) - return X; - VectorType *XVTy = dyn_cast<VectorType>(X->getType()); - if (XVTy && !isa<VectorType>(Y->getType())) - Y = B.CreateVectorSplat(XVTy->getElementCount(), Y); - return B.CreateMul(X, Y); - }; - - // Get a suitable insert point for SCEV expansion. For blocks in the vector - // loop, choose the end of the vector loop header (=VectorHeader), because - // the DomTree is not kept up-to-date for additional blocks generated in the - // vector loop. By using the header as insertion point, we guarantee that the - // expanded instructions dominate all their uses. - auto GetInsertPoint = [this, &B, VectorHeader]() { - BasicBlock *InsertBB = B.GetInsertPoint()->getParent(); - if (InsertBB != LoopVectorBody && - LI->getLoopFor(VectorHeader) == LI->getLoopFor(InsertBB)) - return VectorHeader->getTerminator(); - return &*B.GetInsertPoint(); - }; - - switch (ID.getKind()) { - case InductionDescriptor::IK_IntInduction: { - assert(!isa<VectorType>(Index->getType()) && - "Vector indices not supported for integer inductions yet"); - assert(Index->getType() == StartValue->getType() && - "Index type does not match StartValue type"); - if (ID.getConstIntStepValue() && ID.getConstIntStepValue()->isMinusOne()) - return B.CreateSub(StartValue, Index); - auto *Offset = CreateMul( - Index, Exp.expandCodeFor(Step, Index->getType(), GetInsertPoint())); - return CreateAdd(StartValue, Offset); - } - case InductionDescriptor::IK_PtrInduction: { - assert(isa<SCEVConstant>(Step) && - "Expected constant step for pointer induction"); - return B.CreateGEP( - ID.getElementType(), StartValue, - CreateMul(Index, - Exp.expandCodeFor(Step, Index->getType()->getScalarType(), - GetInsertPoint()))); - } - case InductionDescriptor::IK_FpInduction: { - assert(!isa<VectorType>(Index->getType()) && - "Vector indices not supported for FP inductions yet"); - assert(Step->getType()->isFloatingPointTy() && "Expected FP Step value"); - auto InductionBinOp = ID.getInductionBinOp(); - assert(InductionBinOp && - (InductionBinOp->getOpcode() == Instruction::FAdd || - InductionBinOp->getOpcode() == Instruction::FSub) && - "Original bin op should be defined for FP induction"); - - Value *StepValue = cast<SCEVUnknown>(Step)->getValue(); - Value *MulExp = B.CreateFMul(StepValue, Index); - return B.CreateBinOp(InductionBinOp->getOpcode(), StartValue, MulExp, - "induction"); - } - case InductionDescriptor::IK_NoInduction: - return nullptr; - } - llvm_unreachable("invalid enum"); -} - -Loop *InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) { +void InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) { LoopScalarBody = OrigLoop->getHeader(); LoopVectorPreHeader = OrigLoop->getLoopPreheader(); assert(LoopVectorPreHeader && "Invalid loop structure"); @@ -3354,43 +3015,24 @@ Loop *InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) { BrInst->setDebugLoc(ScalarLatchTerm->getDebugLoc()); ReplaceInstWithInst(LoopMiddleBlock->getTerminator(), BrInst); - // We intentionally don't let SplitBlock to update LoopInfo since - // LoopVectorBody should belong to another loop than LoopVectorPreHeader. - // LoopVectorBody is explicitly added to the correct place few lines later. - LoopVectorBody = - SplitBlock(LoopVectorPreHeader, LoopVectorPreHeader->getTerminator(), DT, - nullptr, nullptr, Twine(Prefix) + "vector.body"); - - // Update dominator for loop exit. + // Update dominator for loop exit. During skeleton creation, only the vector + // pre-header and the middle block are created. The vector loop is entirely + // created during VPlan exection. if (!Cost->requiresScalarEpilogue(VF)) // If there is an epilogue which must run, there's no edge from the // middle block to exit blocks and thus no need to update the immediate // dominator of the exit blocks. DT->changeImmediateDominator(LoopExitBlock, LoopMiddleBlock); - - // Create and register the new vector loop. - Loop *Lp = LI->AllocateLoop(); - Loop *ParentLoop = OrigLoop->getParentLoop(); - - // Insert the new loop into the loop nest and register the new basic blocks - // before calling any utilities such as SCEV that require valid LoopInfo. - if (ParentLoop) { - ParentLoop->addChildLoop(Lp); - } else { - LI->addTopLevelLoop(Lp); - } - Lp->addBasicBlockToLoop(LoopVectorBody, *LI); - return Lp; } void InnerLoopVectorizer::createInductionResumeValues( - Loop *L, std::pair<BasicBlock *, Value *> AdditionalBypass) { + std::pair<BasicBlock *, Value *> AdditionalBypass) { assert(((AdditionalBypass.first && AdditionalBypass.second) || (!AdditionalBypass.first && !AdditionalBypass.second)) && "Inconsistent information about additional bypass."); - Value *VectorTripCount = getOrCreateVectorTripCount(L); - assert(VectorTripCount && L && "Expected valid arguments"); + Value *VectorTripCount = getOrCreateVectorTripCount(LoopVectorPreHeader); + assert(VectorTripCount && "Expected valid arguments"); // We are going to resume the execution of the scalar loop. // Go over all of the induction variables that we found and fix the // PHIs that are left in the scalar version of the loop. @@ -3403,19 +3045,13 @@ void InnerLoopVectorizer::createInductionResumeValues( PHINode *OrigPhi = InductionEntry.first; InductionDescriptor II = InductionEntry.second; - // Create phi nodes to merge from the backedge-taken check block. - PHINode *BCResumeVal = - PHINode::Create(OrigPhi->getType(), 3, "bc.resume.val", - LoopScalarPreHeader->getTerminator()); - // Copy original phi DL over to the new one. - BCResumeVal->setDebugLoc(OrigPhi->getDebugLoc()); Value *&EndValue = IVEndValues[OrigPhi]; Value *EndValueFromAdditionalBypass = AdditionalBypass.second; if (OrigPhi == OldInduction) { // We know what the end value is. EndValue = VectorTripCount; } else { - IRBuilder<> B(L->getLoopPreheader()->getTerminator()); + IRBuilder<> B(LoopVectorPreHeader->getTerminator()); // Fast-math-flags propagate from the original induction instruction. if (II.getInductionBinOp() && isa<FPMathOperator>(II.getInductionBinOp())) @@ -3424,10 +3060,10 @@ void InnerLoopVectorizer::createInductionResumeValues( Type *StepType = II.getStep()->getType(); Instruction::CastOps CastOp = CastInst::getCastOpcode(VectorTripCount, true, StepType, true); - Value *CRD = B.CreateCast(CastOp, VectorTripCount, StepType, "cast.crd"); - const DataLayout &DL = LoopScalarBody->getModule()->getDataLayout(); - EndValue = - emitTransformedIndex(B, CRD, PSE.getSE(), DL, II, LoopVectorBody); + Value *VTC = B.CreateCast(CastOp, VectorTripCount, StepType, "cast.vtc"); + Value *Step = + CreateStepValue(II.getStep(), *PSE.getSE(), &*B.GetInsertPoint()); + EndValue = emitTransformedIndex(B, VTC, II.getStartValue(), Step, II); EndValue->setName("ind.end"); // Compute the end value for the additional bypass (if applicable). @@ -3435,13 +3071,23 @@ void InnerLoopVectorizer::createInductionResumeValues( B.SetInsertPoint(&(*AdditionalBypass.first->getFirstInsertionPt())); CastOp = CastInst::getCastOpcode(AdditionalBypass.second, true, StepType, true); - CRD = - B.CreateCast(CastOp, AdditionalBypass.second, StepType, "cast.crd"); + Value *Step = + CreateStepValue(II.getStep(), *PSE.getSE(), &*B.GetInsertPoint()); + VTC = + B.CreateCast(CastOp, AdditionalBypass.second, StepType, "cast.vtc"); EndValueFromAdditionalBypass = - emitTransformedIndex(B, CRD, PSE.getSE(), DL, II, LoopVectorBody); + emitTransformedIndex(B, VTC, II.getStartValue(), Step, II); EndValueFromAdditionalBypass->setName("ind.end"); } } + + // Create phi nodes to merge from the backedge-taken check block. + PHINode *BCResumeVal = + PHINode::Create(OrigPhi->getType(), 3, "bc.resume.val", + LoopScalarPreHeader->getTerminator()); + // Copy original phi DL over to the new one. + BCResumeVal->setDebugLoc(OrigPhi->getDebugLoc()); + // The new PHI merges the original incoming value, in case of a bypass, // or the value at the end of the vectorized loop. BCResumeVal->addIncoming(EndValue, LoopMiddleBlock); @@ -3460,13 +3106,10 @@ void InnerLoopVectorizer::createInductionResumeValues( } } -BasicBlock *InnerLoopVectorizer::completeLoopSkeleton(Loop *L, - MDNode *OrigLoopID) { - assert(L && "Expected valid loop."); - +BasicBlock *InnerLoopVectorizer::completeLoopSkeleton(MDNode *OrigLoopID) { // The trip counts should be cached by now. - Value *Count = getOrCreateTripCount(L); - Value *VectorTripCount = getOrCreateVectorTripCount(L); + Value *Count = getOrCreateTripCount(LoopVectorPreHeader); + Value *VectorTripCount = getOrCreateVectorTripCount(LoopVectorPreHeader); auto *ScalarLatchTerm = OrigLoop->getLoopLatch()->getTerminator(); @@ -3491,14 +3134,8 @@ BasicBlock *InnerLoopVectorizer::completeLoopSkeleton(Loop *L, cast<BranchInst>(LoopMiddleBlock->getTerminator())->setCondition(CmpN); } - // Get ready to start creating new instructions into the vectorized body. - assert(LoopVectorPreHeader == L->getLoopPreheader() && - "Inconsistent vector loop preheader"); - Builder.SetInsertPoint(&*LoopVectorBody->getFirstInsertionPt()); - #ifdef EXPENSIVE_CHECKS assert(DT->verify(DominatorTree::VerificationLevel::Fast)); - LI->verify(*DT); #endif return LoopVectorPreHeader; @@ -3521,7 +3158,7 @@ InnerLoopVectorizer::createVectorizedLoopSkeleton() { |/ | | v | [ ] \ - | [ ]_| <-- vector loop. + | [ ]_| <-- vector loop (created during VPlan execution). | | | v \ -[ ] <--- middle-block. @@ -3548,34 +3185,32 @@ InnerLoopVectorizer::createVectorizedLoopSkeleton() { // simply happens to be prone to hitting this in practice. In theory, we // can hit the same issue for any SCEV, or ValueTracking query done during // mutation. See PR49900. - getOrCreateTripCount(OrigLoop); + getOrCreateTripCount(OrigLoop->getLoopPreheader()); // Create an empty vector loop, and prepare basic blocks for the runtime // checks. - Loop *Lp = createVectorLoopSkeleton(""); + createVectorLoopSkeleton(""); // Now, compare the new count to zero. If it is zero skip the vector loop and // jump to the scalar loop. This check also covers the case where the // backedge-taken count is uint##_max: adding one to it will overflow leading // to an incorrect trip count of zero. In this (rare) case we will also jump // to the scalar loop. - emitMinimumIterationCountCheck(Lp, LoopScalarPreHeader); + emitIterationCountCheck(LoopScalarPreHeader); // Generate the code to check any assumptions that we've made for SCEV // expressions. - emitSCEVChecks(Lp, LoopScalarPreHeader); + emitSCEVChecks(LoopScalarPreHeader); // Generate the code that checks in runtime if arrays overlap. We put the // checks into a separate block to make the more common case of few elements // faster. - emitMemRuntimeChecks(Lp, LoopScalarPreHeader); - - createHeaderBranch(Lp); + emitMemRuntimeChecks(LoopScalarPreHeader); // Emit phis for the new starting index of the scalar loop. - createInductionResumeValues(Lp); + createInductionResumeValues(); - return {completeLoopSkeleton(Lp, OrigLoopID), nullptr}; + return {completeLoopSkeleton(OrigLoopID), nullptr}; } // Fix up external users of the induction variable. At this point, we are @@ -3584,8 +3219,9 @@ InnerLoopVectorizer::createVectorizedLoopSkeleton() { // value for the IV when arriving directly from the middle block. void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi, const InductionDescriptor &II, - Value *CountRoundDown, Value *EndValue, - BasicBlock *MiddleBlock) { + Value *VectorTripCount, Value *EndValue, + BasicBlock *MiddleBlock, + BasicBlock *VectorHeader, VPlan &Plan) { // There are two kinds of external IV usages - those that use the value // computed in the last iteration (the PHI) and those that use the penultimate // value (the value that feeds into the phi from the loop latch). @@ -3612,8 +3248,6 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi, for (User *U : OrigPhi->users()) { auto *UI = cast<Instruction>(U); if (!OrigLoop->contains(UI)) { - const DataLayout &DL = - OrigLoop->getHeader()->getModule()->getDataLayout(); assert(isa<PHINode>(UI) && "Expected LCSSA form"); IRBuilder<> B(MiddleBlock->getTerminator()); @@ -3623,15 +3257,18 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi, B.setFastMathFlags(II.getInductionBinOp()->getFastMathFlags()); Value *CountMinusOne = B.CreateSub( - CountRoundDown, ConstantInt::get(CountRoundDown->getType(), 1)); + VectorTripCount, ConstantInt::get(VectorTripCount->getType(), 1)); Value *CMO = !II.getStep()->getType()->isIntegerTy() ? B.CreateCast(Instruction::SIToFP, CountMinusOne, II.getStep()->getType()) : B.CreateSExtOrTrunc(CountMinusOne, II.getStep()->getType()); CMO->setName("cast.cmo"); + + Value *Step = CreateStepValue(II.getStep(), *PSE.getSE(), + VectorHeader->getTerminator()); Value *Escape = - emitTransformedIndex(B, CMO, PSE.getSE(), DL, II, LoopVectorBody); + emitTransformedIndex(B, CMO, II.getStartValue(), Step, II); Escape->setName("ind.escape"); MissingVals[UI] = Escape; } @@ -3644,8 +3281,10 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi, // In this case, if IV1 has an external use, we need to avoid adding both // "last value of IV1" and "penultimate value of IV2". So, verify that we // don't already have an incoming value for the middle block. - if (PHI->getBasicBlockIndex(MiddleBlock) == -1) + if (PHI->getBasicBlockIndex(MiddleBlock) == -1) { PHI->addIncoming(I.second, MiddleBlock); + Plan.removeLiveOut(PHI); + } } } @@ -3924,18 +3563,16 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths(VPTransformState &State) { } } -void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State) { +void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State, + VPlan &Plan) { // Insert truncates and extends for any truncated instructions as hints to // InstCombine. if (VF.isVector()) truncateToMinimalBitwidths(State); // Fix widened non-induction PHIs by setting up the PHI operands. - if (OrigPHIsToFix.size()) { - assert(EnableVPlanNativePath && - "Unexpected non-induction PHIs for fixup in non VPlan-native path"); - fixNonInductionPHIs(State); - } + if (EnableVPlanNativePath) + fixNonInductionPHIs(Plan, State); // At this point every instruction in the original loop is widened to a // vector form. Now we need to fix the recurrences in the loop. These PHI @@ -3946,24 +3583,37 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State) { // Forget the original basic block. PSE.getSE()->forgetLoop(OrigLoop); - // If we inserted an edge from the middle block to the unique exit block, - // update uses outside the loop (phis) to account for the newly inserted - // edge. - if (!Cost->requiresScalarEpilogue(VF)) { + VPBasicBlock *LatchVPBB = Plan.getVectorLoopRegion()->getExitingBasicBlock(); + Loop *VectorLoop = LI->getLoopFor(State.CFG.VPBB2IRBB[LatchVPBB]); + if (Cost->requiresScalarEpilogue(VF)) { + // No edge from the middle block to the unique exit block has been inserted + // and there is nothing to fix from vector loop; phis should have incoming + // from scalar loop only. + Plan.clearLiveOuts(); + } else { + // If we inserted an edge from the middle block to the unique exit block, + // update uses outside the loop (phis) to account for the newly inserted + // edge. + // Fix-up external users of the induction variables. for (auto &Entry : Legal->getInductionVars()) fixupIVUsers(Entry.first, Entry.second, - getOrCreateVectorTripCount(LI->getLoopFor(LoopVectorBody)), - IVEndValues[Entry.first], LoopMiddleBlock); - - fixLCSSAPHIs(State); + getOrCreateVectorTripCount(VectorLoop->getLoopPreheader()), + IVEndValues[Entry.first], LoopMiddleBlock, + VectorLoop->getHeader(), Plan); } + // Fix LCSSA phis not already fixed earlier. Extracts may need to be generated + // in the exit block, so update the builder. + State.Builder.SetInsertPoint(State.CFG.ExitBB->getFirstNonPHI()); + for (auto &KV : Plan.getLiveOuts()) + KV.second->fixPhi(Plan, State); + for (Instruction *PI : PredicatedInstructions) sinkScalarOperands(&*PI); // Remove redundant induction instructions. - cse(LoopVectorBody); + cse(VectorLoop->getHeader()); // Set/update profile weights for the vector and remainder loops as original // loop iterations are now distributed among them. Note that original loop @@ -3978,9 +3628,9 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State) { // For scalable vectorization we can't know at compile time how many iterations // of the loop are handled in one vector iteration, so instead assume a pessimistic // vscale of '1'. - setProfileInfoAfterUnrolling( - LI->getLoopFor(LoopScalarBody), LI->getLoopFor(LoopVectorBody), - LI->getLoopFor(LoopScalarBody), VF.getKnownMinValue() * UF); + setProfileInfoAfterUnrolling(LI->getLoopFor(LoopScalarBody), VectorLoop, + LI->getLoopFor(LoopScalarBody), + VF.getKnownMinValue() * UF); } void InnerLoopVectorizer::fixCrossIterationPHIs(VPTransformState &State) { @@ -3990,7 +3640,8 @@ void InnerLoopVectorizer::fixCrossIterationPHIs(VPTransformState &State) { // the currently empty PHI nodes. At this point every instruction in the // original loop is widened to a vector form so we can use them to construct // the incoming edges. - VPBasicBlock *Header = State.Plan->getEntry()->getEntryBasicBlock(); + VPBasicBlock *Header = + State.Plan->getVectorLoopRegion()->getEntryBasicBlock(); for (VPRecipeBase &R : Header->phis()) { if (auto *ReductionPhi = dyn_cast<VPReductionPHIRecipe>(&R)) fixReduction(ReductionPhi, State); @@ -4106,8 +3757,10 @@ void InnerLoopVectorizer::fixFirstOrderRecurrence( // and thus no phis which needed updated. if (!Cost->requiresScalarEpilogue(VF)) for (PHINode &LCSSAPhi : LoopExitBlock->phis()) - if (llvm::is_contained(LCSSAPhi.incoming_values(), Phi)) + if (llvm::is_contained(LCSSAPhi.incoming_values(), Phi)) { LCSSAPhi.addIncoming(ExtractForPhiUsedOutsideLoop, LoopMiddleBlock); + State.Plan->removeLiveOut(&LCSSAPhi); + } } void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, @@ -4121,14 +3774,14 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, RecurKind RK = RdxDesc.getRecurrenceKind(); TrackingVH<Value> ReductionStartValue = RdxDesc.getRecurrenceStartValue(); Instruction *LoopExitInst = RdxDesc.getLoopExitInstr(); - setDebugLocFromInst(ReductionStartValue); + State.setDebugLocFromInst(ReductionStartValue); VPValue *LoopExitInstDef = PhiR->getBackedgeValue(); // This is the vector-clone of the value that leaves the loop. Type *VecTy = State.get(LoopExitInstDef, 0)->getType(); // Wrap flags are in general invalid after vectorization, clear them. - clearReductionWrapFlags(RdxDesc, State); + clearReductionWrapFlags(PhiR, State); // Before each round, move the insertion point right between // the PHIs and the values we are going to write. @@ -4136,9 +3789,13 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, // instructions. Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt()); - setDebugLocFromInst(LoopExitInst); + State.setDebugLocFromInst(LoopExitInst); Type *PhiTy = OrigPhi->getType(); + + VPBasicBlock *LatchVPBB = + PhiR->getParent()->getEnclosingLoopRegion()->getExitingBasicBlock(); + BasicBlock *VectorLoopLatch = State.CFG.VPBB2IRBB[LatchVPBB]; // If tail is folded by masking, the vector value to leave the loop should be // a Select choosing between the vectorized LoopExitInst and vectorized Phi, // instead of the former. For an inloop reduction the reduction will already @@ -4146,17 +3803,20 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, if (Cost->foldTailByMasking() && !PhiR->isInLoop()) { for (unsigned Part = 0; Part < UF; ++Part) { Value *VecLoopExitInst = State.get(LoopExitInstDef, Part); - Value *Sel = nullptr; + SelectInst *Sel = nullptr; for (User *U : VecLoopExitInst->users()) { if (isa<SelectInst>(U)) { assert(!Sel && "Reduction exit feeding two selects"); - Sel = U; + Sel = cast<SelectInst>(U); } else assert(isa<PHINode>(U) && "Reduction exit must feed Phi's or select"); } assert(Sel && "Reduction exit feeds no select"); State.reset(LoopExitInstDef, Sel, Part); + if (isa<FPMathOperator>(Sel)) + Sel->setFastMathFlags(RdxDesc.getFastMathFlags()); + // If the target can create a predicated operator for the reduction at no // extra cost in the loop (for example a predicated vadd), it can be // cheaper for the select to remain in the loop than be sunk out of it, @@ -4168,8 +3828,7 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, TargetTransformInfo::ReductionFlags())) { auto *VecRdxPhi = cast<PHINode>(State.get(PhiR, Part)); - VecRdxPhi->setIncomingValueForBlock( - LI->getLoopFor(LoopVectorBody)->getLoopLatch(), Sel); + VecRdxPhi->setIncomingValueForBlock(VectorLoopLatch, Sel); } } } @@ -4180,8 +3839,7 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, if (VF.isVector() && PhiTy != RdxDesc.getRecurrenceType()) { assert(!PhiR->isInLoop() && "Unexpected truncated inloop reduction!"); Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), VF); - Builder.SetInsertPoint( - LI->getLoopFor(LoopVectorBody)->getLoopLatch()->getTerminator()); + Builder.SetInsertPoint(VectorLoopLatch->getTerminator()); VectorParts RdxParts(UF); for (unsigned Part = 0; Part < UF; ++Part) { RdxParts[Part] = State.get(LoopExitInstDef, Part); @@ -4212,7 +3870,7 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, // conditional branch, and (c) other passes may add new predecessors which // terminate on this line. This is the easiest way to ensure we don't // accidentally cause an extra step back into the loop while debugging. - setDebugLocFromInst(LoopMiddleBlock->getTerminator()); + State.setDebugLocFromInst(LoopMiddleBlock->getTerminator()); if (PhiR->isOrdered()) ReducedPartRdx = State.get(LoopExitInstDef, UF - 1); else { @@ -4269,6 +3927,17 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, // Set the resume value for this reduction ReductionResumeValues.insert({&RdxDesc, BCBlockPhi}); + // If there were stores of the reduction value to a uniform memory address + // inside the loop, create the final store here. + if (StoreInst *SI = RdxDesc.IntermediateStore) { + StoreInst *NewSI = + Builder.CreateStore(ReducedPartRdx, SI->getPointerOperand()); + propagateMetadata(NewSI, SI); + + // If the reduction value is used in other places, + // then let the code below create PHI's for that. + } + // Now, we need to fix the users of the reduction variable // inside and outside of the scalar remainder loop. @@ -4277,8 +3946,10 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, // fixFirstOrderRecurrence for a more complete explaination of the logic. if (!Cost->requiresScalarEpilogue(VF)) for (PHINode &LCSSAPhi : LoopExitBlock->phis()) - if (llvm::is_contained(LCSSAPhi.incoming_values(), LoopExitInst)) + if (llvm::is_contained(LCSSAPhi.incoming_values(), LoopExitInst)) { LCSSAPhi.addIncoming(ReducedPartRdx, LoopMiddleBlock); + State.Plan->removeLiveOut(&LCSSAPhi); + } // Fix the scalar loop reduction variable with the incoming reduction sum // from the vector body and from the backedge value. @@ -4291,63 +3962,35 @@ void InnerLoopVectorizer::fixReduction(VPReductionPHIRecipe *PhiR, OrigPhi->setIncomingValue(IncomingEdgeBlockIdx, LoopExitInst); } -void InnerLoopVectorizer::clearReductionWrapFlags(const RecurrenceDescriptor &RdxDesc, +void InnerLoopVectorizer::clearReductionWrapFlags(VPReductionPHIRecipe *PhiR, VPTransformState &State) { + const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor(); RecurKind RK = RdxDesc.getRecurrenceKind(); if (RK != RecurKind::Add && RK != RecurKind::Mul) return; - Instruction *LoopExitInstr = RdxDesc.getLoopExitInstr(); - assert(LoopExitInstr && "null loop exit instruction"); - SmallVector<Instruction *, 8> Worklist; - SmallPtrSet<Instruction *, 8> Visited; - Worklist.push_back(LoopExitInstr); - Visited.insert(LoopExitInstr); + SmallVector<VPValue *, 8> Worklist; + SmallPtrSet<VPValue *, 8> Visited; + Worklist.push_back(PhiR); + Visited.insert(PhiR); while (!Worklist.empty()) { - Instruction *Cur = Worklist.pop_back_val(); - if (isa<OverflowingBinaryOperator>(Cur)) - for (unsigned Part = 0; Part < UF; ++Part) { - // FIXME: Should not rely on getVPValue at this point. - Value *V = State.get(State.Plan->getVPValue(Cur, true), Part); - cast<Instruction>(V)->dropPoisonGeneratingFlags(); + VPValue *Cur = Worklist.pop_back_val(); + for (unsigned Part = 0; Part < UF; ++Part) { + Value *V = State.get(Cur, Part); + if (!isa<OverflowingBinaryOperator>(V)) + break; + cast<Instruction>(V)->dropPoisonGeneratingFlags(); } - for (User *U : Cur->users()) { - Instruction *UI = cast<Instruction>(U); - if ((Cur != LoopExitInstr || OrigLoop->contains(UI->getParent())) && - Visited.insert(UI).second) - Worklist.push_back(UI); - } - } -} - -void InnerLoopVectorizer::fixLCSSAPHIs(VPTransformState &State) { - for (PHINode &LCSSAPhi : LoopExitBlock->phis()) { - if (LCSSAPhi.getBasicBlockIndex(LoopMiddleBlock) != -1) - // Some phis were already hand updated by the reduction and recurrence - // code above, leave them alone. - continue; - - auto *IncomingValue = LCSSAPhi.getIncomingValue(0); - // Non-instruction incoming values will have only one value. - - VPLane Lane = VPLane::getFirstLane(); - if (isa<Instruction>(IncomingValue) && - !Cost->isUniformAfterVectorization(cast<Instruction>(IncomingValue), - VF)) - Lane = VPLane::getLastLaneForVF(VF); - - // Can be a loop invariant incoming value or the last scalar value to be - // extracted from the vectorized loop. - // FIXME: Should not rely on getVPValue at this point. - Builder.SetInsertPoint(LoopMiddleBlock->getTerminator()); - Value *lastIncomingValue = - OrigLoop->isLoopInvariant(IncomingValue) - ? IncomingValue - : State.get(State.Plan->getVPValue(IncomingValue, true), - VPIteration(UF - 1, Lane)); - LCSSAPhi.addIncoming(lastIncomingValue, LoopMiddleBlock); + for (VPUser *U : Cur->users()) { + auto *UserRecipe = dyn_cast<VPRecipeBase>(U); + if (!UserRecipe) + continue; + for (VPValue *V : UserRecipe->definedValues()) + if (Visited.insert(V).second) + Worklist.push_back(V); + } } } @@ -4425,17 +4068,23 @@ void InnerLoopVectorizer::sinkScalarOperands(Instruction *PredInst) { } while (Changed); } -void InnerLoopVectorizer::fixNonInductionPHIs(VPTransformState &State) { - for (PHINode *OrigPhi : OrigPHIsToFix) { - VPWidenPHIRecipe *VPPhi = - cast<VPWidenPHIRecipe>(State.Plan->getVPValue(OrigPhi)); - PHINode *NewPhi = cast<PHINode>(State.get(VPPhi, 0)); - // Make sure the builder has a valid insert point. - Builder.SetInsertPoint(NewPhi); - for (unsigned i = 0; i < VPPhi->getNumOperands(); ++i) { - VPValue *Inc = VPPhi->getIncomingValue(i); - VPBasicBlock *VPBB = VPPhi->getIncomingBlock(i); - NewPhi->addIncoming(State.get(Inc, 0), State.CFG.VPBB2IRBB[VPBB]); +void InnerLoopVectorizer::fixNonInductionPHIs(VPlan &Plan, + VPTransformState &State) { + auto Iter = depth_first( + VPBlockRecursiveTraversalWrapper<VPBlockBase *>(Plan.getEntry())); + for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(Iter)) { + for (VPRecipeBase &P : VPBB->phis()) { + VPWidenPHIRecipe *VPPhi = dyn_cast<VPWidenPHIRecipe>(&P); + if (!VPPhi) + continue; + PHINode *NewPhi = cast<PHINode>(State.get(VPPhi, 0)); + // Make sure the builder has a valid insert point. + Builder.SetInsertPoint(NewPhi); + for (unsigned i = 0; i < VPPhi->getNumOperands(); ++i) { + VPValue *Inc = VPPhi->getIncomingValue(i); + VPBasicBlock *VPBB = VPPhi->getIncomingBlock(i); + NewPhi->addIncoming(State.get(Inc, 0), State.CFG.VPBB2IRBB[VPBB]); + } } } } @@ -4445,139 +4094,6 @@ bool InnerLoopVectorizer::useOrderedReductions( return Cost->useOrderedReductions(RdxDesc); } -void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, - VPWidenPHIRecipe *PhiR, - VPTransformState &State) { - PHINode *P = cast<PHINode>(PN); - if (EnableVPlanNativePath) { - // Currently we enter here in the VPlan-native path for non-induction - // PHIs where all control flow is uniform. We simply widen these PHIs. - // Create a vector phi with no operands - the vector phi operands will be - // set at the end of vector code generation. - Type *VecTy = (State.VF.isScalar()) - ? PN->getType() - : VectorType::get(PN->getType(), State.VF); - Value *VecPhi = Builder.CreatePHI(VecTy, PN->getNumOperands(), "vec.phi"); - State.set(PhiR, VecPhi, 0); - OrigPHIsToFix.push_back(P); - - return; - } - - assert(PN->getParent() == OrigLoop->getHeader() && - "Non-header phis should have been handled elsewhere"); - - // In order to support recurrences we need to be able to vectorize Phi nodes. - // Phi nodes have cycles, so we need to vectorize them in two stages. This is - // stage #1: We create a new vector PHI node with no incoming edges. We'll use - // this value when we vectorize all of the instructions that use the PHI. - - assert(!Legal->isReductionVariable(P) && - "reductions should be handled elsewhere"); - - setDebugLocFromInst(P); - - // This PHINode must be an induction variable. - // Make sure that we know about it. - assert(Legal->getInductionVars().count(P) && "Not an induction variable"); - - InductionDescriptor II = Legal->getInductionVars().lookup(P); - const DataLayout &DL = OrigLoop->getHeader()->getModule()->getDataLayout(); - - auto *IVR = PhiR->getParent()->getPlan()->getCanonicalIV(); - PHINode *CanonicalIV = cast<PHINode>(State.get(IVR, 0)); - - // FIXME: The newly created binary instructions should contain nsw/nuw flags, - // which can be found from the original scalar operations. - switch (II.getKind()) { - case InductionDescriptor::IK_NoInduction: - llvm_unreachable("Unknown induction"); - case InductionDescriptor::IK_IntInduction: - case InductionDescriptor::IK_FpInduction: - llvm_unreachable("Integer/fp induction is handled elsewhere."); - case InductionDescriptor::IK_PtrInduction: { - // Handle the pointer induction variable case. - assert(P->getType()->isPointerTy() && "Unexpected type."); - - if (Cost->isScalarAfterVectorization(P, State.VF)) { - // This is the normalized GEP that starts counting at zero. - Value *PtrInd = - Builder.CreateSExtOrTrunc(CanonicalIV, II.getStep()->getType()); - // Determine the number of scalars we need to generate for each unroll - // iteration. If the instruction is uniform, we only need to generate the - // first lane. Otherwise, we generate all VF values. - bool IsUniform = vputils::onlyFirstLaneUsed(PhiR); - assert((IsUniform || !State.VF.isScalable()) && - "Cannot scalarize a scalable VF"); - unsigned Lanes = IsUniform ? 1 : State.VF.getFixedValue(); - - for (unsigned Part = 0; Part < UF; ++Part) { - Value *PartStart = - createStepForVF(Builder, PtrInd->getType(), VF, Part); - - for (unsigned Lane = 0; Lane < Lanes; ++Lane) { - Value *Idx = Builder.CreateAdd( - PartStart, ConstantInt::get(PtrInd->getType(), Lane)); - Value *GlobalIdx = Builder.CreateAdd(PtrInd, Idx); - Value *SclrGep = emitTransformedIndex(Builder, GlobalIdx, PSE.getSE(), - DL, II, State.CFG.PrevBB); - SclrGep->setName("next.gep"); - State.set(PhiR, SclrGep, VPIteration(Part, Lane)); - } - } - return; - } - assert(isa<SCEVConstant>(II.getStep()) && - "Induction step not a SCEV constant!"); - Type *PhiType = II.getStep()->getType(); - - // Build a pointer phi - Value *ScalarStartValue = PhiR->getStartValue()->getLiveInIRValue(); - Type *ScStValueType = ScalarStartValue->getType(); - PHINode *NewPointerPhi = - PHINode::Create(ScStValueType, 2, "pointer.phi", CanonicalIV); - NewPointerPhi->addIncoming(ScalarStartValue, LoopVectorPreHeader); - - // A pointer induction, performed by using a gep - BasicBlock *LoopLatch = LI->getLoopFor(LoopVectorBody)->getLoopLatch(); - Instruction *InductionLoc = LoopLatch->getTerminator(); - const SCEV *ScalarStep = II.getStep(); - SCEVExpander Exp(*PSE.getSE(), DL, "induction"); - Value *ScalarStepValue = - Exp.expandCodeFor(ScalarStep, PhiType, InductionLoc); - Value *RuntimeVF = getRuntimeVF(Builder, PhiType, VF); - Value *NumUnrolledElems = - Builder.CreateMul(RuntimeVF, ConstantInt::get(PhiType, State.UF)); - Value *InductionGEP = GetElementPtrInst::Create( - II.getElementType(), NewPointerPhi, - Builder.CreateMul(ScalarStepValue, NumUnrolledElems), "ptr.ind", - InductionLoc); - NewPointerPhi->addIncoming(InductionGEP, LoopLatch); - - // Create UF many actual address geps that use the pointer - // phi as base and a vectorized version of the step value - // (<step*0, ..., step*N>) as offset. - for (unsigned Part = 0; Part < State.UF; ++Part) { - Type *VecPhiType = VectorType::get(PhiType, State.VF); - Value *StartOffsetScalar = - Builder.CreateMul(RuntimeVF, ConstantInt::get(PhiType, Part)); - Value *StartOffset = - Builder.CreateVectorSplat(State.VF, StartOffsetScalar); - // Create a vector of consecutive numbers from zero to VF. - StartOffset = - Builder.CreateAdd(StartOffset, Builder.CreateStepVector(VecPhiType)); - - Value *GEP = Builder.CreateGEP( - II.getElementType(), NewPointerPhi, - Builder.CreateMul( - StartOffset, Builder.CreateVectorSplat(State.VF, ScalarStepValue), - "vector.gep")); - State.set(PhiR, GEP, Part); - } - } - } -} - /// A helper function for checking whether an integer division-related /// instruction may divide by zero (in which case it must be predicated if /// executed conditionally in the scalar code). @@ -4601,7 +4117,7 @@ void InnerLoopVectorizer::widenCallInstruction(CallInst &I, VPValue *Def, VPTransformState &State) { assert(!isa<DbgInfoIntrinsic>(I) && "DbgInfoIntrinsic should have been dropped during VPlan construction"); - setDebugLocFromInst(&I); + State.setDebugLocFromInst(&I); Module *M = I.getParent()->getParent()->getParent(); auto *CI = cast<CallInst>(&I); @@ -4631,13 +4147,13 @@ void InnerLoopVectorizer::widenCallInstruction(CallInst &I, VPValue *Def, // Some intrinsics have a scalar argument - don't replace it with a // vector. Value *Arg; - if (!UseVectorIntrinsic || !hasVectorInstrinsicScalarOpd(ID, I.index())) + if (!UseVectorIntrinsic || + !isVectorIntrinsicWithScalarOpAtArg(ID, I.index())) Arg = State.get(I.value(), Part); - else { + else Arg = State.get(I.value(), VPIteration(0, 0)); - if (hasVectorInstrinsicOverloadedScalarOpd(ID, I.index())) - TysForDecl.push_back(Arg->getType()); - } + if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I.index())) + TysForDecl.push_back(Arg->getType()); Args.push_back(Arg); } @@ -4665,7 +4181,7 @@ void InnerLoopVectorizer::widenCallInstruction(CallInst &I, VPValue *Def, V->copyFastMathFlags(CI); State.set(Def, V, Part); - addMetadata(V, &I); + State.addMetadata(V, &I); } } @@ -4676,6 +4192,14 @@ void LoopVectorizationCostModel::collectLoopScalars(ElementCount VF) { assert(VF.isVector() && Scalars.find(VF) == Scalars.end() && "This function should not be visited twice for the same VF"); + // This avoids any chances of creating a REPLICATE recipe during planning + // since that would result in generation of scalarized code during execution, + // which is not supported for scalable vectors. + if (VF.isScalable()) { + Scalars[VF].insert(Uniforms[VF].begin(), Uniforms[VF].end()); + return; + } + SmallSetVector<Instruction *, 8> Worklist; // These sets are used to seed the analysis with pointers used by memory @@ -4765,7 +4289,7 @@ void LoopVectorizationCostModel::collectLoopScalars(ElementCount VF) { } // Insert the forced scalars. - // FIXME: Currently widenPHIInstruction() often creates a dead vector + // FIXME: Currently VPWidenPHIRecipe() often creates a dead vector // induction variable when the PHI user is scalarized. auto ForcedScalar = ForcedScalars.find(VF); if (ForcedScalar != ForcedScalars.end()) @@ -4892,6 +4416,27 @@ bool LoopVectorizationCostModel::interleavedAccessCanBeWidened( if (hasIrregularType(ScalarTy, DL)) return false; + // If the group involves a non-integral pointer, we may not be able to + // losslessly cast all values to a common type. + unsigned InterleaveFactor = Group->getFactor(); + bool ScalarNI = DL.isNonIntegralPointerType(ScalarTy); + for (unsigned i = 0; i < InterleaveFactor; i++) { + Instruction *Member = Group->getMember(i); + if (!Member) + continue; + auto *MemberTy = getLoadStoreType(Member); + bool MemberNI = DL.isNonIntegralPointerType(MemberTy); + // Don't coerce non-integral pointers to integers or vice versa. + if (MemberNI != ScalarNI) { + // TODO: Consider adding special nullptr value case here + return false; + } else if (MemberNI && ScalarNI && + ScalarTy->getPointerAddressSpace() != + MemberTy->getPointerAddressSpace()) { + return false; + } + } + // Check if masking is required. // A Group may need masking for one of two reasons: it resides in a block that // needs predication, or it was decided to use masking to deal with gaps @@ -5174,7 +4719,7 @@ bool LoopVectorizationCostModel::runtimeChecksRequired() { return true; } - if (!PSE.getUnionPredicate().getPredicates().empty()) { + if (!PSE.getPredicate().isAlwaysTrue()) { reportVectorizationFailure("Runtime SCEV check is required with -Os/-Oz", "runtime SCEV checks needed. Enable vectorization of this " "loop with '#pragma clang loop vectorize(enable)' when " @@ -5465,14 +5010,6 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) { } } - // For scalable vectors don't use tail folding for low trip counts or - // optimizing for code size. We only permit this if the user has explicitly - // requested it. - if (ScalarEpilogueStatus != CM_ScalarEpilogueNotNeededUsePredicate && - ScalarEpilogueStatus != CM_ScalarEpilogueNotAllowedUsePredicate && - MaxFactors.ScalableVF.isVector()) - MaxFactors.ScalableVF = ElementCount::getScalable(0); - // If we don't know the precise trip count, or if the trip count that we // found modulo the vectorization factor is not zero, try to fold the tail // by masking. @@ -5515,7 +5052,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) { ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget( unsigned ConstTripCount, unsigned SmallestType, unsigned WidestType, - const ElementCount &MaxSafeVF, bool FoldTailByMasking) { + ElementCount MaxSafeVF, bool FoldTailByMasking) { bool ComputeScalableMaxVF = MaxSafeVF.isScalable(); TypeSize WidestRegister = TTI.getRegisterBitWidth( ComputeScalableMaxVF ? TargetTransformInfo::RGK_ScalableVector @@ -5560,9 +5097,12 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget( return ElementCount::getFixed(ClampedConstTripCount); } + TargetTransformInfo::RegisterKind RegKind = + ComputeScalableMaxVF ? TargetTransformInfo::RGK_ScalableVector + : TargetTransformInfo::RGK_FixedWidthVector; ElementCount MaxVF = MaxVectorElementCount; - if (TTI.shouldMaximizeVectorBandwidth() || - (MaximizeBandwidth && isScalarEpilogueAllowed())) { + if (MaximizeBandwidth || (MaximizeBandwidth.getNumOccurrences() == 0 && + TTI.shouldMaximizeVectorBandwidth(RegKind))) { auto MaxVectorElementCountMaxBW = ElementCount::get( PowerOf2Floor(WidestRegister.getKnownMinSize() / SmallestType), ComputeScalableMaxVF); @@ -5600,6 +5140,11 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget( MaxVF = MinVF; } } + + // Invalidate any widening decisions we might have made, in case the loop + // requires prediction (decided later), but we have already made some + // load/store widening decisions. + invalidateCostModelingDecisions(); } return MaxVF; } @@ -5667,7 +5212,8 @@ VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor( assert(VFCandidates.count(ElementCount::getFixed(1)) && "Expected Scalar VF to be a candidate"); - const VectorizationFactor ScalarCost(ElementCount::getFixed(1), ExpectedCost); + const VectorizationFactor ScalarCost(ElementCount::getFixed(1), ExpectedCost, + ExpectedCost); VectorizationFactor ChosenFactor = ScalarCost; bool ForceVectorization = Hints->getForce() == LoopVectorizeHints::FK_Enabled; @@ -5685,12 +5231,12 @@ VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor( continue; VectorizationCostTy C = expectedCost(i, &InvalidCosts); - VectorizationFactor Candidate(i, C.first); + VectorizationFactor Candidate(i, C.first, ScalarCost.ScalarCost); #ifndef NDEBUG unsigned AssumedMinimumVscale = 1; if (Optional<unsigned> VScale = getVScaleForTuning()) - AssumedMinimumVscale = VScale.getValue(); + AssumedMinimumVscale = *VScale; unsigned Width = Candidate.Width.isScalable() ? Candidate.Width.getKnownMinValue() * AssumedMinimumVscale @@ -5878,7 +5424,7 @@ LoopVectorizationCostModel::selectEpilogueVectorizationFactor( LLVM_DEBUG(dbgs() << "LEV: Epilogue vectorization factor is forced.\n";); ElementCount ForcedEC = ElementCount::getFixed(EpilogueVectorizationForceVF); if (LVP.hasPlanWithVF(ForcedEC)) - return {ForcedEC, 0}; + return {ForcedEC, 0, 0}; else { LLVM_DEBUG( dbgs() @@ -5908,7 +5454,7 @@ LoopVectorizationCostModel::selectEpilogueVectorizationFactor( if (MainLoopVF.isScalable()) { EstimatedRuntimeVF = ElementCount::getFixed(MainLoopVF.getKnownMinValue()); if (Optional<unsigned> VScale = getVScaleForTuning()) - EstimatedRuntimeVF *= VScale.getValue(); + EstimatedRuntimeVF *= *VScale; } for (auto &NextVF : ProfitableVFs) @@ -6144,9 +5690,15 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF, return IC; } - // Note that if we've already vectorized the loop we will have done the - // runtime check and so interleaving won't require further checks. - bool InterleavingRequiresRuntimePointerCheck = + // For any scalar loop that either requires runtime checks or predication we + // are better off leaving this to the unroller. Note that if we've already + // vectorized the loop we will have done the runtime check and so interleaving + // won't require further checks. + bool ScalarInterleavingRequiresPredication = + (VF.isScalar() && any_of(TheLoop->blocks(), [this](BasicBlock *BB) { + return Legal->blockNeedsPredication(BB); + })); + bool ScalarInterleavingRequiresRuntimePointerCheck = (VF.isScalar() && Legal->getRuntimePointerChecking()->Need); // We want to interleave small loops in order to reduce the loop overhead and @@ -6156,7 +5708,8 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF, << "LV: VF is " << VF << '\n'); const bool AggressivelyInterleaveReductions = TTI.enableAggressiveInterleaving(HasReductions); - if (!InterleavingRequiresRuntimePointerCheck && LoopCost < SmallLoopCost) { + if (!ScalarInterleavingRequiresRuntimePointerCheck && + !ScalarInterleavingRequiresPredication && LoopCost < SmallLoopCost) { // We assume that the cost overhead is 1 and we use the cost model // to estimate the cost of the loop and interleave until the cost of the // loop overhead is about 5% of the cost of the loop. @@ -6319,16 +5872,10 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<ElementCount> VFs) { LLVM_DEBUG(dbgs() << "LV(REG): Calculating max register usage:\n"); - // A lambda that gets the register usage for the given type and VF. - const auto &TTICapture = TTI; - auto GetRegUsage = [&TTICapture](Type *Ty, ElementCount VF) -> unsigned { + auto GetRegUsage = [&TTI = TTI](Type *Ty, ElementCount VF) -> unsigned { if (Ty->isTokenTy() || !VectorType::isValidElementType(Ty)) return 0; - InstructionCost::CostType RegUsage = - *TTICapture.getRegUsageForType(VectorType::get(Ty, VF)).getValue(); - assert(RegUsage >= 0 && RegUsage <= std::numeric_limits<unsigned>::max() && - "Nonsensical values for register usage."); - return RegUsage; + return TTI.getRegUsageForType(VectorType::get(Ty, VF)); }; for (unsigned int i = 0, s = IdxToInstr.size(); i < s; ++i) { @@ -7079,10 +6626,17 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, bool TypeNotScalarized = false; if (VF.isVector() && VectorTy->isVectorTy()) { - unsigned NumParts = TTI.getNumberOfParts(VectorTy); - if (NumParts) - TypeNotScalarized = NumParts < VF.getKnownMinValue(); - else + if (unsigned NumParts = TTI.getNumberOfParts(VectorTy)) { + if (VF.isScalable()) + // <vscale x 1 x iN> is assumed to be profitable over iN because + // scalable registers are a distinct register class from scalar ones. + // If we ever find a target which wants to lower scalable vectors + // back to scalars, we'll need to update this code to explicitly + // ask TTI about the register class uses for each part. + TypeNotScalarized = NumParts <= VF.getKnownMinValue(); + else + TypeNotScalarized = NumParts < VF.getKnownMinValue(); + } else C = InstructionCost::getInvalid(); } return VectorizationCostTy(C, TypeNotScalarized); @@ -7158,8 +6712,6 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(ElementCount VF) { Cost = getGatherScatterCost(&I, VF); setWideningDecision(&I, VF, CM_GatherScatter, Cost); } else { - assert((isa<LoadInst>(&I) || !VF.isScalable()) && - "Cannot yet scalarize uniform stores"); Cost = getUniformMemOpCost(&I, VF); setWideningDecision(&I, VF, CM_Scalarize, Cost); } @@ -7517,8 +7069,13 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, InstWidening Decision = getWideningDecision(I, Width); assert(Decision != CM_Unknown && "CM decision should be taken at this point"); - if (Decision == CM_Scalarize) + if (Decision == CM_Scalarize) { + if (VF.isScalable() && isa<StoreInst>(I)) + // We can't scalarize a scalable vector store (even a uniform one + // currently), return an invalid cost so as to prevent vectorization. + return InstructionCost::getInvalid(); Width = ElementCount::getFixed(1); + } } VectorTy = ToVectorTy(getLoadStoreType(I), Width); return getMemoryInstructionCost(I, VF); @@ -7686,6 +7243,16 @@ void LoopVectorizationCostModel::collectValuesToIgnore() { // Ignore ephemeral values. CodeMetrics::collectEphemeralValues(TheLoop, AC, ValuesToIgnore); + // Find all stores to invariant variables. Since they are going to sink + // outside the loop we do not need calculate cost for them. + for (BasicBlock *BB : TheLoop->blocks()) + for (Instruction &I : *BB) { + StoreInst *SI; + if ((SI = dyn_cast<StoreInst>(&I)) && + Legal->isInvariantAddressOfReduction(SI->getPointerOperand())) + ValuesToIgnore.insert(&I); + } + // Ignore type-promoting instructions we identified during reduction // detection. for (auto &Reduction : Legal->getReductionVars()) { @@ -7787,7 +7354,7 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) { if (VPlanBuildStressTest) return VectorizationFactor::Disabled(); - return {VF, 0 /*Cost*/}; + return {VF, 0 /*Cost*/, 0 /* ScalarCost */}; } LLVM_DEBUG( @@ -7796,6 +7363,14 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) { return VectorizationFactor::Disabled(); } +bool LoopVectorizationPlanner::requiresTooManyRuntimeChecks() const { + unsigned NumRuntimePointerChecks = Requirements.getNumRuntimePointerChecks(); + return (NumRuntimePointerChecks > + VectorizerParams::RuntimeMemoryCheckThreshold && + !Hints.allowReordering()) || + NumRuntimePointerChecks > PragmaVectorizeMemoryCheckThreshold; +} + Optional<VectorizationFactor> LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) { assert(OrigLoop->isInnermost() && "Inner loop expected."); @@ -7830,7 +7405,7 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) { CM.collectInLoopReductions(); buildVPlansWithVPRecipes(UserVF, UserVF); LLVM_DEBUG(printPlans(dbgs())); - return {{UserVF, 0}}; + return {{UserVF, 0, 0}}; } else reportVectorizationInfo("UserVF ignored because of invalid costs.", "InvalidCost", ORE, OrigLoop); @@ -7864,30 +7439,7 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) { return VectorizationFactor::Disabled(); // Select the optimal vectorization factor. - auto SelectedVF = CM.selectVectorizationFactor(VFCandidates); - - // Check if it is profitable to vectorize with runtime checks. - unsigned NumRuntimePointerChecks = Requirements.getNumRuntimePointerChecks(); - if (SelectedVF.Width.getKnownMinValue() > 1 && NumRuntimePointerChecks) { - bool PragmaThresholdReached = - NumRuntimePointerChecks > PragmaVectorizeMemoryCheckThreshold; - bool ThresholdReached = - NumRuntimePointerChecks > VectorizerParams::RuntimeMemoryCheckThreshold; - if ((ThresholdReached && !Hints.allowReordering()) || - PragmaThresholdReached) { - ORE->emit([&]() { - return OptimizationRemarkAnalysisAliasing( - DEBUG_TYPE, "CantReorderMemOps", OrigLoop->getStartLoc(), - OrigLoop->getHeader()) - << "loop not vectorized: cannot prove it is safe to reorder " - "memory operations"; - }); - LLVM_DEBUG(dbgs() << "LV: Too many memory checks needed.\n"); - Hints.emitRemarkWithHints(); - return VectorizationFactor::Disabled(); - } - } - return SelectedVF; + return CM.selectVectorizationFactor(VFCandidates); } VPlan &LoopVectorizationPlanner::getBestPlanFor(ElementCount VF) const { @@ -7940,17 +7492,36 @@ static void AddRuntimeUnrollDisableMetaData(Loop *L) { void LoopVectorizationPlanner::executePlan(ElementCount BestVF, unsigned BestUF, VPlan &BestVPlan, InnerLoopVectorizer &ILV, - DominatorTree *DT) { + DominatorTree *DT, + bool IsEpilogueVectorization) { LLVM_DEBUG(dbgs() << "Executing best plan with VF=" << BestVF << ", UF=" << BestUF << '\n'); // Perform the actual loop transformation. - // 1. Create a new empty loop. Unlink the old loop and connect the new one. + // 1. Set up the skeleton for vectorization, including vector pre-header and + // middle block. The vector loop is created during VPlan execution. VPTransformState State{BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan}; Value *CanonicalIVStartValue; std::tie(State.CFG.PrevBB, CanonicalIVStartValue) = ILV.createVectorizedLoopSkeleton(); + + // Only use noalias metadata when using memory checks guaranteeing no overlap + // across all iterations. + const LoopAccessInfo *LAI = ILV.Legal->getLAI(); + if (LAI && !LAI->getRuntimePointerChecking()->getChecks().empty() && + !LAI->getRuntimePointerChecking()->getDiffChecks()) { + + // We currently don't use LoopVersioning for the actual loop cloning but we + // still use it to add the noalias metadata. + // TODO: Find a better way to re-use LoopVersioning functionality to add + // metadata. + State.LVer = std::make_unique<LoopVersioning>( + *LAI, LAI->getRuntimePointerChecking()->getChecks(), OrigLoop, LI, DT, + PSE.getSE()); + State.LVer->prepareNoAliasMetadata(); + } + ILV.collectPoisonGeneratingRecipes(State); ILV.printDebugTracesAtStart(); @@ -7966,7 +7537,9 @@ void LoopVectorizationPlanner::executePlan(ElementCount BestVF, unsigned BestUF, // 2. Copy and widen instructions from the old loop into the new loop. BestVPlan.prepareToExecute(ILV.getOrCreateTripCount(nullptr), ILV.getOrCreateVectorTripCount(nullptr), - CanonicalIVStartValue, State); + CanonicalIVStartValue, State, + IsEpilogueVectorization); + BestVPlan.execute(&State); // Keep all loop hints from the original loop on the vector loop (we'll @@ -7977,8 +7550,10 @@ void LoopVectorizationPlanner::executePlan(ElementCount BestVF, unsigned BestUF, makeFollowupLoopID(OrigLoopID, {LLVMLoopVectorizeFollowupAll, LLVMLoopVectorizeFollowupVectorized}); - Loop *L = LI->getLoopFor(State.CFG.PrevBB); - if (VectorizedLoopID.hasValue()) + VPBasicBlock *HeaderVPBB = + BestVPlan.getVectorLoopRegion()->getEntryBasicBlock(); + Loop *L = LI->getLoopFor(State.CFG.VPBB2IRBB[HeaderVPBB]); + if (VectorizedLoopID) L->setLoopID(VectorizedLoopID.getValue()); else { // Keep all loop hints from the original loop on the vector loop (we'll @@ -7995,7 +7570,7 @@ void LoopVectorizationPlanner::executePlan(ElementCount BestVF, unsigned BestUF, // 3. Fix the vectorized code: take care of header phi's, live-outs, // predication, updating analyses. - ILV.fixVectorizedLoop(State); + ILV.fixVectorizedLoop(State, BestVPlan); ILV.printDebugTracesAtEnd(); } @@ -8066,22 +7641,31 @@ Value *InnerLoopUnroller::getBroadcastInstrs(Value *V) { return V; } std::pair<BasicBlock *, Value *> EpilogueVectorizerMainLoop::createEpilogueVectorizedLoopSkeleton() { MDNode *OrigLoopID = OrigLoop->getLoopID(); - Loop *Lp = createVectorLoopSkeleton(""); + + // Workaround! Compute the trip count of the original loop and cache it + // before we start modifying the CFG. This code has a systemic problem + // wherein it tries to run analysis over partially constructed IR; this is + // wrong, and not simply for SCEV. The trip count of the original loop + // simply happens to be prone to hitting this in practice. In theory, we + // can hit the same issue for any SCEV, or ValueTracking query done during + // mutation. See PR49900. + getOrCreateTripCount(OrigLoop->getLoopPreheader()); + createVectorLoopSkeleton(""); // Generate the code to check the minimum iteration count of the vector // epilogue (see below). EPI.EpilogueIterationCountCheck = - emitMinimumIterationCountCheck(Lp, LoopScalarPreHeader, true); + emitIterationCountCheck(LoopScalarPreHeader, true); EPI.EpilogueIterationCountCheck->setName("iter.check"); // Generate the code to check any assumptions that we've made for SCEV // expressions. - EPI.SCEVSafetyCheck = emitSCEVChecks(Lp, LoopScalarPreHeader); + EPI.SCEVSafetyCheck = emitSCEVChecks(LoopScalarPreHeader); // Generate the code that checks at runtime if arrays overlap. We put the // checks into a separate block to make the more common case of few elements // faster. - EPI.MemSafetyCheck = emitMemRuntimeChecks(Lp, LoopScalarPreHeader); + EPI.MemSafetyCheck = emitMemRuntimeChecks(LoopScalarPreHeader); // Generate the iteration count check for the main loop, *after* the check // for the epilogue loop, so that the path-length is shorter for the case @@ -8090,19 +7674,17 @@ EpilogueVectorizerMainLoop::createEpilogueVectorizedLoopSkeleton() { // trip count. Note: the branch will get updated later on when we vectorize // the epilogue. EPI.MainLoopIterationCountCheck = - emitMinimumIterationCountCheck(Lp, LoopScalarPreHeader, false); + emitIterationCountCheck(LoopScalarPreHeader, false); // Generate the induction variable. - Value *CountRoundDown = getOrCreateVectorTripCount(Lp); - EPI.VectorTripCount = CountRoundDown; - createHeaderBranch(Lp); + EPI.VectorTripCount = getOrCreateVectorTripCount(LoopVectorPreHeader); // Skip induction resume value creation here because they will be created in // the second pass. If we created them here, they wouldn't be used anyway, // because the vplan in the second pass still contains the inductions from the // original loop. - return {completeLoopSkeleton(Lp, OrigLoopID), nullptr}; + return {completeLoopSkeleton(OrigLoopID), nullptr}; } void EpilogueVectorizerMainLoop::printDebugTracesAtStart() { @@ -8122,13 +7704,13 @@ void EpilogueVectorizerMainLoop::printDebugTracesAtEnd() { }); } -BasicBlock *EpilogueVectorizerMainLoop::emitMinimumIterationCountCheck( - Loop *L, BasicBlock *Bypass, bool ForEpilogue) { - assert(L && "Expected valid Loop."); +BasicBlock * +EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass, + bool ForEpilogue) { assert(Bypass && "Expected valid bypass basic block."); ElementCount VFactor = ForEpilogue ? EPI.EpilogueVF : VF; unsigned UFactor = ForEpilogue ? EPI.EpilogueUF : UF; - Value *Count = getOrCreateTripCount(L); + Value *Count = getOrCreateTripCount(LoopVectorPreHeader); // Reuse existing vector loop preheader for TC checks. // Note that new preheader block is generated for vector loop. BasicBlock *const TCCheckBlock = LoopVectorPreHeader; @@ -8187,7 +7769,7 @@ BasicBlock *EpilogueVectorizerMainLoop::emitMinimumIterationCountCheck( std::pair<BasicBlock *, Value *> EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton() { MDNode *OrigLoopID = OrigLoop->getLoopID(); - Loop *Lp = createVectorLoopSkeleton("vec.epilog."); + createVectorLoopSkeleton("vec.epilog."); // Now, compare the remaining count and if there aren't enough iterations to // execute the vectorized epilogue skip to the scalar part. @@ -8196,7 +7778,7 @@ EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton() { LoopVectorPreHeader = SplitBlock(LoopVectorPreHeader, LoopVectorPreHeader->getTerminator(), DT, LI, nullptr, "vec.epilog.ph"); - emitMinimumVectorEpilogueIterCountCheck(Lp, LoopScalarPreHeader, + emitMinimumVectorEpilogueIterCountCheck(LoopScalarPreHeader, VecEpilogueIterationCountCheck); // Adjust the control flow taking the state info from the main loop @@ -8268,9 +7850,6 @@ EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton() { EPResumeVal->addIncoming(ConstantInt::get(IdxTy, 0), EPI.MainLoopIterationCountCheck); - // Generate the induction variable. - createHeaderBranch(Lp); - // Generate induction resume values. These variables save the new starting // indexes for the scalar loop. They are used to test if there are any tail // iterations left once the vector loop has completed. @@ -8278,15 +7857,15 @@ EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton() { // check, then the resume value for the induction variable comes from // the trip count of the main vector loop, hence passing the AdditionalBypass // argument. - createInductionResumeValues(Lp, {VecEpilogueIterationCountCheck, - EPI.VectorTripCount} /* AdditionalBypass */); + createInductionResumeValues({VecEpilogueIterationCountCheck, + EPI.VectorTripCount} /* AdditionalBypass */); - return {completeLoopSkeleton(Lp, OrigLoopID), EPResumeVal}; + return {completeLoopSkeleton(OrigLoopID), EPResumeVal}; } BasicBlock * EpilogueVectorizerEpilogueLoop::emitMinimumVectorEpilogueIterCountCheck( - Loop *L, BasicBlock *Bypass, BasicBlock *Insert) { + BasicBlock *Bypass, BasicBlock *Insert) { assert(EPI.TripCount && "Expected trip count to have been safed in the first pass."); @@ -8427,7 +8006,8 @@ VPValue *VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlanPtr &Plan) { // constructing the desired canonical IV in the header block as its first // non-phi instructions. assert(CM.foldTailByMasking() && "must fold the tail"); - VPBasicBlock *HeaderVPBB = Plan->getEntry()->getEntryBasicBlock(); + VPBasicBlock *HeaderVPBB = + Plan->getVectorLoopRegion()->getEntryBasicBlock(); auto NewInsertionPoint = HeaderVPBB->getFirstNonPhi(); auto *IV = new VPWidenCanonicalIVRecipe(Plan->getCanonicalIV()); HeaderVPBB->insert(IV, HeaderVPBB->getFirstNonPhi()); @@ -8469,8 +8049,6 @@ VPRecipeBase *VPRecipeBuilder::tryToWidenMemory(Instruction *I, "Must be called with either a load or store"); auto willWiden = [&](ElementCount VF) -> bool { - if (VF.isScalar()) - return false; LoopVectorizationCostModel::InstWidening Decision = CM.getWideningDecision(I, VF); assert(Decision != LoopVectorizationCostModel::CM_Unknown && @@ -8507,11 +8085,12 @@ VPRecipeBase *VPRecipeBuilder::tryToWidenMemory(Instruction *I, Mask, Consecutive, Reverse); } -static VPWidenIntOrFpInductionRecipe * -createWidenInductionRecipe(PHINode *Phi, Instruction *PhiOrTrunc, - VPValue *Start, const InductionDescriptor &IndDesc, - LoopVectorizationCostModel &CM, Loop &OrigLoop, - VFRange &Range) { +/// Creates a VPWidenIntOrFpInductionRecpipe for \p Phi. If needed, it will also +/// insert a recipe to expand the step for the induction recipe. +static VPWidenIntOrFpInductionRecipe *createWidenInductionRecipes( + PHINode *Phi, Instruction *PhiOrTrunc, VPValue *Start, + const InductionDescriptor &IndDesc, LoopVectorizationCostModel &CM, + VPlan &Plan, ScalarEvolution &SE, Loop &OrigLoop, VFRange &Range) { // Returns true if an instruction \p I should be scalarized instead of // vectorized for the chosen vectorization factor. auto ShouldScalarizeInstruction = [&CM](Instruction *I, ElementCount VF) { @@ -8519,18 +8098,6 @@ createWidenInductionRecipe(PHINode *Phi, Instruction *PhiOrTrunc, CM.isProfitableToScalarize(I, VF); }; - bool NeedsScalarIV = LoopVectorizationPlanner::getDecisionAndClampRange( - [&](ElementCount VF) { - // Returns true if we should generate a scalar version of \p IV. - if (ShouldScalarizeInstruction(PhiOrTrunc, VF)) - return true; - auto isScalarInst = [&](User *U) -> bool { - auto *I = cast<Instruction>(U); - return OrigLoop.contains(I) && ShouldScalarizeInstruction(I, VF); - }; - return any_of(PhiOrTrunc->users(), isScalarInst); - }, - Range); bool NeedsScalarIVOnly = LoopVectorizationPlanner::getDecisionAndClampRange( [&](ElementCount VF) { return ShouldScalarizeInstruction(PhiOrTrunc, VF); @@ -8538,30 +8105,38 @@ createWidenInductionRecipe(PHINode *Phi, Instruction *PhiOrTrunc, Range); assert(IndDesc.getStartValue() == Phi->getIncomingValueForBlock(OrigLoop.getLoopPreheader())); + assert(SE.isLoopInvariant(IndDesc.getStep(), &OrigLoop) && + "step must be loop invariant"); + + VPValue *Step = + vputils::getOrCreateVPValueForSCEVExpr(Plan, IndDesc.getStep(), SE); if (auto *TruncI = dyn_cast<TruncInst>(PhiOrTrunc)) { - return new VPWidenIntOrFpInductionRecipe(Phi, Start, IndDesc, TruncI, - NeedsScalarIV, !NeedsScalarIVOnly); + return new VPWidenIntOrFpInductionRecipe(Phi, Start, Step, IndDesc, TruncI, + !NeedsScalarIVOnly); } assert(isa<PHINode>(PhiOrTrunc) && "must be a phi node here"); - return new VPWidenIntOrFpInductionRecipe(Phi, Start, IndDesc, NeedsScalarIV, + return new VPWidenIntOrFpInductionRecipe(Phi, Start, Step, IndDesc, !NeedsScalarIVOnly); } -VPWidenIntOrFpInductionRecipe *VPRecipeBuilder::tryToOptimizeInductionPHI( - PHINode *Phi, ArrayRef<VPValue *> Operands, VFRange &Range) const { +VPRecipeBase *VPRecipeBuilder::tryToOptimizeInductionPHI( + PHINode *Phi, ArrayRef<VPValue *> Operands, VPlan &Plan, VFRange &Range) { // Check if this is an integer or fp induction. If so, build the recipe that // produces its scalar and vector values. if (auto *II = Legal->getIntOrFpInductionDescriptor(Phi)) - return createWidenInductionRecipe(Phi, Phi, Operands[0], *II, CM, *OrigLoop, - Range); + return createWidenInductionRecipes(Phi, Phi, Operands[0], *II, CM, Plan, + *PSE.getSE(), *OrigLoop, Range); + // Check if this is pointer induction. If so, build the recipe for it. + if (auto *II = Legal->getPointerInductionDescriptor(Phi)) + return new VPWidenPointerInductionRecipe(Phi, Operands[0], *II, + *PSE.getSE()); return nullptr; } VPWidenIntOrFpInductionRecipe *VPRecipeBuilder::tryToOptimizeInductionTruncate( - TruncInst *I, ArrayRef<VPValue *> Operands, VFRange &Range, - VPlan &Plan) const { + TruncInst *I, ArrayRef<VPValue *> Operands, VFRange &Range, VPlan &Plan) { // Optimize the special case where the source is a constant integer // induction variable. Notice that we can only optimize the 'trunc' case // because (a) FP conversions lose precision, (b) sext/zext may wrap, and @@ -8582,7 +8157,8 @@ VPWidenIntOrFpInductionRecipe *VPRecipeBuilder::tryToOptimizeInductionTruncate( auto *Phi = cast<PHINode>(I->getOperand(0)); const InductionDescriptor &II = *Legal->getIntOrFpInductionDescriptor(Phi); VPValue *Start = Plan.getOrAddVPValue(II.getStartValue()); - return createWidenInductionRecipe(Phi, I, Start, II, CM, *OrigLoop, Range); + return createWidenInductionRecipes(Phi, I, Start, II, CM, Plan, + *PSE.getSE(), *OrigLoop, Range); } return nullptr; } @@ -8599,13 +8175,30 @@ VPRecipeOrVPValueTy VPRecipeBuilder::tryToBlend(PHINode *Phi, return Operands[0]; } + unsigned NumIncoming = Phi->getNumIncomingValues(); + // For in-loop reductions, we do not need to create an additional select. + VPValue *InLoopVal = nullptr; + for (unsigned In = 0; In < NumIncoming; In++) { + PHINode *PhiOp = + dyn_cast_or_null<PHINode>(Operands[In]->getUnderlyingValue()); + if (PhiOp && CM.isInLoopReduction(PhiOp)) { + assert(!InLoopVal && "Found more than one in-loop reduction!"); + InLoopVal = Operands[In]; + } + } + + assert((!InLoopVal || NumIncoming == 2) && + "Found an in-loop reduction for PHI with unexpected number of " + "incoming values"); + if (InLoopVal) + return Operands[Operands[0] == InLoopVal ? 1 : 0]; + // We know that all PHIs in non-header blocks are converted into selects, so // we don't have to worry about the insertion order and we can just use the // builder. At this point we generate the predication tree. There may be // duplications since this is a simple recursive scan, but future // optimizations will clean it up. SmallVector<VPValue *, 2> OperandsWithMask; - unsigned NumIncoming = Phi->getNumIncomingValues(); for (unsigned In = 0; In < NumIncoming; In++) { VPValue *EdgeMask = @@ -8711,6 +8304,7 @@ VPWidenRecipe *VPRecipeBuilder::tryToWiden(Instruction *I, case Instruction::URem: case Instruction::Xor: case Instruction::ZExt: + case Instruction::Freeze: return true; } return false; @@ -8836,14 +8430,14 @@ VPRegionBlock *VPRecipeBuilder::createReplicateRegion(Instruction *Instr, Plan->removeVPValueFor(Instr); Plan->addVPValue(Instr, PHIRecipe); } - auto *Exit = new VPBasicBlock(Twine(RegionName) + ".continue", PHIRecipe); + auto *Exiting = new VPBasicBlock(Twine(RegionName) + ".continue", PHIRecipe); auto *Pred = new VPBasicBlock(Twine(RegionName) + ".if", PredRecipe); - VPRegionBlock *Region = new VPRegionBlock(Entry, Exit, RegionName, true); + VPRegionBlock *Region = new VPRegionBlock(Entry, Exiting, RegionName, true); // Note: first set Entry as region entry and then connect successors starting // from it in order, to propagate the "parent" of each VPBasicBlock. - VPBlockUtils::insertTwoBlocksAfter(Pred, Exit, BlockInMask, Entry); - VPBlockUtils::connectBlocks(Pred, Exit); + VPBlockUtils::insertTwoBlocksAfter(Pred, Exiting, Entry); + VPBlockUtils::connectBlocks(Pred, Exiting); return Region; } @@ -8852,52 +8446,37 @@ VPRecipeOrVPValueTy VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr, ArrayRef<VPValue *> Operands, VFRange &Range, VPlanPtr &Plan) { - // First, check for specific widening recipes that deal with calls, memory - // operations, inductions and Phi nodes. - if (auto *CI = dyn_cast<CallInst>(Instr)) - return toVPRecipeResult(tryToWidenCall(CI, Operands, Range)); - - if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr)) - return toVPRecipeResult(tryToWidenMemory(Instr, Operands, Range, Plan)); - + // First, check for specific widening recipes that deal with inductions, Phi + // nodes, calls and memory operations. VPRecipeBase *Recipe; if (auto Phi = dyn_cast<PHINode>(Instr)) { if (Phi->getParent() != OrigLoop->getHeader()) return tryToBlend(Phi, Operands, Plan); - if ((Recipe = tryToOptimizeInductionPHI(Phi, Operands, Range))) + if ((Recipe = tryToOptimizeInductionPHI(Phi, Operands, *Plan, Range))) return toVPRecipeResult(Recipe); VPHeaderPHIRecipe *PhiRecipe = nullptr; - if (Legal->isReductionVariable(Phi) || Legal->isFirstOrderRecurrence(Phi)) { - VPValue *StartV = Operands[0]; - if (Legal->isReductionVariable(Phi)) { - const RecurrenceDescriptor &RdxDesc = - Legal->getReductionVars().find(Phi)->second; - assert(RdxDesc.getRecurrenceStartValue() == - Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader())); - PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV, - CM.isInLoopReduction(Phi), - CM.useOrderedReductions(RdxDesc)); - } else { - PhiRecipe = new VPFirstOrderRecurrencePHIRecipe(Phi, *StartV); - } - - // Record the incoming value from the backedge, so we can add the incoming - // value from the backedge after all recipes have been created. - recordRecipeOf(cast<Instruction>( - Phi->getIncomingValueForBlock(OrigLoop->getLoopLatch()))); - PhisToFix.push_back(PhiRecipe); + assert((Legal->isReductionVariable(Phi) || + Legal->isFirstOrderRecurrence(Phi)) && + "can only widen reductions and first-order recurrences here"); + VPValue *StartV = Operands[0]; + if (Legal->isReductionVariable(Phi)) { + const RecurrenceDescriptor &RdxDesc = + Legal->getReductionVars().find(Phi)->second; + assert(RdxDesc.getRecurrenceStartValue() == + Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader())); + PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV, + CM.isInLoopReduction(Phi), + CM.useOrderedReductions(RdxDesc)); } else { - // TODO: record backedge value for remaining pointer induction phis. - assert(Phi->getType()->isPointerTy() && - "only pointer phis should be handled here"); - assert(Legal->getInductionVars().count(Phi) && - "Not an induction variable"); - InductionDescriptor II = Legal->getInductionVars().lookup(Phi); - VPValue *Start = Plan->getOrAddVPValue(II.getStartValue()); - PhiRecipe = new VPWidenPHIRecipe(Phi, Start); + PhiRecipe = new VPFirstOrderRecurrencePHIRecipe(Phi, *StartV); } + // Record the incoming value from the backedge, so we can add the incoming + // value from the backedge after all recipes have been created. + recordRecipeOf(cast<Instruction>( + Phi->getIncomingValueForBlock(OrigLoop->getLoopLatch()))); + PhisToFix.push_back(PhiRecipe); return toVPRecipeResult(PhiRecipe); } @@ -8906,6 +8485,17 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr, Range, *Plan))) return toVPRecipeResult(Recipe); + // All widen recipes below deal only with VF > 1. + if (LoopVectorizationPlanner::getDecisionAndClampRange( + [&](ElementCount VF) { return VF.isScalar(); }, Range)) + return nullptr; + + if (auto *CI = dyn_cast<CallInst>(Instr)) + return toVPRecipeResult(tryToWidenCall(CI, Operands, Range)); + + if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr)) + return toVPRecipeResult(tryToWidenMemory(Instr, Operands, Range, Plan)); + if (!shouldWiden(Instr, Range)) return nullptr; @@ -8979,15 +8569,13 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF, // CanonicalIVIncrement{NUW} VPInstruction to increment it by VF * UF and a // BranchOnCount VPInstruction to the latch. static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL, - bool HasNUW, bool IsVPlanNative) { + bool HasNUW) { Value *StartIdx = ConstantInt::get(IdxTy, 0); auto *StartV = Plan.getOrAddVPValue(StartIdx); auto *CanonicalIVPHI = new VPCanonicalIVPHIRecipe(StartV, DL); VPRegionBlock *TopRegion = Plan.getVectorLoopRegion(); VPBasicBlock *Header = TopRegion->getEntryBasicBlock(); - if (IsVPlanNative) - Header = cast<VPBasicBlock>(Header->getSingleSuccessor()); Header->insert(CanonicalIVPHI, Header->begin()); auto *CanonicalIVIncrement = @@ -8996,11 +8584,7 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL, {CanonicalIVPHI}, DL); CanonicalIVPHI->addOperand(CanonicalIVIncrement); - VPBasicBlock *EB = TopRegion->getExitBasicBlock(); - if (IsVPlanNative) { - EB = cast<VPBasicBlock>(EB->getSinglePredecessor()); - EB->setCondBit(nullptr); - } + VPBasicBlock *EB = TopRegion->getExitingBasicBlock(); EB->appendRecipe(CanonicalIVIncrement); auto *BranchOnCount = @@ -9009,6 +8593,26 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, DebugLoc DL, EB->appendRecipe(BranchOnCount); } +// Add exit values to \p Plan. VPLiveOuts are added for each LCSSA phi in the +// original exit block. +static void addUsersInExitBlock(VPBasicBlock *HeaderVPBB, + VPBasicBlock *MiddleVPBB, Loop *OrigLoop, + VPlan &Plan) { + BasicBlock *ExitBB = OrigLoop->getUniqueExitBlock(); + BasicBlock *ExitingBB = OrigLoop->getExitingBlock(); + // Only handle single-exit loops with unique exit blocks for now. + if (!ExitBB || !ExitBB->getSinglePredecessor() || !ExitingBB) + return; + + // Introduce VPUsers modeling the exit values. + for (PHINode &ExitPhi : ExitBB->phis()) { + Value *IncomingValue = + ExitPhi.getIncomingValueForBlock(ExitingBB); + VPValue *V = Plan.getOrAddVPValue(IncomingValue, true); + Plan.addLiveOut(&ExitPhi, V); + } +} + VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( VFRange &Range, SmallPtrSetImpl<Instruction *> &DeadInstructions, const MapVector<Instruction *, Instruction *> &SinkAfter) { @@ -9037,7 +8641,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( RecipeBuilder.recordRecipeOf(Phi); for (auto &R : ReductionOperations) { RecipeBuilder.recordRecipeOf(R); - // For min/max reducitons, where we have a pair of icmp/select, we also + // For min/max reductions, where we have a pair of icmp/select, we also // need to record the ICmp recipe, so it can be removed later. assert(!RecurrenceDescriptor::isSelectCmpRecurrenceKind(Kind) && "Only min/max recurrences allowed for inloop reductions"); @@ -9069,18 +8673,25 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( // visit each basic block after having visited its predecessor basic blocks. // --------------------------------------------------------------------------- - // Create initial VPlan skeleton, with separate header and latch blocks. - VPBasicBlock *HeaderVPBB = new VPBasicBlock(); + // Create initial VPlan skeleton, starting with a block for the pre-header, + // followed by a region for the vector loop, followed by the middle block. The + // skeleton vector loop region contains a header and latch block. + VPBasicBlock *Preheader = new VPBasicBlock("vector.ph"); + auto Plan = std::make_unique<VPlan>(Preheader); + + VPBasicBlock *HeaderVPBB = new VPBasicBlock("vector.body"); VPBasicBlock *LatchVPBB = new VPBasicBlock("vector.latch"); VPBlockUtils::insertBlockAfter(LatchVPBB, HeaderVPBB); auto *TopRegion = new VPRegionBlock(HeaderVPBB, LatchVPBB, "vector loop"); - auto Plan = std::make_unique<VPlan>(TopRegion); + VPBlockUtils::insertBlockAfter(TopRegion, Preheader); + VPBasicBlock *MiddleVPBB = new VPBasicBlock("middle.block"); + VPBlockUtils::insertBlockAfter(MiddleVPBB, TopRegion); Instruction *DLInst = getDebugLocFromInstOrOperands(Legal->getPrimaryInduction()); addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), DLInst ? DLInst->getDebugLoc() : DebugLoc(), - !CM.foldTailByMasking(), false); + !CM.foldTailByMasking()); // Scan the body of the loop in a topological order to visit each basic block // after having visited its predecessor basic blocks. @@ -9093,11 +8704,12 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( // Relevant instructions from basic block BB will be grouped into VPRecipe // ingredients and fill a new VPBasicBlock. unsigned VPBBsForBB = 0; - VPBB->setName(BB->getName()); + if (VPBB != HeaderVPBB) + VPBB->setName(BB->getName()); Builder.setInsertPoint(VPBB); // Introduce each ingredient into VPlan. - // TODO: Model and preserve debug instrinsics in VPlan. + // TODO: Model and preserve debug intrinsics in VPlan. for (Instruction &I : BB->instructionsWithoutDebug()) { Instruction *Instr = &I; @@ -9115,6 +8727,14 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( auto OpRange = Plan->mapToVPValues(Instr->operands()); Operands = {OpRange.begin(), OpRange.end()}; } + + // Invariant stores inside loop will be deleted and a single store + // with the final reduction value will be added to the exit block + StoreInst *SI; + if ((SI = dyn_cast<StoreInst>(&I)) && + Legal->isInvariantAddressOfReduction(SI->getPointerOperand())) + continue; + if (auto RecipeOrValue = RecipeBuilder.tryToCreateWidenRecipe( Instr, Operands, Range, Plan)) { // If Instr can be simplified to an existing VPValue, use it. @@ -9165,14 +8785,18 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( VPBB = cast<VPBasicBlock>(VPBB->getSingleSuccessor()); } + HeaderVPBB->setName("vector.body"); + // Fold the last, empty block into its predecessor. VPBB = VPBlockUtils::tryToMergeBlockIntoPredecessor(VPBB); assert(VPBB && "expected to fold last (empty) block"); // After here, VPBB should not be used. VPBB = nullptr; - assert(isa<VPRegionBlock>(Plan->getEntry()) && - !Plan->getEntry()->getEntryBasicBlock()->empty() && + addUsersInExitBlock(HeaderVPBB, MiddleVPBB, OrigLoop, *Plan); + + assert(isa<VPRegionBlock>(Plan->getVectorLoopRegion()) && + !Plan->getVectorLoopRegion()->getEntryBasicBlock()->empty() && "entry block must be set to a VPRegionBlock having a non-empty entry " "VPBasicBlock"); RecipeBuilder.fixHeaderPhis(); @@ -9252,12 +8876,13 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( Ind->moveBefore(*HeaderVPBB, HeaderVPBB->getFirstNonPhi()); // Adjust the recipes for any inloop reductions. - adjustRecipesForReductions(cast<VPBasicBlock>(TopRegion->getExit()), Plan, + adjustRecipesForReductions(cast<VPBasicBlock>(TopRegion->getExiting()), Plan, RecipeBuilder, Range.Start); // Introduce a recipe to combine the incoming and previous values of a // first-order recurrence. - for (VPRecipeBase &R : Plan->getEntry()->getEntryBasicBlock()->phis()) { + for (VPRecipeBase &R : + Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) { auto *RecurPhi = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R); if (!RecurPhi) continue; @@ -9317,13 +8942,6 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( } } - // From this point onwards, VPlan-to-VPlan transformations may change the plan - // in ways that accessing values using original IR values is incorrect. - Plan->disableValue2VPValue(); - - VPlanTransforms::sinkScalarOperands(*Plan); - VPlanTransforms::mergeReplicateRegions(*Plan); - std::string PlanName; raw_string_ostream RSO(PlanName); ElementCount VF = Range.Start; @@ -9337,10 +8955,20 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes( RSO.flush(); Plan->setName(PlanName); + // From this point onwards, VPlan-to-VPlan transformations may change the plan + // in ways that accessing values using original IR values is incorrect. + Plan->disableValue2VPValue(); + + VPlanTransforms::optimizeInductions(*Plan, *PSE.getSE()); + VPlanTransforms::sinkScalarOperands(*Plan); + VPlanTransforms::mergeReplicateRegions(*Plan); + VPlanTransforms::removeDeadRecipes(*Plan); + VPlanTransforms::removeRedundantExpandSCEVRecipes(*Plan); + // Fold Exit block into its predecessor if possible. // TODO: Fold block earlier once all VPlan transforms properly maintain a // VPBasicBlock as exit. - VPBlockUtils::tryToMergeBlockIntoPredecessor(TopRegion->getExit()); + VPBlockUtils::tryToMergeBlockIntoPredecessor(TopRegion->getExiting()); assert(VPlanVerifier::verifyPlanIsValid(*Plan) && "VPlan is invalid"); return Plan; @@ -9365,23 +8993,20 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) { VF *= 2) Plan->addVF(VF); - if (EnableVPlanPredication) { - VPlanPredicator VPP(*Plan); - VPP.predicate(); - - // Avoid running transformation to recipes until masked code generation in - // VPlan-native path is in place. - return Plan; - } - SmallPtrSet<Instruction *, 1> DeadInstructions; VPlanTransforms::VPInstructionsToVPRecipes( OrigLoop, Plan, [this](PHINode *P) { return Legal->getIntOrFpInductionDescriptor(P); }, DeadInstructions, *PSE.getSE()); + // Remove the existing terminator of the exiting block of the top-most region. + // A BranchOnCount will be added instead when adding the canonical IV recipes. + auto *Term = + Plan->getVectorLoopRegion()->getExitingBasicBlock()->getTerminator(); + Term->eraseFromParent(); + addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), DebugLoc(), - true, true); + true); return Plan; } @@ -9433,7 +9058,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( R->getOperand(FirstOpId) == Chain ? FirstOpId + 1 : FirstOpId; VPValue *VecOp = Plan->getVPValue(R->getOperand(VecOpId)); - auto *CondOp = CM.foldTailByMasking() + auto *CondOp = CM.blockNeedsPredicationForAnyReason(R->getParent()) ? RecipeBuilder.createBlockInMask(R->getParent(), Plan) : nullptr; @@ -9453,9 +9078,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( WidenRecipe->getVPSingleValue()->replaceAllUsesWith(RedRecipe); Plan->removeVPValueFor(R); Plan->addVPValue(R, RedRecipe); - // Append the recipe to the end of the VPBasicBlock because we need to - // ensure that it comes after all of it's inputs, including CondOp. - WidenRecipe->getParent()->appendRecipe(RedRecipe); + WidenRecipe->getParent()->insert(RedRecipe, WidenRecipe->getIterator()); WidenRecipe->getVPSingleValue()->replaceAllUsesWith(RedRecipe); WidenRecipe->eraseFromParent(); @@ -9477,7 +9100,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( // dedicated latch block. if (CM.foldTailByMasking()) { Builder.setInsertPoint(LatchVPBB, LatchVPBB->begin()); - for (VPRecipeBase &R : Plan->getEntry()->getEntryBasicBlock()->phis()) { + for (VPRecipeBase &R : + Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) { VPReductionPHIRecipe *PhiR = dyn_cast<VPReductionPHIRecipe>(&R); if (!PhiR || PhiR->isInLoop()) continue; @@ -9529,7 +9153,7 @@ void VPWidenCallRecipe::execute(VPTransformState &State) { void VPWidenSelectRecipe::execute(VPTransformState &State) { auto &I = *cast<SelectInst>(getUnderlyingInstr()); - State.ILV->setDebugLocFromInst(&I); + State.setDebugLocFromInst(&I); // The condition can be loop invariant but still defined inside the // loop. This means that we can't just use the original 'cond' value. @@ -9544,7 +9168,7 @@ void VPWidenSelectRecipe::execute(VPTransformState &State) { Value *Op1 = State.get(getOperand(2), Part); Value *Sel = State.Builder.CreateSelect(Cond, Op0, Op1); State.set(this, Sel, Part); - State.ILV->addMetadata(Sel, &I); + State.addMetadata(Sel, &I); } } @@ -9578,7 +9202,7 @@ void VPWidenRecipe::execute(VPTransformState &State) { case Instruction::Or: case Instruction::Xor: { // Just widen unops and binops. - State.ILV->setDebugLocFromInst(&I); + State.setDebugLocFromInst(&I); for (unsigned Part = 0; Part < State.UF; ++Part) { SmallVector<Value *, 2> Ops; @@ -9601,17 +9225,28 @@ void VPWidenRecipe::execute(VPTransformState &State) { // Use this vector value for all users of the original instruction. State.set(this, V, Part); - State.ILV->addMetadata(V, &I); + State.addMetadata(V, &I); } break; } + case Instruction::Freeze: { + State.setDebugLocFromInst(&I); + + for (unsigned Part = 0; Part < State.UF; ++Part) { + Value *Op = State.get(getOperand(0), Part); + + Value *Freeze = Builder.CreateFreeze(Op); + State.set(this, Freeze, Part); + } + break; + } case Instruction::ICmp: case Instruction::FCmp: { // Widen compares. Generate vector compares. bool FCmp = (I.getOpcode() == Instruction::FCmp); auto *Cmp = cast<CmpInst>(&I); - State.ILV->setDebugLocFromInst(Cmp); + State.setDebugLocFromInst(Cmp); for (unsigned Part = 0; Part < State.UF; ++Part) { Value *A = State.get(getOperand(0), Part); Value *B = State.get(getOperand(1), Part); @@ -9625,7 +9260,7 @@ void VPWidenRecipe::execute(VPTransformState &State) { C = Builder.CreateICmp(Cmp->getPredicate(), A, B); } State.set(this, C, Part); - State.ILV->addMetadata(C, &I); + State.addMetadata(C, &I); } break; @@ -9644,7 +9279,7 @@ void VPWidenRecipe::execute(VPTransformState &State) { case Instruction::FPTrunc: case Instruction::BitCast: { auto *CI = cast<CastInst>(&I); - State.ILV->setDebugLocFromInst(CI); + State.setDebugLocFromInst(CI); /// Vectorize casts. Type *DestTy = (State.VF.isScalar()) @@ -9655,7 +9290,7 @@ void VPWidenRecipe::execute(VPTransformState &State) { Value *A = State.get(getOperand(0), Part); Value *Cast = Builder.CreateCast(CI->getOpcode(), A, DestTy); State.set(this, Cast, Part); - State.ILV->addMetadata(Cast, &I); + State.addMetadata(Cast, &I); } break; } @@ -9691,7 +9326,7 @@ void VPWidenGEPRecipe::execute(VPTransformState &State) { for (unsigned Part = 0; Part < State.UF; ++Part) { Value *EntryPart = State.Builder.CreateVectorSplat(State.VF, Clone); State.set(this, EntryPart, Part); - State.ILV->addMetadata(EntryPart, GEP); + State.addMetadata(EntryPart, GEP); } } else { // If the GEP has at least one loop-varying operand, we are sure to @@ -9729,32 +9364,276 @@ void VPWidenGEPRecipe::execute(VPTransformState &State) { // Create the new GEP. Note that this GEP may be a scalar if VF == 1, // but it should be a vector, otherwise. - auto *NewGEP = IsInBounds - ? State.Builder.CreateInBoundsGEP( - GEP->getSourceElementType(), Ptr, Indices) - : State.Builder.CreateGEP(GEP->getSourceElementType(), - Ptr, Indices); + auto *NewGEP = State.Builder.CreateGEP(GEP->getSourceElementType(), Ptr, + Indices, "", IsInBounds); assert((State.VF.isScalar() || NewGEP->getType()->isVectorTy()) && "NewGEP is not a pointer vector"); State.set(this, NewGEP, Part); - State.ILV->addMetadata(NewGEP, GEP); + State.addMetadata(NewGEP, GEP); } } } void VPWidenIntOrFpInductionRecipe::execute(VPTransformState &State) { assert(!State.Instance && "Int or FP induction being replicated."); - auto *CanonicalIV = State.get(getParent()->getPlan()->getCanonicalIV(), 0); - State.ILV->widenIntOrFpInduction(IV, this, State, CanonicalIV); + + Value *Start = getStartValue()->getLiveInIRValue(); + const InductionDescriptor &ID = getInductionDescriptor(); + TruncInst *Trunc = getTruncInst(); + IRBuilderBase &Builder = State.Builder; + assert(IV->getType() == ID.getStartValue()->getType() && "Types must match"); + assert(State.VF.isVector() && "must have vector VF"); + + // The value from the original loop to which we are mapping the new induction + // variable. + Instruction *EntryVal = Trunc ? cast<Instruction>(Trunc) : IV; + + // Fast-math-flags propagate from the original induction instruction. + IRBuilder<>::FastMathFlagGuard FMFG(Builder); + if (ID.getInductionBinOp() && isa<FPMathOperator>(ID.getInductionBinOp())) + Builder.setFastMathFlags(ID.getInductionBinOp()->getFastMathFlags()); + + // Now do the actual transformations, and start with fetching the step value. + Value *Step = State.get(getStepValue(), VPIteration(0, 0)); + + assert((isa<PHINode>(EntryVal) || isa<TruncInst>(EntryVal)) && + "Expected either an induction phi-node or a truncate of it!"); + + // Construct the initial value of the vector IV in the vector loop preheader + auto CurrIP = Builder.saveIP(); + BasicBlock *VectorPH = State.CFG.getPreheaderBBFor(this); + Builder.SetInsertPoint(VectorPH->getTerminator()); + if (isa<TruncInst>(EntryVal)) { + assert(Start->getType()->isIntegerTy() && + "Truncation requires an integer type"); + auto *TruncType = cast<IntegerType>(EntryVal->getType()); + Step = Builder.CreateTrunc(Step, TruncType); + Start = Builder.CreateCast(Instruction::Trunc, Start, TruncType); + } + + Value *Zero = getSignedIntOrFpConstant(Start->getType(), 0); + Value *SplatStart = Builder.CreateVectorSplat(State.VF, Start); + Value *SteppedStart = getStepVector( + SplatStart, Zero, Step, ID.getInductionOpcode(), State.VF, State.Builder); + + // We create vector phi nodes for both integer and floating-point induction + // variables. Here, we determine the kind of arithmetic we will perform. + Instruction::BinaryOps AddOp; + Instruction::BinaryOps MulOp; + if (Step->getType()->isIntegerTy()) { + AddOp = Instruction::Add; + MulOp = Instruction::Mul; + } else { + AddOp = ID.getInductionOpcode(); + MulOp = Instruction::FMul; + } + + // Multiply the vectorization factor by the step using integer or + // floating-point arithmetic as appropriate. + Type *StepType = Step->getType(); + Value *RuntimeVF; + if (Step->getType()->isFloatingPointTy()) + RuntimeVF = getRuntimeVFAsFloat(Builder, StepType, State.VF); + else + RuntimeVF = getRuntimeVF(Builder, StepType, State.VF); + Value *Mul = Builder.CreateBinOp(MulOp, Step, RuntimeVF); + + // Create a vector splat to use in the induction update. + // + // FIXME: If the step is non-constant, we create the vector splat with + // IRBuilder. IRBuilder can constant-fold the multiply, but it doesn't + // handle a constant vector splat. + Value *SplatVF = isa<Constant>(Mul) + ? ConstantVector::getSplat(State.VF, cast<Constant>(Mul)) + : Builder.CreateVectorSplat(State.VF, Mul); + Builder.restoreIP(CurrIP); + + // We may need to add the step a number of times, depending on the unroll + // factor. The last of those goes into the PHI. + PHINode *VecInd = PHINode::Create(SteppedStart->getType(), 2, "vec.ind", + &*State.CFG.PrevBB->getFirstInsertionPt()); + VecInd->setDebugLoc(EntryVal->getDebugLoc()); + Instruction *LastInduction = VecInd; + for (unsigned Part = 0; Part < State.UF; ++Part) { + State.set(this, LastInduction, Part); + + if (isa<TruncInst>(EntryVal)) + State.addMetadata(LastInduction, EntryVal); + + LastInduction = cast<Instruction>( + Builder.CreateBinOp(AddOp, LastInduction, SplatVF, "step.add")); + LastInduction->setDebugLoc(EntryVal->getDebugLoc()); + } + + LastInduction->setName("vec.ind.next"); + VecInd->addIncoming(SteppedStart, VectorPH); + // Add induction update using an incorrect block temporarily. The phi node + // will be fixed after VPlan execution. Note that at this point the latch + // block cannot be used, as it does not exist yet. + // TODO: Model increment value in VPlan, by turning the recipe into a + // multi-def and a subclass of VPHeaderPHIRecipe. + VecInd->addIncoming(LastInduction, VectorPH); +} + +void VPWidenPointerInductionRecipe::execute(VPTransformState &State) { + assert(IndDesc.getKind() == InductionDescriptor::IK_PtrInduction && + "Not a pointer induction according to InductionDescriptor!"); + assert(cast<PHINode>(getUnderlyingInstr())->getType()->isPointerTy() && + "Unexpected type."); + + auto *IVR = getParent()->getPlan()->getCanonicalIV(); + PHINode *CanonicalIV = cast<PHINode>(State.get(IVR, 0)); + + if (onlyScalarsGenerated(State.VF)) { + // This is the normalized GEP that starts counting at zero. + Value *PtrInd = State.Builder.CreateSExtOrTrunc( + CanonicalIV, IndDesc.getStep()->getType()); + // Determine the number of scalars we need to generate for each unroll + // iteration. If the instruction is uniform, we only need to generate the + // first lane. Otherwise, we generate all VF values. + bool IsUniform = vputils::onlyFirstLaneUsed(this); + assert((IsUniform || !State.VF.isScalable()) && + "Cannot scalarize a scalable VF"); + unsigned Lanes = IsUniform ? 1 : State.VF.getFixedValue(); + + for (unsigned Part = 0; Part < State.UF; ++Part) { + Value *PartStart = + createStepForVF(State.Builder, PtrInd->getType(), State.VF, Part); + + for (unsigned Lane = 0; Lane < Lanes; ++Lane) { + Value *Idx = State.Builder.CreateAdd( + PartStart, ConstantInt::get(PtrInd->getType(), Lane)); + Value *GlobalIdx = State.Builder.CreateAdd(PtrInd, Idx); + + Value *Step = CreateStepValue(IndDesc.getStep(), SE, + State.CFG.PrevBB->getTerminator()); + Value *SclrGep = emitTransformedIndex( + State.Builder, GlobalIdx, IndDesc.getStartValue(), Step, IndDesc); + SclrGep->setName("next.gep"); + State.set(this, SclrGep, VPIteration(Part, Lane)); + } + } + return; + } + + assert(isa<SCEVConstant>(IndDesc.getStep()) && + "Induction step not a SCEV constant!"); + Type *PhiType = IndDesc.getStep()->getType(); + + // Build a pointer phi + Value *ScalarStartValue = getStartValue()->getLiveInIRValue(); + Type *ScStValueType = ScalarStartValue->getType(); + PHINode *NewPointerPhi = + PHINode::Create(ScStValueType, 2, "pointer.phi", CanonicalIV); + + BasicBlock *VectorPH = State.CFG.getPreheaderBBFor(this); + NewPointerPhi->addIncoming(ScalarStartValue, VectorPH); + + // A pointer induction, performed by using a gep + const DataLayout &DL = NewPointerPhi->getModule()->getDataLayout(); + Instruction *InductionLoc = &*State.Builder.GetInsertPoint(); + + const SCEV *ScalarStep = IndDesc.getStep(); + SCEVExpander Exp(SE, DL, "induction"); + Value *ScalarStepValue = Exp.expandCodeFor(ScalarStep, PhiType, InductionLoc); + Value *RuntimeVF = getRuntimeVF(State.Builder, PhiType, State.VF); + Value *NumUnrolledElems = + State.Builder.CreateMul(RuntimeVF, ConstantInt::get(PhiType, State.UF)); + Value *InductionGEP = GetElementPtrInst::Create( + IndDesc.getElementType(), NewPointerPhi, + State.Builder.CreateMul(ScalarStepValue, NumUnrolledElems), "ptr.ind", + InductionLoc); + // Add induction update using an incorrect block temporarily. The phi node + // will be fixed after VPlan execution. Note that at this point the latch + // block cannot be used, as it does not exist yet. + // TODO: Model increment value in VPlan, by turning the recipe into a + // multi-def and a subclass of VPHeaderPHIRecipe. + NewPointerPhi->addIncoming(InductionGEP, VectorPH); + + // Create UF many actual address geps that use the pointer + // phi as base and a vectorized version of the step value + // (<step*0, ..., step*N>) as offset. + for (unsigned Part = 0; Part < State.UF; ++Part) { + Type *VecPhiType = VectorType::get(PhiType, State.VF); + Value *StartOffsetScalar = + State.Builder.CreateMul(RuntimeVF, ConstantInt::get(PhiType, Part)); + Value *StartOffset = + State.Builder.CreateVectorSplat(State.VF, StartOffsetScalar); + // Create a vector of consecutive numbers from zero to VF. + StartOffset = State.Builder.CreateAdd( + StartOffset, State.Builder.CreateStepVector(VecPhiType)); + + Value *GEP = State.Builder.CreateGEP( + IndDesc.getElementType(), NewPointerPhi, + State.Builder.CreateMul( + StartOffset, + State.Builder.CreateVectorSplat(State.VF, ScalarStepValue), + "vector.gep")); + State.set(this, GEP, Part); + } } -void VPWidenPHIRecipe::execute(VPTransformState &State) { - State.ILV->widenPHIInstruction(cast<PHINode>(getUnderlyingValue()), this, - State); +void VPScalarIVStepsRecipe::execute(VPTransformState &State) { + assert(!State.Instance && "VPScalarIVStepsRecipe being replicated."); + + // Fast-math-flags propagate from the original induction instruction. + IRBuilder<>::FastMathFlagGuard FMFG(State.Builder); + if (IndDesc.getInductionBinOp() && + isa<FPMathOperator>(IndDesc.getInductionBinOp())) + State.Builder.setFastMathFlags( + IndDesc.getInductionBinOp()->getFastMathFlags()); + + Value *Step = State.get(getStepValue(), VPIteration(0, 0)); + auto CreateScalarIV = [&](Value *&Step) -> Value * { + Value *ScalarIV = State.get(getCanonicalIV(), VPIteration(0, 0)); + auto *CanonicalIV = State.get(getParent()->getPlan()->getCanonicalIV(), 0); + if (!isCanonical() || CanonicalIV->getType() != Ty) { + ScalarIV = + Ty->isIntegerTy() + ? State.Builder.CreateSExtOrTrunc(ScalarIV, Ty) + : State.Builder.CreateCast(Instruction::SIToFP, ScalarIV, Ty); + ScalarIV = emitTransformedIndex(State.Builder, ScalarIV, + getStartValue()->getLiveInIRValue(), Step, + IndDesc); + ScalarIV->setName("offset.idx"); + } + if (TruncToTy) { + assert(Step->getType()->isIntegerTy() && + "Truncation requires an integer step"); + ScalarIV = State.Builder.CreateTrunc(ScalarIV, TruncToTy); + Step = State.Builder.CreateTrunc(Step, TruncToTy); + } + return ScalarIV; + }; + + Value *ScalarIV = CreateScalarIV(Step); + if (State.VF.isVector()) { + buildScalarSteps(ScalarIV, Step, IndDesc, this, State); + return; + } + + for (unsigned Part = 0; Part < State.UF; ++Part) { + assert(!State.VF.isScalable() && "scalable vectors not yet supported."); + Value *EntryPart; + if (Step->getType()->isFloatingPointTy()) { + Value *StartIdx = + getRuntimeVFAsFloat(State.Builder, Step->getType(), State.VF * Part); + // Floating-point operations inherit FMF via the builder's flags. + Value *MulOp = State.Builder.CreateFMul(StartIdx, Step); + EntryPart = State.Builder.CreateBinOp(IndDesc.getInductionOpcode(), + ScalarIV, MulOp); + } else { + Value *StartIdx = + getRuntimeVF(State.Builder, Step->getType(), State.VF * Part); + EntryPart = State.Builder.CreateAdd( + ScalarIV, State.Builder.CreateMul(StartIdx, Step), "induction"); + } + State.set(this, EntryPart, Part); + } } void VPBlendRecipe::execute(VPTransformState &State) { - State.ILV->setDebugLocFromInst(Phi, &State.Builder); + State.setDebugLocFromInst(Phi); // We know that all PHIs in non-header blocks are converted into // selects, so we don't have to worry about the insertion order and we // can just use the builder. @@ -10015,7 +9894,7 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { // Handle Stores: if (SI) { - State.ILV->setDebugLocFromInst(SI); + State.setDebugLocFromInst(SI); for (unsigned Part = 0; Part < State.UF; ++Part) { Instruction *NewSI = nullptr; @@ -10041,14 +9920,14 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { else NewSI = Builder.CreateAlignedStore(StoredVal, VecPtr, Alignment); } - State.ILV->addMetadata(NewSI, SI); + State.addMetadata(NewSI, SI); } return; } // Handle loads. assert(LI && "Must have a load instruction"); - State.ILV->setDebugLocFromInst(LI); + State.setDebugLocFromInst(LI); for (unsigned Part = 0; Part < State.UF; ++Part) { Value *NewLI; if (CreateGatherScatter) { @@ -10056,7 +9935,7 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { Value *VectorGep = State.get(getAddr(), Part); NewLI = Builder.CreateMaskedGather(DataTy, VectorGep, Alignment, MaskPart, nullptr, "wide.masked.gather"); - State.ILV->addMetadata(NewLI, LI); + State.addMetadata(NewLI, LI); } else { auto *VecPtr = CreateVecPtr(Part, State.get(getAddr(), VPIteration(0, 0))); @@ -10069,12 +9948,12 @@ void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { Builder.CreateAlignedLoad(DataTy, VecPtr, Alignment, "wide.load"); // Add metadata to the load, but setVectorValue to the reverse shuffle. - State.ILV->addMetadata(NewLI, LI); + State.addMetadata(NewLI, LI); if (Reverse) NewLI = Builder.CreateVectorReverse(NewLI, "reverse"); } - State.set(this, NewLI, Part); + State.set(getVPSingleValue(), NewLI, Part); } } @@ -10155,7 +10034,8 @@ Value *VPTransformState::get(VPValue *Def, unsigned Part) { // Check if there is a scalar value for the selected lane. if (!hasScalarValue(Def, {Part, LastLane})) { // At the moment, VPWidenIntOrFpInductionRecipes can also be uniform. - assert(isa<VPWidenIntOrFpInductionRecipe>(Def->getDef()) && + assert((isa<VPWidenIntOrFpInductionRecipe>(Def->getDef()) || + isa<VPScalarIVStepsRecipe>(Def->getDef())) && "unexpected recipe found to be invariant"); IsUniform = true; LastLane = 0; @@ -10237,8 +10117,7 @@ static bool processLoopInVPlanNativePath( // If we are stress testing VPlan builds, do not attempt to generate vector // code. Masked vector code generation support will follow soon. // Also, do not attempt to vectorize if no vector code will be produced. - if (VPlanBuildStressTest || EnableVPlanPredication || - VectorizationFactor::Disabled() == VF) + if (VPlanBuildStressTest || VectorizationFactor::Disabled() == VF) return false; VPlan &BestPlan = LVP.getBestPlanFor(VF.Width); @@ -10250,7 +10129,7 @@ static bool processLoopInVPlanNativePath( &CM, BFI, PSI, Checks); LLVM_DEBUG(dbgs() << "Vectorizing outer loop in \"" << L->getHeader()->getParent()->getName() << "\"\n"); - LVP.executePlan(VF.Width, 1, BestPlan, LB, DT); + LVP.executePlan(VF.Width, 1, BestPlan, LB, DT, false); } // Mark the loop as already vectorized to avoid vectorizing again. @@ -10318,8 +10197,8 @@ bool LoopVectorizePass::processLoop(Loop *L) { const std::string DebugLocStr = getDebugLocString(L); #endif /* NDEBUG */ - LLVM_DEBUG(dbgs() << "\nLV: Checking a loop in \"" - << L->getHeader()->getParent()->getName() << "\" from " + LLVM_DEBUG(dbgs() << "\nLV: Checking a loop in '" + << L->getHeader()->getParent()->getName() << "' from " << DebugLocStr << "\n"); LoopVectorizeHints Hints(L, InterleaveOnlyWhenForced, *ORE, TTI); @@ -10474,10 +10353,30 @@ bool LoopVectorizePass::processLoop(Loop *L) { VectorizationFactor VF = VectorizationFactor::Disabled(); unsigned IC = 1; + GeneratedRTChecks Checks(*PSE.getSE(), DT, LI, + F->getParent()->getDataLayout()); if (MaybeVF) { + if (LVP.requiresTooManyRuntimeChecks()) { + ORE->emit([&]() { + return OptimizationRemarkAnalysisAliasing( + DEBUG_TYPE, "CantReorderMemOps", L->getStartLoc(), + L->getHeader()) + << "loop not vectorized: cannot prove it is safe to reorder " + "memory operations"; + }); + LLVM_DEBUG(dbgs() << "LV: Too many memory checks needed.\n"); + Hints.emitRemarkWithHints(); + return false; + } VF = *MaybeVF; // Select the interleave count. IC = CM.selectInterleaveCount(VF.Width, *VF.Cost.getValue()); + + unsigned SelectedIC = std::max(IC, UserIC); + // Optimistically generate runtime checks if they are needed. Drop them if + // they turn out to not be profitable. + if (VF.Width.isVector() || SelectedIC > 1) + Checks.Create(L, *LVL.getLAI(), PSE.getPredicate(), VF.Width, SelectedIC); } // Identify the diagnostic messages that should be produced. @@ -10565,14 +10464,6 @@ bool LoopVectorizePass::processLoop(Loop *L) { bool DisableRuntimeUnroll = false; MDNode *OrigLoopID = L->getLoopID(); { - // Optimistically generate runtime checks. Drop them if they turn out to not - // be profitable. Limit the scope of Checks, so the cleanup happens - // immediately after vector codegeneration is done. - GeneratedRTChecks Checks(*PSE.getSE(), DT, LI, - F->getParent()->getDataLayout()); - if (!VF.Width.isScalar() || IC > 1) - Checks.Create(L, *LVL.getLAI(), PSE.getUnionPredicate()); - using namespace ore; if (!VectorizeLoop) { assert(IC > 1 && "interleave count should not be 1 or 0"); @@ -10582,7 +10473,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { &CM, BFI, PSI, Checks); VPlan &BestPlan = LVP.getBestPlanFor(VF.Width); - LVP.executePlan(VF.Width, IC, BestPlan, Unroller, DT); + LVP.executePlan(VF.Width, IC, BestPlan, Unroller, DT, false); ORE->emit([&]() { return OptimizationRemark(LV_NAME, "Interleaved", L->getStartLoc(), @@ -10607,12 +10498,9 @@ bool LoopVectorizePass::processLoop(Loop *L) { VPlan &BestMainPlan = LVP.getBestPlanFor(EPI.MainLoopVF); LVP.executePlan(EPI.MainLoopVF, EPI.MainLoopUF, BestMainPlan, MainILV, - DT); + DT, true); ++LoopsVectorized; - simplifyLoop(L, DT, LI, SE, AC, nullptr, false /* PreserveLCSSA */); - formLCSSARecursively(*L, *DT, LI, SE); - // Second pass vectorizes the epilogue and adjusts the control flow // edges from the first pass. EPI.MainLoopVF = EPI.EpilogueVF; @@ -10622,23 +10510,24 @@ bool LoopVectorizePass::processLoop(Loop *L) { Checks); VPlan &BestEpiPlan = LVP.getBestPlanFor(EPI.EpilogueVF); + VPRegionBlock *VectorLoop = BestEpiPlan.getVectorLoopRegion(); + VPBasicBlock *Header = VectorLoop->getEntryBasicBlock(); + Header->setName("vec.epilog.vector.body"); // Ensure that the start values for any VPReductionPHIRecipes are // updated before vectorising the epilogue loop. - VPBasicBlock *Header = BestEpiPlan.getEntry()->getEntryBasicBlock(); for (VPRecipeBase &R : Header->phis()) { if (auto *ReductionPhi = dyn_cast<VPReductionPHIRecipe>(&R)) { if (auto *Resume = MainILV.getReductionResumeValue( ReductionPhi->getRecurrenceDescriptor())) { - VPValue *StartVal = new VPValue(Resume); - BestEpiPlan.addExternalDef(StartVal); + VPValue *StartVal = BestEpiPlan.getOrAddExternalDef(Resume); ReductionPhi->setOperand(0, StartVal); } } } LVP.executePlan(EPI.EpilogueVF, EPI.EpilogueUF, BestEpiPlan, EpilogILV, - DT); + DT, true); ++LoopsEpilogueVectorized; if (!MainILV.areSafetyChecksAdded()) @@ -10648,7 +10537,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { &LVL, &CM, BFI, PSI, Checks); VPlan &BestPlan = LVP.getBestPlanFor(VF.Width); - LVP.executePlan(VF.Width, IC, BestPlan, LB, DT); + LVP.executePlan(VF.Width, IC, BestPlan, LB, DT, false); ++LoopsVectorized; // Add metadata to disable runtime unrolling a scalar loop when there @@ -10674,7 +10563,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { Optional<MDNode *> RemainderLoopID = makeFollowupLoopID(OrigLoopID, {LLVMLoopVectorizeFollowupAll, LLVMLoopVectorizeFollowupEpilogue}); - if (RemainderLoopID.hasValue()) { + if (RemainderLoopID) { L->setLoopID(RemainderLoopID.getValue()); } else { if (DisableRuntimeUnroll) @@ -10756,8 +10645,12 @@ LoopVectorizeResult LoopVectorizePass::runImpl( PreservedAnalyses LoopVectorizePass::run(Function &F, FunctionAnalysisManager &AM) { - auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); auto &LI = AM.getResult<LoopAnalysis>(F); + // There are no loops in the function. Return before computing other expensive + // analyses. + if (LI.empty()) + return PreservedAnalyses::all(); + auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); auto &TTI = AM.getResult<TargetIRAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &BFI = AM.getResult<BlockFrequencyAnalysis>(F); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 644372483edd..019a09665a67 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -53,7 +53,6 @@ #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/DebugLoc.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" @@ -64,7 +63,6 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" -#include "llvm/IR/NoFolder.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" @@ -72,8 +70,9 @@ #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" +#ifdef EXPENSIVE_CHECKS #include "llvm/IR/Verifier.h" -#include "llvm/InitializePasses.h" +#endif #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" @@ -87,6 +86,7 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/InjectTLIMappings.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Vectorize.h" #include <algorithm> @@ -164,13 +164,14 @@ static cl::opt<int> LookAheadMaxDepth( "slp-max-look-ahead-depth", cl::init(2), cl::Hidden, cl::desc("The maximum look-ahead depth for operand reordering scores")); -// The Look-ahead heuristic goes through the users of the bundle to calculate -// the users cost in getExternalUsesCost(). To avoid compilation time increase -// we limit the number of users visited to this value. -static cl::opt<unsigned> LookAheadUsersBudget( - "slp-look-ahead-users-budget", cl::init(2), cl::Hidden, - cl::desc("The maximum number of users to visit while visiting the " - "predecessors. This prevents compilation time increase.")); +// The maximum depth that the look-ahead score heuristic will explore +// when it probing among candidates for vectorization tree roots. +// The higher this value, the higher the compilation time overhead but unlike +// similar limit for operands ordering this is less frequently used, hence +// impact of higher value is less noticeable. +static cl::opt<int> RootLookAheadMaxDepth( + "slp-max-root-look-ahead-depth", cl::init(2), cl::Hidden, + cl::desc("The maximum look-ahead depth for searching best rooting option")); static cl::opt<bool> ViewSLPTree("view-slp-tree", cl::Hidden, @@ -471,17 +472,36 @@ static bool isValidForAlternation(unsigned Opcode) { return true; } +static InstructionsState getSameOpcode(ArrayRef<Value *> VL, + unsigned BaseIndex = 0); + +/// Checks if the provided operands of 2 cmp instructions are compatible, i.e. +/// compatible instructions or constants, or just some other regular values. +static bool areCompatibleCmpOps(Value *BaseOp0, Value *BaseOp1, Value *Op0, + Value *Op1) { + return (isConstant(BaseOp0) && isConstant(Op0)) || + (isConstant(BaseOp1) && isConstant(Op1)) || + (!isa<Instruction>(BaseOp0) && !isa<Instruction>(Op0) && + !isa<Instruction>(BaseOp1) && !isa<Instruction>(Op1)) || + getSameOpcode({BaseOp0, Op0}).getOpcode() || + getSameOpcode({BaseOp1, Op1}).getOpcode(); +} + /// \returns analysis of the Instructions in \p VL described in /// InstructionsState, the Opcode that we suppose the whole list /// could be vectorized even if its structure is diverse. static InstructionsState getSameOpcode(ArrayRef<Value *> VL, - unsigned BaseIndex = 0) { + unsigned BaseIndex) { // Make sure these are all Instructions. if (llvm::any_of(VL, [](Value *V) { return !isa<Instruction>(V); })) return InstructionsState(VL[BaseIndex], nullptr, nullptr); bool IsCastOp = isa<CastInst>(VL[BaseIndex]); bool IsBinOp = isa<BinaryOperator>(VL[BaseIndex]); + bool IsCmpOp = isa<CmpInst>(VL[BaseIndex]); + CmpInst::Predicate BasePred = + IsCmpOp ? cast<CmpInst>(VL[BaseIndex])->getPredicate() + : CmpInst::BAD_ICMP_PREDICATE; unsigned Opcode = cast<Instruction>(VL[BaseIndex])->getOpcode(); unsigned AltOpcode = Opcode; unsigned AltIndex = BaseIndex; @@ -514,6 +534,57 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL, continue; } } + } else if (IsCmpOp && isa<CmpInst>(VL[Cnt])) { + auto *BaseInst = cast<Instruction>(VL[BaseIndex]); + auto *Inst = cast<Instruction>(VL[Cnt]); + Type *Ty0 = BaseInst->getOperand(0)->getType(); + Type *Ty1 = Inst->getOperand(0)->getType(); + if (Ty0 == Ty1) { + Value *BaseOp0 = BaseInst->getOperand(0); + Value *BaseOp1 = BaseInst->getOperand(1); + Value *Op0 = Inst->getOperand(0); + Value *Op1 = Inst->getOperand(1); + CmpInst::Predicate CurrentPred = + cast<CmpInst>(VL[Cnt])->getPredicate(); + CmpInst::Predicate SwappedCurrentPred = + CmpInst::getSwappedPredicate(CurrentPred); + // Check for compatible operands. If the corresponding operands are not + // compatible - need to perform alternate vectorization. + if (InstOpcode == Opcode) { + if (BasePred == CurrentPred && + areCompatibleCmpOps(BaseOp0, BaseOp1, Op0, Op1)) + continue; + if (BasePred == SwappedCurrentPred && + areCompatibleCmpOps(BaseOp0, BaseOp1, Op1, Op0)) + continue; + if (E == 2 && + (BasePred == CurrentPred || BasePred == SwappedCurrentPred)) + continue; + auto *AltInst = cast<CmpInst>(VL[AltIndex]); + CmpInst::Predicate AltPred = AltInst->getPredicate(); + Value *AltOp0 = AltInst->getOperand(0); + Value *AltOp1 = AltInst->getOperand(1); + // Check if operands are compatible with alternate operands. + if (AltPred == CurrentPred && + areCompatibleCmpOps(AltOp0, AltOp1, Op0, Op1)) + continue; + if (AltPred == SwappedCurrentPred && + areCompatibleCmpOps(AltOp0, AltOp1, Op1, Op0)) + continue; + } + if (BaseIndex == AltIndex && BasePred != CurrentPred) { + assert(isValidForAlternation(Opcode) && + isValidForAlternation(InstOpcode) && + "Cast isn't safe for alternation, logic needs to be updated!"); + AltIndex = Cnt; + continue; + } + auto *AltInst = cast<CmpInst>(VL[AltIndex]); + CmpInst::Predicate AltPred = AltInst->getPredicate(); + if (BasePred == CurrentPred || BasePred == SwappedCurrentPred || + AltPred == CurrentPred || AltPred == SwappedCurrentPred) + continue; + } } else if (InstOpcode == Opcode || InstOpcode == AltOpcode) continue; return InstructionsState(VL[BaseIndex], nullptr, nullptr); @@ -570,7 +641,7 @@ static bool InTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst, CallInst *CI = cast<CallInst>(UserInst); Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); for (unsigned i = 0, e = CI->arg_size(); i != e; ++i) { - if (hasVectorInstrinsicScalarOpd(ID, i)) + if (isVectorIntrinsicWithScalarOpAtArg(ID, i)) return (CI->getArgOperand(i) == Scalar); } LLVM_FALLTHROUGH; @@ -666,11 +737,11 @@ static void inversePermutation(ArrayRef<unsigned> Indices, /// \returns inserting index of InsertElement or InsertValue instruction, /// using Offset as base offset for index. -static Optional<unsigned> getInsertIndex(Value *InsertInst, +static Optional<unsigned> getInsertIndex(const Value *InsertInst, unsigned Offset = 0) { int Index = Offset; - if (auto *IE = dyn_cast<InsertElementInst>(InsertInst)) { - if (auto *CI = dyn_cast<ConstantInt>(IE->getOperand(2))) { + if (const auto *IE = dyn_cast<InsertElementInst>(InsertInst)) { + if (const auto *CI = dyn_cast<ConstantInt>(IE->getOperand(2))) { auto *VT = cast<FixedVectorType>(IE->getType()); if (CI->getValue().uge(VT->getNumElements())) return None; @@ -681,13 +752,13 @@ static Optional<unsigned> getInsertIndex(Value *InsertInst, return None; } - auto *IV = cast<InsertValueInst>(InsertInst); + const auto *IV = cast<InsertValueInst>(InsertInst); Type *CurrentType = IV->getType(); for (unsigned I : IV->indices()) { - if (auto *ST = dyn_cast<StructType>(CurrentType)) { + if (const auto *ST = dyn_cast<StructType>(CurrentType)) { Index *= ST->getNumElements(); CurrentType = ST->getElementType(I); - } else if (auto *AT = dyn_cast<ArrayType>(CurrentType)) { + } else if (const auto *AT = dyn_cast<ArrayType>(CurrentType)) { Index *= AT->getNumElements(); CurrentType = AT->getElementType(); } else { @@ -698,11 +769,7 @@ static Optional<unsigned> getInsertIndex(Value *InsertInst, return Index; } -/// Reorders the list of scalars in accordance with the given \p Order and then -/// the \p Mask. \p Order - is the original order of the scalars, need to -/// reorder scalars into an unordered state at first according to the given -/// order. Then the ordered scalars are shuffled once again in accordance with -/// the provided mask. +/// Reorders the list of scalars in accordance with the given \p Mask. static void reorderScalars(SmallVectorImpl<Value *> &Scalars, ArrayRef<int> Mask) { assert(!Mask.empty() && "Expected non-empty mask."); @@ -714,6 +781,58 @@ static void reorderScalars(SmallVectorImpl<Value *> &Scalars, Scalars[Mask[I]] = Prev[I]; } +/// Checks if the provided value does not require scheduling. It does not +/// require scheduling if this is not an instruction or it is an instruction +/// that does not read/write memory and all operands are either not instructions +/// or phi nodes or instructions from different blocks. +static bool areAllOperandsNonInsts(Value *V) { + auto *I = dyn_cast<Instruction>(V); + if (!I) + return true; + return !mayHaveNonDefUseDependency(*I) && + all_of(I->operands(), [I](Value *V) { + auto *IO = dyn_cast<Instruction>(V); + if (!IO) + return true; + return isa<PHINode>(IO) || IO->getParent() != I->getParent(); + }); +} + +/// Checks if the provided value does not require scheduling. It does not +/// require scheduling if this is not an instruction or it is an instruction +/// that does not read/write memory and all users are phi nodes or instructions +/// from the different blocks. +static bool isUsedOutsideBlock(Value *V) { + auto *I = dyn_cast<Instruction>(V); + if (!I) + return true; + // Limits the number of uses to save compile time. + constexpr int UsesLimit = 8; + return !I->mayReadOrWriteMemory() && !I->hasNUsesOrMore(UsesLimit) && + all_of(I->users(), [I](User *U) { + auto *IU = dyn_cast<Instruction>(U); + if (!IU) + return true; + return IU->getParent() != I->getParent() || isa<PHINode>(IU); + }); +} + +/// Checks if the specified value does not require scheduling. It does not +/// require scheduling if all operands and all users do not need to be scheduled +/// in the current basic block. +static bool doesNotNeedToBeScheduled(Value *V) { + return areAllOperandsNonInsts(V) && isUsedOutsideBlock(V); +} + +/// Checks if the specified array of instructions does not require scheduling. +/// It is so if all either instructions have operands that do not require +/// scheduling or their users do not require scheduling since they are phis or +/// in other basic blocks. +static bool doesNotNeedToSchedule(ArrayRef<Value *> VL) { + return !VL.empty() && + (all_of(VL, isUsedOutsideBlock) || all_of(VL, areAllOperandsNonInsts)); +} + namespace slpvectorizer { /// Bottom Up SLP Vectorizer. @@ -734,8 +853,8 @@ public: TargetLibraryInfo *TLi, AAResults *Aa, LoopInfo *Li, DominatorTree *Dt, AssumptionCache *AC, DemandedBits *DB, const DataLayout *DL, OptimizationRemarkEmitter *ORE) - : F(Func), SE(Se), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt), AC(AC), - DB(DB), DL(DL), ORE(ORE), Builder(Se->getContext()) { + : BatchAA(*Aa), F(Func), SE(Se), TTI(Tti), TLI(TLi), LI(Li), + DT(Dt), AC(AC), DB(DB), DL(DL), ORE(ORE), Builder(Se->getContext()) { CodeMetrics::collectEphemeralValues(F, AC, EphValues); // Use the vector register size specified by the target unless overridden // by a command-line option. @@ -776,7 +895,10 @@ public: /// Construct a vectorizable tree that starts at \p Roots, ignoring users for /// the purpose of scheduling and extraction in the \p UserIgnoreLst. void buildTree(ArrayRef<Value *> Roots, - ArrayRef<Value *> UserIgnoreLst = None); + const SmallDenseSet<Value *> &UserIgnoreLst); + + /// Construct a vectorizable tree that starts at \p Roots. + void buildTree(ArrayRef<Value *> Roots); /// Builds external uses of the vectorized scalars, i.e. the list of /// vectorized scalars to be extracted, their lanes and their scalar users. \p @@ -797,6 +919,7 @@ public: } MinBWs.clear(); InstrElementSize.clear(); + UserIgnoreList = nullptr; } unsigned getTreeSize() const { return VectorizableTree.size(); } @@ -810,6 +933,9 @@ public: /// ExtractElement, ExtractValue), which can be part of the graph. Optional<OrdersType> findReusedOrderedScalars(const TreeEntry &TE); + /// Sort loads into increasing pointers offsets to allow greater clustering. + Optional<OrdersType> findPartiallyOrderedLoads(const TreeEntry &TE); + /// Gets reordering data for the given tree entry. If the entry is vectorized /// - just return ReorderIndices, otherwise check if the scalars can be /// reordered and return the most optimal order. @@ -924,96 +1050,18 @@ public: #endif }; - /// A helper data structure to hold the operands of a vector of instructions. - /// This supports a fixed vector length for all operand vectors. - class VLOperands { - /// For each operand we need (i) the value, and (ii) the opcode that it - /// would be attached to if the expression was in a left-linearized form. - /// This is required to avoid illegal operand reordering. - /// For example: - /// \verbatim - /// 0 Op1 - /// |/ - /// Op1 Op2 Linearized + Op2 - /// \ / ----------> |/ - /// - - - /// - /// Op1 - Op2 (0 + Op1) - Op2 - /// \endverbatim - /// - /// Value Op1 is attached to a '+' operation, and Op2 to a '-'. - /// - /// Another way to think of this is to track all the operations across the - /// path from the operand all the way to the root of the tree and to - /// calculate the operation that corresponds to this path. For example, the - /// path from Op2 to the root crosses the RHS of the '-', therefore the - /// corresponding operation is a '-' (which matches the one in the - /// linearized tree, as shown above). - /// - /// For lack of a better term, we refer to this operation as Accumulated - /// Path Operation (APO). - struct OperandData { - OperandData() = default; - OperandData(Value *V, bool APO, bool IsUsed) - : V(V), APO(APO), IsUsed(IsUsed) {} - /// The operand value. - Value *V = nullptr; - /// TreeEntries only allow a single opcode, or an alternate sequence of - /// them (e.g, +, -). Therefore, we can safely use a boolean value for the - /// APO. It is set to 'true' if 'V' is attached to an inverse operation - /// in the left-linearized form (e.g., Sub/Div), and 'false' otherwise - /// (e.g., Add/Mul) - bool APO = false; - /// Helper data for the reordering function. - bool IsUsed = false; - }; - - /// During operand reordering, we are trying to select the operand at lane - /// that matches best with the operand at the neighboring lane. Our - /// selection is based on the type of value we are looking for. For example, - /// if the neighboring lane has a load, we need to look for a load that is - /// accessing a consecutive address. These strategies are summarized in the - /// 'ReorderingMode' enumerator. - enum class ReorderingMode { - Load, ///< Matching loads to consecutive memory addresses - Opcode, ///< Matching instructions based on opcode (same or alternate) - Constant, ///< Matching constants - Splat, ///< Matching the same instruction multiple times (broadcast) - Failed, ///< We failed to create a vectorizable group - }; - - using OperandDataVec = SmallVector<OperandData, 2>; - - /// A vector of operand vectors. - SmallVector<OperandDataVec, 4> OpsVec; - + /// A helper class used for scoring candidates for two consecutive lanes. + class LookAheadHeuristics { const DataLayout &DL; ScalarEvolution &SE; const BoUpSLP &R; + int NumLanes; // Total number of lanes (aka vectorization factor). + int MaxLevel; // The maximum recursion depth for accumulating score. - /// \returns the operand data at \p OpIdx and \p Lane. - OperandData &getData(unsigned OpIdx, unsigned Lane) { - return OpsVec[OpIdx][Lane]; - } - - /// \returns the operand data at \p OpIdx and \p Lane. Const version. - const OperandData &getData(unsigned OpIdx, unsigned Lane) const { - return OpsVec[OpIdx][Lane]; - } - - /// Clears the used flag for all entries. - void clearUsed() { - for (unsigned OpIdx = 0, NumOperands = getNumOperands(); - OpIdx != NumOperands; ++OpIdx) - for (unsigned Lane = 0, NumLanes = getNumLanes(); Lane != NumLanes; - ++Lane) - OpsVec[OpIdx][Lane].IsUsed = false; - } - - /// Swap the operand at \p OpIdx1 with that one at \p OpIdx2. - void swap(unsigned OpIdx1, unsigned OpIdx2, unsigned Lane) { - std::swap(OpsVec[OpIdx1][Lane], OpsVec[OpIdx2][Lane]); - } + public: + LookAheadHeuristics(const DataLayout &DL, ScalarEvolution &SE, + const BoUpSLP &R, int NumLanes, int MaxLevel) + : DL(DL), SE(SE), R(R), NumLanes(NumLanes), MaxLevel(MaxLevel) {} // The hard-coded scores listed here are not very important, though it shall // be higher for better matches to improve the resulting cost. When @@ -1028,6 +1076,11 @@ public: /// Loads from consecutive memory addresses, e.g. load(A[i]), load(A[i+1]). static const int ScoreConsecutiveLoads = 4; + /// The same load multiple times. This should have a better score than + /// `ScoreSplat` because it in x86 for a 2-lane vector we can represent it + /// with `movddup (%reg), xmm0` which has a throughput of 0.5 versus 0.5 for + /// a vector load and 1.0 for a broadcast. + static const int ScoreSplatLoads = 3; /// Loads from reversed memory addresses, e.g. load(A[i+1]), load(A[i]). static const int ScoreReversedLoads = 3; /// ExtractElementInst from same vector and consecutive indexes. @@ -1046,43 +1099,67 @@ public: static const int ScoreUndef = 1; /// Score for failing to find a decent match. static const int ScoreFail = 0; - /// User exteranl to the vectorized code. - static const int ExternalUseCost = 1; - /// The user is internal but in a different lane. - static const int UserInDiffLaneCost = ExternalUseCost; + /// Score if all users are vectorized. + static const int ScoreAllUserVectorized = 1; /// \returns the score of placing \p V1 and \p V2 in consecutive lanes. - static int getShallowScore(Value *V1, Value *V2, const DataLayout &DL, - ScalarEvolution &SE, int NumLanes) { - if (V1 == V2) - return VLOperands::ScoreSplat; + /// \p U1 and \p U2 are the users of \p V1 and \p V2. + /// Also, checks if \p V1 and \p V2 are compatible with instructions in \p + /// MainAltOps. + int getShallowScore(Value *V1, Value *V2, Instruction *U1, Instruction *U2, + ArrayRef<Value *> MainAltOps) const { + if (V1 == V2) { + if (isa<LoadInst>(V1)) { + // Retruns true if the users of V1 and V2 won't need to be extracted. + auto AllUsersAreInternal = [U1, U2, this](Value *V1, Value *V2) { + // Bail out if we have too many uses to save compilation time. + static constexpr unsigned Limit = 8; + if (V1->hasNUsesOrMore(Limit) || V2->hasNUsesOrMore(Limit)) + return false; + + auto AllUsersVectorized = [U1, U2, this](Value *V) { + return llvm::all_of(V->users(), [U1, U2, this](Value *U) { + return U == U1 || U == U2 || R.getTreeEntry(U) != nullptr; + }); + }; + return AllUsersVectorized(V1) && AllUsersVectorized(V2); + }; + // A broadcast of a load can be cheaper on some targets. + if (R.TTI->isLegalBroadcastLoad(V1->getType(), + ElementCount::getFixed(NumLanes)) && + ((int)V1->getNumUses() == NumLanes || + AllUsersAreInternal(V1, V2))) + return LookAheadHeuristics::ScoreSplatLoads; + } + return LookAheadHeuristics::ScoreSplat; + } auto *LI1 = dyn_cast<LoadInst>(V1); auto *LI2 = dyn_cast<LoadInst>(V2); if (LI1 && LI2) { if (LI1->getParent() != LI2->getParent()) - return VLOperands::ScoreFail; + return LookAheadHeuristics::ScoreFail; Optional<int> Dist = getPointersDiff( LI1->getType(), LI1->getPointerOperand(), LI2->getType(), LI2->getPointerOperand(), DL, SE, /*StrictCheck=*/true); - if (!Dist) - return VLOperands::ScoreFail; + if (!Dist || *Dist == 0) + return LookAheadHeuristics::ScoreFail; // The distance is too large - still may be profitable to use masked // loads/gathers. if (std::abs(*Dist) > NumLanes / 2) - return VLOperands::ScoreAltOpcodes; + return LookAheadHeuristics::ScoreAltOpcodes; // This still will detect consecutive loads, but we might have "holes" // in some cases. It is ok for non-power-2 vectorization and may produce // better results. It should not affect current vectorization. - return (*Dist > 0) ? VLOperands::ScoreConsecutiveLoads - : VLOperands::ScoreReversedLoads; + return (*Dist > 0) ? LookAheadHeuristics::ScoreConsecutiveLoads + : LookAheadHeuristics::ScoreReversedLoads; } auto *C1 = dyn_cast<Constant>(V1); auto *C2 = dyn_cast<Constant>(V2); if (C1 && C2) - return VLOperands::ScoreConstants; + return LookAheadHeuristics::ScoreConstants; // Extracts from consecutive indexes of the same vector better score as // the extracts could be optimized away. @@ -1091,7 +1168,7 @@ public: if (match(V1, m_ExtractElt(m_Value(EV1), m_ConstantInt(Ex1Idx)))) { // Undefs are always profitable for extractelements. if (isa<UndefValue>(V2)) - return VLOperands::ScoreConsecutiveExtracts; + return LookAheadHeuristics::ScoreConsecutiveExtracts; Value *EV2 = nullptr; ConstantInt *Ex2Idx = nullptr; if (match(V2, @@ -1099,108 +1176,62 @@ public: m_Undef())))) { // Undefs are always profitable for extractelements. if (!Ex2Idx) - return VLOperands::ScoreConsecutiveExtracts; + return LookAheadHeuristics::ScoreConsecutiveExtracts; if (isUndefVector(EV2) && EV2->getType() == EV1->getType()) - return VLOperands::ScoreConsecutiveExtracts; + return LookAheadHeuristics::ScoreConsecutiveExtracts; if (EV2 == EV1) { int Idx1 = Ex1Idx->getZExtValue(); int Idx2 = Ex2Idx->getZExtValue(); int Dist = Idx2 - Idx1; // The distance is too large - still may be profitable to use // shuffles. + if (std::abs(Dist) == 0) + return LookAheadHeuristics::ScoreSplat; if (std::abs(Dist) > NumLanes / 2) - return VLOperands::ScoreAltOpcodes; - return (Dist > 0) ? VLOperands::ScoreConsecutiveExtracts - : VLOperands::ScoreReversedExtracts; + return LookAheadHeuristics::ScoreSameOpcode; + return (Dist > 0) ? LookAheadHeuristics::ScoreConsecutiveExtracts + : LookAheadHeuristics::ScoreReversedExtracts; } + return LookAheadHeuristics::ScoreAltOpcodes; } + return LookAheadHeuristics::ScoreFail; } auto *I1 = dyn_cast<Instruction>(V1); auto *I2 = dyn_cast<Instruction>(V2); if (I1 && I2) { if (I1->getParent() != I2->getParent()) - return VLOperands::ScoreFail; - InstructionsState S = getSameOpcode({I1, I2}); + return LookAheadHeuristics::ScoreFail; + SmallVector<Value *, 4> Ops(MainAltOps.begin(), MainAltOps.end()); + Ops.push_back(I1); + Ops.push_back(I2); + InstructionsState S = getSameOpcode(Ops); // Note: Only consider instructions with <= 2 operands to avoid // complexity explosion. - if (S.getOpcode() && S.MainOp->getNumOperands() <= 2) - return S.isAltShuffle() ? VLOperands::ScoreAltOpcodes - : VLOperands::ScoreSameOpcode; + if (S.getOpcode() && + (S.MainOp->getNumOperands() <= 2 || !MainAltOps.empty() || + !S.isAltShuffle()) && + all_of(Ops, [&S](Value *V) { + return cast<Instruction>(V)->getNumOperands() == + S.MainOp->getNumOperands(); + })) + return S.isAltShuffle() ? LookAheadHeuristics::ScoreAltOpcodes + : LookAheadHeuristics::ScoreSameOpcode; } if (isa<UndefValue>(V2)) - return VLOperands::ScoreUndef; - - return VLOperands::ScoreFail; - } - - /// Holds the values and their lanes that are taking part in the look-ahead - /// score calculation. This is used in the external uses cost calculation. - /// Need to hold all the lanes in case of splat/broadcast at least to - /// correctly check for the use in the different lane. - SmallDenseMap<Value *, SmallSet<int, 4>> InLookAheadValues; - - /// \returns the additional cost due to uses of \p LHS and \p RHS that are - /// either external to the vectorized code, or require shuffling. - int getExternalUsesCost(const std::pair<Value *, int> &LHS, - const std::pair<Value *, int> &RHS) { - int Cost = 0; - std::array<std::pair<Value *, int>, 2> Values = {{LHS, RHS}}; - for (int Idx = 0, IdxE = Values.size(); Idx != IdxE; ++Idx) { - Value *V = Values[Idx].first; - if (isa<Constant>(V)) { - // Since this is a function pass, it doesn't make semantic sense to - // walk the users of a subclass of Constant. The users could be in - // another function, or even another module that happens to be in - // the same LLVMContext. - continue; - } + return LookAheadHeuristics::ScoreUndef; - // Calculate the absolute lane, using the minimum relative lane of LHS - // and RHS as base and Idx as the offset. - int Ln = std::min(LHS.second, RHS.second) + Idx; - assert(Ln >= 0 && "Bad lane calculation"); - unsigned UsersBudget = LookAheadUsersBudget; - for (User *U : V->users()) { - if (const TreeEntry *UserTE = R.getTreeEntry(U)) { - // The user is in the VectorizableTree. Check if we need to insert. - int UserLn = UserTE->findLaneForValue(U); - assert(UserLn >= 0 && "Bad lane"); - // If the values are different, check just the line of the current - // value. If the values are the same, need to add UserInDiffLaneCost - // only if UserLn does not match both line numbers. - if ((LHS.first != RHS.first && UserLn != Ln) || - (LHS.first == RHS.first && UserLn != LHS.second && - UserLn != RHS.second)) { - Cost += UserInDiffLaneCost; - break; - } - } else { - // Check if the user is in the look-ahead code. - auto It2 = InLookAheadValues.find(U); - if (It2 != InLookAheadValues.end()) { - // The user is in the look-ahead code. Check the lane. - if (!It2->getSecond().contains(Ln)) { - Cost += UserInDiffLaneCost; - break; - } - } else { - // The user is neither in SLP tree nor in the look-ahead code. - Cost += ExternalUseCost; - break; - } - } - // Limit the number of visited uses to cap compilation time. - if (--UsersBudget == 0) - break; - } - } - return Cost; + return LookAheadHeuristics::ScoreFail; } - /// Go through the operands of \p LHS and \p RHS recursively until \p - /// MaxLevel, and return the cummulative score. For example: + /// Go through the operands of \p LHS and \p RHS recursively until + /// MaxLevel, and return the cummulative score. \p U1 and \p U2 are + /// the users of \p LHS and \p RHS (that is \p LHS and \p RHS are operands + /// of \p U1 and \p U2), except at the beginning of the recursion where + /// these are set to nullptr. + /// + /// For example: /// \verbatim /// A[0] B[0] A[1] B[1] C[0] D[0] B[1] A[1] /// \ / \ / \ / \ / @@ -1211,8 +1242,8 @@ public: /// each level recursively, accumulating the score. It starts from matching /// the additions at level 0, then moves on to the loads (level 1). The /// score of G1 and G2 is higher than G1 and G3, because {A[0],A[1]} and - /// {B[0],B[1]} match with VLOperands::ScoreConsecutiveLoads, while - /// {A[0],C[0]} has a score of VLOperands::ScoreFail. + /// {B[0],B[1]} match with LookAheadHeuristics::ScoreConsecutiveLoads, while + /// {A[0],C[0]} has a score of LookAheadHeuristics::ScoreFail. /// Please note that the order of the operands does not matter, as we /// evaluate the score of all profitable combinations of operands. In /// other words the score of G1 and G4 is the same as G1 and G2. This @@ -1220,18 +1251,13 @@ public: /// Look-ahead SLP: Auto-vectorization in the presence of commutative /// operations, CGO 2018 by Vasileios Porpodas, Rodrigo C. O. Rocha, /// LuÃs F. W. Góes - int getScoreAtLevelRec(const std::pair<Value *, int> &LHS, - const std::pair<Value *, int> &RHS, int CurrLevel, - int MaxLevel) { + int getScoreAtLevelRec(Value *LHS, Value *RHS, Instruction *U1, + Instruction *U2, int CurrLevel, + ArrayRef<Value *> MainAltOps) const { - Value *V1 = LHS.first; - Value *V2 = RHS.first; // Get the shallow score of V1 and V2. - int ShallowScoreAtThisLevel = std::max( - (int)ScoreFail, getShallowScore(V1, V2, DL, SE, getNumLanes()) - - getExternalUsesCost(LHS, RHS)); - int Lane1 = LHS.second; - int Lane2 = RHS.second; + int ShallowScoreAtThisLevel = + getShallowScore(LHS, RHS, U1, U2, MainAltOps); // If reached MaxLevel, // or if V1 and V2 are not instructions, @@ -1239,20 +1265,17 @@ public: // or if they are not consecutive, // or if profitable to vectorize loads or extractelements, early return // the current cost. - auto *I1 = dyn_cast<Instruction>(V1); - auto *I2 = dyn_cast<Instruction>(V2); + auto *I1 = dyn_cast<Instruction>(LHS); + auto *I2 = dyn_cast<Instruction>(RHS); if (CurrLevel == MaxLevel || !(I1 && I2) || I1 == I2 || - ShallowScoreAtThisLevel == VLOperands::ScoreFail || + ShallowScoreAtThisLevel == LookAheadHeuristics::ScoreFail || (((isa<LoadInst>(I1) && isa<LoadInst>(I2)) || + (I1->getNumOperands() > 2 && I2->getNumOperands() > 2) || (isa<ExtractElementInst>(I1) && isa<ExtractElementInst>(I2))) && ShallowScoreAtThisLevel)) return ShallowScoreAtThisLevel; assert(I1 && I2 && "Should have early exited."); - // Keep track of in-tree values for determining the external-use cost. - InLookAheadValues[V1].insert(Lane1); - InLookAheadValues[V2].insert(Lane2); - // Contains the I2 operand indexes that got matched with I1 operands. SmallSet<unsigned, 4> Op2Used; @@ -1275,11 +1298,12 @@ public: if (Op2Used.count(OpIdx2)) continue; // Recursively calculate the cost at each level - int TmpScore = getScoreAtLevelRec({I1->getOperand(OpIdx1), Lane1}, - {I2->getOperand(OpIdx2), Lane2}, - CurrLevel + 1, MaxLevel); + int TmpScore = + getScoreAtLevelRec(I1->getOperand(OpIdx1), I2->getOperand(OpIdx2), + I1, I2, CurrLevel + 1, None); // Look for the best score. - if (TmpScore > VLOperands::ScoreFail && TmpScore > MaxTmpScore) { + if (TmpScore > LookAheadHeuristics::ScoreFail && + TmpScore > MaxTmpScore) { MaxTmpScore = TmpScore; MaxOpIdx2 = OpIdx2; FoundBest = true; @@ -1293,24 +1317,213 @@ public: } return ShallowScoreAtThisLevel; } + }; + /// A helper data structure to hold the operands of a vector of instructions. + /// This supports a fixed vector length for all operand vectors. + class VLOperands { + /// For each operand we need (i) the value, and (ii) the opcode that it + /// would be attached to if the expression was in a left-linearized form. + /// This is required to avoid illegal operand reordering. + /// For example: + /// \verbatim + /// 0 Op1 + /// |/ + /// Op1 Op2 Linearized + Op2 + /// \ / ----------> |/ + /// - - + /// + /// Op1 - Op2 (0 + Op1) - Op2 + /// \endverbatim + /// + /// Value Op1 is attached to a '+' operation, and Op2 to a '-'. + /// + /// Another way to think of this is to track all the operations across the + /// path from the operand all the way to the root of the tree and to + /// calculate the operation that corresponds to this path. For example, the + /// path from Op2 to the root crosses the RHS of the '-', therefore the + /// corresponding operation is a '-' (which matches the one in the + /// linearized tree, as shown above). + /// + /// For lack of a better term, we refer to this operation as Accumulated + /// Path Operation (APO). + struct OperandData { + OperandData() = default; + OperandData(Value *V, bool APO, bool IsUsed) + : V(V), APO(APO), IsUsed(IsUsed) {} + /// The operand value. + Value *V = nullptr; + /// TreeEntries only allow a single opcode, or an alternate sequence of + /// them (e.g, +, -). Therefore, we can safely use a boolean value for the + /// APO. It is set to 'true' if 'V' is attached to an inverse operation + /// in the left-linearized form (e.g., Sub/Div), and 'false' otherwise + /// (e.g., Add/Mul) + bool APO = false; + /// Helper data for the reordering function. + bool IsUsed = false; + }; + + /// During operand reordering, we are trying to select the operand at lane + /// that matches best with the operand at the neighboring lane. Our + /// selection is based on the type of value we are looking for. For example, + /// if the neighboring lane has a load, we need to look for a load that is + /// accessing a consecutive address. These strategies are summarized in the + /// 'ReorderingMode' enumerator. + enum class ReorderingMode { + Load, ///< Matching loads to consecutive memory addresses + Opcode, ///< Matching instructions based on opcode (same or alternate) + Constant, ///< Matching constants + Splat, ///< Matching the same instruction multiple times (broadcast) + Failed, ///< We failed to create a vectorizable group + }; + + using OperandDataVec = SmallVector<OperandData, 2>; + + /// A vector of operand vectors. + SmallVector<OperandDataVec, 4> OpsVec; + + const DataLayout &DL; + ScalarEvolution &SE; + const BoUpSLP &R; + + /// \returns the operand data at \p OpIdx and \p Lane. + OperandData &getData(unsigned OpIdx, unsigned Lane) { + return OpsVec[OpIdx][Lane]; + } + + /// \returns the operand data at \p OpIdx and \p Lane. Const version. + const OperandData &getData(unsigned OpIdx, unsigned Lane) const { + return OpsVec[OpIdx][Lane]; + } + + /// Clears the used flag for all entries. + void clearUsed() { + for (unsigned OpIdx = 0, NumOperands = getNumOperands(); + OpIdx != NumOperands; ++OpIdx) + for (unsigned Lane = 0, NumLanes = getNumLanes(); Lane != NumLanes; + ++Lane) + OpsVec[OpIdx][Lane].IsUsed = false; + } + + /// Swap the operand at \p OpIdx1 with that one at \p OpIdx2. + void swap(unsigned OpIdx1, unsigned OpIdx2, unsigned Lane) { + std::swap(OpsVec[OpIdx1][Lane], OpsVec[OpIdx2][Lane]); + } + + /// \param Lane lane of the operands under analysis. + /// \param OpIdx operand index in \p Lane lane we're looking the best + /// candidate for. + /// \param Idx operand index of the current candidate value. + /// \returns The additional score due to possible broadcasting of the + /// elements in the lane. It is more profitable to have power-of-2 unique + /// elements in the lane, it will be vectorized with higher probability + /// after removing duplicates. Currently the SLP vectorizer supports only + /// vectorization of the power-of-2 number of unique scalars. + int getSplatScore(unsigned Lane, unsigned OpIdx, unsigned Idx) const { + Value *IdxLaneV = getData(Idx, Lane).V; + if (!isa<Instruction>(IdxLaneV) || IdxLaneV == getData(OpIdx, Lane).V) + return 0; + SmallPtrSet<Value *, 4> Uniques; + for (unsigned Ln = 0, E = getNumLanes(); Ln < E; ++Ln) { + if (Ln == Lane) + continue; + Value *OpIdxLnV = getData(OpIdx, Ln).V; + if (!isa<Instruction>(OpIdxLnV)) + return 0; + Uniques.insert(OpIdxLnV); + } + int UniquesCount = Uniques.size(); + int UniquesCntWithIdxLaneV = + Uniques.contains(IdxLaneV) ? UniquesCount : UniquesCount + 1; + Value *OpIdxLaneV = getData(OpIdx, Lane).V; + int UniquesCntWithOpIdxLaneV = + Uniques.contains(OpIdxLaneV) ? UniquesCount : UniquesCount + 1; + if (UniquesCntWithIdxLaneV == UniquesCntWithOpIdxLaneV) + return 0; + return (PowerOf2Ceil(UniquesCntWithOpIdxLaneV) - + UniquesCntWithOpIdxLaneV) - + (PowerOf2Ceil(UniquesCntWithIdxLaneV) - UniquesCntWithIdxLaneV); + } + + /// \param Lane lane of the operands under analysis. + /// \param OpIdx operand index in \p Lane lane we're looking the best + /// candidate for. + /// \param Idx operand index of the current candidate value. + /// \returns The additional score for the scalar which users are all + /// vectorized. + int getExternalUseScore(unsigned Lane, unsigned OpIdx, unsigned Idx) const { + Value *IdxLaneV = getData(Idx, Lane).V; + Value *OpIdxLaneV = getData(OpIdx, Lane).V; + // Do not care about number of uses for vector-like instructions + // (extractelement/extractvalue with constant indices), they are extracts + // themselves and already externally used. Vectorization of such + // instructions does not add extra extractelement instruction, just may + // remove it. + if (isVectorLikeInstWithConstOps(IdxLaneV) && + isVectorLikeInstWithConstOps(OpIdxLaneV)) + return LookAheadHeuristics::ScoreAllUserVectorized; + auto *IdxLaneI = dyn_cast<Instruction>(IdxLaneV); + if (!IdxLaneI || !isa<Instruction>(OpIdxLaneV)) + return 0; + return R.areAllUsersVectorized(IdxLaneI, None) + ? LookAheadHeuristics::ScoreAllUserVectorized + : 0; + } + + /// Score scaling factor for fully compatible instructions but with + /// different number of external uses. Allows better selection of the + /// instructions with less external uses. + static const int ScoreScaleFactor = 10; /// \Returns the look-ahead score, which tells us how much the sub-trees /// rooted at \p LHS and \p RHS match, the more they match the higher the /// score. This helps break ties in an informed way when we cannot decide on /// the order of the operands by just considering the immediate /// predecessors. - int getLookAheadScore(const std::pair<Value *, int> &LHS, - const std::pair<Value *, int> &RHS) { - InLookAheadValues.clear(); - return getScoreAtLevelRec(LHS, RHS, 1, LookAheadMaxDepth); + int getLookAheadScore(Value *LHS, Value *RHS, ArrayRef<Value *> MainAltOps, + int Lane, unsigned OpIdx, unsigned Idx, + bool &IsUsed) { + LookAheadHeuristics LookAhead(DL, SE, R, getNumLanes(), + LookAheadMaxDepth); + // Keep track of the instruction stack as we recurse into the operands + // during the look-ahead score exploration. + int Score = + LookAhead.getScoreAtLevelRec(LHS, RHS, /*U1=*/nullptr, /*U2=*/nullptr, + /*CurrLevel=*/1, MainAltOps); + if (Score) { + int SplatScore = getSplatScore(Lane, OpIdx, Idx); + if (Score <= -SplatScore) { + // Set the minimum score for splat-like sequence to avoid setting + // failed state. + Score = 1; + } else { + Score += SplatScore; + // Scale score to see the difference between different operands + // and similar operands but all vectorized/not all vectorized + // uses. It does not affect actual selection of the best + // compatible operand in general, just allows to select the + // operand with all vectorized uses. + Score *= ScoreScaleFactor; + Score += getExternalUseScore(Lane, OpIdx, Idx); + IsUsed = true; + } + } + return Score; } + /// Best defined scores per lanes between the passes. Used to choose the + /// best operand (with the highest score) between the passes. + /// The key - {Operand Index, Lane}. + /// The value - the best score between the passes for the lane and the + /// operand. + SmallDenseMap<std::pair<unsigned, unsigned>, unsigned, 8> + BestScoresPerLanes; + // Search all operands in Ops[*][Lane] for the one that matches best // Ops[OpIdx][LastLane] and return its opreand index. // If no good match can be found, return None. - Optional<unsigned> - getBestOperand(unsigned OpIdx, int Lane, int LastLane, - ArrayRef<ReorderingMode> ReorderingModes) { + Optional<unsigned> getBestOperand(unsigned OpIdx, int Lane, int LastLane, + ArrayRef<ReorderingMode> ReorderingModes, + ArrayRef<Value *> MainAltOps) { unsigned NumOperands = getNumOperands(); // The operand of the previous lane at OpIdx. @@ -1318,6 +1531,8 @@ public: // Our strategy mode for OpIdx. ReorderingMode RMode = ReorderingModes[OpIdx]; + if (RMode == ReorderingMode::Failed) + return None; // The linearized opcode of the operand at OpIdx, Lane. bool OpIdxAPO = getData(OpIdx, Lane).APO; @@ -1329,7 +1544,15 @@ public: Optional<unsigned> Idx = None; unsigned Score = 0; } BestOp; - + BestOp.Score = + BestScoresPerLanes.try_emplace(std::make_pair(OpIdx, Lane), 0) + .first->second; + + // Track if the operand must be marked as used. If the operand is set to + // Score 1 explicitly (because of non power-of-2 unique scalars, we may + // want to reestimate the operands again on the following iterations). + bool IsUsed = + RMode == ReorderingMode::Splat || RMode == ReorderingMode::Constant; // Iterate through all unused operands and look for the best. for (unsigned Idx = 0; Idx != NumOperands; ++Idx) { // Get the operand at Idx and Lane. @@ -1355,11 +1578,12 @@ public: bool LeftToRight = Lane > LastLane; Value *OpLeft = (LeftToRight) ? OpLastLane : Op; Value *OpRight = (LeftToRight) ? Op : OpLastLane; - unsigned Score = - getLookAheadScore({OpLeft, LastLane}, {OpRight, Lane}); - if (Score > BestOp.Score) { + int Score = getLookAheadScore(OpLeft, OpRight, MainAltOps, Lane, + OpIdx, Idx, IsUsed); + if (Score > static_cast<int>(BestOp.Score)) { BestOp.Idx = Idx; BestOp.Score = Score; + BestScoresPerLanes[std::make_pair(OpIdx, Lane)] = Score; } break; } @@ -1368,12 +1592,12 @@ public: BestOp.Idx = Idx; break; case ReorderingMode::Failed: - return None; + llvm_unreachable("Not expected Failed reordering mode."); } } if (BestOp.Idx) { - getData(BestOp.Idx.getValue(), Lane).IsUsed = true; + getData(*BestOp.Idx, Lane).IsUsed = IsUsed; return BestOp.Idx; } // If we could not find a good match return None. @@ -1690,6 +1914,10 @@ public: // rest of the lanes. We are visiting the nodes in a circular fashion, // using FirstLane as the center point and increasing the radius // distance. + SmallVector<SmallVector<Value *, 2>> MainAltOps(NumOperands); + for (unsigned I = 0; I < NumOperands; ++I) + MainAltOps[I].push_back(getData(I, FirstLane).V); + for (unsigned Distance = 1; Distance != NumLanes; ++Distance) { // Visit the lane on the right and then the lane on the left. for (int Direction : {+1, -1}) { @@ -1702,21 +1930,29 @@ public: // Look for a good match for each operand. for (unsigned OpIdx = 0; OpIdx != NumOperands; ++OpIdx) { // Search for the operand that matches SortedOps[OpIdx][Lane-1]. - Optional<unsigned> BestIdx = - getBestOperand(OpIdx, Lane, LastLane, ReorderingModes); + Optional<unsigned> BestIdx = getBestOperand( + OpIdx, Lane, LastLane, ReorderingModes, MainAltOps[OpIdx]); // By not selecting a value, we allow the operands that follow to // select a better matching value. We will get a non-null value in // the next run of getBestOperand(). if (BestIdx) { // Swap the current operand with the one returned by // getBestOperand(). - swap(OpIdx, BestIdx.getValue(), Lane); + swap(OpIdx, *BestIdx, Lane); } else { // We failed to find a best operand, set mode to 'Failed'. ReorderingModes[OpIdx] = ReorderingMode::Failed; // Enable the second pass. StrategyFailed = true; } + // Try to get the alternate opcode and follow it during analysis. + if (MainAltOps[OpIdx].size() != 2) { + OperandData &AltOp = getData(OpIdx, Lane); + InstructionsState OpS = + getSameOpcode({MainAltOps[OpIdx].front(), AltOp.V}); + if (OpS.getOpcode() && OpS.isAltShuffle()) + MainAltOps[OpIdx].push_back(AltOp.V); + } } } } @@ -1780,15 +2016,109 @@ public: #endif }; + /// Evaluate each pair in \p Candidates and return index into \p Candidates + /// for a pair which have highest score deemed to have best chance to form + /// root of profitable tree to vectorize. Return None if no candidate scored + /// above the LookAheadHeuristics::ScoreFail. + /// \param Limit Lower limit of the cost, considered to be good enough score. + Optional<int> + findBestRootPair(ArrayRef<std::pair<Value *, Value *>> Candidates, + int Limit = LookAheadHeuristics::ScoreFail) { + LookAheadHeuristics LookAhead(*DL, *SE, *this, /*NumLanes=*/2, + RootLookAheadMaxDepth); + int BestScore = Limit; + Optional<int> Index = None; + for (int I : seq<int>(0, Candidates.size())) { + int Score = LookAhead.getScoreAtLevelRec(Candidates[I].first, + Candidates[I].second, + /*U1=*/nullptr, /*U2=*/nullptr, + /*Level=*/1, None); + if (Score > BestScore) { + BestScore = Score; + Index = I; + } + } + return Index; + } + /// Checks if the instruction is marked for deletion. bool isDeleted(Instruction *I) const { return DeletedInstructions.count(I); } - /// Marks values operands for later deletion by replacing them with Undefs. - void eraseInstructions(ArrayRef<Value *> AV); + /// Removes an instruction from its block and eventually deletes it. + /// It's like Instruction::eraseFromParent() except that the actual deletion + /// is delayed until BoUpSLP is destructed. + void eraseInstruction(Instruction *I) { + DeletedInstructions.insert(I); + } + + /// Checks if the instruction was already analyzed for being possible + /// reduction root. + bool isAnalyzedReductionRoot(Instruction *I) const { + return AnalyzedReductionsRoots.count(I); + } + /// Register given instruction as already analyzed for being possible + /// reduction root. + void analyzedReductionRoot(Instruction *I) { + AnalyzedReductionsRoots.insert(I); + } + /// Checks if the provided list of reduced values was checked already for + /// vectorization. + bool areAnalyzedReductionVals(ArrayRef<Value *> VL) { + return AnalyzedReductionVals.contains(hash_value(VL)); + } + /// Adds the list of reduced values to list of already checked values for the + /// vectorization. + void analyzedReductionVals(ArrayRef<Value *> VL) { + AnalyzedReductionVals.insert(hash_value(VL)); + } + /// Clear the list of the analyzed reduction root instructions. + void clearReductionData() { + AnalyzedReductionsRoots.clear(); + AnalyzedReductionVals.clear(); + } + /// Checks if the given value is gathered in one of the nodes. + bool isAnyGathered(const SmallDenseSet<Value *> &Vals) const { + return any_of(MustGather, [&](Value *V) { return Vals.contains(V); }); + } ~BoUpSLP(); private: + /// Check if the operands on the edges \p Edges of the \p UserTE allows + /// reordering (i.e. the operands can be reordered because they have only one + /// user and reordarable). + /// \param ReorderableGathers List of all gather nodes that require reordering + /// (e.g., gather of extractlements or partially vectorizable loads). + /// \param GatherOps List of gather operand nodes for \p UserTE that require + /// reordering, subset of \p NonVectorized. + bool + canReorderOperands(TreeEntry *UserTE, + SmallVectorImpl<std::pair<unsigned, TreeEntry *>> &Edges, + ArrayRef<TreeEntry *> ReorderableGathers, + SmallVectorImpl<TreeEntry *> &GatherOps); + + /// Returns vectorized operand \p OpIdx of the node \p UserTE from the graph, + /// if any. If it is not vectorized (gather node), returns nullptr. + TreeEntry *getVectorizedOperand(TreeEntry *UserTE, unsigned OpIdx) { + ArrayRef<Value *> VL = UserTE->getOperand(OpIdx); + TreeEntry *TE = nullptr; + const auto *It = find_if(VL, [this, &TE](Value *V) { + TE = getTreeEntry(V); + return TE; + }); + if (It != VL.end() && TE->isSame(VL)) + return TE; + return nullptr; + } + + /// Returns vectorized operand \p OpIdx of the node \p UserTE from the graph, + /// if any. If it is not vectorized (gather node), returns nullptr. + const TreeEntry *getVectorizedOperand(const TreeEntry *UserTE, + unsigned OpIdx) const { + return const_cast<BoUpSLP *>(this)->getVectorizedOperand( + const_cast<TreeEntry *>(UserTE), OpIdx); + } + /// Checks if all users of \p I are the part of the vectorization tree. bool areAllUsersVectorized(Instruction *I, ArrayRef<Value *> VectorizedVals) const; @@ -1815,12 +2145,17 @@ private: /// Vectorize a single entry in the tree, starting in \p VL. Value *vectorizeTree(ArrayRef<Value *> VL); + /// Create a new vector from a list of scalar values. Produces a sequence + /// which exploits values reused across lanes, and arranges the inserts + /// for ease of later optimization. + Value *createBuildVector(ArrayRef<Value *> VL); + /// \returns the scalarization cost for this type. Scalarization in this /// context means the creation of vectors from a group of scalars. If \p /// NeedToShuffle is true, need to add a cost of reshuffling some of the /// vector elements. InstructionCost getGatherCost(FixedVectorType *Ty, - const DenseSet<unsigned> &ShuffledIndices, + const APInt &ShuffledIndices, bool NeedToShuffle) const; /// Checks if the gathered \p VL can be represented as shuffle(s) of previous @@ -1855,6 +2190,29 @@ private: const DataLayout &DL, ScalarEvolution &SE, const BoUpSLP &R); + + /// Helper for `findExternalStoreUsersReorderIndices()`. It iterates over the + /// users of \p TE and collects the stores. It returns the map from the store + /// pointers to the collected stores. + DenseMap<Value *, SmallVector<StoreInst *, 4>> + collectUserStores(const BoUpSLP::TreeEntry *TE) const; + + /// Helper for `findExternalStoreUsersReorderIndices()`. It checks if the + /// stores in \p StoresVec can for a vector instruction. If so it returns true + /// and populates \p ReorderIndices with the shuffle indices of the the stores + /// when compared to the sorted vector. + bool CanFormVector(const SmallVector<StoreInst *, 4> &StoresVec, + OrdersType &ReorderIndices) const; + + /// Iterates through the users of \p TE, looking for scalar stores that can be + /// potentially vectorized in a future SLP-tree. If found, it keeps track of + /// their order and builds an order index vector for each store bundle. It + /// returns all these order vectors found. + /// We run this after the tree has formed, otherwise we may come across user + /// instructions that are not yet in the tree. + SmallVector<OrdersType, 1> + findExternalStoreUsersReorderIndices(TreeEntry *TE) const; + struct TreeEntry { using VecTreeTy = SmallVector<std::unique_ptr<TreeEntry>, 8>; TreeEntry(VecTreeTy &Container) : Container(Container) {} @@ -2199,15 +2557,21 @@ private: ScalarToTreeEntry[V] = Last; } // Update the scheduler bundle to point to this TreeEntry. - unsigned Lane = 0; - for (ScheduleData *BundleMember = Bundle.getValue(); BundleMember; - BundleMember = BundleMember->NextInBundle) { - BundleMember->TE = Last; - BundleMember->Lane = Lane; - ++Lane; - } - assert((!Bundle.getValue() || Lane == VL.size()) && + ScheduleData *BundleMember = *Bundle; + assert((BundleMember || isa<PHINode>(S.MainOp) || + isVectorLikeInstWithConstOps(S.MainOp) || + doesNotNeedToSchedule(VL)) && "Bundle and VL out of sync"); + if (BundleMember) { + for (Value *V : VL) { + if (doesNotNeedToBeScheduled(V)) + continue; + assert(BundleMember && "Unexpected end of bundle."); + BundleMember->TE = Last; + BundleMember = BundleMember->NextInBundle; + } + } + assert(!BundleMember && "Bundle and VL out of sync"); } else { MustGather.insert(VL.begin(), VL.end()); } @@ -2241,7 +2605,7 @@ private: /// Maps a specific scalar to its tree entry. SmallDenseMap<Value*, TreeEntry *> ScalarToTreeEntry; - /// Maps a value to the proposed vectorizable size. + /// Maps a value to the proposed vectorizable size. SmallDenseMap<Value *, unsigned> InstrElementSize; /// A list of scalars that we found that we need to keep as scalars. @@ -2272,12 +2636,12 @@ private: // First check if the result is already in the cache. AliasCacheKey key = std::make_pair(Inst1, Inst2); Optional<bool> &result = AliasCache[key]; - if (result.hasValue()) { + if (result) { return result.getValue(); } bool aliased = true; if (Loc1.Ptr && isSimple(Inst1)) - aliased = isModOrRefSet(AA->getModRefInfo(Inst2, Loc1)); + aliased = isModOrRefSet(BatchAA.getModRefInfo(Inst2, Loc1)); // Store the result in the cache. result = aliased; return aliased; @@ -2289,20 +2653,23 @@ private: /// TODO: consider moving this to the AliasAnalysis itself. DenseMap<AliasCacheKey, Optional<bool>> AliasCache; - /// Removes an instruction from its block and eventually deletes it. - /// It's like Instruction::eraseFromParent() except that the actual deletion - /// is delayed until BoUpSLP is destructed. - /// This is required to ensure that there are no incorrect collisions in the - /// AliasCache, which can happen if a new instruction is allocated at the - /// same address as a previously deleted instruction. - void eraseInstruction(Instruction *I, bool ReplaceOpsWithUndef = false) { - auto It = DeletedInstructions.try_emplace(I, ReplaceOpsWithUndef).first; - It->getSecond() = It->getSecond() && ReplaceOpsWithUndef; - } + // Cache for pointerMayBeCaptured calls inside AA. This is preserved + // globally through SLP because we don't perform any action which + // invalidates capture results. + BatchAAResults BatchAA; /// Temporary store for deleted instructions. Instructions will be deleted - /// eventually when the BoUpSLP is destructed. - DenseMap<Instruction *, bool> DeletedInstructions; + /// eventually when the BoUpSLP is destructed. The deferral is required to + /// ensure that there are no incorrect collisions in the AliasCache, which + /// can happen if a new instruction is allocated at the same address as a + /// previously deleted instruction. + DenseSet<Instruction *> DeletedInstructions; + + /// Set of the instruction, being analyzed already for reductions. + SmallPtrSet<Instruction *, 16> AnalyzedReductionsRoots; + + /// Set of hashes for the list of reduction values already being analyzed. + DenseSet<size_t> AnalyzedReductionVals; /// A list of values that need to extracted out of the tree. /// This list holds pairs of (Internal Scalar : External User). External User @@ -2336,14 +2703,39 @@ private: NextLoadStore = nullptr; IsScheduled = false; SchedulingRegionID = BlockSchedulingRegionID; - UnscheduledDepsInBundle = UnscheduledDeps; clearDependencies(); OpValue = OpVal; TE = nullptr; - Lane = -1; + } + + /// Verify basic self consistency properties + void verify() { + if (hasValidDependencies()) { + assert(UnscheduledDeps <= Dependencies && "invariant"); + } else { + assert(UnscheduledDeps == Dependencies && "invariant"); + } + + if (IsScheduled) { + assert(isSchedulingEntity() && + "unexpected scheduled state"); + for (const ScheduleData *BundleMember = this; BundleMember; + BundleMember = BundleMember->NextInBundle) { + assert(BundleMember->hasValidDependencies() && + BundleMember->UnscheduledDeps == 0 && + "unexpected scheduled state"); + assert((BundleMember == this || !BundleMember->IsScheduled) && + "only bundle is marked scheduled"); + } + } + + assert(Inst->getParent() == FirstInBundle->Inst->getParent() && + "all bundle members must be in same basic block"); } /// Returns true if the dependency information has been calculated. + /// Note that depenendency validity can vary between instructions within + /// a single bundle. bool hasValidDependencies() const { return Dependencies != InvalidDeps; } /// Returns true for single instructions and for bundle representatives @@ -2353,7 +2745,7 @@ private: /// Returns true if it represents an instruction bundle and not only a /// single instruction. bool isPartOfBundle() const { - return NextInBundle != nullptr || FirstInBundle != this; + return NextInBundle != nullptr || FirstInBundle != this || TE; } /// Returns true if it is ready for scheduling, i.e. it has no more @@ -2361,20 +2753,23 @@ private: bool isReady() const { assert(isSchedulingEntity() && "can't consider non-scheduling entity for ready list"); - return UnscheduledDepsInBundle == 0 && !IsScheduled; + return unscheduledDepsInBundle() == 0 && !IsScheduled; } - /// Modifies the number of unscheduled dependencies, also updating it for - /// the whole bundle. + /// Modifies the number of unscheduled dependencies for this instruction, + /// and returns the number of remaining dependencies for the containing + /// bundle. int incrementUnscheduledDeps(int Incr) { + assert(hasValidDependencies() && + "increment of unscheduled deps would be meaningless"); UnscheduledDeps += Incr; - return FirstInBundle->UnscheduledDepsInBundle += Incr; + return FirstInBundle->unscheduledDepsInBundle(); } /// Sets the number of unscheduled dependencies to the number of /// dependencies. void resetUnscheduledDeps() { - incrementUnscheduledDeps(Dependencies - UnscheduledDeps); + UnscheduledDeps = Dependencies; } /// Clears all dependency information. @@ -2382,6 +2777,19 @@ private: Dependencies = InvalidDeps; resetUnscheduledDeps(); MemoryDependencies.clear(); + ControlDependencies.clear(); + } + + int unscheduledDepsInBundle() const { + assert(isSchedulingEntity() && "only meaningful on the bundle"); + int Sum = 0; + for (const ScheduleData *BundleMember = this; BundleMember; + BundleMember = BundleMember->NextInBundle) { + if (BundleMember->UnscheduledDeps == InvalidDeps) + return InvalidDeps; + Sum += BundleMember->UnscheduledDeps; + } + return Sum; } void dump(raw_ostream &os) const { @@ -2402,6 +2810,12 @@ private: Instruction *Inst = nullptr; + /// Opcode of the current instruction in the schedule data. + Value *OpValue = nullptr; + + /// The TreeEntry that this instruction corresponds to. + TreeEntry *TE = nullptr; + /// Points to the head in an instruction bundle (and always to this for /// single instructions). ScheduleData *FirstInBundle = nullptr; @@ -2418,6 +2832,12 @@ private: /// This list is derived on demand in calculateDependencies(). SmallVector<ScheduleData *, 4> MemoryDependencies; + /// List of instructions which this instruction could be control dependent + /// on. Allowing such nodes to be scheduled below this one could introduce + /// a runtime fault which didn't exist in the original program. + /// ex: this is a load or udiv following a readonly call which inf loops + SmallVector<ScheduleData *, 4> ControlDependencies; + /// This ScheduleData is in the current scheduling region if this matches /// the current SchedulingRegionID of BlockScheduling. int SchedulingRegionID = 0; @@ -2437,22 +2857,9 @@ private: /// Note that this is negative as long as Dependencies is not calculated. int UnscheduledDeps = InvalidDeps; - /// The sum of UnscheduledDeps in a bundle. Equals to UnscheduledDeps for - /// single instructions. - int UnscheduledDepsInBundle = InvalidDeps; - /// True if this instruction is scheduled (or considered as scheduled in the /// dry-run). bool IsScheduled = false; - - /// Opcode of the current instruction in the schedule data. - Value *OpValue = nullptr; - - /// The TreeEntry that this instruction corresponds to. - TreeEntry *TE = nullptr; - - /// The lane of this node in the TreeEntry. - int Lane = -1; }; #ifndef NDEBUG @@ -2467,6 +2874,21 @@ private: friend struct DOTGraphTraits<BoUpSLP *>; /// Contains all scheduling data for a basic block. + /// It does not schedules instructions, which are not memory read/write + /// instructions and their operands are either constants, or arguments, or + /// phis, or instructions from others blocks, or their users are phis or from + /// the other blocks. The resulting vector instructions can be placed at the + /// beginning of the basic block without scheduling (if operands does not need + /// to be scheduled) or at the end of the block (if users are outside of the + /// block). It allows to save some compile time and memory used by the + /// compiler. + /// ScheduleData is assigned for each instruction in between the boundaries of + /// the tree entry, even for those, which are not part of the graph. It is + /// required to correctly follow the dependencies between the instructions and + /// their correct scheduling. The ScheduleData is not allocated for the + /// instructions, which do not require scheduling, like phis, nodes with + /// extractelements/insertelements only or nodes with instructions, with + /// uses/operands outside of the block. struct BlockScheduling { BlockScheduling(BasicBlock *BB) : BB(BB), ChunkSize(BB->size()), ChunkPos(ChunkSize) {} @@ -2477,6 +2899,7 @@ private: ScheduleEnd = nullptr; FirstLoadStoreInRegion = nullptr; LastLoadStoreInRegion = nullptr; + RegionHasStackSave = false; // Reduce the maximum schedule region size by the size of the // previous scheduling run. @@ -2490,20 +2913,29 @@ private: ++SchedulingRegionID; } - ScheduleData *getScheduleData(Value *V) { - ScheduleData *SD = ScheduleDataMap[V]; - if (SD && SD->SchedulingRegionID == SchedulingRegionID) + ScheduleData *getScheduleData(Instruction *I) { + if (BB != I->getParent()) + // Avoid lookup if can't possibly be in map. + return nullptr; + ScheduleData *SD = ScheduleDataMap.lookup(I); + if (SD && isInSchedulingRegion(SD)) return SD; return nullptr; } + ScheduleData *getScheduleData(Value *V) { + if (auto *I = dyn_cast<Instruction>(V)) + return getScheduleData(I); + return nullptr; + } + ScheduleData *getScheduleData(Value *V, Value *Key) { if (V == Key) return getScheduleData(V); auto I = ExtraScheduleDataMap.find(V); if (I != ExtraScheduleDataMap.end()) { - ScheduleData *SD = I->second[Key]; - if (SD && SD->SchedulingRegionID == SchedulingRegionID) + ScheduleData *SD = I->second.lookup(Key); + if (SD && isInSchedulingRegion(SD)) return SD; } return nullptr; @@ -2524,7 +2956,7 @@ private: BundleMember = BundleMember->NextInBundle) { if (BundleMember->Inst != BundleMember->OpValue) continue; - + // Handle the def-use chain dependencies. // Decrement the unscheduled counter and insert to ready list if ready. @@ -2546,10 +2978,12 @@ private: }; // If BundleMember is a vector bundle, its operands may have been - // reordered duiring buildTree(). We therefore need to get its operands + // reordered during buildTree(). We therefore need to get its operands // through the TreeEntry. if (TreeEntry *TE = BundleMember->TE) { - int Lane = BundleMember->Lane; + // Need to search for the lane since the tree entry can be reordered. + int Lane = std::distance(TE->Scalars.begin(), + find(TE->Scalars, BundleMember->Inst)); assert(Lane >= 0 && "Lane not set"); // Since vectorization tree is being built recursively this assertion @@ -2558,7 +2992,7 @@ private: // where their second (immediate) operand is not added. Since // immediates do not affect scheduler behavior this is considered // okay. - auto *In = TE->getMainOp(); + auto *In = BundleMember->Inst; assert(In && (isa<ExtractValueInst>(In) || isa<ExtractElementInst>(In) || In->getNumOperands() == TE->getNumOperands()) && @@ -2578,7 +3012,8 @@ private: } // Handle the memory dependencies. for (ScheduleData *MemoryDepSD : BundleMember->MemoryDependencies) { - if (MemoryDepSD->incrementUnscheduledDeps(-1) == 0) { + if (MemoryDepSD->hasValidDependencies() && + MemoryDepSD->incrementUnscheduledDeps(-1) == 0) { // There are no more unscheduled dependencies after decrementing, // so we can put the dependent instruction into the ready list. ScheduleData *DepBundle = MemoryDepSD->FirstInBundle; @@ -2589,6 +3024,48 @@ private: << "SLP: gets ready (mem): " << *DepBundle << "\n"); } } + // Handle the control dependencies. + for (ScheduleData *DepSD : BundleMember->ControlDependencies) { + if (DepSD->incrementUnscheduledDeps(-1) == 0) { + // There are no more unscheduled dependencies after decrementing, + // so we can put the dependent instruction into the ready list. + ScheduleData *DepBundle = DepSD->FirstInBundle; + assert(!DepBundle->IsScheduled && + "already scheduled bundle gets ready"); + ReadyList.insert(DepBundle); + LLVM_DEBUG(dbgs() + << "SLP: gets ready (ctl): " << *DepBundle << "\n"); + } + } + + } + } + + /// Verify basic self consistency properties of the data structure. + void verify() { + if (!ScheduleStart) + return; + + assert(ScheduleStart->getParent() == ScheduleEnd->getParent() && + ScheduleStart->comesBefore(ScheduleEnd) && + "Not a valid scheduling region?"); + + for (auto *I = ScheduleStart; I != ScheduleEnd; I = I->getNextNode()) { + auto *SD = getScheduleData(I); + if (!SD) + continue; + assert(isInSchedulingRegion(SD) && + "primary schedule data not in window?"); + assert(isInSchedulingRegion(SD->FirstInBundle) && + "entire bundle in window!"); + (void)SD; + doForAllOpcodes(I, [](ScheduleData *SD) { SD->verify(); }); + } + + for (auto *SD : ReadyInsts) { + assert(SD->isSchedulingEntity() && SD->isReady() && + "item in ready list not ready?"); + (void)SD; } } @@ -2599,7 +3076,7 @@ private: auto I = ExtraScheduleDataMap.find(V); if (I != ExtraScheduleDataMap.end()) for (auto &P : I->second) - if (P.second->SchedulingRegionID == SchedulingRegionID) + if (isInSchedulingRegion(P.second)) Action(P.second); } @@ -2608,10 +3085,11 @@ private: void initialFillReadyList(ReadyListType &ReadyList) { for (auto *I = ScheduleStart; I != ScheduleEnd; I = I->getNextNode()) { doForAllOpcodes(I, [&](ScheduleData *SD) { - if (SD->isSchedulingEntity() && SD->isReady()) { + if (SD->isSchedulingEntity() && SD->hasValidDependencies() && + SD->isReady()) { ReadyList.insert(SD); LLVM_DEBUG(dbgs() - << "SLP: initially in ready list: " << *I << "\n"); + << "SLP: initially in ready list: " << *SD << "\n"); } }); } @@ -2669,18 +3147,14 @@ private: /// Attaches ScheduleData to Instruction. /// Note that the mapping survives during all vectorization iterations, i.e. /// ScheduleData structures are recycled. - DenseMap<Value *, ScheduleData *> ScheduleDataMap; + DenseMap<Instruction *, ScheduleData *> ScheduleDataMap; /// Attaches ScheduleData to Instruction with the leading key. DenseMap<Value *, SmallDenseMap<Value *, ScheduleData *>> ExtraScheduleDataMap; - struct ReadyList : SmallVector<ScheduleData *, 8> { - void insert(ScheduleData *SD) { push_back(SD); } - }; - /// The ready-list for scheduling (only used for the dry-run). - ReadyList ReadyInsts; + SetVector<ScheduleData *> ReadyInsts; /// The first instruction of the scheduling region. Instruction *ScheduleStart = nullptr; @@ -2696,6 +3170,11 @@ private: /// (can be null). ScheduleData *LastLoadStoreInRegion = nullptr; + /// Is there an llvm.stacksave or llvm.stackrestore in the scheduling + /// region? Used to optimize the dependence calculation for the + /// common case where there isn't. + bool RegionHasStackSave = false; + /// The current size of the scheduling region. int ScheduleRegionSize = 0; @@ -2704,8 +3183,8 @@ private: /// The ID of the scheduling region. For a new vectorization iteration this /// is incremented which "removes" all ScheduleData from the region. - // Make sure that the initial SchedulingRegionID is greater than the - // initial SchedulingRegionID in ScheduleData (which is 0). + /// Make sure that the initial SchedulingRegionID is greater than the + /// initial SchedulingRegionID in ScheduleData (which is 0). int SchedulingRegionID = 1; }; @@ -2717,7 +3196,7 @@ private: void scheduleBlock(BlockScheduling *BS); /// List of users to ignore during scheduling and that don't need extracting. - ArrayRef<Value *> UserIgnoreList; + const SmallDenseSet<Value *> *UserIgnoreList = nullptr; /// A DenseMapInfo implementation for holding DenseMaps and DenseSets of /// sorted SmallVectors of unsigned. @@ -2748,7 +3227,6 @@ private: ScalarEvolution *SE; TargetTransformInfo *TTI; TargetLibraryInfo *TLI; - AAResults *AA; LoopInfo *LI; DominatorTree *DT; AssumptionCache *AC; @@ -2865,20 +3343,25 @@ template <> struct DOTGraphTraits<BoUpSLP *> : public DefaultDOTGraphTraits { } // end namespace llvm BoUpSLP::~BoUpSLP() { - for (const auto &Pair : DeletedInstructions) { - // Replace operands of ignored instructions with Undefs in case if they were - // marked for deletion. - if (Pair.getSecond()) { - Value *Undef = UndefValue::get(Pair.getFirst()->getType()); - Pair.getFirst()->replaceAllUsesWith(Undef); - } - Pair.getFirst()->dropAllReferences(); - } - for (const auto &Pair : DeletedInstructions) { - assert(Pair.getFirst()->use_empty() && + SmallVector<WeakTrackingVH> DeadInsts; + for (auto *I : DeletedInstructions) { + for (Use &U : I->operands()) { + auto *Op = dyn_cast<Instruction>(U.get()); + if (Op && !DeletedInstructions.count(Op) && Op->hasOneUser() && + wouldInstructionBeTriviallyDead(Op, TLI)) + DeadInsts.emplace_back(Op); + } + I->dropAllReferences(); + } + for (auto *I : DeletedInstructions) { + assert(I->use_empty() && "trying to erase instruction with users."); - Pair.getFirst()->eraseFromParent(); + I->eraseFromParent(); } + + // Cleanup any dead scalar code feeding the vectorized instructions + RecursivelyDeleteTriviallyDeadInstructions(DeadInsts, TLI); + #ifdef EXPENSIVE_CHECKS // If we could guarantee that this call is not extremely slow, we could // remove the ifdef limitation (see PR47712). @@ -2886,13 +3369,6 @@ BoUpSLP::~BoUpSLP() { #endif } -void BoUpSLP::eraseInstructions(ArrayRef<Value *> AV) { - for (auto *V : AV) { - if (auto *I = dyn_cast<Instruction>(V)) - eraseInstruction(I, /*ReplaceOpsWithUndef=*/true); - }; -} - /// Reorders the given \p Reuses mask according to the given \p Mask. \p Reuses /// contains original mask for the scalars reused in the node. Procedure /// transform this mask in accordance with the given \p Mask. @@ -2997,6 +3473,189 @@ BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) { return None; } +namespace { +/// Tracks the state we can represent the loads in the given sequence. +enum class LoadsState { Gather, Vectorize, ScatterVectorize }; +} // anonymous namespace + +/// Checks if the given array of loads can be represented as a vectorized, +/// scatter or just simple gather. +static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0, + const TargetTransformInfo &TTI, + const DataLayout &DL, ScalarEvolution &SE, + LoopInfo &LI, + SmallVectorImpl<unsigned> &Order, + SmallVectorImpl<Value *> &PointerOps) { + // Check that a vectorized load would load the same memory as a scalar + // load. For example, we don't want to vectorize loads that are smaller + // than 8-bit. Even though we have a packed struct {<i2, i2, i2, i2>} LLVM + // treats loading/storing it as an i8 struct. If we vectorize loads/stores + // from such a struct, we read/write packed bits disagreeing with the + // unvectorized version. + Type *ScalarTy = VL0->getType(); + + if (DL.getTypeSizeInBits(ScalarTy) != DL.getTypeAllocSizeInBits(ScalarTy)) + return LoadsState::Gather; + + // Make sure all loads in the bundle are simple - we can't vectorize + // atomic or volatile loads. + PointerOps.clear(); + PointerOps.resize(VL.size()); + auto *POIter = PointerOps.begin(); + for (Value *V : VL) { + auto *L = cast<LoadInst>(V); + if (!L->isSimple()) + return LoadsState::Gather; + *POIter = L->getPointerOperand(); + ++POIter; + } + + Order.clear(); + // Check the order of pointer operands or that all pointers are the same. + bool IsSorted = sortPtrAccesses(PointerOps, ScalarTy, DL, SE, Order); + if (IsSorted || all_of(PointerOps, [&PointerOps](Value *P) { + if (getUnderlyingObject(P) != getUnderlyingObject(PointerOps.front())) + return false; + auto *GEP = dyn_cast<GetElementPtrInst>(P); + if (!GEP) + return false; + auto *GEP0 = cast<GetElementPtrInst>(PointerOps.front()); + return GEP->getNumOperands() == 2 && + ((isConstant(GEP->getOperand(1)) && + isConstant(GEP0->getOperand(1))) || + getSameOpcode({GEP->getOperand(1), GEP0->getOperand(1)}) + .getOpcode()); + })) { + if (IsSorted) { + Value *Ptr0; + Value *PtrN; + if (Order.empty()) { + Ptr0 = PointerOps.front(); + PtrN = PointerOps.back(); + } else { + Ptr0 = PointerOps[Order.front()]; + PtrN = PointerOps[Order.back()]; + } + Optional<int> Diff = + getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, DL, SE); + // Check that the sorted loads are consecutive. + if (static_cast<unsigned>(*Diff) == VL.size() - 1) + return LoadsState::Vectorize; + } + // TODO: need to improve analysis of the pointers, if not all of them are + // GEPs or have > 2 operands, we end up with a gather node, which just + // increases the cost. + Loop *L = LI.getLoopFor(cast<LoadInst>(VL0)->getParent()); + bool ProfitableGatherPointers = + static_cast<unsigned>(count_if(PointerOps, [L](Value *V) { + return L && L->isLoopInvariant(V); + })) <= VL.size() / 2 && VL.size() > 2; + if (ProfitableGatherPointers || all_of(PointerOps, [IsSorted](Value *P) { + auto *GEP = dyn_cast<GetElementPtrInst>(P); + return (IsSorted && !GEP && doesNotNeedToBeScheduled(P)) || + (GEP && GEP->getNumOperands() == 2); + })) { + Align CommonAlignment = cast<LoadInst>(VL0)->getAlign(); + for (Value *V : VL) + CommonAlignment = + std::min(CommonAlignment, cast<LoadInst>(V)->getAlign()); + auto *VecTy = FixedVectorType::get(ScalarTy, VL.size()); + if (TTI.isLegalMaskedGather(VecTy, CommonAlignment) && + !TTI.forceScalarizeMaskedGather(VecTy, CommonAlignment)) + return LoadsState::ScatterVectorize; + } + } + + return LoadsState::Gather; +} + +bool clusterSortPtrAccesses(ArrayRef<Value *> VL, Type *ElemTy, + const DataLayout &DL, ScalarEvolution &SE, + SmallVectorImpl<unsigned> &SortedIndices) { + assert(llvm::all_of( + VL, [](const Value *V) { return V->getType()->isPointerTy(); }) && + "Expected list of pointer operands."); + // Map from bases to a vector of (Ptr, Offset, OrigIdx), which we insert each + // Ptr into, sort and return the sorted indices with values next to one + // another. + MapVector<Value *, SmallVector<std::tuple<Value *, int, unsigned>>> Bases; + Bases[VL[0]].push_back(std::make_tuple(VL[0], 0U, 0U)); + + unsigned Cnt = 1; + for (Value *Ptr : VL.drop_front()) { + bool Found = any_of(Bases, [&](auto &Base) { + Optional<int> Diff = + getPointersDiff(ElemTy, Base.first, ElemTy, Ptr, DL, SE, + /*StrictCheck=*/true); + if (!Diff) + return false; + + Base.second.emplace_back(Ptr, *Diff, Cnt++); + return true; + }); + + if (!Found) { + // If we haven't found enough to usefully cluster, return early. + if (Bases.size() > VL.size() / 2 - 1) + return false; + + // Not found already - add a new Base + Bases[Ptr].emplace_back(Ptr, 0, Cnt++); + } + } + + // For each of the bases sort the pointers by Offset and check if any of the + // base become consecutively allocated. + bool AnyConsecutive = false; + for (auto &Base : Bases) { + auto &Vec = Base.second; + if (Vec.size() > 1) { + llvm::stable_sort(Vec, [](const std::tuple<Value *, int, unsigned> &X, + const std::tuple<Value *, int, unsigned> &Y) { + return std::get<1>(X) < std::get<1>(Y); + }); + int InitialOffset = std::get<1>(Vec[0]); + AnyConsecutive |= all_of(enumerate(Vec), [InitialOffset](auto &P) { + return std::get<1>(P.value()) == int(P.index()) + InitialOffset; + }); + } + } + + // Fill SortedIndices array only if it looks worth-while to sort the ptrs. + SortedIndices.clear(); + if (!AnyConsecutive) + return false; + + for (auto &Base : Bases) { + for (auto &T : Base.second) + SortedIndices.push_back(std::get<2>(T)); + } + + assert(SortedIndices.size() == VL.size() && + "Expected SortedIndices to be the size of VL"); + return true; +} + +Optional<BoUpSLP::OrdersType> +BoUpSLP::findPartiallyOrderedLoads(const BoUpSLP::TreeEntry &TE) { + assert(TE.State == TreeEntry::NeedToGather && "Expected gather node only."); + Type *ScalarTy = TE.Scalars[0]->getType(); + + SmallVector<Value *> Ptrs; + Ptrs.reserve(TE.Scalars.size()); + for (Value *V : TE.Scalars) { + auto *L = dyn_cast<LoadInst>(V); + if (!L || !L->isSimple()) + return None; + Ptrs.push_back(L->getPointerOperand()); + } + + BoUpSLP::OrdersType Order; + if (clusterSortPtrAccesses(Ptrs, ScalarTy, *DL, *SE, Order)) + return Order; + return None; +} + Optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) { // No need to reorder if need to shuffle reuses, still need to shuffle the @@ -3037,6 +3696,9 @@ Optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &TE, } if (Optional<OrdersType> CurrentOrder = findReusedOrderedScalars(TE)) return CurrentOrder; + if (TE.Scalars.size() >= 4) + if (Optional<OrdersType> Order = findPartiallyOrderedLoads(TE)) + return Order; } return None; } @@ -3047,13 +3709,55 @@ void BoUpSLP::reorderTopToBottom() { // ExtractElement gather nodes which can be vectorized and need to handle // their ordering. DenseMap<const TreeEntry *, OrdersType> GathersToOrders; + + // AltShuffles can also have a preferred ordering that leads to fewer + // instructions, e.g., the addsub instruction in x86. + DenseMap<const TreeEntry *, OrdersType> AltShufflesToOrders; + + // Maps a TreeEntry to the reorder indices of external users. + DenseMap<const TreeEntry *, SmallVector<OrdersType, 1>> + ExternalUserReorderMap; + // FIXME: Workaround for syntax error reported by MSVC buildbots. + TargetTransformInfo &TTIRef = *TTI; // Find all reorderable nodes with the given VF. // Currently the are vectorized stores,loads,extracts + some gathering of // extracts. - for_each(VectorizableTree, [this, &VFToOrderedEntries, &GathersToOrders]( + for_each(VectorizableTree, [this, &TTIRef, &VFToOrderedEntries, + &GathersToOrders, &ExternalUserReorderMap, + &AltShufflesToOrders]( const std::unique_ptr<TreeEntry> &TE) { + // Look for external users that will probably be vectorized. + SmallVector<OrdersType, 1> ExternalUserReorderIndices = + findExternalStoreUsersReorderIndices(TE.get()); + if (!ExternalUserReorderIndices.empty()) { + VFToOrderedEntries[TE->Scalars.size()].insert(TE.get()); + ExternalUserReorderMap.try_emplace(TE.get(), + std::move(ExternalUserReorderIndices)); + } + + // Patterns like [fadd,fsub] can be combined into a single instruction in + // x86. Reordering them into [fsub,fadd] blocks this pattern. So we need + // to take into account their order when looking for the most used order. + if (TE->isAltShuffle()) { + VectorType *VecTy = + FixedVectorType::get(TE->Scalars[0]->getType(), TE->Scalars.size()); + unsigned Opcode0 = TE->getOpcode(); + unsigned Opcode1 = TE->getAltOpcode(); + // The opcode mask selects between the two opcodes. + SmallBitVector OpcodeMask(TE->Scalars.size(), 0); + for (unsigned Lane : seq<unsigned>(0, TE->Scalars.size())) + if (cast<Instruction>(TE->Scalars[Lane])->getOpcode() == Opcode1) + OpcodeMask.set(Lane); + // If this pattern is supported by the target then we consider the order. + if (TTIRef.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask)) { + VFToOrderedEntries[TE->Scalars.size()].insert(TE.get()); + AltShufflesToOrders.try_emplace(TE.get(), OrdersType()); + } + // TODO: Check the reverse order too. + } + if (Optional<OrdersType> CurrentOrder = - getReorderingData(*TE.get(), /*TopToBottom=*/true)) { + getReorderingData(*TE, /*TopToBottom=*/true)) { // Do not include ordering for nodes used in the alt opcode vectorization, // better to reorder them during bottom-to-top stage. If follow the order // here, it causes reordering of the whole graph though actually it is @@ -3071,10 +3775,7 @@ void BoUpSLP::reorderTopToBottom() { EI.UserTE->isAltShuffle() && EI.UserTE->Idx != 0; })) return; - if (UserTE->UserTreeIndices.empty()) - UserTE = nullptr; - else - UserTE = UserTE->UserTreeIndices.back().UserTE; + UserTE = UserTE->UserTreeIndices.back().UserTE; ++Cnt; } VFToOrderedEntries[TE->Scalars.size()].insert(TE.get()); @@ -3105,11 +3806,30 @@ void BoUpSLP::reorderTopToBottom() { if (!OpTE->ReuseShuffleIndices.empty()) continue; // Count number of orders uses. - const auto &Order = [OpTE, &GathersToOrders]() -> const OrdersType & { - if (OpTE->State == TreeEntry::NeedToGather) - return GathersToOrders.find(OpTE)->second; + const auto &Order = [OpTE, &GathersToOrders, + &AltShufflesToOrders]() -> const OrdersType & { + if (OpTE->State == TreeEntry::NeedToGather) { + auto It = GathersToOrders.find(OpTE); + if (It != GathersToOrders.end()) + return It->second; + } + if (OpTE->isAltShuffle()) { + auto It = AltShufflesToOrders.find(OpTE); + if (It != AltShufflesToOrders.end()) + return It->second; + } return OpTE->ReorderIndices; }(); + // First consider the order of the external scalar users. + auto It = ExternalUserReorderMap.find(OpTE); + if (It != ExternalUserReorderMap.end()) { + const auto &ExternalUserReorderIndices = It->second; + for (const OrdersType &ExtOrder : ExternalUserReorderIndices) + ++OrdersUses.insert(std::make_pair(ExtOrder, 0)).first->second; + // No other useful reorder data in this entry. + if (Order.empty()) + continue; + } // Stores actually store the mask, not the order, need to invert. if (OpTE->State == TreeEntry::Vectorize && !OpTE->isAltShuffle() && OpTE->getOpcode() == Instruction::Store && !Order.empty()) { @@ -3199,6 +3919,57 @@ void BoUpSLP::reorderTopToBottom() { } } +bool BoUpSLP::canReorderOperands( + TreeEntry *UserTE, SmallVectorImpl<std::pair<unsigned, TreeEntry *>> &Edges, + ArrayRef<TreeEntry *> ReorderableGathers, + SmallVectorImpl<TreeEntry *> &GatherOps) { + for (unsigned I = 0, E = UserTE->getNumOperands(); I < E; ++I) { + if (any_of(Edges, [I](const std::pair<unsigned, TreeEntry *> &OpData) { + return OpData.first == I && + OpData.second->State == TreeEntry::Vectorize; + })) + continue; + if (TreeEntry *TE = getVectorizedOperand(UserTE, I)) { + // Do not reorder if operand node is used by many user nodes. + if (any_of(TE->UserTreeIndices, + [UserTE](const EdgeInfo &EI) { return EI.UserTE != UserTE; })) + return false; + // Add the node to the list of the ordered nodes with the identity + // order. + Edges.emplace_back(I, TE); + // Add ScatterVectorize nodes to the list of operands, where just + // reordering of the scalars is required. Similar to the gathers, so + // simply add to the list of gathered ops. + // If there are reused scalars, process this node as a regular vectorize + // node, just reorder reuses mask. + if (TE->State != TreeEntry::Vectorize && TE->ReuseShuffleIndices.empty()) + GatherOps.push_back(TE); + continue; + } + TreeEntry *Gather = nullptr; + if (count_if(ReorderableGathers, + [&Gather, UserTE, I](TreeEntry *TE) { + assert(TE->State != TreeEntry::Vectorize && + "Only non-vectorized nodes are expected."); + if (any_of(TE->UserTreeIndices, + [UserTE, I](const EdgeInfo &EI) { + return EI.UserTE == UserTE && EI.EdgeIdx == I; + })) { + assert(TE->isSame(UserTE->getOperand(I)) && + "Operand entry does not match operands."); + Gather = TE; + return true; + } + return false; + }) > 1 && + !all_of(UserTE->getOperand(I), isConstant)) + return false; + if (Gather) + GatherOps.push_back(Gather); + } + return true; +} + void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { SetVector<TreeEntry *> OrderedEntries; DenseMap<const TreeEntry *, OrdersType> GathersToOrders; @@ -3212,49 +3983,13 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { if (TE->State != TreeEntry::Vectorize) NonVectorized.push_back(TE.get()); if (Optional<OrdersType> CurrentOrder = - getReorderingData(*TE.get(), /*TopToBottom=*/false)) { + getReorderingData(*TE, /*TopToBottom=*/false)) { OrderedEntries.insert(TE.get()); if (TE->State != TreeEntry::Vectorize) GathersToOrders.try_emplace(TE.get(), *CurrentOrder); } }); - // Checks if the operands of the users are reordarable and have only single - // use. - auto &&CheckOperands = - [this, &NonVectorized](const auto &Data, - SmallVectorImpl<TreeEntry *> &GatherOps) { - for (unsigned I = 0, E = Data.first->getNumOperands(); I < E; ++I) { - if (any_of(Data.second, - [I](const std::pair<unsigned, TreeEntry *> &OpData) { - return OpData.first == I && - OpData.second->State == TreeEntry::Vectorize; - })) - continue; - ArrayRef<Value *> VL = Data.first->getOperand(I); - const TreeEntry *TE = nullptr; - const auto *It = find_if(VL, [this, &TE](Value *V) { - TE = getTreeEntry(V); - return TE; - }); - if (It != VL.end() && TE->isSame(VL)) - return false; - TreeEntry *Gather = nullptr; - if (count_if(NonVectorized, [VL, &Gather](TreeEntry *TE) { - assert(TE->State != TreeEntry::Vectorize && - "Only non-vectorized nodes are expected."); - if (TE->isSame(VL)) { - Gather = TE; - return true; - } - return false; - }) > 1) - return false; - if (Gather) - GatherOps.push_back(Gather); - } - return true; - }; // 1. Propagate order to the graph nodes, which use only reordered nodes. // I.e., if the node has operands, that are reordered, try to make at least // one operand order in the natural order and reorder others + reorder the @@ -3263,7 +3998,7 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { while (!OrderedEntries.empty()) { // 1. Filter out only reordered nodes. // 2. If the entry has multiple uses - skip it and jump to the next node. - MapVector<TreeEntry *, SmallVector<std::pair<unsigned, TreeEntry *>>> Users; + DenseMap<TreeEntry *, SmallVector<std::pair<unsigned, TreeEntry *>>> Users; SmallVector<TreeEntry *> Filtered; for (TreeEntry *TE : OrderedEntries) { if (!(TE->State == TreeEntry::Vectorize || @@ -3291,10 +4026,17 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { // Erase filtered entries. for_each(Filtered, [&OrderedEntries](TreeEntry *TE) { OrderedEntries.remove(TE); }); - for (const auto &Data : Users) { + SmallVector< + std::pair<TreeEntry *, SmallVector<std::pair<unsigned, TreeEntry *>>>> + UsersVec(Users.begin(), Users.end()); + sort(UsersVec, [](const auto &Data1, const auto &Data2) { + return Data1.first->Idx > Data2.first->Idx; + }); + for (auto &Data : UsersVec) { // Check that operands are used only in the User node. SmallVector<TreeEntry *> GatherOps; - if (!CheckOperands(Data, GatherOps)) { + if (!canReorderOperands(Data.first, Data.second, NonVectorized, + GatherOps)) { for_each(Data.second, [&OrderedEntries](const std::pair<unsigned, TreeEntry *> &Op) { OrderedEntries.remove(Op.second); @@ -3310,18 +4052,22 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { // the same node my be considered several times, though might be not // profitable. SmallPtrSet<const TreeEntry *, 4> VisitedOps; + SmallPtrSet<const TreeEntry *, 4> VisitedUsers; for (const auto &Op : Data.second) { TreeEntry *OpTE = Op.second; if (!VisitedOps.insert(OpTE).second) continue; - if (!OpTE->ReuseShuffleIndices.empty() || - (IgnoreReorder && OpTE == VectorizableTree.front().get())) + if (!OpTE->ReuseShuffleIndices.empty()) continue; const auto &Order = [OpTE, &GathersToOrders]() -> const OrdersType & { if (OpTE->State == TreeEntry::NeedToGather) return GathersToOrders.find(OpTE)->second; return OpTE->ReorderIndices; }(); + unsigned NumOps = count_if( + Data.second, [OpTE](const std::pair<unsigned, TreeEntry *> &P) { + return P.second == OpTE; + }); // Stores actually store the mask, not the order, need to invert. if (OpTE->State == TreeEntry::Vectorize && !OpTE->isAltShuffle() && OpTE->getOpcode() == Instruction::Store && !Order.empty()) { @@ -3333,14 +4079,52 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { return Idx == UndefMaskElem ? E : static_cast<unsigned>(Idx); }); fixupOrderingIndices(CurrentOrder); - ++OrdersUses.insert(std::make_pair(CurrentOrder, 0)).first->second; + OrdersUses.insert(std::make_pair(CurrentOrder, 0)).first->second += + NumOps; } else { - ++OrdersUses.insert(std::make_pair(Order, 0)).first->second; + OrdersUses.insert(std::make_pair(Order, 0)).first->second += NumOps; + } + auto Res = OrdersUses.insert(std::make_pair(OrdersType(), 0)); + const auto &&AllowsReordering = [IgnoreReorder, &GathersToOrders]( + const TreeEntry *TE) { + if (!TE->ReorderIndices.empty() || !TE->ReuseShuffleIndices.empty() || + (TE->State == TreeEntry::Vectorize && TE->isAltShuffle()) || + (IgnoreReorder && TE->Idx == 0)) + return true; + if (TE->State == TreeEntry::NeedToGather) { + auto It = GathersToOrders.find(TE); + if (It != GathersToOrders.end()) + return !It->second.empty(); + return true; + } + return false; + }; + for (const EdgeInfo &EI : OpTE->UserTreeIndices) { + TreeEntry *UserTE = EI.UserTE; + if (!VisitedUsers.insert(UserTE).second) + continue; + // May reorder user node if it requires reordering, has reused + // scalars, is an alternate op vectorize node or its op nodes require + // reordering. + if (AllowsReordering(UserTE)) + continue; + // Check if users allow reordering. + // Currently look up just 1 level of operands to avoid increase of + // the compile time. + // Profitable to reorder if definitely more operands allow + // reordering rather than those with natural order. + ArrayRef<std::pair<unsigned, TreeEntry *>> Ops = Users[UserTE]; + if (static_cast<unsigned>(count_if( + Ops, [UserTE, &AllowsReordering]( + const std::pair<unsigned, TreeEntry *> &Op) { + return AllowsReordering(Op.second) && + all_of(Op.second->UserTreeIndices, + [UserTE](const EdgeInfo &EI) { + return EI.UserTE == UserTE; + }); + })) <= Ops.size() / 2) + ++Res.first->second; } - OrdersUses.insert(std::make_pair(OrdersType(), 0)).first->second += - OpTE->UserTreeIndices.size(); - assert(OrdersUses[{}] > 0 && "Counter cannot be less than 0."); - --OrdersUses[{}]; } // If no orders - skip current nodes and jump to the next one, if any. if (OrdersUses.empty()) { @@ -3381,7 +4165,7 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { OrderedEntries.remove(TE); if (!VisitedOps.insert(TE).second) continue; - if (!TE->ReuseShuffleIndices.empty() && TE->ReorderIndices.empty()) { + if (TE->ReuseShuffleIndices.size() == BestOrder.size()) { // Just reorder reuses indices. reorderReuses(TE->ReuseShuffleIndices, Mask); continue; @@ -3393,6 +4177,8 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { TE->ReorderIndices.empty()) && "Non-matching sizes of user/operand entries."); reorderOrder(TE->ReorderIndices, Mask); + if (IgnoreReorder && TE == VectorizableTree.front().get()) + IgnoreReorder = false; } // For gathers just need to reorder its scalars. for (TreeEntry *Gather : GatherOps) { @@ -3484,7 +4270,7 @@ void BoUpSLP::buildExternalUses( } // Ignore users in the user ignore list. - if (is_contained(UserIgnoreList, UserInst)) + if (UserIgnoreList && UserIgnoreList->contains(UserInst)) continue; LLVM_DEBUG(dbgs() << "SLP: Need to extract:" << *U << " from lane " @@ -3495,78 +4281,270 @@ void BoUpSLP::buildExternalUses( } } +DenseMap<Value *, SmallVector<StoreInst *, 4>> +BoUpSLP::collectUserStores(const BoUpSLP::TreeEntry *TE) const { + DenseMap<Value *, SmallVector<StoreInst *, 4>> PtrToStoresMap; + for (unsigned Lane : seq<unsigned>(0, TE->Scalars.size())) { + Value *V = TE->Scalars[Lane]; + // To save compilation time we don't visit if we have too many users. + static constexpr unsigned UsersLimit = 4; + if (V->hasNUsesOrMore(UsersLimit)) + break; + + // Collect stores per pointer object. + for (User *U : V->users()) { + auto *SI = dyn_cast<StoreInst>(U); + if (SI == nullptr || !SI->isSimple() || + !isValidElementType(SI->getValueOperand()->getType())) + continue; + // Skip entry if already + if (getTreeEntry(U)) + continue; + + Value *Ptr = getUnderlyingObject(SI->getPointerOperand()); + auto &StoresVec = PtrToStoresMap[Ptr]; + // For now just keep one store per pointer object per lane. + // TODO: Extend this to support multiple stores per pointer per lane + if (StoresVec.size() > Lane) + continue; + // Skip if in different BBs. + if (!StoresVec.empty() && + SI->getParent() != StoresVec.back()->getParent()) + continue; + // Make sure that the stores are of the same type. + if (!StoresVec.empty() && + SI->getValueOperand()->getType() != + StoresVec.back()->getValueOperand()->getType()) + continue; + StoresVec.push_back(SI); + } + } + return PtrToStoresMap; +} + +bool BoUpSLP::CanFormVector(const SmallVector<StoreInst *, 4> &StoresVec, + OrdersType &ReorderIndices) const { + // We check whether the stores in StoreVec can form a vector by sorting them + // and checking whether they are consecutive. + + // To avoid calling getPointersDiff() while sorting we create a vector of + // pairs {store, offset from first} and sort this instead. + SmallVector<std::pair<StoreInst *, int>, 4> StoreOffsetVec(StoresVec.size()); + StoreInst *S0 = StoresVec[0]; + StoreOffsetVec[0] = {S0, 0}; + Type *S0Ty = S0->getValueOperand()->getType(); + Value *S0Ptr = S0->getPointerOperand(); + for (unsigned Idx : seq<unsigned>(1, StoresVec.size())) { + StoreInst *SI = StoresVec[Idx]; + Optional<int> Diff = + getPointersDiff(S0Ty, S0Ptr, SI->getValueOperand()->getType(), + SI->getPointerOperand(), *DL, *SE, + /*StrictCheck=*/true); + // We failed to compare the pointers so just abandon this StoresVec. + if (!Diff) + return false; + StoreOffsetVec[Idx] = {StoresVec[Idx], *Diff}; + } + + // Sort the vector based on the pointers. We create a copy because we may + // need the original later for calculating the reorder (shuffle) indices. + stable_sort(StoreOffsetVec, [](const std::pair<StoreInst *, int> &Pair1, + const std::pair<StoreInst *, int> &Pair2) { + int Offset1 = Pair1.second; + int Offset2 = Pair2.second; + return Offset1 < Offset2; + }); + + // Check if the stores are consecutive by checking if their difference is 1. + for (unsigned Idx : seq<unsigned>(1, StoreOffsetVec.size())) + if (StoreOffsetVec[Idx].second != StoreOffsetVec[Idx-1].second + 1) + return false; + + // Calculate the shuffle indices according to their offset against the sorted + // StoreOffsetVec. + ReorderIndices.reserve(StoresVec.size()); + for (StoreInst *SI : StoresVec) { + unsigned Idx = find_if(StoreOffsetVec, + [SI](const std::pair<StoreInst *, int> &Pair) { + return Pair.first == SI; + }) - + StoreOffsetVec.begin(); + ReorderIndices.push_back(Idx); + } + // Identity order (e.g., {0,1,2,3}) is modeled as an empty OrdersType in + // reorderTopToBottom() and reorderBottomToTop(), so we are following the + // same convention here. + auto IsIdentityOrder = [](const OrdersType &Order) { + for (unsigned Idx : seq<unsigned>(0, Order.size())) + if (Idx != Order[Idx]) + return false; + return true; + }; + if (IsIdentityOrder(ReorderIndices)) + ReorderIndices.clear(); + + return true; +} + +#ifndef NDEBUG +LLVM_DUMP_METHOD static void dumpOrder(const BoUpSLP::OrdersType &Order) { + for (unsigned Idx : Order) + dbgs() << Idx << ", "; + dbgs() << "\n"; +} +#endif + +SmallVector<BoUpSLP::OrdersType, 1> +BoUpSLP::findExternalStoreUsersReorderIndices(TreeEntry *TE) const { + unsigned NumLanes = TE->Scalars.size(); + + DenseMap<Value *, SmallVector<StoreInst *, 4>> PtrToStoresMap = + collectUserStores(TE); + + // Holds the reorder indices for each candidate store vector that is a user of + // the current TreeEntry. + SmallVector<OrdersType, 1> ExternalReorderIndices; + + // Now inspect the stores collected per pointer and look for vectorization + // candidates. For each candidate calculate the reorder index vector and push + // it into `ExternalReorderIndices` + for (const auto &Pair : PtrToStoresMap) { + auto &StoresVec = Pair.second; + // If we have fewer than NumLanes stores, then we can't form a vector. + if (StoresVec.size() != NumLanes) + continue; + + // If the stores are not consecutive then abandon this StoresVec. + OrdersType ReorderIndices; + if (!CanFormVector(StoresVec, ReorderIndices)) + continue; + + // We now know that the scalars in StoresVec can form a vector instruction, + // so set the reorder indices. + ExternalReorderIndices.push_back(ReorderIndices); + } + return ExternalReorderIndices; +} + void BoUpSLP::buildTree(ArrayRef<Value *> Roots, - ArrayRef<Value *> UserIgnoreLst) { + const SmallDenseSet<Value *> &UserIgnoreLst) { deleteTree(); - UserIgnoreList = UserIgnoreLst; + UserIgnoreList = &UserIgnoreLst; if (!allSameType(Roots)) return; buildTree_rec(Roots, 0, EdgeInfo()); } -namespace { -/// Tracks the state we can represent the loads in the given sequence. -enum class LoadsState { Gather, Vectorize, ScatterVectorize }; -} // anonymous namespace - -/// Checks if the given array of loads can be represented as a vectorized, -/// scatter or just simple gather. -static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0, - const TargetTransformInfo &TTI, - const DataLayout &DL, ScalarEvolution &SE, - SmallVectorImpl<unsigned> &Order, - SmallVectorImpl<Value *> &PointerOps) { - // Check that a vectorized load would load the same memory as a scalar - // load. For example, we don't want to vectorize loads that are smaller - // than 8-bit. Even though we have a packed struct {<i2, i2, i2, i2>} LLVM - // treats loading/storing it as an i8 struct. If we vectorize loads/stores - // from such a struct, we read/write packed bits disagreeing with the - // unvectorized version. - Type *ScalarTy = VL0->getType(); - - if (DL.getTypeSizeInBits(ScalarTy) != DL.getTypeAllocSizeInBits(ScalarTy)) - return LoadsState::Gather; +void BoUpSLP::buildTree(ArrayRef<Value *> Roots) { + deleteTree(); + if (!allSameType(Roots)) + return; + buildTree_rec(Roots, 0, EdgeInfo()); +} - // Make sure all loads in the bundle are simple - we can't vectorize - // atomic or volatile loads. - PointerOps.clear(); - PointerOps.resize(VL.size()); - auto *POIter = PointerOps.begin(); +/// \return true if the specified list of values has only one instruction that +/// requires scheduling, false otherwise. +#ifndef NDEBUG +static bool needToScheduleSingleInstruction(ArrayRef<Value *> VL) { + Value *NeedsScheduling = nullptr; for (Value *V : VL) { - auto *L = cast<LoadInst>(V); - if (!L->isSimple()) - return LoadsState::Gather; - *POIter = L->getPointerOperand(); - ++POIter; + if (doesNotNeedToBeScheduled(V)) + continue; + if (!NeedsScheduling) { + NeedsScheduling = V; + continue; + } + return false; } + return NeedsScheduling; +} +#endif - Order.clear(); - // Check the order of pointer operands. - if (llvm::sortPtrAccesses(PointerOps, ScalarTy, DL, SE, Order)) { - Value *Ptr0; - Value *PtrN; - if (Order.empty()) { - Ptr0 = PointerOps.front(); - PtrN = PointerOps.back(); +/// Generates key/subkey pair for the given value to provide effective sorting +/// of the values and better detection of the vectorizable values sequences. The +/// keys/subkeys can be used for better sorting of the values themselves (keys) +/// and in values subgroups (subkeys). +static std::pair<size_t, size_t> generateKeySubkey( + Value *V, const TargetLibraryInfo *TLI, + function_ref<hash_code(size_t, LoadInst *)> LoadsSubkeyGenerator, + bool AllowAlternate) { + hash_code Key = hash_value(V->getValueID() + 2); + hash_code SubKey = hash_value(0); + // Sort the loads by the distance between the pointers. + if (auto *LI = dyn_cast<LoadInst>(V)) { + Key = hash_combine(hash_value(Instruction::Load), Key); + if (LI->isSimple()) + SubKey = hash_value(LoadsSubkeyGenerator(Key, LI)); + else + SubKey = hash_value(LI); + } else if (isVectorLikeInstWithConstOps(V)) { + // Sort extracts by the vector operands. + if (isa<ExtractElementInst, UndefValue>(V)) + Key = hash_value(Value::UndefValueVal + 1); + if (auto *EI = dyn_cast<ExtractElementInst>(V)) { + if (!isUndefVector(EI->getVectorOperand()) && + !isa<UndefValue>(EI->getIndexOperand())) + SubKey = hash_value(EI->getVectorOperand()); + } + } else if (auto *I = dyn_cast<Instruction>(V)) { + // Sort other instructions just by the opcodes except for CMPInst. + // For CMP also sort by the predicate kind. + if ((isa<BinaryOperator>(I) || isa<CastInst>(I)) && + isValidForAlternation(I->getOpcode())) { + if (AllowAlternate) + Key = hash_value(isa<BinaryOperator>(I) ? 1 : 0); + else + Key = hash_combine(hash_value(I->getOpcode()), Key); + SubKey = hash_combine( + hash_value(I->getOpcode()), hash_value(I->getType()), + hash_value(isa<BinaryOperator>(I) + ? I->getType() + : cast<CastInst>(I)->getOperand(0)->getType())); + // For casts, look through the only operand to improve compile time. + if (isa<CastInst>(I)) { + std::pair<size_t, size_t> OpVals = + generateKeySubkey(I->getOperand(0), TLI, LoadsSubkeyGenerator, + /*=AllowAlternate*/ true); + Key = hash_combine(OpVals.first, Key); + SubKey = hash_combine(OpVals.first, SubKey); + } + } else if (auto *CI = dyn_cast<CmpInst>(I)) { + CmpInst::Predicate Pred = CI->getPredicate(); + if (CI->isCommutative()) + Pred = std::min(Pred, CmpInst::getInversePredicate(Pred)); + CmpInst::Predicate SwapPred = CmpInst::getSwappedPredicate(Pred); + SubKey = hash_combine(hash_value(I->getOpcode()), hash_value(Pred), + hash_value(SwapPred), + hash_value(CI->getOperand(0)->getType())); + } else if (auto *Call = dyn_cast<CallInst>(I)) { + Intrinsic::ID ID = getVectorIntrinsicIDForCall(Call, TLI); + if (isTriviallyVectorizable(ID)) { + SubKey = hash_combine(hash_value(I->getOpcode()), hash_value(ID)); + } else if (!VFDatabase(*Call).getMappings(*Call).empty()) { + SubKey = hash_combine(hash_value(I->getOpcode()), + hash_value(Call->getCalledFunction())); + } else { + Key = hash_combine(hash_value(Call), Key); + SubKey = hash_combine(hash_value(I->getOpcode()), hash_value(Call)); + } + for (const CallBase::BundleOpInfo &Op : Call->bundle_op_infos()) + SubKey = hash_combine(hash_value(Op.Begin), hash_value(Op.End), + hash_value(Op.Tag), SubKey); + } else if (auto *Gep = dyn_cast<GetElementPtrInst>(I)) { + if (Gep->getNumOperands() == 2 && isa<ConstantInt>(Gep->getOperand(1))) + SubKey = hash_value(Gep->getPointerOperand()); + else + SubKey = hash_value(Gep); + } else if (BinaryOperator::isIntDivRem(I->getOpcode()) && + !isa<ConstantInt>(I->getOperand(1))) { + // Do not try to vectorize instructions with potentially high cost. + SubKey = hash_value(I); } else { - Ptr0 = PointerOps[Order.front()]; - PtrN = PointerOps[Order.back()]; + SubKey = hash_value(I->getOpcode()); } - Optional<int> Diff = - getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, DL, SE); - // Check that the sorted loads are consecutive. - if (static_cast<unsigned>(*Diff) == VL.size() - 1) - return LoadsState::Vectorize; - Align CommonAlignment = cast<LoadInst>(VL0)->getAlign(); - for (Value *V : VL) - CommonAlignment = - commonAlignment(CommonAlignment, cast<LoadInst>(V)->getAlign()); - if (TTI.isLegalMaskedGather(FixedVectorType::get(ScalarTy, VL.size()), - CommonAlignment)) - return LoadsState::ScatterVectorize; + Key = hash_combine(hash_value(I->getParent()), Key); } - - return LoadsState::Gather; + return std::make_pair(Key, SubKey); } void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, @@ -3651,10 +4629,84 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // If all of the operands are identical or constant we have a simple solution. // If we deal with insert/extract instructions, they all must have constant // indices, otherwise we should gather them, not try to vectorize. - if (allConstant(VL) || isSplat(VL) || !allSameBlock(VL) || !S.getOpcode() || - (isa<InsertElementInst, ExtractValueInst, ExtractElementInst>(S.MainOp) && - !all_of(VL, isVectorLikeInstWithConstOps))) { - LLVM_DEBUG(dbgs() << "SLP: Gathering due to C,S,B,O. \n"); + // If alternate op node with 2 elements with gathered operands - do not + // vectorize. + auto &&NotProfitableForVectorization = [&S, this, + Depth](ArrayRef<Value *> VL) { + if (!S.getOpcode() || !S.isAltShuffle() || VL.size() > 2) + return false; + if (VectorizableTree.size() < MinTreeSize) + return false; + if (Depth >= RecursionMaxDepth - 1) + return true; + // Check if all operands are extracts, part of vector node or can build a + // regular vectorize node. + SmallVector<unsigned, 2> InstsCount(VL.size(), 0); + for (Value *V : VL) { + auto *I = cast<Instruction>(V); + InstsCount.push_back(count_if(I->operand_values(), [](Value *Op) { + return isa<Instruction>(Op) || isVectorLikeInstWithConstOps(Op); + })); + } + bool IsCommutative = isCommutative(S.MainOp) || isCommutative(S.AltOp); + if ((IsCommutative && + std::accumulate(InstsCount.begin(), InstsCount.end(), 0) < 2) || + (!IsCommutative && + all_of(InstsCount, [](unsigned ICnt) { return ICnt < 2; }))) + return true; + assert(VL.size() == 2 && "Expected only 2 alternate op instructions."); + SmallVector<SmallVector<std::pair<Value *, Value *>>> Candidates; + auto *I1 = cast<Instruction>(VL.front()); + auto *I2 = cast<Instruction>(VL.back()); + for (int Op = 0, E = S.MainOp->getNumOperands(); Op < E; ++Op) + Candidates.emplace_back().emplace_back(I1->getOperand(Op), + I2->getOperand(Op)); + if (static_cast<unsigned>(count_if( + Candidates, [this](ArrayRef<std::pair<Value *, Value *>> Cand) { + return findBestRootPair(Cand, LookAheadHeuristics::ScoreSplat); + })) >= S.MainOp->getNumOperands() / 2) + return false; + if (S.MainOp->getNumOperands() > 2) + return true; + if (IsCommutative) { + // Check permuted operands. + Candidates.clear(); + for (int Op = 0, E = S.MainOp->getNumOperands(); Op < E; ++Op) + Candidates.emplace_back().emplace_back(I1->getOperand(Op), + I2->getOperand((Op + 1) % E)); + if (any_of( + Candidates, [this](ArrayRef<std::pair<Value *, Value *>> Cand) { + return findBestRootPair(Cand, LookAheadHeuristics::ScoreSplat); + })) + return false; + } + return true; + }; + SmallVector<unsigned> SortedIndices; + BasicBlock *BB = nullptr; + bool AreAllSameInsts = + (S.getOpcode() && allSameBlock(VL)) || + (S.OpValue->getType()->isPointerTy() && UserTreeIdx.UserTE && + UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize && + VL.size() > 2 && + all_of(VL, + [&BB](Value *V) { + auto *I = dyn_cast<GetElementPtrInst>(V); + if (!I) + return doesNotNeedToBeScheduled(V); + if (!BB) + BB = I->getParent(); + return BB == I->getParent() && I->getNumOperands() == 2; + }) && + BB && + sortPtrAccesses(VL, UserTreeIdx.UserTE->getMainOp()->getType(), *DL, *SE, + SortedIndices)); + if (allConstant(VL) || isSplat(VL) || !AreAllSameInsts || + (isa<InsertElementInst, ExtractValueInst, ExtractElementInst>( + S.OpValue) && + !all_of(VL, isVectorLikeInstWithConstOps)) || + NotProfitableForVectorization(VL)) { + LLVM_DEBUG(dbgs() << "SLP: Gathering due to C,S,B,O, small shuffle. \n"); if (TryToFindDuplicates(S)) newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, ReuseShuffleIndicies); @@ -3665,12 +4717,14 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // the same block. // Don't vectorize ephemeral values. - for (Value *V : VL) { - if (EphValues.count(V)) { - LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *V - << ") is ephemeral.\n"); - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); - return; + if (!EphValues.empty()) { + for (Value *V : VL) { + if (EphValues.count(V)) { + LLVM_DEBUG(dbgs() << "SLP: The instruction (" << *V + << ") is ephemeral.\n"); + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx); + return; + } } } @@ -3708,20 +4762,37 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } // The reduction nodes (stored in UserIgnoreList) also should stay scalar. - for (Value *V : VL) { - if (is_contained(UserIgnoreList, V)) { - LLVM_DEBUG(dbgs() << "SLP: Gathering due to gathered scalar.\n"); - if (TryToFindDuplicates(S)) - newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - return; + if (UserIgnoreList && !UserIgnoreList->empty()) { + for (Value *V : VL) { + if (UserIgnoreList && UserIgnoreList->contains(V)) { + LLVM_DEBUG(dbgs() << "SLP: Gathering due to gathered scalar.\n"); + if (TryToFindDuplicates(S)) + newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, + ReuseShuffleIndicies); + return; + } } } + // Special processing for sorted pointers for ScatterVectorize node with + // constant indeces only. + if (AreAllSameInsts && !(S.getOpcode() && allSameBlock(VL)) && + UserTreeIdx.UserTE && + UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize) { + assert(S.OpValue->getType()->isPointerTy() && + count_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); }) >= + 2 && + "Expected pointers only."); + // Reset S to make it GetElementPtr kind of node. + const auto *It = find_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); }); + assert(It != VL.end() && "Expected at least one GEP."); + S = getSameOpcode(*It); + } + // Check that all of the users of the scalars that we want to vectorize are // schedulable. auto *VL0 = cast<Instruction>(S.OpValue); - BasicBlock *BB = VL0->getParent(); + BB = VL0->getParent(); if (!DT->isReachableFromEntry(BB)) { // Don't go into unreachable blocks. They may contain instructions with @@ -3739,9 +4810,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (!BSRef) BSRef = std::make_unique<BlockScheduling>(BB); - BlockScheduling &BS = *BSRef.get(); + BlockScheduling &BS = *BSRef; Optional<ScheduleData *> Bundle = BS.tryScheduleBundle(VL, this, S); +#ifdef EXPENSIVE_CHECKS + // Make sure we didn't break any internal invariants + BS.verify(); +#endif if (!Bundle) { LLVM_DEBUG(dbgs() << "SLP: We are not able to schedule this bundle!\n"); assert((!BS.getScheduleData(VL0) || @@ -3761,10 +4836,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // Check for terminator values (e.g. invoke). for (Value *V : VL) - for (unsigned I = 0, E = PH->getNumIncomingValues(); I < E; ++I) { - Instruction *Term = dyn_cast<Instruction>( - cast<PHINode>(V)->getIncomingValueForBlock( - PH->getIncomingBlock(I))); + for (Value *Incoming : cast<PHINode>(V)->incoming_values()) { + Instruction *Term = dyn_cast<Instruction>(Incoming); if (Term && Term->isTerminator()) { LLVM_DEBUG(dbgs() << "SLP: Need to swizzle PHINodes (terminator use).\n"); @@ -3908,7 +4981,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, SmallVector<Value *> PointerOps; OrdersType CurrentOrder; TreeEntry *TE = nullptr; - switch (canVectorizeLoads(VL, VL0, *TTI, *DL, *SE, CurrentOrder, + switch (canVectorizeLoads(VL, VL0, *TTI, *DL, *SE, *LI, CurrentOrder, PointerOps)) { case LoadsState::Vectorize: if (CurrentOrder.empty()) { @@ -4089,7 +5162,10 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, case Instruction::GetElementPtr: { // We don't combine GEPs with complicated (nested) indexing. for (Value *V : VL) { - if (cast<Instruction>(V)->getNumOperands() != 2) { + auto *I = dyn_cast<GetElementPtrInst>(V); + if (!I) + continue; + if (I->getNumOperands() != 2) { LLVM_DEBUG(dbgs() << "SLP: not-vectorizable GEP (nested indexes).\n"); BS.cancelScheduling(VL, VL0); newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx, @@ -4100,9 +5176,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // We can't combine several GEPs into one vector if they operate on // different types. - Type *Ty0 = VL0->getOperand(0)->getType(); + Type *Ty0 = cast<GEPOperator>(VL0)->getSourceElementType(); for (Value *V : VL) { - Type *CurTy = cast<Instruction>(V)->getOperand(0)->getType(); + auto *GEP = dyn_cast<GEPOperator>(V); + if (!GEP) + continue; + Type *CurTy = GEP->getSourceElementType(); if (Ty0 != CurTy) { LLVM_DEBUG(dbgs() << "SLP: not-vectorizable GEP (different types).\n"); @@ -4113,15 +5192,22 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } } + bool IsScatterUser = + UserTreeIdx.UserTE && + UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize; // We don't combine GEPs with non-constant indexes. Type *Ty1 = VL0->getOperand(1)->getType(); for (Value *V : VL) { - auto Op = cast<Instruction>(V)->getOperand(1); - if (!isa<ConstantInt>(Op) || + auto *I = dyn_cast<GetElementPtrInst>(V); + if (!I) + continue; + auto *Op = I->getOperand(1); + if ((!IsScatterUser && !isa<ConstantInt>(Op)) || (Op->getType() != Ty1 && - Op->getType()->getScalarSizeInBits() > - DL->getIndexSizeInBits( - V->getType()->getPointerAddressSpace()))) { + ((IsScatterUser && !isa<ConstantInt>(Op)) || + Op->getType()->getScalarSizeInBits() > + DL->getIndexSizeInBits( + V->getType()->getPointerAddressSpace())))) { LLVM_DEBUG(dbgs() << "SLP: not-vectorizable GEP (non-constant indexes).\n"); BS.cancelScheduling(VL, VL0); @@ -4136,9 +5222,14 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, LLVM_DEBUG(dbgs() << "SLP: added a vector of GEPs.\n"); SmallVector<ValueList, 2> Operands(2); // Prepare the operand vector for pointer operands. - for (Value *V : VL) - Operands.front().push_back( - cast<GetElementPtrInst>(V)->getPointerOperand()); + for (Value *V : VL) { + auto *GEP = dyn_cast<GetElementPtrInst>(V); + if (!GEP) { + Operands.front().push_back(V); + continue; + } + Operands.front().push_back(GEP->getPointerOperand()); + } TE->setOperand(0, Operands.front()); // Need to cast all indices to the same type before vectorization to // avoid crash. @@ -4149,9 +5240,10 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, Type *VL0Ty = VL0->getOperand(IndexIdx)->getType(); Type *Ty = all_of(VL, [VL0Ty, IndexIdx](Value *V) { - return VL0Ty == cast<GetElementPtrInst>(V) - ->getOperand(IndexIdx) - ->getType(); + auto *GEP = dyn_cast<GetElementPtrInst>(V); + if (!GEP) + return true; + return VL0Ty == GEP->getOperand(IndexIdx)->getType(); }) ? VL0Ty : DL->getIndexType(cast<GetElementPtrInst>(VL0) @@ -4159,10 +5251,19 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, ->getScalarType()); // Prepare the operand vector. for (Value *V : VL) { - auto *Op = cast<Instruction>(V)->getOperand(IndexIdx); - auto *CI = cast<ConstantInt>(Op); - Operands.back().push_back(ConstantExpr::getIntegerCast( - CI, Ty, CI->getValue().isSignBitSet())); + auto *I = dyn_cast<GetElementPtrInst>(V); + if (!I) { + Operands.back().push_back( + ConstantInt::get(Ty, 0, /*isSigned=*/false)); + continue; + } + auto *Op = I->getOperand(IndexIdx); + auto *CI = dyn_cast<ConstantInt>(Op); + if (!CI) + Operands.back().push_back(Op); + else + Operands.back().push_back(ConstantExpr::getIntegerCast( + CI, Ty, CI->getValue().isSignBitSet())); } TE->setOperand(IndexIdx, Operands.back()); @@ -4268,7 +5369,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, unsigned NumArgs = CI->arg_size(); SmallVector<Value*, 4> ScalarArgs(NumArgs, nullptr); for (unsigned j = 0; j != NumArgs; ++j) - if (hasVectorInstrinsicScalarOpd(ID, j)) + if (isVectorIntrinsicWithScalarOpAtArg(ID, j)) ScalarArgs[j] = CI->getArgOperand(j); for (Value *V : VL) { CallInst *CI2 = dyn_cast<CallInst>(V); @@ -4287,7 +5388,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // Some intrinsics have scalar arguments and should be same in order for // them to be vectorized. for (unsigned j = 0; j != NumArgs; ++j) { - if (hasVectorInstrinsicScalarOpd(ID, j)) { + if (isVectorIntrinsicWithScalarOpAtArg(ID, j)) { Value *A1J = CI2->getArgOperand(j); if (ScalarArgs[j] != A1J) { BS.cancelScheduling(VL, VL0); @@ -4320,7 +5421,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, for (unsigned i = 0, e = CI->arg_size(); i != e; ++i) { // For scalar operands no need to to create an entry since no need to // vectorize it. - if (hasVectorInstrinsicScalarOpd(ID, i)) + if (isVectorIntrinsicWithScalarOpAtArg(ID, i)) continue; ValueList Operands; // Prepare the operand vector. @@ -4347,9 +5448,42 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, LLVM_DEBUG(dbgs() << "SLP: added a ShuffleVector op.\n"); // Reorder operands if reordering would enable vectorization. - if (isa<BinaryOperator>(VL0)) { + auto *CI = dyn_cast<CmpInst>(VL0); + if (isa<BinaryOperator>(VL0) || CI) { ValueList Left, Right; - reorderInputsAccordingToOpcode(VL, Left, Right, *DL, *SE, *this); + if (!CI || all_of(VL, [](Value *V) { + return cast<CmpInst>(V)->isCommutative(); + })) { + reorderInputsAccordingToOpcode(VL, Left, Right, *DL, *SE, *this); + } else { + CmpInst::Predicate P0 = CI->getPredicate(); + CmpInst::Predicate AltP0 = cast<CmpInst>(S.AltOp)->getPredicate(); + assert(P0 != AltP0 && + "Expected different main/alternate predicates."); + CmpInst::Predicate AltP0Swapped = CmpInst::getSwappedPredicate(AltP0); + Value *BaseOp0 = VL0->getOperand(0); + Value *BaseOp1 = VL0->getOperand(1); + // Collect operands - commute if it uses the swapped predicate or + // alternate operation. + for (Value *V : VL) { + auto *Cmp = cast<CmpInst>(V); + Value *LHS = Cmp->getOperand(0); + Value *RHS = Cmp->getOperand(1); + CmpInst::Predicate CurrentPred = Cmp->getPredicate(); + if (P0 == AltP0Swapped) { + if (CI != Cmp && S.AltOp != Cmp && + ((P0 == CurrentPred && + !areCompatibleCmpOps(BaseOp0, BaseOp1, LHS, RHS)) || + (AltP0 == CurrentPred && + areCompatibleCmpOps(BaseOp0, BaseOp1, LHS, RHS)))) + std::swap(LHS, RHS); + } else if (P0 != CurrentPred && AltP0 != CurrentPred) { + std::swap(LHS, RHS); + } + Left.push_back(LHS); + Right.push_back(RHS); + } + } TE->setOperand(0, Left); TE->setOperand(1, Right); buildTree_rec(Left, Depth + 1, {TE, 0}); @@ -4493,7 +5627,9 @@ bool BoUpSLP::areAllUsersVectorized(Instruction *I, ArrayRef<Value *> VectorizedVals) const { return (I->hasOneUse() && is_contained(VectorizedVals, I)) || all_of(I->users(), [this](User *U) { - return ScalarToTreeEntry.count(U) > 0 || MustGather.contains(U); + return ScalarToTreeEntry.count(U) > 0 || + isVectorLikeInstWithConstOps(U) || + (isa<ExtractElementInst>(U) && MustGather.contains(U)); }); } @@ -4550,19 +5686,21 @@ computeExtractCost(ArrayRef<Value *> VL, FixedVectorType *VecTy, // Process extracts in blocks of EltsPerVector to check if the source vector // operand can be re-used directly. If not, add the cost of creating a shuffle // to extract the values into a vector register. + SmallVector<int> RegMask(EltsPerVector, UndefMaskElem); for (auto *V : VL) { ++Idx; - // Need to exclude undefs from analysis. - if (isa<UndefValue>(V) || Mask[Idx] == UndefMaskElem) - continue; - // Reached the start of a new vector registers. if (Idx % EltsPerVector == 0) { + RegMask.assign(EltsPerVector, UndefMaskElem); AllConsecutive = true; continue; } + // Need to exclude undefs from analysis. + if (isa<UndefValue>(V) || Mask[Idx] == UndefMaskElem) + continue; + // Check all extracts for a vector register on the target directly // extract values in order. unsigned CurrentIdx = *getExtractIndex(cast<Instruction>(V)); @@ -4570,6 +5708,7 @@ computeExtractCost(ArrayRef<Value *> VL, FixedVectorType *VecTy, unsigned PrevIdx = *getExtractIndex(cast<Instruction>(VL[Idx - 1])); AllConsecutive &= PrevIdx + 1 == CurrentIdx && CurrentIdx % EltsPerVector == Idx % EltsPerVector; + RegMask[Idx % EltsPerVector] = CurrentIdx % EltsPerVector; } if (AllConsecutive) @@ -4581,10 +5720,10 @@ computeExtractCost(ArrayRef<Value *> VL, FixedVectorType *VecTy, // If we have a series of extracts which are not consecutive and hence // cannot re-use the source vector register directly, compute the shuffle - // cost to extract the a vector with EltsPerVector elements. + // cost to extract the vector with EltsPerVector elements. Cost += TTI.getShuffleCost( TargetTransformInfo::SK_PermuteSingleSrc, - FixedVectorType::get(VecTy->getElementType(), EltsPerVector)); + FixedVectorType::get(VecTy->getElementType(), EltsPerVector), RegMask); } return Cost; } @@ -4592,12 +5731,12 @@ computeExtractCost(ArrayRef<Value *> VL, FixedVectorType *VecTy, /// Build shuffle mask for shuffle graph entries and lists of main and alternate /// operations operands. static void -buildSuffleEntryMask(ArrayRef<Value *> VL, ArrayRef<unsigned> ReorderIndices, - ArrayRef<int> ReusesIndices, - const function_ref<bool(Instruction *)> IsAltOp, - SmallVectorImpl<int> &Mask, - SmallVectorImpl<Value *> *OpScalars = nullptr, - SmallVectorImpl<Value *> *AltScalars = nullptr) { +buildShuffleEntryMask(ArrayRef<Value *> VL, ArrayRef<unsigned> ReorderIndices, + ArrayRef<int> ReusesIndices, + const function_ref<bool(Instruction *)> IsAltOp, + SmallVectorImpl<int> &Mask, + SmallVectorImpl<Value *> *OpScalars = nullptr, + SmallVectorImpl<Value *> *AltScalars = nullptr) { unsigned Sz = VL.size(); Mask.assign(Sz, UndefMaskElem); SmallVector<int> OrderMask; @@ -4627,6 +5766,29 @@ buildSuffleEntryMask(ArrayRef<Value *> VL, ArrayRef<unsigned> ReorderIndices, } } +/// Checks if the specified instruction \p I is an alternate operation for the +/// given \p MainOp and \p AltOp instructions. +static bool isAlternateInstruction(const Instruction *I, + const Instruction *MainOp, + const Instruction *AltOp) { + if (auto *CI0 = dyn_cast<CmpInst>(MainOp)) { + auto *AltCI0 = cast<CmpInst>(AltOp); + auto *CI = cast<CmpInst>(I); + CmpInst::Predicate P0 = CI0->getPredicate(); + CmpInst::Predicate AltP0 = AltCI0->getPredicate(); + assert(P0 != AltP0 && "Expected different main/alternate predicates."); + CmpInst::Predicate AltP0Swapped = CmpInst::getSwappedPredicate(AltP0); + CmpInst::Predicate CurrentPred = CI->getPredicate(); + if (P0 == AltP0Swapped) + return I == AltCI0 || + (I != MainOp && + !areCompatibleCmpOps(CI0->getOperand(0), CI0->getOperand(1), + CI->getOperand(0), CI->getOperand(1))); + return AltP0 == CurrentPred || AltP0Swapped == CurrentPred; + } + return I->getOpcode() == AltOp->getOpcode(); +} + InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals) { ArrayRef<Value*> VL = E->Scalars; @@ -4740,7 +5902,7 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, SmallVector<const TreeEntry *> Entries; Optional<TargetTransformInfo::ShuffleKind> Shuffle = isGatherShuffledEntry(E, Mask, Entries); - if (Shuffle.hasValue()) { + if (Shuffle) { InstructionCost GatherCost = 0; if (ShuffleVectorInst::isIdentityMask(Mask)) { // Perfect match in the graph, will reuse the previously vectorized @@ -4776,7 +5938,7 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, SmallVector<int> Mask; Optional<TargetTransformInfo::ShuffleKind> ShuffleKind = isFixedVectorShuffle(VL, Mask); - if (ShuffleKind.hasValue()) { + if (ShuffleKind) { // Found the bunch of extractelement instructions that must be gathered // into a vector and can be represented as a permutation elements in a // single input vector or of 2 input vectors. @@ -4794,7 +5956,9 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, // broadcast. assert(VecTy == FinalVecTy && "No reused scalars expected for broadcast."); - return TTI->getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy); + return TTI->getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, + /*Mask=*/None, /*Index=*/0, + /*SubTp=*/nullptr, /*Args=*/VL[0]); } InstructionCost ReuseShuffleCost = 0; if (NeedToShuffleReuses) @@ -4818,8 +5982,9 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, !VectorizedLoads.count(Slice.back()) && allSameBlock(Slice)) { SmallVector<Value *> PointerOps; OrdersType CurrentOrder; - LoadsState LS = canVectorizeLoads(Slice, Slice.front(), *TTI, *DL, - *SE, CurrentOrder, PointerOps); + LoadsState LS = + canVectorizeLoads(Slice, Slice.front(), *TTI, *DL, *SE, *LI, + CurrentOrder, PointerOps); switch (LS) { case LoadsState::Vectorize: case LoadsState::ScatterVectorize: @@ -4909,7 +6074,11 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, assert((E->State == TreeEntry::Vectorize || E->State == TreeEntry::ScatterVectorize) && "Unhandled state"); - assert(E->getOpcode() && allSameType(VL) && allSameBlock(VL) && "Invalid VL"); + assert(E->getOpcode() && + ((allSameType(VL) && allSameBlock(VL)) || + (E->getOpcode() == Instruction::GetElementPtr && + E->getMainOp()->getType()->isPointerTy())) && + "Invalid VL"); Instruction *VL0 = E->getMainOp(); unsigned ShuffleOrOp = E->isAltShuffle() ? (unsigned)Instruction::ShuffleVector : E->getOpcode(); @@ -4981,28 +6150,60 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, assert(E->ReuseShuffleIndices.empty() && "Unique insertelements only are expected."); auto *SrcVecTy = cast<FixedVectorType>(VL0->getType()); - unsigned const NumElts = SrcVecTy->getNumElements(); unsigned const NumScalars = VL.size(); + + unsigned NumOfParts = TTI->getNumberOfParts(SrcVecTy); + + unsigned OffsetBeg = *getInsertIndex(VL.front()); + unsigned OffsetEnd = OffsetBeg; + for (Value *V : VL.drop_front()) { + unsigned Idx = *getInsertIndex(V); + if (OffsetBeg > Idx) + OffsetBeg = Idx; + else if (OffsetEnd < Idx) + OffsetEnd = Idx; + } + unsigned VecScalarsSz = PowerOf2Ceil(NumElts); + if (NumOfParts > 0) + VecScalarsSz = PowerOf2Ceil((NumElts + NumOfParts - 1) / NumOfParts); + unsigned VecSz = + (1 + OffsetEnd / VecScalarsSz - OffsetBeg / VecScalarsSz) * + VecScalarsSz; + unsigned Offset = VecScalarsSz * (OffsetBeg / VecScalarsSz); + unsigned InsertVecSz = std::min<unsigned>( + PowerOf2Ceil(OffsetEnd - OffsetBeg + 1), + ((OffsetEnd - OffsetBeg + VecScalarsSz) / VecScalarsSz) * + VecScalarsSz); + bool IsWholeSubvector = + OffsetBeg == Offset && ((OffsetEnd + 1) % VecScalarsSz == 0); + // Check if we can safely insert a subvector. If it is not possible, just + // generate a whole-sized vector and shuffle the source vector and the new + // subvector. + if (OffsetBeg + InsertVecSz > VecSz) { + // Align OffsetBeg to generate correct mask. + OffsetBeg = alignDown(OffsetBeg, VecSz, Offset); + InsertVecSz = VecSz; + } + APInt DemandedElts = APInt::getZero(NumElts); // TODO: Add support for Instruction::InsertValue. SmallVector<int> Mask; if (!E->ReorderIndices.empty()) { inversePermutation(E->ReorderIndices, Mask); - Mask.append(NumElts - NumScalars, UndefMaskElem); + Mask.append(InsertVecSz - Mask.size(), UndefMaskElem); } else { - Mask.assign(NumElts, UndefMaskElem); - std::iota(Mask.begin(), std::next(Mask.begin(), NumScalars), 0); + Mask.assign(VecSz, UndefMaskElem); + std::iota(Mask.begin(), std::next(Mask.begin(), InsertVecSz), 0); } - unsigned Offset = *getInsertIndex(VL0, 0); bool IsIdentity = true; - SmallVector<int> PrevMask(NumElts, UndefMaskElem); + SmallVector<int> PrevMask(InsertVecSz, UndefMaskElem); Mask.swap(PrevMask); for (unsigned I = 0; I < NumScalars; ++I) { unsigned InsertIdx = *getInsertIndex(VL[PrevMask[I]]); DemandedElts.setBit(InsertIdx); - IsIdentity &= InsertIdx - Offset == I; - Mask[InsertIdx - Offset] = I; + IsIdentity &= InsertIdx - OffsetBeg == I; + Mask[InsertIdx - OffsetBeg] = I; } assert(Offset < NumElts && "Failed to find vector index offset"); @@ -5010,32 +6211,41 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, Cost -= TTI->getScalarizationOverhead(SrcVecTy, DemandedElts, /*Insert*/ true, /*Extract*/ false); - if (IsIdentity && NumElts != NumScalars && Offset % NumScalars != 0) { - // FIXME: Replace with SK_InsertSubvector once it is properly supported. - unsigned Sz = PowerOf2Ceil(Offset + NumScalars); - Cost += TTI->getShuffleCost( - TargetTransformInfo::SK_PermuteSingleSrc, - FixedVectorType::get(SrcVecTy->getElementType(), Sz)); - } else if (!IsIdentity) { - auto *FirstInsert = - cast<Instruction>(*find_if(E->Scalars, [E](Value *V) { - return !is_contained(E->Scalars, - cast<Instruction>(V)->getOperand(0)); - })); - if (isUndefVector(FirstInsert->getOperand(0))) { - Cost += TTI->getShuffleCost(TTI::SK_PermuteSingleSrc, SrcVecTy, Mask); + // First cost - resize to actual vector size if not identity shuffle or + // need to shift the vector. + // Do not calculate the cost if the actual size is the register size and + // we can merge this shuffle with the following SK_Select. + auto *InsertVecTy = + FixedVectorType::get(SrcVecTy->getElementType(), InsertVecSz); + if (!IsIdentity) + Cost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, + InsertVecTy, Mask); + auto *FirstInsert = cast<Instruction>(*find_if(E->Scalars, [E](Value *V) { + return !is_contained(E->Scalars, cast<Instruction>(V)->getOperand(0)); + })); + // Second cost - permutation with subvector, if some elements are from the + // initial vector or inserting a subvector. + // TODO: Implement the analysis of the FirstInsert->getOperand(0) + // subvector of ActualVecTy. + if (!isUndefVector(FirstInsert->getOperand(0)) && NumScalars != NumElts && + !IsWholeSubvector) { + if (InsertVecSz != VecSz) { + auto *ActualVecTy = + FixedVectorType::get(SrcVecTy->getElementType(), VecSz); + Cost += TTI->getShuffleCost(TTI::SK_InsertSubvector, ActualVecTy, + None, OffsetBeg - Offset, InsertVecTy); } else { - SmallVector<int> InsertMask(NumElts); - std::iota(InsertMask.begin(), InsertMask.end(), 0); - for (unsigned I = 0; I < NumElts; I++) { + for (unsigned I = 0, End = OffsetBeg - Offset; I < End; ++I) + Mask[I] = I; + for (unsigned I = OffsetBeg - Offset, End = OffsetEnd - Offset; + I <= End; ++I) if (Mask[I] != UndefMaskElem) - InsertMask[Offset + I] = NumElts + I; - } - Cost += - TTI->getShuffleCost(TTI::SK_PermuteTwoSrc, SrcVecTy, InsertMask); + Mask[I] = I + VecSz; + for (unsigned I = OffsetEnd + 1 - Offset; I < VecSz; ++I) + Mask[I] = I; + Cost += TTI->getShuffleCost(TTI::SK_PermuteTwoSrc, InsertVecTy, Mask); } } - return Cost; } case Instruction::ZExt: @@ -5116,9 +6326,8 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, // If the selects are the only uses of the compares, they will be dead // and we can adjust the cost by removing their cost. if (IntrinsicAndUse.second) - IntrinsicCost -= - TTI->getCmpSelInstrCost(Instruction::ICmp, VecTy, MaskTy, - CmpInst::BAD_ICMP_PREDICATE, CostKind); + IntrinsicCost -= TTI->getCmpSelInstrCost(Instruction::ICmp, VecTy, + MaskTy, VecPred, CostKind); VecCost = std::min(VecCost, IntrinsicCost); } LLVM_DEBUG(dumpTreeCosts(E, CommonCost, VecCost, ScalarCost)); @@ -5198,7 +6407,14 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, TargetTransformInfo::OperandValueKind Op1VK = TargetTransformInfo::OK_AnyValue; TargetTransformInfo::OperandValueKind Op2VK = - TargetTransformInfo::OK_UniformConstantValue; + any_of(VL, + [](Value *V) { + return isa<GetElementPtrInst>(V) && + !isConstant( + cast<GetElementPtrInst>(V)->getOperand(1)); + }) + ? TargetTransformInfo::OK_AnyValue + : TargetTransformInfo::OK_UniformConstantValue; InstructionCost ScalarEltCost = TTI->getArithmeticInstrCost( Instruction::Add, ScalarTy, CostKind, Op1VK, Op2VK); @@ -5229,7 +6445,7 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, Align CommonAlignment = Alignment; for (Value *V : VL) CommonAlignment = - commonAlignment(CommonAlignment, cast<LoadInst>(V)->getAlign()); + std::min(CommonAlignment, cast<LoadInst>(V)->getAlign()); VecLdCost = TTI->getGatherScatterOpCost( Instruction::Load, VecTy, cast<LoadInst>(VL0)->getPointerOperand(), /*VariableMask=*/false, CommonAlignment, CostKind, VL0); @@ -5279,7 +6495,8 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, ((Instruction::isBinaryOp(E->getOpcode()) && Instruction::isBinaryOp(E->getAltOpcode())) || (Instruction::isCast(E->getOpcode()) && - Instruction::isCast(E->getAltOpcode()))) && + Instruction::isCast(E->getAltOpcode())) || + (isa<CmpInst>(VL0) && isa<CmpInst>(E->getAltOp()))) && "Invalid Shuffle Vector Operand"); InstructionCost ScalarCost = 0; if (NeedToShuffleReuses) { @@ -5327,6 +6544,14 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, VecCost = TTI->getArithmeticInstrCost(E->getOpcode(), VecTy, CostKind); VecCost += TTI->getArithmeticInstrCost(E->getAltOpcode(), VecTy, CostKind); + } else if (auto *CI0 = dyn_cast<CmpInst>(VL0)) { + VecCost = TTI->getCmpSelInstrCost(E->getOpcode(), ScalarTy, + Builder.getInt1Ty(), + CI0->getPredicate(), CostKind, VL0); + VecCost += TTI->getCmpSelInstrCost( + E->getOpcode(), ScalarTy, Builder.getInt1Ty(), + cast<CmpInst>(E->getAltOp())->getPredicate(), CostKind, + E->getAltOp()); } else { Type *Src0SclTy = E->getMainOp()->getOperand(0)->getType(); Type *Src1SclTy = E->getAltOp()->getOperand(0)->getType(); @@ -5338,16 +6563,21 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, TTI::CastContextHint::None, CostKind); } - SmallVector<int> Mask; - buildSuffleEntryMask( - E->Scalars, E->ReorderIndices, E->ReuseShuffleIndices, - [E](Instruction *I) { - assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode"); - return I->getOpcode() == E->getAltOpcode(); - }, - Mask); - CommonCost = - TTI->getShuffleCost(TargetTransformInfo::SK_Select, FinalVecTy, Mask); + if (E->ReuseShuffleIndices.empty()) { + CommonCost = + TTI->getShuffleCost(TargetTransformInfo::SK_Select, FinalVecTy); + } else { + SmallVector<int> Mask; + buildShuffleEntryMask( + E->Scalars, E->ReorderIndices, E->ReuseShuffleIndices, + [E](Instruction *I) { + assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode"); + return I->getOpcode() == E->getAltOpcode(); + }, + Mask); + CommonCost = TTI->getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, + FinalVecTy, Mask); + } LLVM_DEBUG(dumpTreeCosts(E, CommonCost, VecCost, ScalarCost)); return CommonCost + VecCost - ScalarCost; } @@ -5475,7 +6705,10 @@ bool BoUpSLP::isTreeTinyAndNotFullyVectorizable(bool ForReduction) const { // No need to vectorize inserts of gathered values. if (VectorizableTree.size() == 2 && isa<InsertElementInst>(VectorizableTree[0]->Scalars[0]) && - VectorizableTree[1]->State == TreeEntry::NeedToGather) + VectorizableTree[1]->State == TreeEntry::NeedToGather && + (VectorizableTree[1]->getVectorFactor() <= 2 || + !(isSplat(VectorizableTree[1]->Scalars) || + allConstant(VectorizableTree[1]->Scalars)))) return true; // We can vectorize the tree if its size is greater than or equal to the @@ -5605,20 +6838,26 @@ static bool areTwoInsertFromSameBuildVector(InsertElementInst *VU, return false; auto *IE1 = VU; auto *IE2 = V; + unsigned Idx1 = *getInsertIndex(IE1); + unsigned Idx2 = *getInsertIndex(IE2); // Go through the vector operand of insertelement instructions trying to find // either VU as the original vector for IE2 or V as the original vector for // IE1. do { - if (IE2 == VU || IE1 == V) - return true; + if (IE2 == VU) + return VU->hasOneUse(); + if (IE1 == V) + return V->hasOneUse(); if (IE1) { - if (IE1 != VU && !IE1->hasOneUse()) + if ((IE1 != VU && !IE1->hasOneUse()) || + getInsertIndex(IE1).value_or(Idx2) == Idx2) IE1 = nullptr; else IE1 = dyn_cast<InsertElementInst>(IE1->getOperand(0)); } if (IE2) { - if (IE2 != V && !IE2->hasOneUse()) + if ((IE2 != V && !IE2->hasOneUse()) || + getInsertIndex(IE2).value_or(Idx1) == Idx1) IE2 = nullptr; else IE2 = dyn_cast<InsertElementInst>(IE2->getOperand(0)); @@ -5627,6 +6866,153 @@ static bool areTwoInsertFromSameBuildVector(InsertElementInst *VU, return false; } +/// Checks if the \p IE1 instructions is followed by \p IE2 instruction in the +/// buildvector sequence. +static bool isFirstInsertElement(const InsertElementInst *IE1, + const InsertElementInst *IE2) { + if (IE1 == IE2) + return false; + const auto *I1 = IE1; + const auto *I2 = IE2; + const InsertElementInst *PrevI1; + const InsertElementInst *PrevI2; + unsigned Idx1 = *getInsertIndex(IE1); + unsigned Idx2 = *getInsertIndex(IE2); + do { + if (I2 == IE1) + return true; + if (I1 == IE2) + return false; + PrevI1 = I1; + PrevI2 = I2; + if (I1 && (I1 == IE1 || I1->hasOneUse()) && + getInsertIndex(I1).value_or(Idx2) != Idx2) + I1 = dyn_cast<InsertElementInst>(I1->getOperand(0)); + if (I2 && ((I2 == IE2 || I2->hasOneUse())) && + getInsertIndex(I2).value_or(Idx1) != Idx1) + I2 = dyn_cast<InsertElementInst>(I2->getOperand(0)); + } while ((I1 && PrevI1 != I1) || (I2 && PrevI2 != I2)); + llvm_unreachable("Two different buildvectors not expected."); +} + +namespace { +/// Returns incoming Value *, if the requested type is Value * too, or a default +/// value, otherwise. +struct ValueSelect { + template <typename U> + static typename std::enable_if<std::is_same<Value *, U>::value, Value *>::type + get(Value *V) { + return V; + } + template <typename U> + static typename std::enable_if<!std::is_same<Value *, U>::value, U>::type + get(Value *) { + return U(); + } +}; +} // namespace + +/// Does the analysis of the provided shuffle masks and performs the requested +/// actions on the vectors with the given shuffle masks. It tries to do it in +/// several steps. +/// 1. If the Base vector is not undef vector, resizing the very first mask to +/// have common VF and perform action for 2 input vectors (including non-undef +/// Base). Other shuffle masks are combined with the resulting after the 1 stage +/// and processed as a shuffle of 2 elements. +/// 2. If the Base is undef vector and have only 1 shuffle mask, perform the +/// action only for 1 vector with the given mask, if it is not the identity +/// mask. +/// 3. If > 2 masks are used, perform the remaining shuffle actions for 2 +/// vectors, combing the masks properly between the steps. +template <typename T> +static T *performExtractsShuffleAction( + MutableArrayRef<std::pair<T *, SmallVector<int>>> ShuffleMask, Value *Base, + function_ref<unsigned(T *)> GetVF, + function_ref<std::pair<T *, bool>(T *, ArrayRef<int>)> ResizeAction, + function_ref<T *(ArrayRef<int>, ArrayRef<T *>)> Action) { + assert(!ShuffleMask.empty() && "Empty list of shuffles for inserts."); + SmallVector<int> Mask(ShuffleMask.begin()->second); + auto VMIt = std::next(ShuffleMask.begin()); + T *Prev = nullptr; + bool IsBaseNotUndef = !isUndefVector(Base); + if (IsBaseNotUndef) { + // Base is not undef, need to combine it with the next subvectors. + std::pair<T *, bool> Res = ResizeAction(ShuffleMask.begin()->first, Mask); + for (unsigned Idx = 0, VF = Mask.size(); Idx < VF; ++Idx) { + if (Mask[Idx] == UndefMaskElem) + Mask[Idx] = Idx; + else + Mask[Idx] = (Res.second ? Idx : Mask[Idx]) + VF; + } + auto *V = ValueSelect::get<T *>(Base); + (void)V; + assert((!V || GetVF(V) == Mask.size()) && + "Expected base vector of VF number of elements."); + Prev = Action(Mask, {nullptr, Res.first}); + } else if (ShuffleMask.size() == 1) { + // Base is undef and only 1 vector is shuffled - perform the action only for + // single vector, if the mask is not the identity mask. + std::pair<T *, bool> Res = ResizeAction(ShuffleMask.begin()->first, Mask); + if (Res.second) + // Identity mask is found. + Prev = Res.first; + else + Prev = Action(Mask, {ShuffleMask.begin()->first}); + } else { + // Base is undef and at least 2 input vectors shuffled - perform 2 vectors + // shuffles step by step, combining shuffle between the steps. + unsigned Vec1VF = GetVF(ShuffleMask.begin()->first); + unsigned Vec2VF = GetVF(VMIt->first); + if (Vec1VF == Vec2VF) { + // No need to resize the input vectors since they are of the same size, we + // can shuffle them directly. + ArrayRef<int> SecMask = VMIt->second; + for (unsigned I = 0, VF = Mask.size(); I < VF; ++I) { + if (SecMask[I] != UndefMaskElem) { + assert(Mask[I] == UndefMaskElem && "Multiple uses of scalars."); + Mask[I] = SecMask[I] + Vec1VF; + } + } + Prev = Action(Mask, {ShuffleMask.begin()->first, VMIt->first}); + } else { + // Vectors of different sizes - resize and reshuffle. + std::pair<T *, bool> Res1 = + ResizeAction(ShuffleMask.begin()->first, Mask); + std::pair<T *, bool> Res2 = ResizeAction(VMIt->first, VMIt->second); + ArrayRef<int> SecMask = VMIt->second; + for (unsigned I = 0, VF = Mask.size(); I < VF; ++I) { + if (Mask[I] != UndefMaskElem) { + assert(SecMask[I] == UndefMaskElem && "Multiple uses of scalars."); + if (Res1.second) + Mask[I] = I; + } else if (SecMask[I] != UndefMaskElem) { + assert(Mask[I] == UndefMaskElem && "Multiple uses of scalars."); + Mask[I] = (Res2.second ? I : SecMask[I]) + VF; + } + } + Prev = Action(Mask, {Res1.first, Res2.first}); + } + VMIt = std::next(VMIt); + } + // Perform requested actions for the remaining masks/vectors. + for (auto E = ShuffleMask.end(); VMIt != E; ++VMIt) { + // Shuffle other input vectors, if any. + std::pair<T *, bool> Res = ResizeAction(VMIt->first, VMIt->second); + ArrayRef<int> SecMask = VMIt->second; + for (unsigned I = 0, VF = Mask.size(); I < VF; ++I) { + if (SecMask[I] != UndefMaskElem) { + assert((Mask[I] == UndefMaskElem || IsBaseNotUndef) && + "Multiple uses of scalars."); + Mask[I] = (Res.second ? I : SecMask[I]) + VF; + } else if (Mask[I] != UndefMaskElem) { + Mask[I] = I; + } + } + Prev = Action(Mask, {Prev, Res.first}); + } + return Prev; +} + InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { InstructionCost Cost = 0; LLVM_DEBUG(dbgs() << "SLP: Calculating cost for tree of size " @@ -5635,7 +7021,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { unsigned BundleWidth = VectorizableTree[0]->Scalars.size(); for (unsigned I = 0, E = VectorizableTree.size(); I < E; ++I) { - TreeEntry &TE = *VectorizableTree[I].get(); + TreeEntry &TE = *VectorizableTree[I]; InstructionCost C = getEntryCost(&TE, VectorizedVals); Cost += C; @@ -5647,9 +7033,8 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { SmallPtrSet<Value *, 16> ExtractCostCalculated; InstructionCost ExtractCost = 0; - SmallVector<unsigned> VF; - SmallVector<SmallVector<int>> ShuffleMask; - SmallVector<Value *> FirstUsers; + SmallVector<MapVector<const TreeEntry *, SmallVector<int>>> ShuffleMasks; + SmallVector<std::pair<Value *, const TreeEntry *>> FirstUsers; SmallVector<APInt> DemandedElts; for (ExternalUser &EU : ExternalUses) { // We only add extract cost once for the same scalar. @@ -5678,37 +7063,55 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { if (auto *FTy = dyn_cast<FixedVectorType>(VU->getType())) { Optional<unsigned> InsertIdx = getInsertIndex(VU); if (InsertIdx) { - auto *It = find_if(FirstUsers, [VU](Value *V) { - return areTwoInsertFromSameBuildVector(VU, - cast<InsertElementInst>(V)); - }); + const TreeEntry *ScalarTE = getTreeEntry(EU.Scalar); + auto *It = + find_if(FirstUsers, + [VU](const std::pair<Value *, const TreeEntry *> &Pair) { + return areTwoInsertFromSameBuildVector( + VU, cast<InsertElementInst>(Pair.first)); + }); int VecId = -1; if (It == FirstUsers.end()) { - VF.push_back(FTy->getNumElements()); - ShuffleMask.emplace_back(VF.back(), UndefMaskElem); + (void)ShuffleMasks.emplace_back(); + SmallVectorImpl<int> &Mask = ShuffleMasks.back()[ScalarTE]; + if (Mask.empty()) + Mask.assign(FTy->getNumElements(), UndefMaskElem); // Find the insertvector, vectorized in tree, if any. Value *Base = VU; - while (isa<InsertElementInst>(Base)) { + while (auto *IEBase = dyn_cast<InsertElementInst>(Base)) { + if (IEBase != EU.User && + (!IEBase->hasOneUse() || + getInsertIndex(IEBase).value_or(*InsertIdx) == *InsertIdx)) + break; // Build the mask for the vectorized insertelement instructions. - if (const TreeEntry *E = getTreeEntry(Base)) { - VU = cast<InsertElementInst>(Base); + if (const TreeEntry *E = getTreeEntry(IEBase)) { + VU = IEBase; do { - int Idx = E->findLaneForValue(Base); - ShuffleMask.back()[Idx] = Idx; - Base = cast<InsertElementInst>(Base)->getOperand(0); + IEBase = cast<InsertElementInst>(Base); + int Idx = *getInsertIndex(IEBase); + assert(Mask[Idx] == UndefMaskElem && + "InsertElementInstruction used already."); + Mask[Idx] = Idx; + Base = IEBase->getOperand(0); } while (E == getTreeEntry(Base)); break; } Base = cast<InsertElementInst>(Base)->getOperand(0); } - FirstUsers.push_back(VU); - DemandedElts.push_back(APInt::getZero(VF.back())); + FirstUsers.emplace_back(VU, ScalarTE); + DemandedElts.push_back(APInt::getZero(FTy->getNumElements())); VecId = FirstUsers.size() - 1; } else { + if (isFirstInsertElement(VU, cast<InsertElementInst>(It->first))) + It->first = VU; VecId = std::distance(FirstUsers.begin(), It); } - ShuffleMask[VecId][*InsertIdx] = EU.Lane; - DemandedElts[VecId].setBit(*InsertIdx); + int InIdx = *InsertIdx; + SmallVectorImpl<int> &Mask = ShuffleMasks[VecId][ScalarTE]; + if (Mask.empty()) + Mask.assign(FTy->getNumElements(), UndefMaskElem); + Mask[InIdx] = EU.Lane; + DemandedElts[VecId].setBit(InIdx); continue; } } @@ -5734,86 +7137,75 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { InstructionCost SpillCost = getSpillCost(); Cost += SpillCost + ExtractCost; - if (FirstUsers.size() == 1) { - int Limit = ShuffleMask.front().size() * 2; - if (all_of(ShuffleMask.front(), [Limit](int Idx) { return Idx < Limit; }) && - !ShuffleVectorInst::isIdentityMask(ShuffleMask.front())) { - InstructionCost C = TTI->getShuffleCost( + auto &&ResizeToVF = [this, &Cost](const TreeEntry *TE, ArrayRef<int> Mask) { + InstructionCost C = 0; + unsigned VF = Mask.size(); + unsigned VecVF = TE->getVectorFactor(); + if (VF != VecVF && + (any_of(Mask, [VF](int Idx) { return Idx >= static_cast<int>(VF); }) || + (all_of(Mask, + [VF](int Idx) { return Idx < 2 * static_cast<int>(VF); }) && + !ShuffleVectorInst::isIdentityMask(Mask)))) { + SmallVector<int> OrigMask(VecVF, UndefMaskElem); + std::copy(Mask.begin(), std::next(Mask.begin(), std::min(VF, VecVF)), + OrigMask.begin()); + C = TTI->getShuffleCost( TTI::SK_PermuteSingleSrc, - cast<FixedVectorType>(FirstUsers.front()->getType()), - ShuffleMask.front()); - LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C - << " for final shuffle of insertelement external users " - << *VectorizableTree.front()->Scalars.front() << ".\n" - << "SLP: Current total cost = " << Cost << "\n"); + FixedVectorType::get(TE->getMainOp()->getType(), VecVF), OrigMask); + LLVM_DEBUG( + dbgs() << "SLP: Adding cost " << C + << " for final shuffle of insertelement external users.\n"; + TE->dump(); dbgs() << "SLP: Current total cost = " << Cost << "\n"); Cost += C; + return std::make_pair(TE, true); } + return std::make_pair(TE, false); + }; + // Calculate the cost of the reshuffled vectors, if any. + for (int I = 0, E = FirstUsers.size(); I < E; ++I) { + Value *Base = cast<Instruction>(FirstUsers[I].first)->getOperand(0); + unsigned VF = ShuffleMasks[I].begin()->second.size(); + auto *FTy = FixedVectorType::get( + cast<VectorType>(FirstUsers[I].first->getType())->getElementType(), VF); + auto Vector = ShuffleMasks[I].takeVector(); + auto &&EstimateShufflesCost = [this, FTy, + &Cost](ArrayRef<int> Mask, + ArrayRef<const TreeEntry *> TEs) { + assert((TEs.size() == 1 || TEs.size() == 2) && + "Expected exactly 1 or 2 tree entries."); + if (TEs.size() == 1) { + int Limit = 2 * Mask.size(); + if (!all_of(Mask, [Limit](int Idx) { return Idx < Limit; }) || + !ShuffleVectorInst::isIdentityMask(Mask)) { + InstructionCost C = + TTI->getShuffleCost(TTI::SK_PermuteSingleSrc, FTy, Mask); + LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C + << " for final shuffle of insertelement " + "external users.\n"; + TEs.front()->dump(); + dbgs() << "SLP: Current total cost = " << Cost << "\n"); + Cost += C; + } + } else { + InstructionCost C = + TTI->getShuffleCost(TTI::SK_PermuteTwoSrc, FTy, Mask); + LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C + << " for final shuffle of vector node and external " + "insertelement users.\n"; + if (TEs.front()) { TEs.front()->dump(); } TEs.back()->dump(); + dbgs() << "SLP: Current total cost = " << Cost << "\n"); + Cost += C; + } + return TEs.back(); + }; + (void)performExtractsShuffleAction<const TreeEntry>( + makeMutableArrayRef(Vector.data(), Vector.size()), Base, + [](const TreeEntry *E) { return E->getVectorFactor(); }, ResizeToVF, + EstimateShufflesCost); InstructionCost InsertCost = TTI->getScalarizationOverhead( - cast<FixedVectorType>(FirstUsers.front()->getType()), - DemandedElts.front(), /*Insert*/ true, /*Extract*/ false); - LLVM_DEBUG(dbgs() << "SLP: subtracting the cost " << InsertCost - << " for insertelements gather.\n" - << "SLP: Current total cost = " << Cost << "\n"); - Cost -= InsertCost; - } else if (FirstUsers.size() >= 2) { - unsigned MaxVF = *std::max_element(VF.begin(), VF.end()); - // Combined masks of the first 2 vectors. - SmallVector<int> CombinedMask(MaxVF, UndefMaskElem); - copy(ShuffleMask.front(), CombinedMask.begin()); - APInt CombinedDemandedElts = DemandedElts.front().zextOrSelf(MaxVF); - auto *VecTy = FixedVectorType::get( - cast<VectorType>(FirstUsers.front()->getType())->getElementType(), - MaxVF); - for (int I = 0, E = ShuffleMask[1].size(); I < E; ++I) { - if (ShuffleMask[1][I] != UndefMaskElem) { - CombinedMask[I] = ShuffleMask[1][I] + MaxVF; - CombinedDemandedElts.setBit(I); - } - } - InstructionCost C = - TTI->getShuffleCost(TTI::SK_PermuteTwoSrc, VecTy, CombinedMask); - LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C - << " for final shuffle of vector node and external " - "insertelement users " - << *VectorizableTree.front()->Scalars.front() << ".\n" - << "SLP: Current total cost = " << Cost << "\n"); - Cost += C; - InstructionCost InsertCost = TTI->getScalarizationOverhead( - VecTy, CombinedDemandedElts, /*Insert*/ true, /*Extract*/ false); - LLVM_DEBUG(dbgs() << "SLP: subtracting the cost " << InsertCost - << " for insertelements gather.\n" - << "SLP: Current total cost = " << Cost << "\n"); + cast<FixedVectorType>(FirstUsers[I].first->getType()), DemandedElts[I], + /*Insert*/ true, /*Extract*/ false); Cost -= InsertCost; - for (int I = 2, E = FirstUsers.size(); I < E; ++I) { - // Other elements - permutation of 2 vectors (the initial one and the - // next Ith incoming vector). - unsigned VF = ShuffleMask[I].size(); - for (unsigned Idx = 0; Idx < VF; ++Idx) { - int Mask = ShuffleMask[I][Idx]; - if (Mask != UndefMaskElem) - CombinedMask[Idx] = MaxVF + Mask; - else if (CombinedMask[Idx] != UndefMaskElem) - CombinedMask[Idx] = Idx; - } - for (unsigned Idx = VF; Idx < MaxVF; ++Idx) - if (CombinedMask[Idx] != UndefMaskElem) - CombinedMask[Idx] = Idx; - InstructionCost C = - TTI->getShuffleCost(TTI::SK_PermuteTwoSrc, VecTy, CombinedMask); - LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C - << " for final shuffle of vector node and external " - "insertelement users " - << *VectorizableTree.front()->Scalars.front() << ".\n" - << "SLP: Current total cost = " << Cost << "\n"); - Cost += C; - InstructionCost InsertCost = TTI->getScalarizationOverhead( - cast<FixedVectorType>(FirstUsers[I]->getType()), DemandedElts[I], - /*Insert*/ true, /*Extract*/ false); - LLVM_DEBUG(dbgs() << "SLP: subtracting the cost " << InsertCost - << " for insertelements gather.\n" - << "SLP: Current total cost = " << Cost << "\n"); - Cost -= InsertCost; - } } #ifndef NDEBUG @@ -5906,6 +7298,12 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, SmallVectorImpl<int> &Mask, } } + if (UsedTEs.empty()) { + assert(all_of(TE->Scalars, UndefValue::classof) && + "Expected vector of undefs only."); + return None; + } + unsigned VF = 0; if (UsedTEs.size() == 1) { // Try to find the perfect match in another gather node at first. @@ -5965,17 +7363,11 @@ BoUpSLP::isGatherShuffledEntry(const TreeEntry *TE, SmallVectorImpl<int> &Mask, return None; } -InstructionCost -BoUpSLP::getGatherCost(FixedVectorType *Ty, - const DenseSet<unsigned> &ShuffledIndices, - bool NeedToShuffle) const { - unsigned NumElts = Ty->getNumElements(); - APInt DemandedElts = APInt::getZero(NumElts); - for (unsigned I = 0; I < NumElts; ++I) - if (!ShuffledIndices.count(I)) - DemandedElts.setBit(I); +InstructionCost BoUpSLP::getGatherCost(FixedVectorType *Ty, + const APInt &ShuffledIndices, + bool NeedToShuffle) const { InstructionCost Cost = - TTI->getScalarizationOverhead(Ty, DemandedElts, /*Insert*/ true, + TTI->getScalarizationOverhead(Ty, ~ShuffledIndices, /*Insert*/ true, /*Extract*/ false); if (NeedToShuffle) Cost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, Ty); @@ -5992,19 +7384,19 @@ InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL) const { // Find the cost of inserting/extracting values from the vector. // Check if the same elements are inserted several times and count them as // shuffle candidates. - DenseSet<unsigned> ShuffledElements; + APInt ShuffledElements = APInt::getZero(VL.size()); DenseSet<Value *> UniqueElements; // Iterate in reverse order to consider insert elements with the high cost. for (unsigned I = VL.size(); I > 0; --I) { unsigned Idx = I - 1; // No need to shuffle duplicates for constants. if (isConstant(VL[Idx])) { - ShuffledElements.insert(Idx); + ShuffledElements.setBit(Idx); continue; } if (!UniqueElements.insert(VL[Idx]).second) { DuplicateNonConst = true; - ShuffledElements.insert(Idx); + ShuffledElements.setBit(Idx); } } return getGatherCost(VecTy, ShuffledElements, DuplicateNonConst); @@ -6029,14 +7421,83 @@ void BoUpSLP::reorderInputsAccordingToOpcode(ArrayRef<Value *> VL, void BoUpSLP::setInsertPointAfterBundle(const TreeEntry *E) { // Get the basic block this bundle is in. All instructions in the bundle - // should be in this block. + // should be in this block (except for extractelement-like instructions with + // constant indeces). auto *Front = E->getMainOp(); auto *BB = Front->getParent(); assert(llvm::all_of(E->Scalars, [=](Value *V) -> bool { + if (E->getOpcode() == Instruction::GetElementPtr && + !isa<GetElementPtrInst>(V)) + return true; auto *I = cast<Instruction>(V); - return !E->isOpcodeOrAlt(I) || I->getParent() == BB; + return !E->isOpcodeOrAlt(I) || I->getParent() == BB || + isVectorLikeInstWithConstOps(I); })); + auto &&FindLastInst = [E, Front, this, &BB]() { + Instruction *LastInst = Front; + for (Value *V : E->Scalars) { + auto *I = dyn_cast<Instruction>(V); + if (!I) + continue; + if (LastInst->getParent() == I->getParent()) { + if (LastInst->comesBefore(I)) + LastInst = I; + continue; + } + assert(isVectorLikeInstWithConstOps(LastInst) && + isVectorLikeInstWithConstOps(I) && + "Expected vector-like insts only."); + if (!DT->isReachableFromEntry(LastInst->getParent())) { + LastInst = I; + continue; + } + if (!DT->isReachableFromEntry(I->getParent())) + continue; + auto *NodeA = DT->getNode(LastInst->getParent()); + auto *NodeB = DT->getNode(I->getParent()); + assert(NodeA && "Should only process reachable instructions"); + assert(NodeB && "Should only process reachable instructions"); + assert((NodeA == NodeB) == + (NodeA->getDFSNumIn() == NodeB->getDFSNumIn()) && + "Different nodes should have different DFS numbers"); + if (NodeA->getDFSNumIn() < NodeB->getDFSNumIn()) + LastInst = I; + } + BB = LastInst->getParent(); + return LastInst; + }; + + auto &&FindFirstInst = [E, Front]() { + Instruction *FirstInst = Front; + for (Value *V : E->Scalars) { + auto *I = dyn_cast<Instruction>(V); + if (!I) + continue; + if (I->comesBefore(FirstInst)) + FirstInst = I; + } + return FirstInst; + }; + + // Set the insert point to the beginning of the basic block if the entry + // should not be scheduled. + if (E->State != TreeEntry::NeedToGather && + doesNotNeedToSchedule(E->Scalars)) { + Instruction *InsertInst; + if (all_of(E->Scalars, isUsedOutsideBlock)) + InsertInst = FindLastInst(); + else + InsertInst = FindFirstInst(); + // If the instruction is PHI, set the insert point after all the PHIs. + if (isa<PHINode>(InsertInst)) + InsertInst = BB->getFirstNonPHI(); + BasicBlock::iterator InsertPt = InsertInst->getIterator(); + Builder.SetInsertPoint(BB, InsertPt); + Builder.SetCurrentDebugLocation(Front->getDebugLoc()); + return; + } + // The last instruction in the bundle in program order. Instruction *LastInst = nullptr; @@ -6045,8 +7506,10 @@ void BoUpSLP::setInsertPointAfterBundle(const TreeEntry *E) { // VL.back() and iterate over schedule data until we reach the end of the // bundle. The end of the bundle is marked by null ScheduleData. if (BlocksSchedules.count(BB)) { - auto *Bundle = - BlocksSchedules[BB]->getScheduleData(E->isOneOf(E->Scalars.back())); + Value *V = E->isOneOf(E->Scalars.back()); + if (doesNotNeedToBeScheduled(V)) + V = *find_if_not(E->Scalars, doesNotNeedToBeScheduled); + auto *Bundle = BlocksSchedules[BB]->getScheduleData(V); if (Bundle && Bundle->isPartOfBundle()) for (; Bundle; Bundle = Bundle->NextInBundle) if (Bundle->OpValue == Bundle->Inst) @@ -6072,19 +7535,16 @@ void BoUpSLP::setInsertPointAfterBundle(const TreeEntry *E) { // we both exit early from buildTree_rec and that the bundle be out-of-order // (causing us to iterate all the way to the end of the block). if (!LastInst) { - SmallPtrSet<Value *, 16> Bundle(E->Scalars.begin(), E->Scalars.end()); - for (auto &I : make_range(BasicBlock::iterator(Front), BB->end())) { - if (Bundle.erase(&I) && E->isOpcodeOrAlt(&I)) - LastInst = &I; - if (Bundle.empty()) - break; - } + LastInst = FindLastInst(); + // If the instruction is PHI, set the insert point after all the PHIs. + if (isa<PHINode>(LastInst)) + LastInst = BB->getFirstNonPHI()->getPrevNode(); } assert(LastInst && "Failed to find last instruction in bundle"); // Set the insertion point after the last instruction in the bundle. Set the // debug location to Front. - Builder.SetInsertPoint(BB, ++LastInst->getIterator()); + Builder.SetInsertPoint(BB, std::next(LastInst->getIterator())); Builder.SetCurrentDebugLocation(Front->getDebugLoc()); } @@ -6214,8 +7674,15 @@ public: } // namespace Value *BoUpSLP::vectorizeTree(ArrayRef<Value *> VL) { - unsigned VF = VL.size(); + const unsigned VF = VL.size(); InstructionsState S = getSameOpcode(VL); + // Special processing for GEPs bundle, which may include non-gep values. + if (!S.getOpcode() && VL.front()->getType()->isPointerTy()) { + const auto *It = + find_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); }); + if (It != VL.end()) + S = getSameOpcode(*It); + } if (S.getOpcode()) { if (TreeEntry *E = getTreeEntry(S.OpValue)) if (E->isSame(VL)) { @@ -6270,7 +7737,18 @@ Value *BoUpSLP::vectorizeTree(ArrayRef<Value *> VL) { } } - // Check that every instruction appears once in this bundle. + // Can't vectorize this, so simply build a new vector with each lane + // corresponding to the requested value. + return createBuildVector(VL); +} +Value *BoUpSLP::createBuildVector(ArrayRef<Value *> VL) { + assert(any_of(VectorizableTree, + [VL](const std::unique_ptr<TreeEntry> &TE) { + return TE->State == TreeEntry::NeedToGather && TE->isSame(VL); + }) && + "Non-matching gather node."); + unsigned VF = VL.size(); + // Exploit possible reuse of values across lanes. SmallVector<int> ReuseShuffleIndicies; SmallVector<Value *> UniqueValues; if (VL.size() > 2) { @@ -6303,6 +7781,10 @@ Value *BoUpSLP::vectorizeTree(ArrayRef<Value *> VL) { ReuseShuffleIndicies.append(VF - ReuseShuffleIndicies.size(), UndefMaskElem); } else if (UniqueValues.size() >= VF - 1 || UniqueValues.size() <= 1) { + if (UniqueValues.empty()) { + assert(all_of(VL, UndefValue::classof) && "Expected list of undefs."); + NumValues = VF; + } ReuseShuffleIndicies.clear(); UniqueValues.clear(); UniqueValues.append(VL.begin(), std::next(VL.begin(), NumValues)); @@ -6342,7 +7824,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { SmallVector<const TreeEntry *> Entries; Optional<TargetTransformInfo::ShuffleKind> Shuffle = isGatherShuffledEntry(E, Mask, Entries); - if (Shuffle.hasValue()) { + if (Shuffle) { assert((Entries.size() == 1 || Entries.size() == 2) && "Expected shuffle of 1 or 2 entries."); Vec = Builder.CreateShuffleVector(Entries.front()->VectorizedValue, @@ -6376,14 +7858,20 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { auto *VecTy = FixedVectorType::get(ScalarTy, E->Scalars.size()); switch (ShuffleOrOp) { case Instruction::PHI: { - assert( - (E->ReorderIndices.empty() || E != VectorizableTree.front().get()) && - "PHI reordering is free."); + assert((E->ReorderIndices.empty() || + E != VectorizableTree.front().get() || + !E->UserTreeIndices.empty()) && + "PHI reordering is free."); auto *PH = cast<PHINode>(VL0); Builder.SetInsertPoint(PH->getParent()->getFirstNonPHI()); Builder.SetCurrentDebugLocation(PH->getDebugLoc()); PHINode *NewPhi = Builder.CreatePHI(VecTy, PH->getNumIncomingValues()); Value *V = NewPhi; + + // Adjust insertion point once all PHI's have been generated. + Builder.SetInsertPoint(&*PH->getParent()->getFirstInsertionPt()); + Builder.SetCurrentDebugLocation(PH->getDebugLoc()); + ShuffleBuilder.addInversedMask(E->ReorderIndices); ShuffleBuilder.addMask(E->ReuseShuffleIndices); V = ShuffleBuilder.finalize(V); @@ -6449,7 +7937,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { cast<FixedVectorType>(FirstInsert->getType())->getNumElements(); const unsigned NumScalars = E->Scalars.size(); - unsigned Offset = *getInsertIndex(VL0, 0); + unsigned Offset = *getInsertIndex(VL0); assert(Offset < NumElts && "Failed to find vector index offset"); // Create shuffle to resize vector @@ -6656,19 +8144,18 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { unsigned AS = LI->getPointerAddressSpace(); Value *PO = LI->getPointerOperand(); if (E->State == TreeEntry::Vectorize) { - Value *VecPtr = Builder.CreateBitCast(PO, VecTy->getPointerTo(AS)); + NewLI = Builder.CreateAlignedLoad(VecTy, VecPtr, LI->getAlign()); // The pointer operand uses an in-tree scalar so we add the new BitCast - // to ExternalUses list to make sure that an extract will be generated - // in the future. + // or LoadInst to ExternalUses list to make sure that an extract will + // be generated in the future. if (TreeEntry *Entry = getTreeEntry(PO)) { // Find which lane we need to extract. unsigned FoundLane = Entry->findLaneForValue(PO); - ExternalUses.emplace_back(PO, cast<User>(VecPtr), FoundLane); + ExternalUses.emplace_back( + PO, PO != VecPtr ? cast<User>(VecPtr) : NewLI, FoundLane); } - - NewLI = Builder.CreateAlignedLoad(VecTy, VecPtr, LI->getAlign()); } else { assert(E->State == TreeEntry::ScatterVectorize && "Unhandled state"); Value *VecPtr = vectorizeTree(E->getOperand(0)); @@ -6676,7 +8163,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Align CommonAlignment = LI->getAlign(); for (Value *V : E->Scalars) CommonAlignment = - commonAlignment(CommonAlignment, cast<LoadInst>(V)->getAlign()); + std::min(CommonAlignment, cast<LoadInst>(V)->getAlign()); NewLI = Builder.CreateMaskedGather(VecTy, VecPtr, CommonAlignment); } Value *V = propagateMetadata(NewLI, E->Scalars); @@ -6701,17 +8188,18 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *ScalarPtr = SI->getPointerOperand(); Value *VecPtr = Builder.CreateBitCast( ScalarPtr, VecValue->getType()->getPointerTo(AS)); - StoreInst *ST = Builder.CreateAlignedStore(VecValue, VecPtr, - SI->getAlign()); + StoreInst *ST = + Builder.CreateAlignedStore(VecValue, VecPtr, SI->getAlign()); - // The pointer operand uses an in-tree scalar, so add the new BitCast to - // ExternalUses to make sure that an extract will be generated in the - // future. + // The pointer operand uses an in-tree scalar, so add the new BitCast or + // StoreInst to ExternalUses to make sure that an extract will be + // generated in the future. if (TreeEntry *Entry = getTreeEntry(ScalarPtr)) { // Find which lane we need to extract. unsigned FoundLane = Entry->findLaneForValue(ScalarPtr); - ExternalUses.push_back( - ExternalUser(ScalarPtr, cast<User>(VecPtr), FoundLane)); + ExternalUses.push_back(ExternalUser( + ScalarPtr, ScalarPtr != VecPtr ? cast<User>(VecPtr) : ST, + FoundLane)); } Value *V = propagateMetadata(ST, E->Scalars); @@ -6733,8 +8221,14 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } Value *V = Builder.CreateGEP(GEP0->getSourceElementType(), Op0, OpVecs); - if (Instruction *I = dyn_cast<Instruction>(V)) - V = propagateMetadata(I, E->Scalars); + if (Instruction *I = dyn_cast<GetElementPtrInst>(V)) { + SmallVector<Value *> GEPs; + for (Value *V : E->Scalars) { + if (isa<GetElementPtrInst>(V)) + GEPs.push_back(V); + } + V = propagateMetadata(I, GEPs); + } ShuffleBuilder.addInversedMask(E->ReorderIndices); ShuffleBuilder.addMask(E->ReuseShuffleIndices); @@ -6767,11 +8261,11 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { ValueList OpVL; // Some intrinsics have scalar arguments. This argument should not be // vectorized. - if (UseIntrinsic && hasVectorInstrinsicScalarOpd(IID, j)) { + if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(IID, j)) { CallInst *CEI = cast<CallInst>(VL0); ScalarArg = CEI->getArgOperand(j); OpVecs.push_back(CEI->getArgOperand(j)); - if (hasVectorInstrinsicOverloadedScalarOpd(IID, j)) + if (isVectorIntrinsicWithOverloadTypeAtArg(IID, j)) TysForDecl.push_back(ScalarArg->getType()); continue; } @@ -6779,6 +8273,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *OpVec = vectorizeTree(E->getOperand(j)); LLVM_DEBUG(dbgs() << "SLP: OpVec[" << j << "]: " << *OpVec << "\n"); OpVecs.push_back(OpVec); + if (isVectorIntrinsicWithOverloadTypeAtArg(IID, j)) + TysForDecl.push_back(OpVec->getType()); } Function *CF; @@ -6822,11 +8318,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { ((Instruction::isBinaryOp(E->getOpcode()) && Instruction::isBinaryOp(E->getAltOpcode())) || (Instruction::isCast(E->getOpcode()) && - Instruction::isCast(E->getAltOpcode()))) && + Instruction::isCast(E->getAltOpcode())) || + (isa<CmpInst>(VL0) && isa<CmpInst>(E->getAltOp()))) && "Invalid Shuffle Vector Operand"); Value *LHS = nullptr, *RHS = nullptr; - if (Instruction::isBinaryOp(E->getOpcode())) { + if (Instruction::isBinaryOp(E->getOpcode()) || isa<CmpInst>(VL0)) { setInsertPointAfterBundle(E); LHS = vectorizeTree(E->getOperand(0)); RHS = vectorizeTree(E->getOperand(1)); @@ -6846,6 +8343,11 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS, RHS); V1 = Builder.CreateBinOp( static_cast<Instruction::BinaryOps>(E->getAltOpcode()), LHS, RHS); + } else if (auto *CI0 = dyn_cast<CmpInst>(VL0)) { + V0 = Builder.CreateCmp(CI0->getPredicate(), LHS, RHS); + auto *AltCI = cast<CmpInst>(E->getAltOp()); + CmpInst::Predicate AltPred = AltCI->getPredicate(); + V1 = Builder.CreateCmp(AltPred, LHS, RHS); } else { V0 = Builder.CreateCast( static_cast<Instruction::CastOps>(E->getOpcode()), LHS, VecTy); @@ -6866,11 +8368,11 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { // each vector operation. ValueList OpScalars, AltScalars; SmallVector<int> Mask; - buildSuffleEntryMask( + buildShuffleEntryMask( E->Scalars, E->ReorderIndices, E->ReuseShuffleIndices, [E](Instruction *I) { assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode"); - return I->getOpcode() == E->getAltOpcode(); + return isAlternateInstruction(I, E->getMainOp(), E->getAltOp()); }, Mask, &OpScalars, &AltScalars); @@ -6901,6 +8403,17 @@ Value *BoUpSLP::vectorizeTree() { return vectorizeTree(ExternallyUsedValues); } +namespace { +/// Data type for handling buildvector sequences with the reused scalars from +/// other tree entries. +struct ShuffledInsertData { + /// List of insertelements to be replaced by shuffles. + SmallVector<InsertElementInst *> InsertElements; + /// The parent vectors and shuffle mask for the given list of inserts. + MapVector<Value *, SmallVector<int>> ValueMasks; +}; +} // namespace + Value * BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { // All blocks must be scheduled before any instructions are inserted. @@ -6934,6 +8447,9 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { LLVM_DEBUG(dbgs() << "SLP: Extracting " << ExternalUses.size() << " values .\n"); + SmallVector<ShuffledInsertData> ShuffledInserts; + // Maps vector instruction to original insertelement instruction + DenseMap<Value *, InsertElementInst *> VectorToInsertElement; // Extract all of the elements with the external uses. for (const auto &ExternalUse : ExternalUses) { Value *Scalar = ExternalUse.Scalar; @@ -6947,6 +8463,10 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { assert(E && "Invalid scalar"); assert(E->State != TreeEntry::NeedToGather && "Extracting from a gather list"); + // Non-instruction pointers are not deleted, just skip them. + if (E->getOpcode() == Instruction::GetElementPtr && + !isa<GetElementPtrInst>(Scalar)) + continue; Value *Vec = E->VectorizedValue; assert(Vec && "Can't find vectorizable value"); @@ -6973,6 +8493,8 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { assert(isa<FixedVectorType>(Scalar->getType()) && isa<InsertElementInst>(Scalar) && "In-tree scalar of vector type is not insertelement?"); + auto *IE = cast<InsertElementInst>(Scalar); + VectorToInsertElement.try_emplace(Vec, IE); return Vec; }; // If User == nullptr, the Scalar is used as extra arg. Generate @@ -7001,6 +8523,69 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { continue; } + if (auto *VU = dyn_cast<InsertElementInst>(User)) { + // Skip if the scalar is another vector op or Vec is not an instruction. + if (!Scalar->getType()->isVectorTy() && isa<Instruction>(Vec)) { + if (auto *FTy = dyn_cast<FixedVectorType>(User->getType())) { + Optional<unsigned> InsertIdx = getInsertIndex(VU); + if (InsertIdx) { + // Need to use original vector, if the root is truncated. + if (MinBWs.count(Scalar) && + VectorizableTree[0]->VectorizedValue == Vec) + Vec = VectorRoot; + auto *It = + find_if(ShuffledInserts, [VU](const ShuffledInsertData &Data) { + // Checks if 2 insertelements are from the same buildvector. + InsertElementInst *VecInsert = Data.InsertElements.front(); + return areTwoInsertFromSameBuildVector(VU, VecInsert); + }); + unsigned Idx = *InsertIdx; + if (It == ShuffledInserts.end()) { + (void)ShuffledInserts.emplace_back(); + It = std::next(ShuffledInserts.begin(), + ShuffledInserts.size() - 1); + SmallVectorImpl<int> &Mask = It->ValueMasks[Vec]; + if (Mask.empty()) + Mask.assign(FTy->getNumElements(), UndefMaskElem); + // Find the insertvector, vectorized in tree, if any. + Value *Base = VU; + while (auto *IEBase = dyn_cast<InsertElementInst>(Base)) { + if (IEBase != User && + (!IEBase->hasOneUse() || + getInsertIndex(IEBase).value_or(Idx) == Idx)) + break; + // Build the mask for the vectorized insertelement instructions. + if (const TreeEntry *E = getTreeEntry(IEBase)) { + do { + IEBase = cast<InsertElementInst>(Base); + int IEIdx = *getInsertIndex(IEBase); + assert(Mask[Idx] == UndefMaskElem && + "InsertElementInstruction used already."); + Mask[IEIdx] = IEIdx; + Base = IEBase->getOperand(0); + } while (E == getTreeEntry(Base)); + break; + } + Base = cast<InsertElementInst>(Base)->getOperand(0); + // After the vectorization the def-use chain has changed, need + // to look through original insertelement instructions, if they + // get replaced by vector instructions. + auto It = VectorToInsertElement.find(Base); + if (It != VectorToInsertElement.end()) + Base = It->second; + } + } + SmallVectorImpl<int> &Mask = It->ValueMasks[Vec]; + if (Mask.empty()) + Mask.assign(FTy->getNumElements(), UndefMaskElem); + Mask[Idx] = ExternalUse.Lane; + It->InsertElements.push_back(cast<InsertElementInst>(User)); + continue; + } + } + } + } + // Generate extracts for out-of-tree users. // Find the insertion point for the extractelement lane. if (auto *VecI = dyn_cast<Instruction>(Vec)) { @@ -7036,6 +8621,221 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { LLVM_DEBUG(dbgs() << "SLP: Replaced:" << *User << ".\n"); } + // Checks if the mask is an identity mask. + auto &&IsIdentityMask = [](ArrayRef<int> Mask, FixedVectorType *VecTy) { + int Limit = Mask.size(); + return VecTy->getNumElements() == Mask.size() && + all_of(Mask, [Limit](int Idx) { return Idx < Limit; }) && + ShuffleVectorInst::isIdentityMask(Mask); + }; + // Tries to combine 2 different masks into single one. + auto &&CombineMasks = [](SmallVectorImpl<int> &Mask, ArrayRef<int> ExtMask) { + SmallVector<int> NewMask(ExtMask.size(), UndefMaskElem); + for (int I = 0, Sz = ExtMask.size(); I < Sz; ++I) { + if (ExtMask[I] == UndefMaskElem) + continue; + NewMask[I] = Mask[ExtMask[I]]; + } + Mask.swap(NewMask); + }; + // Peek through shuffles, trying to simplify the final shuffle code. + auto &&PeekThroughShuffles = + [&IsIdentityMask, &CombineMasks](Value *&V, SmallVectorImpl<int> &Mask, + bool CheckForLengthChange = false) { + while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) { + // Exit if not a fixed vector type or changing size shuffle. + if (!isa<FixedVectorType>(SV->getType()) || + (CheckForLengthChange && SV->changesLength())) + break; + // Exit if the identity or broadcast mask is found. + if (IsIdentityMask(Mask, cast<FixedVectorType>(SV->getType())) || + SV->isZeroEltSplat()) + break; + bool IsOp1Undef = isUndefVector(SV->getOperand(0)); + bool IsOp2Undef = isUndefVector(SV->getOperand(1)); + if (!IsOp1Undef && !IsOp2Undef) + break; + SmallVector<int> ShuffleMask(SV->getShuffleMask().begin(), + SV->getShuffleMask().end()); + CombineMasks(ShuffleMask, Mask); + Mask.swap(ShuffleMask); + if (IsOp2Undef) + V = SV->getOperand(0); + else + V = SV->getOperand(1); + } + }; + // Smart shuffle instruction emission, walks through shuffles trees and + // tries to find the best matching vector for the actual shuffle + // instruction. + auto &&CreateShuffle = [this, &IsIdentityMask, &PeekThroughShuffles, + &CombineMasks](Value *V1, Value *V2, + ArrayRef<int> Mask) -> Value * { + assert(V1 && "Expected at least one vector value."); + if (V2 && !isUndefVector(V2)) { + // Peek through shuffles. + Value *Op1 = V1; + Value *Op2 = V2; + int VF = + cast<VectorType>(V1->getType())->getElementCount().getKnownMinValue(); + SmallVector<int> CombinedMask1(Mask.size(), UndefMaskElem); + SmallVector<int> CombinedMask2(Mask.size(), UndefMaskElem); + for (int I = 0, E = Mask.size(); I < E; ++I) { + if (Mask[I] < VF) + CombinedMask1[I] = Mask[I]; + else + CombinedMask2[I] = Mask[I] - VF; + } + Value *PrevOp1; + Value *PrevOp2; + do { + PrevOp1 = Op1; + PrevOp2 = Op2; + PeekThroughShuffles(Op1, CombinedMask1, /*CheckForLengthChange=*/true); + PeekThroughShuffles(Op2, CombinedMask2, /*CheckForLengthChange=*/true); + // Check if we have 2 resizing shuffles - need to peek through operands + // again. + if (auto *SV1 = dyn_cast<ShuffleVectorInst>(Op1)) + if (auto *SV2 = dyn_cast<ShuffleVectorInst>(Op2)) + if (SV1->getOperand(0)->getType() == + SV2->getOperand(0)->getType() && + SV1->getOperand(0)->getType() != SV1->getType() && + isUndefVector(SV1->getOperand(1)) && + isUndefVector(SV2->getOperand(1))) { + Op1 = SV1->getOperand(0); + Op2 = SV2->getOperand(0); + SmallVector<int> ShuffleMask1(SV1->getShuffleMask().begin(), + SV1->getShuffleMask().end()); + CombineMasks(ShuffleMask1, CombinedMask1); + CombinedMask1.swap(ShuffleMask1); + SmallVector<int> ShuffleMask2(SV2->getShuffleMask().begin(), + SV2->getShuffleMask().end()); + CombineMasks(ShuffleMask2, CombinedMask2); + CombinedMask2.swap(ShuffleMask2); + } + } while (PrevOp1 != Op1 || PrevOp2 != Op2); + VF = cast<VectorType>(Op1->getType()) + ->getElementCount() + .getKnownMinValue(); + for (int I = 0, E = Mask.size(); I < E; ++I) { + if (CombinedMask2[I] != UndefMaskElem) { + assert(CombinedMask1[I] == UndefMaskElem && + "Expected undefined mask element"); + CombinedMask1[I] = CombinedMask2[I] + (Op1 == Op2 ? 0 : VF); + } + } + Value *Vec = Builder.CreateShuffleVector( + Op1, Op1 == Op2 ? PoisonValue::get(Op1->getType()) : Op2, + CombinedMask1); + if (auto *I = dyn_cast<Instruction>(Vec)) { + GatherShuffleSeq.insert(I); + CSEBlocks.insert(I->getParent()); + } + return Vec; + } + if (isa<PoisonValue>(V1)) + return PoisonValue::get(FixedVectorType::get( + cast<VectorType>(V1->getType())->getElementType(), Mask.size())); + Value *Op = V1; + SmallVector<int> CombinedMask(Mask.begin(), Mask.end()); + PeekThroughShuffles(Op, CombinedMask); + if (!isa<FixedVectorType>(Op->getType()) || + !IsIdentityMask(CombinedMask, cast<FixedVectorType>(Op->getType()))) { + Value *Vec = Builder.CreateShuffleVector(Op, CombinedMask); + if (auto *I = dyn_cast<Instruction>(Vec)) { + GatherShuffleSeq.insert(I); + CSEBlocks.insert(I->getParent()); + } + return Vec; + } + return Op; + }; + + auto &&ResizeToVF = [&CreateShuffle](Value *Vec, ArrayRef<int> Mask) { + unsigned VF = Mask.size(); + unsigned VecVF = cast<FixedVectorType>(Vec->getType())->getNumElements(); + if (VF != VecVF) { + if (any_of(Mask, [VF](int Idx) { return Idx >= static_cast<int>(VF); })) { + Vec = CreateShuffle(Vec, nullptr, Mask); + return std::make_pair(Vec, true); + } + SmallVector<int> ResizeMask(VF, UndefMaskElem); + for (unsigned I = 0; I < VF; ++I) { + if (Mask[I] != UndefMaskElem) + ResizeMask[Mask[I]] = Mask[I]; + } + Vec = CreateShuffle(Vec, nullptr, ResizeMask); + } + + return std::make_pair(Vec, false); + }; + // Perform shuffling of the vectorize tree entries for better handling of + // external extracts. + for (int I = 0, E = ShuffledInserts.size(); I < E; ++I) { + // Find the first and the last instruction in the list of insertelements. + sort(ShuffledInserts[I].InsertElements, isFirstInsertElement); + InsertElementInst *FirstInsert = ShuffledInserts[I].InsertElements.front(); + InsertElementInst *LastInsert = ShuffledInserts[I].InsertElements.back(); + Builder.SetInsertPoint(LastInsert); + auto Vector = ShuffledInserts[I].ValueMasks.takeVector(); + Value *NewInst = performExtractsShuffleAction<Value>( + makeMutableArrayRef(Vector.data(), Vector.size()), + FirstInsert->getOperand(0), + [](Value *Vec) { + return cast<VectorType>(Vec->getType()) + ->getElementCount() + .getKnownMinValue(); + }, + ResizeToVF, + [FirstInsert, &CreateShuffle](ArrayRef<int> Mask, + ArrayRef<Value *> Vals) { + assert((Vals.size() == 1 || Vals.size() == 2) && + "Expected exactly 1 or 2 input values."); + if (Vals.size() == 1) { + // Do not create shuffle if the mask is a simple identity + // non-resizing mask. + if (Mask.size() != cast<FixedVectorType>(Vals.front()->getType()) + ->getNumElements() || + !ShuffleVectorInst::isIdentityMask(Mask)) + return CreateShuffle(Vals.front(), nullptr, Mask); + return Vals.front(); + } + return CreateShuffle(Vals.front() ? Vals.front() + : FirstInsert->getOperand(0), + Vals.back(), Mask); + }); + auto It = ShuffledInserts[I].InsertElements.rbegin(); + // Rebuild buildvector chain. + InsertElementInst *II = nullptr; + if (It != ShuffledInserts[I].InsertElements.rend()) + II = *It; + SmallVector<Instruction *> Inserts; + while (It != ShuffledInserts[I].InsertElements.rend()) { + assert(II && "Must be an insertelement instruction."); + if (*It == II) + ++It; + else + Inserts.push_back(cast<Instruction>(II)); + II = dyn_cast<InsertElementInst>(II->getOperand(0)); + } + for (Instruction *II : reverse(Inserts)) { + II->replaceUsesOfWith(II->getOperand(0), NewInst); + if (auto *NewI = dyn_cast<Instruction>(NewInst)) + if (II->getParent() == NewI->getParent() && II->comesBefore(NewI)) + II->moveAfter(NewI); + NewInst = II; + } + LastInsert->replaceAllUsesWith(NewInst); + for (InsertElementInst *IE : reverse(ShuffledInserts[I].InsertElements)) { + IE->replaceUsesOfWith(IE->getOperand(0), + PoisonValue::get(IE->getOperand(0)->getType())); + IE->replaceUsesOfWith(IE->getOperand(1), + PoisonValue::get(IE->getOperand(1)->getType())); + eraseInstruction(IE); + } + CSEBlocks.insert(LastInsert->getParent()); + } + // For each vectorized value: for (auto &TEPtr : VectorizableTree) { TreeEntry *Entry = TEPtr.get(); @@ -7050,6 +8850,9 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { for (int Lane = 0, LE = Entry->Scalars.size(); Lane != LE; ++Lane) { Value *Scalar = Entry->Scalars[Lane]; + if (Entry->getOpcode() == Instruction::GetElementPtr && + !isa<GetElementPtrInst>(Scalar)) + continue; #ifndef NDEBUG Type *Ty = Scalar->getType(); if (!Ty->isVoidTy()) { @@ -7057,7 +8860,8 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { LLVM_DEBUG(dbgs() << "SLP: \tvalidating user:" << *U << ".\n"); // It is legal to delete users in the ignorelist. - assert((getTreeEntry(U) || is_contained(UserIgnoreList, U) || + assert((getTreeEntry(U) || + (UserIgnoreList && UserIgnoreList->contains(U)) || (isa_and_nonnull<Instruction>(U) && isDeleted(cast<Instruction>(U)))) && "Deleting out-of-tree value"); @@ -7225,9 +9029,11 @@ void BoUpSLP::optimizeGatherSequence() { BoUpSLP::ScheduleData * BoUpSLP::BlockScheduling::buildBundle(ArrayRef<Value *> VL) { - ScheduleData *Bundle = nullptr; + ScheduleData *Bundle = nullptr; ScheduleData *PrevInBundle = nullptr; for (Value *V : VL) { + if (doesNotNeedToBeScheduled(V)) + continue; ScheduleData *BundleMember = getScheduleData(V); assert(BundleMember && "no ScheduleData for bundle member " @@ -7239,8 +9045,6 @@ BoUpSLP::BlockScheduling::buildBundle(ArrayRef<Value *> VL) { } else { Bundle = BundleMember; } - BundleMember->UnscheduledDepsInBundle = 0; - Bundle->UnscheduledDepsInBundle += BundleMember->UnscheduledDeps; // Group the instructions to a bundle. BundleMember->FirstInBundle = Bundle; @@ -7257,7 +9061,8 @@ BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP, const InstructionsState &S) { // No need to schedule PHIs, insertelement, extractelement and extractvalue // instructions. - if (isa<PHINode>(S.OpValue) || isVectorLikeInstWithConstOps(S.OpValue)) + if (isa<PHINode>(S.OpValue) || isVectorLikeInstWithConstOps(S.OpValue) || + doesNotNeedToSchedule(VL)) return nullptr; // Initialize the instruction bundle. @@ -7276,16 +9081,17 @@ BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP, doForAllOpcodes(I, [](ScheduleData *SD) { SD->clearDependencies(); }); ReSchedule = true; } - if (ReSchedule) { - resetSchedule(); - initialFillReadyList(ReadyInsts); - } if (Bundle) { LLVM_DEBUG(dbgs() << "SLP: try schedule bundle " << *Bundle << " in block " << BB->getName() << "\n"); calculateDependencies(Bundle, /*InsertInReadyList=*/true, SLP); } + if (ReSchedule) { + resetSchedule(); + initialFillReadyList(ReadyInsts); + } + // Now try to schedule the new bundle or (if no bundle) just calculate // dependencies. As soon as the bundle is "ready" it means that there are no // cyclic dependencies and we can schedule it. Note that's important that we @@ -7293,14 +9099,17 @@ BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP, while (((!Bundle && ReSchedule) || (Bundle && !Bundle->isReady())) && !ReadyInsts.empty()) { ScheduleData *Picked = ReadyInsts.pop_back_val(); - if (Picked->isSchedulingEntity() && Picked->isReady()) - schedule(Picked, ReadyInsts); + assert(Picked->isSchedulingEntity() && Picked->isReady() && + "must be ready to schedule"); + schedule(Picked, ReadyInsts); } }; // Make sure that the scheduling region contains all // instructions of the bundle. for (Value *V : VL) { + if (doesNotNeedToBeScheduled(V)) + continue; if (!extendSchedulingRegion(V, S)) { // If the scheduling region got new instructions at the lower end (or it // is a new region for the first bundle). This makes it necessary to @@ -7315,9 +9124,16 @@ BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP, bool ReSchedule = false; for (Value *V : VL) { + if (doesNotNeedToBeScheduled(V)) + continue; ScheduleData *BundleMember = getScheduleData(V); assert(BundleMember && "no ScheduleData for bundle member (maybe not in same basic block)"); + + // Make sure we don't leave the pieces of the bundle in the ready list when + // whole bundle might not be ready. + ReadyInsts.remove(BundleMember); + if (!BundleMember->IsScheduled) continue; // A bundle member was scheduled as single instruction before and now @@ -7339,16 +9155,24 @@ BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, BoUpSLP *SLP, void BoUpSLP::BlockScheduling::cancelScheduling(ArrayRef<Value *> VL, Value *OpValue) { - if (isa<PHINode>(OpValue) || isVectorLikeInstWithConstOps(OpValue)) + if (isa<PHINode>(OpValue) || isVectorLikeInstWithConstOps(OpValue) || + doesNotNeedToSchedule(VL)) return; + if (doesNotNeedToBeScheduled(OpValue)) + OpValue = *find_if_not(VL, doesNotNeedToBeScheduled); ScheduleData *Bundle = getScheduleData(OpValue); LLVM_DEBUG(dbgs() << "SLP: cancel scheduling of " << *Bundle << "\n"); assert(!Bundle->IsScheduled && "Can't cancel bundle which is already scheduled"); - assert(Bundle->isSchedulingEntity() && Bundle->isPartOfBundle() && + assert(Bundle->isSchedulingEntity() && + (Bundle->isPartOfBundle() || needToScheduleSingleInstruction(VL)) && "tried to unbundle something which is not a bundle"); + // Remove the bundle from the ready list. + if (Bundle->isReady()) + ReadyInsts.remove(Bundle); + // Un-bundle: make single instructions out of the bundle. ScheduleData *BundleMember = Bundle; while (BundleMember) { @@ -7356,8 +9180,8 @@ void BoUpSLP::BlockScheduling::cancelScheduling(ArrayRef<Value *> VL, BundleMember->FirstInBundle = BundleMember; ScheduleData *Next = BundleMember->NextInBundle; BundleMember->NextInBundle = nullptr; - BundleMember->UnscheduledDepsInBundle = BundleMember->UnscheduledDeps; - if (BundleMember->UnscheduledDepsInBundle == 0) { + BundleMember->TE = nullptr; + if (BundleMember->unscheduledDepsInBundle() == 0) { ReadyInsts.insert(BundleMember); } BundleMember = Next; @@ -7380,9 +9204,10 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, Instruction *I = dyn_cast<Instruction>(V); assert(I && "bundle member must be an instruction"); assert(!isa<PHINode>(I) && !isVectorLikeInstWithConstOps(I) && + !doesNotNeedToBeScheduled(I) && "phi nodes/insertelements/extractelements/extractvalues don't need to " "be scheduled"); - auto &&CheckSheduleForI = [this, &S](Instruction *I) -> bool { + auto &&CheckScheduleForI = [this, &S](Instruction *I) -> bool { ScheduleData *ISD = getScheduleData(I); if (!ISD) return false; @@ -7394,7 +9219,7 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, ExtraScheduleDataMap[I][S.OpValue] = SD; return true; }; - if (CheckSheduleForI(I)) + if (CheckScheduleForI(I)) return true; if (!ScheduleStart) { // It's the first instruction in the new region. @@ -7402,7 +9227,7 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, ScheduleStart = I; ScheduleEnd = I->getNextNode(); if (isOneOf(S, I) != I) - CheckSheduleForI(I); + CheckScheduleForI(I); assert(ScheduleEnd && "tried to vectorize a terminator?"); LLVM_DEBUG(dbgs() << "SLP: initialize schedule region to " << *I << "\n"); return true; @@ -7430,7 +9255,7 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, initScheduleData(I, ScheduleStart, nullptr, FirstLoadStoreInRegion); ScheduleStart = I; if (isOneOf(S, I) != I) - CheckSheduleForI(I); + CheckScheduleForI(I); LLVM_DEBUG(dbgs() << "SLP: extend schedule region start to " << *I << "\n"); return true; @@ -7444,7 +9269,7 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, nullptr); ScheduleEnd = I->getNextNode(); if (isOneOf(S, I) != I) - CheckSheduleForI(I); + CheckScheduleForI(I); assert(ScheduleEnd && "tried to vectorize a terminator?"); LLVM_DEBUG(dbgs() << "SLP: extend schedule region end to " << *I << "\n"); return true; @@ -7456,7 +9281,10 @@ void BoUpSLP::BlockScheduling::initScheduleData(Instruction *FromI, ScheduleData *NextLoadStore) { ScheduleData *CurrentLoadStore = PrevLoadStore; for (Instruction *I = FromI; I != ToI; I = I->getNextNode()) { - ScheduleData *SD = ScheduleDataMap[I]; + // No need to allocate data for non-schedulable instructions. + if (doesNotNeedToBeScheduled(I)) + continue; + ScheduleData *SD = ScheduleDataMap.lookup(I); if (!SD) { SD = allocateScheduleDataChunks(); ScheduleDataMap[I] = SD; @@ -7479,6 +9307,10 @@ void BoUpSLP::BlockScheduling::initScheduleData(Instruction *FromI, } CurrentLoadStore = SD; } + + if (match(I, m_Intrinsic<Intrinsic::stacksave>()) || + match(I, m_Intrinsic<Intrinsic::stackrestore>())) + RegionHasStackSave = true; } if (NextLoadStore) { if (CurrentLoadStore) @@ -7511,8 +9343,7 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, // Handle def-use chain dependencies. if (BundleMember->OpValue != BundleMember->Inst) { - ScheduleData *UseSD = getScheduleData(BundleMember->Inst); - if (UseSD && isInSchedulingRegion(UseSD->FirstInBundle)) { + if (ScheduleData *UseSD = getScheduleData(BundleMember->Inst)) { BundleMember->Dependencies++; ScheduleData *DestBundle = UseSD->FirstInBundle; if (!DestBundle->IsScheduled) @@ -7522,10 +9353,7 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, } } else { for (User *U : BundleMember->Inst->users()) { - assert(isa<Instruction>(U) && - "user of instruction must be instruction"); - ScheduleData *UseSD = getScheduleData(U); - if (UseSD && isInSchedulingRegion(UseSD->FirstInBundle)) { + if (ScheduleData *UseSD = getScheduleData(cast<Instruction>(U))) { BundleMember->Dependencies++; ScheduleData *DestBundle = UseSD->FirstInBundle; if (!DestBundle->IsScheduled) @@ -7536,6 +9364,75 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, } } + auto makeControlDependent = [&](Instruction *I) { + auto *DepDest = getScheduleData(I); + assert(DepDest && "must be in schedule window"); + DepDest->ControlDependencies.push_back(BundleMember); + BundleMember->Dependencies++; + ScheduleData *DestBundle = DepDest->FirstInBundle; + if (!DestBundle->IsScheduled) + BundleMember->incrementUnscheduledDeps(1); + if (!DestBundle->hasValidDependencies()) + WorkList.push_back(DestBundle); + }; + + // Any instruction which isn't safe to speculate at the begining of the + // block is control dependend on any early exit or non-willreturn call + // which proceeds it. + if (!isGuaranteedToTransferExecutionToSuccessor(BundleMember->Inst)) { + for (Instruction *I = BundleMember->Inst->getNextNode(); + I != ScheduleEnd; I = I->getNextNode()) { + if (isSafeToSpeculativelyExecute(I, &*BB->begin())) + continue; + + // Add the dependency + makeControlDependent(I); + + if (!isGuaranteedToTransferExecutionToSuccessor(I)) + // Everything past here must be control dependent on I. + break; + } + } + + if (RegionHasStackSave) { + // If we have an inalloc alloca instruction, it needs to be scheduled + // after any preceeding stacksave. We also need to prevent any alloca + // from reordering above a preceeding stackrestore. + if (match(BundleMember->Inst, m_Intrinsic<Intrinsic::stacksave>()) || + match(BundleMember->Inst, m_Intrinsic<Intrinsic::stackrestore>())) { + for (Instruction *I = BundleMember->Inst->getNextNode(); + I != ScheduleEnd; I = I->getNextNode()) { + if (match(I, m_Intrinsic<Intrinsic::stacksave>()) || + match(I, m_Intrinsic<Intrinsic::stackrestore>())) + // Any allocas past here must be control dependent on I, and I + // must be memory dependend on BundleMember->Inst. + break; + + if (!isa<AllocaInst>(I)) + continue; + + // Add the dependency + makeControlDependent(I); + } + } + + // In addition to the cases handle just above, we need to prevent + // allocas from moving below a stacksave. The stackrestore case + // is currently thought to be conservatism. + if (isa<AllocaInst>(BundleMember->Inst)) { + for (Instruction *I = BundleMember->Inst->getNextNode(); + I != ScheduleEnd; I = I->getNextNode()) { + if (!match(I, m_Intrinsic<Intrinsic::stacksave>()) && + !match(I, m_Intrinsic<Intrinsic::stackrestore>())) + continue; + + // Add the dependency + makeControlDependent(I); + break; + } + } + } + // Handle the memory dependencies (if any). ScheduleData *DepDest = BundleMember->NextLoadStore; if (!DepDest) @@ -7598,7 +9495,7 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, } } if (InsertInReadyList && SD->isReady()) { - ReadyInsts.push_back(SD); + ReadyInsts.insert(SD); LLVM_DEBUG(dbgs() << "SLP: gets ready on update: " << *SD->Inst << "\n"); } @@ -7625,11 +9522,18 @@ void BoUpSLP::scheduleBlock(BlockScheduling *BS) { LLVM_DEBUG(dbgs() << "SLP: schedule block " << BS->BB->getName() << "\n"); + // A key point - if we got here, pre-scheduling was able to find a valid + // scheduling of the sub-graph of the scheduling window which consists + // of all vector bundles and their transitive users. As such, we do not + // need to reschedule anything *outside of* that subgraph. + BS->resetSchedule(); // For the real scheduling we use a more sophisticated ready-list: it is // sorted by the original instruction location. This lets the final schedule // be as close as possible to the original instruction order. + // WARNING: If changing this order causes a correctness issue, that means + // there is some missing dependence edge in the schedule data graph. struct ScheduleDataCompare { bool operator()(ScheduleData *SD1, ScheduleData *SD2) const { return SD2->SchedulingPriority < SD1->SchedulingPriority; @@ -7637,21 +9541,22 @@ void BoUpSLP::scheduleBlock(BlockScheduling *BS) { }; std::set<ScheduleData *, ScheduleDataCompare> ReadyInsts; - // Ensure that all dependency data is updated and fill the ready-list with - // initial instructions. + // Ensure that all dependency data is updated (for nodes in the sub-graph) + // and fill the ready-list with initial instructions. int Idx = 0; - int NumToSchedule = 0; for (auto *I = BS->ScheduleStart; I != BS->ScheduleEnd; I = I->getNextNode()) { - BS->doForAllOpcodes(I, [this, &Idx, &NumToSchedule, BS](ScheduleData *SD) { + BS->doForAllOpcodes(I, [this, &Idx, BS](ScheduleData *SD) { + TreeEntry *SDTE = getTreeEntry(SD->Inst); + (void)SDTE; assert((isVectorLikeInstWithConstOps(SD->Inst) || - SD->isPartOfBundle() == (getTreeEntry(SD->Inst) != nullptr)) && + SD->isPartOfBundle() == + (SDTE && !doesNotNeedToSchedule(SDTE->Scalars))) && "scheduler and vectorizer bundle mismatch"); SD->FirstInBundle->SchedulingPriority = Idx++; - if (SD->isSchedulingEntity()) { + + if (SD->isSchedulingEntity() && SD->isPartOfBundle()) BS->calculateDependencies(SD, false, this); - NumToSchedule++; - } }); } BS->initialFillReadyList(ReadyInsts); @@ -7674,9 +9579,23 @@ void BoUpSLP::scheduleBlock(BlockScheduling *BS) { } BS->schedule(picked, ReadyInsts); - NumToSchedule--; } - assert(NumToSchedule == 0 && "could not schedule all instructions"); + + // Check that we didn't break any of our invariants. +#ifdef EXPENSIVE_CHECKS + BS->verify(); +#endif + +#if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS) + // Check that all schedulable entities got scheduled + for (auto *I = BS->ScheduleStart; I != BS->ScheduleEnd; I = I->getNextNode()) { + BS->doForAllOpcodes(I, [&](ScheduleData *SD) { + if (SD->isSchedulingEntity() && SD->hasValidDependencies()) { + assert(SD->IsScheduled && "must be scheduled at this point"); + } + }); + } +#endif // Avoid duplicate scheduling of the block. BS->ScheduleStart = nullptr; @@ -7686,11 +9605,8 @@ unsigned BoUpSLP::getVectorElementSize(Value *V) { // If V is a store, just return the width of the stored value (or value // truncated just before storing) without traversing the expression tree. // This is the common case. - if (auto *Store = dyn_cast<StoreInst>(V)) { - if (auto *Trunc = dyn_cast<TruncInst>(Store->getValueOperand())) - return DL->getTypeSizeInBits(Trunc->getSrcTy()); + if (auto *Store = dyn_cast<StoreInst>(V)) return DL->getTypeSizeInBits(Store->getValueOperand()->getType()); - } if (auto *IEI = dyn_cast<InsertElementInst>(V)) return getVectorElementSize(IEI->getOperand(1)); @@ -8092,6 +10008,8 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_, // Scan the blocks in the function in post order. for (auto BB : post_order(&F.getEntryBlock())) { + // Start new block - clear the list of reduction roots. + R.clearReductionData(); collectSeedInstructions(BB); // Vectorize trees that end at stores. @@ -8122,11 +10040,10 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_, } bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, - unsigned Idx) { + unsigned Idx, unsigned MinVF) { LLVM_DEBUG(dbgs() << "SLP: Analyzing a store chain of length " << Chain.size() << "\n"); const unsigned Sz = R.getVectorElementSize(Chain[0]); - const unsigned MinVF = R.getMinVecRegSize() / Sz; unsigned VF = Chain.size(); if (!isPowerOf2_32(Sz) || !isPowerOf2_32(VF) || VF < 2 || VF < MinVF) @@ -8265,9 +10182,15 @@ bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores, unsigned EltSize = R.getVectorElementSize(Operands[0]); unsigned MaxElts = llvm::PowerOf2Floor(MaxVecRegSize / EltSize); - unsigned MinVF = R.getMinVF(EltSize); unsigned MaxVF = std::min(R.getMaximumVF(EltSize, Instruction::Store), MaxElts); + auto *Store = cast<StoreInst>(Operands[0]); + Type *StoreTy = Store->getValueOperand()->getType(); + Type *ValueTy = StoreTy; + if (auto *Trunc = dyn_cast<TruncInst>(Store->getValueOperand())) + ValueTy = Trunc->getSrcTy(); + unsigned MinVF = TTI->getStoreMinimumVF( + R.getMinVF(DL->getTypeSizeInBits(ValueTy)), StoreTy, ValueTy); // FIXME: Is division-by-2 the correct step? Should we assert that the // register size is a power-of-2? @@ -8277,7 +10200,7 @@ bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores, ArrayRef<Value *> Slice = makeArrayRef(Operands).slice(Cnt, Size); if (!VectorizedStores.count(Slice.front()) && !VectorizedStores.count(Slice.back()) && - vectorizeStoreChain(Slice, R, Cnt)) { + vectorizeStoreChain(Slice, R, Cnt, MinVF)) { // Mark the vectorized stores so that we don't vectorize them again. VectorizedStores.insert(Slice.begin(), Slice.end()); Changed = true; @@ -8481,7 +10404,8 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) { if (!I) return false; - if (!isa<BinaryOperator>(I) && !isa<CmpInst>(I)) + if ((!isa<BinaryOperator>(I) && !isa<CmpInst>(I)) || + isa<VectorType>(I->getType())) return false; Value *P = I->getParent(); @@ -8492,32 +10416,40 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) { if (!Op0 || !Op1 || Op0->getParent() != P || Op1->getParent() != P) return false; - // Try to vectorize V. - if (tryToVectorizePair(Op0, Op1, R)) - return true; + // First collect all possible candidates + SmallVector<std::pair<Value *, Value *>, 4> Candidates; + Candidates.emplace_back(Op0, Op1); auto *A = dyn_cast<BinaryOperator>(Op0); auto *B = dyn_cast<BinaryOperator>(Op1); // Try to skip B. - if (B && B->hasOneUse()) { + if (A && B && B->hasOneUse()) { auto *B0 = dyn_cast<BinaryOperator>(B->getOperand(0)); auto *B1 = dyn_cast<BinaryOperator>(B->getOperand(1)); - if (B0 && B0->getParent() == P && tryToVectorizePair(A, B0, R)) - return true; - if (B1 && B1->getParent() == P && tryToVectorizePair(A, B1, R)) - return true; + if (B0 && B0->getParent() == P) + Candidates.emplace_back(A, B0); + if (B1 && B1->getParent() == P) + Candidates.emplace_back(A, B1); } - // Try to skip A. - if (A && A->hasOneUse()) { + if (B && A && A->hasOneUse()) { auto *A0 = dyn_cast<BinaryOperator>(A->getOperand(0)); auto *A1 = dyn_cast<BinaryOperator>(A->getOperand(1)); - if (A0 && A0->getParent() == P && tryToVectorizePair(A0, B, R)) - return true; - if (A1 && A1->getParent() == P && tryToVectorizePair(A1, B, R)) - return true; + if (A0 && A0->getParent() == P) + Candidates.emplace_back(A0, B); + if (A1 && A1->getParent() == P) + Candidates.emplace_back(A1, B); } - return false; + + if (Candidates.size() == 1) + return tryToVectorizePair(Op0, Op1, R); + + // We have multiple options. Try to pick the single best. + Optional<int> BestCandidate = R.findBestRootPair(Candidates); + if (!BestCandidate) + return false; + return tryToVectorizePair(Candidates[*BestCandidate].first, + Candidates[*BestCandidate].second, R); } namespace { @@ -8552,15 +10484,16 @@ class HorizontalReduction { using ReductionOpsType = SmallVector<Value *, 16>; using ReductionOpsListType = SmallVector<ReductionOpsType, 2>; ReductionOpsListType ReductionOps; - SmallVector<Value *, 32> ReducedVals; + /// List of possibly reduced values. + SmallVector<SmallVector<Value *>> ReducedVals; + /// Maps reduced value to the corresponding reduction operation. + DenseMap<Value *, SmallVector<Instruction *>> ReducedValsToOps; // Use map vector to make stable output. MapVector<Instruction *, Value *> ExtraArgs; WeakTrackingVH ReductionRoot; /// The type of reduction operation. RecurKind RdxKind; - const unsigned INVALID_OPERAND_INDEX = std::numeric_limits<unsigned>::max(); - static bool isCmpSelMinMax(Instruction *I) { return match(I, m_Select(m_Cmp(), m_Value(), m_Value())) && RecurrenceDescriptor::isMinMaxRecurrenceKind(getRdxKind(I)); @@ -8604,26 +10537,6 @@ class HorizontalReduction { return I->getOperand(Index); } - /// Checks if the ParentStackElem.first should be marked as a reduction - /// operation with an extra argument or as extra argument itself. - void markExtraArg(std::pair<Instruction *, unsigned> &ParentStackElem, - Value *ExtraArg) { - if (ExtraArgs.count(ParentStackElem.first)) { - ExtraArgs[ParentStackElem.first] = nullptr; - // We ran into something like: - // ParentStackElem.first = ExtraArgs[ParentStackElem.first] + ExtraArg. - // The whole ParentStackElem.first should be considered as an extra value - // in this case. - // Do not perform analysis of remaining operands of ParentStackElem.first - // instruction, this whole instruction is an extra argument. - ParentStackElem.second = INVALID_OPERAND_INDEX; - } else { - // We ran into something like: - // ParentStackElem.first += ... + ExtraArg + ... - ExtraArgs[ParentStackElem.first] = ExtraArg; - } - } - /// Creates reduction operation with the current opcode. static Value *createOp(IRBuilder<> &Builder, RecurKind Kind, Value *LHS, Value *RHS, const Twine &Name, bool UseSelect) { @@ -8682,7 +10595,7 @@ class HorizontalReduction { } /// Creates reduction operation with the current opcode with the IR flags - /// from \p ReductionOps. + /// from \p ReductionOps, dropping nuw/nsw flags. static Value *createOp(IRBuilder<> &Builder, RecurKind RdxKind, Value *LHS, Value *RHS, const Twine &Name, const ReductionOpsListType &ReductionOps) { @@ -8696,31 +10609,21 @@ class HorizontalReduction { Value *Op = createOp(Builder, RdxKind, LHS, RHS, Name, UseSelect); if (RecurrenceDescriptor::isIntMinMaxRecurrenceKind(RdxKind)) { if (auto *Sel = dyn_cast<SelectInst>(Op)) { - propagateIRFlags(Sel->getCondition(), ReductionOps[0]); - propagateIRFlags(Op, ReductionOps[1]); + propagateIRFlags(Sel->getCondition(), ReductionOps[0], nullptr, + /*IncludeWrapFlags=*/false); + propagateIRFlags(Op, ReductionOps[1], nullptr, + /*IncludeWrapFlags=*/false); return Op; } } - propagateIRFlags(Op, ReductionOps[0]); - return Op; - } - - /// Creates reduction operation with the current opcode with the IR flags - /// from \p I. - static Value *createOp(IRBuilder<> &Builder, RecurKind RdxKind, Value *LHS, - Value *RHS, const Twine &Name, Instruction *I) { - auto *SelI = dyn_cast<SelectInst>(I); - Value *Op = createOp(Builder, RdxKind, LHS, RHS, Name, SelI != nullptr); - if (SelI && RecurrenceDescriptor::isIntMinMaxRecurrenceKind(RdxKind)) { - if (auto *Sel = dyn_cast<SelectInst>(Op)) - propagateIRFlags(Sel->getCondition(), SelI->getCondition()); - } - propagateIRFlags(Op, I); + propagateIRFlags(Op, ReductionOps[0], nullptr, /*IncludeWrapFlags=*/false); return Op; } - static RecurKind getRdxKind(Instruction *I) { - assert(I && "Expected instruction for reduction matching"); + static RecurKind getRdxKind(Value *V) { + auto *I = dyn_cast<Instruction>(V); + if (!I) + return RecurKind::None; if (match(I, m_Add(m_Value(), m_Value()))) return RecurKind::Add; if (match(I, m_Mul(m_Value(), m_Value()))) @@ -8882,7 +10785,9 @@ public: HorizontalReduction() = default; /// Try to find a reduction tree. - bool matchAssociativeReduction(PHINode *Phi, Instruction *Inst) { + bool matchAssociativeReduction(PHINode *Phi, Instruction *Inst, + ScalarEvolution &SE, const DataLayout &DL, + const TargetLibraryInfo &TLI) { assert((!Phi || is_contained(Phi->operands(), Inst)) && "Phi needs to use the binary operator"); assert((isa<BinaryOperator>(Inst) || isa<SelectInst>(Inst) || @@ -8926,124 +10831,178 @@ public: ReductionRoot = Inst; - // The opcode for leaf values that we perform a reduction on. - // For example: load(x) + load(y) + load(z) + fptoui(w) - // The leaf opcode for 'w' does not match, so we don't include it as a - // potential candidate for the reduction. - unsigned LeafOpcode = 0; - - // Post-order traverse the reduction tree starting at Inst. We only handle - // true trees containing binary operators or selects. - SmallVector<std::pair<Instruction *, unsigned>, 32> Stack; - Stack.push_back(std::make_pair(Inst, getFirstOperandIndex(Inst))); - initReductionOps(Inst); - while (!Stack.empty()) { - Instruction *TreeN = Stack.back().first; - unsigned EdgeToVisit = Stack.back().second++; - const RecurKind TreeRdxKind = getRdxKind(TreeN); - bool IsReducedValue = TreeRdxKind != RdxKind; - - // Postorder visit. - if (IsReducedValue || EdgeToVisit >= getNumberOfOperands(TreeN)) { - if (IsReducedValue) - ReducedVals.push_back(TreeN); - else { - auto ExtraArgsIter = ExtraArgs.find(TreeN); - if (ExtraArgsIter != ExtraArgs.end() && !ExtraArgsIter->second) { - // Check if TreeN is an extra argument of its parent operation. - if (Stack.size() <= 1) { - // TreeN can't be an extra argument as it is a root reduction - // operation. - return false; - } - // Yes, TreeN is an extra argument, do not add it to a list of - // reduction operations. - // Stack[Stack.size() - 2] always points to the parent operation. - markExtraArg(Stack[Stack.size() - 2], TreeN); - ExtraArgs.erase(TreeN); - } else - addReductionOps(TreeN); - } - // Retract. - Stack.pop_back(); - continue; - } - - // Visit operands. - Value *EdgeVal = getRdxOperand(TreeN, EdgeToVisit); - auto *EdgeInst = dyn_cast<Instruction>(EdgeVal); - if (!EdgeInst) { - // Edge value is not a reduction instruction or a leaf instruction. - // (It may be a constant, function argument, or something else.) - markExtraArg(Stack.back(), EdgeVal); - continue; + // Iterate through all the operands of the possible reduction tree and + // gather all the reduced values, sorting them by their value id. + BasicBlock *BB = Inst->getParent(); + bool IsCmpSelMinMax = isCmpSelMinMax(Inst); + SmallVector<Instruction *> Worklist(1, Inst); + // Checks if the operands of the \p TreeN instruction are also reduction + // operations or should be treated as reduced values or an extra argument, + // which is not part of the reduction. + auto &&CheckOperands = [this, IsCmpSelMinMax, + BB](Instruction *TreeN, + SmallVectorImpl<Value *> &ExtraArgs, + SmallVectorImpl<Value *> &PossibleReducedVals, + SmallVectorImpl<Instruction *> &ReductionOps) { + for (int I = getFirstOperandIndex(TreeN), + End = getNumberOfOperands(TreeN); + I < End; ++I) { + Value *EdgeVal = getRdxOperand(TreeN, I); + ReducedValsToOps[EdgeVal].push_back(TreeN); + auto *EdgeInst = dyn_cast<Instruction>(EdgeVal); + // Edge has wrong parent - mark as an extra argument. + if (EdgeInst && !isVectorLikeInstWithConstOps(EdgeInst) && + !hasSameParent(EdgeInst, BB)) { + ExtraArgs.push_back(EdgeVal); + continue; + } + // If the edge is not an instruction, or it is different from the main + // reduction opcode or has too many uses - possible reduced value. + if (!EdgeInst || getRdxKind(EdgeInst) != RdxKind || + IsCmpSelMinMax != isCmpSelMinMax(EdgeInst) || + !hasRequiredNumberOfUses(IsCmpSelMinMax, EdgeInst) || + !isVectorizable(getRdxKind(EdgeInst), EdgeInst)) { + PossibleReducedVals.push_back(EdgeVal); + continue; + } + ReductionOps.push_back(EdgeInst); } - RecurKind EdgeRdxKind = getRdxKind(EdgeInst); - // Continue analysis if the next operand is a reduction operation or - // (possibly) a leaf value. If the leaf value opcode is not set, - // the first met operation != reduction operation is considered as the - // leaf opcode. - // Only handle trees in the current basic block. - // Each tree node needs to have minimal number of users except for the - // ultimate reduction. - const bool IsRdxInst = EdgeRdxKind == RdxKind; - if (EdgeInst != Phi && EdgeInst != Inst && - hasSameParent(EdgeInst, Inst->getParent()) && - hasRequiredNumberOfUses(isCmpSelMinMax(Inst), EdgeInst) && - (!LeafOpcode || LeafOpcode == EdgeInst->getOpcode() || IsRdxInst)) { - if (IsRdxInst) { - // We need to be able to reassociate the reduction operations. - if (!isVectorizable(EdgeRdxKind, EdgeInst)) { - // I is an extra argument for TreeN (its parent operation). - markExtraArg(Stack.back(), EdgeInst); - continue; - } - } else if (!LeafOpcode) { - LeafOpcode = EdgeInst->getOpcode(); + }; + // Try to regroup reduced values so that it gets more profitable to try to + // reduce them. Values are grouped by their value ids, instructions - by + // instruction op id and/or alternate op id, plus do extra analysis for + // loads (grouping them by the distabce between pointers) and cmp + // instructions (grouping them by the predicate). + MapVector<size_t, MapVector<size_t, MapVector<Value *, unsigned>>> + PossibleReducedVals; + initReductionOps(Inst); + while (!Worklist.empty()) { + Instruction *TreeN = Worklist.pop_back_val(); + SmallVector<Value *> Args; + SmallVector<Value *> PossibleRedVals; + SmallVector<Instruction *> PossibleReductionOps; + CheckOperands(TreeN, Args, PossibleRedVals, PossibleReductionOps); + // If too many extra args - mark the instruction itself as a reduction + // value, not a reduction operation. + if (Args.size() < 2) { + addReductionOps(TreeN); + // Add extra args. + if (!Args.empty()) { + assert(Args.size() == 1 && "Expected only single argument."); + ExtraArgs[TreeN] = Args.front(); } - Stack.push_back( - std::make_pair(EdgeInst, getFirstOperandIndex(EdgeInst))); - continue; + // Add reduction values. The values are sorted for better vectorization + // results. + for (Value *V : PossibleRedVals) { + size_t Key, Idx; + std::tie(Key, Idx) = generateKeySubkey( + V, &TLI, + [&PossibleReducedVals, &DL, &SE](size_t Key, LoadInst *LI) { + auto It = PossibleReducedVals.find(Key); + if (It != PossibleReducedVals.end()) { + for (const auto &LoadData : It->second) { + auto *RLI = cast<LoadInst>(LoadData.second.front().first); + if (getPointersDiff(RLI->getType(), + RLI->getPointerOperand(), LI->getType(), + LI->getPointerOperand(), DL, SE, + /*StrictCheck=*/true)) + return hash_value(RLI->getPointerOperand()); + } + } + return hash_value(LI->getPointerOperand()); + }, + /*AllowAlternate=*/false); + ++PossibleReducedVals[Key][Idx] + .insert(std::make_pair(V, 0)) + .first->second; + } + Worklist.append(PossibleReductionOps.rbegin(), + PossibleReductionOps.rend()); + } else { + size_t Key, Idx; + std::tie(Key, Idx) = generateKeySubkey( + TreeN, &TLI, + [&PossibleReducedVals, &DL, &SE](size_t Key, LoadInst *LI) { + auto It = PossibleReducedVals.find(Key); + if (It != PossibleReducedVals.end()) { + for (const auto &LoadData : It->second) { + auto *RLI = cast<LoadInst>(LoadData.second.front().first); + if (getPointersDiff(RLI->getType(), RLI->getPointerOperand(), + LI->getType(), LI->getPointerOperand(), + DL, SE, /*StrictCheck=*/true)) + return hash_value(RLI->getPointerOperand()); + } + } + return hash_value(LI->getPointerOperand()); + }, + /*AllowAlternate=*/false); + ++PossibleReducedVals[Key][Idx] + .insert(std::make_pair(TreeN, 0)) + .first->second; + } + } + auto PossibleReducedValsVect = PossibleReducedVals.takeVector(); + // Sort values by the total number of values kinds to start the reduction + // from the longest possible reduced values sequences. + for (auto &PossibleReducedVals : PossibleReducedValsVect) { + auto PossibleRedVals = PossibleReducedVals.second.takeVector(); + SmallVector<SmallVector<Value *>> PossibleRedValsVect; + for (auto It = PossibleRedVals.begin(), E = PossibleRedVals.end(); + It != E; ++It) { + PossibleRedValsVect.emplace_back(); + auto RedValsVect = It->second.takeVector(); + stable_sort(RedValsVect, [](const auto &P1, const auto &P2) { + return P1.second < P2.second; + }); + for (const std::pair<Value *, unsigned> &Data : RedValsVect) + PossibleRedValsVect.back().append(Data.second, Data.first); } - // I is an extra argument for TreeN (its parent operation). - markExtraArg(Stack.back(), EdgeInst); - } + stable_sort(PossibleRedValsVect, [](const auto &P1, const auto &P2) { + return P1.size() > P2.size(); + }); + ReducedVals.emplace_back(); + for (ArrayRef<Value *> Data : PossibleRedValsVect) + ReducedVals.back().append(Data.rbegin(), Data.rend()); + } + // Sort the reduced values by number of same/alternate opcode and/or pointer + // operand. + stable_sort(ReducedVals, [](ArrayRef<Value *> P1, ArrayRef<Value *> P2) { + return P1.size() > P2.size(); + }); return true; } /// Attempt to vectorize the tree found by matchAssociativeReduction. Value *tryToReduce(BoUpSLP &V, TargetTransformInfo *TTI) { + constexpr int ReductionLimit = 4; + constexpr unsigned RegMaxNumber = 4; + constexpr unsigned RedValsMaxNumber = 128; // If there are a sufficient number of reduction values, reduce // to a nearby power-of-2. We can safely generate oversized // vectors and rely on the backend to split them to legal sizes. - unsigned NumReducedVals = ReducedVals.size(); - if (NumReducedVals < 4) + unsigned NumReducedVals = std::accumulate( + ReducedVals.begin(), ReducedVals.end(), 0, + [](int Num, ArrayRef<Value *> Vals) { return Num + Vals.size(); }); + if (NumReducedVals < ReductionLimit) return nullptr; - // Intersect the fast-math-flags from all reduction operations. - FastMathFlags RdxFMF; - RdxFMF.set(); - for (ReductionOpsType &RdxOp : ReductionOps) { - for (Value *RdxVal : RdxOp) { - if (auto *FPMO = dyn_cast<FPMathOperator>(RdxVal)) - RdxFMF &= FPMO->getFastMathFlags(); - } - } - IRBuilder<> Builder(cast<Instruction>(ReductionRoot)); - Builder.setFastMathFlags(RdxFMF); + // Track the reduced values in case if they are replaced by extractelement + // because of the vectorization. + DenseMap<Value *, WeakTrackingVH> TrackedVals; BoUpSLP::ExtraValueToDebugLocsMap ExternallyUsedValues; // The same extra argument may be used several times, so log each attempt // to use it. for (const std::pair<Instruction *, Value *> &Pair : ExtraArgs) { assert(Pair.first && "DebugLoc must be set."); ExternallyUsedValues[Pair.second].push_back(Pair.first); + TrackedVals.try_emplace(Pair.second, Pair.second); } // The compare instruction of a min/max is the insertion point for new // instructions and may be replaced with a new compare instruction. - auto getCmpForMinMaxReduction = [](Instruction *RdxRootInst) { + auto &&GetCmpForMinMaxReduction = [](Instruction *RdxRootInst) { assert(isa<SelectInst>(RdxRootInst) && "Expected min/max reduction to have select root instruction"); Value *ScalarCond = cast<SelectInst>(RdxRootInst)->getCondition(); @@ -9055,164 +11014,390 @@ public: // The reduction root is used as the insertion point for new instructions, // so set it as externally used to prevent it from being deleted. ExternallyUsedValues[ReductionRoot]; - SmallVector<Value *, 16> IgnoreList; - for (ReductionOpsType &RdxOp : ReductionOps) - IgnoreList.append(RdxOp.begin(), RdxOp.end()); - - unsigned ReduxWidth = PowerOf2Floor(NumReducedVals); - if (NumReducedVals > ReduxWidth) { - // In the loop below, we are building a tree based on a window of - // 'ReduxWidth' values. - // If the operands of those values have common traits (compare predicate, - // constant operand, etc), then we want to group those together to - // minimize the cost of the reduction. - - // TODO: This should be extended to count common operands for - // compares and binops. - - // Step 1: Count the number of times each compare predicate occurs. - SmallDenseMap<unsigned, unsigned> PredCountMap; - for (Value *RdxVal : ReducedVals) { - CmpInst::Predicate Pred; - if (match(RdxVal, m_Cmp(Pred, m_Value(), m_Value()))) - ++PredCountMap[Pred]; - } - // Step 2: Sort the values so the most common predicates come first. - stable_sort(ReducedVals, [&PredCountMap](Value *A, Value *B) { - CmpInst::Predicate PredA, PredB; - if (match(A, m_Cmp(PredA, m_Value(), m_Value())) && - match(B, m_Cmp(PredB, m_Value(), m_Value()))) { - return PredCountMap[PredA] > PredCountMap[PredB]; - } - return false; - }); - } + SmallDenseSet<Value *> IgnoreList; + for (ReductionOpsType &RdxOps : ReductionOps) + for (Value *RdxOp : RdxOps) { + if (!RdxOp) + continue; + IgnoreList.insert(RdxOp); + } + bool IsCmpSelMinMax = isCmpSelMinMax(cast<Instruction>(ReductionRoot)); + + // Need to track reduced vals, they may be changed during vectorization of + // subvectors. + for (ArrayRef<Value *> Candidates : ReducedVals) + for (Value *V : Candidates) + TrackedVals.try_emplace(V, V); + DenseMap<Value *, unsigned> VectorizedVals; Value *VectorizedTree = nullptr; - unsigned i = 0; - while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) { - ArrayRef<Value *> VL(&ReducedVals[i], ReduxWidth); - V.buildTree(VL, IgnoreList); - if (V.isTreeTinyAndNotFullyVectorizable(/*ForReduction=*/true)) - break; - if (V.isLoadCombineReductionCandidate(RdxKind)) - break; - V.reorderTopToBottom(); - V.reorderBottomToTop(/*IgnoreReorder=*/true); - V.buildExternalUses(ExternallyUsedValues); - - // For a poison-safe boolean logic reduction, do not replace select - // instructions with logic ops. All reduced values will be frozen (see - // below) to prevent leaking poison. - if (isa<SelectInst>(ReductionRoot) && - isBoolLogicOp(cast<Instruction>(ReductionRoot)) && - NumReducedVals != ReduxWidth) - break; + bool CheckForReusedReductionOps = false; + // Try to vectorize elements based on their type. + for (unsigned I = 0, E = ReducedVals.size(); I < E; ++I) { + ArrayRef<Value *> OrigReducedVals = ReducedVals[I]; + InstructionsState S = getSameOpcode(OrigReducedVals); + SmallVector<Value *> Candidates; + DenseMap<Value *, Value *> TrackedToOrig; + for (unsigned Cnt = 0, Sz = OrigReducedVals.size(); Cnt < Sz; ++Cnt) { + Value *RdxVal = TrackedVals.find(OrigReducedVals[Cnt])->second; + // Check if the reduction value was not overriden by the extractelement + // instruction because of the vectorization and exclude it, if it is not + // compatible with other values. + if (auto *Inst = dyn_cast<Instruction>(RdxVal)) + if (isVectorLikeInstWithConstOps(Inst) && + (!S.getOpcode() || !S.isOpcodeOrAlt(Inst))) + continue; + Candidates.push_back(RdxVal); + TrackedToOrig.try_emplace(RdxVal, OrigReducedVals[Cnt]); + } + bool ShuffledExtracts = false; + // Try to handle shuffled extractelements. + if (S.getOpcode() == Instruction::ExtractElement && !S.isAltShuffle() && + I + 1 < E) { + InstructionsState NextS = getSameOpcode(ReducedVals[I + 1]); + if (NextS.getOpcode() == Instruction::ExtractElement && + !NextS.isAltShuffle()) { + SmallVector<Value *> CommonCandidates(Candidates); + for (Value *RV : ReducedVals[I + 1]) { + Value *RdxVal = TrackedVals.find(RV)->second; + // Check if the reduction value was not overriden by the + // extractelement instruction because of the vectorization and + // exclude it, if it is not compatible with other values. + if (auto *Inst = dyn_cast<Instruction>(RdxVal)) + if (!NextS.getOpcode() || !NextS.isOpcodeOrAlt(Inst)) + continue; + CommonCandidates.push_back(RdxVal); + TrackedToOrig.try_emplace(RdxVal, RV); + } + SmallVector<int> Mask; + if (isFixedVectorShuffle(CommonCandidates, Mask)) { + ++I; + Candidates.swap(CommonCandidates); + ShuffledExtracts = true; + } + } + } + unsigned NumReducedVals = Candidates.size(); + if (NumReducedVals < ReductionLimit) + continue; - V.computeMinimumValueSizes(); + unsigned MaxVecRegSize = V.getMaxVecRegSize(); + unsigned EltSize = V.getVectorElementSize(Candidates[0]); + unsigned MaxElts = RegMaxNumber * PowerOf2Floor(MaxVecRegSize / EltSize); + + unsigned ReduxWidth = std::min<unsigned>( + PowerOf2Floor(NumReducedVals), std::max(RedValsMaxNumber, MaxElts)); + unsigned Start = 0; + unsigned Pos = Start; + // Restarts vectorization attempt with lower vector factor. + unsigned PrevReduxWidth = ReduxWidth; + bool CheckForReusedReductionOpsLocal = false; + auto &&AdjustReducedVals = [&Pos, &Start, &ReduxWidth, NumReducedVals, + &CheckForReusedReductionOpsLocal, + &PrevReduxWidth, &V, + &IgnoreList](bool IgnoreVL = false) { + bool IsAnyRedOpGathered = !IgnoreVL && V.isAnyGathered(IgnoreList); + if (!CheckForReusedReductionOpsLocal && PrevReduxWidth == ReduxWidth) { + // Check if any of the reduction ops are gathered. If so, worth + // trying again with less number of reduction ops. + CheckForReusedReductionOpsLocal |= IsAnyRedOpGathered; + } + ++Pos; + if (Pos < NumReducedVals - ReduxWidth + 1) + return IsAnyRedOpGathered; + Pos = Start; + ReduxWidth /= 2; + return IsAnyRedOpGathered; + }; + while (Pos < NumReducedVals - ReduxWidth + 1 && + ReduxWidth >= ReductionLimit) { + // Dependency in tree of the reduction ops - drop this attempt, try + // later. + if (CheckForReusedReductionOpsLocal && PrevReduxWidth != ReduxWidth && + Start == 0) { + CheckForReusedReductionOps = true; + break; + } + PrevReduxWidth = ReduxWidth; + ArrayRef<Value *> VL(std::next(Candidates.begin(), Pos), ReduxWidth); + // Beeing analyzed already - skip. + if (V.areAnalyzedReductionVals(VL)) { + (void)AdjustReducedVals(/*IgnoreVL=*/true); + continue; + } + // Early exit if any of the reduction values were deleted during + // previous vectorization attempts. + if (any_of(VL, [&V](Value *RedVal) { + auto *RedValI = dyn_cast<Instruction>(RedVal); + if (!RedValI) + return false; + return V.isDeleted(RedValI); + })) + break; + V.buildTree(VL, IgnoreList); + if (V.isTreeTinyAndNotFullyVectorizable(/*ForReduction=*/true)) { + if (!AdjustReducedVals()) + V.analyzedReductionVals(VL); + continue; + } + if (V.isLoadCombineReductionCandidate(RdxKind)) { + if (!AdjustReducedVals()) + V.analyzedReductionVals(VL); + continue; + } + V.reorderTopToBottom(); + // No need to reorder the root node at all. + V.reorderBottomToTop(/*IgnoreReorder=*/true); + // Keep extracted other reduction values, if they are used in the + // vectorization trees. + BoUpSLP::ExtraValueToDebugLocsMap LocalExternallyUsedValues( + ExternallyUsedValues); + for (unsigned Cnt = 0, Sz = ReducedVals.size(); Cnt < Sz; ++Cnt) { + if (Cnt == I || (ShuffledExtracts && Cnt == I - 1)) + continue; + for_each(ReducedVals[Cnt], + [&LocalExternallyUsedValues, &TrackedVals](Value *V) { + if (isa<Instruction>(V)) + LocalExternallyUsedValues[TrackedVals[V]]; + }); + } + // Number of uses of the candidates in the vector of values. + SmallDenseMap<Value *, unsigned> NumUses; + for (unsigned Cnt = 0; Cnt < Pos; ++Cnt) { + Value *V = Candidates[Cnt]; + if (NumUses.count(V) > 0) + continue; + NumUses[V] = std::count(VL.begin(), VL.end(), V); + } + for (unsigned Cnt = Pos + ReduxWidth; Cnt < NumReducedVals; ++Cnt) { + Value *V = Candidates[Cnt]; + if (NumUses.count(V) > 0) + continue; + NumUses[V] = std::count(VL.begin(), VL.end(), V); + } + // Gather externally used values. + SmallPtrSet<Value *, 4> Visited; + for (unsigned Cnt = 0; Cnt < Pos; ++Cnt) { + Value *V = Candidates[Cnt]; + if (!Visited.insert(V).second) + continue; + unsigned NumOps = VectorizedVals.lookup(V) + NumUses[V]; + if (NumOps != ReducedValsToOps.find(V)->second.size()) + LocalExternallyUsedValues[V]; + } + for (unsigned Cnt = Pos + ReduxWidth; Cnt < NumReducedVals; ++Cnt) { + Value *V = Candidates[Cnt]; + if (!Visited.insert(V).second) + continue; + unsigned NumOps = VectorizedVals.lookup(V) + NumUses[V]; + if (NumOps != ReducedValsToOps.find(V)->second.size()) + LocalExternallyUsedValues[V]; + } + V.buildExternalUses(LocalExternallyUsedValues); + + V.computeMinimumValueSizes(); + + // Intersect the fast-math-flags from all reduction operations. + FastMathFlags RdxFMF; + RdxFMF.set(); + for (Value *U : IgnoreList) + if (auto *FPMO = dyn_cast<FPMathOperator>(U)) + RdxFMF &= FPMO->getFastMathFlags(); + // Estimate cost. + InstructionCost TreeCost = V.getTreeCost(VL); + InstructionCost ReductionCost = + getReductionCost(TTI, VL, ReduxWidth, RdxFMF); + InstructionCost Cost = TreeCost + ReductionCost; + if (!Cost.isValid()) { + LLVM_DEBUG(dbgs() << "Encountered invalid baseline cost.\n"); + return nullptr; + } + if (Cost >= -SLPCostThreshold) { + V.getORE()->emit([&]() { + return OptimizationRemarkMissed( + SV_NAME, "HorSLPNotBeneficial", + ReducedValsToOps.find(VL[0])->second.front()) + << "Vectorizing horizontal reduction is possible" + << "but not beneficial with cost " << ore::NV("Cost", Cost) + << " and threshold " + << ore::NV("Threshold", -SLPCostThreshold); + }); + if (!AdjustReducedVals()) + V.analyzedReductionVals(VL); + continue; + } - // Estimate cost. - InstructionCost TreeCost = - V.getTreeCost(makeArrayRef(&ReducedVals[i], ReduxWidth)); - InstructionCost ReductionCost = - getReductionCost(TTI, ReducedVals[i], ReduxWidth, RdxFMF); - InstructionCost Cost = TreeCost + ReductionCost; - if (!Cost.isValid()) { - LLVM_DEBUG(dbgs() << "Encountered invalid baseline cost.\n"); - return nullptr; - } - if (Cost >= -SLPCostThreshold) { + LLVM_DEBUG(dbgs() << "SLP: Vectorizing horizontal reduction at cost:" + << Cost << ". (HorRdx)\n"); V.getORE()->emit([&]() { - return OptimizationRemarkMissed(SV_NAME, "HorSLPNotBeneficial", - cast<Instruction>(VL[0])) - << "Vectorizing horizontal reduction is possible" - << "but not beneficial with cost " << ore::NV("Cost", Cost) - << " and threshold " - << ore::NV("Threshold", -SLPCostThreshold); + return OptimizationRemark( + SV_NAME, "VectorizedHorizontalReduction", + ReducedValsToOps.find(VL[0])->second.front()) + << "Vectorized horizontal reduction with cost " + << ore::NV("Cost", Cost) << " and with tree size " + << ore::NV("TreeSize", V.getTreeSize()); }); - break; - } - LLVM_DEBUG(dbgs() << "SLP: Vectorizing horizontal reduction at cost:" - << Cost << ". (HorRdx)\n"); - V.getORE()->emit([&]() { - return OptimizationRemark(SV_NAME, "VectorizedHorizontalReduction", - cast<Instruction>(VL[0])) - << "Vectorized horizontal reduction with cost " - << ore::NV("Cost", Cost) << " and with tree size " - << ore::NV("TreeSize", V.getTreeSize()); - }); + Builder.setFastMathFlags(RdxFMF); - // Vectorize a tree. - DebugLoc Loc = cast<Instruction>(ReducedVals[i])->getDebugLoc(); - Value *VectorizedRoot = V.vectorizeTree(ExternallyUsedValues); + // Vectorize a tree. + Value *VectorizedRoot = V.vectorizeTree(LocalExternallyUsedValues); - // Emit a reduction. If the root is a select (min/max idiom), the insert - // point is the compare condition of that select. - Instruction *RdxRootInst = cast<Instruction>(ReductionRoot); - if (isCmpSelMinMax(RdxRootInst)) - Builder.SetInsertPoint(getCmpForMinMaxReduction(RdxRootInst)); - else - Builder.SetInsertPoint(RdxRootInst); + // Emit a reduction. If the root is a select (min/max idiom), the insert + // point is the compare condition of that select. + Instruction *RdxRootInst = cast<Instruction>(ReductionRoot); + if (IsCmpSelMinMax) + Builder.SetInsertPoint(GetCmpForMinMaxReduction(RdxRootInst)); + else + Builder.SetInsertPoint(RdxRootInst); - // To prevent poison from leaking across what used to be sequential, safe, - // scalar boolean logic operations, the reduction operand must be frozen. - if (isa<SelectInst>(RdxRootInst) && isBoolLogicOp(RdxRootInst)) - VectorizedRoot = Builder.CreateFreeze(VectorizedRoot); + // To prevent poison from leaking across what used to be sequential, + // safe, scalar boolean logic operations, the reduction operand must be + // frozen. + if (isa<SelectInst>(RdxRootInst) && isBoolLogicOp(RdxRootInst)) + VectorizedRoot = Builder.CreateFreeze(VectorizedRoot); - Value *ReducedSubTree = - emitReduction(VectorizedRoot, Builder, ReduxWidth, TTI); + Value *ReducedSubTree = + emitReduction(VectorizedRoot, Builder, ReduxWidth, TTI); - if (!VectorizedTree) { - // Initialize the final value in the reduction. - VectorizedTree = ReducedSubTree; - } else { - // Update the final value in the reduction. - Builder.SetCurrentDebugLocation(Loc); - VectorizedTree = createOp(Builder, RdxKind, VectorizedTree, - ReducedSubTree, "op.rdx", ReductionOps); + if (!VectorizedTree) { + // Initialize the final value in the reduction. + VectorizedTree = ReducedSubTree; + } else { + // Update the final value in the reduction. + Builder.SetCurrentDebugLocation( + cast<Instruction>(ReductionOps.front().front())->getDebugLoc()); + VectorizedTree = createOp(Builder, RdxKind, VectorizedTree, + ReducedSubTree, "op.rdx", ReductionOps); + } + // Count vectorized reduced values to exclude them from final reduction. + for (Value *V : VL) + ++VectorizedVals.try_emplace(TrackedToOrig.find(V)->second, 0) + .first->getSecond(); + Pos += ReduxWidth; + Start = Pos; + ReduxWidth = PowerOf2Floor(NumReducedVals - Pos); } - i += ReduxWidth; - ReduxWidth = PowerOf2Floor(NumReducedVals - i); } - if (VectorizedTree) { // Finish the reduction. - for (; i < NumReducedVals; ++i) { - auto *I = cast<Instruction>(ReducedVals[i]); - Builder.SetCurrentDebugLocation(I->getDebugLoc()); - VectorizedTree = - createOp(Builder, RdxKind, VectorizedTree, I, "", ReductionOps); + // Need to add extra arguments and not vectorized possible reduction + // values. + // Try to avoid dependencies between the scalar remainders after + // reductions. + auto &&FinalGen = + [this, &Builder, + &TrackedVals](ArrayRef<std::pair<Instruction *, Value *>> InstVals) { + unsigned Sz = InstVals.size(); + SmallVector<std::pair<Instruction *, Value *>> ExtraReds(Sz / 2 + + Sz % 2); + for (unsigned I = 0, E = (Sz / 2) * 2; I < E; I += 2) { + Instruction *RedOp = InstVals[I + 1].first; + Builder.SetCurrentDebugLocation(RedOp->getDebugLoc()); + Value *RdxVal1 = InstVals[I].second; + Value *StableRdxVal1 = RdxVal1; + auto It1 = TrackedVals.find(RdxVal1); + if (It1 != TrackedVals.end()) + StableRdxVal1 = It1->second; + Value *RdxVal2 = InstVals[I + 1].second; + Value *StableRdxVal2 = RdxVal2; + auto It2 = TrackedVals.find(RdxVal2); + if (It2 != TrackedVals.end()) + StableRdxVal2 = It2->second; + Value *ExtraRed = createOp(Builder, RdxKind, StableRdxVal1, + StableRdxVal2, "op.rdx", ReductionOps); + ExtraReds[I / 2] = std::make_pair(InstVals[I].first, ExtraRed); + } + if (Sz % 2 == 1) + ExtraReds[Sz / 2] = InstVals.back(); + return ExtraReds; + }; + SmallVector<std::pair<Instruction *, Value *>> ExtraReductions; + SmallPtrSet<Value *, 8> Visited; + for (ArrayRef<Value *> Candidates : ReducedVals) { + for (Value *RdxVal : Candidates) { + if (!Visited.insert(RdxVal).second) + continue; + unsigned NumOps = VectorizedVals.lookup(RdxVal); + for (Instruction *RedOp : + makeArrayRef(ReducedValsToOps.find(RdxVal)->second) + .drop_back(NumOps)) + ExtraReductions.emplace_back(RedOp, RdxVal); + } } for (auto &Pair : ExternallyUsedValues) { // Add each externally used value to the final reduction. - for (auto *I : Pair.second) { - Builder.SetCurrentDebugLocation(I->getDebugLoc()); - VectorizedTree = createOp(Builder, RdxKind, VectorizedTree, - Pair.first, "op.extra", I); - } + for (auto *I : Pair.second) + ExtraReductions.emplace_back(I, Pair.first); + } + // Iterate through all not-vectorized reduction values/extra arguments. + while (ExtraReductions.size() > 1) { + SmallVector<std::pair<Instruction *, Value *>> NewReds = + FinalGen(ExtraReductions); + ExtraReductions.swap(NewReds); + } + // Final reduction. + if (ExtraReductions.size() == 1) { + Instruction *RedOp = ExtraReductions.back().first; + Builder.SetCurrentDebugLocation(RedOp->getDebugLoc()); + Value *RdxVal = ExtraReductions.back().second; + Value *StableRdxVal = RdxVal; + auto It = TrackedVals.find(RdxVal); + if (It != TrackedVals.end()) + StableRdxVal = It->second; + VectorizedTree = createOp(Builder, RdxKind, VectorizedTree, + StableRdxVal, "op.rdx", ReductionOps); } ReductionRoot->replaceAllUsesWith(VectorizedTree); - // Mark all scalar reduction ops for deletion, they are replaced by the - // vector reductions. - V.eraseInstructions(IgnoreList); + // The original scalar reduction is expected to have no remaining + // uses outside the reduction tree itself. Assert that we got this + // correct, replace internal uses with undef, and mark for eventual + // deletion. +#ifndef NDEBUG + SmallSet<Value *, 4> IgnoreSet; + for (ArrayRef<Value *> RdxOps : ReductionOps) + IgnoreSet.insert(RdxOps.begin(), RdxOps.end()); +#endif + for (ArrayRef<Value *> RdxOps : ReductionOps) { + for (Value *Ignore : RdxOps) { + if (!Ignore) + continue; +#ifndef NDEBUG + for (auto *U : Ignore->users()) { + assert(IgnoreSet.count(U) && + "All users must be either in the reduction ops list."); + } +#endif + if (!Ignore->use_empty()) { + Value *Undef = UndefValue::get(Ignore->getType()); + Ignore->replaceAllUsesWith(Undef); + } + V.eraseInstruction(cast<Instruction>(Ignore)); + } + } + } else if (!CheckForReusedReductionOps) { + for (ReductionOpsType &RdxOps : ReductionOps) + for (Value *RdxOp : RdxOps) + V.analyzedReductionRoot(cast<Instruction>(RdxOp)); } return VectorizedTree; } - unsigned numReductionValues() const { return ReducedVals.size(); } - private: /// Calculate the cost of a reduction. InstructionCost getReductionCost(TargetTransformInfo *TTI, - Value *FirstReducedVal, unsigned ReduxWidth, - FastMathFlags FMF) { + ArrayRef<Value *> ReducedVals, + unsigned ReduxWidth, FastMathFlags FMF) { TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + Value *FirstReducedVal = ReducedVals.front(); Type *ScalarTy = FirstReducedVal->getType(); FixedVectorType *VectorTy = FixedVectorType::get(ScalarTy, ReduxWidth); - InstructionCost VectorCost, ScalarCost; + InstructionCost VectorCost = 0, ScalarCost; + // If all of the reduced values are constant, the vector cost is 0, since + // the reduction value can be calculated at the compile time. + bool AllConsts = all_of(ReducedVals, isConstant); switch (RdxKind) { case RecurKind::Add: case RecurKind::Mul: @@ -9222,17 +11407,22 @@ private: case RecurKind::FAdd: case RecurKind::FMul: { unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(RdxKind); - VectorCost = - TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF, CostKind); + if (!AllConsts) + VectorCost = + TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF, CostKind); ScalarCost = TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy, CostKind); break; } case RecurKind::FMax: case RecurKind::FMin: { auto *SclCondTy = CmpInst::makeCmpResultType(ScalarTy); - auto *VecCondTy = cast<VectorType>(CmpInst::makeCmpResultType(VectorTy)); - VectorCost = TTI->getMinMaxReductionCost(VectorTy, VecCondTy, - /*IsUnsigned=*/false, CostKind); + if (!AllConsts) { + auto *VecCondTy = + cast<VectorType>(CmpInst::makeCmpResultType(VectorTy)); + VectorCost = + TTI->getMinMaxReductionCost(VectorTy, VecCondTy, + /*IsUnsigned=*/false, CostKind); + } CmpInst::Predicate RdxPred = getMinMaxReductionPredicate(RdxKind); ScalarCost = TTI->getCmpSelInstrCost(Instruction::FCmp, ScalarTy, SclCondTy, RdxPred, CostKind) + @@ -9245,11 +11435,14 @@ private: case RecurKind::UMax: case RecurKind::UMin: { auto *SclCondTy = CmpInst::makeCmpResultType(ScalarTy); - auto *VecCondTy = cast<VectorType>(CmpInst::makeCmpResultType(VectorTy)); - bool IsUnsigned = - RdxKind == RecurKind::UMax || RdxKind == RecurKind::UMin; - VectorCost = TTI->getMinMaxReductionCost(VectorTy, VecCondTy, IsUnsigned, - CostKind); + if (!AllConsts) { + auto *VecCondTy = + cast<VectorType>(CmpInst::makeCmpResultType(VectorTy)); + bool IsUnsigned = + RdxKind == RecurKind::UMax || RdxKind == RecurKind::UMin; + VectorCost = TTI->getMinMaxReductionCost(VectorTy, VecCondTy, + IsUnsigned, CostKind); + } CmpInst::Predicate RdxPred = getMinMaxReductionPredicate(RdxKind); ScalarCost = TTI->getCmpSelInstrCost(Instruction::ICmp, ScalarTy, SclCondTy, RdxPred, CostKind) + @@ -9463,7 +11656,8 @@ static bool matchRdxBop(Instruction *I, Value *&V0, Value *&V1) { /// performed. static bool tryToVectorizeHorReductionOrInstOperands( PHINode *P, Instruction *Root, BasicBlock *BB, BoUpSLP &R, - TargetTransformInfo *TTI, + TargetTransformInfo *TTI, ScalarEvolution &SE, const DataLayout &DL, + const TargetLibraryInfo &TLI, const function_ref<bool(Instruction *, BoUpSLP &)> Vectorize) { if (!ShouldVectorizeHor) return false; @@ -9482,7 +11676,7 @@ static bool tryToVectorizeHorReductionOrInstOperands( // horizontal reduction. // Interrupt the process if the Root instruction itself was vectorized or all // sub-trees not higher that RecursionMaxDepth were analyzed/vectorized. - // Skip the analysis of CmpInsts.Compiler implements postanalysis of the + // Skip the analysis of CmpInsts. Compiler implements postanalysis of the // CmpInsts so we can skip extra attempts in // tryToVectorizeHorReductionOrInstOperands and save compile time. std::queue<std::pair<Instruction *, unsigned>> Stack; @@ -9490,13 +11684,16 @@ static bool tryToVectorizeHorReductionOrInstOperands( SmallPtrSet<Value *, 8> VisitedInstrs; SmallVector<WeakTrackingVH> PostponedInsts; bool Res = false; - auto &&TryToReduce = [TTI, &P, &R](Instruction *Inst, Value *&B0, - Value *&B1) -> Value * { + auto &&TryToReduce = [TTI, &SE, &DL, &P, &R, &TLI](Instruction *Inst, + Value *&B0, + Value *&B1) -> Value * { + if (R.isAnalyzedReductionRoot(Inst)) + return nullptr; bool IsBinop = matchRdxBop(Inst, B0, B1); bool IsSelect = match(Inst, m_Select(m_Value(), m_Value(), m_Value())); if (IsBinop || IsSelect) { HorizontalReduction HorRdx; - if (HorRdx.matchAssociativeReduction(P, Inst)) + if (HorRdx.matchAssociativeReduction(P, Inst, SE, DL, TLI)) return HorRdx.tryToReduce(R, TTI); } return nullptr; @@ -9541,7 +11738,7 @@ static bool tryToVectorizeHorReductionOrInstOperands( // Do not try to vectorize CmpInst operands, this is done separately. // Final attempt for binop args vectorization should happen after the loop // to try to find reductions. - if (!isa<CmpInst>(Inst)) + if (!isa<CmpInst, InsertElementInst, InsertValueInst>(Inst)) PostponedInsts.push_back(Inst); } @@ -9554,8 +11751,8 @@ static bool tryToVectorizeHorReductionOrInstOperands( if (auto *I = dyn_cast<Instruction>(Op)) // Do not try to vectorize CmpInst operands, this is done // separately. - if (!isa<PHINode>(I) && !isa<CmpInst>(I) && !R.isDeleted(I) && - I->getParent() == BB) + if (!isa<PHINode, CmpInst, InsertElementInst, InsertValueInst>(I) && + !R.isDeleted(I) && I->getParent() == BB) Stack.emplace(I, Level); } // Try to vectorized binops where reductions were not found. @@ -9579,8 +11776,8 @@ bool SLPVectorizerPass::vectorizeRootInstruction(PHINode *P, Value *V, auto &&ExtraVectorization = [this](Instruction *I, BoUpSLP &R) -> bool { return tryToVectorize(I, R); }; - return tryToVectorizeHorReductionOrInstOperands(P, I, BB, R, TTI, - ExtraVectorization); + return tryToVectorizeHorReductionOrInstOperands(P, I, BB, R, TTI, *SE, *DL, + *TLI, ExtraVectorization); } bool SLPVectorizerPass::vectorizeInsertValueInst(InsertValueInst *IVI, @@ -9748,12 +11945,16 @@ bool SLPVectorizerPass::vectorizeSimpleInstructions( for (auto *I : reverse(Instructions)) { if (R.isDeleted(I)) continue; - if (auto *LastInsertValue = dyn_cast<InsertValueInst>(I)) + if (auto *LastInsertValue = dyn_cast<InsertValueInst>(I)) { OpsChanged |= vectorizeInsertValueInst(LastInsertValue, BB, R); - else if (auto *LastInsertElem = dyn_cast<InsertElementInst>(I)) + } else if (auto *LastInsertElem = dyn_cast<InsertElementInst>(I)) { OpsChanged |= vectorizeInsertElementInst(LastInsertElem, BB, R); - else if (isa<CmpInst>(I)) + } else if (isa<CmpInst>(I)) { PostponedCmps.push_back(I); + continue; + } + // Try to find reductions in buildvector sequnces. + OpsChanged |= vectorizeRootInstruction(nullptr, I, BB, R, TTI); } if (AtTerminator) { // Try to find reductions first. @@ -10171,7 +12372,7 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { DomTreeNodeBase<llvm::BasicBlock> *NodeI2 = DT->getNode(I2->getParent()); assert(NodeI1 && "Should only process reachable instructions"); - assert(NodeI1 && "Should only process reachable instructions"); + assert(NodeI2 && "Should only process reachable instructions"); assert((NodeI1 == NodeI2) == (NodeI1->getDFSNumIn() == NodeI2->getDFSNumIn()) && "Different nodes should have different DFS numbers"); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h index 8822c0004eb2..97f2b1a93815 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h @@ -72,17 +72,17 @@ class VPRecipeBuilder { VPRecipeBase *tryToWidenMemory(Instruction *I, ArrayRef<VPValue *> Operands, VFRange &Range, VPlanPtr &Plan); - /// Check if an induction recipe should be constructed for \I. If so build and - /// return it. If not, return null. - VPWidenIntOrFpInductionRecipe * - tryToOptimizeInductionPHI(PHINode *Phi, ArrayRef<VPValue *> Operands, - VFRange &Range) const; + /// Check if an induction recipe should be constructed for \p Phi. If so build + /// and return it. If not, return null. + VPRecipeBase *tryToOptimizeInductionPHI(PHINode *Phi, + ArrayRef<VPValue *> Operands, + VPlan &Plan, VFRange &Range); /// Optimize the special case where the operand of \p I is a constant integer /// induction variable. VPWidenIntOrFpInductionRecipe * tryToOptimizeInductionTruncate(TruncInst *I, ArrayRef<VPValue *> Operands, - VFRange &Range, VPlan &Plan) const; + VFRange &Range, VPlan &Plan); /// Handle non-loop phi nodes. Return a VPValue, if all incoming values match /// or a new VPBlendRecipe otherwise. Currently all such phi nodes are turned diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.cpp index 342d4a074e10..4d709097c306 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -23,11 +23,10 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" -#include "llvm/Analysis/IVDescriptors.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" -#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Type.h" @@ -35,13 +34,13 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/GenericDomTreeConstruction.h" #include "llvm/Support/GraphWriter.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/LoopVersioning.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include <cassert> -#include <iterator> #include <string> #include <vector> @@ -60,7 +59,7 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, const VPValue &V) { } #endif -Value *VPLane::getAsRuntimeExpr(IRBuilder<> &Builder, +Value *VPLane::getAsRuntimeExpr(IRBuilderBase &Builder, const ElementCount &VF) const { switch (LaneKind) { case VPLane::Kind::ScalableLast: @@ -158,25 +157,25 @@ void VPBlockBase::setPlan(VPlan *ParentPlan) { } /// \return the VPBasicBlock that is the exit of Block, possibly indirectly. -const VPBasicBlock *VPBlockBase::getExitBasicBlock() const { +const VPBasicBlock *VPBlockBase::getExitingBasicBlock() const { const VPBlockBase *Block = this; while (const VPRegionBlock *Region = dyn_cast<VPRegionBlock>(Block)) - Block = Region->getExit(); + Block = Region->getExiting(); return cast<VPBasicBlock>(Block); } -VPBasicBlock *VPBlockBase::getExitBasicBlock() { +VPBasicBlock *VPBlockBase::getExitingBasicBlock() { VPBlockBase *Block = this; while (VPRegionBlock *Region = dyn_cast<VPRegionBlock>(Block)) - Block = Region->getExit(); + Block = Region->getExiting(); return cast<VPBasicBlock>(Block); } VPBlockBase *VPBlockBase::getEnclosingBlockWithSuccessors() { if (!Successors.empty() || !Parent) return this; - assert(Parent->getExit() == this && - "Block w/o successors not the exit of its parent."); + assert(Parent->getExiting() == this && + "Block w/o successors not the exiting block of its parent."); return Parent->getEnclosingBlockWithSuccessors(); } @@ -188,28 +187,6 @@ VPBlockBase *VPBlockBase::getEnclosingBlockWithPredecessors() { return Parent->getEnclosingBlockWithPredecessors(); } -VPValue *VPBlockBase::getCondBit() { - return CondBitUser.getSingleOperandOrNull(); -} - -const VPValue *VPBlockBase::getCondBit() const { - return CondBitUser.getSingleOperandOrNull(); -} - -void VPBlockBase::setCondBit(VPValue *CV) { CondBitUser.resetSingleOpUser(CV); } - -VPValue *VPBlockBase::getPredicate() { - return PredicateUser.getSingleOperandOrNull(); -} - -const VPValue *VPBlockBase::getPredicate() const { - return PredicateUser.getSingleOperandOrNull(); -} - -void VPBlockBase::setPredicate(VPValue *CV) { - PredicateUser.resetSingleOpUser(CV); -} - void VPBlockBase::deleteCFG(VPBlockBase *Entry) { SmallVector<VPBlockBase *, 8> Blocks(depth_first(Entry)); @@ -245,6 +222,52 @@ Value *VPTransformState::get(VPValue *Def, const VPIteration &Instance) { // set(Def, Extract, Instance); return Extract; } +BasicBlock *VPTransformState::CFGState::getPreheaderBBFor(VPRecipeBase *R) { + VPRegionBlock *LoopRegion = R->getParent()->getEnclosingLoopRegion(); + return VPBB2IRBB[LoopRegion->getPreheaderVPBB()]; +} + +void VPTransformState::addNewMetadata(Instruction *To, + const Instruction *Orig) { + // If the loop was versioned with memchecks, add the corresponding no-alias + // metadata. + if (LVer && (isa<LoadInst>(Orig) || isa<StoreInst>(Orig))) + LVer->annotateInstWithNoAlias(To, Orig); +} + +void VPTransformState::addMetadata(Instruction *To, Instruction *From) { + propagateMetadata(To, From); + addNewMetadata(To, From); +} + +void VPTransformState::addMetadata(ArrayRef<Value *> To, Instruction *From) { + for (Value *V : To) { + if (Instruction *I = dyn_cast<Instruction>(V)) + addMetadata(I, From); + } +} + +void VPTransformState::setDebugLocFromInst(const Value *V) { + if (const Instruction *Inst = dyn_cast_or_null<Instruction>(V)) { + const DILocation *DIL = Inst->getDebugLoc(); + + // When a FSDiscriminator is enabled, we don't need to add the multiply + // factors to the discriminators. + if (DIL && Inst->getFunction()->isDebugInfoForProfiling() && + !isa<DbgInfoIntrinsic>(Inst) && !EnableFSDiscriminator) { + // FIXME: For scalable vectors, assume vscale=1. + auto NewDIL = + DIL->cloneByMultiplyingDuplicationFactor(UF * VF.getKnownMinValue()); + if (NewDIL) + Builder.SetCurrentDebugLocation(*NewDIL); + else + LLVM_DEBUG(dbgs() << "Failed to create new discriminator: " + << DIL->getFilename() << " Line: " << DIL->getLine()); + } else + Builder.SetCurrentDebugLocation(DIL); + } else + Builder.SetCurrentDebugLocation(DebugLoc()); +} BasicBlock * VPBasicBlock::createEmptyBasicBlock(VPTransformState::CFGState &CFG) { @@ -252,43 +275,36 @@ VPBasicBlock::createEmptyBasicBlock(VPTransformState::CFGState &CFG) { // Pred stands for Predessor. Prev stands for Previous - last visited/created. BasicBlock *PrevBB = CFG.PrevBB; BasicBlock *NewBB = BasicBlock::Create(PrevBB->getContext(), getName(), - PrevBB->getParent(), CFG.LastBB); + PrevBB->getParent(), CFG.ExitBB); LLVM_DEBUG(dbgs() << "LV: created " << NewBB->getName() << '\n'); // Hook up the new basic block to its predecessors. for (VPBlockBase *PredVPBlock : getHierarchicalPredecessors()) { - VPBasicBlock *PredVPBB = PredVPBlock->getExitBasicBlock(); - auto &PredVPSuccessors = PredVPBB->getSuccessors(); + VPBasicBlock *PredVPBB = PredVPBlock->getExitingBasicBlock(); + auto &PredVPSuccessors = PredVPBB->getHierarchicalSuccessors(); BasicBlock *PredBB = CFG.VPBB2IRBB[PredVPBB]; - // In outer loop vectorization scenario, the predecessor BBlock may not yet - // be visited(backedge). Mark the VPBasicBlock for fixup at the end of - // vectorization. We do not encounter this case in inner loop vectorization - // as we start out by building a loop skeleton with the vector loop header - // and latch blocks. As a result, we never enter this function for the - // header block in the non VPlan-native path. - if (!PredBB) { - assert(EnableVPlanNativePath && - "Unexpected null predecessor in non VPlan-native path"); - CFG.VPBBsToFix.push_back(PredVPBB); - continue; - } - assert(PredBB && "Predecessor basic-block not found building successor."); auto *PredBBTerminator = PredBB->getTerminator(); LLVM_DEBUG(dbgs() << "LV: draw edge from" << PredBB->getName() << '\n'); + + auto *TermBr = dyn_cast<BranchInst>(PredBBTerminator); if (isa<UnreachableInst>(PredBBTerminator)) { assert(PredVPSuccessors.size() == 1 && "Predecessor ending w/o branch must have single successor."); + DebugLoc DL = PredBBTerminator->getDebugLoc(); PredBBTerminator->eraseFromParent(); - BranchInst::Create(NewBB, PredBB); + auto *Br = BranchInst::Create(NewBB, PredBB); + Br->setDebugLoc(DL); + } else if (TermBr && !TermBr->isConditional()) { + TermBr->setSuccessor(0, NewBB); } else { - assert(PredVPSuccessors.size() == 2 && - "Predecessor ending with branch must have two successors."); + // Set each forward successor here when it is created, excluding + // backedges. A backward successor is set when the branch is created. unsigned idx = PredVPSuccessors.front() == this ? 0 : 1; - assert(!PredBBTerminator->getSuccessor(idx) && + assert(!TermBr->getSuccessor(idx) && "Trying to reset an existing successor block."); - PredBBTerminator->setSuccessor(idx, NewBB); + TermBr->setSuccessor(idx, NewBB); } } return NewBB; @@ -300,27 +316,51 @@ void VPBasicBlock::execute(VPTransformState *State) { VPBlockBase *SingleHPred = nullptr; BasicBlock *NewBB = State->CFG.PrevBB; // Reuse it if possible. - // 1. Create an IR basic block, or reuse the last one if possible. - // The last IR basic block is reused, as an optimization, in three cases: - // A. the first VPBB reuses the loop header BB - when PrevVPBB is null; - // B. when the current VPBB has a single (hierarchical) predecessor which - // is PrevVPBB and the latter has a single (hierarchical) successor; and - // C. when the current VPBB is an entry of a region replica - where PrevVPBB - // is the exit of this region from a previous instance, or the predecessor - // of this region. - if (PrevVPBB && /* A */ - !((SingleHPred = getSingleHierarchicalPredecessor()) && - SingleHPred->getExitBasicBlock() == PrevVPBB && - PrevVPBB->getSingleHierarchicalSuccessor()) && /* B */ - !(Replica && getPredecessors().empty())) { /* C */ + auto IsLoopRegion = [](VPBlockBase *BB) { + auto *R = dyn_cast<VPRegionBlock>(BB); + return R && !R->isReplicator(); + }; + + // 1. Create an IR basic block, or reuse the last one or ExitBB if possible. + if (getPlan()->getVectorLoopRegion()->getSingleSuccessor() == this) { + // ExitBB can be re-used for the exit block of the Plan. + NewBB = State->CFG.ExitBB; + State->CFG.PrevBB = NewBB; + + // Update the branch instruction in the predecessor to branch to ExitBB. + VPBlockBase *PredVPB = getSingleHierarchicalPredecessor(); + VPBasicBlock *ExitingVPBB = PredVPB->getExitingBasicBlock(); + assert(PredVPB->getSingleSuccessor() == this && + "predecessor must have the current block as only successor"); + BasicBlock *ExitingBB = State->CFG.VPBB2IRBB[ExitingVPBB]; + // The Exit block of a loop is always set to be successor 0 of the Exiting + // block. + cast<BranchInst>(ExitingBB->getTerminator())->setSuccessor(0, NewBB); + } else if (PrevVPBB && /* A */ + !((SingleHPred = getSingleHierarchicalPredecessor()) && + SingleHPred->getExitingBasicBlock() == PrevVPBB && + PrevVPBB->getSingleHierarchicalSuccessor() && + (SingleHPred->getParent() == getEnclosingLoopRegion() && + !IsLoopRegion(SingleHPred))) && /* B */ + !(Replica && getPredecessors().empty())) { /* C */ + // The last IR basic block is reused, as an optimization, in three cases: + // A. the first VPBB reuses the loop pre-header BB - when PrevVPBB is null; + // B. when the current VPBB has a single (hierarchical) predecessor which + // is PrevVPBB and the latter has a single (hierarchical) successor which + // both are in the same non-replicator region; and + // C. when the current VPBB is an entry of a region replica - where PrevVPBB + // is the exiting VPBB of this region from a previous instance, or the + // predecessor of this region. + NewBB = createEmptyBasicBlock(State->CFG); State->Builder.SetInsertPoint(NewBB); // Temporarily terminate with unreachable until CFG is rewired. UnreachableInst *Terminator = State->Builder.CreateUnreachable(); + // Register NewBB in its loop. In innermost loops its the same for all + // BB's. + if (State->CurrentVectorLoop) + State->CurrentVectorLoop->addBasicBlockToLoop(NewBB, *State->LI); State->Builder.SetInsertPoint(Terminator); - // Register NewBB in its loop. In innermost loops its the same for all BB's. - Loop *L = State->LI->getLoopFor(State->CFG.LastBB); - L->addBasicBlockToLoop(NewBB, *State->LI); State->CFG.PrevBB = NewBB; } @@ -334,29 +374,6 @@ void VPBasicBlock::execute(VPTransformState *State) { for (VPRecipeBase &Recipe : Recipes) Recipe.execute(*State); - VPValue *CBV; - if (EnableVPlanNativePath && (CBV = getCondBit())) { - assert(CBV->getUnderlyingValue() && - "Unexpected null underlying value for condition bit"); - - // Condition bit value in a VPBasicBlock is used as the branch selector. In - // the VPlan-native path case, since all branches are uniform we generate a - // branch instruction using the condition value from vector lane 0 and dummy - // successors. The successors are fixed later when the successor blocks are - // visited. - Value *NewCond = State->get(CBV, {0, 0}); - - // Replace the temporary unreachable terminator with the new conditional - // branch. - auto *CurrentTerminator = NewBB->getTerminator(); - assert(isa<UnreachableInst>(CurrentTerminator) && - "Expected to replace unreachable terminator with conditional " - "branch."); - auto *CondBr = BranchInst::Create(NewBB, nullptr, NewCond); - CondBr->setSuccessor(0, nullptr); - ReplaceInstWithInst(CurrentTerminator, CondBr); - } - LLVM_DEBUG(dbgs() << "LV: filled BB:" << *NewBB); } @@ -395,6 +412,61 @@ VPBasicBlock *VPBasicBlock::splitAt(iterator SplitAt) { return SplitBlock; } +VPRegionBlock *VPBasicBlock::getEnclosingLoopRegion() { + VPRegionBlock *P = getParent(); + if (P && P->isReplicator()) { + P = P->getParent(); + assert(!cast<VPRegionBlock>(P)->isReplicator() && + "unexpected nested replicate regions"); + } + return P; +} + +static bool hasConditionalTerminator(const VPBasicBlock *VPBB) { + if (VPBB->empty()) { + assert( + VPBB->getNumSuccessors() < 2 && + "block with multiple successors doesn't have a recipe as terminator"); + return false; + } + + const VPRecipeBase *R = &VPBB->back(); + auto *VPI = dyn_cast<VPInstruction>(R); + bool IsCondBranch = + isa<VPBranchOnMaskRecipe>(R) || + (VPI && (VPI->getOpcode() == VPInstruction::BranchOnCond || + VPI->getOpcode() == VPInstruction::BranchOnCount)); + (void)IsCondBranch; + + if (VPBB->getNumSuccessors() >= 2 || VPBB->isExiting()) { + assert(IsCondBranch && "block with multiple successors not terminated by " + "conditional branch recipe"); + + return true; + } + + assert( + !IsCondBranch && + "block with 0 or 1 successors terminated by conditional branch recipe"); + return false; +} + +VPRecipeBase *VPBasicBlock::getTerminator() { + if (hasConditionalTerminator(this)) + return &back(); + return nullptr; +} + +const VPRecipeBase *VPBasicBlock::getTerminator() const { + if (hasConditionalTerminator(this)) + return &back(); + return nullptr; +} + +bool VPBasicBlock::isExiting() const { + return getParent()->getExitingBasicBlock() == this; +} + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPBlockBase::printSuccessors(raw_ostream &O, const Twine &Indent) const { if (getSuccessors().empty()) { @@ -411,13 +483,6 @@ void VPBlockBase::printSuccessors(raw_ostream &O, const Twine &Indent) const { void VPBasicBlock::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { O << Indent << getName() << ":\n"; - if (const VPValue *Pred = getPredicate()) { - O << Indent << "BlockPredicate:"; - Pred->printAsOperand(O, SlotTracker); - if (const auto *PredInst = dyn_cast<VPInstruction>(Pred)) - O << " (" << PredInst->getParent()->getName() << ")"; - O << '\n'; - } auto RecipeIndent = Indent + " "; for (const VPRecipeBase &Recipe : *this) { @@ -426,14 +491,6 @@ void VPBasicBlock::print(raw_ostream &O, const Twine &Indent, } printSuccessors(O, Indent); - - if (const VPValue *CBV = getCondBit()) { - O << Indent << "CondBit: "; - CBV->printAsOperand(O, SlotTracker); - if (const auto *CBI = dyn_cast<VPInstruction>(CBV)) - O << " (" << CBI->getParent()->getName() << ")"; - O << '\n'; - } } #endif @@ -448,25 +505,26 @@ void VPRegionBlock::execute(VPTransformState *State) { ReversePostOrderTraversal<VPBlockBase *> RPOT(Entry); if (!isReplicator()) { + // Create and register the new vector loop. + Loop *PrevLoop = State->CurrentVectorLoop; + State->CurrentVectorLoop = State->LI->AllocateLoop(); + BasicBlock *VectorPH = State->CFG.VPBB2IRBB[getPreheaderVPBB()]; + Loop *ParentLoop = State->LI->getLoopFor(VectorPH); + + // Insert the new loop into the loop nest and register the new basic blocks + // before calling any utilities such as SCEV that require valid LoopInfo. + if (ParentLoop) + ParentLoop->addChildLoop(State->CurrentVectorLoop); + else + State->LI->addTopLevelLoop(State->CurrentVectorLoop); + // Visit the VPBlocks connected to "this", starting from it. for (VPBlockBase *Block : RPOT) { - if (EnableVPlanNativePath) { - // The inner loop vectorization path does not represent loop preheader - // and exit blocks as part of the VPlan. In the VPlan-native path, skip - // vectorizing loop preheader block. In future, we may replace this - // check with the check for loop preheader. - if (Block->getNumPredecessors() == 0) - continue; - - // Skip vectorizing loop exit block. In future, we may replace this - // check with the check for loop exit. - if (Block->getNumSuccessors() == 0) - continue; - } - LLVM_DEBUG(dbgs() << "LV: VPBlock in RPO " << Block->getName() << '\n'); Block->execute(State); } + + State->CurrentVectorLoop = PrevLoop; return; } @@ -508,341 +566,32 @@ void VPRegionBlock::print(raw_ostream &O, const Twine &Indent, } #endif -bool VPRecipeBase::mayWriteToMemory() const { - switch (getVPDefID()) { - case VPWidenMemoryInstructionSC: { - return cast<VPWidenMemoryInstructionRecipe>(this)->isStore(); - } - case VPReplicateSC: - case VPWidenCallSC: - return cast<Instruction>(getVPSingleValue()->getUnderlyingValue()) - ->mayWriteToMemory(); - case VPBranchOnMaskSC: - return false; - case VPWidenIntOrFpInductionSC: - case VPWidenCanonicalIVSC: - case VPWidenPHISC: - case VPBlendSC: - case VPWidenSC: - case VPWidenGEPSC: - case VPReductionSC: - case VPWidenSelectSC: { - const Instruction *I = - dyn_cast_or_null<Instruction>(getVPSingleValue()->getUnderlyingValue()); - (void)I; - assert((!I || !I->mayWriteToMemory()) && - "underlying instruction may write to memory"); - return false; - } - default: - return true; - } -} - -bool VPRecipeBase::mayReadFromMemory() const { - switch (getVPDefID()) { - case VPWidenMemoryInstructionSC: { - return !cast<VPWidenMemoryInstructionRecipe>(this)->isStore(); - } - case VPReplicateSC: - case VPWidenCallSC: - return cast<Instruction>(getVPSingleValue()->getUnderlyingValue()) - ->mayReadFromMemory(); - case VPBranchOnMaskSC: - return false; - case VPWidenIntOrFpInductionSC: - case VPWidenCanonicalIVSC: - case VPWidenPHISC: - case VPBlendSC: - case VPWidenSC: - case VPWidenGEPSC: - case VPReductionSC: - case VPWidenSelectSC: { - const Instruction *I = - dyn_cast_or_null<Instruction>(getVPSingleValue()->getUnderlyingValue()); - (void)I; - assert((!I || !I->mayReadFromMemory()) && - "underlying instruction may read from memory"); - return false; - } - default: - return true; - } -} - -bool VPRecipeBase::mayHaveSideEffects() const { - switch (getVPDefID()) { - case VPBranchOnMaskSC: - return false; - case VPWidenIntOrFpInductionSC: - case VPWidenCanonicalIVSC: - case VPWidenPHISC: - case VPBlendSC: - case VPWidenSC: - case VPWidenGEPSC: - case VPReductionSC: - case VPWidenSelectSC: { - const Instruction *I = - dyn_cast_or_null<Instruction>(getVPSingleValue()->getUnderlyingValue()); - (void)I; - assert((!I || !I->mayHaveSideEffects()) && - "underlying instruction has side-effects"); - return false; - } - case VPReplicateSC: { - auto *R = cast<VPReplicateRecipe>(this); - return R->getUnderlyingInstr()->mayHaveSideEffects(); - } - default: - return true; - } -} - -void VPRecipeBase::insertBefore(VPRecipeBase *InsertPos) { - assert(!Parent && "Recipe already in some VPBasicBlock"); - assert(InsertPos->getParent() && - "Insertion position not in any VPBasicBlock"); - Parent = InsertPos->getParent(); - Parent->getRecipeList().insert(InsertPos->getIterator(), this); -} - -void VPRecipeBase::insertAfter(VPRecipeBase *InsertPos) { - assert(!Parent && "Recipe already in some VPBasicBlock"); - assert(InsertPos->getParent() && - "Insertion position not in any VPBasicBlock"); - Parent = InsertPos->getParent(); - Parent->getRecipeList().insertAfter(InsertPos->getIterator(), this); -} - -void VPRecipeBase::removeFromParent() { - assert(getParent() && "Recipe not in any VPBasicBlock"); - getParent()->getRecipeList().remove(getIterator()); - Parent = nullptr; -} - -iplist<VPRecipeBase>::iterator VPRecipeBase::eraseFromParent() { - assert(getParent() && "Recipe not in any VPBasicBlock"); - return getParent()->getRecipeList().erase(getIterator()); -} - -void VPRecipeBase::moveAfter(VPRecipeBase *InsertPos) { - removeFromParent(); - insertAfter(InsertPos); -} - -void VPRecipeBase::moveBefore(VPBasicBlock &BB, - iplist<VPRecipeBase>::iterator I) { - assert(I == BB.end() || I->getParent() == &BB); - removeFromParent(); - Parent = &BB; - BB.getRecipeList().insert(I, this); -} - -void VPInstruction::generateInstruction(VPTransformState &State, - unsigned Part) { - IRBuilder<> &Builder = State.Builder; - Builder.SetCurrentDebugLocation(DL); - - if (Instruction::isBinaryOp(getOpcode())) { - Value *A = State.get(getOperand(0), Part); - Value *B = State.get(getOperand(1), Part); - Value *V = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B); - State.set(this, V, Part); - return; - } - - switch (getOpcode()) { - case VPInstruction::Not: { - Value *A = State.get(getOperand(0), Part); - Value *V = Builder.CreateNot(A); - State.set(this, V, Part); - break; - } - case VPInstruction::ICmpULE: { - Value *IV = State.get(getOperand(0), Part); - Value *TC = State.get(getOperand(1), Part); - Value *V = Builder.CreateICmpULE(IV, TC); - State.set(this, V, Part); - break; - } - case Instruction::Select: { - Value *Cond = State.get(getOperand(0), Part); - Value *Op1 = State.get(getOperand(1), Part); - Value *Op2 = State.get(getOperand(2), Part); - Value *V = Builder.CreateSelect(Cond, Op1, Op2); - State.set(this, V, Part); - break; - } - case VPInstruction::ActiveLaneMask: { - // Get first lane of vector induction variable. - Value *VIVElem0 = State.get(getOperand(0), VPIteration(Part, 0)); - // Get the original loop tripcount. - Value *ScalarTC = State.get(getOperand(1), Part); - - auto *Int1Ty = Type::getInt1Ty(Builder.getContext()); - auto *PredTy = VectorType::get(Int1Ty, State.VF); - Instruction *Call = Builder.CreateIntrinsic( - Intrinsic::get_active_lane_mask, {PredTy, ScalarTC->getType()}, - {VIVElem0, ScalarTC}, nullptr, "active.lane.mask"); - State.set(this, Call, Part); - break; - } - case VPInstruction::FirstOrderRecurrenceSplice: { - // Generate code to combine the previous and current values in vector v3. - // - // vector.ph: - // v_init = vector(..., ..., ..., a[-1]) - // br vector.body - // - // vector.body - // i = phi [0, vector.ph], [i+4, vector.body] - // v1 = phi [v_init, vector.ph], [v2, vector.body] - // v2 = a[i, i+1, i+2, i+3]; - // v3 = vector(v1(3), v2(0, 1, 2)) - - // For the first part, use the recurrence phi (v1), otherwise v2. - auto *V1 = State.get(getOperand(0), 0); - Value *PartMinus1 = Part == 0 ? V1 : State.get(getOperand(1), Part - 1); - if (!PartMinus1->getType()->isVectorTy()) { - State.set(this, PartMinus1, Part); - } else { - Value *V2 = State.get(getOperand(1), Part); - State.set(this, Builder.CreateVectorSplice(PartMinus1, V2, -1), Part); - } - break; - } - - case VPInstruction::CanonicalIVIncrement: - case VPInstruction::CanonicalIVIncrementNUW: { - Value *Next = nullptr; - if (Part == 0) { - bool IsNUW = getOpcode() == VPInstruction::CanonicalIVIncrementNUW; - auto *Phi = State.get(getOperand(0), 0); - // The loop step is equal to the vectorization factor (num of SIMD - // elements) times the unroll factor (num of SIMD instructions). - Value *Step = - createStepForVF(Builder, Phi->getType(), State.VF, State.UF); - Next = Builder.CreateAdd(Phi, Step, "index.next", IsNUW, false); - } else { - Next = State.get(this, 0); - } - - State.set(this, Next, Part); - break; - } - case VPInstruction::BranchOnCount: { - if (Part != 0) - break; - // First create the compare. - Value *IV = State.get(getOperand(0), Part); - Value *TC = State.get(getOperand(1), Part); - Value *Cond = Builder.CreateICmpEQ(IV, TC); - - // Now create the branch. - auto *Plan = getParent()->getPlan(); - VPRegionBlock *TopRegion = Plan->getVectorLoopRegion(); - VPBasicBlock *Header = TopRegion->getEntry()->getEntryBasicBlock(); - if (Header->empty()) { - assert(EnableVPlanNativePath && - "empty entry block only expected in VPlanNativePath"); - Header = cast<VPBasicBlock>(Header->getSingleSuccessor()); +void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV, + Value *CanonicalIVStartValue, + VPTransformState &State, + bool IsEpilogueVectorization) { + + VPBasicBlock *ExitingVPBB = getVectorLoopRegion()->getExitingBasicBlock(); + auto *Term = dyn_cast<VPInstruction>(&ExitingVPBB->back()); + // Try to simplify BranchOnCount to 'BranchOnCond true' if TC <= VF * UF when + // preparing to execute the plan for the main vector loop. + if (!IsEpilogueVectorization && Term && + Term->getOpcode() == VPInstruction::BranchOnCount && + isa<ConstantInt>(TripCountV)) { + ConstantInt *C = cast<ConstantInt>(TripCountV); + uint64_t TCVal = C->getZExtValue(); + if (TCVal && TCVal <= State.VF.getKnownMinValue() * State.UF) { + auto *BOC = + new VPInstruction(VPInstruction::BranchOnCond, + {getOrAddExternalDef(State.Builder.getTrue())}); + Term->eraseFromParent(); + ExitingVPBB->appendRecipe(BOC); + // TODO: Further simplifications are possible + // 1. Replace inductions with constants. + // 2. Replace vector loop region with VPBasicBlock. } - // TODO: Once the exit block is modeled in VPlan, use it instead of going - // through State.CFG.LastBB. - BasicBlock *Exit = - cast<BranchInst>(State.CFG.LastBB->getTerminator())->getSuccessor(0); - - Builder.CreateCondBr(Cond, Exit, State.CFG.VPBB2IRBB[Header]); - Builder.GetInsertBlock()->getTerminator()->eraseFromParent(); - break; - } - default: - llvm_unreachable("Unsupported opcode for instruction"); - } -} - -void VPInstruction::execute(VPTransformState &State) { - assert(!State.Instance && "VPInstruction executing an Instance"); - IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder); - State.Builder.setFastMathFlags(FMF); - for (unsigned Part = 0; Part < State.UF; ++Part) - generateInstruction(State, Part); -} - -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -void VPInstruction::dump() const { - VPSlotTracker SlotTracker(getParent()->getPlan()); - print(dbgs(), "", SlotTracker); -} - -void VPInstruction::print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const { - O << Indent << "EMIT "; - - if (hasResult()) { - printAsOperand(O, SlotTracker); - O << " = "; - } - - switch (getOpcode()) { - case VPInstruction::Not: - O << "not"; - break; - case VPInstruction::ICmpULE: - O << "icmp ule"; - break; - case VPInstruction::SLPLoad: - O << "combined load"; - break; - case VPInstruction::SLPStore: - O << "combined store"; - break; - case VPInstruction::ActiveLaneMask: - O << "active lane mask"; - break; - case VPInstruction::FirstOrderRecurrenceSplice: - O << "first-order splice"; - break; - case VPInstruction::CanonicalIVIncrement: - O << "VF * UF + "; - break; - case VPInstruction::CanonicalIVIncrementNUW: - O << "VF * UF +(nuw) "; - break; - case VPInstruction::BranchOnCount: - O << "branch-on-count "; - break; - default: - O << Instruction::getOpcodeName(getOpcode()); - } - - O << FMF; - - for (const VPValue *Operand : operands()) { - O << " "; - Operand->printAsOperand(O, SlotTracker); } - if (DL) { - O << ", !dbg "; - DL.print(O); - } -} -#endif - -void VPInstruction::setFastMathFlags(FastMathFlags FMFNew) { - // Make sure the VPInstruction is a floating-point operation. - assert((Opcode == Instruction::FAdd || Opcode == Instruction::FMul || - Opcode == Instruction::FNeg || Opcode == Instruction::FSub || - Opcode == Instruction::FDiv || Opcode == Instruction::FRem || - Opcode == Instruction::FCmp) && - "this op can't take fast-math flags"); - FMF = FMFNew; -} - -void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV, - Value *CanonicalIVStartValue, - VPTransformState &State) { // Check if the trip count is needed, and if so build it. if (TripCount && TripCount->getNumUsers()) { for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part) @@ -868,111 +617,78 @@ void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV, // When vectorizing the epilogue loop, the canonical induction start value // needs to be changed from zero to the value after the main vector loop. if (CanonicalIVStartValue) { - VPValue *VPV = new VPValue(CanonicalIVStartValue); - addExternalDef(VPV); + VPValue *VPV = getOrAddExternalDef(CanonicalIVStartValue); auto *IV = getCanonicalIV(); assert(all_of(IV->users(), [](const VPUser *U) { + if (isa<VPScalarIVStepsRecipe>(U)) + return true; auto *VPI = cast<VPInstruction>(U); return VPI->getOpcode() == VPInstruction::CanonicalIVIncrement || VPI->getOpcode() == VPInstruction::CanonicalIVIncrementNUW; }) && - "the canonical IV should only be used by its increments when " + "the canonical IV should only be used by its increments or " + "ScalarIVSteps when " "resetting the start value"); IV->setOperand(0, VPV); } } -/// Generate the code inside the body of the vectorized loop. Assumes a single -/// LoopVectorBody basic-block was created for this. Introduce additional -/// basic-blocks as needed, and fill them all. +/// Generate the code inside the preheader and body of the vectorized loop. +/// Assumes a single pre-header basic-block was created for this. Introduce +/// additional basic-blocks as needed, and fill them all. void VPlan::execute(VPTransformState *State) { - // 0. Set the reverse mapping from VPValues to Values for code generation. + // Set the reverse mapping from VPValues to Values for code generation. for (auto &Entry : Value2VPValue) State->VPValue2Value[Entry.second] = Entry.first; - BasicBlock *VectorPreHeaderBB = State->CFG.PrevBB; - State->CFG.VectorPreHeader = VectorPreHeaderBB; - BasicBlock *VectorHeaderBB = VectorPreHeaderBB->getSingleSuccessor(); - assert(VectorHeaderBB && "Loop preheader does not have a single successor."); - - // 1. Make room to generate basic-blocks inside loop body if needed. - BasicBlock *VectorLatchBB = VectorHeaderBB->splitBasicBlock( - VectorHeaderBB->getFirstInsertionPt(), "vector.body.latch"); - Loop *L = State->LI->getLoopFor(VectorHeaderBB); - L->addBasicBlockToLoop(VectorLatchBB, *State->LI); - // Remove the edge between Header and Latch to allow other connections. - // Temporarily terminate with unreachable until CFG is rewired. - // Note: this asserts the generated code's assumption that - // getFirstInsertionPt() can be dereferenced into an Instruction. - VectorHeaderBB->getTerminator()->eraseFromParent(); - State->Builder.SetInsertPoint(VectorHeaderBB); - UnreachableInst *Terminator = State->Builder.CreateUnreachable(); - State->Builder.SetInsertPoint(Terminator); - - // 2. Generate code in loop body. + // Initialize CFG state. State->CFG.PrevVPBB = nullptr; - State->CFG.PrevBB = VectorHeaderBB; - State->CFG.LastBB = VectorLatchBB; + State->CFG.ExitBB = State->CFG.PrevBB->getSingleSuccessor(); + BasicBlock *VectorPreHeader = State->CFG.PrevBB; + State->Builder.SetInsertPoint(VectorPreHeader->getTerminator()); + // Generate code in the loop pre-header and body. for (VPBlockBase *Block : depth_first(Entry)) Block->execute(State); - // Setup branch terminator successors for VPBBs in VPBBsToFix based on - // VPBB's successors. - for (auto VPBB : State->CFG.VPBBsToFix) { - assert(EnableVPlanNativePath && - "Unexpected VPBBsToFix in non VPlan-native path"); - BasicBlock *BB = State->CFG.VPBB2IRBB[VPBB]; - assert(BB && "Unexpected null basic block for VPBB"); - - unsigned Idx = 0; - auto *BBTerminator = BB->getTerminator(); - - for (VPBlockBase *SuccVPBlock : VPBB->getHierarchicalSuccessors()) { - VPBasicBlock *SuccVPBB = SuccVPBlock->getEntryBasicBlock(); - BBTerminator->setSuccessor(Idx, State->CFG.VPBB2IRBB[SuccVPBB]); - ++Idx; - } - } - - // 3. Merge the temporary latch created with the last basic-block filled. - BasicBlock *LastBB = State->CFG.PrevBB; - assert(isa<BranchInst>(LastBB->getTerminator()) && - "Expected VPlan CFG to terminate with branch"); - - // Move both the branch and check from LastBB to VectorLatchBB. - auto *LastBranch = cast<BranchInst>(LastBB->getTerminator()); - LastBranch->moveBefore(VectorLatchBB->getTerminator()); - VectorLatchBB->getTerminator()->eraseFromParent(); - // Move condition so it is guaranteed to be next to branch. This is only done - // to avoid excessive test updates. - // TODO: Remove special handling once the increments for all inductions are - // modeled explicitly in VPlan. - cast<Instruction>(LastBranch->getCondition())->moveBefore(LastBranch); - // Connect LastBB to VectorLatchBB to facilitate their merge. - BranchInst::Create(VectorLatchBB, LastBB); - - // Merge LastBB with Latch. - bool Merged = MergeBlockIntoPredecessor(VectorLatchBB, nullptr, State->LI); - (void)Merged; - assert(Merged && "Could not merge last basic block with latch."); - VectorLatchBB = LastBB; + VPBasicBlock *LatchVPBB = getVectorLoopRegion()->getExitingBasicBlock(); + BasicBlock *VectorLatchBB = State->CFG.VPBB2IRBB[LatchVPBB]; // Fix the latch value of canonical, reduction and first-order recurrences // phis in the vector loop. - VPBasicBlock *Header = Entry->getEntryBasicBlock(); - if (Header->empty()) { - assert(EnableVPlanNativePath); - Header = cast<VPBasicBlock>(Header->getSingleSuccessor()); - } + VPBasicBlock *Header = getVectorLoopRegion()->getEntryBasicBlock(); for (VPRecipeBase &R : Header->phis()) { // Skip phi-like recipes that generate their backedege values themselves. - // TODO: Model their backedge values explicitly. - if (isa<VPWidenIntOrFpInductionRecipe>(&R) || isa<VPWidenPHIRecipe>(&R)) + if (isa<VPWidenPHIRecipe>(&R)) + continue; + + if (isa<VPWidenPointerInductionRecipe>(&R) || + isa<VPWidenIntOrFpInductionRecipe>(&R)) { + PHINode *Phi = nullptr; + if (isa<VPWidenIntOrFpInductionRecipe>(&R)) { + Phi = cast<PHINode>(State->get(R.getVPSingleValue(), 0)); + } else { + auto *WidenPhi = cast<VPWidenPointerInductionRecipe>(&R); + // TODO: Split off the case that all users of a pointer phi are scalar + // from the VPWidenPointerInductionRecipe. + if (WidenPhi->onlyScalarsGenerated(State->VF)) + continue; + + auto *GEP = cast<GetElementPtrInst>(State->get(WidenPhi, 0)); + Phi = cast<PHINode>(GEP->getPointerOperand()); + } + + Phi->setIncomingBlock(1, VectorLatchBB); + + // Move the last step to the end of the latch block. This ensures + // consistent placement of all induction updates. + Instruction *Inc = cast<Instruction>(Phi->getIncomingValue(1)); + Inc->moveBefore(VectorLatchBB->getTerminator()->getPrevNode()); continue; + } auto *PhiR = cast<VPHeaderPHIRecipe>(&R); // For canonical IV, first-order recurrences and in-order reduction phis, @@ -993,9 +709,12 @@ void VPlan::execute(VPTransformState *State) { } // We do not attempt to preserve DT for outer loop vectorization currently. - if (!EnableVPlanNativePath) - updateDominatorTree(State->DT, VectorPreHeaderBB, VectorLatchBB, - L->getExitBlock()); + if (!EnableVPlanNativePath) { + BasicBlock *VectorHeaderBB = State->CFG.VPBB2IRBB[Header]; + State->DT->addNewBlock(VectorHeaderBB, VectorPreHeader); + updateDominatorTree(State->DT, VectorHeaderBB, VectorLatchBB, + State->CFG.ExitBB); + } } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -1021,6 +740,17 @@ void VPlan::print(raw_ostream &O) const { O << '\n'; Block->print(O, "", SlotTracker); } + + if (!LiveOuts.empty()) + O << "\n"; + for (auto &KV : LiveOuts) { + O << "Live-out "; + KV.second->getPhi()->printAsOperand(O); + O << " = "; + KV.second->getOperand(0)->printAsOperand(O, SlotTracker); + O << "\n"; + } + O << "}\n"; } @@ -1034,11 +764,14 @@ LLVM_DUMP_METHOD void VPlan::dump() const { print(dbgs()); } #endif -void VPlan::updateDominatorTree(DominatorTree *DT, BasicBlock *LoopPreHeaderBB, +void VPlan::addLiveOut(PHINode *PN, VPValue *V) { + assert(LiveOuts.count(PN) == 0 && "an exit value for PN already exists"); + LiveOuts.insert({PN, new VPLiveOut(PN, V)}); +} + +void VPlan::updateDominatorTree(DominatorTree *DT, BasicBlock *LoopHeaderBB, BasicBlock *LoopLatchBB, BasicBlock *LoopExitBB) { - BasicBlock *LoopHeaderBB = LoopPreHeaderBB->getSingleSuccessor(); - assert(LoopHeaderBB && "Loop preheader does not have a single successor."); // The vector body may be more than a single basic-block by this point. // Update the dominator tree information inside the vector body by propagating // it from header to latch, expecting only triangular control-flow, if any. @@ -1075,6 +808,7 @@ void VPlan::updateDominatorTree(DominatorTree *DT, BasicBlock *LoopPreHeaderBB, } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + Twine VPlanPrinter::getUID(const VPBlockBase *Block) { return (isa<VPRegionBlock>(Block) ? "cluster_N" : "N") + Twine(getOrCreateBID(Block)); @@ -1122,8 +856,8 @@ void VPlanPrinter::dumpBlock(const VPBlockBase *Block) { void VPlanPrinter::drawEdge(const VPBlockBase *From, const VPBlockBase *To, bool Hidden, const Twine &Label) { // Due to "dot" we print an edge between two regions as an edge between the - // exit basic block and the entry basic of the respective regions. - const VPBlockBase *Tail = From->getExitBasicBlock(); + // exiting basic block and the entry basic of the respective regions. + const VPBlockBase *Tail = From->getExitingBasicBlock(); const VPBlockBase *Head = To->getEntryBasicBlock(); OS << Indent << getUID(Tail) << " -> " << getUID(Head); OS << " [ label=\"" << Label << '\"'; @@ -1213,328 +947,6 @@ void VPlanIngredient::print(raw_ostream &O) const { V->printAsOperand(O, false); } -void VPWidenCallRecipe::print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const { - O << Indent << "WIDEN-CALL "; - - auto *CI = cast<CallInst>(getUnderlyingInstr()); - if (CI->getType()->isVoidTy()) - O << "void "; - else { - printAsOperand(O, SlotTracker); - O << " = "; - } - - O << "call @" << CI->getCalledFunction()->getName() << "("; - printOperands(O, SlotTracker); - O << ")"; -} - -void VPWidenSelectRecipe::print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const { - O << Indent << "WIDEN-SELECT "; - printAsOperand(O, SlotTracker); - O << " = select "; - getOperand(0)->printAsOperand(O, SlotTracker); - O << ", "; - getOperand(1)->printAsOperand(O, SlotTracker); - O << ", "; - getOperand(2)->printAsOperand(O, SlotTracker); - O << (InvariantCond ? " (condition is loop invariant)" : ""); -} - -void VPWidenRecipe::print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const { - O << Indent << "WIDEN "; - printAsOperand(O, SlotTracker); - O << " = " << getUnderlyingInstr()->getOpcodeName() << " "; - printOperands(O, SlotTracker); -} - -void VPWidenIntOrFpInductionRecipe::print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const { - O << Indent << "WIDEN-INDUCTION"; - if (getTruncInst()) { - O << "\\l\""; - O << " +\n" << Indent << "\" " << VPlanIngredient(IV) << "\\l\""; - O << " +\n" << Indent << "\" "; - getVPValue(0)->printAsOperand(O, SlotTracker); - } else - O << " " << VPlanIngredient(IV); -} -#endif - -bool VPWidenIntOrFpInductionRecipe::isCanonical() const { - auto *StartC = dyn_cast<ConstantInt>(getStartValue()->getLiveInIRValue()); - auto *StepC = dyn_cast<SCEVConstant>(getInductionDescriptor().getStep()); - return StartC && StartC->isZero() && StepC && StepC->isOne(); -} - -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -void VPWidenGEPRecipe::print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const { - O << Indent << "WIDEN-GEP "; - O << (IsPtrLoopInvariant ? "Inv" : "Var"); - size_t IndicesNumber = IsIndexLoopInvariant.size(); - for (size_t I = 0; I < IndicesNumber; ++I) - O << "[" << (IsIndexLoopInvariant[I] ? "Inv" : "Var") << "]"; - - O << " "; - printAsOperand(O, SlotTracker); - O << " = getelementptr "; - printOperands(O, SlotTracker); -} - -void VPWidenPHIRecipe::print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const { - O << Indent << "WIDEN-PHI "; - - auto *OriginalPhi = cast<PHINode>(getUnderlyingValue()); - // Unless all incoming values are modeled in VPlan print the original PHI - // directly. - // TODO: Remove once all VPWidenPHIRecipe instances keep all relevant incoming - // values as VPValues. - if (getNumOperands() != OriginalPhi->getNumOperands()) { - O << VPlanIngredient(OriginalPhi); - return; - } - - printAsOperand(O, SlotTracker); - O << " = phi "; - printOperands(O, SlotTracker); -} - -void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const { - O << Indent << "BLEND "; - Phi->printAsOperand(O, false); - O << " ="; - if (getNumIncomingValues() == 1) { - // Not a User of any mask: not really blending, this is a - // single-predecessor phi. - O << " "; - getIncomingValue(0)->printAsOperand(O, SlotTracker); - } else { - for (unsigned I = 0, E = getNumIncomingValues(); I < E; ++I) { - O << " "; - getIncomingValue(I)->printAsOperand(O, SlotTracker); - O << "/"; - getMask(I)->printAsOperand(O, SlotTracker); - } - } -} - -void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const { - O << Indent << "REDUCE "; - printAsOperand(O, SlotTracker); - O << " = "; - getChainOp()->printAsOperand(O, SlotTracker); - O << " +"; - if (isa<FPMathOperator>(getUnderlyingInstr())) - O << getUnderlyingInstr()->getFastMathFlags(); - O << " reduce." << Instruction::getOpcodeName(RdxDesc->getOpcode()) << " ("; - getVecOp()->printAsOperand(O, SlotTracker); - if (getCondOp()) { - O << ", "; - getCondOp()->printAsOperand(O, SlotTracker); - } - O << ")"; -} - -void VPReplicateRecipe::print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const { - O << Indent << (IsUniform ? "CLONE " : "REPLICATE "); - - if (!getUnderlyingInstr()->getType()->isVoidTy()) { - printAsOperand(O, SlotTracker); - O << " = "; - } - O << Instruction::getOpcodeName(getUnderlyingInstr()->getOpcode()) << " "; - printOperands(O, SlotTracker); - - if (AlsoPack) - O << " (S->V)"; -} - -void VPPredInstPHIRecipe::print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const { - O << Indent << "PHI-PREDICATED-INSTRUCTION "; - printAsOperand(O, SlotTracker); - O << " = "; - printOperands(O, SlotTracker); -} - -void VPWidenMemoryInstructionRecipe::print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const { - O << Indent << "WIDEN "; - - if (!isStore()) { - printAsOperand(O, SlotTracker); - O << " = "; - } - O << Instruction::getOpcodeName(Ingredient.getOpcode()) << " "; - - printOperands(O, SlotTracker); -} -#endif - -void VPCanonicalIVPHIRecipe::execute(VPTransformState &State) { - Value *Start = getStartValue()->getLiveInIRValue(); - PHINode *EntryPart = PHINode::Create( - Start->getType(), 2, "index", &*State.CFG.PrevBB->getFirstInsertionPt()); - EntryPart->addIncoming(Start, State.CFG.VectorPreHeader); - EntryPart->setDebugLoc(DL); - for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part) - State.set(this, EntryPart, Part); -} - -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -void VPCanonicalIVPHIRecipe::print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const { - O << Indent << "EMIT "; - printAsOperand(O, SlotTracker); - O << " = CANONICAL-INDUCTION"; -} -#endif - -void VPWidenCanonicalIVRecipe::execute(VPTransformState &State) { - Value *CanonicalIV = State.get(getOperand(0), 0); - Type *STy = CanonicalIV->getType(); - IRBuilder<> Builder(State.CFG.PrevBB->getTerminator()); - ElementCount VF = State.VF; - Value *VStart = VF.isScalar() - ? CanonicalIV - : Builder.CreateVectorSplat(VF, CanonicalIV, "broadcast"); - for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part) { - Value *VStep = createStepForVF(Builder, STy, VF, Part); - if (VF.isVector()) { - VStep = Builder.CreateVectorSplat(VF, VStep); - VStep = Builder.CreateAdd(VStep, Builder.CreateStepVector(VStep->getType())); - } - Value *CanonicalVectorIV = Builder.CreateAdd(VStart, VStep, "vec.iv"); - State.set(this, CanonicalVectorIV, Part); - } -} - -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -void VPWidenCanonicalIVRecipe::print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const { - O << Indent << "EMIT "; - printAsOperand(O, SlotTracker); - O << " = WIDEN-CANONICAL-INDUCTION "; - printOperands(O, SlotTracker); -} -#endif - -void VPFirstOrderRecurrencePHIRecipe::execute(VPTransformState &State) { - auto &Builder = State.Builder; - // Create a vector from the initial value. - auto *VectorInit = getStartValue()->getLiveInIRValue(); - - Type *VecTy = State.VF.isScalar() - ? VectorInit->getType() - : VectorType::get(VectorInit->getType(), State.VF); - - if (State.VF.isVector()) { - auto *IdxTy = Builder.getInt32Ty(); - auto *One = ConstantInt::get(IdxTy, 1); - IRBuilder<>::InsertPointGuard Guard(Builder); - Builder.SetInsertPoint(State.CFG.VectorPreHeader->getTerminator()); - auto *RuntimeVF = getRuntimeVF(Builder, IdxTy, State.VF); - auto *LastIdx = Builder.CreateSub(RuntimeVF, One); - VectorInit = Builder.CreateInsertElement( - PoisonValue::get(VecTy), VectorInit, LastIdx, "vector.recur.init"); - } - - // Create a phi node for the new recurrence. - PHINode *EntryPart = PHINode::Create( - VecTy, 2, "vector.recur", &*State.CFG.PrevBB->getFirstInsertionPt()); - EntryPart->addIncoming(VectorInit, State.CFG.VectorPreHeader); - State.set(this, EntryPart, 0); -} - -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -void VPFirstOrderRecurrencePHIRecipe::print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const { - O << Indent << "FIRST-ORDER-RECURRENCE-PHI "; - printAsOperand(O, SlotTracker); - O << " = phi "; - printOperands(O, SlotTracker); -} -#endif - -void VPReductionPHIRecipe::execute(VPTransformState &State) { - PHINode *PN = cast<PHINode>(getUnderlyingValue()); - auto &Builder = State.Builder; - - // In order to support recurrences we need to be able to vectorize Phi nodes. - // Phi nodes have cycles, so we need to vectorize them in two stages. This is - // stage #1: We create a new vector PHI node with no incoming edges. We'll use - // this value when we vectorize all of the instructions that use the PHI. - bool ScalarPHI = State.VF.isScalar() || IsInLoop; - Type *VecTy = - ScalarPHI ? PN->getType() : VectorType::get(PN->getType(), State.VF); - - BasicBlock *HeaderBB = State.CFG.PrevBB; - assert(State.LI->getLoopFor(HeaderBB)->getHeader() == HeaderBB && - "recipe must be in the vector loop header"); - unsigned LastPartForNewPhi = isOrdered() ? 1 : State.UF; - for (unsigned Part = 0; Part < LastPartForNewPhi; ++Part) { - Value *EntryPart = - PHINode::Create(VecTy, 2, "vec.phi", &*HeaderBB->getFirstInsertionPt()); - State.set(this, EntryPart, Part); - } - - // Reductions do not have to start at zero. They can start with - // any loop invariant values. - VPValue *StartVPV = getStartValue(); - Value *StartV = StartVPV->getLiveInIRValue(); - - Value *Iden = nullptr; - RecurKind RK = RdxDesc.getRecurrenceKind(); - if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RK) || - RecurrenceDescriptor::isSelectCmpRecurrenceKind(RK)) { - // MinMax reduction have the start value as their identify. - if (ScalarPHI) { - Iden = StartV; - } else { - IRBuilderBase::InsertPointGuard IPBuilder(Builder); - Builder.SetInsertPoint(State.CFG.VectorPreHeader->getTerminator()); - StartV = Iden = - Builder.CreateVectorSplat(State.VF, StartV, "minmax.ident"); - } - } else { - Iden = RdxDesc.getRecurrenceIdentity(RK, VecTy->getScalarType(), - RdxDesc.getFastMathFlags()); - - if (!ScalarPHI) { - Iden = Builder.CreateVectorSplat(State.VF, Iden); - IRBuilderBase::InsertPointGuard IPBuilder(Builder); - Builder.SetInsertPoint(State.CFG.VectorPreHeader->getTerminator()); - Constant *Zero = Builder.getInt32(0); - StartV = Builder.CreateInsertElement(Iden, StartV, Zero); - } - } - - for (unsigned Part = 0; Part < LastPartForNewPhi; ++Part) { - Value *EntryPart = State.get(this, Part); - // Make sure to add the reduction start value only to the - // first unroll part. - Value *StartVal = (Part == 0) ? StartV : Iden; - cast<PHINode>(EntryPart)->addIncoming(StartVal, State.CFG.VectorPreHeader); - } -} - -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -void VPReductionPHIRecipe::print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const { - O << Indent << "WIDEN-REDUCTION-PHI "; - - printAsOperand(O, SlotTracker); - O << " = phi "; - printOperands(O, SlotTracker); -} #endif template void DomTreeBuilder::Calculate<VPDominatorTree>(VPDominatorTree &DT); @@ -1594,7 +1006,10 @@ void VPInterleavedAccessInfo::visitBlock(VPBlockBase *Block, Old2NewTy &Old2New, continue; assert(isa<VPInstruction>(&VPI) && "Can only handle VPInstructions"); auto *VPInst = cast<VPInstruction>(&VPI); - auto *Inst = cast<Instruction>(VPInst->getUnderlyingValue()); + + auto *Inst = dyn_cast_or_null<Instruction>(VPInst->getUnderlyingValue()); + if (!Inst) + continue; auto *IG = IAI.getInterleaveGroup(Inst); if (!IG) continue; @@ -1622,7 +1037,7 @@ void VPInterleavedAccessInfo::visitBlock(VPBlockBase *Block, Old2NewTy &Old2New, VPInterleavedAccessInfo::VPInterleavedAccessInfo(VPlan &Plan, InterleavedAccessInfo &IAI) { Old2NewTy Old2New; - visitRegion(cast<VPRegionBlock>(Plan.getEntry()), Old2New, IAI); + visitRegion(Plan.getVectorLoopRegion(), Old2New, IAI); } void VPSlotTracker::assignSlot(const VPValue *V) { @@ -1632,8 +1047,8 @@ void VPSlotTracker::assignSlot(const VPValue *V) { void VPSlotTracker::assignSlots(const VPlan &Plan) { - for (const VPValue *V : Plan.VPExternalDefs) - assignSlot(V); + for (const auto &P : Plan.VPExternalDefs) + assignSlot(P.second); assignSlot(&Plan.VectorTripCount); if (Plan.BackedgeTakenCount) @@ -1651,7 +1066,19 @@ void VPSlotTracker::assignSlots(const VPlan &Plan) { } bool vputils::onlyFirstLaneUsed(VPValue *Def) { - return all_of(Def->users(), [Def](VPUser *U) { - return cast<VPRecipeBase>(U)->onlyFirstLaneUsed(Def); - }); + return all_of(Def->users(), + [Def](VPUser *U) { return U->onlyFirstLaneUsed(Def); }); +} + +VPValue *vputils::getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr, + ScalarEvolution &SE) { + if (auto *E = dyn_cast<SCEVConstant>(Expr)) + return Plan.getOrAddExternalDef(E->getValue()); + if (auto *E = dyn_cast<SCEVUnknown>(Expr)) + return Plan.getOrAddExternalDef(E->getValue()); + + VPBasicBlock *Preheader = Plan.getEntry()->getEntryBasicBlock(); + VPValue *Step = new VPExpandSCEVRecipe(Expr, SE); + Preheader->appendRecipe(cast<VPRecipeBase>(Step->getDef())); + return Step; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.h index bcaabca692cc..09da4a545d0d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.h @@ -25,27 +25,26 @@ #ifndef LLVM_TRANSFORMS_VECTORIZE_VPLAN_H #define LLVM_TRANSFORMS_VECTORIZE_VPLAN_H -#include "VPlanLoopInfo.h" #include "VPlanValue.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/ilist.h" #include "llvm/ADT/ilist_node.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/DebugLoc.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/Support/InstructionCost.h" +#include "llvm/IR/FMF.h" +#include "llvm/Transforms/Utils/LoopVersioning.h" #include <algorithm> #include <cassert> #include <cstddef> -#include <map> #include <string> namespace llvm { @@ -54,6 +53,7 @@ class BasicBlock; class DominatorTree; class InductionDescriptor; class InnerLoopVectorizer; +class IRBuilderBase; class LoopInfo; class raw_ostream; class RecurrenceDescriptor; @@ -67,10 +67,11 @@ class VPlanSlp; /// Returns a calculation for the total number of elements for a given \p VF. /// For fixed width vectors this value is a constant, whereas for scalable /// vectors it is an expression determined at runtime. -Value *getRuntimeVF(IRBuilder<> &B, Type *Ty, ElementCount VF); +Value *getRuntimeVF(IRBuilderBase &B, Type *Ty, ElementCount VF); /// Return a value for Step multiplied by VF. -Value *createStepForVF(IRBuilder<> &B, Type *Ty, ElementCount VF, int64_t Step); +Value *createStepForVF(IRBuilderBase &B, Type *Ty, ElementCount VF, + int64_t Step); /// A range of powers-of-2 vectorization factors with fixed start and /// adjustable end. The range includes start and excludes end, e.g.,: @@ -151,7 +152,7 @@ public: /// Returns an expression describing the lane index that can be used at /// runtime. - Value *getAsRuntimeExpr(IRBuilder<> &Builder, const ElementCount &VF) const; + Value *getAsRuntimeExpr(IRBuilderBase &Builder, const ElementCount &VF) const; /// Returns the Kind of lane offset. Kind getKind() const { return LaneKind; } @@ -199,10 +200,10 @@ struct VPIteration { /// needed for generating the output IR. struct VPTransformState { VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI, - DominatorTree *DT, IRBuilder<> &Builder, + DominatorTree *DT, IRBuilderBase &Builder, InnerLoopVectorizer *ILV, VPlan *Plan) - : VF(VF), UF(UF), LI(LI), DT(DT), Builder(Builder), ILV(ILV), Plan(Plan) { - } + : VF(VF), UF(UF), LI(LI), DT(DT), Builder(Builder), ILV(ILV), Plan(Plan), + LVer(nullptr) {} /// The chosen Vectorization and Unroll Factors of the loop being vectorized. ElementCount VF; @@ -298,6 +299,27 @@ struct VPTransformState { Iter->second[Instance.Part][CacheIdx] = V; } + /// Add additional metadata to \p To that was not present on \p Orig. + /// + /// Currently this is used to add the noalias annotations based on the + /// inserted memchecks. Use this for instructions that are *cloned* into the + /// vector loop. + void addNewMetadata(Instruction *To, const Instruction *Orig); + + /// Add metadata from one instruction to another. + /// + /// This includes both the original MDs from \p From and additional ones (\see + /// addNewMetadata). Use this for *newly created* instructions in the vector + /// loop. + void addMetadata(Instruction *To, Instruction *From); + + /// Similar to the previous function but it adds the metadata to a + /// vector of instructions. + void addMetadata(ArrayRef<Value *> To, Instruction *From); + + /// Set the debug location in the builder using the debug location in \p V. + void setDebugLocFromInst(const Value *V); + /// Hold state information used when constructing the CFG of the output IR, /// traversing the VPBasicBlocks and generating corresponding IR BasicBlocks. struct CFGState { @@ -308,26 +330,19 @@ struct VPTransformState { /// header BasicBlock. BasicBlock *PrevBB = nullptr; - /// The last IR BasicBlock in the output IR. Set to the new latch - /// BasicBlock, used for placing the newly created BasicBlocks. - BasicBlock *LastBB = nullptr; - - /// The IR BasicBlock that is the preheader of the vector loop in the output - /// IR. - /// FIXME: The vector preheader should also be modeled in VPlan, so any code - /// that needs to be added to the preheader gets directly generated by - /// VPlan. There should be no need to manage a pointer to the IR BasicBlock. - BasicBlock *VectorPreHeader = nullptr; + /// The last IR BasicBlock in the output IR. Set to the exit block of the + /// vector loop. + BasicBlock *ExitBB = nullptr; /// A mapping of each VPBasicBlock to the corresponding BasicBlock. In case /// of replication, maps the BasicBlock of the last replica created. SmallDenseMap<VPBasicBlock *, BasicBlock *> VPBB2IRBB; - /// Vector of VPBasicBlocks whose terminator instruction needs to be fixed - /// up at the end of vector code generation. - SmallVector<VPBasicBlock *, 8> VPBBsToFix; - CFGState() = default; + + /// Returns the BasicBlock* mapped to the pre-header of the loop region + /// containing \p R. + BasicBlock *getPreheaderBBFor(VPRecipeBase *R); } CFG; /// Hold a pointer to LoopInfo to register new basic blocks in the loop. @@ -337,7 +352,7 @@ struct VPTransformState { DominatorTree *DT; /// Hold a reference to the IRBuilder used to generate output IR code. - IRBuilder<> &Builder; + IRBuilderBase &Builder; VPValue2ValueTy VPValue2Value; @@ -353,41 +368,16 @@ struct VPTransformState { /// Holds recipes that may generate a poison value that is used after /// vectorization, even when their operands are not poison. SmallPtrSet<VPRecipeBase *, 16> MayGeneratePoisonRecipes; -}; - -/// VPUsers instance used by VPBlockBase to manage CondBit and the block -/// predicate. Currently VPBlockUsers are used in VPBlockBase for historical -/// reasons, but in the future the only VPUsers should either be recipes or -/// live-outs.VPBlockBase uses. -struct VPBlockUser : public VPUser { - VPBlockUser() : VPUser({}, VPUserID::Block) {} - VPValue *getSingleOperandOrNull() { - if (getNumOperands() == 1) - return getOperand(0); + /// The loop object for the current parent region, or nullptr. + Loop *CurrentVectorLoop = nullptr; - return nullptr; - } - const VPValue *getSingleOperandOrNull() const { - if (getNumOperands() == 1) - return getOperand(0); - - return nullptr; - } - - void resetSingleOpUser(VPValue *NewVal) { - assert(getNumOperands() <= 1 && "Didn't expect more than one operand!"); - if (!NewVal) { - if (getNumOperands() == 1) - removeLastOperand(); - return; - } - - if (getNumOperands() == 1) - setOperand(0, NewVal); - else - addOperand(NewVal); - } + /// LoopVersioning. It's only set up (non-null) if memchecks were + /// used. + /// + /// This is currently only used to add no-alias metadata based on the + /// memchecks. The actually versioning is performed manually. + std::unique_ptr<LoopVersioning> LVer; }; /// VPBlockBase is the building block of the Hierarchical Control-Flow Graph. @@ -410,16 +400,6 @@ class VPBlockBase { /// List of successor blocks. SmallVector<VPBlockBase *, 1> Successors; - /// Successor selector managed by a VPUser. For blocks with zero or one - /// successors, there is no operand. Otherwise there is exactly one operand - /// which is the branch condition. - VPBlockUser CondBitUser; - - /// If the block is predicated, its predicate is stored as an operand of this - /// VPUser to maintain the def-use relations. Otherwise there is no operand - /// here. - VPBlockUser PredicateUser; - /// VPlan containing the block. Can only be set on the entry block of the /// plan. VPlan *Plan = nullptr; @@ -493,11 +473,11 @@ public: const VPBasicBlock *getEntryBasicBlock() const; VPBasicBlock *getEntryBasicBlock(); - /// \return the VPBasicBlock that is the exit of this VPBlockBase, + /// \return the VPBasicBlock that is the exiting this VPBlockBase, /// recursively, if the latter is a VPRegionBlock. Otherwise, if this /// VPBlockBase is a VPBasicBlock, it is returned. - const VPBasicBlock *getExitBasicBlock() const; - VPBasicBlock *getExitBasicBlock(); + const VPBasicBlock *getExitingBasicBlock() const; + VPBasicBlock *getExitingBasicBlock(); const VPBlocksTy &getSuccessors() const { return Successors; } VPBlocksTy &getSuccessors() { return Successors; } @@ -565,20 +545,6 @@ public: return getEnclosingBlockWithPredecessors()->getSinglePredecessor(); } - /// \return the condition bit selecting the successor. - VPValue *getCondBit(); - /// \return the condition bit selecting the successor. - const VPValue *getCondBit() const; - /// Set the condition bit selecting the successor. - void setCondBit(VPValue *CV); - - /// \return the block's predicate. - VPValue *getPredicate(); - /// \return the block's predicate. - const VPValue *getPredicate() const; - /// Set the block's predicate. - void setPredicate(VPValue *Pred); - /// Set a given VPBlockBase \p Successor as the single successor of this /// VPBlockBase. This VPBlockBase is not added as predecessor of \p Successor. /// This VPBlockBase must have no successors. @@ -588,14 +554,11 @@ public: } /// Set two given VPBlockBases \p IfTrue and \p IfFalse to be the two - /// successors of this VPBlockBase. \p Condition is set as the successor - /// selector. This VPBlockBase is not added as predecessor of \p IfTrue or \p - /// IfFalse. This VPBlockBase must have no successors. - void setTwoSuccessors(VPBlockBase *IfTrue, VPBlockBase *IfFalse, - VPValue *Condition) { + /// successors of this VPBlockBase. This VPBlockBase is not added as + /// predecessor of \p IfTrue or \p IfFalse. This VPBlockBase must have no + /// successors. + void setTwoSuccessors(VPBlockBase *IfTrue, VPBlockBase *IfFalse) { assert(Successors.empty() && "Setting two successors when others exist."); - assert(Condition && "Setting two successors without condition!"); - setCondBit(Condition); appendSuccessor(IfTrue); appendSuccessor(IfFalse); } @@ -612,11 +575,8 @@ public: /// Remove all the predecessor of this block. void clearPredecessors() { Predecessors.clear(); } - /// Remove all the successors of this block and set to null its condition bit - void clearSuccessors() { - Successors.clear(); - setCondBit(nullptr); - } + /// Remove all the successors of this block. + void clearSuccessors() { Successors.clear(); } /// The method which generates the output IR that correspond to this /// VPBlockBase, thereby "executing" the VPlan. @@ -665,6 +625,32 @@ public: #endif }; +/// A value that is used outside the VPlan. The operand of the user needs to be +/// added to the associated LCSSA phi node. +class VPLiveOut : public VPUser { + PHINode *Phi; + +public: + VPLiveOut(PHINode *Phi, VPValue *Op) + : VPUser({Op}, VPUser::VPUserID::LiveOut), Phi(Phi) {} + + /// Fixup the wrapped LCSSA phi node in the unique exit block. This simply + /// means we need to add the appropriate incoming value from the middle + /// block as exiting edges from the scalar epilogue loop (if present) are + /// already in place, and we exit the vector loop exclusively to the middle + /// block. + void fixPhi(VPlan &Plan, VPTransformState &State); + + /// Returns true if the VPLiveOut uses scalars of operand \p Op. + bool usesScalars(const VPValue *Op) const override { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + return true; + } + + PHINode *getPhi() const { return Phi; } +}; + /// VPRecipeBase is a base class modeling a sequence of one or more output IR /// instructions. VPRecipeBase owns the the VPValues it defines through VPDef /// and is responsible for deleting its defined values. Single-value @@ -699,6 +685,9 @@ public: /// Insert an unlinked recipe into a basic block immediately before /// the specified recipe. void insertBefore(VPRecipeBase *InsertPos); + /// Insert an unlinked recipe into \p BB immediately before the insertion + /// point \p IP; + void insertBefore(VPBasicBlock &BB, iplist<VPRecipeBase>::iterator IP); /// Insert an unlinked Recipe into a basic block immediately after /// the specified Recipe. @@ -759,14 +748,6 @@ public: bool mayReadOrWriteMemory() const { return mayReadFromMemory() || mayWriteToMemory(); } - - /// Returns true if the recipe only uses the first lane of operand \p Op. - /// Conservatively returns false. - virtual bool onlyFirstLaneUsed(const VPValue *Op) const { - assert(is_contained(operands(), Op) && - "Op must be an operand of the recipe"); - return false; - } }; inline bool VPUser::classof(const VPDef *Def) { @@ -804,6 +785,7 @@ public: CanonicalIVIncrement, CanonicalIVIncrementNUW, BranchOnCount, + BranchOnCond }; private: @@ -892,6 +874,7 @@ public: case Instruction::Unreachable: case Instruction::Fence: case Instruction::AtomicRMW: + case VPInstruction::BranchOnCond: case VPInstruction::BranchOnCount: return false; default: @@ -1049,27 +1032,25 @@ public: }; /// A recipe for handling phi nodes of integer and floating-point inductions, -/// producing their vector and scalar values. +/// producing their vector values. class VPWidenIntOrFpInductionRecipe : public VPRecipeBase, public VPValue { PHINode *IV; const InductionDescriptor &IndDesc; - bool NeedsScalarIV; bool NeedsVectorIV; public: - VPWidenIntOrFpInductionRecipe(PHINode *IV, VPValue *Start, + VPWidenIntOrFpInductionRecipe(PHINode *IV, VPValue *Start, VPValue *Step, const InductionDescriptor &IndDesc, - bool NeedsScalarIV, bool NeedsVectorIV) - : VPRecipeBase(VPWidenIntOrFpInductionSC, {Start}), VPValue(IV, this), - IV(IV), IndDesc(IndDesc), NeedsScalarIV(NeedsScalarIV), + bool NeedsVectorIV) + : VPRecipeBase(VPWidenIntOrFpInductionSC, {Start, Step}), + VPValue(IV, this), IV(IV), IndDesc(IndDesc), NeedsVectorIV(NeedsVectorIV) {} - VPWidenIntOrFpInductionRecipe(PHINode *IV, VPValue *Start, + VPWidenIntOrFpInductionRecipe(PHINode *IV, VPValue *Start, VPValue *Step, const InductionDescriptor &IndDesc, - TruncInst *Trunc, bool NeedsScalarIV, - bool NeedsVectorIV) - : VPRecipeBase(VPWidenIntOrFpInductionSC, {Start}), VPValue(Trunc, this), - IV(IV), IndDesc(IndDesc), NeedsScalarIV(NeedsScalarIV), + TruncInst *Trunc, bool NeedsVectorIV) + : VPRecipeBase(VPWidenIntOrFpInductionSC, {Start, Step}), + VPValue(Trunc, this), IV(IV), IndDesc(IndDesc), NeedsVectorIV(NeedsVectorIV) {} ~VPWidenIntOrFpInductionRecipe() override = default; @@ -1093,6 +1074,10 @@ public: VPValue *getStartValue() { return getOperand(0); } const VPValue *getStartValue() const { return getOperand(0); } + /// Returns the step value of the induction. + VPValue *getStepValue() { return getOperand(1); } + const VPValue *getStepValue() const { return getOperand(1); } + /// Returns the first defined value as TruncInst, if it is one or nullptr /// otherwise. TruncInst *getTruncInst() { @@ -1102,6 +1087,8 @@ public: return dyn_cast_or_null<TruncInst>(getVPValue(0)->getUnderlyingValue()); } + PHINode *getPHINode() { return IV; } + /// Returns the induction descriptor for the recipe. const InductionDescriptor &getInductionDescriptor() const { return IndDesc; } @@ -1115,9 +1102,6 @@ public: return TruncI ? TruncI->getType() : IV->getType(); } - /// Returns true if a scalar phi needs to be created for the induction. - bool needsScalarIV() const { return NeedsScalarIV; } - /// Returns true if a vector phi needs to be created for the induction. bool needsVectorIV() const { return NeedsVectorIV; } }; @@ -1167,6 +1151,9 @@ public: VPValue *getStartValue() { return getNumOperands() == 0 ? nullptr : getOperand(0); } + VPValue *getStartValue() const { + return getNumOperands() == 0 ? nullptr : getOperand(0); + } /// Returns the incoming value from the loop backedge. VPValue *getBackedgeValue() { @@ -1180,6 +1167,52 @@ public: } }; +class VPWidenPointerInductionRecipe : public VPHeaderPHIRecipe { + const InductionDescriptor &IndDesc; + + /// SCEV used to expand step. + /// FIXME: move expansion of step to the pre-header, once it is modeled + /// explicitly. + ScalarEvolution &SE; + +public: + /// Create a new VPWidenPointerInductionRecipe for \p Phi with start value \p + /// Start. + VPWidenPointerInductionRecipe(PHINode *Phi, VPValue *Start, + const InductionDescriptor &IndDesc, + ScalarEvolution &SE) + : VPHeaderPHIRecipe(VPVWidenPointerInductionSC, VPWidenPointerInductionSC, + Phi), + IndDesc(IndDesc), SE(SE) { + addOperand(Start); + } + + ~VPWidenPointerInductionRecipe() override = default; + + /// Method to support type inquiry through isa, cast, and dyn_cast. + static inline bool classof(const VPRecipeBase *B) { + return B->getVPDefID() == VPRecipeBase::VPWidenPointerInductionSC; + } + static inline bool classof(const VPHeaderPHIRecipe *R) { + return R->getVPDefID() == VPRecipeBase::VPWidenPointerInductionSC; + } + static inline bool classof(const VPValue *V) { + return V->getVPValueID() == VPValue::VPVWidenPointerInductionSC; + } + + /// Generate vector values for the pointer induction. + void execute(VPTransformState &State) override; + + /// Returns true if only scalar values will be generated. + bool onlyScalarsGenerated(ElementCount VF); + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const override; +#endif +}; + /// A recipe for handling header phis that are widened in the vector loop. /// In the VPlan native path, all incoming VPValues & VPBasicBlock pairs are /// managed in the recipe directly. @@ -1363,9 +1396,8 @@ public: "Op must be an operand of the recipe"); // Recursing through Blend recipes only, must terminate at header phi's the // latest. - return all_of(users(), [this](VPUser *U) { - return cast<VPRecipeBase>(U)->onlyFirstLaneUsed(this); - }); + return all_of(users(), + [this](VPUser *U) { return U->onlyFirstLaneUsed(this); }); } }; @@ -1440,6 +1472,15 @@ public: unsigned getNumStoreOperands() const { return getNumOperands() - (HasMask ? 2 : 1); } + + /// The recipe only uses the first lane of the address. + bool onlyFirstLaneUsed(const VPValue *Op) const override { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + return Op == getAddr() && all_of(getStoredValues(), [Op](VPValue *StoredV) { + return Op != StoredV; + }); + } }; /// A recipe to represent inloop reduction operations, performing a reduction on @@ -1551,6 +1592,13 @@ public: "Op must be an operand of the recipe"); return isUniform(); } + + /// Returns true if the recipe uses scalars of operand \p Op. + bool usesScalars(const VPValue *Op) const override { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + return true; + } }; /// A recipe for generating conditional branches on the bits of a mask. @@ -1590,6 +1638,13 @@ public: // Mask is optional. return getNumOperands() == 1 ? getOperand(0) : nullptr; } + + /// Returns true if the recipe uses scalars of operand \p Op. + bool usesScalars(const VPValue *Op) const override { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + return true; + } }; /// VPPredInstPHIRecipe is a recipe for generating the phi nodes needed when @@ -1619,6 +1674,13 @@ public: void print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const override; #endif + + /// Returns true if the recipe uses scalars of operand \p Op. + bool usesScalars(const VPValue *Op) const override { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + return true; + } }; /// A Recipe for widening load/store operations. @@ -1627,7 +1689,7 @@ public: /// - For store: Address, stored value, optional mask /// TODO: We currently execute only per-part unless a specific instance is /// provided. -class VPWidenMemoryInstructionRecipe : public VPRecipeBase, public VPValue { +class VPWidenMemoryInstructionRecipe : public VPRecipeBase { Instruction &Ingredient; // Whether the loaded-from / stored-to addresses are consecutive. @@ -1649,10 +1711,10 @@ class VPWidenMemoryInstructionRecipe : public VPRecipeBase, public VPValue { public: VPWidenMemoryInstructionRecipe(LoadInst &Load, VPValue *Addr, VPValue *Mask, bool Consecutive, bool Reverse) - : VPRecipeBase(VPWidenMemoryInstructionSC, {Addr}), - VPValue(VPValue::VPVMemoryInstructionSC, &Load, this), Ingredient(Load), + : VPRecipeBase(VPWidenMemoryInstructionSC, {Addr}), Ingredient(Load), Consecutive(Consecutive), Reverse(Reverse) { assert((Consecutive || !Reverse) && "Reverse implies consecutive"); + new VPValue(VPValue::VPVMemoryInstructionSC, &Load, this); setMask(Mask); } @@ -1660,7 +1722,6 @@ public: VPValue *StoredValue, VPValue *Mask, bool Consecutive, bool Reverse) : VPRecipeBase(VPWidenMemoryInstructionSC, {Addr, StoredValue}), - VPValue(VPValue::VPVMemoryInstructionSC, &Store, this), Ingredient(Store), Consecutive(Consecutive), Reverse(Reverse) { assert((Consecutive || !Reverse) && "Reverse implies consecutive"); setMask(Mask); @@ -1714,9 +1775,42 @@ public: "Op must be an operand of the recipe"); // Widened, consecutive memory operations only demand the first lane of - // their address. - return Op == getAddr() && isConsecutive(); + // their address, unless the same operand is also stored. That latter can + // happen with opaque pointers. + return Op == getAddr() && isConsecutive() && + (!isStore() || Op != getStoredValue()); + } + + Instruction &getIngredient() const { return Ingredient; } +}; + +/// Recipe to expand a SCEV expression. +class VPExpandSCEVRecipe : public VPRecipeBase, public VPValue { + const SCEV *Expr; + ScalarEvolution &SE; + +public: + VPExpandSCEVRecipe(const SCEV *Expr, ScalarEvolution &SE) + : VPRecipeBase(VPExpandSCEVSC, {}), VPValue(nullptr, this), Expr(Expr), + SE(SE) {} + + ~VPExpandSCEVRecipe() override = default; + + /// Method to support type inquiry through isa, cast, and dyn_cast. + static inline bool classof(const VPDef *D) { + return D->getVPDefID() == VPExpandSCEVSC; } + + /// Generate a canonical vector induction variable of the vector loop, with + void execute(VPTransformState &State) override; + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const override; +#endif + + const SCEV *getSCEV() const { return Expr; } }; /// Canonical scalar induction phi of the vector loop. Starting at the specified @@ -1738,6 +1832,12 @@ public: static inline bool classof(const VPDef *D) { return D->getVPDefID() == VPCanonicalIVPHISC; } + static inline bool classof(const VPHeaderPHIRecipe *D) { + return D->getVPDefID() == VPCanonicalIVPHISC; + } + static inline bool classof(const VPValue *V) { + return V->getVPValueID() == VPValue::VPVCanonicalIVPHISC; + } /// Generate the canonical scalar induction phi of the vector loop. void execute(VPTransformState &State) override; @@ -1803,6 +1903,64 @@ public: } }; +/// A recipe for handling phi nodes of integer and floating-point inductions, +/// producing their scalar values. +class VPScalarIVStepsRecipe : public VPRecipeBase, public VPValue { + /// Scalar type to use for the generated values. + Type *Ty; + /// If not nullptr, truncate the generated values to TruncToTy. + Type *TruncToTy; + const InductionDescriptor &IndDesc; + +public: + VPScalarIVStepsRecipe(Type *Ty, const InductionDescriptor &IndDesc, + VPValue *CanonicalIV, VPValue *Start, VPValue *Step, + Type *TruncToTy) + : VPRecipeBase(VPScalarIVStepsSC, {CanonicalIV, Start, Step}), + VPValue(nullptr, this), Ty(Ty), TruncToTy(TruncToTy), IndDesc(IndDesc) { + } + + ~VPScalarIVStepsRecipe() override = default; + + /// Method to support type inquiry through isa, cast, and dyn_cast. + static inline bool classof(const VPDef *D) { + return D->getVPDefID() == VPRecipeBase::VPScalarIVStepsSC; + } + /// Extra classof implementations to allow directly casting from VPUser -> + /// VPScalarIVStepsRecipe. + static inline bool classof(const VPUser *U) { + auto *R = dyn_cast<VPRecipeBase>(U); + return R && R->getVPDefID() == VPRecipeBase::VPScalarIVStepsSC; + } + static inline bool classof(const VPRecipeBase *R) { + return R->getVPDefID() == VPRecipeBase::VPScalarIVStepsSC; + } + + /// Generate the scalarized versions of the phi node as needed by their users. + void execute(VPTransformState &State) override; + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const override; +#endif + + /// Returns true if the induction is canonical, i.e. starting at 0 and + /// incremented by UF * VF (= the original IV is incremented by 1). + bool isCanonical() const; + + VPCanonicalIVPHIRecipe *getCanonicalIV() const; + VPValue *getStartValue() const { return getOperand(1); } + VPValue *getStepValue() const { return getOperand(2); } + + /// Returns true if the recipe only uses the first lane of operand \p Op. + bool onlyFirstLaneUsed(const VPValue *Op) const override { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + return true; + } +}; + /// VPBasicBlock serves as the leaf of the Hierarchical Control-Flow Graph. It /// holds a sequence of zero or more VPRecipe's each representing a sequence of /// output IR instructions. All PHI-like recipes must come before any non-PHI recipes. @@ -1895,6 +2053,8 @@ public: /// SplitAt to the new block. Returns the new block. VPBasicBlock *splitAt(iterator SplitAt); + VPRegionBlock *getEnclosingLoopRegion(); + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) /// Print this VPBsicBlock to \p O, prefixing all lines with \p Indent. \p /// SlotTracker is used to print unnamed VPValue's using consequtive numbers. @@ -1906,6 +2066,14 @@ public: using VPBlockBase::print; // Get the print(raw_stream &O) version. #endif + /// If the block has multiple successors, return the branch recipe terminating + /// the block. If there are no or only a single successor, return nullptr; + VPRecipeBase *getTerminator(); + const VPRecipeBase *getTerminator() const; + + /// Returns true if the block is exiting it's parent region. + bool isExiting() const; + private: /// Create an IR BasicBlock to hold the output instructions generated by this /// VPBasicBlock, and return it. Update the CFGState accordingly. @@ -1913,7 +2081,7 @@ private: }; /// VPRegionBlock represents a collection of VPBasicBlocks and VPRegionBlocks -/// which form a Single-Entry-Single-Exit subgraph of the output IR CFG. +/// which form a Single-Entry-Single-Exiting subgraph of the output IR CFG. /// A VPRegionBlock may indicate that its contents are to be replicated several /// times. This is designed to support predicated scalarization, in which a /// scalar if-then code structure needs to be generated VF * UF times. Having @@ -1924,25 +2092,26 @@ class VPRegionBlock : public VPBlockBase { /// Hold the Single Entry of the SESE region modelled by the VPRegionBlock. VPBlockBase *Entry; - /// Hold the Single Exit of the SESE region modelled by the VPRegionBlock. - VPBlockBase *Exit; + /// Hold the Single Exiting block of the SESE region modelled by the + /// VPRegionBlock. + VPBlockBase *Exiting; /// An indicator whether this region is to generate multiple replicated /// instances of output IR corresponding to its VPBlockBases. bool IsReplicator; public: - VPRegionBlock(VPBlockBase *Entry, VPBlockBase *Exit, + VPRegionBlock(VPBlockBase *Entry, VPBlockBase *Exiting, const std::string &Name = "", bool IsReplicator = false) - : VPBlockBase(VPRegionBlockSC, Name), Entry(Entry), Exit(Exit), + : VPBlockBase(VPRegionBlockSC, Name), Entry(Entry), Exiting(Exiting), IsReplicator(IsReplicator) { assert(Entry->getPredecessors().empty() && "Entry block has predecessors."); - assert(Exit->getSuccessors().empty() && "Exit block has successors."); + assert(Exiting->getSuccessors().empty() && "Exit block has successors."); Entry->setParent(this); - Exit->setParent(this); + Exiting->setParent(this); } VPRegionBlock(const std::string &Name = "", bool IsReplicator = false) - : VPBlockBase(VPRegionBlockSC, Name), Entry(nullptr), Exit(nullptr), + : VPBlockBase(VPRegionBlockSC, Name), Entry(nullptr), Exiting(nullptr), IsReplicator(IsReplicator) {} ~VPRegionBlock() override { @@ -1976,16 +2145,22 @@ public: // DominatorTreeBase representing the Graph type. VPBlockBase &front() const { return *Entry; } - const VPBlockBase *getExit() const { return Exit; } - VPBlockBase *getExit() { return Exit; } + const VPBlockBase *getExiting() const { return Exiting; } + VPBlockBase *getExiting() { return Exiting; } - /// Set \p ExitBlock as the exit VPBlockBase of this VPRegionBlock. \p - /// ExitBlock must have no successors. - void setExit(VPBlockBase *ExitBlock) { - assert(ExitBlock->getSuccessors().empty() && + /// Set \p ExitingBlock as the exiting VPBlockBase of this VPRegionBlock. \p + /// ExitingBlock must have no successors. + void setExiting(VPBlockBase *ExitingBlock) { + assert(ExitingBlock->getSuccessors().empty() && "Exit block cannot have successors."); - Exit = ExitBlock; - ExitBlock->setParent(this); + Exiting = ExitingBlock; + ExitingBlock->setParent(this); + } + + /// Returns the pre-header VPBasicBlock of the loop region. + VPBasicBlock *getPreheaderVPBB() { + assert(!isReplicator() && "should only get pre-header of loop regions"); + return getSinglePredecessor()->getExitingBasicBlock(); } /// An indicator whether this region is to generate multiple replicated @@ -2119,11 +2294,11 @@ struct GraphTraits<Inverse<VPRegionBlock *>> using nodes_iterator = df_iterator<NodeRef>; static NodeRef getEntryNode(Inverse<GraphRef> N) { - return N.Graph->getExit(); + return N.Graph->getExiting(); } static nodes_iterator nodes_begin(GraphRef N) { - return nodes_iterator::begin(N->getExit()); + return nodes_iterator::begin(N->getExiting()); } static nodes_iterator nodes_end(GraphRef N) { @@ -2281,12 +2456,9 @@ class VPlan { /// Holds the name of the VPlan, for printing. std::string Name; - /// Holds all the external definitions created for this VPlan. - // TODO: Introduce a specific representation for external definitions in - // VPlan. External definitions must be immutable and hold a pointer to its - // underlying IR that will be used to implement its structural comparison - // (operators '==' and '<'). - SetVector<VPValue *> VPExternalDefs; + /// Holds all the external definitions created for this VPlan. External + /// definitions must be immutable and hold a pointer to their underlying IR. + DenseMap<Value *, VPValue *> VPExternalDefs; /// Represents the trip count of the original loop, for folding /// the tail. @@ -2307,13 +2479,13 @@ class VPlan { /// to be free when the plan's destructor is called. SmallVector<VPValue *, 16> VPValuesToFree; - /// Holds the VPLoopInfo analysis for this VPlan. - VPLoopInfo VPLInfo; - /// Indicates whether it is safe use the Value2VPValue mapping or if the /// mapping cannot be used any longer, because it is stale. bool Value2VPValueEnabled = true; + /// Values used outside the plan. + MapVector<PHINode *, VPLiveOut *> LiveOuts; + public: VPlan(VPBlockBase *Entry = nullptr) : Entry(Entry) { if (Entry) @@ -2321,6 +2493,8 @@ public: } ~VPlan() { + clearLiveOuts(); + if (Entry) { VPValue DummyValue; for (VPBlockBase *Block : depth_first(Entry)) @@ -2334,13 +2508,14 @@ public: delete TripCount; if (BackedgeTakenCount) delete BackedgeTakenCount; - for (VPValue *Def : VPExternalDefs) - delete Def; + for (auto &P : VPExternalDefs) + delete P.second; } /// Prepare the plan for execution, setting up the required live-in values. void prepareToExecute(Value *TripCount, Value *VectorTripCount, - Value *CanonicalIVStartValue, VPTransformState &State); + Value *CanonicalIVStartValue, VPTransformState &State, + bool IsEpilogueVectorization); /// Generate the IR code for this VPlan. void execute(struct VPTransformState *State); @@ -2383,9 +2558,13 @@ public: void setName(const Twine &newName) { Name = newName.str(); } - /// Add \p VPVal to the pool of external definitions if it's not already - /// in the pool. - void addExternalDef(VPValue *VPVal) { VPExternalDefs.insert(VPVal); } + /// Get the existing or add a new external definition for \p V. + VPValue *getOrAddExternalDef(Value *V) { + auto I = VPExternalDefs.insert({V, nullptr}); + if (I.second) + I.first->second = new VPValue(V); + return I.first->second; + } void addVPValue(Value *V) { assert(Value2VPValueEnabled && @@ -2432,10 +2611,6 @@ public: Value2VPValue.erase(V); } - /// Return the VPLoopInfo analysis for this VPlan. - VPLoopInfo &getVPLoopInfo() { return VPLInfo; } - const VPLoopInfo &getVPLoopInfo() const { return VPLInfo; } - #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) /// Print this VPlan to \p O. void print(raw_ostream &O) const; @@ -2465,7 +2640,10 @@ public: /// Returns the VPRegionBlock of the vector loop. VPRegionBlock *getVectorLoopRegion() { - return cast<VPRegionBlock>(getEntry()); + return cast<VPRegionBlock>(getEntry()->getSingleSuccessor()); + } + const VPRegionBlock *getVectorLoopRegion() const { + return cast<VPRegionBlock>(getEntry()->getSingleSuccessor()); } /// Returns the canonical induction recipe of the vector loop. @@ -2478,6 +2656,23 @@ public: return cast<VPCanonicalIVPHIRecipe>(&*EntryVPBB->begin()); } + void addLiveOut(PHINode *PN, VPValue *V); + + void clearLiveOuts() { + for (auto &KV : LiveOuts) + delete KV.second; + LiveOuts.clear(); + } + + void removeLiveOut(PHINode *PN) { + delete LiveOuts[PN]; + LiveOuts.erase(PN); + } + + const MapVector<PHINode *, VPLiveOut *> &getLiveOuts() const { + return LiveOuts; + } + private: /// Add to the given dominator tree the header block and every new basic block /// that was created between it and the latch block, inclusive. @@ -2567,9 +2762,8 @@ public: /// Insert disconnected VPBlockBase \p NewBlock after \p BlockPtr. Add \p /// NewBlock as successor of \p BlockPtr and \p BlockPtr as predecessor of \p /// NewBlock, and propagate \p BlockPtr parent to \p NewBlock. \p BlockPtr's - /// successors are moved from \p BlockPtr to \p NewBlock and \p BlockPtr's - /// conditional bit is propagated to \p NewBlock. \p NewBlock must have - /// neither successors nor predecessors. + /// successors are moved from \p BlockPtr to \p NewBlock. \p NewBlock must + /// have neither successors nor predecessors. static void insertBlockAfter(VPBlockBase *NewBlock, VPBlockBase *BlockPtr) { assert(NewBlock->getSuccessors().empty() && NewBlock->getPredecessors().empty() && @@ -2580,24 +2774,22 @@ public: disconnectBlocks(BlockPtr, Succ); connectBlocks(NewBlock, Succ); } - NewBlock->setCondBit(BlockPtr->getCondBit()); - BlockPtr->setCondBit(nullptr); connectBlocks(BlockPtr, NewBlock); } /// Insert disconnected VPBlockBases \p IfTrue and \p IfFalse after \p /// BlockPtr. Add \p IfTrue and \p IfFalse as succesors of \p BlockPtr and \p /// BlockPtr as predecessor of \p IfTrue and \p IfFalse. Propagate \p BlockPtr - /// parent to \p IfTrue and \p IfFalse. \p Condition is set as the successor - /// selector. \p BlockPtr must have no successors and \p IfTrue and \p IfFalse - /// must have neither successors nor predecessors. + /// parent to \p IfTrue and \p IfFalse. \p BlockPtr must have no successors + /// and \p IfTrue and \p IfFalse must have neither successors nor + /// predecessors. static void insertTwoBlocksAfter(VPBlockBase *IfTrue, VPBlockBase *IfFalse, - VPValue *Condition, VPBlockBase *BlockPtr) { + VPBlockBase *BlockPtr) { assert(IfTrue->getSuccessors().empty() && "Can't insert IfTrue with successors."); assert(IfFalse->getSuccessors().empty() && "Can't insert IfFalse with successors."); - BlockPtr->setTwoSuccessors(IfTrue, IfFalse, Condition); + BlockPtr->setTwoSuccessors(IfTrue, IfFalse); IfTrue->setPredecessors({BlockPtr}); IfFalse->setPredecessors({BlockPtr}); IfTrue->setParent(BlockPtr->getParent()); @@ -2639,8 +2831,8 @@ public: R.moveBefore(*PredVPBB, PredVPBB->end()); VPBlockUtils::disconnectBlocks(PredVPBB, VPBB); auto *ParentRegion = cast<VPRegionBlock>(Block->getParent()); - if (ParentRegion->getExit() == Block) - ParentRegion->setExit(PredVPBB); + if (ParentRegion->getExiting() == Block) + ParentRegion->setExiting(PredVPBB); SmallVector<VPBlockBase *> Successors(Block->successors()); for (auto *Succ : Successors) { VPBlockUtils::disconnectBlocks(Block, Succ); @@ -2650,41 +2842,6 @@ public: return PredVPBB; } - /// Returns true if the edge \p FromBlock -> \p ToBlock is a back-edge. - static bool isBackEdge(const VPBlockBase *FromBlock, - const VPBlockBase *ToBlock, const VPLoopInfo *VPLI) { - assert(FromBlock->getParent() == ToBlock->getParent() && - FromBlock->getParent() && "Must be in same region"); - const VPLoop *FromLoop = VPLI->getLoopFor(FromBlock); - const VPLoop *ToLoop = VPLI->getLoopFor(ToBlock); - if (!FromLoop || !ToLoop || FromLoop != ToLoop) - return false; - - // A back-edge is a branch from the loop latch to its header. - return ToLoop->isLoopLatch(FromBlock) && ToBlock == ToLoop->getHeader(); - } - - /// Returns true if \p Block is a loop latch - static bool blockIsLoopLatch(const VPBlockBase *Block, - const VPLoopInfo *VPLInfo) { - if (const VPLoop *ParentVPL = VPLInfo->getLoopFor(Block)) - return ParentVPL->isLoopLatch(Block); - - return false; - } - - /// Count and return the number of succesors of \p PredBlock excluding any - /// backedges. - static unsigned countSuccessorsNoBE(VPBlockBase *PredBlock, - VPLoopInfo *VPLI) { - unsigned Count = 0; - for (VPBlockBase *SuccBlock : PredBlock->getSuccessors()) { - if (!VPBlockUtils::isBackEdge(PredBlock, SuccBlock, VPLI)) - Count++; - } - return Count; - } - /// Return an iterator range over \p Range which only includes \p BlockTy /// blocks. The accesses are casted to \p BlockTy. template <typename BlockTy, typename T> @@ -2845,6 +3002,13 @@ namespace vputils { /// Returns true if only the first lane of \p Def is used. bool onlyFirstLaneUsed(VPValue *Def); +/// Get or create a VPValue that corresponds to the expansion of \p Expr. If \p +/// Expr is a SCEVConstant or SCEVUnknown, return a VPValue wrapping the live-in +/// value. Otherwise return a VPExpandSCEVRecipe to expand \p Expr. If \p Plan's +/// pre-header already contains a recipe expanding \p Expr, return it. If not, +/// create a new one. +VPValue *getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr, + ScalarEvolution &SE); } // end namespace vputils } // end namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp index 379988733312..84b0dac862b6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp @@ -42,9 +42,6 @@ private: // Vectorization plan that we are working on. VPlan &Plan; - // Output Top Region. - VPRegionBlock *TopRegion = nullptr; - // Builder of the VPlan instruction-level representation. VPBuilder VPIRBuilder; @@ -59,6 +56,9 @@ private: // Hold phi node's that need to be fixed once the plain CFG has been built. SmallVector<PHINode *, 8> PhisToFix; + /// Maps loops in the original IR to their corresponding region. + DenseMap<Loop *, VPRegionBlock *> Loop2Region; + // Utility functions. void setVPBBPredsFromBB(VPBasicBlock *VPBB, BasicBlock *BB); void fixPhiNodes(); @@ -73,8 +73,9 @@ public: PlainCFGBuilder(Loop *Lp, LoopInfo *LI, VPlan &P) : TheLoop(Lp), LI(LI), Plan(P) {} - // Build the plain CFG and return its Top Region. - VPRegionBlock *buildPlainCFG(); + /// Build plain CFG for TheLoop. Return the pre-header VPBasicBlock connected + /// to a new VPRegionBlock (TopRegion) enclosing the plain CFG. + VPBasicBlock *buildPlainCFG(); }; } // anonymous namespace @@ -106,19 +107,32 @@ void PlainCFGBuilder::fixPhiNodes() { } } -// Create a new empty VPBasicBlock for an incoming BasicBlock or retrieve an -// existing one if it was already created. +// Create a new empty VPBasicBlock for an incoming BasicBlock in the region +// corresponding to the containing loop or retrieve an existing one if it was +// already created. If no region exists yet for the loop containing \p BB, a new +// one is created. VPBasicBlock *PlainCFGBuilder::getOrCreateVPBB(BasicBlock *BB) { auto BlockIt = BB2VPBB.find(BB); if (BlockIt != BB2VPBB.end()) // Retrieve existing VPBB. return BlockIt->second; + // Get or create a region for the loop containing BB. + Loop *CurrentLoop = LI->getLoopFor(BB); + VPRegionBlock *ParentR = nullptr; + if (CurrentLoop) { + auto Iter = Loop2Region.insert({CurrentLoop, nullptr}); + if (Iter.second) + Iter.first->second = new VPRegionBlock( + CurrentLoop->getHeader()->getName().str(), false /*isReplicator*/); + ParentR = Iter.first->second; + } + // Create new VPBB. LLVM_DEBUG(dbgs() << "Creating VPBasicBlock for " << BB->getName() << "\n"); VPBasicBlock *VPBB = new VPBasicBlock(BB->getName()); BB2VPBB[BB] = VPBB; - VPBB->setParent(TopRegion); + VPBB->setParent(ParentR); return VPBB; } @@ -182,8 +196,7 @@ VPValue *PlainCFGBuilder::getOrCreateVPOperand(Value *IRVal) { // A and B: Create VPValue and add it to the pool of external definitions and // to the Value->VPValue map. - VPValue *NewVPVal = new VPValue(IRVal); - Plan.addExternalDef(NewVPVal); + VPValue *NewVPVal = Plan.getOrAddExternalDef(IRVal); IRDef2VPValue[IRVal] = NewVPVal; return NewVPVal; } @@ -203,10 +216,13 @@ void PlainCFGBuilder::createVPInstructionsForVPBB(VPBasicBlock *VPBB, "Instruction shouldn't have been visited."); if (auto *Br = dyn_cast<BranchInst>(Inst)) { - // Branch instruction is not explicitly represented in VPlan but we need - // to represent its condition bit when it's conditional. - if (Br->isConditional()) - getOrCreateVPOperand(Br->getCondition()); + // Conditional branch instruction are represented using BranchOnCond + // recipes. + if (Br->isConditional()) { + VPValue *Cond = getOrCreateVPOperand(Br->getCondition()); + VPBB->appendRecipe( + new VPInstruction(VPInstruction::BranchOnCond, {Cond})); + } // Skip the rest of the Instruction processing for Branch instructions. continue; @@ -238,11 +254,8 @@ void PlainCFGBuilder::createVPInstructionsForVPBB(VPBasicBlock *VPBB, } // Main interface to build the plain CFG. -VPRegionBlock *PlainCFGBuilder::buildPlainCFG() { - // 1. Create the Top Region. It will be the parent of all VPBBs. - TopRegion = new VPRegionBlock("TopRegion", false /*isReplicator*/); - - // 2. Scan the body of the loop in a topological order to visit each basic +VPBasicBlock *PlainCFGBuilder::buildPlainCFG() { + // 1. Scan the body of the loop in a topological order to visit each basic // block after having visited its predecessor basic blocks. Create a VPBB for // each BB and link it to its successor and predecessor VPBBs. Note that // predecessors must be set in the same order as they are in the incomming IR. @@ -251,21 +264,20 @@ VPRegionBlock *PlainCFGBuilder::buildPlainCFG() { // Loop PH needs to be explicitly visited since it's not taken into account by // LoopBlocksDFS. - BasicBlock *PreheaderBB = TheLoop->getLoopPreheader(); - assert((PreheaderBB->getTerminator()->getNumSuccessors() == 1) && + BasicBlock *ThePreheaderBB = TheLoop->getLoopPreheader(); + assert((ThePreheaderBB->getTerminator()->getNumSuccessors() == 1) && "Unexpected loop preheader"); - VPBasicBlock *PreheaderVPBB = getOrCreateVPBB(PreheaderBB); - for (auto &I : *PreheaderBB) { + VPBasicBlock *ThePreheaderVPBB = getOrCreateVPBB(ThePreheaderBB); + ThePreheaderVPBB->setName("vector.ph"); + for (auto &I : *ThePreheaderBB) { if (I.getType()->isVoidTy()) continue; - VPValue *VPV = new VPValue(&I); - Plan.addExternalDef(VPV); - IRDef2VPValue[&I] = VPV; + IRDef2VPValue[&I] = Plan.getOrAddExternalDef(&I); } // Create empty VPBB for Loop H so that we can link PH->H. VPBlockBase *HeaderVPBB = getOrCreateVPBB(TheLoop->getHeader()); - // Preheader's predecessors will be set during the loop RPO traversal below. - PreheaderVPBB->setOneSuccessor(HeaderVPBB); + HeaderVPBB->setName("vector.body"); + ThePreheaderVPBB->setOneSuccessor(HeaderVPBB); LoopBlocksRPO RPO(TheLoop); RPO.perform(LI); @@ -295,16 +307,13 @@ VPRegionBlock *PlainCFGBuilder::buildPlainCFG() { // Get VPBB's condition bit. assert(isa<BranchInst>(TI) && "Unsupported terminator!"); - auto *Br = cast<BranchInst>(TI); - Value *BrCond = Br->getCondition(); // Look up the branch condition to get the corresponding VPValue // representing the condition bit in VPlan (which may be in another VPBB). - assert(IRDef2VPValue.count(BrCond) && + assert(IRDef2VPValue.count(cast<BranchInst>(TI)->getCondition()) && "Missing condition bit in IRDef2VPValue!"); - VPValue *VPCondBit = IRDef2VPValue[BrCond]; - // Link successors using condition bit. - VPBB->setTwoSuccessors(SuccVPBB0, SuccVPBB1, VPCondBit); + // Link successors. + VPBB->setTwoSuccessors(SuccVPBB0, SuccVPBB1); } else llvm_unreachable("Number of successors not supported."); @@ -312,30 +321,61 @@ VPRegionBlock *PlainCFGBuilder::buildPlainCFG() { setVPBBPredsFromBB(VPBB, BB); } - // 3. Process outermost loop exit. We created an empty VPBB for the loop + // 2. Process outermost loop exit. We created an empty VPBB for the loop // single exit BB during the RPO traversal of the loop body but Instructions // weren't visited because it's not part of the the loop. BasicBlock *LoopExitBB = TheLoop->getUniqueExitBlock(); assert(LoopExitBB && "Loops with multiple exits are not supported."); VPBasicBlock *LoopExitVPBB = BB2VPBB[LoopExitBB]; - createVPInstructionsForVPBB(LoopExitVPBB, LoopExitBB); // Loop exit was already set as successor of the loop exiting BB. // We only set its predecessor VPBB now. setVPBBPredsFromBB(LoopExitVPBB, LoopExitBB); + // 3. Fix up region blocks for loops. For each loop, + // * use the header block as entry to the corresponding region, + // * use the latch block as exit of the corresponding region, + // * set the region as successor of the loop pre-header, and + // * set the exit block as successor to the region. + SmallVector<Loop *> LoopWorkList; + LoopWorkList.push_back(TheLoop); + while (!LoopWorkList.empty()) { + Loop *L = LoopWorkList.pop_back_val(); + BasicBlock *Header = L->getHeader(); + BasicBlock *Exiting = L->getLoopLatch(); + assert(Exiting == L->getExitingBlock() && + "Latch must be the only exiting block"); + VPRegionBlock *Region = Loop2Region[L]; + VPBasicBlock *HeaderVPBB = getOrCreateVPBB(Header); + VPBasicBlock *ExitingVPBB = getOrCreateVPBB(Exiting); + + // Disconnect backedge and pre-header from header. + VPBasicBlock *PreheaderVPBB = getOrCreateVPBB(L->getLoopPreheader()); + VPBlockUtils::disconnectBlocks(PreheaderVPBB, HeaderVPBB); + VPBlockUtils::disconnectBlocks(ExitingVPBB, HeaderVPBB); + + Region->setParent(PreheaderVPBB->getParent()); + Region->setEntry(HeaderVPBB); + VPBlockUtils::connectBlocks(PreheaderVPBB, Region); + + // Disconnect exit block from exiting (=latch) block, set exiting block and + // connect region to exit block. + VPBasicBlock *ExitVPBB = getOrCreateVPBB(L->getExitBlock()); + VPBlockUtils::disconnectBlocks(ExitingVPBB, ExitVPBB); + Region->setExiting(ExitingVPBB); + VPBlockUtils::connectBlocks(Region, ExitVPBB); + + // Queue sub-loops for processing. + LoopWorkList.append(L->begin(), L->end()); + } // 4. The whole CFG has been built at this point so all the input Values must // have a VPlan couterpart. Fix VPlan phi nodes by adding their corresponding // VPlan operands. fixPhiNodes(); - // 5. Final Top Region setup. Set outermost loop pre-header and single exit as - // Top Region entry and exit. - TopRegion->setEntry(PreheaderVPBB); - TopRegion->setExit(LoopExitVPBB); - return TopRegion; + return ThePreheaderVPBB; } -VPRegionBlock *VPlanHCFGBuilder::buildPlainCFG() { +VPBasicBlock *VPlanHCFGBuilder::buildPlainCFG() { PlainCFGBuilder PCFGBuilder(TheLoop, LI, Plan); return PCFGBuilder.buildPlainCFG(); } @@ -343,20 +383,15 @@ VPRegionBlock *VPlanHCFGBuilder::buildPlainCFG() { // Public interface to build a H-CFG. void VPlanHCFGBuilder::buildHierarchicalCFG() { // Build Top Region enclosing the plain CFG and set it as VPlan entry. - VPRegionBlock *TopRegion = buildPlainCFG(); - Plan.setEntry(TopRegion); + VPBasicBlock *EntryVPBB = buildPlainCFG(); + Plan.setEntry(EntryVPBB); LLVM_DEBUG(Plan.setName("HCFGBuilder: Plain CFG\n"); dbgs() << Plan); + VPRegionBlock *TopRegion = Plan.getVectorLoopRegion(); Verifier.verifyHierarchicalCFG(TopRegion); // Compute plain CFG dom tree for VPLInfo. VPDomTree.recalculate(*TopRegion); LLVM_DEBUG(dbgs() << "Dominator Tree after building the plain CFG.\n"; VPDomTree.print(dbgs())); - - // Compute VPLInfo and keep it in Plan. - VPLoopInfo &VPLInfo = Plan.getVPLoopInfo(); - VPLInfo.analyze(VPDomTree); - LLVM_DEBUG(dbgs() << "VPLoop Info After buildPlainCFG:\n"; - VPLInfo.print(dbgs())); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h index 238ee7e6347c..2d52990af268 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h @@ -24,13 +24,15 @@ #ifndef LLVM_TRANSFORMS_VECTORIZE_VPLAN_VPLANHCFGBUILDER_H #define LLVM_TRANSFORMS_VECTORIZE_VPLAN_VPLANHCFGBUILDER_H -#include "VPlan.h" #include "VPlanDominatorTree.h" #include "VPlanVerifier.h" namespace llvm { class Loop; +class LoopInfo; +class VPRegionBlock; +class VPlan; class VPlanTestBase; /// Main class to build the VPlan H-CFG for an incoming IR. @@ -55,9 +57,9 @@ private: // are introduced. VPDominatorTree VPDomTree; - /// Build plain CFG for TheLoop. Return a new VPRegionBlock (TopRegion) - /// enclosing the plain CFG. - VPRegionBlock *buildPlainCFG(); + /// Build plain CFG for TheLoop. Return the pre-header VPBasicBlock connected + /// to a new VPRegionBlock (TopRegion) enclosing the plain CFG. + VPBasicBlock *buildPlainCFG(); public: VPlanHCFGBuilder(Loop *Lp, LoopInfo *LI, VPlan &P) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanLoopInfo.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanLoopInfo.h deleted file mode 100644 index 5208f2d58e2b..000000000000 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanLoopInfo.h +++ /dev/null @@ -1,44 +0,0 @@ -//===-- VPLoopInfo.h --------------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file defines VPLoopInfo analysis and VPLoop class. VPLoopInfo is a -/// specialization of LoopInfoBase for VPBlockBase. VPLoops is a specialization -/// of LoopBase that is used to hold loop metadata from VPLoopInfo. Further -/// information can be found in VectorizationPlanner.rst. -/// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLOOPINFO_H -#define LLVM_TRANSFORMS_VECTORIZE_VPLOOPINFO_H - -#include "llvm/Analysis/LoopInfoImpl.h" - -namespace llvm { -class VPBlockBase; - -/// Hold analysis information for every loop detected by VPLoopInfo. It is an -/// instantiation of LoopBase. -class VPLoop : public LoopBase<VPBlockBase, VPLoop> { -private: - friend class LoopInfoBase<VPBlockBase, VPLoop>; - explicit VPLoop(VPBlockBase *VPB) : LoopBase<VPBlockBase, VPLoop>(VPB) {} -}; - -/// VPLoopInfo provides analysis of natural loop for VPBlockBase-based -/// Hierarchical CFG. It is a specialization of LoopInfoBase class. -// TODO: VPLoopInfo is initially computed on top of the VPlan plain CFG, which -// is the same as the incoming IR CFG. If it's more efficient than running the -// whole loop detection algorithm, we may want to create a mechanism to -// translate LoopInfo into VPLoopInfo. However, that would require significant -// changes in LoopInfoBase class. -typedef LoopInfoBase<VPBlockBase, VPLoop> VPLoopInfo; - -} // namespace llvm - -#endif // LLVM_TRANSFORMS_VECTORIZE_VPLOOPINFO_H diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp deleted file mode 100644 index e879a33db6ee..000000000000 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp +++ /dev/null @@ -1,248 +0,0 @@ -//===-- VPlanPredicator.cpp -------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file implements the VPlanPredicator class which contains the public -/// interfaces to predicate and linearize the VPlan region. -/// -//===----------------------------------------------------------------------===// - -#include "VPlanPredicator.h" -#include "VPlan.h" -#include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/GraphTraits.h" -#include "llvm/ADT/PostOrderIterator.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" - -#define DEBUG_TYPE "VPlanPredicator" - -using namespace llvm; - -// Generate VPInstructions at the beginning of CurrBB that calculate the -// predicate being propagated from PredBB to CurrBB depending on the edge type -// between them. For example if: -// i. PredBB is controlled by predicate %BP, and -// ii. The edge PredBB->CurrBB is the false edge, controlled by the condition -// bit value %CBV then this function will generate the following two -// VPInstructions at the start of CurrBB: -// %IntermediateVal = not %CBV -// %FinalVal = and %BP %IntermediateVal -// It returns %FinalVal. -VPValue *VPlanPredicator::getOrCreateNotPredicate(VPBasicBlock *PredBB, - VPBasicBlock *CurrBB) { - VPValue *CBV = PredBB->getCondBit(); - - // Set the intermediate value - this is either 'CBV', or 'not CBV' - // depending on the edge type. - EdgeType ET = getEdgeTypeBetween(PredBB, CurrBB); - VPValue *IntermediateVal = nullptr; - switch (ET) { - case EdgeType::TRUE_EDGE: - // CurrBB is the true successor of PredBB - nothing to do here. - IntermediateVal = CBV; - break; - - case EdgeType::FALSE_EDGE: - // CurrBB is the False successor of PredBB - compute not of CBV. - IntermediateVal = Builder.createNot(CBV, {}); - break; - } - - // Now AND intermediate value with PredBB's block predicate if it has one. - VPValue *BP = PredBB->getPredicate(); - if (BP) - return Builder.createAnd(BP, IntermediateVal, {}); - else - return IntermediateVal; -} - -// Generate a tree of ORs for all IncomingPredicates in WorkList. -// Note: This function destroys the original Worklist. -// -// P1 P2 P3 P4 P5 -// \ / \ / / -// OR1 OR2 / -// \ | / -// \ +/-+ -// \ / | -// OR3 | -// \ | -// OR4 <- Returns this -// | -// -// The algorithm uses a worklist of predicates as its main data structure. -// We pop a pair of values from the front (e.g. P1 and P2), generate an OR -// (in this example OR1), and push it back. In this example the worklist -// contains {P3, P4, P5, OR1}. -// The process iterates until we have only one element in the Worklist (OR4). -// The last element is the root predicate which is returned. -VPValue *VPlanPredicator::genPredicateTree(std::list<VPValue *> &Worklist) { - if (Worklist.empty()) - return nullptr; - - // The worklist initially contains all the leaf nodes. Initialize the tree - // using them. - while (Worklist.size() >= 2) { - // Pop a pair of values from the front. - VPValue *LHS = Worklist.front(); - Worklist.pop_front(); - VPValue *RHS = Worklist.front(); - Worklist.pop_front(); - - // Create an OR of these values. - VPValue *Or = Builder.createOr(LHS, RHS, {}); - - // Push OR to the back of the worklist. - Worklist.push_back(Or); - } - - assert(Worklist.size() == 1 && "Expected 1 item in worklist"); - - // The root is the last node in the worklist. - VPValue *Root = Worklist.front(); - - // This root needs to replace the existing block predicate. This is done in - // the caller function. - return Root; -} - -// Return whether the edge FromBlock -> ToBlock is a TRUE_EDGE or FALSE_EDGE -VPlanPredicator::EdgeType -VPlanPredicator::getEdgeTypeBetween(VPBlockBase *FromBlock, - VPBlockBase *ToBlock) { - unsigned Count = 0; - for (VPBlockBase *SuccBlock : FromBlock->getSuccessors()) { - if (SuccBlock == ToBlock) { - assert(Count < 2 && "Switch not supported currently"); - return (Count == 0) ? EdgeType::TRUE_EDGE : EdgeType::FALSE_EDGE; - } - Count++; - } - - llvm_unreachable("Broken getEdgeTypeBetween"); -} - -// Generate all predicates needed for CurrBlock by going through its immediate -// predecessor blocks. -void VPlanPredicator::createOrPropagatePredicates(VPBlockBase *CurrBlock, - VPRegionBlock *Region) { - // Blocks that dominate region exit inherit the predicate from the region. - // Return after setting the predicate. - if (VPDomTree.dominates(CurrBlock, Region->getExit())) { - VPValue *RegionBP = Region->getPredicate(); - CurrBlock->setPredicate(RegionBP); - return; - } - - // Collect all incoming predicates in a worklist. - std::list<VPValue *> IncomingPredicates; - - // Set the builder's insertion point to the top of the current BB - VPBasicBlock *CurrBB = cast<VPBasicBlock>(CurrBlock->getEntryBasicBlock()); - Builder.setInsertPoint(CurrBB, CurrBB->begin()); - - // For each predecessor, generate the VPInstructions required for - // computing 'BP AND (not) CBV" at the top of CurrBB. - // Collect the outcome of this calculation for all predecessors - // into IncomingPredicates. - for (VPBlockBase *PredBlock : CurrBlock->getPredecessors()) { - // Skip back-edges - if (VPBlockUtils::isBackEdge(PredBlock, CurrBlock, VPLI)) - continue; - - VPValue *IncomingPredicate = nullptr; - unsigned NumPredSuccsNoBE = - VPBlockUtils::countSuccessorsNoBE(PredBlock, VPLI); - - // If there is an unconditional branch to the currBB, then we don't create - // edge predicates. We use the predecessor's block predicate instead. - if (NumPredSuccsNoBE == 1) - IncomingPredicate = PredBlock->getPredicate(); - else if (NumPredSuccsNoBE == 2) { - // Emit recipes into CurrBlock if required - assert(isa<VPBasicBlock>(PredBlock) && "Only BBs have multiple exits"); - IncomingPredicate = - getOrCreateNotPredicate(cast<VPBasicBlock>(PredBlock), CurrBB); - } else - llvm_unreachable("FIXME: switch statement ?"); - - if (IncomingPredicate) - IncomingPredicates.push_back(IncomingPredicate); - } - - // Logically OR all incoming predicates by building the Predicate Tree. - VPValue *Predicate = genPredicateTree(IncomingPredicates); - - // Now update the block's predicate with the new one. - CurrBlock->setPredicate(Predicate); -} - -// Generate all predicates needed for Region. -void VPlanPredicator::predicateRegionRec(VPRegionBlock *Region) { - VPBasicBlock *EntryBlock = cast<VPBasicBlock>(Region->getEntry()); - ReversePostOrderTraversal<VPBlockBase *> RPOT(EntryBlock); - - // Generate edge predicates and append them to the block predicate. RPO is - // necessary since the predecessor blocks' block predicate needs to be set - // before the current block's block predicate can be computed. - for (VPBlockBase *Block : RPOT) { - // TODO: Handle nested regions once we start generating the same. - assert(!isa<VPRegionBlock>(Block) && "Nested region not expected"); - createOrPropagatePredicates(Block, Region); - } -} - -// Linearize the CFG within Region. -// TODO: Predication and linearization need RPOT for every region. -// This traversal is expensive. Since predication is not adding new -// blocks, we should be able to compute RPOT once in predication and -// reuse it here. This becomes even more important once we have nested -// regions. -void VPlanPredicator::linearizeRegionRec(VPRegionBlock *Region) { - ReversePostOrderTraversal<VPBlockBase *> RPOT(Region->getEntry()); - VPBlockBase *PrevBlock = nullptr; - - for (VPBlockBase *CurrBlock : RPOT) { - // TODO: Handle nested regions once we start generating the same. - assert(!isa<VPRegionBlock>(CurrBlock) && "Nested region not expected"); - - // Linearize control flow by adding an unconditional edge between PrevBlock - // and CurrBlock skipping loop headers and latches to keep intact loop - // header predecessors and loop latch successors. - if (PrevBlock && !VPLI->isLoopHeader(CurrBlock) && - !VPBlockUtils::blockIsLoopLatch(PrevBlock, VPLI)) { - - LLVM_DEBUG(dbgs() << "Linearizing: " << PrevBlock->getName() << "->" - << CurrBlock->getName() << "\n"); - - PrevBlock->clearSuccessors(); - CurrBlock->clearPredecessors(); - VPBlockUtils::connectBlocks(PrevBlock, CurrBlock); - } - - PrevBlock = CurrBlock; - } -} - -// Entry point. The driver function for the predicator. -void VPlanPredicator::predicate() { - // Predicate the blocks within Region. - predicateRegionRec(cast<VPRegionBlock>(Plan.getEntry())); - - // Linearlize the blocks with Region. - linearizeRegionRec(cast<VPRegionBlock>(Plan.getEntry())); -} - -VPlanPredicator::VPlanPredicator(VPlan &Plan) - : Plan(Plan), VPLI(&(Plan.getVPLoopInfo())) { - // FIXME: Predicator is currently computing the dominator information for the - // top region. Once we start storing dominator information in a VPRegionBlock, - // we can avoid this recalculation. - VPDomTree.recalculate(*(cast<VPRegionBlock>(Plan.getEntry()))); -} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanPredicator.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanPredicator.h deleted file mode 100644 index a5db9a54da3c..000000000000 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanPredicator.h +++ /dev/null @@ -1,74 +0,0 @@ -//===-- VPlanPredicator.h ---------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file defines the VPlanPredicator class which contains the public -/// interfaces to predicate and linearize the VPlan region. -/// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLAN_PREDICATOR_H -#define LLVM_TRANSFORMS_VECTORIZE_VPLAN_PREDICATOR_H - -#include "LoopVectorizationPlanner.h" -#include "VPlan.h" -#include "VPlanDominatorTree.h" - -namespace llvm { - -class VPlanPredicator { -private: - enum class EdgeType { - TRUE_EDGE, - FALSE_EDGE, - }; - - // VPlan being predicated. - VPlan &Plan; - - // VPLoopInfo for Plan's HCFG. - VPLoopInfo *VPLI; - - // Dominator tree for Plan's HCFG. - VPDominatorTree VPDomTree; - - // VPlan builder used to generate VPInstructions for block predicates. - VPBuilder Builder; - - /// Get the type of edge from \p FromBlock to \p ToBlock. Returns TRUE_EDGE if - /// \p ToBlock is either the unconditional successor or the conditional true - /// successor of \p FromBlock and FALSE_EDGE otherwise. - EdgeType getEdgeTypeBetween(VPBlockBase *FromBlock, VPBlockBase *ToBlock); - - /// Create and return VPValue corresponding to the predicate for the edge from - /// \p PredBB to \p CurrentBlock. - VPValue *getOrCreateNotPredicate(VPBasicBlock *PredBB, VPBasicBlock *CurrBB); - - /// Generate and return the result of ORing all the predicate VPValues in \p - /// Worklist. - VPValue *genPredicateTree(std::list<VPValue *> &Worklist); - - /// Create or propagate predicate for \p CurrBlock in region \p Region using - /// predicate(s) of its predecessor(s) - void createOrPropagatePredicates(VPBlockBase *CurrBlock, - VPRegionBlock *Region); - - /// Predicate the CFG within \p Region. - void predicateRegionRec(VPRegionBlock *Region); - - /// Linearize the CFG within \p Region. - void linearizeRegionRec(VPRegionBlock *Region); - -public: - VPlanPredicator(VPlan &Plan); - - /// Predicate Plan's HCFG. - void predicate(); -}; -} // end namespace llvm -#endif // LLVM_TRANSFORMS_VECTORIZE_VPLAN_PREDICATOR_H diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp new file mode 100644 index 000000000000..92422b17457c --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -0,0 +1,840 @@ +//===- VPlanRecipes.cpp - Implementations for VPlan recipes ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains implementations for different VPlan recipes. +/// +//===----------------------------------------------------------------------===// + +#include "VPlan.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/IVDescriptors.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" +#include <cassert> + +using namespace llvm; + +extern cl::opt<bool> EnableVPlanNativePath; + +bool VPRecipeBase::mayWriteToMemory() const { + switch (getVPDefID()) { + case VPWidenMemoryInstructionSC: { + return cast<VPWidenMemoryInstructionRecipe>(this)->isStore(); + } + case VPReplicateSC: + case VPWidenCallSC: + return cast<Instruction>(getVPSingleValue()->getUnderlyingValue()) + ->mayWriteToMemory(); + case VPBranchOnMaskSC: + return false; + case VPWidenIntOrFpInductionSC: + case VPWidenCanonicalIVSC: + case VPWidenPHISC: + case VPBlendSC: + case VPWidenSC: + case VPWidenGEPSC: + case VPReductionSC: + case VPWidenSelectSC: { + const Instruction *I = + dyn_cast_or_null<Instruction>(getVPSingleValue()->getUnderlyingValue()); + (void)I; + assert((!I || !I->mayWriteToMemory()) && + "underlying instruction may write to memory"); + return false; + } + default: + return true; + } +} + +bool VPRecipeBase::mayReadFromMemory() const { + switch (getVPDefID()) { + case VPWidenMemoryInstructionSC: { + return !cast<VPWidenMemoryInstructionRecipe>(this)->isStore(); + } + case VPReplicateSC: + case VPWidenCallSC: + return cast<Instruction>(getVPSingleValue()->getUnderlyingValue()) + ->mayReadFromMemory(); + case VPBranchOnMaskSC: + return false; + case VPWidenIntOrFpInductionSC: + case VPWidenCanonicalIVSC: + case VPWidenPHISC: + case VPBlendSC: + case VPWidenSC: + case VPWidenGEPSC: + case VPReductionSC: + case VPWidenSelectSC: { + const Instruction *I = + dyn_cast_or_null<Instruction>(getVPSingleValue()->getUnderlyingValue()); + (void)I; + assert((!I || !I->mayReadFromMemory()) && + "underlying instruction may read from memory"); + return false; + } + default: + return true; + } +} + +bool VPRecipeBase::mayHaveSideEffects() const { + switch (getVPDefID()) { + case VPWidenIntOrFpInductionSC: + case VPWidenPointerInductionSC: + case VPWidenCanonicalIVSC: + case VPWidenPHISC: + case VPBlendSC: + case VPWidenSC: + case VPWidenGEPSC: + case VPReductionSC: + case VPWidenSelectSC: + case VPScalarIVStepsSC: { + const Instruction *I = + dyn_cast_or_null<Instruction>(getVPSingleValue()->getUnderlyingValue()); + (void)I; + assert((!I || !I->mayHaveSideEffects()) && + "underlying instruction has side-effects"); + return false; + } + case VPReplicateSC: { + auto *R = cast<VPReplicateRecipe>(this); + return R->getUnderlyingInstr()->mayHaveSideEffects(); + } + default: + return true; + } +} + +void VPLiveOut::fixPhi(VPlan &Plan, VPTransformState &State) { + auto Lane = VPLane::getLastLaneForVF(State.VF); + VPValue *ExitValue = getOperand(0); + if (Plan.isUniformAfterVectorization(ExitValue)) + Lane = VPLane::getFirstLane(); + Phi->addIncoming(State.get(ExitValue, VPIteration(State.UF - 1, Lane)), + State.Builder.GetInsertBlock()); +} + +void VPRecipeBase::insertBefore(VPRecipeBase *InsertPos) { + assert(!Parent && "Recipe already in some VPBasicBlock"); + assert(InsertPos->getParent() && + "Insertion position not in any VPBasicBlock"); + Parent = InsertPos->getParent(); + Parent->getRecipeList().insert(InsertPos->getIterator(), this); +} + +void VPRecipeBase::insertBefore(VPBasicBlock &BB, + iplist<VPRecipeBase>::iterator I) { + assert(!Parent && "Recipe already in some VPBasicBlock"); + assert(I == BB.end() || I->getParent() == &BB); + Parent = &BB; + BB.getRecipeList().insert(I, this); +} + +void VPRecipeBase::insertAfter(VPRecipeBase *InsertPos) { + assert(!Parent && "Recipe already in some VPBasicBlock"); + assert(InsertPos->getParent() && + "Insertion position not in any VPBasicBlock"); + Parent = InsertPos->getParent(); + Parent->getRecipeList().insertAfter(InsertPos->getIterator(), this); +} + +void VPRecipeBase::removeFromParent() { + assert(getParent() && "Recipe not in any VPBasicBlock"); + getParent()->getRecipeList().remove(getIterator()); + Parent = nullptr; +} + +iplist<VPRecipeBase>::iterator VPRecipeBase::eraseFromParent() { + assert(getParent() && "Recipe not in any VPBasicBlock"); + return getParent()->getRecipeList().erase(getIterator()); +} + +void VPRecipeBase::moveAfter(VPRecipeBase *InsertPos) { + removeFromParent(); + insertAfter(InsertPos); +} + +void VPRecipeBase::moveBefore(VPBasicBlock &BB, + iplist<VPRecipeBase>::iterator I) { + removeFromParent(); + insertBefore(BB, I); +} + +void VPInstruction::generateInstruction(VPTransformState &State, + unsigned Part) { + IRBuilderBase &Builder = State.Builder; + Builder.SetCurrentDebugLocation(DL); + + if (Instruction::isBinaryOp(getOpcode())) { + Value *A = State.get(getOperand(0), Part); + Value *B = State.get(getOperand(1), Part); + Value *V = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B); + State.set(this, V, Part); + return; + } + + switch (getOpcode()) { + case VPInstruction::Not: { + Value *A = State.get(getOperand(0), Part); + Value *V = Builder.CreateNot(A); + State.set(this, V, Part); + break; + } + case VPInstruction::ICmpULE: { + Value *IV = State.get(getOperand(0), Part); + Value *TC = State.get(getOperand(1), Part); + Value *V = Builder.CreateICmpULE(IV, TC); + State.set(this, V, Part); + break; + } + case Instruction::Select: { + Value *Cond = State.get(getOperand(0), Part); + Value *Op1 = State.get(getOperand(1), Part); + Value *Op2 = State.get(getOperand(2), Part); + Value *V = Builder.CreateSelect(Cond, Op1, Op2); + State.set(this, V, Part); + break; + } + case VPInstruction::ActiveLaneMask: { + // Get first lane of vector induction variable. + Value *VIVElem0 = State.get(getOperand(0), VPIteration(Part, 0)); + // Get the original loop tripcount. + Value *ScalarTC = State.get(getOperand(1), Part); + + auto *Int1Ty = Type::getInt1Ty(Builder.getContext()); + auto *PredTy = VectorType::get(Int1Ty, State.VF); + Instruction *Call = Builder.CreateIntrinsic( + Intrinsic::get_active_lane_mask, {PredTy, ScalarTC->getType()}, + {VIVElem0, ScalarTC}, nullptr, "active.lane.mask"); + State.set(this, Call, Part); + break; + } + case VPInstruction::FirstOrderRecurrenceSplice: { + // Generate code to combine the previous and current values in vector v3. + // + // vector.ph: + // v_init = vector(..., ..., ..., a[-1]) + // br vector.body + // + // vector.body + // i = phi [0, vector.ph], [i+4, vector.body] + // v1 = phi [v_init, vector.ph], [v2, vector.body] + // v2 = a[i, i+1, i+2, i+3]; + // v3 = vector(v1(3), v2(0, 1, 2)) + + // For the first part, use the recurrence phi (v1), otherwise v2. + auto *V1 = State.get(getOperand(0), 0); + Value *PartMinus1 = Part == 0 ? V1 : State.get(getOperand(1), Part - 1); + if (!PartMinus1->getType()->isVectorTy()) { + State.set(this, PartMinus1, Part); + } else { + Value *V2 = State.get(getOperand(1), Part); + State.set(this, Builder.CreateVectorSplice(PartMinus1, V2, -1), Part); + } + break; + } + case VPInstruction::CanonicalIVIncrement: + case VPInstruction::CanonicalIVIncrementNUW: { + Value *Next = nullptr; + if (Part == 0) { + bool IsNUW = getOpcode() == VPInstruction::CanonicalIVIncrementNUW; + auto *Phi = State.get(getOperand(0), 0); + // The loop step is equal to the vectorization factor (num of SIMD + // elements) times the unroll factor (num of SIMD instructions). + Value *Step = + createStepForVF(Builder, Phi->getType(), State.VF, State.UF); + Next = Builder.CreateAdd(Phi, Step, "index.next", IsNUW, false); + } else { + Next = State.get(this, 0); + } + + State.set(this, Next, Part); + break; + } + case VPInstruction::BranchOnCond: { + if (Part != 0) + break; + + Value *Cond = State.get(getOperand(0), VPIteration(Part, 0)); + VPRegionBlock *ParentRegion = getParent()->getParent(); + VPBasicBlock *Header = ParentRegion->getEntryBasicBlock(); + + // Replace the temporary unreachable terminator with a new conditional + // branch, hooking it up to backward destination for exiting blocks now and + // to forward destination(s) later when they are created. + BranchInst *CondBr = + Builder.CreateCondBr(Cond, Builder.GetInsertBlock(), nullptr); + + if (getParent()->isExiting()) + CondBr->setSuccessor(1, State.CFG.VPBB2IRBB[Header]); + + CondBr->setSuccessor(0, nullptr); + Builder.GetInsertBlock()->getTerminator()->eraseFromParent(); + break; + } + case VPInstruction::BranchOnCount: { + if (Part != 0) + break; + // First create the compare. + Value *IV = State.get(getOperand(0), Part); + Value *TC = State.get(getOperand(1), Part); + Value *Cond = Builder.CreateICmpEQ(IV, TC); + + // Now create the branch. + auto *Plan = getParent()->getPlan(); + VPRegionBlock *TopRegion = Plan->getVectorLoopRegion(); + VPBasicBlock *Header = TopRegion->getEntry()->getEntryBasicBlock(); + + // Replace the temporary unreachable terminator with a new conditional + // branch, hooking it up to backward destination (the header) now and to the + // forward destination (the exit/middle block) later when it is created. + // Note that CreateCondBr expects a valid BB as first argument, so we need + // to set it to nullptr later. + BranchInst *CondBr = Builder.CreateCondBr(Cond, Builder.GetInsertBlock(), + State.CFG.VPBB2IRBB[Header]); + CondBr->setSuccessor(0, nullptr); + Builder.GetInsertBlock()->getTerminator()->eraseFromParent(); + break; + } + default: + llvm_unreachable("Unsupported opcode for instruction"); + } +} + +void VPInstruction::execute(VPTransformState &State) { + assert(!State.Instance && "VPInstruction executing an Instance"); + IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder); + State.Builder.setFastMathFlags(FMF); + for (unsigned Part = 0; Part < State.UF; ++Part) + generateInstruction(State, Part); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void VPInstruction::dump() const { + VPSlotTracker SlotTracker(getParent()->getPlan()); + print(dbgs(), "", SlotTracker); +} + +void VPInstruction::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "EMIT "; + + if (hasResult()) { + printAsOperand(O, SlotTracker); + O << " = "; + } + + switch (getOpcode()) { + case VPInstruction::Not: + O << "not"; + break; + case VPInstruction::ICmpULE: + O << "icmp ule"; + break; + case VPInstruction::SLPLoad: + O << "combined load"; + break; + case VPInstruction::SLPStore: + O << "combined store"; + break; + case VPInstruction::ActiveLaneMask: + O << "active lane mask"; + break; + case VPInstruction::FirstOrderRecurrenceSplice: + O << "first-order splice"; + break; + case VPInstruction::CanonicalIVIncrement: + O << "VF * UF + "; + break; + case VPInstruction::CanonicalIVIncrementNUW: + O << "VF * UF +(nuw) "; + break; + case VPInstruction::BranchOnCond: + O << "branch-on-cond"; + break; + case VPInstruction::BranchOnCount: + O << "branch-on-count "; + break; + default: + O << Instruction::getOpcodeName(getOpcode()); + } + + O << FMF; + + for (const VPValue *Operand : operands()) { + O << " "; + Operand->printAsOperand(O, SlotTracker); + } + + if (DL) { + O << ", !dbg "; + DL.print(O); + } +} +#endif + +void VPInstruction::setFastMathFlags(FastMathFlags FMFNew) { + // Make sure the VPInstruction is a floating-point operation. + assert((Opcode == Instruction::FAdd || Opcode == Instruction::FMul || + Opcode == Instruction::FNeg || Opcode == Instruction::FSub || + Opcode == Instruction::FDiv || Opcode == Instruction::FRem || + Opcode == Instruction::FCmp) && + "this op can't take fast-math flags"); + FMF = FMFNew; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void VPWidenCallRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "WIDEN-CALL "; + + auto *CI = cast<CallInst>(getUnderlyingInstr()); + if (CI->getType()->isVoidTy()) + O << "void "; + else { + printAsOperand(O, SlotTracker); + O << " = "; + } + + O << "call @" << CI->getCalledFunction()->getName() << "("; + printOperands(O, SlotTracker); + O << ")"; +} + +void VPWidenSelectRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "WIDEN-SELECT "; + printAsOperand(O, SlotTracker); + O << " = select "; + getOperand(0)->printAsOperand(O, SlotTracker); + O << ", "; + getOperand(1)->printAsOperand(O, SlotTracker); + O << ", "; + getOperand(2)->printAsOperand(O, SlotTracker); + O << (InvariantCond ? " (condition is loop invariant)" : ""); +} + +void VPWidenRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "WIDEN "; + printAsOperand(O, SlotTracker); + O << " = " << getUnderlyingInstr()->getOpcodeName() << " "; + printOperands(O, SlotTracker); +} + +void VPWidenIntOrFpInductionRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "WIDEN-INDUCTION"; + if (getTruncInst()) { + O << "\\l\""; + O << " +\n" << Indent << "\" " << VPlanIngredient(IV) << "\\l\""; + O << " +\n" << Indent << "\" "; + getVPValue(0)->printAsOperand(O, SlotTracker); + } else + O << " " << VPlanIngredient(IV); + + O << ", "; + getStepValue()->printAsOperand(O, SlotTracker); +} +#endif + +bool VPWidenIntOrFpInductionRecipe::isCanonical() const { + auto *StartC = dyn_cast<ConstantInt>(getStartValue()->getLiveInIRValue()); + auto *StepC = dyn_cast<SCEVConstant>(getInductionDescriptor().getStep()); + return StartC && StartC->isZero() && StepC && StepC->isOne(); +} + +VPCanonicalIVPHIRecipe *VPScalarIVStepsRecipe::getCanonicalIV() const { + return cast<VPCanonicalIVPHIRecipe>(getOperand(0)); +} + +bool VPScalarIVStepsRecipe::isCanonical() const { + auto *CanIV = getCanonicalIV(); + // The start value of the steps-recipe must match the start value of the + // canonical induction and it must step by 1. + if (CanIV->getStartValue() != getStartValue()) + return false; + auto *StepVPV = getStepValue(); + if (StepVPV->getDef()) + return false; + auto *StepC = dyn_cast_or_null<ConstantInt>(StepVPV->getLiveInIRValue()); + return StepC && StepC->isOne(); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void VPScalarIVStepsRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent; + printAsOperand(O, SlotTracker); + O << Indent << "= SCALAR-STEPS "; + printOperands(O, SlotTracker); +} + +void VPWidenGEPRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "WIDEN-GEP "; + O << (IsPtrLoopInvariant ? "Inv" : "Var"); + size_t IndicesNumber = IsIndexLoopInvariant.size(); + for (size_t I = 0; I < IndicesNumber; ++I) + O << "[" << (IsIndexLoopInvariant[I] ? "Inv" : "Var") << "]"; + + O << " "; + printAsOperand(O, SlotTracker); + O << " = getelementptr "; + printOperands(O, SlotTracker); +} + +void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "BLEND "; + Phi->printAsOperand(O, false); + O << " ="; + if (getNumIncomingValues() == 1) { + // Not a User of any mask: not really blending, this is a + // single-predecessor phi. + O << " "; + getIncomingValue(0)->printAsOperand(O, SlotTracker); + } else { + for (unsigned I = 0, E = getNumIncomingValues(); I < E; ++I) { + O << " "; + getIncomingValue(I)->printAsOperand(O, SlotTracker); + O << "/"; + getMask(I)->printAsOperand(O, SlotTracker); + } + } +} + +void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "REDUCE "; + printAsOperand(O, SlotTracker); + O << " = "; + getChainOp()->printAsOperand(O, SlotTracker); + O << " +"; + if (isa<FPMathOperator>(getUnderlyingInstr())) + O << getUnderlyingInstr()->getFastMathFlags(); + O << " reduce." << Instruction::getOpcodeName(RdxDesc->getOpcode()) << " ("; + getVecOp()->printAsOperand(O, SlotTracker); + if (getCondOp()) { + O << ", "; + getCondOp()->printAsOperand(O, SlotTracker); + } + O << ")"; + if (RdxDesc->IntermediateStore) + O << " (with final reduction value stored in invariant address sank " + "outside of loop)"; +} + +void VPReplicateRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << (IsUniform ? "CLONE " : "REPLICATE "); + + if (!getUnderlyingInstr()->getType()->isVoidTy()) { + printAsOperand(O, SlotTracker); + O << " = "; + } + if (auto *CB = dyn_cast<CallBase>(getUnderlyingInstr())) { + O << "call @" << CB->getCalledFunction()->getName() << "("; + interleaveComma(make_range(op_begin(), op_begin() + (getNumOperands() - 1)), + O, [&O, &SlotTracker](VPValue *Op) { + Op->printAsOperand(O, SlotTracker); + }); + O << ")"; + } else { + O << Instruction::getOpcodeName(getUnderlyingInstr()->getOpcode()) << " "; + printOperands(O, SlotTracker); + } + + if (AlsoPack) + O << " (S->V)"; +} + +void VPPredInstPHIRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "PHI-PREDICATED-INSTRUCTION "; + printAsOperand(O, SlotTracker); + O << " = "; + printOperands(O, SlotTracker); +} + +void VPWidenMemoryInstructionRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "WIDEN "; + + if (!isStore()) { + getVPSingleValue()->printAsOperand(O, SlotTracker); + O << " = "; + } + O << Instruction::getOpcodeName(Ingredient.getOpcode()) << " "; + + printOperands(O, SlotTracker); +} +#endif + +void VPCanonicalIVPHIRecipe::execute(VPTransformState &State) { + Value *Start = getStartValue()->getLiveInIRValue(); + PHINode *EntryPart = PHINode::Create( + Start->getType(), 2, "index", &*State.CFG.PrevBB->getFirstInsertionPt()); + + BasicBlock *VectorPH = State.CFG.getPreheaderBBFor(this); + EntryPart->addIncoming(Start, VectorPH); + EntryPart->setDebugLoc(DL); + for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part) + State.set(this, EntryPart, Part); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void VPCanonicalIVPHIRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "EMIT "; + printAsOperand(O, SlotTracker); + O << " = CANONICAL-INDUCTION"; +} +#endif + +bool VPWidenPointerInductionRecipe::onlyScalarsGenerated(ElementCount VF) { + bool IsUniform = vputils::onlyFirstLaneUsed(this); + return all_of(users(), + [&](const VPUser *U) { return U->usesScalars(this); }) && + (IsUniform || !VF.isScalable()); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void VPWidenPointerInductionRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "EMIT "; + printAsOperand(O, SlotTracker); + O << " = WIDEN-POINTER-INDUCTION "; + getStartValue()->printAsOperand(O, SlotTracker); + O << ", " << *IndDesc.getStep(); +} +#endif + +void VPExpandSCEVRecipe::execute(VPTransformState &State) { + assert(!State.Instance && "cannot be used in per-lane"); + const DataLayout &DL = State.CFG.PrevBB->getModule()->getDataLayout(); + SCEVExpander Exp(SE, DL, "induction"); + + Value *Res = Exp.expandCodeFor(Expr, Expr->getType(), + &*State.Builder.GetInsertPoint()); + + for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part) + State.set(this, Res, Part); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void VPExpandSCEVRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "EMIT "; + getVPSingleValue()->printAsOperand(O, SlotTracker); + O << " = EXPAND SCEV " << *Expr; +} +#endif + +void VPWidenCanonicalIVRecipe::execute(VPTransformState &State) { + Value *CanonicalIV = State.get(getOperand(0), 0); + Type *STy = CanonicalIV->getType(); + IRBuilder<> Builder(State.CFG.PrevBB->getTerminator()); + ElementCount VF = State.VF; + Value *VStart = VF.isScalar() + ? CanonicalIV + : Builder.CreateVectorSplat(VF, CanonicalIV, "broadcast"); + for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part) { + Value *VStep = createStepForVF(Builder, STy, VF, Part); + if (VF.isVector()) { + VStep = Builder.CreateVectorSplat(VF, VStep); + VStep = + Builder.CreateAdd(VStep, Builder.CreateStepVector(VStep->getType())); + } + Value *CanonicalVectorIV = Builder.CreateAdd(VStart, VStep, "vec.iv"); + State.set(this, CanonicalVectorIV, Part); + } +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void VPWidenCanonicalIVRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "EMIT "; + printAsOperand(O, SlotTracker); + O << " = WIDEN-CANONICAL-INDUCTION "; + printOperands(O, SlotTracker); +} +#endif + +void VPFirstOrderRecurrencePHIRecipe::execute(VPTransformState &State) { + auto &Builder = State.Builder; + // Create a vector from the initial value. + auto *VectorInit = getStartValue()->getLiveInIRValue(); + + Type *VecTy = State.VF.isScalar() + ? VectorInit->getType() + : VectorType::get(VectorInit->getType(), State.VF); + + BasicBlock *VectorPH = State.CFG.getPreheaderBBFor(this); + if (State.VF.isVector()) { + auto *IdxTy = Builder.getInt32Ty(); + auto *One = ConstantInt::get(IdxTy, 1); + IRBuilder<>::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(VectorPH->getTerminator()); + auto *RuntimeVF = getRuntimeVF(Builder, IdxTy, State.VF); + auto *LastIdx = Builder.CreateSub(RuntimeVF, One); + VectorInit = Builder.CreateInsertElement( + PoisonValue::get(VecTy), VectorInit, LastIdx, "vector.recur.init"); + } + + // Create a phi node for the new recurrence. + PHINode *EntryPart = PHINode::Create( + VecTy, 2, "vector.recur", &*State.CFG.PrevBB->getFirstInsertionPt()); + EntryPart->addIncoming(VectorInit, VectorPH); + State.set(this, EntryPart, 0); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void VPFirstOrderRecurrencePHIRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "FIRST-ORDER-RECURRENCE-PHI "; + printAsOperand(O, SlotTracker); + O << " = phi "; + printOperands(O, SlotTracker); +} +#endif + +void VPReductionPHIRecipe::execute(VPTransformState &State) { + PHINode *PN = cast<PHINode>(getUnderlyingValue()); + auto &Builder = State.Builder; + + // In order to support recurrences we need to be able to vectorize Phi nodes. + // Phi nodes have cycles, so we need to vectorize them in two stages. This is + // stage #1: We create a new vector PHI node with no incoming edges. We'll use + // this value when we vectorize all of the instructions that use the PHI. + bool ScalarPHI = State.VF.isScalar() || IsInLoop; + Type *VecTy = + ScalarPHI ? PN->getType() : VectorType::get(PN->getType(), State.VF); + + BasicBlock *HeaderBB = State.CFG.PrevBB; + assert(State.CurrentVectorLoop->getHeader() == HeaderBB && + "recipe must be in the vector loop header"); + unsigned LastPartForNewPhi = isOrdered() ? 1 : State.UF; + for (unsigned Part = 0; Part < LastPartForNewPhi; ++Part) { + Value *EntryPart = + PHINode::Create(VecTy, 2, "vec.phi", &*HeaderBB->getFirstInsertionPt()); + State.set(this, EntryPart, Part); + } + + BasicBlock *VectorPH = State.CFG.getPreheaderBBFor(this); + + // Reductions do not have to start at zero. They can start with + // any loop invariant values. + VPValue *StartVPV = getStartValue(); + Value *StartV = StartVPV->getLiveInIRValue(); + + Value *Iden = nullptr; + RecurKind RK = RdxDesc.getRecurrenceKind(); + if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RK) || + RecurrenceDescriptor::isSelectCmpRecurrenceKind(RK)) { + // MinMax reduction have the start value as their identify. + if (ScalarPHI) { + Iden = StartV; + } else { + IRBuilderBase::InsertPointGuard IPBuilder(Builder); + Builder.SetInsertPoint(VectorPH->getTerminator()); + StartV = Iden = + Builder.CreateVectorSplat(State.VF, StartV, "minmax.ident"); + } + } else { + Iden = RdxDesc.getRecurrenceIdentity(RK, VecTy->getScalarType(), + RdxDesc.getFastMathFlags()); + + if (!ScalarPHI) { + Iden = Builder.CreateVectorSplat(State.VF, Iden); + IRBuilderBase::InsertPointGuard IPBuilder(Builder); + Builder.SetInsertPoint(VectorPH->getTerminator()); + Constant *Zero = Builder.getInt32(0); + StartV = Builder.CreateInsertElement(Iden, StartV, Zero); + } + } + + for (unsigned Part = 0; Part < LastPartForNewPhi; ++Part) { + Value *EntryPart = State.get(this, Part); + // Make sure to add the reduction start value only to the + // first unroll part. + Value *StartVal = (Part == 0) ? StartV : Iden; + cast<PHINode>(EntryPart)->addIncoming(StartVal, VectorPH); + } +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void VPReductionPHIRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "WIDEN-REDUCTION-PHI "; + + printAsOperand(O, SlotTracker); + O << " = phi "; + printOperands(O, SlotTracker); +} +#endif + +void VPWidenPHIRecipe::execute(VPTransformState &State) { + assert(EnableVPlanNativePath && + "Non-native vplans are not expected to have VPWidenPHIRecipes."); + + // Currently we enter here in the VPlan-native path for non-induction + // PHIs where all control flow is uniform. We simply widen these PHIs. + // Create a vector phi with no operands - the vector phi operands will be + // set at the end of vector code generation. + VPBasicBlock *Parent = getParent(); + VPRegionBlock *LoopRegion = Parent->getEnclosingLoopRegion(); + unsigned StartIdx = 0; + // For phis in header blocks of loop regions, use the index of the value + // coming from the preheader. + if (LoopRegion->getEntryBasicBlock() == Parent) { + for (unsigned I = 0; I < getNumOperands(); ++I) { + if (getIncomingBlock(I) == + LoopRegion->getSinglePredecessor()->getExitingBasicBlock()) + StartIdx = I; + } + } + Value *Op0 = State.get(getOperand(StartIdx), 0); + Type *VecTy = Op0->getType(); + Value *VecPhi = State.Builder.CreatePHI(VecTy, 2, "vec.phi"); + State.set(this, VecPhi, 0); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void VPWidenPHIRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "WIDEN-PHI "; + + auto *OriginalPhi = cast<PHINode>(getUnderlyingValue()); + // Unless all incoming values are modeled in VPlan print the original PHI + // directly. + // TODO: Remove once all VPWidenPHIRecipe instances keep all relevant incoming + // values as VPValues. + if (getNumOperands() != OriginalPhi->getNumOperands()) { + O << VPlanIngredient(OriginalPhi); + return; + } + + printAsOperand(O, SlotTracker); + O << " = phi "; + printOperands(O, SlotTracker); +} +#endif diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanSLP.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanSLP.cpp index 9e19e172dea5..3a7e77fd9efd 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanSLP.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanSLP.cpp @@ -15,16 +15,10 @@ //===----------------------------------------------------------------------===// #include "VPlan.h" -#include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/PostOrderIterator.h" +#include "VPlanValue.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/Twine.h" -#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/VectorUtils.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CFG.h" -#include "llvm/IR/Dominators.h" -#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Type.h" @@ -32,12 +26,9 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/GraphWriter.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include <algorithm> #include <cassert> -#include <iterator> #include <utility> using namespace llvm; @@ -396,7 +387,7 @@ VPInstruction *VPlanSlp::buildGraph(ArrayRef<VPValue *> Values) { return markFailed(); assert(getOpcode(Values) && "Opcodes for all values must match"); - unsigned ValuesOpcode = getOpcode(Values).getValue(); + unsigned ValuesOpcode = *getOpcode(Values); SmallVector<VPValue *, 4> CombinedOperands; if (areCommutative(Values)) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 70ce773a8a85..cca484e13bf1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -13,6 +13,8 @@ #include "VPlanTransforms.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Analysis/IVDescriptors.h" using namespace llvm; @@ -22,17 +24,15 @@ void VPlanTransforms::VPInstructionsToVPRecipes( GetIntOrFpInductionDescriptor, SmallPtrSetImpl<Instruction *> &DeadInstructions, ScalarEvolution &SE) { - auto *TopRegion = cast<VPRegionBlock>(Plan->getEntry()); - ReversePostOrderTraversal<VPBlockBase *> RPOT(TopRegion->getEntry()); - - for (VPBlockBase *Base : RPOT) { - // Do not widen instructions in pre-header and exit blocks. - if (Base->getNumPredecessors() == 0 || Base->getNumSuccessors() == 0) - continue; - - VPBasicBlock *VPBB = Base->getEntryBasicBlock(); + ReversePostOrderTraversal<VPBlockRecursiveTraversalWrapper<VPBlockBase *>> + RPOT(Plan->getEntry()); + for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT)) { + VPRecipeBase *Term = VPBB->getTerminator(); + auto EndIter = Term ? Term->getIterator() : VPBB->end(); // Introduce each ingredient into VPlan. - for (VPRecipeBase &Ingredient : llvm::make_early_inc_range(*VPBB)) { + for (VPRecipeBase &Ingredient : + make_early_inc_range(make_range(VPBB->begin(), EndIter))) { + VPValue *VPV = Ingredient.getVPSingleValue(); Instruction *Inst = cast<Instruction>(VPV->getUnderlyingValue()); if (DeadInstructions.count(Inst)) { @@ -47,8 +47,10 @@ void VPlanTransforms::VPInstructionsToVPRecipes( auto *Phi = cast<PHINode>(VPPhi->getUnderlyingValue()); if (const auto *II = GetIntOrFpInductionDescriptor(Phi)) { VPValue *Start = Plan->getOrAddVPValue(II->getStartValue()); + VPValue *Step = + vputils::getOrCreateVPValueForSCEVExpr(*Plan, II->getStep(), SE); NewRecipe = - new VPWidenIntOrFpInductionRecipe(Phi, Start, *II, false, true); + new VPWidenIntOrFpInductionRecipe(Phi, Start, Step, *II, true); } else { Plan->addVPValue(Phi, VPPhi); continue; @@ -295,14 +297,19 @@ bool VPlanTransforms::mergeReplicateRegions(VPlan &Plan) { } void VPlanTransforms::removeRedundantInductionCasts(VPlan &Plan) { - SmallVector<std::pair<VPRecipeBase *, VPValue *>> CastsToRemove; - for (auto &Phi : Plan.getEntry()->getEntryBasicBlock()->phis()) { + for (auto &Phi : Plan.getVectorLoopRegion()->getEntryBasicBlock()->phis()) { auto *IV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&Phi); if (!IV || IV->getTruncInst()) continue; - // Visit all casts connected to IV and in Casts. Collect them. - // remember them for removal. + // A sequence of IR Casts has potentially been recorded for IV, which + // *must be bypassed* when the IV is vectorized, because the vectorized IV + // will produce the desired casted value. This sequence forms a def-use + // chain and is provided in reverse order, ending with the cast that uses + // the IV phi. Search for the recipe of the last cast in the chain and + // replace it with the original IV. Note that only the final cast is + // expected to have users outside the cast-chain and the dead casts left + // over will be cleaned up later. auto &Casts = IV->getInductionDescriptor().getCastInsts(); VPValue *FindMyCast = IV; for (Instruction *IRCast : reverse(Casts)) { @@ -315,14 +322,9 @@ void VPlanTransforms::removeRedundantInductionCasts(VPlan &Plan) { break; } } - assert(FoundUserCast && "Missing a cast to remove"); - CastsToRemove.emplace_back(FoundUserCast, IV); FindMyCast = FoundUserCast->getVPSingleValue(); } - } - for (auto &E : CastsToRemove) { - E.first->getVPSingleValue()->replaceAllUsesWith(E.second); - E.first->eraseFromParent(); + FindMyCast->replaceAllUsesWith(IV); } } @@ -358,3 +360,73 @@ void VPlanTransforms::removeRedundantCanonicalIVs(VPlan &Plan) { } } } + +void VPlanTransforms::removeDeadRecipes(VPlan &Plan) { + ReversePostOrderTraversal<VPBlockRecursiveTraversalWrapper<VPBlockBase *>> + RPOT(Plan.getEntry()); + + for (VPBasicBlock *VPBB : reverse(VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT))) { + // The recipes in the block are processed in reverse order, to catch chains + // of dead recipes. + for (VPRecipeBase &R : make_early_inc_range(reverse(*VPBB))) { + if (R.mayHaveSideEffects() || any_of(R.definedValues(), [](VPValue *V) { + return V->getNumUsers() > 0; + })) + continue; + R.eraseFromParent(); + } + } +} + +void VPlanTransforms::optimizeInductions(VPlan &Plan, ScalarEvolution &SE) { + SmallVector<VPRecipeBase *> ToRemove; + VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock(); + bool HasOnlyVectorVFs = !Plan.hasVF(ElementCount::getFixed(1)); + for (VPRecipeBase &Phi : HeaderVPBB->phis()) { + auto *IV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&Phi); + if (!IV) + continue; + if (HasOnlyVectorVFs && + none_of(IV->users(), [IV](VPUser *U) { return U->usesScalars(IV); })) + continue; + + const InductionDescriptor &ID = IV->getInductionDescriptor(); + VPValue *Step = + vputils::getOrCreateVPValueForSCEVExpr(Plan, ID.getStep(), SE); + Instruction *TruncI = IV->getTruncInst(); + VPScalarIVStepsRecipe *Steps = new VPScalarIVStepsRecipe( + IV->getPHINode()->getType(), ID, Plan.getCanonicalIV(), + IV->getStartValue(), Step, TruncI ? TruncI->getType() : nullptr); + HeaderVPBB->insert(Steps, HeaderVPBB->getFirstNonPhi()); + + // Update scalar users of IV to use Step instead. Use SetVector to ensure + // the list of users doesn't contain duplicates. + SetVector<VPUser *> Users(IV->user_begin(), IV->user_end()); + for (VPUser *U : Users) { + if (HasOnlyVectorVFs && !U->usesScalars(IV)) + continue; + for (unsigned I = 0, E = U->getNumOperands(); I != E; I++) { + if (U->getOperand(I) != IV) + continue; + U->setOperand(I, Steps); + } + } + } +} + +void VPlanTransforms::removeRedundantExpandSCEVRecipes(VPlan &Plan) { + DenseMap<const SCEV *, VPValue *> SCEV2VPV; + + for (VPRecipeBase &R : + make_early_inc_range(*Plan.getEntry()->getEntryBasicBlock())) { + auto *ExpR = dyn_cast<VPExpandSCEVRecipe>(&R); + if (!ExpR) + continue; + + auto I = SCEV2VPV.insert({ExpR->getSCEV(), ExpR}); + if (I.second) + continue; + ExpR->replaceAllUsesWith(I.first->second); + ExpR->eraseFromParent(); + } +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.h index e74409a86466..3372e255dff7 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -14,8 +14,7 @@ #define LLVM_TRANSFORMS_VECTORIZE_VPLANTRANSFORMS_H #include "VPlan.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h" +#include "llvm/ADT/STLFunctionalExtras.h" namespace llvm { @@ -23,6 +22,7 @@ class InductionDescriptor; class Instruction; class PHINode; class ScalarEvolution; +class Loop; struct VPlanTransforms { /// Replaces the VPInstructions in \p Plan with corresponding @@ -49,6 +49,18 @@ struct VPlanTransforms { /// Try to replace VPWidenCanonicalIVRecipes with a widened canonical IV /// recipe, if it exists. static void removeRedundantCanonicalIVs(VPlan &Plan); + + static void removeDeadRecipes(VPlan &Plan); + + /// If any user of a VPWidenIntOrFpInductionRecipe needs scalar values, + /// provide them by building scalar steps off of the canonical scalar IV and + /// update the original IV's users. This is an optional optimization to reduce + /// the needs of vector extracts. + static void optimizeInductions(VPlan &Plan, ScalarEvolution &SE); + + /// Remove redundant EpxandSCEVRecipes in \p Plan's entry block by replacing + /// them with already existing recipes expanding the same SCEV expression. + static void removeRedundantExpandSCEVRecipes(VPlan &Plan); }; } // namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanValue.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanValue.h index 5296d2b9485c..5fc676834331 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanValue.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanValue.h @@ -106,6 +106,7 @@ public: VPVFirstOrderRecurrencePHISC, VPVWidenPHISC, VPVWidenIntOrFpInductionSC, + VPVWidenPointerInductionSC, VPVPredInstPHI, VPVReductionPHISC, }; @@ -207,9 +208,7 @@ public: /// Subclass identifier (for isa/dyn_cast). enum class VPUserID { Recipe, - // TODO: Currently VPUsers are used in VPBlockBase, but in the future the - // only VPUsers should either be recipes or live-outs. - Block + LiveOut, }; private: @@ -286,6 +285,22 @@ public: /// Method to support type inquiry through isa, cast, and dyn_cast. static inline bool classof(const VPDef *Recipe); + + /// Returns true if the VPUser uses scalars of operand \p Op. Conservatively + /// returns if only first (scalar) lane is used, as default. + virtual bool usesScalars(const VPValue *Op) const { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + return onlyFirstLaneUsed(Op); + } + + /// Returns true if the VPUser only uses the first lane of operand \p Op. + /// Conservatively returns false. + virtual bool onlyFirstLaneUsed(const VPValue *Op) const { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + return false; + } }; /// This class augments a recipe with a set of VPValues defined by the recipe. @@ -327,10 +342,12 @@ public: /// type identification. using VPRecipeTy = enum { VPBranchOnMaskSC, + VPExpandSCEVSC, VPInstructionSC, VPInterleaveSC, VPReductionSC, VPReplicateSC, + VPScalarIVStepsSC, VPWidenCallSC, VPWidenCanonicalIVSC, VPWidenGEPSC, @@ -344,6 +361,7 @@ public: VPFirstOrderRecurrencePHISC, VPWidenPHISC, VPWidenIntOrFpInductionSC, + VPWidenPointerInductionSC, VPPredInstPHISC, VPReductionPHISC, VPFirstPHISC = VPBlendSC, diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp index d36f250995e1..f917883145c0 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp @@ -43,17 +43,20 @@ static bool hasDuplicates(const SmallVectorImpl<VPBlockBase *> &VPBlockVec) { /// \p Region. Checks in this function are generic for VPBlockBases. They are /// not specific for VPBasicBlocks or VPRegionBlocks. static void verifyBlocksInRegion(const VPRegionBlock *Region) { - for (const VPBlockBase *VPB : - make_range(df_iterator<const VPBlockBase *>::begin(Region->getEntry()), - df_iterator<const VPBlockBase *>::end(Region->getExit()))) { + for (const VPBlockBase *VPB : make_range( + df_iterator<const VPBlockBase *>::begin(Region->getEntry()), + df_iterator<const VPBlockBase *>::end(Region->getExiting()))) { // Check block's parent. assert(VPB->getParent() == Region && "VPBlockBase has wrong parent"); + auto *VPBB = dyn_cast<VPBasicBlock>(VPB); // Check block's condition bit. - if (VPB->getNumSuccessors() > 1) - assert(VPB->getCondBit() && "Missing condition bit!"); + if (VPB->getNumSuccessors() > 1 || (VPBB && VPBB->isExiting())) + assert(VPBB && VPBB->getTerminator() && + "Block has multiple successors but doesn't " + "have a proper branch recipe!"); else - assert(!VPB->getCondBit() && "Unexpected condition bit!"); + assert((!VPBB || !VPBB->getTerminator()) && "Unexpected branch recipe!"); // Check block's successors. const auto &Successors = VPB->getSuccessors(); @@ -94,13 +97,14 @@ static void verifyBlocksInRegion(const VPRegionBlock *Region) { /// VPBlockBases. Do not recurse inside nested VPRegionBlocks. static void verifyRegion(const VPRegionBlock *Region) { const VPBlockBase *Entry = Region->getEntry(); - const VPBlockBase *Exit = Region->getExit(); + const VPBlockBase *Exiting = Region->getExiting(); - // Entry and Exit shouldn't have any predecessor/successor, respectively. + // Entry and Exiting shouldn't have any predecessor/successor, respectively. assert(!Entry->getNumPredecessors() && "Region entry has predecessors."); - assert(!Exit->getNumSuccessors() && "Region exit has successors."); + assert(!Exiting->getNumSuccessors() && + "Region exiting block has successors."); (void)Entry; - (void)Exit; + (void)Exiting; verifyBlocksInRegion(Region); } @@ -111,9 +115,9 @@ static void verifyRegionRec(const VPRegionBlock *Region) { verifyRegion(Region); // Recurse inside nested regions. - for (const VPBlockBase *VPB : - make_range(df_iterator<const VPBlockBase *>::begin(Region->getEntry()), - df_iterator<const VPBlockBase *>::end(Region->getExit()))) { + for (const VPBlockBase *VPB : make_range( + df_iterator<const VPBlockBase *>::begin(Region->getEntry()), + df_iterator<const VPBlockBase *>::end(Region->getExiting()))) { if (const auto *SubRegion = dyn_cast<VPRegionBlock>(VPB)) verifyRegionRec(SubRegion); } @@ -157,7 +161,7 @@ bool VPlanVerifier::verifyPlanIsValid(const VPlan &Plan) { } } - const VPRegionBlock *TopRegion = cast<VPRegionBlock>(Plan.getEntry()); + const VPRegionBlock *TopRegion = Plan.getVectorLoopRegion(); const VPBasicBlock *Entry = dyn_cast<VPBasicBlock>(TopRegion->getEntry()); if (!Entry) { errs() << "VPlan entry block is not a VPBasicBlock\n"; @@ -170,19 +174,19 @@ bool VPlanVerifier::verifyPlanIsValid(const VPlan &Plan) { return false; } - const VPBasicBlock *Exit = dyn_cast<VPBasicBlock>(TopRegion->getExit()); - if (!Exit) { - errs() << "VPlan exit block is not a VPBasicBlock\n"; + const VPBasicBlock *Exiting = dyn_cast<VPBasicBlock>(TopRegion->getExiting()); + if (!Exiting) { + errs() << "VPlan exiting block is not a VPBasicBlock\n"; return false; } - if (Exit->empty()) { - errs() << "VPlan vector loop exit must end with BranchOnCount " + if (Exiting->empty()) { + errs() << "VPlan vector loop exiting block must end with BranchOnCount " "VPInstruction but is empty\n"; return false; } - auto *LastInst = dyn_cast<VPInstruction>(std::prev(Exit->end())); + auto *LastInst = dyn_cast<VPInstruction>(std::prev(Exiting->end())); if (!LastInst || LastInst->getOpcode() != VPInstruction::BranchOnCount) { errs() << "VPlan vector loop exit must end with BranchOnCount " "VPInstruction\n"; @@ -197,10 +201,17 @@ bool VPlanVerifier::verifyPlanIsValid(const VPlan &Plan) { errs() << "region entry block has predecessors\n"; return false; } - if (Region->getExit()->getNumSuccessors() != 0) { - errs() << "region exit block has successors\n"; + if (Region->getExiting()->getNumSuccessors() != 0) { + errs() << "region exiting block has successors\n"; return false; } } + + for (auto &KV : Plan.getLiveOuts()) + if (KV.second->getNumOperands() != 1) { + errs() << "live outs must have a single operand\n"; + return false; + } + return true; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index 258f6c67e54d..90598937affc 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -103,11 +103,13 @@ private: bool foldSingleElementStore(Instruction &I); bool scalarizeLoadExtract(Instruction &I); bool foldShuffleOfBinops(Instruction &I); + bool foldShuffleFromReductions(Instruction &I); + bool foldSelectShuffle(Instruction &I, bool FromReduction = false); void replaceValue(Value &Old, Value &New) { Old.replaceAllUsesWith(&New); - New.takeName(&Old); if (auto *NewI = dyn_cast<Instruction>(&New)) { + New.takeName(&Old); Worklist.pushUsersToWorkList(*NewI); Worklist.pushValue(NewI); } @@ -255,12 +257,12 @@ bool VectorCombine::vectorizeLoadInsert(Instruction &I) { ExtractElementInst *VectorCombine::getShuffleExtract( ExtractElementInst *Ext0, ExtractElementInst *Ext1, unsigned PreferredExtractIndex = InvalidIndex) const { - assert(isa<ConstantInt>(Ext0->getIndexOperand()) && - isa<ConstantInt>(Ext1->getIndexOperand()) && - "Expected constant extract indexes"); + auto *Index0C = dyn_cast<ConstantInt>(Ext0->getIndexOperand()); + auto *Index1C = dyn_cast<ConstantInt>(Ext1->getIndexOperand()); + assert(Index0C && Index1C && "Expected constant extract indexes"); - unsigned Index0 = cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue(); - unsigned Index1 = cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue(); + unsigned Index0 = Index0C->getZExtValue(); + unsigned Index1 = Index1C->getZExtValue(); // If the extract indexes are identical, no shuffle is needed. if (Index0 == Index1) @@ -306,9 +308,10 @@ bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0, const Instruction &I, ExtractElementInst *&ConvertToShuffle, unsigned PreferredExtractIndex) { - assert(isa<ConstantInt>(Ext0->getOperand(1)) && - isa<ConstantInt>(Ext1->getOperand(1)) && - "Expected constant extract indexes"); + auto *Ext0IndexC = dyn_cast<ConstantInt>(Ext0->getOperand(1)); + auto *Ext1IndexC = dyn_cast<ConstantInt>(Ext1->getOperand(1)); + assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes"); + unsigned Opcode = I.getOpcode(); Type *ScalarTy = Ext0->getType(); auto *VecTy = cast<VectorType>(Ext0->getOperand(0)->getType()); @@ -331,8 +334,8 @@ bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0, // Get cost estimates for the extract elements. These costs will factor into // both sequences. - unsigned Ext0Index = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue(); - unsigned Ext1Index = cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue(); + unsigned Ext0Index = Ext0IndexC->getZExtValue(); + unsigned Ext1Index = Ext1IndexC->getZExtValue(); InstructionCost Extract0Cost = TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext0Index); @@ -694,8 +697,9 @@ bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) { ScalarInst->copyIRFlags(&I); // Fold the vector constants in the original vectors into a new base vector. - Constant *NewVecC = IsCmp ? ConstantExpr::getCompare(Pred, VecC0, VecC1) - : ConstantExpr::get(Opcode, VecC0, VecC1); + Value *NewVecC = + IsCmp ? Builder.CreateCmp(Pred, VecC0, VecC1) + : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, VecC0, VecC1); Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index); replaceValue(I, *Insert); return true; @@ -1015,12 +1019,8 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { return false; NumInstChecked++; } - } - - if (!LastCheckedInst) - LastCheckedInst = UI; - else if (LastCheckedInst->comesBefore(UI)) LastCheckedInst = UI; + } auto ScalarIdx = canScalarizeAccess(FixedVT, UI->getOperand(1), &I, AC, DT); if (!ScalarIdx.isSafe()) { @@ -1117,6 +1117,339 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) { return true; } +/// Given a commutative reduction, the order of the input lanes does not alter +/// the results. We can use this to remove certain shuffles feeding the +/// reduction, removing the need to shuffle at all. +bool VectorCombine::foldShuffleFromReductions(Instruction &I) { + auto *II = dyn_cast<IntrinsicInst>(&I); + if (!II) + return false; + switch (II->getIntrinsicID()) { + case Intrinsic::vector_reduce_add: + case Intrinsic::vector_reduce_mul: + case Intrinsic::vector_reduce_and: + case Intrinsic::vector_reduce_or: + case Intrinsic::vector_reduce_xor: + case Intrinsic::vector_reduce_smin: + case Intrinsic::vector_reduce_smax: + case Intrinsic::vector_reduce_umin: + case Intrinsic::vector_reduce_umax: + break; + default: + return false; + } + + // Find all the inputs when looking through operations that do not alter the + // lane order (binops, for example). Currently we look for a single shuffle, + // and can ignore splat values. + std::queue<Value *> Worklist; + SmallPtrSet<Value *, 4> Visited; + ShuffleVectorInst *Shuffle = nullptr; + if (auto *Op = dyn_cast<Instruction>(I.getOperand(0))) + Worklist.push(Op); + + while (!Worklist.empty()) { + Value *CV = Worklist.front(); + Worklist.pop(); + if (Visited.contains(CV)) + continue; + + // Splats don't change the order, so can be safely ignored. + if (isSplatValue(CV)) + continue; + + Visited.insert(CV); + + if (auto *CI = dyn_cast<Instruction>(CV)) { + if (CI->isBinaryOp()) { + for (auto *Op : CI->operand_values()) + Worklist.push(Op); + continue; + } else if (auto *SV = dyn_cast<ShuffleVectorInst>(CI)) { + if (Shuffle && Shuffle != SV) + return false; + Shuffle = SV; + continue; + } + } + + // Anything else is currently an unknown node. + return false; + } + + if (!Shuffle) + return false; + + // Check all uses of the binary ops and shuffles are also included in the + // lane-invariant operations (Visited should be the list of lanewise + // instructions, including the shuffle that we found). + for (auto *V : Visited) + for (auto *U : V->users()) + if (!Visited.contains(U) && U != &I) + return false; + + FixedVectorType *VecType = + dyn_cast<FixedVectorType>(II->getOperand(0)->getType()); + if (!VecType) + return false; + FixedVectorType *ShuffleInputType = + dyn_cast<FixedVectorType>(Shuffle->getOperand(0)->getType()); + if (!ShuffleInputType) + return false; + int NumInputElts = ShuffleInputType->getNumElements(); + + // Find the mask from sorting the lanes into order. This is most likely to + // become a identity or concat mask. Undef elements are pushed to the end. + SmallVector<int> ConcatMask; + Shuffle->getShuffleMask(ConcatMask); + sort(ConcatMask, [](int X, int Y) { return (unsigned)X < (unsigned)Y; }); + bool UsesSecondVec = + any_of(ConcatMask, [&](int M) { return M >= NumInputElts; }); + InstructionCost OldCost = TTI.getShuffleCost( + UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, VecType, + Shuffle->getShuffleMask()); + InstructionCost NewCost = TTI.getShuffleCost( + UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, VecType, + ConcatMask); + + LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle + << "\n"); + LLVM_DEBUG(dbgs() << " OldCost: " << OldCost << " vs NewCost: " << NewCost + << "\n"); + if (NewCost < OldCost) { + Builder.SetInsertPoint(Shuffle); + Value *NewShuffle = Builder.CreateShuffleVector( + Shuffle->getOperand(0), Shuffle->getOperand(1), ConcatMask); + LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n"); + replaceValue(*Shuffle, *NewShuffle); + } + + // See if we can re-use foldSelectShuffle, getting it to reduce the size of + // the shuffle into a nicer order, as it can ignore the order of the shuffles. + return foldSelectShuffle(*Shuffle, true); +} + +/// This method looks for groups of shuffles acting on binops, of the form: +/// %x = shuffle ... +/// %y = shuffle ... +/// %a = binop %x, %y +/// %b = binop %x, %y +/// shuffle %a, %b, selectmask +/// We may, especially if the shuffle is wider than legal, be able to convert +/// the shuffle to a form where only parts of a and b need to be computed. On +/// architectures with no obvious "select" shuffle, this can reduce the total +/// number of operations if the target reports them as cheaper. +bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) { + auto *SVI = dyn_cast<ShuffleVectorInst>(&I); + auto *VT = dyn_cast<FixedVectorType>(I.getType()); + if (!SVI || !VT) + return false; + auto *Op0 = dyn_cast<Instruction>(SVI->getOperand(0)); + auto *Op1 = dyn_cast<Instruction>(SVI->getOperand(1)); + if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() || + VT != Op0->getType()) + return false; + auto *SVI0A = dyn_cast<ShuffleVectorInst>(Op0->getOperand(0)); + auto *SVI0B = dyn_cast<ShuffleVectorInst>(Op0->getOperand(1)); + auto *SVI1A = dyn_cast<ShuffleVectorInst>(Op1->getOperand(0)); + auto *SVI1B = dyn_cast<ShuffleVectorInst>(Op1->getOperand(1)); + auto checkSVNonOpUses = [&](Instruction *I) { + if (!I || I->getOperand(0)->getType() != VT) + return true; + return any_of(I->users(), [&](User *U) { return U != Op0 && U != Op1; }); + }; + if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) || + checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B)) + return false; + + // Collect all the uses that are shuffles that we can transform together. We + // may not have a single shuffle, but a group that can all be transformed + // together profitably. + SmallVector<ShuffleVectorInst *> Shuffles; + auto collectShuffles = [&](Instruction *I) { + for (auto *U : I->users()) { + auto *SV = dyn_cast<ShuffleVectorInst>(U); + if (!SV || SV->getType() != VT) + return false; + if (!llvm::is_contained(Shuffles, SV)) + Shuffles.push_back(SV); + } + return true; + }; + if (!collectShuffles(Op0) || !collectShuffles(Op1)) + return false; + // From a reduction, we need to be processing a single shuffle, otherwise the + // other uses will not be lane-invariant. + if (FromReduction && Shuffles.size() > 1) + return false; + + // For each of the output shuffles, we try to sort all the first vector + // elements to the beginning, followed by the second array elements at the + // end. If the binops are legalized to smaller vectors, this may reduce total + // number of binops. We compute the ReconstructMask mask needed to convert + // back to the original lane order. + SmallVector<int> V1, V2; + SmallVector<SmallVector<int>> ReconstructMasks; + int MaxV1Elt = 0, MaxV2Elt = 0; + unsigned NumElts = VT->getNumElements(); + for (ShuffleVectorInst *SVN : Shuffles) { + SmallVector<int> Mask; + SVN->getShuffleMask(Mask); + + // Check the operands are the same as the original, or reversed (in which + // case we need to commute the mask). + Value *SVOp0 = SVN->getOperand(0); + Value *SVOp1 = SVN->getOperand(1); + if (SVOp0 == Op1 && SVOp1 == Op0) { + std::swap(SVOp0, SVOp1); + ShuffleVectorInst::commuteShuffleMask(Mask, NumElts); + } + if (SVOp0 != Op0 || SVOp1 != Op1) + return false; + + // Calculate the reconstruction mask for this shuffle, as the mask needed to + // take the packed values from Op0/Op1 and reconstructing to the original + // order. + SmallVector<int> ReconstructMask; + for (unsigned I = 0; I < Mask.size(); I++) { + if (Mask[I] < 0) { + ReconstructMask.push_back(-1); + } else if (Mask[I] < static_cast<int>(NumElts)) { + MaxV1Elt = std::max(MaxV1Elt, Mask[I]); + auto It = find(V1, Mask[I]); + if (It != V1.end()) + ReconstructMask.push_back(It - V1.begin()); + else { + ReconstructMask.push_back(V1.size()); + V1.push_back(Mask[I]); + } + } else { + MaxV2Elt = std::max<int>(MaxV2Elt, Mask[I] - NumElts); + auto It = find(V2, Mask[I] - NumElts); + if (It != V2.end()) + ReconstructMask.push_back(NumElts + It - V2.begin()); + else { + ReconstructMask.push_back(NumElts + V2.size()); + V2.push_back(Mask[I] - NumElts); + } + } + } + + // For reductions, we know that the lane ordering out doesn't alter the + // result. In-order can help simplify the shuffle away. + if (FromReduction) + sort(ReconstructMask); + ReconstructMasks.push_back(ReconstructMask); + } + + // If the Maximum element used from V1 and V2 are not larger than the new + // vectors, the vectors are already packes and performing the optimization + // again will likely not help any further. This also prevents us from getting + // stuck in a cycle in case the costs do not also rule it out. + if (V1.empty() || V2.empty() || + (MaxV1Elt == static_cast<int>(V1.size()) - 1 && + MaxV2Elt == static_cast<int>(V2.size()) - 1)) + return false; + + // Calculate the masks needed for the new input shuffles, which get padded + // with undef + SmallVector<int> V1A, V1B, V2A, V2B; + for (unsigned I = 0; I < V1.size(); I++) { + V1A.push_back(SVI0A->getMaskValue(V1[I])); + V1B.push_back(SVI0B->getMaskValue(V1[I])); + } + for (unsigned I = 0; I < V2.size(); I++) { + V2A.push_back(SVI1A->getMaskValue(V2[I])); + V2B.push_back(SVI1B->getMaskValue(V2[I])); + } + while (V1A.size() < NumElts) { + V1A.push_back(UndefMaskElem); + V1B.push_back(UndefMaskElem); + } + while (V2A.size() < NumElts) { + V2A.push_back(UndefMaskElem); + V2B.push_back(UndefMaskElem); + } + + auto AddShuffleCost = [&](InstructionCost C, ShuffleVectorInst *SV) { + return C + + TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, VT, SV->getShuffleMask()); + }; + auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) { + return C + TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, VT, Mask); + }; + + // Get the costs of the shuffles + binops before and after with the new + // shuffle masks. + InstructionCost CostBefore = + TTI.getArithmeticInstrCost(Op0->getOpcode(), VT) + + TTI.getArithmeticInstrCost(Op1->getOpcode(), VT); + CostBefore += std::accumulate(Shuffles.begin(), Shuffles.end(), + InstructionCost(0), AddShuffleCost); + // This set helps us only cost each unique shuffle once. + SmallPtrSet<ShuffleVectorInst *, 4> InputShuffles( + {SVI0A, SVI0B, SVI1A, SVI1B}); + CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(), + InstructionCost(0), AddShuffleCost); + + // The new binops will be unused for lanes past the used shuffle lengths. + // These types attempt to get the correct cost for that from the target. + FixedVectorType *Op0SmallVT = + FixedVectorType::get(VT->getScalarType(), V1.size()); + FixedVectorType *Op1SmallVT = + FixedVectorType::get(VT->getScalarType(), V2.size()); + InstructionCost CostAfter = + TTI.getArithmeticInstrCost(Op0->getOpcode(), Op0SmallVT) + + TTI.getArithmeticInstrCost(Op1->getOpcode(), Op1SmallVT); + CostAfter += std::accumulate(ReconstructMasks.begin(), ReconstructMasks.end(), + InstructionCost(0), AddShuffleMaskCost); + std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B}); + CostAfter += + std::accumulate(OutputShuffleMasks.begin(), OutputShuffleMasks.end(), + InstructionCost(0), AddShuffleMaskCost); + + if (CostBefore <= CostAfter) + return false; + + // The cost model has passed, create the new instructions. + Builder.SetInsertPoint(SVI0A); + Value *NSV0A = Builder.CreateShuffleVector(SVI0A->getOperand(0), + SVI0A->getOperand(1), V1A); + Builder.SetInsertPoint(SVI0B); + Value *NSV0B = Builder.CreateShuffleVector(SVI0B->getOperand(0), + SVI0B->getOperand(1), V1B); + Builder.SetInsertPoint(SVI1A); + Value *NSV1A = Builder.CreateShuffleVector(SVI1A->getOperand(0), + SVI1A->getOperand(1), V2A); + Builder.SetInsertPoint(SVI1B); + Value *NSV1B = Builder.CreateShuffleVector(SVI1B->getOperand(0), + SVI1B->getOperand(1), V2B); + Builder.SetInsertPoint(Op0); + Value *NOp0 = Builder.CreateBinOp((Instruction::BinaryOps)Op0->getOpcode(), + NSV0A, NSV0B); + if (auto *I = dyn_cast<Instruction>(NOp0)) + I->copyIRFlags(Op0, true); + Builder.SetInsertPoint(Op1); + Value *NOp1 = Builder.CreateBinOp((Instruction::BinaryOps)Op1->getOpcode(), + NSV1A, NSV1B); + if (auto *I = dyn_cast<Instruction>(NOp1)) + I->copyIRFlags(Op1, true); + + for (int S = 0, E = ReconstructMasks.size(); S != E; S++) { + Builder.SetInsertPoint(Shuffles[S]); + Value *NSV = Builder.CreateShuffleVector(NOp0, NOp1, ReconstructMasks[S]); + replaceValue(*Shuffles[S], *NSV); + } + + Worklist.pushValue(NSV0A); + Worklist.pushValue(NSV0B); + Worklist.pushValue(NSV1A); + Worklist.pushValue(NSV1B); + for (auto *S : Shuffles) + Worklist.add(S); + return true; +} + /// This is the entry point for all transforms. Pass manager differences are /// handled in the callers of this function. bool VectorCombine::run() { @@ -1136,6 +1469,8 @@ bool VectorCombine::run() { MadeChange |= foldBitcastShuf(I); MadeChange |= foldExtractedCmps(I); MadeChange |= foldShuffleOfBinops(I); + MadeChange |= foldShuffleFromReductions(I); + MadeChange |= foldSelectShuffle(I); } MadeChange |= scalarizeBinopOrCmp(I); MadeChange |= scalarizeLoadExtract(I); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/Vectorize.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/Vectorize.cpp index 010ca28fc237..208e5eeea864 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/Vectorize.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/Vectorize.cpp @@ -15,7 +15,6 @@ #include "llvm/Transforms/Vectorize.h" #include "llvm-c/Initialization.h" #include "llvm-c/Transforms/Vectorize.h" -#include "llvm/Analysis/Passes.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/InitializePasses.h" #include "llvm/PassRegistry.h" |