diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms')
234 files changed, 32817 insertions, 13951 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index 49fa0f59d488..1d23ec8ced20 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -19,6 +19,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -28,6 +29,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" @@ -47,6 +49,16 @@ static cl::opt<unsigned> MaxInstrsToScan( "aggressive-instcombine-max-scan-instrs", cl::init(64), cl::Hidden, cl::desc("Max number of instructions to scan for aggressive instcombine.")); +static cl::opt<unsigned> StrNCmpInlineThreshold( + "strncmp-inline-threshold", cl::init(3), cl::Hidden, + cl::desc("The maximum length of a constant string for a builtin string cmp " + "call eligible for inlining. The default value is 3.")); + +static cl::opt<unsigned> + MemChrInlineThreshold("memchr-inline-threshold", cl::init(3), cl::Hidden, + cl::desc("The maximum length of a constant string to " + "inline a memchr call.")); + /// Match a pattern for a bitwise funnel/rotate operation that partially guards /// against undefined behavior by branching around the funnel-shift/rotation /// when the shift amount is 0. @@ -73,7 +85,7 @@ static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) { m_Shl(m_Value(ShVal0), m_Value(ShAmt)), m_LShr(m_Value(ShVal1), m_Sub(m_SpecificInt(Width), m_Deferred(ShAmt))))))) { - return Intrinsic::fshl; + return Intrinsic::fshl; } // fshr(ShVal0, ShVal1, ShAmt) @@ -82,7 +94,7 @@ static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) { m_OneUse(m_c_Or(m_Shl(m_Value(ShVal0), m_Sub(m_SpecificInt(Width), m_Value(ShAmt))), m_LShr(m_Value(ShVal1), m_Deferred(ShAmt)))))) { - return Intrinsic::fshr; + return Intrinsic::fshr; } return Intrinsic::not_intrinsic; @@ -399,21 +411,11 @@ static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) { /// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids /// pessimistic codegen that has to account for setting errno and can enable /// vectorization. -static bool foldSqrt(Instruction &I, TargetTransformInfo &TTI, +static bool foldSqrt(CallInst *Call, LibFunc Func, TargetTransformInfo &TTI, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT) { - // Match a call to sqrt mathlib function. - auto *Call = dyn_cast<CallInst>(&I); - if (!Call) - return false; Module *M = Call->getModule(); - LibFunc Func; - if (!TLI.getLibFunc(*Call, Func) || !isLibFuncEmittable(M, &TLI, Func)) - return false; - - if (Func != LibFunc_sqrt && Func != LibFunc_sqrtf && Func != LibFunc_sqrtl) - return false; // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created // (because NNAN or the operand arg must not be less than -0.0) and (2) we @@ -425,19 +427,20 @@ static bool foldSqrt(Instruction &I, TargetTransformInfo &TTI, Value *Arg = Call->getArgOperand(0); if (TTI.haveFastSqrt(Ty) && (Call->hasNoNaNs() || - cannotBeOrderedLessThanZero(Arg, M->getDataLayout(), &TLI, 0, &AC, &I, - &DT))) { - IRBuilder<> Builder(&I); + cannotBeOrderedLessThanZero( + Arg, 0, + SimplifyQuery(Call->getDataLayout(), &TLI, &DT, &AC, Call)))) { + IRBuilder<> Builder(Call); IRBuilderBase::FastMathFlagGuard Guard(Builder); Builder.setFastMathFlags(Call->getFastMathFlags()); Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty); Value *NewSqrt = Builder.CreateCall(Sqrt, Arg, "sqrt"); - I.replaceAllUsesWith(NewSqrt); + Call->replaceAllUsesWith(NewSqrt); // Explicitly erase the old call because a call with side effects is not // trivially dead. - I.eraseFromParent(); + Call->eraseFromParent(); return true; } @@ -808,8 +811,7 @@ static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL, APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0); Load1Ptr = Load1Ptr->stripAndAccumulateConstantOffsets( DL, Offset1, /* AllowNonInbounds */ true); - Load1Ptr = Builder.CreatePtrAdd(Load1Ptr, - Builder.getInt32(Offset1.getZExtValue())); + Load1Ptr = Builder.CreatePtrAdd(Load1Ptr, Builder.getInt(Offset1)); } // Generate wider load. NewLoad = Builder.CreateAlignedLoad(WiderType, Load1Ptr, LI1->getAlign(), @@ -922,20 +924,321 @@ static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) { return true; } +namespace { +class StrNCmpInliner { +public: + StrNCmpInliner(CallInst *CI, LibFunc Func, DomTreeUpdater *DTU, + const DataLayout &DL) + : CI(CI), Func(Func), DTU(DTU), DL(DL) {} + + bool optimizeStrNCmp(); + +private: + void inlineCompare(Value *LHS, StringRef RHS, uint64_t N, bool Swapped); + + CallInst *CI; + LibFunc Func; + DomTreeUpdater *DTU; + const DataLayout &DL; +}; + +} // namespace + +/// First we normalize calls to strncmp/strcmp to the form of +/// compare(s1, s2, N), which means comparing first N bytes of s1 and s2 +/// (without considering '\0'). +/// +/// Examples: +/// +/// \code +/// strncmp(s, "a", 3) -> compare(s, "a", 2) +/// strncmp(s, "abc", 3) -> compare(s, "abc", 3) +/// strncmp(s, "a\0b", 3) -> compare(s, "a\0b", 2) +/// strcmp(s, "a") -> compare(s, "a", 2) +/// +/// char s2[] = {'a'} +/// strncmp(s, s2, 3) -> compare(s, s2, 3) +/// +/// char s2[] = {'a', 'b', 'c', 'd'} +/// strncmp(s, s2, 3) -> compare(s, s2, 3) +/// \endcode +/// +/// We only handle cases where N and exactly one of s1 and s2 are constant. +/// Cases that s1 and s2 are both constant are already handled by the +/// instcombine pass. +/// +/// We do not handle cases where N > StrNCmpInlineThreshold. +/// +/// We also do not handles cases where N < 2, which are already +/// handled by the instcombine pass. +/// +bool StrNCmpInliner::optimizeStrNCmp() { + if (StrNCmpInlineThreshold < 2) + return false; + + if (!isOnlyUsedInZeroComparison(CI)) + return false; + + Value *Str1P = CI->getArgOperand(0); + Value *Str2P = CI->getArgOperand(1); + // Should be handled elsewhere. + if (Str1P == Str2P) + return false; + + StringRef Str1, Str2; + bool HasStr1 = getConstantStringInfo(Str1P, Str1, /*TrimAtNul=*/false); + bool HasStr2 = getConstantStringInfo(Str2P, Str2, /*TrimAtNul=*/false); + if (HasStr1 == HasStr2) + return false; + + // Note that '\0' and characters after it are not trimmed. + StringRef Str = HasStr1 ? Str1 : Str2; + Value *StrP = HasStr1 ? Str2P : Str1P; + + size_t Idx = Str.find('\0'); + uint64_t N = Idx == StringRef::npos ? UINT64_MAX : Idx + 1; + if (Func == LibFunc_strncmp) { + if (auto *ConstInt = dyn_cast<ConstantInt>(CI->getArgOperand(2))) + N = std::min(N, ConstInt->getZExtValue()); + else + return false; + } + // Now N means how many bytes we need to compare at most. + if (N > Str.size() || N < 2 || N > StrNCmpInlineThreshold) + return false; + + // Cases where StrP has two or more dereferenceable bytes might be better + // optimized elsewhere. + bool CanBeNull = false, CanBeFreed = false; + if (StrP->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed) > 1) + return false; + inlineCompare(StrP, Str, N, HasStr1); + return true; +} + +/// Convert +/// +/// \code +/// ret = compare(s1, s2, N) +/// \endcode +/// +/// into +/// +/// \code +/// ret = (int)s1[0] - (int)s2[0] +/// if (ret != 0) +/// goto NE +/// ... +/// ret = (int)s1[N-2] - (int)s2[N-2] +/// if (ret != 0) +/// goto NE +/// ret = (int)s1[N-1] - (int)s2[N-1] +/// NE: +/// \endcode +/// +/// CFG before and after the transformation: +/// +/// (before) +/// BBCI +/// +/// (after) +/// BBCI -> BBSubs[0] (sub,icmp) --NE-> BBNE -> BBTail +/// | ^ +/// E | +/// | | +/// BBSubs[1] (sub,icmp) --NE-----+ +/// ... | +/// BBSubs[N-1] (sub) ---------+ +/// +void StrNCmpInliner::inlineCompare(Value *LHS, StringRef RHS, uint64_t N, + bool Swapped) { + auto &Ctx = CI->getContext(); + IRBuilder<> B(Ctx); + + BasicBlock *BBCI = CI->getParent(); + BasicBlock *BBTail = + SplitBlock(BBCI, CI, DTU, nullptr, nullptr, BBCI->getName() + ".tail"); + + SmallVector<BasicBlock *> BBSubs; + for (uint64_t I = 0; I < N; ++I) + BBSubs.push_back( + BasicBlock::Create(Ctx, "sub_" + Twine(I), BBCI->getParent(), BBTail)); + BasicBlock *BBNE = BasicBlock::Create(Ctx, "ne", BBCI->getParent(), BBTail); + + cast<BranchInst>(BBCI->getTerminator())->setSuccessor(0, BBSubs[0]); + + B.SetInsertPoint(BBNE); + PHINode *Phi = B.CreatePHI(CI->getType(), N); + B.CreateBr(BBTail); + + Value *Base = LHS; + for (uint64_t i = 0; i < N; ++i) { + B.SetInsertPoint(BBSubs[i]); + Value *VL = + B.CreateZExt(B.CreateLoad(B.getInt8Ty(), + B.CreateInBoundsPtrAdd(Base, B.getInt64(i))), + CI->getType()); + Value *VR = + ConstantInt::get(CI->getType(), static_cast<unsigned char>(RHS[i])); + Value *Sub = Swapped ? B.CreateSub(VR, VL) : B.CreateSub(VL, VR); + if (i < N - 1) + B.CreateCondBr(B.CreateICmpNE(Sub, ConstantInt::get(CI->getType(), 0)), + BBNE, BBSubs[i + 1]); + else + B.CreateBr(BBNE); + + Phi->addIncoming(Sub, BBSubs[i]); + } + + CI->replaceAllUsesWith(Phi); + CI->eraseFromParent(); + + if (DTU) { + SmallVector<DominatorTree::UpdateType, 8> Updates; + Updates.push_back({DominatorTree::Insert, BBCI, BBSubs[0]}); + for (uint64_t i = 0; i < N; ++i) { + if (i < N - 1) + Updates.push_back({DominatorTree::Insert, BBSubs[i], BBSubs[i + 1]}); + Updates.push_back({DominatorTree::Insert, BBSubs[i], BBNE}); + } + Updates.push_back({DominatorTree::Insert, BBNE, BBTail}); + Updates.push_back({DominatorTree::Delete, BBCI, BBTail}); + DTU->applyUpdates(Updates); + } +} + +/// Convert memchr with a small constant string into a switch +static bool foldMemChr(CallInst *Call, DomTreeUpdater *DTU, + const DataLayout &DL) { + if (isa<Constant>(Call->getArgOperand(1))) + return false; + + StringRef Str; + Value *Base = Call->getArgOperand(0); + if (!getConstantStringInfo(Base, Str, /*TrimAtNul=*/false)) + return false; + + uint64_t N = Str.size(); + if (auto *ConstInt = dyn_cast<ConstantInt>(Call->getArgOperand(2))) { + uint64_t Val = ConstInt->getZExtValue(); + // Ignore the case that n is larger than the size of string. + if (Val > N) + return false; + N = Val; + } else + return false; + + if (N > MemChrInlineThreshold) + return false; + + BasicBlock *BB = Call->getParent(); + BasicBlock *BBNext = SplitBlock(BB, Call, DTU); + IRBuilder<> IRB(BB); + IntegerType *ByteTy = IRB.getInt8Ty(); + BB->getTerminator()->eraseFromParent(); + SwitchInst *SI = IRB.CreateSwitch( + IRB.CreateTrunc(Call->getArgOperand(1), ByteTy), BBNext, N); + Type *IndexTy = DL.getIndexType(Call->getType()); + SmallVector<DominatorTree::UpdateType, 8> Updates; + + BasicBlock *BBSuccess = BasicBlock::Create( + Call->getContext(), "memchr.success", BB->getParent(), BBNext); + IRB.SetInsertPoint(BBSuccess); + PHINode *IndexPHI = IRB.CreatePHI(IndexTy, N, "memchr.idx"); + Value *FirstOccursLocation = IRB.CreateInBoundsPtrAdd(Base, IndexPHI); + IRB.CreateBr(BBNext); + if (DTU) + Updates.push_back({DominatorTree::Insert, BBSuccess, BBNext}); + + SmallPtrSet<ConstantInt *, 4> Cases; + for (uint64_t I = 0; I < N; ++I) { + ConstantInt *CaseVal = ConstantInt::get(ByteTy, Str[I]); + if (!Cases.insert(CaseVal).second) + continue; + + BasicBlock *BBCase = BasicBlock::Create(Call->getContext(), "memchr.case", + BB->getParent(), BBSuccess); + SI->addCase(CaseVal, BBCase); + IRB.SetInsertPoint(BBCase); + IndexPHI->addIncoming(ConstantInt::get(IndexTy, I), BBCase); + IRB.CreateBr(BBSuccess); + if (DTU) { + Updates.push_back({DominatorTree::Insert, BB, BBCase}); + Updates.push_back({DominatorTree::Insert, BBCase, BBSuccess}); + } + } + + PHINode *PHI = + PHINode::Create(Call->getType(), 2, Call->getName(), BBNext->begin()); + PHI->addIncoming(Constant::getNullValue(Call->getType()), BB); + PHI->addIncoming(FirstOccursLocation, BBSuccess); + + Call->replaceAllUsesWith(PHI); + Call->eraseFromParent(); + + if (DTU) + DTU->applyUpdates(Updates); + + return true; +} + +static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI, + TargetLibraryInfo &TLI, AssumptionCache &AC, + DominatorTree &DT, const DataLayout &DL, + bool &MadeCFGChange) { + + auto *CI = dyn_cast<CallInst>(&I); + if (!CI || CI->isNoBuiltin()) + return false; + + Function *CalledFunc = CI->getCalledFunction(); + if (!CalledFunc) + return false; + + LibFunc LF; + if (!TLI.getLibFunc(*CalledFunc, LF) || + !isLibFuncEmittable(CI->getModule(), &TLI, LF)) + return false; + + DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Lazy); + + switch (LF) { + case LibFunc_sqrt: + case LibFunc_sqrtf: + case LibFunc_sqrtl: + return foldSqrt(CI, LF, TTI, TLI, AC, DT); + case LibFunc_strcmp: + case LibFunc_strncmp: + if (StrNCmpInliner(CI, LF, &DTU, DL).optimizeStrNCmp()) { + MadeCFGChange = true; + return true; + } + break; + case LibFunc_memchr: + if (foldMemChr(CI, &DTU, DL)) { + MadeCFGChange = true; + return true; + } + break; + default:; + } + return false; +} + /// This is the entry point for folds that could be implemented in regular /// InstCombine, but they are separated because they are not expected to /// occur frequently and/or have more than a constant-length pattern match. static bool foldUnusualPatterns(Function &F, DominatorTree &DT, TargetTransformInfo &TTI, TargetLibraryInfo &TLI, AliasAnalysis &AA, - AssumptionCache &AC) { + AssumptionCache &AC, bool &MadeCFGChange) { bool MadeChange = false; for (BasicBlock &BB : F) { // Ignore unreachable basic blocks. if (!DT.isReachableFromEntry(&BB)) continue; - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); // Walk the block backwards for efficiency. We're matching a chain of // use->defs, so we're more likely to succeed by starting from the bottom. @@ -953,7 +1256,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT, // NOTE: This function introduces erasing of the instruction `I`, so it // needs to be called at the end of this sequence, otherwise we may make // bugs. - MadeChange |= foldSqrt(I, TTI, TLI, AC, DT); + MadeChange |= foldLibCalls(I, TTI, TLI, AC, DT, DL, MadeCFGChange); } } @@ -969,12 +1272,12 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT, /// handled in the callers of this function. static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI, TargetLibraryInfo &TLI, DominatorTree &DT, - AliasAnalysis &AA) { + AliasAnalysis &AA, bool &MadeCFGChange) { bool MadeChange = false; - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); TruncInstCombine TIC(AC, TLI, DL, DT); MadeChange |= TIC.run(F); - MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC); + MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC, MadeCFGChange); return MadeChange; } @@ -985,12 +1288,16 @@ PreservedAnalyses AggressiveInstCombinePass::run(Function &F, auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &TTI = AM.getResult<TargetIRAnalysis>(F); auto &AA = AM.getResult<AAManager>(F); - if (!runImpl(F, AC, TTI, TLI, DT, AA)) { + bool MadeCFGChange = false; + if (!runImpl(F, AC, TTI, TLI, DT, AA, MadeCFGChange)) { // No changes, all analyses are preserved. return PreservedAnalyses::all(); } // Mark all the analyses that instcombine updates as preserved. PreservedAnalyses PA; - PA.preserveSet<CFGAnalyses>(); + if (MadeCFGChange) + PA.preserve<DominatorTreeAnalysis>(); + else + PA.preserveSet<CFGAnalyses>(); return PA; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/CFGuard/CFGuard.cpp b/contrib/llvm-project/llvm/lib/Transforms/CFGuard/CFGuard.cpp index 4d4306576017..0e1a0a6ed947 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/CFGuard/CFGuard.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/CFGuard/CFGuard.cpp @@ -18,6 +18,7 @@ #include "llvm/IR/CallingConv.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/TargetParser/Triple.h" @@ -219,7 +220,7 @@ void CFGuardImpl::insertCFGuardDispatch(CallBase *CB) { // Create a copy of the call/invoke instruction and add the new bundle. assert((isa<CallInst>(CB) || isa<InvokeInst>(CB)) && "Unknown indirect call type"); - CallBase *NewCB = CallBase::Create(CB, Bundles, CB); + CallBase *NewCB = CallBase::Create(CB, Bundles, CB->getIterator()); // Change the target of the call to be the guard dispatch function. NewCB->setCalledOperand(GuardDispatchLoad); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp index 3e3825fcd50e..dd92b3593af9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp @@ -8,10 +8,11 @@ #include "llvm/Transforms/Coroutines/CoroCleanup.h" #include "CoroInternal.h" +#include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" +#include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" -#include "llvm/IR/Function.h" #include "llvm/Transforms/Scalar/SimplifyCFG.h" using namespace llvm; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroEarly.cpp b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroEarly.cpp index 489106422e19..d8e827e9cebc 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroEarly.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroEarly.cpp @@ -148,6 +148,7 @@ void Lowerer::lowerCoroNoop(IntrinsicInst *II) { NoopCoro = new GlobalVariable(M, NoopCoroConst->getType(), /*isConstant=*/true, GlobalVariable::PrivateLinkage, NoopCoroConst, "NoopCoro.Frame.Const"); + cast<GlobalVariable>(NoopCoro)->setNoSanitizeMetadata(); } Builder.SetInsertPoint(II); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroElide.cpp index 2f4083028ae0..598ef7779d77 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroElide.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroElide.cpp @@ -33,24 +33,47 @@ static cl::opt<std::string> CoroElideInfoOutputFilename( namespace { // Created on demand if the coro-elide pass has work to do. -struct Lowerer : coro::LowererBase { +class FunctionElideInfo { +public: + FunctionElideInfo(Function *F) : ContainingFunction(F) { + this->collectPostSplitCoroIds(); + } + + bool hasCoroIds() const { return !CoroIds.empty(); } + + const SmallVectorImpl<CoroIdInst *> &getCoroIds() const { return CoroIds; } + +private: + Function *ContainingFunction; SmallVector<CoroIdInst *, 4> CoroIds; + // Used in canCoroBeginEscape to distinguish coro.suspend switchs. + SmallPtrSet<const SwitchInst *, 4> CoroSuspendSwitches; + + void collectPostSplitCoroIds(); + friend class CoroIdElider; +}; + +class CoroIdElider { +public: + CoroIdElider(CoroIdInst *CoroId, FunctionElideInfo &FEI, AAResults &AA, + DominatorTree &DT, OptimizationRemarkEmitter &ORE); + void elideHeapAllocations(uint64_t FrameSize, Align FrameAlign); + bool lifetimeEligibleForElide() const; + bool attemptElide(); + bool canCoroBeginEscape(const CoroBeginInst *, + const SmallPtrSetImpl<BasicBlock *> &) const; + +private: + CoroIdInst *CoroId; + FunctionElideInfo &FEI; + AAResults &AA; + DominatorTree &DT; + OptimizationRemarkEmitter &ORE; + SmallVector<CoroBeginInst *, 1> CoroBegins; SmallVector<CoroAllocInst *, 1> CoroAllocs; SmallVector<CoroSubFnInst *, 4> ResumeAddr; DenseMap<CoroBeginInst *, SmallVector<CoroSubFnInst *, 4>> DestroyAddr; - SmallPtrSet<const SwitchInst *, 4> CoroSuspendSwitches; - - Lowerer(Module &M) : LowererBase(M) {} - - void elideHeapAllocations(Function *F, uint64_t FrameSize, Align FrameAlign, - AAResults &AA); - bool shouldElide(Function *F, DominatorTree &DT) const; - void collectPostSplitCoroIds(Function *F); - bool processCoroId(CoroIdInst *, AAResults &AA, DominatorTree &DT, - OptimizationRemarkEmitter &ORE); - bool hasEscapePath(const CoroBeginInst *, - const SmallPtrSetImpl<BasicBlock *> &) const; }; } // end anonymous namespace @@ -136,13 +159,66 @@ static std::unique_ptr<raw_fd_ostream> getOrCreateLogFile() { } #endif +void FunctionElideInfo::collectPostSplitCoroIds() { + for (auto &I : instructions(this->ContainingFunction)) { + if (auto *CII = dyn_cast<CoroIdInst>(&I)) + if (CII->getInfo().isPostSplit()) + // If it is the coroutine itself, don't touch it. + if (CII->getCoroutine() != CII->getFunction()) + CoroIds.push_back(CII); + + // Consider case like: + // %0 = call i8 @llvm.coro.suspend(...) + // switch i8 %0, label %suspend [i8 0, label %resume + // i8 1, label %cleanup] + // and collect the SwitchInsts which are used by escape analysis later. + if (auto *CSI = dyn_cast<CoroSuspendInst>(&I)) + if (CSI->hasOneUse() && isa<SwitchInst>(CSI->use_begin()->getUser())) { + SwitchInst *SWI = cast<SwitchInst>(CSI->use_begin()->getUser()); + if (SWI->getNumCases() == 2) + CoroSuspendSwitches.insert(SWI); + } + } +} + +CoroIdElider::CoroIdElider(CoroIdInst *CoroId, FunctionElideInfo &FEI, + AAResults &AA, DominatorTree &DT, + OptimizationRemarkEmitter &ORE) + : CoroId(CoroId), FEI(FEI), AA(AA), DT(DT), ORE(ORE) { + // Collect all coro.begin and coro.allocs associated with this coro.id. + for (User *U : CoroId->users()) { + if (auto *CB = dyn_cast<CoroBeginInst>(U)) + CoroBegins.push_back(CB); + else if (auto *CA = dyn_cast<CoroAllocInst>(U)) + CoroAllocs.push_back(CA); + } + + // Collect all coro.subfn.addrs associated with coro.begin. + // Note, we only devirtualize the calls if their coro.subfn.addr refers to + // coro.begin directly. If we run into cases where this check is too + // conservative, we can consider relaxing the check. + for (CoroBeginInst *CB : CoroBegins) { + for (User *U : CB->users()) + if (auto *II = dyn_cast<CoroSubFnInst>(U)) + switch (II->getIndex()) { + case CoroSubFnInst::ResumeIndex: + ResumeAddr.push_back(II); + break; + case CoroSubFnInst::DestroyIndex: + DestroyAddr[CB].push_back(II); + break; + default: + llvm_unreachable("unexpected coro.subfn.addr constant"); + } + } +} + // To elide heap allocations we need to suppress code blocks guarded by // llvm.coro.alloc and llvm.coro.free instructions. -void Lowerer::elideHeapAllocations(Function *F, uint64_t FrameSize, - Align FrameAlign, AAResults &AA) { - LLVMContext &C = F->getContext(); - auto *InsertPt = - getFirstNonAllocaInTheEntryBlock(CoroIds.front()->getFunction()); +void CoroIdElider::elideHeapAllocations(uint64_t FrameSize, Align FrameAlign) { + LLVMContext &C = FEI.ContainingFunction->getContext(); + BasicBlock::iterator InsertPt = + getFirstNonAllocaInTheEntryBlock(FEI.ContainingFunction)->getIterator(); // Replacing llvm.coro.alloc with false will suppress dynamic // allocation as it is expected for the frontend to generate the code that @@ -160,7 +236,7 @@ void Lowerer::elideHeapAllocations(Function *F, uint64_t FrameSize, // is spilled into the coroutine frame and recreate the alignment information // here. Possibly we will need to do a mini SROA here and break the coroutine // frame into individual AllocaInst recreating the original alignment. - const DataLayout &DL = F->getParent()->getDataLayout(); + const DataLayout &DL = FEI.ContainingFunction->getDataLayout(); auto FrameTy = ArrayType::get(Type::getInt8Ty(C), FrameSize); auto *Frame = new AllocaInst(FrameTy, DL.getAllocaAddrSpace(), "", InsertPt); Frame->setAlignment(FrameAlign); @@ -177,8 +253,8 @@ void Lowerer::elideHeapAllocations(Function *F, uint64_t FrameSize, removeTailCallAttribute(Frame, AA); } -bool Lowerer::hasEscapePath(const CoroBeginInst *CB, - const SmallPtrSetImpl<BasicBlock *> &TIs) const { +bool CoroIdElider::canCoroBeginEscape( + const CoroBeginInst *CB, const SmallPtrSetImpl<BasicBlock *> &TIs) const { const auto &It = DestroyAddr.find(CB); assert(It != DestroyAddr.end()); @@ -247,7 +323,7 @@ bool Lowerer::hasEscapePath(const CoroBeginInst *CB, // which means a escape path to normal terminator, it is reasonable to skip // it since coroutine frame doesn't change outside the coroutine body. if (isa<SwitchInst>(TI) && - CoroSuspendSwitches.count(cast<SwitchInst>(TI))) { + FEI.CoroSuspendSwitches.count(cast<SwitchInst>(TI))) { Worklist.push_back(cast<SwitchInst>(TI)->getSuccessor(1)); Worklist.push_back(cast<SwitchInst>(TI)->getSuccessor(2)); } else @@ -260,7 +336,7 @@ bool Lowerer::hasEscapePath(const CoroBeginInst *CB, return false; } -bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const { +bool CoroIdElider::lifetimeEligibleForElide() const { // If no CoroAllocs, we cannot suppress allocation, so elision is not // possible. if (CoroAllocs.empty()) @@ -269,6 +345,7 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const { // Check that for every coro.begin there is at least one coro.destroy directly // referencing the SSA value of that coro.begin along each // non-exceptional path. + // // If the value escaped, then coro.destroy would have been referencing a // memory location storing that value and not the virtual register. @@ -276,7 +353,7 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const { // First gather all of the terminators for the function. // Consider the final coro.suspend as the real terminator when the current // function is a coroutine. - for (BasicBlock &B : *F) { + for (BasicBlock &B : *FEI.ContainingFunction) { auto *TI = B.getTerminator(); if (TI->getNumSuccessors() != 0 || isa<UnreachableInst>(TI)) @@ -286,91 +363,43 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const { } // Filter out the coro.destroy that lie along exceptional paths. - SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins; - for (const auto &It : DestroyAddr) { + for (const auto *CB : CoroBegins) { + auto It = DestroyAddr.find(CB); + + // FIXME: If we have not found any destroys for this coro.begin, we + // disqualify this elide. + if (It == DestroyAddr.end()) + return false; + + const auto &CorrespondingDestroyAddrs = It->second; + // If every terminators is dominated by coro.destroy, we could know the // corresponding coro.begin wouldn't escape. - // - // Otherwise hasEscapePath would decide whether there is any paths from + auto DominatesTerminator = [&](auto *TI) { + return llvm::any_of(CorrespondingDestroyAddrs, [&](auto *Destroy) { + return DT.dominates(Destroy, TI->getTerminator()); + }); + }; + + if (llvm::all_of(Terminators, DominatesTerminator)) + continue; + + // Otherwise canCoroBeginEscape would decide whether there is any paths from // coro.begin to Terminators which not pass through any of the - // coro.destroys. + // coro.destroys. This is a slower analysis. // - // hasEscapePath is relatively slow, so we avoid to run it as much as + // canCoroBeginEscape is relatively slow, so we avoid to run it as much as // possible. - if (llvm::all_of(Terminators, - [&](auto *TI) { - return llvm::any_of(It.second, [&](auto *DA) { - return DT.dominates(DA, TI->getTerminator()); - }); - }) || - !hasEscapePath(It.first, Terminators)) - ReferencedCoroBegins.insert(It.first); + if (canCoroBeginEscape(CB, Terminators)) + return false; } - // If size of the set is the same as total number of coro.begin, that means we - // found a coro.free or coro.destroy referencing each coro.begin, so we can - // perform heap elision. - return ReferencedCoroBegins.size() == CoroBegins.size(); -} - -void Lowerer::collectPostSplitCoroIds(Function *F) { - CoroIds.clear(); - CoroSuspendSwitches.clear(); - for (auto &I : instructions(F)) { - if (auto *CII = dyn_cast<CoroIdInst>(&I)) - if (CII->getInfo().isPostSplit()) - // If it is the coroutine itself, don't touch it. - if (CII->getCoroutine() != CII->getFunction()) - CoroIds.push_back(CII); - - // Consider case like: - // %0 = call i8 @llvm.coro.suspend(...) - // switch i8 %0, label %suspend [i8 0, label %resume - // i8 1, label %cleanup] - // and collect the SwitchInsts which are used by escape analysis later. - if (auto *CSI = dyn_cast<CoroSuspendInst>(&I)) - if (CSI->hasOneUse() && isa<SwitchInst>(CSI->use_begin()->getUser())) { - SwitchInst *SWI = cast<SwitchInst>(CSI->use_begin()->getUser()); - if (SWI->getNumCases() == 2) - CoroSuspendSwitches.insert(SWI); - } - } + // We have checked all CoroBegins and their paths to the terminators without + // finding disqualifying code patterns, so we can perform heap allocations. + return true; } -bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA, - DominatorTree &DT, OptimizationRemarkEmitter &ORE) { - CoroBegins.clear(); - CoroAllocs.clear(); - ResumeAddr.clear(); - DestroyAddr.clear(); - - // Collect all coro.begin and coro.allocs associated with this coro.id. - for (User *U : CoroId->users()) { - if (auto *CB = dyn_cast<CoroBeginInst>(U)) - CoroBegins.push_back(CB); - else if (auto *CA = dyn_cast<CoroAllocInst>(U)) - CoroAllocs.push_back(CA); - } - - // Collect all coro.subfn.addrs associated with coro.begin. - // Note, we only devirtualize the calls if their coro.subfn.addr refers to - // coro.begin directly. If we run into cases where this check is too - // conservative, we can consider relaxing the check. - for (CoroBeginInst *CB : CoroBegins) { - for (User *U : CB->users()) - if (auto *II = dyn_cast<CoroSubFnInst>(U)) - switch (II->getIndex()) { - case CoroSubFnInst::ResumeIndex: - ResumeAddr.push_back(II); - break; - case CoroSubFnInst::DestroyIndex: - DestroyAddr[CB].push_back(II); - break; - default: - llvm_unreachable("unexpected coro.subfn.addr constant"); - } - } - +bool CoroIdElider::attemptElide() { // PostSplit coro.id refers to an array of subfunctions in its Info // argument. ConstantArray *Resumers = CoroId->getInfo().Resumers; @@ -381,82 +410,68 @@ bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA, replaceWithConstant(ResumeAddrConstant, ResumeAddr); - bool ShouldElide = shouldElide(CoroId->getFunction(), DT); - if (!ShouldElide) - ORE.emit([&]() { - if (auto FrameSizeAndAlign = - getFrameLayout(cast<Function>(ResumeAddrConstant))) - return OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId) - << "'" << ore::NV("callee", CoroId->getCoroutine()->getName()) - << "' not elided in '" - << ore::NV("caller", CoroId->getFunction()->getName()) - << "' (frame_size=" - << ore::NV("frame_size", FrameSizeAndAlign->first) << ", align=" - << ore::NV("align", FrameSizeAndAlign->second.value()) << ")"; - else - return OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId) - << "'" << ore::NV("callee", CoroId->getCoroutine()->getName()) - << "' not elided in '" - << ore::NV("caller", CoroId->getFunction()->getName()) - << "' (frame_size=unknown, align=unknown)"; - }); + bool EligibleForElide = lifetimeEligibleForElide(); auto *DestroyAddrConstant = Resumers->getAggregateElement( - ShouldElide ? CoroSubFnInst::CleanupIndex : CoroSubFnInst::DestroyIndex); + EligibleForElide ? CoroSubFnInst::CleanupIndex + : CoroSubFnInst::DestroyIndex); for (auto &It : DestroyAddr) replaceWithConstant(DestroyAddrConstant, It.second); - if (ShouldElide) { - if (auto FrameSizeAndAlign = - getFrameLayout(cast<Function>(ResumeAddrConstant))) { - elideHeapAllocations(CoroId->getFunction(), FrameSizeAndAlign->first, - FrameSizeAndAlign->second, AA); - coro::replaceCoroFree(CoroId, /*Elide=*/true); - NumOfCoroElided++; + auto FrameSizeAndAlign = getFrameLayout(cast<Function>(ResumeAddrConstant)); + + auto CallerFunctionName = FEI.ContainingFunction->getName(); + auto CalleeCoroutineName = CoroId->getCoroutine()->getName(); + + if (EligibleForElide && FrameSizeAndAlign) { + elideHeapAllocations(FrameSizeAndAlign->first, FrameSizeAndAlign->second); + coro::replaceCoroFree(CoroId, /*Elide=*/true); + NumOfCoroElided++; + #ifndef NDEBUG if (!CoroElideInfoOutputFilename.empty()) - *getOrCreateLogFile() - << "Elide " << CoroId->getCoroutine()->getName() << " in " - << CoroId->getFunction()->getName() << "\n"; + *getOrCreateLogFile() << "Elide " << CalleeCoroutineName << " in " + << FEI.ContainingFunction->getName() << "\n"; #endif + ORE.emit([&]() { return OptimizationRemark(DEBUG_TYPE, "CoroElide", CoroId) - << "'" << ore::NV("callee", CoroId->getCoroutine()->getName()) - << "' elided in '" - << ore::NV("caller", CoroId->getFunction()->getName()) + << "'" << ore::NV("callee", CalleeCoroutineName) + << "' elided in '" << ore::NV("caller", CallerFunctionName) << "' (frame_size=" << ore::NV("frame_size", FrameSizeAndAlign->first) << ", align=" << ore::NV("align", FrameSizeAndAlign->second.value()) << ")"; }); - } else { - ORE.emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId) - << "'" << ore::NV("callee", CoroId->getCoroutine()->getName()) - << "' not elided in '" - << ore::NV("caller", CoroId->getFunction()->getName()) - << "' (frame_size=unknown, align=unknown)"; - }); - } + } else { + ORE.emit([&]() { + auto Remark = OptimizationRemarkMissed(DEBUG_TYPE, "CoroElide", CoroId) + << "'" << ore::NV("callee", CalleeCoroutineName) + << "' not elided in '" + << ore::NV("caller", CallerFunctionName); + + if (FrameSizeAndAlign) + return Remark << "' (frame_size=" + << ore::NV("frame_size", FrameSizeAndAlign->first) + << ", align=" + << ore::NV("align", FrameSizeAndAlign->second.value()) + << ")"; + else + return Remark << "' (frame_size=unknown, align=unknown)"; + }); } return true; } -static bool declaresCoroElideIntrinsics(Module &M) { - return coro::declaresIntrinsics(M, {"llvm.coro.id", "llvm.coro.id.async"}); -} - PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) { auto &M = *F.getParent(); - if (!declaresCoroElideIntrinsics(M)) + if (!coro::declaresIntrinsics(M, {"llvm.coro.id"})) return PreservedAnalyses::all(); - Lowerer L(M); - L.CoroIds.clear(); - L.collectPostSplitCoroIds(&F); - // If we did not find any coro.id, there is nothing to do. - if (L.CoroIds.empty()) + FunctionElideInfo FEI{&F}; + // Elide is not necessary if there's no coro.id within the function. + if (!FEI.hasCoroIds()) return PreservedAnalyses::all(); AAResults &AA = AM.getResult<AAManager>(F); @@ -464,8 +479,10 @@ PreservedAnalyses CoroElidePass::run(Function &F, FunctionAnalysisManager &AM) { auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); bool Changed = false; - for (auto *CII : L.CoroIds) - Changed |= L.processCoroId(CII, AA, DT, ORE); + for (auto *CII : FEI.getCoroIds()) { + CoroIdElider CIE(CII, FEI, AA, DT, ORE); + Changed |= CIE.attemptElide(); + } return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroFrame.cpp index e69c718f0ae3..73e30ea00a0e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -19,6 +19,7 @@ #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallString.h" +#include "llvm/Analysis/CFG.h" #include "llvm/Analysis/PtrUseVisitor.h" #include "llvm/Analysis/StackLifetime.h" #include "llvm/Config/llvm-config.h" @@ -43,6 +44,8 @@ using namespace llvm; +extern cl::opt<bool> UseNewDbgInfoFormat; + // The "coro-suspend-crossing" flag is very noisy. There is another debug type, // "coro-frame", which results in leaner debug spew. #define DEBUG_TYPE "coro-suspend-crossing" @@ -971,7 +974,7 @@ static void cacheDIVar(FrameDataInfo &FrameData, DIVarCache.insert({V, (*I)->getVariable()}); }; CacheIt(findDbgDeclares(V)); - CacheIt(findDPVDeclares(V)); + CacheIt(findDVRDeclares(V)); } } @@ -1123,7 +1126,7 @@ static void buildFrameDebugInfo(Function &F, coro::Shape &Shape, "Coroutine with switch ABI should own Promise alloca"); TinyPtrVector<DbgDeclareInst *> DIs = findDbgDeclares(PromiseAlloca); - TinyPtrVector<DPValue *> DPVs = findDPVDeclares(PromiseAlloca); + TinyPtrVector<DbgVariableRecord *> DVRs = findDVRDeclares(PromiseAlloca); DILocalVariable *PromiseDIVariable = nullptr; DILocation *DILoc = nullptr; @@ -1131,10 +1134,10 @@ static void buildFrameDebugInfo(Function &F, coro::Shape &Shape, DbgDeclareInst *PromiseDDI = DIs.front(); PromiseDIVariable = PromiseDDI->getVariable(); DILoc = PromiseDDI->getDebugLoc().get(); - } else if (!DPVs.empty()) { - DPValue *PromiseDPV = DPVs.front(); - PromiseDIVariable = PromiseDPV->getVariable(); - DILoc = PromiseDPV->getDebugLoc().get(); + } else if (!DVRs.empty()) { + DbgVariableRecord *PromiseDVR = DVRs.front(); + PromiseDIVariable = PromiseDVR->getVariable(); + DILoc = PromiseDVR->getDebugLoc().get(); } else { return; } @@ -1150,7 +1153,7 @@ static void buildFrameDebugInfo(Function &F, coro::Shape &Shape, llvm::DINodeArray()); StructType *FrameTy = Shape.FrameTy; SmallVector<Metadata *, 16> Elements; - DataLayout Layout = F.getParent()->getDataLayout(); + DataLayout Layout = F.getDataLayout(); DenseMap<Value *, DILocalVariable *> DIVarCache; cacheDIVar(FrameData, DIVarCache); @@ -1273,11 +1276,12 @@ static void buildFrameDebugInfo(Function &F, coro::Shape &Shape, } if (UseNewDbgInfoFormat) { - DPValue *NewDPV = new DPValue(ValueAsMetadata::get(Shape.FramePtr), - FrameDIVar, DBuilder.createExpression(), - DILoc, DPValue::LocationType::Declare); + DbgVariableRecord *NewDVR = + new DbgVariableRecord(ValueAsMetadata::get(Shape.FramePtr), FrameDIVar, + DBuilder.createExpression(), DILoc, + DbgVariableRecord::LocationType::Declare); BasicBlock::iterator It = Shape.getInsertPtAfterFramePtr(); - It->getParent()->insertDPValueBefore(NewDPV, It); + It->getParent()->insertDbgRecordBefore(NewDVR, It); } else { DBuilder.insertDeclare(Shape.FramePtr, FrameDIVar, DBuilder.createExpression(), DILoc, @@ -1296,7 +1300,7 @@ static void buildFrameDebugInfo(Function &F, coro::Shape &Shape, static StructType *buildFrameType(Function &F, coro::Shape &Shape, FrameDataInfo &FrameData) { LLVMContext &C = F.getContext(); - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); StructType *FrameTy = [&] { SmallString<32> Name(F.getName()); Name.append(".Frame"); @@ -1439,17 +1443,22 @@ namespace { struct AllocaUseVisitor : PtrUseVisitor<AllocaUseVisitor> { using Base = PtrUseVisitor<AllocaUseVisitor>; AllocaUseVisitor(const DataLayout &DL, const DominatorTree &DT, - const CoroBeginInst &CB, const SuspendCrossingInfo &Checker, + const coro::Shape &CoroShape, + const SuspendCrossingInfo &Checker, bool ShouldUseLifetimeStartInfo) - : PtrUseVisitor(DL), DT(DT), CoroBegin(CB), Checker(Checker), - ShouldUseLifetimeStartInfo(ShouldUseLifetimeStartInfo) {} + : PtrUseVisitor(DL), DT(DT), CoroShape(CoroShape), Checker(Checker), + ShouldUseLifetimeStartInfo(ShouldUseLifetimeStartInfo) { + for (AnyCoroSuspendInst *SuspendInst : CoroShape.CoroSuspends) + CoroSuspendBBs.insert(SuspendInst->getParent()); + } void visit(Instruction &I) { Users.insert(&I); Base::visit(I); // If the pointer is escaped prior to CoroBegin, we have to assume it would // be written into before CoroBegin as well. - if (PI.isEscaped() && !DT.dominates(&CoroBegin, PI.getEscapingInst())) { + if (PI.isEscaped() && + !DT.dominates(CoroShape.CoroBegin, PI.getEscapingInst())) { MayWriteBeforeCoroBegin = true; } } @@ -1552,10 +1561,19 @@ struct AllocaUseVisitor : PtrUseVisitor<AllocaUseVisitor> { // When we found the lifetime markers refers to a // subrange of the original alloca, ignore the lifetime // markers to avoid misleading the analysis. - if (II.getIntrinsicID() != Intrinsic::lifetime_start || !IsOffsetKnown || - !Offset.isZero()) + if (!IsOffsetKnown || !Offset.isZero()) return Base::visitIntrinsicInst(II); - LifetimeStarts.insert(&II); + switch (II.getIntrinsicID()) { + default: + return Base::visitIntrinsicInst(II); + case Intrinsic::lifetime_start: + LifetimeStarts.insert(&II); + LifetimeStartBBs.push_back(II.getParent()); + break; + case Intrinsic::lifetime_end: + LifetimeEndBBs.insert(II.getParent()); + break; + } } void visitCallBase(CallBase &CB) { @@ -1585,7 +1603,7 @@ struct AllocaUseVisitor : PtrUseVisitor<AllocaUseVisitor> { private: const DominatorTree &DT; - const CoroBeginInst &CoroBegin; + const coro::Shape &CoroShape; const SuspendCrossingInfo &Checker; // All alias to the original AllocaInst, created before CoroBegin and used // after CoroBegin. Each entry contains the instruction and the offset in the @@ -1593,6 +1611,9 @@ private: DenseMap<Instruction *, std::optional<APInt>> AliasOffetMap{}; SmallPtrSet<Instruction *, 4> Users{}; SmallPtrSet<IntrinsicInst *, 2> LifetimeStarts{}; + SmallVector<BasicBlock *> LifetimeStartBBs{}; + SmallPtrSet<BasicBlock *, 2> LifetimeEndBBs{}; + SmallPtrSet<const BasicBlock *, 2> CoroSuspendBBs{}; bool MayWriteBeforeCoroBegin{false}; bool ShouldUseLifetimeStartInfo{true}; @@ -1604,10 +1625,19 @@ private: // every basic block that uses the pointer to see if they cross suspension // points. The uses cover both direct uses as well as indirect uses. if (ShouldUseLifetimeStartInfo && !LifetimeStarts.empty()) { - for (auto *I : Users) - for (auto *S : LifetimeStarts) - if (Checker.isDefinitionAcrossSuspend(*S, I)) - return true; + // If there is no explicit lifetime.end, then assume the address can + // cross suspension points. + if (LifetimeEndBBs.empty()) + return true; + + // If there is a path from a lifetime.start to a suspend without a + // corresponding lifetime.end, then the alloca's lifetime persists + // beyond that suspension point and the alloca must go on the frame. + llvm::SmallVector<BasicBlock *> Worklist(LifetimeStartBBs); + if (isManyPotentiallyReachableFromMany(Worklist, CoroSuspendBBs, + &LifetimeEndBBs, &DT)) + return true; + // Addresses are guaranteed to be identical after every lifetime.start so // we cannot use the local stack if the address escaped and there is a // suspend point between lifetime markers. This should also cover the @@ -1645,13 +1675,13 @@ private: } void handleMayWrite(const Instruction &I) { - if (!DT.dominates(&CoroBegin, &I)) + if (!DT.dominates(CoroShape.CoroBegin, &I)) MayWriteBeforeCoroBegin = true; } bool usedAfterCoroBegin(Instruction &I) { for (auto &U : I.uses()) - if (DT.dominates(&CoroBegin, U)) + if (DT.dominates(CoroShape.CoroBegin, U)) return true; return false; } @@ -1660,7 +1690,7 @@ private: // We track all aliases created prior to CoroBegin but used after. // These aliases may need to be recreated after CoroBegin if the alloca // need to live on the frame. - if (DT.dominates(&CoroBegin, &I) || !usedAfterCoroBegin(I)) + if (DT.dominates(CoroShape.CoroBegin, &I) || !usedAfterCoroBegin(I)) return; if (!IsOffsetKnown) { @@ -1862,13 +1892,13 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { SpillAlignment, E.first->getName() + Twine(".reload")); TinyPtrVector<DbgDeclareInst *> DIs = findDbgDeclares(Def); - TinyPtrVector<DPValue *> DPVs = findDPVDeclares(Def); + TinyPtrVector<DbgVariableRecord *> DVRs = findDVRDeclares(Def); // Try best to find dbg.declare. If the spill is a temp, there may not // be a direct dbg.declare. Walk up the load chain to find one from an // alias. if (F->getSubprogram()) { auto *CurDef = Def; - while (DIs.empty() && DPVs.empty() && isa<LoadInst>(CurDef)) { + while (DIs.empty() && DVRs.empty() && isa<LoadInst>(CurDef)) { auto *LdInst = cast<LoadInst>(CurDef); // Only consider ptr to ptr same type load. if (LdInst->getPointerOperandType() != LdInst->getType()) @@ -1877,7 +1907,7 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { if (!isa<AllocaInst, LoadInst>(CurDef)) break; DIs = findDbgDeclares(CurDef); - DPVs = findDPVDeclares(CurDef); + DVRs = findDVRDeclares(CurDef); } } @@ -1887,12 +1917,12 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { // fragments. It will be unreachable in the main function, and // processed by coro::salvageDebugInfo() by CoroCloner. if (UseNewDbgInfoFormat) { - DPValue *NewDPV = - new DPValue(ValueAsMetadata::get(CurrentReload), - DDI->getVariable(), DDI->getExpression(), - DDI->getDebugLoc(), DPValue::LocationType::Declare); - Builder.GetInsertPoint()->getParent()->insertDPValueBefore( - NewDPV, Builder.GetInsertPoint()); + DbgVariableRecord *NewDVR = new DbgVariableRecord( + ValueAsMetadata::get(CurrentReload), DDI->getVariable(), + DDI->getExpression(), DDI->getDebugLoc(), + DbgVariableRecord::LocationType::Declare); + Builder.GetInsertPoint()->getParent()->insertDbgRecordBefore( + NewDVR, Builder.GetInsertPoint()); } else { DIBuilder(*CurrentBlock->getParent()->getParent(), AllowUnresolved) .insertDeclare(CurrentReload, DDI->getVariable(), @@ -1905,7 +1935,7 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { false /*UseEntryValue*/); }; for_each(DIs, SalvageOne); - for_each(DPVs, SalvageOne); + for_each(DVRs, SalvageOne); } // If we have a single edge PHINode, remove it and replace it with a @@ -1925,8 +1955,8 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { U->replaceUsesOfWith(Def, CurrentReload); // Instructions are added to Def's user list if the attached // debug records use Def. Update those now. - for (auto &DPV : U->getDbgValueRange()) - DPV.replaceVariableLocationOp(Def, CurrentReload, true); + for (DbgVariableRecord &DVR : filterDbgVars(U->getDbgRecordRange())) + DVR.replaceVariableLocationOp(Def, CurrentReload, true); } } @@ -1977,12 +2007,12 @@ static void insertSpills(const FrameDataInfo &FrameData, coro::Shape &Shape) { G->setName(Alloca->getName() + Twine(".reload.addr")); SmallVector<DbgVariableIntrinsic *, 4> DIs; - SmallVector<DPValue *> DPValues; - findDbgUsers(DIs, Alloca, &DPValues); + SmallVector<DbgVariableRecord *> DbgVariableRecords; + findDbgUsers(DIs, Alloca, &DbgVariableRecords); for (auto *DVI : DIs) DVI->replaceUsesOfWith(Alloca, G); - for (auto *DPV : DPValues) - DPV->replaceVariableLocationOp(Alloca, G); + for (auto *DVR : DbgVariableRecords) + DVR->replaceVariableLocationOp(Alloca, G); for (Instruction *I : UsersToUpdate) { // It is meaningless to retain the lifetime intrinsics refer for the @@ -2728,12 +2758,11 @@ static void sinkSpillUsesAfterCoroBegin(Function &F, /// after the suspend block. Doing so minimizes the lifetime of each variable, /// hence minimizing the amount of data we end up putting on the frame. static void sinkLifetimeStartMarkers(Function &F, coro::Shape &Shape, - SuspendCrossingInfo &Checker) { + SuspendCrossingInfo &Checker, + const DominatorTree &DT) { if (F.hasOptNone()) return; - DominatorTree DT(F); - // Collect all possible basic blocks which may dominate all uses of allocas. SmallPtrSet<BasicBlock *, 4> DomSet; DomSet.insert(&F.getEntryBlock()); @@ -2829,8 +2858,7 @@ static void collectFrameAlloca(AllocaInst *AI, coro::Shape &Shape, bool ShouldUseLifetimeStartInfo = (Shape.ABI != coro::ABI::Async && Shape.ABI != coro::ABI::Retcon && Shape.ABI != coro::ABI::RetconOnce); - AllocaUseVisitor Visitor{AI->getModule()->getDataLayout(), DT, - *Shape.CoroBegin, Checker, + AllocaUseVisitor Visitor{AI->getDataLayout(), DT, Shape, Checker, ShouldUseLifetimeStartInfo}; Visitor.visitPtr(*AI); if (!Visitor.getShouldLiveOnFrame()) @@ -2947,10 +2975,12 @@ void coro::salvageDebugInfo( std::optional<BasicBlock::iterator> InsertPt; if (auto *I = dyn_cast<Instruction>(Storage)) { InsertPt = I->getInsertionPointAfterDef(); - // Update DILocation only in O0 since it is easy to get out of sync in - // optimizations. See https://github.com/llvm/llvm-project/pull/75104 for - // an example. - if (!OptimizeFrame && I->getDebugLoc()) + // Update DILocation only if variable was not inlined. + DebugLoc ILoc = I->getDebugLoc(); + DebugLoc DVILoc = DVI.getDebugLoc(); + if (ILoc && DVILoc && + DVILoc->getScope()->getSubprogram() == + ILoc->getScope()->getSubprogram()) DVI.setDebugLoc(I->getDebugLoc()); } else if (isa<Argument>(Storage)) InsertPt = F->getEntryBlock().begin(); @@ -2960,43 +2990,45 @@ void coro::salvageDebugInfo( } void coro::salvageDebugInfo( - SmallDenseMap<Argument *, AllocaInst *, 4> &ArgToAllocaMap, DPValue &DPV, - bool OptimizeFrame, bool UseEntryValue) { + SmallDenseMap<Argument *, AllocaInst *, 4> &ArgToAllocaMap, + DbgVariableRecord &DVR, bool OptimizeFrame, bool UseEntryValue) { - Function *F = DPV.getFunction(); + Function *F = DVR.getFunction(); // Follow the pointer arithmetic all the way to the incoming // function argument and convert into a DIExpression. - bool SkipOutermostLoad = DPV.isDbgDeclare(); - Value *OriginalStorage = DPV.getVariableLocationOp(0); + bool SkipOutermostLoad = DVR.isDbgDeclare(); + Value *OriginalStorage = DVR.getVariableLocationOp(0); auto SalvagedInfo = ::salvageDebugInfoImpl( ArgToAllocaMap, OptimizeFrame, UseEntryValue, F, OriginalStorage, - DPV.getExpression(), SkipOutermostLoad); + DVR.getExpression(), SkipOutermostLoad); if (!SalvagedInfo) return; Value *Storage = &SalvagedInfo->first; DIExpression *Expr = &SalvagedInfo->second; - DPV.replaceVariableLocationOp(OriginalStorage, Storage); - DPV.setExpression(Expr); + DVR.replaceVariableLocationOp(OriginalStorage, Storage); + DVR.setExpression(Expr); // We only hoist dbg.declare today since it doesn't make sense to hoist // dbg.value since it does not have the same function wide guarantees that // dbg.declare does. - if (DPV.getType() == DPValue::LocationType::Declare) { + if (DVR.getType() == DbgVariableRecord::LocationType::Declare) { std::optional<BasicBlock::iterator> InsertPt; if (auto *I = dyn_cast<Instruction>(Storage)) { InsertPt = I->getInsertionPointAfterDef(); - // Update DILocation only in O0 since it is easy to get out of sync in - // optimizations. See https://github.com/llvm/llvm-project/pull/75104 for - // an example. - if (!OptimizeFrame && I->getDebugLoc()) - DPV.setDebugLoc(I->getDebugLoc()); + // Update DILocation only if variable was not inlined. + DebugLoc ILoc = I->getDebugLoc(); + DebugLoc DVRLoc = DVR.getDebugLoc(); + if (ILoc && DVRLoc && + DVRLoc->getScope()->getSubprogram() == + ILoc->getScope()->getSubprogram()) + DVR.setDebugLoc(ILoc); } else if (isa<Argument>(Storage)) InsertPt = F->getEntryBlock().begin(); if (InsertPt) { - DPV.removeFromParent(); - (*InsertPt)->getParent()->insertDPValueBefore(&DPV, *InsertPt); + DVR.removeFromParent(); + (*InsertPt)->getParent()->insertDbgRecordBefore(&DVR, *InsertPt); } } } @@ -3064,7 +3096,7 @@ static void doRematerializations( } void coro::buildCoroutineFrame( - Function &F, Shape &Shape, + Function &F, Shape &Shape, TargetTransformInfo &TTI, const std::function<bool(Instruction &)> &MaterializableCallback) { // Don't eliminate swifterror in async functions that won't be split. if (Shape.ABI != coro::ABI::Async || !Shape.CoroSuspends.empty()) @@ -3100,7 +3132,7 @@ void coro::buildCoroutineFrame( SmallVector<Value *, 8> Args(AsyncEnd->args()); auto Arguments = ArrayRef<Value *>(Args).drop_front(3); auto *Call = createMustTailCall(AsyncEnd->getDebugLoc(), MustTailCallFn, - Arguments, Builder); + TTI, Arguments, Builder); splitAround(Call, "MustTailCall.Before.CoroEnd"); } } @@ -3118,12 +3150,13 @@ void coro::buildCoroutineFrame( doRematerializations(F, Checker, MaterializableCallback); + const DominatorTree DT(F); FrameDataInfo FrameData; SmallVector<CoroAllocaAllocInst*, 4> LocalAllocas; SmallVector<Instruction*, 4> DeadInstructions; if (Shape.ABI != coro::ABI::Async && Shape.ABI != coro::ABI::Retcon && Shape.ABI != coro::ABI::RetconOnce) - sinkLifetimeStartMarkers(F, Shape, Checker); + sinkLifetimeStartMarkers(F, Shape, Checker, DT); // Collect the spills for arguments and other not-materializable values. for (Argument &A : F.args()) @@ -3131,7 +3164,6 @@ void coro::buildCoroutineFrame( if (Checker.isDefinitionAcrossSuspend(A, U)) FrameData.Spills[&A].push_back(cast<Instruction>(U)); - const DominatorTree DT(F); for (Instruction &I : instructions(F)) { // Values returned from coroutine structure intrinsics should not be part // of the Coroutine Frame. @@ -3188,15 +3220,15 @@ void coro::buildCoroutineFrame( for (auto &Iter : FrameData.Spills) { auto *V = Iter.first; SmallVector<DbgValueInst *, 16> DVIs; - SmallVector<DPValue *, 16> DPVs; - findDbgValues(DVIs, V, &DPVs); + SmallVector<DbgVariableRecord *, 16> DVRs; + findDbgValues(DVIs, V, &DVRs); for (DbgValueInst *DVI : DVIs) if (Checker.isDefinitionAcrossSuspend(*V, DVI)) FrameData.Spills[V].push_back(DVI); // Add the instructions which carry debug info that is in the frame. - for (DPValue *DPV : DPVs) - if (Checker.isDefinitionAcrossSuspend(*V, DPV->Marker->MarkedInstr)) - FrameData.Spills[V].push_back(DPV->Marker->MarkedInstr); + for (DbgVariableRecord *DVR : DVRs) + if (Checker.isDefinitionAcrossSuspend(*V, DVR->Marker->MarkedInstr)) + FrameData.Spills[V].push_back(DVR->Marker->MarkedInstr); } LLVM_DEBUG(dumpSpills("Spills", FrameData.Spills)); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroInstr.h b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroInstr.h index f01aa58eb899..a31703fe0130 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroInstr.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroInstr.h @@ -78,6 +78,39 @@ public: } }; +/// This represents the llvm.coro.await.suspend.{void,bool,handle} instructions. +// FIXME: add callback metadata +// FIXME: make a proper IntrinisicInst. Currently this is not possible, +// because llvm.coro.await.suspend.* can be invoked. +class LLVM_LIBRARY_VISIBILITY CoroAwaitSuspendInst : public CallBase { + enum { AwaiterArg, FrameArg, WrapperArg }; + +public: + Value *getAwaiter() const { return getArgOperand(AwaiterArg); } + + Value *getFrame() const { return getArgOperand(FrameArg); } + + Function *getWrapperFunction() const { + return cast<Function>(getArgOperand(WrapperArg)); + } + + // Methods to support type inquiry through isa, cast, and dyn_cast: + static bool classof(const CallBase *CB) { + if (const Function *CF = CB->getCalledFunction()) { + auto IID = CF->getIntrinsicID(); + return IID == Intrinsic::coro_await_suspend_void || + IID == Intrinsic::coro_await_suspend_bool || + IID == Intrinsic::coro_await_suspend_handle; + } + + return false; + } + + static bool classof(const Value *V) { + return isa<CallBase>(V) && classof(cast<CallBase>(V)); + } +}; + /// This represents a common base class for llvm.coro.id instructions. class LLVM_LIBRARY_VISIBILITY AnyCoroIdInst : public IntrinsicInst { public: diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroInternal.h b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroInternal.h index fb16a4090689..5716fd0ea4ab 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroInternal.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroInternal.h @@ -12,6 +12,7 @@ #define LLVM_LIB_TRANSFORMS_COROUTINES_COROINTERNAL_H #include "CoroInstr.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/IRBuilder.h" namespace llvm { @@ -34,8 +35,8 @@ void salvageDebugInfo( SmallDenseMap<Argument *, AllocaInst *, 4> &ArgToAllocaMap, DbgVariableIntrinsic &DVI, bool OptimizeFrame, bool IsEntryPoint); void salvageDebugInfo( - SmallDenseMap<Argument *, AllocaInst *, 4> &ArgToAllocaMap, DPValue &DPV, - bool OptimizeFrame, bool UseEntryValue); + SmallDenseMap<Argument *, AllocaInst *, 4> &ArgToAllocaMap, + DbgVariableRecord &DVR, bool OptimizeFrame, bool UseEntryValue); // Keeps data and helper functions for lowering coroutine intrinsics. struct LowererBase { @@ -46,7 +47,7 @@ struct LowererBase { ConstantPointerNull *const NullPtr; LowererBase(Module &M); - Value *makeSubFnCall(Value *Arg, int Index, Instruction *InsertPt); + CallInst *makeSubFnCall(Value *Arg, int Index, Instruction *InsertPt); }; enum class ABI { @@ -83,6 +84,8 @@ struct LLVM_LIBRARY_VISIBILITY Shape { SmallVector<CoroAlignInst *, 2> CoroAligns; SmallVector<AnyCoroSuspendInst *, 4> CoroSuspends; SmallVector<CallInst*, 2> SwiftErrorOps; + SmallVector<CoroAwaitSuspendInst *, 4> CoroAwaitSuspends; + SmallVector<CallInst *, 2> SymmetricTransfers; // Field indexes for special fields in the switch lowering. struct SwitchFieldIndex { @@ -272,9 +275,10 @@ struct LLVM_LIBRARY_VISIBILITY Shape { bool defaultMaterializable(Instruction &V); void buildCoroutineFrame( - Function &F, Shape &Shape, + Function &F, Shape &Shape, TargetTransformInfo &TTI, const std::function<bool(Instruction &)> &MaterializableCallback); CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn, + TargetTransformInfo &TTI, ArrayRef<Value *> Arguments, IRBuilder<> &); } // End namespace coro. } // End namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroSplit.cpp index 7758b52abc20..9e4da5f8ca96 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -113,21 +113,24 @@ private: /// ABIs. AnyCoroSuspendInst *ActiveSuspend = nullptr; + TargetTransformInfo &TTI; + public: /// Create a cloner for a switch lowering. CoroCloner(Function &OrigF, const Twine &Suffix, coro::Shape &Shape, - Kind FKind) - : OrigF(OrigF), NewF(nullptr), Suffix(Suffix), Shape(Shape), - FKind(FKind), Builder(OrigF.getContext()) { + Kind FKind, TargetTransformInfo &TTI) + : OrigF(OrigF), NewF(nullptr), Suffix(Suffix), Shape(Shape), FKind(FKind), + Builder(OrigF.getContext()), TTI(TTI) { assert(Shape.ABI == coro::ABI::Switch); } /// Create a cloner for a continuation lowering. CoroCloner(Function &OrigF, const Twine &Suffix, coro::Shape &Shape, - Function *NewF, AnyCoroSuspendInst *ActiveSuspend) + Function *NewF, AnyCoroSuspendInst *ActiveSuspend, + TargetTransformInfo &TTI) : OrigF(OrigF), NewF(NewF), Suffix(Suffix), Shape(Shape), FKind(Shape.ABI == coro::ABI::Async ? Kind::Async : Kind::Continuation), - Builder(OrigF.getContext()), ActiveSuspend(ActiveSuspend) { + Builder(OrigF.getContext()), ActiveSuspend(ActiveSuspend), TTI(TTI) { assert(Shape.ABI == coro::ABI::Retcon || Shape.ABI == coro::ABI::RetconOnce || Shape.ABI == coro::ABI::Async); assert(NewF && "need existing function for continuation"); @@ -167,11 +170,86 @@ private: } // end anonymous namespace +// FIXME: +// Lower the intrinisc in CoroEarly phase if coroutine frame doesn't escape +// and it is known that other transformations, for example, sanitizers +// won't lead to incorrect code. +static void lowerAwaitSuspend(IRBuilder<> &Builder, CoroAwaitSuspendInst *CB, + coro::Shape &Shape) { + auto Wrapper = CB->getWrapperFunction(); + auto Awaiter = CB->getAwaiter(); + auto FramePtr = CB->getFrame(); + + Builder.SetInsertPoint(CB); + + CallBase *NewCall = nullptr; + // await_suspend has only 2 parameters, awaiter and handle. + // Copy parameter attributes from the intrinsic call, but remove the last, + // because the last parameter now becomes the function that is being called. + AttributeList NewAttributes = + CB->getAttributes().removeParamAttributes(CB->getContext(), 2); + + if (auto Invoke = dyn_cast<InvokeInst>(CB)) { + auto WrapperInvoke = + Builder.CreateInvoke(Wrapper, Invoke->getNormalDest(), + Invoke->getUnwindDest(), {Awaiter, FramePtr}); + + WrapperInvoke->setCallingConv(Invoke->getCallingConv()); + std::copy(Invoke->bundle_op_info_begin(), Invoke->bundle_op_info_end(), + WrapperInvoke->bundle_op_info_begin()); + WrapperInvoke->setAttributes(NewAttributes); + WrapperInvoke->setDebugLoc(Invoke->getDebugLoc()); + NewCall = WrapperInvoke; + } else if (auto Call = dyn_cast<CallInst>(CB)) { + auto WrapperCall = Builder.CreateCall(Wrapper, {Awaiter, FramePtr}); + + WrapperCall->setAttributes(NewAttributes); + WrapperCall->setDebugLoc(Call->getDebugLoc()); + NewCall = WrapperCall; + } else { + llvm_unreachable("Unexpected coro_await_suspend invocation method"); + } + + if (CB->getCalledFunction()->getIntrinsicID() == + Intrinsic::coro_await_suspend_handle) { + // Follow the lowered await_suspend call above with a lowered resume call + // to the returned coroutine. + if (auto *Invoke = dyn_cast<InvokeInst>(CB)) { + // If the await_suspend call is an invoke, we continue in the next block. + Builder.SetInsertPoint(Invoke->getNormalDest()->getFirstInsertionPt()); + } + + coro::LowererBase LB(*Wrapper->getParent()); + auto *ResumeAddr = LB.makeSubFnCall(NewCall, CoroSubFnInst::ResumeIndex, + &*Builder.GetInsertPoint()); + + LLVMContext &Ctx = Builder.getContext(); + FunctionType *ResumeTy = FunctionType::get( + Type::getVoidTy(Ctx), PointerType::getUnqual(Ctx), false); + auto *ResumeCall = Builder.CreateCall(ResumeTy, ResumeAddr, {NewCall}); + ResumeCall->setCallingConv(CallingConv::Fast); + + // We can't insert the 'ret' instruction and adjust the cc until the + // function has been split, so remember this for later. + Shape.SymmetricTransfers.push_back(ResumeCall); + + NewCall = ResumeCall; + } + + CB->replaceAllUsesWith(NewCall); + CB->eraseFromParent(); +} + +static void lowerAwaitSuspends(Function &F, coro::Shape &Shape) { + IRBuilder<> Builder(F.getContext()); + for (auto *AWS : Shape.CoroAwaitSuspends) + lowerAwaitSuspend(Builder, AWS, Shape); +} + static void maybeFreeRetconStorage(IRBuilder<> &Builder, const coro::Shape &Shape, Value *FramePtr, CallGraph *CG) { - assert(Shape.ABI == coro::ABI::Retcon || - Shape.ABI == coro::ABI::RetconOnce); + assert(Shape.ABI == coro::ABI::Retcon || Shape.ABI == coro::ABI::RetconOnce); if (Shape.RetconLowering.IsFrameInlineInStorage) return; @@ -269,7 +347,7 @@ static void replaceFallthroughCoroEnd(AnyCoroEndInst *End, if (auto *RetStructTy = dyn_cast<StructType>(RetTy)) { assert(RetStructTy->getNumElements() == NumReturns && - "numbers of returns should match resume function singature"); + "numbers of returns should match resume function singature"); Value *ReturnValue = UndefValue::get(RetStructTy); unsigned Idx = 0; for (Value *RetValEl : CoroResults->return_values()) @@ -282,7 +360,8 @@ static void replaceFallthroughCoroEnd(AnyCoroEndInst *End, assert(NumReturns == 1); Builder.CreateRet(*CoroResults->retval_begin()); } - CoroResults->replaceAllUsesWith(ConstantTokenNone::get(CoroResults->getContext())); + CoroResults->replaceAllUsesWith( + ConstantTokenNone::get(CoroResults->getContext())); CoroResults->eraseFromParent(); break; } @@ -296,7 +375,7 @@ static void replaceFallthroughCoroEnd(AnyCoroEndInst *End, auto RetTy = Shape.getResumeFunctionType()->getReturnType(); auto RetStructTy = dyn_cast<StructType>(RetTy); PointerType *ContinuationTy = - cast<PointerType>(RetStructTy ? RetStructTy->getElementType(0) : RetTy); + cast<PointerType>(RetStructTy ? RetStructTy->getElementType(0) : RetTy); Value *ReturnValue = ConstantPointerNull::get(ContinuationTy); if (RetStructTy) { @@ -407,104 +486,6 @@ static void replaceCoroEnd(AnyCoroEndInst *End, const coro::Shape &Shape, End->eraseFromParent(); } -// Create an entry block for a resume function with a switch that will jump to -// suspend points. -static void createResumeEntryBlock(Function &F, coro::Shape &Shape) { - assert(Shape.ABI == coro::ABI::Switch); - LLVMContext &C = F.getContext(); - - // resume.entry: - // %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0, - // i32 2 - // % index = load i32, i32* %index.addr - // switch i32 %index, label %unreachable [ - // i32 0, label %resume.0 - // i32 1, label %resume.1 - // ... - // ] - - auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F); - auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F); - - IRBuilder<> Builder(NewEntry); - auto *FramePtr = Shape.FramePtr; - auto *FrameTy = Shape.FrameTy; - auto *GepIndex = Builder.CreateStructGEP( - FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr"); - auto *Index = Builder.CreateLoad(Shape.getIndexType(), GepIndex, "index"); - auto *Switch = - Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size()); - Shape.SwitchLowering.ResumeSwitch = Switch; - - size_t SuspendIndex = 0; - for (auto *AnyS : Shape.CoroSuspends) { - auto *S = cast<CoroSuspendInst>(AnyS); - ConstantInt *IndexVal = Shape.getIndex(SuspendIndex); - - // Replace CoroSave with a store to Index: - // %index.addr = getelementptr %f.frame... (index field number) - // store i32 %IndexVal, i32* %index.addr1 - auto *Save = S->getCoroSave(); - Builder.SetInsertPoint(Save); - if (S->isFinal()) { - // The coroutine should be marked done if it reaches the final suspend - // point. - markCoroutineAsDone(Builder, Shape, FramePtr); - } else { - auto *GepIndex = Builder.CreateStructGEP( - FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr"); - Builder.CreateStore(IndexVal, GepIndex); - } - - Save->replaceAllUsesWith(ConstantTokenNone::get(C)); - Save->eraseFromParent(); - - // Split block before and after coro.suspend and add a jump from an entry - // switch: - // - // whateverBB: - // whatever - // %0 = call i8 @llvm.coro.suspend(token none, i1 false) - // switch i8 %0, label %suspend[i8 0, label %resume - // i8 1, label %cleanup] - // becomes: - // - // whateverBB: - // whatever - // br label %resume.0.landing - // - // resume.0: ; <--- jump from the switch in the resume.entry - // %0 = tail call i8 @llvm.coro.suspend(token none, i1 false) - // br label %resume.0.landing - // - // resume.0.landing: - // %1 = phi i8[-1, %whateverBB], [%0, %resume.0] - // switch i8 % 1, label %suspend [i8 0, label %resume - // i8 1, label %cleanup] - - auto *SuspendBB = S->getParent(); - auto *ResumeBB = - SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex)); - auto *LandingBB = ResumeBB->splitBasicBlock( - S->getNextNode(), ResumeBB->getName() + Twine(".landing")); - Switch->addCase(IndexVal, ResumeBB); - - cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB); - auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, ""); - PN->insertBefore(LandingBB->begin()); - S->replaceAllUsesWith(PN); - PN->addIncoming(Builder.getInt8(-1), SuspendBB); - PN->addIncoming(S, ResumeBB); - - ++SuspendIndex; - } - - Builder.SetInsertPoint(UnreachBB); - Builder.CreateUnreachable(); - - Shape.SwitchLowering.ResumeEntryBlock = NewEntry; -} - // In the resume function, we remove the last case (when coro::Shape is built, // the final suspend point (if present) is always the last element of // CoroSuspends array) since it is an undefined behavior to resume a coroutine @@ -583,11 +564,12 @@ void CoroCloner::replaceRetconOrAsyncSuspendUses() { Shape.ABI == coro::ABI::Async); auto NewS = VMap[ActiveSuspend]; - if (NewS->use_empty()) return; + if (NewS->use_empty()) + return; // Copy out all the continuation arguments after the buffer pointer into // an easily-indexed data structure for convenience. - SmallVector<Value*, 8> Args; + SmallVector<Value *, 8> Args; // The async ABI includes all arguments -- including the first argument. bool IsAsyncABI = Shape.ABI == coro::ABI::Async; for (auto I = IsAsyncABI ? NewF->arg_begin() : std::next(NewF->arg_begin()), @@ -614,7 +596,8 @@ void CoroCloner::replaceRetconOrAsyncSuspendUses() { } // If we have no remaining uses, we're done. - if (NewS->use_empty()) return; + if (NewS->use_empty()) + return; // Otherwise, we need to create an aggregate. Value *Agg = PoisonValue::get(NewS->getType()); @@ -652,7 +635,8 @@ void CoroCloner::replaceCoroSuspends() { for (AnyCoroSuspendInst *CS : Shape.CoroSuspends) { // The active suspend was handled earlier. - if (CS == ActiveSuspend) continue; + if (CS == ActiveSuspend) + continue; auto *MappedCS = cast<AnyCoroSuspendInst>(VMap[CS]); MappedCS->replaceAllUsesWith(SuspendResult); @@ -725,17 +709,18 @@ static void replaceSwiftErrorOps(Function &F, coro::Shape &Shape, } /// Returns all DbgVariableIntrinsic in F. -static std::pair<SmallVector<DbgVariableIntrinsic *, 8>, SmallVector<DPValue *>> +static std::pair<SmallVector<DbgVariableIntrinsic *, 8>, + SmallVector<DbgVariableRecord *>> collectDbgVariableIntrinsics(Function &F) { SmallVector<DbgVariableIntrinsic *, 8> Intrinsics; - SmallVector<DPValue *> DPValues; + SmallVector<DbgVariableRecord *> DbgVariableRecords; for (auto &I : instructions(F)) { - for (DPValue &DPV : I.getDbgValueRange()) - DPValues.push_back(&DPV); + for (DbgVariableRecord &DVR : filterDbgVars(I.getDbgRecordRange())) + DbgVariableRecords.push_back(&DVR); if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I)) Intrinsics.push_back(DVI); } - return {Intrinsics, DPValues}; + return {Intrinsics, DbgVariableRecords}; } void CoroCloner::replaceSwiftErrorOps() { @@ -743,7 +728,7 @@ void CoroCloner::replaceSwiftErrorOps() { } void CoroCloner::salvageDebugInfo() { - auto [Worklist, DPValues] = collectDbgVariableIntrinsics(*NewF); + auto [Worklist, DbgVariableRecords] = collectDbgVariableIntrinsics(*NewF); SmallDenseMap<Argument *, AllocaInst *, 4> ArgToAllocaMap; // Only 64-bit ABIs have a register we can refer to with the entry value. @@ -752,8 +737,8 @@ void CoroCloner::salvageDebugInfo() { for (DbgVariableIntrinsic *DVI : Worklist) coro::salvageDebugInfo(ArgToAllocaMap, *DVI, Shape.OptimizeFrame, UseEntryValue); - for (DPValue *DPV : DPValues) - coro::salvageDebugInfo(ArgToAllocaMap, *DPV, Shape.OptimizeFrame, + for (DbgVariableRecord *DVR : DbgVariableRecords) + coro::salvageDebugInfo(ArgToAllocaMap, *DVR, Shape.OptimizeFrame, UseEntryValue); // Remove all salvaged dbg.declare intrinsics that became @@ -778,7 +763,7 @@ void CoroCloner::salvageDebugInfo() { } }; for_each(Worklist, RemoveOne); - for_each(DPValues, RemoveOne); + for_each(DbgVariableRecords, RemoveOne); } void CoroCloner::replaceEntryBlock() { @@ -810,7 +795,7 @@ void CoroCloner::replaceEntryBlock() { // In switch-lowering, we built a resume-entry block in the original // function. Make the entry block branch to this. auto *SwitchBB = - cast<BasicBlock>(VMap[Shape.SwitchLowering.ResumeEntryBlock]); + cast<BasicBlock>(VMap[Shape.SwitchLowering.ResumeEntryBlock]); Builder.CreateBr(SwitchBB); break; } @@ -1101,6 +1086,24 @@ void CoroCloner::create() { // Set up the new entry block. replaceEntryBlock(); + // Turn symmetric transfers into musttail calls. + for (CallInst *ResumeCall : Shape.SymmetricTransfers) { + ResumeCall = cast<CallInst>(VMap[ResumeCall]); + if (TTI.supportsTailCallFor(ResumeCall)) { + // FIXME: Could we support symmetric transfer effectively without + // musttail? + ResumeCall->setTailCallKind(CallInst::TCK_MustTail); + } + + // Put a 'ret void' after the call, and split any remaining instructions to + // an unreachable block. + BasicBlock *BB = ResumeCall->getParent(); + BB->splitBasicBlock(ResumeCall->getNextNode()); + Builder.SetInsertPoint(BB->getTerminator()); + Builder.CreateRetVoid(); + BB->getTerminator()->eraseFromParent(); + } + Builder.SetInsertPoint(&NewF->getEntryBlock().front()); NewFramePtr = deriveNewFramePointer(); @@ -1158,17 +1161,7 @@ void CoroCloner::create() { // to suppress deallocation code. if (Shape.ABI == coro::ABI::Switch) coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]), - /*Elide=*/ FKind == CoroCloner::Kind::SwitchCleanup); -} - -// Create a resume clone by cloning the body of the original function, setting -// new entry block and replacing coro.suspend an appropriate value to force -// resume or cleanup pass for every suspend point. -static Function *createClone(Function &F, const Twine &Suffix, - coro::Shape &Shape, CoroCloner::Kind FKind) { - CoroCloner Cloner(F, Suffix, Shape, FKind); - Cloner.create(); - return Cloner.getFunction(); + /*Elide=*/FKind == CoroCloner::Kind::SwitchCleanup); } static void updateAsyncFuncPointerContextSize(coro::Shape &Shape) { @@ -1212,67 +1205,6 @@ static void replaceFrameSizeAndAlignment(coro::Shape &Shape) { } } -// Create a global constant array containing pointers to functions provided and -// set Info parameter of CoroBegin to point at this constant. Example: -// -// @f.resumers = internal constant [2 x void(%f.frame*)*] -// [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy] -// define void @f() { -// ... -// call i8* @llvm.coro.begin(i8* null, i32 0, i8* null, -// i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*)) -// -// Assumes that all the functions have the same signature. -static void setCoroInfo(Function &F, coro::Shape &Shape, - ArrayRef<Function *> Fns) { - // This only works under the switch-lowering ABI because coro elision - // only works on the switch-lowering ABI. - assert(Shape.ABI == coro::ABI::Switch); - - SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end()); - assert(!Args.empty()); - Function *Part = *Fns.begin(); - Module *M = Part->getParent(); - auto *ArrTy = ArrayType::get(Part->getType(), Args.size()); - - auto *ConstVal = ConstantArray::get(ArrTy, Args); - auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true, - GlobalVariable::PrivateLinkage, ConstVal, - F.getName() + Twine(".resumers")); - - // Update coro.begin instruction to refer to this constant. - LLVMContext &C = F.getContext(); - auto *BC = ConstantExpr::getPointerCast(GV, PointerType::getUnqual(C)); - Shape.getSwitchCoroId()->setInfo(BC); -} - -// Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame. -static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn, - Function *DestroyFn, Function *CleanupFn) { - assert(Shape.ABI == coro::ABI::Switch); - - IRBuilder<> Builder(&*Shape.getInsertPtAfterFramePtr()); - - auto *ResumeAddr = Builder.CreateStructGEP( - Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Resume, - "resume.addr"); - Builder.CreateStore(ResumeFn, ResumeAddr); - - Value *DestroyOrCleanupFn = DestroyFn; - - CoroIdInst *CoroId = Shape.getSwitchCoroId(); - if (CoroAllocInst *CA = CoroId->getCoroAlloc()) { - // If there is a CoroAlloc and it returns false (meaning we elide the - // allocation, use CleanupFn instead of DestroyFn). - DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn); - } - - auto *DestroyAddr = Builder.CreateStructGEP( - Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Destroy, - "destroy.addr"); - Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr); -} - static void postSplitCleanup(Function &F) { removeUnreachableBlocks(F); @@ -1285,196 +1217,6 @@ static void postSplitCleanup(Function &F) { #endif } -// Assuming we arrived at the block NewBlock from Prev instruction, store -// PHI's incoming values in the ResolvedValues map. -static void -scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock, - DenseMap<Value *, Value *> &ResolvedValues) { - auto *PrevBB = Prev->getParent(); - for (PHINode &PN : NewBlock->phis()) { - auto V = PN.getIncomingValueForBlock(PrevBB); - // See if we already resolved it. - auto VI = ResolvedValues.find(V); - if (VI != ResolvedValues.end()) - V = VI->second; - // Remember the value. - ResolvedValues[&PN] = V; - } -} - -// Replace a sequence of branches leading to a ret, with a clone of a ret -// instruction. Suspend instruction represented by a switch, track the PHI -// values and select the correct case successor when possible. -static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) { - // There is nothing to simplify. - if (isa<ReturnInst>(InitialInst)) - return false; - - DenseMap<Value *, Value *> ResolvedValues; - assert(InitialInst->getModule()); - const DataLayout &DL = InitialInst->getModule()->getDataLayout(); - - auto GetFirstValidInstruction = [](Instruction *I) { - while (I) { - // BitCastInst wouldn't generate actual code so that we could skip it. - if (isa<BitCastInst>(I) || I->isDebugOrPseudoInst() || - I->isLifetimeStartOrEnd()) - I = I->getNextNode(); - else if (isInstructionTriviallyDead(I)) - // Duing we are in the middle of the transformation, we need to erase - // the dead instruction manually. - I = &*I->eraseFromParent(); - else - break; - } - return I; - }; - - auto TryResolveConstant = [&ResolvedValues](Value *V) { - auto It = ResolvedValues.find(V); - if (It != ResolvedValues.end()) - V = It->second; - return dyn_cast<ConstantInt>(V); - }; - - Instruction *I = InitialInst; - while (I->isTerminator() || isa<CmpInst>(I)) { - if (isa<ReturnInst>(I)) { - ReplaceInstWithInst(InitialInst, I->clone()); - return true; - } - - if (auto *BR = dyn_cast<BranchInst>(I)) { - unsigned SuccIndex = 0; - if (BR->isConditional()) { - // Handle the case the condition of the conditional branch is constant. - // e.g., - // - // br i1 false, label %cleanup, label %CoroEnd - // - // It is possible during the transformation. We could continue the - // simplifying in this case. - ConstantInt *Cond = TryResolveConstant(BR->getCondition()); - if (!Cond) - return false; - - SuccIndex = Cond->isOne() ? 0 : 1; - } - - BasicBlock *Succ = BR->getSuccessor(SuccIndex); - scanPHIsAndUpdateValueMap(I, Succ, ResolvedValues); - I = GetFirstValidInstruction(Succ->getFirstNonPHIOrDbgOrLifetime()); - - continue; - } - - if (auto *CondCmp = dyn_cast<CmpInst>(I)) { - // If the case number of suspended switch instruction is reduced to - // 1, then it is simplified to CmpInst in llvm::ConstantFoldTerminator. - auto *BR = dyn_cast<BranchInst>( - GetFirstValidInstruction(CondCmp->getNextNode())); - if (!BR || !BR->isConditional() || CondCmp != BR->getCondition()) - return false; - - // And the comparsion looks like : %cond = icmp eq i8 %V, constant. - // So we try to resolve constant for the first operand only since the - // second operand should be literal constant by design. - ConstantInt *Cond0 = TryResolveConstant(CondCmp->getOperand(0)); - auto *Cond1 = dyn_cast<ConstantInt>(CondCmp->getOperand(1)); - if (!Cond0 || !Cond1) - return false; - - // Both operands of the CmpInst are Constant. So that we could evaluate - // it immediately to get the destination. - auto *ConstResult = - dyn_cast_or_null<ConstantInt>(ConstantFoldCompareInstOperands( - CondCmp->getPredicate(), Cond0, Cond1, DL)); - if (!ConstResult) - return false; - - ResolvedValues[BR->getCondition()] = ConstResult; - - // Handle this branch in next iteration. - I = BR; - continue; - } - - if (auto *SI = dyn_cast<SwitchInst>(I)) { - ConstantInt *Cond = TryResolveConstant(SI->getCondition()); - if (!Cond) - return false; - - BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor(); - scanPHIsAndUpdateValueMap(I, BB, ResolvedValues); - I = GetFirstValidInstruction(BB->getFirstNonPHIOrDbgOrLifetime()); - continue; - } - - return false; - } - - return false; -} - -// Check whether CI obeys the rules of musttail attribute. -static bool shouldBeMustTail(const CallInst &CI, const Function &F) { - if (CI.isInlineAsm()) - return false; - - // Match prototypes and calling conventions of resume function. - FunctionType *CalleeTy = CI.getFunctionType(); - if (!CalleeTy->getReturnType()->isVoidTy() || (CalleeTy->getNumParams() != 1)) - return false; - - Type *CalleeParmTy = CalleeTy->getParamType(0); - if (!CalleeParmTy->isPointerTy() || - (CalleeParmTy->getPointerAddressSpace() != 0)) - return false; - - if (CI.getCallingConv() != F.getCallingConv()) - return false; - - // CI should not has any ABI-impacting function attributes. - static const Attribute::AttrKind ABIAttrs[] = { - Attribute::StructRet, Attribute::ByVal, Attribute::InAlloca, - Attribute::Preallocated, Attribute::InReg, Attribute::Returned, - Attribute::SwiftSelf, Attribute::SwiftError}; - AttributeList Attrs = CI.getAttributes(); - for (auto AK : ABIAttrs) - if (Attrs.hasParamAttr(0, AK)) - return false; - - return true; -} - -// Add musttail to any resume instructions that is immediately followed by a -// suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call -// for symmetrical coroutine control transfer (C++ Coroutines TS extension). -// This transformation is done only in the resume part of the coroutine that has -// identical signature and calling convention as the coro.resume call. -static void addMustTailToCoroResumes(Function &F, TargetTransformInfo &TTI) { - bool changed = false; - - // Collect potential resume instructions. - SmallVector<CallInst *, 4> Resumes; - for (auto &I : instructions(F)) - if (auto *Call = dyn_cast<CallInst>(&I)) - if (shouldBeMustTail(*Call, F)) - Resumes.push_back(Call); - - // Set musttail on those that are followed by a ret instruction. - for (CallInst *Call : Resumes) - // Skip targets which don't support tail call on the specific case. - if (TTI.supportsTailCallFor(Call) && - simplifyTerminatorLeadingToRet(Call->getNextNode())) { - Call->setTailCallKind(CallInst::TCK_MustTail); - changed = true; - } - - if (changed) - removeUnreachableBlocks(F); -} - // Coroutine has no suspend points. Remove heap allocation for the coroutine // frame if possible. static void handleNoSuspendCoroutine(coro::Shape &Shape) { @@ -1506,6 +1248,7 @@ static void handleNoSuspendCoroutine(coro::Shape &Shape) { } CoroBegin->eraseFromParent(); + Shape.CoroBegin = nullptr; } // SimplifySuspendPoint needs to check that there is no calls between @@ -1617,7 +1360,7 @@ static bool simplifySuspendPoint(CoroSuspendInst *Suspend, // No longer need a call to coro.resume or coro.destroy. if (auto *Invoke = dyn_cast<InvokeInst>(CB)) { - BranchInst::Create(Invoke->getNormalDest(), Invoke); + BranchInst::Create(Invoke->getNormalDest(), Invoke->getIterator()); } // Grab the CalledValue from CB before erasing the CallInstr. @@ -1678,44 +1421,209 @@ static void simplifySuspendPoints(coro::Shape &Shape) { } } -static void splitSwitchCoroutine(Function &F, coro::Shape &Shape, - SmallVectorImpl<Function *> &Clones, - TargetTransformInfo &TTI) { - assert(Shape.ABI == coro::ABI::Switch); +namespace { - createResumeEntryBlock(F, Shape); - auto ResumeClone = createClone(F, ".resume", Shape, - CoroCloner::Kind::SwitchResume); - auto DestroyClone = createClone(F, ".destroy", Shape, - CoroCloner::Kind::SwitchUnwind); - auto CleanupClone = createClone(F, ".cleanup", Shape, - CoroCloner::Kind::SwitchCleanup); +struct SwitchCoroutineSplitter { + static void split(Function &F, coro::Shape &Shape, + SmallVectorImpl<Function *> &Clones, + TargetTransformInfo &TTI) { + assert(Shape.ABI == coro::ABI::Switch); - postSplitCleanup(*ResumeClone); - postSplitCleanup(*DestroyClone); - postSplitCleanup(*CleanupClone); + createResumeEntryBlock(F, Shape); + auto *ResumeClone = + createClone(F, ".resume", Shape, CoroCloner::Kind::SwitchResume, TTI); + auto *DestroyClone = + createClone(F, ".destroy", Shape, CoroCloner::Kind::SwitchUnwind, TTI); + auto *CleanupClone = + createClone(F, ".cleanup", Shape, CoroCloner::Kind::SwitchCleanup, TTI); - // Adding musttail call to support symmetric transfer. - // Skip targets which don't support tail call. - // - // FIXME: Could we support symmetric transfer effectively without musttail - // call? - if (TTI.supportsTailCalls()) - addMustTailToCoroResumes(*ResumeClone, TTI); + postSplitCleanup(*ResumeClone); + postSplitCleanup(*DestroyClone); + postSplitCleanup(*CleanupClone); - // Store addresses resume/destroy/cleanup functions in the coroutine frame. - updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone); + // Store addresses resume/destroy/cleanup functions in the coroutine frame. + updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone); - assert(Clones.empty()); - Clones.push_back(ResumeClone); - Clones.push_back(DestroyClone); - Clones.push_back(CleanupClone); - - // Create a constant array referring to resume/destroy/clone functions pointed - // by the last argument of @llvm.coro.info, so that CoroElide pass can - // determined correct function to call. - setCoroInfo(F, Shape, Clones); -} + assert(Clones.empty()); + Clones.push_back(ResumeClone); + Clones.push_back(DestroyClone); + Clones.push_back(CleanupClone); + + // Create a constant array referring to resume/destroy/clone functions + // pointed by the last argument of @llvm.coro.info, so that CoroElide pass + // can determined correct function to call. + setCoroInfo(F, Shape, Clones); + } + +private: + // Create a resume clone by cloning the body of the original function, setting + // new entry block and replacing coro.suspend an appropriate value to force + // resume or cleanup pass for every suspend point. + static Function *createClone(Function &F, const Twine &Suffix, + coro::Shape &Shape, CoroCloner::Kind FKind, + TargetTransformInfo &TTI) { + CoroCloner Cloner(F, Suffix, Shape, FKind, TTI); + Cloner.create(); + return Cloner.getFunction(); + } + + // Create an entry block for a resume function with a switch that will jump to + // suspend points. + static void createResumeEntryBlock(Function &F, coro::Shape &Shape) { + LLVMContext &C = F.getContext(); + + // resume.entry: + // %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 + // 0, i32 2 % index = load i32, i32* %index.addr switch i32 %index, label + // %unreachable [ + // i32 0, label %resume.0 + // i32 1, label %resume.1 + // ... + // ] + + auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F); + auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F); + + IRBuilder<> Builder(NewEntry); + auto *FramePtr = Shape.FramePtr; + auto *FrameTy = Shape.FrameTy; + auto *GepIndex = Builder.CreateStructGEP( + FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr"); + auto *Index = Builder.CreateLoad(Shape.getIndexType(), GepIndex, "index"); + auto *Switch = + Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size()); + Shape.SwitchLowering.ResumeSwitch = Switch; + + size_t SuspendIndex = 0; + for (auto *AnyS : Shape.CoroSuspends) { + auto *S = cast<CoroSuspendInst>(AnyS); + ConstantInt *IndexVal = Shape.getIndex(SuspendIndex); + + // Replace CoroSave with a store to Index: + // %index.addr = getelementptr %f.frame... (index field number) + // store i32 %IndexVal, i32* %index.addr1 + auto *Save = S->getCoroSave(); + Builder.SetInsertPoint(Save); + if (S->isFinal()) { + // The coroutine should be marked done if it reaches the final suspend + // point. + markCoroutineAsDone(Builder, Shape, FramePtr); + } else { + auto *GepIndex = Builder.CreateStructGEP( + FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr"); + Builder.CreateStore(IndexVal, GepIndex); + } + + Save->replaceAllUsesWith(ConstantTokenNone::get(C)); + Save->eraseFromParent(); + + // Split block before and after coro.suspend and add a jump from an entry + // switch: + // + // whateverBB: + // whatever + // %0 = call i8 @llvm.coro.suspend(token none, i1 false) + // switch i8 %0, label %suspend[i8 0, label %resume + // i8 1, label %cleanup] + // becomes: + // + // whateverBB: + // whatever + // br label %resume.0.landing + // + // resume.0: ; <--- jump from the switch in the resume.entry + // %0 = tail call i8 @llvm.coro.suspend(token none, i1 false) + // br label %resume.0.landing + // + // resume.0.landing: + // %1 = phi i8[-1, %whateverBB], [%0, %resume.0] + // switch i8 % 1, label %suspend [i8 0, label %resume + // i8 1, label %cleanup] + + auto *SuspendBB = S->getParent(); + auto *ResumeBB = + SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex)); + auto *LandingBB = ResumeBB->splitBasicBlock( + S->getNextNode(), ResumeBB->getName() + Twine(".landing")); + Switch->addCase(IndexVal, ResumeBB); + + cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB); + auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, ""); + PN->insertBefore(LandingBB->begin()); + S->replaceAllUsesWith(PN); + PN->addIncoming(Builder.getInt8(-1), SuspendBB); + PN->addIncoming(S, ResumeBB); + + ++SuspendIndex; + } + + Builder.SetInsertPoint(UnreachBB); + Builder.CreateUnreachable(); + + Shape.SwitchLowering.ResumeEntryBlock = NewEntry; + } + + // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame. + static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn, + Function *DestroyFn, Function *CleanupFn) { + IRBuilder<> Builder(&*Shape.getInsertPtAfterFramePtr()); + + auto *ResumeAddr = Builder.CreateStructGEP( + Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Resume, + "resume.addr"); + Builder.CreateStore(ResumeFn, ResumeAddr); + + Value *DestroyOrCleanupFn = DestroyFn; + + CoroIdInst *CoroId = Shape.getSwitchCoroId(); + if (CoroAllocInst *CA = CoroId->getCoroAlloc()) { + // If there is a CoroAlloc and it returns false (meaning we elide the + // allocation, use CleanupFn instead of DestroyFn). + DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn); + } + + auto *DestroyAddr = Builder.CreateStructGEP( + Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Destroy, + "destroy.addr"); + Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr); + } + + // Create a global constant array containing pointers to functions provided + // and set Info parameter of CoroBegin to point at this constant. Example: + // + // @f.resumers = internal constant [2 x void(%f.frame*)*] + // [void(%f.frame*)* @f.resume, void(%f.frame*)* + // @f.destroy] + // define void @f() { + // ... + // call i8* @llvm.coro.begin(i8* null, i32 0, i8* null, + // i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to + // i8*)) + // + // Assumes that all the functions have the same signature. + static void setCoroInfo(Function &F, coro::Shape &Shape, + ArrayRef<Function *> Fns) { + // This only works under the switch-lowering ABI because coro elision + // only works on the switch-lowering ABI. + SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end()); + assert(!Args.empty()); + Function *Part = *Fns.begin(); + Module *M = Part->getParent(); + auto *ArrTy = ArrayType::get(Part->getType(), Args.size()); + + auto *ConstVal = ConstantArray::get(ArrTy, Args); + auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true, + GlobalVariable::PrivateLinkage, ConstVal, + F.getName() + Twine(".resumers")); + + // Update coro.begin instruction to refer to this constant. + LLVMContext &C = F.getContext(); + auto *BC = ConstantExpr::getPointerCast(GV, PointerType::getUnqual(C)); + Shape.getSwitchCoroId()->setInfo(BC); + } +}; + +} // namespace static void replaceAsyncResumeFunction(CoroSuspendAsyncInst *Suspend, Value *Continuation) { @@ -1748,6 +1656,7 @@ static void coerceArguments(IRBuilder<> &Builder, FunctionType *FnTy, } CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn, + TargetTransformInfo &TTI, ArrayRef<Value *> Arguments, IRBuilder<> &Builder) { auto *FnTy = MustTailCallFn->getFunctionType(); @@ -1757,14 +1666,18 @@ CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn, coerceArguments(Builder, FnTy, Arguments, CallArgs); auto *TailCall = Builder.CreateCall(FnTy, MustTailCallFn, CallArgs); - TailCall->setTailCallKind(CallInst::TCK_MustTail); + // Skip targets which don't support tail call. + if (TTI.supportsTailCallFor(TailCall)) { + TailCall->setTailCallKind(CallInst::TCK_MustTail); + } TailCall->setDebugLoc(Loc); TailCall->setCallingConv(MustTailCallFn->getCallingConv()); return TailCall; } static void splitAsyncCoroutine(Function &F, coro::Shape &Shape, - SmallVectorImpl<Function *> &Clones) { + SmallVectorImpl<Function *> &Clones, + TargetTransformInfo &TTI) { assert(Shape.ABI == coro::ABI::Async); assert(Clones.empty()); // Reset various things that the optimizer might have decided it @@ -1806,11 +1719,10 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape, auto ProjectionFunctionName = Suspend->getAsyncContextProjectionFunction()->getName(); bool UseSwiftMangling = false; - if (ProjectionFunctionName.equals("__swift_async_resume_project_context")) { + if (ProjectionFunctionName == "__swift_async_resume_project_context") { ResumeNameSuffix = "TQ"; UseSwiftMangling = true; - } else if (ProjectionFunctionName.equals( - "__swift_async_resume_get_context")) { + } else if (ProjectionFunctionName == "__swift_async_resume_get_context") { ResumeNameSuffix = "TY"; UseSwiftMangling = true; } @@ -1839,13 +1751,11 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape, SmallVector<Value *, 8> Args(Suspend->args()); auto FnArgs = ArrayRef<Value *>(Args).drop_front( CoroSuspendAsyncInst::MustTailCallFuncArg + 1); - auto *TailCall = - coro::createMustTailCall(Suspend->getDebugLoc(), Fn, FnArgs, Builder); + auto *TailCall = coro::createMustTailCall(Suspend->getDebugLoc(), Fn, TTI, + FnArgs, Builder); Builder.CreateRetVoid(); InlineFunctionInfo FnInfo; - auto InlineRes = InlineFunction(*TailCall, FnInfo); - assert(InlineRes.isSuccess() && "Expected inlining to succeed"); - (void)InlineRes; + (void)InlineFunction(*TailCall, FnInfo); // Replace the lvm.coro.async.resume intrisic call. replaceAsyncResumeFunction(Suspend, Continuation); @@ -1856,14 +1766,14 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape, auto *Suspend = Shape.CoroSuspends[Idx]; auto *Clone = Clones[Idx]; - CoroCloner(F, "resume." + Twine(Idx), Shape, Clone, Suspend).create(); + CoroCloner(F, "resume." + Twine(Idx), Shape, Clone, Suspend, TTI).create(); } } static void splitRetconCoroutine(Function &F, coro::Shape &Shape, - SmallVectorImpl<Function *> &Clones) { - assert(Shape.ABI == coro::ABI::Retcon || - Shape.ABI == coro::ABI::RetconOnce); + SmallVectorImpl<Function *> &Clones, + TargetTransformInfo &TTI) { + assert(Shape.ABI == coro::ABI::Retcon || Shape.ABI == coro::ABI::RetconOnce); assert(Clones.empty()); // Reset various things that the optimizer might have decided it @@ -1881,7 +1791,7 @@ static void splitRetconCoroutine(Function &F, coro::Shape &Shape, IRBuilder<> Builder(Id); // Determine the size of the frame. - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); auto Size = DL.getTypeAllocSize(Shape.FrameTy); // Allocate. We don't need to update the call graph node because we're @@ -1889,7 +1799,7 @@ static void splitRetconCoroutine(Function &F, coro::Shape &Shape, // FIXME: pass the required alignment RawFramePtr = Shape.emitAlloc(Builder, Builder.getInt64(Size), nullptr); RawFramePtr = - Builder.CreateBitCast(RawFramePtr, Shape.CoroBegin->getType()); + Builder.CreateBitCast(RawFramePtr, Shape.CoroBegin->getType()); // Stash the allocated frame pointer in the continuation storage. Builder.CreateStore(RawFramePtr, Id->getStorage()); @@ -1929,8 +1839,8 @@ static void splitRetconCoroutine(Function &F, coro::Shape &Shape, // Create the unified return block. if (!ReturnBB) { // Place it before the first suspend. - ReturnBB = BasicBlock::Create(F.getContext(), "coro.return", &F, - NewSuspendBB); + ReturnBB = + BasicBlock::Create(F.getContext(), "coro.return", &F, NewSuspendBB); Shape.RetconLowering.ReturnBlock = ReturnBB; IRBuilder<> Builder(ReturnBB); @@ -1944,8 +1854,8 @@ static void splitRetconCoroutine(Function &F, coro::Shape &Shape, // Next, all the directly-yielded values. for (auto *ResultTy : Shape.getRetconResultTypes()) - ReturnPHIs.push_back(Builder.CreatePHI(ResultTy, - Shape.CoroSuspends.size())); + ReturnPHIs.push_back( + Builder.CreatePHI(ResultTy, Shape.CoroSuspends.size())); // Build the return value. auto RetTy = F.getReturnType(); @@ -1954,9 +1864,9 @@ static void splitRetconCoroutine(Function &F, coro::Shape &Shape, // We can't rely on the types matching up because that type would // have to be infinite. auto CastedContinuationTy = - (ReturnPHIs.size() == 1 ? RetTy : RetTy->getStructElementType(0)); + (ReturnPHIs.size() == 1 ? RetTy : RetTy->getStructElementType(0)); auto *CastedContinuation = - Builder.CreateBitCast(ReturnPHIs[0], CastedContinuationTy); + Builder.CreateBitCast(ReturnPHIs[0], CastedContinuationTy); Value *RetV; if (ReturnPHIs.size() == 1) { @@ -1985,22 +1895,23 @@ static void splitRetconCoroutine(Function &F, coro::Shape &Shape, auto Suspend = Shape.CoroSuspends[i]; auto Clone = Clones[i]; - CoroCloner(F, "resume." + Twine(i), Shape, Clone, Suspend).create(); + CoroCloner(F, "resume." + Twine(i), Shape, Clone, Suspend, TTI).create(); } } namespace { - class PrettyStackTraceFunction : public PrettyStackTraceEntry { - Function &F; - public: - PrettyStackTraceFunction(Function &F) : F(F) {} - void print(raw_ostream &OS) const override { - OS << "While splitting coroutine "; - F.printAsOperand(OS, /*print type*/ false, F.getParent()); - OS << "\n"; - } - }; -} +class PrettyStackTraceFunction : public PrettyStackTraceEntry { + Function &F; + +public: + PrettyStackTraceFunction(Function &F) : F(F) {} + void print(raw_ostream &OS) const override { + OS << "While splitting coroutine "; + F.printAsOperand(OS, /*print type*/ false, F.getParent()); + OS << "\n"; + } +}; +} // namespace static coro::Shape splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones, @@ -2016,8 +1927,10 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones, if (!Shape.CoroBegin) return Shape; + lowerAwaitSuspends(F, Shape); + simplifySuspendPoints(Shape); - buildCoroutineFrame(F, Shape, MaterializableCallback); + buildCoroutineFrame(F, Shape, TTI, MaterializableCallback); replaceFrameSizeAndAlignment(Shape); // If there are no suspend points, no split required, just remove @@ -2027,14 +1940,14 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones, } else { switch (Shape.ABI) { case coro::ABI::Switch: - splitSwitchCoroutine(F, Shape, Clones, TTI); + SwitchCoroutineSplitter::split(F, Shape, Clones, TTI); break; case coro::ABI::Async: - splitAsyncCoroutine(F, Shape, Clones); + splitAsyncCoroutine(F, Shape, Clones, TTI); break; case coro::ABI::Retcon: case coro::ABI::RetconOnce: - splitRetconCoroutine(F, Shape, Clones); + splitRetconCoroutine(F, Shape, Clones, TTI); break; } } @@ -2047,20 +1960,28 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones, // original function. The Cloner has already salvaged debug info in the new // coroutine funclets. SmallDenseMap<Argument *, AllocaInst *, 4> ArgToAllocaMap; - auto [DbgInsts, DPValues] = collectDbgVariableIntrinsics(F); + auto [DbgInsts, DbgVariableRecords] = collectDbgVariableIntrinsics(F); for (auto *DDI : DbgInsts) coro::salvageDebugInfo(ArgToAllocaMap, *DDI, Shape.OptimizeFrame, false /*UseEntryValue*/); - for (DPValue *DPV : DPValues) - coro::salvageDebugInfo(ArgToAllocaMap, *DPV, Shape.OptimizeFrame, + for (DbgVariableRecord *DVR : DbgVariableRecords) + coro::salvageDebugInfo(ArgToAllocaMap, *DVR, Shape.OptimizeFrame, false /*UseEntryValue*/); return Shape; } /// Remove calls to llvm.coro.end in the original function. -static void removeCoroEnds(const coro::Shape &Shape) { - for (auto *End : Shape.CoroEnds) { - replaceCoroEnd(End, Shape, Shape.FramePtr, /*in resume*/ false, nullptr); +static void removeCoroEndsFromRampFunction(const coro::Shape &Shape) { + if (Shape.ABI != coro::ABI::Switch) { + for (auto *End : Shape.CoroEnds) { + replaceCoroEnd(End, Shape, Shape.FramePtr, /*in resume*/ false, nullptr); + } + } else { + for (llvm::AnyCoroEndInst *End : Shape.CoroEnds) { + auto &Context = End->getContext(); + End->replaceAllUsesWith(ConstantInt::getFalse(Context)); + End->eraseFromParent(); + } } } @@ -2069,18 +1990,6 @@ static void updateCallGraphAfterCoroutineSplit( const SmallVectorImpl<Function *> &Clones, LazyCallGraph::SCC &C, LazyCallGraph &CG, CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR, FunctionAnalysisManager &FAM) { - if (!Shape.CoroBegin) - return; - - if (Shape.ABI != coro::ABI::Switch) - removeCoroEnds(Shape); - else { - for (llvm::AnyCoroEndInst *End : Shape.CoroEnds) { - auto &Context = End->getContext(); - End->replaceAllUsesWith(ConstantInt::getFalse(Context)); - End->eraseFromParent(); - } - } if (!Clones.empty()) { switch (Shape.ABI) { @@ -2196,12 +2105,6 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C, if (Coroutines.empty() && PrepareFns.empty()) return PreservedAnalyses::all(); - if (Coroutines.empty()) { - for (auto *PrepareFn : PrepareFns) { - replaceAllPrepares(PrepareFn, CG, C); - } - } - // Split all the coroutines. for (LazyCallGraph::Node *N : Coroutines) { Function &F = N->getFunction(); @@ -2214,6 +2117,7 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C, const coro::Shape Shape = splitCoroutine(F, Clones, FAM.getResult<TargetIRAnalysis>(F), OptimizeFrame, MaterializableCallback); + removeCoroEndsFromRampFunction(Shape); updateCallGraphAfterCoroutineSplit(*N, Shape, Clones, C, CG, AM, UR, FAM); ORE.emit([&]() { @@ -2231,11 +2135,9 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C, } } - if (!PrepareFns.empty()) { for (auto *PrepareFn : PrepareFns) { replaceAllPrepares(PrepareFn, CG, C); } - } return PreservedAnalyses::none(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/Coroutines.cpp index eef5543bae24..1a92bc163625 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Coroutines/Coroutines.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Coroutines/Coroutines.cpp @@ -47,15 +47,15 @@ coro::LowererBase::LowererBase(Module &M) // // call ptr @llvm.coro.subfn.addr(ptr %Arg, i8 %index) -Value *coro::LowererBase::makeSubFnCall(Value *Arg, int Index, - Instruction *InsertPt) { +CallInst *coro::LowererBase::makeSubFnCall(Value *Arg, int Index, + Instruction *InsertPt) { auto *IndexVal = ConstantInt::get(Type::getInt8Ty(Context), Index); auto *Fn = Intrinsic::getDeclaration(&TheModule, Intrinsic::coro_subfn_addr); assert(Index >= CoroSubFnInst::IndexFirst && Index < CoroSubFnInst::IndexLast && "makeSubFnCall: Index value out of range"); - return CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt); + return CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt->getIterator()); } // NOTE: Must be sorted! @@ -67,6 +67,9 @@ static const char *const CoroIntrinsics[] = { "llvm.coro.async.resume", "llvm.coro.async.size.replace", "llvm.coro.async.store_resume", + "llvm.coro.await.suspend.bool", + "llvm.coro.await.suspend.handle", + "llvm.coro.await.suspend.void", "llvm.coro.begin", "llvm.coro.destroy", "llvm.coro.done", @@ -157,8 +160,8 @@ static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin, CoroSuspendInst *SuspendInst) { Module *M = SuspendInst->getModule(); auto *Fn = Intrinsic::getDeclaration(M, Intrinsic::coro_save); - auto *SaveInst = - cast<CoroSaveInst>(CallInst::Create(Fn, CoroBegin, "", SuspendInst)); + auto *SaveInst = cast<CoroSaveInst>( + CallInst::Create(Fn, CoroBegin, "", SuspendInst->getIterator())); assert(!SuspendInst->getCoroSave()); SuspendInst->setArgOperand(0, SaveInst); return SaveInst; @@ -174,7 +177,11 @@ void coro::Shape::buildFrom(Function &F) { SmallVector<CoroSaveInst *, 2> UnusedCoroSaves; for (Instruction &I : instructions(F)) { - if (auto II = dyn_cast<IntrinsicInst>(&I)) { + // FIXME: coro_await_suspend_* are not proper `IntrinisicInst`s + // because they might be invoked + if (auto AWS = dyn_cast<CoroAwaitSuspendInst>(&I)) { + CoroAwaitSuspends.push_back(AWS); + } else if (auto II = dyn_cast<IntrinsicInst>(&I)) { switch (II->getIntrinsicID()) { default: continue; @@ -362,7 +369,7 @@ void coro::Shape::buildFrom(Function &F) { // calls, but that messes with our invariants. Re-insert the // bitcast and ignore this type mismatch. if (CastInst::isBitCastable(SrcTy, *RI)) { - auto BCI = new BitCastInst(*SI, *RI, "", Suspend); + auto BCI = new BitCastInst(*SI, *RI, "", Suspend->getIterator()); SI->set(BCI); continue; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp b/contrib/llvm-project/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp index fb7cba9edbdb..1a8096f647d8 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp @@ -133,6 +133,7 @@ static inline void maybeHandleGlobals(Module &M) { continue; G.setLinkage(GlobalVariable::ExternalWeakLinkage); + G.setInitializer(nullptr); G.setExternallyInitialized(true); } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/AlwaysInliner.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/AlwaysInliner.cpp index cc375f9badcd..1f787c733079 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/AlwaysInliner.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/AlwaysInliner.cpp @@ -15,12 +15,12 @@ #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/InlineAdvisor.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" -#include "llvm/Transforms/IPO/Inliner.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -37,86 +37,73 @@ bool AlwaysInlineImpl( function_ref<BlockFrequencyInfo &(Function &)> GetBFI) { SmallSetVector<CallBase *, 16> Calls; bool Changed = false; - SmallVector<Function *, 16> InlinedFunctions; - for (Function &F : M) { - // When callee coroutine function is inlined into caller coroutine function - // before coro-split pass, - // coro-early pass can not handle this quiet well. - // So we won't inline the coroutine function if it have not been unsplited + SmallVector<Function *, 16> InlinedComdatFunctions; + + for (Function &F : make_early_inc_range(M)) { if (F.isPresplitCoroutine()) continue; - if (!F.isDeclaration() && isInlineViable(F).isSuccess()) { - Calls.clear(); - - for (User *U : F.users()) - if (auto *CB = dyn_cast<CallBase>(U)) - if (CB->getCalledFunction() == &F && - CB->hasFnAttr(Attribute::AlwaysInline) && - !CB->getAttributes().hasFnAttr(Attribute::NoInline)) - Calls.insert(CB); - - for (CallBase *CB : Calls) { - Function *Caller = CB->getCaller(); - OptimizationRemarkEmitter ORE(Caller); - DebugLoc DLoc = CB->getDebugLoc(); - BasicBlock *Block = CB->getParent(); - - InlineFunctionInfo IFI(GetAssumptionCache, &PSI, - GetBFI ? &GetBFI(*Caller) : nullptr, - GetBFI ? &GetBFI(F) : nullptr); - - InlineResult Res = InlineFunction(*CB, IFI, /*MergeAttributes=*/true, - &GetAAR(F), InsertLifetime); - if (!Res.isSuccess()) { - ORE.emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc, - Block) - << "'" << ore::NV("Callee", &F) << "' is not inlined into '" - << ore::NV("Caller", Caller) - << "': " << ore::NV("Reason", Res.getFailureReason()); - }); - continue; - } - - emitInlinedIntoBasedOnCost( - ORE, DLoc, Block, F, *Caller, - InlineCost::getAlways("always inline attribute"), - /*ForProfileContext=*/false, DEBUG_TYPE); + if (F.isDeclaration() || !isInlineViable(F).isSuccess()) + continue; - Changed = true; + Calls.clear(); + + for (User *U : F.users()) + if (auto *CB = dyn_cast<CallBase>(U)) + if (CB->getCalledFunction() == &F && + CB->hasFnAttr(Attribute::AlwaysInline) && + !CB->getAttributes().hasFnAttr(Attribute::NoInline)) + Calls.insert(CB); + + for (CallBase *CB : Calls) { + Function *Caller = CB->getCaller(); + OptimizationRemarkEmitter ORE(Caller); + DebugLoc DLoc = CB->getDebugLoc(); + BasicBlock *Block = CB->getParent(); + + InlineFunctionInfo IFI(GetAssumptionCache, &PSI, + GetBFI ? &GetBFI(*Caller) : nullptr, + GetBFI ? &GetBFI(F) : nullptr); + + InlineResult Res = InlineFunction(*CB, IFI, /*MergeAttributes=*/true, + &GetAAR(F), InsertLifetime); + if (!Res.isSuccess()) { + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc, Block) + << "'" << ore::NV("Callee", &F) << "' is not inlined into '" + << ore::NV("Caller", Caller) + << "': " << ore::NV("Reason", Res.getFailureReason()); + }); + continue; } - if (F.hasFnAttribute(Attribute::AlwaysInline)) { - // Remember to try and delete this function afterward. This both avoids - // re-walking the rest of the module and avoids dealing with any - // iterator invalidation issues while deleting functions. - InlinedFunctions.push_back(&F); - } + emitInlinedIntoBasedOnCost( + ORE, DLoc, Block, F, *Caller, + InlineCost::getAlways("always inline attribute"), + /*ForProfileContext=*/false, DEBUG_TYPE); + + Changed = true; } - } - // Remove any live functions. - erase_if(InlinedFunctions, [&](Function *F) { - F->removeDeadConstantUsers(); - return !F->isDefTriviallyDead(); - }); - - // Delete the non-comdat ones from the module and also from our vector. - auto NonComdatBegin = partition( - InlinedFunctions, [&](Function *F) { return F->hasComdat(); }); - for (Function *F : make_range(NonComdatBegin, InlinedFunctions.end())) { - M.getFunctionList().erase(F); - Changed = true; + F.removeDeadConstantUsers(); + if (F.hasFnAttribute(Attribute::AlwaysInline) && F.isDefTriviallyDead()) { + // Remember to try and delete this function afterward. This allows to call + // filterDeadComdatFunctions() only once. + if (F.hasComdat()) { + InlinedComdatFunctions.push_back(&F); + } else { + M.getFunctionList().erase(F); + Changed = true; + } + } } - InlinedFunctions.erase(NonComdatBegin, InlinedFunctions.end()); - if (!InlinedFunctions.empty()) { + if (!InlinedComdatFunctions.empty()) { // Now we just have the comdat functions. Filter out the ones whose comdats // are not actually dead. - filterDeadComdatFunctions(InlinedFunctions); + filterDeadComdatFunctions(InlinedComdatFunctions); // The remaining functions are actually dead. - for (Function *F : InlinedFunctions) { + for (Function *F : InlinedComdatFunctions) { M.getFunctionList().erase(F); Changed = true; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp index 062a3d341007..99ec50aa4775 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -58,6 +58,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" #include "llvm/IR/NoFolder.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" @@ -203,7 +204,7 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM, // Loop over all the callers of the function, transforming the call sites to // pass in the loaded pointers. SmallVector<Value *, 16> Args; - const DataLayout &DL = F->getParent()->getDataLayout(); + const DataLayout &DL = F->getDataLayout(); SmallVector<WeakTrackingVH, 16> DeadArgs; while (!F->use_empty()) { @@ -266,9 +267,10 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM, CallBase *NewCS = nullptr; if (InvokeInst *II = dyn_cast<InvokeInst>(&CB)) { NewCS = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), - Args, OpBundles, "", &CB); + Args, OpBundles, "", CB.getIterator()); } else { - auto *NewCall = CallInst::Create(NF, Args, OpBundles, "", &CB); + auto *NewCall = + CallInst::Create(NF, Args, OpBundles, "", CB.getIterator()); NewCall->setTailCallKind(cast<CallInst>(&CB)->getTailCallKind()); NewCS = NewCall; } @@ -421,11 +423,11 @@ doPromotion(Function *F, FunctionAnalysisManager &FAM, /// Return true if we can prove that all callees pass in a valid pointer for the /// specified function argument. -static bool allCallersPassValidPointerForArgument(Argument *Arg, - Align NeededAlign, - uint64_t NeededDerefBytes) { +static bool allCallersPassValidPointerForArgument( + Argument *Arg, SmallPtrSetImpl<CallBase *> &RecursiveCalls, + Align NeededAlign, uint64_t NeededDerefBytes) { Function *Callee = Arg->getParent(); - const DataLayout &DL = Callee->getParent()->getDataLayout(); + const DataLayout &DL = Callee->getDataLayout(); APInt Bytes(64, NeededDerefBytes); // Check if the argument itself is marked dereferenceable and aligned. @@ -436,6 +438,33 @@ static bool allCallersPassValidPointerForArgument(Argument *Arg, // direct callees. return all_of(Callee->users(), [&](User *U) { CallBase &CB = cast<CallBase>(*U); + // In case of functions with recursive calls, this check + // (isDereferenceableAndAlignedPointer) will fail when it tries to look at + // the first caller of this function. The caller may or may not have a load, + // incase it doesn't load the pointer being passed, this check will fail. + // So, it's safe to skip the check incase we know that we are dealing with a + // recursive call. For example we have a IR given below. + // + // def fun(ptr %a) { + // ... + // %loadres = load i32, ptr %a, align 4 + // %res = call i32 @fun(ptr %a) + // ... + // } + // + // def bar(ptr %x) { + // ... + // %resbar = call i32 @fun(ptr %x) + // ... + // } + // + // Since we record processed recursive calls, we check if the current + // CallBase has been processed before. If yes it means that it is a + // recursive call and we can skip the check just for this call. So, just + // return true. + if (RecursiveCalls.contains(&CB)) + return true; + return isDereferenceableAndAlignedPointer(CB.getArgOperand(Arg->getArgNo()), NeededAlign, Bytes, DL); }); @@ -569,6 +598,7 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR, SmallVector<const Use *, 16> Worklist; SmallPtrSet<const Use *, 16> Visited; SmallVector<LoadInst *, 16> Loads; + SmallPtrSet<CallBase *, 4> RecursiveCalls; auto AppendUses = [&](const Value *V) { for (const Use &U : V->uses()) if (Visited.insert(&U).second) @@ -609,6 +639,33 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR, // unknown users } + auto *CB = dyn_cast<CallBase>(V); + Value *PtrArg = U->get(); + if (CB && CB->getCalledFunction() == CB->getFunction()) { + if (PtrArg != Arg) { + LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: " + << "pointer offset is not equal to zero\n"); + return false; + } + + unsigned int ArgNo = Arg->getArgNo(); + if (U->getOperandNo() != ArgNo) { + LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: " + << "arg position is different in callee\n"); + return false; + } + + // We limit promotion to only promoting up to a fixed number of elements + // of the aggregate. + if (MaxElements > 0 && ArgParts.size() > MaxElements) { + LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: " + << "more than " << MaxElements << " parts\n"); + return false; + } + + RecursiveCalls.insert(CB); + continue; + } // Unknown user. LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: " << "unknown user " << *V << "\n"); @@ -617,7 +674,7 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR, if (NeededDerefBytes || NeededAlign > 1) { // Try to prove a required deref / aligned requirement. - if (!allCallersPassValidPointerForArgument(Arg, NeededAlign, + if (!allCallersPassValidPointerForArgument(Arg, RecursiveCalls, NeededAlign, NeededDerefBytes)) { LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: " << "not dereferenceable or aligned\n"); @@ -753,7 +810,7 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM, if (BB.getTerminatingMustTailCall()) return nullptr; - const DataLayout &DL = F->getParent()->getDataLayout(); + const DataLayout &DL = F->getDataLayout(); auto &AAR = FAM.getResult<AAManager>(*F); const auto &TTI = FAM.getResult<TargetIRAnalysis>(*F); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp index d8e290cbc8a4..910c0aeacc42 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/Attributor.cpp @@ -132,13 +132,13 @@ static cl::opt<bool> #ifndef NDEBUG static cl::list<std::string> SeedAllowList("attributor-seed-allow-list", cl::Hidden, - cl::desc("Comma seperated list of attribute names that are " + cl::desc("Comma separated list of attribute names that are " "allowed to be seeded."), cl::CommaSeparated); static cl::list<std::string> FunctionSeedAllowList( "attributor-function-seed-allow-list", cl::Hidden, - cl::desc("Comma seperated list of function names that are " + cl::desc("Comma separated list of function names that are " "allowed to be seeded."), cl::CommaSeparated); #endif @@ -275,7 +275,7 @@ AA::getInitialValueForObj(Attributor &A, const AbstractAttribute &QueryingAA, return ConstantFoldLoadFromConst(Initializer, &Ty, Offset, DL); } - return ConstantFoldLoadFromUniformValue(Initializer, &Ty); + return ConstantFoldLoadFromUniformValue(Initializer, &Ty, DL); } bool AA::isValidInScope(const Value &V, const Function *Scope) { @@ -1749,6 +1749,10 @@ bool Attributor::checkForAllCallees( return Pred(Callees.getArrayRef()); } +bool canMarkAsVisited(const User *Usr) { + return isa<PHINode>(Usr) || !isa<Instruction>(Usr); +} + bool Attributor::checkForAllUses( function_ref<bool(const Use &, bool &)> Pred, const AbstractAttribute &QueryingAA, const Value &V, @@ -1796,7 +1800,7 @@ bool Attributor::checkForAllUses( while (!Worklist.empty()) { const Use *U = Worklist.pop_back_val(); - if (isa<PHINode>(U->getUser()) && !Visited.insert(U).second) + if (canMarkAsVisited(U->getUser()) && !Visited.insert(U).second) continue; DEBUG_WITH_TYPE(VERBOSE_DEBUG_TYPE, { if (auto *Fn = dyn_cast<Function>(U->getUser())) @@ -2381,8 +2385,7 @@ void Attributor::identifyDeadInternalFunctions() { bool FoundLiveInternal = true; while (FoundLiveInternal) { FoundLiveInternal = false; - for (unsigned u = 0, e = InternalFns.size(); u < e; ++u) { - Function *F = InternalFns[u]; + for (Function *&F : InternalFns) { if (!F) continue; @@ -2399,13 +2402,13 @@ void Attributor::identifyDeadInternalFunctions() { } LiveInternalFns.insert(F); - InternalFns[u] = nullptr; + F = nullptr; FoundLiveInternal = true; } } - for (unsigned u = 0, e = InternalFns.size(); u < e; ++u) - if (Function *F = InternalFns[u]) + for (Function *F : InternalFns) + if (F) ToBeDeletedFunctions.insert(F); } @@ -2551,12 +2554,9 @@ ChangeStatus Attributor::cleanupIR() { for (const auto &V : ToBeDeletedInsts) { if (Instruction *I = dyn_cast_or_null<Instruction>(V)) { - if (auto *CB = dyn_cast<CallBase>(I)) { - assert((isa<IntrinsicInst>(CB) || isRunOn(*I->getFunction())) && - "Cannot delete an instruction outside the current SCC!"); - if (!isa<IntrinsicInst>(CB)) - Configuration.CGUpdater.removeCallSite(*CB); - } + assert((!isa<CallBase>(I) || isa<IntrinsicInst>(I) || + isRunOn(*I->getFunction())) && + "Cannot delete an instruction outside the current SCC!"); I->dropDroppableUses(); CGModifiedFunctions.insert(I->getFunction()); if (!I->getType()->isVoidTy()) @@ -2738,6 +2738,8 @@ void Attributor::createShallowWrapper(Function &F) { Function::Create(FnTy, F.getLinkage(), F.getAddressSpace(), F.getName()); F.setName(""); // set the inside function anonymous M.getFunctionList().insert(F.getIterator(), Wrapper); + // Flag whether the function is using new-debug-info or not. + Wrapper->IsNewDbgInfoFormat = M.IsNewDbgInfoFormat; F.setLinkage(GlobalValue::InternalLinkage); @@ -2818,6 +2820,8 @@ bool Attributor::internalizeFunctions(SmallPtrSetImpl<Function *> &FnSet, VMap[&Arg] = &(*NewFArgIt++); } SmallVector<ReturnInst *, 8> Returns; + // Flag whether the function is using new-debug-info or not. + Copied->IsNewDbgInfoFormat = F->IsNewDbgInfoFormat; // Copy the body of the original function to the new one CloneFunctionInto(Copied, F, VMap, @@ -3035,6 +3039,8 @@ ChangeStatus Attributor::rewriteFunctionSignatures( OldFn->getParent()->getFunctionList().insert(OldFn->getIterator(), NewFn); NewFn->takeName(OldFn); NewFn->copyAttributesFrom(OldFn); + // Flag whether the function is using new-debug-info or not. + NewFn->IsNewDbgInfoFormat = OldFn->IsNewDbgInfoFormat; // Patch the pointer to LLVM function in debug info descriptor. NewFn->setSubprogram(OldFn->getSubprogram()); @@ -3117,12 +3123,12 @@ ChangeStatus Attributor::rewriteFunctionSignatures( // Create a new call or invoke instruction to replace the old one. CallBase *NewCB; if (InvokeInst *II = dyn_cast<InvokeInst>(OldCB)) { - NewCB = - InvokeInst::Create(NewFn, II->getNormalDest(), II->getUnwindDest(), - NewArgOperands, OperandBundleDefs, "", OldCB); + NewCB = InvokeInst::Create(NewFn, II->getNormalDest(), + II->getUnwindDest(), NewArgOperands, + OperandBundleDefs, "", OldCB->getIterator()); } else { auto *NewCI = CallInst::Create(NewFn, NewArgOperands, OperandBundleDefs, - "", OldCB); + "", OldCB->getIterator()); NewCI->setTailCallKind(cast<CallInst>(OldCB)->getTailCallKind()); NewCB = NewCI; } @@ -3177,7 +3183,6 @@ ChangeStatus Attributor::rewriteFunctionSignatures( assert(OldCB.getType() == NewCB.getType() && "Cannot handle call sites with different types!"); ModifiedFns.insert(OldCB.getFunction()); - Configuration.CGUpdater.replaceCallSite(OldCB, NewCB); OldCB.replaceAllUsesWith(&NewCB); OldCB.eraseFromParent(); } @@ -3948,7 +3953,7 @@ static bool runAttributorLightOnFunctions(InformationCache &InfoCache, // We look at internal functions only on-demand but if any use is not a // direct call or outside the current set of analyzed functions, we have // to do it eagerly. - if (F->hasLocalLinkage()) { + if (AC.UseLiveness && F->hasLocalLinkage()) { if (llvm::all_of(F->uses(), [&Functions](const Use &U) { const auto *CB = dyn_cast<CallBase>(U.getUser()); return CB && CB->isCallee(&U) && diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp index 585364dd7aa2..2816a85743fa 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -419,7 +419,8 @@ struct AAReturnedFromReturnedValues : public BaseType { /// See AbstractAttribute::updateImpl(...). ChangeStatus updateImpl(Attributor &A) override { StateType S(StateType::getBestState(this->getState())); - clampReturnedValueStates<AAType, StateType, IRAttributeKind, RecurseForSelectAndPHI>( + clampReturnedValueStates<AAType, StateType, IRAttributeKind, + RecurseForSelectAndPHI>( A, *this, S, PropagateCallBaseContext ? this->getCallBaseContext() : nullptr); // TODO: If we know we visited all returned values, thus no are assumed @@ -1472,7 +1473,7 @@ struct AAPointerInfoFloating : public AAPointerInfoImpl { // Make a strictly ascending list of offsets as required by addAccess() llvm::sort(Offsets); - auto *Last = std::unique(Offsets.begin(), Offsets.end()); + auto *Last = llvm::unique(Offsets); Offsets.erase(Last, Offsets.end()); VectorType *VT = dyn_cast<VectorType>(&Ty); @@ -1607,22 +1608,19 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) { // // The RHS is a reference that may be invalidated by an insertion caused by // the LHS. So we ensure that the side-effect of the LHS happens first. + + assert(OffsetInfoMap.contains(CurPtr) && + "CurPtr does not exist in the map!"); + auto &UsrOI = OffsetInfoMap[Usr]; auto &PtrOI = OffsetInfoMap[CurPtr]; assert(!PtrOI.isUnassigned() && "Cannot pass through if the input Ptr was not visited!"); - UsrOI = PtrOI; + UsrOI.merge(PtrOI); Follow = true; return true; }; - const auto *F = getAnchorScope(); - const auto *CI = - F ? A.getInfoCache().getAnalysisResultForFunction<CycleAnalysis>(*F) - : nullptr; - const auto *TLI = - F ? A.getInfoCache().getTargetLibraryInfoForFunction(*F) : nullptr; - auto UsePred = [&](const Use &U, bool &Follow) -> bool { Value *CurPtr = U.get(); User *Usr = U.getUser(); @@ -1634,8 +1632,6 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) { if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Usr)) { if (CE->isCast()) return HandlePassthroughUser(Usr, CurPtr, Follow); - if (CE->isCompare()) - return true; if (!isa<GEPOperator>(CE)) { LLVM_DEBUG(dbgs() << "[AAPointerInfo] Unhandled constant user " << *CE << "\n"); @@ -1668,18 +1664,18 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) { // For PHIs we need to take care of the recurrence explicitly as the value // might change while we iterate through a loop. For now, we give up if // the PHI is not invariant. - if (isa<PHINode>(Usr)) { + if (auto *PHI = dyn_cast<PHINode>(Usr)) { // Note the order here, the Usr access might change the map, CurPtr is // already in it though. - bool IsFirstPHIUser = !OffsetInfoMap.count(Usr); - auto &UsrOI = OffsetInfoMap[Usr]; + bool IsFirstPHIUser = !OffsetInfoMap.count(PHI); + auto &UsrOI = OffsetInfoMap[PHI]; auto &PtrOI = OffsetInfoMap[CurPtr]; // Check if the PHI operand has already an unknown offset as we can't // improve on that anymore. if (PtrOI.isUnknown()) { LLVM_DEBUG(dbgs() << "[AAPointerInfo] PHI operand offset unknown " - << *CurPtr << " in " << *Usr << "\n"); + << *CurPtr << " in " << *PHI << "\n"); Follow = !UsrOI.isUnknown(); UsrOI.setUnknown(); return true; @@ -1702,7 +1698,8 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) { auto It = OffsetInfoMap.find(CurPtrBase); if (It == OffsetInfoMap.end()) { LLVM_DEBUG(dbgs() << "[AAPointerInfo] PHI operand is too complex " - << *CurPtr << " in " << *Usr << "\n"); + << *CurPtr << " in " << *PHI + << " (base: " << *CurPtrBase << ")\n"); UsrOI.setUnknown(); Follow = true; return true; @@ -1715,6 +1712,9 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) { // Cycles reported by CycleInfo. It is sufficient to check the PHIs in // every Cycle header; if such a node is marked unknown, this will // eventually propagate through the whole net of PHIs in the recurrence. + const auto *CI = + A.getInfoCache().getAnalysisResultForFunction<CycleAnalysis>( + *PHI->getFunction()); if (mayBeInCycle(CI, cast<Instruction>(Usr), /* HeaderOnly */ true)) { auto BaseOI = It->getSecond(); BaseOI.addToAll(Offset.getZExtValue()); @@ -1726,7 +1726,7 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) { LLVM_DEBUG( dbgs() << "[AAPointerInfo] PHI operand pointer offset mismatch " - << *CurPtr << " in " << *Usr << "\n"); + << *CurPtr << " in " << *PHI << "\n"); UsrOI.setUnknown(); Follow = true; return true; @@ -1879,6 +1879,8 @@ ChangeStatus AAPointerInfoFloating::updateImpl(Attributor &A) { if (auto *CB = dyn_cast<CallBase>(Usr)) { if (CB->isLifetimeStartOrEnd()) return true; + const auto *TLI = + A.getInfoCache().getTargetLibraryInfoForFunction(*CB->getFunction()); if (getFreedOperand(CB, TLI) == U) return true; if (CB->isArgOperand(&U)) { @@ -2447,13 +2449,14 @@ bool AANonNull::isImpliedByIR(Attributor &A, const IRPosition &IRP, return true; }, IRP.getAssociatedFunction(), nullptr, {Instruction::Ret}, - UsedAssumedInformation)) + UsedAssumedInformation, false, /*CheckPotentiallyDead=*/true)) return false; } if (llvm::any_of(Worklist, [&](AA::ValueAndContext VAC) { - return !isKnownNonZero(VAC.getValue(), A.getDataLayout(), 0, AC, - VAC.getCtxI(), DT); + return !isKnownNonZero( + VAC.getValue(), + SimplifyQuery(A.getDataLayout(), DT, AC, VAC.getCtxI())); })) return false; @@ -5190,6 +5193,12 @@ static unsigned getKnownAlignForUse(Attributor &A, AAAlign &QueryingAA, } else if (auto *LI = dyn_cast<LoadInst>(I)) { if (LI->getPointerOperand() == UseV) MA = LI->getAlign(); + } else if (auto *AI = dyn_cast<AtomicRMWInst>(I)) { + if (AI->getPointerOperand() == UseV) + MA = AI->getAlign(); + } else if (auto *AI = dyn_cast<AtomicCmpXchgInst>(I)) { + if (AI->getPointerOperand() == UseV) + MA = AI->getAlign(); } if (!MA || *MA <= QueryingAA.getKnownAlign()) @@ -5683,6 +5692,9 @@ bool AANoCapture::isImpliedByIR(Attributor &A, const IRPosition &IRP, return V.use_empty(); // You cannot "capture" null in the default address space. + // + // FIXME: This should use NullPointerIsDefined to account for the function + // attribute. if (isa<UndefValue>(V) || (isa<ConstantPointerNull>(V) && V.getType()->getPointerAddressSpace() == 0)) { return true; @@ -5892,10 +5904,13 @@ ChangeStatus AANoCaptureImpl::updateImpl(Attributor &A) { const Function *F = isArgumentPosition() ? IRP.getAssociatedFunction() : IRP.getAnchorScope(); - assert(F && "Expected a function!"); - const IRPosition &FnPos = IRPosition::function(*F); + + // TODO: Is the checkForAllUses below useful for constants? + if (!F) + return indicatePessimisticFixpoint(); AANoCapture::StateType T; + const IRPosition &FnPos = IRPosition::function(*F); // Readonly means we cannot capture through memory. bool IsKnown; @@ -6128,8 +6143,8 @@ struct AAValueSimplifyImpl : AAValueSimplify { return TypedV; if (CtxI && V.getType()->canLosslesslyBitCastTo(&Ty)) return Check ? &V - : BitCastInst::CreatePointerBitCastOrAddrSpaceCast(&V, &Ty, - "", CtxI); + : BitCastInst::CreatePointerBitCastOrAddrSpaceCast( + &V, &Ty, "", CtxI->getIterator()); return nullptr; } @@ -6731,8 +6746,9 @@ struct AAHeapToStackFunction final : public AAHeapToStack { Size = SizeOffsetPair.Size; } - Instruction *IP = - AI.MoveAllocaIntoEntry ? &F->getEntryBlock().front() : AI.CB; + BasicBlock::iterator IP = AI.MoveAllocaIntoEntry + ? F->getEntryBlock().begin() + : AI.CB->getIterator(); Align Alignment(1); if (MaybeAlign RetAlign = AI.CB->getRetAlign()) @@ -6753,7 +6769,7 @@ struct AAHeapToStackFunction final : public AAHeapToStack { if (Alloca->getType() != AI.CB->getType()) Alloca = BitCastInst::CreatePointerBitCastOrAddrSpaceCast( - Alloca, AI.CB->getType(), "malloc_cast", AI.CB); + Alloca, AI.CB->getType(), "malloc_cast", AI.CB->getIterator()); auto *I8Ty = Type::getInt8Ty(F->getContext()); auto *InitVal = getInitialValueOfAllocation(AI.CB, TLI, I8Ty); @@ -6961,10 +6977,9 @@ ChangeStatus AAHeapToStackFunction::updateImpl(Attributor &A) { if (AI.LibraryFunctionId != LibFunc___kmpc_alloc_shared) { Instruction *CtxI = isa<InvokeInst>(AI.CB) ? AI.CB : AI.CB->getNextNode(); if (!Explorer || !Explorer->findInContextOf(UniqueFree, CtxI)) { - LLVM_DEBUG( - dbgs() - << "[H2S] unique free call might not be executed with the allocation " - << *UniqueFree << "\n"); + LLVM_DEBUG(dbgs() << "[H2S] unique free call might not be executed " + "with the allocation " + << *UniqueFree << "\n"); return false; } } @@ -7450,11 +7465,11 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { /// The values needed are taken from the arguments of \p F starting at /// position \p ArgNo. static void createInitialization(Type *PrivType, Value &Base, Function &F, - unsigned ArgNo, Instruction &IP) { + unsigned ArgNo, BasicBlock::iterator IP) { assert(PrivType && "Expected privatizable type!"); - IRBuilder<NoFolder> IRB(&IP); - const DataLayout &DL = F.getParent()->getDataLayout(); + IRBuilder<NoFolder> IRB(IP->getParent(), IP); + const DataLayout &DL = F.getDataLayout(); // Traverse the type, build GEPs and stores. if (auto *PrivStructType = dyn_cast<StructType>(PrivType)) { @@ -7462,17 +7477,17 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { for (unsigned u = 0, e = PrivStructType->getNumElements(); u < e; u++) { Value *Ptr = constructPointer(&Base, PrivStructLayout->getElementOffset(u), IRB); - new StoreInst(F.getArg(ArgNo + u), Ptr, &IP); + new StoreInst(F.getArg(ArgNo + u), Ptr, IP); } } else if (auto *PrivArrayType = dyn_cast<ArrayType>(PrivType)) { Type *PointeeTy = PrivArrayType->getElementType(); uint64_t PointeeTySize = DL.getTypeStoreSize(PointeeTy); for (unsigned u = 0, e = PrivArrayType->getNumElements(); u < e; u++) { Value *Ptr = constructPointer(&Base, u * PointeeTySize, IRB); - new StoreInst(F.getArg(ArgNo + u), Ptr, &IP); + new StoreInst(F.getArg(ArgNo + u), Ptr, IP); } } else { - new StoreInst(F.getArg(ArgNo), &Base, &IP); + new StoreInst(F.getArg(ArgNo), &Base, IP); } } @@ -7486,7 +7501,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { Instruction *IP = ACS.getInstruction(); IRBuilder<NoFolder> IRB(IP); - const DataLayout &DL = IP->getModule()->getDataLayout(); + const DataLayout &DL = IP->getDataLayout(); // Traverse the type, build GEPs and loads. if (auto *PrivStructType = dyn_cast<StructType>(PrivType)) { @@ -7495,7 +7510,7 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { Type *PointeeTy = PrivStructType->getElementType(u); Value *Ptr = constructPointer(Base, PrivStructLayout->getElementOffset(u), IRB); - LoadInst *L = new LoadInst(PointeeTy, Ptr, "", IP); + LoadInst *L = new LoadInst(PointeeTy, Ptr, "", IP->getIterator()); L->setAlignment(Alignment); ReplacementValues.push_back(L); } @@ -7504,12 +7519,12 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { uint64_t PointeeTySize = DL.getTypeStoreSize(PointeeTy); for (unsigned u = 0, e = PrivArrayType->getNumElements(); u < e; u++) { Value *Ptr = constructPointer(Base, u * PointeeTySize, IRB); - LoadInst *L = new LoadInst(PointeeTy, Ptr, "", IP); + LoadInst *L = new LoadInst(PointeeTy, Ptr, "", IP->getIterator()); L->setAlignment(Alignment); ReplacementValues.push_back(L); } } else { - LoadInst *L = new LoadInst(PrivType, Base, "", IP); + LoadInst *L = new LoadInst(PrivType, Base, "", IP->getIterator()); L->setAlignment(Alignment); ReplacementValues.push_back(L); } @@ -7549,13 +7564,13 @@ struct AAPrivatizablePtrArgument final : public AAPrivatizablePtrImpl { [=](const Attributor::ArgumentReplacementInfo &ARI, Function &ReplacementFn, Function::arg_iterator ArgIt) { BasicBlock &EntryBB = ReplacementFn.getEntryBlock(); - Instruction *IP = &*EntryBB.getFirstInsertionPt(); - const DataLayout &DL = IP->getModule()->getDataLayout(); + BasicBlock::iterator IP = EntryBB.getFirstInsertionPt(); + const DataLayout &DL = IP->getDataLayout(); unsigned AS = DL.getAllocaAddrSpace(); Instruction *AI = new AllocaInst(*PrivatizableType, AS, Arg->getName() + ".priv", IP); createInitialization(*PrivatizableType, *AI, ReplacementFn, - ArgIt->getArgNo(), *IP); + ArgIt->getArgNo(), IP); if (AI->getType() != Arg->getType()) AI = BitCastInst::CreatePointerBitCastOrAddrSpaceCast( @@ -8850,7 +8865,7 @@ struct AADenormalFPMathImpl : public AADenormalFPMath { if (Known.ModeF32.isValid()) OS << " denormal-fp-math-f32=" << Known.ModeF32; OS << ']'; - return OS.str(); + return Str; } }; @@ -8963,7 +8978,7 @@ struct AAValueConstantRangeImpl : AAValueConstantRange { OS << " / "; getAssumed().print(OS); OS << ">"; - return OS.str(); + return Str; } /// Helper function to get a SCEV expr for the associated value at program @@ -9640,7 +9655,7 @@ struct AAPotentialConstantValuesImpl : AAPotentialConstantValues { std::string Str; llvm::raw_string_ostream OS(Str); OS << getState(); - return OS.str(); + return Str; } /// See AbstractAttribute::updateImpl(...). @@ -10324,48 +10339,25 @@ struct AANoFPClassImpl : AANoFPClass { /// See followUsesInMBEC bool followUseInMBEC(Attributor &A, const Use *U, const Instruction *I, AANoFPClass::StateType &State) { - const Value *UseV = U->get(); - const DominatorTree *DT = nullptr; - AssumptionCache *AC = nullptr; - const TargetLibraryInfo *TLI = nullptr; - InformationCache &InfoCache = A.getInfoCache(); - - if (Function *F = getAnchorScope()) { - DT = InfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(*F); - AC = InfoCache.getAnalysisResultForFunction<AssumptionAnalysis>(*F); - TLI = InfoCache.getTargetLibraryInfoForFunction(*F); - } - - const DataLayout &DL = A.getDataLayout(); - - KnownFPClass KnownFPClass = - computeKnownFPClass(UseV, DL, - /*InterestedClasses=*/fcAllFlags, - /*Depth=*/0, TLI, AC, I, DT); - State.addKnownBits(~KnownFPClass.KnownFPClasses); - - if (auto *CI = dyn_cast<CallInst>(UseV)) { - // Special case FP intrinsic with struct return type. - switch (CI->getIntrinsicID()) { - case Intrinsic::frexp: - return true; - case Intrinsic::not_intrinsic: - // TODO: Could recognize math libcalls - return false; - default: - break; - } - } + // TODO: Determine what instructions can be looked through. + auto *CB = dyn_cast<CallBase>(I); + if (!CB) + return false; - if (!UseV->getType()->isFPOrFPVectorTy()) + if (!CB->isArgOperand(U)) return false; - return !isa<LoadInst, AtomicRMWInst>(UseV); + + unsigned ArgNo = CB->getArgOperandNo(U); + IRPosition IRP = IRPosition::callsite_argument(*CB, ArgNo); + if (auto *NoFPAA = A.getAAFor<AANoFPClass>(*this, IRP, DepClassTy::NONE)) + State.addKnownBits(NoFPAA->getState().getKnown()); + return false; } const std::string getAsStr(Attributor *A) const override { std::string Result = "nofpclass"; raw_string_ostream OS(Result); - OS << getAssumedNoFPClass(); + OS << getKnownNoFPClass() << '/' << getAssumedNoFPClass(); return Result; } @@ -10417,11 +10409,12 @@ struct AANoFPClassFloating : public AANoFPClassImpl { struct AANoFPClassReturned final : AAReturnedFromReturnedValues<AANoFPClass, AANoFPClassImpl, - AANoFPClassImpl::StateType, false, Attribute::None, false> { + AANoFPClassImpl::StateType, false, + Attribute::None, false> { AANoFPClassReturned(const IRPosition &IRP, Attributor &A) : AAReturnedFromReturnedValues<AANoFPClass, AANoFPClassImpl, - AANoFPClassImpl::StateType, false, Attribute::None, false>( - IRP, A) {} + AANoFPClassImpl::StateType, false, + Attribute::None, false>(IRP, A) {} /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { @@ -10770,9 +10763,7 @@ struct AAPotentialValuesImpl : AAPotentialValues { return; } Value *Stripped = getAssociatedValue().stripPointerCasts(); - auto *CE = dyn_cast<ConstantExpr>(Stripped); - if (isa<Constant>(Stripped) && - (!CE || CE->getOpcode() != Instruction::ICmp)) { + if (isa<Constant>(Stripped) && !isa<ConstantExpr>(Stripped)) { addValue(A, getState(), *Stripped, getCtxI(), AA::AnyScope, getAnchorScope()); indicateOptimisticFixpoint(); @@ -10786,7 +10777,7 @@ struct AAPotentialValuesImpl : AAPotentialValues { std::string Str; llvm::raw_string_ostream OS(Str); OS << getState(); - return OS.str(); + return Str; } template <typename AAType> @@ -11284,7 +11275,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { const auto *TLI = A.getInfoCache().getTargetLibraryInfoForFunction(*F); auto *AC = InfoCache.getAnalysisResultForFunction<AssumptionAnalysis>(*F); - const DataLayout &DL = I.getModule()->getDataLayout(); + const DataLayout &DL = I.getDataLayout(); SimplifyQuery Q(DL, TLI, DT, AC, &I); Value *NewV = simplifyInstructionWithOperands(&I, NewOps, Q); if (!NewV || NewV == &I) @@ -11366,13 +11357,6 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl { continue; } - if (auto *CE = dyn_cast<ConstantExpr>(V)) { - if (CE->getOpcode() == Instruction::ICmp) - if (handleCmp(A, *CE, CE->getOperand(0), CE->getOperand(1), - CmpInst::Predicate(CE->getPredicate()), II, Worklist)) - continue; - } - if (auto *I = dyn_cast<Instruction>(V)) { if (simplifyInstruction(A, *I, II, Worklist, LivenessAAs)) continue; @@ -11754,11 +11738,14 @@ struct AAAssumptionInfoImpl : public AAAssumptionInfo { return ChangeStatus::UNCHANGED; const IRPosition &IRP = getIRPosition(); - return A.manifestAttrs( - IRP, - Attribute::get(IRP.getAnchorValue().getContext(), AssumptionAttrKey, - llvm::join(getAssumed().getSet(), ",")), - /* ForceReplace */ true); + SmallVector<StringRef, 0> Set(getAssumed().getSet().begin(), + getAssumed().getSet().end()); + llvm::sort(Set); + return A.manifestAttrs(IRP, + Attribute::get(IRP.getAnchorValue().getContext(), + AssumptionAttrKey, + llvm::join(Set, ",")), + /*ForceReplace=*/true); } bool hasAssumption(const StringRef Assumption) const override { @@ -11770,13 +11757,15 @@ struct AAAssumptionInfoImpl : public AAAssumptionInfo { const SetContents &Known = getKnown(); const SetContents &Assumed = getAssumed(); - const std::string KnownStr = - llvm::join(Known.getSet().begin(), Known.getSet().end(), ","); - const std::string AssumedStr = - (Assumed.isUniversal()) - ? "Universal" - : llvm::join(Assumed.getSet().begin(), Assumed.getSet().end(), ","); + SmallVector<StringRef, 0> Set(Known.getSet().begin(), Known.getSet().end()); + llvm::sort(Set); + const std::string KnownStr = llvm::join(Set, ","); + std::string AssumedStr = "Universal"; + if (!Assumed.isUniversal()) { + Set.assign(Assumed.getSet().begin(), Assumed.getSet().end()); + AssumedStr = llvm::join(Set, ","); + } return "Known [" + KnownStr + "]," + " Assumed [" + AssumedStr + "]"; } }; @@ -12313,10 +12302,10 @@ struct AAIndirectCallInfoCallSite : public AAIndirectCallInfo { Value *FP = CB->getCalledOperand(); if (FP->getType()->getPointerAddressSpace()) FP = new AddrSpaceCastInst(FP, PointerType::get(FP->getType(), 0), - FP->getName() + ".as0", CB); + FP->getName() + ".as0", CB->getIterator()); bool CBIsVoid = CB->getType()->isVoidTy(); - Instruction *IP = CB; + BasicBlock::iterator IP = CB->getIterator(); FunctionType *CSFT = CB->getFunctionType(); SmallVector<Value *> CSArgs(CB->arg_begin(), CB->arg_end()); @@ -12336,8 +12325,9 @@ struct AAIndirectCallInfoCallSite : public AAIndirectCallInfo { promoteCall(*CB, NewCallee, nullptr); return ChangeStatus::CHANGED; } - Instruction *NewCall = CallInst::Create(FunctionCallee(CSFT, NewCallee), - CSArgs, CB->getName(), CB); + Instruction *NewCall = + CallInst::Create(FunctionCallee(CSFT, NewCallee), CSArgs, + CB->getName(), CB->getIterator()); if (!CBIsVoid) A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewCall); A.deleteAfterManifest(*CB); @@ -12369,14 +12359,14 @@ struct AAIndirectCallInfoCallSite : public AAIndirectCallInfo { SplitBlockAndInsertIfThen(LastCmp, IP, /* Unreachable */ false); BasicBlock *CBBB = CB->getParent(); A.registerManifestAddedBasicBlock(*ThenTI->getParent()); - A.registerManifestAddedBasicBlock(*CBBB); + A.registerManifestAddedBasicBlock(*IP->getParent()); auto *SplitTI = cast<BranchInst>(LastCmp->getNextNode()); BasicBlock *ElseBB; - if (IP == CB) { + if (&*IP == CB) { ElseBB = BasicBlock::Create(ThenTI->getContext(), "", ThenTI->getFunction(), CBBB); A.registerManifestAddedBasicBlock(*ElseBB); - IP = BranchInst::Create(CBBB, ElseBB); + IP = BranchInst::Create(CBBB, ElseBB)->getIterator(); SplitTI->replaceUsesOfWith(CBBB, ElseBB); } else { ElseBB = IP->getParent(); @@ -12390,7 +12380,7 @@ struct AAIndirectCallInfoCallSite : public AAIndirectCallInfo { NewCall = &cast<CallInst>(promoteCall(*CBClone, NewCallee, &RetBC)); } else { NewCall = CallInst::Create(FunctionCallee(CSFT, NewCallee), CSArgs, - CB->getName(), ThenTI); + CB->getName(), ThenTI->getIterator()); } NewCalls.push_back({NewCall, RetBC}); } @@ -12416,7 +12406,7 @@ struct AAIndirectCallInfoCallSite : public AAIndirectCallInfo { } else { auto *CBClone = cast<CallInst>(CB->clone()); CBClone->setName(CB->getName()); - CBClone->insertBefore(IP); + CBClone->insertBefore(*IP->getParent(), IP); NewCalls.push_back({CBClone, nullptr}); AttachCalleeMetadata(*CBClone); } @@ -12425,7 +12415,7 @@ struct AAIndirectCallInfoCallSite : public AAIndirectCallInfo { if (!CBIsVoid) { auto *PHI = PHINode::Create(CB->getType(), NewCalls.size(), CB->getName() + ".phi", - &*CB->getParent()->getFirstInsertionPt()); + CB->getParent()->getFirstInsertionPt()); for (auto &It : NewCalls) { CallBase *NewCall = It.first; Instruction *CallRet = It.second ? It.second : It.first; @@ -12783,9 +12773,11 @@ struct AAAllocationInfoImpl : public AAAllocationInfo { auto *NumBytesToValue = ConstantInt::get(I->getContext(), APInt(32, NumBytesToAllocate)); + BasicBlock::iterator insertPt = AI->getIterator(); + insertPt = std::next(insertPt); AllocaInst *NewAllocaInst = new AllocaInst(CharType, AI->getAddressSpace(), NumBytesToValue, - AI->getAlign(), AI->getName(), AI->getNextNode()); + AI->getAlign(), AI->getName(), insertPt); if (A.changeAfterManifest(IRPosition::inst(*AI), *NewAllocaInst)) return ChangeStatus::CHANGED; diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/BlockExtractor.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/BlockExtractor.cpp index 0c406aa9822e..ec1be35a3316 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/BlockExtractor.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/BlockExtractor.cpp @@ -142,9 +142,8 @@ bool BlockExtractor::runOnModule(Module &M) { report_fatal_error("Invalid function name specified in the input file", /*GenCrashDiag=*/false); for (const auto &BBInfo : BInfo.second) { - auto Res = llvm::find_if(*F, [&](const BasicBlock &BB) { - return BB.getName().equals(BBInfo); - }); + auto Res = llvm::find_if( + *F, [&](const BasicBlock &BB) { return BB.getName() == BBInfo; }); if (Res == F->end()) report_fatal_error("Invalid block name specified in the input file", /*GenCrashDiag=*/false); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/CalledValuePropagation.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/CalledValuePropagation.cpp index 2c8756c07f87..acc10f57c29a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/CalledValuePropagation.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/CalledValuePropagation.cpp @@ -21,6 +21,7 @@ #include "llvm/Analysis/ValueLatticeUtils.h" #include "llvm/IR/Constants.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ConstantMerge.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ConstantMerge.cpp index 29052c8d997e..a1face0a6a9c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ConstantMerge.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ConstantMerge.cpp @@ -29,6 +29,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" #include "llvm/Transforms/IPO.h" #include <algorithm> #include <cassert> @@ -84,7 +85,7 @@ static void copyDebugLocMetadata(const GlobalVariable *From, static Align getAlign(GlobalVariable *GV) { return GV->getAlign().value_or( - GV->getParent()->getDataLayout().getPreferredAlign(GV)); + GV->getDataLayout().getPreferredAlign(GV)); } static bool diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp index 5cc8258a495a..91d445dfc4c7 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp @@ -139,8 +139,7 @@ void CrossDSOCFI::buildCFICheck(Module &M) { } bool CrossDSOCFI::runOnModule(Module &M) { - VeryLikelyWeights = - MDBuilder(M.getContext()).createBranchWeights((1U << 20) - 1, 1); + VeryLikelyWeights = MDBuilder(M.getContext()).createLikelyBranchWeights(); if (M.getModuleFlag("Cross-DSO CFI") == nullptr) return false; buildCFICheck(M); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp index 4f65748c19e6..a164c82bdf75 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -204,9 +204,9 @@ bool DeadArgumentEliminationPass::deleteDeadVarargs(Function &F) { CallBase *NewCB = nullptr; if (InvokeInst *II = dyn_cast<InvokeInst>(CB)) { NewCB = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), - Args, OpBundles, "", CB); + Args, OpBundles, "", CB->getIterator()); } else { - NewCB = CallInst::Create(NF, Args, OpBundles, "", CB); + NewCB = CallInst::Create(NF, Args, OpBundles, "", CB->getIterator()); cast<CallInst>(NewCB)->setTailCallKind( cast<CallInst>(CB)->getTailCallKind()); } @@ -319,9 +319,7 @@ bool DeadArgumentEliminationPass::removeDeadArgumentsFromCallers(Function &F) { continue; // Now go through all unused args and replace them with poison. - for (unsigned I = 0, E = UnusedArgs.size(); I != E; ++I) { - unsigned ArgNo = UnusedArgs[I]; - + for (unsigned ArgNo : UnusedArgs) { Value *Arg = CB->getArgOperand(ArgNo); CB->setArgOperand(ArgNo, PoisonValue::get(Arg->getType())); CB->removeParamAttrs(ArgNo, UBImplyingAttributes); @@ -946,7 +944,7 @@ bool DeadArgumentEliminationPass::removeDeadStuffFromFunction(Function *F) { NewCB = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(), Args, OpBundles, "", CB.getParent()); } else { - NewCB = CallInst::Create(NFTy, NF, Args, OpBundles, "", &CB); + NewCB = CallInst::Create(NFTy, NF, Args, OpBundles, "", CB.getIterator()); cast<CallInst>(NewCB)->setTailCallKind( cast<CallInst>(&CB)->getTailCallKind()); } @@ -1070,7 +1068,8 @@ bool DeadArgumentEliminationPass::removeDeadStuffFromFunction(Function *F) { } // Replace the return instruction with one returning the new return // value (possibly 0 if we became void). - auto *NewRet = ReturnInst::Create(F->getContext(), RetVal, RI); + auto *NewRet = + ReturnInst::Create(F->getContext(), RetVal, RI->getIterator()); NewRet->setDebugLoc(RI->getDebugLoc()); RI->eraseFromParent(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ExpandVariadics.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ExpandVariadics.cpp new file mode 100644 index 000000000000..b5b590e2b7ac --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ExpandVariadics.cpp @@ -0,0 +1,1044 @@ +//===-- ExpandVariadicsPass.cpp --------------------------------*- C++ -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is an optimization pass for variadic functions. If called from codegen, +// it can serve as the implementation of variadic functions for a given target. +// +// The strategy is to turn the ... part of a variadic function into a va_list +// and fix up the call sites. The majority of the pass is target independent. +// The exceptions are the va_list type itself and the rules for where to store +// variables in memory such that va_arg can iterate over them given a va_list. +// +// The majority of the plumbing is splitting the variadic function into a +// single basic block that packs the variadic arguments into a va_list and +// a second function that does the work of the original. That packing is +// exactly what is done by va_start. Further, the transform from ... to va_list +// replaced va_start with an operation to copy a va_list from the new argument, +// which is exactly a va_copy. This is useful for reducing target-dependence. +// +// A va_list instance is a forward iterator, where the primary operation va_arg +// is dereference-then-increment. This interface forces significant convergent +// evolution between target specific implementations. The variation in runtime +// data layout is limited to that representable by the iterator, parameterised +// by the type passed to the va_arg instruction. +// +// Therefore the majority of the target specific subtlety is packing arguments +// into a stack allocated buffer such that a va_list can be initialised with it +// and the va_arg expansion for the target will find the arguments at runtime. +// +// The aggregate effect is to unblock other transforms, most critically the +// general purpose inliner. Known calls to variadic functions become zero cost. +// +// Consistency with clang is primarily tested by emitting va_arg using clang +// then expanding the variadic functions using this pass, followed by trying +// to constant fold the functions to no-ops. +// +// Target specific behaviour is tested in IR - mainly checking that values are +// put into positions in call frames that make sense for that particular target. +// +// There is one "clever" invariant in use. va_start intrinsics that are not +// within a varidic functions are an error in the IR verifier. When this +// transform moves blocks from a variadic function into a fixed arity one, it +// moves va_start intrinsics along with everything else. That means that the +// va_start intrinsics that need to be rewritten to use the trailing argument +// are exactly those that are in non-variadic functions so no further state +// is needed to distinguish those that need to be rewritten. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/ExpandVariadics.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/TargetParser/Triple.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" + +#define DEBUG_TYPE "expand-variadics" + +using namespace llvm; + +namespace { + +cl::opt<ExpandVariadicsMode> ExpandVariadicsModeOption( + DEBUG_TYPE "-override", cl::desc("Override the behaviour of " DEBUG_TYPE), + cl::init(ExpandVariadicsMode::Unspecified), + cl::values(clEnumValN(ExpandVariadicsMode::Unspecified, "unspecified", + "Use the implementation defaults"), + clEnumValN(ExpandVariadicsMode::Disable, "disable", + "Disable the pass entirely"), + clEnumValN(ExpandVariadicsMode::Optimize, "optimize", + "Optimise without changing ABI"), + clEnumValN(ExpandVariadicsMode::Lowering, "lowering", + "Change variadic calling convention"))); + +bool commandLineOverride() { + return ExpandVariadicsModeOption != ExpandVariadicsMode::Unspecified; +} + +// Instances of this class encapsulate the target-dependant behaviour as a +// function of triple. Implementing a new ABI is adding a case to the switch +// in create(llvm::Triple) at the end of this file. +// This class may end up instantiated in TargetMachine instances, keeping it +// here for now until enough targets are implemented for the API to evolve. +class VariadicABIInfo { +protected: + VariadicABIInfo() = default; + +public: + static std::unique_ptr<VariadicABIInfo> create(const Triple &T); + + // Allow overriding whether the pass runs on a per-target basis + virtual bool enableForTarget() = 0; + + // Whether a valist instance is passed by value or by address + // I.e. does it need to be alloca'ed and stored into, or can + // it be passed directly in a SSA register + virtual bool vaListPassedInSSARegister() = 0; + + // The type of a va_list iterator object + virtual Type *vaListType(LLVMContext &Ctx) = 0; + + // The type of a va_list as a function argument as lowered by C + virtual Type *vaListParameterType(Module &M) = 0; + + // Initialize an allocated va_list object to point to an already + // initialized contiguous memory region. + // Return the value to pass as the va_list argument + virtual Value *initializeVaList(Module &M, LLVMContext &Ctx, + IRBuilder<> &Builder, AllocaInst *VaList, + Value *Buffer) = 0; + + struct VAArgSlotInfo { + Align DataAlign; // With respect to the call frame + bool Indirect; // Passed via a pointer + }; + virtual VAArgSlotInfo slotInfo(const DataLayout &DL, Type *Parameter) = 0; + + // Targets implemented so far all have the same trivial lowering for these + bool vaEndIsNop() { return true; } + bool vaCopyIsMemcpy() { return true; } + + virtual ~VariadicABIInfo() = default; +}; + +// Module implements getFunction() which returns nullptr on missing declaration +// and getOrInsertFunction which creates one when absent. Intrinsics.h only +// implements getDeclaration which creates one when missing. Checking whether +// an intrinsic exists thus inserts it in the module and it then needs to be +// deleted again to clean up. +// The right name for the two functions on intrinsics would match Module::, +// but doing that in a single change would introduce nullptr dereferences +// where currently there are none. The minimal collateral damage approach +// would split the change over a release to help downstream branches. As it +// is unclear what approach will be preferred, implementing the trivial +// function here in the meantime to decouple from that discussion. +Function *getPreexistingDeclaration(Module *M, Intrinsic::ID Id, + ArrayRef<Type *> Tys = {}) { + auto *FT = Intrinsic::getType(M->getContext(), Id, Tys); + return M->getFunction(Tys.empty() ? Intrinsic::getName(Id) + : Intrinsic::getName(Id, Tys, M, FT)); +} + +class ExpandVariadics : public ModulePass { + + // The pass construction sets the default to optimize when called from middle + // end and lowering when called from the backend. The command line variable + // overrides that. This is useful for testing and debugging. It also allows + // building an applications with variadic functions wholly removed if one + // has sufficient control over the dependencies, e.g. a statically linked + // clang that has no variadic function calls remaining in the binary. + +public: + static char ID; + const ExpandVariadicsMode Mode; + std::unique_ptr<VariadicABIInfo> ABI; + + ExpandVariadics(ExpandVariadicsMode Mode) + : ModulePass(ID), + Mode(commandLineOverride() ? ExpandVariadicsModeOption : Mode) {} + + StringRef getPassName() const override { return "Expand variadic functions"; } + + bool rewriteABI() { return Mode == ExpandVariadicsMode::Lowering; } + + bool runOnModule(Module &M) override; + + bool runOnFunction(Module &M, IRBuilder<> &Builder, Function *F); + + Function *replaceAllUsesWithNewDeclaration(Module &M, + Function *OriginalFunction); + + Function *deriveFixedArityReplacement(Module &M, IRBuilder<> &Builder, + Function *OriginalFunction); + + Function *defineVariadicWrapper(Module &M, IRBuilder<> &Builder, + Function *VariadicWrapper, + Function *FixedArityReplacement); + + bool expandCall(Module &M, IRBuilder<> &Builder, CallBase *CB, FunctionType *, + Function *NF); + + // The intrinsic functions va_copy and va_end are removed unconditionally. + // They correspond to a memcpy and a no-op on all implemented targets. + // The va_start intrinsic is removed from basic blocks that were not created + // by this pass, some may remain if needed to maintain the external ABI. + + template <Intrinsic::ID ID, typename InstructionType> + bool expandIntrinsicUsers(Module &M, IRBuilder<> &Builder, + PointerType *IntrinsicArgType) { + bool Changed = false; + const DataLayout &DL = M.getDataLayout(); + if (Function *Intrinsic = + getPreexistingDeclaration(&M, ID, {IntrinsicArgType})) { + for (User *U : make_early_inc_range(Intrinsic->users())) + if (auto *I = dyn_cast<InstructionType>(U)) + Changed |= expandVAIntrinsicCall(Builder, DL, I); + + if (Intrinsic->use_empty()) + Intrinsic->eraseFromParent(); + } + return Changed; + } + + bool expandVAIntrinsicUsersWithAddrspace(Module &M, IRBuilder<> &Builder, + unsigned Addrspace) { + auto &Ctx = M.getContext(); + PointerType *IntrinsicArgType = PointerType::get(Ctx, Addrspace); + bool Changed = false; + + // expand vastart before vacopy as vastart may introduce a vacopy + Changed |= expandIntrinsicUsers<Intrinsic::vastart, VAStartInst>( + M, Builder, IntrinsicArgType); + Changed |= expandIntrinsicUsers<Intrinsic::vaend, VAEndInst>( + M, Builder, IntrinsicArgType); + Changed |= expandIntrinsicUsers<Intrinsic::vacopy, VACopyInst>( + M, Builder, IntrinsicArgType); + return Changed; + } + + bool expandVAIntrinsicCall(IRBuilder<> &Builder, const DataLayout &DL, + VAStartInst *Inst); + + bool expandVAIntrinsicCall(IRBuilder<> &, const DataLayout &, + VAEndInst *Inst); + + bool expandVAIntrinsicCall(IRBuilder<> &Builder, const DataLayout &DL, + VACopyInst *Inst); + + FunctionType *inlinableVariadicFunctionType(Module &M, FunctionType *FTy) { + // The type of "FTy" with the ... removed and a va_list appended + SmallVector<Type *> ArgTypes(FTy->param_begin(), FTy->param_end()); + ArgTypes.push_back(ABI->vaListParameterType(M)); + return FunctionType::get(FTy->getReturnType(), ArgTypes, + /*IsVarArgs=*/false); + } + + static ConstantInt *sizeOfAlloca(LLVMContext &Ctx, const DataLayout &DL, + AllocaInst *Alloced) { + std::optional<TypeSize> AllocaTypeSize = Alloced->getAllocationSize(DL); + uint64_t AsInt = AllocaTypeSize ? AllocaTypeSize->getFixedValue() : 0; + return ConstantInt::get(Type::getInt64Ty(Ctx), AsInt); + } + + bool expansionApplicableToFunction(Module &M, Function *F) { + if (F->isIntrinsic() || !F->isVarArg() || + F->hasFnAttribute(Attribute::Naked)) + return false; + + if (F->getCallingConv() != CallingConv::C) + return false; + + if (rewriteABI()) + return true; + + if (!F->hasExactDefinition()) + return false; + + return true; + } + + bool expansionApplicableToFunctionCall(CallBase *CB) { + if (CallInst *CI = dyn_cast<CallInst>(CB)) { + if (CI->isMustTailCall()) { + // Cannot expand musttail calls + return false; + } + + if (CI->getCallingConv() != CallingConv::C) + return false; + + return true; + } + + if (isa<InvokeInst>(CB)) { + // Invoke not implemented in initial implementation of pass + return false; + } + + // Other unimplemented derivative of CallBase + return false; + } + + class ExpandedCallFrame { + // Helper for constructing an alloca instance containing the arguments bound + // to the variadic ... parameter, rearranged to allow indexing through a + // va_list iterator + enum { N = 4 }; + SmallVector<Type *, N> FieldTypes; + enum Tag { Store, Memcpy, Padding }; + SmallVector<std::tuple<Value *, uint64_t, Tag>, N> Source; + + template <Tag tag> void append(Type *FieldType, Value *V, uint64_t Bytes) { + FieldTypes.push_back(FieldType); + Source.push_back({V, Bytes, tag}); + } + + public: + void store(LLVMContext &Ctx, Type *T, Value *V) { append<Store>(T, V, 0); } + + void memcpy(LLVMContext &Ctx, Type *T, Value *V, uint64_t Bytes) { + append<Memcpy>(T, V, Bytes); + } + + void padding(LLVMContext &Ctx, uint64_t By) { + append<Padding>(ArrayType::get(Type::getInt8Ty(Ctx), By), nullptr, 0); + } + + size_t size() const { return FieldTypes.size(); } + bool empty() const { return FieldTypes.empty(); } + + StructType *asStruct(LLVMContext &Ctx, StringRef Name) { + const bool IsPacked = true; + return StructType::create(Ctx, FieldTypes, + (Twine(Name) + ".vararg").str(), IsPacked); + } + + void initializeStructAlloca(const DataLayout &DL, IRBuilder<> &Builder, + AllocaInst *Alloced) { + + StructType *VarargsTy = cast<StructType>(Alloced->getAllocatedType()); + + for (size_t I = 0; I < size(); I++) { + + auto [V, bytes, tag] = Source[I]; + + if (tag == Padding) { + assert(V == nullptr); + continue; + } + + auto Dst = Builder.CreateStructGEP(VarargsTy, Alloced, I); + + assert(V != nullptr); + + if (tag == Store) + Builder.CreateStore(V, Dst); + + if (tag == Memcpy) + Builder.CreateMemCpy(Dst, {}, V, {}, bytes); + } + } + }; +}; + +bool ExpandVariadics::runOnModule(Module &M) { + bool Changed = false; + if (Mode == ExpandVariadicsMode::Disable) + return Changed; + + Triple TT(M.getTargetTriple()); + ABI = VariadicABIInfo::create(TT); + if (!ABI) + return Changed; + + if (!ABI->enableForTarget()) + return Changed; + + auto &Ctx = M.getContext(); + const DataLayout &DL = M.getDataLayout(); + IRBuilder<> Builder(Ctx); + + // Lowering needs to run on all functions exactly once. + // Optimize could run on functions containing va_start exactly once. + for (Function &F : make_early_inc_range(M)) + Changed |= runOnFunction(M, Builder, &F); + + // After runOnFunction, all known calls to known variadic functions have been + // replaced. va_start intrinsics are presently (and invalidly!) only present + // in functions that used to be variadic and have now been replaced to take a + // va_list instead. If lowering as opposed to optimising, calls to unknown + // variadic functions have also been replaced. + + { + // 0 and AllocaAddrSpace are sufficient for the targets implemented so far + unsigned Addrspace = 0; + Changed |= expandVAIntrinsicUsersWithAddrspace(M, Builder, Addrspace); + + Addrspace = DL.getAllocaAddrSpace(); + if (Addrspace != 0) + Changed |= expandVAIntrinsicUsersWithAddrspace(M, Builder, Addrspace); + } + + if (Mode != ExpandVariadicsMode::Lowering) + return Changed; + + for (Function &F : make_early_inc_range(M)) { + if (F.isDeclaration()) + continue; + + // Now need to track down indirect calls. Can't find those + // by walking uses of variadic functions, need to crawl the instruction + // stream. Fortunately this is only necessary for the ABI rewrite case. + for (BasicBlock &BB : F) { + for (Instruction &I : make_early_inc_range(BB)) { + if (CallBase *CB = dyn_cast<CallBase>(&I)) { + if (CB->isIndirectCall()) { + FunctionType *FTy = CB->getFunctionType(); + if (FTy->isVarArg()) + Changed |= expandCall(M, Builder, CB, FTy, 0); + } + } + } + } + } + + return Changed; +} + +bool ExpandVariadics::runOnFunction(Module &M, IRBuilder<> &Builder, + Function *OriginalFunction) { + bool Changed = false; + + if (!expansionApplicableToFunction(M, OriginalFunction)) + return Changed; + + [[maybe_unused]] const bool OriginalFunctionIsDeclaration = + OriginalFunction->isDeclaration(); + assert(rewriteABI() || !OriginalFunctionIsDeclaration); + + // Declare a new function and redirect every use to that new function + Function *VariadicWrapper = + replaceAllUsesWithNewDeclaration(M, OriginalFunction); + assert(VariadicWrapper->isDeclaration()); + assert(OriginalFunction->use_empty()); + + // Create a new function taking va_list containing the implementation of the + // original + Function *FixedArityReplacement = + deriveFixedArityReplacement(M, Builder, OriginalFunction); + assert(OriginalFunction->isDeclaration()); + assert(FixedArityReplacement->isDeclaration() == + OriginalFunctionIsDeclaration); + assert(VariadicWrapper->isDeclaration()); + + // Create a single block forwarding wrapper that turns a ... into a va_list + [[maybe_unused]] Function *VariadicWrapperDefine = + defineVariadicWrapper(M, Builder, VariadicWrapper, FixedArityReplacement); + assert(VariadicWrapperDefine == VariadicWrapper); + assert(!VariadicWrapper->isDeclaration()); + + // We now have: + // 1. the original function, now as a declaration with no uses + // 2. a variadic function that unconditionally calls a fixed arity replacement + // 3. a fixed arity function equivalent to the original function + + // Replace known calls to the variadic with calls to the va_list equivalent + for (User *U : make_early_inc_range(VariadicWrapper->users())) { + if (CallBase *CB = dyn_cast<CallBase>(U)) { + Value *CalledOperand = CB->getCalledOperand(); + if (VariadicWrapper == CalledOperand) + Changed |= + expandCall(M, Builder, CB, VariadicWrapper->getFunctionType(), + FixedArityReplacement); + } + } + + // The original function will be erased. + // One of the two new functions will become a replacement for the original. + // When preserving the ABI, the other is an internal implementation detail. + // When rewriting the ABI, RAUW then the variadic one. + Function *const ExternallyAccessible = + rewriteABI() ? FixedArityReplacement : VariadicWrapper; + Function *const InternalOnly = + rewriteABI() ? VariadicWrapper : FixedArityReplacement; + + // The external function is the replacement for the original + ExternallyAccessible->setLinkage(OriginalFunction->getLinkage()); + ExternallyAccessible->setVisibility(OriginalFunction->getVisibility()); + ExternallyAccessible->setComdat(OriginalFunction->getComdat()); + ExternallyAccessible->takeName(OriginalFunction); + + // Annotate the internal one as internal + InternalOnly->setVisibility(GlobalValue::DefaultVisibility); + InternalOnly->setLinkage(GlobalValue::InternalLinkage); + + // The original is unused and obsolete + OriginalFunction->eraseFromParent(); + + InternalOnly->removeDeadConstantUsers(); + + if (rewriteABI()) { + // All known calls to the function have been removed by expandCall + // Resolve everything else by replaceAllUsesWith + VariadicWrapper->replaceAllUsesWith(FixedArityReplacement); + VariadicWrapper->eraseFromParent(); + } + + return Changed; +} + +Function * +ExpandVariadics::replaceAllUsesWithNewDeclaration(Module &M, + Function *OriginalFunction) { + auto &Ctx = M.getContext(); + Function &F = *OriginalFunction; + FunctionType *FTy = F.getFunctionType(); + Function *NF = Function::Create(FTy, F.getLinkage(), F.getAddressSpace()); + + NF->setName(F.getName() + ".varargs"); + NF->IsNewDbgInfoFormat = F.IsNewDbgInfoFormat; + + F.getParent()->getFunctionList().insert(F.getIterator(), NF); + + AttrBuilder ParamAttrs(Ctx); + AttributeList Attrs = NF->getAttributes(); + Attrs = Attrs.addParamAttributes(Ctx, FTy->getNumParams(), ParamAttrs); + NF->setAttributes(Attrs); + + OriginalFunction->replaceAllUsesWith(NF); + return NF; +} + +Function * +ExpandVariadics::deriveFixedArityReplacement(Module &M, IRBuilder<> &Builder, + Function *OriginalFunction) { + Function &F = *OriginalFunction; + // The purpose here is split the variadic function F into two functions + // One is a variadic function that bundles the passed argument into a va_list + // and passes it to the second function. The second function does whatever + // the original F does, except that it takes a va_list instead of the ... + + assert(expansionApplicableToFunction(M, &F)); + + auto &Ctx = M.getContext(); + + // Returned value isDeclaration() is equal to F.isDeclaration() + // but that property is not invariant throughout this function + const bool FunctionIsDefinition = !F.isDeclaration(); + + FunctionType *FTy = F.getFunctionType(); + SmallVector<Type *> ArgTypes(FTy->param_begin(), FTy->param_end()); + ArgTypes.push_back(ABI->vaListParameterType(M)); + + FunctionType *NFTy = inlinableVariadicFunctionType(M, FTy); + Function *NF = Function::Create(NFTy, F.getLinkage(), F.getAddressSpace()); + + // Note - same attribute handling as DeadArgumentElimination + NF->copyAttributesFrom(&F); + NF->setComdat(F.getComdat()); + F.getParent()->getFunctionList().insert(F.getIterator(), NF); + NF->setName(F.getName() + ".valist"); + NF->IsNewDbgInfoFormat = F.IsNewDbgInfoFormat; + + AttrBuilder ParamAttrs(Ctx); + + AttributeList Attrs = NF->getAttributes(); + Attrs = Attrs.addParamAttributes(Ctx, NFTy->getNumParams() - 1, ParamAttrs); + NF->setAttributes(Attrs); + + // Splice the implementation into the new function with minimal changes + if (FunctionIsDefinition) { + NF->splice(NF->begin(), &F); + + auto NewArg = NF->arg_begin(); + for (Argument &Arg : F.args()) { + Arg.replaceAllUsesWith(NewArg); + NewArg->setName(Arg.getName()); // takeName without killing the old one + ++NewArg; + } + NewArg->setName("varargs"); + } + + SmallVector<std::pair<unsigned, MDNode *>, 1> MDs; + F.getAllMetadata(MDs); + for (auto [KindID, Node] : MDs) + NF->addMetadata(KindID, *Node); + F.clearMetadata(); + + return NF; +} + +Function * +ExpandVariadics::defineVariadicWrapper(Module &M, IRBuilder<> &Builder, + Function *VariadicWrapper, + Function *FixedArityReplacement) { + auto &Ctx = Builder.getContext(); + const DataLayout &DL = M.getDataLayout(); + assert(VariadicWrapper->isDeclaration()); + Function &F = *VariadicWrapper; + + assert(F.isDeclaration()); + Type *VaListTy = ABI->vaListType(Ctx); + + auto *BB = BasicBlock::Create(Ctx, "entry", &F); + Builder.SetInsertPoint(BB); + + AllocaInst *VaListInstance = + Builder.CreateAlloca(VaListTy, nullptr, "va_start"); + + Builder.CreateLifetimeStart(VaListInstance, + sizeOfAlloca(Ctx, DL, VaListInstance)); + + Builder.CreateIntrinsic(Intrinsic::vastart, {DL.getAllocaPtrType(Ctx)}, + {VaListInstance}); + + SmallVector<Value *> Args; + for (Argument &A : F.args()) + Args.push_back(&A); + + Type *ParameterType = ABI->vaListParameterType(M); + if (ABI->vaListPassedInSSARegister()) + Args.push_back(Builder.CreateLoad(ParameterType, VaListInstance)); + else + Args.push_back(Builder.CreateAddrSpaceCast(VaListInstance, ParameterType)); + + CallInst *Result = Builder.CreateCall(FixedArityReplacement, Args); + + Builder.CreateIntrinsic(Intrinsic::vaend, {DL.getAllocaPtrType(Ctx)}, + {VaListInstance}); + Builder.CreateLifetimeEnd(VaListInstance, + sizeOfAlloca(Ctx, DL, VaListInstance)); + + if (Result->getType()->isVoidTy()) + Builder.CreateRetVoid(); + else + Builder.CreateRet(Result); + + return VariadicWrapper; +} + +bool ExpandVariadics::expandCall(Module &M, IRBuilder<> &Builder, CallBase *CB, + FunctionType *VarargFunctionType, + Function *NF) { + bool Changed = false; + const DataLayout &DL = M.getDataLayout(); + + if (!expansionApplicableToFunctionCall(CB)) { + if (rewriteABI()) + report_fatal_error("Cannot lower callbase instruction"); + return Changed; + } + + // This is tricky. The call instruction's function type might not match + // the type of the caller. When optimising, can leave it unchanged. + // Webassembly detects that inconsistency and repairs it. + FunctionType *FuncType = CB->getFunctionType(); + if (FuncType != VarargFunctionType) { + if (!rewriteABI()) + return Changed; + FuncType = VarargFunctionType; + } + + auto &Ctx = CB->getContext(); + + Align MaxFieldAlign(1); + + // The strategy is to allocate a call frame containing the variadic + // arguments laid out such that a target specific va_list can be initialized + // with it, such that target specific va_arg instructions will correctly + // iterate over it. This means getting the alignment right and sometimes + // embedding a pointer to the value instead of embedding the value itself. + + Function *CBF = CB->getParent()->getParent(); + + ExpandedCallFrame Frame; + + uint64_t CurrentOffset = 0; + + for (unsigned I = FuncType->getNumParams(), E = CB->arg_size(); I < E; ++I) { + Value *ArgVal = CB->getArgOperand(I); + const bool IsByVal = CB->paramHasAttr(I, Attribute::ByVal); + const bool IsByRef = CB->paramHasAttr(I, Attribute::ByRef); + + // The type of the value being passed, decoded from byval/byref metadata if + // required + Type *const UnderlyingType = IsByVal ? CB->getParamByValType(I) + : IsByRef ? CB->getParamByRefType(I) + : ArgVal->getType(); + const uint64_t UnderlyingSize = + DL.getTypeAllocSize(UnderlyingType).getFixedValue(); + + // The type to be written into the call frame + Type *FrameFieldType = UnderlyingType; + + // The value to copy from when initialising the frame alloca + Value *SourceValue = ArgVal; + + VariadicABIInfo::VAArgSlotInfo SlotInfo = ABI->slotInfo(DL, UnderlyingType); + + if (SlotInfo.Indirect) { + // The va_arg lowering loads through a pointer. Set up an alloca to aim + // that pointer at. + Builder.SetInsertPointPastAllocas(CBF); + Builder.SetCurrentDebugLocation(CB->getStableDebugLoc()); + Value *CallerCopy = + Builder.CreateAlloca(UnderlyingType, nullptr, "IndirectAlloca"); + + Builder.SetInsertPoint(CB); + if (IsByVal) + Builder.CreateMemCpy(CallerCopy, {}, ArgVal, {}, UnderlyingSize); + else + Builder.CreateStore(ArgVal, CallerCopy); + + // Indirection now handled, pass the alloca ptr by value + FrameFieldType = DL.getAllocaPtrType(Ctx); + SourceValue = CallerCopy; + } + + // Alignment of the value within the frame + // This probably needs to be controllable as a function of type + Align DataAlign = SlotInfo.DataAlign; + + MaxFieldAlign = std::max(MaxFieldAlign, DataAlign); + + uint64_t DataAlignV = DataAlign.value(); + if (uint64_t Rem = CurrentOffset % DataAlignV) { + // Inject explicit padding to deal with alignment requirements + uint64_t Padding = DataAlignV - Rem; + Frame.padding(Ctx, Padding); + CurrentOffset += Padding; + } + + if (SlotInfo.Indirect) { + Frame.store(Ctx, FrameFieldType, SourceValue); + } else { + if (IsByVal) + Frame.memcpy(Ctx, FrameFieldType, SourceValue, UnderlyingSize); + else + Frame.store(Ctx, FrameFieldType, SourceValue); + } + + CurrentOffset += DL.getTypeAllocSize(FrameFieldType).getFixedValue(); + } + + if (Frame.empty()) { + // Not passing any arguments, hopefully va_arg won't try to read any + // Creating a single byte frame containing nothing to point the va_list + // instance as that is less special-casey in the compiler and probably + // easier to interpret in a debugger. + Frame.padding(Ctx, 1); + } + + StructType *VarargsTy = Frame.asStruct(Ctx, CBF->getName()); + + // The struct instance needs to be at least MaxFieldAlign for the alignment of + // the fields to be correct at runtime. Use the native stack alignment instead + // if that's greater as that tends to give better codegen. + // This is an awkward way to guess whether there is a known stack alignment + // without hitting an assert in DL.getStackAlignment, 1024 is an arbitrary + // number likely to be greater than the natural stack alignment. + // TODO: DL.getStackAlignment could return a MaybeAlign instead of assert + Align AllocaAlign = MaxFieldAlign; + if (DL.exceedsNaturalStackAlignment(Align(1024))) + AllocaAlign = std::max(AllocaAlign, DL.getStackAlignment()); + + // Put the alloca to hold the variadic args in the entry basic block. + Builder.SetInsertPointPastAllocas(CBF); + + // SetCurrentDebugLocation when the builder SetInsertPoint method does not + Builder.SetCurrentDebugLocation(CB->getStableDebugLoc()); + + // The awkward construction here is to set the alignment on the instance + AllocaInst *Alloced = Builder.Insert( + new AllocaInst(VarargsTy, DL.getAllocaAddrSpace(), nullptr, AllocaAlign), + "vararg_buffer"); + Changed = true; + assert(Alloced->getAllocatedType() == VarargsTy); + + // Initialize the fields in the struct + Builder.SetInsertPoint(CB); + Builder.CreateLifetimeStart(Alloced, sizeOfAlloca(Ctx, DL, Alloced)); + Frame.initializeStructAlloca(DL, Builder, Alloced); + + const unsigned NumArgs = FuncType->getNumParams(); + SmallVector<Value *> Args(CB->arg_begin(), CB->arg_begin() + NumArgs); + + // Initialize a va_list pointing to that struct and pass it as the last + // argument + AllocaInst *VaList = nullptr; + { + if (!ABI->vaListPassedInSSARegister()) { + Type *VaListTy = ABI->vaListType(Ctx); + Builder.SetInsertPointPastAllocas(CBF); + Builder.SetCurrentDebugLocation(CB->getStableDebugLoc()); + VaList = Builder.CreateAlloca(VaListTy, nullptr, "va_argument"); + Builder.SetInsertPoint(CB); + Builder.CreateLifetimeStart(VaList, sizeOfAlloca(Ctx, DL, VaList)); + } + Builder.SetInsertPoint(CB); + Args.push_back(ABI->initializeVaList(M, Ctx, Builder, VaList, Alloced)); + } + + // Attributes excluding any on the vararg arguments + AttributeList PAL = CB->getAttributes(); + if (!PAL.isEmpty()) { + SmallVector<AttributeSet, 8> ArgAttrs; + for (unsigned ArgNo = 0; ArgNo < NumArgs; ArgNo++) + ArgAttrs.push_back(PAL.getParamAttrs(ArgNo)); + PAL = + AttributeList::get(Ctx, PAL.getFnAttrs(), PAL.getRetAttrs(), ArgAttrs); + } + + SmallVector<OperandBundleDef, 1> OpBundles; + CB->getOperandBundlesAsDefs(OpBundles); + + CallBase *NewCB = nullptr; + + if (CallInst *CI = dyn_cast<CallInst>(CB)) { + Value *Dst = NF ? NF : CI->getCalledOperand(); + FunctionType *NFTy = inlinableVariadicFunctionType(M, VarargFunctionType); + + NewCB = CallInst::Create(NFTy, Dst, Args, OpBundles, "", CI); + + CallInst::TailCallKind TCK = CI->getTailCallKind(); + assert(TCK != CallInst::TCK_MustTail); + + // Can't tail call a function that is being passed a pointer to an alloca + if (TCK == CallInst::TCK_Tail) + TCK = CallInst::TCK_None; + CI->setTailCallKind(TCK); + + } else { + llvm_unreachable("Unreachable when !expansionApplicableToFunctionCall()"); + } + + if (VaList) + Builder.CreateLifetimeEnd(VaList, sizeOfAlloca(Ctx, DL, VaList)); + + Builder.CreateLifetimeEnd(Alloced, sizeOfAlloca(Ctx, DL, Alloced)); + + NewCB->setAttributes(PAL); + NewCB->takeName(CB); + NewCB->setCallingConv(CB->getCallingConv()); + NewCB->setDebugLoc(DebugLoc()); + + // DeadArgElim and ArgPromotion copy exactly this metadata + NewCB->copyMetadata(*CB, {LLVMContext::MD_prof, LLVMContext::MD_dbg}); + + CB->replaceAllUsesWith(NewCB); + CB->eraseFromParent(); + return Changed; +} + +bool ExpandVariadics::expandVAIntrinsicCall(IRBuilder<> &Builder, + const DataLayout &DL, + VAStartInst *Inst) { + // Only removing va_start instructions that are not in variadic functions. + // Those would be rejected by the IR verifier before this pass. + // After splicing basic blocks from a variadic function into a fixed arity + // one the va_start that used to refer to the ... parameter still exist. + // There are also variadic functions that this pass did not change and + // va_start instances in the created single block wrapper functions. + // Replace exactly the instances in non-variadic functions as those are + // the ones to be fixed up to use the va_list passed as the final argument. + + Function *ContainingFunction = Inst->getFunction(); + if (ContainingFunction->isVarArg()) { + return false; + } + + // The last argument is a vaListParameterType, either a va_list + // or a pointer to one depending on the target. + bool PassedByValue = ABI->vaListPassedInSSARegister(); + Argument *PassedVaList = + ContainingFunction->getArg(ContainingFunction->arg_size() - 1); + + // va_start takes a pointer to a va_list, e.g. one on the stack + Value *VaStartArg = Inst->getArgList(); + + Builder.SetInsertPoint(Inst); + + if (PassedByValue) { + // The general thing to do is create an alloca, store the va_list argument + // to it, then create a va_copy. When vaCopyIsMemcpy(), this optimises to a + // store to the VaStartArg. + assert(ABI->vaCopyIsMemcpy()); + Builder.CreateStore(PassedVaList, VaStartArg); + } else { + + // Otherwise emit a vacopy to pick up target-specific handling if any + auto &Ctx = Builder.getContext(); + + Builder.CreateIntrinsic(Intrinsic::vacopy, {DL.getAllocaPtrType(Ctx)}, + {VaStartArg, PassedVaList}); + } + + Inst->eraseFromParent(); + return true; +} + +bool ExpandVariadics::expandVAIntrinsicCall(IRBuilder<> &, const DataLayout &, + VAEndInst *Inst) { + assert(ABI->vaEndIsNop()); + Inst->eraseFromParent(); + return true; +} + +bool ExpandVariadics::expandVAIntrinsicCall(IRBuilder<> &Builder, + const DataLayout &DL, + VACopyInst *Inst) { + assert(ABI->vaCopyIsMemcpy()); + Builder.SetInsertPoint(Inst); + + auto &Ctx = Builder.getContext(); + Type *VaListTy = ABI->vaListType(Ctx); + uint64_t Size = DL.getTypeAllocSize(VaListTy).getFixedValue(); + + Builder.CreateMemCpy(Inst->getDest(), {}, Inst->getSrc(), {}, + Builder.getInt32(Size)); + + Inst->eraseFromParent(); + return true; +} + +struct Amdgpu final : public VariadicABIInfo { + + bool enableForTarget() override { return true; } + + bool vaListPassedInSSARegister() override { return true; } + + Type *vaListType(LLVMContext &Ctx) override { + return PointerType::getUnqual(Ctx); + } + + Type *vaListParameterType(Module &M) override { + return PointerType::getUnqual(M.getContext()); + } + + Value *initializeVaList(Module &M, LLVMContext &Ctx, IRBuilder<> &Builder, + AllocaInst * /*va_list*/, Value *Buffer) override { + // Given Buffer, which is an AllocInst of vararg_buffer + // need to return something usable as parameter type + return Builder.CreateAddrSpaceCast(Buffer, vaListParameterType(M)); + } + + VAArgSlotInfo slotInfo(const DataLayout &DL, Type *Parameter) override { + return {Align(4), false}; + } +}; + +struct NVPTX final : public VariadicABIInfo { + + bool enableForTarget() override { return true; } + + bool vaListPassedInSSARegister() override { return true; } + + Type *vaListType(LLVMContext &Ctx) override { + return PointerType::getUnqual(Ctx); + } + + Type *vaListParameterType(Module &M) override { + return PointerType::getUnqual(M.getContext()); + } + + Value *initializeVaList(Module &M, LLVMContext &Ctx, IRBuilder<> &Builder, + AllocaInst *, Value *Buffer) override { + return Builder.CreateAddrSpaceCast(Buffer, vaListParameterType(M)); + } + + VAArgSlotInfo slotInfo(const DataLayout &DL, Type *Parameter) override { + // NVPTX expects natural alignment in all cases. The variadic call ABI will + // handle promoting types to their appropriate size and alignment. + Align A = DL.getABITypeAlign(Parameter); + return {A, false}; + } +}; + +struct Wasm final : public VariadicABIInfo { + + bool enableForTarget() override { + // Currently wasm is only used for testing. + return commandLineOverride(); + } + + bool vaListPassedInSSARegister() override { return true; } + + Type *vaListType(LLVMContext &Ctx) override { + return PointerType::getUnqual(Ctx); + } + + Type *vaListParameterType(Module &M) override { + return PointerType::getUnqual(M.getContext()); + } + + Value *initializeVaList(Module &M, LLVMContext &Ctx, IRBuilder<> &Builder, + AllocaInst * /*va_list*/, Value *Buffer) override { + return Buffer; + } + + VAArgSlotInfo slotInfo(const DataLayout &DL, Type *Parameter) override { + LLVMContext &Ctx = Parameter->getContext(); + const unsigned MinAlign = 4; + Align A = DL.getABITypeAlign(Parameter); + if (A < MinAlign) + A = Align(MinAlign); + + if (auto *S = dyn_cast<StructType>(Parameter)) { + if (S->getNumElements() > 1) { + return {DL.getABITypeAlign(PointerType::getUnqual(Ctx)), true}; + } + } + + return {A, false}; + } +}; + +std::unique_ptr<VariadicABIInfo> VariadicABIInfo::create(const Triple &T) { + switch (T.getArch()) { + case Triple::r600: + case Triple::amdgcn: { + return std::make_unique<Amdgpu>(); + } + + case Triple::wasm32: { + return std::make_unique<Wasm>(); + } + + case Triple::nvptx: + case Triple::nvptx64: { + return std::make_unique<NVPTX>(); + } + + default: + return {}; + } +} + +} // namespace + +char ExpandVariadics::ID = 0; + +INITIALIZE_PASS(ExpandVariadics, DEBUG_TYPE, "Expand variadic functions", false, + false) + +ModulePass *llvm::createExpandVariadicsPass(ExpandVariadicsMode M) { + return new ExpandVariadics(M); +} + +PreservedAnalyses ExpandVariadicsPass::run(Module &M, ModuleAnalysisManager &) { + return ExpandVariadics(Mode).runOnModule(M) ? PreservedAnalyses::none() + : PreservedAnalyses::all(); +} + +ExpandVariadicsPass::ExpandVariadicsPass(ExpandVariadicsMode M) : Mode(M) {} diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp index 7ebf265e17ba..7b419d0f098b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionAttrs.cpp @@ -1169,7 +1169,7 @@ static bool isReturnNonNull(Function *F, const SCCNodeSet &SCCNodes, if (auto *Ret = dyn_cast<ReturnInst>(BB.getTerminator())) FlowsToReturn.insert(Ret->getReturnValue()); - auto &DL = F->getParent()->getDataLayout(); + auto &DL = F->getDataLayout(); for (unsigned i = 0; i != FlowsToReturn.size(); ++i) { Value *RetVal = FlowsToReturn[i]; @@ -1186,10 +1186,15 @@ static bool isReturnNonNull(Function *F, const SCCNodeSet &SCCNodes, switch (RVI->getOpcode()) { // Extend the analysis by looking upwards. case Instruction::BitCast: - case Instruction::GetElementPtr: case Instruction::AddrSpaceCast: FlowsToReturn.insert(RVI->getOperand(0)); continue; + case Instruction::GetElementPtr: + if (cast<GEPOperator>(RVI)->isInBounds()) { + FlowsToReturn.insert(RVI->getOperand(0)); + continue; + } + return false; case Instruction::Select: { SelectInst *SI = cast<SelectInst>(RVI); FlowsToReturn.insert(SI->getTrueValue()); @@ -1287,7 +1292,8 @@ static void addNoUndefAttrs(const SCCNodeSet &SCCNodes, // values. for (Function *F : SCCNodes) { // Already noundef. - if (F->getAttributes().hasRetAttr(Attribute::NoUndef)) + AttributeList Attrs = F->getAttributes(); + if (Attrs.hasRetAttr(Attribute::NoUndef)) continue; // We can infer and propagate function attributes only when we know that the @@ -1305,10 +1311,30 @@ static void addNoUndefAttrs(const SCCNodeSet &SCCNodes, if (F->getReturnType()->isVoidTy()) continue; - if (all_of(*F, [](BasicBlock &BB) { + const DataLayout &DL = F->getDataLayout(); + if (all_of(*F, [&](BasicBlock &BB) { if (auto *Ret = dyn_cast<ReturnInst>(BB.getTerminator())) { // TODO: perform context-sensitive analysis? - return isGuaranteedNotToBeUndefOrPoison(Ret->getReturnValue()); + Value *RetVal = Ret->getReturnValue(); + if (!isGuaranteedNotToBeUndefOrPoison(RetVal)) + return false; + + // We know the original return value is not poison now, but it + // could still be converted to poison by another return attribute. + // Try to explicitly re-prove the relevant attributes. + if (Attrs.hasRetAttr(Attribute::NonNull) && + !isKnownNonZero(RetVal, DL)) + return false; + + if (MaybeAlign Align = Attrs.getRetAlignment()) + if (RetVal->getPointerAlignment(DL) < *Align) + return false; + + Attribute Attr = Attrs.getRetAttr(Attribute::Range); + if (Attr.isValid() && + !Attr.getRange().contains( + computeConstantRange(RetVal, /*ForSigned=*/false))) + return false; } return true; })) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp index 49b3f2b085e1..038785114a0c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionImport.cpp @@ -125,7 +125,8 @@ static cl::opt<bool> ComputeDead("compute-dead", cl::init(true), cl::Hidden, static cl::opt<bool> EnableImportMetadata( "enable-import-metadata", cl::init(false), cl::Hidden, - cl::desc("Enable import metadata like 'thinlto_src_module'")); + cl::desc("Enable import metadata like 'thinlto_src_module' and " + "'thinlto_src_file'")); /// Summary file to use for function importing when using -function-import from /// the command line. @@ -139,6 +140,17 @@ static cl::opt<bool> ImportAllIndex("import-all-index", cl::desc("Import all external functions in index.")); +/// This is a test-only option. +/// If this option is enabled, the ThinLTO indexing step will import each +/// function declaration as a fallback. In a real build this may increase ram +/// usage of the indexing step unnecessarily. +/// TODO: Implement selective import (based on combined summary analysis) to +/// ensure the imported function has a use case in the postlink pipeline. +static cl::opt<bool> ImportDeclaration( + "import-declaration", cl::init(false), cl::Hidden, + cl::desc("If true, import function declaration as fallback if the function " + "definition is not imported.")); + /// Pass a workload description file - an example of workload would be the /// functions executed to satisfy a RPC request. A workload is defined by a root /// function and the list of functions that are (frequently) needed to satisfy @@ -162,6 +174,10 @@ static cl::opt<std::string> WorkloadDefinitions( "}"), cl::Hidden); +namespace llvm { +extern cl::opt<bool> EnableMemProfContextDisambiguation; +} + // Load lazily a module from \p FileName in \p Context. static std::unique_ptr<Module> loadFile(const std::string &FileName, LLVMContext &Context) { @@ -240,8 +256,12 @@ static auto qualifyCalleeCandidates( } /// Given a list of possible callee implementation for a call site, select one -/// that fits the \p Threshold. If none are found, the Reason will give the last -/// reason for the failure (last, in the order of CalleeSummaryList entries). +/// that fits the \p Threshold for function definition import. If none are +/// found, the Reason will give the last reason for the failure (last, in the +/// order of CalleeSummaryList entries). While looking for a callee definition, +/// sets \p TooLargeOrNoInlineSummary to the last seen too-large or noinline +/// candidate; other modules may want to know the function summary or +/// declaration even if a definition is not needed. /// /// FIXME: select "best" instead of first that fits. But what is "best"? /// - The smallest: more likely to be inlined. @@ -254,24 +274,32 @@ static const GlobalValueSummary * selectCallee(const ModuleSummaryIndex &Index, ArrayRef<std::unique_ptr<GlobalValueSummary>> CalleeSummaryList, unsigned Threshold, StringRef CallerModulePath, + const GlobalValueSummary *&TooLargeOrNoInlineSummary, FunctionImporter::ImportFailureReason &Reason) { + // Records the last summary with reason noinline or too-large. + TooLargeOrNoInlineSummary = nullptr; auto QualifiedCandidates = qualifyCalleeCandidates(Index, CalleeSummaryList, CallerModulePath); for (auto QualifiedValue : QualifiedCandidates) { Reason = QualifiedValue.first; + // Skip a summary if its import is not (proved to be) legal. if (Reason != FunctionImporter::ImportFailureReason::None) continue; auto *Summary = cast<FunctionSummary>(QualifiedValue.second->getBaseObject()); + // Don't bother importing the definition if the chance of inlining it is + // not high enough (except under `--force-import-all`). if ((Summary->instCount() > Threshold) && !Summary->fflags().AlwaysInline && !ForceImportAll) { + TooLargeOrNoInlineSummary = Summary; Reason = FunctionImporter::ImportFailureReason::TooLarge; continue; } - // Don't bother importing if we can't inline it anyway. + // Don't bother importing the definition if we can't inline it anyway. if (Summary->fflags().NoInline && !ForceImportAll) { + TooLargeOrNoInlineSummary = Summary; Reason = FunctionImporter::ImportFailureReason::NoInline; continue; } @@ -353,11 +381,20 @@ class GlobalsImporter final { if (!GVS || !Index.canImportGlobalVar(GVS, /* AnalyzeRefs */ true) || LocalNotInModule(GVS)) continue; - auto ILI = ImportList[RefSummary->modulePath()].insert(VI.getGUID()); + + // If there isn't an entry for GUID, insert <GUID, Definition> pair. + // Otherwise, definition should take precedence over declaration. + auto [Iter, Inserted] = + ImportList[RefSummary->modulePath()].try_emplace( + VI.getGUID(), GlobalValueSummary::Definition); // Only update stat and exports if we haven't already imported this // variable. - if (!ILI.second) + if (!Inserted) { + // Set the value to 'std::min(existing-value, new-value)' to make + // sure a definition takes precedence over a declaration. + Iter->second = std::min(GlobalValueSummary::Definition, Iter->second); break; + } NumImportedGlobalVarsThinLink++; // Any references made by this variable will be marked exported // later, in ComputeCrossModuleImport, after import decisions are @@ -540,7 +577,8 @@ class WorkloadImportsManager : public ModuleImportsManager { LLVM_DEBUG(dbgs() << "[Workload][Including]" << VI.name() << " from " << ExportingModule << " : " << Function::getGUID(VI.name()) << "\n"); - ImportList[ExportingModule].insert(VI.getGUID()); + ImportList[ExportingModule][VI.getGUID()] = + GlobalValueSummary::Definition; GVI.onImportingSummary(*GVS); if (ExportLists) (*ExportLists)[ExportingModule].insert(VI); @@ -764,9 +802,26 @@ static void computeImportForFunction( } FunctionImporter::ImportFailureReason Reason{}; - CalleeSummary = selectCallee(Index, VI.getSummaryList(), NewThreshold, - Summary.modulePath(), Reason); + + // `SummaryForDeclImport` is an summary eligible for declaration import. + const GlobalValueSummary *SummaryForDeclImport = nullptr; + CalleeSummary = + selectCallee(Index, VI.getSummaryList(), NewThreshold, + Summary.modulePath(), SummaryForDeclImport, Reason); if (!CalleeSummary) { + // There isn't a callee for definition import but one for declaration + // import. + if (ImportDeclaration && SummaryForDeclImport) { + StringRef DeclSourceModule = SummaryForDeclImport->modulePath(); + + // Since definition takes precedence over declaration for the same VI, + // try emplace <VI, declaration> pair without checking insert result. + // If insert doesn't happen, there must be an existing entry keyed by + // VI. Note `ExportLists` only keeps track of exports due to imported + // definitions. + ImportList[DeclSourceModule].try_emplace( + VI.getGUID(), GlobalValueSummary::Declaration); + } // Update with new larger threshold if this was a retry (otherwise // we would have already inserted with NewThreshold above). Also // update failure info if requested. @@ -811,11 +866,15 @@ static void computeImportForFunction( "selectCallee() didn't honor the threshold"); auto ExportModulePath = ResolvedCalleeSummary->modulePath(); - auto ILI = ImportList[ExportModulePath].insert(VI.getGUID()); + + // Try emplace the definition entry, and update stats based on insertion + // status. + auto [Iter, Inserted] = ImportList[ExportModulePath].try_emplace( + VI.getGUID(), GlobalValueSummary::Definition); + // We previously decided to import this GUID definition if it was already // inserted in the set of imports from the exporting module. - bool PreviouslyImported = !ILI.second; - if (!PreviouslyImported) { + if (Inserted || Iter->second == GlobalValueSummary::Declaration) { NumImportedFunctionsThinLink++; if (IsHotCallsite) NumImportedHotFunctionsThinLink++; @@ -823,6 +882,9 @@ static void computeImportForFunction( NumImportedCriticalFunctionsThinLink++; } + if (Iter->second == GlobalValueSummary::Declaration) + Iter->second = GlobalValueSummary::Definition; + // Any calls/references made by this function will be marked exported // later, in ComputeCrossModuleImport, after import decisions are // complete, which is more efficient than adding them here. @@ -933,15 +995,33 @@ static bool isGlobalVarSummary(const ModuleSummaryIndex &Index, return false; } -template <class T> -static unsigned numGlobalVarSummaries(const ModuleSummaryIndex &Index, - T &Cont) { +// Return the number of global variable summaries in ExportSet. +static unsigned +numGlobalVarSummaries(const ModuleSummaryIndex &Index, + FunctionImporter::ExportSetTy &ExportSet) { unsigned NumGVS = 0; - for (auto &V : Cont) - if (isGlobalVarSummary(Index, V)) + for (auto &VI : ExportSet) + if (isGlobalVarSummary(Index, VI.getGUID())) ++NumGVS; return NumGVS; } + +// Given ImportMap, return the number of global variable summaries and record +// the number of defined function summaries as output parameter. +static unsigned +numGlobalVarSummaries(const ModuleSummaryIndex &Index, + FunctionImporter::FunctionsToImportTy &ImportMap, + unsigned &DefinedFS) { + unsigned NumGVS = 0; + DefinedFS = 0; + for (auto &[GUID, Type] : ImportMap) { + if (isGlobalVarSummary(Index, GUID)) + ++NumGVS; + else if (Type == GlobalValueSummary::Definition) + ++DefinedFS; + } + return NumGVS; +} #endif #ifndef NDEBUG @@ -949,13 +1029,12 @@ static bool checkVariableImport( const ModuleSummaryIndex &Index, DenseMap<StringRef, FunctionImporter::ImportMapTy> &ImportLists, DenseMap<StringRef, FunctionImporter::ExportSetTy> &ExportLists) { - DenseSet<GlobalValue::GUID> FlattenedImports; for (auto &ImportPerModule : ImportLists) for (auto &ExportPerModule : ImportPerModule.second) - FlattenedImports.insert(ExportPerModule.second.begin(), - ExportPerModule.second.end()); + for (auto &[GUID, Type] : ExportPerModule.second) + FlattenedImports.insert(GUID); // Checks that all GUIDs of read/writeonly vars we see in export lists // are also in the import lists. Otherwise we my face linker undefs, @@ -1007,6 +1086,8 @@ void llvm::ComputeCrossModuleImport( // since we may import the same values multiple times into different modules // during the import computation. for (auto &ELI : ExportLists) { + // `NewExports` tracks the VI that gets exported because the full definition + // of its user/referencer gets exported. FunctionImporter::ExportSetTy NewExports; const auto &DefinedGVSummaries = ModuleToDefinedGVSummaries.lookup(ELI.first); @@ -1039,10 +1120,10 @@ void llvm::ComputeCrossModuleImport( NewExports.insert(Ref); } } - // Prune list computed above to only include values defined in the exporting - // module. We do this after the above insertion since we may hit the same - // ref/call target multiple times in above loop, and it is more efficient to - // avoid a set lookup each time. + // Prune list computed above to only include values defined in the + // exporting module. We do this after the above insertion since we may hit + // the same ref/call target multiple times in above loop, and it is more + // efficient to avoid a set lookup each time. for (auto EI = NewExports.begin(); EI != NewExports.end();) { if (!DefinedGVSummaries.count(EI->getGUID())) NewExports.erase(EI++); @@ -1066,9 +1147,13 @@ void llvm::ComputeCrossModuleImport( << " modules.\n"); for (auto &Src : ModuleImports.second) { auto SrcModName = Src.first; - unsigned NumGVSPerMod = numGlobalVarSummaries(Index, Src.second); - LLVM_DEBUG(dbgs() << " - " << Src.second.size() - NumGVSPerMod - << " functions imported from " << SrcModName << "\n"); + unsigned DefinedFS = 0; + unsigned NumGVSPerMod = + numGlobalVarSummaries(Index, Src.second, DefinedFS); + LLVM_DEBUG(dbgs() << " - " << DefinedFS << " function definitions and " + << Src.second.size() - NumGVSPerMod - DefinedFS + << " function declarations imported from " << SrcModName + << "\n"); LLVM_DEBUG(dbgs() << " - " << NumGVSPerMod << " global vars imported from " << SrcModName << "\n"); } @@ -1084,9 +1169,12 @@ static void dumpImportListForModule(const ModuleSummaryIndex &Index, << ImportList.size() << " modules.\n"); for (auto &Src : ImportList) { auto SrcModName = Src.first; - unsigned NumGVSPerMod = numGlobalVarSummaries(Index, Src.second); - LLVM_DEBUG(dbgs() << " - " << Src.second.size() - NumGVSPerMod - << " functions imported from " << SrcModName << "\n"); + unsigned DefinedFS = 0; + unsigned NumGVSPerMod = numGlobalVarSummaries(Index, Src.second, DefinedFS); + LLVM_DEBUG(dbgs() << " - " << DefinedFS << " function definitions and " + << Src.second.size() - DefinedFS - NumGVSPerMod + << " function declarations imported from " << SrcModName + << "\n"); LLVM_DEBUG(dbgs() << " - " << NumGVSPerMod << " vars imported from " << SrcModName << "\n"); } @@ -1144,7 +1232,13 @@ static void ComputeCrossModuleImportForModuleFromIndexForTest( if (Summary->modulePath() == ModulePath) continue; // Add an entry to provoke importing by thinBackend. - ImportList[Summary->modulePath()].insert(GUID); + auto [Iter, Inserted] = ImportList[Summary->modulePath()].try_emplace( + GUID, Summary->importType()); + if (!Inserted) { + // Use 'std::min' to make sure definition (with enum value 0) takes + // precedence over declaration (with enum value 1). + Iter->second = std::min(Iter->second, Summary->importType()); + } } #ifndef NDEBUG dumpImportListForModule(Index, ModulePath, ImportList); @@ -1327,20 +1421,25 @@ void llvm::gatherImportedSummariesForModule( StringRef ModulePath, const DenseMap<StringRef, GVSummaryMapTy> &ModuleToDefinedGVSummaries, const FunctionImporter::ImportMapTy &ImportList, - std::map<std::string, GVSummaryMapTy> &ModuleToSummariesForIndex) { + std::map<std::string, GVSummaryMapTy> &ModuleToSummariesForIndex, + GVSummaryPtrSet &DecSummaries) { // Include all summaries from the importing module. ModuleToSummariesForIndex[std::string(ModulePath)] = ModuleToDefinedGVSummaries.lookup(ModulePath); // Include summaries for imports. for (const auto &ILI : ImportList) { auto &SummariesForIndex = ModuleToSummariesForIndex[std::string(ILI.first)]; + const auto &DefinedGVSummaries = ModuleToDefinedGVSummaries.lookup(ILI.first); - for (const auto &GI : ILI.second) { - const auto &DS = DefinedGVSummaries.find(GI); + for (const auto &[GUID, Type] : ILI.second) { + const auto &DS = DefinedGVSummaries.find(GUID); assert(DS != DefinedGVSummaries.end() && "Expected a defined summary for imported global value"); - SummariesForIndex[GI] = DS->second; + if (Type == GlobalValueSummary::Declaration) + DecSummaries.insert(DS->second); + + SummariesForIndex[GUID] = DS->second; } } } @@ -1350,7 +1449,7 @@ std::error_code llvm::EmitImportsFiles( StringRef ModulePath, StringRef OutputFilename, const std::map<std::string, GVSummaryMapTy> &ModuleToSummariesForIndex) { std::error_code EC; - raw_fd_ostream ImportsOS(OutputFilename, EC, sys::fs::OpenFlags::OF_None); + raw_fd_ostream ImportsOS(OutputFilename, EC, sys::fs::OpenFlags::OF_Text); if (EC) return EC; for (const auto &ILI : ModuleToSummariesForIndex) @@ -1612,6 +1711,16 @@ Expected<bool> FunctionImporter::importFunctions( for (const auto &FunctionsToImportPerModule : ImportList) { ModuleNameOrderedList.insert(FunctionsToImportPerModule.first); } + + auto getImportType = [&](const FunctionsToImportTy &GUIDToImportType, + GlobalValue::GUID GUID) + -> std::optional<GlobalValueSummary::ImportKind> { + auto Iter = GUIDToImportType.find(GUID); + if (Iter == GUIDToImportType.end()) + return std::nullopt; + return Iter->second; + }; + for (const auto &Name : ModuleNameOrderedList) { // Get the module for the import const auto &FunctionsToImportPerModule = ImportList.find(Name); @@ -1629,25 +1738,43 @@ Expected<bool> FunctionImporter::importFunctions( return std::move(Err); auto &ImportGUIDs = FunctionsToImportPerModule->second; + // Find the globals to import SetVector<GlobalValue *> GlobalsToImport; for (Function &F : *SrcModule) { if (!F.hasName()) continue; auto GUID = F.getGUID(); - auto Import = ImportGUIDs.count(GUID); - LLVM_DEBUG(dbgs() << (Import ? "Is" : "Not") << " importing function " + auto MaybeImportType = getImportType(ImportGUIDs, GUID); + + bool ImportDefinition = + (MaybeImportType && + (*MaybeImportType == GlobalValueSummary::Definition)); + + LLVM_DEBUG(dbgs() << (MaybeImportType ? "Is" : "Not") + << " importing function" + << (ImportDefinition + ? " definition " + : (MaybeImportType ? " declaration " : " ")) << GUID << " " << F.getName() << " from " << SrcModule->getSourceFileName() << "\n"); - if (Import) { + if (ImportDefinition) { if (Error Err = F.materialize()) return std::move(Err); - if (EnableImportMetadata) { - // Add 'thinlto_src_module' metadata for statistics and debugging. + // MemProf should match function's definition and summary, + // 'thinlto_src_module' is needed. + if (EnableImportMetadata || EnableMemProfContextDisambiguation) { + // Add 'thinlto_src_module' and 'thinlto_src_file' metadata for + // statistics and debugging. F.setMetadata( "thinlto_src_module", MDNode::get(DestModule.getContext(), {MDString::get(DestModule.getContext(), + SrcModule->getModuleIdentifier())})); + F.setMetadata( + "thinlto_src_file", + MDNode::get(DestModule.getContext(), + {MDString::get(DestModule.getContext(), SrcModule->getSourceFileName())})); } GlobalsToImport.insert(&F); @@ -1657,11 +1784,20 @@ Expected<bool> FunctionImporter::importFunctions( if (!GV.hasName()) continue; auto GUID = GV.getGUID(); - auto Import = ImportGUIDs.count(GUID); - LLVM_DEBUG(dbgs() << (Import ? "Is" : "Not") << " importing global " + auto MaybeImportType = getImportType(ImportGUIDs, GUID); + + bool ImportDefinition = + (MaybeImportType && + (*MaybeImportType == GlobalValueSummary::Definition)); + + LLVM_DEBUG(dbgs() << (MaybeImportType ? "Is" : "Not") + << " importing global" + << (ImportDefinition + ? " definition " + : (MaybeImportType ? " declaration " : " ")) << GUID << " " << GV.getName() << " from " << SrcModule->getSourceFileName() << "\n"); - if (Import) { + if (ImportDefinition) { if (Error Err = GV.materialize()) return std::move(Err); ImportedGVCount += GlobalsToImport.insert(&GV); @@ -1671,11 +1807,20 @@ Expected<bool> FunctionImporter::importFunctions( if (!GA.hasName() || isa<GlobalIFunc>(GA.getAliaseeObject())) continue; auto GUID = GA.getGUID(); - auto Import = ImportGUIDs.count(GUID); - LLVM_DEBUG(dbgs() << (Import ? "Is" : "Not") << " importing alias " + auto MaybeImportType = getImportType(ImportGUIDs, GUID); + + bool ImportDefinition = + (MaybeImportType && + (*MaybeImportType == GlobalValueSummary::Definition)); + + LLVM_DEBUG(dbgs() << (MaybeImportType ? "Is" : "Not") + << " importing alias" + << (ImportDefinition + ? " definition " + : (MaybeImportType ? " declaration " : " ")) << GUID << " " << GA.getName() << " from " << SrcModule->getSourceFileName() << "\n"); - if (Import) { + if (ImportDefinition) { if (Error Err = GA.materialize()) return std::move(Err); // Import alias as a copy of its aliasee. @@ -1686,12 +1831,18 @@ Expected<bool> FunctionImporter::importFunctions( LLVM_DEBUG(dbgs() << "Is importing aliasee fn " << GO->getGUID() << " " << GO->getName() << " from " << SrcModule->getSourceFileName() << "\n"); - if (EnableImportMetadata) { - // Add 'thinlto_src_module' metadata for statistics and debugging. + if (EnableImportMetadata || EnableMemProfContextDisambiguation) { + // Add 'thinlto_src_module' and 'thinlto_src_file' metadata for + // statistics and debugging. Fn->setMetadata( "thinlto_src_module", MDNode::get(DestModule.getContext(), {MDString::get(DestModule.getContext(), + SrcModule->getModuleIdentifier())})); + Fn->setMetadata( + "thinlto_src_file", + MDNode::get(DestModule.getContext(), + {MDString::get(DestModule.getContext(), SrcModule->getSourceFileName())})); } GlobalsToImport.insert(Fn); @@ -1735,6 +1886,7 @@ Expected<bool> FunctionImporter::importFunctions( NumImportedFunctions += (ImportedCount - ImportedGVCount); NumImportedGlobalVars += ImportedGVCount; + // TODO: Print counters for definitions and declarations in the debugging log. LLVM_DEBUG(dbgs() << "Imported " << ImportedCount - ImportedGVCount << " functions for Module " << DestModule.getModuleIdentifier() << "\n"); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp index a4c12006ee24..2d7b7355229e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp @@ -689,7 +689,9 @@ bool FunctionSpecializer::run() { // specialization budget, which is derived from maximum number of // specializations per specialization candidate function. auto CompareScore = [&AllSpecs](unsigned I, unsigned J) { - return AllSpecs[I].Score > AllSpecs[J].Score; + if (AllSpecs[I].Score != AllSpecs[J].Score) + return AllSpecs[I].Score > AllSpecs[J].Score; + return I > J; }; const unsigned NSpecs = std::min(NumCandidates * MaxClones, unsigned(AllSpecs.size())); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp index 951372adcfa9..ab1e41ebf9a9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalOpt.cpp @@ -87,8 +87,11 @@ STATISTIC(NumNestRemoved , "Number of nest attributes removed"); STATISTIC(NumAliasesResolved, "Number of global aliases resolved"); STATISTIC(NumAliasesRemoved, "Number of global aliases eliminated"); STATISTIC(NumCXXDtorsRemoved, "Number of global C++ destructors removed"); +STATISTIC(NumAtExitRemoved, "Number of atexit handlers removed"); STATISTIC(NumInternalFunc, "Number of internal functions"); STATISTIC(NumColdCC, "Number of functions marked coldcc"); +STATISTIC(NumIFuncsResolved, "Number of statically resolved IFuncs"); +STATISTIC(NumIFuncsDeleted, "Number of IFuncs removed"); static cl::opt<bool> EnableColdCCStressTest("enable-coldcc-stress-test", @@ -294,7 +297,7 @@ static bool CleanupConstantGlobalUsers(GlobalVariable *GV, // A load from a uniform value is always the same, regardless of any // applied offset. Type *Ty = LI->getType(); - if (Constant *Res = ConstantFoldLoadFromUniformValue(Init, Ty)) { + if (Constant *Res = ConstantFoldLoadFromUniformValue(Init, Ty, DL)) { LI->replaceAllUsesWith(Res); EraseFromParent(LI); continue; @@ -304,6 +307,10 @@ static bool CleanupConstantGlobalUsers(GlobalVariable *GV, APInt Offset(DL.getIndexTypeSizeInBits(PtrOp->getType()), 0); PtrOp = PtrOp->stripAndAccumulateConstantOffsets( DL, Offset, /* AllowNonInbounds */ true); + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(PtrOp)) { + if (II->getIntrinsicID() == Intrinsic::threadlocal_address) + PtrOp = II->getArgOperand(0); + } if (PtrOp == GV) { if (auto *Value = ConstantFoldLoadFromConst(Init, Ty, Offset, DL)) { LI->replaceAllUsesWith(Value); @@ -316,6 +323,9 @@ static bool CleanupConstantGlobalUsers(GlobalVariable *GV, } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(U)) { // memset/cpy/mv if (getUnderlyingObject(MI->getRawDest()) == GV) EraseFromParent(MI); + } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U)) { + if (II->getIntrinsicID() == Intrinsic::threadlocal_address) + append_range(WorkList, II->users()); } } @@ -951,7 +961,7 @@ OptimizeGlobalAddressOfAllocation(GlobalVariable *GV, CallInst *CI, GV->getContext(), !isa<ConstantPointerNull>(SI->getValueOperand())), InitBool, false, Align(1), SI->getOrdering(), - SI->getSyncScopeID(), SI); + SI->getSyncScopeID(), SI->getIterator()); SI->eraseFromParent(); continue; } @@ -968,7 +978,8 @@ OptimizeGlobalAddressOfAllocation(GlobalVariable *GV, CallInst *CI, // Replace the cmp X, 0 with a use of the bool value. Value *LV = new LoadInst(InitBool->getValueType(), InitBool, InitBool->getName() + ".val", false, Align(1), - LI->getOrdering(), LI->getSyncScopeID(), LI); + LI->getOrdering(), LI->getSyncScopeID(), + LI->getIterator()); InitBoolUsed = true; switch (ICI->getPredicate()) { default: llvm_unreachable("Unknown ICmp Predicate!"); @@ -980,7 +991,7 @@ OptimizeGlobalAddressOfAllocation(GlobalVariable *GV, CallInst *CI, break; case ICmpInst::ICMP_ULE: case ICmpInst::ICMP_EQ: - LV = BinaryOperator::CreateNot(LV, "notinit", ICI); + LV = BinaryOperator::CreateNot(LV, "notinit", ICI->getIterator()); break; case ICmpInst::ICMP_NE: case ICmpInst::ICMP_UGT: @@ -1202,7 +1213,7 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { for(auto *GVe : GVs){ DIGlobalVariable *DGV = GVe->getVariable(); DIExpression *E = GVe->getExpression(); - const DataLayout &DL = GV->getParent()->getDataLayout(); + const DataLayout &DL = GV->getDataLayout(); unsigned SizeInOctets = DL.getTypeAllocSizeInBits(NewGV->getValueType()) / 8; @@ -1258,9 +1269,10 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { if (LoadInst *LI = dyn_cast<LoadInst>(StoredVal)) { assert(LI->getOperand(0) == GV && "Not a copy!"); // Insert a new load, to preserve the saved value. - StoreVal = new LoadInst(NewGV->getValueType(), NewGV, - LI->getName() + ".b", false, Align(1), - LI->getOrdering(), LI->getSyncScopeID(), LI); + StoreVal = + new LoadInst(NewGV->getValueType(), NewGV, LI->getName() + ".b", + false, Align(1), LI->getOrdering(), + LI->getSyncScopeID(), LI->getIterator()); } else { assert((isa<CastInst>(StoredVal) || isa<SelectInst>(StoredVal)) && "This is not a form that we understand!"); @@ -1270,19 +1282,19 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { } StoreInst *NSI = new StoreInst(StoreVal, NewGV, false, Align(1), SI->getOrdering(), - SI->getSyncScopeID(), SI); + SI->getSyncScopeID(), SI->getIterator()); NSI->setDebugLoc(SI->getDebugLoc()); } else { // Change the load into a load of bool then a select. LoadInst *LI = cast<LoadInst>(UI); - LoadInst *NLI = new LoadInst(NewGV->getValueType(), NewGV, - LI->getName() + ".b", false, Align(1), - LI->getOrdering(), LI->getSyncScopeID(), LI); + LoadInst *NLI = new LoadInst( + NewGV->getValueType(), NewGV, LI->getName() + ".b", false, Align(1), + LI->getOrdering(), LI->getSyncScopeID(), LI->getIterator()); Instruction *NSI; if (IsOneZero) - NSI = new ZExtInst(NLI, LI->getType(), "", LI); + NSI = new ZExtInst(NLI, LI->getType(), "", LI->getIterator()); else - NSI = SelectInst::Create(NLI, OtherVal, InitVal, "", LI); + NSI = SelectInst::Create(NLI, OtherVal, InitVal, "", LI->getIterator()); NSI->takeName(LI); // Since LI is split into two instructions, NLI and NSI both inherit the // same DebugLoc @@ -1344,7 +1356,7 @@ static bool isPointerValueDeadOnEntryToFunction( // // We don't do an exhaustive search for memory operations - simply look // through bitcasts as they're quite common and benign. - const DataLayout &DL = GV->getParent()->getDataLayout(); + const DataLayout &DL = GV->getDataLayout(); SmallVector<LoadInst *, 4> Loads; SmallVector<StoreInst *, 4> Stores; for (auto *U : GV->users()) { @@ -1440,7 +1452,7 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS, function_ref<TargetTransformInfo &(Function &)> GetTTI, function_ref<TargetLibraryInfo &(Function &)> GetTLI, function_ref<DominatorTree &(Function &)> LookupDomTree) { - auto &DL = GV->getParent()->getDataLayout(); + auto &DL = GV->getDataLayout(); // If this is a first class global and has only one accessing function and // this function is non-recursive, we replace the global with a local alloca // in this function. @@ -1457,17 +1469,17 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS, GS.AccessingFunction->doesNotRecurse() && isPointerValueDeadOnEntryToFunction(GS.AccessingFunction, GV, LookupDomTree)) { - const DataLayout &DL = GV->getParent()->getDataLayout(); + const DataLayout &DL = GV->getDataLayout(); LLVM_DEBUG(dbgs() << "LOCALIZING GLOBAL: " << *GV << "\n"); - Instruction &FirstI = const_cast<Instruction&>(*GS.AccessingFunction - ->getEntryBlock().begin()); + BasicBlock::iterator FirstI = + GS.AccessingFunction->getEntryBlock().begin().getNonConst(); Type *ElemTy = GV->getValueType(); // FIXME: Pass Global's alignment when globals have alignment - AllocaInst *Alloca = new AllocaInst(ElemTy, DL.getAllocaAddrSpace(), nullptr, - GV->getName(), &FirstI); + AllocaInst *Alloca = new AllocaInst(ElemTy, DL.getAllocaAddrSpace(), + nullptr, GV->getName(), FirstI); if (!isa<UndefValue>(GV->getInitializer())) - new StoreInst(GV->getInitializer(), Alloca, &FirstI); + new StoreInst(GV->getInitializer(), Alloca, FirstI); GV->replaceAllUsesWith(Alloca); GV->eraseFromParent(); @@ -1528,7 +1540,7 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS, ++NumMarked; } if (!GV->getInitializer()->getType()->isSingleValueType()) { - const DataLayout &DL = GV->getParent()->getDataLayout(); + const DataLayout &DL = GV->getDataLayout(); if (SRAGlobal(GV, DL)) return true; } @@ -1857,7 +1869,7 @@ static void RemovePreallocated(Function *F) { assert((isa<CallInst>(CB) || isa<InvokeInst>(CB)) && "Unknown indirect call type"); - CallBase *NewCB = CallBase::Create(CB, OpBundles, CB); + CallBase *NewCB = CallBase::Create(CB, OpBundles, CB->getIterator()); CB->replaceAllUsesWith(NewCB); NewCB->takeName(CB); CB->eraseFromParent(); @@ -2212,6 +2224,9 @@ static bool mayHaveOtherReferences(GlobalValue &GV, const LLVMUsed &U) { static bool hasUsesToReplace(GlobalAlias &GA, const LLVMUsed &U, bool &RenameTarget) { + if (GA.isWeakForLinker()) + return false; + RenameTarget = false; bool Ret = false; if (hasUseOtherThanLLVMUsed(GA, U)) @@ -2317,18 +2332,19 @@ OptimizeGlobalAliases(Module &M, } static Function * -FindCXAAtExit(Module &M, function_ref<TargetLibraryInfo &(Function &)> GetTLI) { +FindAtExitLibFunc(Module &M, + function_ref<TargetLibraryInfo &(Function &)> GetTLI, + LibFunc Func) { // Hack to get a default TLI before we have actual Function. auto FuncIter = M.begin(); if (FuncIter == M.end()) return nullptr; auto *TLI = &GetTLI(*FuncIter); - LibFunc F = LibFunc_cxa_atexit; - if (!TLI->has(F)) + if (!TLI->has(Func)) return nullptr; - Function *Fn = M.getFunction(TLI->getName(F)); + Function *Fn = M.getFunction(TLI->getName(Func)); if (!Fn) return nullptr; @@ -2336,17 +2352,18 @@ FindCXAAtExit(Module &M, function_ref<TargetLibraryInfo &(Function &)> GetTLI) { TLI = &GetTLI(*Fn); // Make sure that the function has the correct prototype. - if (!TLI->getLibFunc(*Fn, F) || F != LibFunc_cxa_atexit) + LibFunc F; + if (!TLI->getLibFunc(*Fn, F) || F != Func) return nullptr; return Fn; } -/// Returns whether the given function is an empty C++ destructor and can -/// therefore be eliminated. -/// Note that we assume that other optimization passes have already simplified -/// the code so we simply check for 'ret'. -static bool cxxDtorIsEmpty(const Function &Fn) { +/// Returns whether the given function is an empty C++ destructor or atexit +/// handler and can therefore be eliminated. Note that we assume that other +/// optimization passes have already simplified the code so we simply check for +/// 'ret'. +static bool IsEmptyAtExitFunction(const Function &Fn) { // FIXME: We could eliminate C++ destructors if they're readonly/readnone and // nounwind, but that doesn't seem worth doing. if (Fn.isDeclaration()) @@ -2362,7 +2379,7 @@ static bool cxxDtorIsEmpty(const Function &Fn) { return false; } -static bool OptimizeEmptyGlobalCXXDtors(Function *CXAAtExitFn) { +static bool OptimizeEmptyGlobalAtExitDtors(Function *CXAAtExitFn, bool isCXX) { /// Itanium C++ ABI p3.3.5: /// /// After constructing a global (or local static) object, that will require @@ -2375,8 +2392,8 @@ static bool OptimizeEmptyGlobalCXXDtors(Function *CXAAtExitFn) { /// registered before this one. It returns zero if registration is /// successful, nonzero on failure. - // This pass will look for calls to __cxa_atexit where the function is trivial - // and remove them. + // This pass will look for calls to __cxa_atexit or atexit where the function + // is trivial and remove them. bool Changed = false; for (User *U : llvm::make_early_inc_range(CXAAtExitFn->users())) { @@ -2389,14 +2406,17 @@ static bool OptimizeEmptyGlobalCXXDtors(Function *CXAAtExitFn) { Function *DtorFn = dyn_cast<Function>(CI->getArgOperand(0)->stripPointerCasts()); - if (!DtorFn || !cxxDtorIsEmpty(*DtorFn)) + if (!DtorFn || !IsEmptyAtExitFunction(*DtorFn)) continue; // Just remove the call. CI->replaceAllUsesWith(Constant::getNullValue(CI->getType())); CI->eraseFromParent(); - ++NumCXXDtorsRemoved; + if (isCXX) + ++NumCXXDtorsRemoved; + else + ++NumAtExitRemoved; Changed |= true; } @@ -2404,6 +2424,62 @@ static bool OptimizeEmptyGlobalCXXDtors(Function *CXAAtExitFn) { return Changed; } +static Function *hasSideeffectFreeStaticResolution(GlobalIFunc &IF) { + if (IF.isInterposable()) + return nullptr; + + Function *Resolver = IF.getResolverFunction(); + if (!Resolver) + return nullptr; + + if (Resolver->isInterposable()) + return nullptr; + + // Only handle functions that have been optimized into a single basic block. + auto It = Resolver->begin(); + if (++It != Resolver->end()) + return nullptr; + + BasicBlock &BB = Resolver->getEntryBlock(); + + if (any_of(BB, [](Instruction &I) { return I.mayHaveSideEffects(); })) + return nullptr; + + auto *Ret = dyn_cast<ReturnInst>(BB.getTerminator()); + if (!Ret) + return nullptr; + + return dyn_cast<Function>(Ret->getReturnValue()); +} + +/// Find IFuncs that have resolvers that always point at the same statically +/// known callee, and replace their callers with a direct call. +static bool OptimizeStaticIFuncs(Module &M) { + bool Changed = false; + for (GlobalIFunc &IF : M.ifuncs()) + if (Function *Callee = hasSideeffectFreeStaticResolution(IF)) + if (!IF.use_empty() && + (!Callee->isDeclaration() || + none_of(IF.users(), [](User *U) { return isa<GlobalAlias>(U); }))) { + IF.replaceAllUsesWith(Callee); + NumIFuncsResolved++; + Changed = true; + } + return Changed; +} + +static bool +DeleteDeadIFuncs(Module &M, + SmallPtrSetImpl<const Comdat *> &NotDiscardableComdats) { + bool Changed = false; + for (GlobalIFunc &IF : make_early_inc_range(M.ifuncs())) + if (deleteIfDead(IF, NotDiscardableComdats)) { + NumIFuncsDeleted++; + Changed = true; + } + return Changed; +} + static bool optimizeGlobalsInModule(Module &M, const DataLayout &DL, function_ref<TargetLibraryInfo &(Function &)> GetTLI, @@ -2460,9 +2536,18 @@ optimizeGlobalsInModule(Module &M, const DataLayout &DL, // Try to remove trivial global destructors if they are not removed // already. - Function *CXAAtExitFn = FindCXAAtExit(M, GetTLI); - if (CXAAtExitFn) - LocalChange |= OptimizeEmptyGlobalCXXDtors(CXAAtExitFn); + if (Function *CXAAtExitFn = + FindAtExitLibFunc(M, GetTLI, LibFunc_cxa_atexit)) + LocalChange |= OptimizeEmptyGlobalAtExitDtors(CXAAtExitFn, true); + + if (Function *AtExitFn = FindAtExitLibFunc(M, GetTLI, LibFunc_atexit)) + LocalChange |= OptimizeEmptyGlobalAtExitDtors(AtExitFn, false); + + // Optimize IFuncs whose callee's are statically known. + LocalChange |= OptimizeStaticIFuncs(M); + + // Remove any IFuncs that are now dead. + LocalChange |= DeleteDeadIFuncs(M, NotDiscardableComdats); Changed |= LocalChange; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalSplit.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalSplit.cpp index 84e9c219f935..fd49b745fd75 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalSplit.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/GlobalSplit.cpp @@ -47,28 +47,66 @@ static bool splitGlobal(GlobalVariable &GV) { if (!Init) return false; - // Verify that each user of the global is an inrange getelementptr constant. - // From this it follows that any loads from or stores to that global must use - // a pointer derived from an inrange getelementptr constant, which is - // sufficient to allow us to apply the splitting transform. + const DataLayout &DL = GV.getDataLayout(); + const StructLayout *SL = DL.getStructLayout(Init->getType()); + ArrayRef<TypeSize> MemberOffsets = SL->getMemberOffsets(); + unsigned IndexWidth = DL.getIndexTypeSizeInBits(GV.getType()); + + // Verify that each user of the global is an inrange getelementptr constant, + // and collect information on how it relates to the global. + struct GEPInfo { + GEPOperator *GEP; + unsigned MemberIndex; + APInt MemberRelativeOffset; + + GEPInfo(GEPOperator *GEP, unsigned MemberIndex, APInt MemberRelativeOffset) + : GEP(GEP), MemberIndex(MemberIndex), + MemberRelativeOffset(std::move(MemberRelativeOffset)) {} + }; + SmallVector<GEPInfo> Infos; for (User *U : GV.users()) { - if (!isa<Constant>(U)) + auto *GEP = dyn_cast<GEPOperator>(U); + if (!GEP) return false; - auto *GEP = dyn_cast<GEPOperator>(U); - if (!GEP || !GEP->getInRangeIndex() || *GEP->getInRangeIndex() != 1 || - !isa<ConstantInt>(GEP->getOperand(1)) || - !cast<ConstantInt>(GEP->getOperand(1))->isZero() || - !isa<ConstantInt>(GEP->getOperand(2))) + std::optional<ConstantRange> InRange = GEP->getInRange(); + if (!InRange) + return false; + + APInt Offset(IndexWidth, 0); + if (!GEP->accumulateConstantOffset(DL, Offset)) + return false; + + // Determine source-relative inrange. + ConstantRange SrcInRange = InRange->sextOrTrunc(IndexWidth).add(Offset); + + // Check that the GEP offset is in the range (treating upper bound as + // inclusive here). + if (!SrcInRange.contains(Offset) && SrcInRange.getUpper() != Offset) + return false; + + // Find which struct member the range corresponds to. + if (SrcInRange.getLower().uge(SL->getSizeInBytes())) return false; + + unsigned MemberIndex = + SL->getElementContainingOffset(SrcInRange.getLower().getZExtValue()); + TypeSize MemberStart = MemberOffsets[MemberIndex]; + TypeSize MemberEnd = MemberIndex == MemberOffsets.size() - 1 + ? SL->getSizeInBytes() + : MemberOffsets[MemberIndex + 1]; + + // Verify that the range matches that struct member. + if (SrcInRange.getLower() != MemberStart || + SrcInRange.getUpper() != MemberEnd) + return false; + + Infos.emplace_back(GEP, MemberIndex, Offset - MemberStart); } SmallVector<MDNode *, 2> Types; GV.getMetadata(LLVMContext::MD_type, Types); - const DataLayout &DL = GV.getParent()->getDataLayout(); - const StructLayout *SL = DL.getStructLayout(Init->getType()); - IntegerType *Int32Ty = Type::getInt32Ty(GV.getContext()); std::vector<GlobalVariable *> SplitGlobals(Init->getNumOperands()); @@ -114,21 +152,13 @@ static bool splitGlobal(GlobalVariable &GV) { SplitGV->setVCallVisibilityMetadata(GV.getVCallVisibility()); } - for (User *U : GV.users()) { - auto *GEP = cast<GEPOperator>(U); - unsigned I = cast<ConstantInt>(GEP->getOperand(2))->getZExtValue(); - if (I >= SplitGlobals.size()) - continue; - - SmallVector<Value *, 4> Ops; - Ops.push_back(ConstantInt::get(Int32Ty, 0)); - for (unsigned I = 3; I != GEP->getNumOperands(); ++I) - Ops.push_back(GEP->getOperand(I)); - + for (const GEPInfo &Info : Infos) { + assert(Info.MemberIndex < SplitGlobals.size() && "Invalid member"); auto *NewGEP = ConstantExpr::getGetElementPtr( - SplitGlobals[I]->getInitializer()->getType(), SplitGlobals[I], Ops, - GEP->isInBounds()); - GEP->replaceAllUsesWith(NewGEP); + Type::getInt8Ty(GV.getContext()), SplitGlobals[Info.MemberIndex], + ConstantInt::get(GV.getContext(), Info.MemberRelativeOffset), + Info.GEP->isInBounds()); + Info.GEP->replaceAllUsesWith(NewGEP); } // Finally, remove the original global. Any remaining uses refer to invalid diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/HotColdSplitting.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/HotColdSplitting.cpp index fabb3c5fb921..2ec5da488683 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/HotColdSplitting.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/HotColdSplitting.cpp @@ -39,6 +39,7 @@ #include "llvm/IR/CFG.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/EHPersonalities.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" @@ -168,10 +169,24 @@ static bool mayExtractBlock(const BasicBlock &BB) { // // Resumes that are not reachable from a cleanup landing pad are considered to // be unreachable. It’s not safe to split them out either. + if (BB.hasAddressTaken() || BB.isEHPad()) return false; auto Term = BB.getTerminator(); - return !isa<InvokeInst>(Term) && !isa<ResumeInst>(Term); + if (isa<InvokeInst>(Term) || isa<ResumeInst>(Term)) + return false; + + // Do not outline basic blocks that have token type instructions. e.g., + // exception: + // %0 = cleanuppad within none [] + // call void @"?terminate@@YAXXZ"() [ "funclet"(token %0) ] + // br label %continue-exception + if (llvm::any_of( + BB, [](const Instruction &I) { return I.getType()->isTokenTy(); })) { + return false; + } + + return true; } /// Mark \p F cold. Based on this assumption, also optimize it for minimum size. @@ -215,15 +230,10 @@ bool HotColdSplitting::isFunctionCold(const Function &F) const { return false; } -bool HotColdSplitting::isBasicBlockCold(BasicBlock *BB, - BranchProbability ColdProbThresh, - SmallPtrSetImpl<BasicBlock *> &ColdBlocks, - SmallPtrSetImpl<BasicBlock *> &AnnotatedColdBlocks, - BlockFrequencyInfo *BFI) const { - // This block is already part of some outlining region. - if (ColdBlocks.count(BB)) - return true; - +bool HotColdSplitting::isBasicBlockCold( + BasicBlock *BB, BranchProbability ColdProbThresh, + SmallPtrSetImpl<BasicBlock *> &AnnotatedColdBlocks, + BlockFrequencyInfo *BFI) const { if (BFI) { if (PSI->isColdBlock(BB, BFI)) return true; @@ -263,6 +273,11 @@ bool HotColdSplitting::shouldOutlineFrom(const Function &F) const { F.hasFnAttribute(Attribute::SanitizeMemory)) return false; + // Do not outline scoped EH personality functions. + if (F.hasPersonalityFn()) + if (isScopedEHPersonality(classifyEHPersonality(F.getPersonalityFn()))) + return false; + return true; } @@ -372,18 +387,12 @@ static int getOutliningPenalty(ArrayRef<BasicBlock *> Region, return Penalty; } -Function *HotColdSplitting::extractColdRegion( - const BlockSequence &Region, const CodeExtractorAnalysisCache &CEAC, - DominatorTree &DT, BlockFrequencyInfo *BFI, TargetTransformInfo &TTI, - OptimizationRemarkEmitter &ORE, AssumptionCache *AC, unsigned Count) { +// Determine if it is beneficial to split the \p Region. +bool HotColdSplitting::isSplittingBeneficial(CodeExtractor &CE, + const BlockSequence &Region, + TargetTransformInfo &TTI) { assert(!Region.empty()); - // TODO: Pass BFI and BPI to update profile information. - CodeExtractor CE(Region, &DT, /* AggregateArgs */ false, /* BFI */ nullptr, - /* BPI */ nullptr, AC, /* AllowVarArgs */ false, - /* AllowAlloca */ false, /* AllocaBlock */ nullptr, - /* Suffix */ "cold." + std::to_string(Count)); - // Perform a simple cost/benefit analysis to decide whether or not to permit // splitting. SetVector<Value *> Inputs, Outputs, Sinks; @@ -394,9 +403,18 @@ Function *HotColdSplitting::extractColdRegion( LLVM_DEBUG(dbgs() << "Split profitability: benefit = " << OutliningBenefit << ", penalty = " << OutliningPenalty << "\n"); if (!OutliningBenefit.isValid() || OutliningBenefit <= OutliningPenalty) - return nullptr; + return false; - Function *OrigF = Region[0]->getParent(); + return true; +} + +// Split the single \p EntryPoint cold region. \p CE is the region code +// extractor. +Function *HotColdSplitting::extractColdRegion( + BasicBlock &EntryPoint, CodeExtractor &CE, + const CodeExtractorAnalysisCache &CEAC, BlockFrequencyInfo *BFI, + TargetTransformInfo &TTI, OptimizationRemarkEmitter &ORE) { + Function *OrigF = EntryPoint.getParent(); if (Function *OutF = CE.extractCodeRegion(CEAC)) { User *U = *OutF->user_begin(); CallInst *CI = cast<CallInst>(U); @@ -419,7 +437,7 @@ Function *HotColdSplitting::extractColdRegion( LLVM_DEBUG(llvm::dbgs() << "Outlined Region: " << *OutF); ORE.emit([&]() { return OptimizationRemark(DEBUG_TYPE, "HotColdSplit", - &*Region[0]->begin()) + &*EntryPoint.begin()) << ore::NV("Original", OrigF) << " split cold code into " << ore::NV("Split", OutF); }); @@ -428,9 +446,9 @@ Function *HotColdSplitting::extractColdRegion( ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "ExtractFailed", - &*Region[0]->begin()) + &*EntryPoint.begin()) << "Failed to extract region at block " - << ore::NV("Block", Region.front()); + << ore::NV("Block", &EntryPoint); }); return nullptr; } @@ -620,16 +638,18 @@ public: } // namespace bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) { - bool Changed = false; - - // The set of cold blocks. + // The set of cold blocks outlined. SmallPtrSet<BasicBlock *, 4> ColdBlocks; + // The set of cold blocks cannot be outlined. + SmallPtrSet<BasicBlock *, 4> CannotBeOutlinedColdBlocks; + // Set of cold blocks obtained with RPOT. SmallPtrSet<BasicBlock *, 4> AnnotatedColdBlocks; - // The worklist of non-intersecting regions left to outline. - SmallVector<OutliningRegion, 2> OutliningWorklist; + // The worklist of non-intersecting regions left to outline. The first member + // of the pair is the entry point into the region to be outlined. + SmallVector<std::pair<BasicBlock *, CodeExtractor>, 2> OutliningWorklist; // Set up an RPO traversal. Experimentally, this performs better (outlines // more) than a PO traversal, because we prevent region overlap by keeping @@ -655,10 +675,18 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) { if (ColdBranchProbDenom.getNumOccurrences()) ColdProbThresh = BranchProbability(1, ColdBranchProbDenom.getValue()); + unsigned OutlinedFunctionID = 1; // Find all cold regions. for (BasicBlock *BB : RPOT) { - if (!isBasicBlockCold(BB, ColdProbThresh, ColdBlocks, AnnotatedColdBlocks, - BFI)) + // This block is already part of some outlining region. + if (ColdBlocks.count(BB)) + continue; + + // This block is already part of some region cannot be outlined. + if (CannotBeOutlinedColdBlocks.count(BB)) + continue; + + if (!isBasicBlockCold(BB, ColdProbThresh, AnnotatedColdBlocks, BFI)) continue; LLVM_DEBUG({ @@ -681,50 +709,69 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) { return markFunctionCold(F); } - // If this outlining region intersects with another, drop the new region. - // - // TODO: It's theoretically possible to outline more by only keeping the - // largest region which contains a block, but the extra bookkeeping to do - // this is tricky/expensive. - bool RegionsOverlap = any_of(Region.blocks(), [&](const BlockTy &Block) { - return !ColdBlocks.insert(Block.first).second; - }); - if (RegionsOverlap) - continue; + do { + BlockSequence SubRegion = Region.takeSingleEntrySubRegion(*DT); + LLVM_DEBUG({ + dbgs() << "Hot/cold splitting attempting to outline these blocks:\n"; + for (BasicBlock *BB : SubRegion) + BB->dump(); + }); + + // TODO: Pass BFI and BPI to update profile information. + CodeExtractor CE( + SubRegion, &*DT, /* AggregateArgs */ false, /* BFI */ nullptr, + /* BPI */ nullptr, AC, /* AllowVarArgs */ false, + /* AllowAlloca */ false, /* AllocaBlock */ nullptr, + /* Suffix */ "cold." + std::to_string(OutlinedFunctionID)); + + if (CE.isEligible() && isSplittingBeneficial(CE, SubRegion, TTI) && + // If this outlining region intersects with another, drop the new + // region. + // + // TODO: It's theoretically possible to outline more by only keeping + // the largest region which contains a block, but the extra + // bookkeeping to do this is tricky/expensive. + none_of(SubRegion, [&](BasicBlock *Block) { + return ColdBlocks.contains(Block); + })) { + ColdBlocks.insert(SubRegion.begin(), SubRegion.end()); + + LLVM_DEBUG({ + for (auto *Block : SubRegion) + dbgs() << " contains cold block:" << Block->getName() << "\n"; + }); + + OutliningWorklist.emplace_back( + std::make_pair(SubRegion[0], std::move(CE))); + ++OutlinedFunctionID; + } else { + // The cold block region cannot be outlined. + for (auto *Block : SubRegion) + if ((DT->dominates(BB, Block) && PDT->dominates(Block, BB)) || + (PDT->dominates(BB, Block) && DT->dominates(Block, BB))) + // Will skip this cold block in the loop to save the compile time + CannotBeOutlinedColdBlocks.insert(Block); + } + } while (!Region.empty()); - OutliningWorklist.emplace_back(std::move(Region)); ++NumColdRegionsFound; } } if (OutliningWorklist.empty()) - return Changed; + return false; // Outline single-entry cold regions, splitting up larger regions as needed. - unsigned OutlinedFunctionID = 1; // Cache and recycle the CodeExtractor analysis to avoid O(n^2) compile-time. CodeExtractorAnalysisCache CEAC(F); - do { - OutliningRegion Region = OutliningWorklist.pop_back_val(); - assert(!Region.empty() && "Empty outlining region in worklist"); - do { - BlockSequence SubRegion = Region.takeSingleEntrySubRegion(*DT); - LLVM_DEBUG({ - dbgs() << "Hot/cold splitting attempting to outline these blocks:\n"; - for (BasicBlock *BB : SubRegion) - BB->dump(); - }); - - Function *Outlined = extractColdRegion(SubRegion, CEAC, *DT, BFI, TTI, - ORE, AC, OutlinedFunctionID); - if (Outlined) { - ++OutlinedFunctionID; - Changed = true; - } - } while (!Region.empty()); - } while (!OutliningWorklist.empty()); + for (auto &BCE : OutliningWorklist) { + Function *Outlined = + extractColdRegion(*BCE.first, BCE.second, CEAC, BFI, TTI, ORE); + assert(Outlined && "Should be outlined"); + (void)Outlined; + } - return Changed; + return true; } bool HotColdSplitting::run(Module &M) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp index 8e6d0e814372..96c803c0186e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/IROutliner.cpp @@ -678,7 +678,7 @@ Function *IROutliner::createFunction(Module &M, OutlinableGroup &Group, Mg.getNameWithPrefix(MangledNameStream, F, false); DISubprogram *OutlinedSP = DB.createFunction( - Unit /* Context */, F->getName(), MangledNameStream.str(), + Unit /* Context */, F->getName(), Dummy, Unit /* File */, 0 /* Line 0 is reserved for compiler-generated code. */, DB.createSubroutineType( @@ -721,6 +721,12 @@ static void moveFunctionData(Function &Old, Function &New, std::vector<Instruction *> DebugInsts; for (Instruction &Val : CurrBB) { + // Since debug-info originates from many different locations in the + // program, it will cause incorrect reporting from a debugger if we keep + // the same debug instructions. Drop non-intrinsic DbgVariableRecords + // here, collect intrinsics for removal later. + Val.dropDbgRecords(); + // We must handle the scoping of called functions differently than // other outlined instructions. if (!isa<CallInst>(&Val)) { @@ -744,10 +750,7 @@ static void moveFunctionData(Function &Old, Function &New, // From this point we are only handling call instructions. CallInst *CI = cast<CallInst>(&Val); - // We add any debug statements here, to be removed after. Since the - // instructions originate from many different locations in the program, - // it will cause incorrect reporting from a debugger if we keep the - // same debug instructions. + // Collect debug intrinsics for later removal. if (isa<DbgInfoIntrinsic>(CI)) { DebugInsts.push_back(&Val); continue; @@ -1498,7 +1501,7 @@ CallInst *replaceCalledFunction(Module &M, OutlinableRegion &Region) { << *AggFunc << " with new set of arguments\n"); // Create the new call instruction and erase the old one. Call = CallInst::Create(AggFunc->getFunctionType(), AggFunc, NewCallArgs, "", - Call); + Call->getIterator()); // It is possible that the call to the outlined function is either the first // instruction is in the new block, the last instruction, or both. If either diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp index a9747aebf67b..23ee23eb047f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/Inliner.cpp @@ -197,6 +197,14 @@ InlinerPass::getAdvisor(const ModuleAnalysisManagerCGSCCProxy::Result &MAM, return *IAA->getAdvisor(); } +void makeFunctionBodyUnreachable(Function &F) { + F.dropAllReferences(); + for (BasicBlock &BB : make_early_inc_range(F)) + BB.eraseFromParent(); + BasicBlock *BB = BasicBlock::Create(F.getContext(), "", &F); + new UnreachableInst(F.getContext(), BB); +} + PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR) { @@ -215,8 +223,6 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, InlineAdvisor &Advisor = getAdvisor(MAMProxy, FAM, M); Advisor.onPassEntry(&InitialC); - auto AdvisorOnExit = make_scope_exit([&] { Advisor.onPassExit(&InitialC); }); - // We use a single common worklist for calls across the entire SCC. We // process these in-order and append new calls introduced during inlining to // the end. The PriorityInlineOrder is optional here, in which the smaller @@ -271,12 +277,15 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, } } } - if (Calls.empty()) - return PreservedAnalyses::all(); // Capture updatable variable for the current SCC. auto *C = &InitialC; + auto AdvisorOnExit = make_scope_exit([&] { Advisor.onPassExit(C); }); + + if (Calls.empty()) + return PreservedAnalyses::all(); + // When inlining a callee produces new call sites, we want to keep track of // the fact that they were inlined from the callee. This allows us to avoid // infinite inlining in some obscure cases. To represent this, we use an @@ -448,11 +457,9 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, }), Calls.end()); - // Clear the body and queue the function itself for deletion when we - // finish inlining and call graph updates. - // Note that after this point, it is an error to do anything other - // than use the callee's address or delete it. - Callee.dropAllReferences(); + // Clear the body and queue the function itself for call graph + // updating when we finish inlining. + makeFunctionBodyUnreachable(Callee); assert(!is_contained(DeadFunctions, &Callee) && "Cannot put cause a function to become dead twice!"); DeadFunctions.push_back(&Callee); @@ -530,7 +537,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, if (!DeadFunctionsInComdats.empty()) { filterDeadComdatFunctions(DeadFunctionsInComdats); for (auto *Callee : DeadFunctionsInComdats) - Callee->dropAllReferences(); + makeFunctionBodyUnreachable(*Callee); DeadFunctions.append(DeadFunctionsInComdats); } @@ -542,25 +549,18 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // that is OK as all we do is delete things and add pointers to unordered // sets. for (Function *DeadF : DeadFunctions) { + CG.markDeadFunction(*DeadF); // Get the necessary information out of the call graph and nuke the // function there. Also, clear out any cached analyses. auto &DeadC = *CG.lookupSCC(*CG.lookup(*DeadF)); FAM.clear(*DeadF, DeadF->getName()); AM.clear(DeadC, DeadC.getName()); - auto &DeadRC = DeadC.getOuterRefSCC(); - CG.removeDeadFunction(*DeadF); // Mark the relevant parts of the call graph as invalid so we don't visit // them. UR.InvalidatedSCCs.insert(&DeadC); - UR.InvalidatedRefSCCs.insert(&DeadRC); - - // If the updated SCC was the one containing the deleted function, clear it. - if (&DeadC == UR.UpdatedC) - UR.UpdatedC = nullptr; - // And delete the actual function from the module. - M.getFunctionList().erase(DeadF); + UR.DeadFunctions.push_back(DeadF); ++NumDeleted; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp index 633fcb3314c4..0742b259c489 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp @@ -879,7 +879,7 @@ void LowerTypeTestsModule::buildBitSetsFromGlobalVariables( // Multiply by 2 to account for padding elements. Constant *CombinedGlobalIdxs[] = {ConstantInt::get(Int32Ty, 0), ConstantInt::get(Int32Ty, I * 2)}; - Constant *CombinedGlobalElemPtr = ConstantExpr::getGetElementPtr( + Constant *CombinedGlobalElemPtr = ConstantExpr::getInBoundsGetElementPtr( NewInit->getType(), CombinedGlobal, CombinedGlobalIdxs); assert(GV->getType()->getAddressSpace() == 0); GlobalAlias *GAlias = @@ -1519,8 +1519,10 @@ void LowerTypeTestsModule::createJumpTable( // for the function to avoid double BTI. This is a no-op without // -mbranch-protection=. if (JumpTableArch == Triple::aarch64 || JumpTableArch == Triple::thumb) { - F->addFnAttr("branch-target-enforcement", "false"); - F->addFnAttr("sign-return-address", "none"); + if (F->hasFnAttribute("branch-target-enforcement")) + F->removeFnAttr("branch-target-enforcement"); + if (F->hasFnAttribute("sign-return-address")) + F->removeFnAttr("sign-return-address"); } if (JumpTableArch == Triple::riscv32 || JumpTableArch == Triple::riscv64) { // Make sure the jump table assembly is not modified by the assembler or @@ -2074,7 +2076,7 @@ bool LowerTypeTestsModule::lower() { CfiFunctionLinkage Linkage; MDNode *FuncMD; // {name, linkage, type[, type...]} }; - DenseMap<StringRef, ExportedFunctionInfo> ExportedFunctions; + MapVector<StringRef, ExportedFunctionInfo> ExportedFunctions; if (ExportSummary) { // A set of all functions that are address taken by a live global object. DenseSet<GlobalValue::GUID> AddressTaken; diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp index e10b3c56ae14..66bd786c85df 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp @@ -44,6 +44,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/Cloning.h" +#include <deque> #include <sstream> #include <unordered_map> #include <vector> @@ -122,6 +123,10 @@ static cl::opt<unsigned> "frames through tail calls.")); namespace llvm { +cl::opt<bool> EnableMemProfContextDisambiguation( + "enable-memprof-context-disambiguation", cl::init(false), cl::Hidden, + cl::ZeroOrMore, cl::desc("Enable MemProf context disambiguation")); + // Indicate we are linking with an allocator that supports hot/cold operator // new interfaces. cl::opt<bool> SupportsHotColdNew( @@ -129,6 +134,8 @@ cl::opt<bool> SupportsHotColdNew( cl::desc("Linking with hot/cold operator new interfaces")); } // namespace llvm +extern cl::opt<bool> MemProfReportHintedSizes; + namespace { /// CRTP base for graphs built from either IR or ThinLTO summary index. /// @@ -167,6 +174,7 @@ public: void dump() const; void print(raw_ostream &OS) const; + void printTotalSizes(raw_ostream &OS) const; friend raw_ostream &operator<<(raw_ostream &OS, const CallsiteContextGraph &CCG) { @@ -258,8 +266,70 @@ public: // TODO: Should this be a map (from Caller node) for more efficient lookup? std::vector<std::shared_ptr<ContextEdge>> CallerEdges; - // The set of IDs for contexts including this node. - DenseSet<uint32_t> ContextIds; + // Get the list of edges from which we can compute allocation information + // such as the context ids and allocation type of this node. + const std::vector<std::shared_ptr<ContextEdge>> * + getEdgesWithAllocInfo() const { + // If node has any callees, compute from those, otherwise compute from + // callers (i.e. if this is the leaf allocation node). + if (!CalleeEdges.empty()) + return &CalleeEdges; + if (!CallerEdges.empty()) { + // A node with caller edges but no callee edges must be the allocation + // node. + assert(IsAllocation); + return &CallerEdges; + } + return nullptr; + } + + // Compute the context ids for this node from the union of its edge context + // ids. + DenseSet<uint32_t> getContextIds() const { + DenseSet<uint32_t> ContextIds; + auto *Edges = getEdgesWithAllocInfo(); + if (!Edges) + return {}; + unsigned Count = 0; + for (auto &Edge : *Edges) + Count += Edge->getContextIds().size(); + ContextIds.reserve(Count); + for (auto &Edge : *Edges) + ContextIds.insert(Edge->getContextIds().begin(), + Edge->getContextIds().end()); + return ContextIds; + } + + // Compute the allocation type for this node from the OR of its edge + // allocation types. + uint8_t computeAllocType() const { + auto *Edges = getEdgesWithAllocInfo(); + if (!Edges) + return (uint8_t)AllocationType::None; + uint8_t BothTypes = + (uint8_t)AllocationType::Cold | (uint8_t)AllocationType::NotCold; + uint8_t AllocType = (uint8_t)AllocationType::None; + for (auto &Edge : *Edges) { + AllocType |= Edge->AllocTypes; + // Bail early if alloc type reached both, no further refinement. + if (AllocType == BothTypes) + return AllocType; + } + return AllocType; + } + + // The context ids set for this node is empty if its edge context ids are + // also all empty. + bool emptyContextIds() const { + auto *Edges = getEdgesWithAllocInfo(); + if (!Edges) + return true; + for (auto &Edge : *Edges) { + if (!Edge->getContextIds().empty()) + return false; + } + return true; + } // List of clones of this ContextNode, initially empty. std::vector<ContextNode *> Clones; @@ -304,11 +374,11 @@ public: void printCall(raw_ostream &OS) const { Call.print(OS); } // True if this node was effectively removed from the graph, in which case - // its context id set, caller edges, and callee edges should all be empty. + // it should have an allocation type of None and empty context ids. bool isRemoved() const { - assert(ContextIds.empty() == - (CalleeEdges.empty() && CallerEdges.empty())); - return ContextIds.empty(); + assert((AllocTypes == (uint8_t)AllocationType::None) == + emptyContextIds()); + return AllocTypes == (uint8_t)AllocationType::None; } void dump() const; @@ -336,7 +406,7 @@ public: ContextEdge(ContextNode *Callee, ContextNode *Caller, uint8_t AllocType, DenseSet<uint32_t> ContextIds) : Callee(Callee), Caller(Caller), AllocTypes(AllocType), - ContextIds(ContextIds) {} + ContextIds(std::move(ContextIds)) {} DenseSet<uint32_t> &getContextIds() { return ContextIds; } @@ -349,9 +419,12 @@ public: } }; - /// Helper to remove callee edges that have allocation type None (due to not + /// Helpers to remove callee edges that have allocation type None (due to not /// carrying any context ids) after transformations. void removeNoneTypeCalleeEdges(ContextNode *Node); + void + recursivelyRemoveNoneTypeCalleeEdges(ContextNode *Node, + DenseSet<const ContextNode *> &Visited); protected: /// Get a list of nodes corresponding to the stack ids in the given callsite @@ -369,7 +442,7 @@ protected: void addStackNodesForMIB(ContextNode *AllocNode, CallStack<NodeT, IteratorT> &StackContext, CallStack<NodeT, IteratorT> &CallsiteContext, - AllocationType AllocType); + AllocationType AllocType, uint64_t TotalSize); /// Matches all callsite metadata (or summary) to the nodes created for /// allocation memprof MIB metadata, synthesizing new nodes to reflect any @@ -418,7 +491,8 @@ private: /// else to its callers. Also updates OrigNode's edges to remove any context /// ids moved to the newly created edge. void connectNewNode(ContextNode *NewNode, ContextNode *OrigNode, - bool TowardsCallee); + bool TowardsCallee, + DenseSet<uint32_t> RemainingContextIds); /// Get the stack id corresponding to the given Id or Index (for IR this will /// return itself, for a summary index this will return the id recorded in the @@ -431,9 +505,8 @@ private: /// we were able to identify the call chain through intermediate tail calls. /// In the latter case new context nodes are added to the graph for the /// identified tail calls, and their synthesized nodes are added to - /// TailCallToContextNodeMap. The EdgeIter is updated in either case to the - /// next element after the input position (either incremented or updated after - /// removing the old edge). + /// TailCallToContextNodeMap. The EdgeIter is updated in the latter case for + /// the updated edges and to prepare it for an increment in the caller. bool calleesMatch(CallTy Call, EdgeIter &EI, MapVector<CallInfo, ContextNode *> &TailCallToContextNodeMap); @@ -494,20 +567,17 @@ private: ContextNode *getNodeForAlloc(const CallInfo &C); ContextNode *getNodeForStackId(uint64_t StackId); - /// Removes the node information recorded for the given call. - void unsetNodeForInst(const CallInfo &C); - /// Computes the alloc type corresponding to the given context ids, by /// unioning their recorded alloc types. uint8_t computeAllocType(DenseSet<uint32_t> &ContextIds); - /// Returns the alloction type of the intersection of the contexts of two + /// Returns the allocation type of the intersection of the contexts of two /// nodes (based on their provided context id sets), optimized for the case /// when Node1Ids is smaller than Node2Ids. uint8_t intersectAllocTypesImpl(const DenseSet<uint32_t> &Node1Ids, const DenseSet<uint32_t> &Node2Ids); - /// Returns the alloction type of the intersection of the contexts of two + /// Returns the allocation type of the intersection of the contexts of two /// nodes (based on their provided context id sets). uint8_t intersectAllocTypes(const DenseSet<uint32_t> &Node1Ids, const DenseSet<uint32_t> &Node2Ids); @@ -515,34 +585,43 @@ private: /// Create a clone of Edge's callee and move Edge to that new callee node, /// performing the necessary context id and allocation type updates. /// If callee's caller edge iterator is supplied, it is updated when removing - /// the edge from that list. + /// the edge from that list. If ContextIdsToMove is non-empty, only that + /// subset of Edge's ids are moved to an edge to the new callee. ContextNode * moveEdgeToNewCalleeClone(const std::shared_ptr<ContextEdge> &Edge, - EdgeIter *CallerEdgeI = nullptr); + EdgeIter *CallerEdgeI = nullptr, + DenseSet<uint32_t> ContextIdsToMove = {}); /// Change the callee of Edge to existing callee clone NewCallee, performing /// the necessary context id and allocation type updates. /// If callee's caller edge iterator is supplied, it is updated when removing - /// the edge from that list. + /// the edge from that list. If ContextIdsToMove is non-empty, only that + /// subset of Edge's ids are moved to an edge to the new callee. void moveEdgeToExistingCalleeClone(const std::shared_ptr<ContextEdge> &Edge, ContextNode *NewCallee, EdgeIter *CallerEdgeI = nullptr, - bool NewClone = false); + bool NewClone = false, + DenseSet<uint32_t> ContextIdsToMove = {}); /// Recursively perform cloning on the graph for the given Node and its /// callers, in order to uniquely identify the allocation behavior of an - /// allocation given its context. - void identifyClones(ContextNode *Node, - DenseSet<const ContextNode *> &Visited); + /// allocation given its context. The context ids of the allocation being + /// processed are given in AllocContextIds. + void identifyClones(ContextNode *Node, DenseSet<const ContextNode *> &Visited, + const DenseSet<uint32_t> &AllocContextIds); /// Map from each context ID to the AllocationType assigned to that context. - std::map<uint32_t, AllocationType> ContextIdToAllocationType; + DenseMap<uint32_t, AllocationType> ContextIdToAllocationType; + + /// Map from each contextID to the profiled aggregate allocation size, + /// optionally populated when requested (via MemProfReportHintedSizes). + DenseMap<uint32_t, uint64_t> ContextIdToTotalSize; /// Identifies the context node created for a stack id when adding the MIB /// contexts to the graph. This is used to locate the context nodes when /// trying to assign the corresponding callsites with those stack ids to these /// nodes. - std::map<uint64_t, ContextNode *> StackEntryIdToContextNodeMap; + DenseMap<uint64_t, ContextNode *> StackEntryIdToContextNodeMap; /// Maps to track the calls to their corresponding nodes in the graph. MapVector<CallInfo, ContextNode *> AllocationCallToContextNodeMap; @@ -798,15 +877,6 @@ CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::getNodeForStackId( } template <typename DerivedCCG, typename FuncTy, typename CallTy> -void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::unsetNodeForInst( - const CallInfo &C) { - AllocationCallToContextNodeMap.erase(C) || - NonAllocationCallToContextNodeMap.erase(C); - assert(!AllocationCallToContextNodeMap.count(C) && - !NonAllocationCallToContextNodeMap.count(C)); -} - -template <typename DerivedCCG, typename FuncTy, typename CallTy> void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode:: addOrUpdateCallerEdge(ContextNode *Caller, AllocationType AllocType, unsigned int ContextId) { @@ -940,11 +1010,24 @@ CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::addAllocNode( return AllocNode; } +static std::string getAllocTypeString(uint8_t AllocTypes) { + if (!AllocTypes) + return "None"; + std::string Str; + if (AllocTypes & (uint8_t)AllocationType::NotCold) + Str += "NotCold"; + if (AllocTypes & (uint8_t)AllocationType::Cold) + Str += "Cold"; + return Str; +} + template <typename DerivedCCG, typename FuncTy, typename CallTy> template <class NodeT, class IteratorT> void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::addStackNodesForMIB( ContextNode *AllocNode, CallStack<NodeT, IteratorT> &StackContext, - CallStack<NodeT, IteratorT> &CallsiteContext, AllocationType AllocType) { + CallStack<NodeT, IteratorT> &CallsiteContext, AllocationType AllocType, + uint64_t TotalSize) { + assert(!MemProfReportHintedSizes || TotalSize > 0); // Treating the hot alloc type as NotCold before the disambiguation for "hot" // is done. if (AllocType == AllocationType::Hot) @@ -952,9 +1035,13 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::addStackNodesForMIB( ContextIdToAllocationType[++LastContextId] = AllocType; + if (MemProfReportHintedSizes) { + assert(TotalSize); + ContextIdToTotalSize[LastContextId] = TotalSize; + } + // Update alloc type and context ids for this MIB. AllocNode->AllocTypes |= (uint8_t)AllocType; - AllocNode->ContextIds.insert(LastContextId); // Now add or update nodes for each stack id in alloc's context. // Later when processing the stack ids on non-alloc callsites we will adjust @@ -979,7 +1066,6 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::addStackNodesForMIB( auto Ins = StackIdSet.insert(StackId); if (!Ins.second) StackNode->Recursive = true; - StackNode->ContextIds.insert(LastContextId); StackNode->AllocTypes |= (uint8_t)AllocType; PrevNode->addOrUpdateCallerEdge(StackNode, AllocType, LastContextId); PrevNode = StackNode; @@ -998,6 +1084,10 @@ CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::duplicateContextIds( assert(ContextIdToAllocationType.count(OldId)); // The new context has the same allocation type as original. ContextIdToAllocationType[LastContextId] = ContextIdToAllocationType[OldId]; + // For now set this to 0 so we don't duplicate sizes. Not clear how to divvy + // up the size. Assume that if we are able to duplicate context ids that we + // will be able to disambiguate all copies. + ContextIdToTotalSize[LastContextId] = 0; } return NewContextIds; } @@ -1030,7 +1120,6 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>:: // it resulted in any added ids to NextNode. if (!NewIdsToAdd.empty()) { Edge->getContextIds().insert(NewIdsToAdd.begin(), NewIdsToAdd.end()); - NextNode->ContextIds.insert(NewIdsToAdd.begin(), NewIdsToAdd.end()); UpdateCallers(NextNode, Visited, UpdateCallers); } } @@ -1039,21 +1128,16 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>:: DenseSet<const ContextEdge *> Visited; for (auto &Entry : AllocationCallToContextNodeMap) { auto *Node = Entry.second; - // Update ids on the allocation nodes before calling the recursive - // update along caller edges, since this simplifies the logic during - // that traversal. - DenseSet<uint32_t> NewIdsToAdd = GetNewIds(Node->ContextIds); - Node->ContextIds.insert(NewIdsToAdd.begin(), NewIdsToAdd.end()); UpdateCallers(Node, Visited, UpdateCallers); } } template <typename DerivedCCG, typename FuncTy, typename CallTy> void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::connectNewNode( - ContextNode *NewNode, ContextNode *OrigNode, bool TowardsCallee) { - // Make a copy of the context ids, since this will be adjusted below as they - // are moved. - DenseSet<uint32_t> RemainingContextIds = NewNode->ContextIds; + ContextNode *NewNode, ContextNode *OrigNode, bool TowardsCallee, + // This must be passed by value to make a copy since it will be adjusted + // as ids are moved. + DenseSet<uint32_t> RemainingContextIds) { auto &OrigEdges = TowardsCallee ? OrigNode->CalleeEdges : OrigNode->CallerEdges; // Increment iterator in loop so that we can remove edges as needed. @@ -1072,15 +1156,15 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::connectNewNode( continue; } if (TowardsCallee) { + uint8_t NewAllocType = computeAllocType(NewEdgeContextIds); auto NewEdge = std::make_shared<ContextEdge>( - Edge->Callee, NewNode, computeAllocType(NewEdgeContextIds), - NewEdgeContextIds); + Edge->Callee, NewNode, NewAllocType, std::move(NewEdgeContextIds)); NewNode->CalleeEdges.push_back(NewEdge); NewEdge->Callee->CallerEdges.push_back(NewEdge); } else { + uint8_t NewAllocType = computeAllocType(NewEdgeContextIds); auto NewEdge = std::make_shared<ContextEdge>( - NewNode, Edge->Caller, computeAllocType(NewEdgeContextIds), - NewEdgeContextIds); + NewNode, Edge->Caller, NewAllocType, std::move(NewEdgeContextIds)); NewNode->CallerEdges.push_back(NewEdge); NewEdge->Caller->CalleeEdges.push_back(NewEdge); } @@ -1100,6 +1184,51 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::connectNewNode( } template <typename DerivedCCG, typename FuncTy, typename CallTy> +static void checkEdge( + const std::shared_ptr<ContextEdge<DerivedCCG, FuncTy, CallTy>> &Edge) { + // Confirm that alloc type is not None and that we have at least one context + // id. + assert(Edge->AllocTypes != (uint8_t)AllocationType::None); + assert(!Edge->ContextIds.empty()); +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> +static void checkNode(const ContextNode<DerivedCCG, FuncTy, CallTy> *Node, + bool CheckEdges = true) { + if (Node->isRemoved()) + return; +#ifndef NDEBUG + // Compute node's context ids once for use in asserts. + auto NodeContextIds = Node->getContextIds(); +#endif + // Node's context ids should be the union of both its callee and caller edge + // context ids. + if (Node->CallerEdges.size()) { + DenseSet<uint32_t> CallerEdgeContextIds( + Node->CallerEdges.front()->ContextIds); + for (const auto &Edge : llvm::drop_begin(Node->CallerEdges)) { + if (CheckEdges) + checkEdge<DerivedCCG, FuncTy, CallTy>(Edge); + set_union(CallerEdgeContextIds, Edge->ContextIds); + } + // Node can have more context ids than callers if some contexts terminate at + // node and some are longer. + assert(NodeContextIds == CallerEdgeContextIds || + set_is_subset(CallerEdgeContextIds, NodeContextIds)); + } + if (Node->CalleeEdges.size()) { + DenseSet<uint32_t> CalleeEdgeContextIds( + Node->CalleeEdges.front()->ContextIds); + for (const auto &Edge : llvm::drop_begin(Node->CalleeEdges)) { + if (CheckEdges) + checkEdge<DerivedCCG, FuncTy, CallTy>(Edge); + set_union(CalleeEdgeContextIds, Edge->getContextIds()); + } + assert(NodeContextIds == CalleeEdgeContextIds); + } +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>:: assignStackNodesPostOrder(ContextNode *Node, DenseSet<const ContextNode *> &Visited, @@ -1174,7 +1303,7 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>:: // duplicated context ids. We have to recompute as we might have overlap // overlap between the saved context ids for different last nodes, and // removed them already during the post order traversal. - set_intersect(SavedContextIds, FirstNode->ContextIds); + set_intersect(SavedContextIds, FirstNode->getContextIds()); ContextNode *PrevNode = nullptr; for (auto Id : Ids) { ContextNode *CurNode = getNodeForStackId(Id); @@ -1207,18 +1336,17 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>:: ContextNode *NewNode = NodeOwner.back().get(); NodeToCallingFunc[NewNode] = Func; NonAllocationCallToContextNodeMap[Call] = NewNode; - NewNode->ContextIds = SavedContextIds; - NewNode->AllocTypes = computeAllocType(NewNode->ContextIds); + NewNode->AllocTypes = computeAllocType(SavedContextIds); // Connect to callees of innermost stack frame in inlined call chain. // This updates context ids for FirstNode's callee's to reflect those // moved to NewNode. - connectNewNode(NewNode, FirstNode, /*TowardsCallee=*/true); + connectNewNode(NewNode, FirstNode, /*TowardsCallee=*/true, SavedContextIds); // Connect to callers of outermost stack frame in inlined call chain. // This updates context ids for FirstNode's caller's to reflect those // moved to NewNode. - connectNewNode(NewNode, LastNode, /*TowardsCallee=*/false); + connectNewNode(NewNode, LastNode, /*TowardsCallee=*/false, SavedContextIds); // Now we need to remove context ids from edges/nodes between First and // Last Node. @@ -1230,18 +1358,32 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>:: // Remove the context ids moved to NewNode from CurNode, and the // edge from the prior node. - set_subtract(CurNode->ContextIds, NewNode->ContextIds); if (PrevNode) { auto *PrevEdge = CurNode->findEdgeFromCallee(PrevNode); assert(PrevEdge); - set_subtract(PrevEdge->getContextIds(), NewNode->ContextIds); + set_subtract(PrevEdge->getContextIds(), SavedContextIds); if (PrevEdge->getContextIds().empty()) { PrevNode->eraseCallerEdge(PrevEdge); CurNode->eraseCalleeEdge(PrevEdge); } } + // Since we update the edges from leaf to tail, only look at the callee + // edges. This isn't an alloc node, so if there are no callee edges, the + // alloc type is None. + CurNode->AllocTypes = CurNode->CalleeEdges.empty() + ? (uint8_t)AllocationType::None + : CurNode->computeAllocType(); PrevNode = CurNode; } + if (VerifyNodes) { + checkNode<DerivedCCG, FuncTy, CallTy>(NewNode, /*CheckEdges=*/true); + for (auto Id : Ids) { + ContextNode *CurNode = getNodeForStackId(Id); + // We should only have kept stack ids that had nodes. + assert(CurNode); + checkNode<DerivedCCG, FuncTy, CallTy>(CurNode, /*CheckEdges=*/true); + } + } } } @@ -1315,7 +1457,7 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::updateStackNodes() { // Initialize the context ids with the last node's. We will subsequently // refine the context ids by computing the intersection along all edges. - DenseSet<uint32_t> LastNodeContextIds = LastNode->ContextIds; + DenseSet<uint32_t> LastNodeContextIds = LastNode->getContextIds(); assert(!LastNodeContextIds.empty()); for (unsigned I = 0; I < Calls.size(); I++) { @@ -1381,7 +1523,7 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::updateStackNodes() { // not fully matching stack contexts. To do this, subtract any context ids // found in caller nodes of the last node found above. if (Ids.back() != getLastStackId(Call)) { - for (const auto &PE : CurNode->CallerEdges) { + for (const auto &PE : LastNode->CallerEdges) { set_subtract(StackSequenceContextIds, PE->getContextIds()); if (StackSequenceContextIds.empty()) break; @@ -1438,6 +1580,8 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::updateStackNodes() { DenseSet<const ContextNode *> Visited; for (auto &Entry : AllocationCallToContextNodeMap) assignStackNodesPostOrder(Entry.second, Visited, StackIdToMatchingCalls); + if (VerifyCCG) + check(); } uint64_t ModuleCallsiteContextGraph::getLastStackId(Instruction *Call) { @@ -1547,7 +1691,7 @@ ModuleCallsiteContextGraph::ModuleCallsiteContextGraph( CallStack<MDNode, MDNode::op_iterator> StackContext(StackNode); addStackNodesForMIB<MDNode, MDNode::op_iterator>( AllocNode, StackContext, CallsiteContext, - getMIBAllocType(MIBMD)); + getMIBAllocType(MIBMD), getMIBTotalSize(MIBMD)); } assert(AllocNode->AllocTypes != (uint8_t)AllocationType::None); // Memprof and callsite metadata on memory allocations no longer @@ -1619,12 +1763,20 @@ IndexCallsiteContextGraph::IndexCallsiteContextGraph( // stack ids on the allocation call during ModuleSummaryAnalysis. CallStack<MIBInfo, SmallVector<unsigned>::const_iterator> EmptyContext; + unsigned I = 0; + assert(!MemProfReportHintedSizes || + AN.TotalSizes.size() == AN.MIBs.size()); // Now add all of the MIBs and their stack nodes. for (auto &MIB : AN.MIBs) { CallStack<MIBInfo, SmallVector<unsigned>::const_iterator> StackContext(&MIB); + uint64_t TotalSize = 0; + if (MemProfReportHintedSizes) + TotalSize = AN.TotalSizes[I]; addStackNodesForMIB<MIBInfo, SmallVector<unsigned>::const_iterator>( - AllocNode, StackContext, EmptyContext, MIB.AllocType); + AllocNode, StackContext, EmptyContext, MIB.AllocType, + TotalSize); + I++; } assert(AllocNode->AllocTypes != (uint8_t)AllocationType::None); // Initialize version 0 on the summary alloc node to the current alloc @@ -1677,19 +1829,16 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, // from the profiled contexts. MapVector<CallInfo, ContextNode *> TailCallToContextNodeMap; - for (auto Entry = NonAllocationCallToContextNodeMap.begin(); - Entry != NonAllocationCallToContextNodeMap.end();) { - auto *Node = Entry->second; + for (auto &Entry : NonAllocationCallToContextNodeMap) { + auto *Node = Entry.second; assert(Node->Clones.empty()); // Check all node callees and see if in the same function. - bool Removed = false; auto Call = Node->Call.call(); - for (auto EI = Node->CalleeEdges.begin(); EI != Node->CalleeEdges.end();) { + for (auto EI = Node->CalleeEdges.begin(); EI != Node->CalleeEdges.end(); + ++EI) { auto Edge = *EI; - if (!Edge->Callee->hasCall()) { - ++EI; + if (!Edge->Callee->hasCall()) continue; - } assert(NodeToCallingFunc.count(Edge->Callee)); // Check if the called function matches that of the callee node. if (calleesMatch(Call, EI, TailCallToContextNodeMap)) @@ -1698,15 +1847,18 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, // Work around by setting Node to have a null call, so it gets // skipped during cloning. Otherwise assignFunctions will assert // because its data structures are not designed to handle this case. - Entry = NonAllocationCallToContextNodeMap.erase(Entry); Node->setCall(CallInfo()); - Removed = true; break; } - if (!Removed) - Entry++; } + // Remove all mismatched nodes identified in the above loop from the node map + // (checking whether they have a null call which is set above). For a + // MapVector like NonAllocationCallToContextNodeMap it is much more efficient + // to do the removal via remove_if than by individually erasing entries above. + NonAllocationCallToContextNodeMap.remove_if( + [](const auto &it) { return !it.second->hasCall(); }); + // Add the new nodes after the above loop so that the iteration is not // invalidated. for (auto &[Call, Node] : TailCallToContextNodeMap) @@ -1735,16 +1887,12 @@ bool CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::calleesMatch( // calls between the profiled caller and callee. std::vector<std::pair<CallTy, FuncTy *>> FoundCalleeChain; if (!calleeMatchesFunc(Call, ProfiledCalleeFunc, CallerFunc, - FoundCalleeChain)) { - ++EI; + FoundCalleeChain)) return false; - } // The usual case where the profiled callee matches that of the IR/summary. - if (FoundCalleeChain.empty()) { - ++EI; + if (FoundCalleeChain.empty()) return true; - } auto AddEdge = [Edge, &EI](ContextNode *Caller, ContextNode *Callee) { auto *CurEdge = Callee->findEdgeFromCaller(Caller); @@ -1781,8 +1929,6 @@ bool CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::calleesMatch( // First check if we have already synthesized a node for this tail call. if (TailCallToContextNodeMap.count(NewCall)) { NewNode = TailCallToContextNodeMap[NewCall]; - NewNode->ContextIds.insert(Edge->ContextIds.begin(), - Edge->ContextIds.end()); NewNode->AllocTypes |= Edge->AllocTypes; } else { FuncToCallsWithMetadata[Func].push_back({NewCall}); @@ -1792,7 +1938,6 @@ bool CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::calleesMatch( NewNode = NodeOwner.back().get(); NodeToCallingFunc[NewNode] = Func; TailCallToContextNodeMap[NewCall] = NewNode; - NewNode->ContextIds = Edge->ContextIds; NewNode->AllocTypes = Edge->AllocTypes; } @@ -1809,6 +1954,13 @@ bool CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::calleesMatch( Edge->Callee->eraseCallerEdge(Edge.get()); EI = Edge->Caller->CalleeEdges.erase(EI); + // To simplify the increment of EI in the caller, subtract one from EI. + // In the final AddEdge call we would have either added a new callee edge, + // to Edge->Caller, or found an existing one. Either way we are guaranteed + // that there is at least one callee edge. + assert(!Edge->Caller->CalleeEdges.empty()); + --EI; + return true; } @@ -1873,15 +2025,17 @@ bool ModuleCallsiteContextGraph::findProfiledCalleeThroughTailCalls( } else if (findProfiledCalleeThroughTailCalls( ProfiledCallee, CalledFunction, Depth + 1, FoundCalleeChain, FoundMultipleCalleeChains)) { - if (FoundMultipleCalleeChains) - return false; + // findProfiledCalleeThroughTailCalls should not have returned + // true if FoundMultipleCalleeChains. + assert(!FoundMultipleCalleeChains); if (FoundSingleCalleeChain) { FoundMultipleCalleeChains = true; return false; } FoundSingleCalleeChain = true; SaveCallsiteInfo(&I, CalleeFunc); - } + } else if (FoundMultipleCalleeChains) + return false; } } @@ -1988,8 +2142,9 @@ bool IndexCallsiteContextGraph::findProfiledCalleeThroughTailCalls( } else if (findProfiledCalleeThroughTailCalls( ProfiledCallee, CallEdge.first, Depth + 1, FoundCalleeChain, FoundMultipleCalleeChains)) { - if (FoundMultipleCalleeChains) - return false; + // findProfiledCalleeThroughTailCalls should not have returned + // true if FoundMultipleCalleeChains. + assert(!FoundMultipleCalleeChains); if (FoundSingleCalleeChain) { FoundMultipleCalleeChains = true; return false; @@ -1999,7 +2154,8 @@ bool IndexCallsiteContextGraph::findProfiledCalleeThroughTailCalls( // Add FS to FSToVIMap in case it isn't already there. assert(!FSToVIMap.count(FS) || FSToVIMap[FS] == FSVI); FSToVIMap[FS] = FSVI; - } + } else if (FoundMultipleCalleeChains) + return false; } } @@ -2053,17 +2209,6 @@ bool IndexCallsiteContextGraph::calleeMatchesFunc( return true; } -static std::string getAllocTypeString(uint8_t AllocTypes) { - if (!AllocTypes) - return "None"; - std::string Str; - if (AllocTypes & (uint8_t)AllocationType::NotCold) - Str += "NotCold"; - if (AllocTypes & (uint8_t)AllocationType::Cold) - Str += "Cold"; - return Str; -} - template <typename DerivedCCG, typename FuncTy, typename CallTy> void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode::dump() const { @@ -2082,6 +2227,8 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode::print( OS << "\n"; OS << "\tAllocTypes: " << getAllocTypeString(AllocTypes) << "\n"; OS << "\tContextIds:"; + // Make a copy of the computed context ids that we can sort for stability. + auto ContextIds = getContextIds(); std::vector<uint32_t> SortedIds(ContextIds.begin(), ContextIds.end()); std::sort(SortedIds.begin(), SortedIds.end()); for (auto Id : SortedIds) @@ -2142,49 +2289,26 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::print( } template <typename DerivedCCG, typename FuncTy, typename CallTy> -static void checkEdge( - const std::shared_ptr<ContextEdge<DerivedCCG, FuncTy, CallTy>> &Edge) { - // Confirm that alloc type is not None and that we have at least one context - // id. - assert(Edge->AllocTypes != (uint8_t)AllocationType::None); - assert(!Edge->ContextIds.empty()); -} - -template <typename DerivedCCG, typename FuncTy, typename CallTy> -static void checkNode(const ContextNode<DerivedCCG, FuncTy, CallTy> *Node, - bool CheckEdges = true) { - if (Node->isRemoved()) - return; - // Node's context ids should be the union of both its callee and caller edge - // context ids. - if (Node->CallerEdges.size()) { - auto EI = Node->CallerEdges.begin(); - auto &FirstEdge = *EI; - EI++; - DenseSet<uint32_t> CallerEdgeContextIds(FirstEdge->ContextIds); - for (; EI != Node->CallerEdges.end(); EI++) { - const auto &Edge = *EI; - if (CheckEdges) - checkEdge<DerivedCCG, FuncTy, CallTy>(Edge); - set_union(CallerEdgeContextIds, Edge->ContextIds); - } - // Node can have more context ids than callers if some contexts terminate at - // node and some are longer. - assert(Node->ContextIds == CallerEdgeContextIds || - set_is_subset(CallerEdgeContextIds, Node->ContextIds)); - } - if (Node->CalleeEdges.size()) { - auto EI = Node->CalleeEdges.begin(); - auto &FirstEdge = *EI; - EI++; - DenseSet<uint32_t> CalleeEdgeContextIds(FirstEdge->ContextIds); - for (; EI != Node->CalleeEdges.end(); EI++) { - const auto &Edge = *EI; - if (CheckEdges) - checkEdge<DerivedCCG, FuncTy, CallTy>(Edge); - set_union(CalleeEdgeContextIds, Edge->ContextIds); +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::printTotalSizes( + raw_ostream &OS) const { + using GraphType = const CallsiteContextGraph<DerivedCCG, FuncTy, CallTy> *; + for (const auto Node : nodes<GraphType>(this)) { + if (Node->isRemoved()) + continue; + if (!Node->IsAllocation) + continue; + DenseSet<uint32_t> ContextIds = Node->getContextIds(); + std::vector<uint32_t> SortedIds(ContextIds.begin(), ContextIds.end()); + std::sort(SortedIds.begin(), SortedIds.end()); + for (auto Id : SortedIds) { + auto SizeI = ContextIdToTotalSize.find(Id); + assert(SizeI != ContextIdToTotalSize.end()); + auto TypeI = ContextIdToAllocationType.find(Id); + assert(TypeI != ContextIdToAllocationType.end()); + OS << getAllocTypeString((uint8_t)TypeI->second) << " context " << Id + << " with total size " << SizeI->second << " is " + << getAllocTypeString(Node->AllocTypes) << " after cloning\n"; } - assert(Node->ContextIds == CalleeEdgeContextIds); } } @@ -2275,7 +2399,7 @@ struct DOTGraphTraits<const CallsiteContextGraph<DerivedCCG, FuncTy, CallTy> *> static std::string getNodeAttributes(NodeRef Node, GraphType) { std::string AttributeString = (Twine("tooltip=\"") + getNodeId(Node) + " " + - getContextIds(Node->ContextIds) + "\"") + getContextIds(Node->getContextIds()) + "\"") .str(); AttributeString += (Twine(",fillcolor=\"") + getColor(Node->AllocTypes) + "\"").str(); @@ -2347,7 +2471,8 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::exportToDot( template <typename DerivedCCG, typename FuncTy, typename CallTy> typename CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::ContextNode * CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::moveEdgeToNewCalleeClone( - const std::shared_ptr<ContextEdge> &Edge, EdgeIter *CallerEdgeI) { + const std::shared_ptr<ContextEdge> &Edge, EdgeIter *CallerEdgeI, + DenseSet<uint32_t> ContextIdsToMove) { ContextNode *Node = Edge->Callee; NodeOwner.push_back( std::make_unique<ContextNode>(Node->IsAllocation, Node->Call)); @@ -2355,7 +2480,8 @@ CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::moveEdgeToNewCalleeClone( Node->addClone(Clone); assert(NodeToCallingFunc.count(Node)); NodeToCallingFunc[Clone] = NodeToCallingFunc[Node]; - moveEdgeToExistingCalleeClone(Edge, Clone, CallerEdgeI, /*NewClone=*/true); + moveEdgeToExistingCalleeClone(Edge, Clone, CallerEdgeI, /*NewClone=*/true, + ContextIdsToMove); return Clone; } @@ -2363,27 +2489,75 @@ template <typename DerivedCCG, typename FuncTy, typename CallTy> void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>:: moveEdgeToExistingCalleeClone(const std::shared_ptr<ContextEdge> &Edge, ContextNode *NewCallee, EdgeIter *CallerEdgeI, - bool NewClone) { + bool NewClone, + DenseSet<uint32_t> ContextIdsToMove) { // NewCallee and Edge's current callee must be clones of the same original // node (Edge's current callee may be the original node too). assert(NewCallee->getOrigNode() == Edge->Callee->getOrigNode()); - auto &EdgeContextIds = Edge->getContextIds(); + ContextNode *OldCallee = Edge->Callee; - if (CallerEdgeI) - *CallerEdgeI = OldCallee->CallerEdges.erase(*CallerEdgeI); - else - OldCallee->eraseCallerEdge(Edge.get()); - Edge->Callee = NewCallee; - NewCallee->CallerEdges.push_back(Edge); - // Don't need to update Edge's context ids since we are simply reconnecting - // it. - set_subtract(OldCallee->ContextIds, EdgeContextIds); - NewCallee->ContextIds.insert(EdgeContextIds.begin(), EdgeContextIds.end()); - NewCallee->AllocTypes |= Edge->AllocTypes; - OldCallee->AllocTypes = computeAllocType(OldCallee->ContextIds); - // OldCallee alloc type should be None iff its context id set is now empty. - assert((OldCallee->AllocTypes == (uint8_t)AllocationType::None) == - OldCallee->ContextIds.empty()); + + // We might already have an edge to the new callee from earlier cloning for a + // different allocation. If one exists we will reuse it. + auto ExistingEdgeToNewCallee = NewCallee->findEdgeFromCaller(Edge->Caller); + + // Callers will pass an empty ContextIdsToMove set when they want to move the + // edge. Copy in Edge's ids for simplicity. + if (ContextIdsToMove.empty()) + ContextIdsToMove = Edge->getContextIds(); + + // If we are moving all of Edge's ids, then just move the whole Edge. + // Otherwise only move the specified subset, to a new edge if needed. + if (Edge->getContextIds().size() == ContextIdsToMove.size()) { + // Moving the whole Edge. + if (CallerEdgeI) + *CallerEdgeI = OldCallee->CallerEdges.erase(*CallerEdgeI); + else + OldCallee->eraseCallerEdge(Edge.get()); + if (ExistingEdgeToNewCallee) { + // Since we already have an edge to NewCallee, simply move the ids + // onto it, and remove the existing Edge. + ExistingEdgeToNewCallee->getContextIds().insert(ContextIdsToMove.begin(), + ContextIdsToMove.end()); + ExistingEdgeToNewCallee->AllocTypes |= Edge->AllocTypes; + assert(Edge->ContextIds == ContextIdsToMove); + Edge->ContextIds.clear(); + Edge->AllocTypes = (uint8_t)AllocationType::None; + Edge->Caller->eraseCalleeEdge(Edge.get()); + } else { + // Otherwise just reconnect Edge to NewCallee. + Edge->Callee = NewCallee; + NewCallee->CallerEdges.push_back(Edge); + // Don't need to update Edge's context ids since we are simply + // reconnecting it. + } + // In either case, need to update the alloc types on New Callee. + NewCallee->AllocTypes |= Edge->AllocTypes; + } else { + // Only moving a subset of Edge's ids. + if (CallerEdgeI) + ++CallerEdgeI; + // Compute the alloc type of the subset of ids being moved. + auto CallerEdgeAllocType = computeAllocType(ContextIdsToMove); + if (ExistingEdgeToNewCallee) { + // Since we already have an edge to NewCallee, simply move the ids + // onto it. + ExistingEdgeToNewCallee->getContextIds().insert(ContextIdsToMove.begin(), + ContextIdsToMove.end()); + ExistingEdgeToNewCallee->AllocTypes |= CallerEdgeAllocType; + } else { + // Otherwise, create a new edge to NewCallee for the ids being moved. + auto NewEdge = std::make_shared<ContextEdge>( + NewCallee, Edge->Caller, CallerEdgeAllocType, ContextIdsToMove); + Edge->Caller->CalleeEdges.push_back(NewEdge); + NewCallee->CallerEdges.push_back(NewEdge); + } + // In either case, need to update the alloc types on NewCallee, and remove + // those ids and update the alloc type on the original Edge. + NewCallee->AllocTypes |= CallerEdgeAllocType; + set_subtract(Edge->ContextIds, ContextIdsToMove); + Edge->AllocTypes = computeAllocType(Edge->ContextIds); + } // Now walk the old callee node's callee edges and move Edge's context ids // over to the corresponding edge into the clone (which is created here if // this is a newly created clone). @@ -2391,7 +2565,7 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>:: // The context ids moving to the new callee are the subset of this edge's // context ids and the context ids on the caller edge being moved. DenseSet<uint32_t> EdgeContextIdsToMove = - set_intersection(OldCalleeEdge->getContextIds(), EdgeContextIds); + set_intersection(OldCalleeEdge->getContextIds(), ContextIdsToMove); set_subtract(OldCalleeEdge->getContextIds(), EdgeContextIdsToMove); OldCalleeEdge->AllocTypes = computeAllocType(OldCalleeEdge->getContextIds()); @@ -2415,6 +2589,12 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>:: NewCallee->CalleeEdges.push_back(NewEdge); NewEdge->Callee->CallerEdges.push_back(NewEdge); } + // Recompute the node alloc type now that its callee edges have been + // updated (since we will compute from those edges). + OldCallee->AllocTypes = OldCallee->computeAllocType(); + // OldCallee alloc type should be None iff its context id set is now empty. + assert((OldCallee->AllocTypes == (uint8_t)AllocationType::None) == + OldCallee->emptyContextIds()); if (VerifyCCG) { checkNode<DerivedCCG, FuncTy, CallTy>(OldCallee, /*CheckEdges=*/false); checkNode<DerivedCCG, FuncTy, CallTy>(NewCallee, /*CheckEdges=*/false); @@ -2428,10 +2608,43 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>:: } template <typename DerivedCCG, typename FuncTy, typename CallTy> +void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>:: + recursivelyRemoveNoneTypeCalleeEdges( + ContextNode *Node, DenseSet<const ContextNode *> &Visited) { + auto Inserted = Visited.insert(Node); + if (!Inserted.second) + return; + + removeNoneTypeCalleeEdges(Node); + + for (auto *Clone : Node->Clones) + recursivelyRemoveNoneTypeCalleeEdges(Clone, Visited); + + // The recursive call may remove some of this Node's caller edges. + // Iterate over a copy and skip any that were removed. + auto CallerEdges = Node->CallerEdges; + for (auto &Edge : CallerEdges) { + // Skip any that have been removed by an earlier recursive call. + if (Edge->Callee == nullptr && Edge->Caller == nullptr) { + assert(!is_contained(Node->CallerEdges, Edge)); + continue; + } + recursivelyRemoveNoneTypeCalleeEdges(Edge->Caller, Visited); + } +} + +template <typename DerivedCCG, typename FuncTy, typename CallTy> void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones() { DenseSet<const ContextNode *> Visited; + for (auto &Entry : AllocationCallToContextNodeMap) { + Visited.clear(); + identifyClones(Entry.second, Visited, Entry.second->getContextIds()); + } + Visited.clear(); for (auto &Entry : AllocationCallToContextNodeMap) - identifyClones(Entry.second, Visited); + recursivelyRemoveNoneTypeCalleeEdges(Entry.second, Visited); + if (VerifyCCG) + check(); } // helper function to check an AllocType is cold or notcold or both. @@ -2444,9 +2657,10 @@ bool checkColdOrNotCold(uint8_t AllocType) { template <typename DerivedCCG, typename FuncTy, typename CallTy> void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones( - ContextNode *Node, DenseSet<const ContextNode *> &Visited) { + ContextNode *Node, DenseSet<const ContextNode *> &Visited, + const DenseSet<uint32_t> &AllocContextIds) { if (VerifyNodes) - checkNode<DerivedCCG, FuncTy, CallTy>(Node); + checkNode<DerivedCCG, FuncTy, CallTy>(Node, /*CheckEdges=*/false); assert(!Node->CloneOf); // If Node as a null call, then either it wasn't found in the module (regular @@ -2478,7 +2692,7 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones( } // Ignore any caller we previously visited via another edge. if (!Visited.count(Edge->Caller) && !Edge->Caller->CloneOf) { - identifyClones(Edge->Caller, Visited); + identifyClones(Edge->Caller, Visited, AllocContextIds); } } } @@ -2507,8 +2721,16 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones( std::stable_sort(Node->CallerEdges.begin(), Node->CallerEdges.end(), [&](const std::shared_ptr<ContextEdge> &A, const std::shared_ptr<ContextEdge> &B) { - assert(checkColdOrNotCold(A->AllocTypes) && - checkColdOrNotCold(B->AllocTypes)); + // Nodes with non-empty context ids should be sorted before + // those with empty context ids. + if (A->ContextIds.empty()) + // Either B ContextIds are non-empty (in which case we + // should return false because B < A), or B ContextIds + // are empty, in which case they are equal, and we should + // maintain the original relative ordering. + return false; + if (B->ContextIds.empty()) + return true; if (A->AllocTypes == B->AllocTypes) // Use the first context id for each edge as a @@ -2533,13 +2755,23 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones( if (hasSingleAllocType(Node->AllocTypes) || Node->CallerEdges.size() <= 1) break; + // Only need to process the ids along this edge pertaining to the given + // allocation. + auto CallerEdgeContextsForAlloc = + set_intersection(CallerEdge->getContextIds(), AllocContextIds); + if (CallerEdgeContextsForAlloc.empty()) { + ++EI; + continue; + } + auto CallerAllocTypeForAlloc = computeAllocType(CallerEdgeContextsForAlloc); + // Compute the node callee edge alloc types corresponding to the context ids // for this caller edge. std::vector<uint8_t> CalleeEdgeAllocTypesForCallerEdge; CalleeEdgeAllocTypesForCallerEdge.reserve(Node->CalleeEdges.size()); for (auto &CalleeEdge : Node->CalleeEdges) CalleeEdgeAllocTypesForCallerEdge.push_back(intersectAllocTypes( - CalleeEdge->getContextIds(), CallerEdge->getContextIds())); + CalleeEdge->getContextIds(), CallerEdgeContextsForAlloc)); // Don't clone if doing so will not disambiguate any alloc types amongst // caller edges (including the callee edges that would be cloned). @@ -2554,7 +2786,7 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones( // disambiguated by splitting out different context ids. assert(CallerEdge->AllocTypes != (uint8_t)AllocationType::None); assert(Node->AllocTypes != (uint8_t)AllocationType::None); - if (allocTypeToUse(CallerEdge->AllocTypes) == + if (allocTypeToUse(CallerAllocTypeForAlloc) == allocTypeToUse(Node->AllocTypes) && allocTypesMatch<DerivedCCG, FuncTy, CallTy>( CalleeEdgeAllocTypesForCallerEdge, Node->CalleeEdges)) { @@ -2567,7 +2799,7 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones( ContextNode *Clone = nullptr; for (auto *CurClone : Node->Clones) { if (allocTypeToUse(CurClone->AllocTypes) != - allocTypeToUse(CallerEdge->AllocTypes)) + allocTypeToUse(CallerAllocTypeForAlloc)) continue; if (!allocTypesMatch<DerivedCCG, FuncTy, CallTy>( @@ -2579,47 +2811,26 @@ void CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::identifyClones( // The edge iterator is adjusted when we move the CallerEdge to the clone. if (Clone) - moveEdgeToExistingCalleeClone(CallerEdge, Clone, &EI); + moveEdgeToExistingCalleeClone(CallerEdge, Clone, &EI, /*NewClone=*/false, + CallerEdgeContextsForAlloc); else - Clone = moveEdgeToNewCalleeClone(CallerEdge, &EI); + Clone = + moveEdgeToNewCalleeClone(CallerEdge, &EI, CallerEdgeContextsForAlloc); assert(EI == Node->CallerEdges.end() || Node->AllocTypes != (uint8_t)AllocationType::None); // Sanity check that no alloc types on clone or its edges are None. assert(Clone->AllocTypes != (uint8_t)AllocationType::None); - assert(llvm::none_of( - Clone->CallerEdges, [&](const std::shared_ptr<ContextEdge> &E) { - return E->AllocTypes == (uint8_t)AllocationType::None; - })); } - // Cloning may have resulted in some cloned callee edges with type None, - // because they aren't carrying any contexts. Remove those edges. - for (auto *Clone : Node->Clones) { - removeNoneTypeCalleeEdges(Clone); - if (VerifyNodes) - checkNode<DerivedCCG, FuncTy, CallTy>(Clone); - } // We should still have some context ids on the original Node. - assert(!Node->ContextIds.empty()); - - // Remove any callee edges that ended up with alloc type None after creating - // clones and updating callee edges. - removeNoneTypeCalleeEdges(Node); + assert(!Node->emptyContextIds()); // Sanity check that no alloc types on node or edges are None. assert(Node->AllocTypes != (uint8_t)AllocationType::None); - assert(llvm::none_of(Node->CalleeEdges, - [&](const std::shared_ptr<ContextEdge> &E) { - return E->AllocTypes == (uint8_t)AllocationType::None; - })); - assert(llvm::none_of(Node->CallerEdges, - [&](const std::shared_ptr<ContextEdge> &E) { - return E->AllocTypes == (uint8_t)AllocationType::None; - })); if (VerifyNodes) - checkNode<DerivedCCG, FuncTy, CallTy>(Node); + checkNode<DerivedCCG, FuncTy, CallTy>(Node, /*CheckEdges=*/false); } void ModuleCallsiteContextGraph::updateAllocationCall( @@ -2817,7 +3028,7 @@ bool CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::assignFunctions() { // find additional cloning is required. std::deque<ContextNode *> ClonesWorklist; // Ignore original Node if we moved all of its contexts to clones. - if (!Node->ContextIds.empty()) + if (!Node->emptyContextIds()) ClonesWorklist.push_back(Node); ClonesWorklist.insert(ClonesWorklist.end(), Node->Clones.begin(), Node->Clones.end()); @@ -3157,7 +3368,7 @@ bool CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::assignFunctions() { // Skip if either no call to update, or if we ended up with no context ids // (we moved all edges onto other clones). - if (!Node->hasCall() || Node->ContextIds.empty()) + if (!Node->hasCall() || Node->emptyContextIds()) return; if (Node->IsAllocation) { @@ -3375,10 +3586,22 @@ bool MemProfContextDisambiguation::applyImport(Module &M) { auto *GVSummary = ImportSummary->findSummaryInModule(TheFnVI, M.getModuleIdentifier()); - if (!GVSummary) - // Must have been imported, use the first summary (might be multiple if - // this was a linkonce_odr). - GVSummary = TheFnVI.getSummaryList().front().get(); + if (!GVSummary) { + // Must have been imported, use the summary which matches the definition。 + // (might be multiple if this was a linkonce_odr). + auto SrcModuleMD = F.getMetadata("thinlto_src_module"); + assert(SrcModuleMD && + "enable-import-metadata is needed to emit thinlto_src_module"); + StringRef SrcModule = + dyn_cast<MDString>(SrcModuleMD->getOperand(0))->getString(); + for (auto &GVS : TheFnVI.getSummaryList()) { + if (GVS->modulePath() == SrcModule) { + GVSummary = GVS.get(); + break; + } + } + assert(GVSummary && GVSummary->modulePath() == SrcModule); + } // If this was an imported alias skip it as we won't have the function // summary, and it should be cloned in the original module. @@ -3471,17 +3694,23 @@ bool MemProfContextDisambiguation::applyImport(Module &M) { auto *MIBMD = cast<const MDNode>(MDOp); MDNode *StackMDNode = getMIBStackNode(MIBMD); assert(StackMDNode); - SmallVector<unsigned> StackIdsFromMetadata; CallStack<MDNode, MDNode::op_iterator> StackContext(StackMDNode); - for (auto ContextIter = - StackContext.beginAfterSharedPrefix(CallsiteContext); + auto ContextIterBegin = + StackContext.beginAfterSharedPrefix(CallsiteContext); + // Skip the checking on the first iteration. + uint64_t LastStackContextId = + (ContextIterBegin != StackContext.end() && + *ContextIterBegin == 0) + ? 1 + : 0; + for (auto ContextIter = ContextIterBegin; ContextIter != StackContext.end(); ++ContextIter) { // If this is a direct recursion, simply skip the duplicate // entries, to be consistent with how the summary ids were // generated during ModuleSummaryAnalysis. - if (!StackIdsFromMetadata.empty() && - StackIdsFromMetadata.back() == *ContextIter) + if (LastStackContextId == *ContextIter) continue; + LastStackContextId = *ContextIter; assert(StackIdIndexIter != MIBIter->StackIdIndices.end()); assert(ImportSummary->getStackIdAtIndex(*StackIdIndexIter) == *ContextIter); @@ -3619,6 +3848,9 @@ bool CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::process() { if (ExportToDot) exportToDot("clonefuncassign"); + if (MemProfReportHintedSizes) + printTotalSizes(errs()); + return Changed; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp index c8c011d94e4a..b50a700e0903 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/MergeFunctions.cpp @@ -256,15 +256,22 @@ private: /// Fill PDIUnrelatedWL with instructions from the entry block that are /// unrelated to parameter related debug info. - void filterInstsUnrelatedToPDI(BasicBlock *GEntryBlock, - std::vector<Instruction *> &PDIUnrelatedWL); + /// \param PDVRUnrelatedWL The equivalent non-intrinsic debug records. + void + filterInstsUnrelatedToPDI(BasicBlock *GEntryBlock, + std::vector<Instruction *> &PDIUnrelatedWL, + std::vector<DbgVariableRecord *> &PDVRUnrelatedWL); /// Erase the rest of the CFG (i.e. barring the entry block). void eraseTail(Function *G); /// Erase the instructions in PDIUnrelatedWL as they are unrelated to the /// parameter debug info, from the entry block. - void eraseInstsUnrelatedToPDI(std::vector<Instruction *> &PDIUnrelatedWL); + /// \param PDVRUnrelatedWL contains the equivalent set of non-instruction + /// debug-info records. + void + eraseInstsUnrelatedToPDI(std::vector<Instruction *> &PDIUnrelatedWL, + std::vector<DbgVariableRecord *> &PDVRUnrelatedWL); /// Replace G with a simple tail call to bitcast(F). Also (unless /// MergeFunctionsPDI holds) replace direct uses of G with bitcast(F), @@ -506,7 +513,8 @@ static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) { // Erase the instructions in PDIUnrelatedWL as they are unrelated to the // parameter debug info, from the entry block. void MergeFunctions::eraseInstsUnrelatedToPDI( - std::vector<Instruction *> &PDIUnrelatedWL) { + std::vector<Instruction *> &PDIUnrelatedWL, + std::vector<DbgVariableRecord *> &PDVRUnrelatedWL) { LLVM_DEBUG( dbgs() << " Erasing instructions (in reverse order of appearance in " "entry block) unrelated to parameter debug info from entry " @@ -519,6 +527,16 @@ void MergeFunctions::eraseInstsUnrelatedToPDI( I->eraseFromParent(); PDIUnrelatedWL.pop_back(); } + + while (!PDVRUnrelatedWL.empty()) { + DbgVariableRecord *DVR = PDVRUnrelatedWL.back(); + LLVM_DEBUG(dbgs() << " Deleting DbgVariableRecord "); + LLVM_DEBUG(DVR->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); + DVR->eraseFromParent(); + PDVRUnrelatedWL.pop_back(); + } + LLVM_DEBUG(dbgs() << " } // Done erasing instructions unrelated to parameter " "debug info from entry block. \n"); } @@ -547,75 +565,99 @@ void MergeFunctions::eraseTail(Function *G) { // The rest are unrelated to debug info for the parameters; fill up // PDIUnrelatedWL with such instructions. void MergeFunctions::filterInstsUnrelatedToPDI( - BasicBlock *GEntryBlock, std::vector<Instruction *> &PDIUnrelatedWL) { + BasicBlock *GEntryBlock, std::vector<Instruction *> &PDIUnrelatedWL, + std::vector<DbgVariableRecord *> &PDVRUnrelatedWL) { std::set<Instruction *> PDIRelated; - for (BasicBlock::iterator BI = GEntryBlock->begin(), BIE = GEntryBlock->end(); - BI != BIE; ++BI) { - if (auto *DVI = dyn_cast<DbgValueInst>(&*BI)) { - LLVM_DEBUG(dbgs() << " Deciding: "); - LLVM_DEBUG(BI->print(dbgs())); + std::set<DbgVariableRecord *> PDVRRelated; + + // Work out whether a dbg.value intrinsic or an equivalent DbgVariableRecord + // is a parameter to be preserved. + auto ExamineDbgValue = [](auto *DbgVal, auto &Container) { + LLVM_DEBUG(dbgs() << " Deciding: "); + LLVM_DEBUG(DbgVal->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); + DILocalVariable *DILocVar = DbgVal->getVariable(); + if (DILocVar->isParameter()) { + LLVM_DEBUG(dbgs() << " Include (parameter): "); + LLVM_DEBUG(DbgVal->print(dbgs())); LLVM_DEBUG(dbgs() << "\n"); - DILocalVariable *DILocVar = DVI->getVariable(); - if (DILocVar->isParameter()) { - LLVM_DEBUG(dbgs() << " Include (parameter): "); - LLVM_DEBUG(BI->print(dbgs())); - LLVM_DEBUG(dbgs() << "\n"); - PDIRelated.insert(&*BI); - } else { - LLVM_DEBUG(dbgs() << " Delete (!parameter): "); - LLVM_DEBUG(BI->print(dbgs())); - LLVM_DEBUG(dbgs() << "\n"); - } - } else if (auto *DDI = dyn_cast<DbgDeclareInst>(&*BI)) { - LLVM_DEBUG(dbgs() << " Deciding: "); - LLVM_DEBUG(BI->print(dbgs())); + Container.insert(DbgVal); + } else { + LLVM_DEBUG(dbgs() << " Delete (!parameter): "); + LLVM_DEBUG(DbgVal->print(dbgs())); LLVM_DEBUG(dbgs() << "\n"); - DILocalVariable *DILocVar = DDI->getVariable(); - if (DILocVar->isParameter()) { - LLVM_DEBUG(dbgs() << " Parameter: "); - LLVM_DEBUG(DILocVar->print(dbgs())); - AllocaInst *AI = dyn_cast_or_null<AllocaInst>(DDI->getAddress()); - if (AI) { - LLVM_DEBUG(dbgs() << " Processing alloca users: "); - LLVM_DEBUG(dbgs() << "\n"); - for (User *U : AI->users()) { - if (StoreInst *SI = dyn_cast<StoreInst>(U)) { - if (Value *Arg = SI->getValueOperand()) { - if (isa<Argument>(Arg)) { - LLVM_DEBUG(dbgs() << " Include: "); - LLVM_DEBUG(AI->print(dbgs())); - LLVM_DEBUG(dbgs() << "\n"); - PDIRelated.insert(AI); - LLVM_DEBUG(dbgs() << " Include (parameter): "); - LLVM_DEBUG(SI->print(dbgs())); - LLVM_DEBUG(dbgs() << "\n"); - PDIRelated.insert(SI); - LLVM_DEBUG(dbgs() << " Include: "); - LLVM_DEBUG(BI->print(dbgs())); - LLVM_DEBUG(dbgs() << "\n"); - PDIRelated.insert(&*BI); - } else { - LLVM_DEBUG(dbgs() << " Delete (!parameter): "); - LLVM_DEBUG(SI->print(dbgs())); - LLVM_DEBUG(dbgs() << "\n"); - } + } + }; + + auto ExamineDbgDeclare = [&PDIRelated](auto *DbgDecl, auto &Container) { + LLVM_DEBUG(dbgs() << " Deciding: "); + LLVM_DEBUG(DbgDecl->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); + DILocalVariable *DILocVar = DbgDecl->getVariable(); + if (DILocVar->isParameter()) { + LLVM_DEBUG(dbgs() << " Parameter: "); + LLVM_DEBUG(DILocVar->print(dbgs())); + AllocaInst *AI = dyn_cast_or_null<AllocaInst>(DbgDecl->getAddress()); + if (AI) { + LLVM_DEBUG(dbgs() << " Processing alloca users: "); + LLVM_DEBUG(dbgs() << "\n"); + for (User *U : AI->users()) { + if (StoreInst *SI = dyn_cast<StoreInst>(U)) { + if (Value *Arg = SI->getValueOperand()) { + if (isa<Argument>(Arg)) { + LLVM_DEBUG(dbgs() << " Include: "); + LLVM_DEBUG(AI->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); + PDIRelated.insert(AI); + LLVM_DEBUG(dbgs() << " Include (parameter): "); + LLVM_DEBUG(SI->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); + PDIRelated.insert(SI); + LLVM_DEBUG(dbgs() << " Include: "); + LLVM_DEBUG(DbgDecl->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); + Container.insert(DbgDecl); + } else { + LLVM_DEBUG(dbgs() << " Delete (!parameter): "); + LLVM_DEBUG(SI->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); } - } else { - LLVM_DEBUG(dbgs() << " Defer: "); - LLVM_DEBUG(U->print(dbgs())); - LLVM_DEBUG(dbgs() << "\n"); } + } else { + LLVM_DEBUG(dbgs() << " Defer: "); + LLVM_DEBUG(U->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); } - } else { - LLVM_DEBUG(dbgs() << " Delete (alloca NULL): "); - LLVM_DEBUG(BI->print(dbgs())); - LLVM_DEBUG(dbgs() << "\n"); } } else { - LLVM_DEBUG(dbgs() << " Delete (!parameter): "); - LLVM_DEBUG(BI->print(dbgs())); + LLVM_DEBUG(dbgs() << " Delete (alloca NULL): "); + LLVM_DEBUG(DbgDecl->print(dbgs())); LLVM_DEBUG(dbgs() << "\n"); } + } else { + LLVM_DEBUG(dbgs() << " Delete (!parameter): "); + LLVM_DEBUG(DbgDecl->print(dbgs())); + LLVM_DEBUG(dbgs() << "\n"); + } + }; + + for (BasicBlock::iterator BI = GEntryBlock->begin(), BIE = GEntryBlock->end(); + BI != BIE; ++BI) { + // Examine DbgVariableRecords as they happen "before" the instruction. Are + // they connected to parameters? + for (DbgVariableRecord &DVR : filterDbgVars(BI->getDbgRecordRange())) { + if (DVR.isDbgValue() || DVR.isDbgAssign()) { + ExamineDbgValue(&DVR, PDVRRelated); + } else { + assert(DVR.isDbgDeclare()); + ExamineDbgDeclare(&DVR, PDVRRelated); + } + } + + if (auto *DVI = dyn_cast<DbgValueInst>(&*BI)) { + ExamineDbgValue(DVI, PDIRelated); + } else if (auto *DDI = dyn_cast<DbgDeclareInst>(&*BI)) { + ExamineDbgDeclare(DDI, PDIRelated); } else if (BI->isTerminator() && &*BI == GEntryBlock->getTerminator()) { LLVM_DEBUG(dbgs() << " Will Include Terminator: "); LLVM_DEBUG(BI->print(dbgs())); @@ -630,17 +672,25 @@ void MergeFunctions::filterInstsUnrelatedToPDI( LLVM_DEBUG( dbgs() << " Report parameter debug info related/related instructions: {\n"); - for (Instruction &I : *GEntryBlock) { - if (PDIRelated.find(&I) == PDIRelated.end()) { + + auto IsPDIRelated = [](auto *Rec, auto &Container, auto &UnrelatedCont) { + if (Container.find(Rec) == Container.end()) { LLVM_DEBUG(dbgs() << " !PDIRelated: "); - LLVM_DEBUG(I.print(dbgs())); + LLVM_DEBUG(Rec->print(dbgs())); LLVM_DEBUG(dbgs() << "\n"); - PDIUnrelatedWL.push_back(&I); + UnrelatedCont.push_back(Rec); } else { LLVM_DEBUG(dbgs() << " PDIRelated: "); - LLVM_DEBUG(I.print(dbgs())); + LLVM_DEBUG(Rec->print(dbgs())); LLVM_DEBUG(dbgs() << "\n"); } + }; + + // Collect the set of unrelated instructions and debug records. + for (Instruction &I : *GEntryBlock) { + for (DbgVariableRecord &DVR : filterDbgVars(I.getDbgRecordRange())) + IsPDIRelated(&DVR, PDVRRelated, PDVRUnrelatedWL); + IsPDIRelated(&I, PDIRelated, PDIUnrelatedWL); } LLVM_DEBUG(dbgs() << " }\n"); } @@ -662,11 +712,13 @@ static bool canCreateThunkFor(Function *F) { return true; } -/// Copy metadata from one function to another. -static void copyMetadataIfPresent(Function *From, Function *To, StringRef Key) { - if (MDNode *MD = From->getMetadata(Key)) { - To->setMetadata(Key, MD); - } +/// Copy all metadata of a specific kind from one function to another. +static void copyMetadataIfPresent(Function *From, Function *To, + StringRef Kind) { + SmallVector<MDNode *, 4> MDs; + From->getMetadata(Kind, MDs); + for (MDNode *MD : MDs) + To->addMetadata(Kind, *MD); } // Replace G with a simple tail call to bitcast(F). Also (unless @@ -680,6 +732,7 @@ static void copyMetadataIfPresent(Function *From, Function *To, StringRef Key) { void MergeFunctions::writeThunk(Function *F, Function *G) { BasicBlock *GEntryBlock = nullptr; std::vector<Instruction *> PDIUnrelatedWL; + std::vector<DbgVariableRecord *> PDVRUnrelatedWL; BasicBlock *BB = nullptr; Function *NewG = nullptr; if (MergeFunctionsPDI) { @@ -691,13 +744,14 @@ void MergeFunctions::writeThunk(Function *F, Function *G) { dbgs() << "writeThunk: (MergeFunctionsPDI) filter parameter related " "debug info for " << G->getName() << "() {\n"); - filterInstsUnrelatedToPDI(GEntryBlock, PDIUnrelatedWL); + filterInstsUnrelatedToPDI(GEntryBlock, PDIUnrelatedWL, PDVRUnrelatedWL); GEntryBlock->getTerminator()->eraseFromParent(); BB = GEntryBlock; } else { NewG = Function::Create(G->getFunctionType(), G->getLinkage(), G->getAddressSpace(), "", G->getParent()); NewG->setComdat(G->getComdat()); + NewG->IsNewDbgInfoFormat = G->IsNewDbgInfoFormat; BB = BasicBlock::Create(F->getContext(), "", NewG); } @@ -740,7 +794,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) { << G->getName() << "()\n"); } eraseTail(G); - eraseInstsUnrelatedToPDI(PDIUnrelatedWL); + eraseInstsUnrelatedToPDI(PDIUnrelatedWL, PDVRUnrelatedWL); LLVM_DEBUG( dbgs() << "} // End of parameter related debug info filtering for: " << G->getName() << "()\n"); @@ -825,6 +879,7 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { F->getAddressSpace(), "", F->getParent()); NewF->copyAttributesFrom(F); NewF->takeName(F); + NewF->IsNewDbgInfoFormat = F->IsNewDbgInfoFormat; // Ensure CFI type metadata is propagated to the new function. copyMetadataIfPresent(F, NewF, "type"); copyMetadataIfPresent(F, NewF, "kcfi_type"); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp index 4176d561363f..b290651d66c5 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -880,7 +880,7 @@ private: if (BB != Before.getParent()) return false; - const DataLayout &DL = Array.getModule()->getDataLayout(); + const DataLayout &DL = Array.getDataLayout(); const unsigned int PointerSize = DL.getPointerSize(); for (Instruction &I : *BB) { @@ -1146,17 +1146,18 @@ private: const DataLayout &DL = M.getDataLayout(); AllocaInst *AllocaI = new AllocaInst( I.getType(), DL.getAllocaAddrSpace(), nullptr, - I.getName() + ".seq.output.alloc", &OuterFn->front().front()); + I.getName() + ".seq.output.alloc", OuterFn->front().begin()); // Emit a store instruction in the sequential BB to update the // value. - new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()); + new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()->getIterator()); // Emit a load instruction and replace the use of the output value // with it. for (Instruction *UsrI : OutsideUsers) { - LoadInst *LoadI = new LoadInst( - I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI); + LoadInst *LoadI = new LoadInst(I.getType(), AllocaI, + I.getName() + ".seq.output.load", + UsrI->getIterator()); UsrI->replaceUsesOfWith(&I, LoadI); } } @@ -1261,7 +1262,8 @@ private: ++U) Args.push_back(CI->getArgOperand(U)); - CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI); + CallInst *NewCI = + CallInst::Create(FT, Callee, Args, "", CI->getIterator()); if (CI->getDebugLoc()) NewCI->setDebugLoc(CI->getDebugLoc()); @@ -1451,7 +1453,6 @@ private: }; emitRemark<OptimizationRemark>(CI, "OMP160", Remark); - CGUpdater.removeCallSite(*CI); CI->eraseFromParent(); Changed = true; ++NumOpenMPParallelRegionsDeleted; @@ -1471,7 +1472,6 @@ private: OMPRTL_omp_get_num_threads, OMPRTL_omp_in_parallel, OMPRTL_omp_get_cancellation, - OMPRTL_omp_get_thread_limit, OMPRTL_omp_get_supported_active_levels, OMPRTL_omp_get_level, OMPRTL_omp_get_ancestor_thread_num, @@ -1666,21 +1666,21 @@ private: BP->print(Printer); Printer << Separator; } - LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n"); + LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << ValuesStr << "\n"); ValuesStr.clear(); for (auto *P : OAs[1].StoredValues) { P->print(Printer); Printer << Separator; } - LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n"); + LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << ValuesStr << "\n"); ValuesStr.clear(); for (auto *S : OAs[2].StoredValues) { S->print(Printer); Printer << Separator; } - LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n"); + LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << ValuesStr << "\n"); } /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be @@ -1740,8 +1740,8 @@ private: Args.push_back(Arg.get()); Args.push_back(Handle); - CallInst *IssueCallsite = - CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall); + CallInst *IssueCallsite = CallInst::Create(IssueDecl, Args, /*NameStr=*/"", + RuntimeCall.getIterator()); OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite); RuntimeCall.eraseFromParent(); @@ -1756,7 +1756,7 @@ private: Handle // handle to wait on. }; CallInst *WaitCallsite = CallInst::Create( - WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint); + WaitDecl, WaitParams, /*NameStr=*/"", WaitMovementPoint.getIterator()); OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite); return true; @@ -1894,7 +1894,6 @@ private: else emitRemark<OptimizationRemark>(&F, "OMP170", Remark); - CGUpdater.removeCallSite(*CI); CI->replaceAllUsesWith(ReplVal); CI->eraseFromParent(); ++NumOpenMPRuntimeCallsDeduplicated; @@ -4026,11 +4025,12 @@ struct AAKernelInfoFunction : AAKernelInfo { static_cast<unsigned>(AddressSpace::Shared)); // Emit a store instruction to update the value. - new StoreInst(&I, SharedMem, RegionEndBB->getTerminator()); + new StoreInst(&I, SharedMem, + RegionEndBB->getTerminator()->getIterator()); - LoadInst *LoadI = new LoadInst(I.getType(), SharedMem, - I.getName() + ".guarded.output.load", - RegionBarrierBB->getTerminator()); + LoadInst *LoadI = new LoadInst( + I.getType(), SharedMem, I.getName() + ".guarded.output.load", + RegionBarrierBB->getTerminator()->getIterator()); // Emit a load instruction and replace uses of the output value. for (Use *U : OutsideUses) @@ -4083,8 +4083,9 @@ struct AAKernelInfoFunction : AAKernelInfo { // Second barrier ensures workers have read broadcast values. if (HasBroadcastValues) { - CallInst *Barrier = CallInst::Create(BarrierFn, {Ident, Tid}, "", - RegionBarrierBB->getTerminator()); + CallInst *Barrier = + CallInst::Create(BarrierFn, {Ident, Tid}, "", + RegionBarrierBB->getTerminator()->getIterator()); Barrier->setDebugLoc(DL); OMPInfoCache.setCallingConvention(BarrierFn, Barrier); } @@ -4235,7 +4236,7 @@ struct AAKernelInfoFunction : AAKernelInfo { ORA << "Value has potential side effects preventing SPMD-mode " "execution"; if (isa<CallBase>(NonCompatibleI)) { - ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to " + ORA << ". Add `[[omp::assume(\"ompx_spmd_amenable\")]]` to " "the called function to override"; } return ORA << "."; @@ -4377,7 +4378,7 @@ struct AAKernelInfoFunction : AAKernelInfo { continue; auto Remark = [&](OptimizationRemarkAnalysis ORA) { return ORA << "Call may contain unknown parallel regions. Use " - << "`__attribute__((assume(\"omp_no_parallelism\")))` to " + << "`[[omp::assume(\"omp_no_parallelism\")]]` to " "override."; }; A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB, @@ -4489,7 +4490,7 @@ struct AAKernelInfoFunction : AAKernelInfo { Type *VoidPtrTy = PointerType::getUnqual(Ctx); Instruction *WorkFnAI = new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr, - "worker.work_fn.addr", &Kernel->getEntryBlock().front()); + "worker.work_fn.addr", Kernel->getEntryBlock().begin()); WorkFnAI->setDebugLoc(DLoc); OMPInfoCache.OMPBuilder.updateToLocation( diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp index aa4f205ec5bd..3ca095e1520f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/PartialInlining.cpp @@ -764,7 +764,7 @@ bool PartialInlinerImpl::shouldPartialInline( }); return false; } - const DataLayout &DL = Caller->getParent()->getDataLayout(); + const DataLayout &DL = Caller->getDataLayout(); // The savings of eliminating the call: int NonWeightedSavings = getCallsiteCost(CalleeTTI, CB, DL); @@ -804,7 +804,7 @@ InstructionCost PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB, TargetTransformInfo *TTI) { InstructionCost InlineCost = 0; - const DataLayout &DL = BB->getParent()->getParent()->getDataLayout(); + const DataLayout &DL = BB->getDataLayout(); int InstrCost = InlineConstants::getInstrCost(); for (Instruction &I : BB->instructionsWithoutDebug()) { // Skip free instructions. diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SCCP.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SCCP.cpp index b1f9b827dcba..94ae511b2e4a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SCCP.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SCCP.cpp @@ -144,9 +144,8 @@ static bool runIPSCCP( // Assume the function is called. Solver.markBlockExecutable(&F.front()); - // Assume nothing about the incoming arguments. for (Argument &AI : F.args()) - Solver.markOverdefined(&AI); + Solver.trackValueOfArgument(&AI); } // Determine if we can track any of the module's global variables. If so, add @@ -282,32 +281,21 @@ static bool runIPSCCP( Function *F = I.first; const ValueLatticeElement &ReturnValue = I.second; - // If there is a known constant range for the return value, add !range - // metadata to the function's call sites. + // If there is a known constant range for the return value, add range + // attribute to the return value. if (ReturnValue.isConstantRange() && !ReturnValue.getConstantRange().isSingleElement()) { // Do not add range metadata if the return value may include undef. if (ReturnValue.isConstantRangeIncludingUndef()) continue; + // Do not touch existing attribute for now. + // TODO: We should be able to take the intersection of the existing + // attribute and the inferred range. + if (F->hasRetAttribute(Attribute::Range)) + continue; auto &CR = ReturnValue.getConstantRange(); - for (User *User : F->users()) { - auto *CB = dyn_cast<CallBase>(User); - if (!CB || CB->getCalledFunction() != F) - continue; - - // Do not touch existing metadata for now. - // TODO: We should be able to take the intersection of the existing - // metadata and the inferred range. - if (CB->getMetadata(LLVMContext::MD_range)) - continue; - - LLVMContext &Context = CB->getParent()->getContext(); - Metadata *RangeMD[] = { - ConstantAsMetadata::get(ConstantInt::get(Context, CR.getLower())), - ConstantAsMetadata::get(ConstantInt::get(Context, CR.getUpper()))}; - CB->setMetadata(LLVMContext::MD_range, MDNode::get(Context, RangeMD)); - } + F->addRangeRetAttr(CR); continue; } if (F->getReturnType()->isVoidTy()) @@ -328,7 +316,7 @@ static bool runIPSCCP( SmallSetVector<Function *, 8> FuncZappedReturn; for (ReturnInst *RI : ReturnsToZap) { Function *F = RI->getParent()->getParent(); - RI->setOperand(0, UndefValue::get(F->getReturnType())); + RI->setOperand(0, PoisonValue::get(F->getReturnType())); // Record all functions that are zapped. FuncZappedReturn.insert(F); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp index f7a54d428f20..f878e3e591a0 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleContextTracker.cpp @@ -624,7 +624,7 @@ void SampleContextTracker::createContextLessProfileMap( FunctionSamples *FProfile = Node->getFunctionSamples(); // Profile's context can be empty, use ContextNode's func name. if (FProfile) - ContextLessProfiles.Create(Node->getFuncName()).merge(*FProfile); + ContextLessProfiles.create(Node->getFuncName()).merge(*FProfile); } } } // namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp index 2fd8668d15e2..6af284d513ef 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfile.cpp @@ -71,6 +71,7 @@ #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/ProfiledCallGraph.h" #include "llvm/Transforms/IPO/SampleContextTracker.h" +#include "llvm/Transforms/IPO/SampleProfileMatcher.h" #include "llvm/Transforms/IPO/SampleProfileProbe.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/CallPromotionUtils.h" @@ -129,16 +130,20 @@ static cl::opt<std::string> SampleProfileRemappingFile( "sample-profile-remapping-file", cl::init(""), cl::value_desc("filename"), cl::desc("Profile remapping file loaded by -sample-profile"), cl::Hidden); -static cl::opt<bool> SalvageStaleProfile( +cl::opt<bool> SalvageStaleProfile( "salvage-stale-profile", cl::Hidden, cl::init(false), cl::desc("Salvage stale profile by fuzzy matching and use the remapped " "location for sample profile query.")); +cl::opt<bool> + SalvageUnusedProfile("salvage-unused-profile", cl::Hidden, cl::init(false), + cl::desc("Salvage unused profile by matching with new " + "functions on call graph.")); -static cl::opt<bool> ReportProfileStaleness( +cl::opt<bool> ReportProfileStaleness( "report-profile-staleness", cl::Hidden, cl::init(false), cl::desc("Compute and report stale profile statistical metrics.")); -static cl::opt<bool> PersistProfileStaleness( +cl::opt<bool> PersistProfileStaleness( "persist-profile-staleness", cl::Hidden, cl::init(false), cl::desc("Compute stale profile statistical metrics and write it into the " "native object file(.llvm_stats section).")); @@ -234,22 +239,38 @@ static cl::opt<unsigned> ProfileICPRelativeHotnessSkip( cl::desc( "Skip relative hotness check for ICP up to given number of targets.")); +static cl::opt<unsigned> HotFuncCutoffForStalenessError( + "hot-func-cutoff-for-staleness-error", cl::Hidden, cl::init(800000), + cl::desc("A function is considered hot for staleness error check if its " + "total sample count is above the specified percentile")); + +static cl::opt<unsigned> MinfuncsForStalenessError( + "min-functions-for-staleness-error", cl::Hidden, cl::init(50), + cl::desc("Skip the check if the number of hot functions is smaller than " + "the specified number.")); + +static cl::opt<unsigned> PrecentMismatchForStalenessError( + "precent-mismatch-for-staleness-error", cl::Hidden, cl::init(80), + cl::desc("Reject the profile if the mismatch percent is higher than the " + "given number.")); + static cl::opt<bool> CallsitePrioritizedInline( "sample-profile-prioritized-inline", cl::Hidden, - cl::desc("Use call site prioritized inlining for sample profile loader." "Currently only CSSPGO is supported.")); static cl::opt<bool> UsePreInlinerDecision( "sample-profile-use-preinliner", cl::Hidden, - cl::desc("Use the preinliner decisions stored in profile context.")); static cl::opt<bool> AllowRecursiveInline( "sample-profile-recursive-inline", cl::Hidden, - cl::desc("Allow sample loader inliner to inline recursive calls.")); +static cl::opt<bool> RemoveProbeAfterProfileAnnotation( + "sample-profile-remove-probe", cl::Hidden, cl::init(false), + cl::desc("Remove pseudo-probe after sample profile annotation.")); + static cl::opt<std::string> ProfileInlineReplayFile( "sample-profile-inline-replay", cl::init(""), cl::value_desc("filename"), cl::desc( @@ -418,7 +439,10 @@ struct CandidateComparer { const FunctionSamples *LCS = LHS.CalleeSamples; const FunctionSamples *RCS = RHS.CalleeSamples; - assert(LCS && RCS && "Expect non-null FunctionSamples"); + // In inline replay mode, CalleeSamples may be null and the order doesn't + // matter. + if (!LCS || !RCS) + return LCS; // Tie breaker using number of samples try to favor smaller functions first if (LCS->getBodySamples().size() != RCS->getBodySamples().size()) @@ -433,79 +457,6 @@ using CandidateQueue = PriorityQueue<InlineCandidate, std::vector<InlineCandidate>, CandidateComparer>; -// Sample profile matching - fuzzy match. -class SampleProfileMatcher { - Module &M; - SampleProfileReader &Reader; - const PseudoProbeManager *ProbeManager; - SampleProfileMap FlattenedProfiles; - // For each function, the matcher generates a map, of which each entry is a - // mapping from the source location of current build to the source location in - // the profile. - StringMap<LocToLocMap> FuncMappings; - - // Profile mismatching statstics. - uint64_t TotalProfiledCallsites = 0; - uint64_t NumMismatchedCallsites = 0; - uint64_t MismatchedCallsiteSamples = 0; - uint64_t TotalCallsiteSamples = 0; - uint64_t TotalProfiledFunc = 0; - uint64_t NumMismatchedFuncHash = 0; - uint64_t MismatchedFuncHashSamples = 0; - uint64_t TotalFuncHashSamples = 0; - - // A dummy name for unknown indirect callee, used to differentiate from a - // non-call instruction that also has an empty callee name. - static constexpr const char *UnknownIndirectCallee = - "unknown.indirect.callee"; - -public: - SampleProfileMatcher(Module &M, SampleProfileReader &Reader, - const PseudoProbeManager *ProbeManager) - : M(M), Reader(Reader), ProbeManager(ProbeManager){}; - void runOnModule(); - -private: - FunctionSamples *getFlattenedSamplesFor(const Function &F) { - StringRef CanonFName = FunctionSamples::getCanonicalFnName(F); - auto It = FlattenedProfiles.find(FunctionId(CanonFName)); - if (It != FlattenedProfiles.end()) - return &It->second; - return nullptr; - } - void runOnFunction(const Function &F); - void findIRAnchors(const Function &F, - std::map<LineLocation, StringRef> &IRAnchors); - void findProfileAnchors( - const FunctionSamples &FS, - std::map<LineLocation, std::unordered_set<FunctionId>> - &ProfileAnchors); - void countMismatchedSamples(const FunctionSamples &FS); - void countProfileMismatches( - const Function &F, const FunctionSamples &FS, - const std::map<LineLocation, StringRef> &IRAnchors, - const std::map<LineLocation, std::unordered_set<FunctionId>> - &ProfileAnchors); - void countProfileCallsiteMismatches( - const FunctionSamples &FS, - const std::map<LineLocation, StringRef> &IRAnchors, - const std::map<LineLocation, std::unordered_set<FunctionId>> - &ProfileAnchors, - uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites); - LocToLocMap &getIRToProfileLocationMap(const Function &F) { - auto Ret = FuncMappings.try_emplace( - FunctionSamples::getCanonicalFnName(F.getName()), LocToLocMap()); - return Ret.first->second; - } - void distributeIRToProfileLocationMap(); - void distributeIRToProfileLocationMap(FunctionSamples &FS); - void runStaleProfileMatching( - const Function &F, const std::map<LineLocation, StringRef> &IRAnchors, - const std::map<LineLocation, std::unordered_set<FunctionId>> - &ProfileAnchors, - LocToLocMap &IRToProfileLocationMap); -}; - /// Sample profile pass. /// /// This pass reads profile data from the file specified by @@ -518,12 +469,13 @@ public: IntrusiveRefCntPtr<vfs::FileSystem> FS, std::function<AssumptionCache &(Function &)> GetAssumptionCache, std::function<TargetTransformInfo &(Function &)> GetTargetTransformInfo, - std::function<const TargetLibraryInfo &(Function &)> GetTLI) + std::function<const TargetLibraryInfo &(Function &)> GetTLI, + LazyCallGraph &CG) : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName), std::move(FS)), GetAC(std::move(GetAssumptionCache)), GetTTI(std::move(GetTargetTransformInfo)), GetTLI(std::move(GetTLI)), - LTOPhase(LTOPhase), + CG(CG), LTOPhase(LTOPhase), AnnotatedPassName(AnnotateSampleProfileInlinePhase ? llvm::AnnotateInlinePassName(InlineContext{ LTOPhase, InlinePass::SampleProfileInliner}) @@ -531,7 +483,7 @@ public: bool doInitialization(Module &M, FunctionAnalysisManager *FAM = nullptr); bool runOnModule(Module &M, ModuleAnalysisManager *AM, - ProfileSummaryInfo *_PSI, LazyCallGraph &CG); + ProfileSummaryInfo *_PSI); protected: bool runOnFunction(Function &F, ModuleAnalysisManager *AM); @@ -573,6 +525,9 @@ protected: std::vector<Function *> buildFunctionOrder(Module &M, LazyCallGraph &CG); std::unique_ptr<ProfiledCallGraph> buildProfiledCallGraph(Module &M); void generateMDProfMetadata(Function &F); + bool rejectHighStalenessProfile(Module &M, ProfileSummaryInfo *PSI, + const SampleProfileMap &Profiles); + void removePseudoProbeInsts(Module &M); /// Map from function name to Function *. Used to find the function from /// the function name. If the function name contains suffix, additional @@ -580,9 +535,14 @@ protected: /// is one-to-one mapping. HashKeyMap<std::unordered_map, FunctionId, Function *> SymbolMap; + /// Map from function name to profile name generated by call-graph based + /// profile fuzzy matching(--salvage-unused-profile). + HashKeyMap<std::unordered_map, FunctionId, FunctionId> FuncNameToProfNameMap; + std::function<AssumptionCache &(Function &)> GetAC; std::function<TargetTransformInfo &(Function &)> GetTTI; std::function<const TargetLibraryInfo &(Function &)> GetTLI; + LazyCallGraph &CG; /// Profile tracker for different context. std::unique_ptr<SampleContextTracker> ContextTracker; @@ -597,7 +557,7 @@ protected: /// Profle Symbol list tells whether a function name appears in the binary /// used to generate the current profile. - std::unique_ptr<ProfileSymbolList> PSL; + std::shared_ptr<ProfileSymbolList> PSL; /// Total number of samples collected in this profile. /// @@ -749,7 +709,8 @@ SampleProfileLoader::findCalleeFunctionSamples(const CallBase &Inst) const { return nullptr; return FS->findFunctionSamplesAt(FunctionSamples::getCallSiteIdentifier(DIL), - CalleeName, Reader->getRemapper()); + CalleeName, Reader->getRemapper(), + &FuncNameToProfNameMap); } /// Returns a vector of FunctionSamples that are the indirect call targets @@ -827,8 +788,8 @@ SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const { if (FunctionSamples::ProfileIsCS) it.first->second = ContextTracker->getContextSamplesFor(DIL); else - it.first->second = - Samples->findFunctionSamples(DIL, Reader->getRemapper()); + it.first->second = Samples->findFunctionSamples( + DIL, Reader->getRemapper(), &FuncNameToProfNameMap); } return it.first->second; } @@ -841,27 +802,23 @@ SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const { /// NOMORE_ICP_MAGICNUM count values in the value profile of \p Inst, we /// cannot promote for \p Inst anymore. static bool doesHistoryAllowICP(const Instruction &Inst, StringRef Candidate) { - uint32_t NumVals = 0; uint64_t TotalCount = 0; - std::unique_ptr<InstrProfValueData[]> ValueData = - std::make_unique<InstrProfValueData[]>(MaxNumPromotions); - bool Valid = - getValueProfDataFromInst(Inst, IPVK_IndirectCallTarget, MaxNumPromotions, - ValueData.get(), NumVals, TotalCount, true); + auto ValueData = getValueProfDataFromInst(Inst, IPVK_IndirectCallTarget, + MaxNumPromotions, TotalCount, true); // No valid value profile so no promoted targets have been recorded // before. Ok to do ICP. - if (!Valid) + if (ValueData.empty()) return true; unsigned NumPromoted = 0; - for (uint32_t I = 0; I < NumVals; I++) { - if (ValueData[I].Count != NOMORE_ICP_MAGICNUM) + for (const auto &V : ValueData) { + if (V.Count != NOMORE_ICP_MAGICNUM) continue; // If the promotion candidate has NOMORE_ICP_MAGICNUM count in the // metadata, it means the candidate has been promoted for this // indirect call. - if (ValueData[I].Value == Function::getGUID(Candidate)) + if (V.Value == Function::getGUID(Candidate)) return false; NumPromoted++; // If already have MaxNumPromotions promotion, don't do it anymore. @@ -887,14 +844,10 @@ updateIDTMetaData(Instruction &Inst, // `MaxNumPromotions` inside it. if (MaxNumPromotions == 0) return; - uint32_t NumVals = 0; // OldSum is the existing total count in the value profile data. uint64_t OldSum = 0; - std::unique_ptr<InstrProfValueData[]> ValueData = - std::make_unique<InstrProfValueData[]>(MaxNumPromotions); - bool Valid = - getValueProfDataFromInst(Inst, IPVK_IndirectCallTarget, MaxNumPromotions, - ValueData.get(), NumVals, OldSum, true); + auto ValueData = getValueProfDataFromInst(Inst, IPVK_IndirectCallTarget, + MaxNumPromotions, OldSum, true); DenseMap<uint64_t, uint64_t> ValueCountMap; if (Sum == 0) { @@ -903,10 +856,8 @@ updateIDTMetaData(Instruction &Inst, "If sum is 0, assume only one element in CallTargets " "with count being NOMORE_ICP_MAGICNUM"); // Initialize ValueCountMap with existing value profile data. - if (Valid) { - for (uint32_t I = 0; I < NumVals; I++) - ValueCountMap[ValueData[I].Value] = ValueData[I].Count; - } + for (const auto &V : ValueData) + ValueCountMap[V.Value] = V.Count; auto Pair = ValueCountMap.try_emplace(CallTargets[0].Value, CallTargets[0].Count); // If the target already exists in value profile, decrease the total @@ -919,11 +870,9 @@ updateIDTMetaData(Instruction &Inst, } else { // Initialize ValueCountMap with existing NOMORE_ICP_MAGICNUM // counts in the value profile. - if (Valid) { - for (uint32_t I = 0; I < NumVals; I++) { - if (ValueData[I].Count == NOMORE_ICP_MAGICNUM) - ValueCountMap[ValueData[I].Value] = ValueData[I].Count; - } + for (const auto &V : ValueData) { + if (V.Count == NOMORE_ICP_MAGICNUM) + ValueCountMap[V.Value] = V.Count; } for (const auto &Data : CallTargets) { @@ -1106,6 +1055,9 @@ void SampleProfileLoader::findExternalInlineCandidate( // For AutoFDO profile, retrieve candidate profiles by walking over // the nested inlinee profiles. if (!FunctionSamples::ProfileIsCS) { + // Set threshold to zero to honor pre-inliner decision. + if (UsePreInlinerDecision) + Threshold = 0; Samples->findInlinedFunctions(InlinedGUIDs, SymbolMap, Threshold); return; } @@ -1129,7 +1081,7 @@ void SampleProfileLoader::findExternalInlineCandidate( CalleeSample->getContext().hasAttribute(ContextShouldBeInlined); if (!PreInline && CalleeSample->getHeadSamplesEstimate() < Threshold) continue; - + Function *Func = SymbolMap.lookup(CalleeSample->getFunction()); // Add to the import list only when it's defined out of module. if (!Func || Func->isDeclaration()) @@ -1441,10 +1393,11 @@ SampleProfileLoader::shouldInlineCandidate(InlineCandidate &Candidate) { return InlineCost::getAlways("preinliner"); } - // For old FDO inliner, we inline the call site as long as cost is not - // "Never". The cost-benefit check is done earlier. + // For old FDO inliner, we inline the call site if it is below hot threshold, + // even if the function is hot based on sample profile data. This is to + // prevent huge functions from being inlined. if (!CallsitePrioritizedInline) { - return InlineCost::get(Cost.getCost(), INT_MAX); + return InlineCost::get(Cost.getCost(), SampleHotCallSiteThreshold); } // Otherwise only use the cost from call analyzer, but overwite threshold with @@ -1633,7 +1586,7 @@ void SampleProfileLoader::promoteMergeNotInlinedContextSamples( FunctionId(FunctionSamples::getCanonicalFnName(Callee->getName()))]; OutlineFS->merge(*FS, 1); // Set outlined profile to be synthetic to not bias the inliner. - OutlineFS->SetContextSynthetic(); + OutlineFS->setContextSynthetic(); } } else { auto pair = @@ -1647,7 +1600,7 @@ void SampleProfileLoader::promoteMergeNotInlinedContextSamples( static SmallVector<InstrProfValueData, 2> GetSortedValueDataFromCallTargets(const SampleRecord::CallTargetMap &M) { SmallVector<InstrProfValueData, 2> R; - for (const auto &I : SampleRecord::SortCallTargets(M)) { + for (const auto &I : SampleRecord::sortCallTargets(M)) { R.emplace_back( InstrProfValueData{I.first.getHashCode(), I.second}); } @@ -1711,7 +1664,8 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { else if (OverwriteExistingWeights) I.setMetadata(LLVMContext::MD_prof, nullptr); } else if (!isa<IntrinsicInst>(&I)) { - setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])}); + setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])}, + /*IsExpected=*/false); } } } else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) { @@ -1722,7 +1676,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { if (cast<CallBase>(I).isIndirectCall()) { I.setMetadata(LLVMContext::MD_prof, nullptr); } else { - setBranchWeights(I, {uint32_t(0)}); + setBranchWeights(I, {uint32_t(0)}, /*IsExpected=*/false); } } } @@ -1765,13 +1719,15 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { // if needed. Sample counts in profiles are 64-bit unsigned values, // but internally branch weights are expressed as 32-bit values. if (Weight > std::numeric_limits<uint32_t>::max()) { - LLVM_DEBUG(dbgs() << " (saturated due to uint32_t overflow)"); + LLVM_DEBUG(dbgs() << " (saturated due to uint32_t overflow)\n"); Weight = std::numeric_limits<uint32_t>::max(); } if (!SampleProfileUseProfi) { // Weight is added by one to avoid propagation errors introduced by // 0 weights. - Weights.push_back(static_cast<uint32_t>(Weight + 1)); + Weights.push_back(static_cast<uint32_t>( + Weight == std::numeric_limits<uint32_t>::max() ? Weight + : Weight + 1)); } else { // Profi creates proper weights that do not require "+1" adjustments but // we evenly split the weight among branches with the same destination. @@ -1803,7 +1759,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) { if (MaxWeight > 0 && (!TI->extractProfTotalWeight(TempWeight) || OverwriteExistingWeights)) { LLVM_DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n"); - setBranchWeights(*TI, Weights); + setBranchWeights(*TI, Weights, /*IsExpected=*/false); ORE->emit([&]() { return OptimizationRemark(DEBUG_TYPE, "PopularDest", MaxDestInst) << "most popular destination for conditional branches at " @@ -1830,15 +1786,22 @@ bool SampleProfileLoader::emitAnnotations(Function &F) { bool Changed = false; if (FunctionSamples::ProfileIsProbeBased) { - if (!ProbeManager->profileIsValid(F, *Samples)) { + LLVM_DEBUG({ + if (!ProbeManager->getDesc(F)) + dbgs() << "Probe descriptor missing for Function " << F.getName() + << "\n"; + }); + + if (ProbeManager->profileIsValid(F, *Samples)) { + ++NumMatchedProfile; + } else { + ++NumMismatchedProfile; LLVM_DEBUG( dbgs() << "Profile is invalid due to CFG mismatch for Function " << F.getName() << "\n"); - ++NumMismatchedProfile; if (!SalvageStaleProfile) return false; } - ++NumMatchedProfile; } else { if (getFunctionLoc(F) == 0) return false; @@ -1874,7 +1837,7 @@ SampleProfileLoader::buildProfiledCallGraph(Module &M) { // the profile. This makes sure functions missing from the profile still // gets a chance to be processed. for (Function &F : M) { - if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile")) + if (skipProfileForFunction(F)) continue; ProfiledCG->addProfiledFunction( getRepInFormat(FunctionSamples::getCanonicalFnName(F))); @@ -1903,7 +1866,7 @@ SampleProfileLoader::buildFunctionOrder(Module &M, LazyCallGraph &CG) { } for (Function &F : M) - if (!F.isDeclaration() && F.hasFnAttribute("use-sample-profile")) + if (!skipProfileForFunction(F)) FunctionOrderList.push_back(&F); return FunctionOrderList; } @@ -1969,25 +1932,14 @@ SampleProfileLoader::buildFunctionOrder(Module &M, LazyCallGraph &CG) { } for (auto *Node : Range) { Function *F = SymbolMap.lookup(Node->Name); - if (F && !F->isDeclaration() && F->hasFnAttribute("use-sample-profile")) + if (F && !skipProfileForFunction(*F)) FunctionOrderList.push_back(F); } ++CGI; } - } else { - CG.buildRefSCCs(); - for (LazyCallGraph::RefSCC &RC : CG.postorder_ref_sccs()) { - for (LazyCallGraph::SCC &C : RC) { - for (LazyCallGraph::Node &N : C) { - Function &F = N.getFunction(); - if (!F.isDeclaration() && F.hasFnAttribute("use-sample-profile")) - FunctionOrderList.push_back(&F); - } - } - } - } - - std::reverse(FunctionOrderList.begin(), FunctionOrderList.end()); + std::reverse(FunctionOrderList.begin(), FunctionOrderList.end()); + } else + buildTopDownFuncOrder(CG, FunctionOrderList); LLVM_DEBUG({ dbgs() << "Function processing order:\n"; @@ -2116,432 +2068,79 @@ bool SampleProfileLoader::doInitialization(Module &M, if (ReportProfileStaleness || PersistProfileStaleness || SalvageStaleProfile) { - MatchingManager = - std::make_unique<SampleProfileMatcher>(M, *Reader, ProbeManager.get()); + MatchingManager = std::make_unique<SampleProfileMatcher>( + M, *Reader, CG, ProbeManager.get(), LTOPhase, SymbolMap, PSL, + FuncNameToProfNameMap); } return true; } -void SampleProfileMatcher::findIRAnchors( - const Function &F, std::map<LineLocation, StringRef> &IRAnchors) { - // For inlined code, recover the original callsite and callee by finding the - // top-level inline frame. e.g. For frame stack "main:1 @ foo:2 @ bar:3", the - // top-level frame is "main:1", the callsite is "1" and the callee is "foo". - auto FindTopLevelInlinedCallsite = [](const DILocation *DIL) { - assert((DIL && DIL->getInlinedAt()) && "No inlined callsite"); - const DILocation *PrevDIL = nullptr; - do { - PrevDIL = DIL; - DIL = DIL->getInlinedAt(); - } while (DIL->getInlinedAt()); - - LineLocation Callsite = FunctionSamples::getCallSiteIdentifier(DIL); - StringRef CalleeName = PrevDIL->getSubprogramLinkageName(); - return std::make_pair(Callsite, CalleeName); - }; - - auto GetCanonicalCalleeName = [](const CallBase *CB) { - StringRef CalleeName = UnknownIndirectCallee; - if (Function *Callee = CB->getCalledFunction()) - CalleeName = FunctionSamples::getCanonicalFnName(Callee->getName()); - return CalleeName; - }; - - // Extract profile matching anchors in the IR. - for (auto &BB : F) { - for (auto &I : BB) { - DILocation *DIL = I.getDebugLoc(); - if (!DIL) - continue; - - if (FunctionSamples::ProfileIsProbeBased) { - if (auto Probe = extractProbe(I)) { - // Flatten inlined IR for the matching. - if (DIL->getInlinedAt()) { - IRAnchors.emplace(FindTopLevelInlinedCallsite(DIL)); - } else { - // Use empty StringRef for basic block probe. - StringRef CalleeName; - if (const auto *CB = dyn_cast<CallBase>(&I)) { - // Skip the probe inst whose callee name is "llvm.pseudoprobe". - if (!isa<IntrinsicInst>(&I)) - CalleeName = GetCanonicalCalleeName(CB); - } - IRAnchors.emplace(LineLocation(Probe->Id, 0), CalleeName); - } - } - } else { - // TODO: For line-number based profile(AutoFDO), currently only support - // find callsite anchors. In future, we need to parse all the non-call - // instructions to extract the line locations for profile matching. - if (!isa<CallBase>(&I) || isa<IntrinsicInst>(&I)) - continue; - - if (DIL->getInlinedAt()) { - IRAnchors.emplace(FindTopLevelInlinedCallsite(DIL)); - } else { - LineLocation Callsite = FunctionSamples::getCallSiteIdentifier(DIL); - StringRef CalleeName = GetCanonicalCalleeName(dyn_cast<CallBase>(&I)); - IRAnchors.emplace(Callsite, CalleeName); - } - } - } - } -} - -void SampleProfileMatcher::countMismatchedSamples(const FunctionSamples &FS) { - const auto *FuncDesc = ProbeManager->getDesc(FS.getGUID()); - // Skip the function that is external or renamed. - if (!FuncDesc) - return; - - if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) { - MismatchedFuncHashSamples += FS.getTotalSamples(); - return; - } - for (const auto &I : FS.getCallsiteSamples()) - for (const auto &CS : I.second) - countMismatchedSamples(CS.second); -} - -void SampleProfileMatcher::countProfileMismatches( - const Function &F, const FunctionSamples &FS, - const std::map<LineLocation, StringRef> &IRAnchors, - const std::map<LineLocation, std::unordered_set<FunctionId>> - &ProfileAnchors) { - [[maybe_unused]] bool IsFuncHashMismatch = false; - if (FunctionSamples::ProfileIsProbeBased) { - TotalFuncHashSamples += FS.getTotalSamples(); - TotalProfiledFunc++; - const auto *FuncDesc = ProbeManager->getDesc(F); - if (FuncDesc) { - if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) { - NumMismatchedFuncHash++; - IsFuncHashMismatch = true; - } - countMismatchedSamples(FS); - } - } - - uint64_t FuncMismatchedCallsites = 0; - uint64_t FuncProfiledCallsites = 0; - countProfileCallsiteMismatches(FS, IRAnchors, ProfileAnchors, - FuncMismatchedCallsites, - FuncProfiledCallsites); - TotalProfiledCallsites += FuncProfiledCallsites; - NumMismatchedCallsites += FuncMismatchedCallsites; - LLVM_DEBUG({ - if (FunctionSamples::ProfileIsProbeBased && !IsFuncHashMismatch && - FuncMismatchedCallsites) - dbgs() << "Function checksum is matched but there are " - << FuncMismatchedCallsites << "/" << FuncProfiledCallsites - << " mismatched callsites.\n"; - }); -} - -void SampleProfileMatcher::countProfileCallsiteMismatches( - const FunctionSamples &FS, - const std::map<LineLocation, StringRef> &IRAnchors, - const std::map<LineLocation, std::unordered_set<FunctionId>> - &ProfileAnchors, - uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites) { - - // Check if there are any callsites in the profile that does not match to any - // IR callsites, those callsite samples will be discarded. - for (const auto &I : ProfileAnchors) { - const auto &Loc = I.first; - const auto &Callees = I.second; - assert(!Callees.empty() && "Callees should not be empty"); - - StringRef IRCalleeName; - const auto &IR = IRAnchors.find(Loc); - if (IR != IRAnchors.end()) - IRCalleeName = IR->second; - - // Compute number of samples in the original profile. - uint64_t CallsiteSamples = 0; - if (auto CTM = FS.findCallTargetMapAt(Loc)) { - for (const auto &I : *CTM) - CallsiteSamples += I.second; - } - const auto *FSMap = FS.findFunctionSamplesMapAt(Loc); - if (FSMap) { - for (const auto &I : *FSMap) - CallsiteSamples += I.second.getTotalSamples(); - } - - bool CallsiteIsMatched = false; - // Since indirect call does not have CalleeName, check conservatively if - // callsite in the profile is a callsite location. This is to reduce num of - // false positive since otherwise all the indirect call samples will be - // reported as mismatching. - if (IRCalleeName == UnknownIndirectCallee) - CallsiteIsMatched = true; - else if (Callees.size() == 1 && Callees.count(getRepInFormat(IRCalleeName))) - CallsiteIsMatched = true; - - FuncProfiledCallsites++; - TotalCallsiteSamples += CallsiteSamples; - if (!CallsiteIsMatched) { - FuncMismatchedCallsites++; - MismatchedCallsiteSamples += CallsiteSamples; - } - } -} - -void SampleProfileMatcher::findProfileAnchors(const FunctionSamples &FS, - std::map<LineLocation, std::unordered_set<FunctionId>> &ProfileAnchors) { - auto isInvalidLineOffset = [](uint32_t LineOffset) { - return LineOffset & 0x8000; - }; - - for (const auto &I : FS.getBodySamples()) { - const LineLocation &Loc = I.first; - if (isInvalidLineOffset(Loc.LineOffset)) +// Note that this is a module-level check. Even if one module is errored out, +// the entire build will be errored out. However, the user could make big +// changes to functions in single module but those changes might not be +// performance significant to the whole binary. Therefore, to avoid those false +// positives, we select a reasonable big set of hot functions that are supposed +// to be globally performance significant, only compute and check the mismatch +// within those functions. The function selection is based on two criteria: +// 1) The function is hot enough, which is tuned by a hotness-based +// flag(HotFuncCutoffForStalenessError). 2) The num of function is large enough +// which is tuned by the MinfuncsForStalenessError flag. +bool SampleProfileLoader::rejectHighStalenessProfile( + Module &M, ProfileSummaryInfo *PSI, const SampleProfileMap &Profiles) { + assert(FunctionSamples::ProfileIsProbeBased && + "Only support for probe-based profile"); + uint64_t TotalHotFunc = 0; + uint64_t NumMismatchedFunc = 0; + for (const auto &I : Profiles) { + const auto &FS = I.second; + const auto *FuncDesc = ProbeManager->getDesc(FS.getGUID()); + if (!FuncDesc) continue; - for (const auto &I : I.second.getCallTargets()) { - auto Ret = ProfileAnchors.try_emplace(Loc, - std::unordered_set<FunctionId>()); - Ret.first->second.insert(I.first); - } - } - for (const auto &I : FS.getCallsiteSamples()) { - const LineLocation &Loc = I.first; - if (isInvalidLineOffset(Loc.LineOffset)) + // Use a hotness-based threshold to control the function selection. + if (!PSI->isHotCountNthPercentile(HotFuncCutoffForStalenessError, + FS.getTotalSamples())) continue; - const auto &CalleeMap = I.second; - for (const auto &I : CalleeMap) { - auto Ret = ProfileAnchors.try_emplace(Loc, - std::unordered_set<FunctionId>()); - Ret.first->second.insert(I.first); - } - } -} - -// Call target name anchor based profile fuzzy matching. -// Input: -// For IR locations, the anchor is the callee name of direct callsite; For -// profile locations, it's the call target name for BodySamples or inlinee's -// profile name for CallsiteSamples. -// Matching heuristic: -// First match all the anchors in lexical order, then split the non-anchor -// locations between the two anchors evenly, first half are matched based on the -// start anchor, second half are matched based on the end anchor. -// For example, given: -// IR locations: [1, 2(foo), 3, 5, 6(bar), 7] -// Profile locations: [1, 2, 3(foo), 4, 7, 8(bar), 9] -// The matching gives: -// [1, 2(foo), 3, 5, 6(bar), 7] -// | | | | | | -// [1, 2, 3(foo), 4, 7, 8(bar), 9] -// The output mapping: [2->3, 3->4, 5->7, 6->8, 7->9]. -void SampleProfileMatcher::runStaleProfileMatching( - const Function &F, - const std::map<LineLocation, StringRef> &IRAnchors, - const std::map<LineLocation, std::unordered_set<FunctionId>> - &ProfileAnchors, - LocToLocMap &IRToProfileLocationMap) { - LLVM_DEBUG(dbgs() << "Run stale profile matching for " << F.getName() - << "\n"); - assert(IRToProfileLocationMap.empty() && - "Run stale profile matching only once per function"); - - std::unordered_map<FunctionId, std::set<LineLocation>> - CalleeToCallsitesMap; - for (const auto &I : ProfileAnchors) { - const auto &Loc = I.first; - const auto &Callees = I.second; - // Filter out possible indirect calls, use direct callee name as anchor. - if (Callees.size() == 1) { - FunctionId CalleeName = *Callees.begin(); - const auto &Candidates = CalleeToCallsitesMap.try_emplace( - CalleeName, std::set<LineLocation>()); - Candidates.first->second.insert(Loc); - } - } - - auto InsertMatching = [&](const LineLocation &From, const LineLocation &To) { - // Skip the unchanged location mapping to save memory. - if (From != To) - IRToProfileLocationMap.insert({From, To}); - }; - - // Use function's beginning location as the initial anchor. - int32_t LocationDelta = 0; - SmallVector<LineLocation> LastMatchedNonAnchors; - - for (const auto &IR : IRAnchors) { - const auto &Loc = IR.first; - auto CalleeName = IR.second; - bool IsMatchedAnchor = false; - // Match the anchor location in lexical order. - if (!CalleeName.empty()) { - auto CandidateAnchors = CalleeToCallsitesMap.find( - getRepInFormat(CalleeName)); - if (CandidateAnchors != CalleeToCallsitesMap.end() && - !CandidateAnchors->second.empty()) { - auto CI = CandidateAnchors->second.begin(); - const auto Candidate = *CI; - CandidateAnchors->second.erase(CI); - InsertMatching(Loc, Candidate); - LLVM_DEBUG(dbgs() << "Callsite with callee:" << CalleeName - << " is matched from " << Loc << " to " << Candidate - << "\n"); - LocationDelta = Candidate.LineOffset - Loc.LineOffset; - - // Match backwards for non-anchor locations. - // The locations in LastMatchedNonAnchors have been matched forwards - // based on the previous anchor, spilt it evenly and overwrite the - // second half based on the current anchor. - for (size_t I = (LastMatchedNonAnchors.size() + 1) / 2; - I < LastMatchedNonAnchors.size(); I++) { - const auto &L = LastMatchedNonAnchors[I]; - uint32_t CandidateLineOffset = L.LineOffset + LocationDelta; - LineLocation Candidate(CandidateLineOffset, L.Discriminator); - InsertMatching(L, Candidate); - LLVM_DEBUG(dbgs() << "Location is rematched backwards from " << L - << " to " << Candidate << "\n"); - } - - IsMatchedAnchor = true; - LastMatchedNonAnchors.clear(); - } - } - - // Match forwards for non-anchor locations. - if (!IsMatchedAnchor) { - uint32_t CandidateLineOffset = Loc.LineOffset + LocationDelta; - LineLocation Candidate(CandidateLineOffset, Loc.Discriminator); - InsertMatching(Loc, Candidate); - LLVM_DEBUG(dbgs() << "Location is matched from " << Loc << " to " - << Candidate << "\n"); - LastMatchedNonAnchors.emplace_back(Loc); - } - } -} - -void SampleProfileMatcher::runOnFunction(const Function &F) { - // We need to use flattened function samples for matching. - // Unlike IR, which includes all callsites from the source code, the callsites - // in profile only show up when they are hit by samples, i,e. the profile - // callsites in one context may differ from those in another context. To get - // the maximum number of callsites, we merge the function profiles from all - // contexts, aka, the flattened profile to find profile anchors. - const auto *FSFlattened = getFlattenedSamplesFor(F); - if (!FSFlattened) - return; - // Anchors for IR. It's a map from IR location to callee name, callee name is - // empty for non-call instruction and use a dummy name(UnknownIndirectCallee) - // for unknown indrect callee name. - std::map<LineLocation, StringRef> IRAnchors; - findIRAnchors(F, IRAnchors); - // Anchors for profile. It's a map from callsite location to a set of callee - // name. - std::map<LineLocation, std::unordered_set<FunctionId>> ProfileAnchors; - findProfileAnchors(*FSFlattened, ProfileAnchors); - - // Detect profile mismatch for profile staleness metrics report. - // Skip reporting the metrics for imported functions. - if (!GlobalValue::isAvailableExternallyLinkage(F.getLinkage()) && - (ReportProfileStaleness || PersistProfileStaleness)) { - // Use top-level nested FS for counting profile mismatch metrics since - // currently once a callsite is mismatched, all its children profiles are - // dropped. - if (const auto *FS = Reader.getSamplesFor(F)) - countProfileMismatches(F, *FS, IRAnchors, ProfileAnchors); + TotalHotFunc++; + if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) + NumMismatchedFunc++; } + // Make sure that the num of selected function is not too small to distinguish + // from the user's benign changes. + if (TotalHotFunc < MinfuncsForStalenessError) + return false; - // Run profile matching for checksum mismatched profile, currently only - // support for pseudo-probe. - if (SalvageStaleProfile && FunctionSamples::ProfileIsProbeBased && - !ProbeManager->profileIsValid(F, *FSFlattened)) { - // The matching result will be saved to IRToProfileLocationMap, create a new - // map for each function. - runStaleProfileMatching(F, IRAnchors, ProfileAnchors, - getIRToProfileLocationMap(F)); + // Finally check the mismatch percentage against the threshold. + if (NumMismatchedFunc * 100 >= + TotalHotFunc * PrecentMismatchForStalenessError) { + auto &Ctx = M.getContext(); + const char *Msg = + "The input profile significantly mismatches current source code. " + "Please recollect profile to avoid performance regression."; + Ctx.diagnose(DiagnosticInfoSampleProfile(M.getModuleIdentifier(), Msg)); + return true; } + return false; } -void SampleProfileMatcher::runOnModule() { - ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles, - FunctionSamples::ProfileIsCS); +void SampleProfileLoader::removePseudoProbeInsts(Module &M) { for (auto &F : M) { - if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile")) - continue; - runOnFunction(F); - } - if (SalvageStaleProfile) - distributeIRToProfileLocationMap(); - - if (ReportProfileStaleness) { - if (FunctionSamples::ProfileIsProbeBased) { - errs() << "(" << NumMismatchedFuncHash << "/" << TotalProfiledFunc << ")" - << " of functions' profile are invalid and " - << " (" << MismatchedFuncHashSamples << "/" << TotalFuncHashSamples - << ")" - << " of samples are discarded due to function hash mismatch.\n"; - } - errs() << "(" << NumMismatchedCallsites << "/" << TotalProfiledCallsites - << ")" - << " of callsites' profile are invalid and " - << "(" << MismatchedCallsiteSamples << "/" << TotalCallsiteSamples - << ")" - << " of samples are discarded due to callsite location mismatch.\n"; - } - - if (PersistProfileStaleness) { - LLVMContext &Ctx = M.getContext(); - MDBuilder MDB(Ctx); - - SmallVector<std::pair<StringRef, uint64_t>> ProfStatsVec; - if (FunctionSamples::ProfileIsProbeBased) { - ProfStatsVec.emplace_back("NumMismatchedFuncHash", NumMismatchedFuncHash); - ProfStatsVec.emplace_back("TotalProfiledFunc", TotalProfiledFunc); - ProfStatsVec.emplace_back("MismatchedFuncHashSamples", - MismatchedFuncHashSamples); - ProfStatsVec.emplace_back("TotalFuncHashSamples", TotalFuncHashSamples); - } - - ProfStatsVec.emplace_back("NumMismatchedCallsites", NumMismatchedCallsites); - ProfStatsVec.emplace_back("TotalProfiledCallsites", TotalProfiledCallsites); - ProfStatsVec.emplace_back("MismatchedCallsiteSamples", - MismatchedCallsiteSamples); - ProfStatsVec.emplace_back("TotalCallsiteSamples", TotalCallsiteSamples); - - auto *MD = MDB.createLLVMStats(ProfStatsVec); - auto *NMD = M.getOrInsertNamedMetadata("llvm.stats"); - NMD->addOperand(MD); - } -} - -void SampleProfileMatcher::distributeIRToProfileLocationMap( - FunctionSamples &FS) { - const auto ProfileMappings = FuncMappings.find(FS.getFuncName()); - if (ProfileMappings != FuncMappings.end()) { - FS.setIRToProfileLocationMap(&(ProfileMappings->second)); - } - - for (auto &Inlinees : FS.getCallsiteSamples()) { - for (auto FS : Inlinees.second) { - distributeIRToProfileLocationMap(FS.second); + std::vector<Instruction *> InstsToDel; + for (auto &BB : F) { + for (auto &I : BB) { + if (isa<PseudoProbeInst>(&I)) + InstsToDel.push_back(&I); + } } - } -} - -// Use a central place to distribute the matching results. Outlined and inlined -// profile with the function name will be set to the same pointer. -void SampleProfileMatcher::distributeIRToProfileLocationMap() { - for (auto &I : Reader.getProfiles()) { - distributeIRToProfileLocationMap(I.second); + for (auto *I : InstsToDel) + I->eraseFromParent(); } } bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, - ProfileSummaryInfo *_PSI, - LazyCallGraph &CG) { + ProfileSummaryInfo *_PSI) { GUIDToFuncNameMapper Mapper(M, *Reader, GUIDToFuncNameMap); PSI = _PSI; @@ -2550,6 +2149,11 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, ProfileSummary::PSK_Sample); PSI->refresh(); } + + if (FunctionSamples::ProfileIsProbeBased && + rejectHighStalenessProfile(M, PSI, Reader->getProfiles())) + return false; + // Compute the total number of samples collected in this profile. for (const auto &I : Reader->getProfiles()) TotalCollectedSamples += I.second.getTotalSamples(); @@ -2581,13 +2185,18 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, } } } - assert(SymbolMap.count(FunctionId()) == 0 && - "No empty StringRef should be added in SymbolMap"); + // Stale profile matching. if (ReportProfileStaleness || PersistProfileStaleness || SalvageStaleProfile) { MatchingManager->runOnModule(); + MatchingManager->clearMatchingData(); } + assert(SymbolMap.count(FunctionId()) == 0 && + "No empty StringRef should be added in SymbolMap"); + assert((SalvageUnusedProfile || FuncNameToProfNameMap.empty()) && + "FuncNameToProfNameMap is not empty when --salvage-unused-profile is " + "not enabled"); bool retval = false; for (auto *F : buildFunctionOrder(M, CG)) { @@ -2602,6 +2211,9 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, notInlinedCallInfo) updateProfileCallee(pair.first, pair.second.entryCount); + if (RemoveProbeAfterProfileAnnotation && FunctionSamples::ProfileIsProbeBased) + removePseudoProbeInsts(M); + return retval; } @@ -2714,19 +2326,18 @@ PreservedAnalyses SampleProfileLoaderPass::run(Module &M, if (!FS) FS = vfs::getRealFileSystem(); + LazyCallGraph &CG = AM.getResult<LazyCallGraphAnalysis>(M); SampleProfileLoader SampleLoader( ProfileFileName.empty() ? SampleProfileFile : ProfileFileName, ProfileRemappingFileName.empty() ? SampleProfileRemappingFile : ProfileRemappingFileName, - LTOPhase, FS, GetAssumptionCache, GetTTI, GetTLI); - + LTOPhase, FS, GetAssumptionCache, GetTTI, GetTLI, CG); if (!SampleLoader.doInitialization(M, &FAM)) return PreservedAnalyses::all(); ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M); - LazyCallGraph &CG = AM.getResult<LazyCallGraphAnalysis>(M); - if (!SampleLoader.runOnModule(M, &AM, PSI, CG)) + if (!SampleLoader.runOnModule(M, &AM, PSI)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileMatcher.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileMatcher.cpp new file mode 100644 index 000000000000..312672e56b01 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileMatcher.cpp @@ -0,0 +1,922 @@ +//===- SampleProfileMatcher.cpp - Sampling-based Stale Profile Matcher ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the SampleProfileMatcher used for stale +// profile matching. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/SampleProfileMatcher.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/Support/CommandLine.h" + +using namespace llvm; +using namespace sampleprof; + +#define DEBUG_TYPE "sample-profile-matcher" + +static cl::opt<unsigned> FuncProfileSimilarityThreshold( + "func-profile-similarity-threshold", cl::Hidden, cl::init(80), + cl::desc("Consider a profile matches a function if the similarity of their " + "callee sequences is above the specified percentile.")); + +static cl::opt<unsigned> MinFuncCountForCGMatching( + "min-func-count-for-cg-matching", cl::Hidden, cl::init(5), + cl::desc("The minimum number of basic blocks required for a function to " + "run stale profile call graph matching.")); + +static cl::opt<unsigned> MinCallCountForCGMatching( + "min-call-count-for-cg-matching", cl::Hidden, cl::init(3), + cl::desc("The minimum number of call anchors required for a function to " + "run stale profile call graph matching.")); + +extern cl::opt<bool> SalvageStaleProfile; +extern cl::opt<bool> SalvageUnusedProfile; +extern cl::opt<bool> PersistProfileStaleness; +extern cl::opt<bool> ReportProfileStaleness; + +static cl::opt<unsigned> SalvageStaleProfileMaxCallsites( + "salvage-stale-profile-max-callsites", cl::Hidden, cl::init(UINT_MAX), + cl::desc("The maximum number of callsites in a function, above which stale " + "profile matching will be skipped.")); + +void SampleProfileMatcher::findIRAnchors(const Function &F, + AnchorMap &IRAnchors) const { + // For inlined code, recover the original callsite and callee by finding the + // top-level inline frame. e.g. For frame stack "main:1 @ foo:2 @ bar:3", the + // top-level frame is "main:1", the callsite is "1" and the callee is "foo". + auto FindTopLevelInlinedCallsite = [](const DILocation *DIL) { + assert((DIL && DIL->getInlinedAt()) && "No inlined callsite"); + const DILocation *PrevDIL = nullptr; + do { + PrevDIL = DIL; + DIL = DIL->getInlinedAt(); + } while (DIL->getInlinedAt()); + + LineLocation Callsite = FunctionSamples::getCallSiteIdentifier( + DIL, FunctionSamples::ProfileIsFS); + StringRef CalleeName = PrevDIL->getSubprogramLinkageName(); + return std::make_pair(Callsite, FunctionId(CalleeName)); + }; + + auto GetCanonicalCalleeName = [](const CallBase *CB) { + StringRef CalleeName = UnknownIndirectCallee; + if (Function *Callee = CB->getCalledFunction()) + CalleeName = FunctionSamples::getCanonicalFnName(Callee->getName()); + return CalleeName; + }; + + // Extract profile matching anchors in the IR. + for (auto &BB : F) { + for (auto &I : BB) { + DILocation *DIL = I.getDebugLoc(); + if (!DIL) + continue; + + if (FunctionSamples::ProfileIsProbeBased) { + if (auto Probe = extractProbe(I)) { + // Flatten inlined IR for the matching. + if (DIL->getInlinedAt()) { + IRAnchors.emplace(FindTopLevelInlinedCallsite(DIL)); + } else { + // Use empty StringRef for basic block probe. + StringRef CalleeName; + if (const auto *CB = dyn_cast<CallBase>(&I)) { + // Skip the probe inst whose callee name is "llvm.pseudoprobe". + if (!isa<IntrinsicInst>(&I)) + CalleeName = GetCanonicalCalleeName(CB); + } + LineLocation Loc = LineLocation(Probe->Id, 0); + IRAnchors.emplace(Loc, FunctionId(CalleeName)); + } + } + } else { + // TODO: For line-number based profile(AutoFDO), currently only support + // find callsite anchors. In future, we need to parse all the non-call + // instructions to extract the line locations for profile matching. + if (!isa<CallBase>(&I) || isa<IntrinsicInst>(&I)) + continue; + + if (DIL->getInlinedAt()) { + IRAnchors.emplace(FindTopLevelInlinedCallsite(DIL)); + } else { + LineLocation Callsite = FunctionSamples::getCallSiteIdentifier( + DIL, FunctionSamples::ProfileIsFS); + StringRef CalleeName = GetCanonicalCalleeName(dyn_cast<CallBase>(&I)); + IRAnchors.emplace(Callsite, FunctionId(CalleeName)); + } + } + } + } +} + +void SampleProfileMatcher::findProfileAnchors(const FunctionSamples &FS, + AnchorMap &ProfileAnchors) const { + auto isInvalidLineOffset = [](uint32_t LineOffset) { + return LineOffset & 0x8000; + }; + + auto InsertAnchor = [](const LineLocation &Loc, const FunctionId &CalleeName, + AnchorMap &ProfileAnchors) { + auto Ret = ProfileAnchors.try_emplace(Loc, CalleeName); + if (!Ret.second) { + // For multiple callees, which indicates it's an indirect call, we use a + // dummy name(UnknownIndirectCallee) as the indrect callee name. + Ret.first->second = FunctionId(UnknownIndirectCallee); + } + }; + + for (const auto &I : FS.getBodySamples()) { + const LineLocation &Loc = I.first; + if (isInvalidLineOffset(Loc.LineOffset)) + continue; + for (const auto &C : I.second.getCallTargets()) + InsertAnchor(Loc, C.first, ProfileAnchors); + } + + for (const auto &I : FS.getCallsiteSamples()) { + const LineLocation &Loc = I.first; + if (isInvalidLineOffset(Loc.LineOffset)) + continue; + for (const auto &C : I.second) + InsertAnchor(Loc, C.first, ProfileAnchors); + } +} + +bool SampleProfileMatcher::functionHasProfile(const FunctionId &IRFuncName, + Function *&FuncWithoutProfile) { + FuncWithoutProfile = nullptr; + auto R = FunctionsWithoutProfile.find(IRFuncName); + if (R != FunctionsWithoutProfile.end()) + FuncWithoutProfile = R->second; + return !FuncWithoutProfile; +} + +bool SampleProfileMatcher::isProfileUnused(const FunctionId &ProfileFuncName) { + return SymbolMap->find(ProfileFuncName) == SymbolMap->end(); +} + +bool SampleProfileMatcher::functionMatchesProfile( + const FunctionId &IRFuncName, const FunctionId &ProfileFuncName, + bool FindMatchedProfileOnly) { + if (IRFuncName == ProfileFuncName) + return true; + if (!SalvageUnusedProfile) + return false; + + // If IR function doesn't have profile and the profile is unused, try + // matching them. + Function *IRFunc = nullptr; + if (functionHasProfile(IRFuncName, IRFunc) || + !isProfileUnused(ProfileFuncName)) + return false; + + assert(FunctionId(IRFunc->getName()) != ProfileFuncName && + "IR function should be different from profile function to match"); + return functionMatchesProfile(*IRFunc, ProfileFuncName, + FindMatchedProfileOnly); +} + +LocToLocMap +SampleProfileMatcher::longestCommonSequence(const AnchorList &AnchorList1, + const AnchorList &AnchorList2, + bool MatchUnusedFunction) { + int32_t Size1 = AnchorList1.size(), Size2 = AnchorList2.size(), + MaxDepth = Size1 + Size2; + auto Index = [&](int32_t I) { return I + MaxDepth; }; + + LocToLocMap EqualLocations; + if (MaxDepth == 0) + return EqualLocations; + + // Backtrack the SES result. + auto Backtrack = [&](const std::vector<std::vector<int32_t>> &Trace, + const AnchorList &AnchorList1, + const AnchorList &AnchorList2, + LocToLocMap &EqualLocations) { + int32_t X = Size1, Y = Size2; + for (int32_t Depth = Trace.size() - 1; X > 0 || Y > 0; Depth--) { + const auto &P = Trace[Depth]; + int32_t K = X - Y; + int32_t PrevK = K; + if (K == -Depth || (K != Depth && P[Index(K - 1)] < P[Index(K + 1)])) + PrevK = K + 1; + else + PrevK = K - 1; + + int32_t PrevX = P[Index(PrevK)]; + int32_t PrevY = PrevX - PrevK; + while (X > PrevX && Y > PrevY) { + X--; + Y--; + EqualLocations.insert({AnchorList1[X].first, AnchorList2[Y].first}); + } + + if (Depth == 0) + break; + + if (Y == PrevY) + X--; + else if (X == PrevX) + Y--; + X = PrevX; + Y = PrevY; + } + }; + + // The greedy LCS/SES algorithm. + + // An array contains the endpoints of the furthest reaching D-paths. + std::vector<int32_t> V(2 * MaxDepth + 1, -1); + V[Index(1)] = 0; + // Trace is used to backtrack the SES result. + std::vector<std::vector<int32_t>> Trace; + for (int32_t Depth = 0; Depth <= MaxDepth; Depth++) { + Trace.push_back(V); + for (int32_t K = -Depth; K <= Depth; K += 2) { + int32_t X = 0, Y = 0; + if (K == -Depth || (K != Depth && V[Index(K - 1)] < V[Index(K + 1)])) + X = V[Index(K + 1)]; + else + X = V[Index(K - 1)] + 1; + Y = X - K; + while (X < Size1 && Y < Size2 && + functionMatchesProfile( + AnchorList1[X].second, AnchorList2[Y].second, + !MatchUnusedFunction /* Find matched function only */)) + X++, Y++; + + V[Index(K)] = X; + + if (X >= Size1 && Y >= Size2) { + // Length of an SES is D. + Backtrack(Trace, AnchorList1, AnchorList2, EqualLocations); + return EqualLocations; + } + } + } + // Length of an SES is greater than MaxDepth. + return EqualLocations; +} + +void SampleProfileMatcher::matchNonCallsiteLocs( + const LocToLocMap &MatchedAnchors, const AnchorMap &IRAnchors, + LocToLocMap &IRToProfileLocationMap) { + auto InsertMatching = [&](const LineLocation &From, const LineLocation &To) { + // Skip the unchanged location mapping to save memory. + if (From != To) + IRToProfileLocationMap.insert({From, To}); + }; + + // Use function's beginning location as the initial anchor. + int32_t LocationDelta = 0; + SmallVector<LineLocation> LastMatchedNonAnchors; + for (const auto &IR : IRAnchors) { + const auto &Loc = IR.first; + bool IsMatchedAnchor = false; + // Match the anchor location in lexical order. + auto R = MatchedAnchors.find(Loc); + if (R != MatchedAnchors.end()) { + const auto &Candidate = R->second; + InsertMatching(Loc, Candidate); + LLVM_DEBUG(dbgs() << "Callsite with callee:" << IR.second.stringRef() + << " is matched from " << Loc << " to " << Candidate + << "\n"); + LocationDelta = Candidate.LineOffset - Loc.LineOffset; + + // Match backwards for non-anchor locations. + // The locations in LastMatchedNonAnchors have been matched forwards + // based on the previous anchor, spilt it evenly and overwrite the + // second half based on the current anchor. + for (size_t I = (LastMatchedNonAnchors.size() + 1) / 2; + I < LastMatchedNonAnchors.size(); I++) { + const auto &L = LastMatchedNonAnchors[I]; + uint32_t CandidateLineOffset = L.LineOffset + LocationDelta; + LineLocation Candidate(CandidateLineOffset, L.Discriminator); + InsertMatching(L, Candidate); + LLVM_DEBUG(dbgs() << "Location is rematched backwards from " << L + << " to " << Candidate << "\n"); + } + + IsMatchedAnchor = true; + LastMatchedNonAnchors.clear(); + } + + // Match forwards for non-anchor locations. + if (!IsMatchedAnchor) { + uint32_t CandidateLineOffset = Loc.LineOffset + LocationDelta; + LineLocation Candidate(CandidateLineOffset, Loc.Discriminator); + InsertMatching(Loc, Candidate); + LLVM_DEBUG(dbgs() << "Location is matched from " << Loc << " to " + << Candidate << "\n"); + LastMatchedNonAnchors.emplace_back(Loc); + } + } +} + +// Filter the non-call locations from IRAnchors and ProfileAnchors and write +// them into a list for random access later. +void SampleProfileMatcher::getFilteredAnchorList( + const AnchorMap &IRAnchors, const AnchorMap &ProfileAnchors, + AnchorList &FilteredIRAnchorsList, AnchorList &FilteredProfileAnchorList) { + for (const auto &I : IRAnchors) { + if (I.second.stringRef().empty()) + continue; + FilteredIRAnchorsList.emplace_back(I); + } + + for (const auto &I : ProfileAnchors) + FilteredProfileAnchorList.emplace_back(I); +} + +// Call target name anchor based profile fuzzy matching. +// Input: +// For IR locations, the anchor is the callee name of direct callsite; For +// profile locations, it's the call target name for BodySamples or inlinee's +// profile name for CallsiteSamples. +// Matching heuristic: +// First match all the anchors using the diff algorithm, then split the +// non-anchor locations between the two anchors evenly, first half are matched +// based on the start anchor, second half are matched based on the end anchor. +// For example, given: +// IR locations: [1, 2(foo), 3, 5, 6(bar), 7] +// Profile locations: [1, 2, 3(foo), 4, 7, 8(bar), 9] +// The matching gives: +// [1, 2(foo), 3, 5, 6(bar), 7] +// | | | | | | +// [1, 2, 3(foo), 4, 7, 8(bar), 9] +// The output mapping: [2->3, 3->4, 5->7, 6->8, 7->9]. +void SampleProfileMatcher::runStaleProfileMatching( + const Function &F, const AnchorMap &IRAnchors, + const AnchorMap &ProfileAnchors, LocToLocMap &IRToProfileLocationMap, + bool RunCFGMatching, bool RunCGMatching) { + if (!RunCFGMatching && !RunCGMatching) + return; + LLVM_DEBUG(dbgs() << "Run stale profile matching for " << F.getName() + << "\n"); + assert(IRToProfileLocationMap.empty() && + "Run stale profile matching only once per function"); + + AnchorList FilteredProfileAnchorList; + AnchorList FilteredIRAnchorsList; + getFilteredAnchorList(IRAnchors, ProfileAnchors, FilteredIRAnchorsList, + FilteredProfileAnchorList); + + if (FilteredIRAnchorsList.empty() || FilteredProfileAnchorList.empty()) + return; + + if (FilteredIRAnchorsList.size() > SalvageStaleProfileMaxCallsites || + FilteredProfileAnchorList.size() > SalvageStaleProfileMaxCallsites) { + LLVM_DEBUG(dbgs() << "Skip stale profile matching for " << F.getName() + << " because the number of callsites in the IR is " + << FilteredIRAnchorsList.size() + << " and in the profile is " + << FilteredProfileAnchorList.size() << "\n"); + return; + } + + // Match the callsite anchors by finding the longest common subsequence + // between IR and profile. + // Define a match between two anchors as follows: + // 1) The function names of anchors are the same. + // 2) The similarity between the anchor functions is above a threshold if + // RunCGMatching is set. + // For 2), we only consider the anchor functions from IR and profile don't + // appear on either side to reduce the matching scope. Note that we need to + // use IR anchor as base(A side) to align with the order of + // IRToProfileLocationMap. + LocToLocMap MatchedAnchors = + longestCommonSequence(FilteredIRAnchorsList, FilteredProfileAnchorList, + RunCGMatching /* Match unused functions */); + + // CFG level matching: + // Apply the callsite matchings to infer matching for the basic + // block(non-callsite) locations and write the result to + // IRToProfileLocationMap. + if (RunCFGMatching) + matchNonCallsiteLocs(MatchedAnchors, IRAnchors, IRToProfileLocationMap); +} + +void SampleProfileMatcher::runOnFunction(Function &F) { + // We need to use flattened function samples for matching. + // Unlike IR, which includes all callsites from the source code, the callsites + // in profile only show up when they are hit by samples, i,e. the profile + // callsites in one context may differ from those in another context. To get + // the maximum number of callsites, we merge the function profiles from all + // contexts, aka, the flattened profile to find profile anchors. + const auto *FSFlattened = getFlattenedSamplesFor(F); + if (SalvageUnusedProfile && !FSFlattened) { + // Apply the matching in place to find the new function's matched profile. + // TODO: For extended profile format, if a function profile is unused and + // it's top-level, even if the profile is matched, it's not found in the + // profile. This is because sample reader only read the used profile at the + // beginning, we need to support loading the profile on-demand in future. + auto R = FuncToProfileNameMap.find(&F); + if (R != FuncToProfileNameMap.end()) + FSFlattened = getFlattenedSamplesFor(R->second); + } + if (!FSFlattened) + return; + + // Anchors for IR. It's a map from IR location to callee name, callee name is + // empty for non-call instruction and use a dummy name(UnknownIndirectCallee) + // for unknown indrect callee name. + AnchorMap IRAnchors; + findIRAnchors(F, IRAnchors); + // Anchors for profile. It's a map from callsite location to a set of callee + // name. + AnchorMap ProfileAnchors; + findProfileAnchors(*FSFlattened, ProfileAnchors); + + // Compute the callsite match states for profile staleness report. + if (ReportProfileStaleness || PersistProfileStaleness) + recordCallsiteMatchStates(F, IRAnchors, ProfileAnchors, nullptr); + + if (!SalvageStaleProfile) + return; + // For probe-based profiles, run matching only when profile checksum is + // mismatched. + bool ChecksumMismatch = FunctionSamples::ProfileIsProbeBased && + !ProbeManager->profileIsValid(F, *FSFlattened); + bool RunCFGMatching = + !FunctionSamples::ProfileIsProbeBased || ChecksumMismatch; + bool RunCGMatching = SalvageUnusedProfile; + // For imported functions, the checksum metadata(pseudo_probe_desc) are + // dropped, so we leverage function attribute(profile-checksum-mismatch) to + // transfer the info: add the attribute during pre-link phase and check it + // during post-link phase(see "profileIsValid"). + if (ChecksumMismatch && LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink) + F.addFnAttr("profile-checksum-mismatch"); + + // The matching result will be saved to IRToProfileLocationMap, create a + // new map for each function. + auto &IRToProfileLocationMap = getIRToProfileLocationMap(F); + runStaleProfileMatching(F, IRAnchors, ProfileAnchors, IRToProfileLocationMap, + RunCFGMatching, RunCGMatching); + // Find and update callsite match states after matching. + if (RunCFGMatching && (ReportProfileStaleness || PersistProfileStaleness)) + recordCallsiteMatchStates(F, IRAnchors, ProfileAnchors, + &IRToProfileLocationMap); +} + +void SampleProfileMatcher::recordCallsiteMatchStates( + const Function &F, const AnchorMap &IRAnchors, + const AnchorMap &ProfileAnchors, + const LocToLocMap *IRToProfileLocationMap) { + bool IsPostMatch = IRToProfileLocationMap != nullptr; + auto &CallsiteMatchStates = + FuncCallsiteMatchStates[FunctionSamples::getCanonicalFnName(F.getName())]; + + auto MapIRLocToProfileLoc = [&](const LineLocation &IRLoc) { + // IRToProfileLocationMap is null in pre-match phrase. + if (!IRToProfileLocationMap) + return IRLoc; + const auto &ProfileLoc = IRToProfileLocationMap->find(IRLoc); + if (ProfileLoc != IRToProfileLocationMap->end()) + return ProfileLoc->second; + else + return IRLoc; + }; + + for (const auto &I : IRAnchors) { + // After fuzzy profile matching, use the matching result to remap the + // current IR callsite. + const auto &ProfileLoc = MapIRLocToProfileLoc(I.first); + const auto &IRCalleeId = I.second; + const auto &It = ProfileAnchors.find(ProfileLoc); + if (It == ProfileAnchors.end()) + continue; + const auto &ProfCalleeId = It->second; + if (IRCalleeId == ProfCalleeId) { + auto It = CallsiteMatchStates.find(ProfileLoc); + if (It == CallsiteMatchStates.end()) + CallsiteMatchStates.emplace(ProfileLoc, MatchState::InitialMatch); + else if (IsPostMatch) { + if (It->second == MatchState::InitialMatch) + It->second = MatchState::UnchangedMatch; + else if (It->second == MatchState::InitialMismatch) + It->second = MatchState::RecoveredMismatch; + } + } + } + + // Check if there are any callsites in the profile that does not match to any + // IR callsites. + for (const auto &I : ProfileAnchors) { + const auto &Loc = I.first; + assert(!I.second.stringRef().empty() && "Callees should not be empty"); + auto It = CallsiteMatchStates.find(Loc); + if (It == CallsiteMatchStates.end()) + CallsiteMatchStates.emplace(Loc, MatchState::InitialMismatch); + else if (IsPostMatch) { + // Update the state if it's not matched(UnchangedMatch or + // RecoveredMismatch). + if (It->second == MatchState::InitialMismatch) + It->second = MatchState::UnchangedMismatch; + else if (It->second == MatchState::InitialMatch) + It->second = MatchState::RemovedMatch; + } + } +} + +void SampleProfileMatcher::countMismatchedFuncSamples(const FunctionSamples &FS, + bool IsTopLevel) { + const auto *FuncDesc = ProbeManager->getDesc(FS.getGUID()); + // Skip the function that is external or renamed. + if (!FuncDesc) + return; + + if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) { + if (IsTopLevel) + NumStaleProfileFunc++; + // Given currently all probe ids are after block probe ids, once the + // checksum is mismatched, it's likely all the callites are mismatched and + // dropped. We conservatively count all the samples as mismatched and stop + // counting the inlinees' profiles. + MismatchedFunctionSamples += FS.getTotalSamples(); + return; + } + + // Even the current-level function checksum is matched, it's possible that the + // nested inlinees' checksums are mismatched that affect the inlinee's sample + // loading, we need to go deeper to check the inlinees' function samples. + // Similarly, count all the samples as mismatched if the inlinee's checksum is + // mismatched using this recursive function. + for (const auto &I : FS.getCallsiteSamples()) + for (const auto &CS : I.second) + countMismatchedFuncSamples(CS.second, false); +} + +void SampleProfileMatcher::countMismatchedCallsiteSamples( + const FunctionSamples &FS) { + auto It = FuncCallsiteMatchStates.find(FS.getFuncName()); + // Skip it if no mismatched callsite or this is an external function. + if (It == FuncCallsiteMatchStates.end() || It->second.empty()) + return; + const auto &CallsiteMatchStates = It->second; + + auto findMatchState = [&](const LineLocation &Loc) { + auto It = CallsiteMatchStates.find(Loc); + if (It == CallsiteMatchStates.end()) + return MatchState::Unknown; + return It->second; + }; + + auto AttributeMismatchedSamples = [&](const enum MatchState &State, + uint64_t Samples) { + if (isMismatchState(State)) + MismatchedCallsiteSamples += Samples; + else if (State == MatchState::RecoveredMismatch) + RecoveredCallsiteSamples += Samples; + }; + + // The non-inlined callsites are saved in the body samples of function + // profile, go through it to count the non-inlined callsite samples. + for (const auto &I : FS.getBodySamples()) + AttributeMismatchedSamples(findMatchState(I.first), I.second.getSamples()); + + // Count the inlined callsite samples. + for (const auto &I : FS.getCallsiteSamples()) { + auto State = findMatchState(I.first); + uint64_t CallsiteSamples = 0; + for (const auto &CS : I.second) + CallsiteSamples += CS.second.getTotalSamples(); + AttributeMismatchedSamples(State, CallsiteSamples); + + if (isMismatchState(State)) + continue; + + // When the current level of inlined call site matches the profiled call + // site, we need to go deeper along the inline tree to count mismatches from + // lower level inlinees. + for (const auto &CS : I.second) + countMismatchedCallsiteSamples(CS.second); + } +} + +void SampleProfileMatcher::countMismatchCallsites(const FunctionSamples &FS) { + auto It = FuncCallsiteMatchStates.find(FS.getFuncName()); + // Skip it if no mismatched callsite or this is an external function. + if (It == FuncCallsiteMatchStates.end() || It->second.empty()) + return; + const auto &MatchStates = It->second; + [[maybe_unused]] bool OnInitialState = + isInitialState(MatchStates.begin()->second); + for (const auto &I : MatchStates) { + TotalProfiledCallsites++; + assert( + (OnInitialState ? isInitialState(I.second) : isFinalState(I.second)) && + "Profile matching state is inconsistent"); + + if (isMismatchState(I.second)) + NumMismatchedCallsites++; + else if (I.second == MatchState::RecoveredMismatch) + NumRecoveredCallsites++; + } +} + +void SampleProfileMatcher::countCallGraphRecoveredSamples( + const FunctionSamples &FS, + std::unordered_set<FunctionId> &CallGraphRecoveredProfiles) { + if (CallGraphRecoveredProfiles.count(FS.getFunction())) { + NumCallGraphRecoveredFuncSamples += FS.getTotalSamples(); + return; + } + + for (const auto &CM : FS.getCallsiteSamples()) { + for (const auto &CS : CM.second) { + countCallGraphRecoveredSamples(CS.second, CallGraphRecoveredProfiles); + } + } +} + +void SampleProfileMatcher::computeAndReportProfileStaleness() { + if (!ReportProfileStaleness && !PersistProfileStaleness) + return; + + std::unordered_set<FunctionId> CallGraphRecoveredProfiles; + if (SalvageUnusedProfile) { + for (const auto &I : FuncToProfileNameMap) { + CallGraphRecoveredProfiles.insert(I.second); + if (GlobalValue::isAvailableExternallyLinkage(I.first->getLinkage())) + continue; + NumCallGraphRecoveredProfiledFunc++; + } + } + + // Count profile mismatches for profile staleness report. + for (const auto &F : M) { + if (skipProfileForFunction(F)) + continue; + // As the stats will be merged by linker, skip reporting the metrics for + // imported functions to avoid repeated counting. + if (GlobalValue::isAvailableExternallyLinkage(F.getLinkage())) + continue; + const auto *FS = Reader.getSamplesFor(F); + if (!FS) + continue; + TotalProfiledFunc++; + TotalFunctionSamples += FS->getTotalSamples(); + + if (SalvageUnusedProfile && !CallGraphRecoveredProfiles.empty()) + countCallGraphRecoveredSamples(*FS, CallGraphRecoveredProfiles); + + // Checksum mismatch is only used in pseudo-probe mode. + if (FunctionSamples::ProfileIsProbeBased) + countMismatchedFuncSamples(*FS, true); + + // Count mismatches and samples for calliste. + countMismatchCallsites(*FS); + countMismatchedCallsiteSamples(*FS); + } + + if (ReportProfileStaleness) { + if (FunctionSamples::ProfileIsProbeBased) { + errs() << "(" << NumStaleProfileFunc << "/" << TotalProfiledFunc + << ") of functions' profile are invalid and (" + << MismatchedFunctionSamples << "/" << TotalFunctionSamples + << ") of samples are discarded due to function hash mismatch.\n"; + } + if (SalvageUnusedProfile) { + errs() << "(" << NumCallGraphRecoveredProfiledFunc << "/" + << TotalProfiledFunc << ") of functions' profile are matched and (" + << NumCallGraphRecoveredFuncSamples << "/" << TotalFunctionSamples + << ") of samples are reused by call graph matching.\n"; + } + + errs() << "(" << (NumMismatchedCallsites + NumRecoveredCallsites) << "/" + << TotalProfiledCallsites + << ") of callsites' profile are invalid and (" + << (MismatchedCallsiteSamples + RecoveredCallsiteSamples) << "/" + << TotalFunctionSamples + << ") of samples are discarded due to callsite location mismatch.\n"; + errs() << "(" << NumRecoveredCallsites << "/" + << (NumRecoveredCallsites + NumMismatchedCallsites) + << ") of callsites and (" << RecoveredCallsiteSamples << "/" + << (RecoveredCallsiteSamples + MismatchedCallsiteSamples) + << ") of samples are recovered by stale profile matching.\n"; + } + + if (PersistProfileStaleness) { + LLVMContext &Ctx = M.getContext(); + MDBuilder MDB(Ctx); + + SmallVector<std::pair<StringRef, uint64_t>> ProfStatsVec; + if (FunctionSamples::ProfileIsProbeBased) { + ProfStatsVec.emplace_back("NumStaleProfileFunc", NumStaleProfileFunc); + ProfStatsVec.emplace_back("TotalProfiledFunc", TotalProfiledFunc); + ProfStatsVec.emplace_back("MismatchedFunctionSamples", + MismatchedFunctionSamples); + ProfStatsVec.emplace_back("TotalFunctionSamples", TotalFunctionSamples); + } + + if (SalvageUnusedProfile) { + ProfStatsVec.emplace_back("NumCallGraphRecoveredProfiledFunc", + NumCallGraphRecoveredProfiledFunc); + ProfStatsVec.emplace_back("NumCallGraphRecoveredFuncSamples", + NumCallGraphRecoveredFuncSamples); + } + + ProfStatsVec.emplace_back("NumMismatchedCallsites", NumMismatchedCallsites); + ProfStatsVec.emplace_back("NumRecoveredCallsites", NumRecoveredCallsites); + ProfStatsVec.emplace_back("TotalProfiledCallsites", TotalProfiledCallsites); + ProfStatsVec.emplace_back("MismatchedCallsiteSamples", + MismatchedCallsiteSamples); + ProfStatsVec.emplace_back("RecoveredCallsiteSamples", + RecoveredCallsiteSamples); + + auto *MD = MDB.createLLVMStats(ProfStatsVec); + auto *NMD = M.getOrInsertNamedMetadata("llvm.stats"); + NMD->addOperand(MD); + } +} + +void SampleProfileMatcher::findFunctionsWithoutProfile() { + // TODO: Support MD5 profile. + if (FunctionSamples::UseMD5) + return; + StringSet<> NamesInProfile; + if (auto NameTable = Reader.getNameTable()) { + for (auto Name : *NameTable) + NamesInProfile.insert(Name.stringRef()); + } + + for (auto &F : M) { + // Skip declarations, as even if the function can be matched, we have + // nothing to do with it. + if (F.isDeclaration()) + continue; + + StringRef CanonFName = FunctionSamples::getCanonicalFnName(F.getName()); + const auto *FS = getFlattenedSamplesFor(F); + if (FS) + continue; + + // For extended binary, functions fully inlined may not be loaded in the + // top-level profile, so check the NameTable which has the all symbol names + // in profile. + if (NamesInProfile.count(CanonFName)) + continue; + + // For extended binary, non-profiled function symbols are in the profile + // symbol list table. + if (PSL && PSL->contains(CanonFName)) + continue; + + LLVM_DEBUG(dbgs() << "Function " << CanonFName + << " is not in profile or profile symbol list.\n"); + FunctionsWithoutProfile[FunctionId(CanonFName)] = &F; + } +} + +bool SampleProfileMatcher::functionMatchesProfileHelper( + const Function &IRFunc, const FunctionId &ProfFunc) { + // The value is in the range [0, 1]. The bigger the value is, the more similar + // two sequences are. + float Similarity = 0.0; + + const auto *FSFlattened = getFlattenedSamplesFor(ProfFunc); + if (!FSFlattened) + return false; + // The check for similarity or checksum may not be reliable if the function is + // tiny, we use the number of basic block as a proxy for the function + // complexity and skip the matching if it's too small. + if (IRFunc.size() < MinFuncCountForCGMatching || + FSFlattened->getBodySamples().size() < MinFuncCountForCGMatching) + return false; + + // For probe-based function, we first trust the checksum info. If the checksum + // doesn't match, we continue checking for similarity. + if (FunctionSamples::ProfileIsProbeBased) { + const auto *FuncDesc = ProbeManager->getDesc(IRFunc); + if (FuncDesc && + !ProbeManager->profileIsHashMismatched(*FuncDesc, *FSFlattened)) { + LLVM_DEBUG(dbgs() << "The checksums for " << IRFunc.getName() + << "(IR) and " << ProfFunc << "(Profile) match.\n"); + + return true; + } + } + + AnchorMap IRAnchors; + findIRAnchors(IRFunc, IRAnchors); + AnchorMap ProfileAnchors; + findProfileAnchors(*FSFlattened, ProfileAnchors); + + AnchorList FilteredIRAnchorsList; + AnchorList FilteredProfileAnchorList; + getFilteredAnchorList(IRAnchors, ProfileAnchors, FilteredIRAnchorsList, + FilteredProfileAnchorList); + + // Similarly skip the matching if the num of anchors is not enough. + if (FilteredIRAnchorsList.size() < MinCallCountForCGMatching || + FilteredProfileAnchorList.size() < MinCallCountForCGMatching) + return false; + + // Use the diff algorithm to find the LCS between IR and profile. + + // Don't recursively match the callee function to avoid infinite matching, + // callee functions will be handled later since it's processed in top-down + // order . + LocToLocMap MatchedAnchors = + longestCommonSequence(FilteredIRAnchorsList, FilteredProfileAnchorList, + false /* Match unused functions */); + + Similarity = + static_cast<float>(MatchedAnchors.size()) * 2 / + (FilteredIRAnchorsList.size() + FilteredProfileAnchorList.size()); + + LLVM_DEBUG(dbgs() << "The similarity between " << IRFunc.getName() + << "(IR) and " << ProfFunc << "(profile) is " + << format("%.2f", Similarity) << "\n"); + assert((Similarity >= 0 && Similarity <= 1.0) && + "Similarity value should be in [0, 1]"); + return Similarity * 100 > FuncProfileSimilarityThreshold; +} + +// If FindMatchedProfileOnly is set to true, only use the processed function +// results. This is used for skipping the repeated recursive matching. +bool SampleProfileMatcher::functionMatchesProfile(Function &IRFunc, + const FunctionId &ProfFunc, + bool FindMatchedProfileOnly) { + auto R = FuncProfileMatchCache.find({&IRFunc, ProfFunc}); + if (R != FuncProfileMatchCache.end()) + return R->second; + + if (FindMatchedProfileOnly) + return false; + + bool Matched = functionMatchesProfileHelper(IRFunc, ProfFunc); + FuncProfileMatchCache[{&IRFunc, ProfFunc}] = Matched; + if (Matched) { + FuncToProfileNameMap[&IRFunc] = ProfFunc; + LLVM_DEBUG(dbgs() << "Function:" << IRFunc.getName() + << " matches profile:" << ProfFunc << "\n"); + } + + return Matched; +} + +void SampleProfileMatcher::runOnModule() { + ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles, + FunctionSamples::ProfileIsCS); + if (SalvageUnusedProfile) + findFunctionsWithoutProfile(); + + // Process the matching in top-down order so that the caller matching result + // can be used to the callee matching. + std::vector<Function *> TopDownFunctionList; + TopDownFunctionList.reserve(M.size()); + buildTopDownFuncOrder(CG, TopDownFunctionList); + for (auto *F : TopDownFunctionList) { + if (skipProfileForFunction(*F)) + continue; + runOnFunction(*F); + } + + // Update the data in SampleLoader. + if (SalvageUnusedProfile) + for (auto &I : FuncToProfileNameMap) { + assert(I.first && "New function is null"); + FunctionId FuncName(I.first->getName()); + FuncNameToProfNameMap->emplace(FuncName, I.second); + // We need to remove the old entry to avoid duplicating the function + // processing. + SymbolMap->erase(FuncName); + SymbolMap->emplace(I.second, I.first); + } + + if (SalvageStaleProfile) + distributeIRToProfileLocationMap(); + + computeAndReportProfileStaleness(); +} + +void SampleProfileMatcher::distributeIRToProfileLocationMap( + FunctionSamples &FS) { + const auto ProfileMappings = FuncMappings.find(FS.getFuncName()); + if (ProfileMappings != FuncMappings.end()) { + FS.setIRToProfileLocationMap(&(ProfileMappings->second)); + } + + for (auto &Callees : + const_cast<CallsiteSampleMap &>(FS.getCallsiteSamples())) { + for (auto &FS : Callees.second) { + distributeIRToProfileLocationMap(FS.second); + } + } +} + +// Use a central place to distribute the matching results. Outlined and inlined +// profile with the function name will be set to the same pointer. +void SampleProfileMatcher::distributeIRToProfileLocationMap() { + for (auto &I : Reader.getProfiles()) { + distributeIRToProfileLocationMap(I.second); + } +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp index 090e5560483e..b489d4fdaa21 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/SampleProfileProbe.cpp @@ -23,6 +23,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" #include "llvm/IR/PseudoProbe.h" #include "llvm/ProfileData/SampleProf.h" #include "llvm/Support/CRC.h" @@ -173,21 +174,113 @@ SampleProfileProber::SampleProfileProber(Function &Func, BlockProbeIds.clear(); CallProbeIds.clear(); LastProbeId = (uint32_t)PseudoProbeReservedId::Last; - computeProbeIdForBlocks(); - computeProbeIdForCallsites(); - computeCFGHash(); + + DenseSet<BasicBlock *> BlocksToIgnore; + DenseSet<BasicBlock *> BlocksAndCallsToIgnore; + computeBlocksToIgnore(BlocksToIgnore, BlocksAndCallsToIgnore); + + computeProbeId(BlocksToIgnore, BlocksAndCallsToIgnore); + computeCFGHash(BlocksToIgnore); +} + +// Two purposes to compute the blocks to ignore: +// 1. Reduce the IR size. +// 2. Make the instrumentation(checksum) stable. e.g. the frondend may +// generate unstable IR while optimizing nounwind attribute, some versions are +// optimized with the call-to-invoke conversion, while other versions do not. +// This discrepancy in probe ID could cause profile mismatching issues. +// Note that those ignored blocks are either cold blocks or new split blocks +// whose original blocks are instrumented, so it shouldn't degrade the profile +// quality. +void SampleProfileProber::computeBlocksToIgnore( + DenseSet<BasicBlock *> &BlocksToIgnore, + DenseSet<BasicBlock *> &BlocksAndCallsToIgnore) { + // Ignore the cold EH and unreachable blocks and calls. + computeEHOnlyBlocks(*F, BlocksAndCallsToIgnore); + findUnreachableBlocks(BlocksAndCallsToIgnore); + + BlocksToIgnore.insert(BlocksAndCallsToIgnore.begin(), + BlocksAndCallsToIgnore.end()); + + // Handle the call-to-invoke conversion case: make sure that the probe id and + // callsite id are consistent before and after the block split. For block + // probe, we only keep the head block probe id and ignore the block ids of the + // normal dests. For callsite probe, it's different to block probe, there is + // no additional callsite in the normal dests, so we don't ignore the + // callsites. + findInvokeNormalDests(BlocksToIgnore); +} + +// Unreachable blocks and calls are always cold, ignore them. +void SampleProfileProber::findUnreachableBlocks( + DenseSet<BasicBlock *> &BlocksToIgnore) { + for (auto &BB : *F) { + if (&BB != &F->getEntryBlock() && pred_size(&BB) == 0) + BlocksToIgnore.insert(&BB); + } +} + +// In call-to-invoke conversion, basic block can be split into multiple blocks, +// only instrument probe in the head block, ignore the normal dests. +void SampleProfileProber::findInvokeNormalDests( + DenseSet<BasicBlock *> &InvokeNormalDests) { + for (auto &BB : *F) { + auto *TI = BB.getTerminator(); + if (auto *II = dyn_cast<InvokeInst>(TI)) { + auto *ND = II->getNormalDest(); + InvokeNormalDests.insert(ND); + + // The normal dest and the try/catch block are connected by an + // unconditional branch. + while (pred_size(ND) == 1) { + auto *Pred = *pred_begin(ND); + if (succ_size(Pred) == 1) { + InvokeNormalDests.insert(Pred); + ND = Pred; + } else + break; + } + } + } +} + +// The call-to-invoke conversion splits the original block into a list of block, +// we need to compute the hash using the original block's successors to keep the +// CFG Hash consistent. For a given head block, we keep searching the +// succesor(normal dest or unconditional branch dest) to find the tail block, +// the tail block's successors are the original block's successors. +const Instruction *SampleProfileProber::getOriginalTerminator( + const BasicBlock *Head, const DenseSet<BasicBlock *> &BlocksToIgnore) { + auto *TI = Head->getTerminator(); + if (auto *II = dyn_cast<InvokeInst>(TI)) { + return getOriginalTerminator(II->getNormalDest(), BlocksToIgnore); + } else if (succ_size(Head) == 1 && + BlocksToIgnore.contains(*succ_begin(Head))) { + // Go to the unconditional branch dest. + return getOriginalTerminator(*succ_begin(Head), BlocksToIgnore); + } + return TI; } // Compute Hash value for the CFG: the lower 32 bits are CRC32 of the index // value of each BB in the CFG. The higher 32 bits record the number of edges // preceded by the number of indirect calls. // This is derived from FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash(). -void SampleProfileProber::computeCFGHash() { +void SampleProfileProber::computeCFGHash( + const DenseSet<BasicBlock *> &BlocksToIgnore) { std::vector<uint8_t> Indexes; JamCRC JC; for (auto &BB : *F) { - for (BasicBlock *Succ : successors(&BB)) { + if (BlocksToIgnore.contains(&BB)) + continue; + + auto *TI = getOriginalTerminator(&BB, BlocksToIgnore); + for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I) { + auto *Succ = TI->getSuccessor(I); auto Index = getBlockId(Succ); + // Ingore ignored-block(zero ID) to avoid unstable checksum. + if (Index == 0) + continue; for (int J = 0; J < 4; J++) Indexes.push_back((uint8_t)(Index >> (J * 8))); } @@ -207,27 +300,20 @@ void SampleProfileProber::computeCFGHash() { << ", Hash = " << FunctionHash << "\n"); } -void SampleProfileProber::computeProbeIdForBlocks() { - DenseSet<BasicBlock *> KnownColdBlocks; - computeEHOnlyBlocks(*F, KnownColdBlocks); - // Insert pseudo probe to non-cold blocks only. This will reduce IR size as - // well as the binary size while retaining the profile quality. - for (auto &BB : *F) { - ++LastProbeId; - if (!KnownColdBlocks.contains(&BB)) - BlockProbeIds[&BB] = LastProbeId; - } -} - -void SampleProfileProber::computeProbeIdForCallsites() { +void SampleProfileProber::computeProbeId( + const DenseSet<BasicBlock *> &BlocksToIgnore, + const DenseSet<BasicBlock *> &BlocksAndCallsToIgnore) { LLVMContext &Ctx = F->getContext(); Module *M = F->getParent(); for (auto &BB : *F) { + if (!BlocksToIgnore.contains(&BB)) + BlockProbeIds[&BB] = ++LastProbeId; + + if (BlocksAndCallsToIgnore.contains(&BB)) + continue; for (auto &I : BB) { - if (!isa<CallBase>(I)) - continue; - if (isa<IntrinsicInst>(&I)) + if (!isa<CallBase>(I) || isa<IntrinsicInst>(&I)) continue; // The current implementation uses the lower 16 bits of the discriminator @@ -258,7 +344,7 @@ uint32_t SampleProfileProber::getCallsiteId(const Instruction *Call) const { void SampleProfileProber::instrumentOneFunc(Function &F, TargetMachine *TM) { Module *M = F.getParent(); MDBuilder MDB(F.getContext()); - // Since the GUID from probe desc and inline stack are computed seperately, we + // Since the GUID from probe desc and inline stack are computed separately, we // need to make sure their names are consistent, so here also use the name // from debug info. StringRef FName = F.getName(); @@ -346,8 +432,8 @@ void SampleProfileProber::instrumentOneFunc(Function &F, TargetMachine *TM) { // and type of a callsite probe. This gets rid of the dependency on // plumbing a customized metadata through the codegen pipeline. uint32_t V = PseudoProbeDwarfDiscriminator::packProbeData( - Index, Type, 0, - PseudoProbeDwarfDiscriminator::FullDistributionFactor); + Index, Type, 0, PseudoProbeDwarfDiscriminator::FullDistributionFactor, + DIL->getBaseDiscriminator()); DIL = DIL->cloneWithDiscriminator(V); Call->setDebugLoc(DIL); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp index e5f9fa1dda88..9bf29c46938e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp @@ -575,16 +575,23 @@ bool writeThinLTOBitcode(raw_ostream &OS, raw_ostream *ThinLinkOS, } } // anonymous namespace - +extern bool WriteNewDbgInfoFormatToBitcode; PreservedAnalyses llvm::ThinLTOBitcodeWriterPass::run(Module &M, ModuleAnalysisManager &AM) { FunctionAnalysisManager &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + + ScopedDbgInfoFormatSetter FormatSetter(M, M.IsNewDbgInfoFormat && + WriteNewDbgInfoFormatToBitcode); + if (M.IsNewDbgInfoFormat) + M.removeDebugIntrinsicDeclarations(); + bool Changed = writeThinLTOBitcode( OS, ThinLinkOS, [&FAM](Function &F) -> AAResults & { return FAM.getResult<AAManager>(F); }, M, &AM.getResult<ModuleSummaryIndexAnalysis>(M)); + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index 01aba47cdbff..19bc841b1052 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -311,7 +311,7 @@ void wholeprogramdevirt::setAfterReturnValues( VirtualCallTarget::VirtualCallTarget(GlobalValue *Fn, const TypeMemberInfo *TM) : Fn(Fn), TM(TM), - IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), + IsBigEndian(Fn->getDataLayout().isBigEndian()), WasDevirt(false) {} namespace { @@ -434,7 +434,7 @@ struct VirtualCallSite { emitRemark(OptName, TargetName, OREGetter); CB.replaceAllUsesWith(New); if (auto *II = dyn_cast<InvokeInst>(&CB)) { - BranchInst::Create(II->getNormalDest(), &CB); + BranchInst::Create(II->getNormalDest(), CB.getIterator()); II->getUnwindDest()->removePredecessor(II->getParent()); } CB.eraseFromParent(); @@ -861,7 +861,7 @@ void llvm::updatePublicTypeTestCalls(Module &M, auto *CI = cast<CallInst>(U.getUser()); auto *NewCI = CallInst::Create( TypeTestFunc, {CI->getArgOperand(0), CI->getArgOperand(1)}, - std::nullopt, "", CI); + std::nullopt, "", CI->getIterator()); CI->replaceAllUsesWith(NewCI); CI->eraseFromParent(); } @@ -1066,17 +1066,10 @@ bool DevirtModule::tryFindVirtualCallTargets( GlobalObject::VCallVisibilityPublic) return false; - Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(), - TM.Offset + ByteOffset, M, TM.Bits->GV); - if (!Ptr) - return false; - - auto C = Ptr->stripPointerCasts(); - // Make sure this is a function or alias to a function. - auto Fn = dyn_cast<Function>(C); - auto A = dyn_cast<GlobalAlias>(C); - if (!Fn && A) - Fn = dyn_cast<Function>(A->getAliasee()); + Function *Fn = nullptr; + Constant *C = nullptr; + std::tie(Fn, C) = + getFunctionAtVTableOffset(TM.Bits->GV, TM.Offset + ByteOffset, M); if (!Fn) return false; @@ -1203,8 +1196,7 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, // function pointer to the devirtualized target. In case of a mismatch, // fall back to indirect call. if (DevirtCheckMode == WPDCheckMode::Fallback) { - MDNode *Weights = - MDBuilder(M.getContext()).createBranchWeights((1U << 20) - 1, 1); + MDNode *Weights = MDBuilder(M.getContext()).createLikelyBranchWeights(); // Version the indirect call site. If the called value is equal to the // given callee, 'NewInst' will be executed, otherwise the original call // site will be executed. @@ -1232,8 +1224,8 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, CB.setMetadata(LLVMContext::MD_callees, nullptr); if (CB.getCalledOperand() && CB.getOperandBundle(LLVMContext::OB_ptrauth)) { - auto *NewCS = - CallBase::removeOperandBundle(&CB, LLVMContext::OB_ptrauth, &CB); + auto *NewCS = CallBase::removeOperandBundle( + &CB, LLVMContext::OB_ptrauth, CB.getIterator()); CB.replaceAllUsesWith(NewCS); // Schedule for deletion at the end of pass run. CallsWithPtrAuthBundleRemoved.push_back(&CB); @@ -1624,7 +1616,7 @@ std::string DevirtModule::getGlobalName(VTableSlot Slot, for (uint64_t Arg : Args) OS << '_' << Arg; OS << '_' << Name; - return OS.str(); + return FullName; } bool DevirtModule::shouldExportConstantsAsAbsoluteSymbols() { @@ -1935,7 +1927,7 @@ void DevirtModule::rebuildGlobal(VTableBits &B) { // element (the original initializer). auto Alias = GlobalAlias::create( B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "", - ConstantExpr::getGetElementPtr( + ConstantExpr::getInBoundsGetElementPtr( NewInit->getType(), NewGV, ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0), ConstantInt::get(Int32Ty, 1)}), diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 8a00b75a1f74..0a55f4762fdf 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -819,30 +819,37 @@ static Instruction *foldNoWrapAdd(BinaryOperator &Add, Value *X; const APInt *C1, *C2; if (match(Op1, m_APInt(C1)) && - match(Op0, m_OneUse(m_ZExt(m_NUWAdd(m_Value(X), m_APInt(C2))))) && + match(Op0, m_ZExt(m_NUWAddLike(m_Value(X), m_APInt(C2)))) && C1->isNegative() && C1->sge(-C2->sext(C1->getBitWidth()))) { - Constant *NewC = - ConstantInt::get(X->getType(), *C2 + C1->trunc(C2->getBitWidth())); - return new ZExtInst(Builder.CreateNUWAdd(X, NewC), Ty); + APInt NewC = *C2 + C1->trunc(C2->getBitWidth()); + // If the smaller add will fold to zero, we don't need to check one use. + if (NewC.isZero()) + return new ZExtInst(X, Ty); + // Otherwise only do this if the existing zero extend will be removed. + if (Op0->hasOneUse()) + return new ZExtInst( + Builder.CreateNUWAdd(X, ConstantInt::get(X->getType(), NewC)), Ty); } // More general combining of constants in the wide type. // (sext (X +nsw NarrowC)) + C --> (sext X) + (sext(NarrowC) + C) + // or (zext nneg (X +nsw NarrowC)) + C --> (sext X) + (sext(NarrowC) + C) Constant *NarrowC; - if (match(Op0, m_OneUse(m_SExt(m_NSWAdd(m_Value(X), m_Constant(NarrowC)))))) { + if (match(Op0, m_OneUse(m_SExtLike( + m_NSWAddLike(m_Value(X), m_Constant(NarrowC)))))) { Value *WideC = Builder.CreateSExt(NarrowC, Ty); Value *NewC = Builder.CreateAdd(WideC, Op1C); Value *WideX = Builder.CreateSExt(X, Ty); return BinaryOperator::CreateAdd(WideX, NewC); } // (zext (X +nuw NarrowC)) + C --> (zext X) + (zext(NarrowC) + C) - if (match(Op0, m_OneUse(m_ZExt(m_NUWAdd(m_Value(X), m_Constant(NarrowC)))))) { + if (match(Op0, + m_OneUse(m_ZExt(m_NUWAddLike(m_Value(X), m_Constant(NarrowC)))))) { Value *WideC = Builder.CreateZExt(NarrowC, Ty); Value *NewC = Builder.CreateAdd(WideC, Op1C); Value *WideX = Builder.CreateZExt(X, Ty); return BinaryOperator::CreateAdd(WideX, NewC); } - return nullptr; } @@ -894,7 +901,7 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { const APInt *C; unsigned BitWidth = Ty->getScalarSizeInBits(); if (match(Op0, m_OneUse(m_AShr(m_Value(X), - m_SpecificIntAllowUndef(BitWidth - 1)))) && + m_SpecificIntAllowPoison(BitWidth - 1)))) && match(Op1, m_One())) return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty); @@ -903,8 +910,14 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { // (X | Op01C) + Op1C --> X + (Op01C + Op1C) iff the `or` is actually an `add` Constant *Op01C; - if (match(Op0, m_DisjointOr(m_Value(X), m_ImmConstant(Op01C)))) - return BinaryOperator::CreateAdd(X, ConstantExpr::getAdd(Op01C, Op1C)); + if (match(Op0, m_DisjointOr(m_Value(X), m_ImmConstant(Op01C)))) { + BinaryOperator *NewAdd = + BinaryOperator::CreateAdd(X, ConstantExpr::getAdd(Op01C, Op1C)); + NewAdd->setHasNoSignedWrap(Add.hasNoSignedWrap() && + willNotOverflowSignedAdd(Op01C, Op1C, Add)); + NewAdd->setHasNoUnsignedWrap(Add.hasNoUnsignedWrap()); + return NewAdd; + } // (X | C2) + C --> (X | C2) ^ C2 iff (C2 == -C) const APInt *C2; @@ -986,7 +999,7 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { if (C->isOne()) { if (match(Op0, m_ZExt(m_Add(m_Value(X), m_AllOnes())))) { const SimplifyQuery Q = SQ.getWithInstruction(&Add); - if (llvm::isKnownNonZero(X, DL, 0, Q.AC, Q.CxtI, Q.DT)) + if (llvm::isKnownNonZero(X, Q)) return new ZExtInst(X, Ty); } } @@ -1012,7 +1025,7 @@ static bool matchesSquareSum(BinaryOperator &I, Mul2Rhs M2Rhs, Value *&A, // (a * a) + (((a * 2) + b) * b) if (match(&I, m_c_BinOp( AddOp, m_OneUse(m_BinOp(MulOp, m_Value(A), m_Deferred(A))), - m_OneUse(m_BinOp( + m_OneUse(m_c_BinOp( MulOp, m_c_BinOp(AddOp, m_BinOp(Mul2Op, m_Deferred(A), M2Rhs), m_Value(B)), @@ -1023,16 +1036,16 @@ static bool matchesSquareSum(BinaryOperator &I, Mul2Rhs M2Rhs, Value *&A, // + // (a * a + b * b) or (b * b + a * a) return match( - &I, - m_c_BinOp(AddOp, - m_CombineOr( - m_OneUse(m_BinOp( - Mul2Op, m_BinOp(MulOp, m_Value(A), m_Value(B)), M2Rhs)), - m_OneUse(m_BinOp(MulOp, m_BinOp(Mul2Op, m_Value(A), M2Rhs), + &I, m_c_BinOp( + AddOp, + m_CombineOr( + m_OneUse(m_BinOp( + Mul2Op, m_BinOp(MulOp, m_Value(A), m_Value(B)), M2Rhs)), + m_OneUse(m_c_BinOp(MulOp, m_BinOp(Mul2Op, m_Value(A), M2Rhs), m_Value(B)))), - m_OneUse(m_c_BinOp( - AddOp, m_BinOp(MulOp, m_Deferred(A), m_Deferred(A)), - m_BinOp(MulOp, m_Deferred(B), m_Deferred(B)))))); + m_OneUse( + m_c_BinOp(AddOp, m_BinOp(MulOp, m_Deferred(A), m_Deferred(A)), + m_BinOp(MulOp, m_Deferred(B), m_Deferred(B)))))); } // Fold integer variations of a^2 + 2*a*b + b^2 -> (a + b)^2 @@ -1132,6 +1145,8 @@ static bool MulWillOverflow(APInt &C0, APInt &C1, bool IsSigned) { // Simplifies X % C0 + (( X / C0 ) % C1) * C0 to X % (C0 * C1), where (C0 * C1) // does not overflow. +// Simplifies (X / C0) * C1 + (X % C0) * C2 to +// (X / C0) * (C1 - C2 * C0) + X * C2 Value *InstCombinerImpl::SimplifyAddWithRemainder(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); Value *X, *MulOpV; @@ -1159,6 +1174,33 @@ Value *InstCombinerImpl::SimplifyAddWithRemainder(BinaryOperator &I) { } } + // Match I = (X / C0) * C1 + (X % C0) * C2 + Value *Div, *Rem; + APInt C1, C2; + if (!LHS->hasOneUse() || !MatchMul(LHS, Div, C1)) + Div = LHS, C1 = APInt(I.getType()->getScalarSizeInBits(), 1); + if (!RHS->hasOneUse() || !MatchMul(RHS, Rem, C2)) + Rem = RHS, C2 = APInt(I.getType()->getScalarSizeInBits(), 1); + if (match(Div, m_IRem(m_Value(), m_Value()))) { + std::swap(Div, Rem); + std::swap(C1, C2); + } + Value *DivOpV; + APInt DivOpC; + if (MatchRem(Rem, X, C0, IsSigned) && + MatchDiv(Div, DivOpV, DivOpC, IsSigned) && X == DivOpV && C0 == DivOpC) { + APInt NewC = C1 - C2 * C0; + if (!NewC.isZero() && !Rem->hasOneUse()) + return nullptr; + if (!isGuaranteedNotToBeUndef(X, &AC, &I, &DT)) + return nullptr; + Value *MulXC2 = Builder.CreateMul(X, ConstantInt::get(X->getType(), C2)); + if (NewC.isZero()) + return MulXC2; + return Builder.CreateAdd( + Builder.CreateMul(Div, ConstantInt::get(X->getType(), NewC)), MulXC2); + } + return nullptr; } @@ -1654,7 +1696,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { // (A s>> (BW - 1)) + (zext (A s> 0)) --> (A s>> (BW - 1)) | (zext (A != 0)) ICmpInst::Predicate Pred; uint64_t BitWidth = Ty->getScalarSizeInBits(); - if (match(LHS, m_AShr(m_Value(A), m_SpecificIntAllowUndef(BitWidth - 1))) && + if (match(LHS, m_AShr(m_Value(A), m_SpecificIntAllowPoison(BitWidth - 1))) && match(RHS, m_OneUse(m_ZExt( m_OneUse(m_ICmp(Pred, m_Specific(A), m_ZeroInt()))))) && Pred == CmpInst::ICMP_SGT) { @@ -1663,6 +1705,24 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { return BinaryOperator::CreateOr(LHS, Zext); } + { + Value *Cond, *Ext; + Constant *C; + // (add X, (sext/zext (icmp eq X, C))) + // -> (select (icmp eq X, C), (add C, (sext/zext 1)), X) + auto CondMatcher = m_CombineAnd( + m_Value(Cond), m_ICmp(Pred, m_Deferred(A), m_ImmConstant(C))); + + if (match(&I, + m_c_Add(m_Value(A), + m_CombineAnd(m_Value(Ext), m_ZExtOrSExt(CondMatcher)))) && + Pred == ICmpInst::ICMP_EQ && Ext->hasOneUse()) { + Value *Add = isa<ZExtInst>(Ext) ? InstCombiner::AddOne(C) + : InstCombiner::SubOne(C); + return replaceInstUsesWith(I, Builder.CreateSelect(Cond, Add, A)); + } + } + if (Instruction *Ashr = foldAddToAshr(I)) return Ashr; @@ -1867,64 +1927,10 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) { // Check for (fadd double (sitofp x), y), see if we can merge this into an // integer add followed by a promotion. - Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); - if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) { - Value *LHSIntVal = LHSConv->getOperand(0); - Type *FPType = LHSConv->getType(); - - // TODO: This check is overly conservative. In many cases known bits - // analysis can tell us that the result of the addition has less significant - // bits than the integer type can hold. - auto IsValidPromotion = [](Type *FTy, Type *ITy) { - Type *FScalarTy = FTy->getScalarType(); - Type *IScalarTy = ITy->getScalarType(); - - // Do we have enough bits in the significand to represent the result of - // the integer addition? - unsigned MaxRepresentableBits = - APFloat::semanticsPrecision(FScalarTy->getFltSemantics()); - return IScalarTy->getIntegerBitWidth() <= MaxRepresentableBits; - }; - - // (fadd double (sitofp x), fpcst) --> (sitofp (add int x, intcst)) - // ... if the constant fits in the integer value. This is useful for things - // like (double)(x & 1234) + 4.0 -> (double)((X & 1234)+4) which no longer - // requires a constant pool load, and generally allows the add to be better - // instcombined. - if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS)) - if (IsValidPromotion(FPType, LHSIntVal->getType())) { - Constant *CI = ConstantFoldCastOperand(Instruction::FPToSI, CFP, - LHSIntVal->getType(), DL); - if (LHSConv->hasOneUse() && - ConstantFoldCastOperand(Instruction::SIToFP, CI, I.getType(), DL) == - CFP && - willNotOverflowSignedAdd(LHSIntVal, CI, I)) { - // Insert the new integer add. - Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, CI, "addconv"); - return new SIToFPInst(NewAdd, I.getType()); - } - } - - // (fadd double (sitofp x), (sitofp y)) --> (sitofp (add int x, y)) - if (SIToFPInst *RHSConv = dyn_cast<SIToFPInst>(RHS)) { - Value *RHSIntVal = RHSConv->getOperand(0); - // It's enough to check LHS types only because we require int types to - // be the same for this transform. - if (IsValidPromotion(FPType, LHSIntVal->getType())) { - // Only do this if x/y have the same type, if at least one of them has a - // single use (so we don't increase the number of int->fp conversions), - // and if the integer add will not overflow. - if (LHSIntVal->getType() == RHSIntVal->getType() && - (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && - willNotOverflowSignedAdd(LHSIntVal, RHSIntVal, I)) { - // Insert the new integer add. - Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, RHSIntVal, "addconv"); - return new SIToFPInst(NewAdd, I.getType()); - } - } - } - } + if (Instruction *R = foldFBinOpOfIntCasts(I)) + return R; + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); // Handle specials cases for FAdd with selects feeding the operation if (Value *V = SimplifySelectsFeedingBinaryOp(I, LHS, RHS)) return replaceInstUsesWith(I, V); @@ -2024,43 +2030,30 @@ Value *InstCombinerImpl::OptimizePointerDifference(Value *LHS, Value *RHS, if (!GEP1) return nullptr; - if (GEP2) { - // (gep X, ...) - (gep X, ...) - // - // Avoid duplicating the arithmetic if there are more than one non-constant - // indices between the two GEPs and either GEP has a non-constant index and - // multiple users. If zero non-constant index, the result is a constant and - // there is no duplication. If one non-constant index, the result is an add - // or sub with a constant, which is no larger than the original code, and - // there's no duplicated arithmetic, even if either GEP has multiple - // users. If more than one non-constant indices combined, as long as the GEP - // with at least one non-constant index doesn't have multiple users, there - // is no duplication. - unsigned NumNonConstantIndices1 = GEP1->countNonConstantIndices(); - unsigned NumNonConstantIndices2 = GEP2->countNonConstantIndices(); - if (NumNonConstantIndices1 + NumNonConstantIndices2 > 1 && - ((NumNonConstantIndices1 > 0 && !GEP1->hasOneUse()) || - (NumNonConstantIndices2 > 0 && !GEP2->hasOneUse()))) { - return nullptr; - } - } + // To avoid duplicating the offset arithmetic, rewrite the GEP to use the + // computed offset. This may erase the original GEP, so be sure to cache the + // inbounds flag before emitting the offset. + // TODO: We should probably do this even if there is only one GEP. + bool RewriteGEPs = GEP2 != nullptr; // Emit the offset of the GEP and an intptr_t. - Value *Result = EmitGEPOffset(GEP1); + bool GEP1IsInBounds = GEP1->isInBounds(); + Value *Result = EmitGEPOffset(GEP1, RewriteGEPs); // If this is a single inbounds GEP and the original sub was nuw, // then the final multiplication is also nuw. if (auto *I = dyn_cast<Instruction>(Result)) - if (IsNUW && !GEP2 && !Swapped && GEP1->isInBounds() && + if (IsNUW && !GEP2 && !Swapped && GEP1IsInBounds && I->getOpcode() == Instruction::Mul) I->setHasNoUnsignedWrap(); // If we have a 2nd GEP of the same base pointer, subtract the offsets. // If both GEPs are inbounds, then the subtract does not have signed overflow. if (GEP2) { - Value *Offset = EmitGEPOffset(GEP2); + bool GEP2IsInBounds = GEP2->isInBounds(); + Value *Offset = EmitGEPOffset(GEP2, RewriteGEPs); Result = Builder.CreateSub(Result, Offset, "gepdiff", /* NUW */ false, - GEP1->isInBounds() && GEP2->isInBounds()); + GEP1IsInBounds && GEP2IsInBounds); } // If we have p - gep(p, ...) then we have to negate the result. @@ -2333,8 +2326,10 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { if (match(Op0, m_APInt(Op0C))) { if (Op0C->isMask()) { // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known - // zero. - KnownBits RHSKnown = computeKnownBits(Op1, 0, &I); + // zero. We don't use information from dominating conditions so this + // transform is easier to reverse if necessary. + KnownBits RHSKnown = llvm::computeKnownBits( + Op1, 0, SQ.getWithInstruction(&I).getWithoutDomCondCache()); if ((*Op0C | RHSKnown.Zero).isAllOnes()) return BinaryOperator::CreateXor(Op1, Op0); } @@ -2448,6 +2443,21 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { } } + { + // (sub (xor X, (sext C)), (sext C)) => (select C, (neg X), X) + // (sub (sext C), (xor X, (sext C))) => (select C, X, (neg X)) + Value *C, *X; + auto m_SubXorCmp = [&C, &X](Value *LHS, Value *RHS) { + return match(LHS, m_OneUse(m_c_Xor(m_Value(X), m_Specific(RHS)))) && + match(RHS, m_SExt(m_Value(C))) && + (C->getType()->getScalarSizeInBits() == 1); + }; + if (m_SubXorCmp(Op0, Op1)) + return SelectInst::Create(C, Builder.CreateNeg(X), X); + if (m_SubXorCmp(Op1, Op0)) + return SelectInst::Create(C, X, Builder.CreateNeg(X)); + } + if (Instruction *R = tryFoldInstWithCtpopWithNot(&I)) return R; @@ -2561,9 +2571,10 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { // sub (xor A, B), B ; flip bits if negative and subtract -1 (add 1) // --> (A < 0) ? -A : A Value *IsNeg = Builder.CreateIsNeg(A); - // Copy the nuw/nsw flags from the sub to the negate. - Value *NegA = Builder.CreateNeg(A, "", I.hasNoUnsignedWrap(), - I.hasNoSignedWrap()); + // Copy the nsw flags from the sub to the negate. + Value *NegA = I.hasNoUnsignedWrap() + ? Constant::getNullValue(A->getType()) + : Builder.CreateNeg(A, "", I.hasNoSignedWrap()); return SelectInst::Create(IsNeg, NegA, A); } @@ -2786,6 +2797,16 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { propagateSelectFMF(NewSel, P == X); return NewSel; } + + // -(Cond ? X : C) --> Cond ? -X : -C + // -(Cond ? C : Y) --> Cond ? -C : -Y + if (match(X, m_ImmConstant()) || match(Y, m_ImmConstant())) { + Value *NegX = Builder.CreateFNegFMF(X, &I, X->getName() + ".neg"); + Value *NegY = Builder.CreateFNegFMF(Y, &I, Y->getName() + ".neg"); + SelectInst *NewSel = SelectInst::Create(Cond, NegX, NegY); + propagateSelectFMF(NewSel, /*CommonOperand=*/true); + return NewSel; + } } // fneg (copysign x, y) -> copysign x, (fneg y) @@ -2832,6 +2853,9 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { if (Instruction *X = foldFNegIntoConstant(I, DL)) return X; + if (Instruction *R = foldFBinOpOfIntCasts(I)) + return R; + Value *X, *Y; Constant *C; @@ -2842,7 +2866,8 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { // Note that if this fsub was really an fneg, the fadd with -0.0 will get // killed later. We still limit that particular transform with 'hasOneUse' // because an fneg is assumed better/cheaper than a generic fsub. - if (I.hasNoSignedZeros() || cannotBeNegativeZero(Op0, SQ.DL, SQ.TLI)) { + if (I.hasNoSignedZeros() || + cannotBeNegativeZero(Op0, 0, getSimplifyQuery().getWithInstruction(&I))) { if (match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { Value *NewSub = Builder.CreateFSubFMF(Y, X, &I); return BinaryOperator::CreateFAddFMF(Op0, NewSub, &I); diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 5fd944a859ef..3222e8298c3f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -701,6 +701,45 @@ Value *InstCombinerImpl::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, return Builder.CreateICmp(NewPred, Input, RangeEnd); } +// (or (icmp eq X, 0), (icmp eq X, Pow2OrZero)) +// -> (icmp eq (and X, Pow2OrZero), X) +// (and (icmp ne X, 0), (icmp ne X, Pow2OrZero)) +// -> (icmp ne (and X, Pow2OrZero), X) +static Value * +foldAndOrOfICmpsWithPow2AndWithZero(InstCombiner::BuilderTy &Builder, + ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, + const SimplifyQuery &Q) { + CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_NE : CmpInst::ICMP_EQ; + // Make sure we have right compares for our op. + if (LHS->getPredicate() != Pred || RHS->getPredicate() != Pred) + return nullptr; + + // Make it so we can match LHS against the (icmp eq/ne X, 0) just for + // simplicity. + if (match(RHS->getOperand(1), m_Zero())) + std::swap(LHS, RHS); + + Value *Pow2, *Op; + // Match the desired pattern: + // LHS: (icmp eq/ne X, 0) + // RHS: (icmp eq/ne X, Pow2OrZero) + // Skip if Pow2OrZero is 1. Either way it gets folded to (icmp ugt X, 1) but + // this form ends up slightly less canonical. + // We could potentially be more sophisticated than requiring LHS/RHS + // be one-use. We don't create additional instructions if only one + // of them is one-use. So cases where one is one-use and the other + // is two-use might be profitable. + if (!match(LHS, m_OneUse(m_ICmp(Pred, m_Value(Op), m_Zero()))) || + !match(RHS, m_OneUse(m_c_ICmp(Pred, m_Specific(Op), m_Value(Pow2)))) || + match(Pow2, m_One()) || + !isKnownToBeAPowerOfTwo(Pow2, Q.DL, /*OrZero=*/true, /*Depth=*/0, Q.AC, + Q.CxtI, Q.DT)) + return nullptr; + + Value *And = Builder.CreateAnd(Op, Pow2); + return Builder.CreateICmp(Pred, And, Op); +} + // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) Value *InstCombinerImpl::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, @@ -887,9 +926,11 @@ static Value *foldIsPowerOf2OrZero(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd, } /// Reduce a pair of compares that check if a value has exactly 1 bit set. -/// Also used for logical and/or, must be poison safe. +/// Also used for logical and/or, must be poison safe if range attributes are +/// dropped. static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, bool JoinedByAnd, - InstCombiner::BuilderTy &Builder) { + InstCombiner::BuilderTy &Builder, + InstCombinerImpl &IC) { // Handle 'and' / 'or' commutation: make the equality check the first operand. if (JoinedByAnd && Cmp1->getPredicate() == ICmpInst::ICMP_NE) std::swap(Cmp0, Cmp1); @@ -903,7 +944,10 @@ static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, bool JoinedByAnd, match(Cmp1, m_ICmp(Pred1, m_Intrinsic<Intrinsic::ctpop>(m_Specific(X)), m_SpecificInt(2))) && Pred0 == ICmpInst::ICMP_NE && Pred1 == ICmpInst::ICMP_ULT) { - Value *CtPop = Cmp1->getOperand(0); + auto *CtPop = cast<Instruction>(Cmp1->getOperand(0)); + // Drop range attributes and re-infer them in the next iteration. + CtPop->dropPoisonGeneratingAnnotations(); + IC.addToWorklist(CtPop); return Builder.CreateICmpEQ(CtPop, ConstantInt::get(CtPop->getType(), 1)); } // (X == 0) || (ctpop(X) u> 1) --> ctpop(X) != 1 @@ -911,7 +955,10 @@ static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, bool JoinedByAnd, match(Cmp1, m_ICmp(Pred1, m_Intrinsic<Intrinsic::ctpop>(m_Specific(X)), m_SpecificInt(1))) && Pred0 == ICmpInst::ICMP_EQ && Pred1 == ICmpInst::ICMP_UGT) { - Value *CtPop = Cmp1->getOperand(0); + auto *CtPop = cast<Instruction>(Cmp1->getOperand(0)); + // Drop range attributes and re-infer them in the next iteration. + CtPop->dropPoisonGeneratingAnnotations(); + IC.addToWorklist(CtPop); return Builder.CreateICmpNE(CtPop, ConstantInt::get(CtPop->getType(), 1)); } return nullptr; @@ -947,9 +994,9 @@ static Value *foldNegativePower2AndShiftedMask( // bits (0). auto isReducible = [](const Value *B, const Value *D, const Value *E) { const APInt *BCst, *DCst, *ECst; - return match(B, m_APIntAllowUndef(BCst)) && match(D, m_APInt(DCst)) && + return match(B, m_APIntAllowPoison(BCst)) && match(D, m_APInt(DCst)) && match(E, m_APInt(ECst)) && *DCst == *ECst && - (isa<UndefValue>(B) || + (isa<PoisonValue>(B) || (BCst->countLeadingOnes() == DCst->countLeadingZeros())); }; @@ -1031,10 +1078,6 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp, !ICmpInst::isEquality(EqPred)) return nullptr; - auto IsKnownNonZero = [&](Value *V) { - return isKnownNonZero(V, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); - }; - ICmpInst::Predicate UnsignedPred; Value *A, *B; @@ -1043,9 +1086,9 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp, match(ZeroCmpOp, m_c_Add(m_Specific(A), m_Value(B))) && (ZeroICmp->hasOneUse() || UnsignedICmp->hasOneUse())) { auto GetKnownNonZeroAndOther = [&](Value *&NonZero, Value *&Other) { - if (!IsKnownNonZero(NonZero)) + if (!isKnownNonZero(NonZero, Q)) std::swap(NonZero, Other); - return IsKnownNonZero(NonZero); + return isKnownNonZero(NonZero, Q); }; // Given ZeroCmpOp = (A + B) @@ -1196,7 +1239,7 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, // operand 0). Value *Y; ICmpInst::Predicate Pred1; - if (!match(Cmp1, m_c_ICmp(Pred1, m_Value(Y), m_Deferred(X)))) + if (!match(Cmp1, m_c_ICmp(Pred1, m_Value(Y), m_Specific(X)))) return nullptr; // Replace variable with constant value equivalence to remove a variable use: @@ -1419,6 +1462,44 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, } } + // Canonicalize the range check idiom: + // and (fcmp olt/ole/ult/ule x, C), (fcmp ogt/oge/ugt/uge x, -C) + // --> fabs(x) olt/ole/ult/ule C + // or (fcmp ogt/oge/ugt/uge x, C), (fcmp olt/ole/ult/ule x, -C) + // --> fabs(x) ogt/oge/ugt/uge C + // TODO: Generalize to handle a negated variable operand? + const APFloat *LHSC, *RHSC; + if (LHS0 == RHS0 && LHS->hasOneUse() && RHS->hasOneUse() && + FCmpInst::getSwappedPredicate(PredL) == PredR && + match(LHS1, m_APFloatAllowPoison(LHSC)) && + match(RHS1, m_APFloatAllowPoison(RHSC)) && + LHSC->bitwiseIsEqual(neg(*RHSC))) { + auto IsLessThanOrLessEqual = [](FCmpInst::Predicate Pred) { + switch (Pred) { + case FCmpInst::FCMP_OLT: + case FCmpInst::FCMP_OLE: + case FCmpInst::FCMP_ULT: + case FCmpInst::FCMP_ULE: + return true; + default: + return false; + } + }; + if (IsLessThanOrLessEqual(IsAnd ? PredR : PredL)) { + std::swap(LHSC, RHSC); + std::swap(PredL, PredR); + } + if (IsLessThanOrLessEqual(IsAnd ? PredL : PredR)) { + BuilderTy::FastMathFlagGuard Guard(Builder); + Builder.setFastMathFlags(LHS->getFastMathFlags() | + RHS->getFastMathFlags()); + + Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, LHS0); + return Builder.CreateFCmp(PredL, FAbs, + ConstantFP::get(LHS0->getType(), *LHSC)); + } + } + return nullptr; } @@ -1516,7 +1597,7 @@ Instruction *InstCombinerImpl::canonicalizeConditionalNegationViaMathToSelect( if (!match(&I, m_c_BinOp(m_OneUse(m_Value()), m_Value())) || !match(I.getOperand(1), m_SExt(m_Value(Cond))) || !Cond->getType()->isIntOrIntVectorTy(1) || - !match(I.getOperand(0), m_c_Add(m_SExt(m_Deferred(Cond)), m_Value(X)))) + !match(I.getOperand(0), m_c_Add(m_SExt(m_Specific(Cond)), m_Value(X)))) return nullptr; return SelectInst::Create(Cond, Builder.CreateNeg(X, X->getName() + ".neg"), X); @@ -1647,7 +1728,7 @@ static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast, } } - if (match(Cast, m_OneUse(m_SExt(m_Value(X))))) { + if (match(Cast, m_OneUse(m_SExtLike(m_Value(X))))) { if (Constant *TruncC = IC.getLosslessSignedTrunc(C, SrcTy)) { // LogicOpc (sext X), C --> sext (LogicOpc X, C) Value *NewOp = IC.Builder.CreateBinOp(LogicOpc, X, TruncC); @@ -2179,6 +2260,49 @@ foldBitwiseLogicWithIntrinsics(BinaryOperator &I, } } +// Try to simplify V by replacing occurrences of Op with RepOp, but only look +// through bitwise operations. In particular, for X | Y we try to replace Y with +// 0 inside X and for X & Y we try to replace Y with -1 inside X. +// Return the simplified result of X if successful, and nullptr otherwise. +// If SimplifyOnly is true, no new instructions will be created. +static Value *simplifyAndOrWithOpReplaced(Value *V, Value *Op, Value *RepOp, + bool SimplifyOnly, + InstCombinerImpl &IC, + unsigned Depth = 0) { + if (Op == RepOp) + return nullptr; + + if (V == Op) + return RepOp; + + auto *I = dyn_cast<BinaryOperator>(V); + if (!I || !I->isBitwiseLogicOp() || Depth >= 3) + return nullptr; + + if (!I->hasOneUse()) + SimplifyOnly = true; + + Value *NewOp0 = simplifyAndOrWithOpReplaced(I->getOperand(0), Op, RepOp, + SimplifyOnly, IC, Depth + 1); + Value *NewOp1 = simplifyAndOrWithOpReplaced(I->getOperand(1), Op, RepOp, + SimplifyOnly, IC, Depth + 1); + if (!NewOp0 && !NewOp1) + return nullptr; + + if (!NewOp0) + NewOp0 = I->getOperand(0); + if (!NewOp1) + NewOp1 = I->getOperand(1); + + if (Value *Res = simplifyBinOp(I->getOpcode(), NewOp0, NewOp1, + IC.getSimplifyQuery().getWithInstruction(I))) + return Res; + + if (SimplifyOnly) + return nullptr; + return IC.Builder.CreateBinOp(I->getOpcode(), NewOp0, NewOp1); +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. @@ -2220,10 +2344,12 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Value *X, *Y; - if (match(Op0, m_OneUse(m_LogicalShift(m_One(), m_Value(X)))) && + const APInt *C; + if ((match(Op0, m_OneUse(m_LogicalShift(m_One(), m_Value(X)))) || + (match(Op0, m_OneUse(m_Shl(m_APInt(C), m_Value(X)))) && (*C)[0])) && match(Op1, m_One())) { - // (1 << X) & 1 --> zext(X == 0) // (1 >> X) & 1 --> zext(X == 0) + // (C << X) & 1 --> zext(X == 0), when C is odd Value *IsZero = Builder.CreateICmpEQ(X, ConstantInt::get(Ty, 0)); return new ZExtInst(IsZero, Ty); } @@ -2246,7 +2372,6 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { isKnownToBeAPowerOfTwo(Y, /*OrZero*/ true, /*Depth*/ 0, &I)) return BinaryOperator::CreateAnd(Builder.CreateNot(X), Y); - const APInt *C; if (match(Op1, m_APInt(C))) { const APInt *XorC; if (match(Op0, m_OneUse(m_Xor(m_Value(X), m_APInt(XorC))))) { @@ -2426,8 +2551,8 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { match(C1, m_Power2())) { Constant *Log2C1 = ConstantExpr::getExactLogBase2(C1); Constant *Cmp = - ConstantExpr::getCompare(ICmpInst::ICMP_ULT, Log2C3, C2); - if (Cmp->isZeroValue()) { + ConstantFoldCompareInstOperands(ICmpInst::ICMP_ULT, Log2C3, C2, DL); + if (Cmp && Cmp->isZeroValue()) { // iff C1,C3 is pow2 and Log2(C3) >= C2: // ((C1 >> X) << C2) & C3 -> X == (cttz(C1)+C2-cttz(C3)) ? C3 : 0 Constant *ShlC = ConstantExpr::getAdd(C2, Log2C1); @@ -2449,19 +2574,19 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { // Assumes any IEEE-represented type has the sign bit in the high bit. // TODO: Unify with APInt matcher. This version allows undef unlike m_APInt Value *CastOp; - if (match(Op0, m_BitCast(m_Value(CastOp))) && + if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) && match(Op1, m_MaxSignedValue()) && !Builder.GetInsertBlock()->getParent()->hasFnAttribute( - Attribute::NoImplicitFloat)) { + Attribute::NoImplicitFloat)) { Type *EltTy = CastOp->getType()->getScalarType(); - if (EltTy->isFloatingPointTy() && EltTy->isIEEE() && - EltTy->getPrimitiveSizeInBits() == - I.getType()->getScalarType()->getPrimitiveSizeInBits()) { + if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) { Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp); return new BitCastInst(FAbs, I.getType()); } } + // and(shl(zext(X), Y), SignMask) -> and(sext(X), SignMask) + // where Y is a valid shift amount. if (match(&I, m_And(m_OneUse(m_Shl(m_ZExt(m_Value(X)), m_Value(Y))), m_SignMask())) && match(Y, m_SpecificInt_ICMP( @@ -2470,15 +2595,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { Ty->getScalarSizeInBits() - X->getType()->getScalarSizeInBits())))) { auto *SExt = Builder.CreateSExt(X, Ty, X->getName() + ".signext"); - auto *SanitizedSignMask = cast<Constant>(Op1); - // We must be careful with the undef elements of the sign bit mask, however: - // the mask elt can be undef iff the shift amount for that lane was undef, - // otherwise we need to sanitize undef masks to zero. - SanitizedSignMask = Constant::replaceUndefsWith( - SanitizedSignMask, ConstantInt::getNullValue(Ty->getScalarType())); - SanitizedSignMask = - Constant::mergeUndefsWith(SanitizedSignMask, cast<Constant>(Y)); - return BinaryOperator::CreateAnd(SExt, SanitizedSignMask); + return BinaryOperator::CreateAnd(SExt, Op1); } if (Instruction *Z = narrowMaskedBinOp(I)) @@ -2505,13 +2622,6 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { { Value *A, *B, *C; - // A & (A ^ B) --> A & ~B - if (match(Op1, m_OneUse(m_c_Xor(m_Specific(Op0), m_Value(B))))) - return BinaryOperator::CreateAnd(Op0, Builder.CreateNot(B)); - // (A ^ B) & A --> A & ~B - if (match(Op0, m_OneUse(m_c_Xor(m_Specific(Op1), m_Value(B))))) - return BinaryOperator::CreateAnd(Op1, Builder.CreateNot(B)); - // A & ~(A ^ B) --> A & B if (match(Op1, m_Not(m_c_Xor(m_Specific(Op0), m_Value(B))))) return BinaryOperator::CreateAnd(Op0, B); @@ -2637,7 +2747,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { // arm may not be reversible due to poison semantics. Is that a good // canonicalization? Value *A, *B; - if (match(&I, m_c_And(m_OneUse(m_SExt(m_Value(A))), m_Value(B))) && + if (match(&I, m_c_And(m_SExt(m_Value(A)), m_Value(B))) && A->getType()->isIntOrIntVectorTy(1)) return SelectInst::Create(A, B, Constant::getNullValue(Ty)); @@ -2667,7 +2777,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { // (iN X s>> (N-1)) & Y --> (X s< 0) ? Y : 0 -- with optional sext if (match(&I, m_c_And(m_OneUse(m_SExtOrSelf( - m_AShr(m_Value(X), m_APIntAllowUndef(C)))), + m_AShr(m_Value(X), m_APIntAllowPoison(C)))), m_Value(Y))) && *C == X->getType()->getScalarSizeInBits() - 1) { Value *IsNeg = Builder.CreateIsNeg(X, "isneg"); @@ -2676,7 +2786,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { // If there's a 'not' of the shifted value, swap the select operands: // ~(iN X s>> (N-1)) & Y --> (X s< 0) ? 0 : Y -- with optional sext if (match(&I, m_c_And(m_OneUse(m_SExtOrSelf( - m_Not(m_AShr(m_Value(X), m_APIntAllowUndef(C))))), + m_Not(m_AShr(m_Value(X), m_APIntAllowPoison(C))))), m_Value(Y))) && *C == X->getType()->getScalarSizeInBits() - 1) { Value *IsNeg = Builder.CreateIsNeg(X, "isneg"); @@ -2708,6 +2818,15 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { if (Instruction *Res = foldBitwiseLogicWithIntrinsics(I, Builder)) return Res; + if (Value *V = + simplifyAndOrWithOpReplaced(Op0, Op1, Constant::getAllOnesValue(Ty), + /*SimplifyOnly*/ false, *this)) + return BinaryOperator::CreateAnd(V, Op1); + if (Value *V = + simplifyAndOrWithOpReplaced(Op1, Op0, Constant::getAllOnesValue(Ty), + /*SimplifyOnly*/ false, *this)) + return BinaryOperator::CreateAnd(Op0, V); + return nullptr; } @@ -2726,17 +2845,18 @@ Instruction *InstCombinerImpl::matchBSwapOrBitReverse(Instruction &I, return LastInst; } -/// Match UB-safe variants of the funnel shift intrinsic. -static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, - const DominatorTree &DT) { +std::optional<std::pair<Intrinsic::ID, SmallVector<Value *, 3>>> +InstCombinerImpl::convertOrOfShiftsToFunnelShift(Instruction &Or) { // TODO: Can we reduce the code duplication between this and the related // rotate matching code under visitSelect and visitTrunc? + assert(Or.getOpcode() == BinaryOperator::Or && "Expecting or instruction"); + unsigned Width = Or.getType()->getScalarSizeInBits(); Instruction *Or0, *Or1; if (!match(Or.getOperand(0), m_Instruction(Or0)) || !match(Or.getOperand(1), m_Instruction(Or1))) - return nullptr; + return std::nullopt; bool IsFshl = true; // Sub on LSHR. SmallVector<Value *, 3> FShiftArgs; @@ -2750,7 +2870,7 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) || Or0->getOpcode() == Or1->getOpcode()) - return nullptr; + return std::nullopt; // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)). if (Or0->getOpcode() == BinaryOperator::LShr) { @@ -2767,7 +2887,7 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * { // Check for constant shift amounts that sum to the bitwidth. const APInt *LI, *RI; - if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI))) + if (match(L, m_APIntAllowPoison(LI)) && match(R, m_APIntAllowPoison(RI))) if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width) return ConstantInt::get(L->getType(), *LI); @@ -2777,7 +2897,7 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) && match(R, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) && - match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width))) + match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowPoison(Width))) return ConstantExpr::mergeUndefsWith(LC, RC); // (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width. @@ -2786,7 +2906,7 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, // might remove it after this fold). This still doesn't guarantee that the // final codegen will match this original pattern. if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) { - KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or); + KnownBits KnownL = computeKnownBits(L, /*Depth*/ 0, &Or); return KnownL.getMaxValue().ult(Width) ? L : nullptr; } @@ -2834,7 +2954,7 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, IsFshl = false; // Sub on SHL. } if (!ShAmt) - return nullptr; + return std::nullopt; FShiftArgs = {ShVal0, ShVal1, ShAmt}; } else if (isa<ZExtInst>(Or0) || isa<ZExtInst>(Or1)) { @@ -2856,18 +2976,18 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, const APInt *ZextHighShlAmt; if (!match(Or0, m_OneUse(m_Shl(m_Value(ZextHigh), m_APInt(ZextHighShlAmt))))) - return nullptr; + return std::nullopt; if (!match(Or1, m_ZExt(m_Value(Low))) || !match(ZextHigh, m_ZExt(m_Value(High)))) - return nullptr; + return std::nullopt; unsigned HighSize = High->getType()->getScalarSizeInBits(); unsigned LowSize = Low->getType()->getScalarSizeInBits(); // Make sure High does not overlap with Low and most significant bits of // High aren't shifted out. if (ZextHighShlAmt->ult(LowSize) || ZextHighShlAmt->ugt(Width - HighSize)) - return nullptr; + return std::nullopt; for (User *U : ZextHigh->users()) { Value *X, *Y; @@ -2898,11 +3018,21 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, } if (FShiftArgs.empty()) - return nullptr; + return std::nullopt; Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; - Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType()); - return CallInst::Create(F, FShiftArgs); + return std::make_pair(IID, FShiftArgs); +} + +/// Match UB-safe variants of the funnel shift intrinsic. +static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC) { + if (auto Opt = IC.convertOrOfShiftsToFunnelShift(Or)) { + auto [IID, FShiftArgs] = *Opt; + Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType()); + return CallInst::Create(F, FShiftArgs); + } + + return nullptr; } /// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns. @@ -3058,20 +3188,20 @@ Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B, return nullptr; } -/// We have an expression of the form (A & C) | (B & D). Try to simplify this -/// to "A' ? C : D", where A' is a boolean or vector of booleans. +/// We have an expression of the form (A & B) | (C & D). Try to simplify this +/// to "A' ? B : D", where A' is a boolean or vector of booleans. /// When InvertFalseVal is set to true, we try to match the pattern -/// where we have peeked through a 'not' op and A and B are the same: -/// (A & C) | ~(A | D) --> (A & C) | (~A & ~D) --> A' ? C : ~D -Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B, +/// where we have peeked through a 'not' op and A and C are the same: +/// (A & B) | ~(A | D) --> (A & B) | (~A & ~D) --> A' ? B : ~D +Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D, bool InvertFalseVal) { // The potential condition of the select may be bitcasted. In that case, look // through its bitcast and the corresponding bitcast of the 'not' condition. Type *OrigType = A->getType(); A = peekThroughBitcast(A, true); - B = peekThroughBitcast(B, true); - if (Value *Cond = getSelectCondition(A, B, InvertFalseVal)) { - // ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D)) + C = peekThroughBitcast(C, true); + if (Value *Cond = getSelectCondition(A, C, InvertFalseVal)) { + // ((bc Cond) & B) | ((bc ~Cond) & D) --> bc (select Cond, (bc B), (bc D)) // If this is a vector, we may need to cast to match the condition's length. // The bitcasts will either all exist or all not exist. The builder will // not create unnecessary casts if the types already match. @@ -3085,11 +3215,11 @@ Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B, Type *EltTy = Builder.getIntNTy(SelEltSize / Elts); SelTy = VectorType::get(EltTy, VecTy->getElementCount()); } - Value *BitcastC = Builder.CreateBitCast(C, SelTy); + Value *BitcastB = Builder.CreateBitCast(B, SelTy); if (InvertFalseVal) D = Builder.CreateNot(D); Value *BitcastD = Builder.CreateBitCast(D, SelTy); - Value *Select = Builder.CreateSelect(Cond, BitcastC, BitcastD); + Value *Select = Builder.CreateSelect(Cond, BitcastB, BitcastD); return Builder.CreateBitCast(Select, OrigType); } @@ -3112,14 +3242,14 @@ static Value *foldAndOrOfICmpEqConstantAndICmp(ICmpInst *LHS, ICmpInst *RHS, const APInt *CInt; if (LPred != ICmpInst::ICMP_EQ || - !match(LHS->getOperand(1), m_APIntAllowUndef(CInt)) || + !match(LHS->getOperand(1), m_APIntAllowPoison(CInt)) || !LHS0->getType()->isIntOrIntVectorTy() || !(LHS->hasOneUse() || RHS->hasOneUse())) return nullptr; auto MatchRHSOp = [LHS0, CInt](const Value *RHSOp) { return match(RHSOp, - m_Add(m_Specific(LHS0), m_SpecificIntAllowUndef(-*CInt))) || + m_Add(m_Specific(LHS0), m_SpecificIntAllowPoison(-*CInt))) || (CInt->isZero() && RHSOp == LHS0); }; @@ -3157,6 +3287,7 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); Value *LHS1 = LHS->getOperand(1), *RHS1 = RHS->getOperand(1); + const APInt *LHSC = nullptr, *RHSC = nullptr; match(LHS1, m_APInt(LHSC)); match(RHS1, m_APInt(RHSC)); @@ -3224,7 +3355,7 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldSignedTruncationCheck(LHS, RHS, I, Builder)) return V; - if (Value *V = foldIsPowerOf2(LHS, RHS, IsAnd, Builder)) + if (Value *V = foldIsPowerOf2(LHS, RHS, IsAnd, Builder, *this)) return V; if (Value *V = foldPowerOf2AndShiftedMask(LHS, RHS, IsAnd, Builder)) @@ -3262,6 +3393,11 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Constant::getAllOnesValue(LHS0->getType())); } + if (!IsLogical) + if (Value *V = + foldAndOrOfICmpsWithPow2AndWithZero(Builder, LHS, RHS, IsAnd, Q)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). if (!LHSC || !RHSC) return nullptr; @@ -3338,6 +3474,25 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, return foldAndOrOfICmpsUsingRanges(LHS, RHS, IsAnd); } +static Value *foldOrOfInversions(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + assert(I.getOpcode() == Instruction::Or && + "Simplification only supports or at the moment."); + + Value *Cmp1, *Cmp2, *Cmp3, *Cmp4; + if (!match(I.getOperand(0), m_And(m_Value(Cmp1), m_Value(Cmp2))) || + !match(I.getOperand(1), m_And(m_Value(Cmp3), m_Value(Cmp4)))) + return nullptr; + + // Check if any two pairs of the and operations are inversions of each other. + if (isKnownInversion(Cmp1, Cmp3) && isKnownInversion(Cmp2, Cmp4)) + return Builder.CreateXor(Cmp1, Cmp4); + if (isKnownInversion(Cmp1, Cmp4) && isKnownInversion(Cmp2, Cmp3)) + return Builder.CreateXor(Cmp1, Cmp3); + + return nullptr; +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. @@ -3367,6 +3522,11 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (Instruction *X = foldComplexAndOrPatterns(I, Builder)) return X; + // (A & B) | (C & D) -> A ^ D where A == ~C && B == ~D + // (A & B) | (C & D) -> A ^ C where A == ~D && B == ~C + if (Value *V = foldOrOfInversions(I, Builder)) + return replaceInstUsesWith(I, V); + // (A&B)|(A&C) -> A&(B|C) etc if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); @@ -3393,7 +3553,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { /*MatchBitReversals*/ true)) return BitOp; - if (Instruction *Funnel = matchFunnelShift(I, *this, DT)) + if (Instruction *Funnel = matchFunnelShift(I, *this)) return Funnel; if (Instruction *Concat = matchOrConcat(I, Builder)) @@ -3423,10 +3583,6 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { return BinaryOperator::CreateMul(X, IncrementY); } - // X | (X ^ Y) --> X | Y (4 commuted patterns) - if (match(&I, m_c_Or(m_Value(X), m_c_Xor(m_Deferred(X), m_Value(Y))))) - return BinaryOperator::CreateOr(X, Y); - // (A & C) | (B & D) Value *A, *B, *C, *D; if (match(Op0, m_And(m_Value(A), m_Value(C))) && @@ -3520,30 +3676,18 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { // (A ^ B) | ((B ^ C) ^ A) -> (A ^ B) | C if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) - if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) + if (match(Op1, + m_c_Xor(m_c_Xor(m_Specific(B), m_Value(C)), m_Specific(A))) || + match(Op1, m_c_Xor(m_c_Xor(m_Specific(A), m_Value(C)), m_Specific(B)))) return BinaryOperator::CreateOr(Op0, C); - // ((A ^ C) ^ B) | (B ^ A) -> (B ^ A) | C - if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B)))) - if (match(Op1, m_Xor(m_Specific(B), m_Specific(A)))) + // ((B ^ C) ^ A) | (A ^ B) -> (A ^ B) | C + if (match(Op1, m_Xor(m_Value(A), m_Value(B)))) + if (match(Op0, + m_c_Xor(m_c_Xor(m_Specific(B), m_Value(C)), m_Specific(A))) || + match(Op0, m_c_Xor(m_c_Xor(m_Specific(A), m_Value(C)), m_Specific(B)))) return BinaryOperator::CreateOr(Op1, C); - // ((A & B) ^ C) | B -> C | B - if (match(Op0, m_c_Xor(m_c_And(m_Value(A), m_Specific(Op1)), m_Value(C)))) - return BinaryOperator::CreateOr(C, Op1); - - // B | ((A & B) ^ C) -> B | C - if (match(Op1, m_c_Xor(m_c_And(m_Value(A), m_Specific(Op0)), m_Value(C)))) - return BinaryOperator::CreateOr(Op0, C); - - // ((B | C) & A) | B -> B | (A & C) - if (match(Op0, m_c_And(m_c_Or(m_Specific(Op1), m_Value(C)), m_Value(A)))) - return BinaryOperator::CreateOr(Op1, Builder.CreateAnd(A, C)); - - // B | ((B | C) & A) -> B | (A & C) - if (match(Op1, m_c_And(m_c_Or(m_Specific(Op0), m_Value(C)), m_Value(A)))) - return BinaryOperator::CreateOr(Op0, Builder.CreateAnd(A, C)); - if (Instruction *DeMorgan = matchDeMorgansLaws(I, *this)) return DeMorgan; @@ -3564,8 +3708,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { // (A & B) | (A ^ B) --> A | B // (B & A) | (A ^ B) --> A | B - if (match(Op0, m_And(m_Specific(A), m_Specific(B))) || - match(Op0, m_And(m_Specific(B), m_Specific(A)))) + if (match(Op0, m_c_And(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateOr(A, B); // ~A | (A ^ B) --> ~(A & B) @@ -3596,32 +3739,8 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { Value *Nand = Builder.CreateNot(Builder.CreateAnd(A, B), "nand"); return BinaryOperator::CreateOr(Nand, C); } - - // A | (~A ^ B) --> ~B | A - // B | (A ^ ~B) --> ~A | B - if (Op1->hasOneUse() && match(A, m_Not(m_Specific(Op0)))) { - Value *NotB = Builder.CreateNot(B, B->getName() + ".not"); - return BinaryOperator::CreateOr(NotB, Op0); - } - if (Op1->hasOneUse() && match(B, m_Not(m_Specific(Op0)))) { - Value *NotA = Builder.CreateNot(A, A->getName() + ".not"); - return BinaryOperator::CreateOr(NotA, Op0); - } } - // A | ~(A | B) -> A | ~B - // A | ~(A ^ B) -> A | ~B - if (match(Op1, m_Not(m_Value(A)))) - if (BinaryOperator *B = dyn_cast<BinaryOperator>(A)) - if ((Op0 == B->getOperand(0) || Op0 == B->getOperand(1)) && - Op1->hasOneUse() && (B->getOpcode() == Instruction::Or || - B->getOpcode() == Instruction::Xor)) { - Value *NotOp = Op0 == B->getOperand(0) ? B->getOperand(1) : - B->getOperand(0); - Value *Not = Builder.CreateNot(NotOp, NotOp->getName() + ".not"); - return BinaryOperator::CreateOr(Not, Op0); - } - if (SwappedForXor) std::swap(Op0, Op1); @@ -3742,7 +3861,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { const auto TryXorOpt = [&](Value *Lhs, Value *Rhs) -> Instruction * { if (match(Lhs, m_c_Xor(m_And(m_Value(A), m_Value(B)), m_Deferred(A))) && match(Rhs, - m_c_Xor(m_And(m_Specific(A), m_Specific(B)), m_Deferred(B)))) { + m_c_Xor(m_And(m_Specific(A), m_Specific(B)), m_Specific(B)))) { return BinaryOperator::CreateXor(A, B); } return nullptr; @@ -3887,13 +4006,12 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { // This is generous interpretation of noimplicitfloat, this is not a true // floating-point operation. Value *CastOp; - if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) && + if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) && + match(Op1, m_SignMask()) && !Builder.GetInsertBlock()->getParent()->hasFnAttribute( Attribute::NoImplicitFloat)) { Type *EltTy = CastOp->getType()->getScalarType(); - if (EltTy->isFloatingPointTy() && EltTy->isIEEE() && - EltTy->getPrimitiveSizeInBits() == - I.getType()->getScalarType()->getPrimitiveSizeInBits()) { + if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) { Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp); Value *FNegFAbs = Builder.CreateFNeg(FAbs); return new BitCastInst(FNegFAbs, I.getType()); @@ -3911,6 +4029,19 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (Instruction *Res = foldBitwiseLogicWithIntrinsics(I, Builder)) return Res; + if (Value *V = + simplifyAndOrWithOpReplaced(Op0, Op1, Constant::getNullValue(Ty), + /*SimplifyOnly*/ false, *this)) + return BinaryOperator::CreateOr(V, Op1); + if (Value *V = + simplifyAndOrWithOpReplaced(Op1, Op0, Constant::getNullValue(Ty), + /*SimplifyOnly*/ false, *this)) + return BinaryOperator::CreateOr(Op0, V); + + if (cast<PossiblyDisjointInst>(I).isDisjoint()) + if (Value *V = SimplifyAddWithRemainder(I)) + return replaceInstUsesWith(I, V); + return nullptr; } @@ -4193,10 +4324,11 @@ static Instruction *canonicalizeAbs(BinaryOperator &Xor, // xor (add A, Op1), Op1 ; add -1 and flip bits if negative // --> (A < 0) ? -A : A Value *IsNeg = Builder.CreateIsNeg(A); - // Copy the nuw/nsw flags from the add to the negate. + // Copy the nsw flags from the add to the negate. auto *Add = cast<BinaryOperator>(Op0); - Value *NegA = Builder.CreateNeg(A, "", Add->hasNoUnsignedWrap(), - Add->hasNoSignedWrap()); + Value *NegA = Add->hasNoUnsignedWrap() + ? Constant::getNullValue(A->getType()) + : Builder.CreateNeg(A, "", Add->hasNoSignedWrap()); return SelectInst::Create(IsNeg, NegA, A); } return nullptr; @@ -4385,26 +4517,16 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { // ~(C >>s Y) --> ~C >>u Y (when inverting the replicated sign bits) Constant *C; if (match(NotVal, m_AShr(m_Constant(C), m_Value(Y))) && - match(C, m_Negative())) { - // We matched a negative constant, so propagating undef is unsafe. - // Clamp undef elements to -1. - Type *EltTy = Ty->getScalarType(); - C = Constant::replaceUndefsWith(C, ConstantInt::getAllOnesValue(EltTy)); + match(C, m_Negative())) return BinaryOperator::CreateLShr(ConstantExpr::getNot(C), Y); - } // ~(C >>u Y) --> ~C >>s Y (when inverting the replicated sign bits) if (match(NotVal, m_LShr(m_Constant(C), m_Value(Y))) && - match(C, m_NonNegative())) { - // We matched a non-negative constant, so propagating undef is unsafe. - // Clamp undef elements to 0. - Type *EltTy = Ty->getScalarType(); - C = Constant::replaceUndefsWith(C, ConstantInt::getNullValue(EltTy)); + match(C, m_NonNegative())) return BinaryOperator::CreateAShr(ConstantExpr::getNot(C), Y); - } // ~(X + C) --> ~C - X - if (match(NotVal, m_c_Add(m_Value(X), m_ImmConstant(C)))) + if (match(NotVal, m_Add(m_Value(X), m_ImmConstant(C)))) return BinaryOperator::CreateSub(ConstantExpr::getNot(C), X); // ~(X - Y) --> ~X + Y @@ -4547,8 +4669,12 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Value *M; if (match(&I, m_c_Xor(m_c_And(m_Not(m_Value(M)), m_Value()), - m_c_And(m_Deferred(M), m_Value())))) - return BinaryOperator::CreateDisjointOr(Op0, Op1); + m_c_And(m_Deferred(M), m_Value())))) { + if (isGuaranteedNotToBeUndef(M)) + return BinaryOperator::CreateDisjointOr(Op0, Op1); + else + return BinaryOperator::CreateOr(Op0, Op1); + } if (Instruction *Xor = visitMaskedMerge(I, Builder)) return Xor; @@ -4587,7 +4713,7 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { // constant depending on whether this input is less than 0. const APInt *CA; if (match(Op0, m_OneUse(m_TruncOrSelf( - m_AShr(m_Value(X), m_APIntAllowUndef(CA))))) && + m_AShr(m_Value(X), m_APIntAllowPoison(CA))))) && *CA == X->getType()->getScalarSizeInBits() - 1 && !match(C1, m_AllOnes())) { assert(!C1->isZeroValue() && "Unexpected xor with 0"); @@ -4658,13 +4784,12 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { // Assumes any IEEE-represented type has the sign bit in the high bit. // TODO: Unify with APInt matcher. This version allows undef unlike m_APInt Value *CastOp; - if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) && + if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) && + match(Op1, m_SignMask()) && !Builder.GetInsertBlock()->getParent()->hasFnAttribute( Attribute::NoImplicitFloat)) { Type *EltTy = CastOp->getType()->getScalarType(); - if (EltTy->isFloatingPointTy() && EltTy->isIEEE() && - EltTy->getPrimitiveSizeInBits() == - I.getType()->getScalarType()->getPrimitiveSizeInBits()) { + if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) { Value *FNeg = Builder.CreateFNeg(CastOp); return new BitCastInst(FNeg, I.getType()); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index bc43edb5e620..3223fccbcf49 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -171,21 +171,8 @@ Instruction *InstCombinerImpl::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { IntegerType* IntType = IntegerType::get(MI->getContext(), Size<<3); // If the memcpy has metadata describing the members, see if we can get the - // TBAA tag describing our copy. - AAMDNodes AACopyMD = MI->getAAMetadata(); - - if (MDNode *M = AACopyMD.TBAAStruct) { - AACopyMD.TBAAStruct = nullptr; - if (M->getNumOperands() == 3 && M->getOperand(0) && - mdconst::hasa<ConstantInt>(M->getOperand(0)) && - mdconst::extract<ConstantInt>(M->getOperand(0))->isZero() && - M->getOperand(1) && - mdconst::hasa<ConstantInt>(M->getOperand(1)) && - mdconst::extract<ConstantInt>(M->getOperand(1))->getValue() == - Size && - M->getOperand(2) && isa<MDNode>(M->getOperand(2))) - AACopyMD.TBAA = cast<MDNode>(M->getOperand(2)); - } + // TBAA, scope and noalias tags describing our copy. + AAMDNodes AACopyMD = MI->getAAMetadata().adjustForAccess(Size); Value *Src = MI->getArgOperand(1); Value *Dest = MI->getArgOperand(0); @@ -287,7 +274,7 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) { DbgAssign->replaceVariableLocationOp(FillC, FillVal); }; for_each(at::getAssignmentMarkers(S), replaceOpForAssignmentMarkers); - for_each(at::getDPVAssignmentMarkers(S), replaceOpForAssignmentMarkers); + for_each(at::getDVRAssignmentMarkers(S), replaceOpForAssignmentMarkers); S->setAlignment(Alignment); if (isa<AtomicMemSetInst>(MI)) @@ -320,7 +307,7 @@ Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) { // If we can unconditionally load from this address, replace with a // load/select idiom. TODO: use DT for context sensitive query if (isDereferenceablePointer(LoadPtr, II.getType(), - II.getModule()->getDataLayout(), &II, &AC)) { + II.getDataLayout(), &II, &AC)) { LoadInst *LI = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, "unmaskedload"); LI->copyMetadata(II); @@ -517,6 +504,13 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) { return IC.replaceInstUsesWith(II, ConstantInt::getNullValue(II.getType())); } + // If ctlz/cttz is only used as a shift amount, set is_zero_poison to true. + if (II.hasOneUse() && match(Op1, m_Zero()) && + match(II.user_back(), m_Shift(m_Value(), m_Specific(&II)))) { + II.dropUBImplyingAttrsAndMetadata(); + return IC.replaceOperand(II, 1, IC.Builder.getTrue()); + } + Constant *C; if (IsTZ) { @@ -570,6 +564,13 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) { IC.Builder.CreateBinaryIntrinsic(Intrinsic::cttz, C, Op1); return BinaryOperator::CreateSub(ConstCttz, X); } + + // cttz(add(lshr(UINT_MAX, %val), 1)) --> sub(width, %val) + if (match(Op0, m_Add(m_LShr(m_AllOnes(), m_Value(X)), m_One()))) { + Value *Width = + ConstantInt::get(II.getType(), II.getType()->getScalarSizeInBits()); + return BinaryOperator::CreateSub(Width, X); + } } else { // ctlz(lshr(%const, %val), 1) --> add(ctlz(%const, 1), %val) if (match(Op0, m_LShr(m_ImmConstant(C), m_Value(X))) && @@ -609,20 +610,18 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) { // then change the 'ZeroIsPoison' parameter to 'true' // because we know the zero behavior can't affect the result. if (!Known.One.isZero() || - isKnownNonZero(Op0, IC.getDataLayout(), 0, &IC.getAssumptionCache(), &II, - &IC.getDominatorTree())) { + isKnownNonZero(Op0, IC.getSimplifyQuery().getWithInstruction(&II))) { if (!match(II.getArgOperand(1), m_One())) return IC.replaceOperand(II, 1, IC.Builder.getTrue()); } - // Add range metadata since known bits can't completely reflect what we know. - auto *IT = cast<IntegerType>(Op0->getType()->getScalarType()); - if (IT && IT->getBitWidth() != 1 && !II.getMetadata(LLVMContext::MD_range)) { - Metadata *LowAndHigh[] = { - ConstantAsMetadata::get(ConstantInt::get(IT, DefiniteZeros)), - ConstantAsMetadata::get(ConstantInt::get(IT, PossibleZeros + 1))}; - II.setMetadata(LLVMContext::MD_range, - MDNode::get(II.getContext(), LowAndHigh)); + // Add range attribute since known bits can't completely reflect what we know. + unsigned BitWidth = Op0->getType()->getScalarSizeInBits(); + if (BitWidth != 1 && !II.hasRetAttr(Attribute::Range) && + !II.getMetadata(LLVMContext::MD_range)) { + ConstantRange Range(APInt(BitWidth, DefiniteZeros), + APInt(BitWidth, PossibleZeros + 1)); + II.addRangeRetAttr(Range); return &II; } @@ -694,16 +693,12 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) { Constant::getNullValue(Ty)), Ty); - // Add range metadata since known bits can't completely reflect what we know. - auto *IT = cast<IntegerType>(Ty->getScalarType()); - unsigned MinCount = Known.countMinPopulation(); - unsigned MaxCount = Known.countMaxPopulation(); - if (IT->getBitWidth() != 1 && !II.getMetadata(LLVMContext::MD_range)) { - Metadata *LowAndHigh[] = { - ConstantAsMetadata::get(ConstantInt::get(IT, MinCount)), - ConstantAsMetadata::get(ConstantInt::get(IT, MaxCount + 1))}; - II.setMetadata(LLVMContext::MD_range, - MDNode::get(II.getContext(), LowAndHigh)); + // Add range attribute since known bits can't completely reflect what we know. + if (BitWidth != 1 && !II.hasRetAttr(Attribute::Range) && + !II.getMetadata(LLVMContext::MD_range)) { + ConstantRange Range(APInt(BitWidth, Known.countMinPopulation()), + APInt(BitWidth, Known.countMaxPopulation() + 1)); + II.addRangeRetAttr(Range); return &II; } @@ -918,7 +913,8 @@ Instruction *InstCombinerImpl::foldIntrinsicIsFPClass(IntrinsicInst &II) { const FPClassTest OrderedMask = Mask & ~fcNan; const FPClassTest OrderedInvertedMask = ~OrderedMask & ~fcNan; - const bool IsStrict = II.isStrictFP(); + const bool IsStrict = + II.getFunction()->getAttributes().hasFnAttr(Attribute::StrictFP); Value *FNegSrc; if (match(Src0, m_FNeg(m_Value(FNegSrc)))) { @@ -1047,10 +1043,8 @@ Instruction *InstCombinerImpl::foldIntrinsicIsFPClass(IntrinsicInst &II) { return nullptr; } -static std::optional<bool> getKnownSign(Value *Op, Instruction *CxtI, - const DataLayout &DL, AssumptionCache *AC, - DominatorTree *DT) { - KnownBits Known = computeKnownBits(Op, DL, 0, AC, CxtI, DT); +static std::optional<bool> getKnownSign(Value *Op, const SimplifyQuery &SQ) { + KnownBits Known = computeKnownBits(Op, /*Depth=*/0, SQ); if (Known.isNonNegative()) return false; if (Known.isNegative()) @@ -1058,34 +1052,30 @@ static std::optional<bool> getKnownSign(Value *Op, Instruction *CxtI, Value *X, *Y; if (match(Op, m_NSWSub(m_Value(X), m_Value(Y)))) - return isImpliedByDomCondition(ICmpInst::ICMP_SLT, X, Y, CxtI, DL); + return isImpliedByDomCondition(ICmpInst::ICMP_SLT, X, Y, SQ.CxtI, SQ.DL); - return isImpliedByDomCondition( - ICmpInst::ICMP_SLT, Op, Constant::getNullValue(Op->getType()), CxtI, DL); + return std::nullopt; } -static std::optional<bool> getKnownSignOrZero(Value *Op, Instruction *CxtI, - const DataLayout &DL, - AssumptionCache *AC, - DominatorTree *DT) { - if (std::optional<bool> Sign = getKnownSign(Op, CxtI, DL, AC, DT)) +static std::optional<bool> getKnownSignOrZero(Value *Op, + const SimplifyQuery &SQ) { + if (std::optional<bool> Sign = getKnownSign(Op, SQ)) return Sign; Value *X, *Y; if (match(Op, m_NSWSub(m_Value(X), m_Value(Y)))) - return isImpliedByDomCondition(ICmpInst::ICMP_SLE, X, Y, CxtI, DL); + return isImpliedByDomCondition(ICmpInst::ICMP_SLE, X, Y, SQ.CxtI, SQ.DL); return std::nullopt; } /// Return true if two values \p Op0 and \p Op1 are known to have the same sign. -static bool signBitMustBeTheSame(Value *Op0, Value *Op1, Instruction *CxtI, - const DataLayout &DL, AssumptionCache *AC, - DominatorTree *DT) { - std::optional<bool> Known1 = getKnownSign(Op1, CxtI, DL, AC, DT); +static bool signBitMustBeTheSame(Value *Op0, Value *Op1, + const SimplifyQuery &SQ) { + std::optional<bool> Known1 = getKnownSign(Op1, SQ); if (!Known1) return false; - std::optional<bool> Known0 = getKnownSign(Op0, CxtI, DL, AC, DT); + std::optional<bool> Known0 = getKnownSign(Op0, SQ); if (!Known0) return false; return *Known0 == *Known1; @@ -1235,10 +1225,11 @@ static Instruction *foldClampRangeOfTwo(IntrinsicInst *II, /// If this min/max has a constant operand and an operand that is a matching /// min/max with a constant operand, constant-fold the 2 constant operands. static Value *reassociateMinMaxWithConstants(IntrinsicInst *II, - IRBuilderBase &Builder) { + IRBuilderBase &Builder, + const SimplifyQuery &SQ) { Intrinsic::ID MinMaxID = II->getIntrinsicID(); - auto *LHS = dyn_cast<IntrinsicInst>(II->getArgOperand(0)); - if (!LHS || LHS->getIntrinsicID() != MinMaxID) + auto *LHS = dyn_cast<MinMaxIntrinsic>(II->getArgOperand(0)); + if (!LHS) return nullptr; Constant *C0, *C1; @@ -1246,11 +1237,21 @@ static Value *reassociateMinMaxWithConstants(IntrinsicInst *II, !match(II->getArgOperand(1), m_ImmConstant(C1))) return nullptr; - // max (max X, C0), C1 --> max X, (max C0, C1) --> max X, NewC + // max (max X, C0), C1 --> max X, (max C0, C1) + // min (min X, C0), C1 --> min X, (min C0, C1) + // umax (smax X, nneg C0), nneg C1 --> smax X, (umax C0, C1) + // smin (umin X, nneg C0), nneg C1 --> umin X, (smin C0, C1) + Intrinsic::ID InnerMinMaxID = LHS->getIntrinsicID(); + if (InnerMinMaxID != MinMaxID && + !(((MinMaxID == Intrinsic::umax && InnerMinMaxID == Intrinsic::smax) || + (MinMaxID == Intrinsic::smin && InnerMinMaxID == Intrinsic::umin)) && + isKnownNonNegative(C0, SQ) && isKnownNonNegative(C1, SQ))) + return nullptr; + ICmpInst::Predicate Pred = MinMaxIntrinsic::getPredicate(MinMaxID); Value *CondC = Builder.CreateICmp(Pred, C0, C1); Value *NewC = Builder.CreateSelect(CondC, C0, C1); - return Builder.CreateIntrinsic(MinMaxID, II->getType(), + return Builder.CreateIntrinsic(InnerMinMaxID, II->getType(), {LHS->getArgOperand(0), NewC}); } @@ -1430,6 +1431,70 @@ static Instruction *foldBitOrderCrossLogicOp(Value *V, return nullptr; } +static Value *simplifyReductionOperand(Value *Arg, bool CanReorderLanes) { + if (!CanReorderLanes) + return nullptr; + + Value *V; + if (match(Arg, m_VecReverse(m_Value(V)))) + return V; + + ArrayRef<int> Mask; + if (!isa<FixedVectorType>(Arg->getType()) || + !match(Arg, m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))) || + !cast<ShuffleVectorInst>(Arg)->isSingleSource()) + return nullptr; + + int Sz = Mask.size(); + SmallBitVector UsedIndices(Sz); + for (int Idx : Mask) { + if (Idx == PoisonMaskElem || UsedIndices.test(Idx)) + return nullptr; + UsedIndices.set(Idx); + } + + // Can remove shuffle iff just shuffled elements, no repeats, undefs, or + // other changes. + return UsedIndices.all() ? V : nullptr; +} + +/// Fold an unsigned minimum of trailing or leading zero bits counts: +/// umin(cttz(CtOp, ZeroUndef), ConstOp) --> cttz(CtOp | (1 << ConstOp)) +/// umin(ctlz(CtOp, ZeroUndef), ConstOp) --> ctlz(CtOp | (SignedMin +/// >> ConstOp)) +template <Intrinsic::ID IntrID> +static Value * +foldMinimumOverTrailingOrLeadingZeroCount(Value *I0, Value *I1, + const DataLayout &DL, + InstCombiner::BuilderTy &Builder) { + static_assert(IntrID == Intrinsic::cttz || IntrID == Intrinsic::ctlz, + "This helper only supports cttz and ctlz intrinsics"); + + Value *CtOp; + Value *ZeroUndef; + if (!match(I0, + m_OneUse(m_Intrinsic<IntrID>(m_Value(CtOp), m_Value(ZeroUndef))))) + return nullptr; + + unsigned BitWidth = I1->getType()->getScalarSizeInBits(); + auto LessBitWidth = [BitWidth](auto &C) { return C.ult(BitWidth); }; + if (!match(I1, m_CheckedInt(LessBitWidth))) + // We have a constant >= BitWidth (which can be handled by CVP) + // or a non-splat vector with elements < and >= BitWidth + return nullptr; + + Type *Ty = I1->getType(); + Constant *NewConst = ConstantFoldBinaryOpOperands( + IntrID == Intrinsic::cttz ? Instruction::Shl : Instruction::LShr, + IntrID == Intrinsic::cttz + ? ConstantInt::get(Ty, 1) + : ConstantInt::get(Ty, APInt::getSignedMinValue(BitWidth)), + cast<Constant>(I1), DL); + return Builder.CreateBinaryIntrinsic( + IntrID, Builder.CreateOr(CtOp, NewConst), + ConstantInt::getTrue(ZeroUndef->getType())); +} + /// CallInst simplification. This mostly only handles folding of intrinsic /// instructions. For normal calls, it allows visitCallBase to do the heavy /// lifting. @@ -1584,8 +1649,19 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (match(IIOperand, m_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X)))) return replaceOperand(*II, 0, X); + Value *Y; + // abs(a * abs(b)) -> abs(a * b) + if (match(IIOperand, + m_OneUse(m_c_Mul(m_Value(X), + m_Intrinsic<Intrinsic::abs>(m_Value(Y)))))) { + bool NSW = + cast<Instruction>(IIOperand)->hasNoSignedWrap() && IntMinIsPoison; + auto *XY = NSW ? Builder.CreateNSWMul(X, Y) : Builder.CreateMul(X, Y); + return replaceOperand(*II, 0, XY); + } + if (std::optional<bool> Known = - getKnownSignOrZero(IIOperand, II, DL, &AC, &DT)) { + getKnownSignOrZero(IIOperand, SQ.getWithInstruction(II))) { // abs(x) -> x if x >= 0 (include abs(x-y) --> x - y where x >= y) // abs(x) -> x if x > 0 (include abs(x-y) --> x - y where x > y) if (!*Known) @@ -1624,6 +1700,16 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Value *Cmp = Builder.CreateICmpNE(I0, Zero); return CastInst::Create(Instruction::ZExt, Cmp, II->getType()); } + // umin(cttz(x), const) --> cttz(x | (1 << const)) + if (Value *FoldedCttz = + foldMinimumOverTrailingOrLeadingZeroCount<Intrinsic::cttz>( + I0, I1, DL, Builder)) + return replaceInstUsesWith(*II, FoldedCttz); + // umin(ctlz(x), const) --> ctlz(x | (SignedMin >> const)) + if (Value *FoldedCtlz = + foldMinimumOverTrailingOrLeadingZeroCount<Intrinsic::ctlz>( + I0, I1, DL, Builder)) + return replaceInstUsesWith(*II, FoldedCtlz); [[fallthrough]]; } case Intrinsic::umax: { @@ -1710,7 +1796,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { bool UseAndN = IID == Intrinsic::smin || IID == Intrinsic::umin; if (IID == Intrinsic::smax || IID == Intrinsic::smin) { - auto KnownSign = getKnownSign(X, II, DL, &AC, &DT); + auto KnownSign = getKnownSign(X, SQ.getWithInstruction(II)); if (KnownSign == std::nullopt) { UseOr = false; UseAndN = false; @@ -1759,6 +1845,13 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Instruction *I = moveAddAfterMinMax(II, Builder)) return I; + // minmax (X & NegPow2C, Y & NegPow2C) --> minmax(X, Y) & NegPow2C + const APInt *RHSC; + if (match(I0, m_OneUse(m_And(m_Value(X), m_NegatedPower2(RHSC)))) && + match(I1, m_OneUse(m_And(m_Value(Y), m_SpecificInt(*RHSC))))) + return BinaryOperator::CreateAnd(Builder.CreateBinaryIntrinsic(IID, X, Y), + ConstantInt::get(II->getType(), *RHSC)); + // smax(X, -X) --> abs(X) // smin(X, -X) --> -abs(X) // umax(X, -X) --> -abs(X) @@ -1780,7 +1873,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // We don't have a "nabs" intrinsic, so negate if needed based on the // max/min operation. if (IID == Intrinsic::smin || IID == Intrinsic::umax) - Abs = Builder.CreateNeg(Abs, "nabs", /* NUW */ false, IntMinIsPoison); + Abs = Builder.CreateNeg(Abs, "nabs", IntMinIsPoison); return replaceInstUsesWith(CI, Abs); } @@ -1790,7 +1883,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Instruction *SAdd = matchSAddSubSat(*II)) return SAdd; - if (Value *NewMinMax = reassociateMinMaxWithConstants(II, Builder)) + if (Value *NewMinMax = reassociateMinMaxWithConstants(II, Builder, SQ)) return replaceInstUsesWith(*II, NewMinMax); if (Instruction *R = reassociateMinMaxWithConstantInOperand(II, Builder)) @@ -1800,8 +1893,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { return NewMinMax; // Try to fold minmax with constant RHS based on range information - const APInt *RHSC; - if (match(I1, m_APIntAllowUndef(RHSC))) { + if (match(I1, m_APIntAllowPoison(RHSC))) { ICmpInst::Predicate Pred = ICmpInst::getNonStrictPredicate(MinMaxIntrinsic::getPredicate(IID)); bool IsSigned = MinMaxIntrinsic::isSigned(IID); @@ -1845,12 +1937,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // bswap (lshr X, Y) --> shl (bswap X), Y Value *X, *Y; if (match(IIOperand, m_OneUse(m_LogicalShift(m_Value(X), m_Value(Y))))) { - // The transform allows undef vector elements, so try a constant match - // first. If knownbits can handle that case, that clause could be removed. unsigned BitWidth = IIOperand->getType()->getScalarSizeInBits(); - const APInt *C; - if ((match(Y, m_APIntAllowUndef(C)) && (*C & 7) == 0) || - MaskedValueIsZero(Y, APInt::getLowBitsSet(BitWidth, 3))) { + if (MaskedValueIsZero(Y, APInt::getLowBitsSet(BitWidth, 3))) { Value *NewSwap = Builder.CreateUnaryIntrinsic(Intrinsic::bswap, X); BinaryOperator::BinaryOps InverseShift = cast<BinaryOperator>(IIOperand)->getOpcode() == Instruction::Shl @@ -1964,15 +2052,19 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (ModuloC != ShAmtC) return replaceOperand(*II, 2, ModuloC); - assert(ConstantExpr::getICmp(ICmpInst::ICMP_UGT, WidthC, ShAmtC) == - ConstantInt::getTrue(CmpInst::makeCmpResultType(Ty)) && + assert(match(ConstantFoldCompareInstOperands(ICmpInst::ICMP_UGT, WidthC, + ShAmtC, DL), + m_One()) && "Shift amount expected to be modulo bitwidth"); // Canonicalize funnel shift right by constant to funnel shift left. This // is not entirely arbitrary. For historical reasons, the backend may // recognize rotate left patterns but miss rotate right patterns. if (IID == Intrinsic::fshr) { - // fshr X, Y, C --> fshl X, Y, (BitWidth - C) + // fshr X, Y, C --> fshl X, Y, (BitWidth - C) if C is not zero. + if (!isKnownNonZero(ShAmtC, SQ.getWithInstruction(II))) + return nullptr; + Constant *LeftShiftC = ConstantExpr::getSub(WidthC, ShAmtC); Module *Mod = II->getModule(); Function *Fshl = Intrinsic::getDeclaration(Mod, Intrinsic::fshl, Ty); @@ -2046,7 +2138,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // See if we can deduce non-null. if (!CI.hasRetAttr(Attribute::NonNull) && (Known.isNonZero() || - isKnownNonZero(II, DL, /*Depth*/ 0, &AC, II, &DT))) { + isKnownNonZero(II, getSimplifyQuery().getWithInstruction(II)))) { CI.addRetAttr(Attribute::NonNull); Changed = true; } @@ -2078,8 +2170,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Value *Arg0 = II->getArgOperand(0); Value *Arg1 = II->getArgOperand(1); bool IsSigned = IID == Intrinsic::sadd_with_overflow; - bool HasNWAdd = IsSigned ? match(Arg0, m_NSWAdd(m_Value(X), m_APInt(C0))) - : match(Arg0, m_NUWAdd(m_Value(X), m_APInt(C0))); + bool HasNWAdd = IsSigned + ? match(Arg0, m_NSWAddLike(m_Value(X), m_APInt(C0))) + : match(Arg0, m_NUWAddLike(m_Value(X), m_APInt(C0))); if (HasNWAdd && match(Arg1, m_APInt(C1))) { bool Overflow; APInt NewC = @@ -2154,8 +2247,22 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } + // usub_sat((sub nuw C, A), C1) -> usub_sat(usub_sat(C, C1), A) + // which after that: + // usub_sat((sub nuw C, A), C1) -> usub_sat(C - C1, A) if C1 u< C + // usub_sat((sub nuw C, A), C1) -> 0 otherwise + Constant *C, *C1; + Value *A; + if (IID == Intrinsic::usub_sat && + match(Arg0, m_NUWSub(m_ImmConstant(C), m_Value(A))) && + match(Arg1, m_ImmConstant(C1))) { + auto *NewC = Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, C, C1); + auto *NewSub = + Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, NewC, A); + return replaceInstUsesWith(*SI, NewSub); + } + // ssub.sat(X, C) -> sadd.sat(X, -C) if C != MIN - Constant *C; if (IID == Intrinsic::ssub_sat && match(Arg1, m_Constant(C)) && C->isNotMinSignedValue()) { Value *NegVal = ConstantExpr::getNeg(C); @@ -2260,13 +2367,14 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { default: llvm_unreachable("unexpected intrinsic ID"); } - Instruction *NewCall = Builder.CreateBinaryIntrinsic( + Value *V = Builder.CreateBinaryIntrinsic( IID, X, ConstantFP::get(Arg0->getType(), Res), II); // TODO: Conservatively intersecting FMF. If Res == C2, the transform // was a simplification (so Arg0 and its original flags could // propagate?) - NewCall->andIRFlags(M); - return replaceInstUsesWith(*II, NewCall); + if (auto *CI = dyn_cast<CallInst>(V)) + CI->andIRFlags(M); + return replaceInstUsesWith(*II, V); } } @@ -2281,11 +2389,18 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // max X, -X --> fabs X // min X, -X --> -(fabs X) - // TODO: Remove one-use limitation? That is obviously better for max. - // It would be an extra instruction for min (fnabs), but that is - // still likely better for analysis and codegen. - if ((match(Arg0, m_OneUse(m_FNeg(m_Value(X)))) && Arg1 == X) || - (match(Arg1, m_OneUse(m_FNeg(m_Value(X)))) && Arg0 == X)) { + // TODO: Remove one-use limitation? That is obviously better for max, + // hence why we don't check for one-use for that. However, + // it would be an extra instruction for min (fnabs), but + // that is still likely better for analysis and codegen. + auto IsMinMaxOrXNegX = [IID, &X](Value *Op0, Value *Op1) { + if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_Specific(X))) + return Op0->hasOneUse() || + (IID != Intrinsic::minimum && IID != Intrinsic::minnum); + return false; + }; + + if (IsMinMaxOrXNegX(Arg0, Arg1) || IsMinMaxOrXNegX(Arg1, Arg0)) { Value *R = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, II); if (IID == Intrinsic::minimum || IID == Intrinsic::minnum) R = Builder.CreateFNegFMF(R, II); @@ -2352,17 +2467,6 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { break; } case Intrinsic::fmuladd: { - // Canonicalize fast fmuladd to the separate fmul + fadd. - if (II->isFast()) { - BuilderTy::FastMathFlagGuard Guard(Builder); - Builder.setFastMathFlags(II->getFastMathFlags()); - Value *Mul = Builder.CreateFMul(II->getArgOperand(0), - II->getArgOperand(1)); - Value *Add = Builder.CreateFAdd(Mul, II->getArgOperand(2)); - Add->takeName(II); - return replaceInstUsesWith(*II, Add); - } - // Try to simplify the underlying FMul. if (Value *V = simplifyFMulInst(II->getArgOperand(0), II->getArgOperand(1), II->getFastMathFlags(), @@ -2415,20 +2519,20 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } case Intrinsic::copysign: { Value *Mag = II->getArgOperand(0), *Sign = II->getArgOperand(1); - if (SignBitMustBeZero(Sign, DL, &TLI)) { + if (std::optional<bool> KnownSignBit = computeKnownFPSignBit( + Sign, /*Depth=*/0, getSimplifyQuery().getWithInstruction(II))) { + if (*KnownSignBit) { + // If we know that the sign argument is negative, reduce to FNABS: + // copysign Mag, -Sign --> fneg (fabs Mag) + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, Mag, II); + return replaceInstUsesWith(*II, Builder.CreateFNegFMF(Fabs, II)); + } + // If we know that the sign argument is positive, reduce to FABS: // copysign Mag, +Sign --> fabs Mag Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, Mag, II); return replaceInstUsesWith(*II, Fabs); } - // TODO: There should be a ValueTracking sibling like SignBitMustBeOne. - const APFloat *C; - if (match(Sign, m_APFloat(C)) && C->isNegative()) { - // If we know that the sign argument is negative, reduce to FNABS: - // copysign Mag, -Sign --> fneg (fabs Mag) - Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, Mag, II); - return replaceInstUsesWith(*II, Builder.CreateFNegFMF(Fabs, II)); - } // Propagate sign argument through nested calls: // copysign Mag, (copysign ?, X) --> copysign Mag, X @@ -2436,6 +2540,16 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (match(Sign, m_Intrinsic<Intrinsic::copysign>(m_Value(), m_Value(X)))) return replaceOperand(*II, 1, X); + // Clear sign-bit of constant magnitude: + // copysign -MagC, X --> copysign MagC, X + // TODO: Support constant folding for fabs + const APFloat *MagC; + if (match(Mag, m_APFloat(MagC)) && MagC->isNegative()) { + APFloat PosMagC = *MagC; + PosMagC.clearSign(); + return replaceOperand(*II, 0, ConstantFP::get(Mag->getType(), PosMagC)); + } + // Peek through changes of magnitude's sign-bit. This call rewrites those: // copysign (fabs X), Sign --> copysign X, Sign // copysign (fneg X), Sign --> copysign X, Sign @@ -2446,13 +2560,25 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } case Intrinsic::fabs: { Value *Cond, *TVal, *FVal; - if (match(II->getArgOperand(0), - m_Select(m_Value(Cond), m_Value(TVal), m_Value(FVal)))) { + Value *Arg = II->getArgOperand(0); + Value *X; + // fabs (-X) --> fabs (X) + if (match(Arg, m_FNeg(m_Value(X)))) { + CallInst *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, II); + return replaceInstUsesWith(CI, Fabs); + } + + if (match(Arg, m_Select(m_Value(Cond), m_Value(TVal), m_Value(FVal)))) { // fabs (select Cond, TrueC, FalseC) --> select Cond, AbsT, AbsF - if (isa<Constant>(TVal) && isa<Constant>(FVal)) { + if (isa<Constant>(TVal) || isa<Constant>(FVal)) { CallInst *AbsT = Builder.CreateCall(II->getCalledFunction(), {TVal}); CallInst *AbsF = Builder.CreateCall(II->getCalledFunction(), {FVal}); - return SelectInst::Create(Cond, AbsT, AbsF); + SelectInst *SI = SelectInst::Create(Cond, AbsT, AbsF); + FastMathFlags FMF1 = II->getFastMathFlags(); + FastMathFlags FMF2 = cast<SelectInst>(Arg)->getFastMathFlags(); + FMF2.setNoSignedZeros(false); + SI->setFastMathFlags(FMF1 | FMF2); + return SI; } // fabs (select Cond, -FVal, FVal) --> fabs FVal if (match(TVal, m_FNeg(m_Specific(FVal)))) @@ -2491,23 +2617,24 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } case Intrinsic::cos: case Intrinsic::amdgcn_cos: { - Value *X; + Value *X, *Sign; Value *Src = II->getArgOperand(0); - if (match(Src, m_FNeg(m_Value(X))) || match(Src, m_FAbs(m_Value(X)))) { - // cos(-x) -> cos(x) - // cos(fabs(x)) -> cos(x) + if (match(Src, m_FNeg(m_Value(X))) || match(Src, m_FAbs(m_Value(X))) || + match(Src, m_CopySign(m_Value(X), m_Value(Sign)))) { + // cos(-x) --> cos(x) + // cos(fabs(x)) --> cos(x) + // cos(copysign(x, y)) --> cos(x) return replaceOperand(*II, 0, X); } break; } - case Intrinsic::sin: { + case Intrinsic::sin: + case Intrinsic::amdgcn_sin: { Value *X; if (match(II->getArgOperand(0), m_OneUse(m_FNeg(m_Value(X))))) { // sin(-x) --> -sin(x) - Value *NewSin = Builder.CreateUnaryIntrinsic(Intrinsic::sin, X, II); - Instruction *FNeg = UnaryOperator::CreateFNeg(NewSin); - FNeg->copyFastMathFlags(II); - return FNeg; + Value *NewSin = Builder.CreateUnaryIntrinsic(IID, X, II); + return UnaryOperator::CreateFNegFMF(NewSin, II); } break; } @@ -2535,7 +2662,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { FastMathFlags InnerFlags = cast<FPMathOperator>(Src)->getFastMathFlags(); if ((FMF.allowReassoc() && InnerFlags.allowReassoc()) || - signBitMustBeTheSame(Exp, InnerExp, II, DL, &AC, &DT)) { + signBitMustBeTheSame(Exp, InnerExp, SQ.getWithInstruction(II))) { // TODO: Add nsw/nuw probably safe if integer type exceeds exponent // width. Value *NewExp = Builder.CreateAdd(InnerExp, Exp); @@ -2545,6 +2672,49 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } + // ldexp(x, zext(i1 y)) -> fmul x, (select y, 2.0, 1.0) + // ldexp(x, sext(i1 y)) -> fmul x, (select y, 0.5, 1.0) + Value *ExtSrc; + if (match(Exp, m_ZExt(m_Value(ExtSrc))) && + ExtSrc->getType()->getScalarSizeInBits() == 1) { + Value *Select = + Builder.CreateSelect(ExtSrc, ConstantFP::get(II->getType(), 2.0), + ConstantFP::get(II->getType(), 1.0)); + return BinaryOperator::CreateFMulFMF(Src, Select, II); + } + if (match(Exp, m_SExt(m_Value(ExtSrc))) && + ExtSrc->getType()->getScalarSizeInBits() == 1) { + Value *Select = + Builder.CreateSelect(ExtSrc, ConstantFP::get(II->getType(), 0.5), + ConstantFP::get(II->getType(), 1.0)); + return BinaryOperator::CreateFMulFMF(Src, Select, II); + } + + // ldexp(x, c ? exp : 0) -> c ? ldexp(x, exp) : x + // ldexp(x, c ? 0 : exp) -> c ? x : ldexp(x, exp) + /// + // TODO: If we cared, should insert a canonicalize for x + Value *SelectCond, *SelectLHS, *SelectRHS; + if (match(II->getArgOperand(1), + m_OneUse(m_Select(m_Value(SelectCond), m_Value(SelectLHS), + m_Value(SelectRHS))))) { + Value *NewLdexp = nullptr; + Value *Select = nullptr; + if (match(SelectRHS, m_ZeroInt())) { + NewLdexp = Builder.CreateLdexp(Src, SelectLHS); + Select = Builder.CreateSelect(SelectCond, NewLdexp, Src); + } else if (match(SelectLHS, m_ZeroInt())) { + NewLdexp = Builder.CreateLdexp(Src, SelectRHS); + Select = Builder.CreateSelect(SelectCond, Src, NewLdexp); + } + + if (NewLdexp) { + Select->takeName(II); + cast<Instruction>(NewLdexp)->copyFastMathFlags(II); + return replaceInstUsesWith(*II, Select); + } + } + break; } case Intrinsic::ptrauth_auth: @@ -2552,13 +2722,14 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // (sign|resign) + (auth|resign) can be folded by omitting the middle // sign+auth component if the key and discriminator match. bool NeedSign = II->getIntrinsicID() == Intrinsic::ptrauth_resign; + Value *Ptr = II->getArgOperand(0); Value *Key = II->getArgOperand(1); Value *Disc = II->getArgOperand(2); // AuthKey will be the key we need to end up authenticating against in // whatever we replace this sequence with. Value *AuthKey = nullptr, *AuthDisc = nullptr, *BasePtr; - if (auto CI = dyn_cast<CallBase>(II->getArgOperand(0))) { + if (const auto *CI = dyn_cast<CallBase>(Ptr)) { BasePtr = CI->getArgOperand(0); if (CI->getIntrinsicID() == Intrinsic::ptrauth_sign) { if (CI->getArgOperand(1) != Key || CI->getArgOperand(2) != Disc) @@ -2570,6 +2741,27 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { AuthDisc = CI->getArgOperand(2); } else break; + } else if (const auto *PtrToInt = dyn_cast<PtrToIntOperator>(Ptr)) { + // ptrauth constants are equivalent to a call to @llvm.ptrauth.sign for + // our purposes, so check for that too. + const auto *CPA = dyn_cast<ConstantPtrAuth>(PtrToInt->getOperand(0)); + if (!CPA || !CPA->isKnownCompatibleWith(Key, Disc, DL)) + break; + + // resign(ptrauth(p,ks,ds),ks,ds,kr,dr) -> ptrauth(p,kr,dr) + if (NeedSign && isa<ConstantInt>(II->getArgOperand(4))) { + auto *SignKey = cast<ConstantInt>(II->getArgOperand(3)); + auto *SignDisc = cast<ConstantInt>(II->getArgOperand(4)); + auto *SignAddrDisc = ConstantPointerNull::get(Builder.getPtrTy()); + auto *NewCPA = ConstantPtrAuth::get(CPA->getPointer(), SignKey, + SignDisc, SignAddrDisc); + replaceInstUsesWith( + *II, ConstantExpr::getPointerCast(NewCPA, II->getType())); + return eraseInstFromFunction(*II); + } + + // auth(ptrauth(p,k,d),k,d) -> p + BasePtr = Builder.CreatePtrToInt(CPA->getPointer(), II->getType()); } else break; @@ -2586,8 +2778,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } else { // sign(0) + auth(0) = nop replaceInstUsesWith(*II, BasePtr); - eraseInstFromFunction(*II); - return nullptr; + return eraseInstFromFunction(*II); } SmallVector<Value *, 4> CallArgs; @@ -3113,7 +3304,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } break; } - case Intrinsic::experimental_vector_reverse: { + case Intrinsic::vector_reverse: { Value *BO0, *BO1, *X, *Y; Value *Vec = II->getArgOperand(0); if (match(Vec, m_OneUse(m_BinOp(m_Value(BO0), m_Value(BO1))))) { @@ -3121,28 +3312,30 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (match(BO0, m_VecReverse(m_Value(X)))) { // rev(binop rev(X), rev(Y)) --> binop X, Y if (match(BO1, m_VecReverse(m_Value(Y)))) - return replaceInstUsesWith(CI, - BinaryOperator::CreateWithCopiedFlags( - OldBinOp->getOpcode(), X, Y, OldBinOp, - OldBinOp->getName(), II)); + return replaceInstUsesWith(CI, BinaryOperator::CreateWithCopiedFlags( + OldBinOp->getOpcode(), X, Y, + OldBinOp, OldBinOp->getName(), + II->getIterator())); // rev(binop rev(X), BO1Splat) --> binop X, BO1Splat if (isSplatValue(BO1)) - return replaceInstUsesWith(CI, - BinaryOperator::CreateWithCopiedFlags( - OldBinOp->getOpcode(), X, BO1, - OldBinOp, OldBinOp->getName(), II)); + return replaceInstUsesWith(CI, BinaryOperator::CreateWithCopiedFlags( + OldBinOp->getOpcode(), X, BO1, + OldBinOp, OldBinOp->getName(), + II->getIterator())); } // rev(binop BO0Splat, rev(Y)) --> binop BO0Splat, Y if (match(BO1, m_VecReverse(m_Value(Y))) && isSplatValue(BO0)) - return replaceInstUsesWith(CI, BinaryOperator::CreateWithCopiedFlags( - OldBinOp->getOpcode(), BO0, Y, - OldBinOp, OldBinOp->getName(), II)); + return replaceInstUsesWith(CI, + BinaryOperator::CreateWithCopiedFlags( + OldBinOp->getOpcode(), BO0, Y, OldBinOp, + OldBinOp->getName(), II->getIterator())); } // rev(unop rev(X)) --> unop X if (match(Vec, m_OneUse(m_UnOp(m_VecReverse(m_Value(X)))))) { auto *OldUnOp = cast<UnaryOperator>(Vec); auto *NewUnOp = UnaryOperator::CreateWithCopiedFlags( - OldUnOp->getOpcode(), X, OldUnOp, OldUnOp->getName(), II); + OldUnOp->getOpcode(), X, OldUnOp, OldUnOp->getName(), + II->getIterator()); return replaceInstUsesWith(CI, NewUnOp); } break; @@ -3158,6 +3351,13 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // %res = cmp eq iReduxWidth %val, 11111 Value *Arg = II->getArgOperand(0); Value *Vect; + + if (Value *NewOp = + simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) { + replaceUse(II->getOperandUse(0), NewOp); + return II; + } + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) if (FTy->getElementType() == Builder.getInt1Ty()) { @@ -3189,6 +3389,13 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // Trunc(ctpop(bitcast <n x i1> to in)). Value *Arg = II->getArgOperand(0); Value *Vect; + + if (Value *NewOp = + simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) { + replaceUse(II->getOperandUse(0), NewOp); + return II; + } + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) if (FTy->getElementType() == Builder.getInt1Ty()) { @@ -3217,9 +3424,16 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // ?ext(vector_reduce_add(<n x i1>)) Value *Arg = II->getArgOperand(0); Value *Vect; + + if (Value *NewOp = + simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) { + replaceUse(II->getOperandUse(0), NewOp); + return II; + } + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { - if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) - if (FTy->getElementType() == Builder.getInt1Ty()) { + if (auto *VTy = dyn_cast<VectorType>(Vect->getType())) + if (VTy->getElementType() == Builder.getInt1Ty()) { Value *Res = Builder.CreateAddReduce(Vect); if (Arg != Vect) Res = Builder.CreateCast(cast<CastInst>(Arg)->getOpcode(), Res, @@ -3240,9 +3454,16 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // zext(vector_reduce_and(<n x i1>)) Value *Arg = II->getArgOperand(0); Value *Vect; + + if (Value *NewOp = + simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) { + replaceUse(II->getOperandUse(0), NewOp); + return II; + } + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { - if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) - if (FTy->getElementType() == Builder.getInt1Ty()) { + if (auto *VTy = dyn_cast<VectorType>(Vect->getType())) + if (VTy->getElementType() == Builder.getInt1Ty()) { Value *Res = Builder.CreateAndReduce(Vect); if (Res->getType() != II->getType()) Res = Builder.CreateZExt(Res, II->getType()); @@ -3264,9 +3485,16 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // ?ext(vector_reduce_{and,or}(<n x i1>)) Value *Arg = II->getArgOperand(0); Value *Vect; + + if (Value *NewOp = + simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) { + replaceUse(II->getOperandUse(0), NewOp); + return II; + } + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { - if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) - if (FTy->getElementType() == Builder.getInt1Ty()) { + if (auto *VTy = dyn_cast<VectorType>(Vect->getType())) + if (VTy->getElementType() == Builder.getInt1Ty()) { Value *Res = IID == Intrinsic::vector_reduce_umin ? Builder.CreateAndReduce(Vect) : Builder.CreateOrReduce(Vect); @@ -3299,9 +3527,16 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // zext(vector_reduce_{and,or}(<n x i1>)) Value *Arg = II->getArgOperand(0); Value *Vect; + + if (Value *NewOp = + simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) { + replaceUse(II->getOperandUse(0), NewOp); + return II; + } + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { - if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) - if (FTy->getElementType() == Builder.getInt1Ty()) { + if (auto *VTy = dyn_cast<VectorType>(Vect->getType())) + if (VTy->getElementType() == Builder.getInt1Ty()) { Instruction::CastOps ExtOpc = Instruction::CastOps::CastOpsEnd; if (Arg != Vect) ExtOpc = cast<CastInst>(Arg)->getOpcode(); @@ -3321,31 +3556,16 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { case Intrinsic::vector_reduce_fmin: case Intrinsic::vector_reduce_fadd: case Intrinsic::vector_reduce_fmul: { - bool CanBeReassociated = (IID != Intrinsic::vector_reduce_fadd && - IID != Intrinsic::vector_reduce_fmul) || - II->hasAllowReassoc(); + bool CanReorderLanes = (IID != Intrinsic::vector_reduce_fadd && + IID != Intrinsic::vector_reduce_fmul) || + II->hasAllowReassoc(); const unsigned ArgIdx = (IID == Intrinsic::vector_reduce_fadd || IID == Intrinsic::vector_reduce_fmul) ? 1 : 0; Value *Arg = II->getArgOperand(ArgIdx); - Value *V; - ArrayRef<int> Mask; - if (!isa<FixedVectorType>(Arg->getType()) || !CanBeReassociated || - !match(Arg, m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))) || - !cast<ShuffleVectorInst>(Arg)->isSingleSource()) - break; - int Sz = Mask.size(); - SmallBitVector UsedIndices(Sz); - for (int Idx : Mask) { - if (Idx == PoisonMaskElem || UsedIndices.test(Idx)) - break; - UsedIndices.set(Idx); - } - // Can remove shuffle iff just shuffled elements, no repeats, undefs, or - // other changes. - if (UsedIndices.all()) { - replaceUse(II->getOperandUse(ArgIdx), V); + if (Value *NewOp = simplifyReductionOperand(Arg, CanReorderLanes)) { + replaceUse(II->getOperandUse(ArgIdx), NewOp); return nullptr; } break; @@ -3355,6 +3575,15 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { return I; break; } + case Intrinsic::threadlocal_address: { + Align MinAlign = getKnownAlignment(II->getArgOperand(0), DL, II, &AC, &DT); + MaybeAlign Align = II->getRetAlign(); + if (MinAlign > Align.valueOrOne()) { + II->addRetAttr(Attribute::getWithAlignment(II->getContext(), MinAlign)); + return II; + } + break; + } default: { // Handle target specific intrinsics std::optional<Instruction *> V = targetInstCombineIntrinsic(*II); @@ -3596,7 +3825,7 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { for (Value *V : Call.args()) { if (V->getType()->isPointerTy() && !Call.paramHasAttr(ArgNo, Attribute::NonNull) && - isKnownNonZero(V, DL, 0, &AC, &Call, &DT)) + isKnownNonZero(V, getSimplifyQuery().getWithInstruction(&Call))) ArgNos.push_back(ArgNo); ArgNo++; } @@ -3776,7 +4005,8 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { // isKnownNonNull -> nonnull attribute if (!GCR.hasRetAttr(Attribute::NonNull) && - isKnownNonZero(DerivedPtr, DL, 0, &AC, &Call, &DT)) { + isKnownNonZero(DerivedPtr, + getSimplifyQuery().getWithInstruction(&Call))) { GCR.addRetAttr(Attribute::NonNull); // We discovered new fact, re-check users. Worklist.pushUsersToWorkList(GCR); diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index c5d3f60176a8..8f8304702093 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -182,9 +182,15 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) { if (!Cmp || Cmp->getOperand(0)->getType() != Sel->getType() || (CI.getOpcode() == Instruction::Trunc && shouldChangeType(CI.getSrcTy(), CI.getType()))) { - if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) { - replaceAllDbgUsesWith(*Sel, *NV, CI, DT); - return NV; + + // If it's a bitcast involving vectors, make sure it has the same number + // of elements on both sides. + if (CI.getOpcode() != Instruction::BitCast || + match(&CI, m_ElementWiseBitCast(m_Value()))) { + if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) { + replaceAllDbgUsesWith(*Sel, *NV, CI, DT); + return NV; + } } } } @@ -285,10 +291,12 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC, uint32_t BitWidth = Ty->getScalarSizeInBits(); assert(BitWidth < OrigBitWidth && "Unexpected bitwidths!"); APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth); - if (IC.MaskedValueIsZero(I->getOperand(0), Mask, 0, CxtI) && - IC.MaskedValueIsZero(I->getOperand(1), Mask, 0, CxtI)) { - return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && - canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); + // Do not preserve the original context instruction. Simplifying div/rem + // based on later context may introduce a trap. + if (IC.MaskedValueIsZero(I->getOperand(0), Mask, 0, I) && + IC.MaskedValueIsZero(I->getOperand(1), Mask, 0, I)) { + return canEvaluateTruncated(I->getOperand(0), Ty, IC, I) && + canEvaluateTruncated(I->getOperand(1), Ty, IC, I); } break; } @@ -728,24 +736,23 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { if (DestWidth == 1) { Value *Zero = Constant::getNullValue(SrcTy); - if (DestTy->isIntegerTy()) { - // Canonicalize trunc x to i1 -> icmp ne (and x, 1), 0 (scalar only). - // TODO: We canonicalize to more instructions here because we are probably - // lacking equivalent analysis for trunc relative to icmp. There may also - // be codegen concerns. If those trunc limitations were removed, we could - // remove this transform. - Value *And = Builder.CreateAnd(Src, ConstantInt::get(SrcTy, 1)); - return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); - } - // For vectors, we do not canonicalize all truncs to icmp, so optimize - // patterns that would be covered within visitICmpInst. Value *X; + const APInt *C1; + Constant *C2; + if (match(Src, m_OneUse(m_Shr(m_Shl(m_Power2(C1), m_Value(X)), + m_ImmConstant(C2))))) { + // trunc ((C1 << X) >> C2) to i1 --> X == (C2-cttz(C1)), where C1 is pow2 + Constant *Log2C1 = ConstantInt::get(SrcTy, C1->exactLogBase2()); + Constant *CmpC = ConstantExpr::getSub(C2, Log2C1); + return new ICmpInst(ICmpInst::ICMP_EQ, X, CmpC); + } + Constant *C; - if (match(Src, m_OneUse(m_LShr(m_Value(X), m_Constant(C))))) { + if (match(Src, m_OneUse(m_LShr(m_Value(X), m_ImmConstant(C))))) { // trunc (lshr X, C) to i1 --> icmp ne (and X, C'), 0 Constant *One = ConstantInt::get(SrcTy, APInt(SrcWidth, 1)); - Constant *MaskC = ConstantExpr::getShl(One, C); + Value *MaskC = Builder.CreateShl(One, C); Value *And = Builder.CreateAnd(X, MaskC); return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); } @@ -753,10 +760,24 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { m_Deferred(X))))) { // trunc (or (lshr X, C), X) to i1 --> icmp ne (and X, C'), 0 Constant *One = ConstantInt::get(SrcTy, APInt(SrcWidth, 1)); - Constant *MaskC = ConstantExpr::getShl(One, C); + Value *MaskC = Builder.CreateShl(One, C); Value *And = Builder.CreateAnd(X, Builder.CreateOr(MaskC, One)); return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); } + + { + const APInt *C; + if (match(Src, m_Shl(m_APInt(C), m_Value(X))) && (*C)[0] == 1) { + // trunc (C << X) to i1 --> X == 0, where C is odd + return new ICmpInst(ICmpInst::Predicate::ICMP_EQ, X, Zero); + } + } + + if (Trunc.hasNoUnsignedWrap() || Trunc.hasNoSignedWrap()) { + Value *X, *Y; + if (match(Src, m_Xor(m_Value(X), m_Value(Y)))) + return new ICmpInst(ICmpInst::ICMP_NE, X, Y); + } } Value *A, *B; @@ -884,7 +905,20 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { } } - return nullptr; + bool Changed = false; + if (!Trunc.hasNoSignedWrap() && + ComputeMaxSignificantBits(Src, /*Depth=*/0, &Trunc) <= DestWidth) { + Trunc.setHasNoSignedWrap(true); + Changed = true; + } + if (!Trunc.hasNoUnsignedWrap() && + MaskedValueIsZero(Src, APInt::getBitsSetFrom(SrcWidth, DestWidth), + /*Depth=*/0, &Trunc)) { + Trunc.setHasNoUnsignedWrap(true); + Changed = true; + } + + return Changed ? &Trunc : nullptr; } Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, @@ -1115,6 +1149,10 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) { Value *Src = Zext.getOperand(0); Type *SrcTy = Src->getType(), *DestTy = Zext.getType(); + // zext nneg bool x -> 0 + if (SrcTy->isIntOrIntVectorTy(1) && Zext.hasNonNeg()) + return replaceInstUsesWith(Zext, Constant::getNullValue(Zext.getType())); + // Try to extend the entire expression tree to the wide destination type. unsigned BitsToClear; if (shouldChangeType(SrcTy, DestTy) && @@ -1451,7 +1489,7 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) { Value *Y; if (Src->hasOneUse() && match(X, m_LShr(m_Value(Y), - m_SpecificIntAllowUndef(XBitSize - SrcBitSize)))) { + m_SpecificIntAllowPoison(XBitSize - SrcBitSize)))) { Value *Ashr = Builder.CreateAShr(Y, XBitSize - SrcBitSize); return CastInst::CreateIntegerCast(Ashr, DestTy, /* isSigned */ true); } @@ -1537,11 +1575,14 @@ static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { return !losesInfo; } -static Type *shrinkFPConstant(ConstantFP *CFP) { +static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) { if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext())) return nullptr; // No constant folding of this. + // See if the value can be truncated to bfloat and then reextended. + if (PreferBFloat && fitsInFPType(CFP, APFloat::BFloat())) + return Type::getBFloatTy(CFP->getContext()); // See if the value can be truncated to half and then reextended. - if (fitsInFPType(CFP, APFloat::IEEEhalf())) + if (!PreferBFloat && fitsInFPType(CFP, APFloat::IEEEhalf())) return Type::getHalfTy(CFP->getContext()); // See if the value can be truncated to float and then reextended. if (fitsInFPType(CFP, APFloat::IEEEsingle())) @@ -1556,7 +1597,7 @@ static Type *shrinkFPConstant(ConstantFP *CFP) { // Determine if this is a vector of ConstantFPs and if so, return the minimal // type we can safely truncate all elements to. -static Type *shrinkFPConstantVector(Value *V) { +static Type *shrinkFPConstantVector(Value *V, bool PreferBFloat) { auto *CV = dyn_cast<Constant>(V); auto *CVVTy = dyn_cast<FixedVectorType>(V->getType()); if (!CV || !CVVTy) @@ -1576,7 +1617,7 @@ static Type *shrinkFPConstantVector(Value *V) { if (!CFP) return nullptr; - Type *T = shrinkFPConstant(CFP); + Type *T = shrinkFPConstant(CFP, PreferBFloat); if (!T) return nullptr; @@ -1591,7 +1632,7 @@ static Type *shrinkFPConstantVector(Value *V) { } /// Find the minimum FP type we can safely truncate to. -static Type *getMinimumFPType(Value *V) { +static Type *getMinimumFPType(Value *V, bool PreferBFloat) { if (auto *FPExt = dyn_cast<FPExtInst>(V)) return FPExt->getOperand(0)->getType(); @@ -1599,7 +1640,7 @@ static Type *getMinimumFPType(Value *V) { // that can accurately represent it. This allows us to turn // (float)((double)X+2.0) into x+2.0f. if (auto *CFP = dyn_cast<ConstantFP>(V)) - if (Type *T = shrinkFPConstant(CFP)) + if (Type *T = shrinkFPConstant(CFP, PreferBFloat)) return T; // We can only correctly find a minimum type for a scalable vector when it is @@ -1611,7 +1652,7 @@ static Type *getMinimumFPType(Value *V) { // Try to shrink a vector of FP constants. This returns nullptr on scalable // vectors - if (Type *T = shrinkFPConstantVector(V)) + if (Type *T = shrinkFPConstantVector(V, PreferBFloat)) return T; return V->getType(); @@ -1680,8 +1721,10 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) { Type *Ty = FPT.getType(); auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand(0)); if (BO && BO->hasOneUse()) { - Type *LHSMinType = getMinimumFPType(BO->getOperand(0)); - Type *RHSMinType = getMinimumFPType(BO->getOperand(1)); + Type *LHSMinType = + getMinimumFPType(BO->getOperand(0), /*PreferBFloat=*/Ty->isBFloatTy()); + Type *RHSMinType = + getMinimumFPType(BO->getOperand(1), /*PreferBFloat=*/Ty->isBFloatTy()); unsigned OpWidth = BO->getType()->getFPMantissaWidth(); unsigned LHSWidth = LHSMinType->getFPMantissaWidth(); unsigned RHSWidth = RHSMinType->getFPMantissaWidth(); @@ -1908,10 +1951,26 @@ Instruction *InstCombinerImpl::foldItoFPtoI(CastInst &FI) { return replaceInstUsesWith(FI, X); } +static Instruction *foldFPtoI(Instruction &FI, InstCombiner &IC) { + // fpto{u/s}i non-norm --> 0 + FPClassTest Mask = + FI.getOpcode() == Instruction::FPToUI ? fcPosNormal : fcNormal; + KnownFPClass FPClass = + computeKnownFPClass(FI.getOperand(0), Mask, /*Depth=*/0, + IC.getSimplifyQuery().getWithInstruction(&FI)); + if (FPClass.isKnownNever(Mask)) + return IC.replaceInstUsesWith(FI, ConstantInt::getNullValue(FI.getType())); + + return nullptr; +} + Instruction *InstCombinerImpl::visitFPToUI(FPToUIInst &FI) { if (Instruction *I = foldItoFPtoI(FI)) return I; + if (Instruction *I = foldFPtoI(FI, *this)) + return I; + return commonCastTransforms(FI); } @@ -1919,15 +1978,32 @@ Instruction *InstCombinerImpl::visitFPToSI(FPToSIInst &FI) { if (Instruction *I = foldItoFPtoI(FI)) return I; + if (Instruction *I = foldFPtoI(FI, *this)) + return I; + return commonCastTransforms(FI); } Instruction *InstCombinerImpl::visitUIToFP(CastInst &CI) { - return commonCastTransforms(CI); + if (Instruction *R = commonCastTransforms(CI)) + return R; + if (!CI.hasNonNeg() && isKnownNonNegative(CI.getOperand(0), SQ)) { + CI.setNonNeg(); + return &CI; + } + return nullptr; } Instruction *InstCombinerImpl::visitSIToFP(CastInst &CI) { - return commonCastTransforms(CI); + if (Instruction *R = commonCastTransforms(CI)) + return R; + if (isKnownNonNegative(CI.getOperand(0), SQ)) { + auto *UI = + CastInst::Create(Instruction::UIToFP, CI.getOperand(0), CI.getType()); + UI->setNonNeg(true); + return UI; + } + return nullptr; } Instruction *InstCombinerImpl::visitIntToPtr(IntToPtrInst &CI) { @@ -1975,7 +2051,7 @@ Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) { Mask->getType() == Ty) return BinaryOperator::CreateAnd(Builder.CreatePtrToInt(Ptr, Ty), Mask); - if (auto *GEP = dyn_cast<GetElementPtrInst>(SrcOp)) { + if (auto *GEP = dyn_cast<GEPOperator>(SrcOp)) { // Fold ptrtoint(gep null, x) to multiply + constant if the GEP has one use. // While this can increase the number of instructions it doesn't actually // increase the overall complexity since the arithmetic is just part of @@ -1986,6 +2062,20 @@ Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) { Builder.CreateIntCast(EmitGEPOffset(GEP), Ty, /*isSigned=*/false)); } + + // (ptrtoint (gep (inttoptr Base), ...)) -> Base + Offset + Value *Base; + if (GEP->hasOneUse() && + match(GEP->getPointerOperand(), m_OneUse(m_IntToPtr(m_Value(Base)))) && + Base->getType() == Ty) { + Value *Offset = EmitGEPOffset(GEP); + auto *NewOp = BinaryOperator::CreateAdd(Base, Offset); + if (GEP->hasNoUnsignedWrap() || + (GEP->hasNoUnsignedSignedWrap() && + isKnownNonNegative(Offset, SQ.getWithInstruction(&CI)))) + NewOp->setHasNoUnsignedWrap(true); + return NewOp; + } } Value *Vec, *Scalar, *Index; diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 9973a80a7db9..abadf54a9676 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -23,6 +23,7 @@ #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/KnownBits.h" @@ -213,6 +214,9 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( // Find out if the comparison would be true or false for the i'th element. Constant *C = ConstantFoldCompareInstOperands(ICI.getPredicate(), Elt, CompareRHS, DL, &TLI); + if (!C) + return nullptr; + // If the result is undef for this element, ignore it. if (isa<UndefValue>(C)) { // Extend range state machines to cover this element in case there is an @@ -556,8 +560,9 @@ static Value *rewriteGEPAsOffset(Value *Start, Value *Base, // Create empty phi nodes. This avoids cyclic dependencies when creating // the remaining instructions. if (auto *PHI = dyn_cast<PHINode>(Val)) - NewInsts[PHI] = PHINode::Create(IndexType, PHI->getNumIncomingValues(), - PHI->getName() + ".idx", PHI); + NewInsts[PHI] = + PHINode::Create(IndexType, PHI->getNumIncomingValues(), + PHI->getName() + ".idx", PHI->getIterator()); } IRBuilder<> Builder(Base->getContext()); @@ -813,14 +818,10 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, } } - // Only lower this if the icmp is the only user of the GEP or if we expect - // the result to fold to a constant! - if ((GEPsInBounds || CmpInst::isEquality(Cond)) && - (GEPLHS->hasAllConstantIndices() || GEPLHS->hasOneUse()) && - (GEPRHS->hasAllConstantIndices() || GEPRHS->hasOneUse())) { + if (GEPsInBounds || CmpInst::isEquality(Cond)) { // ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) ---> (OFFSET1 cmp OFFSET2) - Value *L = EmitGEPOffset(GEPLHS); - Value *R = EmitGEPOffset(GEPRHS); + Value *L = EmitGEPOffset(GEPLHS, /*RewriteGEP=*/true); + Value *R = EmitGEPOffset(GEPRHS, /*RewriteGEP=*/true); return new ICmpInst(ICmpInst::getSignedPredicate(Cond), L, R); } } @@ -1255,12 +1256,12 @@ Instruction *InstCombinerImpl::foldICmpWithZero(ICmpInst &Cmp) { // if X non-zero and NoOverflow(X * Y) // (icmp eq/ne Y) - if (!XKnown.One.isZero() || isKnownNonZero(X, DL, 0, Q.AC, Q.CxtI, Q.DT)) + if (!XKnown.One.isZero() || isKnownNonZero(X, Q)) return new ICmpInst(Pred, Y, Cmp.getOperand(1)); // if Y non-zero and NoOverflow(X * Y) // (icmp eq/ne X) - if (!YKnown.One.isZero() || isKnownNonZero(Y, DL, 0, Q.AC, Q.CxtI, Q.DT)) + if (!YKnown.One.isZero() || isKnownNonZero(Y, Q)) return new ICmpInst(Pred, X, Cmp.getOperand(1)); } // Note, we are skipping cases: @@ -1334,7 +1335,6 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) { // We already checked simple implication in InstSimplify, only handle complex // cases here. Value *X = Cmp.getOperand(0), *Y = Cmp.getOperand(1); - ICmpInst::Predicate DomPred; const APInt *C; if (!match(Y, m_APInt(C))) return nullptr; @@ -1342,10 +1342,8 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) { CmpInst::Predicate Pred = Cmp.getPredicate(); ConstantRange CR = ConstantRange::makeExactICmpRegion(Pred, *C); - auto handleDomCond = [&](Value *DomCond, bool CondIsTrue) -> Instruction * { - const APInt *DomC; - if (!match(DomCond, m_ICmp(DomPred, m_Specific(X), m_APInt(DomC)))) - return nullptr; + auto handleDomCond = [&](ICmpInst::Predicate DomPred, + const APInt *DomC) -> Instruction * { // We have 2 compares of a variable with constants. Calculate the constant // ranges of those compares to see if we can transform the 2nd compare: // DomBB: @@ -1353,8 +1351,6 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) { // br DomCond, CmpBB, FalseBB // CmpBB: // Cmp = icmp Pred X, C - if (!CondIsTrue) - DomPred = CmpInst::getInversePredicate(DomPred); ConstantRange DominatingCR = ConstantRange::makeExactICmpRegion(DomPred, *DomC); ConstantRange Intersection = DominatingCR.intersectWith(CR); @@ -1388,15 +1384,21 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) { }; for (BranchInst *BI : DC.conditionsFor(X)) { - auto *Cond = BI->getCondition(); + ICmpInst::Predicate DomPred; + const APInt *DomC; + if (!match(BI->getCondition(), + m_ICmp(DomPred, m_Specific(X), m_APInt(DomC)))) + continue; + BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0)); if (DT.dominates(Edge0, Cmp.getParent())) { - if (auto *V = handleDomCond(Cond, true)) + if (auto *V = handleDomCond(DomPred, DomC)) return V; } else { BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1)); if (DT.dominates(Edge1, Cmp.getParent())) - if (auto *V = handleDomCond(Cond, false)) + if (auto *V = + handleDomCond(CmpInst::getInversePredicate(DomPred), DomC)) return V; } } @@ -1410,6 +1412,19 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, const APInt &C) { ICmpInst::Predicate Pred = Cmp.getPredicate(); Value *X = Trunc->getOperand(0); + Type *SrcTy = X->getType(); + unsigned DstBits = Trunc->getType()->getScalarSizeInBits(), + SrcBits = SrcTy->getScalarSizeInBits(); + + // Match (icmp pred (trunc nuw/nsw X), C) + // Which we can convert to (icmp pred X, (sext/zext C)) + if (shouldChangeType(Trunc->getType(), SrcTy)) { + if (Trunc->hasNoSignedWrap()) + return new ICmpInst(Pred, X, ConstantInt::get(SrcTy, C.sext(SrcBits))); + if (!Cmp.isSigned() && Trunc->hasNoUnsignedWrap()) + return new ICmpInst(Pred, X, ConstantInt::get(SrcTy, C.zext(SrcBits))); + } + if (C.isOne() && C.getBitWidth() > 1) { // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 Value *V = nullptr; @@ -1418,10 +1433,6 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, ConstantInt::get(V->getType(), 1)); } - Type *SrcTy = X->getType(); - unsigned DstBits = Trunc->getType()->getScalarSizeInBits(), - SrcBits = SrcTy->getScalarSizeInBits(); - // TODO: Handle any shifted constant by subtracting trailing zeros. // TODO: Handle non-equality predicates. Value *Y; @@ -1480,19 +1491,29 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, return nullptr; } -/// Fold icmp (trunc X), (trunc Y). -/// Fold icmp (trunc X), (zext Y). +/// Fold icmp (trunc nuw/nsw X), (trunc nuw/nsw Y). +/// Fold icmp (trunc nuw/nsw X), (zext/sext Y). Instruction * InstCombinerImpl::foldICmpTruncWithTruncOrExt(ICmpInst &Cmp, const SimplifyQuery &Q) { - if (Cmp.isSigned()) - return nullptr; - Value *X, *Y; ICmpInst::Predicate Pred; - bool YIsZext = false; + bool YIsSExt = false; // Try to match icmp (trunc X), (trunc Y) if (match(&Cmp, m_ICmp(Pred, m_Trunc(m_Value(X)), m_Trunc(m_Value(Y))))) { + unsigned NoWrapFlags = cast<TruncInst>(Cmp.getOperand(0))->getNoWrapKind() & + cast<TruncInst>(Cmp.getOperand(1))->getNoWrapKind(); + if (Cmp.isSigned()) { + // For signed comparisons, both truncs must be nsw. + if (!(NoWrapFlags & TruncInst::NoSignedWrap)) + return nullptr; + } else { + // For unsigned and equality comparisons, either both must be nuw or + // both must be nsw, we don't care which. + if (!NoWrapFlags) + return nullptr; + } + if (X->getType() != Y->getType() && (!Cmp.getOperand(0)->hasOneUse() || !Cmp.getOperand(1)->hasOneUse())) return nullptr; @@ -1501,13 +1522,21 @@ InstCombinerImpl::foldICmpTruncWithTruncOrExt(ICmpInst &Cmp, std::swap(X, Y); Pred = Cmp.getSwappedPredicate(Pred); } + YIsSExt = !(NoWrapFlags & TruncInst::NoUnsignedWrap); } - // Try to match icmp (trunc X), (zext Y) - else if (match(&Cmp, m_c_ICmp(Pred, m_Trunc(m_Value(X)), - m_OneUse(m_ZExt(m_Value(Y)))))) - - YIsZext = true; - else + // Try to match icmp (trunc nuw X), (zext Y) + else if (!Cmp.isSigned() && + match(&Cmp, m_c_ICmp(Pred, m_NUWTrunc(m_Value(X)), + m_OneUse(m_ZExt(m_Value(Y)))))) { + // Can fold trunc nuw + zext for unsigned and equality predicates. + } + // Try to match icmp (trunc nsw X), (sext Y) + else if (match(&Cmp, m_c_ICmp(Pred, m_NSWTrunc(m_Value(X)), + m_OneUse(m_ZExtOrSExt(m_Value(Y)))))) { + // Can fold trunc nsw + zext/sext for all predicates. + YIsSExt = + isa<SExtInst>(Cmp.getOperand(0)) || isa<SExtInst>(Cmp.getOperand(1)); + } else return nullptr; Type *TruncTy = Cmp.getOperand(0)->getType(); @@ -1519,19 +1548,7 @@ InstCombinerImpl::foldICmpTruncWithTruncOrExt(ICmpInst &Cmp, !isDesirableIntType(X->getType()->getScalarSizeInBits())) return nullptr; - // Check if the trunc is unneeded. - KnownBits KnownX = llvm::computeKnownBits(X, /*Depth*/ 0, Q); - if (KnownX.countMaxActiveBits() > TruncBits) - return nullptr; - - if (!YIsZext) { - // If Y is also a trunc, make sure it is unneeded. - KnownBits KnownY = llvm::computeKnownBits(Y, /*Depth*/ 0, Q); - if (KnownY.countMaxActiveBits() > TruncBits) - return nullptr; - } - - Value *NewY = Builder.CreateZExtOrTrunc(Y, X->getType()); + Value *NewY = Builder.CreateIntCast(Y, X->getType(), YIsSExt); return new ICmpInst(Pred, X, NewY); } @@ -1827,6 +1844,28 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp, } } + // (icmp eq (and (bitcast X to int), ExponentMask), ExponentMask) --> + // llvm.is.fpclass(X, fcInf|fcNan) + // (icmp ne (and (bitcast X to int), ExponentMask), ExponentMask) --> + // llvm.is.fpclass(X, ~(fcInf|fcNan)) + Value *V; + if (!Cmp.getParent()->getParent()->hasFnAttribute( + Attribute::NoImplicitFloat) && + Cmp.isEquality() && + match(X, m_OneUse(m_ElementWiseBitCast(m_Value(V))))) { + Type *FPType = V->getType()->getScalarType(); + if (FPType->isIEEELikeFPTy() && C1 == *C2) { + APInt ExponentMask = + APFloat::getInf(FPType->getFltSemantics()).bitcastToAPInt(); + if (C1 == ExponentMask) { + unsigned Mask = FPClassTest::fcNan | FPClassTest::fcInf; + if (isICMP_NE) + Mask = ~Mask & fcAllFlags; + return replaceInstUsesWith(Cmp, Builder.createIsFPClass(V, Mask)); + } + } + } + return nullptr; } @@ -1848,8 +1887,8 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, auto NewPred = TrueIfNeg ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE; return new ICmpInst(NewPred, X, ConstantInt::getNullValue(X->getType())); } - // (X & X) < 0 --> X == MinSignedC - // (X & X) > -1 --> X != MinSignedC + // (X & -X) < 0 --> X == MinSignedC + // (X & -X) > -1 --> X != MinSignedC if (match(And, m_c_And(m_Neg(m_Value(X)), m_Deferred(X)))) { Constant *MinSignedC = ConstantInt::get( X->getType(), @@ -1922,6 +1961,17 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, return BinaryOperator::CreateAnd(TruncY, X); } + // (icmp eq/ne (and (shl -1, X), Y), 0) + // -> (icmp eq/ne (lshr Y, X), 0) + // We could technically handle any C == 0 or (C < 0 && isOdd(C)) but it seems + // highly unlikely the non-zero case will ever show up in code. + if (C.isZero() && + match(And, m_OneUse(m_c_And(m_OneUse(m_Shl(m_AllOnes(), m_Value(X))), + m_Value(Y))))) { + Value *LShr = Builder.CreateLShr(Y, X); + return new ICmpInst(Pred, LShr, Constant::getNullValue(LShr->getType())); + } + return nullptr; } @@ -1998,6 +2048,16 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, } Value *OrOp0 = Or->getOperand(0), *OrOp1 = Or->getOperand(1); + + // (icmp eq/ne (or disjoint x, C0), C1) + // -> (icmp eq/ne x, C0^C1) + if (Cmp.isEquality() && match(OrOp1, m_ImmConstant()) && + cast<PossiblyDisjointInst>(Or)->isDisjoint()) { + Value *NewC = + Builder.CreateXor(OrOp1, ConstantInt::get(OrOp1->getType(), C)); + return new ICmpInst(Pred, OrOp0, NewC); + } + const APInt *MaskC; if (match(OrOp1, m_APInt(MaskC)) && Cmp.isEquality()) { if (*MaskC == C && (C + 1).isPowerOf2()) { @@ -2357,14 +2417,35 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp, // free on the target. It has the additional benefit of comparing to a // smaller constant that may be more target-friendly. unsigned Amt = ShiftAmt->getLimitedValue(TypeBits - 1); - if (Shl->hasOneUse() && Amt != 0 && C.countr_zero() >= Amt && - DL.isLegalInteger(TypeBits - Amt)) { - Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt); - if (auto *ShVTy = dyn_cast<VectorType>(ShType)) - TruncTy = VectorType::get(TruncTy, ShVTy->getElementCount()); - Constant *NewC = - ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt)); - return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC); + if (Shl->hasOneUse() && Amt != 0 && + shouldChangeType(ShType->getScalarSizeInBits(), TypeBits - Amt)) { + ICmpInst::Predicate CmpPred = Pred; + APInt RHSC = C; + + if (RHSC.countr_zero() < Amt && ICmpInst::isStrictPredicate(CmpPred)) { + // Try the flipped strictness predicate. + // e.g.: + // icmp ult i64 (shl X, 32), 8589934593 -> + // icmp ule i64 (shl X, 32), 8589934592 -> + // icmp ule i32 (trunc X, i32), 2 -> + // icmp ult i32 (trunc X, i32), 3 + if (auto FlippedStrictness = + InstCombiner::getFlippedStrictnessPredicateAndConstant( + Pred, ConstantInt::get(ShType->getContext(), C))) { + CmpPred = FlippedStrictness->first; + RHSC = cast<ConstantInt>(FlippedStrictness->second)->getValue(); + } + } + + if (RHSC.countr_zero() >= Amt) { + Type *TruncTy = ShType->getWithNewBitWidth(TypeBits - Amt); + Constant *NewC = + ConstantInt::get(TruncTy, RHSC.ashr(*ShiftAmt).trunc(TypeBits - Amt)); + return new ICmpInst(CmpPred, + Builder.CreateTrunc(X, TruncTy, "", /*IsNUW=*/false, + Shl->hasNoSignedWrap()), + NewC); + } } return nullptr; @@ -2431,6 +2512,16 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, // those conditions rather than checking them. This is difficult because of // undef/poison (PR34838). if (IsAShr && Shr->hasOneUse()) { + if (IsExact && (Pred == CmpInst::ICMP_SLT || Pred == CmpInst::ICMP_ULT) && + (C - 1).isPowerOf2() && C.countLeadingZeros() > ShAmtVal) { + // When C - 1 is a power of two and the transform can be legally + // performed, prefer this form so the produced constant is close to a + // power of two. + // icmp slt/ult (ashr exact X, ShAmtC), C + // --> icmp slt/ult X, (C - 1) << ShAmtC) + 1 + APInt ShiftedC = (C - 1).shl(ShAmtVal) + 1; + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); + } if (IsExact || Pred == CmpInst::ICMP_SLT || Pred == CmpInst::ICMP_ULT) { // When ShAmtC can be shifted losslessly: // icmp PRED (ashr exact X, ShAmtC), C --> icmp PRED X, (C << ShAmtC) @@ -3025,7 +3116,7 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, // (X + -1) <u C --> X <=u C (if X is never null) if (Pred == CmpInst::ICMP_ULT && C2->isAllOnes()) { const SimplifyQuery Q = SQ.getWithInstruction(&Cmp); - if (llvm::isKnownNonZero(X, DL, 0, Q.AC, Q.CxtI, Q.DT)) + if (llvm::isKnownNonZero(X, Q)) return new ICmpInst(ICmpInst::ICMP_ULE, X, ConstantInt::get(Ty, C)); } @@ -3039,6 +3130,13 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateAnd(X, -C), ConstantExpr::getNeg(cast<Constant>(Y))); + // X+C2 <u C -> (X & C) == 2C + // iff C == -(C2) + // C2 is a power of 2 + if (Pred == ICmpInst::ICMP_ULT && C2->isPowerOf2() && C == -*C2) + return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateAnd(X, C), + ConstantInt::get(Ty, C * 2)); + // X+C >u C2 -> (X & ~C2) != C // iff C & C2 == 0 // C2+1 is a power of 2 @@ -3128,15 +3226,12 @@ Instruction *InstCombinerImpl::foldICmpSelectConstant(ICmpInst &Cmp, C3GreaterThan)) { assert(C1LessThan && C2Equal && C3GreaterThan); - bool TrueWhenLessThan = - ConstantExpr::getCompare(Cmp.getPredicate(), C1LessThan, C) - ->isAllOnesValue(); - bool TrueWhenEqual = - ConstantExpr::getCompare(Cmp.getPredicate(), C2Equal, C) - ->isAllOnesValue(); - bool TrueWhenGreaterThan = - ConstantExpr::getCompare(Cmp.getPredicate(), C3GreaterThan, C) - ->isAllOnesValue(); + bool TrueWhenLessThan = ICmpInst::compare( + C1LessThan->getValue(), C->getValue(), Cmp.getPredicate()); + bool TrueWhenEqual = ICmpInst::compare(C2Equal->getValue(), C->getValue(), + Cmp.getPredicate()); + bool TrueWhenGreaterThan = ICmpInst::compare( + C3GreaterThan->getValue(), C->getValue(), Cmp.getPredicate()); // This generates the new instruction that will replace the original Cmp // Instruction. Instead of enumerating the various combinations when @@ -3206,16 +3301,16 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { if (Cmp.isEquality() && match(Op1, m_Zero())) return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); - // If this is a sign-bit test of a bitcast of a casted FP value, eliminate - // the FP extend/truncate because that cast does not change the sign-bit. - // This is true for all standard IEEE-754 types and the X86 80-bit type. - // The sign-bit is always the most significant bit in those types. const APInt *C; bool TrueIfSigned; - if (match(Op1, m_APInt(C)) && Bitcast->hasOneUse() && - isSignBitCheck(Pred, *C, TrueIfSigned)) { - if (match(BCSrcOp, m_FPExt(m_Value(X))) || - match(BCSrcOp, m_FPTrunc(m_Value(X)))) { + if (match(Op1, m_APInt(C)) && Bitcast->hasOneUse()) { + // If this is a sign-bit test of a bitcast of a casted FP value, eliminate + // the FP extend/truncate because that cast does not change the sign-bit. + // This is true for all standard IEEE-754 types and the X86 80-bit type. + // The sign-bit is always the most significant bit in those types. + if (isSignBitCheck(Pred, *C, TrueIfSigned) && + (match(BCSrcOp, m_FPExt(m_Value(X))) || + match(BCSrcOp, m_FPTrunc(m_Value(X))))) { // (bitcast (fpext/fptrunc X)) to iX) < 0 --> (bitcast X to iY) < 0 // (bitcast (fpext/fptrunc X)) to iX) > -1 --> (bitcast X to iY) > -1 Type *XType = X->getType(); @@ -3234,6 +3329,20 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { ConstantInt::getAllOnesValue(NewType)); } } + + // icmp eq/ne (bitcast X to int), special fp -> llvm.is.fpclass(X, class) + Type *FPType = SrcType->getScalarType(); + if (!Cmp.getParent()->getParent()->hasFnAttribute( + Attribute::NoImplicitFloat) && + Cmp.isEquality() && FPType->isIEEELikeFPTy()) { + FPClassTest Mask = APFloat(FPType->getFltSemantics(), *C).classify(); + if (Mask & (fcInf | fcZero)) { + if (Pred == ICmpInst::ICMP_NE) + Mask = ~Mask; + return replaceInstUsesWith(Cmp, + Builder.createIsFPClass(BCSrcOp, Mask)); + } + } } } @@ -3341,8 +3450,8 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstant(ICmpInst &Cmp) { return new ICmpInst(Cmp.getPredicate(), X, Y); } - if (match(Cmp.getOperand(1), m_APIntAllowUndef(C))) - return foldICmpInstWithConstantAllowUndef(Cmp, *C); + if (match(Cmp.getOperand(1), m_APIntAllowPoison(C))) + return foldICmpInstWithConstantAllowPoison(Cmp, *C); return nullptr; } @@ -3388,6 +3497,11 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( if (Value *NegVal = dyn_castNegVal(BOp0)) return new ICmpInst(Pred, NegVal, BOp1); if (BO->hasOneUse()) { + // (add nuw A, B) != 0 -> (or A, B) != 0 + if (match(BO, m_NUWAdd(m_Value(), m_Value()))) { + Value *Or = Builder.CreateOr(BOp0, BOp1); + return new ICmpInst(Pred, Or, Constant::getNullValue(BO->getType())); + } Value *Neg = Builder.CreateNeg(BOp1); Neg->takeName(BO); return new ICmpInst(Pred, BOp0, Neg); @@ -3396,15 +3510,13 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( break; } case Instruction::Xor: - if (BO->hasOneUse()) { - if (Constant *BOC = dyn_cast<Constant>(BOp1)) { - // For the xor case, we can xor two constants together, eliminating - // the explicit xor. - return new ICmpInst(Pred, BOp0, ConstantExpr::getXor(RHS, BOC)); - } else if (C.isZero()) { - // Replace ((xor A, B) != 0) with (A != B) - return new ICmpInst(Pred, BOp0, BOp1); - } + if (Constant *BOC = dyn_cast<Constant>(BOp1)) { + // For the xor case, we can xor two constants together, eliminating + // the explicit xor. + return new ICmpInst(Pred, BOp0, ConstantExpr::getXor(RHS, BOC)); + } else if (C.isZero()) { + // Replace ((xor A, B) != 0) with (A != B) + return new ICmpInst(Pred, BOp0, BOp1); } break; case Instruction::Or: { @@ -3654,11 +3766,11 @@ foldICmpIntrinsicWithIntrinsic(ICmpInst &Cmp, } /// Try to fold integer comparisons with a constant operand: icmp Pred X, C -/// where X is some kind of instruction and C is AllowUndef. -/// TODO: Move more folds which allow undef to this function. +/// where X is some kind of instruction and C is AllowPoison. +/// TODO: Move more folds which allow poison to this function. Instruction * -InstCombinerImpl::foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp, - const APInt &C) { +InstCombinerImpl::foldICmpInstWithConstantAllowPoison(ICmpInst &Cmp, + const APInt &C) { const ICmpInst::Predicate Pred = Cmp.getPredicate(); if (auto *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0))) { switch (II->getIntrinsicID()) { @@ -3821,6 +3933,52 @@ foldICmpUSubSatOrUAddSatWithConstant(ICmpInst::Predicate Pred, ConstantInt::get(Op1->getType(), EquivInt)); } +static Instruction * +foldICmpOfCmpIntrinsicWithConstant(ICmpInst::Predicate Pred, IntrinsicInst *I, + const APInt &C, + InstCombiner::BuilderTy &Builder) { + std::optional<ICmpInst::Predicate> NewPredicate = std::nullopt; + switch (Pred) { + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: + if (C.isZero()) + NewPredicate = Pred; + else if (C.isOne()) + NewPredicate = + Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_UGT : ICmpInst::ICMP_ULE; + else if (C.isAllOnes()) + NewPredicate = + Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE; + break; + + case ICmpInst::ICMP_SGT: + if (C.isAllOnes()) + NewPredicate = ICmpInst::ICMP_UGE; + else if (C.isZero()) + NewPredicate = ICmpInst::ICMP_UGT; + break; + + case ICmpInst::ICMP_SLT: + if (C.isZero()) + NewPredicate = ICmpInst::ICMP_ULT; + else if (C.isOne()) + NewPredicate = ICmpInst::ICMP_ULE; + break; + + default: + break; + } + + if (!NewPredicate) + return nullptr; + + if (I->getIntrinsicID() == Intrinsic::scmp) + NewPredicate = ICmpInst::getSignedPredicate(*NewPredicate); + Value *LHS = I->getOperand(0); + Value *RHS = I->getOperand(1); + return new ICmpInst(*NewPredicate, LHS, RHS); +} + /// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, IntrinsicInst *II, @@ -3842,6 +4000,11 @@ Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, if (Instruction *R = foldCtpopPow2Test(Cmp, II, C, Builder, Q)) return R; } break; + case Intrinsic::scmp: + case Intrinsic::ucmp: + if (auto *Folded = foldICmpOfCmpIntrinsicWithConstant(Pred, II, C, Builder)) + return Folded; + break; } if (Cmp.isEquality()) @@ -4015,11 +4178,106 @@ Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred, return nullptr; } +// Returns whether V is a Mask ((X + 1) & X == 0) or ~Mask (-Pow2OrZero) +static bool isMaskOrZero(const Value *V, bool Not, const SimplifyQuery &Q, + unsigned Depth = 0) { + if (Not ? match(V, m_NegatedPower2OrZero()) : match(V, m_LowBitMaskOrZero())) + return true; + if (V->getType()->getScalarSizeInBits() == 1) + return true; + if (Depth++ >= MaxAnalysisRecursionDepth) + return false; + Value *X; + const Instruction *I = dyn_cast<Instruction>(V); + if (!I) + return false; + switch (I->getOpcode()) { + case Instruction::ZExt: + // ZExt(Mask) is a Mask. + return !Not && isMaskOrZero(I->getOperand(0), Not, Q, Depth); + case Instruction::SExt: + // SExt(Mask) is a Mask. + // SExt(~Mask) is a ~Mask. + return isMaskOrZero(I->getOperand(0), Not, Q, Depth); + case Instruction::And: + case Instruction::Or: + // Mask0 | Mask1 is a Mask. + // Mask0 & Mask1 is a Mask. + // ~Mask0 | ~Mask1 is a ~Mask. + // ~Mask0 & ~Mask1 is a ~Mask. + return isMaskOrZero(I->getOperand(1), Not, Q, Depth) && + isMaskOrZero(I->getOperand(0), Not, Q, Depth); + case Instruction::Xor: + if (match(V, m_Not(m_Value(X)))) + return isMaskOrZero(X, !Not, Q, Depth); + + // (X ^ -X) is a ~Mask + if (Not) + return match(V, m_c_Xor(m_Value(X), m_Neg(m_Deferred(X)))); + // (X ^ (X - 1)) is a Mask + else + return match(V, m_c_Xor(m_Value(X), m_Add(m_Deferred(X), m_AllOnes()))); + case Instruction::Select: + // c ? Mask0 : Mask1 is a Mask. + return isMaskOrZero(I->getOperand(1), Not, Q, Depth) && + isMaskOrZero(I->getOperand(2), Not, Q, Depth); + case Instruction::Shl: + // (~Mask) << X is a ~Mask. + return Not && isMaskOrZero(I->getOperand(0), Not, Q, Depth); + case Instruction::LShr: + // Mask >> X is a Mask. + return !Not && isMaskOrZero(I->getOperand(0), Not, Q, Depth); + case Instruction::AShr: + // Mask s>> X is a Mask. + // ~Mask s>> X is a ~Mask. + return isMaskOrZero(I->getOperand(0), Not, Q, Depth); + case Instruction::Add: + // Pow2 - 1 is a Mask. + if (!Not && match(I->getOperand(1), m_AllOnes())) + return isKnownToBeAPowerOfTwo(I->getOperand(0), Q.DL, /*OrZero*/ true, + Depth, Q.AC, Q.CxtI, Q.DT); + break; + case Instruction::Sub: + // -Pow2 is a ~Mask. + if (Not && match(I->getOperand(0), m_Zero())) + return isKnownToBeAPowerOfTwo(I->getOperand(1), Q.DL, /*OrZero*/ true, + Depth, Q.AC, Q.CxtI, Q.DT); + break; + case Instruction::Call: { + if (auto *II = dyn_cast<IntrinsicInst>(I)) { + switch (II->getIntrinsicID()) { + // min/max(Mask0, Mask1) is a Mask. + // min/max(~Mask0, ~Mask1) is a ~Mask. + case Intrinsic::umax: + case Intrinsic::smax: + case Intrinsic::umin: + case Intrinsic::smin: + return isMaskOrZero(II->getArgOperand(1), Not, Q, Depth) && + isMaskOrZero(II->getArgOperand(0), Not, Q, Depth); + + // In the context of masks, bitreverse(Mask) == ~Mask + case Intrinsic::bitreverse: + return isMaskOrZero(II->getArgOperand(0), !Not, Q, Depth); + default: + break; + } + } + break; + } + default: + break; + } + return false; +} + /// Some comparisons can be simplified. /// In this case, we are looking for comparisons that look like /// a check for a lossy truncation. /// Folds: /// icmp SrcPred (x & Mask), x to icmp DstPred x, Mask +/// icmp SrcPred (x & ~Mask), ~Mask to icmp DstPred x, ~Mask +/// icmp eq/ne (x & ~Mask), 0 to icmp DstPred x, Mask +/// icmp eq/ne (~x | Mask), -1 to icmp DstPred x, Mask /// Where Mask is some pattern that produces all-ones in low bits: /// (-1 >> y) /// ((-1 << y) >> y) <- non-canonical, has extra uses @@ -4028,89 +4286,125 @@ Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred, /// The Mask can be a constant, too. /// For some predicates, the operands are commutative. /// For others, x can only be on a specific side. -static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, - InstCombiner::BuilderTy &Builder) { - ICmpInst::Predicate SrcPred; - Value *X, *M, *Y; - auto m_VariableMask = m_CombineOr( - m_CombineOr(m_Not(m_Shl(m_AllOnes(), m_Value())), - m_Add(m_Shl(m_One(), m_Value()), m_AllOnes())), - m_CombineOr(m_LShr(m_AllOnes(), m_Value()), - m_LShr(m_Shl(m_AllOnes(), m_Value(Y)), m_Deferred(Y)))); - auto m_Mask = m_CombineOr(m_VariableMask, m_LowBitMask()); - if (!match(&I, m_c_ICmp(SrcPred, - m_c_And(m_CombineAnd(m_Mask, m_Value(M)), m_Value(X)), - m_Deferred(X)))) - return nullptr; +static Value *foldICmpWithLowBitMaskedVal(ICmpInst::Predicate Pred, Value *Op0, + Value *Op1, const SimplifyQuery &Q, + InstCombiner &IC) { ICmpInst::Predicate DstPred; - switch (SrcPred) { + switch (Pred) { case ICmpInst::Predicate::ICMP_EQ: - // x & (-1 >> y) == x -> x u<= (-1 >> y) + // x & Mask == x + // x & ~Mask == 0 + // ~x | Mask == -1 + // -> x u<= Mask + // x & ~Mask == ~Mask + // -> ~Mask u<= x DstPred = ICmpInst::Predicate::ICMP_ULE; break; case ICmpInst::Predicate::ICMP_NE: - // x & (-1 >> y) != x -> x u> (-1 >> y) + // x & Mask != x + // x & ~Mask != 0 + // ~x | Mask != -1 + // -> x u> Mask + // x & ~Mask != ~Mask + // -> ~Mask u> x DstPred = ICmpInst::Predicate::ICMP_UGT; break; case ICmpInst::Predicate::ICMP_ULT: - // x & (-1 >> y) u< x -> x u> (-1 >> y) - // x u> x & (-1 >> y) -> x u> (-1 >> y) + // x & Mask u< x + // -> x u> Mask + // x & ~Mask u< ~Mask + // -> ~Mask u> x DstPred = ICmpInst::Predicate::ICMP_UGT; break; case ICmpInst::Predicate::ICMP_UGE: - // x & (-1 >> y) u>= x -> x u<= (-1 >> y) - // x u<= x & (-1 >> y) -> x u<= (-1 >> y) + // x & Mask u>= x + // -> x u<= Mask + // x & ~Mask u>= ~Mask + // -> ~Mask u<= x DstPred = ICmpInst::Predicate::ICMP_ULE; break; case ICmpInst::Predicate::ICMP_SLT: - // x & (-1 >> y) s< x -> x s> (-1 >> y) - // x s> x & (-1 >> y) -> x s> (-1 >> y) - if (!match(M, m_Constant())) // Can not do this fold with non-constant. - return nullptr; - if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. - return nullptr; + // x & Mask s< x [iff Mask s>= 0] + // -> x s> Mask + // x & ~Mask s< ~Mask [iff ~Mask != 0] + // -> ~Mask s> x DstPred = ICmpInst::Predicate::ICMP_SGT; break; case ICmpInst::Predicate::ICMP_SGE: - // x & (-1 >> y) s>= x -> x s<= (-1 >> y) - // x s<= x & (-1 >> y) -> x s<= (-1 >> y) - if (!match(M, m_Constant())) // Can not do this fold with non-constant. - return nullptr; - if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. - return nullptr; + // x & Mask s>= x [iff Mask s>= 0] + // -> x s<= Mask + // x & ~Mask s>= ~Mask [iff ~Mask != 0] + // -> ~Mask s<= x DstPred = ICmpInst::Predicate::ICMP_SLE; break; - case ICmpInst::Predicate::ICMP_SGT: - case ICmpInst::Predicate::ICMP_SLE: - return nullptr; - case ICmpInst::Predicate::ICMP_UGT: - case ICmpInst::Predicate::ICMP_ULE: - llvm_unreachable("Instsimplify took care of commut. variant"); - break; default: - llvm_unreachable("All possible folds are handled."); - } - - // The mask value may be a vector constant that has undefined elements. But it - // may not be safe to propagate those undefs into the new compare, so replace - // those elements by copying an existing, defined, and safe scalar constant. - Type *OpTy = M->getType(); - auto *VecC = dyn_cast<Constant>(M); - auto *OpVTy = dyn_cast<FixedVectorType>(OpTy); - if (OpVTy && VecC && VecC->containsUndefOrPoisonElement()) { - Constant *SafeReplacementConstant = nullptr; - for (unsigned i = 0, e = OpVTy->getNumElements(); i != e; ++i) { - if (!isa<UndefValue>(VecC->getAggregateElement(i))) { - SafeReplacementConstant = VecC->getAggregateElement(i); - break; + // We don't support sgt,sle + // ult/ugt are simplified to true/false respectively. + return nullptr; + } + + Value *X, *M; + // Put search code in lambda for early positive returns. + auto IsLowBitMask = [&]() { + if (match(Op0, m_c_And(m_Specific(Op1), m_Value(M)))) { + X = Op1; + // Look for: x & Mask pred x + if (isMaskOrZero(M, /*Not=*/false, Q)) { + return !ICmpInst::isSigned(Pred) || + (match(M, m_NonNegative()) || isKnownNonNegative(M, Q)); + } + + // Look for: x & ~Mask pred ~Mask + if (isMaskOrZero(X, /*Not=*/true, Q)) { + return !ICmpInst::isSigned(Pred) || isKnownNonZero(X, Q); } + return false; } - assert(SafeReplacementConstant && "Failed to find undef replacement"); - M = Constant::replaceUndefsWith(VecC, SafeReplacementConstant); - } + if (ICmpInst::isEquality(Pred) && match(Op1, m_AllOnes()) && + match(Op0, m_OneUse(m_Or(m_Value(X), m_Value(M))))) { + + auto Check = [&]() { + // Look for: ~x | Mask == -1 + if (isMaskOrZero(M, /*Not=*/false, Q)) { + if (Value *NotX = + IC.getFreelyInverted(X, X->hasOneUse(), &IC.Builder)) { + X = NotX; + return true; + } + } + return false; + }; + if (Check()) + return true; + std::swap(X, M); + return Check(); + } + if (ICmpInst::isEquality(Pred) && match(Op1, m_Zero()) && + match(Op0, m_OneUse(m_And(m_Value(X), m_Value(M))))) { + auto Check = [&]() { + // Look for: x & ~Mask == 0 + if (isMaskOrZero(M, /*Not=*/true, Q)) { + if (Value *NotM = + IC.getFreelyInverted(M, M->hasOneUse(), &IC.Builder)) { + M = NotM; + return true; + } + } + return false; + }; + if (Check()) + return true; + std::swap(X, M); + return Check(); + } + return false; + }; + + if (!IsLowBitMask()) + return nullptr; - return Builder.CreateICmp(DstPred, X, M); + return IC.Builder.CreateICmp(DstPred, X, M); } /// Some comparisons can be simplified. @@ -4493,6 +4787,44 @@ static Instruction *foldICmpAndXX(ICmpInst &I, const SimplifyQuery &Q, if (Pred == ICmpInst::ICMP_UGE) return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); + if (ICmpInst::isEquality(Pred) && Op0->hasOneUse()) { + // icmp (X & Y) eq/ne Y --> (X | ~Y) eq/ne -1 if Y is freely invertible and + // Y is non-constant. If Y is constant the `X & C == C` form is preferable + // so don't do this fold. + if (!match(Op1, m_ImmConstant())) + if (auto *NotOp1 = + IC.getFreelyInverted(Op1, !Op1->hasNUsesOrMore(3), &IC.Builder)) + return new ICmpInst(Pred, IC.Builder.CreateOr(A, NotOp1), + Constant::getAllOnesValue(Op1->getType())); + // icmp (X & Y) eq/ne Y --> (~X & Y) eq/ne 0 if X is freely invertible. + if (auto *NotA = IC.getFreelyInverted(A, A->hasOneUse(), &IC.Builder)) + return new ICmpInst(Pred, IC.Builder.CreateAnd(Op1, NotA), + Constant::getNullValue(Op1->getType())); + } + + if (!ICmpInst::isSigned(Pred)) + return nullptr; + + KnownBits KnownY = IC.computeKnownBits(A, /*Depth=*/0, &I); + // (X & NegY) spred X --> (X & NegY) upred X + if (KnownY.isNegative()) + return new ICmpInst(ICmpInst::getUnsignedPredicate(Pred), Op0, Op1); + + if (Pred != ICmpInst::ICMP_SLE && Pred != ICmpInst::ICMP_SGT) + return nullptr; + + if (KnownY.isNonNegative()) + // (X & PosY) s<= X --> X s>= 0 + // (X & PosY) s> X --> X s< 0 + return new ICmpInst(ICmpInst::getSwappedPredicate(Pred), Op1, + Constant::getNullValue(Op1->getType())); + + if (isKnownNegative(Op1, IC.getSimplifyQuery().getWithInstruction(&I))) + // (NegX & Y) s<= NegX --> Y s< 0 + // (NegX & Y) s> NegX --> Y s>= 0 + return new ICmpInst(ICmpInst::getFlippedStrictnessPredicate(Pred), A, + Constant::getNullValue(A->getType())); + return nullptr; } @@ -4520,7 +4852,7 @@ static Instruction *foldICmpOrXX(ICmpInst &I, const SimplifyQuery &Q, if (ICmpInst::isEquality(Pred) && Op0->hasOneUse()) { // icmp (X | Y) eq/ne Y --> (X & ~Y) eq/ne 0 if Y is freely invertible if (Value *NotOp1 = - IC.getFreelyInverted(Op1, Op1->hasOneUse(), &IC.Builder)) + IC.getFreelyInverted(Op1, !Op1->hasNUsesOrMore(3), &IC.Builder)) return new ICmpInst(Pred, IC.Builder.CreateAnd(A, NotOp1), Constant::getNullValue(Op1->getType())); // icmp (X | Y) eq/ne Y --> (~X | Y) eq/ne -1 if X is freely invertible. @@ -4548,8 +4880,7 @@ static Instruction *foldICmpXorXX(ICmpInst &I, const SimplifyQuery &Q, // icmp (X ^ Y_NonZero) s>= X --> icmp (X ^ Y_NonZero) s> X // icmp (X ^ Y_NonZero) s<= X --> icmp (X ^ Y_NonZero) s< X CmpInst::Predicate PredOut = CmpInst::getStrictPredicate(Pred); - if (PredOut != Pred && - isKnownNonZero(A, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + if (PredOut != Pred && isKnownNonZero(A, Q)) return new ICmpInst(PredOut, Op0, Op1); return nullptr; @@ -4614,7 +4945,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, const APInt *C; if ((Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE) && match(Op0, m_And(m_BinOp(BO), m_LowBitMask(C))) && - match(BO, m_Add(m_Specific(Op1), m_SpecificIntAllowUndef(*C)))) { + match(BO, m_Add(m_Specific(Op1), m_SpecificIntAllowPoison(*C)))) { CmpInst::Predicate NewPred = Pred == ICmpInst::ICMP_ULT ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ; Constant *Zero = ConstantInt::getNullValue(Op1->getType()); @@ -4623,7 +4954,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, if ((Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE) && match(Op1, m_And(m_BinOp(BO), m_LowBitMask(C))) && - match(BO, m_Add(m_Specific(Op0), m_SpecificIntAllowUndef(*C)))) { + match(BO, m_Add(m_Specific(Op0), m_SpecificIntAllowPoison(*C)))) { CmpInst::Predicate NewPred = Pred == ICmpInst::ICMP_UGT ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ; Constant *Zero = ConstantInt::getNullValue(Op1->getType()); @@ -4773,8 +5104,9 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, (BO0->hasOneUse() || BO1->hasOneUse()) && !I.isUnsigned()) { const APInt *AP1, *AP2; // TODO: Support non-uniform vectors. - // TODO: Allow undef passthrough if B AND D's element is undef. - if (match(B, m_APIntAllowUndef(AP1)) && match(D, m_APIntAllowUndef(AP2)) && + // TODO: Allow poison passthrough if B or D's element is poison. + if (match(B, m_APIntAllowPoison(AP1)) && + match(D, m_APIntAllowPoison(AP2)) && AP1->isNegative() == AP2->isNegative()) { APInt AP1Abs = AP1->abs(); APInt AP2Abs = AP2->abs(); @@ -4832,11 +5164,11 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, return new ICmpInst(Pred, C, D); // (A - B) u>=/u< A --> B u>/u<= A iff B != 0 if (A == Op1 && (Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_ULT) && - isKnownNonZero(B, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + isKnownNonZero(B, Q)) return new ICmpInst(CmpInst::getFlippedStrictnessPredicate(Pred), B, A); // C u<=/u> (C - D) --> C u</u>= D iff B != 0 if (C == Op0 && (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT) && - isKnownNonZero(D, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + isKnownNonZero(D, Q)) return new ICmpInst(CmpInst::getFlippedStrictnessPredicate(Pred), C, D); // icmp (A-B), (C-B) -> icmp A, C for equalities or if there is no overflow. @@ -4878,14 +5210,13 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, // X * Z eq/ne Y * Z -> X eq/ne Y if (ZKnown.countMaxTrailingZeros() == 0) return new ICmpInst(Pred, X, Y); - NonZero = !ZKnown.One.isZero() || - isKnownNonZero(Z, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + NonZero = !ZKnown.One.isZero() || isKnownNonZero(Z, Q); // if Z != 0 and nsw(X * Z) and nsw(Y * Z) // X * Z eq/ne Y * Z -> X eq/ne Y if (NonZero && BO0 && BO1 && Op0HasNSW && Op1HasNSW) return new ICmpInst(Pred, X, Y); } else - NonZero = isKnownNonZero(Z, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + NonZero = isKnownNonZero(Z, Q); // If Z != 0 and nuw(X * Z) and nuw(Y * Z) // X * Z u{lt/le/gt/ge}/eq/ne Y * Z -> X u{lt/le/gt/ge}/eq/ne Y @@ -5027,9 +5358,6 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, if (Value *V = foldMultiplicationOverflowCheck(I)) return replaceInstUsesWith(I, V); - if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder)) - return replaceInstUsesWith(I, V); - if (Instruction *R = foldICmpAndXX(I, Q, *this)) return R; @@ -5272,21 +5600,6 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { } } - // canoncalize: - // (icmp eq/ne (and X, C), X) - // -> (icmp eq/ne (and X, ~C), 0) - { - Constant *CMask; - A = nullptr; - if (match(Op0, m_OneUse(m_And(m_Specific(Op1), m_ImmConstant(CMask))))) - A = Op1; - else if (match(Op1, m_OneUse(m_And(m_Specific(Op0), m_ImmConstant(CMask))))) - A = Op0; - if (A) - return new ICmpInst(Pred, Builder.CreateAnd(A, Builder.CreateNot(CMask)), - Constant::getNullValue(A->getType())); - } - if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && (A == Op0 || B == Op0)) { // A == (A^B) -> B == 0 Value *OtherVal = A == Op0 ? B : A; @@ -5294,8 +5607,8 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { } // (X&Z) == (Y&Z) -> (X^Y) & Z == 0 - if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B)))) && - match(Op1, m_OneUse(m_And(m_Value(C), m_Value(D))))) { + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_And(m_Value(C), m_Value(D)))) { Value *X = nullptr, *Y = nullptr, *Z = nullptr; if (A == C) { @@ -5316,10 +5629,26 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { Z = B; } - if (X) { // Build (X^Y) & Z - Op1 = Builder.CreateXor(X, Y); - Op1 = Builder.CreateAnd(Op1, Z); - return new ICmpInst(Pred, Op1, Constant::getNullValue(Op1->getType())); + if (X) { + // If X^Y is a negative power of two, then `icmp eq/ne (Z & NegP2), 0` + // will fold to `icmp ult/uge Z, -NegP2` incurringb no additional + // instructions. + const APInt *C0, *C1; + bool XorIsNegP2 = match(X, m_APInt(C0)) && match(Y, m_APInt(C1)) && + (*C0 ^ *C1).isNegatedPowerOf2(); + + // If either Op0/Op1 are both one use or X^Y will constant fold and one of + // Op0/Op1 are one use, proceed. In those cases we are instruction neutral + // but `icmp eq/ne A, 0` is easier to analyze than `icmp eq/ne A, B`. + int UseCnt = + int(Op0->hasOneUse()) + int(Op1->hasOneUse()) + + (int(match(X, m_ImmConstant()) && match(Y, m_ImmConstant()))); + if (XorIsNegP2 || UseCnt >= 2) { + // Build (X^Y) & Z + Op1 = Builder.CreateXor(X, Y); + Op1 = Builder.CreateAnd(Op1, Z); + return new ICmpInst(Pred, Op1, Constant::getNullValue(Op1->getType())); + } } } @@ -5349,10 +5678,10 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { // (A >> C) == (B >> C) --> (A^B) u< (1 << C) // For lshr and ashr pairs. const APInt *AP1, *AP2; - if ((match(Op0, m_OneUse(m_LShr(m_Value(A), m_APIntAllowUndef(AP1)))) && - match(Op1, m_OneUse(m_LShr(m_Value(B), m_APIntAllowUndef(AP2))))) || - (match(Op0, m_OneUse(m_AShr(m_Value(A), m_APIntAllowUndef(AP1)))) && - match(Op1, m_OneUse(m_AShr(m_Value(B), m_APIntAllowUndef(AP2)))))) { + if ((match(Op0, m_OneUse(m_LShr(m_Value(A), m_APIntAllowPoison(AP1)))) && + match(Op1, m_OneUse(m_LShr(m_Value(B), m_APIntAllowPoison(AP2))))) || + (match(Op0, m_OneUse(m_AShr(m_Value(A), m_APIntAllowPoison(AP1)))) && + match(Op1, m_OneUse(m_AShr(m_Value(B), m_APIntAllowPoison(AP2)))))) { if (AP1 != AP2) return nullptr; unsigned TypeBits = AP1->getBitWidth(); @@ -5968,7 +6297,7 @@ static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) { // If this is a normal comparison, it demands all bits. If it is a sign bit // comparison, it only demands the sign bit. bool UnusedBit; - if (InstCombiner::isSignBitCheck(I.getPredicate(), *RHS, UnusedBit)) + if (isSignBitCheck(I.getPredicate(), *RHS, UnusedBit)) return APInt::getSignMask(BitWidth); switch (I.getPredicate()) { @@ -6122,13 +6451,13 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { // Don't use dominating conditions when folding icmp using known bits. This // may convert signed into unsigned predicates in ways that other passes // (especially IndVarSimplify) may not be able to reliably undo. - SQ.DC = nullptr; - auto _ = make_scope_exit([&]() { SQ.DC = &DC; }); + SimplifyQuery Q = SQ.getWithoutDomCondCache().getWithInstruction(&I); if (SimplifyDemandedBits(&I, 0, getDemandedBitsLHSMask(I, BitWidth), - Op0Known, 0)) + Op0Known, /*Depth=*/0, Q)) return &I; - if (SimplifyDemandedBits(&I, 1, APInt::getAllOnes(BitWidth), Op1Known, 0)) + if (SimplifyDemandedBits(&I, 1, APInt::getAllOnes(BitWidth), Op1Known, + /*Depth=*/0, Q)) return &I; } @@ -6701,8 +7030,8 @@ static Instruction *foldVectorCmp(CmpInst &Cmp, if (auto *I = dyn_cast<Instruction>(V)) I->copyIRFlags(&Cmp); Module *M = Cmp.getModule(); - Function *F = Intrinsic::getDeclaration( - M, Intrinsic::experimental_vector_reverse, V->getType()); + Function *F = + Intrinsic::getDeclaration(M, Intrinsic::vector_reverse, V->getType()); return CallInst::Create(F, V); }; @@ -6743,10 +7072,10 @@ static Instruction *foldVectorCmp(CmpInst &Cmp, // Length-changing splats are ok, so adjust the constants as needed: // cmp (shuffle V1, M), C --> shuffle (cmp V1, C'), M - Constant *ScalarC = C->getSplatValue(/* AllowUndefs */ true); + Constant *ScalarC = C->getSplatValue(/* AllowPoison */ true); int MaskSplatIndex; - if (ScalarC && match(M, m_SplatOrUndefMask(MaskSplatIndex))) { - // We allow undefs in matching, but this transform removes those for safety. + if (ScalarC && match(M, m_SplatOrPoisonMask(MaskSplatIndex))) { + // We allow poison in matching, but this transform removes it for safety. // Demanded elements analysis should be able to recover some/all of that. C = ConstantVector::getSplat(cast<VectorType>(V1Ty)->getElementCount(), ScalarC); @@ -6930,6 +7259,40 @@ Instruction *InstCombinerImpl::foldICmpCommutative(ICmpInst::Predicate Pred, } } + const SimplifyQuery Q = SQ.getWithInstruction(&CxtI); + if (Value *V = foldICmpWithLowBitMaskedVal(Pred, Op0, Op1, Q, *this)) + return replaceInstUsesWith(CxtI, V); + + // Folding (X / Y) pred X => X swap(pred) 0 for constant Y other than 0 or 1 + auto CheckUGT1 = [](const APInt &Divisor) { return Divisor.ugt(1); }; + { + if (match(Op0, m_UDiv(m_Specific(Op1), m_CheckedInt(CheckUGT1)))) { + return new ICmpInst(ICmpInst::getSwappedPredicate(Pred), Op1, + Constant::getNullValue(Op1->getType())); + } + + if (!ICmpInst::isUnsigned(Pred) && + match(Op0, m_SDiv(m_Specific(Op1), m_CheckedInt(CheckUGT1)))) { + return new ICmpInst(ICmpInst::getSwappedPredicate(Pred), Op1, + Constant::getNullValue(Op1->getType())); + } + } + + // Another case of this fold is (X >> Y) pred X => X swap(pred) 0 if Y != 0 + auto CheckNE0 = [](const APInt &Shift) { return !Shift.isZero(); }; + { + if (match(Op0, m_LShr(m_Specific(Op1), m_CheckedInt(CheckNE0)))) { + return new ICmpInst(ICmpInst::getSwappedPredicate(Pred), Op1, + Constant::getNullValue(Op1->getType())); + } + + if ((Pred == CmpInst::ICMP_SLT || Pred == CmpInst::ICMP_SGE) && + match(Op0, m_AShr(m_Specific(Op1), m_CheckedInt(CheckNE0)))) { + return new ICmpInst(ICmpInst::getSwappedPredicate(Pred), Op1, + Constant::getNullValue(Op1->getType())); + } + } + return nullptr; } @@ -7062,6 +7425,14 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { foldICmpCommutative(I.getSwappedPredicate(), Op1, Op0, I)) return Res; + if (I.isCommutative()) { + if (auto Pair = matchSymmetricPair(I.getOperand(0), I.getOperand(1))) { + replaceOperand(I, 0, Pair->first); + replaceOperand(I, 1, Pair->second); + return &I; + } + } + // In case of a comparison with two select instructions having the same // condition, check whether one of the resulting branches can be simplified. // If so, just compare the other branch and select the appropriate result. @@ -7172,7 +7543,7 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { unsigned OpWidth = Op0->getType()->getScalarSizeInBits(); Instruction *ShiftI; if (match(Op0, m_CombineAnd(m_Instruction(ShiftI), - m_Shr(m_Value(X), m_SpecificIntAllowUndef( + m_Shr(m_Value(X), m_SpecificIntAllowPoison( OpWidth - 1))))) { unsigned ExtOpc = ExtI->getOpcode(); unsigned ShiftOpc = ShiftI->getOpcode(); @@ -7232,36 +7603,37 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, Constant *RHSC) { - if (!isa<ConstantFP>(RHSC)) return nullptr; - const APFloat &RHS = cast<ConstantFP>(RHSC)->getValueAPF(); + const APFloat *RHS; + if (!match(RHSC, m_APFloat(RHS))) + return nullptr; // Get the width of the mantissa. We don't want to hack on conversions that // might lose information from the integer, e.g. "i64 -> float" int MantissaWidth = LHSI->getType()->getFPMantissaWidth(); if (MantissaWidth == -1) return nullptr; // Unknown. - IntegerType *IntTy = cast<IntegerType>(LHSI->getOperand(0)->getType()); - + Type *IntTy = LHSI->getOperand(0)->getType(); + unsigned IntWidth = IntTy->getScalarSizeInBits(); bool LHSUnsigned = isa<UIToFPInst>(LHSI); if (I.isEquality()) { FCmpInst::Predicate P = I.getPredicate(); bool IsExact = false; - APSInt RHSCvt(IntTy->getBitWidth(), LHSUnsigned); - RHS.convertToInteger(RHSCvt, APFloat::rmNearestTiesToEven, &IsExact); + APSInt RHSCvt(IntWidth, LHSUnsigned); + RHS->convertToInteger(RHSCvt, APFloat::rmNearestTiesToEven, &IsExact); // If the floating point constant isn't an integer value, we know if we will // ever compare equal / not equal to it. if (!IsExact) { // TODO: Can never be -0.0 and other non-representable values - APFloat RHSRoundInt(RHS); + APFloat RHSRoundInt(*RHS); RHSRoundInt.roundToIntegral(APFloat::rmNearestTiesToEven); - if (RHS != RHSRoundInt) { + if (*RHS != RHSRoundInt) { if (P == FCmpInst::FCMP_OEQ || P == FCmpInst::FCMP_UEQ) - return replaceInstUsesWith(I, Builder.getFalse()); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); assert(P == FCmpInst::FCMP_ONE || P == FCmpInst::FCMP_UNE); - return replaceInstUsesWith(I, Builder.getTrue()); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); } } @@ -7272,23 +7644,22 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I, // Check to see that the input is converted from an integer type that is small // enough that preserves all bits. TODO: check here for "known" sign bits. // This would allow us to handle (fptosi (x >>s 62) to float) if x is i64 f.e. - unsigned InputSize = IntTy->getScalarSizeInBits(); - // Following test does NOT adjust InputSize downwards for signed inputs, + // Following test does NOT adjust IntWidth downwards for signed inputs, // because the most negative value still requires all the mantissa bits // to distinguish it from one less than that value. - if ((int)InputSize > MantissaWidth) { + if ((int)IntWidth > MantissaWidth) { // Conversion would lose accuracy. Check if loss can impact comparison. - int Exp = ilogb(RHS); + int Exp = ilogb(*RHS); if (Exp == APFloat::IEK_Inf) { - int MaxExponent = ilogb(APFloat::getLargest(RHS.getSemantics())); - if (MaxExponent < (int)InputSize - !LHSUnsigned) + int MaxExponent = ilogb(APFloat::getLargest(RHS->getSemantics())); + if (MaxExponent < (int)IntWidth - !LHSUnsigned) // Conversion could create infinity. return nullptr; } else { // Note that if RHS is zero or NaN, then Exp is negative // and first condition is trivially false. - if (MantissaWidth <= Exp && Exp <= (int)InputSize - !LHSUnsigned) + if (MantissaWidth <= Exp && Exp <= (int)IntWidth - !LHSUnsigned) // Conversion could affect comparison. return nullptr; } @@ -7297,7 +7668,7 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I, // Otherwise, we can potentially simplify the comparison. We know that it // will always come through as an integer value and we know the constant is // not a NAN (it would have been previously simplified). - assert(!RHS.isNaN() && "NaN comparison not already folded!"); + assert(!RHS->isNaN() && "NaN comparison not already folded!"); ICmpInst::Predicate Pred; switch (I.getPredicate()) { @@ -7327,64 +7698,62 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I, Pred = ICmpInst::ICMP_NE; break; case FCmpInst::FCMP_ORD: - return replaceInstUsesWith(I, Builder.getTrue()); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); case FCmpInst::FCMP_UNO: - return replaceInstUsesWith(I, Builder.getFalse()); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); } // Now we know that the APFloat is a normal number, zero or inf. // See if the FP constant is too large for the integer. For example, // comparing an i8 to 300.0. - unsigned IntWidth = IntTy->getScalarSizeInBits(); - if (!LHSUnsigned) { // If the RHS value is > SignedMax, fold the comparison. This handles +INF // and large values. - APFloat SMax(RHS.getSemantics()); + APFloat SMax(RHS->getSemantics()); SMax.convertFromAPInt(APInt::getSignedMaxValue(IntWidth), true, APFloat::rmNearestTiesToEven); - if (SMax < RHS) { // smax < 13123.0 + if (SMax < *RHS) { // smax < 13123.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) - return replaceInstUsesWith(I, Builder.getTrue()); - return replaceInstUsesWith(I, Builder.getFalse()); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); } } else { // If the RHS value is > UnsignedMax, fold the comparison. This handles // +INF and large values. - APFloat UMax(RHS.getSemantics()); + APFloat UMax(RHS->getSemantics()); UMax.convertFromAPInt(APInt::getMaxValue(IntWidth), false, APFloat::rmNearestTiesToEven); - if (UMax < RHS) { // umax < 13123.0 + if (UMax < *RHS) { // umax < 13123.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) - return replaceInstUsesWith(I, Builder.getTrue()); - return replaceInstUsesWith(I, Builder.getFalse()); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); } } if (!LHSUnsigned) { // See if the RHS value is < SignedMin. - APFloat SMin(RHS.getSemantics()); + APFloat SMin(RHS->getSemantics()); SMin.convertFromAPInt(APInt::getSignedMinValue(IntWidth), true, APFloat::rmNearestTiesToEven); - if (SMin > RHS) { // smin > 12312.0 + if (SMin > *RHS) { // smin > 12312.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) - return replaceInstUsesWith(I, Builder.getTrue()); - return replaceInstUsesWith(I, Builder.getFalse()); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); } } else { // See if the RHS value is < UnsignedMin. - APFloat UMin(RHS.getSemantics()); + APFloat UMin(RHS->getSemantics()); UMin.convertFromAPInt(APInt::getMinValue(IntWidth), false, APFloat::rmNearestTiesToEven); - if (UMin > RHS) { // umin > 12312.0 + if (UMin > *RHS) { // umin > 12312.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) - return replaceInstUsesWith(I, Builder.getTrue()); - return replaceInstUsesWith(I, Builder.getFalse()); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); } } @@ -7394,8 +7763,8 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I, // Don't do this for zero, because -0.0 is not fractional. APSInt RHSInt(IntWidth, LHSUnsigned); bool IsExact; - RHS.convertToInteger(RHSInt, APFloat::rmTowardZero, &IsExact); - if (!RHS.isZero()) { + RHS->convertToInteger(RHSInt, APFloat::rmTowardZero, &IsExact); + if (!RHS->isZero()) { if (!IsExact) { // If we had a comparison against a fractional value, we have to adjust // the compare predicate and sometimes the value. RHSC is rounded towards @@ -7403,57 +7772,57 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I, switch (Pred) { default: llvm_unreachable("Unexpected integer comparison!"); case ICmpInst::ICMP_NE: // (float)int != 4.4 --> true - return replaceInstUsesWith(I, Builder.getTrue()); + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); case ICmpInst::ICMP_EQ: // (float)int == 4.4 --> false - return replaceInstUsesWith(I, Builder.getFalse()); + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); case ICmpInst::ICMP_ULE: // (float)int <= 4.4 --> int <= 4 // (float)int <= -4.4 --> false - if (RHS.isNegative()) - return replaceInstUsesWith(I, Builder.getFalse()); + if (RHS->isNegative()) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); break; case ICmpInst::ICMP_SLE: // (float)int <= 4.4 --> int <= 4 // (float)int <= -4.4 --> int < -4 - if (RHS.isNegative()) + if (RHS->isNegative()) Pred = ICmpInst::ICMP_SLT; break; case ICmpInst::ICMP_ULT: // (float)int < -4.4 --> false // (float)int < 4.4 --> int <= 4 - if (RHS.isNegative()) - return replaceInstUsesWith(I, Builder.getFalse()); + if (RHS->isNegative()) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); Pred = ICmpInst::ICMP_ULE; break; case ICmpInst::ICMP_SLT: // (float)int < -4.4 --> int < -4 // (float)int < 4.4 --> int <= 4 - if (!RHS.isNegative()) + if (!RHS->isNegative()) Pred = ICmpInst::ICMP_SLE; break; case ICmpInst::ICMP_UGT: // (float)int > 4.4 --> int > 4 // (float)int > -4.4 --> true - if (RHS.isNegative()) - return replaceInstUsesWith(I, Builder.getTrue()); + if (RHS->isNegative()) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); break; case ICmpInst::ICMP_SGT: // (float)int > 4.4 --> int > 4 // (float)int > -4.4 --> int >= -4 - if (RHS.isNegative()) + if (RHS->isNegative()) Pred = ICmpInst::ICMP_SGE; break; case ICmpInst::ICMP_UGE: // (float)int >= -4.4 --> true // (float)int >= 4.4 --> int > 4 - if (RHS.isNegative()) - return replaceInstUsesWith(I, Builder.getTrue()); + if (RHS->isNegative()) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); Pred = ICmpInst::ICMP_UGT; break; case ICmpInst::ICMP_SGE: // (float)int >= -4.4 --> int >= -4 // (float)int >= 4.4 --> int > 4 - if (!RHS.isNegative()) + if (!RHS->isNegative()) Pred = ICmpInst::ICMP_SGT; break; } @@ -7462,7 +7831,8 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I, // Lower this FP comparison into an appropriate integer version of the // comparison. - return new ICmpInst(Pred, LHSI->getOperand(0), Builder.getInt(RHSInt)); + return new ICmpInst(Pred, LHSI->getOperand(0), + ConstantInt::get(LHSI->getOperand(0)->getType(), RHSInt)); } /// Fold (C / X) < 0.0 --> X < 0.0 if possible. Swap predicate if necessary. @@ -7632,6 +8002,53 @@ static Instruction *foldFCmpFNegCommonOp(FCmpInst &I) { return new FCmpInst(Pred, Op0, Zero, "", &I); } +static Instruction *foldFCmpFSubIntoFCmp(FCmpInst &I, Instruction *LHSI, + Constant *RHSC, InstCombinerImpl &CI) { + const CmpInst::Predicate Pred = I.getPredicate(); + Value *X = LHSI->getOperand(0); + Value *Y = LHSI->getOperand(1); + switch (Pred) { + default: + break; + case FCmpInst::FCMP_UGT: + case FCmpInst::FCMP_ULT: + case FCmpInst::FCMP_UNE: + case FCmpInst::FCMP_OEQ: + case FCmpInst::FCMP_OGE: + case FCmpInst::FCMP_OLE: + // The optimization is not valid if X and Y are infinities of the same + // sign, i.e. the inf - inf = nan case. If the fsub has the ninf or nnan + // flag then we can assume we do not have that case. Otherwise we might be + // able to prove that either X or Y is not infinity. + if (!LHSI->hasNoNaNs() && !LHSI->hasNoInfs() && + !isKnownNeverInfinity(Y, /*Depth=*/0, + CI.getSimplifyQuery().getWithInstruction(&I)) && + !isKnownNeverInfinity(X, /*Depth=*/0, + CI.getSimplifyQuery().getWithInstruction(&I))) + break; + + [[fallthrough]]; + case FCmpInst::FCMP_OGT: + case FCmpInst::FCMP_OLT: + case FCmpInst::FCMP_ONE: + case FCmpInst::FCMP_UEQ: + case FCmpInst::FCMP_UGE: + case FCmpInst::FCMP_ULE: + // fcmp pred (x - y), 0 --> fcmp pred x, y + if (match(RHSC, m_AnyZeroFP()) && + I.getFunction()->getDenormalMode( + LHSI->getType()->getScalarType()->getFltSemantics()) == + DenormalMode::getIEEE()) { + CI.replaceOperand(I, 0, X); + CI.replaceOperand(I, 1, Y); + return &I; + } + break; + } + + return nullptr; +} + Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { bool Changed = false; @@ -7675,15 +8092,23 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { } } + if (I.isCommutative()) { + if (auto Pair = matchSymmetricPair(I.getOperand(0), I.getOperand(1))) { + replaceOperand(I, 0, Pair->first); + replaceOperand(I, 1, Pair->second); + return &I; + } + } + // If we're just checking for a NaN (ORD/UNO) and have a non-NaN operand, // then canonicalize the operand to 0.0. if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) { - if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, DL, &TLI, 0, - &AC, &I, &DT)) + if (!match(Op0, m_PosZeroFP()) && + isKnownNeverNaN(Op0, 0, getSimplifyQuery().getWithInstruction(&I))) return replaceOperand(I, 0, ConstantFP::getZero(OpType)); if (!match(Op1, m_PosZeroFP()) && - isKnownNeverNaN(Op1, DL, &TLI, 0, &AC, &I, &DT)) + isKnownNeverNaN(Op1, 0, getSimplifyQuery().getWithInstruction(&I))) return replaceOperand(I, 1, ConstantFP::getZero(OpType)); } @@ -7715,12 +8140,56 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP())) return replaceOperand(I, 1, ConstantFP::getZero(OpType)); + // Canonicalize: + // fcmp olt X, +inf -> fcmp one X, +inf + // fcmp ole X, +inf -> fcmp ord X, 0 + // fcmp ogt X, +inf -> false + // fcmp oge X, +inf -> fcmp oeq X, +inf + // fcmp ult X, +inf -> fcmp une X, +inf + // fcmp ule X, +inf -> true + // fcmp ugt X, +inf -> fcmp uno X, 0 + // fcmp uge X, +inf -> fcmp ueq X, +inf + // fcmp olt X, -inf -> false + // fcmp ole X, -inf -> fcmp oeq X, -inf + // fcmp ogt X, -inf -> fcmp one X, -inf + // fcmp oge X, -inf -> fcmp ord X, 0 + // fcmp ult X, -inf -> fcmp uno X, 0 + // fcmp ule X, -inf -> fcmp ueq X, -inf + // fcmp ugt X, -inf -> fcmp une X, -inf + // fcmp uge X, -inf -> true + const APFloat *C; + if (match(Op1, m_APFloat(C)) && C->isInfinity()) { + switch (C->isNegative() ? FCmpInst::getSwappedPredicate(Pred) : Pred) { + default: + break; + case FCmpInst::FCMP_ORD: + case FCmpInst::FCMP_UNO: + case FCmpInst::FCMP_TRUE: + case FCmpInst::FCMP_FALSE: + case FCmpInst::FCMP_OGT: + case FCmpInst::FCMP_ULE: + llvm_unreachable("Should be simplified by InstSimplify"); + case FCmpInst::FCMP_OLT: + return new FCmpInst(FCmpInst::FCMP_ONE, Op0, Op1, "", &I); + case FCmpInst::FCMP_OLE: + return new FCmpInst(FCmpInst::FCMP_ORD, Op0, ConstantFP::getZero(OpType), + "", &I); + case FCmpInst::FCMP_OGE: + return new FCmpInst(FCmpInst::FCMP_OEQ, Op0, Op1, "", &I); + case FCmpInst::FCMP_ULT: + return new FCmpInst(FCmpInst::FCMP_UNE, Op0, Op1, "", &I); + case FCmpInst::FCMP_UGT: + return new FCmpInst(FCmpInst::FCMP_UNO, Op0, ConstantFP::getZero(OpType), + "", &I); + case FCmpInst::FCMP_UGE: + return new FCmpInst(FCmpInst::FCMP_UEQ, Op0, Op1, "", &I); + } + } + // Ignore signbit of bitcasted int when comparing equality to FP 0.0: // fcmp oeq/une (bitcast X), 0.0 --> (and X, SignMaskC) ==/!= 0 if (match(Op1, m_PosZeroFP()) && - match(Op0, m_OneUse(m_BitCast(m_Value(X)))) && - X->getType()->isVectorTy() == OpType->isVectorTy() && - X->getType()->getScalarSizeInBits() == OpType->getScalarSizeInBits()) { + match(Op0, m_OneUse(m_ElementWiseBitCast(m_Value(X))))) { ICmpInst::Predicate IntPred = ICmpInst::BAD_ICMP_PREDICATE; if (Pred == FCmpInst::FCMP_OEQ) IntPred = ICmpInst::ICMP_EQ; @@ -7740,6 +8209,21 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { Constant *RHSC; if (match(Op0, m_Instruction(LHSI)) && match(Op1, m_Constant(RHSC))) { switch (LHSI->getOpcode()) { + case Instruction::Select: + // fcmp eq (cond ? x : -x), 0 --> fcmp eq x, 0 + if (FCmpInst::isEquality(Pred) && match(RHSC, m_AnyZeroFP()) && + (match(LHSI, + m_Select(m_Value(), m_Value(X), m_FNeg(m_Deferred(X)))) || + match(LHSI, m_Select(m_Value(), m_FNeg(m_Value(X)), m_Deferred(X))))) + return replaceOperand(I, 0, X); + if (Instruction *NV = FoldOpIntoSelect(I, cast<SelectInst>(LHSI))) + return NV; + break; + case Instruction::FSub: + if (LHSI->hasOneUse()) + if (Instruction *NV = foldFCmpFSubIntoFCmp(I, LHSI, RHSC, *this)) + return NV; + break; case Instruction::PHI: if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) return NV; @@ -7774,6 +8258,14 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { return new FCmpInst(I.getSwappedPredicate(), X, NegC, "", &I); } + // fcmp (fadd X, 0.0), Y --> fcmp X, Y + if (match(Op0, m_FAdd(m_Value(X), m_AnyZeroFP()))) + return new FCmpInst(Pred, X, Op1, "", &I); + + // fcmp X, (fadd Y, 0.0) --> fcmp X, Y + if (match(Op1, m_FAdd(m_Value(Y), m_AnyZeroFP()))) + return new FCmpInst(Pred, Op0, Y, "", &I); + if (match(Op0, m_FPExt(m_Value(X)))) { // fcmp (fpext X), (fpext Y) -> fcmp X, Y if (match(Op1, m_FPExt(m_Value(Y))) && X->getType() == Y->getType()) @@ -7828,7 +8320,6 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { // TODO: Simplify if the copysign constant is 0.0 or NaN. // TODO: Handle non-zero compare constants. // TODO: Handle other predicates. - const APFloat *C; if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::copysign>(m_APFloat(C), m_Value(X)))) && match(Op1, m_AnyZeroFP()) && !C->isZero() && !C->isNaN()) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index c24b6e3a5b33..64fbcc80e0ed 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -67,10 +67,10 @@ public: bool MinimizeSize, AAResults *AA, AssumptionCache &AC, TargetLibraryInfo &TLI, TargetTransformInfo &TTI, DominatorTree &DT, OptimizationRemarkEmitter &ORE, - BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, - const DataLayout &DL, LoopInfo *LI) + BlockFrequencyInfo *BFI, BranchProbabilityInfo *BPI, + ProfileSummaryInfo *PSI, const DataLayout &DL, LoopInfo *LI) : InstCombiner(Worklist, Builder, MinimizeSize, AA, AC, TLI, TTI, DT, ORE, - BFI, PSI, DL, LI) {} + BFI, BPI, PSI, DL, LI) {} virtual ~InstCombinerImpl() = default; @@ -98,6 +98,7 @@ public: Instruction *visitSub(BinaryOperator &I); Instruction *visitFSub(BinaryOperator &I); Instruction *visitMul(BinaryOperator &I); + Instruction *foldPowiReassoc(BinaryOperator &I); Instruction *foldFMulReassoc(BinaryOperator &I); Instruction *visitFMul(BinaryOperator &I); Instruction *visitURem(BinaryOperator &I); @@ -202,16 +203,17 @@ public: FPClassTest Interested = fcAllFlags, const Instruction *CtxI = nullptr, unsigned Depth = 0) const { - return llvm::computeKnownFPClass(Val, FMF, DL, Interested, Depth, &TLI, &AC, - CtxI, &DT); + return llvm::computeKnownFPClass( + Val, FMF, Interested, Depth, + getSimplifyQuery().getWithInstruction(CtxI)); } KnownFPClass computeKnownFPClass(Value *Val, FPClassTest Interested = fcAllFlags, const Instruction *CtxI = nullptr, unsigned Depth = 0) const { - return llvm::computeKnownFPClass(Val, DL, Interested, Depth, &TLI, &AC, - CtxI, &DT); + return llvm::computeKnownFPClass( + Val, Interested, Depth, getSimplifyQuery().getWithInstruction(CtxI)); } /// Check if fmul \p MulVal, +0.0 will yield +0.0 (or signed zero is @@ -236,6 +238,9 @@ public: return getLosslessTrunc(C, TruncTy, Instruction::SExt); } + std::optional<std::pair<Intrinsic::ID, SmallVector<Value *, 3>>> + convertOrOfShiftsToFunnelShift(Instruction &Or); + private: bool annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI); bool isDesirableIntType(unsigned BitWidth) const; @@ -349,8 +354,9 @@ private: } bool willNotOverflowUnsignedMul(const Value *LHS, const Value *RHS, - const Instruction &CxtI) const { - return computeOverflowForUnsignedMul(LHS, RHS, &CxtI) == + const Instruction &CxtI, + bool IsNSW = false) const { + return computeOverflowForUnsignedMul(LHS, RHS, &CxtI, IsNSW) == OverflowResult::NeverOverflows; } @@ -371,10 +377,15 @@ private: } } - Value *EmitGEPOffset(User *GEP); + Value *EmitGEPOffset(GEPOperator *GEP, bool RewriteGEP = false); Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN); Instruction *foldBitcastExtElt(ExtractElementInst &ExtElt); Instruction *foldCastedBitwiseLogic(BinaryOperator &I); + Instruction *foldFBinOpOfIntCasts(BinaryOperator &I); + // Should only be called by `foldFBinOpOfIntCasts`. + Instruction *foldFBinOpOfIntCastsFromSign( + BinaryOperator &BO, bool OpsFromSigned, std::array<Value *, 2> IntOps, + Constant *Op1FpC, SmallVectorImpl<WithCache<const Value *>> &OpsKnown); Instruction *foldBinopOfSextBoolToSelect(BinaryOperator &I); Instruction *narrowBinOp(TruncInst &Trunc); Instruction *narrowMaskedBinOp(BinaryOperator &And); @@ -451,7 +462,7 @@ public: auto *SI = new StoreInst(ConstantInt::getTrue(Ctx), PoisonValue::get(PointerType::getUnqual(Ctx)), /*isVolatile*/ false, Align(1)); - InsertNewInstBefore(SI, InsertAt->getIterator()); + InsertNewInstWith(SI, InsertAt->getIterator()); } /// Combiner aware instruction erasure. @@ -534,21 +545,23 @@ public: ConstantInt *&Less, ConstantInt *&Equal, ConstantInt *&Greater); - /// Attempts to replace V with a simpler value based on the demanded + /// Attempts to replace I with a simpler value based on the demanded /// bits. - Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownBits &Known, - unsigned Depth, Instruction *CxtI); + Value *SimplifyDemandedUseBits(Instruction *I, const APInt &DemandedMask, + KnownBits &Known, unsigned Depth, + const SimplifyQuery &Q); + using InstCombiner::SimplifyDemandedBits; bool SimplifyDemandedBits(Instruction *I, unsigned Op, const APInt &DemandedMask, KnownBits &Known, - unsigned Depth = 0) override; + unsigned Depth, const SimplifyQuery &Q) override; /// Helper routine of SimplifyDemandedUseBits. It computes KnownZero/KnownOne /// bits. It also tries to handle simplifications that can be done based on /// DemandedMask, but without modifying the Instruction. Value *SimplifyMultipleUseDemandedBits(Instruction *I, const APInt &DemandedMask, - KnownBits &Known, - unsigned Depth, Instruction *CxtI); + KnownBits &Known, unsigned Depth, + const SimplifyQuery &Q); /// Helper routine of SimplifyDemandedUseBits. It tries to simplify demanded /// bit for "r1 = shr x, c1; r2 = shl r1, c2" instruction sequence. @@ -565,6 +578,15 @@ public: APInt &PoisonElts, unsigned Depth = 0, bool AllowMultipleUsers = false) override; + /// Attempts to replace V with a simpler value based on the demanded + /// floating-point classes + Value *SimplifyDemandedUseFPClass(Value *V, FPClassTest DemandedMask, + KnownFPClass &Known, unsigned Depth, + Instruction *CxtI); + bool SimplifyDemandedFPClass(Instruction *I, unsigned Op, + FPClassTest DemandedMask, KnownFPClass &Known, + unsigned Depth = 0); + /// Canonicalize the position of binops relative to shufflevector. Instruction *foldVectorBinop(BinaryOperator &Inst); Instruction *foldVectorSelect(SelectInst &Sel); @@ -642,8 +664,8 @@ public: Instruction *foldICmpUsingBoolRange(ICmpInst &I); Instruction *foldICmpInstWithConstant(ICmpInst &Cmp); Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp); - Instruction *foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp, - const APInt &C); + Instruction *foldICmpInstWithConstantAllowPoison(ICmpInst &Cmp, + const APInt &C); Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ); Instruction *foldICmpWithMinMax(Instruction &I, MinMaxIntrinsic *MinMax, Value *Z, ICmpInst::Predicate Pred); @@ -736,6 +758,12 @@ public: Value *EvaluateInDifferentType(Value *V, Type *Ty, bool isSigned); bool tryToSinkInstruction(Instruction *I, BasicBlock *DestBlock); + void tryToSinkInstructionDbgValues( + Instruction *I, BasicBlock::iterator InsertPos, BasicBlock *SrcBlock, + BasicBlock *DestBlock, SmallVectorImpl<DbgVariableIntrinsic *> &DbgUsers); + void tryToSinkInstructionDbgVariableRecords( + Instruction *I, BasicBlock::iterator InsertPos, BasicBlock *SrcBlock, + BasicBlock *DestBlock, SmallVectorImpl<DbgVariableRecord *> &DPUsers); bool removeInstructionsBeforeUnreachable(Instruction &I); void addDeadEdge(BasicBlock *From, BasicBlock *To, diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 1254a050027a..1661fa564c65 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -332,7 +332,7 @@ bool PointerReplacer::collectUsersRecursive(Instruction &I) { Worklist.insert(SI); if (!collectUsersRecursive(*SI)) return false; - } else if (isa<GetElementPtrInst, BitCastInst>(Inst)) { + } else if (isa<GetElementPtrInst>(Inst)) { Worklist.insert(Inst); if (!collectUsersRecursive(*Inst)) return false; @@ -342,9 +342,13 @@ bool PointerReplacer::collectUsersRecursive(Instruction &I) { Worklist.insert(Inst); } else if (isEqualOrValidAddrSpaceCast(Inst, FromAS)) { Worklist.insert(Inst); + if (!collectUsersRecursive(*Inst)) + return false; } else if (Inst->isLifetimeStartOrEnd()) { continue; } else { + // TODO: For arbitrary uses with address space mismatches, should we check + // if we can introduce a valid addrspacecast? LLVM_DEBUG(dbgs() << "Cannot handle pointer user: " << *U << '\n'); return false; } @@ -374,7 +378,7 @@ void PointerReplacer::replace(Instruction *I) { } else if (auto *PHI = dyn_cast<PHINode>(I)) { Type *NewTy = getReplacement(PHI->getIncomingValue(0))->getType(); auto *NewPHI = PHINode::Create(NewTy, PHI->getNumIncomingValues(), - PHI->getName(), PHI); + PHI->getName(), PHI->getIterator()); for (unsigned int I = 0; I < PHI->getNumIncomingValues(); ++I) NewPHI->addIncoming(getReplacement(PHI->getIncomingValue(I)), PHI->getIncomingBlock(I)); @@ -382,44 +386,38 @@ void PointerReplacer::replace(Instruction *I) { } else if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { auto *V = getReplacement(GEP->getPointerOperand()); assert(V && "Operand not replaced"); - SmallVector<Value *, 8> Indices; - Indices.append(GEP->idx_begin(), GEP->idx_end()); + SmallVector<Value *, 8> Indices(GEP->indices()); auto *NewI = GetElementPtrInst::Create(GEP->getSourceElementType(), V, Indices); IC.InsertNewInstWith(NewI, GEP->getIterator()); NewI->takeName(GEP); + NewI->setNoWrapFlags(GEP->getNoWrapFlags()); WorkMap[GEP] = NewI; - } else if (auto *BC = dyn_cast<BitCastInst>(I)) { - auto *V = getReplacement(BC->getOperand(0)); - assert(V && "Operand not replaced"); - auto *NewT = PointerType::get(BC->getType()->getContext(), - V->getType()->getPointerAddressSpace()); - auto *NewI = new BitCastInst(V, NewT); - IC.InsertNewInstWith(NewI, BC->getIterator()); - NewI->takeName(BC); - WorkMap[BC] = NewI; } else if (auto *SI = dyn_cast<SelectInst>(I)) { - auto *NewSI = SelectInst::Create( - SI->getCondition(), getReplacement(SI->getTrueValue()), - getReplacement(SI->getFalseValue()), SI->getName(), nullptr, SI); + Value *TrueValue = SI->getTrueValue(); + Value *FalseValue = SI->getFalseValue(); + if (Value *Replacement = getReplacement(TrueValue)) + TrueValue = Replacement; + if (Value *Replacement = getReplacement(FalseValue)) + FalseValue = Replacement; + auto *NewSI = SelectInst::Create(SI->getCondition(), TrueValue, FalseValue, + SI->getName(), nullptr, SI); IC.InsertNewInstWith(NewSI, SI->getIterator()); NewSI->takeName(SI); WorkMap[SI] = NewSI; } else if (auto *MemCpy = dyn_cast<MemTransferInst>(I)) { - auto *SrcV = getReplacement(MemCpy->getRawSource()); - // The pointer may appear in the destination of a copy, but we don't want to - // replace it. - if (!SrcV) { - assert(getReplacement(MemCpy->getRawDest()) && - "destination not in replace list"); - return; - } + auto *DestV = MemCpy->getRawDest(); + auto *SrcV = MemCpy->getRawSource(); + + if (auto *DestReplace = getReplacement(DestV)) + DestV = DestReplace; + if (auto *SrcReplace = getReplacement(SrcV)) + SrcV = SrcReplace; IC.Builder.SetInsertPoint(MemCpy); auto *NewI = IC.Builder.CreateMemTransferInst( - MemCpy->getIntrinsicID(), MemCpy->getRawDest(), MemCpy->getDestAlign(), - SrcV, MemCpy->getSourceAlign(), MemCpy->getLength(), - MemCpy->isVolatile()); + MemCpy->getIntrinsicID(), DestV, MemCpy->getDestAlign(), SrcV, + MemCpy->getSourceAlign(), MemCpy->getLength(), MemCpy->isVolatile()); AAMDNodes AAMD = MemCpy->getAAMetadata(); if (AAMD) NewI->setAAMetadata(AAMD); @@ -432,16 +430,17 @@ void PointerReplacer::replace(Instruction *I) { assert(isEqualOrValidAddrSpaceCast( ASC, V->getType()->getPointerAddressSpace()) && "Invalid address space cast!"); - auto *NewV = V; + if (V->getType()->getPointerAddressSpace() != ASC->getType()->getPointerAddressSpace()) { auto *NewI = new AddrSpaceCastInst(V, ASC->getType(), ""); NewI->takeName(ASC); IC.InsertNewInstWith(NewI, ASC->getIterator()); - NewV = NewI; + WorkMap[ASC] = NewI; + } else { + WorkMap[ASC] = V; } - IC.replaceInstUsesWith(*ASC, NewV); - IC.eraseInstFromFunction(*ASC); + } else { llvm_unreachable("should never reach here"); } @@ -777,7 +776,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { auto *Zero = ConstantInt::get(IdxType, 0); Value *V = PoisonValue::get(T); - TypeSize Offset = TypeSize::get(0, ET->isScalableTy()); + TypeSize Offset = TypeSize::getZero(); for (uint64_t i = 0; i < NumElements; i++) { Value *Indices[2] = { Zero, @@ -1303,7 +1302,7 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { auto *IdxType = Type::getInt64Ty(T->getContext()); auto *Zero = ConstantInt::get(IdxType, 0); - TypeSize Offset = TypeSize::get(0, AT->getElementType()->isScalableTy()); + TypeSize Offset = TypeSize::getZero(); for (uint64_t i = 0; i < NumElements; i++) { Value *Indices[2] = { Zero, diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 6c3adf00c189..f4f3644acfe5 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -105,7 +105,7 @@ static Value *foldMulSelectToNegate(BinaryOperator &I, if (match(&I, m_c_Mul(m_OneUse(m_Select(m_Value(Cond), m_One(), m_AllOnes())), m_Value(OtherOp)))) { bool HasAnyNoWrap = I.hasNoSignedWrap() || I.hasNoUnsignedWrap(); - Value *Neg = Builder.CreateNeg(OtherOp, "", false, HasAnyNoWrap); + Value *Neg = Builder.CreateNeg(OtherOp, "", HasAnyNoWrap); return Builder.CreateSelect(Cond, OtherOp, Neg); } // mul (select Cond, -1, 1), OtherOp --> select Cond, -OtherOp, OtherOp @@ -113,7 +113,7 @@ static Value *foldMulSelectToNegate(BinaryOperator &I, if (match(&I, m_c_Mul(m_OneUse(m_Select(m_Value(Cond), m_AllOnes(), m_One())), m_Value(OtherOp)))) { bool HasAnyNoWrap = I.hasNoSignedWrap() || I.hasNoUnsignedWrap(); - Value *Neg = Builder.CreateNeg(OtherOp, "", false, HasAnyNoWrap); + Value *Neg = Builder.CreateNeg(OtherOp, "", HasAnyNoWrap); return Builder.CreateSelect(Cond, Neg, OtherOp); } @@ -166,7 +166,9 @@ static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands, if (match(Y, m_OneUse(m_Add(m_BinOp(Shift), m_One()))) && match(Shift, m_OneUse(m_Shl(m_One(), m_Value(Z))))) { bool PropagateNSW = HasNSW && Shift->hasNoSignedWrap(); - Value *FrX = Builder.CreateFreeze(X, X->getName() + ".fr"); + Value *FrX = X; + if (!isGuaranteedNotToBeUndef(X)) + FrX = Builder.CreateFreeze(X, X->getName() + ".fr"); Value *Shl = Builder.CreateShl(FrX, Z, "mulshl", HasNUW, PropagateNSW); return Builder.CreateAdd(Shl, FrX, Mul.getName(), HasNUW, PropagateNSW); } @@ -177,7 +179,9 @@ static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands, // This increases uses of X, so it may require a freeze, but that is still // expected to be an improvement because it removes the multiply. if (match(Y, m_OneUse(m_Not(m_OneUse(m_Shl(m_AllOnes(), m_Value(Z))))))) { - Value *FrX = Builder.CreateFreeze(X, X->getName() + ".fr"); + Value *FrX = X; + if (!isGuaranteedNotToBeUndef(X)) + FrX = Builder.CreateFreeze(X, X->getName() + ".fr"); Value *Shl = Builder.CreateShl(FrX, Z, "mulshl"); return Builder.CreateSub(Shl, FrX, Mul.getName()); } @@ -223,11 +227,13 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { Value *NewOp; Constant *C1, *C2; const APInt *IVal; - if (match(&I, m_Mul(m_Shl(m_Value(NewOp), m_Constant(C2)), - m_Constant(C1))) && + if (match(&I, m_Mul(m_Shl(m_Value(NewOp), m_ImmConstant(C2)), + m_ImmConstant(C1))) && match(C1, m_APInt(IVal))) { // ((X << C2)*C1) == (X * (C1 << C2)) - Constant *Shl = ConstantExpr::getShl(C1, C2); + Constant *Shl = + ConstantFoldBinaryOpOperands(Instruction::Shl, C1, C2, DL); + assert(Shl && "Constant folding of immediate constants failed"); BinaryOperator *Mul = cast<BinaryOperator>(I.getOperand(0)); BinaryOperator *BO = BinaryOperator::CreateMul(NewOp, Shl); if (HasNUW && Mul->hasNoUnsignedWrap()) @@ -276,7 +282,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { const APInt *NegPow2C; Value *X; if (match(Op0, m_ZExtOrSExt(m_Value(X))) && - match(Op1, m_APIntAllowUndef(NegPow2C))) { + match(Op1, m_APIntAllowPoison(NegPow2C))) { unsigned SrcWidth = X->getType()->getScalarSizeInBits(); unsigned ShiftAmt = NegPow2C->countr_zero(); if (ShiftAmt >= BitWidth - SrcWidth) { @@ -319,19 +325,12 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { } // abs(X) * abs(X) -> X * X - // nabs(X) * nabs(X) -> X * X - if (Op0 == Op1) { - Value *X, *Y; - SelectPatternFlavor SPF = matchSelectPattern(Op0, X, Y).Flavor; - if (SPF == SPF_ABS || SPF == SPF_NABS) - return BinaryOperator::CreateMul(X, X); - - if (match(Op0, m_Intrinsic<Intrinsic::abs>(m_Value(X)))) - return BinaryOperator::CreateMul(X, X); - } + Value *X; + if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::abs>(m_Value(X)))) + return BinaryOperator::CreateMul(X, X); { - Value *X, *Y; + Value *Y; // abs(X) * abs(Y) -> abs(X * Y) if (I.hasNoSignedWrap() && match(Op0, @@ -344,7 +343,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { } // -X * C --> X * -C - Value *X, *Y; + Value *Y; Constant *Op1C; if (match(Op0, m_Neg(m_Value(X))) && match(Op1, m_Constant(Op1C))) return BinaryOperator::CreateMul(X, ConstantExpr::getNeg(Op1C)); @@ -370,6 +369,28 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { return BinaryOperator::CreateMul(NegOp0, X); } + if (Op0->hasOneUse()) { + // (mul (div exact X, C0), C1) + // -> (div exact X, C0 / C1) + // iff C0 % C1 == 0 and X / (C0 / C1) doesn't create UB. + const APInt *C1; + auto UDivCheck = [&C1](const APInt &C) { return C.urem(*C1).isZero(); }; + auto SDivCheck = [&C1](const APInt &C) { + APInt Quot, Rem; + APInt::sdivrem(C, *C1, Quot, Rem); + return Rem.isZero() && !Quot.isAllOnes(); + }; + if (match(Op1, m_APInt(C1)) && + (match(Op0, m_Exact(m_UDiv(m_Value(X), m_CheckedInt(UDivCheck)))) || + match(Op0, m_Exact(m_SDiv(m_Value(X), m_CheckedInt(SDivCheck)))))) { + auto BOpc = cast<BinaryOperator>(Op0)->getOpcode(); + return BinaryOperator::CreateExact( + BOpc, X, + Builder.CreateBinOp(BOpc, cast<BinaryOperator>(Op0)->getOperand(1), + Op1)); + } + } + // (X / Y) * Y = X - (X % Y) // (X / Y) * -Y = (X % Y) - X { @@ -397,7 +418,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { auto RemOpc = Div->getOpcode() == Instruction::UDiv ? Instruction::URem : Instruction::SRem; // X must be frozen because we are increasing its number of uses. - Value *XFreeze = Builder.CreateFreeze(X, X->getName() + ".fr"); + Value *XFreeze = X; + if (!isGuaranteedNotToBeUndef(X)) + XFreeze = Builder.CreateFreeze(X, X->getName() + ".fr"); Value *Rem = Builder.CreateBinOp(RemOpc, XFreeze, DivOp1); if (DivOp1 == Y) return BinaryOperator::CreateSub(XFreeze, Rem); @@ -448,6 +471,13 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) return SelectInst::Create(X, Op0, ConstantInt::getNullValue(Ty)); + // mul (sext X), Y -> select X, -Y, 0 + // mul Y, (sext X) -> select X, -Y, 0 + if (match(&I, m_c_Mul(m_OneUse(m_SExt(m_Value(X))), m_Value(Y))) && + X->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(X, Builder.CreateNeg(Y, "", I.hasNoSignedWrap()), + ConstantInt::getNullValue(Op0->getType())); + Constant *ImmC; if (match(Op1, m_ImmConstant(ImmC))) { // (sext bool X) * C --> X ? -C : 0 @@ -485,7 +515,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { // ((ashr X, 31) | 1) * X --> abs(X) // X * ((ashr X, 31) | 1) --> abs(X) if (match(&I, m_c_BinOp(m_Or(m_AShr(m_Value(X), - m_SpecificIntAllowUndef(BitWidth - 1)), + m_SpecificIntAllowPoison(BitWidth - 1)), m_One()), m_Deferred(X)))) { Value *Abs = Builder.CreateBinaryIntrinsic( @@ -530,7 +560,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { I.setHasNoSignedWrap(true); } - if (!HasNUW && willNotOverflowUnsignedMul(Op0, Op1, I)) { + if (!HasNUW && willNotOverflowUnsignedMul(Op0, Op1, I, I.hasNoSignedWrap())) { Changed = true; I.setHasNoUnsignedWrap(true); } @@ -571,36 +601,114 @@ Instruction *InstCombinerImpl::foldFPSignBitOps(BinaryOperator &I) { return nullptr; } +Instruction *InstCombinerImpl::foldPowiReassoc(BinaryOperator &I) { + auto createPowiExpr = [](BinaryOperator &I, InstCombinerImpl &IC, Value *X, + Value *Y, Value *Z) { + InstCombiner::BuilderTy &Builder = IC.Builder; + Value *YZ = Builder.CreateAdd(Y, Z); + Instruction *NewPow = Builder.CreateIntrinsic( + Intrinsic::powi, {X->getType(), YZ->getType()}, {X, YZ}, &I); + + return NewPow; + }; + + Value *X, *Y, *Z; + unsigned Opcode = I.getOpcode(); + assert((Opcode == Instruction::FMul || Opcode == Instruction::FDiv) && + "Unexpected opcode"); + + // powi(X, Y) * X --> powi(X, Y+1) + // X * powi(X, Y) --> powi(X, Y+1) + if (match(&I, m_c_FMul(m_OneUse(m_AllowReassoc(m_Intrinsic<Intrinsic::powi>( + m_Value(X), m_Value(Y)))), + m_Deferred(X)))) { + Constant *One = ConstantInt::get(Y->getType(), 1); + if (willNotOverflowSignedAdd(Y, One, I)) { + Instruction *NewPow = createPowiExpr(I, *this, X, Y, One); + return replaceInstUsesWith(I, NewPow); + } + } + + // powi(x, y) * powi(x, z) -> powi(x, y + z) + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + if (Opcode == Instruction::FMul && I.isOnlyUserOfAnyOperand() && + match(Op0, m_AllowReassoc( + m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y)))) && + match(Op1, m_AllowReassoc(m_Intrinsic<Intrinsic::powi>(m_Specific(X), + m_Value(Z)))) && + Y->getType() == Z->getType()) { + Instruction *NewPow = createPowiExpr(I, *this, X, Y, Z); + return replaceInstUsesWith(I, NewPow); + } + + if (Opcode == Instruction::FDiv && I.hasAllowReassoc() && I.hasNoNaNs()) { + // powi(X, Y) / X --> powi(X, Y-1) + // This is legal when (Y - 1) can't wraparound, in which case reassoc and + // nnan are required. + // TODO: Multi-use may be also better off creating Powi(x,y-1) + if (match(Op0, m_OneUse(m_AllowReassoc(m_Intrinsic<Intrinsic::powi>( + m_Specific(Op1), m_Value(Y))))) && + willNotOverflowSignedSub(Y, ConstantInt::get(Y->getType(), 1), I)) { + Constant *NegOne = ConstantInt::getAllOnesValue(Y->getType()); + Instruction *NewPow = createPowiExpr(I, *this, Op1, Y, NegOne); + return replaceInstUsesWith(I, NewPow); + } + + // powi(X, Y) / (X * Z) --> powi(X, Y-1) / Z + // This is legal when (Y - 1) can't wraparound, in which case reassoc and + // nnan are required. + // TODO: Multi-use may be also better off creating Powi(x,y-1) + if (match(Op0, m_OneUse(m_AllowReassoc(m_Intrinsic<Intrinsic::powi>( + m_Value(X), m_Value(Y))))) && + match(Op1, m_AllowReassoc(m_c_FMul(m_Specific(X), m_Value(Z)))) && + willNotOverflowSignedSub(Y, ConstantInt::get(Y->getType(), 1), I)) { + Constant *NegOne = ConstantInt::getAllOnesValue(Y->getType()); + auto *NewPow = createPowiExpr(I, *this, X, Y, NegOne); + return BinaryOperator::CreateFDivFMF(NewPow, Z, &I); + } + } + + return nullptr; +} + Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) { Value *Op0 = I.getOperand(0); Value *Op1 = I.getOperand(1); Value *X, *Y; Constant *C; + BinaryOperator *Op0BinOp; // Reassociate constant RHS with another constant to form constant // expression. - if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP()) { + if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP() && + match(Op0, m_AllowReassoc(m_BinOp(Op0BinOp)))) { + // Everything in this scope folds I with Op0, intersecting their FMF. + FastMathFlags FMF = I.getFastMathFlags() & Op0BinOp->getFastMathFlags(); + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(FMF); Constant *C1; if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) { // (C1 / X) * C --> (C * C1) / X Constant *CC1 = ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL); if (CC1 && CC1->isNormalFP()) - return BinaryOperator::CreateFDivFMF(CC1, X, &I); + return BinaryOperator::CreateFDivFMF(CC1, X, FMF); } if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) { + // FIXME: This seems like it should also be checking for arcp // (X / C1) * C --> X * (C / C1) Constant *CDivC1 = ConstantFoldBinaryOpOperands(Instruction::FDiv, C, C1, DL); if (CDivC1 && CDivC1->isNormalFP()) - return BinaryOperator::CreateFMulFMF(X, CDivC1, &I); + return BinaryOperator::CreateFMulFMF(X, CDivC1, FMF); // If the constant was a denormal, try reassociating differently. // (X / C1) * C --> X / (C1 / C) Constant *C1DivC = ConstantFoldBinaryOpOperands(Instruction::FDiv, C1, C, DL); if (C1DivC && Op0->hasOneUse() && C1DivC->isNormalFP()) - return BinaryOperator::CreateFDivFMF(X, C1DivC, &I); + return BinaryOperator::CreateFDivFMF(X, C1DivC, FMF); } // We do not need to match 'fadd C, X' and 'fsub X, C' because they are @@ -610,26 +718,33 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) { // (X + C1) * C --> (X * C) + (C * C1) if (Constant *CC1 = ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) { - Value *XC = Builder.CreateFMulFMF(X, C, &I); - return BinaryOperator::CreateFAddFMF(XC, CC1, &I); + Value *XC = Builder.CreateFMul(X, C); + return BinaryOperator::CreateFAddFMF(XC, CC1, FMF); } } if (match(Op0, m_OneUse(m_FSub(m_Constant(C1), m_Value(X))))) { // (C1 - X) * C --> (C * C1) - (X * C) if (Constant *CC1 = ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) { - Value *XC = Builder.CreateFMulFMF(X, C, &I); - return BinaryOperator::CreateFSubFMF(CC1, XC, &I); + Value *XC = Builder.CreateFMul(X, C); + return BinaryOperator::CreateFSubFMF(CC1, XC, FMF); } } } Value *Z; if (match(&I, - m_c_FMul(m_OneUse(m_FDiv(m_Value(X), m_Value(Y))), m_Value(Z)))) { - // Sink division: (X / Y) * Z --> (X * Z) / Y - Value *NewFMul = Builder.CreateFMulFMF(X, Z, &I); - return BinaryOperator::CreateFDivFMF(NewFMul, Y, &I); + m_c_FMul(m_AllowReassoc(m_OneUse(m_FDiv(m_Value(X), m_Value(Y)))), + m_Value(Z)))) { + BinaryOperator *DivOp = cast<BinaryOperator>(((Z == Op0) ? Op1 : Op0)); + FastMathFlags FMF = I.getFastMathFlags() & DivOp->getFastMathFlags(); + if (FMF.allowReassoc()) { + // Sink division: (X / Y) * Z --> (X * Z) / Y + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(FMF); + auto *NewFMul = Builder.CreateFMul(X, Z); + return BinaryOperator::CreateFDivFMF(NewFMul, Y, FMF); + } } // sqrt(X) * sqrt(Y) -> sqrt(X * Y) @@ -683,6 +798,9 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) { return replaceInstUsesWith(I, Pow); } + if (Instruction *FoldedPowi = foldPowiReassoc(I)) + return FoldedPowi; + if (I.isOnlyUserOfAnyOperand()) { // pow(X, Y) * pow(X, Z) -> pow(X, Y + Z) if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) && @@ -699,16 +817,6 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) { return replaceInstUsesWith(I, NewPow); } - // powi(x, y) * powi(x, z) -> powi(x, y + z) - if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) && - match(Op1, m_Intrinsic<Intrinsic::powi>(m_Specific(X), m_Value(Z))) && - Y->getType() == Z->getType()) { - auto *YZ = Builder.CreateAdd(Y, Z); - auto *NewPow = Builder.CreateIntrinsic( - Intrinsic::powi, {X->getType(), YZ->getType()}, {X, YZ}, &I); - return replaceInstUsesWith(I, NewPow); - } - // exp(X) * exp(Y) -> exp(X + Y) if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))) && match(Op1, m_Intrinsic<Intrinsic::exp>(m_Value(Y)))) { @@ -769,13 +877,24 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { if (Instruction *R = foldFPSignBitOps(I)) return R; + if (Instruction *R = foldFBinOpOfIntCasts(I)) + return R; + // X * -1.0 --> -X Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (match(Op1, m_SpecificFP(-1.0))) return UnaryOperator::CreateFNegFMF(Op0, &I); - // With no-nans: X * 0.0 --> copysign(0.0, X) - if (I.hasNoNaNs() && match(Op1, m_PosZeroFP())) { + // With no-nans/no-infs: + // X * 0.0 --> copysign(0.0, X) + // X * -0.0 --> copysign(0.0, -X) + const APFloat *FPC; + if (match(Op1, m_APFloatAllowPoison(FPC)) && FPC->isZero() && + ((I.hasNoInfs() && + isKnownNeverNaN(Op0, /*Depth=*/0, SQ.getWithInstruction(&I))) || + isKnownNeverNaN(&I, /*Depth=*/0, SQ.getWithInstruction(&I)))) { + if (FPC->isNegative()) + Op0 = Builder.CreateFNegFMF(Op0, &I); CallInst *CopySign = Builder.CreateIntrinsic(Intrinsic::copysign, {I.getType()}, {Op1, Op0}, &I); return replaceInstUsesWith(I, CopySign); @@ -788,6 +907,24 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) return BinaryOperator::CreateFMulFMF(X, NegC, &I); + if (I.hasNoNaNs() && I.hasNoSignedZeros()) { + // (uitofp bool X) * Y --> X ? Y : 0 + // Y * (uitofp bool X) --> X ? Y : 0 + // Note INF * 0 is NaN. + if (match(Op0, m_UIToFP(m_Value(X))) && + X->getType()->isIntOrIntVectorTy(1)) { + auto *SI = SelectInst::Create(X, Op1, ConstantFP::get(I.getType(), 0.0)); + SI->copyFastMathFlags(I.getFastMathFlags()); + return SI; + } + if (match(Op1, m_UIToFP(m_Value(X))) && + X->getType()->isIntOrIntVectorTy(1)) { + auto *SI = SelectInst::Create(X, Op0, ConstantFP::get(I.getType(), 0.0)); + SI->copyFastMathFlags(I.getFastMathFlags()); + return SI; + } + } + // (select A, B, C) * (select A, D, E) --> select A, (B*D), (C*E) if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) return replaceInstUsesWith(I, V); @@ -1120,14 +1257,14 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { // We need a multiple of the divisor for a signed add constant, but // unsigned is fine with any constant pair. if (IsSigned && - match(Op0, m_NSWAdd(m_NSWMul(m_Value(X), m_SpecificInt(*C2)), - m_APInt(C1))) && + match(Op0, m_NSWAddLike(m_NSWMul(m_Value(X), m_SpecificInt(*C2)), + m_APInt(C1))) && isMultiple(*C1, *C2, Quotient, IsSigned)) { return BinaryOperator::CreateNSWAdd(X, ConstantInt::get(Ty, Quotient)); } if (!IsSigned && - match(Op0, m_NUWAdd(m_NUWMul(m_Value(X), m_SpecificInt(*C2)), - m_APInt(C1)))) { + match(Op0, m_NUWAddLike(m_NUWMul(m_Value(X), m_SpecificInt(*C2)), + m_APInt(C1)))) { return BinaryOperator::CreateNUWAdd(X, ConstantInt::get(Ty, C1->udiv(*C2))); } @@ -1143,7 +1280,9 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { // 1 / 0 --> undef ; 1 / 1 --> 1 ; 1 / -1 --> -1 ; 1 / anything else --> 0 // (Op1 + 1) u< 3 ? Op1 : 0 // Op1 must be frozen because we are increasing its number of uses. - Value *F1 = Builder.CreateFreeze(Op1, Op1->getName() + ".fr"); + Value *F1 = Op1; + if (!isGuaranteedNotToBeUndef(Op1)) + F1 = Builder.CreateFreeze(Op1, Op1->getName() + ".fr"); Value *Inc = Builder.CreateAdd(F1, Op0); Value *Cmp = Builder.CreateICmpULT(Inc, ConstantInt::get(Ty, 3)); return SelectInst::Create(Cmp, F1, ConstantInt::get(Ty, 0)); @@ -1299,9 +1438,6 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, } // log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y) - // FIXME: missed optimization: if one of the hands of select is/contains - // undef, just directly pick the other one. - // FIXME: can both hands contain undef? // FIXME: Require one use? if (SelectInst *SI = dyn_cast<SelectInst>(Op)) if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, @@ -1513,8 +1649,7 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { // -X / C --> X / -C (if the negation doesn't overflow). // TODO: This could be enhanced to handle arbitrary vector constants by // checking if all elements are not the min-signed-val. - if (!Op1C->isMinSignedValue() && - match(Op0, m_NSWSub(m_Zero(), m_Value(X)))) { + if (!Op1C->isMinSignedValue() && match(Op0, m_NSWNeg(m_Value(X)))) { Constant *NegC = ConstantInt::get(Ty, -(*Op1C)); Instruction *BO = BinaryOperator::CreateSDiv(X, NegC); BO->setIsExact(I.isExact()); @@ -1524,7 +1659,7 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { // -X / Y --> -(X / Y) Value *Y; - if (match(&I, m_SDiv(m_OneUse(m_NSWSub(m_Zero(), m_Value(X))), m_Value(Y)))) + if (match(&I, m_SDiv(m_OneUse(m_NSWNeg(m_Value(X))), m_Value(Y)))) return BinaryOperator::CreateNSWNeg( Builder.CreateSDiv(X, Y, I.getName(), I.isExact())); @@ -1592,15 +1727,17 @@ Instruction *InstCombinerImpl::foldFDivConstantDivisor(BinaryOperator &I) { // -X / C --> X / -C Value *X; - const DataLayout &DL = I.getModule()->getDataLayout(); + const DataLayout &DL = I.getDataLayout(); if (match(I.getOperand(0), m_FNeg(m_Value(X)))) if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) return BinaryOperator::CreateFDivFMF(X, NegC, &I); // nnan X / +0.0 -> copysign(inf, X) - if (I.hasNoNaNs() && match(I.getOperand(1), m_Zero())) { + // nnan nsz X / -0.0 -> copysign(inf, X) + if (I.hasNoNaNs() && + (match(I.getOperand(1), m_PosZeroFP()) || + (I.hasNoSignedZeros() && match(I.getOperand(1), m_AnyZeroFP())))) { IRBuilder<> B(&I); - // TODO: nnan nsz X / -0.0 -> copysign(inf, X) CallInst *CopySign = B.CreateIntrinsic( Intrinsic::copysign, {C->getType()}, {ConstantFP::getInfinity(I.getType()), I.getOperand(0)}, &I); @@ -1635,7 +1772,7 @@ static Instruction *foldFDivConstantDividend(BinaryOperator &I) { // C / -X --> -C / X Value *X; - const DataLayout &DL = I.getModule()->getDataLayout(); + const DataLayout &DL = I.getDataLayout(); if (match(I.getOperand(1), m_FNeg(m_Value(X)))) if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) return BinaryOperator::CreateFDivFMF(NegC, X, &I); @@ -1707,6 +1844,34 @@ static Instruction *foldFDivPowDivisor(BinaryOperator &I, return BinaryOperator::CreateFMulFMF(Op0, Pow, &I); } +/// Convert div to mul if we have an sqrt divisor iff sqrt's operand is a fdiv +/// instruction. +static Instruction *foldFDivSqrtDivisor(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + // X / sqrt(Y / Z) --> X * sqrt(Z / Y) + if (!I.hasAllowReassoc() || !I.hasAllowReciprocal()) + return nullptr; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + auto *II = dyn_cast<IntrinsicInst>(Op1); + if (!II || II->getIntrinsicID() != Intrinsic::sqrt || !II->hasOneUse() || + !II->hasAllowReassoc() || !II->hasAllowReciprocal()) + return nullptr; + + Value *Y, *Z; + auto *DivOp = dyn_cast<Instruction>(II->getOperand(0)); + if (!DivOp) + return nullptr; + if (!match(DivOp, m_FDiv(m_Value(Y), m_Value(Z)))) + return nullptr; + if (!DivOp->hasAllowReassoc() || !I.hasAllowReciprocal() || + !DivOp->hasOneUse()) + return nullptr; + Value *SwapDiv = Builder.CreateFDivFMF(Z, Y, DivOp); + Value *NewSqrt = + Builder.CreateUnaryIntrinsic(II->getIntrinsicID(), SwapDiv, II); + return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I); +} + Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) { Module *M = I.getModule(); @@ -1814,6 +1979,9 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) { if (Instruction *Mul = foldFDivPowDivisor(I, Builder)) return Mul; + if (Instruction *Mul = foldFDivSqrtDivisor(I, Builder)) + return Mul; + // pow(X, Y) / X --> pow(X, Y-1) if (I.hasAllowReassoc() && match(Op0, m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Specific(Op1), @@ -1824,20 +1992,8 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) { return replaceInstUsesWith(I, Pow); } - // powi(X, Y) / X --> powi(X, Y-1) - // This is legal when (Y - 1) can't wraparound, in which case reassoc and nnan - // are required. - // TODO: Multi-use may be also better off creating Powi(x,y-1) - if (I.hasAllowReassoc() && I.hasNoNaNs() && - match(Op0, m_OneUse(m_Intrinsic<Intrinsic::powi>(m_Specific(Op1), - m_Value(Y)))) && - willNotOverflowSignedSub(Y, ConstantInt::get(Y->getType(), 1), I)) { - Constant *NegOne = ConstantInt::getAllOnesValue(Y->getType()); - Value *Y1 = Builder.CreateAdd(Y, NegOne); - Type *Types[] = {Op1->getType(), Y1->getType()}; - Value *Pow = Builder.CreateIntrinsic(Intrinsic::powi, Types, {Op1, Y1}, &I); - return replaceInstUsesWith(I, Pow); - } + if (Instruction *FoldedPowi = foldPowiReassoc(I)) + return FoldedPowi; return nullptr; } @@ -2039,7 +2195,9 @@ Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) { // Op0 urem C -> Op0 < C ? Op0 : Op0 - C, where C >= signbit. // Op0 must be frozen because we are increasing its number of uses. if (match(Op1, m_Negative())) { - Value *F0 = Builder.CreateFreeze(Op0, Op0->getName() + ".fr"); + Value *F0 = Op0; + if (!isGuaranteedNotToBeUndef(Op0)) + F0 = Builder.CreateFreeze(Op0, Op0->getName() + ".fr"); Value *Cmp = Builder.CreateICmpULT(F0, Op1); Value *Sub = Builder.CreateSub(F0, Op1); return SelectInst::Create(Cmp, F0, Sub); @@ -2051,7 +2209,9 @@ Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) { // urem Op0, (sext i1 X) --> (Op0 == -1) ? 0 : Op0 Value *X; if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { - Value *FrozenOp0 = Builder.CreateFreeze(Op0, Op0->getName() + ".frozen"); + Value *FrozenOp0 = Op0; + if (!isGuaranteedNotToBeUndef(Op0)) + FrozenOp0 = Builder.CreateFreeze(Op0, Op0->getName() + ".frozen"); Value *Cmp = Builder.CreateICmpEQ(FrozenOp0, ConstantInt::getAllOnesValue(Ty)); return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), FrozenOp0); @@ -2062,7 +2222,9 @@ Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) { Value *Val = simplifyICmpInst(ICmpInst::ICMP_ULT, X, Op1, SQ.getWithInstruction(&I)); if (Val && match(Val, m_One())) { - Value *FrozenOp0 = Builder.CreateFreeze(Op0, Op0->getName() + ".frozen"); + Value *FrozenOp0 = Op0; + if (!isGuaranteedNotToBeUndef(Op0)) + FrozenOp0 = Builder.CreateFreeze(Op0, Op0->getName() + ".frozen"); Value *Cmp = Builder.CreateICmpEQ(FrozenOp0, Op1); return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), FrozenOp0); } @@ -2093,7 +2255,7 @@ Instruction *InstCombinerImpl::visitSRem(BinaryOperator &I) { // -X srem Y --> -(X srem Y) Value *X, *Y; - if (match(&I, m_SRem(m_OneUse(m_NSWSub(m_Zero(), m_Value(X))), m_Value(Y)))) + if (match(&I, m_SRem(m_OneUse(m_NSWNeg(m_Value(X))), m_Value(Y)))) return BinaryOperator::CreateNSWNeg(Builder.CreateSRem(X, Y)); // If the sign bits of both operands are zero (i.e. we can prove they are diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp index 62e49469cb01..cb052da79bb3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp @@ -140,7 +140,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { // Integral constants can be freely negated. if (match(V, m_AnyIntegralConstant())) - return ConstantExpr::getNeg(cast<Constant>(V), /*HasNUW=*/false, + return ConstantExpr::getNeg(cast<Constant>(V), /*HasNSW=*/false); // If we have a non-instruction, then give up. @@ -222,6 +222,11 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { } break; } + case Instruction::Call: + if (auto *CI = dyn_cast<CmpIntrinsic>(I); CI && CI->hasOneUse()) + return Builder.CreateIntrinsic(CI->getType(), CI->getIntrinsicID(), + {CI->getRHS(), CI->getLHS()}); + break; default: break; // Other instructions require recursive reasoning. } @@ -249,7 +254,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { unsigned SrcWidth = SrcOp->getType()->getScalarSizeInBits(); const APInt &FullShift = APInt(SrcWidth, SrcWidth - 1); if (IsTrulyNegation && - match(SrcOp, m_LShr(m_Value(X), m_SpecificIntAllowUndef(FullShift)))) { + match(SrcOp, m_LShr(m_Value(X), m_SpecificIntAllowPoison(FullShift)))) { Value *Ashr = Builder.CreateAShr(X, FullShift); return Builder.CreateSExt(Ashr, I->getType()); } @@ -258,9 +263,9 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { case Instruction::And: { Constant *ShAmt; // sub(y,and(lshr(x,C),1)) --> add(ashr(shl(x,(BW-1)-C),BW-1),y) - if (match(I, m_c_And(m_OneUse(m_TruncOrSelf( - m_LShr(m_Value(X), m_ImmConstant(ShAmt)))), - m_One()))) { + if (match(I, m_And(m_OneUse(m_TruncOrSelf( + m_LShr(m_Value(X), m_ImmConstant(ShAmt)))), + m_One()))) { unsigned BW = X->getType()->getScalarSizeInBits(); Constant *BWMinusOne = ConstantInt::get(X->getType(), BW - 1); Value *R = Builder.CreateShl(X, Builder.CreateSub(BWMinusOne, ShAmt)); @@ -320,7 +325,8 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { return NegatedPHI; } case Instruction::Select: { - if (isKnownNegation(I->getOperand(1), I->getOperand(2))) { + if (isKnownNegation(I->getOperand(1), I->getOperand(2), /*NeedNSW=*/false, + /*AllowPoison=*/false)) { // Of one hand of select is known to be negation of another hand, // just swap the hands around. auto *NewSelect = cast<SelectInst>(I->clone()); @@ -328,6 +334,17 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { NewSelect->swapValues(); // Don't swap prof metadata, we didn't change the branch behavior. NewSelect->setName(I->getName() + ".neg"); + // Poison-generating flags should be dropped + Value *TV = NewSelect->getTrueValue(); + Value *FV = NewSelect->getFalseValue(); + if (match(TV, m_Neg(m_Specific(FV)))) + cast<Instruction>(TV)->dropPoisonGeneratingFlags(); + else if (match(FV, m_Neg(m_Specific(TV)))) + cast<Instruction>(FV)->dropPoisonGeneratingFlags(); + else { + cast<Instruction>(TV)->dropPoisonGeneratingFlags(); + cast<Instruction>(FV)->dropPoisonGeneratingFlags(); + } Builder.Insert(NewSelect); return NewSelect; } @@ -390,12 +407,12 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg", /* HasNUW */ false, IsNSW); // Otherwise, `shl %x, C` can be interpreted as `mul %x, 1<<C`. - auto *Op1C = dyn_cast<Constant>(I->getOperand(1)); - if (!Op1C || !IsTrulyNegation) + Constant *Op1C; + if (!match(I->getOperand(1), m_ImmConstant(Op1C)) || !IsTrulyNegation) return nullptr; return Builder.CreateMul( I->getOperand(0), - ConstantExpr::getShl(Constant::getAllOnesValue(Op1C->getType()), Op1C), + Builder.CreateShl(Constant::getAllOnesValue(Op1C->getType()), Op1C), I->getName() + ".neg", /* HasNUW */ false, IsNSW); } case Instruction::Or: { diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 20b34c1379d5..b05a33c68889 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -331,7 +331,7 @@ Instruction * InstCombinerImpl::foldPHIArgInsertValueInstructionIntoPHI(PHINode &PN) { auto *FirstIVI = cast<InsertValueInst>(PN.getIncomingValue(0)); - // Scan to see if all operands are `insertvalue`'s with the same indicies, + // Scan to see if all operands are `insertvalue`'s with the same indices, // and all have a single use. for (Value *V : drop_begin(PN.incoming_values())) { auto *I = dyn_cast<InsertValueInst>(V); @@ -371,7 +371,7 @@ Instruction * InstCombinerImpl::foldPHIArgExtractValueInstructionIntoPHI(PHINode &PN) { auto *FirstEVI = cast<ExtractValueInst>(PN.getIncomingValue(0)); - // Scan to see if all operands are `extractvalue`'s with the same indicies, + // Scan to see if all operands are `extractvalue`'s with the same indices, // and all have a single use. for (Value *V : drop_begin(PN.incoming_values())) { auto *I = dyn_cast<ExtractValueInst>(V); @@ -513,7 +513,8 @@ Instruction *InstCombinerImpl::foldPHIArgGEPIntoPHI(PHINode &PN) { // especially bad when the PHIs are in the header of a loop. bool NeededPhi = false; - bool AllInBounds = true; + // Remember flags of the first phi-operand getelementptr. + GEPNoWrapFlags NW = FirstInst->getNoWrapFlags(); // Scan to see if all operands are the same opcode, and all have one user. for (Value *V : drop_begin(PN.incoming_values())) { @@ -523,7 +524,7 @@ Instruction *InstCombinerImpl::foldPHIArgGEPIntoPHI(PHINode &PN) { GEP->getNumOperands() != FirstInst->getNumOperands()) return nullptr; - AllInBounds &= GEP->isInBounds(); + NW &= GEP->getNoWrapFlags(); // Keep track of whether or not all GEPs are of alloca pointers. if (AllBasePointersAreAllocas && @@ -605,8 +606,7 @@ Instruction *InstCombinerImpl::foldPHIArgGEPIntoPHI(PHINode &PN) { Value *Base = FixedOperands[0]; GetElementPtrInst *NewGEP = GetElementPtrInst::Create(FirstInst->getSourceElementType(), Base, - ArrayRef(FixedOperands).slice(1)); - if (AllInBounds) NewGEP->setIsInBounds(); + ArrayRef(FixedOperands).slice(1), NW); PHIArgMergedDebugLoc(NewGEP, PN); return NewGEP; } @@ -1205,7 +1205,8 @@ Instruction *InstCombinerImpl::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { // Otherwise, Create the new PHI node for this user. EltPHI = PHINode::Create(Ty, PN->getNumIncomingValues(), - PN->getName()+".off"+Twine(Offset), PN); + PN->getName() + ".off" + Twine(Offset), + PN->getIterator()); assert(EltPHI->getType() != PN->getType() && "Truncate didn't shrink phi?"); @@ -1378,6 +1379,58 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN, return nullptr; } +// Fold iv = phi(start, iv.next = iv2.next op start) +// where iv2 = phi(iv2.start, iv2.next = iv2 + iv2.step) +// and iv2.start op start = start +// to iv = iv2 op start +static Value *foldDependentIVs(PHINode &PN, IRBuilderBase &Builder) { + BasicBlock *BB = PN.getParent(); + if (PN.getNumIncomingValues() != 2) + return nullptr; + + Value *Start; + Instruction *IvNext; + BinaryOperator *Iv2Next; + auto MatchOuterIV = [&](Value *V1, Value *V2) { + if (match(V2, m_c_BinOp(m_Specific(V1), m_BinOp(Iv2Next))) || + match(V2, m_GEP(m_Specific(V1), m_BinOp(Iv2Next)))) { + Start = V1; + IvNext = cast<Instruction>(V2); + return true; + } + return false; + }; + + if (!MatchOuterIV(PN.getIncomingValue(0), PN.getIncomingValue(1)) && + !MatchOuterIV(PN.getIncomingValue(1), PN.getIncomingValue(0))) + return nullptr; + + PHINode *Iv2; + Value *Iv2Start, *Iv2Step; + if (!matchSimpleRecurrence(Iv2Next, Iv2, Iv2Start, Iv2Step) || + Iv2->getParent() != BB) + return nullptr; + + auto *BO = dyn_cast<BinaryOperator>(IvNext); + Constant *Identity = + BO ? ConstantExpr::getBinOpIdentity(BO->getOpcode(), Iv2Start->getType()) + : Constant::getNullValue(Iv2Start->getType()); + if (Iv2Start != Identity) + return nullptr; + + Builder.SetInsertPoint(&*BB, BB->getFirstInsertionPt()); + if (!BO) { + auto *GEP = cast<GEPOperator>(IvNext); + return Builder.CreateGEP(GEP->getSourceElementType(), Start, Iv2, "", + cast<GEPOperator>(IvNext)->getNoWrapFlags()); + } + + assert(BO->isCommutative() && "Must be commutative"); + Value *Res = Builder.CreateBinOp(BO->getOpcode(), Iv2, Start); + cast<Instruction>(Res)->copyIRFlags(BO); + return Res; +} + // PHINode simplification // Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { @@ -1484,7 +1537,7 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { for (unsigned I = 0, E = PN.getNumIncomingValues(); I != E; ++I) { Instruction *CtxI = PN.getIncomingBlock(I)->getTerminator(); Value *VA = PN.getIncomingValue(I); - if (isKnownNonZero(VA, DL, 0, &AC, CtxI, &DT)) { + if (isKnownNonZero(VA, getSimplifyQuery().getWithInstruction(CtxI))) { if (!NonZeroConst) NonZeroConst = getAnyNonZeroConstInt(PN); if (NonZeroConst != VA) { @@ -1595,5 +1648,8 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { if (auto *V = simplifyUsingControlFlow(*this, PN, DT)) return replaceInstUsesWith(PN, V); + if (Value *Res = foldDependentIVs(PN, Builder)) + return replaceInstUsesWith(PN, Res); + return nullptr; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 9f220ec003ec..aaf4ece3249a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -99,7 +99,8 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel, // transform. Bail out if we can not exclude that possibility. if (isa<FPMathOperator>(BO)) if (!BO->hasNoSignedZeros() && - !cannotBeNegativeZero(Y, IC.getDataLayout(), &TLI)) + !cannotBeNegativeZero(Y, 0, + IC.getSimplifyQuery().getWithInstruction(&Sel))) return nullptr; // BO = binop Y, X @@ -201,6 +202,14 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, const APInt &ValC = !TC.isZero() ? TC : FC; unsigned ValZeros = ValC.logBase2(); unsigned AndZeros = AndMask.logBase2(); + bool ShouldNotVal = !TC.isZero(); + ShouldNotVal ^= Pred == ICmpInst::ICMP_NE; + + // If we would need to create an 'and' + 'shift' + 'xor' to replace a 'select' + // + 'icmp', then this transformation would result in more instructions and + // potentially interfere with other folding. + if (CreateAnd && ShouldNotVal && ValZeros != AndZeros) + return nullptr; // Insert the 'and' instruction on the input to the truncate. if (CreateAnd) @@ -220,8 +229,6 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, // Okay, now we know that everything is set up, we just don't know whether we // have a icmp_ne or icmp_eq and whether the true or false val is the zero. - bool ShouldNotVal = !TC.isZero(); - ShouldNotVal ^= Pred == ICmpInst::ICMP_NE; if (ShouldNotVal) V = Builder.CreateXor(V, ValC); @@ -484,10 +491,9 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, } if (auto *TGEP = dyn_cast<GetElementPtrInst>(TI)) { auto *FGEP = cast<GetElementPtrInst>(FI); - Type *ElementType = TGEP->getResultElementType(); - return TGEP->isInBounds() && FGEP->isInBounds() - ? GetElementPtrInst::CreateInBounds(ElementType, Op0, {Op1}) - : GetElementPtrInst::Create(ElementType, Op0, {Op1}); + Type *ElementType = TGEP->getSourceElementType(); + return GetElementPtrInst::Create( + ElementType, Op0, Op1, TGEP->getNoWrapFlags() & FGEP->getNoWrapFlags()); } llvm_unreachable("Expected BinaryOperator or GEP"); return nullptr; @@ -535,19 +541,29 @@ Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, // between 0, 1 and -1. const APInt *OOpC; bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); - if (!isa<Constant>(OOp) || - (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { - Value *NewSel = Builder.CreateSelect(SI.getCondition(), Swapped ? C : OOp, - Swapped ? OOp : C, "", &SI); - if (isa<FPMathOperator>(&SI)) - cast<Instruction>(NewSel)->setFastMathFlags(FMF); - NewSel->takeName(TVI); - BinaryOperator *BO = - BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel); - BO->copyIRFlags(TVI); - return BO; - } - return nullptr; + if (isa<Constant>(OOp) && + (!OOpIsAPInt || !isSelect01(C->getUniqueInteger(), *OOpC))) + return nullptr; + + // If the false value is a NaN then we have that the floating point math + // operation in the transformed code may not preserve the exact NaN + // bit-pattern -- e.g. `fadd sNaN, 0.0 -> qNaN`. + // This makes the transformation incorrect since the original program would + // have preserved the exact NaN bit-pattern. + // Avoid the folding if the false value might be a NaN. + if (isa<FPMathOperator>(&SI) && + !computeKnownFPClass(FalseVal, FMF, fcNan, &SI).isKnownNeverNaN()) + return nullptr; + + Value *NewSel = Builder.CreateSelect(SI.getCondition(), Swapped ? C : OOp, + Swapped ? OOp : C, "", &SI); + if (isa<FPMathOperator>(&SI)) + cast<Instruction>(NewSel)->setFastMathFlags(FMF); + NewSel->takeName(TVI); + BinaryOperator *BO = + BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel); + BO->copyIRFlags(TVI); + return BO; }; if (Instruction *R = TryFoldSelectIntoOp(SI, TrueVal, FalseVal, false)) @@ -1116,7 +1132,7 @@ static Instruction *foldSelectCtlzToCttz(ICmpInst *ICI, Value *TrueVal, /// into: /// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 false) static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, - InstCombiner::BuilderTy &Builder) { + InstCombinerImpl &IC) { ICmpInst::Predicate Pred = ICI->getPredicate(); Value *CmpLHS = ICI->getOperand(0); Value *CmpRHS = ICI->getOperand(1); @@ -1158,6 +1174,9 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, // Explicitly clear the 'is_zero_poison' flag. It's always valid to go from // true to false on this flag, so we can replace it for all users. II->setArgOperand(1, ConstantInt::getFalse(II->getContext())); + // A range annotation on the intrinsic may no longer be valid. + II->dropPoisonGeneratingAnnotations(); + IC.addToWorklist(II); return SelectArg; } @@ -1190,7 +1209,7 @@ static Value *canonicalizeSPF(ICmpInst &Cmp, Value *TrueVal, Value *FalseVal, match(RHS, m_NSWNeg(m_Specific(LHS))); Constant *IntMinIsPoisonC = ConstantInt::get(Type::getInt1Ty(Cmp.getContext()), IntMinIsPoison); - Instruction *Abs = + Value *Abs = IC.Builder.CreateBinaryIntrinsic(Intrinsic::abs, LHS, IntMinIsPoisonC); if (SPF == SelectPatternFlavor::SPF_NABS) @@ -1228,8 +1247,11 @@ bool InstCombinerImpl::replaceInInstruction(Value *V, Value *Old, Value *New, if (Depth == 2) return false; + assert(!isa<Constant>(Old) && "Only replace non-constant values"); + auto *I = dyn_cast<Instruction>(V); - if (!I || !I->hasOneUse() || !isSafeToSpeculativelyExecute(I)) + if (!I || !I->hasOneUse() || + !isSafeToSpeculativelyExecuteWithVariableReplaced(I)) return false; bool Changed = false; @@ -1274,22 +1296,36 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, Swapped = true; } - // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand. - // Make sure Y cannot be undef though, as we might pick different values for - // undef in the icmp and in f(Y). Additionally, take care to avoid replacing - // X == Y ? X : Z with X == Y ? Y : Z, as that would lead to an infinite - // replacement cycle. Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1); - if (TrueVal != CmpLHS && - isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) { - if (Value *V = simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ, - /* AllowRefinement */ true)) - // Require either the replacement or the simplification result to be a - // constant to avoid infinite loops. - // FIXME: Make this check more precise. - if (isa<Constant>(CmpRHS) || isa<Constant>(V)) + auto ReplaceOldOpWithNewOp = [&](Value *OldOp, + Value *NewOp) -> Instruction * { + // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand. + // Take care to avoid replacing X == Y ? X : Z with X == Y ? Y : Z, as that + // would lead to an infinite replacement cycle. + // If we will be able to evaluate f(Y) to a constant, we can allow undef, + // otherwise Y cannot be undef as we might pick different values for undef + // in the icmp and in f(Y). + if (TrueVal == OldOp) + return nullptr; + + if (Value *V = simplifyWithOpReplaced(TrueVal, OldOp, NewOp, SQ, + /* AllowRefinement=*/true)) { + // Need some guarantees about the new simplified op to ensure we don't inf + // loop. + // If we simplify to a constant, replace if we aren't creating new undef. + if (match(V, m_ImmConstant()) && + isGuaranteedNotToBeUndef(V, SQ.AC, &Sel, &DT)) return replaceOperand(Sel, Swapped ? 2 : 1, V); + // If NewOp is a constant and OldOp is not replace iff NewOp doesn't + // contain and undef elements. + if (match(NewOp, m_ImmConstant()) || NewOp == V) { + if (isGuaranteedNotToBeUndef(NewOp, SQ.AC, &Sel, &DT)) + return replaceOperand(Sel, Swapped ? 2 : 1, V); + return nullptr; + } + } + // Even if TrueVal does not simplify, we can directly replace a use of // CmpLHS with CmpRHS, as long as the instruction is not used anywhere // else and is safe to speculatively execute (we may end up executing it @@ -1297,17 +1333,18 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, // undefined behavior). Only do this if CmpRHS is a constant, as // profitability is not clear for other cases. // FIXME: Support vectors. - if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()) && - !Cmp.getType()->isVectorTy()) - if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS)) + if (OldOp == CmpLHS && match(NewOp, m_ImmConstant()) && + !match(OldOp, m_Constant()) && !Cmp.getType()->isVectorTy() && + isGuaranteedNotToBeUndef(NewOp, SQ.AC, &Sel, &DT)) + if (replaceInInstruction(TrueVal, OldOp, NewOp)) return &Sel; - } - if (TrueVal != CmpRHS && - isGuaranteedNotToBeUndefOrPoison(CmpLHS, SQ.AC, &Sel, &DT)) - if (Value *V = simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ, - /* AllowRefinement */ true)) - if (isa<Constant>(CmpLHS) || isa<Constant>(V)) - return replaceOperand(Sel, Swapped ? 2 : 1, V); + return nullptr; + }; + + if (Instruction *R = ReplaceOldOpWithNewOp(CmpLHS, CmpRHS)) + return R; + if (Instruction *R = ReplaceOldOpWithNewOp(CmpRHS, CmpLHS)) + return R; auto *FalseInst = dyn_cast<Instruction>(FalseVal); if (!FalseInst) @@ -1329,7 +1366,7 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, /* AllowRefinement */ false, &DropFlags) == TrueVal) { for (Instruction *I : DropFlags) { - I->dropPoisonGeneratingFlagsAndMetadata(); + I->dropPoisonGeneratingAnnotations(); Worklist.add(I); } @@ -1354,7 +1391,8 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, // Also ULT predicate can also be UGT iff C0 != -1 (+invert result) // SLT predicate can also be SGT iff C2 != INT_MAX (+invert res.) static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, - InstCombiner::BuilderTy &Builder) { + InstCombiner::BuilderTy &Builder, + InstCombiner &IC) { Value *X = Sel0.getTrueValue(); Value *Sel1 = Sel0.getFalseValue(); @@ -1482,14 +1520,14 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, std::swap(ThresholdLowIncl, ThresholdHighExcl); // The fold has a precondition 1: C2 s>= ThresholdLow - auto *Precond1 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SGE, C2, - ThresholdLowIncl); - if (!match(Precond1, m_One())) + auto *Precond1 = ConstantFoldCompareInstOperands( + ICmpInst::Predicate::ICMP_SGE, C2, ThresholdLowIncl, IC.getDataLayout()); + if (!Precond1 || !match(Precond1, m_One())) return nullptr; // The fold has a precondition 2: C2 s<= ThresholdHigh - auto *Precond2 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SLE, C2, - ThresholdHighExcl); - if (!match(Precond2, m_One())) + auto *Precond2 = ConstantFoldCompareInstOperands( + ICmpInst::Predicate::ICMP_SLE, C2, ThresholdHighExcl, IC.getDataLayout()); + if (!Precond2 || !match(Precond2, m_One())) return nullptr; // If we are matching from a truncated input, we need to sext the @@ -1500,7 +1538,7 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, if (!match(ReplacementLow, m_ImmConstant(LowC)) || !match(ReplacementHigh, m_ImmConstant(HighC))) return nullptr; - const DataLayout &DL = Sel0.getModule()->getDataLayout(); + const DataLayout &DL = Sel0.getDataLayout(); ReplacementLow = ConstantFoldCastOperand(Instruction::SExt, LowC, X->getType(), DL); ReplacementHigh = @@ -1610,7 +1648,7 @@ static Instruction *foldSelectZeroOrOnes(ICmpInst *Cmp, Value *TVal, return nullptr; const APInt *CmpC; - if (!match(Cmp->getOperand(1), m_APIntAllowUndef(CmpC))) + if (!match(Cmp->getOperand(1), m_APIntAllowPoison(CmpC))) return nullptr; // (X u< 2) ? -X : -1 --> sext (X != 0) @@ -1676,6 +1714,109 @@ static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI, return nullptr; } +static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI, + InstCombinerImpl &IC) { + ICmpInst::Predicate Pred = ICI->getPredicate(); + if (!ICmpInst::isEquality(Pred)) + return nullptr; + + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + Value *CmpLHS = ICI->getOperand(0); + Value *CmpRHS = ICI->getOperand(1); + + if (Pred == ICmpInst::ICMP_NE) + std::swap(TrueVal, FalseVal); + + // Transform (X == C) ? X : Y -> (X == C) ? C : Y + // specific handling for Bitwise operation. + // x&y -> (x|y) ^ (x^y) or (x|y) & ~(x^y) + // x|y -> (x&y) | (x^y) or (x&y) ^ (x^y) + // x^y -> (x|y) ^ (x&y) or (x|y) & ~(x&y) + Value *X, *Y; + if (!match(CmpLHS, m_BitwiseLogic(m_Value(X), m_Value(Y))) || + !match(TrueVal, m_c_BitwiseLogic(m_Specific(X), m_Specific(Y)))) + return nullptr; + + const unsigned AndOps = Instruction::And, OrOps = Instruction::Or, + XorOps = Instruction::Xor, NoOps = 0; + enum NotMask { None = 0, NotInner, NotRHS }; + + auto matchFalseVal = [&](unsigned OuterOpc, unsigned InnerOpc, + unsigned NotMask) { + auto matchInner = m_c_BinOp(InnerOpc, m_Specific(X), m_Specific(Y)); + if (OuterOpc == NoOps) + return match(CmpRHS, m_Zero()) && match(FalseVal, matchInner); + + if (NotMask == NotInner) { + return match(FalseVal, m_c_BinOp(OuterOpc, m_NotForbidPoison(matchInner), + m_Specific(CmpRHS))); + } else if (NotMask == NotRHS) { + return match(FalseVal, m_c_BinOp(OuterOpc, matchInner, + m_NotForbidPoison(m_Specific(CmpRHS)))); + } else { + return match(FalseVal, + m_c_BinOp(OuterOpc, matchInner, m_Specific(CmpRHS))); + } + }; + + // (X&Y)==C ? X|Y : X^Y -> (X^Y)|C : X^Y or (X^Y)^ C : X^Y + // (X&Y)==C ? X^Y : X|Y -> (X|Y)^C : X|Y or (X|Y)&~C : X|Y + if (match(CmpLHS, m_And(m_Value(X), m_Value(Y)))) { + if (match(TrueVal, m_c_Or(m_Specific(X), m_Specific(Y)))) { + // (X&Y)==C ? X|Y : (X^Y)|C -> (X^Y)|C : (X^Y)|C -> (X^Y)|C + // (X&Y)==C ? X|Y : (X^Y)^C -> (X^Y)^C : (X^Y)^C -> (X^Y)^C + if (matchFalseVal(OrOps, XorOps, None) || + matchFalseVal(XorOps, XorOps, None)) + return IC.replaceInstUsesWith(SI, FalseVal); + } else if (match(TrueVal, m_c_Xor(m_Specific(X), m_Specific(Y)))) { + // (X&Y)==C ? X^Y : (X|Y)^ C -> (X|Y)^ C : (X|Y)^ C -> (X|Y)^ C + // (X&Y)==C ? X^Y : (X|Y)&~C -> (X|Y)&~C : (X|Y)&~C -> (X|Y)&~C + if (matchFalseVal(XorOps, OrOps, None) || + matchFalseVal(AndOps, OrOps, NotRHS)) + return IC.replaceInstUsesWith(SI, FalseVal); + } + } + + // (X|Y)==C ? X&Y : X^Y -> (X^Y)^C : X^Y or ~(X^Y)&C : X^Y + // (X|Y)==C ? X^Y : X&Y -> (X&Y)^C : X&Y or ~(X&Y)&C : X&Y + if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y)))) { + if (match(TrueVal, m_c_And(m_Specific(X), m_Specific(Y)))) { + // (X|Y)==C ? X&Y: (X^Y)^C -> (X^Y)^C: (X^Y)^C -> (X^Y)^C + // (X|Y)==C ? X&Y:~(X^Y)&C ->~(X^Y)&C:~(X^Y)&C -> ~(X^Y)&C + if (matchFalseVal(XorOps, XorOps, None) || + matchFalseVal(AndOps, XorOps, NotInner)) + return IC.replaceInstUsesWith(SI, FalseVal); + } else if (match(TrueVal, m_c_Xor(m_Specific(X), m_Specific(Y)))) { + // (X|Y)==C ? X^Y : (X&Y)^C -> (X&Y)^C : (X&Y)^C -> (X&Y)^C + // (X|Y)==C ? X^Y :~(X&Y)&C -> ~(X&Y)&C :~(X&Y)&C -> ~(X&Y)&C + if (matchFalseVal(XorOps, AndOps, None) || + matchFalseVal(AndOps, AndOps, NotInner)) + return IC.replaceInstUsesWith(SI, FalseVal); + } + } + + // (X^Y)==C ? X&Y : X|Y -> (X|Y)^C : X|Y or (X|Y)&~C : X|Y + // (X^Y)==C ? X|Y : X&Y -> (X&Y)|C : X&Y or (X&Y)^ C : X&Y + if (match(CmpLHS, m_Xor(m_Value(X), m_Value(Y)))) { + if ((match(TrueVal, m_c_And(m_Specific(X), m_Specific(Y))))) { + // (X^Y)==C ? X&Y : (X|Y)^C -> (X|Y)^C + // (X^Y)==C ? X&Y : (X|Y)&~C -> (X|Y)&~C + if (matchFalseVal(XorOps, OrOps, None) || + matchFalseVal(AndOps, OrOps, NotRHS)) + return IC.replaceInstUsesWith(SI, FalseVal); + } else if (match(TrueVal, m_c_Or(m_Specific(X), m_Specific(Y)))) { + // (X^Y)==C ? (X|Y) : (X&Y)|C -> (X&Y)|C + // (X^Y)==C ? (X|Y) : (X&Y)^C -> (X&Y)^C + if (matchFalseVal(OrOps, AndOps, None) || + matchFalseVal(XorOps, AndOps, None)) + return IC.replaceInstUsesWith(SI, FalseVal); + } + } + + return nullptr; +} + /// Visit a SelectInst that has an ICmpInst as its first operand. Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI) { @@ -1689,7 +1830,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Value *V = foldSelectInstWithICmpConst(SI, ICI, Builder)) return replaceInstUsesWith(SI, V); - if (Value *V = canonicalizeClampLike(SI, *ICI, Builder)) + if (Value *V = canonicalizeClampLike(SI, *ICI, Builder, *this)) return replaceInstUsesWith(SI, V); if (Instruction *NewSel = @@ -1718,6 +1859,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, } } + if (Instruction *NewSel = foldSelectICmpEq(SI, ICI, *this)) + return NewSel; + // Canonicalize a signbit condition to use zero constant by swapping: // (CmpLHS > -1) ? TV : FV --> (CmpLHS < 0) ? FV : TV // To avoid conflicts (infinite loops) with other canonicalizations, this is @@ -1803,7 +1947,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Value *V = foldSelectICmpLshrAshr(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); - if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder)) + if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, *this)) return replaceInstUsesWith(SI, V); if (Value *V = canonicalizeSaturatedSubtract(ICI, TrueVal, FalseVal, Builder)) @@ -2223,20 +2367,20 @@ static Instruction *foldSelectCmpBitcasts(SelectInst &Sel, /// operand, the result of the select will always be equal to its false value. /// For example: /// -/// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst -/// %1 = extractvalue { i64, i1 } %0, 1 -/// %2 = extractvalue { i64, i1 } %0, 0 -/// %3 = select i1 %1, i64 %compare, i64 %2 -/// ret i64 %3 +/// %cmpxchg = cmpxchg ptr %ptr, i64 %compare, i64 %new_value seq_cst seq_cst +/// %val = extractvalue { i64, i1 } %cmpxchg, 0 +/// %success = extractvalue { i64, i1 } %cmpxchg, 1 +/// %sel = select i1 %success, i64 %compare, i64 %val +/// ret i64 %sel /// -/// The returned value of the cmpxchg instruction (%2) is the original value -/// located at %ptr prior to any update. If the cmpxchg operation succeeds, %2 +/// The returned value of the cmpxchg instruction (%val) is the original value +/// located at %ptr prior to any update. If the cmpxchg operation succeeds, %val /// must have been equal to %compare. Thus, the result of the select is always -/// equal to %2, and the code can be simplified to: +/// equal to %val, and the code can be simplified to: /// -/// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst -/// %1 = extractvalue { i64, i1 } %0, 0 -/// ret i64 %1 +/// %cmpxchg = cmpxchg ptr %ptr, i64 %compare, i64 %new_value seq_cst seq_cst +/// %val = extractvalue { i64, i1 } %cmpxchg, 0 +/// ret i64 %val /// static Value *foldSelectCmpXchg(SelectInst &SI) { // A helper that determines if V is an extractvalue instruction whose @@ -2369,14 +2513,11 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel, Value *FVal = Sel.getFalseValue(); Type *SelType = Sel.getType(); - if (ICmpInst::makeCmpResultType(TVal->getType()) != Cond->getType()) - return nullptr; - // Match select ?, TC, FC where the constants are equal but negated. // TODO: Generalize to handle a negated variable operand? const APFloat *TC, *FC; - if (!match(TVal, m_APFloatAllowUndef(TC)) || - !match(FVal, m_APFloatAllowUndef(FC)) || + if (!match(TVal, m_APFloatAllowPoison(TC)) || + !match(FVal, m_APFloatAllowPoison(FC)) || !abs(*TC).bitwiseIsEqual(abs(*FC))) return nullptr; @@ -2386,9 +2527,9 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel, const APInt *C; bool IsTrueIfSignSet; ICmpInst::Predicate Pred; - if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) || - !InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) || - X->getType() != SelType) + if (!match(Cond, m_OneUse(m_ICmp(Pred, m_ElementWiseBitCast(m_Value(X)), + m_APInt(C)))) || + !isSignBitCheck(Pred, *C, IsTrueIfSignSet) || X->getType() != SelType) return nullptr; // If needed, negate the value that will be the sign argument of the copysign: @@ -2423,8 +2564,8 @@ Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { if (auto *I = dyn_cast<Instruction>(V)) I->copyIRFlags(&Sel); Module *M = Sel.getModule(); - Function *F = Intrinsic::getDeclaration( - M, Intrinsic::experimental_vector_reverse, V->getType()); + Function *F = + Intrinsic::getDeclaration(M, Intrinsic::vector_reverse, V->getType()); return CallInst::Create(F, V); }; @@ -2587,7 +2728,7 @@ static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC, bool TrueIfSigned = false; if (!(match(CondVal, m_ICmp(Pred, m_Value(RemRes), m_APInt(C))) && - IC.isSignBitCheck(Pred, *C, TrueIfSigned))) + isSignBitCheck(Pred, *C, TrueIfSigned))) return nullptr; // If the sign bit is not set, we have a SGE/SGT comparison, and the operands @@ -2606,7 +2747,7 @@ static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC, // %cnd = icmp slt i32 %rem, 0 // %add = add i32 %rem, %n // %sel = select i1 %cnd, i32 %add, i32 %rem - if (match(TrueVal, m_Add(m_Value(RemRes), m_Value(Remainder))) && + if (match(TrueVal, m_Add(m_Specific(RemRes), m_Value(Remainder))) && match(RemRes, m_SRem(m_Value(Op), m_Specific(Remainder))) && IC.isKnownToBeAPowerOfTwo(Remainder, /*OrZero*/ true) && FalseVal == RemRes) @@ -2650,46 +2791,33 @@ static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy return nullptr; } +/// Given that \p CondVal is known to be \p CondIsTrue, try to simplify \p SI. +static Value *simplifyNestedSelectsUsingImpliedCond(SelectInst &SI, + Value *CondVal, + bool CondIsTrue, + const DataLayout &DL) { + Value *InnerCondVal = SI.getCondition(); + Value *InnerTrueVal = SI.getTrueValue(); + Value *InnerFalseVal = SI.getFalseValue(); + assert(CondVal->getType() == InnerCondVal->getType() && + "The type of inner condition must match with the outer."); + if (auto Implied = isImpliedCondition(CondVal, InnerCondVal, DL, CondIsTrue)) + return *Implied ? InnerTrueVal : InnerFalseVal; + return nullptr; +} + Instruction *InstCombinerImpl::foldAndOrOfSelectUsingImpliedCond(Value *Op, SelectInst &SI, bool IsAnd) { - Value *CondVal = SI.getCondition(); - Value *A = SI.getTrueValue(); - Value *B = SI.getFalseValue(); - assert(Op->getType()->isIntOrIntVectorTy(1) && "Op must be either i1 or vector of i1."); - - std::optional<bool> Res = isImpliedCondition(Op, CondVal, DL, IsAnd); - if (!Res) + if (SI.getCondition()->getType() != Op->getType()) return nullptr; - - Value *Zero = Constant::getNullValue(A->getType()); - Value *One = Constant::getAllOnesValue(A->getType()); - - if (*Res == true) { - if (IsAnd) - // select op, (select cond, A, B), false => select op, A, false - // and op, (select cond, A, B) => select op, A, false - // if op = true implies condval = true. - return SelectInst::Create(Op, A, Zero); - else - // select op, true, (select cond, A, B) => select op, true, A - // or op, (select cond, A, B) => select op, true, A - // if op = false implies condval = true. - return SelectInst::Create(Op, One, A); - } else { - if (IsAnd) - // select op, (select cond, A, B), false => select op, B, false - // and op, (select cond, A, B) => select op, B, false - // if op = true implies condval = false. - return SelectInst::Create(Op, B, Zero); - else - // select op, true, (select cond, A, B) => select op, true, B - // or op, (select cond, A, B) => select op, true, B - // if op = false implies condval = false. - return SelectInst::Create(Op, One, B); - } + if (Value *V = simplifyNestedSelectsUsingImpliedCond(SI, Op, IsAnd, DL)) + return SelectInst::Create(Op, + IsAnd ? V : ConstantInt::getTrue(Op->getType()), + IsAnd ? ConstantInt::getFalse(Op->getType()) : V); + return nullptr; } // Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need @@ -2772,6 +2900,36 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI, } } + // Match select with (icmp slt (bitcast X to int), 0) + // or (icmp sgt (bitcast X to int), -1) + + for (bool Swap : {false, true}) { + Value *TrueVal = SI.getTrueValue(); + Value *X = SI.getFalseValue(); + + if (Swap) + std::swap(TrueVal, X); + + CmpInst::Predicate Pred; + const APInt *C; + bool TrueIfSigned; + if (!match(CondVal, + m_ICmp(Pred, m_ElementWiseBitCast(m_Specific(X)), m_APInt(C))) || + !isSignBitCheck(Pred, *C, TrueIfSigned)) + continue; + if (!match(TrueVal, m_FNeg(m_Specific(X)))) + return nullptr; + if (Swap == TrueIfSigned && !CondVal->hasOneUse() && !TrueVal->hasOneUse()) + return nullptr; + + // Fold (IsNeg ? -X : X) or (!IsNeg ? X : -X) to fabs(X) + // Fold (IsNeg ? X : -X) or (!IsNeg ? -X : X) to -fabs(X) + Value *Fabs = IC.Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, &SI); + if (Swap != TrueIfSigned) + return IC.replaceInstUsesWith(SI, Fabs); + return UnaryOperator::CreateFNegFMF(Fabs, &SI); + } + return ChangedFMF ? &SI : nullptr; } @@ -2808,17 +2966,17 @@ foldRoundUpIntegerWithPow2Alignment(SelectInst &SI, // FIXME: we could support non non-splats here. const APInt *LowBitMaskCst; - if (!match(XLowBits, m_And(m_Specific(X), m_APIntAllowUndef(LowBitMaskCst)))) + if (!match(XLowBits, m_And(m_Specific(X), m_APIntAllowPoison(LowBitMaskCst)))) return nullptr; // Match even if the AND and ADD are swapped. const APInt *BiasCst, *HighBitMaskCst; if (!match(XBiasedHighBits, - m_And(m_Add(m_Specific(X), m_APIntAllowUndef(BiasCst)), - m_APIntAllowUndef(HighBitMaskCst))) && + m_And(m_Add(m_Specific(X), m_APIntAllowPoison(BiasCst)), + m_APIntAllowPoison(HighBitMaskCst))) && !match(XBiasedHighBits, - m_Add(m_And(m_Specific(X), m_APIntAllowUndef(HighBitMaskCst)), - m_APIntAllowUndef(BiasCst)))) + m_Add(m_And(m_Specific(X), m_APIntAllowPoison(HighBitMaskCst)), + m_APIntAllowPoison(BiasCst)))) return nullptr; if (!LowBitMaskCst->isMask()) @@ -2834,7 +2992,8 @@ foldRoundUpIntegerWithPow2Alignment(SelectInst &SI, return nullptr; if (!XBiasedHighBits->hasOneUse()) { - if (*BiasCst == *LowBitMaskCst) + // We can't directly return XBiasedHighBits if it is more poisonous. + if (*BiasCst == *LowBitMaskCst && impliesPoison(XBiasedHighBits, X)) return XBiasedHighBits; return nullptr; } @@ -2856,6 +3015,32 @@ struct DecomposedSelect { }; } // namespace +/// Folds patterns like: +/// select c2 (select c1 a b) (select c1 b a) +/// into: +/// select (xor c1 c2) b a +static Instruction * +foldSelectOfSymmetricSelect(SelectInst &OuterSelVal, + InstCombiner::BuilderTy &Builder) { + + Value *OuterCond, *InnerCond, *InnerTrueVal, *InnerFalseVal; + if (!match( + &OuterSelVal, + m_Select(m_Value(OuterCond), + m_OneUse(m_Select(m_Value(InnerCond), m_Value(InnerTrueVal), + m_Value(InnerFalseVal))), + m_OneUse(m_Select(m_Deferred(InnerCond), + m_Deferred(InnerFalseVal), + m_Deferred(InnerTrueVal)))))) + return nullptr; + + if (OuterCond->getType() != InnerCond->getType()) + return nullptr; + + Value *Xor = Builder.CreateXor(InnerCond, OuterCond); + return SelectInst::Create(Xor, InnerFalseVal, InnerTrueVal); +} + /// Look for patterns like /// %outer.cond = select i1 %inner.cond, i1 %alt.cond, i1 false /// %inner.sel = select i1 %inner.cond, i8 %inner.sel.t, i8 %inner.sel.f @@ -2960,6 +3145,13 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { return BinaryOperator::CreateOr(CondVal, FalseVal); } + if (match(CondVal, m_OneUse(m_Select(m_Value(A), m_One(), m_Value(B)))) && + impliesPoison(FalseVal, B)) { + // (A || B) || C --> A || (B | C) + return replaceInstUsesWith( + SI, Builder.CreateLogicalOr(A, Builder.CreateOr(B, FalseVal))); + } + if (auto *LHS = dyn_cast<FCmpInst>(CondVal)) if (auto *RHS = dyn_cast<FCmpInst>(FalseVal)) if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ false, @@ -3001,6 +3193,13 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { return BinaryOperator::CreateAnd(CondVal, TrueVal); } + if (match(CondVal, m_OneUse(m_Select(m_Value(A), m_Value(B), m_Zero()))) && + impliesPoison(TrueVal, B)) { + // (A && B) && C --> A && (B & C) + return replaceInstUsesWith( + SI, Builder.CreateLogicalAnd(A, Builder.CreateAnd(B, TrueVal))); + } + if (auto *LHS = dyn_cast<FCmpInst>(CondVal)) if (auto *RHS = dyn_cast<FCmpInst>(TrueVal)) if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ true, @@ -3115,11 +3314,6 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { return replaceInstUsesWith(SI, Op1); } - if (auto *Op1SI = dyn_cast<SelectInst>(Op1)) - if (auto *I = foldAndOrOfSelectUsingImpliedCond(CondVal, *Op1SI, - /* IsAnd */ IsAnd)) - return I; - if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal)) if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1)) if (auto *V = foldAndOrOfICmps(ICmp0, ICmp1, SI, IsAnd, @@ -3201,7 +3395,8 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { // pattern. static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0, const APInt *Cond1, Value *CtlzOp, - unsigned BitWidth) { + unsigned BitWidth, + bool &ShouldDropNUW) { // The challenge in recognizing std::bit_ceil(X) is that the operand is used // for the CTLZ proper and select condition, each possibly with some // operation like add and sub. @@ -3224,6 +3419,8 @@ static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0, ConstantRange CR = ConstantRange::makeExactICmpRegion( CmpInst::getInversePredicate(Pred), *Cond1); + ShouldDropNUW = false; + // Match the operation that's used to compute CtlzOp from CommonAncestor. If // CtlzOp == CommonAncestor, return true as no operation is needed. If a // match is found, execute the operation on CR, update CR, and return true. @@ -3237,6 +3434,7 @@ static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0, return true; } if (match(CtlzOp, m_Sub(m_APInt(C), m_Specific(CommonAncestor)))) { + ShouldDropNUW = true; CR = ConstantRange(*C).sub(CR); return true; } @@ -3306,14 +3504,20 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) { Pred = CmpInst::getInversePredicate(Pred); } + bool ShouldDropNUW; + if (!match(FalseVal, m_One()) || !match(TrueVal, m_OneUse(m_Shl(m_One(), m_OneUse(m_Sub(m_SpecificInt(BitWidth), m_Value(Ctlz)))))) || !match(Ctlz, m_Intrinsic<Intrinsic::ctlz>(m_Value(CtlzOp), m_Zero())) || - !isSafeToRemoveBitCeilSelect(Pred, Cond0, Cond1, CtlzOp, BitWidth)) + !isSafeToRemoveBitCeilSelect(Pred, Cond0, Cond1, CtlzOp, BitWidth, + ShouldDropNUW)) return nullptr; + if (ShouldDropNUW) + cast<Instruction>(CtlzOp)->setHasNoUnsignedWrap(false); + // Build 1 << (-CTLZ & (BitWidth-1)). The negation likely corresponds to a // single hardware instruction as opposed to BitWidth - CTLZ, where BitWidth // is an integer constant. Masking with BitWidth-1 comes free on some @@ -3350,6 +3554,33 @@ static bool matchFMulByZeroIfResultEqZero(InstCombinerImpl &IC, Value *Cmp0, return false; } +/// Check whether the KnownBits of a select arm may be affected by the +/// select condition. +static bool hasAffectedValue(Value *V, SmallPtrSetImpl<Value *> &Affected, + unsigned Depth) { + if (Depth == MaxAnalysisRecursionDepth) + return false; + + // Ignore the case where the select arm itself is affected. These cases + // are handled more efficiently by other optimizations. + if (Depth != 0 && Affected.contains(V)) + return true; + + if (auto *I = dyn_cast<Instruction>(V)) { + if (isa<PHINode>(I)) { + if (Depth == MaxAnalysisRecursionDepth - 1) + return false; + Depth = MaxAnalysisRecursionDepth - 2; + } + return any_of(I->operands(), [&](Value *Op) { + return Op->getType()->isIntOrIntVectorTy() && + hasAffectedValue(Op, Affected, Depth + 1); + }); + } + + return false; +} + Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -3536,16 +3767,15 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *Idx = Gep->getOperand(1); if (isa<VectorType>(CondVal->getType()) && !isa<VectorType>(Idx->getType())) return nullptr; - Type *ElementType = Gep->getResultElementType(); + Type *ElementType = Gep->getSourceElementType(); Value *NewT = Idx; Value *NewF = Constant::getNullValue(Idx->getType()); if (Swap) std::swap(NewT, NewF); Value *NewSI = Builder.CreateSelect(CondVal, NewT, NewF, SI.getName() + ".idx", &SI); - if (Gep->isInBounds()) - return GetElementPtrInst::CreateInBounds(ElementType, Ptr, {NewSI}); - return GetElementPtrInst::Create(ElementType, Ptr, {NewSI}); + return GetElementPtrInst::Create(ElementType, Ptr, NewSI, + Gep->getNoWrapFlags()); }; if (auto *TrueGep = dyn_cast<GetElementPtrInst>(TrueVal)) if (auto *NewGep = SelectGepWithBase(TrueGep, FalseVal, false)) @@ -3620,12 +3850,12 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) { if (TrueSI->getCondition()->getType() == CondVal->getType()) { - // select(C, select(C, a, b), c) -> select(C, a, c) - if (TrueSI->getCondition() == CondVal) { - if (SI.getTrueValue() == TrueSI->getTrueValue()) - return nullptr; - return replaceOperand(SI, 1, TrueSI->getTrueValue()); - } + // Fold nested selects if the inner condition can be implied by the outer + // condition. + if (Value *V = simplifyNestedSelectsUsingImpliedCond( + *TrueSI, CondVal, /*CondIsTrue=*/true, DL)) + return replaceOperand(SI, 1, V); + // select(C0, select(C1, a, b), b) -> select(C0&C1, a, b) // We choose this as normal form to enable folding on the And and // shortening paths for the values (this helps getUnderlyingObjects() for @@ -3640,12 +3870,12 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } if (SelectInst *FalseSI = dyn_cast<SelectInst>(FalseVal)) { if (FalseSI->getCondition()->getType() == CondVal->getType()) { - // select(C, a, select(C, b, c)) -> select(C, a, c) - if (FalseSI->getCondition() == CondVal) { - if (SI.getFalseValue() == FalseSI->getFalseValue()) - return nullptr; - return replaceOperand(SI, 2, FalseSI->getFalseValue()); - } + // Fold nested selects if the inner condition can be implied by the outer + // condition. + if (Value *V = simplifyNestedSelectsUsingImpliedCond( + *FalseSI, CondVal, /*CondIsTrue=*/false, DL)) + return replaceOperand(SI, 2, V); + // select(C0, a, select(C1, a, b)) -> select(C0|C1, a, b) if (FalseSI->getTrueValue() == TrueVal && FalseSI->hasOneUse()) { Value *Or = Builder.CreateLogicalOr(CondVal, FalseSI->getCondition()); @@ -3786,6 +4016,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } } + if (Instruction *I = foldSelectOfSymmetricSelect(SI, Builder)) + return I; + if (Instruction *I = foldNestedSelects(SI, Builder)) return I; @@ -3844,5 +4077,39 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } } + // select Cond, !X, X -> xor Cond, X + if (CondVal->getType() == SI.getType() && isKnownInversion(FalseVal, TrueVal)) + return BinaryOperator::CreateXor(CondVal, FalseVal); + + // For vectors, this transform is only safe if the simplification does not + // look through any lane-crossing operations. For now, limit to scalars only. + if (SelType->isIntegerTy() && + (!isa<Constant>(TrueVal) || !isa<Constant>(FalseVal))) { + // Try to simplify select arms based on KnownBits implied by the condition. + CondContext CC(CondVal); + findValuesAffectedByCondition(CondVal, /*IsAssume=*/false, [&](Value *V) { + CC.AffectedValues.insert(V); + }); + SimplifyQuery Q = SQ.getWithInstruction(&SI).getWithCondContext(CC); + if (!CC.AffectedValues.empty()) { + if (!isa<Constant>(TrueVal) && + hasAffectedValue(TrueVal, CC.AffectedValues, /*Depth=*/0)) { + KnownBits Known = llvm::computeKnownBits(TrueVal, /*Depth=*/0, Q); + if (Known.isConstant()) + return replaceOperand(SI, 1, + ConstantInt::get(SelType, Known.getConstant())); + } + + CC.Invert = true; + if (!isa<Constant>(FalseVal) && + hasAffectedValue(FalseVal, CC.AffectedValues, /*Depth=*/0)) { + KnownBits Known = llvm::computeKnownBits(FalseVal, /*Depth=*/0, Q); + if (Known.isConstant()) + return replaceOperand(SI, 2, + ConstantInt::get(SelType, Known.getConstant())); + } + } + } + return nullptr; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 54490c46dfae..38f8a41214b6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -216,7 +216,7 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, // ((1 << MaskShAmt) - 1) auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes()); // (~(-1 << maskNbits)) - auto MaskB = m_Xor(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_AllOnes()); + auto MaskB = m_Not(m_Shl(m_AllOnes(), m_Value(MaskShAmt))); // (-1 l>> MaskShAmt) auto MaskC = m_LShr(m_AllOnes(), m_Value(MaskShAmt)); // ((-1 << MaskShAmt) l>> MaskShAmt) @@ -257,8 +257,11 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, // And compute the mask as usual: ~(-1 << (SumOfShAmts)) auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); - auto *ExtendedInvertedMask = - ConstantExpr::getShl(ExtendedAllOnes, ExtendedSumOfShAmts); + Constant *ExtendedInvertedMask = ConstantFoldBinaryOpOperands( + Instruction::Shl, ExtendedAllOnes, ExtendedSumOfShAmts, Q.DL); + if (!ExtendedInvertedMask) + return nullptr; + NewMask = ConstantExpr::getNot(ExtendedInvertedMask); } else if (match(Masked, m_c_And(m_CombineOr(MaskC, MaskD), m_Value(X))) || match(Masked, m_Shr(m_Shl(m_Value(X), m_Value(MaskShAmt)), @@ -437,9 +440,16 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) { Value *A; Constant *C, *C1; if (match(Op0, m_Constant(C)) && - match(Op1, m_NUWAdd(m_Value(A), m_Constant(C1)))) { + match(Op1, m_NUWAddLike(m_Value(A), m_Constant(C1)))) { Value *NewC = Builder.CreateBinOp(I.getOpcode(), C, C1); - return BinaryOperator::Create(I.getOpcode(), NewC, A); + BinaryOperator *NewShiftOp = BinaryOperator::Create(I.getOpcode(), NewC, A); + if (I.getOpcode() == Instruction::Shl) { + NewShiftOp->setHasNoSignedWrap(I.hasNoSignedWrap()); + NewShiftOp->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + } else { + NewShiftOp->setIsExact(I.isExact()); + } + return NewShiftOp; } unsigned BitWidth = Ty->getScalarSizeInBits(); @@ -760,18 +770,27 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1, // (C2 >> X) >> C1 --> (C2 >> C1) >> X Constant *C2; Value *X; - if (match(Op0, m_BinOp(I.getOpcode(), m_ImmConstant(C2), m_Value(X)))) - return BinaryOperator::Create( + bool IsLeftShift = I.getOpcode() == Instruction::Shl; + if (match(Op0, m_BinOp(I.getOpcode(), m_ImmConstant(C2), m_Value(X)))) { + Instruction *R = BinaryOperator::Create( I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), C2, C1), X); + BinaryOperator *BO0 = cast<BinaryOperator>(Op0); + if (IsLeftShift) { + R->setHasNoUnsignedWrap(I.hasNoUnsignedWrap() && + BO0->hasNoUnsignedWrap()); + R->setHasNoSignedWrap(I.hasNoSignedWrap() && BO0->hasNoSignedWrap()); + } else + R->setIsExact(I.isExact() && BO0->isExact()); + return R; + } - bool IsLeftShift = I.getOpcode() == Instruction::Shl; Type *Ty = I.getType(); unsigned TypeBits = Ty->getScalarSizeInBits(); // (X / +DivC) >> (Width - 1) --> ext (X <= -DivC) // (X / -DivC) >> (Width - 1) --> ext (X >= +DivC) const APInt *DivC; - if (!IsLeftShift && match(C1, m_SpecificIntAllowUndef(TypeBits - 1)) && + if (!IsLeftShift && match(C1, m_SpecificIntAllowPoison(TypeBits - 1)) && match(Op0, m_SDiv(m_Value(X), m_APInt(DivC))) && !DivC->isZero() && !DivC->isMinSignedValue()) { Constant *NegDivC = ConstantInt::get(Ty, -(*DivC)); @@ -1113,14 +1132,6 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Ty, Mask)); } - if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) { - unsigned AmtSum = ShAmtC + C1->getZExtValue(); - // Oversized shifts are simplified to zero in InstSimplify. - if (AmtSum < BitWidth) - // (X << C1) << C2 --> X << (C1 + C2) - return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum)); - } - // If we have an opposite shift by the same amount, we may be able to // reorder binops and shifts to eliminate math/logic. auto isSuitableBinOpcode = [](Instruction::BinaryOps BinOpcode) { @@ -1175,7 +1186,11 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { // X & (CC << C) Value *M = Builder.CreateAnd(X, ConstantInt::get(Ty, CC->shl(*C)), X->getName() + ".mask"); - return BinaryOperator::Create(Op0BO->getOpcode(), M, YS); + auto *NewOp = BinaryOperator::Create(Op0BO->getOpcode(), M, YS); + if (auto *Disjoint = dyn_cast<PossiblyDisjointInst>(Op0BO); + Disjoint && Disjoint->isDisjoint()) + cast<PossiblyDisjointInst>(NewOp)->setIsDisjoint(true); + return NewOp; } } @@ -1199,17 +1214,23 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { return BinaryOperator::CreateAnd(Mask, X); } + // Transform (-1 >> y) << y to -1 << y + if (match(Op0, m_LShr(m_AllOnes(), m_Specific(Op1)))) { + Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); + return BinaryOperator::CreateShl(AllOnes, Op1); + } + Constant *C1; - if (match(Op1, m_Constant(C1))) { + if (match(Op1, m_ImmConstant(C1))) { Constant *C2; Value *X; // (X * C2) << C1 --> X * (C2 << C1) - if (match(Op0, m_Mul(m_Value(X), m_Constant(C2)))) - return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1)); + if (match(Op0, m_Mul(m_Value(X), m_ImmConstant(C2)))) + return BinaryOperator::CreateMul(X, Builder.CreateShl(C2, C1)); // shl (zext i1 X), C1 --> select (X, 1 << C1, 0) if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { - auto *NewC = ConstantExpr::getShl(ConstantInt::get(Ty, 1), C1); + auto *NewC = Builder.CreateShl(ConstantInt::get(Ty, 1), C1); return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty)); } } @@ -1251,9 +1272,74 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { // (iN (~X) u>> (N - 1)) --> zext (X > -1) if (match(Op0, m_OneUse(m_Not(m_Value(X)))) && - match(Op1, m_SpecificIntAllowUndef(BitWidth - 1))) + match(Op1, m_SpecificIntAllowPoison(BitWidth - 1))) return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty); + // ((X << nuw Z) sub nuw Y) >>u exact Z --> X sub nuw (Y >>u exact Z) + Value *Y; + if (I.isExact() && + match(Op0, m_OneUse(m_NUWSub(m_NUWShl(m_Value(X), m_Specific(Op1)), + m_Value(Y))))) { + Value *NewLshr = Builder.CreateLShr(Y, Op1, "", /*isExact=*/true); + auto *NewSub = BinaryOperator::CreateNUWSub(X, NewLshr); + NewSub->setHasNoSignedWrap( + cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap()); + return NewSub; + } + + // Fold (X + Y) / 2 --> (X & Y) iff (X u<= 1) && (Y u<= 1) + if (match(Op0, m_Add(m_Value(X), m_Value(Y))) && match(Op1, m_One()) && + computeKnownBits(X, /*Depth=*/0, &I).countMaxActiveBits() <= 1 && + computeKnownBits(Y, /*Depth=*/0, &I).countMaxActiveBits() <= 1) + return BinaryOperator::CreateAnd(X, Y); + + // (sub nuw X, (Y << nuw Z)) >>u exact Z --> (X >>u exact Z) sub nuw Y + if (I.isExact() && + match(Op0, m_OneUse(m_NUWSub(m_Value(X), + m_NUWShl(m_Value(Y), m_Specific(Op1)))))) { + Value *NewLshr = Builder.CreateLShr(X, Op1, "", /*isExact=*/true); + auto *NewSub = BinaryOperator::CreateNUWSub(NewLshr, Y); + NewSub->setHasNoSignedWrap( + cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap()); + return NewSub; + } + + auto isSuitableBinOpcode = [](Instruction::BinaryOps BinOpcode) { + switch (BinOpcode) { + default: + return false; + case Instruction::Add: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + // Sub is handled separately. + return true; + } + }; + + // If both the binop and the shift are nuw, then: + // ((X << nuw Z) binop nuw Y) >>u Z --> X binop nuw (Y >>u Z) + if (match(Op0, m_OneUse(m_c_BinOp(m_NUWShl(m_Value(X), m_Specific(Op1)), + m_Value(Y))))) { + BinaryOperator *Op0OB = cast<BinaryOperator>(Op0); + if (isSuitableBinOpcode(Op0OB->getOpcode())) { + if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op0); + !OBO || OBO->hasNoUnsignedWrap()) { + Value *NewLshr = Builder.CreateLShr( + Y, Op1, "", I.isExact() && Op0OB->getOpcode() != Instruction::And); + auto *NewBinOp = BinaryOperator::Create(Op0OB->getOpcode(), NewLshr, X); + if (OBO) { + NewBinOp->setHasNoUnsignedWrap(true); + NewBinOp->setHasNoSignedWrap(OBO->hasNoSignedWrap()); + } else if (auto *Disjoint = dyn_cast<PossiblyDisjointInst>(Op0)) { + cast<PossiblyDisjointInst>(NewBinOp)->setIsDisjoint( + Disjoint->isDisjoint()); + } + return NewBinOp; + } + } + } + if (match(Op1, m_APInt(C))) { unsigned ShAmtC = C->getZExtValue(); auto *II = dyn_cast<IntrinsicInst>(Op0); @@ -1270,7 +1356,6 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { return new ZExtInst(Cmp, Ty); } - Value *X; const APInt *C1; if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) { if (C1->ult(ShAmtC)) { @@ -1315,7 +1400,6 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { // ((X << C) + Y) >>u C --> (X + (Y >>u C)) & (-1 >>u C) // TODO: Consolidate with the more general transform that starts from shl // (the shifts are in the opposite order). - Value *Y; if (match(Op0, m_OneUse(m_c_Add(m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))), m_Value(Y))))) { @@ -1381,14 +1465,6 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { } } - // (X >>u C1) >>u C --> X >>u (C1 + C) - if (match(Op0, m_LShr(m_Value(X), m_APInt(C1)))) { - // Oversized shifts are simplified to zero in InstSimplify. - unsigned AmtSum = ShAmtC + C1->getZExtValue(); - if (AmtSum < BitWidth) - return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum)); - } - Instruction *TruncSrc; if (match(Op0, m_OneUse(m_Trunc(m_Instruction(TruncSrc)))) && match(TruncSrc, m_LShr(m_Value(X), m_APInt(C1)))) { @@ -1414,13 +1490,24 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { const APInt *MulC; if (match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC)))) { - // Look for a "splat" mul pattern - it replicates bits across each half of - // a value, so a right shift is just a mask of the low bits: - // lshr i[2N] (mul nuw X, (2^N)+1), N --> and iN X, (2^N)-1 - // TODO: Generalize to allow more than just half-width shifts? - if (BitWidth > 2 && ShAmtC * 2 == BitWidth && (*MulC - 1).isPowerOf2() && - MulC->logBase2() == ShAmtC) - return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2)); + if (BitWidth > 2 && (*MulC - 1).isPowerOf2() && + MulC->logBase2() == ShAmtC) { + // Look for a "splat" mul pattern - it replicates bits across each half + // of a value, so a right shift simplifies back to just X: + // lshr i[2N] (mul nuw X, (2^N)+1), N --> X + if (ShAmtC * 2 == BitWidth) + return replaceInstUsesWith(I, X); + + // lshr (mul nuw (X, 2^N + 1)), N -> add nuw (X, lshr(X, N)) + if (Op0->hasOneUse()) { + auto *NewAdd = BinaryOperator::CreateNUWAdd( + X, Builder.CreateLShr(X, ConstantInt::get(Ty, ShAmtC), "", + I.isExact())); + NewAdd->setHasNoSignedWrap( + cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap()); + return NewAdd; + } + } // The one-use check is not strictly necessary, but codegen may not be // able to invert the transform and perf may suffer with an extra mul @@ -1440,6 +1527,16 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { } } + // lshr (mul nsw (X, 2^N + 1)), N -> add nsw (X, lshr(X, N)) + if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_APInt(MulC))))) { + if (BitWidth > 2 && (*MulC - 1).isPowerOf2() && + MulC->logBase2() == ShAmtC) { + return BinaryOperator::CreateNSWAdd( + X, Builder.CreateLShr(X, ConstantInt::get(Ty, ShAmtC), "", + I.isExact())); + } + } + // Try to narrow bswap. // In the case where the shift amount equals the bitwidth difference, the // shift is eliminated. @@ -1486,6 +1583,12 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { return BinaryOperator::CreateAnd(Mask, X); } + // Transform (-1 << y) >> y to -1 >> y + if (match(Op0, m_Shl(m_AllOnes(), m_Specific(Op1)))) { + Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); + return BinaryOperator::CreateLShr(AllOnes, Op1); + } + if (Instruction *Overflow = foldLShrOverflowBit(I)) return Overflow; @@ -1637,6 +1740,21 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y))))) return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty); } + + const APInt *MulC; + if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_APInt(MulC)))) && + (BitWidth > 2 && (*MulC - 1).isPowerOf2() && + MulC->logBase2() == ShAmt && + (ShAmt < BitWidth - 1))) /* Minus 1 for the sign bit */ { + + // ashr (mul nsw (X, 2^N + 1)), N -> add nsw (X, ashr(X, N)) + auto *NewAdd = BinaryOperator::CreateNSWAdd( + X, + Builder.CreateAShr(X, ConstantInt::get(Ty, ShAmt), "", I.isExact())); + NewAdd->setHasNoUnsignedWrap( + cast<OverflowingBinaryOperator>(Op0)->hasNoUnsignedWrap()); + return NewAdd; + } } const SimplifyQuery Q = SQ.getWithInstruction(&I); @@ -1647,9 +1765,9 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { // as the pattern to splat the lowest bit. // FIXME: iff X is already masked, we don't need the one-use check. Value *X; - if (match(Op1, m_SpecificIntAllowUndef(BitWidth - 1)) && + if (match(Op1, m_SpecificIntAllowPoison(BitWidth - 1)) && match(Op0, m_OneUse(m_Shl(m_Value(X), - m_SpecificIntAllowUndef(BitWidth - 1))))) { + m_SpecificIntAllowPoison(BitWidth - 1))))) { Constant *Mask = ConstantInt::get(Ty, 1); // Retain the knowledge about the ignored lanes. Mask = Constant::mergeUndefsWith( diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 79873a9b4cbb..b9d06b593685 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -69,7 +69,7 @@ bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst, KnownBits &Known) { APInt DemandedMask(APInt::getAllOnes(Known.getBitWidth())); Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known, - 0, &Inst); + 0, SQ.getWithInstruction(&Inst)); if (!V) return false; if (V == &Inst) return true; replaceInstUsesWith(Inst, V); @@ -88,10 +88,41 @@ bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) { /// change and false otherwise. bool InstCombinerImpl::SimplifyDemandedBits(Instruction *I, unsigned OpNo, const APInt &DemandedMask, - KnownBits &Known, unsigned Depth) { + KnownBits &Known, unsigned Depth, + const SimplifyQuery &Q) { Use &U = I->getOperandUse(OpNo); - Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, Known, - Depth, I); + Value *V = U.get(); + if (isa<Constant>(V)) { + llvm::computeKnownBits(V, Known, Depth, Q); + return false; + } + + Known.resetAll(); + if (DemandedMask.isZero()) { + // Not demanding any bits from V. + replaceUse(U, UndefValue::get(V->getType())); + return true; + } + + if (Depth == MaxAnalysisRecursionDepth) + return false; + + Instruction *VInst = dyn_cast<Instruction>(V); + if (!VInst) { + llvm::computeKnownBits(V, Known, Depth, Q); + return false; + } + + Value *NewVal; + if (VInst->hasOneUse()) { + // If the instruction has one use, we can directly simplify it. + NewVal = SimplifyDemandedUseBits(VInst, DemandedMask, Known, Depth, Q); + } else { + // If there are multiple uses of this instruction, then we can simplify + // VInst to some other value, but not modify the instruction. + NewVal = + SimplifyMultipleUseDemandedBits(VInst, DemandedMask, Known, Depth, Q); + } if (!NewVal) return false; if (Instruction* OpInst = dyn_cast<Instruction>(U)) salvageDebugInfo(*OpInst); @@ -123,50 +154,21 @@ bool InstCombinerImpl::SimplifyDemandedBits(Instruction *I, unsigned OpNo, /// operands based on the information about what bits are demanded. This returns /// some other non-null value if it found out that V is equal to another value /// in the context where the specified bits are demanded, but not for all users. -Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, +Value *InstCombinerImpl::SimplifyDemandedUseBits(Instruction *I, + const APInt &DemandedMask, KnownBits &Known, unsigned Depth, - Instruction *CxtI) { - assert(V != nullptr && "Null pointer of Value???"); + const SimplifyQuery &Q) { + assert(I != nullptr && "Null pointer of Value???"); assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth"); uint32_t BitWidth = DemandedMask.getBitWidth(); - Type *VTy = V->getType(); + Type *VTy = I->getType(); assert( (!VTy->isIntOrIntVectorTy() || VTy->getScalarSizeInBits() == BitWidth) && Known.getBitWidth() == BitWidth && "Value *V, DemandedMask and Known must have same BitWidth"); - if (isa<Constant>(V)) { - computeKnownBits(V, Known, Depth, CxtI); - return nullptr; - } - - Known.resetAll(); - if (DemandedMask.isZero()) // Not demanding any bits from V. - return UndefValue::get(VTy); - - if (Depth == MaxAnalysisRecursionDepth) - return nullptr; - - Instruction *I = dyn_cast<Instruction>(V); - if (!I) { - computeKnownBits(V, Known, Depth, CxtI); - return nullptr; // Only analyze instructions. - } - - // If there are multiple uses of this value and we aren't at the root, then - // we can't do any simplifications of the operands, because DemandedMask - // only reflects the bits demanded by *one* of the users. - if (Depth != 0 && !I->hasOneUse()) - return SimplifyMultipleUseDemandedBits(I, DemandedMask, Known, Depth, CxtI); - KnownBits LHSKnown(BitWidth), RHSKnown(BitWidth); - // If this is the root being simplified, allow it to have multiple uses, - // just set the DemandedMask to all bits so that we can try to simplify the - // operands. This allows visitTruncInst (for example) to simplify the - // operand of a trunc without duplicating all the logic below. - if (Depth == 0 && !V->hasOneUse()) - DemandedMask.setAllBits(); // Update flags after simplifying an operand based on the fact that some high // order bits are not demanded. @@ -190,9 +192,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // significant bit and all those below it. DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); if (ShrinkDemandedConstant(I, 0, DemandedFromOps) || - SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) || + SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1, Q) || ShrinkDemandedConstant(I, 1, DemandedFromOps) || - SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) { + SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1, Q)) { disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); return true; } @@ -201,19 +203,17 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, switch (I->getOpcode()) { default: - computeKnownBits(I, Known, Depth, CxtI); + llvm::computeKnownBits(I, Known, Depth, Q); break; case Instruction::And: { // If either the LHS or the RHS are Zero, the result is zero. - if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) || + if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1, Q) || SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.Zero, LHSKnown, - Depth + 1)) + Depth + 1, Q)) return I; - assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); - assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, - Depth, SQ.getWithInstruction(CxtI)); + Depth, Q); // If the client is only demanding bits that we know, return the known // constant. @@ -235,18 +235,16 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } case Instruction::Or: { // If either the LHS or the RHS are One, the result is One. - if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) || + if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1, Q) || SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.One, LHSKnown, - Depth + 1)) { + Depth + 1, Q)) { // Disjoint flag may not longer hold. I->dropPoisonGeneratingFlags(); return I; } - assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); - assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, - Depth, SQ.getWithInstruction(CxtI)); + Depth, Q); // If the client is only demanding bits that we know, return the known // constant. @@ -268,7 +266,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (!cast<PossiblyDisjointInst>(I)->isDisjoint()) { WithCache<const Value *> LHSCache(I->getOperand(0), LHSKnown), RHSCache(I->getOperand(1), RHSKnown); - if (haveNoCommonBitsSet(LHSCache, RHSCache, SQ.getWithInstruction(I))) { + if (haveNoCommonBitsSet(LHSCache, RHSCache, Q)) { cast<PossiblyDisjointInst>(I)->setIsDisjoint(true); return I; } @@ -277,8 +275,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, break; } case Instruction::Xor: { - if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) || - SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1)) + if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1, Q) || + SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1, Q)) return I; Value *LHS, *RHS; if (DemandedMask == 1 && @@ -291,11 +289,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, Xor); } - assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); - assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); - Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, - Depth, SQ.getWithInstruction(CxtI)); + Depth, Q); // If the client is only demanding bits that we know, return the known // constant. @@ -372,11 +367,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, break; } case Instruction::Select: { - if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnown, Depth + 1) || - SimplifyDemandedBits(I, 1, DemandedMask, LHSKnown, Depth + 1)) + if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnown, Depth + 1, Q) || + SimplifyDemandedBits(I, 1, DemandedMask, LHSKnown, Depth + 1, Q)) return I; - assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); - assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); // If the operands are constants, see if we can simplify them. // This is similar to ShrinkDemandedConstant, but for a select we want to @@ -416,6 +409,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return I; // Only known if known in both the LHS and RHS. + adjustKnownBitsForSelectArm(LHSKnown, I->getOperand(0), I->getOperand(1), + /*Invert=*/false, Depth, Q); + adjustKnownBitsForSelectArm(RHSKnown, I->getOperand(0), I->getOperand(2), + /*Invert=*/true, Depth, Q); Known = LHSKnown.intersectWith(RHSKnown); break; } @@ -443,7 +440,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth); KnownBits InputKnown(SrcBitWidth); - if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1)) { + if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1, + Q)) { // For zext nneg, we may have dropped the instruction which made the // input non-negative. I->dropPoisonGeneratingFlags(); @@ -455,7 +453,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, InputKnown.makeNonNegative(); Known = InputKnown.zextOrTrunc(BitWidth); - assert(!Known.hasConflict() && "Bits known to be one AND zero?"); break; } case Instruction::SExt: { @@ -470,7 +467,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, InputDemandedBits.setBit(SrcBitWidth-1); KnownBits InputKnown(SrcBitWidth); - if (SimplifyDemandedBits(I, 0, InputDemandedBits, InputKnown, Depth + 1)) + if (SimplifyDemandedBits(I, 0, InputDemandedBits, InputKnown, Depth + 1, Q)) return I; // If the input sign bit is known zero, or if the NewBits are not demanded @@ -481,12 +478,11 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy); NewCast->takeName(I); return InsertNewInstWith(NewCast, I->getIterator()); - } + } // If the sign bit of the input is known set or clear, then we know the // top bits of the result. Known = InputKnown.sext(BitWidth); - assert(!Known.hasConflict() && "Bits known to be one AND zero?"); break; } case Instruction::Add: { @@ -510,10 +506,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } // add iN (sext i1 X), (sext i1 Y) --> sext (X | Y) to iN - // TODO: Relax the one-use checks because we are removing an instruction? - if (match(I, m_Add(m_OneUse(m_SExt(m_Value(X))), - m_OneUse(m_SExt(m_Value(Y))))) && - X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType()) { + if (match(I, m_Add(m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) && + X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType() && + (I->getOperand(0)->hasOneUse() || I->getOperand(1)->hasOneUse())) { + // Truth table for inputs and output signbits: // X:0 | X:1 // ----------- @@ -532,7 +528,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, unsigned NLZ = DemandedMask.countl_zero(); APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); if (ShrinkDemandedConstant(I, 1, DemandedFromOps) || - SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) + SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1, Q)) return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); // If low order bits are not demanded and known to be zero in one operand, @@ -542,7 +538,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt DemandedFromLHS = DemandedFromOps; DemandedFromLHS.clearLowBits(NTZ); if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) || - SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1)) + SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1, Q)) return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); // If we are known to be adding zeros to every bit below @@ -565,7 +561,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // Otherwise just compute the known bits of the result. bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); - Known = KnownBits::computeForAddSub(true, NSW, LHSKnown, RHSKnown); + bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap(); + Known = KnownBits::computeForAddSub(true, NSW, NUW, LHSKnown, RHSKnown); break; } case Instruction::Sub: { @@ -574,7 +571,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, unsigned NLZ = DemandedMask.countl_zero(); APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); if (ShrinkDemandedConstant(I, 1, DemandedFromOps) || - SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) + SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1, Q)) return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); // If low order bits are not demanded and are known to be zero in RHS, @@ -584,7 +581,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt DemandedFromLHS = DemandedFromOps; DemandedFromLHS.clearLowBits(NTZ); if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) || - SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1)) + SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1, Q)) return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); // If we are known to be subtracting zeros from every bit below @@ -598,7 +595,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // Otherwise just compute the known bits of the result. bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); - Known = KnownBits::computeForAddSub(false, NSW, LHSKnown, RHSKnown); + bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap(); + Known = KnownBits::computeForAddSub(false, NSW, NUW, LHSKnown, RHSKnown); break; } case Instruction::Mul: { @@ -627,7 +625,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return InsertNewInstWith(And1, I->getIterator()); } - computeKnownBits(I, Known, Depth, CxtI); + llvm::computeKnownBits(I, Known, Depth, Q); break; } case Instruction::Shl: { @@ -640,25 +638,48 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, DemandedMask, Known)) return R; - // TODO: If we only want bits that already match the signbit then we don't + // Do not simplify if shl is part of funnel-shift pattern + if (I->hasOneUse()) { + auto *Inst = dyn_cast<Instruction>(I->user_back()); + if (Inst && Inst->getOpcode() == BinaryOperator::Or) { + if (auto Opt = convertOrOfShiftsToFunnelShift(*Inst)) { + auto [IID, FShiftArgs] = *Opt; + if ((IID == Intrinsic::fshl || IID == Intrinsic::fshr) && + FShiftArgs[0] == FShiftArgs[1]) { + llvm::computeKnownBits(I, Known, Depth, Q); + break; + } + } + } + } + + // We only want bits that already match the signbit then we don't // need to shift. + uint64_t ShiftAmt = SA->getLimitedValue(BitWidth - 1); + if (DemandedMask.countr_zero() >= ShiftAmt) { + if (I->hasNoSignedWrap()) { + unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero(); + unsigned SignBits = + ComputeNumSignBits(I->getOperand(0), Depth + 1, Q.CxtI); + if (SignBits > ShiftAmt && SignBits - ShiftAmt >= NumHiDemandedBits) + return I->getOperand(0); + } - // If we can pre-shift a right-shifted constant to the left without - // losing any high bits amd we don't demand the low bits, then eliminate - // the left-shift: - // (C >> X) << LeftShiftAmtC --> (C << RightShiftAmtC) >> X - uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); - Value *X; - Constant *C; - if (DemandedMask.countr_zero() >= ShiftAmt && - match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) { - Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt); - Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::Shl, C, - LeftShiftAmtC, DL); - if (ConstantFoldBinaryOpOperands(Instruction::LShr, NewC, LeftShiftAmtC, - DL) == C) { - Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X); - return InsertNewInstWith(Lshr, I->getIterator()); + // If we can pre-shift a right-shifted constant to the left without + // losing any high bits and we don't demand the low bits, then eliminate + // the left-shift: + // (C >> X) << LeftShiftAmtC --> (C << LeftShiftAmtC) >> X + Value *X; + Constant *C; + if (match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) { + Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt); + Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::Shl, C, + LeftShiftAmtC, DL); + if (ConstantFoldBinaryOpOperands(Instruction::LShr, NewC, + LeftShiftAmtC, DL) == C) { + Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X); + return InsertNewInstWith(Lshr, I->getIterator()); + } } } @@ -671,9 +692,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, else if (IOp->hasNoUnsignedWrap()) DemandedMaskIn.setHighBits(ShiftAmt); - if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1, Q)) return I; - assert(!Known.hasConflict() && "Bits known to be one AND zero?"); Known = KnownBits::shl(Known, KnownBits::makeConstant(APInt(BitWidth, ShiftAmt)), @@ -685,13 +705,13 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // demanding those bits from the pre-shifted operand either. if (unsigned CTLZ = DemandedMask.countl_zero()) { APInt DemandedFromOp(APInt::getLowBitsSet(BitWidth, BitWidth - CTLZ)); - if (SimplifyDemandedBits(I, 0, DemandedFromOp, Known, Depth + 1)) { + if (SimplifyDemandedBits(I, 0, DemandedFromOp, Known, Depth + 1, Q)) { // We can't guarantee that nsw/nuw hold after simplifying the operand. I->dropPoisonGeneratingFlags(); return I; } } - computeKnownBits(I, Known, Depth, CxtI); + llvm::computeKnownBits(I, Known, Depth, Q); } break; } @@ -700,6 +720,21 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (match(I->getOperand(1), m_APInt(SA))) { uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); + // Do not simplify if lshr is part of funnel-shift pattern + if (I->hasOneUse()) { + auto *Inst = dyn_cast<Instruction>(I->user_back()); + if (Inst && Inst->getOpcode() == BinaryOperator::Or) { + if (auto Opt = convertOrOfShiftsToFunnelShift(*Inst)) { + auto [IID, FShiftArgs] = *Opt; + if ((IID == Intrinsic::fshl || IID == Intrinsic::fshr) && + FShiftArgs[0] == FShiftArgs[1]) { + llvm::computeKnownBits(I, Known, Depth, Q); + break; + } + } + } + } + // If we are just demanding the shifted sign bit and below, then this can // be treated as an ASHR in disguise. if (DemandedMask.countl_zero() >= ShiftAmt) { @@ -707,7 +742,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // need to shift. unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero(); unsigned SignBits = - ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI); + ComputeNumSignBits(I->getOperand(0), Depth + 1, Q.CxtI); if (SignBits >= NumHiDemandedBits) return I->getOperand(0); @@ -731,23 +766,22 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // Unsigned shift right. APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); - if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) { + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1, Q)) { // exact flag may not longer hold. I->dropPoisonGeneratingFlags(); return I; } - assert(!Known.hasConflict() && "Bits known to be one AND zero?"); Known.Zero.lshrInPlace(ShiftAmt); Known.One.lshrInPlace(ShiftAmt); if (ShiftAmt) Known.Zero.setHighBits(ShiftAmt); // high bits known zero. } else { - computeKnownBits(I, Known, Depth, CxtI); + llvm::computeKnownBits(I, Known, Depth, Q); } break; } case Instruction::AShr: { - unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI); + unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, Q.CxtI); // If we only want bits that already match the signbit then we don't need // to shift. @@ -772,42 +806,32 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // Signed shift right. APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); - // If any of the high bits are demanded, we should set the sign bit as - // demanded. - if (DemandedMask.countl_zero() <= ShiftAmt) + // If any of the bits being shifted in are demanded, then we should set + // the sign bit as demanded. + bool ShiftedInBitsDemanded = DemandedMask.countl_zero() < ShiftAmt; + if (ShiftedInBitsDemanded) DemandedMaskIn.setSignBit(); - - if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) { + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1, Q)) { // exact flag may not longer hold. I->dropPoisonGeneratingFlags(); return I; } - assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - // Compute the new bits that are at the top now plus sign bits. - APInt HighBits(APInt::getHighBitsSet( - BitWidth, std::min(SignBits + ShiftAmt - 1, BitWidth))); - Known.Zero.lshrInPlace(ShiftAmt); - Known.One.lshrInPlace(ShiftAmt); - - // If the input sign bit is known to be zero, or if none of the top bits - // are demanded, turn this into an unsigned shift right. - assert(BitWidth > ShiftAmt && "Shift amount not saturated?"); - if (Known.Zero[BitWidth-ShiftAmt-1] || - !DemandedMask.intersects(HighBits)) { + // If the input sign bit is known to be zero, or if none of the shifted in + // bits are demanded, turn this into an unsigned shift right. + if (Known.Zero[BitWidth - 1] || !ShiftedInBitsDemanded) { BinaryOperator *LShr = BinaryOperator::CreateLShr(I->getOperand(0), I->getOperand(1)); LShr->setIsExact(cast<BinaryOperator>(I)->isExact()); LShr->takeName(I); return InsertNewInstWith(LShr, I->getIterator()); - } else if (Known.One[BitWidth-ShiftAmt-1]) { // New bits are known one. - Known.One |= HighBits; - // SignBits may be out-of-sync with Known.countMinSignBits(). Mask out - // high bits of Known.Zero to avoid conflicts. - Known.Zero &= ~HighBits; } + + Known = KnownBits::ashr( + Known, KnownBits::makeConstant(APInt(BitWidth, ShiftAmt)), + ShiftAmt != 0, I->isExact()); } else { - computeKnownBits(I, Known, Depth, CxtI); + llvm::computeKnownBits(I, Known, Depth, Q); } break; } @@ -819,7 +843,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, unsigned RHSTrailingZeros = SA->countr_zero(); APInt DemandedMaskIn = APInt::getHighBitsSet(BitWidth, BitWidth - RHSTrailingZeros); - if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1)) { + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1, Q)) { // We can't guarantee that "exact" is still true after changing the // the dividend. I->dropPoisonGeneratingFlags(); @@ -829,7 +853,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Known = KnownBits::udiv(LHSKnown, KnownBits::makeConstant(*SA), cast<BinaryOperator>(I)->isExact()); } else { - computeKnownBits(I, Known, Depth, CxtI); + llvm::computeKnownBits(I, Known, Depth, Q); } break; } @@ -847,7 +871,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt LowBits = RA - 1; APInt Mask2 = LowBits | APInt::getSignMask(BitWidth); - if (SimplifyDemandedBits(I, 0, Mask2, LHSKnown, Depth + 1)) + if (SimplifyDemandedBits(I, 0, Mask2, LHSKnown, Depth + 1, Q)) return I; // The low bits of LHS are unchanged by the srem. @@ -864,21 +888,11 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (LHSKnown.isNegative() && LowBits.intersects(LHSKnown.One)) Known.One |= ~LowBits; - assert(!Known.hasConflict() && "Bits known to be one AND zero?"); break; } } - computeKnownBits(I, Known, Depth, CxtI); - break; - } - case Instruction::URem: { - APInt AllOnes = APInt::getAllOnes(BitWidth); - if (SimplifyDemandedBits(I, 0, AllOnes, LHSKnown, Depth + 1) || - SimplifyDemandedBits(I, 1, AllOnes, RHSKnown, Depth + 1)) - return I; - - Known = KnownBits::urem(LHSKnown, RHSKnown); + llvm::computeKnownBits(I, Known, Depth, Q); break; } case Instruction::Call: { @@ -934,16 +948,14 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, unsigned MaskWidth = I->getOperand(1)->getType()->getScalarSizeInBits(); RHSKnown = KnownBits(MaskWidth); // If either the LHS or the RHS are Zero, the result is zero. - if (SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1) || + if (SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1, Q) || SimplifyDemandedBits( I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth), - RHSKnown, Depth + 1)) + RHSKnown, Depth + 1, Q)) return I; // TODO: Should be 1-extend RHSKnown = RHSKnown.anyextOrTrunc(BitWidth); - assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); - assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); Known = LHSKnown & RHSKnown; KnownBitsComputed = true; @@ -969,6 +981,44 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth))) return I; + // Combine: + // (ptrmask (getelementptr i8, ptr p, imm i), imm mask) + // -> (ptrmask (getelementptr i8, ptr p, imm (i & mask)), imm mask) + // where only the low bits known to be zero in the pointer are changed + Value *InnerPtr; + uint64_t GEPIndex; + uint64_t PtrMaskImmediate; + if (match(I, m_Intrinsic<Intrinsic::ptrmask>( + m_PtrAdd(m_Value(InnerPtr), m_ConstantInt(GEPIndex)), + m_ConstantInt(PtrMaskImmediate)))) { + + LHSKnown = computeKnownBits(InnerPtr, Depth + 1, I); + if (!LHSKnown.isZero()) { + const unsigned trailingZeros = LHSKnown.countMinTrailingZeros(); + uint64_t PointerAlignBits = (uint64_t(1) << trailingZeros) - 1; + + uint64_t HighBitsGEPIndex = GEPIndex & ~PointerAlignBits; + uint64_t MaskedLowBitsGEPIndex = + GEPIndex & PointerAlignBits & PtrMaskImmediate; + + uint64_t MaskedGEPIndex = HighBitsGEPIndex | MaskedLowBitsGEPIndex; + + if (MaskedGEPIndex != GEPIndex) { + auto *GEP = cast<GEPOperator>(II->getArgOperand(0)); + Builder.SetInsertPoint(I); + Type *GEPIndexType = + DL.getIndexType(GEP->getPointerOperand()->getType()); + Value *MaskedGEP = Builder.CreateGEP( + GEP->getSourceElementType(), InnerPtr, + ConstantInt::get(GEPIndexType, MaskedGEPIndex), + GEP->getName(), GEP->isInBounds()); + + replaceOperand(*I, 0, MaskedGEP); + return I; + } + } + } + break; } @@ -988,8 +1038,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt DemandedMaskRHS(DemandedMask.shl(BitWidth - ShiftAmt)); if (I->getOperand(0) != I->getOperand(1)) { if (SimplifyDemandedBits(I, 0, DemandedMaskLHS, LHSKnown, - Depth + 1) || - SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1)) + Depth + 1, Q) || + SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1, + Q)) return I; } else { // fshl is a rotate // Avoid converting rotate into funnel shift. @@ -1051,13 +1102,13 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } if (!KnownBitsComputed) - computeKnownBits(V, Known, Depth, CxtI); + llvm::computeKnownBits(I, Known, Depth, Q); break; } } - if (V->getType()->isPointerTy()) { - Align Alignment = V->getPointerAlignment(DL); + if (I->getType()->isPointerTy()) { + Align Alignment = I->getPointerAlignment(DL); Known.Zero.setLowBits(Log2(Alignment)); } @@ -1065,13 +1116,14 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // constant. We can't directly simplify pointers as a constant because of // pointer provenance. // TODO: We could return `(inttoptr const)` for pointers. - if (!V->getType()->isPointerTy() && DemandedMask.isSubsetOf(Known.Zero | Known.One)) + if (!I->getType()->isPointerTy() && + DemandedMask.isSubsetOf(Known.Zero | Known.One)) return Constant::getIntegerValue(VTy, Known.One); if (VerifyKnownBits) { - KnownBits ReferenceKnown = computeKnownBits(V, Depth, CxtI); + KnownBits ReferenceKnown = llvm::computeKnownBits(I, Depth, Q); if (Known != ReferenceKnown) { - errs() << "Mismatched known bits for " << *V << " in " + errs() << "Mismatched known bits for " << *I << " in " << I->getFunction()->getName() << "\n"; errs() << "computeKnownBits(): " << ReferenceKnown << "\n"; errs() << "SimplifyDemandedBits(): " << Known << "\n"; @@ -1087,7 +1139,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, /// DemandedMask, but without modifying the Instruction. Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( Instruction *I, const APInt &DemandedMask, KnownBits &Known, unsigned Depth, - Instruction *CxtI) { + const SimplifyQuery &Q) { unsigned BitWidth = DemandedMask.getBitWidth(); Type *ITy = I->getType(); @@ -1100,11 +1152,11 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( // this instruction has a simpler value in that context. switch (I->getOpcode()) { case Instruction::And: { - computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); - computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); + llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q); + llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q); Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, - Depth, SQ.getWithInstruction(CxtI)); - computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI)); + Depth, Q); + computeKnownBitsFromContext(I, Known, Depth, Q); // If the client is only demanding bits that we know, return the known // constant. @@ -1121,11 +1173,11 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( break; } case Instruction::Or: { - computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); - computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); + llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q); + llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q); Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, - Depth, SQ.getWithInstruction(CxtI)); - computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI)); + Depth, Q); + computeKnownBitsFromContext(I, Known, Depth, Q); // If the client is only demanding bits that we know, return the known // constant. @@ -1144,11 +1196,11 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( break; } case Instruction::Xor: { - computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); - computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); + llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q); + llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q); Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, - Depth, SQ.getWithInstruction(CxtI)); - computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI)); + Depth, Q); + computeKnownBitsFromContext(I, Known, Depth, Q); // If the client is only demanding bits that we know, return the known // constant. @@ -1171,17 +1223,19 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( // If an operand adds zeros to every bit below the highest demanded bit, // that operand doesn't change the result. Return the other side. - computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); + llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q); if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) return I->getOperand(0); - computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); + llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q); if (DemandedFromOps.isSubsetOf(LHSKnown.Zero)) return I->getOperand(1); bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); - Known = KnownBits::computeForAddSub(/*Add*/ true, NSW, LHSKnown, RHSKnown); - computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI)); + bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap(); + Known = + KnownBits::computeForAddSub(/*Add=*/true, NSW, NUW, LHSKnown, RHSKnown); + computeKnownBitsFromContext(I, Known, Depth, Q); break; } case Instruction::Sub: { @@ -1190,19 +1244,21 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( // If an operand subtracts zeros from every bit below the highest demanded // bit, that operand doesn't change the result. Return the other side. - computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); + llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q); if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) return I->getOperand(0); bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); - computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); - Known = KnownBits::computeForAddSub(/*Add*/ false, NSW, LHSKnown, RHSKnown); - computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI)); + bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap(); + llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q); + Known = KnownBits::computeForAddSub(/*Add=*/false, NSW, NUW, LHSKnown, + RHSKnown); + computeKnownBitsFromContext(I, Known, Depth, Q); break; } case Instruction::AShr: { // Compute the Known bits to simplify things downstream. - computeKnownBits(I, Known, Depth, CxtI); + llvm::computeKnownBits(I, Known, Depth, Q); // If this user is only demanding bits that we know, return the known // constant. @@ -1229,7 +1285,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( } default: // Compute the Known bits to simplify things downstream. - computeKnownBits(I, Known, Depth, CxtI); + llvm::computeKnownBits(I, Known, Depth, Q); // If this user is only demanding bits that we know, return the known // constant. @@ -1844,14 +1900,16 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, Value *ShufOp = MatchShufAsOp0 ? X : Y; Value *OtherOp = MatchShufAsOp0 ? Y : X; for (User *U : OtherOp->users()) { - auto Shuf = m_Shuffle(m_Specific(ShufOp), m_Value(), m_ZeroMask()); + ArrayRef<int> Mask; + auto Shuf = m_Shuffle(m_Specific(ShufOp), m_Value(), m_Mask(Mask)); if (BO->isCommutative() ? match(U, m_c_BinOp(Opcode, Shuf, m_Specific(OtherOp))) : MatchShufAsOp0 ? match(U, m_BinOp(Opcode, Shuf, m_Specific(OtherOp))) : match(U, m_BinOp(Opcode, m_Specific(OtherOp), Shuf))) - if (DT.dominates(U, I)) - return U; + if (match(Mask, m_ZeroMask()) && Mask[0] != PoisonMaskElem) + if (DT.dominates(U, I)) + return U; } return nullptr; }; @@ -1877,3 +1935,139 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, return MadeChange ? I : nullptr; } + +/// For floating-point classes that resolve to a single bit pattern, return that +/// value. +static Constant *getFPClassConstant(Type *Ty, FPClassTest Mask) { + switch (Mask) { + case fcPosZero: + return ConstantFP::getZero(Ty); + case fcNegZero: + return ConstantFP::getZero(Ty, true); + case fcPosInf: + return ConstantFP::getInfinity(Ty); + case fcNegInf: + return ConstantFP::getInfinity(Ty, true); + case fcNone: + return PoisonValue::get(Ty); + default: + return nullptr; + } +} + +Value *InstCombinerImpl::SimplifyDemandedUseFPClass( + Value *V, const FPClassTest DemandedMask, KnownFPClass &Known, + unsigned Depth, Instruction *CxtI) { + assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth"); + Type *VTy = V->getType(); + + assert(Known == KnownFPClass() && "expected uninitialized state"); + + if (DemandedMask == fcNone) + return isa<UndefValue>(V) ? nullptr : PoisonValue::get(VTy); + + if (Depth == MaxAnalysisRecursionDepth) + return nullptr; + + Instruction *I = dyn_cast<Instruction>(V); + if (!I) { + // Handle constants and arguments + Known = computeKnownFPClass(V, fcAllFlags, CxtI, Depth + 1); + Value *FoldedToConst = + getFPClassConstant(VTy, DemandedMask & Known.KnownFPClasses); + return FoldedToConst == V ? nullptr : FoldedToConst; + } + + if (!I->hasOneUse()) + return nullptr; + + // TODO: Should account for nofpclass/FastMathFlags on current instruction + switch (I->getOpcode()) { + case Instruction::FNeg: { + if (SimplifyDemandedFPClass(I, 0, llvm::fneg(DemandedMask), Known, + Depth + 1)) + return I; + Known.fneg(); + break; + } + case Instruction::Call: { + CallInst *CI = cast<CallInst>(I); + switch (CI->getIntrinsicID()) { + case Intrinsic::fabs: + if (SimplifyDemandedFPClass(I, 0, llvm::inverse_fabs(DemandedMask), Known, + Depth + 1)) + return I; + Known.fabs(); + break; + case Intrinsic::arithmetic_fence: + if (SimplifyDemandedFPClass(I, 0, DemandedMask, Known, Depth + 1)) + return I; + break; + case Intrinsic::copysign: { + // Flip on more potentially demanded classes + const FPClassTest DemandedMaskAnySign = llvm::unknown_sign(DemandedMask); + if (SimplifyDemandedFPClass(I, 0, DemandedMaskAnySign, Known, Depth + 1)) + return I; + + if ((DemandedMask & fcPositive) == fcNone) { + // Roundabout way of replacing with fneg(fabs) + I->setOperand(1, ConstantFP::get(VTy, -1.0)); + return I; + } + + if ((DemandedMask & fcNegative) == fcNone) { + // Roundabout way of replacing with fabs + I->setOperand(1, ConstantFP::getZero(VTy)); + return I; + } + + KnownFPClass KnownSign = + computeKnownFPClass(I->getOperand(1), fcAllFlags, CxtI, Depth + 1); + Known.copysign(KnownSign); + break; + } + default: + Known = computeKnownFPClass(I, ~DemandedMask, CxtI, Depth + 1); + break; + } + + break; + } + case Instruction::Select: { + KnownFPClass KnownLHS, KnownRHS; + if (SimplifyDemandedFPClass(I, 2, DemandedMask, KnownRHS, Depth + 1) || + SimplifyDemandedFPClass(I, 1, DemandedMask, KnownLHS, Depth + 1)) + return I; + + if (KnownLHS.isKnownNever(DemandedMask)) + return I->getOperand(2); + if (KnownRHS.isKnownNever(DemandedMask)) + return I->getOperand(1); + + // TODO: Recognize clamping patterns + Known = KnownLHS | KnownRHS; + break; + } + default: + Known = computeKnownFPClass(I, ~DemandedMask, CxtI, Depth + 1); + break; + } + + return getFPClassConstant(VTy, DemandedMask & Known.KnownFPClasses); +} + +bool InstCombinerImpl::SimplifyDemandedFPClass(Instruction *I, unsigned OpNo, + FPClassTest DemandedMask, + KnownFPClass &Known, + unsigned Depth) { + Use &U = I->getOperandUse(OpNo); + Value *NewVal = + SimplifyDemandedUseFPClass(U.get(), DemandedMask, Known, Depth, I); + if (!NewVal) + return false; + if (Instruction *OpInst = dyn_cast<Instruction>(U)) + salvageDebugInfo(*OpInst); + + replaceUse(U, NewVal); + return true; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 18ab510aae7f..753ed55523c8 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -419,6 +419,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { // If extracting a specified index from the vector, see if we can recursively // find a previously computed scalar that was inserted into the vector. auto *IndexC = dyn_cast<ConstantInt>(Index); + bool HasKnownValidIndex = false; if (IndexC) { // Canonicalize type of constant indices to i64 to simplify CSE if (auto *NewIdx = getPreferredVectorIndex(IndexC)) @@ -426,6 +427,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { ElementCount EC = EI.getVectorOperandType()->getElementCount(); unsigned NumElts = EC.getKnownMinValue(); + HasKnownValidIndex = IndexC->getValue().ult(NumElts); if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(SrcVec)) { Intrinsic::ID IID = II->getIntrinsicID(); @@ -471,8 +473,11 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { return UnaryOperator::CreateWithCopiedFlags(UO->getOpcode(), E, UO); } + // If the binop is not speculatable, we cannot hoist the extractelement if + // it may make the operand poison. BinaryOperator *BO; - if (match(SrcVec, m_BinOp(BO)) && cheapToScalarize(SrcVec, Index)) { + if (match(SrcVec, m_BinOp(BO)) && cheapToScalarize(SrcVec, Index) && + (HasKnownValidIndex || isSafeToSpeculativelyExecute(BO))) { // extelt (binop X, Y), Index --> binop (extelt X, Index), (extelt Y, Index) Value *X = BO->getOperand(0), *Y = BO->getOperand(1); Value *E0 = Builder.CreateExtractElement(X, Index); @@ -487,7 +492,9 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { // extelt (cmp X, Y), Index --> cmp (extelt X, Index), (extelt Y, Index) Value *E0 = Builder.CreateExtractElement(X, Index); Value *E1 = Builder.CreateExtractElement(Y, Index); - return CmpInst::Create(cast<CmpInst>(SrcVec)->getOpcode(), Pred, E0, E1); + CmpInst *SrcCmpInst = cast<CmpInst>(SrcVec); + return CmpInst::CreateWithCopiedFlags(SrcCmpInst->getOpcode(), Pred, E0, E1, + SrcCmpInst); } if (auto *I = dyn_cast<Instruction>(SrcVec)) { @@ -617,7 +624,7 @@ static bool collectSingleShuffleElements(Value *V, Value *LHS, Value *RHS, "Invalid CollectSingleShuffleElements"); unsigned NumElts = cast<FixedVectorType>(V->getType())->getNumElements(); - if (match(V, m_Undef())) { + if (match(V, m_Poison())) { Mask.assign(NumElts, -1); return true; } @@ -1263,7 +1270,8 @@ static Instruction *foldInsSequenceIntoSplat(InsertElementInst &InsElt) { PoisonValue *PoisonVec = PoisonValue::get(VecTy); Constant *Zero = ConstantInt::get(Int64Ty, 0); if (!cast<ConstantInt>(FirstIE->getOperand(2))->isZero()) - FirstIE = InsertElementInst::Create(PoisonVec, SplatVal, Zero, "", &InsElt); + FirstIE = InsertElementInst::Create(PoisonVec, SplatVal, Zero, "", + InsElt.getIterator()); // Splat from element 0, but replace absent elements with poison in the mask. SmallVector<int, 16> Mask(NumElements, 0); @@ -1316,7 +1324,7 @@ static Instruction *foldInsEltIntoSplat(InsertElementInst &InsElt) { static Instruction *foldInsEltIntoIdentityShuffle(InsertElementInst &InsElt) { // Check if the vector operand of this insert is an identity shuffle. auto *Shuf = dyn_cast<ShuffleVectorInst>(InsElt.getOperand(0)); - if (!Shuf || !match(Shuf->getOperand(1), m_Undef()) || + if (!Shuf || !match(Shuf->getOperand(1), m_Poison()) || !(Shuf->isIdentityWithExtract() || Shuf->isIdentityWithPadding())) return nullptr; @@ -2060,16 +2068,17 @@ static BinopElts getAlternateBinop(BinaryOperator *BO, const DataLayout &DL) { case Instruction::Shl: { // shl X, C --> mul X, (1 << C) Constant *C; - if (match(BO1, m_Constant(C))) { - Constant *ShlOne = ConstantExpr::getShl(ConstantInt::get(Ty, 1), C); + if (match(BO1, m_ImmConstant(C))) { + Constant *ShlOne = ConstantFoldBinaryOpOperands( + Instruction::Shl, ConstantInt::get(Ty, 1), C, DL); + assert(ShlOne && "Constant folding of immediate constants failed"); return {Instruction::Mul, BO0, ShlOne}; } break; } case Instruction::Or: { - // or X, C --> add X, C (when X and C have no common bits set) - const APInt *C; - if (match(BO1, m_APInt(C)) && MaskedValueIsZero(BO0, *C, DL)) + // or disjoin X, C --> add X, C + if (cast<PossiblyDisjointInst>(BO)->isDisjoint()) return {Instruction::Add, BO0, BO1}; break; } @@ -2134,7 +2143,8 @@ static Instruction *foldSelectShuffleOfSelectShuffle(ShuffleVectorInst &Shuf) { return new ShuffleVectorInst(X, Y, NewMask); } -static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) { +static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf, + const SimplifyQuery &SQ) { assert(Shuf.isSelect() && "Must have select-equivalent shuffle"); // Are we shuffling together some value and that same value after it has been @@ -2158,6 +2168,19 @@ static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) { if (!IdC) return nullptr; + Value *X = Op0IsBinop ? Op1 : Op0; + + // Prevent folding in the case the non-binop operand might have NaN values. + // If X can have NaN elements then we have that the floating point math + // operation in the transformed code may not preserve the exact NaN + // bit-pattern -- e.g. `fadd sNaN, 0.0 -> qNaN`. + // This makes the transformation incorrect since the original program would + // have preserved the exact NaN bit-pattern. + // Avoid the folding if X can have NaN elements. + if (Shuf.getType()->getElementType()->isFloatingPointTy() && + !isKnownNeverNaN(X, 0, SQ)) + return nullptr; + // Shuffle identity constants into the lanes that return the original value. // Example: shuf (mul X, {-1,-2,-3,-4}), X, {0,5,6,3} --> mul X, {-1,1,1,-4} // Example: shuf X, (add X, {-1,-2,-3,-4}), {0,1,6,7} --> add X, {0,0,-3,-4} @@ -2174,7 +2197,6 @@ static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) { // shuf (bop X, C), X, M --> bop X, C' // shuf X, (bop X, C), M --> bop X, C' - Value *X = Op0IsBinop ? Op1 : Op0; Instruction *NewBO = BinaryOperator::Create(BOpcode, X, NewC); NewBO->copyIRFlags(BO); @@ -2198,19 +2220,19 @@ static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf, uint64_t IndexC; // Match a shuffle that is a splat to a non-zero element. - if (!match(Op0, m_OneUse(m_InsertElt(m_Undef(), m_Value(X), + if (!match(Op0, m_OneUse(m_InsertElt(m_Poison(), m_Value(X), m_ConstantInt(IndexC)))) || - !match(Op1, m_Undef()) || match(Mask, m_ZeroMask()) || IndexC == 0) + !match(Op1, m_Poison()) || match(Mask, m_ZeroMask()) || IndexC == 0) return nullptr; // Insert into element 0 of a poison vector. PoisonValue *PoisonVec = PoisonValue::get(Shuf.getType()); Value *NewIns = Builder.CreateInsertElement(PoisonVec, X, (uint64_t)0); - // Splat from element 0. Any mask element that is undefined remains undefined. + // Splat from element 0. Any mask element that is poison remains poison. // For example: - // shuf (inselt undef, X, 2), _, <2,2,undef> - // --> shuf (inselt undef, X, 0), poison, <0,0,undef> + // shuf (inselt poison, X, 2), _, <2,2,undef> + // --> shuf (inselt poison, X, 0), poison, <0,0,undef> unsigned NumMaskElts = cast<FixedVectorType>(Shuf.getType())->getNumElements(); SmallVector<int, 16> NewMask(NumMaskElts, 0); @@ -2240,7 +2262,8 @@ Instruction *InstCombinerImpl::foldSelectShuffle(ShuffleVectorInst &Shuf) { if (Instruction *I = foldSelectShuffleOfSelectShuffle(Shuf)) return I; - if (Instruction *I = foldSelectShuffleWith1Binop(Shuf)) + if (Instruction *I = foldSelectShuffleWith1Binop( + Shuf, getSimplifyQuery().getWithInstruction(&Shuf))) return I; BinaryOperator *B0, *B1; @@ -2366,7 +2389,7 @@ static Instruction *foldTruncShuffle(ShuffleVectorInst &Shuf, Type *DestType = Shuf.getType(); Value *X; if (!match(Shuf.getOperand(0), m_BitCast(m_Value(X))) || - !match(Shuf.getOperand(1), m_Undef()) || !DestType->isIntOrIntVectorTy()) + !match(Shuf.getOperand(1), m_Poison()) || !DestType->isIntOrIntVectorTy()) return nullptr; // The source type must have the same number of elements as the shuffle, @@ -2399,13 +2422,13 @@ static Instruction *foldTruncShuffle(ShuffleVectorInst &Shuf, } /// Match a shuffle-select-shuffle pattern where the shuffles are widening and -/// narrowing (concatenating with undef and extracting back to the original +/// narrowing (concatenating with poison and extracting back to the original /// length). This allows replacing the wide select with a narrow select. static Instruction *narrowVectorSelect(ShuffleVectorInst &Shuf, InstCombiner::BuilderTy &Builder) { // This must be a narrowing identity shuffle. It extracts the 1st N elements // of the 1st vector operand of a shuffle. - if (!match(Shuf.getOperand(1), m_Undef()) || !Shuf.isIdentityWithExtract()) + if (!match(Shuf.getOperand(1), m_Poison()) || !Shuf.isIdentityWithExtract()) return nullptr; // The vector being shuffled must be a vector select that we can eliminate. @@ -2415,19 +2438,20 @@ static Instruction *narrowVectorSelect(ShuffleVectorInst &Shuf, m_OneUse(m_Select(m_Value(Cond), m_Value(X), m_Value(Y))))) return nullptr; - // We need a narrow condition value. It must be extended with undef elements + // We need a narrow condition value. It must be extended with poison elements // and have the same number of elements as this shuffle. unsigned NarrowNumElts = cast<FixedVectorType>(Shuf.getType())->getNumElements(); Value *NarrowCond; - if (!match(Cond, m_OneUse(m_Shuffle(m_Value(NarrowCond), m_Undef()))) || + if (!match(Cond, m_OneUse(m_Shuffle(m_Value(NarrowCond), m_Poison()))) || cast<FixedVectorType>(NarrowCond->getType())->getNumElements() != NarrowNumElts || !cast<ShuffleVectorInst>(Cond)->isIdentityWithPadding()) return nullptr; - // shuf (sel (shuf NarrowCond, undef, WideMask), X, Y), undef, NarrowMask) --> - // sel NarrowCond, (shuf X, undef, NarrowMask), (shuf Y, undef, NarrowMask) + // shuf (sel (shuf NarrowCond, poison, WideMask), X, Y), poison, NarrowMask) + // --> + // sel NarrowCond, (shuf X, poison, NarrowMask), (shuf Y, poison, NarrowMask) Value *NarrowX = Builder.CreateShuffleVector(X, Shuf.getShuffleMask()); Value *NarrowY = Builder.CreateShuffleVector(Y, Shuf.getShuffleMask()); return SelectInst::Create(NarrowCond, NarrowX, NarrowY); @@ -2445,7 +2469,7 @@ static Instruction *foldShuffleOfUnaryOps(ShuffleVectorInst &Shuf, // Match 1-input (unary) shuffle. // shuffle (fneg/fabs X), Mask --> fneg/fabs (shuffle X, Mask) - if (S0->hasOneUse() && match(Shuf.getOperand(1), m_Undef())) { + if (S0->hasOneUse() && match(Shuf.getOperand(1), m_Poison())) { Value *NewShuf = Builder.CreateShuffleVector(X, Shuf.getShuffleMask()); if (IsFNeg) return UnaryOperator::CreateFNegFMF(NewShuf, S0); @@ -2532,7 +2556,7 @@ static Instruction *foldCastShuffle(ShuffleVectorInst &Shuf, /// Try to fold an extract subvector operation. static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) { Value *Op0 = Shuf.getOperand(0), *Op1 = Shuf.getOperand(1); - if (!Shuf.isIdentityWithExtract() || !match(Op1, m_Undef())) + if (!Shuf.isIdentityWithExtract() || !match(Op1, m_Poison())) return nullptr; // Check if we are extracting all bits of an inserted scalar: @@ -2561,10 +2585,10 @@ static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) { // not allow arbitrary shuffle mask creation as a target-independent transform // (because we can't guarantee that will lower efficiently). // - // If the extracting shuffle has an undef mask element, it transfers to the + // If the extracting shuffle has an poison mask element, it transfers to the // new shuffle mask. Otherwise, copy the original mask element. Example: - // shuf (shuf X, Y, <C0, C1, C2, undef, C4>), undef, <0, undef, 2, 3> --> - // shuf X, Y, <C0, undef, C2, undef> + // shuf (shuf X, Y, <C0, C1, C2, poison, C4>), poison, <0, poison, 2, 3> --> + // shuf X, Y, <C0, poison, C2, poison> unsigned NumElts = cast<FixedVectorType>(Shuf.getType())->getNumElements(); SmallVector<int, 16> NewMask(NumElts); assert(NumElts < Mask.size() && @@ -2738,17 +2762,17 @@ static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) { // BinOp's operands are the result of a first element splat can be simplified to // splatting the first element of the result of the BinOp Instruction *InstCombinerImpl::simplifyBinOpSplats(ShuffleVectorInst &SVI) { - if (!match(SVI.getOperand(1), m_Undef()) || + if (!match(SVI.getOperand(1), m_Poison()) || !match(SVI.getShuffleMask(), m_ZeroMask()) || !SVI.getOperand(0)->hasOneUse()) return nullptr; Value *Op0 = SVI.getOperand(0); Value *X, *Y; - if (!match(Op0, m_BinOp(m_Shuffle(m_Value(X), m_Undef(), m_ZeroMask()), + if (!match(Op0, m_BinOp(m_Shuffle(m_Value(X), m_Poison(), m_ZeroMask()), m_Value(Y))) && !match(Op0, m_BinOp(m_Value(X), - m_Shuffle(m_Value(Y), m_Undef(), m_ZeroMask())))) + m_Shuffle(m_Value(Y), m_Poison(), m_ZeroMask())))) return nullptr; if (X->getType() != Y->getType()) return nullptr; @@ -2818,15 +2842,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { auto *XType = cast<FixedVectorType>(X->getType()); unsigned XNumElts = XType->getNumElements(); SmallVector<int, 16> ScaledMask; - if (XNumElts >= VWidth) { - assert(XNumElts % VWidth == 0 && "Unexpected vector bitcast"); - narrowShuffleMaskElts(XNumElts / VWidth, Mask, ScaledMask); - } else { - assert(VWidth % XNumElts == 0 && "Unexpected vector bitcast"); - if (!widenShuffleMaskElts(VWidth / XNumElts, Mask, ScaledMask)) - ScaledMask.clear(); - } - if (!ScaledMask.empty()) { + if (scaleShuffleMaskElts(XNumElts, Mask, ScaledMask)) { // If the shuffled source vector simplifies, cast that value to this // shuffle's type. if (auto *V = simplifyShuffleVectorInst(X, UndefValue::get(XType), @@ -2884,7 +2900,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (Instruction *I = foldIdentityPaddedShuffles(SVI)) return I; - if (match(RHS, m_Undef()) && canEvaluateShuffled(LHS, Mask)) { + if (match(RHS, m_Poison()) && canEvaluateShuffled(LHS, Mask)) { Value *V = evaluateInDifferentElementOrder(LHS, Mask, Builder); return replaceInstUsesWith(SVI, V); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 6f0cf9d9c8f1..0d8e7e92c5c8 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -190,8 +190,26 @@ bool InstCombiner::isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const { return TTI.isValidAddrSpaceCast(FromAS, ToAS); } -Value *InstCombinerImpl::EmitGEPOffset(User *GEP) { - return llvm::emitGEPOffset(&Builder, DL, GEP); +Value *InstCombinerImpl::EmitGEPOffset(GEPOperator *GEP, bool RewriteGEP) { + if (!RewriteGEP) + return llvm::emitGEPOffset(&Builder, DL, GEP); + + IRBuilderBase::InsertPointGuard Guard(Builder); + auto *Inst = dyn_cast<Instruction>(GEP); + if (Inst) + Builder.SetInsertPoint(Inst); + + Value *Offset = EmitGEPOffset(GEP); + // If a non-trivial GEP has other uses, rewrite it to avoid duplicating + // the offset arithmetic. + if (Inst && !GEP->hasOneUse() && !GEP->hasAllConstantIndices() && + !GEP->getSourceElementType()->isIntegerTy(8)) { + replaceInstUsesWith( + *Inst, Builder.CreateGEP(Builder.getInt8Ty(), GEP->getPointerOperand(), + Offset, "", GEP->getNoWrapFlags())); + eraseInstFromFunction(*Inst); + } + return Offset; } /// Legal integers and common types are considered desirable. This is used to @@ -624,9 +642,11 @@ getBinOpsForFactorization(Instruction::BinaryOps TopOpcode, BinaryOperator *Op, RHS = Op->getOperand(1); if (TopOpcode == Instruction::Add || TopOpcode == Instruction::Sub) { Constant *C; - if (match(Op, m_Shl(m_Value(), m_Constant(C)))) { + if (match(Op, m_Shl(m_Value(), m_ImmConstant(C)))) { // X << C --> X * (1 << C) - RHS = ConstantExpr::getShl(ConstantInt::get(Op->getType(), 1), C); + RHS = ConstantFoldBinaryInstruction( + Instruction::Shl, ConstantInt::get(Op->getType(), 1), C); + assert(RHS && "Constant folding of immediate constants failed"); return Instruction::Mul; } // TODO: We can add other conversions e.g. shr => div etc. @@ -790,9 +810,12 @@ Instruction *InstCombinerImpl::tryFoldInstWithCtpopWithNot(Instruction *I) { Constant *BitWidthC = ConstantInt::get(Ty, Ty->getScalarSizeInBits()); // Need extra check for icmp. Note if this check is true, it generally means // the icmp will simplify to true/false. - if (Opc == Instruction::ICmp && !cast<ICmpInst>(I)->isEquality() && - !ConstantExpr::getICmp(ICmpInst::ICMP_UGT, C, BitWidthC)->isZeroValue()) - return nullptr; + if (Opc == Instruction::ICmp && !cast<ICmpInst>(I)->isEquality()) { + Constant *Cmp = + ConstantFoldCompareInstOperands(ICmpInst::ICMP_UGT, C, BitWidthC, DL); + if (!Cmp || !Cmp->isZeroValue()) + return nullptr; + } // Check we can invert `(not x)` for free. bool Consumes = false; @@ -851,7 +874,7 @@ Instruction *InstCombinerImpl::tryFoldInstWithCtpopWithNot(Instruction *I) { // -> (arithmetic_shift Binop1((not X), Y), Amt) Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) { - const DataLayout &DL = I.getModule()->getDataLayout(); + const DataLayout &DL = I.getDataLayout(); auto IsValidBinOpc = [](unsigned Opc) { switch (Opc) { default: @@ -1347,9 +1370,13 @@ void InstCombinerImpl::freelyInvertAllUsersOf(Value *I, Value *IgnoredUser) { SI->swapProfMetadata(); break; } - case Instruction::Br: - cast<BranchInst>(U)->swapSuccessors(); // swaps prof metadata too + case Instruction::Br: { + BranchInst *BI = cast<BranchInst>(U); + BI->swapSuccessors(); // swaps prof metadata too + if (BPI) + BPI->swapSuccEdgesProbabilities(BI->getParent()); break; + } case Instruction::Xor: replaceInstUsesWith(cast<Instruction>(*U), I); // Add to worklist for DCE. @@ -1401,6 +1428,201 @@ Value *InstCombinerImpl::dyn_castNegVal(Value *V) const { return nullptr; } +// Try to fold: +// 1) (fp_binop ({s|u}itofp x), ({s|u}itofp y)) +// -> ({s|u}itofp (int_binop x, y)) +// 2) (fp_binop ({s|u}itofp x), FpC) +// -> ({s|u}itofp (int_binop x, (fpto{s|u}i FpC))) +// +// Assuming the sign of the cast for x/y is `OpsFromSigned`. +Instruction *InstCombinerImpl::foldFBinOpOfIntCastsFromSign( + BinaryOperator &BO, bool OpsFromSigned, std::array<Value *, 2> IntOps, + Constant *Op1FpC, SmallVectorImpl<WithCache<const Value *>> &OpsKnown) { + + Type *FPTy = BO.getType(); + Type *IntTy = IntOps[0]->getType(); + + unsigned IntSz = IntTy->getScalarSizeInBits(); + // This is the maximum number of inuse bits by the integer where the int -> fp + // casts are exact. + unsigned MaxRepresentableBits = + APFloat::semanticsPrecision(FPTy->getScalarType()->getFltSemantics()); + + // Preserve known number of leading bits. This can allow us to trivial nsw/nuw + // checks later on. + unsigned NumUsedLeadingBits[2] = {IntSz, IntSz}; + + // NB: This only comes up if OpsFromSigned is true, so there is no need to + // cache if between calls to `foldFBinOpOfIntCastsFromSign`. + auto IsNonZero = [&](unsigned OpNo) -> bool { + if (OpsKnown[OpNo].hasKnownBits() && + OpsKnown[OpNo].getKnownBits(SQ).isNonZero()) + return true; + return isKnownNonZero(IntOps[OpNo], SQ); + }; + + auto IsNonNeg = [&](unsigned OpNo) -> bool { + // NB: This matches the impl in ValueTracking, we just try to use cached + // knownbits here. If we ever start supporting WithCache for + // `isKnownNonNegative`, change this to an explicit call. + return OpsKnown[OpNo].getKnownBits(SQ).isNonNegative(); + }; + + // Check if we know for certain that ({s|u}itofp op) is exact. + auto IsValidPromotion = [&](unsigned OpNo) -> bool { + // Can we treat this operand as the desired sign? + if (OpsFromSigned != isa<SIToFPInst>(BO.getOperand(OpNo)) && + !IsNonNeg(OpNo)) + return false; + + // If fp precision >= bitwidth(op) then its exact. + // NB: This is slightly conservative for `sitofp`. For signed conversion, we + // can handle `MaxRepresentableBits == IntSz - 1` as the sign bit will be + // handled specially. We can't, however, increase the bound arbitrarily for + // `sitofp` as for larger sizes, it won't sign extend. + if (MaxRepresentableBits < IntSz) { + // Otherwise if its signed cast check that fp precisions >= bitwidth(op) - + // numSignBits(op). + // TODO: If we add support for `WithCache` in `ComputeNumSignBits`, change + // `IntOps[OpNo]` arguments to `KnownOps[OpNo]`. + if (OpsFromSigned) + NumUsedLeadingBits[OpNo] = IntSz - ComputeNumSignBits(IntOps[OpNo]); + // Finally for unsigned check that fp precision >= bitwidth(op) - + // numLeadingZeros(op). + else { + NumUsedLeadingBits[OpNo] = + IntSz - OpsKnown[OpNo].getKnownBits(SQ).countMinLeadingZeros(); + } + } + // NB: We could also check if op is known to be a power of 2 or zero (which + // will always be representable). Its unlikely, however, that is we are + // unable to bound op in any way we will be able to pass the overflow checks + // later on. + + if (MaxRepresentableBits < NumUsedLeadingBits[OpNo]) + return false; + // Signed + Mul also requires that op is non-zero to avoid -0 cases. + return !OpsFromSigned || BO.getOpcode() != Instruction::FMul || + IsNonZero(OpNo); + }; + + // If we have a constant rhs, see if we can losslessly convert it to an int. + if (Op1FpC != nullptr) { + // Signed + Mul req non-zero + if (OpsFromSigned && BO.getOpcode() == Instruction::FMul && + !match(Op1FpC, m_NonZeroFP())) + return nullptr; + + Constant *Op1IntC = ConstantFoldCastOperand( + OpsFromSigned ? Instruction::FPToSI : Instruction::FPToUI, Op1FpC, + IntTy, DL); + if (Op1IntC == nullptr) + return nullptr; + if (ConstantFoldCastOperand(OpsFromSigned ? Instruction::SIToFP + : Instruction::UIToFP, + Op1IntC, FPTy, DL) != Op1FpC) + return nullptr; + + // First try to keep sign of cast the same. + IntOps[1] = Op1IntC; + } + + // Ensure lhs/rhs integer types match. + if (IntTy != IntOps[1]->getType()) + return nullptr; + + if (Op1FpC == nullptr) { + if (!IsValidPromotion(1)) + return nullptr; + } + if (!IsValidPromotion(0)) + return nullptr; + + // Final we check if the integer version of the binop will not overflow. + BinaryOperator::BinaryOps IntOpc; + // Because of the precision check, we can often rule out overflows. + bool NeedsOverflowCheck = true; + // Try to conservatively rule out overflow based on the already done precision + // checks. + unsigned OverflowMaxOutputBits = OpsFromSigned ? 2 : 1; + unsigned OverflowMaxCurBits = + std::max(NumUsedLeadingBits[0], NumUsedLeadingBits[1]); + bool OutputSigned = OpsFromSigned; + switch (BO.getOpcode()) { + case Instruction::FAdd: + IntOpc = Instruction::Add; + OverflowMaxOutputBits += OverflowMaxCurBits; + break; + case Instruction::FSub: + IntOpc = Instruction::Sub; + OverflowMaxOutputBits += OverflowMaxCurBits; + break; + case Instruction::FMul: + IntOpc = Instruction::Mul; + OverflowMaxOutputBits += OverflowMaxCurBits * 2; + break; + default: + llvm_unreachable("Unsupported binop"); + } + // The precision check may have already ruled out overflow. + if (OverflowMaxOutputBits < IntSz) { + NeedsOverflowCheck = false; + // We can bound unsigned overflow from sub to in range signed value (this is + // what allows us to avoid the overflow check for sub). + if (IntOpc == Instruction::Sub) + OutputSigned = true; + } + + // Precision check did not rule out overflow, so need to check. + // TODO: If we add support for `WithCache` in `willNotOverflow`, change + // `IntOps[...]` arguments to `KnownOps[...]`. + if (NeedsOverflowCheck && + !willNotOverflow(IntOpc, IntOps[0], IntOps[1], BO, OutputSigned)) + return nullptr; + + Value *IntBinOp = Builder.CreateBinOp(IntOpc, IntOps[0], IntOps[1]); + if (auto *IntBO = dyn_cast<BinaryOperator>(IntBinOp)) { + IntBO->setHasNoSignedWrap(OutputSigned); + IntBO->setHasNoUnsignedWrap(!OutputSigned); + } + if (OutputSigned) + return new SIToFPInst(IntBinOp, FPTy); + return new UIToFPInst(IntBinOp, FPTy); +} + +// Try to fold: +// 1) (fp_binop ({s|u}itofp x), ({s|u}itofp y)) +// -> ({s|u}itofp (int_binop x, y)) +// 2) (fp_binop ({s|u}itofp x), FpC) +// -> ({s|u}itofp (int_binop x, (fpto{s|u}i FpC))) +Instruction *InstCombinerImpl::foldFBinOpOfIntCasts(BinaryOperator &BO) { + std::array<Value *, 2> IntOps = {nullptr, nullptr}; + Constant *Op1FpC = nullptr; + // Check for: + // 1) (binop ({s|u}itofp x), ({s|u}itofp y)) + // 2) (binop ({s|u}itofp x), FpC) + if (!match(BO.getOperand(0), m_SIToFP(m_Value(IntOps[0]))) && + !match(BO.getOperand(0), m_UIToFP(m_Value(IntOps[0])))) + return nullptr; + + if (!match(BO.getOperand(1), m_Constant(Op1FpC)) && + !match(BO.getOperand(1), m_SIToFP(m_Value(IntOps[1]))) && + !match(BO.getOperand(1), m_UIToFP(m_Value(IntOps[1])))) + return nullptr; + + // Cache KnownBits a bit to potentially save some analysis. + SmallVector<WithCache<const Value *>, 2> OpsKnown = {IntOps[0], IntOps[1]}; + + // Try treating x/y as coming from both `uitofp` and `sitofp`. There are + // different constraints depending on the sign of the cast. + // NB: `(uitofp nneg X)` == `(sitofp nneg X)`. + if (Instruction *R = foldFBinOpOfIntCastsFromSign(BO, /*OpsFromSigned=*/false, + IntOps, Op1FpC, OpsKnown)) + return R; + return foldFBinOpOfIntCastsFromSign(BO, /*OpsFromSigned=*/true, IntOps, + Op1FpC, OpsKnown); +} + /// A binop with a constant operand and a sign-extended boolean operand may be /// converted into a select of constants by applying the binary operation to /// the constant with the two possible values of the extended boolean (0 or -1). @@ -1448,7 +1670,7 @@ static Constant *constantFoldOperationIntoSelectOperand(Instruction &I, ConstOps.push_back(C); } - return ConstantFoldInstOperands(&I, ConstOps, I.getModule()->getDataLayout()); + return ConstantFoldInstOperands(&I, ConstOps, I.getDataLayout()); } static Value *foldOperationIntoSelectOperand(Instruction &I, SelectInst *SI, @@ -1475,21 +1697,6 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, if (SI->getType()->isIntOrIntVectorTy(1)) return nullptr; - // If it's a bitcast involving vectors, make sure it has the same number of - // elements on both sides. - if (auto *BC = dyn_cast<BitCastInst>(&Op)) { - VectorType *DestTy = dyn_cast<VectorType>(BC->getDestTy()); - VectorType *SrcTy = dyn_cast<VectorType>(BC->getSrcTy()); - - // Verify that either both or neither are vectors. - if ((SrcTy == nullptr) != (DestTy == nullptr)) - return nullptr; - - // If vectors, verify that they have the same number of elements. - if (SrcTy && SrcTy->getElementCount() != DestTy->getElementCount()) - return nullptr; - } - // Test if a FCmpInst instruction is used exclusively by a select as // part of a minimum or maximum operation. If so, refrain from doing // any other folding. This helps out other analyses which understand @@ -1841,8 +2048,8 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { if (auto *BO = dyn_cast<BinaryOperator>(V)) BO->copyIRFlags(&Inst); Module *M = Inst.getModule(); - Function *F = Intrinsic::getDeclaration( - M, Intrinsic::experimental_vector_reverse, V->getType()); + Function *F = + Intrinsic::getDeclaration(M, Intrinsic::vector_reverse, V->getType()); return CallInst::Create(F, V); }; @@ -2010,7 +2217,7 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { Value *Y, *OtherOp; if (!match(LHS, m_OneUse(m_Shuffle(m_Value(X), m_Undef(), m_Mask(MaskC)))) || - !match(MaskC, m_SplatOrUndefMask(SplatIndex)) || + !match(MaskC, m_SplatOrPoisonMask(SplatIndex)) || X->getType() != Inst.getType() || !match(RHS, m_OneUse(m_BinOp(Opcode, m_Value(Y), m_Value(OtherOp))))) return nullptr; @@ -2103,12 +2310,7 @@ Instruction *InstCombinerImpl::narrowMathIfNoOverflow(BinaryOperator &BO) { } static bool isMergedGEPInBounds(GEPOperator &GEP1, GEPOperator &GEP2) { - // At least one GEP must be inbounds. - if (!GEP1.isInBounds() && !GEP2.isInBounds()) - return false; - - return (GEP1.isInBounds() || GEP1.hasAllZeroIndices()) && - (GEP2.isInBounds() || GEP2.hasAllZeroIndices()); + return GEP1.isInBounds() && GEP2.isInBounds(); } /// Thread a GEP operation with constant indices through the constant true/false @@ -2130,13 +2332,50 @@ static Instruction *foldSelectGEP(GetElementPtrInst &GEP, // Propagate 'inbounds' and metadata from existing instructions. // Note: using IRBuilder to create the constants for efficiency. SmallVector<Value *, 4> IndexC(GEP.indices()); - bool IsInBounds = GEP.isInBounds(); + GEPNoWrapFlags NW = GEP.getNoWrapFlags(); Type *Ty = GEP.getSourceElementType(); - Value *NewTrueC = Builder.CreateGEP(Ty, TrueC, IndexC, "", IsInBounds); - Value *NewFalseC = Builder.CreateGEP(Ty, FalseC, IndexC, "", IsInBounds); + Value *NewTrueC = Builder.CreateGEP(Ty, TrueC, IndexC, "", NW); + Value *NewFalseC = Builder.CreateGEP(Ty, FalseC, IndexC, "", NW); return SelectInst::Create(Cond, NewTrueC, NewFalseC, "", nullptr, Sel); } +// Canonicalization: +// gep T, (gep i8, base, C1), (Index + C2) into +// gep T, (gep i8, base, C1 + C2 * sizeof(T)), Index +static Instruction *canonicalizeGEPOfConstGEPI8(GetElementPtrInst &GEP, + GEPOperator *Src, + InstCombinerImpl &IC) { + if (GEP.getNumIndices() != 1) + return nullptr; + auto &DL = IC.getDataLayout(); + Value *Base; + const APInt *C1; + if (!match(Src, m_PtrAdd(m_Value(Base), m_APInt(C1)))) + return nullptr; + Value *VarIndex; + const APInt *C2; + Type *PtrTy = Src->getType()->getScalarType(); + unsigned IndexSizeInBits = DL.getIndexTypeSizeInBits(PtrTy); + if (!match(GEP.getOperand(1), m_AddLike(m_Value(VarIndex), m_APInt(C2)))) + return nullptr; + if (C1->getBitWidth() != IndexSizeInBits || + C2->getBitWidth() != IndexSizeInBits) + return nullptr; + Type *BaseType = GEP.getSourceElementType(); + if (isa<ScalableVectorType>(BaseType)) + return nullptr; + APInt TypeSize(IndexSizeInBits, DL.getTypeAllocSize(BaseType)); + APInt NewOffset = TypeSize * *C2 + *C1; + if (NewOffset.isZero() || + (Src->hasOneUse() && GEP.getOperand(1)->hasOneUse())) { + Value *GEPConst = + IC.Builder.CreatePtrAdd(Base, IC.Builder.getInt(NewOffset)); + return GetElementPtrInst::Create(BaseType, GEPConst, VarIndex); + } + + return nullptr; +} + Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, GEPOperator *Src) { // Combine Indices - If the source pointer to this getelementptr instruction @@ -2145,6 +2384,9 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, if (!shouldMergeGEPs(*cast<GEPOperator>(&GEP), *Src)) return nullptr; + if (auto *I = canonicalizeGEPOfConstGEPI8(GEP, Src, *this)) + return I; + // For constant GEPs, use a more general offset-based folding approach. Type *PtrTy = Src->getType()->getScalarType(); if (GEP.hasAllConstantIndices() && @@ -2357,11 +2599,13 @@ Value *InstCombiner::getFreelyInvertedImpl(Value *V, bool WillInvertAllUses, !shouldAvoidAbsorbingNotIntoSelect(*cast<SelectInst>(V)); // Selects/min/max with invertible operands are freely invertible if (IsSelect || match(V, m_MaxOrMin(m_Value(A), m_Value(B)))) { + bool LocalDoesConsume = DoesConsume; if (!getFreelyInvertedImpl(B, B->hasOneUse(), /*Builder*/ nullptr, - DoesConsume, Depth)) + LocalDoesConsume, Depth)) return nullptr; if (Value *NotA = getFreelyInvertedImpl(A, A->hasOneUse(), Builder, - DoesConsume, Depth)) { + LocalDoesConsume, Depth)) { + DoesConsume = LocalDoesConsume; if (Builder != nullptr) { Value *NotB = getFreelyInvertedImpl(B, B->hasOneUse(), Builder, DoesConsume, Depth); @@ -2376,6 +2620,89 @@ Value *InstCombiner::getFreelyInvertedImpl(Value *V, bool WillInvertAllUses, } } + if (PHINode *PN = dyn_cast<PHINode>(V)) { + bool LocalDoesConsume = DoesConsume; + SmallVector<std::pair<Value *, BasicBlock *>, 8> IncomingValues; + for (Use &U : PN->operands()) { + BasicBlock *IncomingBlock = PN->getIncomingBlock(U); + Value *NewIncomingVal = getFreelyInvertedImpl( + U.get(), /*WillInvertAllUses=*/false, + /*Builder=*/nullptr, LocalDoesConsume, MaxAnalysisRecursionDepth - 1); + if (NewIncomingVal == nullptr) + return nullptr; + // Make sure that we can safely erase the original PHI node. + if (NewIncomingVal == V) + return nullptr; + if (Builder != nullptr) + IncomingValues.emplace_back(NewIncomingVal, IncomingBlock); + } + + DoesConsume = LocalDoesConsume; + if (Builder != nullptr) { + IRBuilderBase::InsertPointGuard Guard(*Builder); + Builder->SetInsertPoint(PN); + PHINode *NewPN = + Builder->CreatePHI(PN->getType(), PN->getNumIncomingValues()); + for (auto [Val, Pred] : IncomingValues) + NewPN->addIncoming(Val, Pred); + return NewPN; + } + return NonNull; + } + + if (match(V, m_SExtLike(m_Value(A)))) { + if (auto *AV = getFreelyInvertedImpl(A, A->hasOneUse(), Builder, + DoesConsume, Depth)) + return Builder ? Builder->CreateSExt(AV, V->getType()) : NonNull; + return nullptr; + } + + if (match(V, m_Trunc(m_Value(A)))) { + if (auto *AV = getFreelyInvertedImpl(A, A->hasOneUse(), Builder, + DoesConsume, Depth)) + return Builder ? Builder->CreateTrunc(AV, V->getType()) : NonNull; + return nullptr; + } + + // De Morgan's Laws: + // (~(A | B)) -> (~A & ~B) + // (~(A & B)) -> (~A | ~B) + auto TryInvertAndOrUsingDeMorgan = [&](Instruction::BinaryOps Opcode, + bool IsLogical, Value *A, + Value *B) -> Value * { + bool LocalDoesConsume = DoesConsume; + if (!getFreelyInvertedImpl(B, B->hasOneUse(), /*Builder=*/nullptr, + LocalDoesConsume, Depth)) + return nullptr; + if (auto *NotA = getFreelyInvertedImpl(A, A->hasOneUse(), Builder, + LocalDoesConsume, Depth)) { + auto *NotB = getFreelyInvertedImpl(B, B->hasOneUse(), Builder, + LocalDoesConsume, Depth); + DoesConsume = LocalDoesConsume; + if (IsLogical) + return Builder ? Builder->CreateLogicalOp(Opcode, NotA, NotB) : NonNull; + return Builder ? Builder->CreateBinOp(Opcode, NotA, NotB) : NonNull; + } + + return nullptr; + }; + + if (match(V, m_Or(m_Value(A), m_Value(B)))) + return TryInvertAndOrUsingDeMorgan(Instruction::And, /*IsLogical=*/false, A, + B); + + if (match(V, m_And(m_Value(A), m_Value(B)))) + return TryInvertAndOrUsingDeMorgan(Instruction::Or, /*IsLogical=*/false, A, + B); + + if (match(V, m_LogicalOr(m_Value(A), m_Value(B)))) + return TryInvertAndOrUsingDeMorgan(Instruction::And, /*IsLogical=*/true, A, + B); + + if (match(V, m_LogicalAnd(m_Value(A), m_Value(B)))) + return TryInvertAndOrUsingDeMorgan(Instruction::Or, /*IsLogical=*/true, A, + B); + return nullptr; } @@ -2384,9 +2711,9 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { SmallVector<Value *, 8> Indices(GEP.indices()); Type *GEPType = GEP.getType(); Type *GEPEltType = GEP.getSourceElementType(); - bool IsGEPSrcEleScalable = GEPEltType->isScalableTy(); - if (Value *V = simplifyGEPInst(GEPEltType, PtrOp, Indices, GEP.isInBounds(), - SQ.getWithInstruction(&GEP))) + if (Value *V = + simplifyGEPInst(GEPEltType, PtrOp, Indices, GEP.getNoWrapFlags(), + SQ.getWithInstruction(&GEP))) return replaceInstUsesWith(GEP, V); // For vector geps, use the generic demanded vector support. @@ -2451,6 +2778,30 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (MadeChange) return &GEP; + // Canonicalize constant GEPs to i8 type. + if (!GEPEltType->isIntegerTy(8) && GEP.hasAllConstantIndices()) { + APInt Offset(DL.getIndexTypeSizeInBits(GEPType), 0); + if (GEP.accumulateConstantOffset(DL, Offset)) + return replaceInstUsesWith( + GEP, Builder.CreatePtrAdd(PtrOp, Builder.getInt(Offset), "", + GEP.getNoWrapFlags())); + } + + // Canonicalize + // - scalable GEPs to an explicit offset using the llvm.vscale intrinsic. + // This has better support in BasicAA. + // - gep i32 p, mul(O, C) -> gep i8, p, mul(O, C*4) to fold the two + // multiplies together. + if (GEPEltType->isScalableTy() || + (!GEPEltType->isIntegerTy(8) && GEP.getNumIndices() == 1 && + match(GEP.getOperand(1), + m_OneUse(m_CombineOr(m_Mul(m_Value(), m_ConstantInt()), + m_Shl(m_Value(), m_ConstantInt())))))) { + Value *Offset = EmitGEPOffset(cast<GEPOperator>(&GEP)); + return replaceInstUsesWith( + GEP, Builder.CreatePtrAdd(PtrOp, Offset, "", GEP.getNoWrapFlags())); + } + // Check to see if the inputs to the PHI node are getelementptr instructions. if (auto *PN = dyn_cast<PHINode>(PtrOp)) { auto *Op1 = dyn_cast<GetElementPtrInst>(PN->getOperand(0)); @@ -2560,9 +2911,7 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (Instruction *I = visitGEPOfGEP(GEP, Src)) return I; - // Skip if GEP source element type is scalable. The type alloc size is unknown - // at compile-time. - if (GEP.getNumIndices() == 1 && !IsGEPSrcEleScalable) { + if (GEP.getNumIndices() == 1) { unsigned AS = GEP.getPointerAddressSpace(); if (GEP.getOperand(1)->getType()->getScalarSizeInBits() == DL.getIndexSizeInBits(AS)) { @@ -2590,19 +2939,57 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { }); return Changed ? &GEP : nullptr; } - } else { + } else if (auto *ExactIns = + dyn_cast<PossiblyExactOperator>(GEP.getOperand(1))) { // Canonicalize (gep T* X, V / sizeof(T)) to (gep i8* X, V) Value *V; - if ((has_single_bit(TyAllocSize) && - match(GEP.getOperand(1), - m_Exact(m_Shr(m_Value(V), - m_SpecificInt(countr_zero(TyAllocSize)))))) || - match(GEP.getOperand(1), - m_Exact(m_IDiv(m_Value(V), m_SpecificInt(TyAllocSize))))) { - GetElementPtrInst *NewGEP = GetElementPtrInst::Create( - Builder.getInt8Ty(), GEP.getPointerOperand(), V); - NewGEP->setIsInBounds(GEP.isInBounds()); - return NewGEP; + if (ExactIns->isExact()) { + if ((has_single_bit(TyAllocSize) && + match(GEP.getOperand(1), + m_Shr(m_Value(V), + m_SpecificInt(countr_zero(TyAllocSize))))) || + match(GEP.getOperand(1), + m_IDiv(m_Value(V), m_SpecificInt(TyAllocSize)))) { + return GetElementPtrInst::Create(Builder.getInt8Ty(), + GEP.getPointerOperand(), V, + GEP.getNoWrapFlags()); + } + } + if (ExactIns->isExact() && ExactIns->hasOneUse()) { + // Try to canonicalize non-i8 element type to i8 if the index is an + // exact instruction. If the index is an exact instruction (div/shr) + // with a constant RHS, we can fold the non-i8 element scale into the + // div/shr (similiar to the mul case, just inverted). + const APInt *C; + std::optional<APInt> NewC; + if (has_single_bit(TyAllocSize) && + match(ExactIns, m_Shr(m_Value(V), m_APInt(C))) && + C->uge(countr_zero(TyAllocSize))) + NewC = *C - countr_zero(TyAllocSize); + else if (match(ExactIns, m_UDiv(m_Value(V), m_APInt(C)))) { + APInt Quot; + uint64_t Rem; + APInt::udivrem(*C, TyAllocSize, Quot, Rem); + if (Rem == 0) + NewC = Quot; + } else if (match(ExactIns, m_SDiv(m_Value(V), m_APInt(C)))) { + APInt Quot; + int64_t Rem; + APInt::sdivrem(*C, TyAllocSize, Quot, Rem); + // For sdiv we need to make sure we arent creating INT_MIN / -1. + if (!Quot.isAllOnes() && Rem == 0) + NewC = Quot; + } + + if (NewC.has_value()) { + Value *NewOp = Builder.CreateBinOp( + static_cast<Instruction::BinaryOps>(ExactIns->getOpcode()), V, + ConstantInt::get(V->getType(), *NewC)); + cast<BinaryOperator>(NewOp)->setIsExact(); + return GetElementPtrInst::Create(Builder.getInt8Ty(), + GEP.getPointerOperand(), NewOp, + GEP.getNoWrapFlags()); + } } } } @@ -2612,6 +2999,14 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { return nullptr; if (GEP.getNumIndices() == 1) { + // We can only preserve inbounds if the original gep is inbounds, the add + // is nsw, and the add operands are non-negative. + auto CanPreserveInBounds = [&](bool AddIsNSW, Value *Idx1, Value *Idx2) { + SimplifyQuery Q = SQ.getWithInstruction(&GEP); + return GEP.isInBounds() && AddIsNSW && isKnownNonNegative(Idx1, Q) && + isKnownNonNegative(Idx2, Q); + }; + // Try to replace ADD + GEP with GEP + GEP. Value *Idx1, *Idx2; if (match(GEP.getOperand(1), @@ -2621,10 +3016,15 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { // as: // %newptr = getelementptr i32, ptr %ptr, i64 %idx1 // %newgep = getelementptr i32, ptr %newptr, i64 %idx2 - auto *NewPtr = Builder.CreateGEP(GEP.getResultElementType(), - GEP.getPointerOperand(), Idx1); - return GetElementPtrInst::Create(GEP.getResultElementType(), NewPtr, - Idx2); + bool IsInBounds = CanPreserveInBounds( + cast<OverflowingBinaryOperator>(GEP.getOperand(1))->hasNoSignedWrap(), + Idx1, Idx2); + auto *NewPtr = + Builder.CreateGEP(GEP.getSourceElementType(), GEP.getPointerOperand(), + Idx1, "", IsInBounds); + return replaceInstUsesWith( + GEP, Builder.CreateGEP(GEP.getSourceElementType(), NewPtr, Idx2, "", + IsInBounds)); } ConstantInt *C; if (match(GEP.getOperand(1), m_OneUse(m_SExtLike(m_OneUse(m_NSWAdd( @@ -2635,12 +3035,17 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { // as: // %newptr = getelementptr i32, ptr %ptr, i32 %idx1 // %newgep = getelementptr i32, ptr %newptr, i32 idx2 + bool IsInBounds = CanPreserveInBounds( + /*IsNSW=*/true, Idx1, C); auto *NewPtr = Builder.CreateGEP( - GEP.getResultElementType(), GEP.getPointerOperand(), - Builder.CreateSExt(Idx1, GEP.getOperand(1)->getType())); - return GetElementPtrInst::Create( - GEP.getResultElementType(), NewPtr, - Builder.CreateSExt(C, GEP.getOperand(1)->getType())); + GEP.getSourceElementType(), GEP.getPointerOperand(), + Builder.CreateSExt(Idx1, GEP.getOperand(1)->getType()), "", + IsInBounds); + return replaceInstUsesWith( + GEP, + Builder.CreateGEP(GEP.getSourceElementType(), NewPtr, + Builder.CreateSExt(C, GEP.getOperand(1)->getType()), + "", IsInBounds)); } } @@ -2846,10 +3251,10 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) { // If we are removing an alloca with a dbg.declare, insert dbg.value calls // before each store. SmallVector<DbgVariableIntrinsic *, 8> DVIs; - SmallVector<DPValue *, 8> DPVs; + SmallVector<DbgVariableRecord *, 8> DVRs; std::unique_ptr<DIBuilder> DIB; if (isa<AllocaInst>(MI)) { - findDbgUsers(DVIs, &MI, &DPVs); + findDbgUsers(DVIs, &MI, &DVRs); DIB.reset(new DIBuilder(*MI.getModule(), /*AllowUnresolved=*/false)); } @@ -2889,9 +3294,9 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) { for (auto *DVI : DVIs) if (DVI->isAddressOfVariable()) ConvertDebugDeclareToDebugValue(DVI, SI, *DIB); - for (auto *DPV : DPVs) - if (DPV->isAddressOfVariable()) - ConvertDebugDeclareToDebugValue(DPV, SI, *DIB); + for (auto *DVR : DVRs) + if (DVR->isAddressOfVariable()) + ConvertDebugDeclareToDebugValue(DVR, SI, *DIB); } else { // Casts, GEP, or anything else: we're about to delete this instruction, // so it can not have any valid uses. @@ -2936,9 +3341,9 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) { for (auto *DVI : DVIs) if (DVI->isAddressOfVariable() || DVI->getExpression()->startsWithDeref()) DVI->eraseFromParent(); - for (auto *DPV : DPVs) - if (DPV->isAddressOfVariable() || DPV->getExpression()->startsWithDeref()) - DPV->eraseFromParent(); + for (auto *DVR : DVRs) + if (DVR->isAddressOfVariable() || DVR->getExpression()->startsWithDeref()) + DVR->eraseFromParent(); return eraseInstFromFunction(MI); } @@ -3085,8 +3490,22 @@ Instruction *InstCombinerImpl::visitFree(CallInst &FI, Value *Op) { } Instruction *InstCombinerImpl::visitReturnInst(ReturnInst &RI) { - // Nothing for now. - return nullptr; + Value *RetVal = RI.getReturnValue(); + if (!RetVal || !AttributeFuncs::isNoFPClassCompatibleType(RetVal->getType())) + return nullptr; + + Function *F = RI.getFunction(); + FPClassTest ReturnClass = F->getAttributes().getRetNoFPClass(); + if (ReturnClass == fcNone) + return nullptr; + + KnownFPClass KnownClass; + Value *Simplified = + SimplifyDemandedUseFPClass(RetVal, ~ReturnClass, KnownClass, 0, &RI); + if (!Simplified) + return nullptr; + + return ReturnInst::Create(RI.getContext(), Simplified); } // WARNING: keep in sync with SimplifyCFGOpt::simplifyUnreachable()! @@ -3183,14 +3602,17 @@ void InstCombinerImpl::handleUnreachableFrom( if (Inst.isEHPad() || Inst.getType()->isTokenTy()) continue; // RemoveDIs: erase debug-info on this instruction manually. - Inst.dropDbgValues(); + Inst.dropDbgRecords(); eraseInstFromFunction(Inst); MadeIRChange = true; } - // RemoveDIs: to match behaviour in dbg.value mode, drop debug-info on - // terminator too. - BB->getTerminator()->dropDbgValues(); + SmallVector<Value *> Changed; + if (handleUnreachableTerminator(BB->getTerminator(), Changed)) { + MadeIRChange = true; + for (Value *V : Changed) + addToWorklist(cast<Instruction>(V)); + } // Handle potentially dead successors. for (BasicBlock *Succ : successors(BB)) @@ -3234,6 +3656,8 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) { if (match(Cond, m_Not(m_Value(X))) && !isa<Constant>(X)) { // Swap Destinations and condition... BI.swapSuccessors(); + if (BPI) + BPI->swapSuccEdgesProbabilities(BI.getParent()); return replaceOperand(BI, 0, X); } @@ -3247,6 +3671,8 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) { Value *NotX = Builder.CreateNot(X, "not." + X->getName()); Value *Or = Builder.CreateLogicalOr(NotX, Y); BI.swapSuccessors(); + if (BPI) + BPI->swapSuccEdgesProbabilities(BI.getParent()); return replaceOperand(BI, 0, Or); } @@ -3263,6 +3689,8 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) { auto *Cmp = cast<CmpInst>(Cond); Cmp->setPredicate(CmpInst::getInversePredicate(Pred)); BI.swapSuccessors(); + if (BPI) + BPI->swapSuccEdgesProbabilities(BI.getParent()); Worklist.push(Cmp); return &BI; } @@ -3281,6 +3709,38 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) { return nullptr; } +// Replaces (switch (select cond, X, C)/(select cond, C, X)) with (switch X) if +// we can prove that both (switch C) and (switch X) go to the default when cond +// is false/true. +static Value *simplifySwitchOnSelectUsingRanges(SwitchInst &SI, + SelectInst *Select, + bool IsTrueArm) { + unsigned CstOpIdx = IsTrueArm ? 1 : 2; + auto *C = dyn_cast<ConstantInt>(Select->getOperand(CstOpIdx)); + if (!C) + return nullptr; + + BasicBlock *CstBB = SI.findCaseValue(C)->getCaseSuccessor(); + if (CstBB != SI.getDefaultDest()) + return nullptr; + Value *X = Select->getOperand(3 - CstOpIdx); + ICmpInst::Predicate Pred; + const APInt *RHSC; + if (!match(Select->getCondition(), + m_ICmp(Pred, m_Specific(X), m_APInt(RHSC)))) + return nullptr; + if (IsTrueArm) + Pred = ICmpInst::getInversePredicate(Pred); + + // See whether we can replace the select with X + ConstantRange CR = ConstantRange::makeExactICmpRegion(Pred, *RHSC); + for (auto Case : SI.cases()) + if (!CR.contains(Case.getCaseValue()->getValue())) + return nullptr; + + return X; +} + Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) { Value *Cond = SI.getCondition(); Value *Op0; @@ -3354,6 +3814,16 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) { } } + // Fold switch(select cond, X, Y) into switch(X/Y) if possible + if (auto *Select = dyn_cast<SelectInst>(Cond)) { + if (Value *V = + simplifySwitchOnSelectUsingRanges(SI, Select, /*IsTrueArm=*/true)) + return replaceOperand(SI, 0, V); + if (Value *V = + simplifySwitchOnSelectUsingRanges(SI, Select, /*IsTrueArm=*/false)) + return replaceOperand(SI, 0, V); + } + KnownBits Known = computeKnownBits(Cond, 0, &SI); unsigned LeadingKnownZeros = Known.countMinLeadingZeros(); unsigned LeadingKnownOnes = Known.countMinLeadingOnes(); @@ -3407,7 +3877,7 @@ InstCombinerImpl::foldExtractOfOverflowIntrinsic(ExtractValueInst &EV) { Intrinsic::ID OvID = WO->getIntrinsicID(); const APInt *C = nullptr; - if (match(WO->getRHS(), m_APIntAllowUndef(C))) { + if (match(WO->getRHS(), m_APIntAllowPoison(C))) { if (*EV.idx_begin() == 0 && (OvID == Intrinsic::smul_with_overflow || OvID == Intrinsic::umul_with_overflow)) { // extractvalue (any_mul_with_overflow X, -1), 0 --> -X @@ -3451,6 +3921,17 @@ InstCombinerImpl::foldExtractOfOverflowIntrinsic(ExtractValueInst &EV) { WO->getLHS()->getType()->isIntOrIntVectorTy(1)) return BinaryOperator::CreateAnd(WO->getLHS(), WO->getRHS()); + // extractvalue (umul_with_overflow X, X), 1 -> X u> 2^(N/2)-1 + if (OvID == Intrinsic::umul_with_overflow && WO->getLHS() == WO->getRHS()) { + unsigned BitWidth = WO->getLHS()->getType()->getScalarSizeInBits(); + // Only handle even bitwidths for performance reasons. + if (BitWidth % 2 == 0) + return new ICmpInst( + ICmpInst::ICMP_UGT, WO->getLHS(), + ConstantInt::get(WO->getLHS()->getType(), + APInt::getLowBitsSet(BitWidth, BitWidth / 2))); + } + // If only the overflow result is used, and the right hand side is a // constant (or constant splat), we can remove the intrinsic by directly // checking for overflow. @@ -3577,6 +4058,12 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { if (Instruction *Res = foldOpIntoPhi(EV, PN)) return Res; + // Canonicalize extract (select Cond, TV, FV) + // -> select cond, (extract TV), (extract FV) + if (auto *SI = dyn_cast<SelectInst>(Agg)) + if (Instruction *R = FoldOpIntoSelect(EV, SI, /*FoldWithMultiUse=*/true)) + return R; + // We could simplify extracts from other values. Note that nested extracts may // already be simplified implicitly by the above: extract (extract (insert) ) // will be translated into extract ( insert ( extract ) ) first and then just @@ -3612,6 +4099,7 @@ static bool isCatchAll(EHPersonality Personality, Constant *TypeInfo) { case EHPersonality::CoreCLR: case EHPersonality::Wasm_CXX: case EHPersonality::XL_CXX: + case EHPersonality::ZOS_CXX: return TypeInfo->isNullValue(); } llvm_unreachable("invalid enum"); @@ -3911,8 +4399,8 @@ Instruction *InstCombinerImpl::visitLandingPadInst(LandingPadInst &LI) { if (MakeNewInstruction) { LandingPadInst *NLI = LandingPadInst::Create(LI.getType(), NewClauses.size()); - for (unsigned i = 0, e = NewClauses.size(); i != e; ++i) - NLI->addClause(NewClauses[i]); + for (Constant *C : NewClauses) + NLI->addClause(C); // A landing pad with no clauses must have the cleanup flag set. It is // theoretically possible, though highly unlikely, that we eliminated all // clauses. If so, force the cleanup flag to true. @@ -3980,7 +4468,7 @@ InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating(FreezeInst &OrigFI) { return nullptr; } - OrigOpInst->dropPoisonGeneratingFlagsAndMetadata(); + OrigOpInst->dropPoisonGeneratingAnnotations(); // If all operands are guaranteed to be non-poison, we can drop freeze. if (!MaybePoisonOperand) @@ -4051,7 +4539,7 @@ Instruction *InstCombinerImpl::foldFreezeIntoRecurrence(FreezeInst &FI, } for (Instruction *I : DropFlags) - I->dropPoisonGeneratingFlagsAndMetadata(); + I->dropPoisonGeneratingAnnotations(); if (StartNeedsFreeze) { Builder.SetInsertPoint(StartBB->getTerminator()); @@ -4305,12 +4793,32 @@ bool InstCombinerImpl::tryToSinkInstruction(Instruction *I, // mark the location undef: we know it was supposed to receive a new location // here, but that computation has been sunk. SmallVector<DbgVariableIntrinsic *, 2> DbgUsers; - findDbgUsers(DbgUsers, I); + SmallVector<DbgVariableRecord *, 2> DbgVariableRecords; + findDbgUsers(DbgUsers, I, &DbgVariableRecords); + if (!DbgUsers.empty()) + tryToSinkInstructionDbgValues(I, InsertPos, SrcBlock, DestBlock, DbgUsers); + if (!DbgVariableRecords.empty()) + tryToSinkInstructionDbgVariableRecords(I, InsertPos, SrcBlock, DestBlock, + DbgVariableRecords); + + // PS: there are numerous flaws with this behaviour, not least that right now + // assignments can be re-ordered past other assignments to the same variable + // if they use different Values. Creating more undef assignements can never be + // undone. And salvaging all users outside of this block can un-necessarily + // alter the lifetime of the live-value that the variable refers to. + // Some of these things can be resolved by tolerating debug use-before-defs in + // LLVM-IR, however it depends on the instruction-referencing CodeGen backend + // being used for more architectures. + + return true; +} +void InstCombinerImpl::tryToSinkInstructionDbgValues( + Instruction *I, BasicBlock::iterator InsertPos, BasicBlock *SrcBlock, + BasicBlock *DestBlock, SmallVectorImpl<DbgVariableIntrinsic *> &DbgUsers) { // For all debug values in the destination block, the sunk instruction // will still be available, so they do not need to be dropped. SmallVector<DbgVariableIntrinsic *, 2> DbgUsersToSalvage; - SmallVector<DPValue *, 2> DPValuesToSalvage; for (auto &DbgUser : DbgUsers) if (DbgUser->getParent() != DestBlock) DbgUsersToSalvage.push_back(DbgUser); @@ -4354,10 +4862,7 @@ bool InstCombinerImpl::tryToSinkInstruction(Instruction *I, // Perform salvaging without the clones, then sink the clones. if (!DIIClones.empty()) { - // RemoveDIs: pass in empty vector of DPValues until we get to instrumenting - // this pass. - SmallVector<DPValue *, 1> DummyDPValues; - salvageDebugInfoForDbgValues(*I, DbgUsersToSalvage, DummyDPValues); + salvageDebugInfoForDbgValues(*I, DbgUsersToSalvage, {}); // The clones are in reverse order of original appearance, reverse again to // maintain the original order. for (auto &DIIClone : llvm::reverse(DIIClones)) { @@ -4365,8 +4870,134 @@ bool InstCombinerImpl::tryToSinkInstruction(Instruction *I, LLVM_DEBUG(dbgs() << "SINK: " << *DIIClone << '\n'); } } +} - return true; +void InstCombinerImpl::tryToSinkInstructionDbgVariableRecords( + Instruction *I, BasicBlock::iterator InsertPos, BasicBlock *SrcBlock, + BasicBlock *DestBlock, + SmallVectorImpl<DbgVariableRecord *> &DbgVariableRecords) { + // Implementation of tryToSinkInstructionDbgValues, but for the + // DbgVariableRecord of variable assignments rather than dbg.values. + + // Fetch all DbgVariableRecords not already in the destination. + SmallVector<DbgVariableRecord *, 2> DbgVariableRecordsToSalvage; + for (auto &DVR : DbgVariableRecords) + if (DVR->getParent() != DestBlock) + DbgVariableRecordsToSalvage.push_back(DVR); + + // Fetch a second collection, of DbgVariableRecords in the source block that + // we're going to sink. + SmallVector<DbgVariableRecord *> DbgVariableRecordsToSink; + for (DbgVariableRecord *DVR : DbgVariableRecordsToSalvage) + if (DVR->getParent() == SrcBlock) + DbgVariableRecordsToSink.push_back(DVR); + + // Sort DbgVariableRecords according to their position in the block. This is a + // partial order: DbgVariableRecords attached to different instructions will + // be ordered by the instruction order, but DbgVariableRecords attached to the + // same instruction won't have an order. + auto Order = [](DbgVariableRecord *A, DbgVariableRecord *B) -> bool { + return B->getInstruction()->comesBefore(A->getInstruction()); + }; + llvm::stable_sort(DbgVariableRecordsToSink, Order); + + // If there are two assignments to the same variable attached to the same + // instruction, the ordering between the two assignments is important. Scan + // for this (rare) case and establish which is the last assignment. + using InstVarPair = std::pair<const Instruction *, DebugVariable>; + SmallDenseMap<InstVarPair, DbgVariableRecord *> FilterOutMap; + if (DbgVariableRecordsToSink.size() > 1) { + SmallDenseMap<InstVarPair, unsigned> CountMap; + // Count how many assignments to each variable there is per instruction. + for (DbgVariableRecord *DVR : DbgVariableRecordsToSink) { + DebugVariable DbgUserVariable = + DebugVariable(DVR->getVariable(), DVR->getExpression(), + DVR->getDebugLoc()->getInlinedAt()); + CountMap[std::make_pair(DVR->getInstruction(), DbgUserVariable)] += 1; + } + + // If there are any instructions with two assignments, add them to the + // FilterOutMap to record that they need extra filtering. + SmallPtrSet<const Instruction *, 4> DupSet; + for (auto It : CountMap) { + if (It.second > 1) { + FilterOutMap[It.first] = nullptr; + DupSet.insert(It.first.first); + } + } + + // For all instruction/variable pairs needing extra filtering, find the + // latest assignment. + for (const Instruction *Inst : DupSet) { + for (DbgVariableRecord &DVR : + llvm::reverse(filterDbgVars(Inst->getDbgRecordRange()))) { + DebugVariable DbgUserVariable = + DebugVariable(DVR.getVariable(), DVR.getExpression(), + DVR.getDebugLoc()->getInlinedAt()); + auto FilterIt = + FilterOutMap.find(std::make_pair(Inst, DbgUserVariable)); + if (FilterIt == FilterOutMap.end()) + continue; + if (FilterIt->second != nullptr) + continue; + FilterIt->second = &DVR; + } + } + } + + // Perform cloning of the DbgVariableRecords that we plan on sinking, filter + // out any duplicate assignments identified above. + SmallVector<DbgVariableRecord *, 2> DVRClones; + SmallSet<DebugVariable, 4> SunkVariables; + for (DbgVariableRecord *DVR : DbgVariableRecordsToSink) { + if (DVR->Type == DbgVariableRecord::LocationType::Declare) + continue; + + DebugVariable DbgUserVariable = + DebugVariable(DVR->getVariable(), DVR->getExpression(), + DVR->getDebugLoc()->getInlinedAt()); + + // For any variable where there were multiple assignments in the same place, + // ignore all but the last assignment. + if (!FilterOutMap.empty()) { + InstVarPair IVP = std::make_pair(DVR->getInstruction(), DbgUserVariable); + auto It = FilterOutMap.find(IVP); + + // Filter out. + if (It != FilterOutMap.end() && It->second != DVR) + continue; + } + + if (!SunkVariables.insert(DbgUserVariable).second) + continue; + + if (DVR->isDbgAssign()) + continue; + + DVRClones.emplace_back(DVR->clone()); + LLVM_DEBUG(dbgs() << "CLONE: " << *DVRClones.back() << '\n'); + } + + // Perform salvaging without the clones, then sink the clones. + if (DVRClones.empty()) + return; + + salvageDebugInfoForDbgValues(*I, {}, DbgVariableRecordsToSalvage); + + // The clones are in reverse order of original appearance. Assert that the + // head bit is set on the iterator as we _should_ have received it via + // getFirstInsertionPt. Inserting like this will reverse the clone order as + // we'll repeatedly insert at the head, such as: + // DVR-3 (third insertion goes here) + // DVR-2 (second insertion goes here) + // DVR-1 (first insertion goes here) + // Any-Prior-DVRs + // InsertPtInst + assert(InsertPos.getHeadBit()); + for (DbgVariableRecord *DVRClone : DVRClones) { + InsertPos->getParent()->insertDbgRecordBefore(DVRClone, InsertPos); + LLVM_DEBUG(dbgs() << "SINK: " << *DVRClone << '\n'); + } } bool InstCombinerImpl::run() { @@ -4412,31 +5043,24 @@ bool InstCombinerImpl::run() { BasicBlock *UserParent = nullptr; unsigned NumUsers = 0; - for (auto *U : I->users()) { - if (U->isDroppable()) + for (Use &U : I->uses()) { + User *User = U.getUser(); + if (User->isDroppable()) continue; if (NumUsers > MaxSinkNumUsers) return std::nullopt; - Instruction *UserInst = cast<Instruction>(U); + Instruction *UserInst = cast<Instruction>(User); // Special handling for Phi nodes - get the block the use occurs in. - if (PHINode *PN = dyn_cast<PHINode>(UserInst)) { - for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) { - if (PN->getIncomingValue(i) == I) { - // Bail out if we have uses in different blocks. We don't do any - // sophisticated analysis (i.e finding NearestCommonDominator of - // these use blocks). - if (UserParent && UserParent != PN->getIncomingBlock(i)) - return std::nullopt; - UserParent = PN->getIncomingBlock(i); - } - } - assert(UserParent && "expected to find user block!"); - } else { - if (UserParent && UserParent != UserInst->getParent()) - return std::nullopt; - UserParent = UserInst->getParent(); - } + BasicBlock *UserBB = UserInst->getParent(); + if (PHINode *PN = dyn_cast<PHINode>(UserInst)) + UserBB = PN->getIncomingBlock(U); + // Bail out if we have uses in different blocks. We don't do any + // sophisticated analysis (i.e finding NearestCommonDominator of these + // use blocks). + if (UserParent && UserParent != UserBB) + return std::nullopt; + UserParent = UserBB; // Make sure these checks are done only once, naturally we do the checks // the first time we get the userparent, this will save compile time. @@ -4495,7 +5119,7 @@ bool InstCombinerImpl::run() { #ifndef NDEBUG std::string OrigI; #endif - LLVM_DEBUG(raw_string_ostream SS(OrigI); I->print(SS); OrigI = SS.str();); + LLVM_DEBUG(raw_string_ostream SS(OrigI); I->print(SS);); LLVM_DEBUG(dbgs() << "IC: Visiting: " << OrigI << '\n'); if (Instruction *Result = visit(*I)) { @@ -4755,8 +5379,9 @@ static bool combineInstructionsOverFunction( Function &F, InstructionWorklist &Worklist, AliasAnalysis *AA, AssumptionCache &AC, TargetLibraryInfo &TLI, TargetTransformInfo &TTI, DominatorTree &DT, OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, - ProfileSummaryInfo *PSI, LoopInfo *LI, const InstCombineOptions &Opts) { - auto &DL = F.getParent()->getDataLayout(); + BranchProbabilityInfo *BPI, ProfileSummaryInfo *PSI, LoopInfo *LI, + const InstCombineOptions &Opts) { + auto &DL = F.getDataLayout(); /// Builder - This is an IRBuilder that automatically inserts new /// instructions into the worklist when they are created. @@ -4793,7 +5418,7 @@ static bool combineInstructionsOverFunction( << F.getName() << "\n"); InstCombinerImpl IC(Worklist, Builder, F.hasMinSize(), AA, AC, TLI, TTI, DT, - ORE, BFI, PSI, DL, LI); + ORE, BFI, BPI, PSI, DL, LI); IC.MaxArraySizeForCombine = MaxArraySize; bool MadeChangeInThisIteration = IC.prepareWorklist(F, RPOT); MadeChangeInThisIteration |= IC.run(); @@ -4804,7 +5429,8 @@ static bool combineInstructionsOverFunction( if (Iteration > Opts.MaxIterations) { report_fatal_error( "Instruction Combining did not reach a fixpoint after " + - Twine(Opts.MaxIterations) + " iterations"); + Twine(Opts.MaxIterations) + " iterations", + /*GenCrashDiag=*/false); } } @@ -4853,9 +5479,10 @@ PreservedAnalyses InstCombinePass::run(Function &F, MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); auto *BFI = (PSI && PSI->hasProfileSummary()) ? &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; + auto *BPI = AM.getCachedResult<BranchProbabilityAnalysis>(F); if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE, - BFI, PSI, LI, Options)) + BFI, BPI, PSI, LI, Options)) // No changes, all analyses are preserved. return PreservedAnalyses::all(); @@ -4902,9 +5529,14 @@ bool InstructionCombiningPass::runOnFunction(Function &F) { (PSI && PSI->hasProfileSummary()) ? &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() : nullptr; + BranchProbabilityInfo *BPI = nullptr; + if (auto *WrapperPass = + getAnalysisIfAvailable<BranchProbabilityInfoWrapperPass>()) + BPI = &WrapperPass->getBPI(); return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE, - BFI, PSI, LI, InstCombineOptions()); + BFI, BPI, PSI, LI, + InstCombineOptions()); } char InstructionCombiningPass::ID = 0; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp index caab98c732ee..9fb1df7ab2b7 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -43,6 +43,7 @@ #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/EHPersonalities.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalValue.h" @@ -107,7 +108,7 @@ static const uint64_t kMIPS32_ShadowOffset32 = 0x0aaa0000; static const uint64_t kMIPS64_ShadowOffset64 = 1ULL << 37; static const uint64_t kAArch64_ShadowOffset64 = 1ULL << 36; static const uint64_t kLoongArch64_ShadowOffset64 = 1ULL << 46; -static const uint64_t kRISCV64_ShadowOffset64 = 0xd55550000; +static const uint64_t kRISCV64_ShadowOffset64 = kDynamicShadowSentinel; static const uint64_t kFreeBSD_ShadowOffset32 = 1ULL << 30; static const uint64_t kFreeBSD_ShadowOffset64 = 1ULL << 46; static const uint64_t kFreeBSDAArch64_ShadowOffset64 = 1ULL << 47; @@ -642,6 +643,73 @@ static uint64_t GetCtorAndDtorPriority(Triple &TargetTriple) { } namespace { +/// Helper RAII class to post-process inserted asan runtime calls during a +/// pass on a single Function. Upon end of scope, detects and applies the +/// required funclet OpBundle. +class RuntimeCallInserter { + Function *OwnerFn = nullptr; + bool TrackInsertedCalls = false; + SmallVector<CallInst *> InsertedCalls; + +public: + RuntimeCallInserter(Function &Fn) : OwnerFn(&Fn) { + if (Fn.hasPersonalityFn()) { + auto Personality = classifyEHPersonality(Fn.getPersonalityFn()); + if (isScopedEHPersonality(Personality)) + TrackInsertedCalls = true; + } + } + + ~RuntimeCallInserter() { + if (InsertedCalls.empty()) + return; + assert(TrackInsertedCalls && "Calls were wrongly tracked"); + + DenseMap<BasicBlock *, ColorVector> BlockColors = colorEHFunclets(*OwnerFn); + for (CallInst *CI : InsertedCalls) { + BasicBlock *BB = CI->getParent(); + assert(BB && "Instruction doesn't belong to a BasicBlock"); + assert(BB->getParent() == OwnerFn && + "Instruction doesn't belong to the expected Function!"); + + ColorVector &Colors = BlockColors[BB]; + // funclet opbundles are only valid in monochromatic BBs. + // Note that unreachable BBs are seen as colorless by colorEHFunclets() + // and will be DCE'ed later. + if (Colors.empty()) + continue; + if (Colors.size() != 1) { + OwnerFn->getContext().emitError( + "Instruction's BasicBlock is not monochromatic"); + continue; + } + + BasicBlock *Color = Colors.front(); + Instruction *EHPad = Color->getFirstNonPHI(); + + if (EHPad && EHPad->isEHPad()) { + // Replace CI with a clone with an added funclet OperandBundle + OperandBundleDef OB("funclet", EHPad); + auto *NewCall = + CallBase::addOperandBundle(CI, LLVMContext::OB_funclet, OB, CI); + NewCall->copyMetadata(*CI); + CI->replaceAllUsesWith(NewCall); + CI->eraseFromParent(); + } + } + } + + CallInst *createRuntimeCall(IRBuilder<> &IRB, FunctionCallee Callee, + ArrayRef<Value *> Args = {}, + const Twine &Name = "") { + assert(IRB.GetInsertBlock()->getParent() == OwnerFn); + + CallInst *Inst = IRB.CreateCall(Callee, Args, Name, nullptr); + if (TrackInsertedCalls) + InsertedCalls.push_back(Inst); + return Inst; + } +}; /// AddressSanitizer: instrument the code in module to find memory bugs. struct AddressSanitizer { @@ -679,7 +747,7 @@ struct AddressSanitizer { } TypeSize getAllocaSizeInBytes(const AllocaInst &AI) const { - return *AI.getAllocationSize(AI.getModule()->getDataLayout()); + return *AI.getAllocationSize(AI.getDataLayout()); } /// Check if we want (and can) handle this alloca. @@ -691,12 +759,14 @@ struct AddressSanitizer { void instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis, InterestingMemoryOperand &O, bool UseCalls, - const DataLayout &DL); - void instrumentPointerComparisonOrSubtraction(Instruction *I); + const DataLayout &DL, RuntimeCallInserter &RTCI); + void instrumentPointerComparisonOrSubtraction(Instruction *I, + RuntimeCallInserter &RTCI); void instrumentAddress(Instruction *OrigIns, Instruction *InsertBefore, Value *Addr, MaybeAlign Alignment, uint32_t TypeStoreSize, bool IsWrite, - Value *SizeArgument, bool UseCalls, uint32_t Exp); + Value *SizeArgument, bool UseCalls, uint32_t Exp, + RuntimeCallInserter &RTCI); Instruction *instrumentAMDGPUAddress(Instruction *OrigIns, Instruction *InsertBefore, Value *Addr, uint32_t TypeStoreSize, bool IsWrite, @@ -707,20 +777,22 @@ struct AddressSanitizer { Instruction *InsertBefore, Value *Addr, TypeSize TypeStoreSize, bool IsWrite, Value *SizeArgument, bool UseCalls, - uint32_t Exp); + uint32_t Exp, + RuntimeCallInserter &RTCI); void instrumentMaskedLoadOrStore(AddressSanitizer *Pass, const DataLayout &DL, Type *IntptrTy, Value *Mask, Value *EVL, Value *Stride, Instruction *I, Value *Addr, MaybeAlign Alignment, unsigned Granularity, Type *OpType, bool IsWrite, Value *SizeArgument, bool UseCalls, - uint32_t Exp); + uint32_t Exp, RuntimeCallInserter &RTCI); Value *createSlowPathCmp(IRBuilder<> &IRB, Value *AddrLong, Value *ShadowValue, uint32_t TypeStoreSize); Instruction *generateCrashCode(Instruction *InsertBefore, Value *Addr, bool IsWrite, size_t AccessSizeIndex, - Value *SizeArgument, uint32_t Exp); - void instrumentMemIntrinsic(MemIntrinsic *MI); + Value *SizeArgument, uint32_t Exp, + RuntimeCallInserter &RTCI); + void instrumentMemIntrinsic(MemIntrinsic *MI, RuntimeCallInserter &RTCI); Value *memToShadow(Value *Shadow, IRBuilder<> &IRB); bool suppressInstrumentationSiteForDebug(int &Instrumented); bool instrumentFunction(Function &F, const TargetLibraryInfo *TLI); @@ -912,6 +984,7 @@ private: struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { Function &F; AddressSanitizer &ASan; + RuntimeCallInserter &RTCI; DIBuilder DIB; LLVMContext *C; Type *IntptrTy; @@ -948,10 +1021,12 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { bool HasReturnsTwiceCall = false; bool PoisonStack; - FunctionStackPoisoner(Function &F, AddressSanitizer &ASan) - : F(F), ASan(ASan), DIB(*F.getParent(), /*AllowUnresolved*/ false), - C(ASan.C), IntptrTy(ASan.IntptrTy), - IntptrPtrTy(PointerType::get(IntptrTy, 0)), Mapping(ASan.Mapping), + FunctionStackPoisoner(Function &F, AddressSanitizer &ASan, + RuntimeCallInserter &RTCI) + : F(F), ASan(ASan), RTCI(RTCI), + DIB(*F.getParent(), /*AllowUnresolved*/ false), C(ASan.C), + IntptrTy(ASan.IntptrTy), IntptrPtrTy(PointerType::get(IntptrTy, 0)), + Mapping(ASan.Mapping), PoisonStack(ClStack && !Triple(F.getParent()->getTargetTriple()).isAMDGPU()) {} @@ -1034,8 +1109,8 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { DynamicAreaOffset); } - IRB.CreateCall( - AsanAllocasUnpoisonFunc, + RTCI.createRuntimeCall( + IRB, AsanAllocasUnpoisonFunc, {IRB.CreateLoad(IntptrTy, DynamicAllocaLayout), DynamicAreaPtr}); } @@ -1064,8 +1139,10 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { /// Collect Alloca instructions we want (and can) handle. void visitAllocaInst(AllocaInst &AI) { // FIXME: Handle scalable vectors instead of ignoring them. - if (!ASan.isInterestingAlloca(AI) || - isa<ScalableVectorType>(AI.getAllocatedType())) { + const Type *AllocaType = AI.getAllocatedType(); + const auto *STy = dyn_cast<StructType>(AllocaType); + if (!ASan.isInterestingAlloca(AI) || isa<ScalableVectorType>(AllocaType) || + (STy && STy->containsHomogeneousScalableVectorTypes())) { if (AI.isStaticAlloca()) { // Skip over allocas that are present *before* the first instrumented // alloca, we don't want to move those around. @@ -1251,16 +1328,19 @@ Value *AddressSanitizer::memToShadow(Value *Shadow, IRBuilder<> &IRB) { } // Instrument memset/memmove/memcpy -void AddressSanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) { +void AddressSanitizer::instrumentMemIntrinsic(MemIntrinsic *MI, + RuntimeCallInserter &RTCI) { InstrumentationIRBuilder IRB(MI); if (isa<MemTransferInst>(MI)) { - IRB.CreateCall(isa<MemMoveInst>(MI) ? AsanMemmove : AsanMemcpy, - {MI->getOperand(0), MI->getOperand(1), - IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); + RTCI.createRuntimeCall( + IRB, isa<MemMoveInst>(MI) ? AsanMemmove : AsanMemcpy, + {IRB.CreateAddrSpaceCast(MI->getOperand(0), PtrTy), + IRB.CreateAddrSpaceCast(MI->getOperand(1), PtrTy), + IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); } else if (isa<MemSetInst>(MI)) { - IRB.CreateCall( - AsanMemset, - {MI->getOperand(0), + RTCI.createRuntimeCall( + IRB, AsanMemset, + {IRB.CreateAddrSpaceCast(MI->getOperand(0), PtrTy), IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false), IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); } @@ -1497,7 +1577,7 @@ bool AddressSanitizer::GlobalIsLinkerInitialized(GlobalVariable *G) { } void AddressSanitizer::instrumentPointerComparisonOrSubtraction( - Instruction *I) { + Instruction *I, RuntimeCallInserter &RTCI) { IRBuilder<> IRB(I); FunctionCallee F = isa<ICmpInst>(I) ? AsanPtrCmpFunction : AsanPtrSubFunction; Value *Param[2] = {I->getOperand(0), I->getOperand(1)}; @@ -1505,7 +1585,7 @@ void AddressSanitizer::instrumentPointerComparisonOrSubtraction( if (i->getType()->isPointerTy()) i = IRB.CreatePointerCast(i, IntptrTy); } - IRB.CreateCall(F, Param); + RTCI.createRuntimeCall(IRB, F, Param); } static void doInstrumentAddress(AddressSanitizer *Pass, Instruction *I, @@ -1513,7 +1593,7 @@ static void doInstrumentAddress(AddressSanitizer *Pass, Instruction *I, MaybeAlign Alignment, unsigned Granularity, TypeSize TypeStoreSize, bool IsWrite, Value *SizeArgument, bool UseCalls, - uint32_t Exp) { + uint32_t Exp, RuntimeCallInserter &RTCI) { // Instrument a 1-, 2-, 4-, 8-, or 16- byte access with one check // if the data is properly aligned. if (!TypeStoreSize.isScalable()) { @@ -1528,18 +1608,19 @@ static void doInstrumentAddress(AddressSanitizer *Pass, Instruction *I, *Alignment >= FixedSize / 8) return Pass->instrumentAddress(I, InsertBefore, Addr, Alignment, FixedSize, IsWrite, nullptr, UseCalls, - Exp); + Exp, RTCI); } } Pass->instrumentUnusualSizeOrAlignment(I, InsertBefore, Addr, TypeStoreSize, - IsWrite, nullptr, UseCalls, Exp); + IsWrite, nullptr, UseCalls, Exp, RTCI); } void AddressSanitizer::instrumentMaskedLoadOrStore( AddressSanitizer *Pass, const DataLayout &DL, Type *IntptrTy, Value *Mask, Value *EVL, Value *Stride, Instruction *I, Value *Addr, MaybeAlign Alignment, unsigned Granularity, Type *OpType, bool IsWrite, - Value *SizeArgument, bool UseCalls, uint32_t Exp) { + Value *SizeArgument, bool UseCalls, uint32_t Exp, + RuntimeCallInserter &RTCI) { auto *VTy = cast<VectorType>(OpType); TypeSize ElemTypeSize = DL.getTypeStoreSizeInBits(VTy->getScalarType()); auto Zero = ConstantInt::get(IntptrTy, 0); @@ -1594,15 +1675,16 @@ void AddressSanitizer::instrumentMaskedLoadOrStore( } else { InstrumentedAddress = IRB.CreateGEP(VTy, Addr, {Zero, Index}); } - doInstrumentAddress(Pass, I, &*IRB.GetInsertPoint(), - InstrumentedAddress, Alignment, Granularity, - ElemTypeSize, IsWrite, SizeArgument, UseCalls, Exp); + doInstrumentAddress(Pass, I, &*IRB.GetInsertPoint(), InstrumentedAddress, + Alignment, Granularity, ElemTypeSize, IsWrite, + SizeArgument, UseCalls, Exp, RTCI); }); } void AddressSanitizer::instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis, InterestingMemoryOperand &O, bool UseCalls, - const DataLayout &DL) { + const DataLayout &DL, + RuntimeCallInserter &RTCI) { Value *Addr = O.getPtr(); // Optimization experiments. @@ -1648,11 +1730,11 @@ void AddressSanitizer::instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis, instrumentMaskedLoadOrStore(this, DL, IntptrTy, O.MaybeMask, O.MaybeEVL, O.MaybeStride, O.getInsn(), Addr, O.Alignment, Granularity, O.OpType, O.IsWrite, nullptr, - UseCalls, Exp); + UseCalls, Exp, RTCI); } else { doInstrumentAddress(this, O.getInsn(), O.getInsn(), Addr, O.Alignment, - Granularity, O.TypeStoreSize, O.IsWrite, nullptr, UseCalls, - Exp); + Granularity, O.TypeStoreSize, O.IsWrite, nullptr, + UseCalls, Exp, RTCI); } } @@ -1660,24 +1742,25 @@ Instruction *AddressSanitizer::generateCrashCode(Instruction *InsertBefore, Value *Addr, bool IsWrite, size_t AccessSizeIndex, Value *SizeArgument, - uint32_t Exp) { + uint32_t Exp, + RuntimeCallInserter &RTCI) { InstrumentationIRBuilder IRB(InsertBefore); Value *ExpVal = Exp == 0 ? nullptr : ConstantInt::get(IRB.getInt32Ty(), Exp); CallInst *Call = nullptr; if (SizeArgument) { if (Exp == 0) - Call = IRB.CreateCall(AsanErrorCallbackSized[IsWrite][0], - {Addr, SizeArgument}); + Call = RTCI.createRuntimeCall(IRB, AsanErrorCallbackSized[IsWrite][0], + {Addr, SizeArgument}); else - Call = IRB.CreateCall(AsanErrorCallbackSized[IsWrite][1], - {Addr, SizeArgument, ExpVal}); + Call = RTCI.createRuntimeCall(IRB, AsanErrorCallbackSized[IsWrite][1], + {Addr, SizeArgument, ExpVal}); } else { if (Exp == 0) - Call = - IRB.CreateCall(AsanErrorCallback[IsWrite][0][AccessSizeIndex], Addr); + Call = RTCI.createRuntimeCall( + IRB, AsanErrorCallback[IsWrite][0][AccessSizeIndex], Addr); else - Call = IRB.CreateCall(AsanErrorCallback[IsWrite][1][AccessSizeIndex], - {Addr, ExpVal}); + Call = RTCI.createRuntimeCall( + IRB, AsanErrorCallback[IsWrite][1][AccessSizeIndex], {Addr, ExpVal}); } Call->setCannotMerge(); @@ -1736,7 +1819,7 @@ Instruction *AddressSanitizer::genAMDGPUReportBlock(IRBuilder<> &IRB, auto *Trm = SplitBlockAndInsertIfThen(ReportCond, &*IRB.GetInsertPoint(), false, - MDBuilder(*C).createBranchWeights(1, 100000)); + MDBuilder(*C).createUnlikelyBranchWeights()); Trm->getParent()->setName("asan.report"); if (Recover) @@ -1753,7 +1836,8 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, MaybeAlign Alignment, uint32_t TypeStoreSize, bool IsWrite, Value *SizeArgument, bool UseCalls, - uint32_t Exp) { + uint32_t Exp, + RuntimeCallInserter &RTCI) { if (TargetTriple.isAMDGPU()) { InsertBefore = instrumentAMDGPUAddress(OrigIns, InsertBefore, Addr, TypeStoreSize, IsWrite, SizeArgument); @@ -1778,11 +1862,12 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, Value *AddrLong = IRB.CreatePointerCast(Addr, IntptrTy); if (UseCalls) { if (Exp == 0) - IRB.CreateCall(AsanMemoryAccessCallback[IsWrite][0][AccessSizeIndex], - AddrLong); + RTCI.createRuntimeCall( + IRB, AsanMemoryAccessCallback[IsWrite][0][AccessSizeIndex], AddrLong); else - IRB.CreateCall(AsanMemoryAccessCallback[IsWrite][1][AccessSizeIndex], - {AddrLong, ConstantInt::get(IRB.getInt32Ty(), Exp)}); + RTCI.createRuntimeCall( + IRB, AsanMemoryAccessCallback[IsWrite][1][AccessSizeIndex], + {AddrLong, ConstantInt::get(IRB.getInt32Ty(), Exp)}); return; } @@ -1811,7 +1896,7 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, // We use branch weights for the slow path check, to indicate that the slow // path is rarely taken. This seems to be the case for SPEC benchmarks. Instruction *CheckTerm = SplitBlockAndInsertIfThen( - Cmp, InsertBefore, false, MDBuilder(*C).createBranchWeights(1, 100000)); + Cmp, InsertBefore, false, MDBuilder(*C).createUnlikelyBranchWeights()); assert(cast<BranchInst>(CheckTerm)->isUnconditional()); BasicBlock *NextBB = CheckTerm->getSuccessor(0); IRB.SetInsertPoint(CheckTerm); @@ -1829,8 +1914,8 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, CrashTerm = SplitBlockAndInsertIfThen(Cmp, InsertBefore, !Recover); } - Instruction *Crash = generateCrashCode(CrashTerm, AddrLong, IsWrite, - AccessSizeIndex, SizeArgument, Exp); + Instruction *Crash = generateCrashCode( + CrashTerm, AddrLong, IsWrite, AccessSizeIndex, SizeArgument, Exp, RTCI); if (OrigIns->getDebugLoc()) Crash->setDebugLoc(OrigIns->getDebugLoc()); } @@ -1840,8 +1925,9 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, // and the last bytes. We call __asan_report_*_n(addr, real_size) to be able // to report the actual access size. void AddressSanitizer::instrumentUnusualSizeOrAlignment( - Instruction *I, Instruction *InsertBefore, Value *Addr, TypeSize TypeStoreSize, - bool IsWrite, Value *SizeArgument, bool UseCalls, uint32_t Exp) { + Instruction *I, Instruction *InsertBefore, Value *Addr, + TypeSize TypeStoreSize, bool IsWrite, Value *SizeArgument, bool UseCalls, + uint32_t Exp, RuntimeCallInserter &RTCI) { InstrumentationIRBuilder IRB(InsertBefore); Value *NumBits = IRB.CreateTypeSize(IntptrTy, TypeStoreSize); Value *Size = IRB.CreateLShr(NumBits, ConstantInt::get(IntptrTy, 3)); @@ -1849,19 +1935,21 @@ void AddressSanitizer::instrumentUnusualSizeOrAlignment( Value *AddrLong = IRB.CreatePointerCast(Addr, IntptrTy); if (UseCalls) { if (Exp == 0) - IRB.CreateCall(AsanMemoryAccessCallbackSized[IsWrite][0], - {AddrLong, Size}); + RTCI.createRuntimeCall(IRB, AsanMemoryAccessCallbackSized[IsWrite][0], + {AddrLong, Size}); else - IRB.CreateCall(AsanMemoryAccessCallbackSized[IsWrite][1], - {AddrLong, Size, ConstantInt::get(IRB.getInt32Ty(), Exp)}); + RTCI.createRuntimeCall( + IRB, AsanMemoryAccessCallbackSized[IsWrite][1], + {AddrLong, Size, ConstantInt::get(IRB.getInt32Ty(), Exp)}); } else { Value *SizeMinusOne = IRB.CreateSub(Size, ConstantInt::get(IntptrTy, 1)); Value *LastByte = IRB.CreateIntToPtr( IRB.CreateAdd(AddrLong, SizeMinusOne), Addr->getType()); - instrumentAddress(I, InsertBefore, Addr, {}, 8, IsWrite, Size, false, Exp); + instrumentAddress(I, InsertBefore, Addr, {}, 8, IsWrite, Size, false, Exp, + RTCI); instrumentAddress(I, InsertBefore, LastByte, {}, 8, IsWrite, Size, false, - Exp); + Exp, RTCI); } } @@ -1878,7 +1966,7 @@ void ModuleAddressSanitizer::poisonOneInitializer(Function &GlobalInit, // Add calls to unpoison all globals before each return instruction. for (auto &BB : GlobalInit) if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) - CallInst::Create(AsanUnpoisonGlobals, "", RI); + CallInst::Create(AsanUnpoisonGlobals, "", RI->getIterator()); } void ModuleAddressSanitizer::createInitializerPoisonCalls( @@ -1956,6 +2044,10 @@ bool ModuleAddressSanitizer::shouldInstrumentGlobal(GlobalVariable *G) const { // On COFF, don't instrument non-ODR linkages. if (G->isInterposable()) return false; + // If the global has AvailableExternally linkage, then it is not in this + // module, which means it does not need to be instrumented. + if (G->hasAvailableExternallyLinkage()) + return false; } // If a comdat is present, it must have a selection kind that implies ODR @@ -2855,6 +2947,8 @@ bool AddressSanitizer::instrumentFunction(Function &F, if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage) return false; if (!ClDebugFunc.empty() && ClDebugFunc == F.getName()) return false; if (F.getName().starts_with("__asan_")) return false; + if (F.isPresplitCoroutine()) + return false; bool FunctionModified = false; @@ -2876,6 +2970,8 @@ bool AddressSanitizer::instrumentFunction(Function &F, FunctionStateRAII CleanupObj(this); + RuntimeCallInserter RTCI(F); + FunctionModified |= maybeInsertDynamicShadowAtFunctionEntry(F); // We can't instrument allocas used with llvm.localescape. Only static allocas @@ -2948,7 +3044,7 @@ bool AddressSanitizer::instrumentFunction(Function &F, bool UseCalls = (InstrumentationWithCallsThreshold >= 0 && OperandsToInstrument.size() + IntrinToInstrument.size() > (unsigned)InstrumentationWithCallsThreshold); - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); ObjectSizeOpts ObjSizeOpts; ObjSizeOpts.RoundToAlign = true; ObjectSizeOffsetVisitor ObjSizeVis(DL, TLI, F.getContext(), ObjSizeOpts); @@ -2958,27 +3054,27 @@ bool AddressSanitizer::instrumentFunction(Function &F, for (auto &Operand : OperandsToInstrument) { if (!suppressInstrumentationSiteForDebug(NumInstrumented)) instrumentMop(ObjSizeVis, Operand, UseCalls, - F.getParent()->getDataLayout()); + F.getDataLayout(), RTCI); FunctionModified = true; } for (auto *Inst : IntrinToInstrument) { if (!suppressInstrumentationSiteForDebug(NumInstrumented)) - instrumentMemIntrinsic(Inst); + instrumentMemIntrinsic(Inst, RTCI); FunctionModified = true; } - FunctionStackPoisoner FSP(F, *this); + FunctionStackPoisoner FSP(F, *this, RTCI); bool ChangedStack = FSP.runOnFunction(); // We must unpoison the stack before NoReturn calls (throw, _exit, etc). // See e.g. https://github.com/google/sanitizers/issues/37 for (auto *CI : NoReturnCalls) { IRBuilder<> IRB(CI); - IRB.CreateCall(AsanHandleNoReturnFunc, {}); + RTCI.createRuntimeCall(IRB, AsanHandleNoReturnFunc, {}); } for (auto *Inst : PointerComparisonsOrSubtracts) { - instrumentPointerComparisonOrSubtraction(Inst); + instrumentPointerComparisonOrSubtraction(Inst, RTCI); FunctionModified = true; } @@ -3054,7 +3150,7 @@ void FunctionStackPoisoner::copyToShadowInline(ArrayRef<uint8_t> ShadowMask, const size_t LargestStoreSizeInBytes = std::min<size_t>(sizeof(uint64_t), ASan.LongSize / 8); - const bool IsLittleEndian = F.getParent()->getDataLayout().isLittleEndian(); + const bool IsLittleEndian = F.getDataLayout().isLittleEndian(); // Poison given range in shadow using larges store size with out leading and // trailing zeros in ShadowMask. Zeros never change, so they need neither @@ -3123,9 +3219,10 @@ void FunctionStackPoisoner::copyToShadow(ArrayRef<uint8_t> ShadowMask, if (j - i >= ASan.MaxInlinePoisoningSize) { copyToShadowInline(ShadowMask, ShadowBytes, Done, i, IRB, ShadowBase); - IRB.CreateCall(AsanSetShadowFunc[Val], - {IRB.CreateAdd(ShadowBase, ConstantInt::get(IntptrTy, i)), - ConstantInt::get(IntptrTy, j - i)}); + RTCI.createRuntimeCall( + IRB, AsanSetShadowFunc[Val], + {IRB.CreateAdd(ShadowBase, ConstantInt::get(IntptrTy, i)), + ConstantInt::get(IntptrTy, j - i)}); Done = j; } } @@ -3151,7 +3248,7 @@ void FunctionStackPoisoner::copyArgsPassedByValToAllocas() { assert(CopyInsertPoint); } IRBuilder<> IRB(CopyInsertPoint); - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); for (Argument &Arg : F.args()) { if (Arg.hasByValAttr()) { Type *Ty = Arg.getParamByValType(); @@ -3412,8 +3509,8 @@ void FunctionStackPoisoner::processStaticAllocas() { StackMallocIdx = StackMallocSizeClass(LocalStackSize); assert(StackMallocIdx <= kMaxAsanStackMallocSizeClass); Value *FakeStackValue = - IRBIf.CreateCall(AsanStackMallocFunc[StackMallocIdx], - ConstantInt::get(IntptrTy, LocalStackSize)); + RTCI.createRuntimeCall(IRBIf, AsanStackMallocFunc[StackMallocIdx], + ConstantInt::get(IntptrTy, LocalStackSize)); IRB.SetInsertPoint(InsBefore); FakeStack = createPHI(IRB, UseAfterReturnIsEnabled, FakeStackValue, Term, ConstantInt::get(IntptrTy, 0)); @@ -3423,7 +3520,8 @@ void FunctionStackPoisoner::processStaticAllocas() { // void *LocalStackBase = (FakeStack) ? FakeStack : // alloca(LocalStackSize); StackMallocIdx = StackMallocSizeClass(LocalStackSize); - FakeStack = IRB.CreateCall(AsanStackMallocFunc[StackMallocIdx], + FakeStack = + RTCI.createRuntimeCall(IRB, AsanStackMallocFunc[StackMallocIdx], ConstantInt::get(IntptrTy, LocalStackSize)); } Value *NoFakeStack = @@ -3558,8 +3656,8 @@ void FunctionStackPoisoner::processStaticAllocas() { IRBPoison.CreateIntToPtr(SavedFlagPtr, IRBPoison.getPtrTy())); } else { // For larger frames call __asan_stack_free_*. - IRBPoison.CreateCall( - AsanStackFreeFunc[StackMallocIdx], + RTCI.createRuntimeCall( + IRBPoison, AsanStackFreeFunc[StackMallocIdx], {FakeStack, ConstantInt::get(IntptrTy, LocalStackSize)}); } @@ -3580,8 +3678,8 @@ void FunctionStackPoisoner::poisonAlloca(Value *V, uint64_t Size, // For now just insert the call to ASan runtime. Value *AddrArg = IRB.CreatePointerCast(V, IntptrTy); Value *SizeArg = ConstantInt::get(IntptrTy, Size); - IRB.CreateCall( - DoPoison ? AsanPoisonStackMemoryFunc : AsanUnpoisonStackMemoryFunc, + RTCI.createRuntimeCall( + IRB, DoPoison ? AsanPoisonStackMemoryFunc : AsanUnpoisonStackMemoryFunc, {AddrArg, SizeArg}); } @@ -3608,7 +3706,7 @@ void FunctionStackPoisoner::handleDynamicAllocaCall(AllocaInst *AI) { // ElementSize size, get allocated memory size in bytes by // OldSize * ElementSize. const unsigned ElementSize = - F.getParent()->getDataLayout().getTypeAllocSize(AI->getAllocatedType()); + F.getDataLayout().getTypeAllocSize(AI->getAllocatedType()); Value *OldSize = IRB.CreateMul(IRB.CreateIntCast(AI->getArraySize(), IntptrTy, false), ConstantInt::get(IntptrTy, ElementSize)); @@ -3642,7 +3740,7 @@ void FunctionStackPoisoner::handleDynamicAllocaCall(AllocaInst *AI) { ConstantInt::get(IntptrTy, Alignment.value())); // Insert __asan_alloca_poison call for new created alloca. - IRB.CreateCall(AsanAllocaPoisonFunc, {NewAddress, OldSize}); + RTCI.createRuntimeCall(IRB, AsanAllocaPoisonFunc, {NewAddress, OldSize}); // Store the last alloca's address to DynamicAllocaLayout. We'll need this // for unpoisoning stuff. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp index cfa8ae26c625..618b6fe1aea4 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/BoundsChecking.cpp @@ -144,7 +144,7 @@ static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI, if (F.hasFnAttribute(Attribute::NoSanitizeBounds)) return false; - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); ObjectSizeOpts EvalOpts; EvalOpts.RoundToAlign = true; EvalOpts.EvalMode = ObjectSizeOpts::Mode::ExactUnderlyingSizeAndOffset; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/CGProfile.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/CGProfile.cpp index c322d0abd6bc..ebd7dae25ed3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/CGProfile.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/CGProfile.cpp @@ -14,6 +14,7 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Constants.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/ProfileData/InstrProf.h" #include "llvm/Transforms/Instrumentation.h" @@ -78,16 +79,11 @@ static bool runCGProfilePass(Module &M, FunctionAnalysisManager &FAM, if (!CB) continue; if (CB->isIndirectCall()) { - InstrProfValueData ValueData[8]; - uint32_t ActualNumValueData; uint64_t TotalC; - if (!getValueProfDataFromInst(*CB, IPVK_IndirectCallTarget, 8, - ValueData, ActualNumValueData, TotalC)) - continue; - for (const auto &VD : - ArrayRef<InstrProfValueData>(ValueData, ActualNumValueData)) { + auto ValueData = + getValueProfDataFromInst(*CB, IPVK_IndirectCallTarget, 8, TotalC); + for (const auto &VD : ValueData) UpdateCounts(TTI, &F, Symtab.getFunction(VD.Value), VD.Count); - } continue; } UpdateCounts(TTI, &F, CB->getCalledFunction(), *BBCount); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp index 0a3d8d6000cf..c2affafa49e5 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp @@ -28,6 +28,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/BranchProbability.h" @@ -1878,7 +1879,7 @@ void CHR::fixupBranchesAndSelects(CHRScope *Scope, static_cast<uint32_t>(CHRBranchBias.scale(1000)), static_cast<uint32_t>(CHRBranchBias.getCompl().scale(1000)), }; - setBranchWeights(*MergedBR, Weights); + setBranchWeights(*MergedBR, Weights, /*IsExpected=*/false); CHR_DEBUG(dbgs() << "CHR branch bias " << Weights[0] << ":" << Weights[1] << "\n"); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp index 2ba127bba6f6..113d39b4f2af 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp @@ -357,7 +357,7 @@ public: /// useful for updating calls of the old function to the new type. struct TransformedFunction { TransformedFunction(FunctionType *OriginalType, FunctionType *TransformedType, - std::vector<unsigned> ArgumentIndexMapping) + const std::vector<unsigned> &ArgumentIndexMapping) : OriginalType(OriginalType), TransformedType(TransformedType), ArgumentIndexMapping(ArgumentIndexMapping) {} @@ -516,10 +516,12 @@ class DataFlowSanitizer { const MemoryMapParams *MapParams; Value *getShadowOffset(Value *Addr, IRBuilder<> &IRB); - Value *getShadowAddress(Value *Addr, Instruction *Pos); - Value *getShadowAddress(Value *Addr, Instruction *Pos, Value *ShadowOffset); - std::pair<Value *, Value *> - getShadowOriginAddress(Value *Addr, Align InstAlignment, Instruction *Pos); + Value *getShadowAddress(Value *Addr, BasicBlock::iterator Pos); + Value *getShadowAddress(Value *Addr, BasicBlock::iterator Pos, + Value *ShadowOffset); + std::pair<Value *, Value *> getShadowOriginAddress(Value *Addr, + Align InstAlignment, + BasicBlock::iterator Pos); bool isInstrumented(const Function *F); bool isInstrumented(const GlobalAlias *GA); bool isForceZeroLabels(const Function *F); @@ -536,7 +538,7 @@ class DataFlowSanitizer { /// Advances \p OriginAddr to point to the next 32-bit origin and then loads /// from it. Returns the origin's loaded value. - Value *loadNextOrigin(Instruction *Pos, Align OriginAlign, + Value *loadNextOrigin(BasicBlock::iterator Pos, Align OriginAlign, Value **OriginAddr); /// Returns whether the given load byte size is amenable to inlined @@ -647,18 +649,18 @@ struct DFSanFunction { /// When Zero is nullptr, it uses ZeroPrimitiveShadow. Otherwise it can be /// zeros with other bitwidths. Value *combineOrigins(const std::vector<Value *> &Shadows, - const std::vector<Value *> &Origins, Instruction *Pos, - ConstantInt *Zero = nullptr); + const std::vector<Value *> &Origins, + BasicBlock::iterator Pos, ConstantInt *Zero = nullptr); Value *getShadow(Value *V); void setShadow(Instruction *I, Value *Shadow); /// Generates IR to compute the union of the two given shadows, inserting it /// before Pos. The combined value is with primitive type. - Value *combineShadows(Value *V1, Value *V2, Instruction *Pos); + Value *combineShadows(Value *V1, Value *V2, BasicBlock::iterator Pos); /// Combines the shadow values of V1 and V2, then converts the combined value /// with primitive type into a shadow value with the original type T. Value *combineShadowsThenConvert(Type *T, Value *V1, Value *V2, - Instruction *Pos); + BasicBlock::iterator Pos); Value *combineOperandShadows(Instruction *Inst); /// Generates IR to load shadow and origin corresponding to bytes [\p @@ -670,11 +672,11 @@ struct DFSanFunction { /// current stack if the returned shadow is tainted. std::pair<Value *, Value *> loadShadowOrigin(Value *Addr, uint64_t Size, Align InstAlignment, - Instruction *Pos); + BasicBlock::iterator Pos); void storePrimitiveShadowOrigin(Value *Addr, uint64_t Size, Align InstAlignment, Value *PrimitiveShadow, - Value *Origin, Instruction *Pos); + Value *Origin, BasicBlock::iterator Pos); /// Applies PrimitiveShadow to all primitive subtypes of T, returning /// the expanded shadow value. /// @@ -682,7 +684,7 @@ struct DFSanFunction { /// EFP([n x T], PS) = [n x EFP(T,PS)] /// EFP(other types, PS) = PS Value *expandFromPrimitiveShadow(Type *T, Value *PrimitiveShadow, - Instruction *Pos); + BasicBlock::iterator Pos); /// Collapses Shadow into a single primitive shadow value, unioning all /// primitive shadow values in the process. Returns the final primitive /// shadow value. @@ -690,10 +692,10 @@ struct DFSanFunction { /// CTP({V1,V2, ...}) = UNION(CFP(V1,PS),CFP(V2,PS),...) /// CTP([V1,V2,...]) = UNION(CFP(V1,PS),CFP(V2,PS),...) /// CTP(other types, PS) = PS - Value *collapseToPrimitiveShadow(Value *Shadow, Instruction *Pos); + Value *collapseToPrimitiveShadow(Value *Shadow, BasicBlock::iterator Pos); void storeZeroPrimitiveShadow(Value *Addr, uint64_t Size, Align ShadowAlign, - Instruction *Pos); + BasicBlock::iterator Pos); Align getShadowAlign(Align InstAlignment); @@ -724,7 +726,7 @@ private: std::pair<Value *, Value *> loadShadowFast(Value *ShadowAddr, Value *OriginAddr, uint64_t Size, Align ShadowAlign, Align OriginAlign, Value *FirstOrigin, - Instruction *Pos); + BasicBlock::iterator Pos); Align getOriginAlign(Align InstAlignment); @@ -760,8 +762,9 @@ private: /// for untainted sinks. /// * Use __dfsan_maybe_store_origin if there are too many origin store /// instrumentations. - void storeOrigin(Instruction *Pos, Value *Addr, uint64_t Size, Value *Shadow, - Value *Origin, Value *StoreOriginAddr, Align InstAlignment); + void storeOrigin(BasicBlock::iterator Pos, Value *Addr, uint64_t Size, + Value *Shadow, Value *Origin, Value *StoreOriginAddr, + Align InstAlignment); /// Convert a scalar value to an i1 by comparing with 0. Value *convertToBool(Value *V, IRBuilder<> &IRB, const Twine &Name = ""); @@ -774,7 +777,8 @@ private: /// shadow always has primitive type. std::pair<Value *, Value *> loadShadowOriginSansLoadTracking(Value *Addr, uint64_t Size, - Align InstAlignment, Instruction *Pos); + Align InstAlignment, + BasicBlock::iterator Pos); int NumOriginStores = 0; }; @@ -785,7 +789,7 @@ public: DFSanVisitor(DFSanFunction &DFSF) : DFSF(DFSF) {} const DataLayout &getDataLayout() const { - return DFSF.F->getParent()->getDataLayout(); + return DFSF.F->getDataLayout(); } // Combines shadow values and origins for all of I's operands. @@ -975,7 +979,7 @@ bool DFSanFunction::shouldInstrumentWithCall() { } Value *DFSanFunction::expandFromPrimitiveShadow(Type *T, Value *PrimitiveShadow, - Instruction *Pos) { + BasicBlock::iterator Pos) { Type *ShadowTy = DFS.getShadowTy(T); if (!isa<ArrayType>(ShadowTy) && !isa<StructType>(ShadowTy)) @@ -984,7 +988,7 @@ Value *DFSanFunction::expandFromPrimitiveShadow(Type *T, Value *PrimitiveShadow, if (DFS.isZeroShadow(PrimitiveShadow)) return DFS.getZeroShadow(ShadowTy); - IRBuilder<> IRB(Pos); + IRBuilder<> IRB(Pos->getParent(), Pos); SmallVector<unsigned, 4> Indices; Value *Shadow = UndefValue::get(ShadowTy); Shadow = expandFromPrimitiveShadowRecursive(Shadow, Indices, ShadowTy, @@ -1025,7 +1029,7 @@ Value *DFSanFunction::collapseToPrimitiveShadow(Value *Shadow, } Value *DFSanFunction::collapseToPrimitiveShadow(Value *Shadow, - Instruction *Pos) { + BasicBlock::iterator Pos) { Type *ShadowTy = Shadow->getType(); if (!isa<ArrayType>(ShadowTy) && !isa<StructType>(ShadowTy)) return Shadow; @@ -1035,7 +1039,7 @@ Value *DFSanFunction::collapseToPrimitiveShadow(Value *Shadow, if (CS && DT.dominates(CS, Pos)) return CS; - IRBuilder<> IRB(Pos); + IRBuilder<> IRB(Pos->getParent(), Pos); Value *PrimitiveShadow = collapseToPrimitiveShadow(Shadow, IRB); // Caches the converted primitive shadow value. CS = PrimitiveShadow; @@ -1225,8 +1229,8 @@ bool DataFlowSanitizer::initializeModule(Module &M) { FunctionType::get(Type::getVoidTy(*Ctx), DFSanMemTransferCallbackArgs, /*isVarArg=*/false); - ColdCallWeights = MDBuilder(*Ctx).createBranchWeights(1, 1000); - OriginStoreWeights = MDBuilder(*Ctx).createBranchWeights(1, 1000); + ColdCallWeights = MDBuilder(*Ctx).createUnlikelyBranchWeights(); + OriginStoreWeights = MDBuilder(*Ctx).createUnlikelyBranchWeights(); return true; } @@ -1542,7 +1546,8 @@ bool DataFlowSanitizer::runImpl( SmallPtrSet<Constant *, 1> PersonalityFns; for (Function &F : M) if (!F.isIntrinsic() && !DFSanRuntimeFunctions.contains(&F) && - !LibAtomicFunction(F)) { + !LibAtomicFunction(F) && + !F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation)) { FnsToInstrument.push_back(&F); if (F.hasPersonalityFn()) PersonalityFns.insert(F.getPersonalityFn()->stripPointerCasts()); @@ -1760,14 +1765,14 @@ bool DataFlowSanitizer::runImpl( // instrumentation. if (ClDebugNonzeroLabels) { for (Value *V : DFSF.NonZeroChecks) { - Instruction *Pos; + BasicBlock::iterator Pos; if (Instruction *I = dyn_cast<Instruction>(V)) - Pos = I->getNextNode(); + Pos = std::next(I->getIterator()); else - Pos = &DFSF.F->getEntryBlock().front(); + Pos = DFSF.F->getEntryBlock().begin(); while (isa<PHINode>(Pos) || isa<AllocaInst>(Pos)) - Pos = Pos->getNextNode(); - IRBuilder<> IRB(Pos); + Pos = std::next(Pos->getIterator()); + IRBuilder<> IRB(Pos->getParent(), Pos); Value *PrimitiveShadow = DFSF.collapseToPrimitiveShadow(V, Pos); Value *Ne = IRB.CreateICmpNE(PrimitiveShadow, DFSF.DFS.ZeroPrimitiveShadow); @@ -1799,8 +1804,8 @@ Value *DFSanFunction::getRetvalTLS(Type *T, IRBuilder<> &IRB) { Value *DFSanFunction::getRetvalOriginTLS() { return DFS.RetvalOriginTLS; } Value *DFSanFunction::getArgOriginTLS(unsigned ArgNo, IRBuilder<> &IRB) { - return IRB.CreateConstGEP2_64(DFS.ArgOriginTLSTy, DFS.ArgOriginTLS, 0, ArgNo, - "_dfsarg_o"); + return IRB.CreateConstInBoundsGEP2_64(DFS.ArgOriginTLSTy, DFS.ArgOriginTLS, 0, + ArgNo, "_dfsarg_o"); } Value *DFSanFunction::getOrigin(Value *V) { @@ -1838,7 +1843,7 @@ void DFSanFunction::setOrigin(Instruction *I, Value *Origin) { Value *DFSanFunction::getShadowForTLSArgument(Argument *A) { unsigned ArgOffset = 0; - const DataLayout &DL = F->getParent()->getDataLayout(); + const DataLayout &DL = F->getDataLayout(); for (auto &FArg : F->args()) { if (!FArg.getType()->isSized()) { if (A == &FArg) @@ -1912,9 +1917,9 @@ Value *DataFlowSanitizer::getShadowOffset(Value *Addr, IRBuilder<> &IRB) { std::pair<Value *, Value *> DataFlowSanitizer::getShadowOriginAddress(Value *Addr, Align InstAlignment, - Instruction *Pos) { + BasicBlock::iterator Pos) { // Returns ((Addr & shadow_mask) + origin_base - shadow_base) & ~4UL - IRBuilder<> IRB(Pos); + IRBuilder<> IRB(Pos->getParent(), Pos); Value *ShadowOffset = getShadowOffset(Addr, IRB); Value *ShadowLong = ShadowOffset; uint64_t ShadowBase = MapParams->ShadowBase; @@ -1944,27 +1949,30 @@ DataFlowSanitizer::getShadowOriginAddress(Value *Addr, Align InstAlignment, return std::make_pair(ShadowPtr, OriginPtr); } -Value *DataFlowSanitizer::getShadowAddress(Value *Addr, Instruction *Pos, +Value *DataFlowSanitizer::getShadowAddress(Value *Addr, + BasicBlock::iterator Pos, Value *ShadowOffset) { - IRBuilder<> IRB(Pos); + IRBuilder<> IRB(Pos->getParent(), Pos); return IRB.CreateIntToPtr(ShadowOffset, PrimitiveShadowPtrTy); } -Value *DataFlowSanitizer::getShadowAddress(Value *Addr, Instruction *Pos) { - IRBuilder<> IRB(Pos); +Value *DataFlowSanitizer::getShadowAddress(Value *Addr, + BasicBlock::iterator Pos) { + IRBuilder<> IRB(Pos->getParent(), Pos); Value *ShadowOffset = getShadowOffset(Addr, IRB); return getShadowAddress(Addr, Pos, ShadowOffset); } Value *DFSanFunction::combineShadowsThenConvert(Type *T, Value *V1, Value *V2, - Instruction *Pos) { + BasicBlock::iterator Pos) { Value *PrimitiveValue = combineShadows(V1, V2, Pos); return expandFromPrimitiveShadow(T, PrimitiveValue, Pos); } // Generates IR to compute the union of the two given shadows, inserting it // before Pos. The combined value is with primitive type. -Value *DFSanFunction::combineShadows(Value *V1, Value *V2, Instruction *Pos) { +Value *DFSanFunction::combineShadows(Value *V1, Value *V2, + BasicBlock::iterator Pos) { if (DFS.isZeroShadow(V1)) return collapseToPrimitiveShadow(V2, Pos); if (DFS.isZeroShadow(V2)) @@ -2002,7 +2010,7 @@ Value *DFSanFunction::combineShadows(Value *V1, Value *V2, Instruction *Pos) { Value *PV1 = collapseToPrimitiveShadow(V1, Pos); Value *PV2 = collapseToPrimitiveShadow(V2, Pos); - IRBuilder<> IRB(Pos); + IRBuilder<> IRB(Pos->getParent(), Pos); CCS.Block = Pos->getParent(); CCS.Shadow = IRB.CreateOr(PV1, PV2); @@ -2031,9 +2039,11 @@ Value *DFSanFunction::combineOperandShadows(Instruction *Inst) { Value *Shadow = getShadow(Inst->getOperand(0)); for (unsigned I = 1, N = Inst->getNumOperands(); I < N; ++I) - Shadow = combineShadows(Shadow, getShadow(Inst->getOperand(I)), Inst); + Shadow = combineShadows(Shadow, getShadow(Inst->getOperand(I)), + Inst->getIterator()); - return expandFromPrimitiveShadow(Inst->getType(), Shadow, Inst); + return expandFromPrimitiveShadow(Inst->getType(), Shadow, + Inst->getIterator()); } void DFSanVisitor::visitInstOperands(Instruction &I) { @@ -2044,7 +2054,8 @@ void DFSanVisitor::visitInstOperands(Instruction &I) { Value *DFSanFunction::combineOrigins(const std::vector<Value *> &Shadows, const std::vector<Value *> &Origins, - Instruction *Pos, ConstantInt *Zero) { + BasicBlock::iterator Pos, + ConstantInt *Zero) { assert(Shadows.size() == Origins.size()); size_t Size = Origins.size(); if (Size == 0) @@ -2063,7 +2074,7 @@ Value *DFSanFunction::combineOrigins(const std::vector<Value *> &Shadows, } Value *OpShadow = Shadows[I]; Value *PrimitiveShadow = collapseToPrimitiveShadow(OpShadow, Pos); - IRBuilder<> IRB(Pos); + IRBuilder<> IRB(Pos->getParent(), Pos); Value *Cond = IRB.CreateICmpNE(PrimitiveShadow, Zero); Origin = IRB.CreateSelect(Cond, OpOrigin, Origin); } @@ -2078,7 +2089,7 @@ Value *DFSanFunction::combineOperandOrigins(Instruction *Inst) { Shadows[I] = getShadow(Inst->getOperand(I)); Origins[I] = getOrigin(Inst->getOperand(I)); } - return combineOrigins(Shadows, Origins, Inst); + return combineOrigins(Shadows, Origins, Inst->getIterator()); } void DFSanVisitor::visitInstOperandOrigins(Instruction &I) { @@ -2129,9 +2140,10 @@ bool DFSanFunction::useCallbackLoadLabelAndOrigin(uint64_t Size, return Alignment < MinOriginAlignment || !DFS.hasLoadSizeForFastPath(Size); } -Value *DataFlowSanitizer::loadNextOrigin(Instruction *Pos, Align OriginAlign, +Value *DataFlowSanitizer::loadNextOrigin(BasicBlock::iterator Pos, + Align OriginAlign, Value **OriginAddr) { - IRBuilder<> IRB(Pos); + IRBuilder<> IRB(Pos->getParent(), Pos); *OriginAddr = IRB.CreateGEP(OriginTy, *OriginAddr, ConstantInt::get(IntptrTy, 1)); return IRB.CreateAlignedLoad(OriginTy, *OriginAddr, OriginAlign); @@ -2139,7 +2151,7 @@ Value *DataFlowSanitizer::loadNextOrigin(Instruction *Pos, Align OriginAlign, std::pair<Value *, Value *> DFSanFunction::loadShadowFast( Value *ShadowAddr, Value *OriginAddr, uint64_t Size, Align ShadowAlign, - Align OriginAlign, Value *FirstOrigin, Instruction *Pos) { + Align OriginAlign, Value *FirstOrigin, BasicBlock::iterator Pos) { const bool ShouldTrackOrigins = DFS.shouldTrackOrigins(); const uint64_t ShadowSize = Size * DFS.ShadowWidthBytes; @@ -2163,7 +2175,7 @@ std::pair<Value *, Value *> DFSanFunction::loadShadowFast( Type *WideShadowTy = ShadowSize == 4 ? Type::getInt32Ty(*DFS.Ctx) : Type::getInt64Ty(*DFS.Ctx); - IRBuilder<> IRB(Pos); + IRBuilder<> IRB(Pos->getParent(), Pos); Value *CombinedWideShadow = IRB.CreateAlignedLoad(WideShadowTy, ShadowAddr, ShadowAlign); @@ -2225,14 +2237,14 @@ std::pair<Value *, Value *> DFSanFunction::loadShadowFast( } std::pair<Value *, Value *> DFSanFunction::loadShadowOriginSansLoadTracking( - Value *Addr, uint64_t Size, Align InstAlignment, Instruction *Pos) { + Value *Addr, uint64_t Size, Align InstAlignment, BasicBlock::iterator Pos) { const bool ShouldTrackOrigins = DFS.shouldTrackOrigins(); // Non-escaped loads. if (AllocaInst *AI = dyn_cast<AllocaInst>(Addr)) { const auto SI = AllocaShadowMap.find(AI); if (SI != AllocaShadowMap.end()) { - IRBuilder<> IRB(Pos); + IRBuilder<> IRB(Pos->getParent(), Pos); Value *ShadowLI = IRB.CreateLoad(DFS.PrimitiveShadowTy, SI->second); const auto OI = AllocaOriginMap.find(AI); assert(!ShouldTrackOrigins || OI != AllocaOriginMap.end()); @@ -2267,7 +2279,7 @@ std::pair<Value *, Value *> DFSanFunction::loadShadowOriginSansLoadTracking( // tracking. if (ShouldTrackOrigins && useCallbackLoadLabelAndOrigin(Size, InstAlignment)) { - IRBuilder<> IRB(Pos); + IRBuilder<> IRB(Pos->getParent(), Pos); CallInst *Call = IRB.CreateCall(DFS.DFSanLoadLabelAndOriginFn, {Addr, ConstantInt::get(DFS.IntptrTy, Size)}); @@ -2286,7 +2298,7 @@ std::pair<Value *, Value *> DFSanFunction::loadShadowOriginSansLoadTracking( const Align OriginAlign = getOriginAlign(InstAlignment); Value *Origin = nullptr; if (ShouldTrackOrigins) { - IRBuilder<> IRB(Pos); + IRBuilder<> IRB(Pos->getParent(), Pos); Origin = IRB.CreateAlignedLoad(DFS.OriginTy, OriginAddr, OriginAlign); } @@ -2299,7 +2311,7 @@ std::pair<Value *, Value *> DFSanFunction::loadShadowOriginSansLoadTracking( return {LI, Origin}; } case 2: { - IRBuilder<> IRB(Pos); + IRBuilder<> IRB(Pos->getParent(), Pos); Value *ShadowAddr1 = IRB.CreateGEP(DFS.PrimitiveShadowTy, ShadowAddr, ConstantInt::get(DFS.IntptrTy, 1)); Value *Load = @@ -2315,23 +2327,22 @@ std::pair<Value *, Value *> DFSanFunction::loadShadowOriginSansLoadTracking( return loadShadowFast(ShadowAddr, OriginAddr, Size, ShadowAlign, OriginAlign, Origin, Pos); - IRBuilder<> IRB(Pos); + IRBuilder<> IRB(Pos->getParent(), Pos); CallInst *FallbackCall = IRB.CreateCall( DFS.DFSanUnionLoadFn, {ShadowAddr, ConstantInt::get(DFS.IntptrTy, Size)}); FallbackCall->addRetAttr(Attribute::ZExt); return {FallbackCall, Origin}; } -std::pair<Value *, Value *> DFSanFunction::loadShadowOrigin(Value *Addr, - uint64_t Size, - Align InstAlignment, - Instruction *Pos) { +std::pair<Value *, Value *> +DFSanFunction::loadShadowOrigin(Value *Addr, uint64_t Size, Align InstAlignment, + BasicBlock::iterator Pos) { Value *PrimitiveShadow, *Origin; std::tie(PrimitiveShadow, Origin) = loadShadowOriginSansLoadTracking(Addr, Size, InstAlignment, Pos); if (DFS.shouldTrackOrigins()) { if (ClTrackOrigins == 2) { - IRBuilder<> IRB(Pos); + IRBuilder<> IRB(Pos->getParent(), Pos); auto *ConstantShadow = dyn_cast<Constant>(PrimitiveShadow); if (!ConstantShadow || !ConstantShadow->isZeroValue()) Origin = updateOriginIfTainted(PrimitiveShadow, Origin, IRB); @@ -2381,7 +2392,7 @@ Value *StripPointerGEPsAndCasts(Value *V) { } void DFSanVisitor::visitLoadInst(LoadInst &LI) { - auto &DL = LI.getModule()->getDataLayout(); + auto &DL = LI.getDataLayout(); uint64_t Size = DL.getTypeStoreSize(LI.getType()); if (Size == 0) { DFSF.setShadow(&LI, DFSF.DFS.getZeroShadow(&LI)); @@ -2397,8 +2408,11 @@ void DFSanVisitor::visitLoadInst(LoadInst &LI) { if (LI.isAtomic()) LI.setOrdering(addAcquireOrdering(LI.getOrdering())); - Instruction *AfterLi = LI.getNextNode(); - Instruction *Pos = LI.isAtomic() ? LI.getNextNode() : &LI; + BasicBlock::iterator AfterLi = std::next(LI.getIterator()); + BasicBlock::iterator Pos = LI.getIterator(); + if (LI.isAtomic()) + Pos = std::next(Pos); + std::vector<Value *> Shadows; std::vector<Value *> Origins; Value *PrimitiveShadow, *Origin; @@ -2431,14 +2445,14 @@ void DFSanVisitor::visitLoadInst(LoadInst &LI) { } if (ClEventCallbacks) { - IRBuilder<> IRB(Pos); + IRBuilder<> IRB(Pos->getParent(), Pos); Value *Addr = LI.getPointerOperand(); CallInst *CI = IRB.CreateCall(DFSF.DFS.DFSanLoadCallbackFn, {PrimitiveShadow, Addr}); CI->addParamAttr(0, Attribute::ZExt); } - IRBuilder<> IRB(AfterLi); + IRBuilder<> IRB(AfterLi->getParent(), AfterLi); DFSF.addReachesFunctionCallbacksIfEnabled(IRB, LI, &LI); } @@ -2456,7 +2470,7 @@ Value *DFSanFunction::updateOrigin(Value *V, IRBuilder<> &IRB) { Value *DFSanFunction::originToIntptr(IRBuilder<> &IRB, Value *Origin) { const unsigned OriginSize = DataFlowSanitizer::OriginWidthBytes; - const DataLayout &DL = F->getParent()->getDataLayout(); + const DataLayout &DL = F->getDataLayout(); unsigned IntptrSize = DL.getTypeStoreSize(DFS.IntptrTy); if (IntptrSize == OriginSize) return Origin; @@ -2469,7 +2483,7 @@ void DFSanFunction::paintOrigin(IRBuilder<> &IRB, Value *Origin, Value *StoreOriginAddr, uint64_t StoreOriginSize, Align Alignment) { const unsigned OriginSize = DataFlowSanitizer::OriginWidthBytes; - const DataLayout &DL = F->getParent()->getDataLayout(); + const DataLayout &DL = F->getDataLayout(); const Align IntptrAlignment = DL.getABITypeAlign(DFS.IntptrTy); unsigned IntptrSize = DL.getTypeStoreSize(DFS.IntptrTy); assert(IntptrAlignment >= MinOriginAlignment); @@ -2510,14 +2524,14 @@ Value *DFSanFunction::convertToBool(Value *V, IRBuilder<> &IRB, return IRB.CreateICmpNE(V, ConstantInt::get(VTy, 0), Name); } -void DFSanFunction::storeOrigin(Instruction *Pos, Value *Addr, uint64_t Size, - Value *Shadow, Value *Origin, +void DFSanFunction::storeOrigin(BasicBlock::iterator Pos, Value *Addr, + uint64_t Size, Value *Shadow, Value *Origin, Value *StoreOriginAddr, Align InstAlignment) { // Do not write origins for zero shadows because we do not trace origins for // untainted sinks. const Align OriginAlignment = getOriginAlign(InstAlignment); Value *CollapsedShadow = collapseToPrimitiveShadow(Shadow, Pos); - IRBuilder<> IRB(Pos); + IRBuilder<> IRB(Pos->getParent(), Pos); if (auto *ConstantShadow = dyn_cast<Constant>(CollapsedShadow)) { if (!ConstantShadow->isZeroValue()) paintOrigin(IRB, updateOrigin(Origin, IRB), StoreOriginAddr, Size, @@ -2543,8 +2557,8 @@ void DFSanFunction::storeOrigin(Instruction *Pos, Value *Addr, uint64_t Size, void DFSanFunction::storeZeroPrimitiveShadow(Value *Addr, uint64_t Size, Align ShadowAlign, - Instruction *Pos) { - IRBuilder<> IRB(Pos); + BasicBlock::iterator Pos) { + IRBuilder<> IRB(Pos->getParent(), Pos); IntegerType *ShadowTy = IntegerType::get(*DFS.Ctx, Size * DFS.ShadowWidthBits); Value *ExtZeroShadow = ConstantInt::get(ShadowTy, 0); @@ -2558,13 +2572,13 @@ void DFSanFunction::storePrimitiveShadowOrigin(Value *Addr, uint64_t Size, Align InstAlignment, Value *PrimitiveShadow, Value *Origin, - Instruction *Pos) { + BasicBlock::iterator Pos) { const bool ShouldTrackOrigins = DFS.shouldTrackOrigins() && Origin; if (AllocaInst *AI = dyn_cast<AllocaInst>(Addr)) { const auto SI = AllocaShadowMap.find(AI); if (SI != AllocaShadowMap.end()) { - IRBuilder<> IRB(Pos); + IRBuilder<> IRB(Pos->getParent(), Pos); IRB.CreateStore(PrimitiveShadow, SI->second); // Do not write origins for 0 shadows because we do not trace origins for @@ -2584,7 +2598,7 @@ void DFSanFunction::storePrimitiveShadowOrigin(Value *Addr, uint64_t Size, return; } - IRBuilder<> IRB(Pos); + IRBuilder<> IRB(Pos->getParent(), Pos); Value *ShadowAddr, *OriginAddr; std::tie(ShadowAddr, OriginAddr) = DFS.getShadowOriginAddress(Addr, InstAlignment, Pos); @@ -2645,7 +2659,7 @@ static AtomicOrdering addReleaseOrdering(AtomicOrdering AO) { } void DFSanVisitor::visitStoreInst(StoreInst &SI) { - auto &DL = SI.getModule()->getDataLayout(); + auto &DL = SI.getDataLayout(); Value *Val = SI.getValueOperand(); uint64_t Size = DL.getTypeStoreSize(Val->getType()); if (Size == 0) @@ -2679,15 +2693,15 @@ void DFSanVisitor::visitStoreInst(StoreInst &SI) { Shadows.push_back(PtrShadow); Origins.push_back(DFSF.getOrigin(SI.getPointerOperand())); } - PrimitiveShadow = DFSF.combineShadows(Shadow, PtrShadow, &SI); + PrimitiveShadow = DFSF.combineShadows(Shadow, PtrShadow, SI.getIterator()); } else { - PrimitiveShadow = DFSF.collapseToPrimitiveShadow(Shadow, &SI); + PrimitiveShadow = DFSF.collapseToPrimitiveShadow(Shadow, SI.getIterator()); } Value *Origin = nullptr; if (ShouldTrackOrigins) - Origin = DFSF.combineOrigins(Shadows, Origins, &SI); + Origin = DFSF.combineOrigins(Shadows, Origins, SI.getIterator()); DFSF.storePrimitiveShadowOrigin(SI.getPointerOperand(), Size, SI.getAlign(), - PrimitiveShadow, Origin, &SI); + PrimitiveShadow, Origin, SI.getIterator()); if (ClEventCallbacks) { IRBuilder<> IRB(&SI); Value *Addr = SI.getPointerOperand(); @@ -2701,7 +2715,7 @@ void DFSanVisitor::visitCASOrRMW(Align InstAlignment, Instruction &I) { assert(isa<AtomicRMWInst>(I) || isa<AtomicCmpXchgInst>(I)); Value *Val = I.getOperand(1); - const auto &DL = I.getModule()->getDataLayout(); + const auto &DL = I.getDataLayout(); uint64_t Size = DL.getTypeStoreSize(Val->getType()); if (Size == 0) return; @@ -2711,7 +2725,7 @@ void DFSanVisitor::visitCASOrRMW(Align InstAlignment, Instruction &I) { IRBuilder<> IRB(&I); Value *Addr = I.getOperand(0); const Align ShadowAlign = DFSF.getShadowAlign(InstAlignment); - DFSF.storeZeroPrimitiveShadow(Addr, Size, ShadowAlign, &I); + DFSF.storeZeroPrimitiveShadow(Addr, Size, ShadowAlign, I.getIterator()); DFSF.setShadow(&I, DFSF.DFS.getZeroShadow(&I)); DFSF.setOrigin(&I, DFSF.DFS.ZeroOrigin); } @@ -2866,7 +2880,7 @@ void DFSanVisitor::visitSelectInst(SelectInst &I) { if (isa<VectorType>(I.getCondition()->getType())) { ShadowSel = DFSF.combineShadowsThenConvert(I.getType(), TrueShadow, - FalseShadow, &I); + FalseShadow, I.getIterator()); if (ShouldTrackOrigins) { Shadows.push_back(TrueShadow); Shadows.push_back(FalseShadow); @@ -2881,25 +2895,25 @@ void DFSanVisitor::visitSelectInst(SelectInst &I) { Origins.push_back(TrueOrigin); } } else { - ShadowSel = - SelectInst::Create(I.getCondition(), TrueShadow, FalseShadow, "", &I); + ShadowSel = SelectInst::Create(I.getCondition(), TrueShadow, FalseShadow, + "", I.getIterator()); if (ShouldTrackOrigins) { Shadows.push_back(ShadowSel); Origins.push_back(SelectInst::Create(I.getCondition(), TrueOrigin, - FalseOrigin, "", &I)); + FalseOrigin, "", I.getIterator())); } } } - DFSF.setShadow(&I, ClTrackSelectControlFlow - ? DFSF.combineShadowsThenConvert( - I.getType(), CondShadow, ShadowSel, &I) - : ShadowSel); + DFSF.setShadow(&I, ClTrackSelectControlFlow ? DFSF.combineShadowsThenConvert( + I.getType(), CondShadow, + ShadowSel, I.getIterator()) + : ShadowSel); if (ShouldTrackOrigins) { if (ClTrackSelectControlFlow) { Shadows.push_back(CondShadow); Origins.push_back(DFSF.getOrigin(I.getCondition())); } - DFSF.setOrigin(&I, DFSF.combineOrigins(Shadows, Origins, &I)); + DFSF.setOrigin(&I, DFSF.combineOrigins(Shadows, Origins, I.getIterator())); } } @@ -2926,8 +2940,8 @@ void DFSanVisitor::visitMemTransferInst(MemTransferInst &I) { IRB.CreateIntCast(I.getArgOperand(2), DFSF.DFS.IntptrTy, false)}); } - Value *DestShadow = DFSF.DFS.getShadowAddress(I.getDest(), &I); - Value *SrcShadow = DFSF.DFS.getShadowAddress(I.getSource(), &I); + Value *DestShadow = DFSF.DFS.getShadowAddress(I.getDest(), I.getIterator()); + Value *SrcShadow = DFSF.DFS.getShadowAddress(I.getSource(), I.getIterator()); Value *LenShadow = IRB.CreateMul(I.getLength(), ConstantInt::get(I.getLength()->getType(), DFSF.DFS.ShadowWidthBytes)); @@ -2996,7 +3010,8 @@ void DFSanVisitor::addShadowArguments(Function &F, CallBase &CB, // Adds non-variable argument shadows. for (unsigned N = FT->getNumParams(); N != 0; ++I, --N) - Args.push_back(DFSF.collapseToPrimitiveShadow(DFSF.getShadow(*I), &CB)); + Args.push_back( + DFSF.collapseToPrimitiveShadow(DFSF.getShadow(*I), CB.getIterator())); // Adds variable argument shadows. if (FT->isVarArg()) { @@ -3004,12 +3019,13 @@ void DFSanVisitor::addShadowArguments(Function &F, CallBase &CB, CB.arg_size() - FT->getNumParams()); auto *LabelVAAlloca = new AllocaInst(LabelVATy, getDataLayout().getAllocaAddrSpace(), - "labelva", &DFSF.F->getEntryBlock().front()); + "labelva", DFSF.F->getEntryBlock().begin()); for (unsigned N = 0; I != CB.arg_end(); ++I, ++N) { auto *LabelVAPtr = IRB.CreateStructGEP(LabelVATy, LabelVAAlloca, N); - IRB.CreateStore(DFSF.collapseToPrimitiveShadow(DFSF.getShadow(*I), &CB), - LabelVAPtr); + IRB.CreateStore( + DFSF.collapseToPrimitiveShadow(DFSF.getShadow(*I), CB.getIterator()), + LabelVAPtr); } Args.push_back(IRB.CreateStructGEP(LabelVATy, LabelVAAlloca, 0)); @@ -3020,7 +3036,7 @@ void DFSanVisitor::addShadowArguments(Function &F, CallBase &CB, if (!DFSF.LabelReturnAlloca) { DFSF.LabelReturnAlloca = new AllocaInst( DFSF.DFS.PrimitiveShadowTy, getDataLayout().getAllocaAddrSpace(), - "labelreturn", &DFSF.F->getEntryBlock().front()); + "labelreturn", DFSF.F->getEntryBlock().begin()); } Args.push_back(DFSF.LabelReturnAlloca); } @@ -3043,7 +3059,7 @@ void DFSanVisitor::addOriginArguments(Function &F, CallBase &CB, ArrayType::get(DFSF.DFS.OriginTy, CB.arg_size() - FT->getNumParams()); auto *OriginVAAlloca = new AllocaInst(OriginVATy, getDataLayout().getAllocaAddrSpace(), - "originva", &DFSF.F->getEntryBlock().front()); + "originva", DFSF.F->getEntryBlock().begin()); for (unsigned N = 0; I != CB.arg_end(); ++I, ++N) { auto *OriginVAPtr = IRB.CreateStructGEP(OriginVATy, OriginVAAlloca, N); @@ -3058,7 +3074,7 @@ void DFSanVisitor::addOriginArguments(Function &F, CallBase &CB, if (!DFSF.OriginReturnAlloca) { DFSF.OriginReturnAlloca = new AllocaInst( DFSF.DFS.OriginTy, getDataLayout().getAllocaAddrSpace(), - "originreturn", &DFSF.F->getEntryBlock().front()); + "originreturn", DFSF.F->getEntryBlock().begin()); } Args.push_back(DFSF.OriginReturnAlloca); } @@ -3155,8 +3171,9 @@ bool DFSanVisitor::visitWrappedCallBase(Function &F, CallBase &CB) { if (!FT->getReturnType()->isVoidTy()) { LoadInst *LabelLoad = IRB.CreateLoad(DFSF.DFS.PrimitiveShadowTy, DFSF.LabelReturnAlloca); - DFSF.setShadow(CustomCI, DFSF.expandFromPrimitiveShadow( - FT->getReturnType(), LabelLoad, &CB)); + DFSF.setShadow(CustomCI, + DFSF.expandFromPrimitiveShadow( + FT->getReturnType(), LabelLoad, CB.getIterator())); if (ShouldTrackOrigins) { LoadInst *OriginLoad = IRB.CreateLoad(DFSF.DFS.OriginTy, DFSF.OriginReturnAlloca); @@ -3185,8 +3202,7 @@ Value *DFSanVisitor::makeAddAcquireOrderingTable(IRBuilder<> &IRB) { OrderingTable[(int)AtomicOrderingCABI::seq_cst] = (int)AtomicOrderingCABI::seq_cst; - return ConstantDataVector::get(IRB.getContext(), - ArrayRef(OrderingTable, NumOrderings)); + return ConstantDataVector::get(IRB.getContext(), OrderingTable); } void DFSanVisitor::visitLibAtomicLoad(CallBase &CB) { @@ -3229,8 +3245,7 @@ Value *DFSanVisitor::makeAddReleaseOrderingTable(IRBuilder<> &IRB) { OrderingTable[(int)AtomicOrderingCABI::seq_cst] = (int)AtomicOrderingCABI::seq_cst; - return ConstantDataVector::get(IRB.getContext(), - ArrayRef(OrderingTable, NumOrderings)); + return ConstantDataVector::get(IRB.getContext(), OrderingTable); } void DFSanVisitor::visitLibAtomicStore(CallBase &CB) { @@ -3433,8 +3448,8 @@ void DFSanVisitor::visitCallBase(CallBase &CB) { void DFSanVisitor::visitPHINode(PHINode &PN) { Type *ShadowTy = DFSF.DFS.getShadowTy(&PN); - PHINode *ShadowPN = - PHINode::Create(ShadowTy, PN.getNumIncomingValues(), "", &PN); + PHINode *ShadowPN = PHINode::Create(ShadowTy, PN.getNumIncomingValues(), "", + PN.getIterator()); // Give the shadow phi node valid predecessors to fool SplitEdge into working. Value *UndefShadow = UndefValue::get(ShadowTy); @@ -3445,8 +3460,8 @@ void DFSanVisitor::visitPHINode(PHINode &PN) { PHINode *OriginPN = nullptr; if (DFSF.DFS.shouldTrackOrigins()) { - OriginPN = - PHINode::Create(DFSF.DFS.OriginTy, PN.getNumIncomingValues(), "", &PN); + OriginPN = PHINode::Create(DFSF.DFS.OriginTy, PN.getNumIncomingValues(), "", + PN.getIterator()); Value *UndefOrigin = UndefValue::get(DFSF.DFS.OriginTy); for (BasicBlock *BB : PN.blocks()) OriginPN->addIncoming(UndefOrigin, BB); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp index efb621cde906..a0e63bf12400 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -15,11 +15,15 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/StackSafetyAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -49,6 +53,8 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/MD5.h" +#include "llvm/Support/RandomNumberGenerator.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Instrumentation/AddressSanitizerCommon.h" @@ -58,6 +64,7 @@ #include "llvm/Transforms/Utils/ModuleUtils.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" #include <optional> +#include <random> using namespace llvm; @@ -177,6 +184,18 @@ static cl::opt<bool> ClWithTls( "platforms that support this"), cl::Hidden, cl::init(true)); +static cl::opt<int> ClHotPercentileCutoff("hwasan-percentile-cutoff-hot", + cl::desc("Hot percentile cuttoff.")); + +static cl::opt<float> + ClRandomSkipRate("hwasan-random-rate", + cl::desc("Probability value in the range [0.0, 1.0] " + "to keep instrumentation of a function.")); + +STATISTIC(NumTotalFuncs, "Number of total funcs"); +STATISTIC(NumInstrumentedFuncs, "Number of instrumented funcs"); +STATISTIC(NumNoProfileSummaryFuncs, "Number of funcs without PS"); + // Mode for selecting how to insert frame record info into the stack ring // buffer. enum RecordStackHistoryMode { @@ -236,6 +255,10 @@ static cl::opt<bool> ClUsePageAliases("hwasan-experimental-use-page-aliases", namespace { +template <typename T> T optOr(cl::opt<T> &Opt, T Other) { + return Opt.getNumOccurrences() ? Opt : Other; +} + bool shouldUsePageAliases(const Triple &TargetTriple) { return ClUsePageAliases && TargetTriple.getArch() == Triple::x86_64; } @@ -245,14 +268,11 @@ bool shouldInstrumentStack(const Triple &TargetTriple) { } bool shouldInstrumentWithCalls(const Triple &TargetTriple) { - return ClInstrumentWithCalls.getNumOccurrences() - ? ClInstrumentWithCalls - : TargetTriple.getArch() == Triple::x86_64; + return optOr(ClInstrumentWithCalls, TargetTriple.getArch() == Triple::x86_64); } bool mightUseStackSafetyAnalysis(bool DisableOptimization) { - return ClUseStackSafety.getNumOccurrences() ? ClUseStackSafety - : !DisableOptimization; + return optOr(ClUseStackSafety, !DisableOptimization); } bool shouldUseStackSafetyAnalysis(const Triple &TargetTriple, @@ -272,10 +292,10 @@ public: HWAddressSanitizer(Module &M, bool CompileKernel, bool Recover, const StackSafetyGlobalInfo *SSI) : M(M), SSI(SSI) { - this->Recover = ClRecover.getNumOccurrences() > 0 ? ClRecover : Recover; - this->CompileKernel = ClEnableKhwasan.getNumOccurrences() > 0 - ? ClEnableKhwasan - : CompileKernel; + this->Recover = optOr(ClRecover, Recover); + this->CompileKernel = optOr(ClEnableKhwasan, CompileKernel); + this->Rng = ClRandomSkipRate.getNumOccurrences() ? M.createRNG(DEBUG_TYPE) + : nullptr; initializeModule(); } @@ -290,8 +310,9 @@ private: Value *PtrTag = nullptr; Value *MemTag = nullptr; }; - void setSSI(const StackSafetyGlobalInfo *S) { SSI = S; } + bool selectiveInstrumentationShouldSkip(Function &F, + FunctionAnalysisManager &FAM) const; void initializeModule(); void createHwasanCtorComdat(); @@ -316,13 +337,17 @@ private: unsigned AccessSizeIndex, Instruction *InsertBefore, DomTreeUpdater &DTU, LoopInfo *LI); - bool ignoreMemIntrinsic(MemIntrinsic *MI); + bool ignoreMemIntrinsic(OptimizationRemarkEmitter &ORE, MemIntrinsic *MI); void instrumentMemIntrinsic(MemIntrinsic *MI); bool instrumentMemAccess(InterestingMemoryOperand &O, DomTreeUpdater &DTU, LoopInfo *LI); - bool ignoreAccess(Instruction *Inst, Value *Ptr); + bool ignoreAccessWithoutRemark(Instruction *Inst, Value *Ptr); + bool ignoreAccess(OptimizationRemarkEmitter &ORE, Instruction *Inst, + Value *Ptr); + void getInterestingMemoryOperands( - Instruction *I, const TargetLibraryInfo &TLI, + OptimizationRemarkEmitter &ORE, Instruction *I, + const TargetLibraryInfo &TLI, SmallVectorImpl<InterestingMemoryOperand> &Interesting); void tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *Tag, size_t Size); @@ -331,14 +356,13 @@ private: bool instrumentStack(memtag::StackInfo &Info, Value *StackTag, Value *UARTag, const DominatorTree &DT, const PostDominatorTree &PDT, const LoopInfo &LI); - Value *readRegister(IRBuilder<> &IRB, StringRef Name); bool instrumentLandingPads(SmallVectorImpl<Instruction *> &RetVec); Value *getNextTagWithCall(IRBuilder<> &IRB); Value *getStackBaseTag(IRBuilder<> &IRB); Value *getAllocaTag(IRBuilder<> &IRB, Value *StackTag, unsigned AllocaNo); Value *getUARTag(IRBuilder<> &IRB); - Value *getHwasanThreadSlotPtr(IRBuilder<> &IRB, Type *Ty); + Value *getHwasanThreadSlotPtr(IRBuilder<> &IRB); Value *applyTagMask(IRBuilder<> &IRB, Value *OldTag); unsigned retagMask(unsigned AllocaNo); @@ -347,8 +371,7 @@ private: void instrumentGlobal(GlobalVariable *GV, uint8_t Tag); void instrumentGlobals(); - Value *getPC(IRBuilder<> &IRB); - Value *getSP(IRBuilder<> &IRB); + Value *getCachedFP(IRBuilder<> &IRB); Value *getFrameRecordInfo(IRBuilder<> &IRB); void instrumentPersonalityFunctions(); @@ -357,6 +380,7 @@ private: Module &M; const StackSafetyGlobalInfo *SSI; Triple TargetTriple; + std::unique_ptr<RandomNumberGenerator> Rng; /// This struct defines the shadow mapping using the rule: /// shadow = (mem >> Scale) + Offset. @@ -383,10 +407,10 @@ private: ShadowMapping Mapping; Type *VoidTy = Type::getVoidTy(M.getContext()); - Type *IntptrTy; - PointerType *PtrTy; - Type *Int8Ty; - Type *Int32Ty; + Type *IntptrTy = M.getDataLayout().getIntPtrType(M.getContext()); + PointerType *PtrTy = PointerType::getUnqual(M.getContext()); + Type *Int8Ty = Type::getInt8Ty(M.getContext()); + Type *Int32Ty = Type::getInt32Ty(M.getContext()); Type *Int64Ty = Type::getInt64Ty(M.getContext()); bool CompileKernel; @@ -397,6 +421,7 @@ private: bool InstrumentLandingPads; bool InstrumentWithCalls; bool InstrumentStack; + bool InstrumentGlobals; bool DetectUseAfterScope; bool UsePageAliases; bool UseMatchAllCallback; @@ -422,7 +447,7 @@ private: Value *ShadowBase = nullptr; Value *StackBaseTag = nullptr; - Value *CachedSP = nullptr; + Value *CachedFP = nullptr; GlobalValue *ThreadPtrGlobal = nullptr; }; @@ -567,8 +592,6 @@ void HWAddressSanitizer::createHwasanCtorComdat() { /// inserts a call to __hwasan_init to the module's constructor list. void HWAddressSanitizer::initializeModule() { LLVM_DEBUG(dbgs() << "Init " << M.getName() << "\n"); - auto &DL = M.getDataLayout(); - TargetTriple = Triple(M.getTargetTriple()); // x86_64 currently has two modes: @@ -586,10 +609,6 @@ void HWAddressSanitizer::initializeModule() { C = &(M.getContext()); IRBuilder<> IRB(*C); - IntptrTy = IRB.getIntPtrTy(DL); - PtrTy = IRB.getPtrTy(); - Int8Ty = IRB.getInt8Ty(); - Int32Ty = IRB.getInt32Ty(); HwasanCtorFunction = nullptr; @@ -599,19 +618,14 @@ void HWAddressSanitizer::initializeModule() { bool NewRuntime = !TargetTriple.isAndroid() || !TargetTriple.isAndroidVersionLT(30); - UseShortGranules = - ClUseShortGranules.getNumOccurrences() ? ClUseShortGranules : NewRuntime; - OutlinedChecks = - (TargetTriple.isAArch64() || TargetTriple.isRISCV64()) && - TargetTriple.isOSBinFormatELF() && - (ClInlineAllChecks.getNumOccurrences() ? !ClInlineAllChecks : !Recover); + UseShortGranules = optOr(ClUseShortGranules, NewRuntime); + OutlinedChecks = (TargetTriple.isAArch64() || TargetTriple.isRISCV64()) && + TargetTriple.isOSBinFormatELF() && + !optOr(ClInlineAllChecks, Recover); - InlineFastPath = - (ClInlineFastPathChecks.getNumOccurrences() - ? ClInlineFastPathChecks - : !(TargetTriple.isAndroid() || - TargetTriple.isOSFuchsia())); // These platforms may prefer less - // inlining to reduce binary size. + // These platforms may prefer less inlining to reduce binary size. + InlineFastPath = optOr(ClInlineFastPathChecks, !(TargetTriple.isAndroid() || + TargetTriple.isOSFuchsia())); if (ClMatchAllTag.getNumOccurrences()) { if (ClMatchAllTag != -1) { @@ -623,22 +637,19 @@ void HWAddressSanitizer::initializeModule() { UseMatchAllCallback = !CompileKernel && MatchAllTag.has_value(); // If we don't have personality function support, fall back to landing pads. - InstrumentLandingPads = ClInstrumentLandingPads.getNumOccurrences() - ? ClInstrumentLandingPads - : !NewRuntime; + InstrumentLandingPads = optOr(ClInstrumentLandingPads, !NewRuntime); + + InstrumentGlobals = + !CompileKernel && !UsePageAliases && optOr(ClGlobals, NewRuntime); if (!CompileKernel) { createHwasanCtorComdat(); - bool InstrumentGlobals = - ClGlobals.getNumOccurrences() ? ClGlobals : NewRuntime; - if (InstrumentGlobals && !UsePageAliases) + if (InstrumentGlobals) instrumentGlobals(); bool InstrumentPersonalityFunctions = - ClInstrumentPersonalityFunctions.getNumOccurrences() - ? ClInstrumentPersonalityFunctions - : NewRuntime; + optOr(ClInstrumentPersonalityFunctions, NewRuntime); if (InstrumentPersonalityFunctions) instrumentPersonalityFunctions(); } @@ -758,7 +769,8 @@ Value *HWAddressSanitizer::getShadowNonTls(IRBuilder<> &IRB) { return IRB.CreateLoad(PtrTy, GlobalDynamicAddress); } -bool HWAddressSanitizer::ignoreAccess(Instruction *Inst, Value *Ptr) { +bool HWAddressSanitizer::ignoreAccessWithoutRemark(Instruction *Inst, + Value *Ptr) { // Do not instrument accesses from different address spaces; we cannot deal // with them. Type *PtrTy = cast<PointerType>(Ptr->getType()->getScalarType()); @@ -778,11 +790,33 @@ bool HWAddressSanitizer::ignoreAccess(Instruction *Inst, Value *Ptr) { if (SSI && SSI->stackAccessIsSafe(*Inst)) return true; } + + if (isa<GlobalVariable>(getUnderlyingObject(Ptr))) { + if (!InstrumentGlobals) + return true; + // TODO: Optimize inbound global accesses, like Asan `instrumentMop`. + } + return false; } +bool HWAddressSanitizer::ignoreAccess(OptimizationRemarkEmitter &ORE, + Instruction *Inst, Value *Ptr) { + bool Ignored = ignoreAccessWithoutRemark(Inst, Ptr); + if (Ignored) { + ORE.emit( + [&]() { return OptimizationRemark(DEBUG_TYPE, "ignoreAccess", Inst); }); + } else { + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "ignoreAccess", Inst); + }); + } + return Ignored; +} + void HWAddressSanitizer::getInterestingMemoryOperands( - Instruction *I, const TargetLibraryInfo &TLI, + OptimizationRemarkEmitter &ORE, Instruction *I, + const TargetLibraryInfo &TLI, SmallVectorImpl<InterestingMemoryOperand> &Interesting) { // Skip memory accesses inserted by another instrumentation. if (I->hasMetadata(LLVMContext::MD_nosanitize)) @@ -793,22 +827,22 @@ void HWAddressSanitizer::getInterestingMemoryOperands( return; if (LoadInst *LI = dyn_cast<LoadInst>(I)) { - if (!ClInstrumentReads || ignoreAccess(I, LI->getPointerOperand())) + if (!ClInstrumentReads || ignoreAccess(ORE, I, LI->getPointerOperand())) return; Interesting.emplace_back(I, LI->getPointerOperandIndex(), false, LI->getType(), LI->getAlign()); } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) { - if (!ClInstrumentWrites || ignoreAccess(I, SI->getPointerOperand())) + if (!ClInstrumentWrites || ignoreAccess(ORE, I, SI->getPointerOperand())) return; Interesting.emplace_back(I, SI->getPointerOperandIndex(), true, SI->getValueOperand()->getType(), SI->getAlign()); } else if (AtomicRMWInst *RMW = dyn_cast<AtomicRMWInst>(I)) { - if (!ClInstrumentAtomics || ignoreAccess(I, RMW->getPointerOperand())) + if (!ClInstrumentAtomics || ignoreAccess(ORE, I, RMW->getPointerOperand())) return; Interesting.emplace_back(I, RMW->getPointerOperandIndex(), true, RMW->getValOperand()->getType(), std::nullopt); } else if (AtomicCmpXchgInst *XCHG = dyn_cast<AtomicCmpXchgInst>(I)) { - if (!ClInstrumentAtomics || ignoreAccess(I, XCHG->getPointerOperand())) + if (!ClInstrumentAtomics || ignoreAccess(ORE, I, XCHG->getPointerOperand())) return; Interesting.emplace_back(I, XCHG->getPointerOperandIndex(), true, XCHG->getCompareOperand()->getType(), @@ -816,7 +850,7 @@ void HWAddressSanitizer::getInterestingMemoryOperands( } else if (auto *CI = dyn_cast<CallInst>(I)) { for (unsigned ArgNo = 0; ArgNo < CI->arg_size(); ArgNo++) { if (!ClInstrumentByval || !CI->isByValArgument(ArgNo) || - ignoreAccess(I, CI->getArgOperand(ArgNo))) + ignoreAccess(ORE, I, CI->getArgOperand(ArgNo))) continue; Type *Ty = CI->getParamByValType(ArgNo); Interesting.emplace_back(I, ArgNo, false, Ty, Align(1)); @@ -898,7 +932,7 @@ HWAddressSanitizer::insertShadowTagCheck(Value *Ptr, Instruction *InsertBefore, R.TagMismatchTerm = SplitBlockAndInsertIfThen( TagMismatch, InsertBefore, false, - MDBuilder(*C).createBranchWeights(1, 100000), &DTU, LI); + MDBuilder(*C).createUnlikelyBranchWeights(), &DTU, LI); return R; } @@ -917,11 +951,33 @@ void HWAddressSanitizer::instrumentMemAccessOutline(Value *Ptr, bool IsWrite, IRBuilder<> IRB(InsertBefore); Module *M = IRB.GetInsertBlock()->getParent()->getParent(); - IRB.CreateCall(Intrinsic::getDeclaration( - M, UseShortGranules - ? Intrinsic::hwasan_check_memaccess_shortgranules - : Intrinsic::hwasan_check_memaccess), - {ShadowBase, Ptr, ConstantInt::get(Int32Ty, AccessInfo)}); + bool useFixedShadowIntrinsic = false; + // The memaccess fixed shadow intrinsic is only supported on AArch64, + // which allows a 16-bit immediate to be left-shifted by 32. + // Since kShadowBaseAlignment == 32, and Linux by default will not + // mmap above 48-bits, practically any valid shadow offset is + // representable. + // In particular, an offset of 4TB (1024 << 32) is representable, and + // ought to be good enough for anybody. + if (TargetTriple.isAArch64() && Mapping.Offset != kDynamicShadowSentinel) { + uint16_t offset_shifted = Mapping.Offset >> 32; + useFixedShadowIntrinsic = (uint64_t)offset_shifted << 32 == Mapping.Offset; + } + + if (useFixedShadowIntrinsic) + IRB.CreateCall( + Intrinsic::getDeclaration( + M, UseShortGranules + ? Intrinsic::hwasan_check_memaccess_shortgranules_fixedshadow + : Intrinsic::hwasan_check_memaccess_fixedshadow), + {Ptr, ConstantInt::get(Int32Ty, AccessInfo), + ConstantInt::get(Int64Ty, Mapping.Offset)}); + else + IRB.CreateCall(Intrinsic::getDeclaration( + M, UseShortGranules + ? Intrinsic::hwasan_check_memaccess_shortgranules + : Intrinsic::hwasan_check_memaccess), + {ShadowBase, Ptr, ConstantInt::get(Int32Ty, AccessInfo)}); } void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, @@ -939,7 +995,7 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, IRB.CreateICmpUGT(TCI.MemTag, ConstantInt::get(Int8Ty, 15)); Instruction *CheckFailTerm = SplitBlockAndInsertIfThen( OutOfShortGranuleTagRange, TCI.TagMismatchTerm, !Recover, - MDBuilder(*C).createBranchWeights(1, 100000), &DTU, LI); + MDBuilder(*C).createUnlikelyBranchWeights(), &DTU, LI); IRB.SetInsertPoint(TCI.TagMismatchTerm); Value *PtrLowBits = IRB.CreateTrunc(IRB.CreateAnd(TCI.PtrLong, 15), Int8Ty); @@ -947,7 +1003,7 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, PtrLowBits, ConstantInt::get(Int8Ty, (1 << AccessSizeIndex) - 1)); Value *PtrLowBitsOOB = IRB.CreateICmpUGE(PtrLowBits, TCI.MemTag); SplitBlockAndInsertIfThen(PtrLowBitsOOB, TCI.TagMismatchTerm, false, - MDBuilder(*C).createBranchWeights(1, 100000), &DTU, + MDBuilder(*C).createUnlikelyBranchWeights(), &DTU, LI, CheckFailTerm->getParent()); IRB.SetInsertPoint(TCI.TagMismatchTerm); @@ -956,7 +1012,7 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, Value *InlineTag = IRB.CreateLoad(Int8Ty, InlineTagAddr); Value *InlineTagMismatch = IRB.CreateICmpNE(TCI.PtrTag, InlineTag); SplitBlockAndInsertIfThen(InlineTagMismatch, TCI.TagMismatchTerm, false, - MDBuilder(*C).createBranchWeights(1, 100000), &DTU, + MDBuilder(*C).createUnlikelyBranchWeights(), &DTU, LI, CheckFailTerm->getParent()); IRB.SetInsertPoint(CheckFailTerm); @@ -999,13 +1055,14 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *Ptr, bool IsWrite, ->setSuccessor(0, TCI.TagMismatchTerm->getParent()); } -bool HWAddressSanitizer::ignoreMemIntrinsic(MemIntrinsic *MI) { +bool HWAddressSanitizer::ignoreMemIntrinsic(OptimizationRemarkEmitter &ORE, + MemIntrinsic *MI) { if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) { - return (!ClInstrumentWrites || ignoreAccess(MTI, MTI->getDest())) && - (!ClInstrumentReads || ignoreAccess(MTI, MTI->getSource())); + return (!ClInstrumentWrites || ignoreAccess(ORE, MTI, MTI->getDest())) && + (!ClInstrumentReads || ignoreAccess(ORE, MTI, MTI->getSource())); } if (isa<MemSetInst>(MI)) - return !ClInstrumentWrites || ignoreAccess(MI, MI->getDest()); + return !ClInstrumentWrites || ignoreAccess(ORE, MI, MI->getDest()); return false; } @@ -1148,10 +1205,10 @@ Value *HWAddressSanitizer::getStackBaseTag(IRBuilder<> &IRB) { // Extract some entropy from the stack pointer for the tags. // Take bits 20..28 (ASLR entropy) and xor with bits 0..8 (these differ // between functions). - Value *StackPointerLong = getSP(IRB); + Value *FramePointerLong = getCachedFP(IRB); Value *StackTag = - applyTagMask(IRB, IRB.CreateXor(StackPointerLong, - IRB.CreateLShr(StackPointerLong, 20))); + applyTagMask(IRB, IRB.CreateXor(FramePointerLong, + IRB.CreateLShr(FramePointerLong, 20))); StackTag->setName("hwasan.stack.base.tag"); return StackTag; } @@ -1165,9 +1222,9 @@ Value *HWAddressSanitizer::getAllocaTag(IRBuilder<> &IRB, Value *StackTag, } Value *HWAddressSanitizer::getUARTag(IRBuilder<> &IRB) { - Value *StackPointerLong = getSP(IRB); + Value *FramePointerLong = getCachedFP(IRB); Value *UARTag = - applyTagMask(IRB, IRB.CreateLShr(StackPointerLong, PointerTagShift)); + applyTagMask(IRB, IRB.CreateLShr(FramePointerLong, PointerTagShift)); UARTag->setName("hwasan.uar.tag"); return UARTag; @@ -1210,57 +1267,37 @@ Value *HWAddressSanitizer::untagPointer(IRBuilder<> &IRB, Value *PtrLong) { return UntaggedPtrLong; } -Value *HWAddressSanitizer::getHwasanThreadSlotPtr(IRBuilder<> &IRB, Type *Ty) { - Module *M = IRB.GetInsertBlock()->getParent()->getParent(); - if (TargetTriple.isAArch64() && TargetTriple.isAndroid()) { - // Android provides a fixed TLS slot for sanitizers. See TLS_SLOT_SANITIZER - // in Bionic's libc/private/bionic_tls.h. - Function *ThreadPointerFunc = - Intrinsic::getDeclaration(M, Intrinsic::thread_pointer); - return IRB.CreateConstGEP1_32(Int8Ty, IRB.CreateCall(ThreadPointerFunc), - 0x30); - } - if (ThreadPtrGlobal) - return ThreadPtrGlobal; - - return nullptr; -} - -Value *HWAddressSanitizer::getPC(IRBuilder<> &IRB) { - if (TargetTriple.getArch() == Triple::aarch64) - return readRegister(IRB, "pc"); - return IRB.CreatePtrToInt(IRB.GetInsertBlock()->getParent(), IntptrTy); +Value *HWAddressSanitizer::getHwasanThreadSlotPtr(IRBuilder<> &IRB) { + // Android provides a fixed TLS slot for sanitizers. See TLS_SLOT_SANITIZER + // in Bionic's libc/platform/bionic/tls_defines.h. + constexpr int SanitizerSlot = 6; + if (TargetTriple.isAArch64() && TargetTriple.isAndroid()) + return memtag::getAndroidSlotPtr(IRB, SanitizerSlot); + return ThreadPtrGlobal; } -Value *HWAddressSanitizer::getSP(IRBuilder<> &IRB) { - if (!CachedSP) { - // FIXME: use addressofreturnaddress (but implement it in aarch64 backend - // first). - Function *F = IRB.GetInsertBlock()->getParent(); - Module *M = F->getParent(); - auto *GetStackPointerFn = Intrinsic::getDeclaration( - M, Intrinsic::frameaddress, - IRB.getPtrTy(M->getDataLayout().getAllocaAddrSpace())); - CachedSP = IRB.CreatePtrToInt( - IRB.CreateCall(GetStackPointerFn, {Constant::getNullValue(Int32Ty)}), - IntptrTy); - } - return CachedSP; +Value *HWAddressSanitizer::getCachedFP(IRBuilder<> &IRB) { + if (!CachedFP) + CachedFP = memtag::getFP(IRB); + return CachedFP; } Value *HWAddressSanitizer::getFrameRecordInfo(IRBuilder<> &IRB) { // Prepare ring buffer data. - Value *PC = getPC(IRB); - Value *SP = getSP(IRB); + Value *PC = memtag::getPC(TargetTriple, IRB); + Value *FP = getCachedFP(IRB); - // Mix SP and PC. + // Mix FP and PC. // Assumptions: // PC is 0x0000PPPPPPPPPPPP (48 bits are meaningful, others are zero) - // SP is 0xsssssssssssSSSS0 (4 lower bits are zero) - // We only really need ~20 lower non-zero bits (SSSS), so we mix like this: - // 0xSSSSPPPPPPPPPPPP - SP = IRB.CreateShl(SP, 44); - return IRB.CreateOr(PC, SP); + // FP is 0xfffffffffffFFFF0 (4 lower bits are zero) + // We only really need ~20 lower non-zero bits (FFFF), so we mix like this: + // 0xFFFFPPPPPPPPPPPP + // + // FP works because in AArch64FrameLowering::getFrameIndexReference, we + // prefer FP-relative offsets for functions compiled with HWASan. + FP = IRB.CreateShl(FP, 44); + return IRB.CreateOr(PC, FP); } void HWAddressSanitizer::emitPrologue(IRBuilder<> &IRB, bool WithFrameRecord) { @@ -1278,7 +1315,7 @@ void HWAddressSanitizer::emitPrologue(IRBuilder<> &IRB, bool WithFrameRecord) { auto getThreadLongMaybeUntagged = [&]() { if (!SlotPtr) - SlotPtr = getHwasanThreadSlotPtr(IRB, IntptrTy); + SlotPtr = getHwasanThreadSlotPtr(IRB); if (!ThreadLong) ThreadLong = IRB.CreateLoad(IntptrTy, SlotPtr); // Extract the address field from ThreadLong. Unnecessary on AArch64 with @@ -1314,6 +1351,22 @@ void HWAddressSanitizer::emitPrologue(IRBuilder<> &IRB, bool WithFrameRecord) { // The use of AShr instead of LShr is due to // https://bugs.llvm.org/show_bug.cgi?id=39030 // Runtime library makes sure not to use the highest bit. + // + // Mechanical proof of this address calculation can be found at: + // https://github.com/google/sanitizers/blob/master/hwaddress-sanitizer/prove_hwasanwrap.smt2 + // + // Example of the wrap case for N = 1 + // Pointer: 0x01AAAAAAAAAAAFF8 + // + + // 0x0000000000000008 + // = + // 0x01AAAAAAAAAAB000 + // & + // WrapMask: 0xFFFFFFFFFFFFF000 + // = + // 0x01AAAAAAAAAAA000 + // + // Then the WrapMask will be a no-op until the next wrap case. Value *WrapMask = IRB.CreateXor( IRB.CreateShl(IRB.CreateAShr(ThreadLong, 56), 12, "", true, true), ConstantInt::get(IntptrTy, (uint64_t)-1)); @@ -1345,32 +1398,18 @@ void HWAddressSanitizer::emitPrologue(IRBuilder<> &IRB, bool WithFrameRecord) { } } -Value *HWAddressSanitizer::readRegister(IRBuilder<> &IRB, StringRef Name) { - Module *M = IRB.GetInsertBlock()->getParent()->getParent(); - Function *ReadRegister = - Intrinsic::getDeclaration(M, Intrinsic::read_register, IntptrTy); - MDNode *MD = MDNode::get(*C, {MDString::get(*C, Name)}); - Value *Args[] = {MetadataAsValue::get(*C, MD)}; - return IRB.CreateCall(ReadRegister, Args); -} - bool HWAddressSanitizer::instrumentLandingPads( SmallVectorImpl<Instruction *> &LandingPadVec) { for (auto *LP : LandingPadVec) { - IRBuilder<> IRB(LP->getNextNode()); + IRBuilder<> IRB(LP->getNextNonDebugInstruction()); IRB.CreateCall( HwasanHandleVfork, - {readRegister(IRB, (TargetTriple.getArch() == Triple::x86_64) ? "rsp" - : "sp")}); + {memtag::readRegister( + IRB, (TargetTriple.getArch() == Triple::x86_64) ? "rsp" : "sp")}); } return true; } -static bool isLifetimeIntrinsic(Value *V) { - auto *II = dyn_cast<IntrinsicInst>(V); - return II && II->isLifetimeStartOrEnd(); -} - bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo, Value *StackTag, Value *UARTag, const DominatorTree &DT, @@ -1387,7 +1426,7 @@ bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo, auto N = I++; auto *AI = KV.first; memtag::AllocaInfo &Info = KV.second; - IRBuilder<> IRB(AI->getNextNode()); + IRBuilder<> IRB(AI->getNextNonDebugInstruction()); // Replace uses of the alloca with tagged address. Value *Tag = getAllocaTag(IRB, StackTag, N); @@ -1422,20 +1461,11 @@ bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo, AI->replaceUsesWithIf(Replacement, [AICast, AILong](const Use &U) { auto *User = U.getUser(); - return User != AILong && User != AICast && !isLifetimeIntrinsic(User); + return User != AILong && User != AICast && + !memtag::isLifetimeIntrinsic(User); }); - for (auto *DDI : Info.DbgVariableIntrinsics) { - // Prepend "tag_offset, N" to the dwarf expression. - // Tag offset logically applies to the alloca pointer, and it makes sense - // to put it at the beginning of the expression. - SmallVector<uint64_t, 8> NewOps = {dwarf::DW_OP_LLVM_tag_offset, - retagMask(N)}; - for (size_t LocNo = 0; LocNo < DDI->getNumVariableLocationOps(); ++LocNo) - if (DDI->getVariableLocationOp(LocNo) == AI) - DDI->setExpression(DIExpression::appendOpsToArg(DDI->getExpression(), - NewOps, LocNo)); - } + memtag::annotateDebugRecords(Info, retagMask(N)); auto TagEnd = [&](Instruction *Node) { IRB.SetInsertPoint(Node); @@ -1450,10 +1480,10 @@ bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo, // function return. Work around this by always untagging at every return // statement if return_twice functions are called. bool StandardLifetime = + !SInfo.CallsReturnTwice && SInfo.UnrecognizedLifetimes.empty() && memtag::isStandardLifetime(Info.LifetimeStart, Info.LifetimeEnd, &DT, - &LI, ClMaxLifetimes) && - !SInfo.CallsReturnTwice; + &LI, ClMaxLifetimes); if (DetectUseAfterScope && StandardLifetime) { IntrinsicInst *Start = Info.LifetimeStart[0]; IRB.SetInsertPoint(Start->getNextNode()); @@ -1481,6 +1511,44 @@ bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo, return true; } +static void emitRemark(const Function &F, OptimizationRemarkEmitter &ORE, + bool Skip) { + if (Skip) { + ORE.emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "Skip", &F) + << "Skipped: F=" << ore::NV("Function", &F); + }); + } else { + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "Sanitize", &F) + << "Sanitized: F=" << ore::NV("Function", &F); + }); + } +} + +bool HWAddressSanitizer::selectiveInstrumentationShouldSkip( + Function &F, FunctionAnalysisManager &FAM) const { + bool Skip = [&]() { + if (ClRandomSkipRate.getNumOccurrences()) { + std::bernoulli_distribution D(ClRandomSkipRate); + return !D(*Rng); + } + if (!ClHotPercentileCutoff.getNumOccurrences()) + return false; + auto &MAMProxy = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F); + ProfileSummaryInfo *PSI = + MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); + if (!PSI || !PSI->hasProfileSummary()) { + ++NumNoProfileSummaryFuncs; + return false; + } + return PSI->isFunctionHotInCallGraphNthPercentile( + ClHotPercentileCutoff, &F, FAM.getResult<BlockFrequencyAnalysis>(F)); + }(); + emitRemark(F, FAM.getResult<OptimizationRemarkEmitterAnalysis>(F), Skip); + return Skip; +} + void HWAddressSanitizer::sanitizeFunction(Function &F, FunctionAnalysisManager &FAM) { if (&F == HwasanCtorFunction) @@ -1489,6 +1557,19 @@ void HWAddressSanitizer::sanitizeFunction(Function &F, if (!F.hasFnAttribute(Attribute::SanitizeHWAddress)) return; + if (F.empty()) + return; + + NumTotalFuncs++; + + OptimizationRemarkEmitter &ORE = + FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); + + if (selectiveInstrumentationShouldSkip(F, FAM)) + return; + + NumInstrumentedFuncs++; + LLVM_DEBUG(dbgs() << "Function: " << F.getName() << "\n"); SmallVector<InterestingMemoryOperand, 16> OperandsToInstrument; @@ -1505,10 +1586,10 @@ void HWAddressSanitizer::sanitizeFunction(Function &F, if (InstrumentLandingPads && isa<LandingPadInst>(Inst)) LandingPadVec.push_back(&Inst); - getInterestingMemoryOperands(&Inst, TLI, OperandsToInstrument); + getInterestingMemoryOperands(ORE, &Inst, TLI, OperandsToInstrument); if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(&Inst)) - if (!ignoreMemIntrinsic(MI)) + if (!ignoreMemIntrinsic(ORE, MI)) IntrinToInstrument.push_back(MI); } @@ -1532,8 +1613,16 @@ void HWAddressSanitizer::sanitizeFunction(Function &F, assert(!ShadowBase); - Instruction *InsertPt = &*F.getEntryBlock().begin(); - IRBuilder<> EntryIRB(InsertPt); + // Remove memory attributes that are about to become invalid. + // HWASan checks read from shadow, which invalidates memory(argmem: *) + // Short granule checks on function arguments read from the argument memory + // (last byte of the granule), which invalidates writeonly. + F.removeFnAttr(llvm::Attribute::Memory); + for (auto &A : F.args()) + A.removeAttr(llvm::Attribute::WriteOnly); + + BasicBlock::iterator InsertPt = F.getEntryBlock().begin(); + IRBuilder<> EntryIRB(&F.getEntryBlock(), InsertPt); emitPrologue(EntryIRB, /*WithFrameRecord*/ ClRecordStackHistory != none && Mapping.WithFrameRecord && @@ -1552,12 +1641,12 @@ void HWAddressSanitizer::sanitizeFunction(Function &F, // entry block back into the entry block so that they aren't treated as // dynamic allocas. if (EntryIRB.GetInsertBlock() != &F.getEntryBlock()) { - InsertPt = &*F.getEntryBlock().begin(); + InsertPt = F.getEntryBlock().begin(); for (Instruction &I : llvm::make_early_inc_range(*EntryIRB.GetInsertBlock())) { if (auto *AI = dyn_cast<AllocaInst>(&I)) if (isa<ConstantInt>(AI->getArraySize())) - I.moveBefore(InsertPt); + I.moveBefore(F.getEntryBlock(), InsertPt); } } @@ -1576,7 +1665,7 @@ void HWAddressSanitizer::sanitizeFunction(Function &F, ShadowBase = nullptr; StackBaseTag = nullptr; - CachedSP = nullptr; + CachedFP = nullptr; } void HWAddressSanitizer::instrumentGlobal(GlobalVariable *GV, uint8_t Tag) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp index 7344fea17517..0d1f50698637 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp @@ -13,13 +13,16 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/IndirectCallPromotionAnalysis.h" #include "llvm/Analysis/IndirectCallVisitor.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" @@ -40,7 +43,9 @@ #include <cassert> #include <cstdint> #include <memory> +#include <set> #include <string> +#include <unordered_map> #include <utility> #include <vector> @@ -51,6 +56,12 @@ using namespace llvm; STATISTIC(NumOfPGOICallPromotion, "Number of indirect call promotions."); STATISTIC(NumOfPGOICallsites, "Number of indirect call candidate sites."); +extern cl::opt<unsigned> MaxNumVTableAnnotations; + +namespace llvm { +extern cl::opt<bool> EnableVTableProfileUse; +} + // Command line option to disable indirect-call promotion with the default as // false. This is for debug purpose. static cl::opt<bool> DisableICP("disable-icp", cl::init(false), cl::Hidden, @@ -103,13 +114,196 @@ static cl::opt<bool> ICPDUMPAFTER("icp-dumpafter", cl::init(false), cl::Hidden, cl::desc("Dump IR after transformation happens")); +// Indirect call promotion pass will fall back to function-based comparison if +// vtable-count / function-count is smaller than this threshold. +static cl::opt<float> ICPVTablePercentageThreshold( + "icp-vtable-percentage-threshold", cl::init(0.99), cl::Hidden, + cl::desc("The percentage threshold of vtable-count / function-count for " + "cost-benefit analysis.")); + +// Although comparing vtables can save a vtable load, we may need to compare +// vtable pointer with multiple vtable address points due to class inheritance. +// Comparing with multiple vtables inserts additional instructions on hot code +// path, and doing so for an earlier candidate delays the comparisons for later +// candidates. For the last candidate, only the fallback path is affected. +// We allow multiple vtable comparison for the last function candidate and use +// the option below to cap the number of vtables. +static cl::opt<int> ICPMaxNumVTableLastCandidate( + "icp-max-num-vtable-last-candidate", cl::init(1), cl::Hidden, + cl::desc("The maximum number of vtable for the last candidate.")); + namespace { +// The key is a vtable global variable, and the value is a map. +// In the inner map, the key represents address point offsets and the value is a +// constant for this address point. +using VTableAddressPointOffsetValMap = + SmallDenseMap<const GlobalVariable *, std::unordered_map<int, Constant *>>; + +// A struct to collect type information for a virtual call site. +struct VirtualCallSiteInfo { + // The offset from the address point to virtual function in the vtable. + uint64_t FunctionOffset; + // The instruction that computes the address point of vtable. + Instruction *VPtr; + // The compatible type used in LLVM type intrinsics. + StringRef CompatibleTypeStr; +}; + +// The key is a virtual call, and value is its type information. +using VirtualCallSiteTypeInfoMap = + SmallDenseMap<const CallBase *, VirtualCallSiteInfo>; + +// The key is vtable GUID, and value is its value profile count. +using VTableGUIDCountsMap = SmallDenseMap<uint64_t, uint64_t, 16>; + +// Return the address point offset of the given compatible type. +// +// Type metadata of a vtable specifies the types that can contain a pointer to +// this vtable, for example, `Base*` can be a pointer to an derived type +// but not vice versa. See also https://llvm.org/docs/TypeMetadata.html +static std::optional<uint64_t> +getAddressPointOffset(const GlobalVariable &VTableVar, + StringRef CompatibleType) { + SmallVector<MDNode *> Types; + VTableVar.getMetadata(LLVMContext::MD_type, Types); + + for (MDNode *Type : Types) + if (auto *TypeId = dyn_cast<MDString>(Type->getOperand(1).get()); + TypeId && TypeId->getString() == CompatibleType) + return cast<ConstantInt>( + cast<ConstantAsMetadata>(Type->getOperand(0))->getValue()) + ->getZExtValue(); + + return std::nullopt; +} + +// Return a constant representing the vtable's address point specified by the +// offset. +static Constant *getVTableAddressPointOffset(GlobalVariable *VTable, + uint32_t AddressPointOffset) { + Module &M = *VTable->getParent(); + LLVMContext &Context = M.getContext(); + assert(AddressPointOffset < + M.getDataLayout().getTypeAllocSize(VTable->getValueType()) && + "Out-of-bound access"); + + return ConstantExpr::getInBoundsGetElementPtr( + Type::getInt8Ty(Context), VTable, + llvm::ConstantInt::get(Type::getInt32Ty(Context), AddressPointOffset)); +} + +// Return the basic block in which Use `U` is used via its `UserInst`. +static BasicBlock *getUserBasicBlock(Use &U, Instruction *UserInst) { + if (PHINode *PN = dyn_cast<PHINode>(UserInst)) + return PN->getIncomingBlock(U); + + return UserInst->getParent(); +} + +// `DestBB` is a suitable basic block to sink `Inst` into when `Inst` have users +// and all users are in `DestBB`. The caller guarantees that `Inst->getParent()` +// is the sole predecessor of `DestBB` and `DestBB` is dominated by +// `Inst->getParent()`. +static bool isDestBBSuitableForSink(Instruction *Inst, BasicBlock *DestBB) { + // 'BB' is used only by assert. + [[maybe_unused]] BasicBlock *BB = Inst->getParent(); + + assert(BB != DestBB && BB->getTerminator()->getNumSuccessors() == 2 && + DestBB->getUniquePredecessor() == BB && + "Guaranteed by ICP transformation"); + + BasicBlock *UserBB = nullptr; + for (Use &Use : Inst->uses()) { + User *User = Use.getUser(); + // Do checked cast since IR verifier guarantees that the user of an + // instruction must be an instruction. See `Verifier::visitInstruction`. + Instruction *UserInst = cast<Instruction>(User); + // We can sink debug or pseudo instructions together with Inst. + if (UserInst->isDebugOrPseudoInst()) + continue; + UserBB = getUserBasicBlock(Use, UserInst); + // Do not sink if Inst is used in a basic block that is not DestBB. + // TODO: Sink to the common dominator of all user blocks. + if (UserBB != DestBB) + return false; + } + return UserBB != nullptr; +} + +// For the virtual call dispatch sequence, try to sink vtable load instructions +// to the cold indirect call fallback. +// FIXME: Move the sink eligibility check below to a utility function in +// Transforms/Utils/ directory. +static bool tryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { + if (!isDestBBSuitableForSink(I, DestBlock)) + return false; + + // Do not move control-flow-involving, volatile loads, vaarg, alloca + // instructions, etc. + if (isa<PHINode>(I) || I->isEHPad() || I->mayThrow() || !I->willReturn() || + isa<AllocaInst>(I)) + return false; + + // Do not sink convergent call instructions. + if (const auto *C = dyn_cast<CallBase>(I)) + if (C->isInlineAsm() || C->cannotMerge() || C->isConvergent()) + return false; + + // Do not move an instruction that may write to memory. + if (I->mayWriteToMemory()) + return false; + + // We can only sink load instructions if there is nothing between the load and + // the end of block that could change the value. + if (I->mayReadFromMemory()) { + // We already know that SrcBlock is the unique predecessor of DestBlock. + for (BasicBlock::iterator Scan = std::next(I->getIterator()), + E = I->getParent()->end(); + Scan != E; ++Scan) { + // Note analysis analysis can tell whether two pointers can point to the + // same object in memory or not thereby find further opportunities to + // sink. + if (Scan->mayWriteToMemory()) + return false; + } + } + + BasicBlock::iterator InsertPos = DestBlock->getFirstInsertionPt(); + I->moveBefore(*DestBlock, InsertPos); + + // TODO: Sink debug intrinsic users of I to 'DestBlock'. + // 'InstCombinerImpl::tryToSinkInstructionDbgValues' and + // 'InstCombinerImpl::tryToSinkInstructionDbgVariableRecords' already have + // the core logic to do this. + return true; +} + +// Try to sink instructions after VPtr to the indirect call fallback. +// Return the number of sunk IR instructions. +static int tryToSinkInstructions(BasicBlock *OriginalBB, + BasicBlock *IndirectCallBB) { + int SinkCount = 0; + // Do not sink across a critical edge for simplicity. + if (IndirectCallBB->getUniquePredecessor() != OriginalBB) + return SinkCount; + // Sink all eligible instructions in OriginalBB in reverse order. + for (Instruction &I : + llvm::make_early_inc_range(llvm::drop_begin(llvm::reverse(*OriginalBB)))) + if (tryToSinkInstruction(&I, IndirectCallBB)) + SinkCount++; + + return SinkCount; +} + // Promote indirect calls to conditional direct calls, keeping track of // thresholds. class IndirectCallPromoter { private: Function &F; + Module &M; + + ProfileSummaryInfo *PSI = nullptr; // Symtab that maps indirect call profile values to function names and // defines. @@ -117,6 +311,11 @@ private: const bool SamplePGO; + // A map from a virtual call to its type information. + const VirtualCallSiteTypeInfoMap &VirtualCSInfo; + + VTableAddressPointOffsetValMap &VTableAddressPointOffsetVal; + OptimizationRemarkEmitter &ORE; // A struct that records the direct target and it's call count. @@ -124,6 +323,16 @@ private: Function *const TargetFunction; const uint64_t Count; + // The following fields only exists for promotion candidates with vtable + // information. + // + // Due to class inheritance, one virtual call candidate can come from + // multiple vtables. `VTableGUIDAndCounts` tracks the vtable GUIDs and + // counts for 'TargetFunction'. `AddressPoints` stores the vtable address + // points for comparison. + VTableGUIDCountsMap VTableGUIDAndCounts; + SmallVector<Constant *> AddressPoints; + PromotionCandidate(Function *F, uint64_t C) : TargetFunction(F), Count(C) {} }; @@ -133,19 +342,63 @@ private: // TotalCount is the total profiled count of call executions, and // NumCandidates is the number of candidate entries in ValueDataRef. std::vector<PromotionCandidate> getPromotionCandidatesForCallSite( - const CallBase &CB, const ArrayRef<InstrProfValueData> &ValueDataRef, + const CallBase &CB, ArrayRef<InstrProfValueData> ValueDataRef, uint64_t TotalCount, uint32_t NumCandidates); - // Promote a list of targets for one indirect-call callsite. Return - // the number of promotions. - uint32_t tryToPromote(CallBase &CB, - const std::vector<PromotionCandidate> &Candidates, - uint64_t &TotalCount); + // Promote a list of targets for one indirect-call callsite by comparing + // indirect callee with functions. Return true if there are IR + // transformations and false otherwise. + bool tryToPromoteWithFuncCmp(CallBase &CB, Instruction *VPtr, + ArrayRef<PromotionCandidate> Candidates, + uint64_t TotalCount, + ArrayRef<InstrProfValueData> ICallProfDataRef, + uint32_t NumCandidates, + VTableGUIDCountsMap &VTableGUIDCounts); + + // Promote a list of targets for one indirect call by comparing vtables with + // functions. Return true if there are IR transformations and false + // otherwise. + bool tryToPromoteWithVTableCmp( + CallBase &CB, Instruction *VPtr, ArrayRef<PromotionCandidate> Candidates, + uint64_t TotalFuncCount, uint32_t NumCandidates, + MutableArrayRef<InstrProfValueData> ICallProfDataRef, + VTableGUIDCountsMap &VTableGUIDCounts); + + // Return true if it's profitable to compare vtables for the callsite. + bool isProfitableToCompareVTables(const CallBase &CB, + ArrayRef<PromotionCandidate> Candidates, + uint64_t TotalCount); + + // Given an indirect callsite and the list of function candidates, compute + // the following vtable information in output parameters and return vtable + // pointer if type profiles exist. + // - Populate `VTableGUIDCounts` with <vtable-guid, count> using !prof + // metadata attached on the vtable pointer. + // - For each function candidate, finds out the vtables from which it gets + // called and stores the <vtable-guid, count> in promotion candidate. + Instruction *computeVTableInfos(const CallBase *CB, + VTableGUIDCountsMap &VTableGUIDCounts, + std::vector<PromotionCandidate> &Candidates); + + Constant *getOrCreateVTableAddressPointVar(GlobalVariable *GV, + uint64_t AddressPointOffset); + + void updateFuncValueProfiles(CallBase &CB, ArrayRef<InstrProfValueData> VDs, + uint64_t Sum, uint32_t MaxMDCount); + + void updateVPtrValueProfiles(Instruction *VPtr, + VTableGUIDCountsMap &VTableGUIDCounts); public: - IndirectCallPromoter(Function &Func, InstrProfSymtab *Symtab, bool SamplePGO, - OptimizationRemarkEmitter &ORE) - : F(Func), Symtab(Symtab), SamplePGO(SamplePGO), ORE(ORE) {} + IndirectCallPromoter( + Function &Func, Module &M, ProfileSummaryInfo *PSI, + InstrProfSymtab *Symtab, bool SamplePGO, + const VirtualCallSiteTypeInfoMap &VirtualCSInfo, + VTableAddressPointOffsetValMap &VTableAddressPointOffsetVal, + OptimizationRemarkEmitter &ORE) + : F(Func), M(M), PSI(PSI), Symtab(Symtab), SamplePGO(SamplePGO), + VirtualCSInfo(VirtualCSInfo), + VTableAddressPointOffsetVal(VTableAddressPointOffsetVal), ORE(ORE) {} IndirectCallPromoter(const IndirectCallPromoter &) = delete; IndirectCallPromoter &operator=(const IndirectCallPromoter &) = delete; @@ -158,7 +411,7 @@ public: // the count. Stop at the first target that is not promoted. std::vector<IndirectCallPromoter::PromotionCandidate> IndirectCallPromoter::getPromotionCandidatesForCallSite( - const CallBase &CB, const ArrayRef<InstrProfValueData> &ValueDataRef, + const CallBase &CB, ArrayRef<InstrProfValueData> ValueDataRef, uint64_t TotalCount, uint32_t NumCandidates) { std::vector<PromotionCandidate> Ret; @@ -241,24 +494,126 @@ IndirectCallPromoter::getPromotionCandidatesForCallSite( return Ret; } +Constant *IndirectCallPromoter::getOrCreateVTableAddressPointVar( + GlobalVariable *GV, uint64_t AddressPointOffset) { + auto [Iter, Inserted] = + VTableAddressPointOffsetVal[GV].try_emplace(AddressPointOffset, nullptr); + if (Inserted) + Iter->second = getVTableAddressPointOffset(GV, AddressPointOffset); + return Iter->second; +} + +Instruction *IndirectCallPromoter::computeVTableInfos( + const CallBase *CB, VTableGUIDCountsMap &GUIDCountsMap, + std::vector<PromotionCandidate> &Candidates) { + if (!EnableVTableProfileUse) + return nullptr; + + // Take the following code sequence as an example, here is how the code works + // @vtable1 = {[n x ptr] [... ptr @func1]} + // @vtable2 = {[m x ptr] [... ptr @func2]} + // + // %vptr = load ptr, ptr %d, !prof !0 + // %0 = tail call i1 @llvm.type.test(ptr %vptr, metadata !"vtable1") + // tail call void @llvm.assume(i1 %0) + // %vfn = getelementptr inbounds ptr, ptr %vptr, i64 1 + // %1 = load ptr, ptr %vfn + // call void %1(ptr %d), !prof !1 + // + // !0 = !{!"VP", i32 2, i64 100, i64 123, i64 50, i64 456, i64 50} + // !1 = !{!"VP", i32 0, i64 100, i64 789, i64 50, i64 579, i64 50} + // + // Step 1. Find out the %vptr instruction for indirect call and use its !prof + // to populate `GUIDCountsMap`. + // Step 2. For each vtable-guid, look up its definition from symtab. LTO can + // make vtable definitions visible across modules. + // Step 3. Compute the byte offset of the virtual call, by adding vtable + // address point offset and function's offset relative to vtable address + // point. For each function candidate, this step tells us the vtable from + // which it comes from, and the vtable address point to compare %vptr with. + + // Only virtual calls have virtual call site info. + auto Iter = VirtualCSInfo.find(CB); + if (Iter == VirtualCSInfo.end()) + return nullptr; + + LLVM_DEBUG(dbgs() << "\nComputing vtable infos for callsite #" + << NumOfPGOICallsites << "\n"); + + const auto &VirtualCallInfo = Iter->second; + Instruction *VPtr = VirtualCallInfo.VPtr; + + SmallDenseMap<Function *, int, 4> CalleeIndexMap; + for (size_t I = 0; I < Candidates.size(); I++) + CalleeIndexMap[Candidates[I].TargetFunction] = I; + + uint64_t TotalVTableCount = 0; + auto VTableValueDataArray = + getValueProfDataFromInst(*VirtualCallInfo.VPtr, IPVK_VTableTarget, + MaxNumVTableAnnotations, TotalVTableCount); + if (VTableValueDataArray.empty()) + return VPtr; + + // Compute the functions and counts from by each vtable. + for (const auto &V : VTableValueDataArray) { + uint64_t VTableVal = V.Value; + GUIDCountsMap[VTableVal] = V.Count; + GlobalVariable *VTableVar = Symtab->getGlobalVariable(VTableVal); + if (!VTableVar) { + LLVM_DEBUG(dbgs() << " Cannot find vtable definition for " << VTableVal + << "; maybe the vtable isn't imported\n"); + continue; + } + + std::optional<uint64_t> MaybeAddressPointOffset = + getAddressPointOffset(*VTableVar, VirtualCallInfo.CompatibleTypeStr); + if (!MaybeAddressPointOffset) + continue; + + const uint64_t AddressPointOffset = *MaybeAddressPointOffset; + + Function *Callee = nullptr; + std::tie(Callee, std::ignore) = getFunctionAtVTableOffset( + VTableVar, AddressPointOffset + VirtualCallInfo.FunctionOffset, M); + if (!Callee) + continue; + auto CalleeIndexIter = CalleeIndexMap.find(Callee); + if (CalleeIndexIter == CalleeIndexMap.end()) + continue; + + auto &Candidate = Candidates[CalleeIndexIter->second]; + // There shouldn't be duplicate GUIDs in one !prof metadata (except + // duplicated zeros), so assign counters directly won't cause overwrite or + // counter loss. + Candidate.VTableGUIDAndCounts[VTableVal] = V.Count; + Candidate.AddressPoints.push_back( + getOrCreateVTableAddressPointVar(VTableVar, AddressPointOffset)); + } + + return VPtr; +} + +// Creates 'branch_weights' prof metadata using TrueWeight and FalseWeight. +// Scales uint64_t counters down to uint32_t if necessary to prevent overflow. +static MDNode *createBranchWeights(LLVMContext &Context, uint64_t TrueWeight, + uint64_t FalseWeight) { + MDBuilder MDB(Context); + uint64_t Scale = calculateCountScale(std::max(TrueWeight, FalseWeight)); + return MDB.createBranchWeights(scaleBranchCount(TrueWeight, Scale), + scaleBranchCount(FalseWeight, Scale)); +} + CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee, uint64_t Count, uint64_t TotalCount, bool AttachProfToDirectCall, OptimizationRemarkEmitter *ORE) { + CallBase &NewInst = promoteCallWithIfThenElse( + CB, DirectCallee, + createBranchWeights(CB.getContext(), Count, TotalCount - Count)); - uint64_t ElseCount = TotalCount - Count; - uint64_t MaxCount = (Count >= ElseCount ? Count : ElseCount); - uint64_t Scale = calculateCountScale(MaxCount); - MDBuilder MDB(CB.getContext()); - MDNode *BranchWeights = MDB.createBranchWeights( - scaleBranchCount(Count, Scale), scaleBranchCount(ElseCount, Scale)); - - CallBase &NewInst = - promoteCallWithIfThenElse(CB, DirectCallee, BranchWeights); - - if (AttachProfToDirectCall) { - setBranchWeights(NewInst, {static_cast<uint32_t>(Count)}); - } + if (AttachProfToDirectCall) + setBranchWeights(NewInst, {static_cast<uint32_t>(Count)}, + /*IsExpected=*/false); using namespace ore; @@ -273,21 +628,176 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee, } // Promote indirect-call to conditional direct-call for one callsite. -uint32_t IndirectCallPromoter::tryToPromote( - CallBase &CB, const std::vector<PromotionCandidate> &Candidates, - uint64_t &TotalCount) { +bool IndirectCallPromoter::tryToPromoteWithFuncCmp( + CallBase &CB, Instruction *VPtr, ArrayRef<PromotionCandidate> Candidates, + uint64_t TotalCount, ArrayRef<InstrProfValueData> ICallProfDataRef, + uint32_t NumCandidates, VTableGUIDCountsMap &VTableGUIDCounts) { uint32_t NumPromoted = 0; for (const auto &C : Candidates) { - uint64_t Count = C.Count; - pgo::promoteIndirectCall(CB, C.TargetFunction, Count, TotalCount, SamplePGO, - &ORE); - assert(TotalCount >= Count); - TotalCount -= Count; + uint64_t FuncCount = C.Count; + pgo::promoteIndirectCall(CB, C.TargetFunction, FuncCount, TotalCount, + SamplePGO, &ORE); + assert(TotalCount >= FuncCount); + TotalCount -= FuncCount; NumOfPGOICallPromotion++; NumPromoted++; + + if (!EnableVTableProfileUse || C.VTableGUIDAndCounts.empty()) + continue; + + // After a virtual call candidate gets promoted, update the vtable's counts + // proportionally. Each vtable-guid in `C.VTableGUIDAndCounts` represents + // a vtable from which the virtual call is loaded. Compute the sum and use + // 128-bit APInt to improve accuracy. + uint64_t SumVTableCount = 0; + for (const auto &[GUID, VTableCount] : C.VTableGUIDAndCounts) + SumVTableCount += VTableCount; + + for (const auto &[GUID, VTableCount] : C.VTableGUIDAndCounts) { + APInt APFuncCount((unsigned)128, FuncCount, false /*signed*/); + APFuncCount *= VTableCount; + VTableGUIDCounts[GUID] -= APFuncCount.udiv(SumVTableCount).getZExtValue(); + } + } + if (NumPromoted == 0) + return false; + + assert(NumPromoted <= ICallProfDataRef.size() && + "Number of promoted functions should not be greater than the number " + "of values in profile metadata"); + + // Update value profiles on the indirect call. + updateFuncValueProfiles(CB, ICallProfDataRef.slice(NumPromoted), TotalCount, + NumCandidates); + updateVPtrValueProfiles(VPtr, VTableGUIDCounts); + return true; +} + +void IndirectCallPromoter::updateFuncValueProfiles( + CallBase &CB, ArrayRef<InstrProfValueData> CallVDs, uint64_t TotalCount, + uint32_t MaxMDCount) { + // First clear the existing !prof. + CB.setMetadata(LLVMContext::MD_prof, nullptr); + // Annotate the remaining value profiles if counter is not zero. + if (TotalCount != 0) + annotateValueSite(M, CB, CallVDs, TotalCount, IPVK_IndirectCallTarget, + MaxMDCount); +} + +void IndirectCallPromoter::updateVPtrValueProfiles( + Instruction *VPtr, VTableGUIDCountsMap &VTableGUIDCounts) { + if (!EnableVTableProfileUse || VPtr == nullptr || + !VPtr->getMetadata(LLVMContext::MD_prof)) + return; + VPtr->setMetadata(LLVMContext::MD_prof, nullptr); + std::vector<InstrProfValueData> VTableValueProfiles; + uint64_t TotalVTableCount = 0; + for (auto [GUID, Count] : VTableGUIDCounts) { + if (Count == 0) + continue; + + VTableValueProfiles.push_back({GUID, Count}); + TotalVTableCount += Count; } - return NumPromoted; + llvm::sort(VTableValueProfiles, + [](const InstrProfValueData &LHS, const InstrProfValueData &RHS) { + return LHS.Count > RHS.Count; + }); + + annotateValueSite(M, *VPtr, VTableValueProfiles, TotalVTableCount, + IPVK_VTableTarget, VTableValueProfiles.size()); +} + +bool IndirectCallPromoter::tryToPromoteWithVTableCmp( + CallBase &CB, Instruction *VPtr, ArrayRef<PromotionCandidate> Candidates, + uint64_t TotalFuncCount, uint32_t NumCandidates, + MutableArrayRef<InstrProfValueData> ICallProfDataRef, + VTableGUIDCountsMap &VTableGUIDCounts) { + SmallVector<uint64_t, 4> PromotedFuncCount; + + for (const auto &Candidate : Candidates) { + for (auto &[GUID, Count] : Candidate.VTableGUIDAndCounts) + VTableGUIDCounts[GUID] -= Count; + + // 'OriginalBB' is the basic block of indirect call. After each candidate + // is promoted, a new basic block is created for the indirect fallback basic + // block and indirect call `CB` is moved into this new BB. + BasicBlock *OriginalBB = CB.getParent(); + promoteCallWithVTableCmp( + CB, VPtr, Candidate.TargetFunction, Candidate.AddressPoints, + createBranchWeights(CB.getContext(), Candidate.Count, + TotalFuncCount - Candidate.Count)); + + int SinkCount = tryToSinkInstructions(OriginalBB, CB.getParent()); + + ORE.emit([&]() { + OptimizationRemark Remark(DEBUG_TYPE, "Promoted", &CB); + + const auto &VTableGUIDAndCounts = Candidate.VTableGUIDAndCounts; + Remark << "Promote indirect call to " + << ore::NV("DirectCallee", Candidate.TargetFunction) + << " with count " << ore::NV("Count", Candidate.Count) + << " out of " << ore::NV("TotalCount", TotalFuncCount) << ", sink " + << ore::NV("SinkCount", SinkCount) + << " instruction(s) and compare " + << ore::NV("VTable", VTableGUIDAndCounts.size()) + << " vtable(s): {"; + + // Sort GUIDs so remark message is deterministic. + std::set<uint64_t> GUIDSet; + for (auto [GUID, Count] : VTableGUIDAndCounts) + GUIDSet.insert(GUID); + for (auto Iter = GUIDSet.begin(); Iter != GUIDSet.end(); Iter++) { + if (Iter != GUIDSet.begin()) + Remark << ", "; + Remark << ore::NV("VTable", Symtab->getGlobalVariable(*Iter)); + } + + Remark << "}"; + + return Remark; + }); + + PromotedFuncCount.push_back(Candidate.Count); + + assert(TotalFuncCount >= Candidate.Count && + "Within one prof metadata, total count is the sum of counts from " + "individual <target, count> pairs"); + // Use std::min since 'TotalFuncCount' is the saturated sum of individual + // counts, see + // https://github.com/llvm/llvm-project/blob/abedb3b8356d5d56f1c575c4f7682fba2cb19787/llvm/lib/ProfileData/InstrProf.cpp#L1281-L1288 + TotalFuncCount -= std::min(TotalFuncCount, Candidate.Count); + NumOfPGOICallPromotion++; + } + + if (PromotedFuncCount.empty()) + return false; + + // Update value profiles for 'CB' and 'VPtr', assuming that each 'CB' has a + // a distinct 'VPtr'. + // FIXME: When Clang `-fstrict-vtable-pointers` is enabled, a vtable might be + // used to load multiple virtual functions. The vtable profiles needs to be + // updated properly in that case (e.g, for each indirect call annotate both + // type profiles and function profiles in one !prof). + for (size_t I = 0; I < PromotedFuncCount.size(); I++) + ICallProfDataRef[I].Count -= + std::max(PromotedFuncCount[I], ICallProfDataRef[I].Count); + // Sort value profiles by count in descending order. + llvm::stable_sort(ICallProfDataRef, [](const InstrProfValueData &LHS, + const InstrProfValueData &RHS) { + return LHS.Count > RHS.Count; + }); + // Drop the <target-value, count> pair if count is zero. + ArrayRef<InstrProfValueData> VDs( + ICallProfDataRef.begin(), + llvm::upper_bound(ICallProfDataRef, 0U, + [](uint64_t Count, const InstrProfValueData &ProfData) { + return ProfData.Count <= Count; + })); + updateFuncValueProfiles(CB, VDs, TotalFuncCount, NumCandidates); + updateVPtrValueProfiles(VPtr, VTableGUIDCounts); + return true; } // Traverse all the indirect-call callsite and get the value profile @@ -296,32 +806,158 @@ bool IndirectCallPromoter::processFunction(ProfileSummaryInfo *PSI) { bool Changed = false; ICallPromotionAnalysis ICallAnalysis; for (auto *CB : findIndirectCalls(F)) { - uint32_t NumVals, NumCandidates; + uint32_t NumCandidates; uint64_t TotalCount; auto ICallProfDataRef = ICallAnalysis.getPromotionCandidatesForInstruction( - CB, NumVals, TotalCount, NumCandidates); + CB, TotalCount, NumCandidates); if (!NumCandidates || (PSI && PSI->hasProfileSummary() && !PSI->isHotCount(TotalCount))) continue; + auto PromotionCandidates = getPromotionCandidatesForCallSite( *CB, ICallProfDataRef, TotalCount, NumCandidates); - uint32_t NumPromoted = tryToPromote(*CB, PromotionCandidates, TotalCount); - if (NumPromoted == 0) - continue; - Changed = true; - // Adjust the MD.prof metadata. First delete the old one. - CB->setMetadata(LLVMContext::MD_prof, nullptr); - // If all promoted, we don't need the MD.prof metadata. - if (TotalCount == 0 || NumPromoted == NumVals) - continue; - // Otherwise we need update with the un-promoted records back. - annotateValueSite(*F.getParent(), *CB, ICallProfDataRef.slice(NumPromoted), - TotalCount, IPVK_IndirectCallTarget, NumCandidates); + VTableGUIDCountsMap VTableGUIDCounts; + Instruction *VPtr = + computeVTableInfos(CB, VTableGUIDCounts, PromotionCandidates); + + if (isProfitableToCompareVTables(*CB, PromotionCandidates, TotalCount)) + Changed |= tryToPromoteWithVTableCmp(*CB, VPtr, PromotionCandidates, + TotalCount, NumCandidates, + ICallProfDataRef, VTableGUIDCounts); + else + Changed |= tryToPromoteWithFuncCmp(*CB, VPtr, PromotionCandidates, + TotalCount, ICallProfDataRef, + NumCandidates, VTableGUIDCounts); } return Changed; } +// TODO: Return false if the function addressing and vtable load instructions +// cannot sink to indirect fallback. +bool IndirectCallPromoter::isProfitableToCompareVTables( + const CallBase &CB, ArrayRef<PromotionCandidate> Candidates, + uint64_t TotalCount) { + if (!EnableVTableProfileUse || Candidates.empty()) + return false; + LLVM_DEBUG(dbgs() << "\nEvaluating vtable profitability for callsite #" + << NumOfPGOICallsites << CB << "\n"); + uint64_t RemainingVTableCount = TotalCount; + const size_t CandidateSize = Candidates.size(); + for (size_t I = 0; I < CandidateSize; I++) { + auto &Candidate = Candidates[I]; + auto &VTableGUIDAndCounts = Candidate.VTableGUIDAndCounts; + + LLVM_DEBUG(dbgs() << " Candidate " << I << " FunctionCount: " + << Candidate.Count << ", VTableCounts:"); + // Add [[maybe_unused]] since <GUID, Count> are only used by LLVM_DEBUG. + for ([[maybe_unused]] auto &[GUID, Count] : VTableGUIDAndCounts) + LLVM_DEBUG(dbgs() << " {" << Symtab->getGlobalVariable(GUID)->getName() + << ", " << Count << "}"); + LLVM_DEBUG(dbgs() << "\n"); + + uint64_t CandidateVTableCount = 0; + for (auto &[GUID, Count] : VTableGUIDAndCounts) + CandidateVTableCount += Count; + + if (CandidateVTableCount < Candidate.Count * ICPVTablePercentageThreshold) { + LLVM_DEBUG( + dbgs() << " function count " << Candidate.Count + << " and its vtable sum count " << CandidateVTableCount + << " have discrepancies. Bail out vtable comparison.\n"); + return false; + } + + RemainingVTableCount -= Candidate.Count; + + // 'MaxNumVTable' limits the number of vtables to make vtable comparison + // profitable. Comparing multiple vtables for one function candidate will + // insert additional instructions on the hot path, and allowing more than + // one vtable for non last candidates may or may not elongate the dependency + // chain for the subsequent candidates. Set its value to 1 for non-last + // candidate and allow option to override it for the last candidate. + int MaxNumVTable = 1; + if (I == CandidateSize - 1) + MaxNumVTable = ICPMaxNumVTableLastCandidate; + + if ((int)Candidate.AddressPoints.size() > MaxNumVTable) { + LLVM_DEBUG(dbgs() << " allow at most " << MaxNumVTable << " and got " + << Candidate.AddressPoints.size() + << " vtables. Bail out for vtable comparison.\n"); + return false; + } + } + + // If the indirect fallback is not cold, don't compare vtables. + if (PSI && PSI->hasProfileSummary() && + !PSI->isColdCount(RemainingVTableCount)) { + LLVM_DEBUG(dbgs() << " Indirect fallback basic block is not cold. Bail " + "out for vtable comparison.\n"); + return false; + } + + return true; +} + +// For virtual calls in the module, collect per-callsite information which will +// be used to associate an ICP candidate with a vtable and a specific function +// in the vtable. With type intrinsics (llvm.type.test), we can find virtual +// calls in a compile-time efficient manner (by iterating its users) and more +// importantly use the compatible type later to figure out the function byte +// offset relative to the start of vtables. +static void +computeVirtualCallSiteTypeInfoMap(Module &M, ModuleAnalysisManager &MAM, + VirtualCallSiteTypeInfoMap &VirtualCSInfo) { + // Right now only llvm.type.test is used to find out virtual call sites. + // With ThinLTO and whole-program-devirtualization, llvm.type.test and + // llvm.public.type.test are emitted, and llvm.public.type.test is either + // refined to llvm.type.test or dropped before indirect-call-promotion pass. + // + // FIXME: For fullLTO with VFE, `llvm.type.checked.load intrinsic` is emitted. + // Find out virtual calls by looking at users of llvm.type.checked.load in + // that case. + Function *TypeTestFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_test)); + if (!TypeTestFunc || TypeTestFunc->use_empty()) + return; + + auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & { + return FAM.getResult<DominatorTreeAnalysis>(F); + }; + // Iterate all type.test calls to find all indirect calls. + for (Use &U : llvm::make_early_inc_range(TypeTestFunc->uses())) { + auto *CI = dyn_cast<CallInst>(U.getUser()); + if (!CI) + continue; + auto *TypeMDVal = cast<MetadataAsValue>(CI->getArgOperand(1)); + if (!TypeMDVal) + continue; + auto *CompatibleTypeId = dyn_cast<MDString>(TypeMDVal->getMetadata()); + if (!CompatibleTypeId) + continue; + + // Find out all devirtualizable call sites given a llvm.type.test + // intrinsic call. + SmallVector<DevirtCallSite, 1> DevirtCalls; + SmallVector<CallInst *, 1> Assumes; + auto &DT = LookupDomTree(*CI->getFunction()); + findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT); + + for (auto &DevirtCall : DevirtCalls) { + CallBase &CB = DevirtCall.CB; + // Given an indirect call, try find the instruction which loads a + // pointer to virtual table. + Instruction *VTablePtr = + PGOIndirectCallVisitor::tryGetVTableInstruction(&CB); + if (!VTablePtr) + continue; + VirtualCSInfo[&CB] = {DevirtCall.Offset, VTablePtr, + CompatibleTypeId->getString()}; + } + } +} + // A wrapper function that does the actual work. static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI, bool InLTO, bool SamplePGO, ModuleAnalysisManager &MAM) { @@ -334,6 +970,20 @@ static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI, bool InLTO, return false; } bool Changed = false; + VirtualCallSiteTypeInfoMap VirtualCSInfo; + + if (EnableVTableProfileUse) + computeVirtualCallSiteTypeInfoMap(M, MAM, VirtualCSInfo); + + // VTableAddressPointOffsetVal stores the vtable address points. The vtable + // address point of a given <vtable, address point offset> is static (doesn't + // change after being computed once). + // IndirectCallPromoter::getOrCreateVTableAddressPointVar creates the map + // entry the first time a <vtable, offset> pair is seen, as + // promoteIndirectCalls processes an IR module and calls IndirectCallPromoter + // repeatedly on each function. + VTableAddressPointOffsetValMap VTableAddressPointOffsetVal; + for (auto &F : M) { if (F.isDeclaration() || F.hasOptNone()) continue; @@ -342,7 +992,9 @@ static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI, bool InLTO, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); - IndirectCallPromoter CallPromoter(F, &Symtab, SamplePGO, ORE); + IndirectCallPromoter CallPromoter(F, M, PSI, &Symtab, SamplePGO, + VirtualCSInfo, + VTableAddressPointOffsetVal, ORE); bool FuncChanged = CallPromoter.processFunction(PSI); if (ICPDUMPAFTER && FuncChanged) { LLVM_DEBUG(dbgs() << "\n== IR Dump After =="; F.print(dbgs())); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp index a19b14087254..d1396071c0ba 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -38,6 +38,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/InitializePasses.h" @@ -64,6 +65,9 @@ using namespace llvm; #define DEBUG_TYPE "instrprof" namespace llvm { +// Command line option to enable vtable value profiling. Defined in +// ProfileData/InstrProf.cpp: -enable-vtable-value-profiling= +extern cl::opt<bool> EnableVTableValueProfiling; // TODO: Remove -debug-info-correlate in next LLVM release, in favor of // -profile-correlate=debug-info. cl::opt<bool> DebugInfoCorrelate( @@ -166,15 +170,59 @@ cl::opt<bool> SkipRetExitBlock( "skip-ret-exit-block", cl::init(true), cl::desc("Suppress counter promotion if exit blocks contain ret.")); +static cl::opt<bool> SampledInstr("sampled-instrumentation", cl::ZeroOrMore, + cl::init(false), + cl::desc("Do PGO instrumentation sampling")); + +static cl::opt<unsigned> SampledInstrPeriod( + "sampled-instr-period", + cl::desc("Set the profile instrumentation sample period. For each sample " + "period, a fixed number of consecutive samples will be recorded. " + "The number is controlled by 'sampled-instr-burst-duration' flag. " + "The default sample period of 65535 is optimized for generating " + "efficient code that leverages unsigned integer wrapping in " + "overflow."), + cl::init(65535)); + +static cl::opt<unsigned> SampledInstrBurstDuration( + "sampled-instr-burst-duration", + cl::desc("Set the profile instrumentation burst duration, which can range " + "from 0 to one less than the value of 'sampled-instr-period'. " + "This number of samples will be recorded for each " + "'sampled-instr-period' count update. Setting to 1 enables " + "simple sampling, in which case it is recommended to set " + "'sampled-instr-period' to a prime number."), + cl::init(200)); + using LoadStorePair = std::pair<Instruction *, Instruction *>; +static uint64_t getIntModuleFlagOrZero(const Module &M, StringRef Flag) { + auto *MD = dyn_cast_or_null<ConstantAsMetadata>(M.getModuleFlag(Flag)); + if (!MD) + return 0; + + // If the flag is a ConstantAsMetadata, it should be an integer representable + // in 64-bits. + return cast<ConstantInt>(MD->getValue())->getZExtValue(); +} + +static bool enablesValueProfiling(const Module &M) { + return isIRPGOFlagSet(&M) || + getIntModuleFlagOrZero(M, "EnableValueProfiling") != 0; +} + +// Conservatively returns true if value profiling is enabled. +static bool profDataReferencedByCode(const Module &M) { + return enablesValueProfiling(M); +} + class InstrLowerer final { public: InstrLowerer(Module &M, const InstrProfOptions &Options, std::function<const TargetLibraryInfo &(Function &F)> GetTLI, bool IsCS) : M(M), Options(Options), TT(Triple(M.getTargetTriple())), IsCS(IsCS), - GetTLI(GetTLI) {} + GetTLI(GetTLI), DataReferencedByCode(profDataReferencedByCode(M)) {} bool lower(); @@ -186,6 +234,9 @@ private: const bool IsCS; std::function<const TargetLibraryInfo &(Function &F)> GetTLI; + + const bool DataReferencedByCode; + struct PerFunctionProfileData { uint32_t NumValueSites[IPVK_Last + 1] = {}; GlobalVariable *RegionCounters = nullptr; @@ -196,15 +247,25 @@ private: PerFunctionProfileData() = default; }; DenseMap<GlobalVariable *, PerFunctionProfileData> ProfileDataMap; + // Key is virtual table variable, value is 'VTableProfData' in the form of + // GlobalVariable. + DenseMap<GlobalVariable *, GlobalVariable *> VTableDataMap; /// If runtime relocation is enabled, this maps functions to the load /// instruction that produces the profile relocation bias. DenseMap<const Function *, LoadInst *> FunctionToProfileBiasMap; std::vector<GlobalValue *> CompilerUsedVars; std::vector<GlobalValue *> UsedVars; std::vector<GlobalVariable *> ReferencedNames; + // The list of virtual table variables of which the VTableProfData is + // collected. + std::vector<GlobalVariable *> ReferencedVTables; GlobalVariable *NamesVar = nullptr; size_t NamesSize = 0; + /// The instance of [[alwaysinline]] rmw_or(ptr, i8). + /// This is name-insensitive. + Function *RMWOrFunc = nullptr; + // vector of counter load/store pairs to be register promoted. std::vector<LoadStorePair> PromotionCandidates; @@ -223,6 +284,9 @@ private: /// Returns true if profile counter update register promotion is enabled. bool isCounterPromotionEnabled() const; + /// Return true if profile sampling is enabled. + bool isSamplingEnabled() const; + /// Count the number of instrumented value sites for the function. void computeNumValueSiteCounts(InstrProfValueProfileInst *Ins); @@ -246,14 +310,17 @@ private: /// using the index represented by the a temp value into a bitmap. void lowerMCDCTestVectorBitmapUpdate(InstrProfMCDCTVBitmapUpdate *Ins); - /// Replace instrprof.mcdc.temp.update with a shift and or instruction using - /// the corresponding condition ID. - void lowerMCDCCondBitmapUpdate(InstrProfMCDCCondBitmapUpdate *Ins); + /// Get the Bias value for data to access mmap-ed area. + /// Create it if it hasn't been seen. + GlobalVariable *getOrCreateBiasVar(StringRef VarName); /// Compute the address of the counter value that this profiling instruction /// acts on. Value *getCounterAddress(InstrProfCntrInstBase *I); + /// Lower the incremental instructions under profile sampling predicates. + void doSampling(Instruction *I); + /// Get the region counters for an increment, creating them if necessary. /// /// If the counter array doesn't yet exist, the profile data variables @@ -265,6 +332,14 @@ private: StringRef Name, GlobalValue::LinkageTypes Linkage); + /// Create [[alwaysinline]] rmw_or(ptr, i8). + /// This doesn't update `RMWOrFunc`. + Function *createRMWOrFunc(); + + /// Get the call to `rmw_or`. + /// Create the instance if it is unknown. + CallInst *getRMWOrCall(Value *Addr, Value *Val); + /// Compute the address of the test vector bitmap that this profiling /// instruction acts on. Value *getBitmapAddress(InstrProfMCDCTVBitmapUpdate *I); @@ -285,7 +360,7 @@ private: GlobalValue::LinkageTypes Linkage); /// Set Comdat property of GV, if required. - void maybeSetComdat(GlobalVariable *GV, Function *Fn, StringRef VarName); + void maybeSetComdat(GlobalVariable *GV, GlobalObject *GO, StringRef VarName); /// Setup the sections into which counters and bitmaps are allocated. GlobalVariable *setupProfileSection(InstrProfInstBase *Inc, @@ -294,9 +369,15 @@ private: /// Create INSTR_PROF_DATA variable for counters and bitmaps. void createDataVariable(InstrProfCntrInstBase *Inc); + /// Get the counters for virtual table values, creating them if necessary. + void getOrCreateVTableProfData(GlobalVariable *GV); + /// Emit the section with compressed function names. void emitNameData(); + /// Emit the section with compressed vtable names. + void emitVTableNames(); + /// Emit value nodes section for value profiling. void emitVNodes(); @@ -584,36 +665,169 @@ PreservedAnalyses InstrProfilingLoweringPass::run(Module &M, return PreservedAnalyses::none(); } +// +// Perform instrumentation sampling. +// +// There are 3 favors of sampling: +// (1) Full burst sampling: We transform: +// Increment_Instruction; +// to: +// if (__llvm_profile_sampling__ < SampledInstrBurstDuration) { +// Increment_Instruction; +// } +// __llvm_profile_sampling__ += 1; +// if (__llvm_profile_sampling__ >= SampledInstrPeriod) { +// __llvm_profile_sampling__ = 0; +// } +// +// "__llvm_profile_sampling__" is a thread-local global shared by all PGO +// counters (value-instrumentation and edge instrumentation). +// +// (2) Fast burst sampling: +// "__llvm_profile_sampling__" variable is an unsigned type, meaning it will +// wrap around to zero when overflows. In this case, the second check is +// unnecessary, so we won't generate check2 when the SampledInstrPeriod is +// set to 65535 (64K - 1). The code after: +// if (__llvm_profile_sampling__ < SampledInstrBurstDuration) { +// Increment_Instruction; +// } +// __llvm_profile_sampling__ += 1; +// +// (3) Simple sampling: +// When SampledInstrBurstDuration sets to 1, we do a simple sampling: +// __llvm_profile_sampling__ += 1; +// if (__llvm_profile_sampling__ >= SampledInstrPeriod) { +// __llvm_profile_sampling__ = 0; +// Increment_Instruction; +// } +// +// Note that, the code snippet after the transformation can still be counter +// promoted. However, with sampling enabled, counter updates are expected to +// be infrequent, making the benefits of counter promotion negligible. +// Moreover, counter promotion can potentially cause issues in server +// applications, particularly when the counters are dumped without a clean +// exit. To mitigate this risk, counter promotion is disabled by default when +// sampling is enabled. This behavior can be overridden using the internal +// option. +void InstrLowerer::doSampling(Instruction *I) { + if (!isSamplingEnabled()) + return; + + unsigned SampledBurstDuration = SampledInstrBurstDuration.getValue(); + unsigned SampledPeriod = SampledInstrPeriod.getValue(); + if (SampledBurstDuration >= SampledPeriod) { + report_fatal_error( + "SampledPeriod needs to be greater than SampledBurstDuration"); + } + bool UseShort = (SampledPeriod <= USHRT_MAX); + bool IsSimpleSampling = (SampledBurstDuration == 1); + // If (SampledBurstDuration == 1 && SampledPeriod == 65535), generate + // the simple sampling style code. + bool IsFastSampling = (!IsSimpleSampling && SampledPeriod == 65535); + + auto GetConstant = [UseShort](IRBuilder<> &Builder, uint32_t C) { + if (UseShort) + return Builder.getInt16(C); + else + return Builder.getInt32(C); + }; + + IntegerType *SamplingVarTy; + if (UseShort) + SamplingVarTy = Type::getInt16Ty(M.getContext()); + else + SamplingVarTy = Type::getInt32Ty(M.getContext()); + auto *SamplingVar = + M.getGlobalVariable(INSTR_PROF_QUOTE(INSTR_PROF_PROFILE_SAMPLING_VAR)); + assert(SamplingVar && "SamplingVar not set properly"); + + // Create the condition for checking the burst duration. + Instruction *SamplingVarIncr; + Value *NewSamplingVarVal; + MDBuilder MDB(I->getContext()); + MDNode *BranchWeight; + IRBuilder<> CondBuilder(I); + auto *LoadSamplingVar = CondBuilder.CreateLoad(SamplingVarTy, SamplingVar); + if (IsSimpleSampling) { + // For the simple sampling, just create the load and increments. + IRBuilder<> IncBuilder(I); + NewSamplingVarVal = + IncBuilder.CreateAdd(LoadSamplingVar, GetConstant(IncBuilder, 1)); + SamplingVarIncr = IncBuilder.CreateStore(NewSamplingVarVal, SamplingVar); + } else { + // For the bust-sampling, create the conditonal update. + auto *DurationCond = CondBuilder.CreateICmpULE( + LoadSamplingVar, GetConstant(CondBuilder, SampledBurstDuration)); + BranchWeight = MDB.createBranchWeights( + SampledBurstDuration, SampledPeriod + 1 - SampledBurstDuration); + Instruction *ThenTerm = SplitBlockAndInsertIfThen( + DurationCond, I, /* Unreachable */ false, BranchWeight); + IRBuilder<> IncBuilder(I); + NewSamplingVarVal = + IncBuilder.CreateAdd(LoadSamplingVar, GetConstant(IncBuilder, 1)); + SamplingVarIncr = IncBuilder.CreateStore(NewSamplingVarVal, SamplingVar); + I->moveBefore(ThenTerm); + } + + if (IsFastSampling) + return; + + // Create the condtion for checking the period. + Instruction *ThenTerm, *ElseTerm; + IRBuilder<> PeriodCondBuilder(SamplingVarIncr); + auto *PeriodCond = PeriodCondBuilder.CreateICmpUGE( + NewSamplingVarVal, GetConstant(PeriodCondBuilder, SampledPeriod)); + BranchWeight = MDB.createBranchWeights(1, SampledPeriod); + SplitBlockAndInsertIfThenElse(PeriodCond, SamplingVarIncr, &ThenTerm, + &ElseTerm, BranchWeight); + + // For the simple sampling, the counter update happens in sampling var reset. + if (IsSimpleSampling) + I->moveBefore(ThenTerm); + + IRBuilder<> ResetBuilder(ThenTerm); + ResetBuilder.CreateStore(GetConstant(ResetBuilder, 0), SamplingVar); + SamplingVarIncr->moveBefore(ElseTerm); +} + bool InstrLowerer::lowerIntrinsics(Function *F) { bool MadeChange = false; PromotionCandidates.clear(); + SmallVector<InstrProfInstBase *, 8> InstrProfInsts; + + // To ensure compatibility with sampling, we save the intrinsics into + // a buffer to prevent potential breakage of the iterator (as the + // intrinsics will be moved to a different BB). for (BasicBlock &BB : *F) { for (Instruction &Instr : llvm::make_early_inc_range(BB)) { - if (auto *IPIS = dyn_cast<InstrProfIncrementInstStep>(&Instr)) { - lowerIncrement(IPIS); - MadeChange = true; - } else if (auto *IPI = dyn_cast<InstrProfIncrementInst>(&Instr)) { - lowerIncrement(IPI); - MadeChange = true; - } else if (auto *IPC = dyn_cast<InstrProfTimestampInst>(&Instr)) { - lowerTimestamp(IPC); - MadeChange = true; - } else if (auto *IPC = dyn_cast<InstrProfCoverInst>(&Instr)) { - lowerCover(IPC); - MadeChange = true; - } else if (auto *IPVP = dyn_cast<InstrProfValueProfileInst>(&Instr)) { - lowerValueProfileInst(IPVP); - MadeChange = true; - } else if (auto *IPMP = dyn_cast<InstrProfMCDCBitmapParameters>(&Instr)) { - IPMP->eraseFromParent(); - MadeChange = true; - } else if (auto *IPBU = dyn_cast<InstrProfMCDCTVBitmapUpdate>(&Instr)) { - lowerMCDCTestVectorBitmapUpdate(IPBU); - MadeChange = true; - } else if (auto *IPTU = dyn_cast<InstrProfMCDCCondBitmapUpdate>(&Instr)) { - lowerMCDCCondBitmapUpdate(IPTU); - MadeChange = true; - } + if (auto *IP = dyn_cast<InstrProfInstBase>(&Instr)) + InstrProfInsts.push_back(IP); + } + } + + for (auto *Instr : InstrProfInsts) { + doSampling(Instr); + if (auto *IPIS = dyn_cast<InstrProfIncrementInstStep>(Instr)) { + lowerIncrement(IPIS); + MadeChange = true; + } else if (auto *IPI = dyn_cast<InstrProfIncrementInst>(Instr)) { + lowerIncrement(IPI); + MadeChange = true; + } else if (auto *IPC = dyn_cast<InstrProfTimestampInst>(Instr)) { + lowerTimestamp(IPC); + MadeChange = true; + } else if (auto *IPC = dyn_cast<InstrProfCoverInst>(Instr)) { + lowerCover(IPC); + MadeChange = true; + } else if (auto *IPVP = dyn_cast<InstrProfValueProfileInst>(Instr)) { + lowerValueProfileInst(IPVP); + MadeChange = true; + } else if (auto *IPMP = dyn_cast<InstrProfMCDCBitmapParameters>(Instr)) { + IPMP->eraseFromParent(); + MadeChange = true; + } else if (auto *IPBU = dyn_cast<InstrProfMCDCTVBitmapUpdate>(Instr)) { + lowerMCDCTestVectorBitmapUpdate(IPBU); + MadeChange = true; } } @@ -636,6 +850,12 @@ bool InstrLowerer::isRuntimeCounterRelocationEnabled() const { return TT.isOSFuchsia(); } +bool InstrLowerer::isSamplingEnabled() const { + if (SampledInstr.getNumOccurrences() > 0) + return SampledInstr; + return Options.Sampling; +} + bool InstrLowerer::isCounterPromotionEnabled() const { if (DoCounterPromotion.getNumOccurrences() > 0) return DoCounterPromotion; @@ -706,6 +926,9 @@ bool InstrLowerer::lower() { if (NeedsRuntimeHook) MadeChange = emitRuntimeHook(); + if (!IsCS && isSamplingEnabled()) + createProfileSamplingVar(M); + bool ContainsProfiling = containsProfilingIntrinsics(M); GlobalVariable *CoverageNamesVar = M.getNamedGlobal(getCoverageUnusedNamesVarName()); @@ -740,6 +963,12 @@ bool InstrLowerer::lower() { } } + if (EnableVTableValueProfiling) + for (GlobalVariable &GV : M.globals()) + // Global variables with type metadata are virtual table variables. + if (GV.hasMetadata(LLVMContext::MD_type)) + getOrCreateVTableProfData(&GV); + for (Function &F : M) MadeChange |= lowerIntrinsics(&F); @@ -753,6 +982,7 @@ bool InstrLowerer::lower() { emitVNodes(); emitNameData(); + emitVTableNames(); // Emit runtime hook for the cases where the target does not unconditionally // require pulling in profile runtime, and coverage is enabled on code that is @@ -847,6 +1077,29 @@ void InstrLowerer::lowerValueProfileInst(InstrProfValueProfileInst *Ind) { Ind->eraseFromParent(); } +GlobalVariable *InstrLowerer::getOrCreateBiasVar(StringRef VarName) { + GlobalVariable *Bias = M.getGlobalVariable(VarName); + if (Bias) + return Bias; + + Type *Int64Ty = Type::getInt64Ty(M.getContext()); + + // Compiler must define this variable when runtime counter relocation + // is being used. Runtime has a weak external reference that is used + // to check whether that's the case or not. + Bias = new GlobalVariable(M, Int64Ty, false, GlobalValue::LinkOnceODRLinkage, + Constant::getNullValue(Int64Ty), VarName); + Bias->setVisibility(GlobalVariable::HiddenVisibility); + // A definition that's weak (linkonce_odr) without being in a COMDAT + // section wouldn't lead to link errors, but it would lead to a dead + // data word from every TU but one. Putting it in COMDAT ensures there + // will be exactly one data slot in the link. + if (TT.supportsCOMDAT()) + Bias->setComdat(M.getOrInsertComdat(VarName)); + + return Bias; +} + Value *InstrLowerer::getCounterAddress(InstrProfCntrInstBase *I) { auto *Counters = getOrCreateRegionCounters(I); IRBuilder<> Builder(I); @@ -865,35 +1118,82 @@ Value *InstrLowerer::getCounterAddress(InstrProfCntrInstBase *I) { LoadInst *&BiasLI = FunctionToProfileBiasMap[Fn]; if (!BiasLI) { IRBuilder<> EntryBuilder(&Fn->getEntryBlock().front()); - auto *Bias = M.getGlobalVariable(getInstrProfCounterBiasVarName()); - if (!Bias) { - // Compiler must define this variable when runtime counter relocation - // is being used. Runtime has a weak external reference that is used - // to check whether that's the case or not. - Bias = new GlobalVariable( - M, Int64Ty, false, GlobalValue::LinkOnceODRLinkage, - Constant::getNullValue(Int64Ty), getInstrProfCounterBiasVarName()); - Bias->setVisibility(GlobalVariable::HiddenVisibility); - // A definition that's weak (linkonce_odr) without being in a COMDAT - // section wouldn't lead to link errors, but it would lead to a dead - // data word from every TU but one. Putting it in COMDAT ensures there - // will be exactly one data slot in the link. - if (TT.supportsCOMDAT()) - Bias->setComdat(M.getOrInsertComdat(Bias->getName())); - } - BiasLI = EntryBuilder.CreateLoad(Int64Ty, Bias); + auto *Bias = getOrCreateBiasVar(getInstrProfCounterBiasVarName()); + BiasLI = EntryBuilder.CreateLoad(Int64Ty, Bias, "profc_bias"); + // Bias doesn't change after startup. + BiasLI->setMetadata(LLVMContext::MD_invariant_load, + MDNode::get(M.getContext(), std::nullopt)); } auto *Add = Builder.CreateAdd(Builder.CreatePtrToInt(Addr, Int64Ty), BiasLI); return Builder.CreateIntToPtr(Add, Addr->getType()); } +/// Create `void [[alwaysinline]] rmw_or(uint8_t *ArgAddr, uint8_t ArgVal)` +/// "Basic" sequence is `*ArgAddr |= ArgVal` +Function *InstrLowerer::createRMWOrFunc() { + auto &Ctx = M.getContext(); + auto *Int8Ty = Type::getInt8Ty(Ctx); + Function *Fn = Function::Create( + FunctionType::get(Type::getVoidTy(Ctx), + {PointerType::getUnqual(Ctx), Int8Ty}, false), + Function::LinkageTypes::PrivateLinkage, "rmw_or", M); + Fn->addFnAttr(Attribute::AlwaysInline); + auto *ArgAddr = Fn->getArg(0); + auto *ArgVal = Fn->getArg(1); + IRBuilder<> Builder(BasicBlock::Create(Ctx, "", Fn)); + + // Load profile bitmap byte. + // %mcdc.bits = load i8, ptr %4, align 1 + auto *Bitmap = Builder.CreateLoad(Int8Ty, ArgAddr, "mcdc.bits"); + + if (Options.Atomic || AtomicCounterUpdateAll) { + // If ((Bitmap & Val) != Val), then execute atomic (Bitmap |= Val). + // Note, just-loaded Bitmap might not be up-to-date. Use it just for + // early testing. + auto *Masked = Builder.CreateAnd(Bitmap, ArgVal); + auto *ShouldStore = Builder.CreateICmpNE(Masked, ArgVal); + auto *ThenTerm = BasicBlock::Create(Ctx, "", Fn); + auto *ElseTerm = BasicBlock::Create(Ctx, "", Fn); + // Assume updating will be rare. + auto *Unlikely = MDBuilder(Ctx).createUnlikelyBranchWeights(); + Builder.CreateCondBr(ShouldStore, ThenTerm, ElseTerm, Unlikely); + + IRBuilder<> ThenBuilder(ThenTerm); + ThenBuilder.CreateAtomicRMW(AtomicRMWInst::Or, ArgAddr, ArgVal, + MaybeAlign(), AtomicOrdering::Monotonic); + ThenBuilder.CreateRetVoid(); + + IRBuilder<> ElseBuilder(ElseTerm); + ElseBuilder.CreateRetVoid(); + + return Fn; + } + + // Perform logical OR of profile bitmap byte and shifted bit offset. + // %8 = or i8 %mcdc.bits, %7 + auto *Result = Builder.CreateOr(Bitmap, ArgVal); + + // Store the updated profile bitmap byte. + // store i8 %8, ptr %3, align 1 + Builder.CreateStore(Result, ArgAddr); + + // Terminator + Builder.CreateRetVoid(); + + return Fn; +} + +CallInst *InstrLowerer::getRMWOrCall(Value *Addr, Value *Val) { + if (!RMWOrFunc) + RMWOrFunc = createRMWOrFunc(); + + return CallInst::Create(RMWOrFunc, {Addr, Val}); +} + Value *InstrLowerer::getBitmapAddress(InstrProfMCDCTVBitmapUpdate *I) { auto *Bitmaps = getOrCreateRegionBitmaps(I); IRBuilder<> Builder(I); - auto *Addr = Builder.CreateConstInBoundsGEP2_32( - Bitmaps->getValueType(), Bitmaps, 0, I->getBitmapIndex()->getZExtValue()); - if (isRuntimeCounterRelocationEnabled()) { LLVMContext &Ctx = M.getContext(); Ctx.diagnose(DiagnosticInfoPGOProfile( @@ -903,7 +1203,7 @@ Value *InstrLowerer::getBitmapAddress(InstrProfMCDCTVBitmapUpdate *I) { DS_Warning)); } - return Addr; + return Bitmaps; } void InstrLowerer::lowerCover(InstrProfCoverInst *CoverInstruction) { @@ -969,30 +1269,24 @@ void InstrLowerer::lowerMCDCTestVectorBitmapUpdate( InstrProfMCDCTVBitmapUpdate *Update) { IRBuilder<> Builder(Update); auto *Int8Ty = Type::getInt8Ty(M.getContext()); - auto *Int8PtrTy = PointerType::getUnqual(M.getContext()); auto *Int32Ty = Type::getInt32Ty(M.getContext()); - auto *Int64Ty = Type::getInt64Ty(M.getContext()); auto *MCDCCondBitmapAddr = Update->getMCDCCondBitmapAddr(); auto *BitmapAddr = getBitmapAddress(Update); - // Load Temp Val. + // Load Temp Val + BitmapIdx. // %mcdc.temp = load i32, ptr %mcdc.addr, align 4 - auto *Temp = Builder.CreateLoad(Int32Ty, MCDCCondBitmapAddr, "mcdc.temp"); + auto *Temp = Builder.CreateAdd( + Builder.CreateLoad(Int32Ty, MCDCCondBitmapAddr, "mcdc.temp"), + Update->getBitmapIndex()); // Calculate byte offset using div8. // %1 = lshr i32 %mcdc.temp, 3 auto *BitmapByteOffset = Builder.CreateLShr(Temp, 0x3); // Add byte offset to section base byte address. - // %2 = zext i32 %1 to i64 - // %3 = add i64 ptrtoint (ptr @__profbm_test to i64), %2 + // %4 = getelementptr inbounds i8, ptr @__profbm_test, i32 %1 auto *BitmapByteAddr = - Builder.CreateAdd(Builder.CreatePtrToInt(BitmapAddr, Int64Ty), - Builder.CreateZExtOrBitCast(BitmapByteOffset, Int64Ty)); - - // Convert to a pointer. - // %4 = inttoptr i32 %3 to ptr - BitmapByteAddr = Builder.CreateIntToPtr(BitmapByteAddr, Int8PtrTy); + Builder.CreateInBoundsPtrAdd(BitmapAddr, BitmapByteOffset); // Calculate bit offset into bitmap byte by using div8 remainder (AND ~8) // %5 = and i32 %mcdc.temp, 7 @@ -1003,45 +1297,7 @@ void InstrLowerer::lowerMCDCTestVectorBitmapUpdate( // %7 = shl i8 1, %6 auto *ShiftedVal = Builder.CreateShl(Builder.getInt8(0x1), BitToSet); - // Load profile bitmap byte. - // %mcdc.bits = load i8, ptr %4, align 1 - auto *Bitmap = Builder.CreateLoad(Int8Ty, BitmapByteAddr, "mcdc.bits"); - - // Perform logical OR of profile bitmap byte and shifted bit offset. - // %8 = or i8 %mcdc.bits, %7 - auto *Result = Builder.CreateOr(Bitmap, ShiftedVal); - - // Store the updated profile bitmap byte. - // store i8 %8, ptr %3, align 1 - Builder.CreateStore(Result, BitmapByteAddr); - Update->eraseFromParent(); -} - -void InstrLowerer::lowerMCDCCondBitmapUpdate( - InstrProfMCDCCondBitmapUpdate *Update) { - IRBuilder<> Builder(Update); - auto *Int32Ty = Type::getInt32Ty(M.getContext()); - auto *MCDCCondBitmapAddr = Update->getMCDCCondBitmapAddr(); - - // Load the MCDC temporary value from the stack. - // %mcdc.temp = load i32, ptr %mcdc.addr, align 4 - auto *Temp = Builder.CreateLoad(Int32Ty, MCDCCondBitmapAddr, "mcdc.temp"); - - // Zero-extend the evaluated condition boolean value (0 or 1) by 32bits. - // %1 = zext i1 %tobool to i32 - auto *CondV_32 = Builder.CreateZExt(Update->getCondBool(), Int32Ty); - - // Shift the boolean value left (by the condition's ID) to form a bitmap. - // %2 = shl i32 %1, <Update->getCondID()> - auto *ShiftedVal = Builder.CreateShl(CondV_32, Update->getCondID()); - - // Perform logical OR of the bitmap against the loaded MCDC temporary value. - // %3 = or i32 %mcdc.temp, %2 - auto *Result = Builder.CreateOr(Temp, ShiftedVal); - - // Store the updated temporary value back to the stack. - // store i32 %3, ptr %mcdc.addr, align 4 - Builder.CreateStore(Result, MCDCCondBitmapAddr); + Builder.Insert(getRMWOrCall(BitmapByteAddr, ShiftedVal)); Update->eraseFromParent(); } @@ -1065,26 +1321,6 @@ static std::string getVarName(InstrProfInstBase *Inc, StringRef Prefix, return (Prefix + Name + "." + Twine(FuncHash)).str(); } -static uint64_t getIntModuleFlagOrZero(const Module &M, StringRef Flag) { - auto *MD = dyn_cast_or_null<ConstantAsMetadata>(M.getModuleFlag(Flag)); - if (!MD) - return 0; - - // If the flag is a ConstantAsMetadata, it should be an integer representable - // in 64-bits. - return cast<ConstantInt>(MD->getValue())->getZExtValue(); -} - -static bool enablesValueProfiling(const Module &M) { - return isIRPGOFlagSet(&M) || - getIntModuleFlagOrZero(M, "EnableValueProfiling") != 0; -} - -// Conservatively returns true if data variables may be referenced by code. -static bool profDataReferencedByCode(const Module &M) { - return enablesValueProfiling(M); -} - static inline bool shouldRecordFunctionAddr(Function *F) { // Only record function addresses if IR PGO is enabled or if clang value // profiling is enabled. Recording function addresses greatly increases object @@ -1198,20 +1434,42 @@ static bool needsRuntimeRegistrationOfSectionRange(const Triple &TT) { return true; } -void InstrLowerer::maybeSetComdat(GlobalVariable *GV, Function *Fn, - StringRef VarName) { - bool DataReferencedByCode = profDataReferencedByCode(M); - bool NeedComdat = needsComdatForCounter(*Fn, M); +void InstrLowerer::maybeSetComdat(GlobalVariable *GV, GlobalObject *GO, + StringRef CounterGroupName) { + // Place lowered global variables in a comdat group if the associated function + // or global variable is a COMDAT. This will make sure that only one copy of + // global variable (e.g. function counters) of the COMDAT function will be + // emitted after linking. + bool NeedComdat = needsComdatForCounter(*GO, M); bool UseComdat = (NeedComdat || TT.isOSBinFormatELF()); if (!UseComdat) return; - StringRef GroupName = - TT.isOSBinFormatCOFF() && DataReferencedByCode ? GV->getName() : VarName; + // Keep in mind that this pass may run before the inliner, so we need to + // create a new comdat group (for counters, profiling data, etc). If we use + // the comdat of the parent function, that will result in relocations against + // discarded sections. + // + // If the data variable is referenced by code, non-counter variables (notably + // profiling data) and counters have to be in different comdats for COFF + // because the Visual C++ linker will report duplicate symbol errors if there + // are multiple external symbols with the same name marked + // IMAGE_COMDAT_SELECT_ASSOCIATIVE. + StringRef GroupName = TT.isOSBinFormatCOFF() && DataReferencedByCode + ? GV->getName() + : CounterGroupName; Comdat *C = M.getOrInsertComdat(GroupName); - if (!NeedComdat) + + if (!NeedComdat) { + // Object file format must be ELF since `UseComdat && !NeedComdat` is true. + // + // For ELF, when not using COMDAT, put counters, data and values into a + // nodeduplicate COMDAT which is lowered to a zero-flag section group. This + // allows -z start-stop-gc to discard the entire group when the function is + // discarded. C->setSelectionKind(Comdat::NoDeduplicate); + } GV->setComdat(C); // COFF doesn't allow the comdat group leader to have private linkage, so // upgrade private linkage to internal linkage to produce a symbol table @@ -1220,6 +1478,104 @@ void InstrLowerer::maybeSetComdat(GlobalVariable *GV, Function *Fn, GV->setLinkage(GlobalValue::InternalLinkage); } +static inline bool shouldRecordVTableAddr(GlobalVariable *GV) { + if (!profDataReferencedByCode(*GV->getParent())) + return false; + + if (!GV->hasLinkOnceLinkage() && !GV->hasLocalLinkage() && + !GV->hasAvailableExternallyLinkage()) + return true; + + // This avoids the profile data from referencing internal symbols in + // COMDAT. + if (GV->hasLocalLinkage() && GV->hasComdat()) + return false; + + return true; +} + +// FIXME: Introduce an internal alias like what's done for functions to reduce +// the number of relocation entries. +static inline Constant *getVTableAddrForProfData(GlobalVariable *GV) { + auto *Int8PtrTy = PointerType::getUnqual(GV->getContext()); + + // Store a nullptr in __profvt_ if a real address shouldn't be used. + if (!shouldRecordVTableAddr(GV)) + return ConstantPointerNull::get(Int8PtrTy); + + return ConstantExpr::getBitCast(GV, Int8PtrTy); +} + +void InstrLowerer::getOrCreateVTableProfData(GlobalVariable *GV) { + assert(!DebugInfoCorrelate && + "Value profiling is not supported with lightweight instrumentation"); + if (GV->isDeclaration() || GV->hasAvailableExternallyLinkage()) + return; + + // Skip llvm internal global variable or __prof variables. + if (GV->getName().starts_with("llvm.") || + GV->getName().starts_with("__llvm") || + GV->getName().starts_with("__prof")) + return; + + // VTableProfData already created + auto It = VTableDataMap.find(GV); + if (It != VTableDataMap.end() && It->second) + return; + + GlobalValue::LinkageTypes Linkage = GV->getLinkage(); + GlobalValue::VisibilityTypes Visibility = GV->getVisibility(); + + // This is to keep consistent with per-function profile data + // for correctness. + if (TT.isOSBinFormatXCOFF()) { + Linkage = GlobalValue::InternalLinkage; + Visibility = GlobalValue::DefaultVisibility; + } + + LLVMContext &Ctx = M.getContext(); + Type *DataTypes[] = { +#define INSTR_PROF_VTABLE_DATA(Type, LLVMType, Name, Init) LLVMType, +#include "llvm/ProfileData/InstrProfData.inc" +#undef INSTR_PROF_VTABLE_DATA + }; + + auto *DataTy = StructType::get(Ctx, ArrayRef(DataTypes)); + + // Used by INSTR_PROF_VTABLE_DATA MACRO + Constant *VTableAddr = getVTableAddrForProfData(GV); + const std::string PGOVTableName = getPGOName(*GV); + // Record the length of the vtable. This is needed since vtable pointers + // loaded from C++ objects might be from the middle of a vtable definition. + uint32_t VTableSizeVal = + M.getDataLayout().getTypeAllocSize(GV->getValueType()); + + Constant *DataVals[] = { +#define INSTR_PROF_VTABLE_DATA(Type, LLVMType, Name, Init) Init, +#include "llvm/ProfileData/InstrProfData.inc" +#undef INSTR_PROF_VTABLE_DATA + }; + + auto *Data = + new GlobalVariable(M, DataTy, /*constant=*/false, Linkage, + ConstantStruct::get(DataTy, DataVals), + getInstrProfVTableVarPrefix() + PGOVTableName); + + Data->setVisibility(Visibility); + Data->setSection(getInstrProfSectionName(IPSK_vtab, TT.getObjectFormat())); + Data->setAlignment(Align(8)); + + maybeSetComdat(Data, GV, Data->getName()); + + VTableDataMap[GV] = Data; + + ReferencedVTables.push_back(GV); + + // VTable <Hash, Addr> is used by runtime but not referenced by other + // sections. Conservatively mark it linker retained. + UsedVars.push_back(Data); +} + GlobalVariable *InstrLowerer::setupProfileSection(InstrProfInstBase *Inc, InstrProfSectKind IPSK) { GlobalVariable *NamePtr = Inc->getName(); @@ -1245,23 +1601,7 @@ GlobalVariable *InstrLowerer::setupProfileSection(InstrProfInstBase *Inc, Linkage = GlobalValue::PrivateLinkage; Visibility = GlobalValue::DefaultVisibility; } - // Move the name variable to the right section. Place them in a COMDAT group - // if the associated function is a COMDAT. This will make sure that only one - // copy of counters of the COMDAT function will be emitted after linking. Keep - // in mind that this pass may run before the inliner, so we need to create a - // new comdat group for the counters and profiling data. If we use the comdat - // of the parent function, that will result in relocations against discarded - // sections. - // - // If the data variable is referenced by code, counters and data have to be - // in different comdats for COFF because the Visual C++ linker will report - // duplicate symbol errors if there are multiple external symbols with the - // same name marked IMAGE_COMDAT_SELECT_ASSOCIATIVE. - // - // For ELF, when not using COMDAT, put counters, data and values into a - // nodeduplicate COMDAT which is lowered to a zero-flag section group. This - // allows -z start-stop-gc to discard the entire group when the function is - // discarded. + // Move the name variable to the right section. bool Renamed; GlobalVariable *Ptr; StringRef VarPrefix; @@ -1294,7 +1634,7 @@ GlobalVariable * InstrLowerer::createRegionBitmaps(InstrProfMCDCBitmapInstBase *Inc, StringRef Name, GlobalValue::LinkageTypes Linkage) { - uint64_t NumBytes = Inc->getNumBitmapBytes()->getZExtValue(); + uint64_t NumBytes = Inc->getNumBitmapBytes(); auto *BitmapTy = ArrayType::get(Type::getInt8Ty(M.getContext()), NumBytes); auto GV = new GlobalVariable(M, BitmapTy, false, Linkage, Constant::getNullValue(BitmapTy), Name); @@ -1313,7 +1653,7 @@ InstrLowerer::getOrCreateRegionBitmaps(InstrProfMCDCBitmapInstBase *Inc) { // the corresponding profile section. auto *BitmapPtr = setupProfileSection(Inc, IPSK_bitmap); PD.RegionBitmaps = BitmapPtr; - PD.NumBitmapBytes = Inc->getNumBitmapBytes()->getZExtValue(); + PD.NumBitmapBytes = Inc->getNumBitmapBytes(); return PD.RegionBitmaps; } @@ -1426,7 +1766,6 @@ void InstrLowerer::createDataVariable(InstrProfCntrInstBase *Inc) { Visibility = GlobalValue::DefaultVisibility; } - bool DataReferencedByCode = profDataReferencedByCode(M); bool NeedComdat = needsComdatForCounter(*Fn, M); bool Renamed; @@ -1633,6 +1972,31 @@ void InstrLowerer::emitNameData() { NamePtr->eraseFromParent(); } +void InstrLowerer::emitVTableNames() { + if (!EnableVTableValueProfiling || ReferencedVTables.empty()) + return; + + // Collect the PGO names of referenced vtables and compress them. + std::string CompressedVTableNames; + if (Error E = collectVTableStrings(ReferencedVTables, CompressedVTableNames, + DoInstrProfNameCompression)) { + report_fatal_error(Twine(toString(std::move(E))), false); + } + + auto &Ctx = M.getContext(); + auto *VTableNamesVal = ConstantDataArray::getString( + Ctx, StringRef(CompressedVTableNames), false /* AddNull */); + GlobalVariable *VTableNamesVar = + new GlobalVariable(M, VTableNamesVal->getType(), true /* constant */, + GlobalValue::PrivateLinkage, VTableNamesVal, + getInstrProfVTableNamesVarName()); + VTableNamesVar->setSection( + getInstrProfSectionName(IPSK_vname, TT.getObjectFormat())); + VTableNamesVar->setAlignment(Align(1)); + // Make VTableNames linker retained. + UsedVars.push_back(VTableNamesVar); +} + void InstrLowerer::emitRegistration() { if (!needsRuntimeRegistrationOfSectionRange(TT)) return; @@ -1727,7 +2091,7 @@ void InstrLowerer::emitUses() { // and ensure this GC property as well. Otherwise, we have to conservatively // make all of the sections retained by the linker. if (TT.isOSBinFormatELF() || TT.isOSBinFormatMachO() || - (TT.isOSBinFormatCOFF() && !profDataReferencedByCode(M))) + (TT.isOSBinFormatCOFF() && !DataReferencedByCode)) appendToCompilerUsed(M, CompilerUsedVars); else appendToUsed(M, CompilerUsedVars); @@ -1766,3 +2130,29 @@ void InstrLowerer::emitInitialization() { appendToGlobalCtors(M, F, 0); } + +namespace llvm { +// Create the variable for profile sampling. +void createProfileSamplingVar(Module &M) { + const StringRef VarName(INSTR_PROF_QUOTE(INSTR_PROF_PROFILE_SAMPLING_VAR)); + IntegerType *SamplingVarTy; + Constant *ValueZero; + if (SampledInstrPeriod.getValue() <= USHRT_MAX) { + SamplingVarTy = Type::getInt16Ty(M.getContext()); + ValueZero = Constant::getIntegerValue(SamplingVarTy, APInt(16, 0)); + } else { + SamplingVarTy = Type::getInt32Ty(M.getContext()); + ValueZero = Constant::getIntegerValue(SamplingVarTy, APInt(32, 0)); + } + auto SamplingVar = new GlobalVariable( + M, SamplingVarTy, false, GlobalValue::WeakAnyLinkage, ValueZero, VarName); + SamplingVar->setVisibility(GlobalValue::DefaultVisibility); + SamplingVar->setThreadLocal(true); + Triple TT(M.getTargetTriple()); + if (TT.supportsCOMDAT()) { + SamplingVar->setLinkage(GlobalValue::ExternalLinkage); + SamplingVar->setComdat(M.getOrInsertComdat(VarName)); + } + appendToCompilerUsed(M, SamplingVar); +} +} // namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/KCFI.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/KCFI.cpp index b1a26880c701..28dc1c02b661 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/KCFI.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/KCFI.cpp @@ -71,8 +71,7 @@ PreservedAnalyses KCFIPass::run(Function &F, FunctionAnalysisManager &AM) { "compatible with -fsanitize=kcfi on this target")); IntegerType *Int32Ty = Type::getInt32Ty(Ctx); - MDNode *VeryUnlikelyWeights = - MDBuilder(Ctx).createBranchWeights(1, (1U << 20) - 1); + MDNode *VeryUnlikelyWeights = MDBuilder(Ctx).createUnlikelyBranchWeights(); Triple T(M.getTargetTriple()); for (CallInst *CI : KCFICalls) { @@ -82,8 +81,8 @@ PreservedAnalyses KCFIPass::run(Function &F, FunctionAnalysisManager &AM) { ->getZExtValue(); // Drop the KCFI operand bundle. - CallBase *Call = - CallBase::removeOperandBundle(CI, LLVMContext::OB_kcfi, CI); + CallBase *Call = CallBase::removeOperandBundle(CI, LLVMContext::OB_kcfi, + CI->getIterator()); assert(Call != CI); Call->copyMetadata(*CI); CI->replaceAllUsesWith(Call); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/LowerAllowCheckPass.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/LowerAllowCheckPass.cpp new file mode 100644 index 000000000000..0115809e939e --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/LowerAllowCheckPass.cpp @@ -0,0 +1,147 @@ +//===- LowerAllowCheckPass.cpp ----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Instrumentation/LowerAllowCheckPass.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/RandomNumberGenerator.h" +#include <memory> +#include <random> + +using namespace llvm; + +#define DEBUG_TYPE "lower-allow-check" + +static cl::opt<int> + HotPercentileCutoff("lower-allow-check-percentile-cutoff-hot", + cl::desc("Hot percentile cuttoff.")); + +static cl::opt<float> + RandomRate("lower-allow-check-random-rate", + cl::desc("Probability value in the range [0.0, 1.0] of " + "unconditional pseudo-random checks.")); + +STATISTIC(NumChecksTotal, "Number of checks"); +STATISTIC(NumChecksRemoved, "Number of removed checks"); + +struct RemarkInfo { + ore::NV Kind; + ore::NV F; + ore::NV BB; + explicit RemarkInfo(IntrinsicInst *II) + : Kind("Kind", II->getArgOperand(0)), + F("Function", II->getParent()->getParent()), + BB("Block", II->getParent()->getName()) {} +}; + +static void emitRemark(IntrinsicInst *II, OptimizationRemarkEmitter &ORE, + bool Removed) { + if (Removed) { + ORE.emit([&]() { + RemarkInfo Info(II); + return OptimizationRemark(DEBUG_TYPE, "Removed", II) + << "Removed check: Kind=" << Info.Kind << " F=" << Info.F + << " BB=" << Info.BB; + }); + } else { + ORE.emit([&]() { + RemarkInfo Info(II); + return OptimizationRemarkMissed(DEBUG_TYPE, "Allowed", II) + << "Allowed check: Kind=" << Info.Kind << " F=" << Info.F + << " BB=" << Info.BB; + }); + } +} + +static bool removeUbsanTraps(Function &F, const BlockFrequencyInfo &BFI, + const ProfileSummaryInfo *PSI, + OptimizationRemarkEmitter &ORE) { + SmallVector<std::pair<IntrinsicInst *, bool>, 16> ReplaceWithValue; + std::unique_ptr<RandomNumberGenerator> Rng; + + auto ShouldRemove = [&](bool IsHot) { + if (!RandomRate.getNumOccurrences()) + return IsHot; + if (!Rng) + Rng = F.getParent()->createRNG(F.getName()); + std::bernoulli_distribution D(RandomRate); + return !D(*Rng); + }; + + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { + IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); + if (!II) + continue; + auto ID = II->getIntrinsicID(); + switch (ID) { + case Intrinsic::allow_ubsan_check: + case Intrinsic::allow_runtime_check: { + ++NumChecksTotal; + + bool IsHot = false; + if (PSI) { + uint64_t Count = BFI.getBlockProfileCount(&BB).value_or(0); + IsHot = PSI->isHotCountNthPercentile(HotPercentileCutoff, Count); + } + + bool ToRemove = ShouldRemove(IsHot); + ReplaceWithValue.push_back({ + II, + ToRemove, + }); + if (ToRemove) + ++NumChecksRemoved; + emitRemark(II, ORE, ToRemove); + break; + } + default: + break; + } + } + } + + for (auto [I, V] : ReplaceWithValue) { + I->replaceAllUsesWith(ConstantInt::getBool(I->getType(), !V)); + I->eraseFromParent(); + } + + return !ReplaceWithValue.empty(); +} + +PreservedAnalyses LowerAllowCheckPass::run(Function &F, + FunctionAnalysisManager &AM) { + if (F.isDeclaration()) + return PreservedAnalyses::all(); + auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F); + ProfileSummaryInfo *PSI = + MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); + BlockFrequencyInfo &BFI = AM.getResult<BlockFrequencyAnalysis>(F); + OptimizationRemarkEmitter &ORE = + AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + + return removeUbsanTraps(F, BFI, PSI, ORE) ? PreservedAnalyses::none() + : PreservedAnalyses::all(); +} + +bool LowerAllowCheckPass::IsRequested() { + return RandomRate.getNumOccurrences() || + HotPercentileCutoff.getNumOccurrences(); +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp index 2236e9cd44c5..2c5d749d4a67 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemProfiler.cpp @@ -59,7 +59,7 @@ extern cl::opt<bool> NoPGOWarnMismatchComdatWeak; constexpr int LLVM_MEM_PROFILER_VERSION = 1; // Size of memory mapped to a single shadow location. -constexpr uint64_t DefaultShadowGranularity = 64; +constexpr uint64_t DefaultMemGranularity = 64; // Scale from granularity down to shadow size. constexpr uint64_t DefaultShadowScale = 3; @@ -77,6 +77,8 @@ constexpr char MemProfShadowMemoryDynamicAddress[] = constexpr char MemProfFilenameVar[] = "__memprof_profile_filename"; +constexpr char MemProfHistogramFlagVar[] = "__memprof_histogram"; + // Command-line flags. static cl::opt<bool> ClInsertVersionCheck( @@ -120,7 +122,7 @@ static cl::opt<int> ClMappingScale("memprof-mapping-scale", static cl::opt<int> ClMappingGranularity("memprof-mapping-granularity", cl::desc("granularity of memprof shadow mapping"), - cl::Hidden, cl::init(DefaultShadowGranularity)); + cl::Hidden, cl::init(DefaultMemGranularity)); static cl::opt<bool> ClStack("memprof-instrument-stack", cl::desc("Instrument scalar stack variables"), @@ -140,11 +142,48 @@ static cl::opt<int> ClDebugMin("memprof-debug-min", cl::desc("Debug min inst"), static cl::opt<int> ClDebugMax("memprof-debug-max", cl::desc("Debug max inst"), cl::Hidden, cl::init(-1)); +// By default disable matching of allocation profiles onto operator new that +// already explicitly pass a hot/cold hint, since we don't currently +// override these hints anyway. +static cl::opt<bool> ClMemProfMatchHotColdNew( + "memprof-match-hot-cold-new", + cl::desc( + "Match allocation profiles onto existing hot/cold operator new calls"), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> ClHistogram("memprof-histogram", + cl::desc("Collect access count histograms"), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> + ClPrintMemProfMatchInfo("memprof-print-match-info", + cl::desc("Print matching stats for each allocation " + "context in this module's profiles"), + cl::Hidden, cl::init(false)); + +extern cl::opt<bool> MemProfReportHintedSizes; + +// Instrumentation statistics STATISTIC(NumInstrumentedReads, "Number of instrumented reads"); STATISTIC(NumInstrumentedWrites, "Number of instrumented writes"); STATISTIC(NumSkippedStackReads, "Number of non-instrumented stack reads"); STATISTIC(NumSkippedStackWrites, "Number of non-instrumented stack writes"); + +// Matching statistics STATISTIC(NumOfMemProfMissing, "Number of functions without memory profile."); +STATISTIC(NumOfMemProfMismatch, + "Number of functions having mismatched memory profile hash."); +STATISTIC(NumOfMemProfFunc, "Number of functions having valid memory profile."); +STATISTIC(NumOfMemProfAllocContextProfiles, + "Number of alloc contexts in memory profile."); +STATISTIC(NumOfMemProfCallSiteProfiles, + "Number of callsites in memory profile."); +STATISTIC(NumOfMemProfMatchedAllocContexts, + "Number of matched memory profile alloc contexts."); +STATISTIC(NumOfMemProfMatchedAllocs, + "Number of matched memory profile allocs."); +STATISTIC(NumOfMemProfMatchedCallSites, + "Number of matched memory profile callsites."); namespace { @@ -171,7 +210,6 @@ struct InterestingMemoryAccess { Value *Addr = nullptr; bool IsWrite; Type *AccessTy; - uint64_t TypeSize; Value *MaybeMask = nullptr; }; @@ -194,7 +232,7 @@ public: void instrumentMop(Instruction *I, const DataLayout &DL, InterestingMemoryAccess &Access); void instrumentAddress(Instruction *OrigIns, Instruction *InsertBefore, - Value *Addr, uint32_t TypeSize, bool IsWrite); + Value *Addr, bool IsWrite); void instrumentMaskedLoadOrStore(const DataLayout &DL, Value *Mask, Instruction *I, Value *Addr, Type *AccessTy, bool IsWrite); @@ -215,7 +253,6 @@ private: // These arrays is indexed by AccessIsWrite FunctionCallee MemProfMemoryAccessCallback[2]; - FunctionCallee MemProfMemoryAccessCallbackSized[2]; FunctionCallee MemProfMemmove, MemProfMemcpy, MemProfMemset; Value *DynamicShadowOffset = nullptr; @@ -250,6 +287,11 @@ ModuleMemProfilerPass::ModuleMemProfilerPass() = default; PreservedAnalyses ModuleMemProfilerPass::run(Module &M, AnalysisManager<Module> &AM) { + + assert((!ClHistogram || (ClHistogram && ClUseCalls)) && + "Cannot use -memprof-histogram without Callbacks. Set " + "memprof-use-callbacks"); + ModuleMemProfiler Profiler(M); if (Profiler.instrumentModule(M)) return PreservedAnalyses::none(); @@ -374,8 +416,6 @@ MemProfiler::isInterestingMemoryAccess(Instruction *I) const { return std::nullopt; } - const DataLayout &DL = I->getModule()->getDataLayout(); - Access.TypeSize = DL.getTypeStoreSizeInBits(Access.AccessTy); return Access; } @@ -383,7 +423,6 @@ void MemProfiler::instrumentMaskedLoadOrStore(const DataLayout &DL, Value *Mask, Instruction *I, Value *Addr, Type *AccessTy, bool IsWrite) { auto *VTy = cast<FixedVectorType>(AccessTy); - uint64_t ElemTypeSize = DL.getTypeStoreSizeInBits(VTy->getScalarType()); unsigned Num = VTy->getNumElements(); auto *Zero = ConstantInt::get(IntptrTy, 0); for (unsigned Idx = 0; Idx < Num; ++Idx) { @@ -408,8 +447,7 @@ void MemProfiler::instrumentMaskedLoadOrStore(const DataLayout &DL, Value *Mask, IRBuilder<> IRB(InsertBefore); InstrumentedAddress = IRB.CreateGEP(VTy, Addr, {Zero, ConstantInt::get(IntptrTy, Idx)}); - instrumentAddress(I, InsertBefore, InstrumentedAddress, ElemTypeSize, - IsWrite); + instrumentAddress(I, InsertBefore, InstrumentedAddress, IsWrite); } } @@ -436,13 +474,13 @@ void MemProfiler::instrumentMop(Instruction *I, const DataLayout &DL, // Since the access counts will be accumulated across the entire allocation, // we only update the shadow access count for the first location and thus // don't need to worry about alignment and type size. - instrumentAddress(I, I, Access.Addr, Access.TypeSize, Access.IsWrite); + instrumentAddress(I, I, Access.Addr, Access.IsWrite); } } void MemProfiler::instrumentAddress(Instruction *OrigIns, Instruction *InsertBefore, Value *Addr, - uint32_t TypeSize, bool IsWrite) { + bool IsWrite) { IRBuilder<> IRB(InsertBefore); Value *AddrLong = IRB.CreatePointerCast(Addr, IntptrTy); @@ -483,7 +521,24 @@ void createProfileFileNameVar(Module &M) { } } +// Set MemprofHistogramFlag as a Global veriable in IR. This makes it accessible +// to the runtime, changing shadow count behavior. +void createMemprofHistogramFlagVar(Module &M) { + const StringRef VarName(MemProfHistogramFlagVar); + Type *IntTy1 = Type::getInt1Ty(M.getContext()); + auto MemprofHistogramFlag = new GlobalVariable( + M, IntTy1, true, GlobalValue::WeakAnyLinkage, + Constant::getIntegerValue(IntTy1, APInt(1, ClHistogram)), VarName); + Triple TT(M.getTargetTriple()); + if (TT.supportsCOMDAT()) { + MemprofHistogramFlag->setLinkage(GlobalValue::ExternalLinkage); + MemprofHistogramFlag->setComdat(M.getOrInsertComdat(VarName)); + } + appendToCompilerUsed(M, MemprofHistogramFlag); +} + bool ModuleMemProfiler::instrumentModule(Module &M) { + // Create a module constructor. std::string MemProfVersion = std::to_string(LLVM_MEM_PROFILER_VERSION); std::string VersionCheckName = @@ -499,6 +554,8 @@ bool ModuleMemProfiler::instrumentModule(Module &M) { createProfileFileNameVar(M); + createMemprofHistogramFlagVar(M); + return true; } @@ -507,16 +564,12 @@ void MemProfiler::initializeCallbacks(Module &M) { for (size_t AccessIsWrite = 0; AccessIsWrite <= 1; AccessIsWrite++) { const std::string TypeStr = AccessIsWrite ? "store" : "load"; + const std::string HistPrefix = ClHistogram ? "hist_" : ""; - SmallVector<Type *, 3> Args2 = {IntptrTy, IntptrTy}; SmallVector<Type *, 2> Args1{1, IntptrTy}; - MemProfMemoryAccessCallbackSized[AccessIsWrite] = - M.getOrInsertFunction(ClMemoryAccessCallbackPrefix + TypeStr + "N", - FunctionType::get(IRB.getVoidTy(), Args2, false)); - - MemProfMemoryAccessCallback[AccessIsWrite] = - M.getOrInsertFunction(ClMemoryAccessCallbackPrefix + TypeStr, - FunctionType::get(IRB.getVoidTy(), Args1, false)); + MemProfMemoryAccessCallback[AccessIsWrite] = M.getOrInsertFunction( + ClMemoryAccessCallbackPrefix + HistPrefix + TypeStr, + FunctionType::get(IRB.getVoidTy(), Args1, false)); } MemProfMemmove = M.getOrInsertFunction( ClMemoryAccessCallbackPrefix + "memmove", PtrTy, PtrTy, PtrTy, IntptrTy); @@ -601,7 +654,7 @@ bool MemProfiler::instrumentFunction(Function &F) { std::optional<InterestingMemoryAccess> Access = isInterestingMemoryAccess(Inst); if (Access) - instrumentMop(Inst, F.getParent()->getDataLayout(), *Access); + instrumentMop(Inst, F.getDataLayout(), *Access); else instrumentMemIntrinsic(cast<MemIntrinsic>(Inst)); } @@ -639,15 +692,35 @@ static uint64_t computeStackId(const memprof::Frame &Frame) { return computeStackId(Frame.Function, Frame.LineOffset, Frame.Column); } -static void addCallStack(CallStackTrie &AllocTrie, - const AllocationInfo *AllocInfo) { +// Helper to generate a single hash id for a given callstack, used for emitting +// matching statistics and useful for uniquing such statistics across modules. +static uint64_t +computeFullStackId(const std::vector<memprof::Frame> &CallStack) { + llvm::HashBuilder<llvm::TruncatedBLAKE3<8>, llvm::endianness::little> + HashBuilder; + for (auto &F : CallStack) + HashBuilder.add(F.Function, F.LineOffset, F.Column); + llvm::BLAKE3Result<8> Hash = HashBuilder.final(); + uint64_t Id; + std::memcpy(&Id, Hash.data(), sizeof(Hash)); + return Id; +} + +static AllocationType addCallStack(CallStackTrie &AllocTrie, + const AllocationInfo *AllocInfo) { SmallVector<uint64_t> StackIds; for (const auto &StackFrame : AllocInfo->CallStack) StackIds.push_back(computeStackId(StackFrame)); auto AllocType = getAllocType(AllocInfo->Info.getTotalLifetimeAccessDensity(), AllocInfo->Info.getAllocCount(), AllocInfo->Info.getTotalLifetime()); - AllocTrie.addCallStack(AllocType, StackIds); + uint64_t TotalSize = 0; + if (MemProfReportHintedSizes) { + TotalSize = AllocInfo->Info.getTotalSize(); + assert(TotalSize); + } + AllocTrie.addCallStack(AllocType, StackIds, TotalSize); + return AllocType; } // Helper to compare the InlinedCallStack computed from an instruction's debug @@ -672,9 +745,47 @@ stackFrameIncludesInlinedCallStack(ArrayRef<Frame> ProfileCallStack, return InlCallStackIter == InlinedCallStack.end(); } -static void readMemprof(Module &M, Function &F, - IndexedInstrProfReader *MemProfReader, - const TargetLibraryInfo &TLI) { +static bool isNewWithHotColdVariant(Function *Callee, + const TargetLibraryInfo &TLI) { + if (!Callee) + return false; + LibFunc Func; + if (!TLI.getLibFunc(*Callee, Func)) + return false; + switch (Func) { + case LibFunc_Znwm: + case LibFunc_ZnwmRKSt9nothrow_t: + case LibFunc_ZnwmSt11align_val_t: + case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t: + case LibFunc_Znam: + case LibFunc_ZnamRKSt9nothrow_t: + case LibFunc_ZnamSt11align_val_t: + case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t: + return true; + case LibFunc_Znwm12__hot_cold_t: + case LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t: + case LibFunc_ZnwmSt11align_val_t12__hot_cold_t: + case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t: + case LibFunc_Znam12__hot_cold_t: + case LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t: + case LibFunc_ZnamSt11align_val_t12__hot_cold_t: + case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t: + return ClMemProfMatchHotColdNew; + default: + return false; + } +} + +struct AllocMatchInfo { + uint64_t TotalSize = 0; + AllocationType AllocType = AllocationType::None; + bool Matched = false; +}; + +static void +readMemprof(Module &M, Function &F, IndexedInstrProfReader *MemProfReader, + const TargetLibraryInfo &TLI, + std::map<uint64_t, AllocMatchInfo> &FullStackIdToAllocMatchInfo) { auto &Ctx = M.getContext(); // Previously we used getIRPGOFuncName() here. If F is local linkage, // getIRPGOFuncName() returns FuncName with prefix 'FileName;'. But @@ -698,6 +809,7 @@ static void readMemprof(Module &M, Function &F, SkipWarning = !PGOWarnMissing; LLVM_DEBUG(dbgs() << "unknown function"); } else if (Err == instrprof_error::hash_mismatch) { + NumOfMemProfMismatch++; SkipWarning = NoPGOWarnMismatch || (NoPGOWarnMismatchComdatWeak && @@ -719,6 +831,8 @@ static void readMemprof(Module &M, Function &F, return; } + NumOfMemProfFunc++; + // Detect if there are non-zero column numbers in the profile. If not, // treat all column numbers as 0 when matching (i.e. ignore any non-zero // columns in the IR). The profiled binary might have been built with @@ -730,9 +844,10 @@ static void readMemprof(Module &M, Function &F, std::map<uint64_t, std::set<const AllocationInfo *>> LocHashToAllocInfo; // For the callsites we need to record the index of the associated frame in // the frame array (see comments below where the map entries are added). - std::map<uint64_t, std::set<std::pair<const SmallVector<Frame> *, unsigned>>> + std::map<uint64_t, std::set<std::pair<const std::vector<Frame> *, unsigned>>> LocHashToCallSites; for (auto &AI : MemProfRec->AllocSites) { + NumOfMemProfAllocContextProfiles++; // Associate the allocation info with the leaf frame. The later matching // code will match any inlined call sequences in the IR with a longer prefix // of call stack frames. @@ -741,6 +856,7 @@ static void readMemprof(Module &M, Function &F, ProfileHasColumns |= AI.CallStack[0].Column; } for (auto &CS : MemProfRec->CallSites) { + NumOfMemProfCallSiteProfiles++; // Need to record all frames from leaf up to and including this function, // as any of these may or may not have been inlined at this point. unsigned Idx = 0; @@ -761,7 +877,7 @@ static void readMemprof(Module &M, Function &F, }; // Now walk the instructions, looking up the associated profile data using - // dbug locations. + // debug locations. for (auto &BB : F) { for (auto &I : BB) { if (I.isDebugOrPseudoInst()) @@ -786,7 +902,7 @@ static void readMemprof(Module &M, Function &F, // and another callsite). std::map<uint64_t, std::set<const AllocationInfo *>>::iterator AllocInfoIter; - std::map<uint64_t, std::set<std::pair<const SmallVector<Frame> *, + std::map<uint64_t, std::set<std::pair<const std::vector<Frame> *, unsigned>>>::iterator CallSitesIter; for (const DILocation *DIL = I.getDebugLoc(); DIL != nullptr; DIL = DIL->getInlinedAt()) { @@ -823,7 +939,7 @@ static void readMemprof(Module &M, Function &F, if (AllocInfoIter != LocHashToAllocInfo.end()) { // Only consider allocations via new, to reduce unnecessary metadata, // since those are the only allocations that will be targeted initially. - if (!isNewLikeFn(CI, &TLI)) + if (!isNewWithHotColdVariant(CI->getCalledFunction(), TLI)) continue; // We may match this instruction's location list to multiple MIB // contexts. Add them to a Trie specialized for trimming the contexts to @@ -834,13 +950,23 @@ static void readMemprof(Module &M, Function &F, // If we found and thus matched all frames on the call, include // this MIB. if (stackFrameIncludesInlinedCallStack(AllocInfo->CallStack, - InlinedCallStack)) - addCallStack(AllocTrie, AllocInfo); + InlinedCallStack)) { + NumOfMemProfMatchedAllocContexts++; + auto AllocType = addCallStack(AllocTrie, AllocInfo); + // Record information about the allocation if match info printing + // was requested. + if (ClPrintMemProfMatchInfo) { + auto FullStackId = computeFullStackId(AllocInfo->CallStack); + FullStackIdToAllocMatchInfo[FullStackId] = { + AllocInfo->Info.getTotalSize(), AllocType, /*Matched=*/true}; + } + } } // We might not have matched any to the full inlined call stack. // But if we did, create and attach metadata, or a function attribute if // all contexts have identical profiled behavior. if (!AllocTrie.empty()) { + NumOfMemProfMatchedAllocs++; // MemprofMDAttached will be false if a function attribute was // attached. bool MemprofMDAttached = AllocTrie.buildAndAttachMIBMetadata(CI); @@ -849,7 +975,7 @@ static void readMemprof(Module &M, Function &F, // Add callsite metadata for the instruction's location list so that // it simpler later on to identify which part of the MIB contexts // are from this particular instruction (including during inlining, - // when the callsite metdata will be updated appropriately). + // when the callsite metadata will be updated appropriately). // FIXME: can this be changed to strip out the matching stack // context ids from the MIB contexts and not add any callsite // metadata here to save space? @@ -868,6 +994,7 @@ static void readMemprof(Module &M, Function &F, // attach call stack metadata. if (stackFrameIncludesInlinedCallStack( *CallStackIdx.first, InlinedCallStack, CallStackIdx.second)) { + NumOfMemProfMatchedCallSites++; addCallsiteMetadata(I, InlinedCallStack, Ctx); // Only need to find one with a matching call stack and add a single // callsite metadata. @@ -913,12 +1040,25 @@ PreservedAnalyses MemProfUsePass::run(Module &M, ModuleAnalysisManager &AM) { auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + // Map from the stack has of each allocation context in the function profiles + // to the total profiled size (bytes), allocation type, and whether we matched + // it to an allocation in the IR. + std::map<uint64_t, AllocMatchInfo> FullStackIdToAllocMatchInfo; + for (auto &F : M) { if (F.isDeclaration()) continue; const TargetLibraryInfo &TLI = FAM.getResult<TargetLibraryAnalysis>(F); - readMemprof(M, F, MemProfReader.get(), TLI); + readMemprof(M, F, MemProfReader.get(), TLI, FullStackIdToAllocMatchInfo); + } + + if (ClPrintMemProfMatchInfo) { + for (const auto &[Id, Info] : FullStackIdToAllocMatchInfo) + errs() << "MemProf " << getAllocTypeAttributeString(Info.AllocType) + << " context with id " << Id << " has total profiled size " + << Info.TotalSize << (Info.Matched ? " is" : " not") + << " matched\n"; } return PreservedAnalyses::none(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index 15bca538860d..c979e81ac1a3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -11,11 +11,11 @@ /// reads. /// /// The algorithm of the tool is similar to Memcheck -/// (http://goo.gl/QKbem). We associate a few shadow bits with every -/// byte of the application memory, poison the shadow of the malloc-ed -/// or alloca-ed memory, load the shadow bits on every memory read, -/// propagate the shadow bits through some of the arithmetic -/// instruction (including MOV), store the shadow bits on every memory +/// (https://static.usenix.org/event/usenix05/tech/general/full_papers/seward/seward_html/usenix2005.html) +/// We associate a few shadow bits with every byte of the application memory, +/// poison the shadow of the malloc-ed or alloca-ed memory, load the shadow, +/// bits on every memory read, propagate the shadow bits through some of the +/// arithmetic instruction (including MOV), store the shadow bits on every memory /// write, report a bug on some other instructions (e.g. JMP) if the /// associated shadow is poisoned. /// @@ -124,8 +124,9 @@ /// __msan_metadata_ptr_for_store_n(ptr, size); /// Note that the sanitizer code has to deal with how shadow/origin pairs /// returned by the these functions are represented in different ABIs. In -/// the X86_64 ABI they are returned in RDX:RAX, and in the SystemZ ABI they -/// are written to memory pointed to by a hidden parameter. +/// the X86_64 ABI they are returned in RDX:RAX, in PowerPC64 they are +/// returned in r3 and r4, and in the SystemZ ABI they are written to memory +/// pointed to by a hidden parameter. /// - TLS variables are stored in a single per-task struct. A call to a /// function __msan_get_context_state() returning a pointer to that struct /// is inserted into every instrumented function before the entry block; @@ -139,7 +140,8 @@ /// Also, KMSAN currently ignores uninitialized memory passed into inline asm /// calls, making sure we're on the safe side wrt. possible false positives. /// -/// KernelMemorySanitizer only supports X86_64 and SystemZ at the moment. +/// KernelMemorySanitizer only supports X86_64, SystemZ and PowerPC64 at the +/// moment. /// // // FIXME: This sanitizer does not yet handle scalable vectors @@ -152,6 +154,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" @@ -178,6 +181,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsAArch64.h" #include "llvm/IR/IntrinsicsX86.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" @@ -212,6 +216,9 @@ using namespace llvm; DEBUG_COUNTER(DebugInsertCheck, "msan-insert-check", "Controls which checks to insert"); +DEBUG_COUNTER(DebugInstrumentInstruction, "msan-instrument-instruction", + "Controls which instruction to instrument"); + static const unsigned kOriginSize = 4; static const Align kMinOriginAlignment = Align(4); static const Align kShadowTLSAlignment = Align(8); @@ -284,9 +291,6 @@ static cl::opt<bool> ClHandleLifetimeIntrinsics( // passed into an assembly call. Note that this may cause false positives. // Because it's impossible to figure out the array sizes, we can only unpoison // the first sizeof(type) bytes for each type* pointer. -// The instrumentation is only enabled in KMSAN builds, and only if -// -msan-handle-asm-conservative is on. This is done because we may want to -// quickly disable assembly instrumentation when it breaks. static cl::opt<bool> ClHandleAsmConservative( "msan-handle-asm-conservative", cl::desc("conservative handling of inline assembly"), cl::Hidden, @@ -1043,8 +1047,8 @@ void MemorySanitizer::initializeModule(Module &M) { OriginTy = IRB.getInt32Ty(); PtrTy = IRB.getPtrTy(); - ColdCallWeights = MDBuilder(*C).createBranchWeights(1, 1000); - OriginStoreWeights = MDBuilder(*C).createBranchWeights(1, 1000); + ColdCallWeights = MDBuilder(*C).createUnlikelyBranchWeights(); + OriginStoreWeights = MDBuilder(*C).createUnlikelyBranchWeights(); if (!CompileKernel) { if (TrackOrigins) @@ -1134,6 +1138,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { std::unique_ptr<VarArgHelper> VAHelper; const TargetLibraryInfo *TLI; Instruction *FnPrologueEnd; + SmallVector<Instruction *, 16> Instructions; // The following flags disable parts of MSan instrumentation based on // exclusion list contents and command-line options. @@ -1214,7 +1219,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } Value *originToIntptr(IRBuilder<> &IRB, Value *Origin) { - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); unsigned IntptrSize = DL.getTypeStoreSize(MS.IntptrTy); if (IntptrSize == kOriginSize) return Origin; @@ -1226,7 +1231,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// Fill memory range with the given origin value. void paintOrigin(IRBuilder<> &IRB, Value *Origin, Value *OriginPtr, TypeSize TS, Align Alignment) { - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); const Align IntptrAlignment = DL.getABITypeAlign(MS.IntptrTy); unsigned IntptrSize = DL.getTypeStoreSize(MS.IntptrTy); assert(IntptrAlignment >= kMinOriginAlignment); @@ -1235,9 +1240,11 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // Note: The loop based formation works for fixed length vectors too, // however we prefer to unroll and specialize alignment below. if (TS.isScalable()) { - Value *Size = IRB.CreateTypeSize(IRB.getInt32Ty(), TS); - Value *RoundUp = IRB.CreateAdd(Size, IRB.getInt32(kOriginSize - 1)); - Value *End = IRB.CreateUDiv(RoundUp, IRB.getInt32(kOriginSize)); + Value *Size = IRB.CreateTypeSize(MS.IntptrTy, TS); + Value *RoundUp = + IRB.CreateAdd(Size, ConstantInt::get(MS.IntptrTy, kOriginSize - 1)); + Value *End = + IRB.CreateUDiv(RoundUp, ConstantInt::get(MS.IntptrTy, kOriginSize)); auto [InsertPt, Index] = SplitBlockAndInsertSimpleForLoop(End, &*IRB.GetInsertPoint()); IRB.SetInsertPoint(InsertPt); @@ -1274,9 +1281,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void storeOrigin(IRBuilder<> &IRB, Value *Addr, Value *Shadow, Value *Origin, Value *OriginPtr, Align Alignment) { - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); const Align OriginAlignment = std::max(kMinOriginAlignment, Alignment); TypeSize StoreSize = DL.getTypeStoreSize(Shadow->getType()); + // ZExt cannot convert between vector and scalar Value *ConvertedShadow = convertShadowToScalar(Shadow, IRB); if (auto *ConstantShadow = dyn_cast<Constant>(ConvertedShadow)) { if (!ClCheckConstantShadow || ConstantShadow->isZeroValue()) { @@ -1339,7 +1347,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } } - // Returns true if Debug Location curresponds to multiple warnings. + // Returns true if Debug Location corresponds to multiple warnings. bool shouldDisambiguateWarningLocation(const DebugLoc &DebugLoc) { if (MS.TrackOrigins < 2) return false; @@ -1386,12 +1394,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void materializeOneCheck(IRBuilder<> &IRB, Value *ConvertedShadow, Value *Origin) { - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); TypeSize TypeSizeInBits = DL.getTypeSizeInBits(ConvertedShadow->getType()); unsigned SizeIndex = TypeSizeToSizeIndex(TypeSizeInBits); if (instrumentWithCalls(ConvertedShadow) && SizeIndex < kNumberOfAccessSizes && !MS.CompileKernel) { FunctionCallee Fn = MS.MaybeWarningFn[SizeIndex]; + // ZExt cannot convert between vector and scalar + ConvertedShadow = convertShadowToScalar(ConvertedShadow, IRB); Value *ConvertedShadow2 = IRB.CreateZExt(ConvertedShadow, IRB.getIntNTy(8 * (1 << SizeIndex))); CallBase *CB = IRB.CreateCall( @@ -1413,7 +1423,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void materializeInstructionChecks( ArrayRef<ShadowOriginAndInsertPoint> InstructionChecks) { - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); // Disable combining in some cases. TrackOrigins checks each shadow to pick // correct origin. bool Combine = !MS.TrackOrigins; @@ -1464,19 +1474,21 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } void materializeChecks() { - llvm::stable_sort(InstrumentationList, - [](const ShadowOriginAndInsertPoint &L, - const ShadowOriginAndInsertPoint &R) { - return L.OrigIns < R.OrigIns; - }); +#ifndef NDEBUG + // For assert below. + SmallPtrSet<Instruction *, 16> Done; +#endif for (auto I = InstrumentationList.begin(); I != InstrumentationList.end();) { - auto J = - std::find_if(I + 1, InstrumentationList.end(), - [L = I->OrigIns](const ShadowOriginAndInsertPoint &R) { - return L != R.OrigIns; - }); + auto OrigIns = I->OrigIns; + // Checks are grouped by the original instruction. We call all + // `insertShadowCheck` for an instruction at once. + assert(Done.insert(OrigIns).second); + auto J = std::find_if(I + 1, InstrumentationList.end(), + [OrigIns](const ShadowOriginAndInsertPoint &R) { + return OrigIns != R.OrigIns; + }); // Process all checks of instruction at once. materializeInstructionChecks(ArrayRef<ShadowOriginAndInsertPoint>(I, J)); I = J; @@ -1517,6 +1529,11 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { for (BasicBlock *BB : depth_first(FnPrologueEnd->getParent())) visit(*BB); + // `visit` above only collects instructions. Process them after iterating + // CFG to avoid requirement on CFG transformations. + for (Instruction *I : Instructions) + InstVisitor<MemorySanitizerVisitor>::visit(*I); + // Finalize PHI nodes. for (PHINode *PN : ShadowPHINodes) { PHINode *PNS = cast<PHINode>(getShadow(PN)); @@ -1566,7 +1583,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // This may return weird-sized types like i1. if (IntegerType *IT = dyn_cast<IntegerType>(OrigTy)) return IT; - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); if (VectorType *VT = dyn_cast<VectorType>(OrigTy)) { uint32_t EltSize = DL.getTypeSizeInBits(VT->getElementType()); return VectorType::get(IntegerType::get(*MS.C, EltSize), @@ -1762,7 +1779,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Type *ShadowTy, bool isStore) { Value *ShadowOriginPtrs; - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); TypeSize Size = DL.getTypeStoreSize(ShadowTy); FunctionCallee Getter = MS.getKmsanShadowOriginAccessFn(isStore, Size); @@ -1950,10 +1967,17 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Function *F = A->getParent(); IRBuilder<> EntryIRB(FnPrologueEnd); unsigned ArgOffset = 0; - const DataLayout &DL = F->getParent()->getDataLayout(); + const DataLayout &DL = F->getDataLayout(); for (auto &FArg : F->args()) { - if (!FArg.getType()->isSized()) { - LLVM_DEBUG(dbgs() << "Arg is not sized\n"); + if (!FArg.getType()->isSized() || FArg.getType()->isScalableTy()) { + LLVM_DEBUG(dbgs() << (FArg.getType()->isScalableTy() + ? "vscale not fully supported\n" + : "Arg is not sized\n")); + if (A == &FArg) { + ShadowPtr = getCleanShadow(V); + setOrigin(A, getCleanOrigin()); + break; + } continue; } @@ -2134,8 +2158,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { OrderingTable[(int)AtomicOrderingCABI::seq_cst] = (int)AtomicOrderingCABI::seq_cst; - return ConstantDataVector::get(IRB.getContext(), - ArrayRef(OrderingTable, NumOrderings)); + return ConstantDataVector::get(IRB.getContext(), OrderingTable); } AtomicOrdering addAcquireOrdering(AtomicOrdering a) { @@ -2169,8 +2192,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { OrderingTable[(int)AtomicOrderingCABI::seq_cst] = (int)AtomicOrderingCABI::seq_cst; - return ConstantDataVector::get(IRB.getContext(), - ArrayRef(OrderingTable, NumOrderings)); + return ConstantDataVector::get(IRB.getContext(), OrderingTable); } // ------------------- Visitors. @@ -2181,7 +2203,15 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // Don't want to visit if we're in the prologue if (isInPrologue(I)) return; - InstVisitor<MemorySanitizerVisitor>::visit(I); + if (!DebugCounter::shouldExecute(DebugInstrumentInstruction)) { + LLVM_DEBUG(dbgs() << "Skipping instruction: " << I << "\n"); + // We still need to set the shadow and origin to clean values. + setShadow(&I, getCleanShadow(&I)); + setOrigin(&I, getCleanOrigin()); + return; + } + + Instructions.push_back(&I); } /// Instrument LoadInst @@ -2469,6 +2499,15 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { MSV->setOrigin(I, Origin); } } + + /// Store the current combined value at the specified origin + /// location. + void DoneAndStoreOrigin(TypeSize TS, Value *OriginPtr) { + if (MSV->MS.TrackOrigins) { + assert(Origin); + MSV->paintOrigin(IRB, Origin, OriginPtr, TS, kMinOriginAlignment); + } + } }; using ShadowAndOriginCombiner = Combiner<true>; @@ -2498,6 +2537,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *CreateShadowCast(IRBuilder<> &IRB, Value *V, Type *dstTy, bool Signed = false) { Type *srcTy = V->getType(); + if (srcTy == dstTy) + return V; size_t srcSizeInBits = VectorOrPrimitiveTypeSizeInBits(srcTy); size_t dstSizeInBits = VectorOrPrimitiveTypeSizeInBits(dstTy); if (srcSizeInBits > 1 && dstSizeInBits == 1) @@ -3261,6 +3302,106 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOriginForNaryOp(I); } + // Convert `Mask` into `<n x i1>`. + Constant *createDppMask(unsigned Width, unsigned Mask) { + SmallVector<Constant *, 4> R(Width); + for (auto &M : R) { + M = ConstantInt::getBool(F.getContext(), Mask & 1); + Mask >>= 1; + } + return ConstantVector::get(R); + } + + // Calculate output shadow as array of booleans `<n x i1>`, assuming if any + // arg is poisoned, entire dot product is poisoned. + Value *findDppPoisonedOutput(IRBuilder<> &IRB, Value *S, unsigned SrcMask, + unsigned DstMask) { + const unsigned Width = + cast<FixedVectorType>(S->getType())->getNumElements(); + + S = IRB.CreateSelect(createDppMask(Width, SrcMask), S, + Constant::getNullValue(S->getType())); + Value *SElem = IRB.CreateOrReduce(S); + Value *IsClean = IRB.CreateIsNull(SElem, "_msdpp"); + Value *DstMaskV = createDppMask(Width, DstMask); + + return IRB.CreateSelect( + IsClean, Constant::getNullValue(DstMaskV->getType()), DstMaskV); + } + + // See `Intel Intrinsics Guide` for `_dp_p*` instructions. + // + // 2 and 4 element versions produce single scalar of dot product, and then + // puts it into elements of output vector, selected by 4 lowest bits of the + // mask. Top 4 bits of the mask control which elements of input to use for dot + // product. + // + // 8 element version mask still has only 4 bit for input, and 4 bit for output + // mask. According to the spec it just operates as 4 element version on first + // 4 elements of inputs and output, and then on last 4 elements of inputs and + // output. + void handleDppIntrinsic(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + + Value *S0 = getShadow(&I, 0); + Value *S1 = getShadow(&I, 1); + Value *S = IRB.CreateOr(S0, S1); + + const unsigned Width = + cast<FixedVectorType>(S->getType())->getNumElements(); + assert(Width == 2 || Width == 4 || Width == 8); + + const unsigned Mask = cast<ConstantInt>(I.getArgOperand(2))->getZExtValue(); + const unsigned SrcMask = Mask >> 4; + const unsigned DstMask = Mask & 0xf; + + // Calculate shadow as `<n x i1>`. + Value *SI1 = findDppPoisonedOutput(IRB, S, SrcMask, DstMask); + if (Width == 8) { + // First 4 elements of shadow are already calculated. `makeDppShadow` + // operats on 32 bit masks, so we can just shift masks, and repeat. + SI1 = IRB.CreateOr( + SI1, findDppPoisonedOutput(IRB, S, SrcMask << 4, DstMask << 4)); + } + // Extend to real size of shadow, poisoning either all or none bits of an + // element. + S = IRB.CreateSExt(SI1, S->getType(), "_msdpp"); + + setShadow(&I, S); + setOriginForNaryOp(I); + } + + Value *convertBlendvToSelectMask(IRBuilder<> &IRB, Value *C) { + C = CreateAppToShadowCast(IRB, C); + FixedVectorType *FVT = cast<FixedVectorType>(C->getType()); + unsigned ElSize = FVT->getElementType()->getPrimitiveSizeInBits(); + C = IRB.CreateAShr(C, ElSize - 1); + FVT = FixedVectorType::get(IRB.getInt1Ty(), FVT->getNumElements()); + return IRB.CreateTrunc(C, FVT); + } + + // `blendv(f, t, c)` is effectively `select(c[top_bit], t, f)`. + void handleBlendvIntrinsic(IntrinsicInst &I) { + Value *C = I.getOperand(2); + Value *T = I.getOperand(1); + Value *F = I.getOperand(0); + + Value *Sc = getShadow(&I, 2); + Value *Oc = MS.TrackOrigins ? getOrigin(C) : nullptr; + + { + IRBuilder<> IRB(&I); + // Extract top bit from condition and its shadow. + C = convertBlendvToSelectMask(IRB, C); + Sc = convertBlendvToSelectMask(IRB, Sc); + + setShadow(C, Sc); + setOrigin(C, Oc); + } + + handleSelectLikeInst(I, C, T, F); + } + // Instrument sum-of-absolute-differences intrinsic. void handleVectorSadIntrinsic(IntrinsicInst &I) { const unsigned SignificantBitsPerResultElement = 16; @@ -3548,7 +3689,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (!MS.TrackOrigins) return; - auto &DL = F.getParent()->getDataLayout(); + auto &DL = F.getDataLayout(); paintOrigin(IRB, getOrigin(V), OriginPtr, DL.getTypeStoreSize(Shadow->getType()), std::max(Alignment, kMinOriginAlignment)); @@ -3616,7 +3757,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOriginForNaryOp(I); } - SmallVector<int, 8> getPclmulMask(unsigned Width, bool OddElements) { + static SmallVector<int, 8> getPclmulMask(unsigned Width, bool OddElements) { SmallVector<int, 8> Mask; for (unsigned X = OddElements ? 1 : 0; X < Width; X += 2) { Mask.append(2, X); @@ -3718,8 +3859,95 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOrigin(&I, getOrigin(&I, 0)); } + void handleArithmeticWithOverflow(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + Value *Shadow0 = getShadow(&I, 0); + Value *Shadow1 = getShadow(&I, 1); + Value *ShadowElt0 = IRB.CreateOr(Shadow0, Shadow1); + Value *ShadowElt1 = + IRB.CreateICmpNE(ShadowElt0, getCleanShadow(ShadowElt0)); + + Value *Shadow = PoisonValue::get(getShadowTy(&I)); + Shadow = IRB.CreateInsertValue(Shadow, ShadowElt0, 0); + Shadow = IRB.CreateInsertValue(Shadow, ShadowElt1, 1); + + setShadow(&I, Shadow); + setOriginForNaryOp(I); + } + + /// Handle Arm NEON vector store intrinsics (vst{2,3,4}). + /// + /// Arm NEON vector store intrinsics have the output address (pointer) as the + /// last argument, with the initial arguments being the inputs. They return + /// void. + void handleNEONVectorStoreIntrinsic(IntrinsicInst &I) { + IRBuilder<> IRB(&I); + + // Don't use getNumOperands() because it includes the callee + int numArgOperands = I.arg_size(); + assert(numArgOperands >= 1); + + // The last arg operand is the output + Value *Addr = I.getArgOperand(numArgOperands - 1); + assert(Addr->getType()->isPointerTy()); + + if (ClCheckAccessAddress) + insertShadowCheck(Addr, &I); + + // Every arg operand, other than the last one, is an input vector + IntrinsicInst *ShadowI = cast<IntrinsicInst>(I.clone()); + for (int i = 0; i < numArgOperands - 1; i++) { + assert(isa<FixedVectorType>(I.getArgOperand(i)->getType())); + ShadowI->setArgOperand(i, getShadow(&I, i)); + } + + // MSan's GetShadowTy assumes the LHS is the type we want the shadow for + // e.g., for: + // [[TMP5:%.*]] = bitcast <16 x i8> [[TMP2]] to i128 + // we know the type of the output (and its shadow) is <16 x i8>. + // + // Arm NEON VST is unusual because the last argument is the output address: + // define void @st2_16b(<16 x i8> %A, <16 x i8> %B, ptr %P) { + // call void @llvm.aarch64.neon.st2.v16i8.p0 + // (<16 x i8> [[A]], <16 x i8> [[B]], ptr [[P]]) + // and we have no type information about P's operand. We must manually + // compute the type (<16 x i8> x 2). + FixedVectorType *OutputVectorTy = FixedVectorType::get( + cast<FixedVectorType>(I.getArgOperand(0)->getType())->getElementType(), + cast<FixedVectorType>(I.getArgOperand(0)->getType())->getNumElements() * + (numArgOperands - 1)); + Type *ShadowTy = getShadowTy(OutputVectorTy); + Value *ShadowPtr, *OriginPtr; + // AArch64 NEON does not need alignment (unless OS requires it) + std::tie(ShadowPtr, OriginPtr) = + getShadowOriginPtr(Addr, IRB, ShadowTy, Align(1), /*isStore*/ true); + ShadowI->setArgOperand(numArgOperands - 1, ShadowPtr); + ShadowI->insertAfter(&I); + + if (MS.TrackOrigins) { + // TODO: if we modelled the vst* instruction more precisely, we could + // more accurately track the origins (e.g., if both inputs are + // uninitialized for vst2, we currently blame the second input, even + // though part of the output depends only on the first input). + OriginCombiner OC(this, IRB); + for (int i = 0; i < numArgOperands - 1; i++) + OC.Add(I.getArgOperand(i)); + + const DataLayout &DL = F.getDataLayout(); + OC.DoneAndStoreOrigin(DL.getTypeStoreSize(OutputVectorTy), OriginPtr); + } + } + void visitIntrinsicInst(IntrinsicInst &I) { switch (I.getIntrinsicID()) { + case Intrinsic::uadd_with_overflow: + case Intrinsic::sadd_with_overflow: + case Intrinsic::usub_with_overflow: + case Intrinsic::ssub_with_overflow: + case Intrinsic::umul_with_overflow: + case Intrinsic::smul_with_overflow: + handleArithmeticWithOverflow(I); + break; case Intrinsic::abs: handleAbsIntrinsic(I); break; @@ -3908,6 +4136,21 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { handleVectorPackIntrinsic(I); break; + case Intrinsic::x86_sse41_pblendvb: + case Intrinsic::x86_sse41_blendvpd: + case Intrinsic::x86_sse41_blendvps: + case Intrinsic::x86_avx_blendv_pd_256: + case Intrinsic::x86_avx_blendv_ps_256: + case Intrinsic::x86_avx2_pblendvb: + handleBlendvIntrinsic(I); + break; + + case Intrinsic::x86_avx_dp_ps_256: + case Intrinsic::x86_sse41_dppd: + case Intrinsic::x86_sse41_dpps: + handleDppIntrinsic(I); + break; + case Intrinsic::x86_mmx_packsswb: case Intrinsic::x86_mmx_packuswb: handleVectorPackIntrinsic(I, 16); @@ -4034,6 +4277,13 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOrigin(&I, getCleanOrigin()); break; + case Intrinsic::aarch64_neon_st2: + case Intrinsic::aarch64_neon_st3: + case Intrinsic::aarch64_neon_st4: { + handleNEONVectorStoreIntrinsic(I); + break; + } + default: if (!handleUnknownIntrinsic(I)) visitInstruction(I); @@ -4103,11 +4353,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // do the usual thing: check argument shadow and mark all outputs as // clean. Note that any side effects of the inline asm that are not // immediately visible in its constraints are not handled. - // For now, handle inline asm by default for KMSAN. - bool HandleAsm = ClHandleAsmConservative.getNumOccurrences() - ? ClHandleAsmConservative - : MS.CompileKernel; - if (HandleAsm) + if (ClHandleAsmConservative) visitAsmInstruction(CB); else visitInstruction(CB); @@ -4168,8 +4414,16 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { LLVM_DEBUG(dbgs() << "Arg " << i << " is not sized: " << CB << "\n"); continue; } + + if (A->getType()->isScalableTy()) { + LLVM_DEBUG(dbgs() << "Arg " << i << " is vscale: " << CB << "\n"); + // Handle as noundef, but don't reserve tls slots. + insertShadowCheck(A, &CB); + continue; + } + unsigned Size = 0; - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); bool ByVal = CB.paramHasAttr(i, Attribute::ByVal); bool NoUndef = CB.paramHasAttr(i, Attribute::NoUndef); @@ -4405,9 +4659,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (!InsPoint) InsPoint = &I; NextNodeIRBuilder IRB(InsPoint); - const DataLayout &DL = F.getParent()->getDataLayout(); - uint64_t TypeSize = DL.getTypeAllocSize(I.getAllocatedType()); - Value *Len = ConstantInt::get(MS.IntptrTy, TypeSize); + const DataLayout &DL = F.getDataLayout(); + TypeSize TS = DL.getTypeAllocSize(I.getAllocatedType()); + Value *Len = IRB.CreateTypeSize(MS.IntptrTy, TS); if (I.isArrayAllocation()) Len = IRB.CreateMul(Len, IRB.CreateZExtOrTrunc(I.getArraySize(), MS.IntptrTy)); @@ -4427,15 +4681,25 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } void visitSelectInst(SelectInst &I) { - IRBuilder<> IRB(&I); // a = select b, c, d Value *B = I.getCondition(); Value *C = I.getTrueValue(); Value *D = I.getFalseValue(); + + handleSelectLikeInst(I, B, C, D); + } + + void handleSelectLikeInst(Instruction &I, Value *B, Value *C, Value *D) { + IRBuilder<> IRB(&I); + Value *Sb = getShadow(B); Value *Sc = getShadow(C); Value *Sd = getShadow(D); + Value *Ob = MS.TrackOrigins ? getOrigin(B) : nullptr; + Value *Oc = MS.TrackOrigins ? getOrigin(C) : nullptr; + Value *Od = MS.TrackOrigins ? getOrigin(D) : nullptr; + // Result shadow if condition shadow is 0. Value *Sa0 = IRB.CreateSelect(B, Sc, Sd); Value *Sa1; @@ -4468,10 +4732,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } // a = select b, c, d // Oa = Sb ? Ob : (b ? Oc : Od) - setOrigin( - &I, IRB.CreateSelect(Sb, getOrigin(I.getCondition()), - IRB.CreateSelect(B, getOrigin(I.getTrueValue()), - getOrigin(I.getFalseValue())))); + setOrigin(&I, IRB.CreateSelect(Sb, Ob, IRB.CreateSelect(B, Oc, Od))); } } @@ -4559,16 +4820,22 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } if (!ElemTy->isSized()) return; - Value *SizeVal = - IRB.CreateTypeSize(MS.IntptrTy, DL.getTypeStoreSize(ElemTy)); + auto Size = DL.getTypeStoreSize(ElemTy); + Value *SizeVal = IRB.CreateTypeSize(MS.IntptrTy, Size); if (MS.CompileKernel) { IRB.CreateCall(MS.MsanInstrumentAsmStoreFn, {Operand, SizeVal}); } else { // ElemTy, derived from elementtype(), does not encode the alignment of // the pointer. Conservatively assume that the shadow memory is unaligned. + // When Size is large, avoid StoreInst as it would expand to many + // instructions. auto [ShadowPtr, _] = getShadowOriginPtrUserspace(Operand, IRB, IRB.getInt8Ty(), Align(1)); - IRB.CreateAlignedStore(getCleanShadow(ElemTy), ShadowPtr, Align(1)); + if (Size <= 32) + IRB.CreateAlignedStore(getCleanShadow(ElemTy), ShadowPtr, Align(1)); + else + IRB.CreateMemSet(ShadowPtr, ConstantInt::getNullValue(IRB.getInt8Ty()), + SizeVal, Align(1)); } } @@ -4614,7 +4881,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // The total number of asm() arguments in the source is nR+nO+nI, and the // corresponding CallInst has nO+nI+1 operands (the last operand is the // function to be called). - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); CallBase *CB = cast<CallBase>(&I); IRBuilder<> IRB(&I); InlineAsm *IA = cast<InlineAsm>(CB->getCalledOperand()); @@ -4802,7 +5069,7 @@ struct VarArgAMD64Helper : public VarArgHelperBase { unsigned GpOffset = 0; unsigned FpOffset = AMD64GpEndOffset; unsigned OverflowOffset = AMD64FpEndOffset; - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); for (const auto &[ArgNo, A] : llvm::enumerate(CB.args())) { bool IsFixed = ArgNo < CB.getFunctionType()->getNumParams(); @@ -4930,8 +5197,7 @@ struct VarArgAMD64Helper : public VarArgHelperBase { // Instrument va_start. // Copy va_list shadow from the backup copy of the TLS contents. - for (size_t i = 0, n = VAStartInstrumentationList.size(); i < n; i++) { - CallInst *OrigInst = VAStartInstrumentationList[i]; + for (CallInst *OrigInst : VAStartInstrumentationList) { NextNodeIRBuilder IRB(OrigInst); Value *VAListTag = OrigInst->getArgOperand(0); @@ -4989,7 +5255,7 @@ struct VarArgMIPS64Helper : public VarArgHelperBase { void visitCallBase(CallBase &CB, IRBuilder<> &IRB) override { unsigned VAArgOffset = 0; - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); for (Value *A : llvm::drop_begin(CB.args(), CB.getFunctionType()->getNumParams())) { Triple TargetTriple(F.getParent()->getTargetTriple()); @@ -5040,8 +5306,7 @@ struct VarArgMIPS64Helper : public VarArgHelperBase { // Instrument va_start. // Copy va_list shadow from the backup copy of the TLS contents. - for (size_t i = 0, n = VAStartInstrumentationList.size(); i < n; i++) { - CallInst *OrigInst = VAStartInstrumentationList[i]; + for (CallInst *OrigInst : VAStartInstrumentationList) { NextNodeIRBuilder IRB(OrigInst); Value *VAListTag = OrigInst->getArgOperand(0); Type *RegSaveAreaPtrTy = PointerType::getUnqual(*MS.C); // i64* @@ -5120,7 +5385,7 @@ struct VarArgAArch64Helper : public VarArgHelperBase { unsigned VrOffset = AArch64VrBegOffset; unsigned OverflowOffset = AArch64VAEndOffset; - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); for (const auto &[ArgNo, A] : llvm::enumerate(CB.args())) { bool IsFixed = ArgNo < CB.getFunctionType()->getNumParams(); auto [AK, RegNum] = classifyArgument(A->getType()); @@ -5215,8 +5480,7 @@ struct VarArgAArch64Helper : public VarArgHelperBase { // Instrument va_start, copy va_list shadow from the backup copy of // the TLS contents. - for (size_t i = 0, n = VAStartInstrumentationList.size(); i < n; i++) { - CallInst *OrigInst = VAStartInstrumentationList[i]; + for (CallInst *OrigInst : VAStartInstrumentationList) { NextNodeIRBuilder IRB(OrigInst); Value *VAListTag = OrigInst->getArgOperand(0); @@ -5331,7 +5595,7 @@ struct VarArgPowerPC64Helper : public VarArgHelperBase { else VAArgBase = 32; unsigned VAArgOffset = VAArgBase; - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); for (const auto &[ArgNo, A] : llvm::enumerate(CB.args())) { bool IsFixed = ArgNo < CB.getFunctionType()->getNumParams(); bool IsByVal = CB.paramHasAttr(ArgNo, Attribute::ByVal); @@ -5426,8 +5690,7 @@ struct VarArgPowerPC64Helper : public VarArgHelperBase { // Instrument va_start. // Copy va_list shadow from the backup copy of the TLS contents. - for (size_t i = 0, n = VAStartInstrumentationList.size(); i < n; i++) { - CallInst *OrigInst = VAStartInstrumentationList[i]; + for (CallInst *OrigInst : VAStartInstrumentationList) { NextNodeIRBuilder IRB(OrigInst); Value *VAListTag = OrigInst->getArgOperand(0); Type *RegSaveAreaPtrTy = PointerType::getUnqual(*MS.C); // i64* @@ -5522,7 +5785,7 @@ struct VarArgSystemZHelper : public VarArgHelperBase { unsigned FpOffset = SystemZFpOffset; unsigned VrIndex = 0; unsigned OverflowOffset = SystemZOverflowOffset; - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); for (const auto &[ArgNo, A] : llvm::enumerate(CB.args())) { bool IsFixed = ArgNo < CB.getFunctionType()->getNumParams(); // SystemZABIInfo does not produce ByVal parameters. @@ -5723,9 +5986,7 @@ struct VarArgSystemZHelper : public VarArgHelperBase { // Instrument va_start. // Copy va_list shadow from the backup copy of the TLS contents. - for (size_t VaStartNo = 0, VaStartNum = VAStartInstrumentationList.size(); - VaStartNo < VaStartNum; VaStartNo++) { - CallInst *OrigInst = VAStartInstrumentationList[VaStartNo]; + for (CallInst *OrigInst : VAStartInstrumentationList) { NextNodeIRBuilder IRB(OrigInst); Value *VAListTag = OrigInst->getArgOperand(0); copyRegSaveArea(IRB, VAListTag); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/NumericalStabilitySanitizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/NumericalStabilitySanitizer.cpp new file mode 100644 index 000000000000..99b1c779f316 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/NumericalStabilitySanitizer.cpp @@ -0,0 +1,2168 @@ +//===-- NumericalStabilitySanitizer.cpp -----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the instrumentation pass for the numerical sanitizer. +// Conceptually the pass injects shadow computations using higher precision +// types and inserts consistency checks. For details see the paper +// https://arxiv.org/abs/2102.12782. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Instrumentation/NumericalStabilitySanitizer.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/InitializePasses.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/Regex.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/EscapeEnumerator.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" + +#include <cstdint> + +using namespace llvm; + +#define DEBUG_TYPE "nsan" + +STATISTIC(NumInstrumentedFTLoads, + "Number of instrumented floating-point loads"); + +STATISTIC(NumInstrumentedFTCalls, + "Number of instrumented floating-point calls"); +STATISTIC(NumInstrumentedFTRets, + "Number of instrumented floating-point returns"); +STATISTIC(NumInstrumentedFTStores, + "Number of instrumented floating-point stores"); +STATISTIC(NumInstrumentedNonFTStores, + "Number of instrumented non floating-point stores"); +STATISTIC( + NumInstrumentedNonFTMemcpyStores, + "Number of instrumented non floating-point stores with memcpy semantics"); +STATISTIC(NumInstrumentedFCmp, "Number of instrumented fcmps"); + +// Using smaller shadow types types can help improve speed. For example, `dlq` +// is 3x slower to 5x faster in opt mode and 2-6x faster in dbg mode compared to +// `dqq`. +static cl::opt<std::string> ClShadowMapping( + "nsan-shadow-type-mapping", cl::init("dqq"), + cl::desc("One shadow type id for each of `float`, `double`, `long double`. " + "`d`,`l`,`q`,`e` mean double, x86_fp80, fp128 (quad) and " + "ppc_fp128 (extended double) respectively. The default is to " + "shadow `float` as `double`, and `double` and `x86_fp80` as " + "`fp128`"), + cl::Hidden); + +static cl::opt<bool> + ClInstrumentFCmp("nsan-instrument-fcmp", cl::init(true), + cl::desc("Instrument floating-point comparisons"), + cl::Hidden); + +static cl::opt<std::string> ClCheckFunctionsFilter( + "check-functions-filter", + cl::desc("Only emit checks for arguments of functions " + "whose names match the given regular expression"), + cl::value_desc("regex")); + +static cl::opt<bool> ClTruncateFCmpEq( + "nsan-truncate-fcmp-eq", cl::init(true), + cl::desc( + "This flag controls the behaviour of fcmp equality comparisons." + "For equality comparisons such as `x == 0.0f`, we can perform the " + "shadow check in the shadow (`x_shadow == 0.0) == (x == 0.0f)`) or app " + " domain (`(trunc(x_shadow) == 0.0f) == (x == 0.0f)`). This helps " + "catch the case when `x_shadow` is accurate enough (and therefore " + "close enough to zero) so that `trunc(x_shadow)` is zero even though " + "both `x` and `x_shadow` are not"), + cl::Hidden); + +// When there is external, uninstrumented code writing to memory, the shadow +// memory can get out of sync with the application memory. Enabling this flag +// emits consistency checks for loads to catch this situation. +// When everything is instrumented, this is not strictly necessary because any +// load should have a corresponding store, but can help debug cases when the +// framework did a bad job at tracking shadow memory modifications by failing on +// load rather than store. +// TODO: provide a way to resume computations from the FT value when the load +// is inconsistent. This ensures that further computations are not polluted. +static cl::opt<bool> ClCheckLoads("nsan-check-loads", + cl::desc("Check floating-point load"), + cl::Hidden); + +static cl::opt<bool> ClCheckStores("nsan-check-stores", cl::init(true), + cl::desc("Check floating-point stores"), + cl::Hidden); + +static cl::opt<bool> ClCheckRet("nsan-check-ret", cl::init(true), + cl::desc("Check floating-point return values"), + cl::Hidden); + +// LLVM may store constant floats as bitcasted ints. +// It's not really necessary to shadow such stores, +// if the shadow value is unknown the framework will re-extend it on load +// anyway. Moreover, because of size collisions (e.g. bf16 vs f16) it is +// impossible to determine the floating-point type based on the size. +// However, for debugging purposes it can be useful to model such stores. +static cl::opt<bool> ClPropagateNonFTConstStoresAsFT( + "nsan-propagate-non-ft-const-stores-as-ft", + cl::desc( + "Propagate non floating-point const stores as floating point values." + "For debugging purposes only"), + cl::Hidden); + +constexpr StringLiteral kNsanModuleCtorName("nsan.module_ctor"); +constexpr StringLiteral kNsanInitName("__nsan_init"); + +// The following values must be kept in sync with the runtime. +constexpr int kShadowScale = 2; +constexpr int kMaxVectorWidth = 8; +constexpr int kMaxNumArgs = 128; +constexpr int kMaxShadowTypeSizeBytes = 16; // fp128 + +namespace { + +// Defines the characteristics (type id, type, and floating-point semantics) +// attached for all possible shadow types. +class ShadowTypeConfig { +public: + static std::unique_ptr<ShadowTypeConfig> fromNsanTypeId(char TypeId); + + // The LLVM Type corresponding to the shadow type. + virtual Type *getType(LLVMContext &Context) const = 0; + + // The nsan type id of the shadow type (`d`, `l`, `q`, ...). + virtual char getNsanTypeId() const = 0; + + virtual ~ShadowTypeConfig() = default; +}; + +template <char NsanTypeId> +class ShadowTypeConfigImpl : public ShadowTypeConfig { +public: + char getNsanTypeId() const override { return NsanTypeId; } + static constexpr const char kNsanTypeId = NsanTypeId; +}; + +// `double` (`d`) shadow type. +class F64ShadowConfig : public ShadowTypeConfigImpl<'d'> { + Type *getType(LLVMContext &Context) const override { + return Type::getDoubleTy(Context); + } +}; + +// `x86_fp80` (`l`) shadow type: X86 long double. +class F80ShadowConfig : public ShadowTypeConfigImpl<'l'> { + Type *getType(LLVMContext &Context) const override { + return Type::getX86_FP80Ty(Context); + } +}; + +// `fp128` (`q`) shadow type. +class F128ShadowConfig : public ShadowTypeConfigImpl<'q'> { + Type *getType(LLVMContext &Context) const override { + return Type::getFP128Ty(Context); + } +}; + +// `ppc_fp128` (`e`) shadow type: IBM extended double with 106 bits of mantissa. +class PPC128ShadowConfig : public ShadowTypeConfigImpl<'e'> { + Type *getType(LLVMContext &Context) const override { + return Type::getPPC_FP128Ty(Context); + } +}; + +// Creates a ShadowTypeConfig given its type id. +std::unique_ptr<ShadowTypeConfig> +ShadowTypeConfig::fromNsanTypeId(const char TypeId) { + switch (TypeId) { + case F64ShadowConfig::kNsanTypeId: + return std::make_unique<F64ShadowConfig>(); + case F80ShadowConfig::kNsanTypeId: + return std::make_unique<F80ShadowConfig>(); + case F128ShadowConfig::kNsanTypeId: + return std::make_unique<F128ShadowConfig>(); + case PPC128ShadowConfig::kNsanTypeId: + return std::make_unique<PPC128ShadowConfig>(); + } + report_fatal_error("nsan: invalid shadow type id '" + Twine(TypeId) + "'"); +} + +// An enum corresponding to shadow value types. Used as indices in arrays, so +// not an `enum class`. +enum FTValueType { kFloat, kDouble, kLongDouble, kNumValueTypes }; + +// If `FT` corresponds to a primitive FTValueType, return it. +static std::optional<FTValueType> ftValueTypeFromType(Type *FT) { + if (FT->isFloatTy()) + return kFloat; + if (FT->isDoubleTy()) + return kDouble; + if (FT->isX86_FP80Ty()) + return kLongDouble; + return {}; +} + +// Returns the LLVM type for an FTValueType. +static Type *typeFromFTValueType(FTValueType VT, LLVMContext &Context) { + switch (VT) { + case kFloat: + return Type::getFloatTy(Context); + case kDouble: + return Type::getDoubleTy(Context); + case kLongDouble: + return Type::getX86_FP80Ty(Context); + case kNumValueTypes: + return nullptr; + } + llvm_unreachable("Unhandled FTValueType enum"); +} + +// Returns the type name for an FTValueType. +static const char *typeNameFromFTValueType(FTValueType VT) { + switch (VT) { + case kFloat: + return "float"; + case kDouble: + return "double"; + case kLongDouble: + return "longdouble"; + case kNumValueTypes: + return nullptr; + } + llvm_unreachable("Unhandled FTValueType enum"); +} + +// A specific mapping configuration of application type to shadow type for nsan +// (see -nsan-shadow-mapping flag). +class MappingConfig { +public: + explicit MappingConfig(LLVMContext &C) : Context(C) { + if (ClShadowMapping.size() != 3) + report_fatal_error("Invalid nsan mapping: " + Twine(ClShadowMapping)); + unsigned ShadowTypeSizeBits[kNumValueTypes]; + for (int VT = 0; VT < kNumValueTypes; ++VT) { + auto Config = ShadowTypeConfig::fromNsanTypeId(ClShadowMapping[VT]); + if (!Config) + report_fatal_error("Failed to get ShadowTypeConfig for " + + Twine(ClShadowMapping[VT])); + const unsigned AppTypeSize = + typeFromFTValueType(static_cast<FTValueType>(VT), Context) + ->getScalarSizeInBits(); + const unsigned ShadowTypeSize = + Config->getType(Context)->getScalarSizeInBits(); + // Check that the shadow type size is at most kShadowScale times the + // application type size, so that shadow memory compoutations are valid. + if (ShadowTypeSize > kShadowScale * AppTypeSize) + report_fatal_error("Invalid nsan mapping f" + Twine(AppTypeSize) + + "->f" + Twine(ShadowTypeSize) + + ": The shadow type size should be at most " + + Twine(kShadowScale) + + " times the application type size"); + ShadowTypeSizeBits[VT] = ShadowTypeSize; + Configs[VT] = std::move(Config); + } + + // Check that the mapping is monotonous. This is required because if one + // does an fpextend of `float->long double` in application code, nsan is + // going to do an fpextend of `shadow(float) -> shadow(long double)` in + // shadow code. This will fail in `qql` mode, since nsan would be + // fpextending `f128->long`, which is invalid. + // TODO: Relax this. + if (ShadowTypeSizeBits[kFloat] > ShadowTypeSizeBits[kDouble] || + ShadowTypeSizeBits[kDouble] > ShadowTypeSizeBits[kLongDouble]) + report_fatal_error("Invalid nsan mapping: { float->f" + + Twine(ShadowTypeSizeBits[kFloat]) + "; double->f" + + Twine(ShadowTypeSizeBits[kDouble]) + + "; long double->f" + + Twine(ShadowTypeSizeBits[kLongDouble]) + " }"); + } + + const ShadowTypeConfig &byValueType(FTValueType VT) const { + assert(VT < FTValueType::kNumValueTypes && "invalid value type"); + return *Configs[VT]; + } + + // Returns the extended shadow type for a given application type. + Type *getExtendedFPType(Type *FT) const { + if (const auto VT = ftValueTypeFromType(FT)) + return Configs[*VT]->getType(Context); + if (FT->isVectorTy()) { + auto *VecTy = cast<VectorType>(FT); + // TODO: add support for scalable vector types. + if (VecTy->isScalableTy()) + return nullptr; + Type *ExtendedScalar = getExtendedFPType(VecTy->getElementType()); + return ExtendedScalar + ? VectorType::get(ExtendedScalar, VecTy->getElementCount()) + : nullptr; + } + return nullptr; + } + +private: + LLVMContext &Context; + std::unique_ptr<ShadowTypeConfig> Configs[FTValueType::kNumValueTypes]; +}; + +// The memory extents of a type specifies how many elements of a given +// FTValueType needs to be stored when storing this type. +struct MemoryExtents { + FTValueType ValueType; + uint64_t NumElts; +}; + +static MemoryExtents getMemoryExtentsOrDie(Type *FT) { + if (const auto VT = ftValueTypeFromType(FT)) + return {*VT, 1}; + if (auto *VecTy = dyn_cast<VectorType>(FT)) { + const auto ScalarExtents = getMemoryExtentsOrDie(VecTy->getElementType()); + return {ScalarExtents.ValueType, + ScalarExtents.NumElts * VecTy->getElementCount().getFixedValue()}; + } + llvm_unreachable("invalid value type"); +} + +// The location of a check. Passed as parameters to runtime checking functions. +class CheckLoc { +public: + // Creates a location that references an application memory location. + static CheckLoc makeStore(Value *Address) { + CheckLoc Result(kStore); + Result.Address = Address; + return Result; + } + static CheckLoc makeLoad(Value *Address) { + CheckLoc Result(kLoad); + Result.Address = Address; + return Result; + } + + // Creates a location that references an argument, given by id. + static CheckLoc makeArg(int ArgId) { + CheckLoc Result(kArg); + Result.ArgId = ArgId; + return Result; + } + + // Creates a location that references the return value of a function. + static CheckLoc makeRet() { return CheckLoc(kRet); } + + // Creates a location that references a vector insert. + static CheckLoc makeInsert() { return CheckLoc(kInsert); } + + // Returns the CheckType of location this refers to, as an integer-typed LLVM + // IR value. + Value *getType(LLVMContext &C) const { + return ConstantInt::get(Type::getInt32Ty(C), static_cast<int>(CheckTy)); + } + + // Returns a CheckType-specific value representing details of the location + // (e.g. application address for loads or stores), as an `IntptrTy`-typed LLVM + // IR value. + Value *getValue(Type *IntptrTy, IRBuilder<> &Builder) const { + switch (CheckTy) { + case kUnknown: + llvm_unreachable("unknown type"); + case kRet: + case kInsert: + return ConstantInt::get(IntptrTy, 0); + case kArg: + return ConstantInt::get(IntptrTy, ArgId); + case kLoad: + case kStore: + return Builder.CreatePtrToInt(Address, IntptrTy); + } + llvm_unreachable("Unhandled CheckType enum"); + } + +private: + // Must be kept in sync with the runtime, + // see compiler-rt/lib/nsan/nsan_stats.h + enum CheckType { + kUnknown = 0, + kRet, + kArg, + kLoad, + kStore, + kInsert, + }; + explicit CheckLoc(CheckType CheckTy) : CheckTy(CheckTy) {} + + Value *Address = nullptr; + const CheckType CheckTy; + int ArgId = -1; +}; + +// A map of LLVM IR values to shadow LLVM IR values. +class ValueToShadowMap { +public: + explicit ValueToShadowMap(const MappingConfig &Config) : Config(Config) {} + + ValueToShadowMap(const ValueToShadowMap &) = delete; + ValueToShadowMap &operator=(const ValueToShadowMap &) = delete; + + // Sets the shadow value for a value. Asserts that the value does not already + // have a value. + void setShadow(Value &V, Value &Shadow) { + [[maybe_unused]] const bool Inserted = Map.try_emplace(&V, &Shadow).second; + LLVM_DEBUG({ + if (!Inserted) { + if (auto *I = dyn_cast<Instruction>(&V)) + errs() << I->getFunction()->getName() << ": "; + errs() << "duplicate shadow (" << &V << "): "; + V.dump(); + } + }); + assert(Inserted && "duplicate shadow"); + } + + // Returns true if the value already has a shadow (including if the value is a + // constant). If true, calling getShadow() is valid. + bool hasShadow(Value *V) const { + return isa<Constant>(V) || (Map.find(V) != Map.end()); + } + + // Returns the shadow value for a given value. Asserts that the value has + // a shadow value. Lazily creates shadows for constant values. + Value *getShadow(Value *V) const { + if (Constant *C = dyn_cast<Constant>(V)) + return getShadowConstant(C); + return Map.find(V)->second; + } + + bool empty() const { return Map.empty(); } + +private: + // Extends a constant application value to its shadow counterpart. + APFloat extendConstantFP(APFloat CV, const fltSemantics &To) const { + bool LosesInfo = false; + CV.convert(To, APFloatBase::rmTowardZero, &LosesInfo); + return CV; + } + + // Returns the shadow constant for the given application constant. + Constant *getShadowConstant(Constant *C) const { + if (UndefValue *U = dyn_cast<UndefValue>(C)) { + return UndefValue::get(Config.getExtendedFPType(U->getType())); + } + if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) { + // Floating-point constants. + Type *Ty = Config.getExtendedFPType(CFP->getType()); + return ConstantFP::get( + Ty, extendConstantFP(CFP->getValueAPF(), Ty->getFltSemantics())); + } + // Vector, array, or aggregate constants. + if (C->getType()->isVectorTy()) { + SmallVector<Constant *, 8> Elements; + for (int I = 0, E = cast<VectorType>(C->getType()) + ->getElementCount() + .getFixedValue(); + I < E; ++I) + Elements.push_back(getShadowConstant(C->getAggregateElement(I))); + return ConstantVector::get(Elements); + } + llvm_unreachable("unimplemented"); + } + + const MappingConfig &Config; + DenseMap<Value *, Value *> Map; +}; + +/// Instantiating NumericalStabilitySanitizer inserts the nsan runtime library +/// API function declarations into the module if they don't exist already. +/// Instantiating ensures the __nsan_init function is in the list of global +/// constructors for the module. +class NumericalStabilitySanitizer { +public: + NumericalStabilitySanitizer(Module &M); + bool sanitizeFunction(Function &F, const TargetLibraryInfo &TLI); + +private: + bool instrumentMemIntrinsic(MemIntrinsic *MI); + void maybeAddSuffixForNsanInterface(CallBase *CI); + bool addrPointsToConstantData(Value *Addr); + void maybeCreateShadowValue(Instruction &Root, const TargetLibraryInfo &TLI, + ValueToShadowMap &Map); + Value *createShadowValueWithOperandsAvailable(Instruction &Inst, + const TargetLibraryInfo &TLI, + const ValueToShadowMap &Map); + PHINode *maybeCreateShadowPhi(PHINode &Phi, const TargetLibraryInfo &TLI); + void createShadowArguments(Function &F, const TargetLibraryInfo &TLI, + ValueToShadowMap &Map); + + void populateShadowStack(CallBase &CI, const TargetLibraryInfo &TLI, + const ValueToShadowMap &Map); + + void propagateShadowValues(Instruction &Inst, const TargetLibraryInfo &TLI, + const ValueToShadowMap &Map); + Value *emitCheck(Value *V, Value *ShadowV, IRBuilder<> &Builder, + CheckLoc Loc); + Value *emitCheckInternal(Value *V, Value *ShadowV, IRBuilder<> &Builder, + CheckLoc Loc); + void emitFCmpCheck(FCmpInst &FCmp, const ValueToShadowMap &Map); + + // Value creation handlers. + Value *handleLoad(LoadInst &Load, Type *VT, Type *ExtendedVT); + Value *handleCallBase(CallBase &Call, Type *VT, Type *ExtendedVT, + const TargetLibraryInfo &TLI, + const ValueToShadowMap &Map, IRBuilder<> &Builder); + Value *maybeHandleKnownCallBase(CallBase &Call, Type *VT, Type *ExtendedVT, + const TargetLibraryInfo &TLI, + const ValueToShadowMap &Map, + IRBuilder<> &Builder); + Value *handleTrunc(const FPTruncInst &Trunc, Type *VT, Type *ExtendedVT, + const ValueToShadowMap &Map, IRBuilder<> &Builder); + Value *handleExt(const FPExtInst &Ext, Type *VT, Type *ExtendedVT, + const ValueToShadowMap &Map, IRBuilder<> &Builder); + + // Value propagation handlers. + void propagateFTStore(StoreInst &Store, Type *VT, Type *ExtendedVT, + const ValueToShadowMap &Map); + void propagateNonFTStore(StoreInst &Store, Type *VT, + const ValueToShadowMap &Map); + + const DataLayout &DL; + LLVMContext &Context; + MappingConfig Config; + IntegerType *IntptrTy = nullptr; + FunctionCallee NsanGetShadowPtrForStore[FTValueType::kNumValueTypes] = {}; + FunctionCallee NsanGetShadowPtrForLoad[FTValueType::kNumValueTypes] = {}; + FunctionCallee NsanCheckValue[FTValueType::kNumValueTypes] = {}; + FunctionCallee NsanFCmpFail[FTValueType::kNumValueTypes] = {}; + FunctionCallee NsanCopyValues; + FunctionCallee NsanSetValueUnknown; + FunctionCallee NsanGetRawShadowTypePtr; + FunctionCallee NsanGetRawShadowPtr; + GlobalValue *NsanShadowRetTag = nullptr; + + Type *NsanShadowRetType = nullptr; + GlobalValue *NsanShadowRetPtr = nullptr; + + GlobalValue *NsanShadowArgsTag = nullptr; + + Type *NsanShadowArgsType = nullptr; + GlobalValue *NsanShadowArgsPtr = nullptr; + + std::optional<Regex> CheckFunctionsFilter; +}; +} // end anonymous namespace + +PreservedAnalyses +NumericalStabilitySanitizerPass::run(Module &M, ModuleAnalysisManager &MAM) { + getOrCreateSanitizerCtorAndInitFunctions( + M, kNsanModuleCtorName, kNsanInitName, /*InitArgTypes=*/{}, + /*InitArgs=*/{}, + // This callback is invoked when the functions are created the first + // time. Hook them into the global ctors list in that case: + [&](Function *Ctor, FunctionCallee) { appendToGlobalCtors(M, Ctor, 0); }); + + NumericalStabilitySanitizer Nsan(M); + auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + for (Function &F : M) + Nsan.sanitizeFunction(F, FAM.getResult<TargetLibraryAnalysis>(F)); + + return PreservedAnalyses::none(); +} + +static GlobalValue *createThreadLocalGV(const char *Name, Module &M, Type *Ty) { + return dyn_cast<GlobalValue>(M.getOrInsertGlobal(Name, Ty, [&M, Ty, Name] { + return new GlobalVariable(M, Ty, false, GlobalVariable::ExternalLinkage, + nullptr, Name, nullptr, + GlobalVariable::InitialExecTLSModel); + })); +} + +NumericalStabilitySanitizer::NumericalStabilitySanitizer(Module &M) + : DL(M.getDataLayout()), Context(M.getContext()), Config(Context) { + IntptrTy = DL.getIntPtrType(Context); + Type *PtrTy = PointerType::getUnqual(Context); + Type *Int32Ty = Type::getInt32Ty(Context); + Type *Int1Ty = Type::getInt1Ty(Context); + Type *VoidTy = Type::getVoidTy(Context); + + AttributeList Attr; + Attr = Attr.addFnAttribute(Context, Attribute::NoUnwind); + // Initialize the runtime values (functions and global variables). + for (int I = 0; I < kNumValueTypes; ++I) { + const FTValueType VT = static_cast<FTValueType>(I); + const char *VTName = typeNameFromFTValueType(VT); + Type *VTTy = typeFromFTValueType(VT, Context); + + // Load/store. + const std::string GetterPrefix = + std::string("__nsan_get_shadow_ptr_for_") + VTName; + NsanGetShadowPtrForStore[VT] = M.getOrInsertFunction( + GetterPrefix + "_store", Attr, PtrTy, PtrTy, IntptrTy); + NsanGetShadowPtrForLoad[VT] = M.getOrInsertFunction( + GetterPrefix + "_load", Attr, PtrTy, PtrTy, IntptrTy); + + // Check. + const auto &ShadowConfig = Config.byValueType(VT); + Type *ShadowTy = ShadowConfig.getType(Context); + NsanCheckValue[VT] = + M.getOrInsertFunction(std::string("__nsan_internal_check_") + VTName + + "_" + ShadowConfig.getNsanTypeId(), + Attr, Int32Ty, VTTy, ShadowTy, Int32Ty, IntptrTy); + NsanFCmpFail[VT] = M.getOrInsertFunction( + std::string("__nsan_fcmp_fail_") + VTName + "_" + + ShadowConfig.getNsanTypeId(), + Attr, VoidTy, VTTy, VTTy, ShadowTy, ShadowTy, Int32Ty, Int1Ty, Int1Ty); + } + + NsanCopyValues = M.getOrInsertFunction("__nsan_copy_values", Attr, VoidTy, + PtrTy, PtrTy, IntptrTy); + NsanSetValueUnknown = M.getOrInsertFunction("__nsan_set_value_unknown", Attr, + VoidTy, PtrTy, IntptrTy); + + // TODO: Add attributes nofree, nosync, readnone, readonly, + NsanGetRawShadowTypePtr = M.getOrInsertFunction( + "__nsan_internal_get_raw_shadow_type_ptr", Attr, PtrTy, PtrTy); + NsanGetRawShadowPtr = M.getOrInsertFunction( + "__nsan_internal_get_raw_shadow_ptr", Attr, PtrTy, PtrTy); + + NsanShadowRetTag = createThreadLocalGV("__nsan_shadow_ret_tag", M, IntptrTy); + + NsanShadowRetType = ArrayType::get(Type::getInt8Ty(Context), + kMaxVectorWidth * kMaxShadowTypeSizeBytes); + NsanShadowRetPtr = + createThreadLocalGV("__nsan_shadow_ret_ptr", M, NsanShadowRetType); + + NsanShadowArgsTag = + createThreadLocalGV("__nsan_shadow_args_tag", M, IntptrTy); + + NsanShadowArgsType = + ArrayType::get(Type::getInt8Ty(Context), + kMaxVectorWidth * kMaxNumArgs * kMaxShadowTypeSizeBytes); + + NsanShadowArgsPtr = + createThreadLocalGV("__nsan_shadow_args_ptr", M, NsanShadowArgsType); + + if (!ClCheckFunctionsFilter.empty()) { + Regex R = Regex(ClCheckFunctionsFilter); + std::string RegexError; + assert(R.isValid(RegexError)); + CheckFunctionsFilter = std::move(R); + } +} + +// Returns true if the given LLVM Value points to constant data (typically, a +// global variable reference). +bool NumericalStabilitySanitizer::addrPointsToConstantData(Value *Addr) { + // If this is a GEP, just analyze its pointer operand. + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Addr)) + Addr = GEP->getPointerOperand(); + + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) + return GV->isConstant(); + return false; +} + +// This instruments the function entry to create shadow arguments. +// Pseudocode: +// if (this_fn_ptr == __nsan_shadow_args_tag) { +// s(arg0) = LOAD<sizeof(arg0)>(__nsan_shadow_args); +// s(arg1) = LOAD<sizeof(arg1)>(__nsan_shadow_args + sizeof(arg0)); +// ... +// __nsan_shadow_args_tag = 0; +// } else { +// s(arg0) = fext(arg0); +// s(arg1) = fext(arg1); +// ... +// } +void NumericalStabilitySanitizer::createShadowArguments( + Function &F, const TargetLibraryInfo &TLI, ValueToShadowMap &Map) { + assert(!F.getIntrinsicID() && "found a definition of an intrinsic"); + + // Do not bother if there are no FP args. + if (all_of(F.args(), [this](const Argument &Arg) { + return Config.getExtendedFPType(Arg.getType()) == nullptr; + })) + return; + + IRBuilder<> Builder(F.getEntryBlock().getFirstNonPHI()); + // The function has shadow args if the shadow args tag matches the function + // address. + Value *HasShadowArgs = Builder.CreateICmpEQ( + Builder.CreateLoad(IntptrTy, NsanShadowArgsTag, /*isVolatile=*/false), + Builder.CreatePtrToInt(&F, IntptrTy)); + + unsigned ShadowArgsOffsetBytes = 0; + for (Argument &Arg : F.args()) { + Type *VT = Arg.getType(); + Type *ExtendedVT = Config.getExtendedFPType(VT); + if (ExtendedVT == nullptr) + continue; // Not an FT value. + Value *L = Builder.CreateAlignedLoad( + ExtendedVT, + Builder.CreateConstGEP2_64(NsanShadowArgsType, NsanShadowArgsPtr, 0, + ShadowArgsOffsetBytes), + Align(1), /*isVolatile=*/false); + Value *Shadow = Builder.CreateSelect(HasShadowArgs, L, + Builder.CreateFPExt(&Arg, ExtendedVT)); + Map.setShadow(Arg, *Shadow); + TypeSize SlotSize = DL.getTypeStoreSize(ExtendedVT); + assert(!SlotSize.isScalable() && "unsupported"); + ShadowArgsOffsetBytes += SlotSize; + } + Builder.CreateStore(ConstantInt::get(IntptrTy, 0), NsanShadowArgsTag); +} + +// Returns true if the instrumentation should emit code to check arguments +// before a function call. +static bool shouldCheckArgs(CallBase &CI, const TargetLibraryInfo &TLI, + const std::optional<Regex> &CheckFunctionsFilter) { + + Function *Fn = CI.getCalledFunction(); + + if (CheckFunctionsFilter) { + // Skip checking args of indirect calls. + if (Fn == nullptr) + return false; + if (CheckFunctionsFilter->match(Fn->getName())) + return true; + return false; + } + + if (Fn == nullptr) + return true; // Always check args of indirect calls. + + // Never check nsan functions, the user called them for a reason. + if (Fn->getName().starts_with("__nsan_")) + return false; + + const auto ID = Fn->getIntrinsicID(); + LibFunc LFunc = LibFunc::NumLibFuncs; + // Always check args of unknown functions. + if (ID == Intrinsic::ID() && !TLI.getLibFunc(*Fn, LFunc)) + return true; + + // Do not check args of an `fabs` call that is used for a comparison. + // This is typically used for `fabs(a-b) < tolerance`, where what matters is + // the result of the comparison, which is already caught be the fcmp checks. + if (ID == Intrinsic::fabs || LFunc == LibFunc_fabsf || + LFunc == LibFunc_fabs || LFunc == LibFunc_fabsl) + for (const auto &U : CI.users()) + if (isa<CmpInst>(U)) + return false; + + return true; // Default is check. +} + +// Populates the shadow call stack (which contains shadow values for every +// floating-point parameter to the function). +void NumericalStabilitySanitizer::populateShadowStack( + CallBase &CI, const TargetLibraryInfo &TLI, const ValueToShadowMap &Map) { + // Do not create a shadow stack for inline asm. + if (CI.isInlineAsm()) + return; + + // Do not bother if there are no FP args. + if (all_of(CI.operands(), [this](const Value *Arg) { + return Config.getExtendedFPType(Arg->getType()) == nullptr; + })) + return; + + IRBuilder<> Builder(&CI); + SmallVector<Value *, 8> ArgShadows; + const bool ShouldCheckArgs = shouldCheckArgs(CI, TLI, CheckFunctionsFilter); + for (auto [ArgIdx, Arg] : enumerate(CI.operands())) { + if (Config.getExtendedFPType(Arg->getType()) == nullptr) + continue; // Not an FT value. + Value *ArgShadow = Map.getShadow(Arg); + ArgShadows.push_back(ShouldCheckArgs ? emitCheck(Arg, ArgShadow, Builder, + CheckLoc::makeArg(ArgIdx)) + : ArgShadow); + } + + // Do not create shadow stacks for intrinsics/known lib funcs. + if (Function *Fn = CI.getCalledFunction()) { + LibFunc LFunc; + if (Fn->isIntrinsic() || TLI.getLibFunc(*Fn, LFunc)) + return; + } + + // Set the shadow stack tag. + Builder.CreateStore(CI.getCalledOperand(), NsanShadowArgsTag); + TypeSize ShadowArgsOffsetBytes = TypeSize::getFixed(0); + + unsigned ShadowArgId = 0; + for (const Value *Arg : CI.operands()) { + Type *VT = Arg->getType(); + Type *ExtendedVT = Config.getExtendedFPType(VT); + if (ExtendedVT == nullptr) + continue; // Not an FT value. + Builder.CreateAlignedStore( + ArgShadows[ShadowArgId++], + Builder.CreateConstGEP2_64(NsanShadowArgsType, NsanShadowArgsPtr, 0, + ShadowArgsOffsetBytes), + Align(1), /*isVolatile=*/false); + TypeSize SlotSize = DL.getTypeStoreSize(ExtendedVT); + assert(!SlotSize.isScalable() && "unsupported"); + ShadowArgsOffsetBytes += SlotSize; + } +} + +// Internal part of emitCheck(). Returns a value that indicates whether +// computation should continue with the shadow or resume by re-fextending the +// value. +enum class ContinuationType { // Keep in sync with runtime. + ContinueWithShadow = 0, + ResumeFromValue = 1, +}; + +Value *NumericalStabilitySanitizer::emitCheckInternal(Value *V, Value *ShadowV, + IRBuilder<> &Builder, + CheckLoc Loc) { + // Do not emit checks for constant values, this is redundant. + if (isa<Constant>(V)) + return ConstantInt::get( + Builder.getInt32Ty(), + static_cast<int>(ContinuationType::ContinueWithShadow)); + + Type *Ty = V->getType(); + if (const auto VT = ftValueTypeFromType(Ty)) + return Builder.CreateCall( + NsanCheckValue[*VT], + {V, ShadowV, Loc.getType(Context), Loc.getValue(IntptrTy, Builder)}); + + if (Ty->isVectorTy()) { + auto *VecTy = cast<VectorType>(Ty); + // We currently skip scalable vector types in MappingConfig, + // thus we should not encounter any such types here. + assert(!VecTy->isScalableTy() && + "Scalable vector types are not supported yet"); + Value *CheckResult = nullptr; + for (int I = 0, E = VecTy->getElementCount().getFixedValue(); I < E; ++I) { + // We resume if any element resumes. Another option would be to create a + // vector shuffle with the array of ContinueWithShadow, but that is too + // complex. + Value *ExtractV = Builder.CreateExtractElement(V, I); + Value *ExtractShadowV = Builder.CreateExtractElement(ShadowV, I); + Value *ComponentCheckResult = + emitCheckInternal(ExtractV, ExtractShadowV, Builder, Loc); + CheckResult = CheckResult + ? Builder.CreateOr(CheckResult, ComponentCheckResult) + : ComponentCheckResult; + } + return CheckResult; + } + if (Ty->isArrayTy()) { + Value *CheckResult = nullptr; + for (auto I : seq(Ty->getArrayNumElements())) { + Value *ExtractV = Builder.CreateExtractElement(V, I); + Value *ExtractShadowV = Builder.CreateExtractElement(ShadowV, I); + Value *ComponentCheckResult = + emitCheckInternal(ExtractV, ExtractShadowV, Builder, Loc); + CheckResult = CheckResult + ? Builder.CreateOr(CheckResult, ComponentCheckResult) + : ComponentCheckResult; + } + return CheckResult; + } + if (Ty->isStructTy()) { + Value *CheckResult = nullptr; + for (auto I : seq(Ty->getStructNumElements())) { + if (Config.getExtendedFPType(Ty->getStructElementType(I)) == nullptr) + continue; // Only check FT values. + Value *ExtractV = Builder.CreateExtractValue(V, I); + Value *ExtractShadowV = Builder.CreateExtractElement(ShadowV, I); + Value *ComponentCheckResult = + emitCheckInternal(ExtractV, ExtractShadowV, Builder, Loc); + CheckResult = CheckResult + ? Builder.CreateOr(CheckResult, ComponentCheckResult) + : ComponentCheckResult; + } + if (!CheckResult) + return ConstantInt::get( + Builder.getInt32Ty(), + static_cast<int>(ContinuationType::ContinueWithShadow)); + return CheckResult; + } + + llvm_unreachable("not implemented"); +} + +// Inserts a runtime check of V against its shadow value ShadowV. +// We check values whenever they escape: on return, call, stores, and +// insertvalue. +// Returns the shadow value that should be used to continue the computations, +// depending on the answer from the runtime. +// TODO: Should we check on select ? phi ? +Value *NumericalStabilitySanitizer::emitCheck(Value *V, Value *ShadowV, + IRBuilder<> &Builder, + CheckLoc Loc) { + // Do not emit checks for constant values, this is redundant. + if (isa<Constant>(V)) + return ShadowV; + + if (Instruction *Inst = dyn_cast<Instruction>(V)) { + Function *F = Inst->getFunction(); + if (CheckFunctionsFilter && !CheckFunctionsFilter->match(F->getName())) { + return ShadowV; + } + } + + Value *CheckResult = emitCheckInternal(V, ShadowV, Builder, Loc); + Value *ICmpEQ = Builder.CreateICmpEQ( + CheckResult, + ConstantInt::get(Builder.getInt32Ty(), + static_cast<int>(ContinuationType::ResumeFromValue))); + return Builder.CreateSelect( + ICmpEQ, Builder.CreateFPExt(V, Config.getExtendedFPType(V->getType())), + ShadowV); +} + +// Inserts a check that fcmp on shadow values are consistent with that on base +// values. +void NumericalStabilitySanitizer::emitFCmpCheck(FCmpInst &FCmp, + const ValueToShadowMap &Map) { + if (!ClInstrumentFCmp) + return; + + Function *F = FCmp.getFunction(); + if (CheckFunctionsFilter && !CheckFunctionsFilter->match(F->getName())) + return; + + Value *LHS = FCmp.getOperand(0); + if (Config.getExtendedFPType(LHS->getType()) == nullptr) + return; + Value *RHS = FCmp.getOperand(1); + + // Split the basic block. On mismatch, we'll jump to the new basic block with + // a call to the runtime for error reporting. + BasicBlock *FCmpBB = FCmp.getParent(); + BasicBlock *NextBB = FCmpBB->splitBasicBlock(FCmp.getNextNode()); + // Remove the newly created terminator unconditional branch. + FCmpBB->back().eraseFromParent(); + BasicBlock *FailBB = + BasicBlock::Create(Context, "", FCmpBB->getParent(), NextBB); + + // Create the shadow fcmp and comparison between the fcmps. + IRBuilder<> FCmpBuilder(FCmpBB); + FCmpBuilder.SetCurrentDebugLocation(FCmp.getDebugLoc()); + Value *ShadowLHS = Map.getShadow(LHS); + Value *ShadowRHS = Map.getShadow(RHS); + // See comment on ClTruncateFCmpEq. + if (FCmp.isEquality() && ClTruncateFCmpEq) { + Type *Ty = ShadowLHS->getType(); + ShadowLHS = FCmpBuilder.CreateFPExt( + FCmpBuilder.CreateFPTrunc(ShadowLHS, LHS->getType()), Ty); + ShadowRHS = FCmpBuilder.CreateFPExt( + FCmpBuilder.CreateFPTrunc(ShadowRHS, RHS->getType()), Ty); + } + Value *ShadowFCmp = + FCmpBuilder.CreateFCmp(FCmp.getPredicate(), ShadowLHS, ShadowRHS); + Value *OriginalAndShadowFcmpMatch = + FCmpBuilder.CreateICmpEQ(&FCmp, ShadowFCmp); + + if (OriginalAndShadowFcmpMatch->getType()->isVectorTy()) { + // If we have a vector type, `OriginalAndShadowFcmpMatch` is a vector of i1, + // where an element is true if the corresponding elements in original and + // shadow are the same. We want all elements to be 1. + OriginalAndShadowFcmpMatch = + FCmpBuilder.CreateAndReduce(OriginalAndShadowFcmpMatch); + } + + // Use MDBuilder(*C).createLikelyBranchWeights() because "match" is the common + // case. + FCmpBuilder.CreateCondBr(OriginalAndShadowFcmpMatch, NextBB, FailBB, + MDBuilder(Context).createLikelyBranchWeights()); + + // Fill in FailBB. + IRBuilder<> FailBuilder(FailBB); + FailBuilder.SetCurrentDebugLocation(FCmp.getDebugLoc()); + + const auto EmitFailCall = [this, &FCmp, &FCmpBuilder, + &FailBuilder](Value *L, Value *R, Value *ShadowL, + Value *ShadowR, Value *Result, + Value *ShadowResult) { + Type *FT = L->getType(); + FunctionCallee *Callee = nullptr; + if (FT->isFloatTy()) { + Callee = &(NsanFCmpFail[kFloat]); + } else if (FT->isDoubleTy()) { + Callee = &(NsanFCmpFail[kDouble]); + } else if (FT->isX86_FP80Ty()) { + // TODO: make NsanFCmpFailLongDouble work. + Callee = &(NsanFCmpFail[kDouble]); + L = FailBuilder.CreateFPTrunc(L, Type::getDoubleTy(Context)); + R = FailBuilder.CreateFPTrunc(L, Type::getDoubleTy(Context)); + } else { + llvm_unreachable("not implemented"); + } + FailBuilder.CreateCall(*Callee, {L, R, ShadowL, ShadowR, + ConstantInt::get(FCmpBuilder.getInt32Ty(), + FCmp.getPredicate()), + Result, ShadowResult}); + }; + if (LHS->getType()->isVectorTy()) { + for (int I = 0, E = cast<VectorType>(LHS->getType()) + ->getElementCount() + .getFixedValue(); + I < E; ++I) { + Value *ExtractLHS = FailBuilder.CreateExtractElement(LHS, I); + Value *ExtractRHS = FailBuilder.CreateExtractElement(RHS, I); + Value *ExtractShaodwLHS = FailBuilder.CreateExtractElement(ShadowLHS, I); + Value *ExtractShaodwRHS = FailBuilder.CreateExtractElement(ShadowRHS, I); + Value *ExtractFCmp = FailBuilder.CreateExtractElement(&FCmp, I); + Value *ExtractShadowFCmp = + FailBuilder.CreateExtractElement(ShadowFCmp, I); + EmitFailCall(ExtractLHS, ExtractRHS, ExtractShaodwLHS, ExtractShaodwRHS, + ExtractFCmp, ExtractShadowFCmp); + } + } else { + EmitFailCall(LHS, RHS, ShadowLHS, ShadowRHS, &FCmp, ShadowFCmp); + } + FailBuilder.CreateBr(NextBB); + + ++NumInstrumentedFCmp; +} + +// Creates a shadow phi value for any phi that defines a value of FT type. +PHINode *NumericalStabilitySanitizer::maybeCreateShadowPhi( + PHINode &Phi, const TargetLibraryInfo &TLI) { + Type *VT = Phi.getType(); + Type *ExtendedVT = Config.getExtendedFPType(VT); + if (ExtendedVT == nullptr) + return nullptr; // Not an FT value. + // The phi operands are shadow values and are not available when the phi is + // created. They will be populated in a final phase, once all shadow values + // have been created. + PHINode *Shadow = PHINode::Create(ExtendedVT, Phi.getNumIncomingValues()); + Shadow->insertAfter(&Phi); + return Shadow; +} + +Value *NumericalStabilitySanitizer::handleLoad(LoadInst &Load, Type *VT, + Type *ExtendedVT) { + IRBuilder<> Builder(Load.getNextNode()); + Builder.SetCurrentDebugLocation(Load.getDebugLoc()); + if (addrPointsToConstantData(Load.getPointerOperand())) { + // No need to look into the shadow memory, the value is a constant. Just + // convert from FT to 2FT. + return Builder.CreateFPExt(&Load, ExtendedVT); + } + + // if (%shadowptr == &) + // %shadow = fpext %v + // else + // %shadow = load (ptrcast %shadow_ptr)) + // Considered options here: + // - Have `NsanGetShadowPtrForLoad` return a fixed address + // &__nsan_unknown_value_shadow_address that is valid to load from, and + // use a select. This has the advantage that the generated IR is simpler. + // - Have `NsanGetShadowPtrForLoad` return nullptr. Because `select` does + // not short-circuit, dereferencing the returned pointer is no longer an + // option, have to split and create a separate basic block. This has the + // advantage of being easier to debug because it crashes if we ever mess + // up. + + const auto Extents = getMemoryExtentsOrDie(VT); + Value *ShadowPtr = Builder.CreateCall( + NsanGetShadowPtrForLoad[Extents.ValueType], + {Load.getPointerOperand(), ConstantInt::get(IntptrTy, Extents.NumElts)}); + ++NumInstrumentedFTLoads; + + // Split the basic block. + BasicBlock *LoadBB = Load.getParent(); + BasicBlock *NextBB = LoadBB->splitBasicBlock(Builder.GetInsertPoint()); + // Create the two options for creating the shadow value. + BasicBlock *ShadowLoadBB = + BasicBlock::Create(Context, "", LoadBB->getParent(), NextBB); + BasicBlock *FExtBB = + BasicBlock::Create(Context, "", LoadBB->getParent(), NextBB); + + // Replace the newly created terminator unconditional branch by a conditional + // branch to one of the options. + { + LoadBB->back().eraseFromParent(); + IRBuilder<> LoadBBBuilder(LoadBB); // The old builder has been invalidated. + LoadBBBuilder.SetCurrentDebugLocation(Load.getDebugLoc()); + LoadBBBuilder.CreateCondBr(LoadBBBuilder.CreateIsNull(ShadowPtr), FExtBB, + ShadowLoadBB); + } + + // Fill in ShadowLoadBB. + IRBuilder<> ShadowLoadBBBuilder(ShadowLoadBB); + ShadowLoadBBBuilder.SetCurrentDebugLocation(Load.getDebugLoc()); + Value *ShadowLoad = ShadowLoadBBBuilder.CreateAlignedLoad( + ExtendedVT, ShadowPtr, Align(1), Load.isVolatile()); + if (ClCheckLoads) { + ShadowLoad = emitCheck(&Load, ShadowLoad, ShadowLoadBBBuilder, + CheckLoc::makeLoad(Load.getPointerOperand())); + } + ShadowLoadBBBuilder.CreateBr(NextBB); + + // Fill in FExtBB. + IRBuilder<> FExtBBBuilder(FExtBB); + FExtBBBuilder.SetCurrentDebugLocation(Load.getDebugLoc()); + Value *FExt = FExtBBBuilder.CreateFPExt(&Load, ExtendedVT); + FExtBBBuilder.CreateBr(NextBB); + + // The shadow value come from any of the options. + IRBuilder<> NextBBBuilder(&*NextBB->begin()); + NextBBBuilder.SetCurrentDebugLocation(Load.getDebugLoc()); + PHINode *ShadowPhi = NextBBBuilder.CreatePHI(ExtendedVT, 2); + ShadowPhi->addIncoming(ShadowLoad, ShadowLoadBB); + ShadowPhi->addIncoming(FExt, FExtBB); + return ShadowPhi; +} + +Value *NumericalStabilitySanitizer::handleTrunc(const FPTruncInst &Trunc, + Type *VT, Type *ExtendedVT, + const ValueToShadowMap &Map, + IRBuilder<> &Builder) { + Value *OrigSource = Trunc.getOperand(0); + Type *OrigSourceTy = OrigSource->getType(); + Type *ExtendedSourceTy = Config.getExtendedFPType(OrigSourceTy); + + // When truncating: + // - (A) If the source has a shadow, we truncate from the shadow, else we + // truncate from the original source. + // - (B) If the shadow of the source is larger than the shadow of the dest, + // we still need a truncate. Else, the shadow of the source is the same + // type as the shadow of the dest (because mappings are non-decreasing), so + // we don't need to emit a truncate. + // Examples, + // with a mapping of {f32->f64;f64->f80;f80->f128} + // fptrunc double %1 to float -> fptrunc x86_fp80 s(%1) to double + // fptrunc x86_fp80 %1 to float -> fptrunc fp128 s(%1) to double + // fptrunc fp128 %1 to float -> fptrunc fp128 %1 to double + // fptrunc x86_fp80 %1 to double -> x86_fp80 s(%1) + // fptrunc fp128 %1 to double -> fptrunc fp128 %1 to x86_fp80 + // fptrunc fp128 %1 to x86_fp80 -> fp128 %1 + // with a mapping of {f32->f64;f64->f128;f80->f128} + // fptrunc double %1 to float -> fptrunc fp128 s(%1) to double + // fptrunc x86_fp80 %1 to float -> fptrunc fp128 s(%1) to double + // fptrunc fp128 %1 to float -> fptrunc fp128 %1 to double + // fptrunc x86_fp80 %1 to double -> fp128 %1 + // fptrunc fp128 %1 to double -> fp128 %1 + // fptrunc fp128 %1 to x86_fp80 -> fp128 %1 + // with a mapping of {f32->f32;f64->f32;f80->f64} + // fptrunc double %1 to float -> float s(%1) + // fptrunc x86_fp80 %1 to float -> fptrunc double s(%1) to float + // fptrunc fp128 %1 to float -> fptrunc fp128 %1 to float + // fptrunc x86_fp80 %1 to double -> fptrunc double s(%1) to float + // fptrunc fp128 %1 to double -> fptrunc fp128 %1 to float + // fptrunc fp128 %1 to x86_fp80 -> fptrunc fp128 %1 to double + + // See (A) above. + Value *Source = ExtendedSourceTy ? Map.getShadow(OrigSource) : OrigSource; + Type *SourceTy = ExtendedSourceTy ? ExtendedSourceTy : OrigSourceTy; + // See (B) above. + if (SourceTy == ExtendedVT) + return Source; + + return Builder.CreateFPTrunc(Source, ExtendedVT); +} + +Value *NumericalStabilitySanitizer::handleExt(const FPExtInst &Ext, Type *VT, + Type *ExtendedVT, + const ValueToShadowMap &Map, + IRBuilder<> &Builder) { + Value *OrigSource = Ext.getOperand(0); + Type *OrigSourceTy = OrigSource->getType(); + Type *ExtendedSourceTy = Config.getExtendedFPType(OrigSourceTy); + // When extending: + // - (A) If the source has a shadow, we extend from the shadow, else we + // extend from the original source. + // - (B) If the shadow of the dest is larger than the shadow of the source, + // we still need an extend. Else, the shadow of the source is the same + // type as the shadow of the dest (because mappings are non-decreasing), so + // we don't need to emit an extend. + // Examples, + // with a mapping of {f32->f64;f64->f80;f80->f128} + // fpext half %1 to float -> fpext half %1 to double + // fpext half %1 to double -> fpext half %1 to x86_fp80 + // fpext half %1 to x86_fp80 -> fpext half %1 to fp128 + // fpext float %1 to double -> double s(%1) + // fpext float %1 to x86_fp80 -> fpext double s(%1) to fp128 + // fpext double %1 to x86_fp80 -> fpext x86_fp80 s(%1) to fp128 + // with a mapping of {f32->f64;f64->f128;f80->f128} + // fpext half %1 to float -> fpext half %1 to double + // fpext half %1 to double -> fpext half %1 to fp128 + // fpext half %1 to x86_fp80 -> fpext half %1 to fp128 + // fpext float %1 to double -> fpext double s(%1) to fp128 + // fpext float %1 to x86_fp80 -> fpext double s(%1) to fp128 + // fpext double %1 to x86_fp80 -> fp128 s(%1) + // with a mapping of {f32->f32;f64->f32;f80->f64} + // fpext half %1 to float -> fpext half %1 to float + // fpext half %1 to double -> fpext half %1 to float + // fpext half %1 to x86_fp80 -> fpext half %1 to double + // fpext float %1 to double -> s(%1) + // fpext float %1 to x86_fp80 -> fpext float s(%1) to double + // fpext double %1 to x86_fp80 -> fpext float s(%1) to double + + // See (A) above. + Value *Source = ExtendedSourceTy ? Map.getShadow(OrigSource) : OrigSource; + Type *SourceTy = ExtendedSourceTy ? ExtendedSourceTy : OrigSourceTy; + // See (B) above. + if (SourceTy == ExtendedVT) + return Source; + + return Builder.CreateFPExt(Source, ExtendedVT); +} + +namespace { +// TODO: This should be tablegen-ed. +struct KnownIntrinsic { + struct WidenedIntrinsic { + const char *NarrowName; + Intrinsic::ID ID; // wide id. + using FnTypeFactory = FunctionType *(*)(LLVMContext &); + FnTypeFactory MakeFnTy; + }; + + static const char *get(LibFunc LFunc); + + // Given an intrinsic with an `FT` argument, try to find a wider intrinsic + // that applies the same operation on the shadow argument. + // Options are: + // - pass in the ID and full function type, + // - pass in the name, which includes the function type through mangling. + static const WidenedIntrinsic *widen(StringRef Name); + +private: + struct LFEntry { + LibFunc LFunc; + const char *IntrinsicName; + }; + static const LFEntry kLibfuncIntrinsics[]; + + static const WidenedIntrinsic kWidenedIntrinsics[]; +}; +} // namespace + +static FunctionType *makeDoubleDouble(LLVMContext &C) { + return FunctionType::get(Type::getDoubleTy(C), {Type::getDoubleTy(C)}, false); +} + +static FunctionType *makeX86FP80X86FP80(LLVMContext &C) { + return FunctionType::get(Type::getX86_FP80Ty(C), {Type::getX86_FP80Ty(C)}, + false); +} + +static FunctionType *makeDoubleDoubleI32(LLVMContext &C) { + return FunctionType::get(Type::getDoubleTy(C), + {Type::getDoubleTy(C), Type::getInt32Ty(C)}, false); +} + +static FunctionType *makeX86FP80X86FP80I32(LLVMContext &C) { + return FunctionType::get(Type::getX86_FP80Ty(C), + {Type::getX86_FP80Ty(C), Type::getInt32Ty(C)}, + false); +} + +static FunctionType *makeDoubleDoubleDouble(LLVMContext &C) { + return FunctionType::get(Type::getDoubleTy(C), + {Type::getDoubleTy(C), Type::getDoubleTy(C)}, false); +} + +static FunctionType *makeX86FP80X86FP80X86FP80(LLVMContext &C) { + return FunctionType::get(Type::getX86_FP80Ty(C), + {Type::getX86_FP80Ty(C), Type::getX86_FP80Ty(C)}, + false); +} + +static FunctionType *makeDoubleDoubleDoubleDouble(LLVMContext &C) { + return FunctionType::get( + Type::getDoubleTy(C), + {Type::getDoubleTy(C), Type::getDoubleTy(C), Type::getDoubleTy(C)}, + false); +} + +static FunctionType *makeX86FP80X86FP80X86FP80X86FP80(LLVMContext &C) { + return FunctionType::get( + Type::getX86_FP80Ty(C), + {Type::getX86_FP80Ty(C), Type::getX86_FP80Ty(C), Type::getX86_FP80Ty(C)}, + false); +} + +const KnownIntrinsic::WidenedIntrinsic KnownIntrinsic::kWidenedIntrinsics[] = { + // TODO: Right now we ignore vector intrinsics. + // This is hard because we have to model the semantics of the intrinsics, + // e.g. llvm.x86.sse2.min.sd means extract first element, min, insert back. + // Intrinsics that take any non-vector FT types: + // NOTE: Right now because of + // https://github.com/llvm/llvm-project/issues/44744 + // for f128 we need to use makeX86FP80X86FP80 (go to a lower precision and + // come back). + {"llvm.sqrt.f32", Intrinsic::sqrt, makeDoubleDouble}, + {"llvm.sqrt.f64", Intrinsic::sqrt, makeX86FP80X86FP80}, + {"llvm.sqrt.f80", Intrinsic::sqrt, makeX86FP80X86FP80}, + {"llvm.powi.f32", Intrinsic::powi, makeDoubleDoubleI32}, + {"llvm.powi.f64", Intrinsic::powi, makeX86FP80X86FP80I32}, + {"llvm.powi.f80", Intrinsic::powi, makeX86FP80X86FP80I32}, + {"llvm.sin.f32", Intrinsic::sin, makeDoubleDouble}, + {"llvm.sin.f64", Intrinsic::sin, makeX86FP80X86FP80}, + {"llvm.sin.f80", Intrinsic::sin, makeX86FP80X86FP80}, + {"llvm.cos.f32", Intrinsic::cos, makeDoubleDouble}, + {"llvm.cos.f64", Intrinsic::cos, makeX86FP80X86FP80}, + {"llvm.cos.f80", Intrinsic::cos, makeX86FP80X86FP80}, + {"llvm.pow.f32", Intrinsic::pow, makeDoubleDoubleDouble}, + {"llvm.pow.f64", Intrinsic::pow, makeX86FP80X86FP80X86FP80}, + {"llvm.pow.f80", Intrinsic::pow, makeX86FP80X86FP80X86FP80}, + {"llvm.exp.f32", Intrinsic::exp, makeDoubleDouble}, + {"llvm.exp.f64", Intrinsic::exp, makeX86FP80X86FP80}, + {"llvm.exp.f80", Intrinsic::exp, makeX86FP80X86FP80}, + {"llvm.exp2.f32", Intrinsic::exp2, makeDoubleDouble}, + {"llvm.exp2.f64", Intrinsic::exp2, makeX86FP80X86FP80}, + {"llvm.exp2.f80", Intrinsic::exp2, makeX86FP80X86FP80}, + {"llvm.log.f32", Intrinsic::log, makeDoubleDouble}, + {"llvm.log.f64", Intrinsic::log, makeX86FP80X86FP80}, + {"llvm.log.f80", Intrinsic::log, makeX86FP80X86FP80}, + {"llvm.log10.f32", Intrinsic::log10, makeDoubleDouble}, + {"llvm.log10.f64", Intrinsic::log10, makeX86FP80X86FP80}, + {"llvm.log10.f80", Intrinsic::log10, makeX86FP80X86FP80}, + {"llvm.log2.f32", Intrinsic::log2, makeDoubleDouble}, + {"llvm.log2.f64", Intrinsic::log2, makeX86FP80X86FP80}, + {"llvm.log2.f80", Intrinsic::log2, makeX86FP80X86FP80}, + {"llvm.fma.f32", Intrinsic::fma, makeDoubleDoubleDoubleDouble}, + + {"llvm.fmuladd.f32", Intrinsic::fmuladd, makeDoubleDoubleDoubleDouble}, + + {"llvm.fma.f64", Intrinsic::fma, makeX86FP80X86FP80X86FP80X86FP80}, + + {"llvm.fmuladd.f64", Intrinsic::fma, makeX86FP80X86FP80X86FP80X86FP80}, + + {"llvm.fma.f80", Intrinsic::fma, makeX86FP80X86FP80X86FP80X86FP80}, + {"llvm.fabs.f32", Intrinsic::fabs, makeDoubleDouble}, + {"llvm.fabs.f64", Intrinsic::fabs, makeX86FP80X86FP80}, + {"llvm.fabs.f80", Intrinsic::fabs, makeX86FP80X86FP80}, + {"llvm.minnum.f32", Intrinsic::minnum, makeDoubleDoubleDouble}, + {"llvm.minnum.f64", Intrinsic::minnum, makeX86FP80X86FP80X86FP80}, + {"llvm.minnum.f80", Intrinsic::minnum, makeX86FP80X86FP80X86FP80}, + {"llvm.maxnum.f32", Intrinsic::maxnum, makeDoubleDoubleDouble}, + {"llvm.maxnum.f64", Intrinsic::maxnum, makeX86FP80X86FP80X86FP80}, + {"llvm.maxnum.f80", Intrinsic::maxnum, makeX86FP80X86FP80X86FP80}, + {"llvm.minimum.f32", Intrinsic::minimum, makeDoubleDoubleDouble}, + {"llvm.minimum.f64", Intrinsic::minimum, makeX86FP80X86FP80X86FP80}, + {"llvm.minimum.f80", Intrinsic::minimum, makeX86FP80X86FP80X86FP80}, + {"llvm.maximum.f32", Intrinsic::maximum, makeDoubleDoubleDouble}, + {"llvm.maximum.f64", Intrinsic::maximum, makeX86FP80X86FP80X86FP80}, + {"llvm.maximum.f80", Intrinsic::maximum, makeX86FP80X86FP80X86FP80}, + {"llvm.copysign.f32", Intrinsic::copysign, makeDoubleDoubleDouble}, + {"llvm.copysign.f64", Intrinsic::copysign, makeX86FP80X86FP80X86FP80}, + {"llvm.copysign.f80", Intrinsic::copysign, makeX86FP80X86FP80X86FP80}, + {"llvm.floor.f32", Intrinsic::floor, makeDoubleDouble}, + {"llvm.floor.f64", Intrinsic::floor, makeX86FP80X86FP80}, + {"llvm.floor.f80", Intrinsic::floor, makeX86FP80X86FP80}, + {"llvm.ceil.f32", Intrinsic::ceil, makeDoubleDouble}, + {"llvm.ceil.f64", Intrinsic::ceil, makeX86FP80X86FP80}, + {"llvm.ceil.f80", Intrinsic::ceil, makeX86FP80X86FP80}, + {"llvm.trunc.f32", Intrinsic::trunc, makeDoubleDouble}, + {"llvm.trunc.f64", Intrinsic::trunc, makeX86FP80X86FP80}, + {"llvm.trunc.f80", Intrinsic::trunc, makeX86FP80X86FP80}, + {"llvm.rint.f32", Intrinsic::rint, makeDoubleDouble}, + {"llvm.rint.f64", Intrinsic::rint, makeX86FP80X86FP80}, + {"llvm.rint.f80", Intrinsic::rint, makeX86FP80X86FP80}, + {"llvm.nearbyint.f32", Intrinsic::nearbyint, makeDoubleDouble}, + {"llvm.nearbyint.f64", Intrinsic::nearbyint, makeX86FP80X86FP80}, + {"llvm.nearbyin80f64", Intrinsic::nearbyint, makeX86FP80X86FP80}, + {"llvm.round.f32", Intrinsic::round, makeDoubleDouble}, + {"llvm.round.f64", Intrinsic::round, makeX86FP80X86FP80}, + {"llvm.round.f80", Intrinsic::round, makeX86FP80X86FP80}, + {"llvm.lround.f32", Intrinsic::lround, makeDoubleDouble}, + {"llvm.lround.f64", Intrinsic::lround, makeX86FP80X86FP80}, + {"llvm.lround.f80", Intrinsic::lround, makeX86FP80X86FP80}, + {"llvm.llround.f32", Intrinsic::llround, makeDoubleDouble}, + {"llvm.llround.f64", Intrinsic::llround, makeX86FP80X86FP80}, + {"llvm.llround.f80", Intrinsic::llround, makeX86FP80X86FP80}, + {"llvm.lrint.f32", Intrinsic::lrint, makeDoubleDouble}, + {"llvm.lrint.f64", Intrinsic::lrint, makeX86FP80X86FP80}, + {"llvm.lrint.f80", Intrinsic::lrint, makeX86FP80X86FP80}, + {"llvm.llrint.f32", Intrinsic::llrint, makeDoubleDouble}, + {"llvm.llrint.f64", Intrinsic::llrint, makeX86FP80X86FP80}, + {"llvm.llrint.f80", Intrinsic::llrint, makeX86FP80X86FP80}, +}; + +const KnownIntrinsic::LFEntry KnownIntrinsic::kLibfuncIntrinsics[] = { + {LibFunc_sqrtf, "llvm.sqrt.f32"}, + {LibFunc_sqrt, "llvm.sqrt.f64"}, + {LibFunc_sqrtl, "llvm.sqrt.f80"}, + {LibFunc_sinf, "llvm.sin.f32"}, + {LibFunc_sin, "llvm.sin.f64"}, + {LibFunc_sinl, "llvm.sin.f80"}, + {LibFunc_cosf, "llvm.cos.f32"}, + {LibFunc_cos, "llvm.cos.f64"}, + {LibFunc_cosl, "llvm.cos.f80"}, + {LibFunc_powf, "llvm.pow.f32"}, + {LibFunc_pow, "llvm.pow.f64"}, + {LibFunc_powl, "llvm.pow.f80"}, + {LibFunc_expf, "llvm.exp.f32"}, + {LibFunc_exp, "llvm.exp.f64"}, + {LibFunc_expl, "llvm.exp.f80"}, + {LibFunc_exp2f, "llvm.exp2.f32"}, + {LibFunc_exp2, "llvm.exp2.f64"}, + {LibFunc_exp2l, "llvm.exp2.f80"}, + {LibFunc_logf, "llvm.log.f32"}, + {LibFunc_log, "llvm.log.f64"}, + {LibFunc_logl, "llvm.log.f80"}, + {LibFunc_log10f, "llvm.log10.f32"}, + {LibFunc_log10, "llvm.log10.f64"}, + {LibFunc_log10l, "llvm.log10.f80"}, + {LibFunc_log2f, "llvm.log2.f32"}, + {LibFunc_log2, "llvm.log2.f64"}, + {LibFunc_log2l, "llvm.log2.f80"}, + {LibFunc_fabsf, "llvm.fabs.f32"}, + {LibFunc_fabs, "llvm.fabs.f64"}, + {LibFunc_fabsl, "llvm.fabs.f80"}, + {LibFunc_copysignf, "llvm.copysign.f32"}, + {LibFunc_copysign, "llvm.copysign.f64"}, + {LibFunc_copysignl, "llvm.copysign.f80"}, + {LibFunc_floorf, "llvm.floor.f32"}, + {LibFunc_floor, "llvm.floor.f64"}, + {LibFunc_floorl, "llvm.floor.f80"}, + {LibFunc_fmaxf, "llvm.maxnum.f32"}, + {LibFunc_fmax, "llvm.maxnum.f64"}, + {LibFunc_fmaxl, "llvm.maxnum.f80"}, + {LibFunc_fminf, "llvm.minnum.f32"}, + {LibFunc_fmin, "llvm.minnum.f64"}, + {LibFunc_fminl, "llvm.minnum.f80"}, + {LibFunc_ceilf, "llvm.ceil.f32"}, + {LibFunc_ceil, "llvm.ceil.f64"}, + {LibFunc_ceill, "llvm.ceil.f80"}, + {LibFunc_truncf, "llvm.trunc.f32"}, + {LibFunc_trunc, "llvm.trunc.f64"}, + {LibFunc_truncl, "llvm.trunc.f80"}, + {LibFunc_rintf, "llvm.rint.f32"}, + {LibFunc_rint, "llvm.rint.f64"}, + {LibFunc_rintl, "llvm.rint.f80"}, + {LibFunc_nearbyintf, "llvm.nearbyint.f32"}, + {LibFunc_nearbyint, "llvm.nearbyint.f64"}, + {LibFunc_nearbyintl, "llvm.nearbyint.f80"}, + {LibFunc_roundf, "llvm.round.f32"}, + {LibFunc_round, "llvm.round.f64"}, + {LibFunc_roundl, "llvm.round.f80"}, +}; + +const char *KnownIntrinsic::get(LibFunc LFunc) { + for (const auto &E : kLibfuncIntrinsics) { + if (E.LFunc == LFunc) + return E.IntrinsicName; + } + return nullptr; +} + +const KnownIntrinsic::WidenedIntrinsic *KnownIntrinsic::widen(StringRef Name) { + for (const auto &E : kWidenedIntrinsics) { + if (E.NarrowName == Name) + return &E; + } + return nullptr; +} + +// Returns the name of the LLVM intrinsic corresponding to the given function. +static const char *getIntrinsicFromLibfunc(Function &Fn, Type *VT, + const TargetLibraryInfo &TLI) { + LibFunc LFunc; + if (!TLI.getLibFunc(Fn, LFunc)) + return nullptr; + + if (const char *Name = KnownIntrinsic::get(LFunc)) + return Name; + + LLVM_DEBUG(errs() << "TODO: LibFunc: " << TLI.getName(LFunc) << "\n"); + return nullptr; +} + +// Try to handle a known function call. +Value *NumericalStabilitySanitizer::maybeHandleKnownCallBase( + CallBase &Call, Type *VT, Type *ExtendedVT, const TargetLibraryInfo &TLI, + const ValueToShadowMap &Map, IRBuilder<> &Builder) { + Function *Fn = Call.getCalledFunction(); + if (Fn == nullptr) + return nullptr; + + Intrinsic::ID WidenedId = Intrinsic::ID(); + FunctionType *WidenedFnTy = nullptr; + if (const auto ID = Fn->getIntrinsicID()) { + const auto *Widened = KnownIntrinsic::widen(Fn->getName()); + if (Widened) { + WidenedId = Widened->ID; + WidenedFnTy = Widened->MakeFnTy(Context); + } else { + // If we don't know how to widen the intrinsic, we have no choice but to + // call the non-wide version on a truncated shadow and extend again + // afterwards. + WidenedId = ID; + WidenedFnTy = Fn->getFunctionType(); + } + } else if (const char *Name = getIntrinsicFromLibfunc(*Fn, VT, TLI)) { + // We might have a call to a library function that we can replace with a + // wider Intrinsic. + const auto *Widened = KnownIntrinsic::widen(Name); + assert(Widened && "make sure KnownIntrinsic entries are consistent"); + WidenedId = Widened->ID; + WidenedFnTy = Widened->MakeFnTy(Context); + } else { + // This is not a known library function or intrinsic. + return nullptr; + } + + // Check that the widened intrinsic is valid. + SmallVector<Intrinsic::IITDescriptor, 8> Table; + getIntrinsicInfoTableEntries(WidenedId, Table); + SmallVector<Type *, 4> ArgTys; + ArrayRef<Intrinsic::IITDescriptor> TableRef = Table; + [[maybe_unused]] Intrinsic::MatchIntrinsicTypesResult MatchResult = + Intrinsic::matchIntrinsicSignature(WidenedFnTy, TableRef, ArgTys); + assert(MatchResult == Intrinsic::MatchIntrinsicTypes_Match && + "invalid widened intrinsic"); + // For known intrinsic functions, we create a second call to the same + // intrinsic with a different type. + SmallVector<Value *, 4> Args; + // The last operand is the intrinsic itself, skip it. + for (unsigned I = 0, E = Call.getNumOperands() - 1; I < E; ++I) { + Value *Arg = Call.getOperand(I); + Type *OrigArgTy = Arg->getType(); + Type *IntrinsicArgTy = WidenedFnTy->getParamType(I); + if (OrigArgTy == IntrinsicArgTy) { + Args.push_back(Arg); // The arg is passed as is. + continue; + } + Type *ShadowArgTy = Config.getExtendedFPType(Arg->getType()); + assert(ShadowArgTy && + "don't know how to get the shadow value for a non-FT"); + Value *Shadow = Map.getShadow(Arg); + if (ShadowArgTy == IntrinsicArgTy) { + // The shadow is the right type for the intrinsic. + assert(Shadow->getType() == ShadowArgTy); + Args.push_back(Shadow); + continue; + } + // There is no intrinsic with his level of precision, truncate the shadow. + Args.push_back(Builder.CreateFPTrunc(Shadow, IntrinsicArgTy)); + } + Value *IntrinsicCall = Builder.CreateIntrinsic(WidenedId, ArgTys, Args); + return WidenedFnTy->getReturnType() == ExtendedVT + ? IntrinsicCall + : Builder.CreateFPExt(IntrinsicCall, ExtendedVT); +} + +// Handle a CallBase, i.e. a function call, an inline asm sequence, or an +// invoke. +Value *NumericalStabilitySanitizer::handleCallBase(CallBase &Call, Type *VT, + Type *ExtendedVT, + const TargetLibraryInfo &TLI, + const ValueToShadowMap &Map, + IRBuilder<> &Builder) { + // We cannot look inside inline asm, just expand the result again. + if (Call.isInlineAsm()) + return Builder.CreateFPExt(&Call, ExtendedVT); + + // Intrinsics and library functions (e.g. sin, exp) are handled + // specifically, because we know their semantics and can do better than + // blindly calling them (e.g. compute the sinus in the actual shadow domain). + if (Value *V = + maybeHandleKnownCallBase(Call, VT, ExtendedVT, TLI, Map, Builder)) + return V; + + // If the return tag matches that of the called function, read the extended + // return value from the shadow ret ptr. Else, just extend the return value. + Value *L = + Builder.CreateLoad(IntptrTy, NsanShadowRetTag, /*isVolatile=*/false); + Value *HasShadowRet = Builder.CreateICmpEQ( + L, Builder.CreatePtrToInt(Call.getCalledOperand(), IntptrTy)); + + Value *ShadowRetVal = Builder.CreateLoad( + ExtendedVT, + Builder.CreateConstGEP2_64(NsanShadowRetType, NsanShadowRetPtr, 0, 0), + /*isVolatile=*/false); + Value *Shadow = Builder.CreateSelect(HasShadowRet, ShadowRetVal, + Builder.CreateFPExt(&Call, ExtendedVT)); + ++NumInstrumentedFTCalls; + return Shadow; +} + +// Creates a shadow value for the given FT value. At that point all operands are +// guaranteed to be available. +Value *NumericalStabilitySanitizer::createShadowValueWithOperandsAvailable( + Instruction &Inst, const TargetLibraryInfo &TLI, + const ValueToShadowMap &Map) { + Type *VT = Inst.getType(); + Type *ExtendedVT = Config.getExtendedFPType(VT); + assert(ExtendedVT != nullptr && "trying to create a shadow for a non-FT"); + + if (auto *Load = dyn_cast<LoadInst>(&Inst)) + return handleLoad(*Load, VT, ExtendedVT); + + if (auto *Call = dyn_cast<CallInst>(&Inst)) { + // Insert after the call. + BasicBlock::iterator It(Inst); + IRBuilder<> Builder(Call->getParent(), ++It); + Builder.SetCurrentDebugLocation(Call->getDebugLoc()); + return handleCallBase(*Call, VT, ExtendedVT, TLI, Map, Builder); + } + + if (auto *Invoke = dyn_cast<InvokeInst>(&Inst)) { + // The Invoke terminates the basic block, create a new basic block in + // between the successful invoke and the next block. + BasicBlock *InvokeBB = Invoke->getParent(); + BasicBlock *NextBB = Invoke->getNormalDest(); + BasicBlock *NewBB = + BasicBlock::Create(Context, "", NextBB->getParent(), NextBB); + Inst.replaceSuccessorWith(NextBB, NewBB); + + IRBuilder<> Builder(NewBB); + Builder.SetCurrentDebugLocation(Invoke->getDebugLoc()); + Value *Shadow = handleCallBase(*Invoke, VT, ExtendedVT, TLI, Map, Builder); + Builder.CreateBr(NextBB); + NewBB->replaceSuccessorsPhiUsesWith(InvokeBB, NewBB); + return Shadow; + } + + IRBuilder<> Builder(Inst.getNextNode()); + Builder.SetCurrentDebugLocation(Inst.getDebugLoc()); + + if (auto *Trunc = dyn_cast<FPTruncInst>(&Inst)) + return handleTrunc(*Trunc, VT, ExtendedVT, Map, Builder); + if (auto *Ext = dyn_cast<FPExtInst>(&Inst)) + return handleExt(*Ext, VT, ExtendedVT, Map, Builder); + + if (auto *UnaryOp = dyn_cast<UnaryOperator>(&Inst)) + return Builder.CreateUnOp(UnaryOp->getOpcode(), + Map.getShadow(UnaryOp->getOperand(0))); + + if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst)) + return Builder.CreateBinOp(BinOp->getOpcode(), + Map.getShadow(BinOp->getOperand(0)), + Map.getShadow(BinOp->getOperand(1))); + + if (isa<UIToFPInst>(&Inst) || isa<SIToFPInst>(&Inst)) { + auto *Cast = dyn_cast<CastInst>(&Inst); + return Builder.CreateCast(Cast->getOpcode(), Cast->getOperand(0), + ExtendedVT); + } + + if (auto *S = dyn_cast<SelectInst>(&Inst)) + return Builder.CreateSelect(S->getCondition(), + Map.getShadow(S->getTrueValue()), + Map.getShadow(S->getFalseValue())); + + if (auto *Extract = dyn_cast<ExtractElementInst>(&Inst)) + return Builder.CreateExtractElement( + Map.getShadow(Extract->getVectorOperand()), Extract->getIndexOperand()); + + if (auto *Insert = dyn_cast<InsertElementInst>(&Inst)) + return Builder.CreateInsertElement(Map.getShadow(Insert->getOperand(0)), + Map.getShadow(Insert->getOperand(1)), + Insert->getOperand(2)); + + if (auto *Shuffle = dyn_cast<ShuffleVectorInst>(&Inst)) + return Builder.CreateShuffleVector(Map.getShadow(Shuffle->getOperand(0)), + Map.getShadow(Shuffle->getOperand(1)), + Shuffle->getShuffleMask()); + // TODO: We could make aggregate object first class citizens. For now we + // just extend the extracted value. + if (auto *Extract = dyn_cast<ExtractValueInst>(&Inst)) + return Builder.CreateFPExt(Extract, ExtendedVT); + + if (auto *BC = dyn_cast<BitCastInst>(&Inst)) + return Builder.CreateFPExt(BC, ExtendedVT); + + report_fatal_error("Unimplemented support for " + + Twine(Inst.getOpcodeName())); +} + +// Creates a shadow value for an instruction that defines a value of FT type. +// FT operands that do not already have shadow values are created recursively. +// The DFS is guaranteed to not loop as phis and arguments already have +// shadows. +void NumericalStabilitySanitizer::maybeCreateShadowValue( + Instruction &Root, const TargetLibraryInfo &TLI, ValueToShadowMap &Map) { + Type *VT = Root.getType(); + Type *ExtendedVT = Config.getExtendedFPType(VT); + if (ExtendedVT == nullptr) + return; // Not an FT value. + + if (Map.hasShadow(&Root)) + return; // Shadow already exists. + + assert(!isa<PHINode>(Root) && "phi nodes should already have shadows"); + + std::vector<Instruction *> DfsStack(1, &Root); + while (!DfsStack.empty()) { + // Ensure that all operands to the instruction have shadows before + // proceeding. + Instruction *I = DfsStack.back(); + // The shadow for the instruction might have been created deeper in the DFS, + // see `forward_use_with_two_uses` test. + if (Map.hasShadow(I)) { + DfsStack.pop_back(); + continue; + } + + bool MissingShadow = false; + for (Value *Op : I->operands()) { + Type *VT = Op->getType(); + if (!Config.getExtendedFPType(VT)) + continue; // Not an FT value. + if (Map.hasShadow(Op)) + continue; // Shadow is already available. + MissingShadow = true; + DfsStack.push_back(cast<Instruction>(Op)); + } + if (MissingShadow) + continue; // Process operands and come back to this instruction later. + + // All operands have shadows. Create a shadow for the current value. + Value *Shadow = createShadowValueWithOperandsAvailable(*I, TLI, Map); + Map.setShadow(*I, *Shadow); + DfsStack.pop_back(); + } +} + +// A floating-point store needs its value and type written to shadow memory. +void NumericalStabilitySanitizer::propagateFTStore( + StoreInst &Store, Type *VT, Type *ExtendedVT, const ValueToShadowMap &Map) { + Value *StoredValue = Store.getValueOperand(); + IRBuilder<> Builder(&Store); + Builder.SetCurrentDebugLocation(Store.getDebugLoc()); + const auto Extents = getMemoryExtentsOrDie(VT); + Value *ShadowPtr = Builder.CreateCall( + NsanGetShadowPtrForStore[Extents.ValueType], + {Store.getPointerOperand(), ConstantInt::get(IntptrTy, Extents.NumElts)}); + + Value *StoredShadow = Map.getShadow(StoredValue); + if (!Store.getParent()->getParent()->hasOptNone()) { + // Only check stores when optimizing, because non-optimized code generates + // too many stores to the stack, creating false positives. + if (ClCheckStores) { + StoredShadow = emitCheck(StoredValue, StoredShadow, Builder, + CheckLoc::makeStore(Store.getPointerOperand())); + ++NumInstrumentedFTStores; + } + } + + Builder.CreateAlignedStore(StoredShadow, ShadowPtr, Align(1), + Store.isVolatile()); +} + +// A non-ft store needs to invalidate shadow memory. Exceptions are: +// - memory transfers of floating-point data through other pointer types (llvm +// optimization passes transform `*(float*)a = *(float*)b` into +// `*(i32*)a = *(i32*)b` ). These have the same semantics as memcpy. +// - Writes of FT-sized constants. LLVM likes to do float stores as bitcasted +// ints. Note that this is not really necessary because if the value is +// unknown the framework will re-extend it on load anyway. It just felt +// easier to debug tests with vectors of FTs. +void NumericalStabilitySanitizer::propagateNonFTStore( + StoreInst &Store, Type *VT, const ValueToShadowMap &Map) { + Value *PtrOp = Store.getPointerOperand(); + IRBuilder<> Builder(Store.getNextNode()); + Builder.SetCurrentDebugLocation(Store.getDebugLoc()); + Value *Dst = PtrOp; + TypeSize SlotSize = DL.getTypeStoreSize(VT); + assert(!SlotSize.isScalable() && "unsupported"); + const auto LoadSizeBytes = SlotSize.getFixedValue(); + Value *ValueSize = Constant::getIntegerValue( + IntptrTy, APInt(IntptrTy->getPrimitiveSizeInBits(), LoadSizeBytes)); + + ++NumInstrumentedNonFTStores; + Value *StoredValue = Store.getValueOperand(); + if (LoadInst *Load = dyn_cast<LoadInst>(StoredValue)) { + // TODO: Handle the case when the value is from a phi. + // This is a memory transfer with memcpy semantics. Copy the type and + // value from the source. Note that we cannot use __nsan_copy_values() + // here, because that will not work when there is a write to memory in + // between the load and the store, e.g. in the case of a swap. + Type *ShadowTypeIntTy = Type::getIntNTy(Context, 8 * LoadSizeBytes); + Type *ShadowValueIntTy = + Type::getIntNTy(Context, 8 * kShadowScale * LoadSizeBytes); + IRBuilder<> LoadBuilder(Load->getNextNode()); + Builder.SetCurrentDebugLocation(Store.getDebugLoc()); + Value *LoadSrc = Load->getPointerOperand(); + // Read the shadow type and value at load time. The type has the same size + // as the FT value, the value has twice its size. + // TODO: cache them to avoid re-creating them when a load is used by + // several stores. Maybe create them like the FT shadows when a load is + // encountered. + Value *RawShadowType = LoadBuilder.CreateAlignedLoad( + ShadowTypeIntTy, + LoadBuilder.CreateCall(NsanGetRawShadowTypePtr, {LoadSrc}), Align(1), + /*isVolatile=*/false); + Value *RawShadowValue = LoadBuilder.CreateAlignedLoad( + ShadowValueIntTy, + LoadBuilder.CreateCall(NsanGetRawShadowPtr, {LoadSrc}), Align(1), + /*isVolatile=*/false); + + // Write back the shadow type and value at store time. + Builder.CreateAlignedStore( + RawShadowType, Builder.CreateCall(NsanGetRawShadowTypePtr, {Dst}), + Align(1), + /*isVolatile=*/false); + Builder.CreateAlignedStore(RawShadowValue, + Builder.CreateCall(NsanGetRawShadowPtr, {Dst}), + Align(1), + /*isVolatile=*/false); + + ++NumInstrumentedNonFTMemcpyStores; + return; + } + // ClPropagateNonFTConstStoresAsFT is by default false. + if (Constant *C; ClPropagateNonFTConstStoresAsFT && + (C = dyn_cast<Constant>(StoredValue))) { + // This might be a fp constant stored as an int. Bitcast and store if it has + // appropriate size. + Type *BitcastTy = nullptr; // The FT type to bitcast to. + if (auto *CInt = dyn_cast<ConstantInt>(C)) { + switch (CInt->getType()->getScalarSizeInBits()) { + case 32: + BitcastTy = Type::getFloatTy(Context); + break; + case 64: + BitcastTy = Type::getDoubleTy(Context); + break; + case 80: + BitcastTy = Type::getX86_FP80Ty(Context); + break; + default: + break; + } + } else if (auto *CDV = dyn_cast<ConstantDataVector>(C)) { + const int NumElements = + cast<VectorType>(CDV->getType())->getElementCount().getFixedValue(); + switch (CDV->getType()->getScalarSizeInBits()) { + case 32: + BitcastTy = + VectorType::get(Type::getFloatTy(Context), NumElements, false); + break; + case 64: + BitcastTy = + VectorType::get(Type::getDoubleTy(Context), NumElements, false); + break; + case 80: + BitcastTy = + VectorType::get(Type::getX86_FP80Ty(Context), NumElements, false); + break; + default: + break; + } + } + if (BitcastTy) { + const MemoryExtents Extents = getMemoryExtentsOrDie(BitcastTy); + Value *ShadowPtr = Builder.CreateCall( + NsanGetShadowPtrForStore[Extents.ValueType], + {PtrOp, ConstantInt::get(IntptrTy, Extents.NumElts)}); + // Bitcast the integer value to the appropriate FT type and extend to 2FT. + Type *ExtVT = Config.getExtendedFPType(BitcastTy); + Value *Shadow = + Builder.CreateFPExt(Builder.CreateBitCast(C, BitcastTy), ExtVT); + Builder.CreateAlignedStore(Shadow, ShadowPtr, Align(1), + Store.isVolatile()); + return; + } + } + // All other stores just reset the shadow value to unknown. + Builder.CreateCall(NsanSetValueUnknown, {Dst, ValueSize}); +} + +void NumericalStabilitySanitizer::propagateShadowValues( + Instruction &Inst, const TargetLibraryInfo &TLI, + const ValueToShadowMap &Map) { + if (auto *Store = dyn_cast<StoreInst>(&Inst)) { + Value *StoredValue = Store->getValueOperand(); + Type *VT = StoredValue->getType(); + Type *ExtendedVT = Config.getExtendedFPType(VT); + if (ExtendedVT == nullptr) + return propagateNonFTStore(*Store, VT, Map); + return propagateFTStore(*Store, VT, ExtendedVT, Map); + } + + if (auto *FCmp = dyn_cast<FCmpInst>(&Inst)) { + emitFCmpCheck(*FCmp, Map); + return; + } + + if (auto *CB = dyn_cast<CallBase>(&Inst)) { + maybeAddSuffixForNsanInterface(CB); + if (CallInst *CI = dyn_cast<CallInst>(&Inst)) + maybeMarkSanitizerLibraryCallNoBuiltin(CI, &TLI); + if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(&Inst)) { + instrumentMemIntrinsic(MI); + return; + } + populateShadowStack(*CB, TLI, Map); + return; + } + + if (auto *RetInst = dyn_cast<ReturnInst>(&Inst)) { + if (!ClCheckRet) + return; + + Value *RV = RetInst->getReturnValue(); + if (RV == nullptr) + return; // This is a `ret void`. + Type *VT = RV->getType(); + Type *ExtendedVT = Config.getExtendedFPType(VT); + if (ExtendedVT == nullptr) + return; // Not an FT ret. + Value *RVShadow = Map.getShadow(RV); + IRBuilder<> Builder(RetInst); + + RVShadow = emitCheck(RV, RVShadow, Builder, CheckLoc::makeRet()); + ++NumInstrumentedFTRets; + // Store tag. + Value *FnAddr = + Builder.CreatePtrToInt(Inst.getParent()->getParent(), IntptrTy); + Builder.CreateStore(FnAddr, NsanShadowRetTag); + // Store value. + Value *ShadowRetValPtr = + Builder.CreateConstGEP2_64(NsanShadowRetType, NsanShadowRetPtr, 0, 0); + Builder.CreateStore(RVShadow, ShadowRetValPtr); + return; + } + + if (InsertValueInst *Insert = dyn_cast<InsertValueInst>(&Inst)) { + Value *V = Insert->getOperand(1); + Type *VT = V->getType(); + Type *ExtendedVT = Config.getExtendedFPType(VT); + if (ExtendedVT == nullptr) + return; + IRBuilder<> Builder(Insert); + emitCheck(V, Map.getShadow(V), Builder, CheckLoc::makeInsert()); + return; + } +} + +// Moves fast math flags from the function to individual instructions, and +// removes the attribute from the function. +// TODO: Make this controllable with a flag. +static void moveFastMathFlags(Function &F, + std::vector<Instruction *> &Instructions) { + FastMathFlags FMF; +#define MOVE_FLAG(attr, setter) \ + if (F.getFnAttribute(attr).getValueAsString() == "true") { \ + F.removeFnAttr(attr); \ + FMF.set##setter(); \ + } + MOVE_FLAG("unsafe-fp-math", Fast) + MOVE_FLAG("no-infs-fp-math", NoInfs) + MOVE_FLAG("no-nans-fp-math", NoNaNs) + MOVE_FLAG("no-signed-zeros-fp-math", NoSignedZeros) +#undef MOVE_FLAG + + for (Instruction *I : Instructions) + if (isa<FPMathOperator>(I)) + I->setFastMathFlags(FMF); +} + +bool NumericalStabilitySanitizer::sanitizeFunction( + Function &F, const TargetLibraryInfo &TLI) { + if (!F.hasFnAttribute(Attribute::SanitizeNumericalStability)) + return false; + + // This is required to prevent instrumenting call to __nsan_init from within + // the module constructor. + if (F.getName() == kNsanModuleCtorName) + return false; + SmallVector<Instruction *, 8> AllLoadsAndStores; + SmallVector<Instruction *, 8> LocalLoadsAndStores; + + // The instrumentation maintains: + // - for each IR value `v` of floating-point (or vector floating-point) type + // FT, a shadow IR value `s(v)` with twice the precision 2FT (e.g. + // double for float and f128 for double). + // - A shadow memory, which stores `s(v)` for any `v` that has been stored, + // along with a shadow memory tag, which stores whether the value in the + // corresponding shadow memory is valid. Note that this might be + // incorrect if a non-instrumented function stores to memory, or if + // memory is stored to through a char pointer. + // - A shadow stack, which holds `s(v)` for any floating-point argument `v` + // of a call to an instrumented function. This allows + // instrumented functions to retrieve the shadow values for their + // arguments. + // Because instrumented functions can be called from non-instrumented + // functions, the stack needs to include a tag so that the instrumented + // function knows whether shadow values are available for their + // parameters (i.e. whether is was called by an instrumented function). + // When shadow arguments are not available, they have to be recreated by + // extending the precision of the non-shadow arguments to the non-shadow + // value. Non-instrumented functions do not modify (or even know about) the + // shadow stack. The shadow stack pointer is __nsan_shadow_args. The shadow + // stack tag is __nsan_shadow_args_tag. The tag is any unique identifier + // for the function (we use the address of the function). Both variables + // are thread local. + // Example: + // calls shadow stack tag shadow stack + // ======================================================================= + // non_instrumented_1() 0 0 + // | + // v + // instrumented_2(float a) 0 0 + // | + // v + // instrumented_3(float b, double c) &instrumented_3 s(b),s(c) + // | + // v + // instrumented_4(float d) &instrumented_4 s(d) + // | + // v + // non_instrumented_5(float e) &non_instrumented_5 s(e) + // | + // v + // instrumented_6(float f) &non_instrumented_5 s(e) + // + // On entry, instrumented_2 checks whether the tag corresponds to its + // function ptr. + // Note that functions reset the tag to 0 after reading shadow parameters. + // This ensures that the function does not erroneously read invalid data if + // called twice in the same stack, once from an instrumented function and + // once from an uninstrumented one. For example, in the following example, + // resetting the tag in (A) ensures that (B) does not reuse the same the + // shadow arguments (which would be incorrect). + // instrumented_1(float a) + // | + // v + // instrumented_2(float b) (A) + // | + // v + // non_instrumented_3() + // | + // v + // instrumented_2(float b) (B) + // + // - A shadow return slot. Any function that returns a floating-point value + // places a shadow return value in __nsan_shadow_ret_val. Again, because + // we might be calling non-instrumented functions, this value is guarded + // by __nsan_shadow_ret_tag marker indicating which instrumented function + // placed the value in __nsan_shadow_ret_val, so that the caller can check + // that this corresponds to the callee. Both variables are thread local. + // + // For example, in the following example, the instrumentation in + // `instrumented_1` rejects the shadow return value from `instrumented_3` + // because is is not tagged as expected (`&instrumented_3` instead of + // `non_instrumented_2`): + // + // instrumented_1() + // | + // v + // float non_instrumented_2() + // | + // v + // float instrumented_3() + // + // Calls of known math functions (sin, cos, exp, ...) are duplicated to call + // their overload on the shadow type. + + // Collect all instructions before processing, as creating shadow values + // creates new instructions inside the function. + std::vector<Instruction *> OriginalInstructions; + for (BasicBlock &BB : F) + for (Instruction &Inst : BB) + OriginalInstructions.emplace_back(&Inst); + + moveFastMathFlags(F, OriginalInstructions); + ValueToShadowMap ValueToShadow(Config); + + // In the first pass, we create shadow values for all FT function arguments + // and all phis. This ensures that the DFS of the next pass does not have + // any loops. + std::vector<PHINode *> OriginalPhis; + createShadowArguments(F, TLI, ValueToShadow); + for (Instruction *I : OriginalInstructions) { + if (PHINode *Phi = dyn_cast<PHINode>(I)) { + if (PHINode *Shadow = maybeCreateShadowPhi(*Phi, TLI)) { + OriginalPhis.push_back(Phi); + ValueToShadow.setShadow(*Phi, *Shadow); + } + } + } + + // Create shadow values for all instructions creating FT values. + for (Instruction *I : OriginalInstructions) + maybeCreateShadowValue(*I, TLI, ValueToShadow); + + // Propagate shadow values across stores, calls and rets. + for (Instruction *I : OriginalInstructions) + propagateShadowValues(*I, TLI, ValueToShadow); + + // The last pass populates shadow phis with shadow values. + for (PHINode *Phi : OriginalPhis) { + PHINode *ShadowPhi = dyn_cast<PHINode>(ValueToShadow.getShadow(Phi)); + for (unsigned I : seq(Phi->getNumOperands())) { + Value *V = Phi->getOperand(I); + Value *Shadow = ValueToShadow.getShadow(V); + BasicBlock *IncomingBB = Phi->getIncomingBlock(I); + // For some instructions (e.g. invoke), we create the shadow in a separate + // block, different from the block where the original value is created. + // In that case, the shadow phi might need to refer to this block instead + // of the original block. + // Note that this can only happen for instructions as constant shadows are + // always created in the same block. + ShadowPhi->addIncoming(Shadow, IncomingBB); + } + } + + return !ValueToShadow.empty(); +} + +// Instrument the memory intrinsics so that they properly modify the shadow +// memory. +bool NumericalStabilitySanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) { + IRBuilder<> Builder(MI); + if (auto *M = dyn_cast<MemSetInst>(MI)) { + Builder.CreateCall( + NsanSetValueUnknown, + {/*Address=*/M->getArgOperand(0), + /*Size=*/Builder.CreateIntCast(M->getArgOperand(2), IntptrTy, false)}); + } else if (auto *M = dyn_cast<MemTransferInst>(MI)) { + Builder.CreateCall( + NsanCopyValues, + {/*Destination=*/M->getArgOperand(0), + /*Source=*/M->getArgOperand(1), + /*Size=*/Builder.CreateIntCast(M->getArgOperand(2), IntptrTy, false)}); + } + return false; +} + +void NumericalStabilitySanitizer::maybeAddSuffixForNsanInterface(CallBase *CI) { + Function *Fn = CI->getCalledFunction(); + if (Fn == nullptr) + return; + + if (!Fn->getName().starts_with("__nsan_")) + return; + + if (Fn->getName() == "__nsan_dump_shadow_mem") { + assert(CI->arg_size() == 4 && + "invalid prototype for __nsan_dump_shadow_mem"); + // __nsan_dump_shadow_mem requires an extra parameter with the dynamic + // configuration: + // (shadow_type_id_for_long_double << 16) | (shadow_type_id_for_double << 8) + // | shadow_type_id_for_double + const uint64_t shadow_value_type_ids = + (static_cast<size_t>(Config.byValueType(kLongDouble).getNsanTypeId()) + << 16) | + (static_cast<size_t>(Config.byValueType(kDouble).getNsanTypeId()) + << 8) | + static_cast<size_t>(Config.byValueType(kFloat).getNsanTypeId()); + CI->setArgOperand(3, ConstantInt::get(IntptrTy, shadow_value_type_ids)); + } +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp new file mode 100644 index 000000000000..de1d4d2381c0 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp @@ -0,0 +1,351 @@ +//===- PGOCtxProfLowering.cpp - Contextual PGO Instr. Lowering ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// + +#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/IR/Analysis.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Support/CommandLine.h" +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "ctx-instr-lower" + +static cl::list<std::string> ContextRoots( + "profile-context-root", cl::Hidden, + cl::desc( + "A function name, assumed to be global, which will be treated as the " + "root of an interesting graph, which will be profiled independently " + "from other similar graphs.")); + +bool PGOCtxProfLoweringPass::isContextualIRPGOEnabled() { + return !ContextRoots.empty(); +} + +// the names of symbols we expect in compiler-rt. Using a namespace for +// readability. +namespace CompilerRtAPINames { +static auto StartCtx = "__llvm_ctx_profile_start_context"; +static auto ReleaseCtx = "__llvm_ctx_profile_release_context"; +static auto GetCtx = "__llvm_ctx_profile_get_context"; +static auto ExpectedCalleeTLS = "__llvm_ctx_profile_expected_callee"; +static auto CallsiteTLS = "__llvm_ctx_profile_callsite"; +} // namespace CompilerRtAPINames + +namespace { +// The lowering logic and state. +class CtxInstrumentationLowerer final { + Module &M; + ModuleAnalysisManager &MAM; + Type *ContextNodeTy = nullptr; + Type *ContextRootTy = nullptr; + + DenseMap<const Function *, Constant *> ContextRootMap; + Function *StartCtx = nullptr; + Function *GetCtx = nullptr; + Function *ReleaseCtx = nullptr; + GlobalVariable *ExpectedCalleeTLS = nullptr; + GlobalVariable *CallsiteInfoTLS = nullptr; + +public: + CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM); + // return true if lowering happened (i.e. a change was made) + bool lowerFunction(Function &F); +}; + +// llvm.instrprof.increment[.step] captures the total number of counters as one +// of its parameters, and llvm.instrprof.callsite captures the total number of +// callsites. Those values are the same for instances of those intrinsics in +// this function. Find the first instance of each and return them. +std::pair<uint32_t, uint32_t> getNrCountersAndCallsites(const Function &F) { + uint32_t NrCounters = 0; + uint32_t NrCallsites = 0; + for (const auto &BB : F) { + for (const auto &I : BB) { + if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(&I)) { + uint32_t V = + static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue()); + assert((!NrCounters || V == NrCounters) && + "expected all llvm.instrprof.increment[.step] intrinsics to " + "have the same total nr of counters parameter"); + NrCounters = V; + } else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(&I)) { + uint32_t V = + static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue()); + assert((!NrCallsites || V == NrCallsites) && + "expected all llvm.instrprof.callsite intrinsics to have the " + "same total nr of callsites parameter"); + NrCallsites = V; + } +#if NDEBUG + if (NrCounters && NrCallsites) + return std::make_pair(NrCounters, NrCallsites); +#endif + } + } + return {NrCounters, NrCallsites}; +} +} // namespace + +// set up tie-in with compiler-rt. +// NOTE!!! +// These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h +CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M, + ModuleAnalysisManager &MAM) + : M(M), MAM(MAM) { + auto *PointerTy = PointerType::get(M.getContext(), 0); + auto *SanitizerMutexType = Type::getInt8Ty(M.getContext()); + auto *I32Ty = Type::getInt32Ty(M.getContext()); + auto *I64Ty = Type::getInt64Ty(M.getContext()); + + // The ContextRoot type + ContextRootTy = + StructType::get(M.getContext(), { + PointerTy, /*FirstNode*/ + PointerTy, /*FirstMemBlock*/ + PointerTy, /*CurrentMem*/ + SanitizerMutexType, /*Taken*/ + }); + // The Context header. + ContextNodeTy = StructType::get(M.getContext(), { + I64Ty, /*Guid*/ + PointerTy, /*Next*/ + I32Ty, /*NrCounters*/ + I32Ty, /*NrCallsites*/ + }); + + // Define a global for each entrypoint. We'll reuse the entrypoint's name as + // prefix. We assume the entrypoint names to be unique. + for (const auto &Fname : ContextRoots) { + if (const auto *F = M.getFunction(Fname)) { + if (F->isDeclaration()) + continue; + auto *G = M.getOrInsertGlobal(Fname + "_ctx_root", ContextRootTy); + cast<GlobalVariable>(G)->setInitializer( + Constant::getNullValue(ContextRootTy)); + ContextRootMap.insert(std::make_pair(F, G)); + for (const auto &BB : *F) + for (const auto &I : BB) + if (const auto *CB = dyn_cast<CallBase>(&I)) + if (CB->isMustTailCall()) { + M.getContext().emitError( + "The function " + Fname + + " was indicated as a context root, but it features musttail " + "calls, which is not supported."); + } + } + } + + // Declare the functions we will call. + StartCtx = cast<Function>( + M.getOrInsertFunction( + CompilerRtAPINames::StartCtx, + FunctionType::get(ContextNodeTy->getPointerTo(), + {ContextRootTy->getPointerTo(), /*ContextRoot*/ + I64Ty, /*Guid*/ I32Ty, + /*NrCounters*/ I32Ty /*NrCallsites*/}, + false)) + .getCallee()); + GetCtx = cast<Function>( + M.getOrInsertFunction(CompilerRtAPINames::GetCtx, + FunctionType::get(ContextNodeTy->getPointerTo(), + {PointerTy, /*Callee*/ + I64Ty, /*Guid*/ + I32Ty, /*NrCounters*/ + I32Ty}, /*NrCallsites*/ + false)) + .getCallee()); + ReleaseCtx = cast<Function>( + M.getOrInsertFunction( + CompilerRtAPINames::ReleaseCtx, + FunctionType::get(Type::getVoidTy(M.getContext()), + { + ContextRootTy->getPointerTo(), /*ContextRoot*/ + }, + false)) + .getCallee()); + + // Declare the TLSes we will need to use. + CallsiteInfoTLS = + new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage, + nullptr, CompilerRtAPINames::CallsiteTLS); + CallsiteInfoTLS->setThreadLocal(true); + CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility); + ExpectedCalleeTLS = + new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage, + nullptr, CompilerRtAPINames::ExpectedCalleeTLS); + ExpectedCalleeTLS->setThreadLocal(true); + ExpectedCalleeTLS->setVisibility(llvm::GlobalValue::HiddenVisibility); +} + +PreservedAnalyses PGOCtxProfLoweringPass::run(Module &M, + ModuleAnalysisManager &MAM) { + CtxInstrumentationLowerer Lowerer(M, MAM); + bool Changed = false; + for (auto &F : M) + Changed |= Lowerer.lowerFunction(F); + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} + +bool CtxInstrumentationLowerer::lowerFunction(Function &F) { + if (F.isDeclaration()) + return false; + auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); + + Value *Guid = nullptr; + auto [NrCounters, NrCallsites] = getNrCountersAndCallsites(F); + + Value *Context = nullptr; + Value *RealContext = nullptr; + + StructType *ThisContextType = nullptr; + Value *TheRootContext = nullptr; + Value *ExpectedCalleeTLSAddr = nullptr; + Value *CallsiteInfoTLSAddr = nullptr; + + auto &Head = F.getEntryBlock(); + for (auto &I : Head) { + // Find the increment intrinsic in the entry basic block. + if (auto *Mark = dyn_cast<InstrProfIncrementInst>(&I)) { + assert(Mark->getIndex()->isZero()); + + IRBuilder<> Builder(Mark); + // FIXME(mtrofin): use InstrProfSymtab::getCanonicalName + Guid = Builder.getInt64(F.getGUID()); + // The type of the context of this function is now knowable since we have + // NrCallsites and NrCounters. We delcare it here because it's more + // convenient - we have the Builder. + ThisContextType = StructType::get( + F.getContext(), + {ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NrCounters), + ArrayType::get(Builder.getPtrTy(), NrCallsites)}); + // Figure out which way we obtain the context object for this function - + // if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the + // former case, we also set TheRootContext since we need to release it + // at the end (plus it can be used to know if we have an entrypoint or a + // regular function) + auto Iter = ContextRootMap.find(&F); + if (Iter != ContextRootMap.end()) { + TheRootContext = Iter->second; + Context = Builder.CreateCall(StartCtx, {TheRootContext, Guid, + Builder.getInt32(NrCounters), + Builder.getInt32(NrCallsites)}); + ORE.emit( + [&] { return OptimizationRemark(DEBUG_TYPE, "Entrypoint", &F); }); + } else { + Context = + Builder.CreateCall(GetCtx, {&F, Guid, Builder.getInt32(NrCounters), + Builder.getInt32(NrCallsites)}); + ORE.emit([&] { + return OptimizationRemark(DEBUG_TYPE, "RegularFunction", &F); + }); + } + // The context could be scratch. + auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty()); + if (NrCallsites > 0) { + // Figure out which index of the TLS 2-element buffers to use. + // Scratch context => we use index == 1. Real contexts => index == 0. + auto *Index = Builder.CreateAnd(CtxAsInt, Builder.getInt64(1)); + // The GEPs corresponding to that index, in the respective TLS. + ExpectedCalleeTLSAddr = Builder.CreateGEP( + Builder.getInt8Ty()->getPointerTo(), + Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {Index}); + CallsiteInfoTLSAddr = Builder.CreateGEP( + Builder.getInt32Ty(), + Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {Index}); + } + // Because the context pointer may have LSB set (to indicate scratch), + // clear it for the value we use as base address for the counter vector. + // This way, if later we want to have "real" (not clobbered) buffers + // acting as scratch, the lowering (at least this part of it that deals + // with counters) stays the same. + RealContext = Builder.CreateIntToPtr( + Builder.CreateAnd(CtxAsInt, Builder.getInt64(-2)), + ThisContextType->getPointerTo()); + I.eraseFromParent(); + break; + } + } + if (!Context) { + ORE.emit([&] { + return OptimizationRemarkMissed(DEBUG_TYPE, "Skip", &F) + << "Function doesn't have instrumentation, skipping"; + }); + return false; + } + + bool ContextWasReleased = false; + for (auto &BB : F) { + for (auto &I : llvm::make_early_inc_range(BB)) { + if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(&I)) { + IRBuilder<> Builder(Instr); + switch (Instr->getIntrinsicID()) { + case llvm::Intrinsic::instrprof_increment: + case llvm::Intrinsic::instrprof_increment_step: { + // Increments (or increment-steps) are just a typical load - increment + // - store in the RealContext. + auto *AsStep = cast<InstrProfIncrementInst>(Instr); + auto *GEP = Builder.CreateGEP( + ThisContextType, RealContext, + {Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()}); + Builder.CreateStore( + Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(), GEP), + AsStep->getStep()), + GEP); + } break; + case llvm::Intrinsic::instrprof_callsite: + // callsite lowering: write the called value in the expected callee + // TLS we treat the TLS as volatile because of signal handlers and to + // avoid these being moved away from the callsite they decorate. + auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr); + Builder.CreateStore(CSIntrinsic->getCallee(), ExpectedCalleeTLSAddr, + true); + // write the GEP of the slot in the sub-contexts portion of the + // context in TLS. Now, here, we use the actual Context value - as + // returned from compiler-rt - which may have the LSB set if the + // Context was scratch. Since the header of the context object and + // then the values are all 8-aligned (or, really, insofar as we care, + // they are even) - if the context is scratch (meaning, an odd value), + // so will the GEP. This is important because this is then visible to + // compiler-rt which will produce scratch contexts for callers that + // have a scratch context. + Builder.CreateStore( + Builder.CreateGEP(ThisContextType, Context, + {Builder.getInt32(0), Builder.getInt32(2), + CSIntrinsic->getIndex()}), + CallsiteInfoTLSAddr, true); + break; + } + I.eraseFromParent(); + } else if (TheRootContext && isa<ReturnInst>(I)) { + // Remember to release the context if we are an entrypoint. + IRBuilder<> Builder(&I); + Builder.CreateCall(ReleaseCtx, {TheRootContext}); + ContextWasReleased = true; + } + } + } + // FIXME: This would happen if the entrypoint tailcalls. A way to fix would be + // to disallow this, (so this then stays as an error), another is to detect + // that and then do a wrapper or disallow the tail call. This only affects + // instrumentation, when we want to detect the call graph. + if (TheRootContext && !ContextWasReleased) + F.getContext().emitError( + "[ctx_prof] An entrypoint was instrumented but it has no `ret` " + "instructions above which to release the context: " + + F.getName()); + return true; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOForceFunctionAttrs.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOForceFunctionAttrs.cpp new file mode 100644 index 000000000000..450c191a896d --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOForceFunctionAttrs.cpp @@ -0,0 +1,65 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Instrumentation/PGOForceFunctionAttrs.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace llvm; + +static bool shouldRunOnFunction(Function &F, ProfileSummaryInfo &PSI, + FunctionAnalysisManager &FAM) { + if (F.isDeclaration()) + return false; + // Respect existing attributes. + if (F.hasOptNone() || F.hasOptSize() || F.hasMinSize()) + return false; + if (F.hasFnAttribute(Attribute::Cold)) + return true; + if (!PSI.hasProfileSummary()) + return false; + BlockFrequencyInfo &BFI = FAM.getResult<BlockFrequencyAnalysis>(F); + return PSI.isFunctionColdInCallGraph(&F, BFI); +} + +PreservedAnalyses PGOForceFunctionAttrsPass::run(Module &M, + ModuleAnalysisManager &AM) { + if (ColdType == PGOOptions::ColdFuncOpt::Default) + return PreservedAnalyses::all(); + ProfileSummaryInfo &PSI = AM.getResult<ProfileSummaryAnalysis>(M); + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + bool MadeChange = false; + for (Function &F : M) { + if (!shouldRunOnFunction(F, PSI, FAM)) + continue; + switch (ColdType) { + case PGOOptions::ColdFuncOpt::Default: + llvm_unreachable("bailed out for default above"); + break; + case PGOOptions::ColdFuncOpt::OptSize: + F.addFnAttr(Attribute::OptimizeForSize); + break; + case PGOOptions::ColdFuncOpt::MinSize: + F.addFnAttr(Attribute::MinSize); + break; + case PGOOptions::ColdFuncOpt::OptNone: + // alwaysinline is incompatible with optnone. + if (F.hasFnAttribute(Attribute::AlwaysInline)) + continue; + F.addFnAttr(Attribute::OptimizeNone); + F.addFnAttr(Attribute::NoInline); + break; + } + MadeChange = true; + } + return MadeChange ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index c20fc942eaf0..4924d5a31747 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -110,6 +110,7 @@ #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Instrumentation/BlockCoverageInference.h" #include "llvm/Transforms/Instrumentation/CFGMST.h" +#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/MisExpect.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -119,6 +120,7 @@ #include <memory> #include <numeric> #include <optional> +#include <stack> #include <string> #include <unordered_map> #include <utility> @@ -318,6 +320,8 @@ static cl::opt<unsigned> PGOFunctionCriticalEdgeThreshold( cl::desc("Do not instrument functions with the number of critical edges " " greater than this threshold.")); +extern cl::opt<unsigned> MaxNumVTableAnnotations; + namespace llvm { // Command line option to turn on CFG dot dump after profile annotation. // Defined in Analysis/BlockFrequencyInfo.cpp: -pgo-view-counts @@ -327,9 +331,27 @@ extern cl::opt<PGOViewCountsType> PGOViewCounts; // Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name= extern cl::opt<std::string> ViewBlockFreqFuncName; +// Command line option to enable vtable value profiling. Defined in +// ProfileData/InstrProf.cpp: -enable-vtable-value-profiling= +extern cl::opt<bool> EnableVTableValueProfiling; +extern cl::opt<bool> EnableVTableProfileUse; extern cl::opt<InstrProfCorrelator::ProfCorrelatorKind> ProfileCorrelate; } // namespace llvm +bool shouldInstrumentEntryBB() { + return PGOInstrumentEntry || + PGOCtxProfLoweringPass::isContextualIRPGOEnabled(); +} + +// FIXME(mtrofin): re-enable this for ctx profiling, for non-indirect calls. Ctx +// profiling implicitly captures indirect call cases, but not other values. +// Supporting other values is relatively straight-forward - just another counter +// range within the context. +bool isValueProfilingDisabled() { + return DisableValueProfiling || + PGOCtxProfLoweringPass::isContextualIRPGOEnabled(); +} + // Return a string describing the branch condition that can be // used in static branch probability heuristics: static std::string getBranchCondString(Instruction *TI) { @@ -376,7 +398,7 @@ static GlobalVariable *createIRLevelProfileFlagVar(Module &M, bool IsCS) { uint64_t ProfileVersion = (INSTR_PROF_RAW_VERSION | VARIANT_MASK_IR_PROF); if (IsCS) ProfileVersion |= VARIANT_MASK_CSIR_PROF; - if (PGOInstrumentEntry) + if (shouldInstrumentEntryBB()) ProfileVersion |= VARIANT_MASK_INSTR_ENTRY; if (DebugInfoCorrelate || ProfileCorrelate == InstrProfCorrelator::DEBUG_INFO) ProfileVersion |= VARIANT_MASK_DBG_CORRELATE; @@ -581,6 +603,8 @@ public: NumOfPGOMemIntrinsics += ValueSites[IPVK_MemOPSize].size(); NumOfPGOBB += MST.bbInfoSize(); ValueSites[IPVK_IndirectCallTarget] = VPC.get(IPVK_IndirectCallTarget); + if (EnableVTableValueProfiling) + ValueSites[IPVK_VTableTarget] = VPC.get(IPVK_VTableTarget); } else { NumOfCSPGOSelectInsts += SIVisitor.getNumOfSelectInsts(); NumOfCSPGOMemIntrinsics += ValueSites[IPVK_MemOPSize].size(); @@ -856,7 +880,7 @@ static void instrumentOneFunc( } FuncPGOInstrumentation<PGOEdge, PGOBBInfo> FuncInfo( - F, TLI, ComdatMembers, true, BPI, BFI, IsCS, PGOInstrumentEntry, + F, TLI, ComdatMembers, true, BPI, BFI, IsCS, shouldInstrumentEntryBB(), PGOBlockCoverage); auto Name = FuncInfo.FuncNameVar; @@ -878,6 +902,43 @@ static void instrumentOneFunc( unsigned NumCounters = InstrumentBBs.size() + FuncInfo.SIVisitor.getNumOfSelectInsts(); + if (PGOCtxProfLoweringPass::isContextualIRPGOEnabled()) { + auto *CSIntrinsic = + Intrinsic::getDeclaration(M, Intrinsic::instrprof_callsite); + // We want to count the instrumentable callsites, then instrument them. This + // is because the llvm.instrprof.callsite intrinsic has an argument (like + // the other instrprof intrinsics) capturing the total number of + // instrumented objects (counters, or callsites, in this case). In this + // case, we want that value so we can readily pass it to the compiler-rt + // APIs that may have to allocate memory based on the nr of callsites. + // The traversal logic is the same for both counting and instrumentation, + // just needs to be done in succession. + auto Visit = [&](llvm::function_ref<void(CallBase * CB)> Visitor) { + for (auto &BB : F) + for (auto &Instr : BB) + if (auto *CS = dyn_cast<CallBase>(&Instr)) { + if ((CS->getCalledFunction() && + CS->getCalledFunction()->isIntrinsic()) || + dyn_cast<InlineAsm>(CS->getCalledOperand())) + continue; + Visitor(CS); + } + }; + // First, count callsites. + uint32_t TotalNrCallsites = 0; + Visit([&TotalNrCallsites](auto *) { ++TotalNrCallsites; }); + + // Now instrument. + uint32_t CallsiteIndex = 0; + Visit([&](auto *CB) { + IRBuilder<> Builder(CB); + Builder.CreateCall(CSIntrinsic, + {Name, CFGHash, Builder.getInt32(TotalNrCallsites), + Builder.getInt32(CallsiteIndex++), + CB->getCalledOperand()}); + }); + } + uint32_t I = 0; if (PGOTemporalInstrumentation) { NumCounters += PGOBlockCoverage ? 8 : 1; @@ -909,7 +970,7 @@ static void instrumentOneFunc( FuncInfo.FunctionHash); assert(I == NumCounters); - if (DisableValueProfiling) + if (isValueProfilingDisabled()) return; NumOfPGOICall += FuncInfo.ValueSites[IPVK_IndirectCallTarget].size(); @@ -920,7 +981,7 @@ static void instrumentOneFunc( // on the instrumentation call based on the funclet coloring. DenseMap<BasicBlock *, ColorVector> BlockColors; if (F.hasPersonalityFn() && - isFuncletEHPersonality(classifyEHPersonality(F.getPersonalityFn()))) + isScopedEHPersonality(classifyEHPersonality(F.getPersonalityFn()))) BlockColors = colorEHFunclets(F); // For each VP Kind, walk the VP candidates and instrument each one. @@ -961,21 +1022,16 @@ namespace { struct PGOUseEdge : public PGOEdge { using PGOEdge::PGOEdge; - bool CountValid = false; - uint64_t CountValue = 0; + std::optional<uint64_t> Count; // Set edge count value - void setEdgeCount(uint64_t Value) { - CountValue = Value; - CountValid = true; - } + void setEdgeCount(uint64_t Value) { Count = Value; } // Return the information string for this object. std::string infoString() const { - if (!CountValid) + if (!Count) return PGOEdge::infoString(); - return (Twine(PGOEdge::infoString()) + " Count=" + Twine(CountValue)) - .str(); + return (Twine(PGOEdge::infoString()) + " Count=" + Twine(*Count)).str(); } }; @@ -983,27 +1039,22 @@ using DirectEdges = SmallVector<PGOUseEdge *, 2>; // This class stores the auxiliary information for each BB. struct PGOUseBBInfo : public PGOBBInfo { - uint64_t CountValue = 0; - bool CountValid; + std::optional<uint64_t> Count; int32_t UnknownCountInEdge = 0; int32_t UnknownCountOutEdge = 0; DirectEdges InEdges; DirectEdges OutEdges; - PGOUseBBInfo(unsigned IX) : PGOBBInfo(IX), CountValid(false) {} + PGOUseBBInfo(unsigned IX) : PGOBBInfo(IX) {} // Set the profile count value for this BB. - void setBBInfoCount(uint64_t Value) { - CountValue = Value; - CountValid = true; - } + void setBBInfoCount(uint64_t Value) { Count = Value; } // Return the information string of this object. std::string infoString() const { - if (!CountValid) + if (!Count) return PGOBBInfo::infoString(); - return (Twine(PGOBBInfo::infoString()) + " Count=" + Twine(CountValue)) - .str(); + return (Twine(PGOBBInfo::infoString()) + " Count=" + Twine(*Count)).str(); } // Add an OutEdge and update the edge count. @@ -1027,7 +1078,8 @@ static uint64_t sumEdgeCount(const ArrayRef<PGOUseEdge *> Edges) { for (const auto &E : Edges) { if (E->Removed) continue; - Total += E->CountValue; + if (E->Count) + Total += *E->Count; } return Total; } @@ -1044,7 +1096,7 @@ public: : F(Func), M(Modu), BFI(BFIin), PSI(PSI), FuncInfo(Func, TLI, ComdatMembers, false, BPI, BFIin, IsCS, InstrumentFuncEntry, HasSingleByteCoverage), - FreqAttr(FFA_Normal), IsCS(IsCS) {} + FreqAttr(FFA_Normal), IsCS(IsCS), VPC(Func, TLI) {} void handleInstrProfError(Error Err, uint64_t MismatchedFuncSum); @@ -1126,6 +1178,8 @@ private: // Is to use the context sensitive profile. bool IsCS; + ValueProfileCollector VPC; + // Find the Instrumented BB and set the value. Return false on error. bool setInstrumentedCounts(const std::vector<uint64_t> &CountFromProfile); @@ -1216,17 +1270,17 @@ bool PGOUseFunc::setInstrumentedCounts( // If only one out-edge, the edge profile count should be the same as BB // profile count. - if (SrcInfo.CountValid && SrcInfo.OutEdges.size() == 1) - setEdgeCount(E.get(), SrcInfo.CountValue); + if (SrcInfo.Count && SrcInfo.OutEdges.size() == 1) + setEdgeCount(E.get(), *SrcInfo.Count); else { const BasicBlock *DestBB = E->DestBB; PGOUseBBInfo &DestInfo = getBBInfo(DestBB); // If only one in-edge, the edge profile count should be the same as BB // profile count. - if (DestInfo.CountValid && DestInfo.InEdges.size() == 1) - setEdgeCount(E.get(), DestInfo.CountValue); + if (DestInfo.Count && DestInfo.InEdges.size() == 1) + setEdgeCount(E.get(), *DestInfo.Count); } - if (E->CountValid) + if (E->Count) continue; // E's count should have been set from profile. If not, this meenas E skips // the instrumentation. We set the count to 0. @@ -1239,7 +1293,7 @@ bool PGOUseFunc::setInstrumentedCounts( // unknown edge in Edges vector. void PGOUseFunc::setEdgeCount(DirectEdges &Edges, uint64_t Value) { for (auto &E : Edges) { - if (E->CountValid) + if (E->Count) continue; E->setEdgeCount(Value); @@ -1371,6 +1425,7 @@ void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) { handleInstrProfError(std::move(Err), MismatchedFuncSum); return; } + IsCS ? NumOfCSPGOFunc++ : NumOfPGOFunc++; std::vector<uint64_t> &CountsFromProfile = Result.get().Counts; DenseMap<const BasicBlock *, bool> Coverage; @@ -1425,7 +1480,8 @@ void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) { for (auto *Succ : successors(&BB)) Weights.push_back((Coverage[Succ] || !Coverage[&BB]) ? 1 : 0); if (Weights.size() >= 2) - llvm::setBranchWeights(*BB.getTerminator(), Weights); + llvm::setBranchWeights(*BB.getTerminator(), Weights, + /*IsExpected=*/false); } unsigned NumCorruptCoverage = 0; @@ -1481,38 +1537,36 @@ void PGOUseFunc::populateCounters() { // For efficient traversal, it's better to start from the end as most // of the instrumented edges are at the end. for (auto &BB : reverse(F)) { - PGOUseBBInfo *Count = findBBInfo(&BB); - if (Count == nullptr) + PGOUseBBInfo *UseBBInfo = findBBInfo(&BB); + if (UseBBInfo == nullptr) continue; - if (!Count->CountValid) { - if (Count->UnknownCountOutEdge == 0) { - Count->CountValue = sumEdgeCount(Count->OutEdges); - Count->CountValid = true; + if (!UseBBInfo->Count) { + if (UseBBInfo->UnknownCountOutEdge == 0) { + UseBBInfo->Count = sumEdgeCount(UseBBInfo->OutEdges); Changes = true; - } else if (Count->UnknownCountInEdge == 0) { - Count->CountValue = sumEdgeCount(Count->InEdges); - Count->CountValid = true; + } else if (UseBBInfo->UnknownCountInEdge == 0) { + UseBBInfo->Count = sumEdgeCount(UseBBInfo->InEdges); Changes = true; } } - if (Count->CountValid) { - if (Count->UnknownCountOutEdge == 1) { + if (UseBBInfo->Count) { + if (UseBBInfo->UnknownCountOutEdge == 1) { uint64_t Total = 0; - uint64_t OutSum = sumEdgeCount(Count->OutEdges); + uint64_t OutSum = sumEdgeCount(UseBBInfo->OutEdges); // If the one of the successor block can early terminate (no-return), // we can end up with situation where out edge sum count is larger as // the source BB's count is collected by a post-dominated block. - if (Count->CountValue > OutSum) - Total = Count->CountValue - OutSum; - setEdgeCount(Count->OutEdges, Total); + if (*UseBBInfo->Count > OutSum) + Total = *UseBBInfo->Count - OutSum; + setEdgeCount(UseBBInfo->OutEdges, Total); Changes = true; } - if (Count->UnknownCountInEdge == 1) { + if (UseBBInfo->UnknownCountInEdge == 1) { uint64_t Total = 0; - uint64_t InSum = sumEdgeCount(Count->InEdges); - if (Count->CountValue > InSum) - Total = Count->CountValue - InSum; - setEdgeCount(Count->InEdges, Total); + uint64_t InSum = sumEdgeCount(UseBBInfo->InEdges); + if (*UseBBInfo->Count > InSum) + Total = *UseBBInfo->Count - InSum; + setEdgeCount(UseBBInfo->InEdges, Total); Changes = true; } } @@ -1527,16 +1581,16 @@ void PGOUseFunc::populateCounters() { auto BI = findBBInfo(&BB); if (BI == nullptr) continue; - assert(BI->CountValid && "BB count is not valid"); + assert(BI->Count && "BB count is not valid"); } #endif - uint64_t FuncEntryCount = getBBInfo(&*F.begin()).CountValue; + uint64_t FuncEntryCount = *getBBInfo(&*F.begin()).Count; uint64_t FuncMaxCount = FuncEntryCount; for (auto &BB : F) { auto BI = findBBInfo(&BB); if (BI == nullptr) continue; - FuncMaxCount = std::max(FuncMaxCount, BI->CountValue); + FuncMaxCount = std::max(FuncMaxCount, *BI->Count); } // Fix the obviously inconsistent entry count. @@ -1566,22 +1620,28 @@ void PGOUseFunc::setBranchWeights() { isa<CallBrInst>(TI))) continue; - if (getBBInfo(&BB).CountValue == 0) + const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB); + if (!*BBCountInfo.Count) continue; // We have a non-zero Branch BB. - const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB); - unsigned Size = BBCountInfo.OutEdges.size(); - SmallVector<uint64_t, 2> EdgeCounts(Size, 0); + + // SuccessorCount can be greater than OutEdgesCount, because + // removed edges don't appear in OutEdges. + unsigned OutEdgesCount = BBCountInfo.OutEdges.size(); + unsigned SuccessorCount = BB.getTerminator()->getNumSuccessors(); + assert(OutEdgesCount <= SuccessorCount); + + SmallVector<uint64_t, 2> EdgeCounts(SuccessorCount, 0); uint64_t MaxCount = 0; - for (unsigned s = 0; s < Size; s++) { - const PGOUseEdge *E = BBCountInfo.OutEdges[s]; + for (unsigned It = 0; It < OutEdgesCount; It++) { + const PGOUseEdge *E = BBCountInfo.OutEdges[It]; const BasicBlock *SrcBB = E->SrcBB; const BasicBlock *DestBB = E->DestBB; if (DestBB == nullptr) continue; unsigned SuccNum = GetSuccessorNumber(SrcBB, DestBB); - uint64_t EdgeCount = E->CountValue; + uint64_t EdgeCount = *E->Count; if (EdgeCount > MaxCount) MaxCount = EdgeCount; EdgeCounts[SuccNum] = EdgeCount; @@ -1622,7 +1682,7 @@ void PGOUseFunc::annotateIrrLoopHeaderWeights() { if (BFI->isIrrLoopHeader(&BB) || isIndirectBrTarget(&BB)) { Instruction *TI = BB.getTerminator(); const PGOUseBBInfo &BBCountInfo = getBBInfo(&BB); - setIrrLoopHeaderMetadata(M, TI, BBCountInfo.CountValue); + setIrrLoopHeaderMetadata(M, TI, *BBCountInfo.Count); } } } @@ -1649,7 +1709,7 @@ void SelectInstVisitor::annotateOneSelectInst(SelectInst &SI) { uint64_t TotalCount = 0; auto BI = UseFunc->findBBInfo(SI.getParent()); if (BI != nullptr) - TotalCount = BI->CountValue; + TotalCount = *BI->Count; // False Count SCounts[1] = (TotalCount > SCounts[0] ? TotalCount - SCounts[0] : 0); uint64_t MaxCount = std::max(SCounts[0], SCounts[1]); @@ -1679,9 +1739,17 @@ void SelectInstVisitor::visitSelectInst(SelectInst &SI) { llvm_unreachable("Unknown visiting mode"); } +static uint32_t getMaxNumAnnotations(InstrProfValueKind ValueProfKind) { + if (ValueProfKind == IPVK_MemOPSize) + return MaxNumMemOPAnnotations; + if (ValueProfKind == llvm::IPVK_VTableTarget) + return MaxNumVTableAnnotations; + return MaxNumAnnotations; +} + // Traverse all valuesites and annotate the instructions for all value kind. void PGOUseFunc::annotateValueSites() { - if (DisableValueProfiling) + if (isValueProfilingDisabled()) return; // Create the PGOFuncName meta data. @@ -1695,8 +1763,23 @@ void PGOUseFunc::annotateValueSites() { void PGOUseFunc::annotateValueSites(uint32_t Kind) { assert(Kind <= IPVK_Last); unsigned ValueSiteIndex = 0; - auto &ValueSites = FuncInfo.ValueSites[Kind]; + unsigned NumValueSites = ProfileRecord.getNumValueSites(Kind); + + // Since there isn't a reliable or fast way for profile reader to tell if a + // profile is generated with `-enable-vtable-value-profiling` on, we run the + // value profile collector over the function IR to find the instrumented sites + // iff function profile records shows the number of instrumented vtable sites + // is not zero. Function cfg already takes the number of instrumented + // indirect call sites into account so it doesn't hash the number of + // instrumented vtables; as a side effect it makes it easier to enable + // profiling and profile use in two steps if needed. + // TODO: Remove this if/when -enable-vtable-value-profiling is on by default. + if (NumValueSites > 0 && Kind == IPVK_VTableTarget && + NumValueSites != FuncInfo.ValueSites[IPVK_VTableTarget].size() && + MaxNumVTableAnnotations != 0) + FuncInfo.ValueSites[IPVK_VTableTarget] = VPC.get(IPVK_VTableTarget); + auto &ValueSites = FuncInfo.ValueSites[Kind]; if (NumValueSites != ValueSites.size()) { auto &Ctx = M->getContext(); Ctx.diagnose(DiagnosticInfoPGOProfile( @@ -1713,10 +1796,10 @@ void PGOUseFunc::annotateValueSites(uint32_t Kind) { LLVM_DEBUG(dbgs() << "Read one value site profile (kind = " << Kind << "): Index = " << ValueSiteIndex << " out of " << NumValueSites << "\n"); - annotateValueSite(*M, *I.AnnotatedInst, ProfileRecord, - static_cast<InstrProfValueKind>(Kind), ValueSiteIndex, - Kind == IPVK_MemOPSize ? MaxNumMemOPAnnotations - : MaxNumAnnotations); + annotateValueSite( + *M, *I.AnnotatedInst, ProfileRecord, + static_cast<InstrProfValueKind>(Kind), ValueSiteIndex, + getMaxNumAnnotations(static_cast<InstrProfValueKind>(Kind))); ValueSiteIndex++; } } @@ -1784,8 +1867,17 @@ static bool InstrumentAllFunctions( function_ref<BlockFrequencyInfo *(Function &)> LookupBFI, bool IsCS) { // For the context-sensitve instrumentation, we should have a separated pass // (before LTO/ThinLTO linking) to create these variables. - if (!IsCS) + if (!IsCS && !PGOCtxProfLoweringPass::isContextualIRPGOEnabled()) createIRLevelProfileFlagVar(M, /*IsCS=*/false); + + Triple TT(M.getTargetTriple()); + LLVMContext &Ctx = M.getContext(); + if (!TT.isOSBinFormatELF() && EnableVTableValueProfiling) + Ctx.diagnose(DiagnosticInfoPGOProfile( + M.getName().data(), + Twine("VTable value profiling is presently not " + "supported for non-ELF object formats"), + DS_Warning)); std::unordered_multimap<Comdat *, GlobalValue *> ComdatMembers; collectComdatMembers(M, ComdatMembers); @@ -1806,6 +1898,8 @@ PGOInstrumentationGenCreateVar::run(Module &M, ModuleAnalysisManager &MAM) { // The variable in a comdat may be discarded by LTO. Ensure the declaration // will be retained. appendToCompilerUsed(M, createIRLevelProfileFlagVar(M, /*IsCS=*/true)); + if (ProfileSampling) + createProfileSamplingVar(M); PreservedAnalyses PA; PA.preserve<FunctionAnalysisManagerModuleProxy>(); PA.preserveSet<AllAnalysesOn<Function>>(); @@ -1850,7 +1944,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI, if (!Func.findBBInfo(&BBI)) continue; auto BFICount = NBFI.getBlockProfileCount(&BBI); - CountValue = Func.getBBInfo(&BBI).CountValue; + CountValue = *Func.getBBInfo(&BBI).Count; BFICountValue = *BFICount; SumCount.add(APFloat(CountValue * 1.0), APFloat::rmNearestTiesToEven); SumBFICount.add(APFloat(BFICountValue * 1.0), APFloat::rmNearestTiesToEven); @@ -1866,7 +1960,7 @@ static void fixFuncEntryCount(PGOUseFunc &Func, LoopInfo &LI, if (Scale < 1.001 && Scale > 0.999) return; - uint64_t FuncEntryCount = Func.getBBInfo(&*F.begin()).CountValue; + uint64_t FuncEntryCount = *Func.getBBInfo(&*F.begin()).Count; uint64_t NewEntryCount = 0.5 + FuncEntryCount * Scale; if (NewEntryCount == 0) NewEntryCount = 1; @@ -1896,8 +1990,7 @@ static void verifyFuncBFI(PGOUseFunc &Func, LoopInfo &LI, uint64_t CountValue = 0; uint64_t BFICountValue = 0; - if (Func.getBBInfo(&BBI).CountValid) - CountValue = Func.getBBInfo(&BBI).CountValue; + CountValue = Func.getBBInfo(&BBI).Count.value_or(CountValue); BBNum++; if (CountValue) @@ -1997,6 +2090,16 @@ static bool annotateAllFunctions( return false; } + if (EnableVTableProfileUse) { + for (GlobalVariable &G : M.globals()) { + if (!G.hasName() || !G.hasMetadata(LLVMContext::MD_type)) + continue; + + // Create the PGOFuncName meta data. + createPGONameMetadata(G, getPGOName(G, false /* InLTO*/)); + } + } + // Add the profile summary (read from the header of the indexed summary) here // so that we can use it below when reading counters (which checks if the // function should be marked with a cold or inlinehint attribute). @@ -2015,6 +2118,8 @@ static bool annotateAllFunctions( bool InstrumentFuncEntry = PGOReader->instrEntryBBEnabled(); if (PGOInstrumentEntry.getNumOccurrences() > 0) InstrumentFuncEntry = PGOInstrumentEntry; + InstrumentFuncEntry |= PGOCtxProfLoweringPass::isContextualIRPGOEnabled(); + bool HasSingleByteCoverage = PGOReader->hasSingleByteCoverage(); for (auto &F : M) { if (skipPGOUse(F)) @@ -2068,7 +2173,7 @@ static bool annotateAllFunctions( HotFunctions.push_back(&F); if (PGOViewCounts != PGOVCT_None && (ViewBlockFreqFuncName.empty() || - F.getName().equals(ViewBlockFreqFuncName))) { + F.getName() == ViewBlockFreqFuncName)) { LoopInfo LI{DominatorTree(F)}; std::unique_ptr<BranchProbabilityInfo> NewBPI = std::make_unique<BranchProbabilityInfo>(F, LI); @@ -2083,7 +2188,7 @@ static bool annotateAllFunctions( } if (PGOViewRawCounts != PGOVCT_None && (ViewBlockFreqFuncName.empty() || - F.getName().equals(ViewBlockFreqFuncName))) { + F.getName() == ViewBlockFreqFuncName)) { if (PGOViewRawCounts == PGOVCT_Graph) if (ViewBlockFreqFuncName.empty()) WriteGraph(&Func, Twine("PGORawCounts_") + Func.getFunc().getName()); @@ -2170,7 +2275,6 @@ PreservedAnalyses PGOInstrumentationUse::run(Module &M, }; auto *PSI = &MAM.getResult<ProfileSummaryAnalysis>(M); - if (!annotateAllFunctions(M, ProfileFileName, ProfileRemappingFileName, *FS, LookupTLI, LookupBPI, LookupBFI, PSI, IsCS)) return PreservedAnalyses::all(); @@ -2185,7 +2289,7 @@ static std::string getSimpleNodeName(const BasicBlock *Node) { std::string SimpleNodeName; raw_string_ostream OS(SimpleNodeName); Node->printAsOperand(OS, false); - return OS.str(); + return SimpleNodeName; } void llvm::setProfMetadata(Module *M, Instruction *TI, @@ -2203,7 +2307,7 @@ void llvm::setProfMetadata(Module *M, Instruction *TI, misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false); - setBranchWeights(*TI, Weights); + setBranchWeights(*TI, Weights, /*IsExpected=*/false); if (EmitBranchProbability) { std::string BrCondStr = getBranchCondString(TI); if (BrCondStr.empty()) @@ -2279,8 +2383,8 @@ template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits { OS << getSimpleNodeName(Node) << ":\\l"; PGOUseBBInfo *BI = Graph->findBBInfo(Node); OS << "Count : "; - if (BI && BI->CountValid) - OS << BI->CountValue << "\\l"; + if (BI && BI->Count) + OS << *BI->Count << "\\l"; else OS << "Unknown\\l"; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp index fd0f69eca96e..dc51c564fbe0 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp @@ -177,10 +177,7 @@ public: MemOPSizeOpt(Function &Func, BlockFrequencyInfo &BFI, OptimizationRemarkEmitter &ORE, DominatorTree *DT, TargetLibraryInfo &TLI) - : Func(Func), BFI(BFI), ORE(ORE), DT(DT), TLI(TLI), Changed(false) { - ValueDataArray = - std::make_unique<InstrProfValueData[]>(INSTR_PROF_NUM_BUCKETS); - } + : Func(Func), BFI(BFI), ORE(ORE), DT(DT), TLI(TLI), Changed(false) {} bool isChanged() const { return Changed; } void perform() { WorkList.clear(); @@ -222,8 +219,6 @@ private: TargetLibraryInfo &TLI; bool Changed; std::vector<MemOp> WorkList; - // The space to read the profile annotation. - std::unique_ptr<InstrProfValueData[]> ValueDataArray; bool perform(MemOp MO); }; @@ -252,10 +247,11 @@ bool MemOPSizeOpt::perform(MemOp MO) { if (!MemOPOptMemcmpBcmp && (MO.isMemcmp(TLI) || MO.isBcmp(TLI))) return false; - uint32_t NumVals, MaxNumVals = INSTR_PROF_NUM_BUCKETS; + uint32_t MaxNumVals = INSTR_PROF_NUM_BUCKETS; uint64_t TotalCount; - if (!getValueProfDataFromInst(*MO.I, IPVK_MemOPSize, MaxNumVals, - ValueDataArray.get(), NumVals, TotalCount)) + auto VDs = + getValueProfDataFromInst(*MO.I, IPVK_MemOPSize, MaxNumVals, TotalCount); + if (VDs.empty()) return false; uint64_t ActualCount = TotalCount; @@ -267,7 +263,6 @@ bool MemOPSizeOpt::perform(MemOp MO) { ActualCount = *BBEdgeCount; } - ArrayRef<InstrProfValueData> VDs(ValueDataArray.get(), NumVals); LLVM_DEBUG(dbgs() << "Read one memory intrinsic profile with count " << ActualCount << "\n"); LLVM_DEBUG( @@ -400,11 +395,10 @@ bool MemOPSizeOpt::perform(MemOp MO) { // Clear the value profile data. MO.I->setMetadata(LLVMContext::MD_prof, nullptr); // If all promoted, we don't need the MD.prof metadata. - if (SavedRemainCount > 0 || Version != NumVals) { + if (SavedRemainCount > 0 || Version != VDs.size()) { // Otherwise we need update with the un-promoted records back. - ArrayRef<InstrProfValueData> RemVDs(RemainingVDs); - annotateValueSite(*Func.getParent(), *MO.I, RemVDs, SavedRemainCount, - IPVK_MemOPSize, NumVals); + annotateValueSite(*Func.getParent(), *MO.I, RemainingVDs, SavedRemainCount, + IPVK_MemOPSize, VDs.size()); } LLVM_DEBUG(dbgs() << "\n\n== Basic Block After==\n"); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PoisonChecking.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PoisonChecking.cpp index 42e7cd80374d..e094acdc3178 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PoisonChecking.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/PoisonChecking.cpp @@ -62,6 +62,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" using namespace llvm; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp index 230bb8b0a5dc..e326f30ad88e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/SanitizerBinaryMetadata.cpp @@ -14,6 +14,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/CaptureTracking.h" @@ -130,7 +131,7 @@ public: std::unique_ptr<SpecialCaseList> Ignorelist) : Mod(M), Options(transformOptionsFromCl(std::move(Opts))), Ignorelist(std::move(Ignorelist)), TargetTriple(M.getTargetTriple()), - IRB(M.getContext()) { + VersionStr(utostr(getVersion())), IRB(M.getContext()) { // FIXME: Make it work with other formats. assert(TargetTriple.isOSBinFormatELF() && "ELF only"); assert(!(TargetTriple.isNVPTX() || TargetTriple.isAMDGPU()) && @@ -167,10 +168,10 @@ private: StringRef getSectionName(StringRef SectionSuffix); // Returns the section start marker name. - Twine getSectionStart(StringRef SectionSuffix); + StringRef getSectionStart(StringRef SectionSuffix); // Returns the section end marker name. - Twine getSectionEnd(StringRef SectionSuffix); + StringRef getSectionEnd(StringRef SectionSuffix); // Returns true if the access to the address should be considered "atomic". bool pretendAtomicAccess(const Value *Addr); @@ -179,6 +180,7 @@ private: const SanitizerBinaryMetadataOptions Options; std::unique_ptr<SpecialCaseList> Ignorelist; const Triple TargetTriple; + const std::string VersionStr; IRBuilder<> IRB; BumpPtrAllocator Alloc; UniqueStringSaver StringPool{Alloc}; @@ -209,19 +211,25 @@ bool SanitizerBinaryMetadata::run() { getSectionMarker(getSectionStart(MI->SectionSuffix), PtrTy), getSectionMarker(getSectionEnd(MI->SectionSuffix), PtrTy), }; + + // Calls to the initialization functions with different versions cannot be + // merged. Give the structors unique names based on the version, which will + // also be used as the COMDAT key. + const std::string StructorPrefix = (MI->FunctionPrefix + VersionStr).str(); + // We declare the _add and _del functions as weak, and only call them if // there is a valid symbol linked. This allows building binaries with // semantic metadata, but without having callbacks. When a tool that wants // the metadata is linked which provides the callbacks, they will be called. Function *Ctor = createSanitizerCtorAndInitFunctions( - Mod, (MI->FunctionPrefix + ".module_ctor").str(), + Mod, StructorPrefix + ".module_ctor", (MI->FunctionPrefix + "_add").str(), InitTypes, InitArgs, /*VersionCheckName=*/StringRef(), /*Weak=*/ClWeakCallbacks) .first; Function *Dtor = createSanitizerCtorAndInitFunctions( - Mod, (MI->FunctionPrefix + ".module_dtor").str(), + Mod, StructorPrefix + ".module_dtor", (MI->FunctionPrefix + "_del").str(), InitTypes, InitArgs, /*VersionCheckName=*/StringRef(), /*Weak=*/ClWeakCallbacks) .first; @@ -454,15 +462,19 @@ SanitizerBinaryMetadata::getSectionMarker(const Twine &MarkerName, Type *Ty) { StringRef SanitizerBinaryMetadata::getSectionName(StringRef SectionSuffix) { // FIXME: Other TargetTriples. // Request ULEB128 encoding for all integer constants. - return StringPool.save(SectionSuffix + "!C"); + return StringPool.save(SectionSuffix + VersionStr + "!C"); } -Twine SanitizerBinaryMetadata::getSectionStart(StringRef SectionSuffix) { - return "__start_" + SectionSuffix; +StringRef SanitizerBinaryMetadata::getSectionStart(StringRef SectionSuffix) { + // Twine only concatenates 2 strings; with >2 strings, concatenating them + // creates Twine temporaries, and returning the final Twine no longer works + // because we'd end up with a stack-use-after-return. So here we also use the + // StringPool to store the new string. + return StringPool.save("__start_" + SectionSuffix + VersionStr); } -Twine SanitizerBinaryMetadata::getSectionEnd(StringRef SectionSuffix) { - return "__stop_" + SectionSuffix; +StringRef SanitizerBinaryMetadata::getSectionEnd(StringRef SectionSuffix) { + return StringPool.save("__stop_" + SectionSuffix + VersionStr); } } // namespace diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp index 17c1c4423842..6a89cee9aaf6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp @@ -25,6 +25,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/Support/CommandLine.h" @@ -88,15 +89,14 @@ static cl::opt<int> ClCoverageLevel( "sanitizer-coverage-level", cl::desc("Sanitizer Coverage. 0: none, 1: entry block, 2: all blocks, " "3: all blocks and critical edges"), - cl::Hidden, cl::init(0)); + cl::Hidden); static cl::opt<bool> ClTracePC("sanitizer-coverage-trace-pc", - cl::desc("Experimental pc tracing"), cl::Hidden, - cl::init(false)); + cl::desc("Experimental pc tracing"), cl::Hidden); static cl::opt<bool> ClTracePCGuard("sanitizer-coverage-trace-pc-guard", cl::desc("pc tracing with a guard"), - cl::Hidden, cl::init(false)); + cl::Hidden); // If true, we create a global variable that contains PCs of all instrumented // BBs, put this global into a named section, and pass this section's bounds @@ -106,38 +106,38 @@ static cl::opt<bool> ClTracePCGuard("sanitizer-coverage-trace-pc-guard", // inline-bool-flag. static cl::opt<bool> ClCreatePCTable("sanitizer-coverage-pc-table", cl::desc("create a static PC table"), - cl::Hidden, cl::init(false)); + cl::Hidden); static cl::opt<bool> ClInline8bitCounters("sanitizer-coverage-inline-8bit-counters", cl::desc("increments 8-bit counter for every edge"), - cl::Hidden, cl::init(false)); + cl::Hidden); static cl::opt<bool> ClInlineBoolFlag("sanitizer-coverage-inline-bool-flag", - cl::desc("sets a boolean flag for every edge"), cl::Hidden, - cl::init(false)); + cl::desc("sets a boolean flag for every edge"), + cl::Hidden); static cl::opt<bool> ClCMPTracing("sanitizer-coverage-trace-compares", cl::desc("Tracing of CMP and similar instructions"), - cl::Hidden, cl::init(false)); + cl::Hidden); static cl::opt<bool> ClDIVTracing("sanitizer-coverage-trace-divs", cl::desc("Tracing of DIV instructions"), - cl::Hidden, cl::init(false)); + cl::Hidden); static cl::opt<bool> ClLoadTracing("sanitizer-coverage-trace-loads", cl::desc("Tracing of load instructions"), - cl::Hidden, cl::init(false)); + cl::Hidden); static cl::opt<bool> ClStoreTracing("sanitizer-coverage-trace-stores", cl::desc("Tracing of store instructions"), - cl::Hidden, cl::init(false)); + cl::Hidden); static cl::opt<bool> ClGEPTracing("sanitizer-coverage-trace-geps", cl::desc("Tracing of GEP instructions"), - cl::Hidden, cl::init(false)); + cl::Hidden); static cl::opt<bool> ClPruneBlocks("sanitizer-coverage-prune-blocks", @@ -146,12 +146,11 @@ static cl::opt<bool> static cl::opt<bool> ClStackDepth("sanitizer-coverage-stack-depth", cl::desc("max stack depth tracing"), - cl::Hidden, cl::init(false)); + cl::Hidden); static cl::opt<bool> ClCollectCF("sanitizer-coverage-control-flow", - cl::desc("collect control flow for each function"), cl::Hidden, - cl::init(false)); + cl::desc("collect control flow for each function"), cl::Hidden); namespace { @@ -203,25 +202,25 @@ SanitizerCoverageOptions OverrideFromCL(SanitizerCoverageOptions Options) { return Options; } -using DomTreeCallback = function_ref<const DominatorTree *(Function &F)>; -using PostDomTreeCallback = - function_ref<const PostDominatorTree *(Function &F)>; - class ModuleSanitizerCoverage { public: - ModuleSanitizerCoverage( - const SanitizerCoverageOptions &Options = SanitizerCoverageOptions(), - const SpecialCaseList *Allowlist = nullptr, - const SpecialCaseList *Blocklist = nullptr) - : Options(OverrideFromCL(Options)), Allowlist(Allowlist), - Blocklist(Blocklist) {} - bool instrumentModule(Module &M, DomTreeCallback DTCallback, - PostDomTreeCallback PDTCallback); + using DomTreeCallback = function_ref<const DominatorTree &(Function &F)>; + using PostDomTreeCallback = + function_ref<const PostDominatorTree &(Function &F)>; + + ModuleSanitizerCoverage(Module &M, DomTreeCallback DTCallback, + PostDomTreeCallback PDTCallback, + const SanitizerCoverageOptions &Options, + const SpecialCaseList *Allowlist, + const SpecialCaseList *Blocklist) + : M(M), DTCallback(DTCallback), PDTCallback(PDTCallback), + Options(Options), Allowlist(Allowlist), Blocklist(Blocklist) {} + + bool instrumentModule(); private: void createFunctionControlFlow(Function &F); - void instrumentFunction(Function &F, DomTreeCallback DTCallback, - PostDomTreeCallback PDTCallback); + void instrumentFunction(Function &F); void InjectCoverageForIndirectCalls(Function &F, ArrayRef<Instruction *> IndirCalls); void InjectTraceForCmp(Function &F, ArrayRef<Instruction *> CmpTraceTargets); @@ -251,6 +250,11 @@ private: std::string getSectionName(const std::string &Section) const; std::string getSectionStart(const std::string &Section) const; std::string getSectionEnd(const std::string &Section) const; + + Module &M; + DomTreeCallback DTCallback; + PostDomTreeCallback PDTCallback; + FunctionCallee SanCovTracePCIndir; FunctionCallee SanCovTracePC, SanCovTracePCGuard; std::array<FunctionCallee, 4> SanCovTraceCmpFunction; @@ -285,16 +289,17 @@ private: PreservedAnalyses SanitizerCoveragePass::run(Module &M, ModuleAnalysisManager &MAM) { - ModuleSanitizerCoverage ModuleSancov(Options, Allowlist.get(), - Blocklist.get()); auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); - auto DTCallback = [&FAM](Function &F) -> const DominatorTree * { - return &FAM.getResult<DominatorTreeAnalysis>(F); + auto DTCallback = [&FAM](Function &F) -> const DominatorTree & { + return FAM.getResult<DominatorTreeAnalysis>(F); }; - auto PDTCallback = [&FAM](Function &F) -> const PostDominatorTree * { - return &FAM.getResult<PostDominatorTreeAnalysis>(F); + auto PDTCallback = [&FAM](Function &F) -> const PostDominatorTree & { + return FAM.getResult<PostDominatorTreeAnalysis>(F); }; - if (!ModuleSancov.instrumentModule(M, DTCallback, PDTCallback)) + ModuleSanitizerCoverage ModuleSancov(M, DTCallback, PDTCallback, + OverrideFromCL(Options), Allowlist.get(), + Blocklist.get()); + if (!ModuleSancov.instrumentModule()) return PreservedAnalyses::all(); PreservedAnalyses PA = PreservedAnalyses::none(); @@ -365,8 +370,7 @@ Function *ModuleSanitizerCoverage::CreateInitCallsForSections( return CtorFunc; } -bool ModuleSanitizerCoverage::instrumentModule( - Module &M, DomTreeCallback DTCallback, PostDomTreeCallback PDTCallback) { +bool ModuleSanitizerCoverage::instrumentModule() { if (Options.CoverageType == SanitizerCoverageOptions::SCK_None) return false; if (Allowlist && @@ -479,7 +483,7 @@ bool ModuleSanitizerCoverage::instrumentModule( M.getOrInsertFunction(SanCovTracePCGuardName, VoidTy, PtrTy); for (auto &F : M) - instrumentFunction(F, DTCallback, PDTCallback); + instrumentFunction(F); Function *Ctor = nullptr; @@ -518,29 +522,29 @@ bool ModuleSanitizerCoverage::instrumentModule( } // True if block has successors and it dominates all of them. -static bool isFullDominator(const BasicBlock *BB, const DominatorTree *DT) { +static bool isFullDominator(const BasicBlock *BB, const DominatorTree &DT) { if (succ_empty(BB)) return false; return llvm::all_of(successors(BB), [&](const BasicBlock *SUCC) { - return DT->dominates(BB, SUCC); + return DT.dominates(BB, SUCC); }); } // True if block has predecessors and it postdominates all of them. static bool isFullPostDominator(const BasicBlock *BB, - const PostDominatorTree *PDT) { + const PostDominatorTree &PDT) { if (pred_empty(BB)) return false; return llvm::all_of(predecessors(BB), [&](const BasicBlock *PRED) { - return PDT->dominates(BB, PRED); + return PDT.dominates(BB, PRED); }); } static bool shouldInstrumentBlock(const Function &F, const BasicBlock *BB, - const DominatorTree *DT, - const PostDominatorTree *PDT, + const DominatorTree &DT, + const PostDominatorTree &PDT, const SanitizerCoverageOptions &Options) { // Don't insert coverage for blocks containing nothing but unreachable: we // will never call __sanitizer_cov() for them, so counting them in @@ -568,17 +572,16 @@ static bool shouldInstrumentBlock(const Function &F, const BasicBlock *BB, && !(isFullPostDominator(BB, PDT) && !BB->getSinglePredecessor()); } - // Returns true iff From->To is a backedge. // A twist here is that we treat From->To as a backedge if // * To dominates From or // * To->UniqueSuccessor dominates From static bool IsBackEdge(BasicBlock *From, BasicBlock *To, - const DominatorTree *DT) { - if (DT->dominates(To, From)) + const DominatorTree &DT) { + if (DT.dominates(To, From)) return true; if (auto Next = To->getUniqueSuccessor()) - if (DT->dominates(Next, From)) + if (DT.dominates(Next, From)) return true; return false; } @@ -588,7 +591,7 @@ static bool IsBackEdge(BasicBlock *From, BasicBlock *To, // // Note that Cmp pruning is controlled by the same flag as the // BB pruning. -static bool IsInterestingCmp(ICmpInst *CMP, const DominatorTree *DT, +static bool IsInterestingCmp(ICmpInst *CMP, const DominatorTree &DT, const SanitizerCoverageOptions &Options) { if (!Options.NoPrune) if (CMP->hasOneUse()) @@ -599,8 +602,7 @@ static bool IsInterestingCmp(ICmpInst *CMP, const DominatorTree *DT, return true; } -void ModuleSanitizerCoverage::instrumentFunction( - Function &F, DomTreeCallback DTCallback, PostDomTreeCallback PDTCallback) { +void ModuleSanitizerCoverage::instrumentFunction(Function &F) { if (F.empty()) return; if (F.getName().contains(".module_ctor")) @@ -629,8 +631,12 @@ void ModuleSanitizerCoverage::instrumentFunction( return; if (F.hasFnAttribute(Attribute::NoSanitizeCoverage)) return; - if (Options.CoverageType >= SanitizerCoverageOptions::SCK_Edge) - SplitAllCriticalEdges(F, CriticalEdgeSplittingOptions().setIgnoreUnreachableDests()); + if (F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation)) + return; + if (Options.CoverageType >= SanitizerCoverageOptions::SCK_Edge) { + SplitAllCriticalEdges( + F, CriticalEdgeSplittingOptions().setIgnoreUnreachableDests()); + } SmallVector<Instruction *, 8> IndirCalls; SmallVector<BasicBlock *, 16> BlocksToInstrument; SmallVector<Instruction *, 8> CmpTraceTargets; @@ -640,8 +646,8 @@ void ModuleSanitizerCoverage::instrumentFunction( SmallVector<LoadInst *, 8> Loads; SmallVector<StoreInst *, 8> Stores; - const DominatorTree *DT = DTCallback(F); - const PostDominatorTree *PDT = PDTCallback(F); + const DominatorTree &DT = DTCallback(F); + const PostDominatorTree &PDT = PDTCallback(F); bool IsLeafFunc = true; for (auto &BB : F) { @@ -979,8 +985,9 @@ void ModuleSanitizerCoverage::InjectCoverageAtBlock(Function &F, BasicBlock &BB, FunctionBoolArray->getValueType(), FunctionBoolArray, {ConstantInt::get(IntptrTy, 0), ConstantInt::get(IntptrTy, Idx)}); auto Load = IRB.CreateLoad(Int1Ty, FlagPtr); - auto ThenTerm = - SplitBlockAndInsertIfThen(IRB.CreateIsNull(Load), &*IP, false); + auto ThenTerm = SplitBlockAndInsertIfThen( + IRB.CreateIsNull(Load), &*IP, false, + MDBuilder(IRB.getContext()).createUnlikelyBranchWeights()); IRBuilder<> ThenIRB(ThenTerm); auto Store = ThenIRB.CreateStore(ConstantInt::getTrue(Int1Ty), FlagPtr); Load->setNoSanitizeMetadata(); @@ -997,7 +1004,9 @@ void ModuleSanitizerCoverage::InjectCoverageAtBlock(Function &F, BasicBlock &BB, auto FrameAddrInt = IRB.CreatePtrToInt(FrameAddrPtr, IntptrTy); auto LowestStack = IRB.CreateLoad(IntptrTy, SanCovLowestStack); auto IsStackLower = IRB.CreateICmpULT(FrameAddrInt, LowestStack); - auto ThenTerm = SplitBlockAndInsertIfThen(IsStackLower, &*IP, false); + auto ThenTerm = SplitBlockAndInsertIfThen( + IsStackLower, &*IP, false, + MDBuilder(IRB.getContext()).createUnlikelyBranchWeights()); IRBuilder<> ThenIRB(ThenTerm); auto Store = ThenIRB.CreateStore(FrameAddrInt, SanCovLowestStack); LowestStack->setNoSanitizeMetadata(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp index 0f42ff790869..92e533d2281a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp @@ -43,7 +43,6 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/EscapeEnumerator.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -512,7 +511,7 @@ bool ThreadSanitizer::sanitizeFunction(Function &F, bool Res = false; bool HasCalls = false; bool SanitizeFunction = F.hasFnAttribute(Attribute::SanitizeThread); - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); // Traverse all instructions, collect loads/stores/returns, check for calls. for (auto &BB : F) { @@ -738,8 +737,8 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { Value *Args[] = {Addr, IRB.CreateBitOrPointerCast(SI->getValueOperand(), Ty), createOrdering(&IRB, SI->getOrdering())}; - CallInst *C = CallInst::Create(TsanAtomicStore[Idx], Args); - ReplaceInstWithInst(I, C); + IRB.CreateCall(TsanAtomicStore[Idx], Args); + SI->eraseFromParent(); } else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) { Value *Addr = RMWI->getPointerOperand(); int Idx = @@ -795,8 +794,8 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { FunctionCallee F = FI->getSyncScopeID() == SyncScope::SingleThread ? TsanAtomicSignalFence : TsanAtomicThreadFence; - CallInst *C = CallInst::Create(F, Args); - ReplaceInstWithInst(I, C); + IRB.CreateCall(F, Args); + FI->eraseFromParent(); } return true; } @@ -804,6 +803,10 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { int ThreadSanitizer::getMemoryAccessFuncIndex(Type *OrigTy, Value *Addr, const DataLayout &DL) { assert(OrigTy->isSized()); + if (OrigTy->isScalableTy()) { + // FIXME: support vscale. + return -1; + } uint32_t TypeSize = DL.getTypeStoreSizeInBits(OrigTy); if (TypeSize != 8 && TypeSize != 16 && TypeSize != 32 && TypeSize != 64 && TypeSize != 128) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ValueProfilePlugins.inc b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ValueProfilePlugins.inc index 3a129de1acd0..b47ef8523ea1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ValueProfilePlugins.inc +++ b/contrib/llvm-project/llvm/lib/Transforms/Instrumentation/ValueProfilePlugins.inc @@ -90,9 +90,38 @@ public: } }; +///--------------------- VirtualTableValueProfilingPlugin -------------------- +class VTableProfilingPlugin { + Function &F; + +public: + static constexpr InstrProfValueKind Kind = IPVK_VTableTarget; + + VTableProfilingPlugin(Function &Fn, TargetLibraryInfo &TLI) : F(Fn) {} + + void run(std::vector<CandidateInfo> &Candidates) { + std::vector<Instruction *> Result = findVTableAddrs(F); + for (Instruction *I : Result) { + Instruction *InsertPt = I->getNextNonDebugInstruction(); + // When finding an insertion point, keep PHI and EH pad instructions + // before vp intrinsics. This is similar to + // `BasicBlock::getFirstInsertionPt`. + while (InsertPt && (dyn_cast<PHINode>(InsertPt) || InsertPt->isEHPad())) + InsertPt = InsertPt->getNextNonDebugInstruction(); + // Skip instrumentating the value if InsertPt is the last instruction. + // FIXME: Set InsertPt to the end of basic block to instrument the value + // if InsertPt is the last instruction. + if (InsertPt == nullptr) + continue; + + Instruction *AnnotatedInst = I; + Candidates.emplace_back(CandidateInfo{I, InsertPt, AnnotatedInst}); + } + } +}; + ///----------------------- Registration of the plugins ------------------------- /// For now, registering a plugin with the ValueProfileCollector is done by /// adding the plugin type to the VP_PLUGIN_LIST macro. -#define VP_PLUGIN_LIST \ - MemIntrinsicPlugin, \ - IndirectCallPromotionPlugin +#define VP_PLUGIN_LIST \ + MemIntrinsicPlugin, IndirectCallPromotionPlugin, VTableProfilingPlugin diff --git a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.cpp b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.cpp index 7af9c39f8236..b4cc00033e72 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/DependencyAnalysis.cpp @@ -220,16 +220,13 @@ static bool findDependencies(DependenceKind Flavor, const Value *Arg, BasicBlock::iterator StartBBBegin = LocalStartBB->begin(); for (;;) { if (LocalStartPos == StartBBBegin) { - pred_iterator PI(LocalStartBB), PE(LocalStartBB, false); - if (PI == PE) + if (pred_empty(LocalStartBB)) // Return if we've reached the function entry. return false; // Add the predecessors to the worklist. - do { - BasicBlock *PredBB = *PI; + for (BasicBlock *PredBB : predecessors(LocalStartBB)) if (Visited.insert(PredBB).second) Worklist.push_back(std::make_pair(PredBB, PredBB->end())); - } while (++PI != PE); break; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.cpp b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.cpp index 02f9db719e26..33870d7ea192 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.cpp @@ -23,7 +23,7 @@ using namespace llvm::objcarc; CallInst *objcarc::createCallInstWithColors( FunctionCallee Func, ArrayRef<Value *> Args, const Twine &NameStr, - Instruction *InsertBefore, + BasicBlock::iterator InsertBefore, const DenseMap<BasicBlock *, ColorVector> &BlockColors) { FunctionType *FTy = Func.getFunctionType(); Value *Callee = Func.getCallee(); @@ -64,23 +64,23 @@ BundledRetainClaimRVs::insertAfterInvokes(Function &F, DominatorTree *DT) { // We don't have to call insertRVCallWithColors since DestBB is the normal // destination of the invoke. - insertRVCall(&*DestBB->getFirstInsertionPt(), I); + insertRVCall(DestBB->getFirstInsertionPt(), I); Changed = true; } return std::make_pair(Changed, CFGChanged); } -CallInst *BundledRetainClaimRVs::insertRVCall(Instruction *InsertPt, +CallInst *BundledRetainClaimRVs::insertRVCall(BasicBlock::iterator InsertPt, CallBase *AnnotatedCall) { DenseMap<BasicBlock *, ColorVector> BlockColors; return insertRVCallWithColors(InsertPt, AnnotatedCall, BlockColors); } CallInst *BundledRetainClaimRVs::insertRVCallWithColors( - Instruction *InsertPt, CallBase *AnnotatedCall, + BasicBlock::iterator InsertPt, CallBase *AnnotatedCall, const DenseMap<BasicBlock *, ColorVector> &BlockColors) { - IRBuilder<> Builder(InsertPt); + IRBuilder<> Builder(InsertPt->getParent(), InsertPt); Function *Func = *objcarc::getAttachedARCFunction(AnnotatedCall); assert(Func && "operand isn't a Function"); Type *ParamTy = Func->getArg(0)->getType(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.h b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.h index 9e68bd574851..f4d7c92d499c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.h +++ b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARC.h @@ -99,7 +99,7 @@ static inline MDString *getRVInstMarker(Module &M) { /// going to be removed from the IR before WinEHPrepare. CallInst *createCallInstWithColors( FunctionCallee Func, ArrayRef<Value *> Args, const Twine &NameStr, - Instruction *InsertBefore, + BasicBlock::iterator InsertBefore, const DenseMap<BasicBlock *, ColorVector> &BlockColors); class BundledRetainClaimRVs { @@ -113,11 +113,12 @@ public: std::pair<bool, bool> insertAfterInvokes(Function &F, DominatorTree *DT); /// Insert a retainRV/claimRV call. - CallInst *insertRVCall(Instruction *InsertPt, CallBase *AnnotatedCall); + CallInst *insertRVCall(BasicBlock::iterator InsertPt, + CallBase *AnnotatedCall); /// Insert a retainRV/claimRV call with colors. CallInst *insertRVCallWithColors( - Instruction *InsertPt, CallBase *AnnotatedCall, + BasicBlock::iterator InsertPt, CallBase *AnnotatedCall, const DenseMap<BasicBlock *, ColorVector> &BlockColors); /// See if an instruction is a bundled retainRV/claimRV call. @@ -140,7 +141,8 @@ public: } auto *NewCall = CallBase::removeOperandBundle( - It->second, LLVMContext::OB_clang_arc_attachedcall, It->second); + It->second, LLVMContext::OB_clang_arc_attachedcall, + It->second->getIterator()); NewCall->copyMetadata(*It->second); It->second->replaceAllUsesWith(NewCall); It->second->eraseFromParent(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp index c397ab63f388..0d0f5c72928a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp @@ -382,12 +382,12 @@ void ObjCARCContract::tryToContractReleaseIntoStoreStrong( Value *Args[] = { Load->getPointerOperand(), New }; if (Args[0]->getType() != I8XX) - Args[0] = new BitCastInst(Args[0], I8XX, "", Store); + Args[0] = new BitCastInst(Args[0], I8XX, "", Store->getIterator()); if (Args[1]->getType() != I8X) - Args[1] = new BitCastInst(Args[1], I8X, "", Store); + Args[1] = new BitCastInst(Args[1], I8X, "", Store->getIterator()); Function *Decl = EP.get(ARCRuntimeEntryPointKind::StoreStrong); - CallInst *StoreStrong = - objcarc::createCallInstWithColors(Decl, Args, "", Store, BlockColors); + CallInst *StoreStrong = objcarc::createCallInstWithColors( + Decl, Args, "", Store->getIterator(), BlockColors); StoreStrong->setDoesNotThrow(); StoreStrong->setDebugLoc(Store->getDebugLoc()); @@ -472,8 +472,8 @@ bool ObjCARCContract::tryToPeepholeInstruction( RVInstMarker->getString(), /*Constraints=*/"", /*hasSideEffects=*/true); - objcarc::createCallInstWithColors(IA, std::nullopt, "", Inst, - BlockColors); + objcarc::createCallInstWithColors(IA, std::nullopt, "", + Inst->getIterator(), BlockColors); } decline_rv_optimization: return false; @@ -484,7 +484,7 @@ bool ObjCARCContract::tryToPeepholeInstruction( if (IsNullOrUndef(CI->getArgOperand(1))) { Value *Null = ConstantPointerNull::get(cast<PointerType>(CI->getType())); Changed = true; - new StoreInst(Null, CI->getArgOperand(0), CI); + new StoreInst(Null, CI->getArgOperand(0), CI->getIterator()); LLVM_DEBUG(dbgs() << "OBJCARCContract: Old = " << *CI << "\n" << " New = " << *Null << "\n"); @@ -575,7 +575,7 @@ bool ObjCARCContract::run(Function &F, AAResults *A, DominatorTree *D) { if (auto *CI = dyn_cast<CallInst>(Inst)) if (objcarc::hasAttachedCallOpBundle(CI)) { - BundledInsts->insertRVCallWithColors(&*I, CI, BlockColors); + BundledInsts->insertRVCallWithColors(I->getIterator(), CI, BlockColors); --I; Changed = true; } @@ -631,8 +631,8 @@ bool ObjCARCContract::run(Function &F, AAResults *A, DominatorTree *D) { assert(DT->dominates(Inst, &InsertBB->back()) && "Invalid insertion point for bitcast"); - Replacement = - new BitCastInst(Replacement, UseTy, "", &InsertBB->back()); + Replacement = new BitCastInst(Replacement, UseTy, "", + InsertBB->back().getIterator()); } // While we're here, rewrite all edges for this PHI, rather @@ -649,8 +649,9 @@ bool ObjCARCContract::run(Function &F, AAResults *A, DominatorTree *D) { } } else { if (Replacement->getType() != UseTy) - Replacement = new BitCastInst(Replacement, UseTy, "", - cast<Instruction>(U.getUser())); + Replacement = + new BitCastInst(Replacement, UseTy, "", + cast<Instruction>(U.getUser())->getIterator()); U.set(Replacement); } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp index b51e4d46bffe..72e860d7dcfa 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp @@ -693,8 +693,9 @@ bool ObjCARCOpt::OptimizeInlinedAutoreleaseRVCall( // AutoreleaseRV and RetainRV cancel out, replace UnsafeClaimRV with Release. assert(Class == ARCInstKind::UnsafeClaimRV); Value *CallArg = cast<CallInst>(Inst)->getArgOperand(0); - CallInst *Release = CallInst::Create( - EP.get(ARCRuntimeEntryPointKind::Release), CallArg, "", Inst); + CallInst *Release = + CallInst::Create(EP.get(ARCRuntimeEntryPointKind::Release), CallArg, "", + Inst->getIterator()); assert(IsAlwaysTail(ARCInstKind::UnsafeClaimRV) && "Expected UnsafeClaimRV to be safe to tail call"); Release->setTailCall(); @@ -808,7 +809,7 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { if (auto *CI = dyn_cast<CallInst>(Inst)) if (objcarc::hasAttachedCallOpBundle(CI)) { - BundledInsts->insertRVCall(&*I, CI); + BundledInsts->insertRVCall(I->getIterator(), CI); Changed = true; } @@ -934,7 +935,7 @@ void ObjCARCOpt::OptimizeIndividualCallImpl(Function &F, Instruction *Inst, Changed = true; new StoreInst(ConstantInt::getTrue(CI->getContext()), PoisonValue::get(PointerType::getUnqual(CI->getContext())), - CI); + CI->getIterator()); Value *NewValue = PoisonValue::get(CI->getType()); LLVM_DEBUG( dbgs() << "A null pointer-to-weak-pointer is undefined behavior." @@ -954,7 +955,7 @@ void ObjCARCOpt::OptimizeIndividualCallImpl(Function &F, Instruction *Inst, Changed = true; new StoreInst(ConstantInt::getTrue(CI->getContext()), PoisonValue::get(PointerType::getUnqual(CI->getContext())), - CI); + CI->getIterator()); Value *NewValue = PoisonValue::get(CI->getType()); LLVM_DEBUG( @@ -990,8 +991,8 @@ void ObjCARCOpt::OptimizeIndividualCallImpl(Function &F, Instruction *Inst, LLVMContext &C = Inst->getContext(); Function *Decl = EP.get(ARCRuntimeEntryPointKind::Release); - CallInst *NewCall = - CallInst::Create(Decl, Call->getArgOperand(0), "", Call); + CallInst *NewCall = CallInst::Create(Decl, Call->getArgOperand(0), "", + Call->getIterator()); NewCall->setMetadata(MDKindCache.get(ARCMDKindID::ImpreciseRelease), MDNode::get(C, std::nullopt)); @@ -1143,7 +1144,8 @@ void ObjCARCOpt::OptimizeIndividualCallImpl(Function &F, Instruction *Inst, if (IsNullOrUndef(Incoming)) continue; Value *Op = PN->getIncomingValue(i); - Instruction *InsertPos = &PN->getIncomingBlock(i)->back(); + BasicBlock::iterator InsertPos = + PN->getIncomingBlock(i)->back().getIterator(); SmallVector<OperandBundleDef, 1> OpBundles; cloneOpBundlesIf(CInst, OpBundles, [](const OperandBundleUse &B) { return B.getTagID() != LLVMContext::OB_funclet; @@ -1153,7 +1155,7 @@ void ObjCARCOpt::OptimizeIndividualCallImpl(Function &F, Instruction *Inst, if (Op->getType() != ParamTy) Op = new BitCastInst(Op, ParamTy, "", InsertPos); Clone->setArgOperand(0, Op); - Clone->insertBefore(InsertPos); + Clone->insertBefore(*InsertPos->getParent(), InsertPos); LLVM_DEBUG(dbgs() << "Cloning " << *CInst << "\n" "And inserting clone at " @@ -1768,12 +1770,14 @@ void ObjCARCOpt::MoveCalls(Value *Arg, RRInfo &RetainsToMove, // Insert the new retain and release calls. for (Instruction *InsertPt : ReleasesToMove.ReverseInsertPts) { - Value *MyArg = ArgTy == ParamTy ? Arg : - new BitCastInst(Arg, ParamTy, "", InsertPt); + Value *MyArg = ArgTy == ParamTy ? Arg + : new BitCastInst(Arg, ParamTy, "", + InsertPt->getIterator()); Function *Decl = EP.get(ARCRuntimeEntryPointKind::Retain); SmallVector<OperandBundleDef, 1> BundleList; addOpBundleForFunclet(InsertPt->getParent(), BundleList); - CallInst *Call = CallInst::Create(Decl, MyArg, BundleList, "", InsertPt); + CallInst *Call = + CallInst::Create(Decl, MyArg, BundleList, "", InsertPt->getIterator()); Call->setDoesNotThrow(); Call->setTailCall(); @@ -1783,12 +1787,14 @@ void ObjCARCOpt::MoveCalls(Value *Arg, RRInfo &RetainsToMove, << *InsertPt << "\n"); } for (Instruction *InsertPt : RetainsToMove.ReverseInsertPts) { - Value *MyArg = ArgTy == ParamTy ? Arg : - new BitCastInst(Arg, ParamTy, "", InsertPt); + Value *MyArg = ArgTy == ParamTy ? Arg + : new BitCastInst(Arg, ParamTy, "", + InsertPt->getIterator()); Function *Decl = EP.get(ARCRuntimeEntryPointKind::Release); SmallVector<OperandBundleDef, 1> BundleList; addOpBundleForFunclet(InsertPt->getParent(), BundleList); - CallInst *Call = CallInst::Create(Decl, MyArg, BundleList, "", InsertPt); + CallInst *Call = + CallInst::Create(Decl, MyArg, BundleList, "", InsertPt->getIterator()); // Attach a clang.imprecise_release metadata tag, if appropriate. if (MDNode *M = ReleasesToMove.ReleaseMetadata) Call->setMetadata(MDKindCache.get(ARCMDKindID::ImpreciseRelease), M); @@ -2125,7 +2131,8 @@ void ObjCARCOpt::OptimizeWeakCalls(Function &F) { // If the load has a builtin retain, insert a plain retain for it. if (Class == ARCInstKind::LoadWeakRetained) { Function *Decl = EP.get(ARCRuntimeEntryPointKind::Retain); - CallInst *CI = CallInst::Create(Decl, EarlierCall, "", Call); + CallInst *CI = + CallInst::Create(Decl, EarlierCall, "", Call->getIterator()); CI->setTailCall(); } // Zap the fully redundant load. @@ -2154,7 +2161,8 @@ void ObjCARCOpt::OptimizeWeakCalls(Function &F) { // If the load has a builtin retain, insert a plain retain for it. if (Class == ARCInstKind::LoadWeakRetained) { Function *Decl = EP.get(ARCRuntimeEntryPointKind::Retain); - CallInst *CI = CallInst::Create(Decl, EarlierCall, "", Call); + CallInst *CI = + CallInst::Create(Decl, EarlierCall, "", Call->getIterator()); CI->setTailCall(); } // Zap the fully redundant load. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ADCE.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ADCE.cpp index 90b544c89226..5f0a9b22c3ee 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ADCE.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ADCE.cpp @@ -350,7 +350,7 @@ bool AggressiveDeadCodeElimination::isInstrumentsConstant(Instruction &I) { // TODO -- move this test into llvm::isInstructionTriviallyDead if (CallInst *CI = dyn_cast<CallInst>(&I)) if (Function *Callee = CI->getCalledFunction()) - if (Callee->getName().equals(getInstrProfValueProfFuncName())) + if (Callee->getName() == getInstrProfValueProfFuncName()) if (isa<Constant>(CI->getArgOperand(0))) return true; return false; @@ -544,19 +544,20 @@ ADCEChanged AggressiveDeadCodeElimination::removeDeadInstructions() { // value of the function, and may therefore be deleted safely. // NOTE: We reuse the Worklist vector here for memory efficiency. for (Instruction &I : llvm::reverse(instructions(F))) { - // With "RemoveDIs" debug-info stored in DPValue objects, debug-info - // attached to this instruction, and drop any for scopes that aren't alive, - // like the rest of this loop does. Extending support to assignment tracking - // is future work. - for (DPValue &DPV : make_early_inc_range(I.getDbgValueRange())) { - // Avoid removing a DPV that is linked to instructions because it holds + // With "RemoveDIs" debug-info stored in DbgVariableRecord objects, + // debug-info attached to this instruction, and drop any for scopes that + // aren't alive, like the rest of this loop does. Extending support to + // assignment tracking is future work. + for (DbgRecord &DR : make_early_inc_range(I.getDbgRecordRange())) { + // Avoid removing a DVR that is linked to instructions because it holds // information about an existing store. - if (DPV.isDbgAssign()) - if (!at::getAssignmentInsts(&DPV).empty()) + if (DbgVariableRecord *DVR = dyn_cast<DbgVariableRecord>(&DR); + DVR && DVR->isDbgAssign()) + if (!at::getAssignmentInsts(DVR).empty()) continue; - if (AliveScopes.count(DPV.getDebugLoc()->getScope())) + if (AliveScopes.count(DR.getDebugLoc()->getScope())) continue; - I.dropOneDbgValue(&DPV); + I.dropOneDbgRecord(&DR); } // Check if the instruction is alive. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp index b182f46cc515..5d9a7bca7efe 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/AnnotationRemarks.cpp @@ -33,7 +33,7 @@ static void tryEmitAutoInitRemark(ArrayRef<Instruction *> Instructions, continue; Function &F = *I->getParent()->getParent(); - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); AutoInitRemark Remark(ORE, REMARK_PASS, DL, TLI); Remark.visit(I); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/BDCE.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/BDCE.cpp index 1fa2c75b0f42..d96dbca30fdb 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/BDCE.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/BDCE.cpp @@ -23,10 +23,13 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/Local.h" + using namespace llvm; +using namespace PatternMatch; #define DEBUG_TYPE "bdce" @@ -42,15 +45,17 @@ static void clearAssumptionsOfUsers(Instruction *I, DemandedBits &DB) { assert(I->getType()->isIntOrIntVectorTy() && "Trivializing a non-integer value?"); + // If all bits of a user are demanded, then we know that nothing below that + // in the def-use chain needs to be changed. + if (DB.getDemandedBits(I).isAllOnes()) + return; + // Initialize the worklist with eligible direct users. SmallPtrSet<Instruction *, 16> Visited; SmallVector<Instruction *, 16> WorkList; for (User *JU : I->users()) { - // If all bits of a user are demanded, then we know that nothing below that - // in the def-use chain needs to be changed. - auto *J = dyn_cast<Instruction>(JU); - if (J && J->getType()->isIntOrIntVectorTy() && - !DB.getDemandedBits(J).isAllOnes()) { + auto *J = cast<Instruction>(JU); + if (J->getType()->isIntOrIntVectorTy()) { Visited.insert(J); WorkList.push_back(J); } @@ -70,18 +75,19 @@ static void clearAssumptionsOfUsers(Instruction *I, DemandedBits &DB) { Instruction *J = WorkList.pop_back_val(); // NSW, NUW, and exact are based on operands that might have changed. - J->dropPoisonGeneratingFlags(); + J->dropPoisonGeneratingAnnotations(); - // We do not have to worry about llvm.assume or range metadata: - // 1. llvm.assume demands its operand, so trivializing can't change it. - // 2. range metadata only applies to memory accesses which demand all bits. + // We do not have to worry about llvm.assume, because it demands its + // operand, so trivializing can't change it. + + // If all bits of a user are demanded, then we know that nothing below + // that in the def-use chain needs to be changed. + if (DB.getDemandedBits(J).isAllOnes()) + continue; for (User *KU : J->users()) { - // If all bits of a user are demanded, then we know that nothing below - // that in the def-use chain needs to be changed. - auto *K = dyn_cast<Instruction>(KU); - if (K && Visited.insert(K).second && K->getType()->isIntOrIntVectorTy() && - !DB.getDemandedBits(K).isAllOnes()) + auto *K = cast<Instruction>(KU); + if (Visited.insert(K).second && K->getType()->isIntOrIntVectorTy()) WorkList.push_back(K); } } @@ -125,6 +131,38 @@ static bool bitTrackingDCE(Function &F, DemandedBits &DB) { } } + // Simplify and, or, xor when their mask does not affect the demanded bits. + if (auto *BO = dyn_cast<BinaryOperator>(&I)) { + APInt Demanded = DB.getDemandedBits(BO); + if (!Demanded.isAllOnes()) { + const APInt *Mask; + if (match(BO->getOperand(1), m_APInt(Mask))) { + bool CanBeSimplified = false; + switch (BO->getOpcode()) { + case Instruction::Or: + case Instruction::Xor: + CanBeSimplified = !Demanded.intersects(*Mask); + break; + case Instruction::And: + CanBeSimplified = Demanded.isSubsetOf(*Mask); + break; + default: + // TODO: Handle more cases here. + break; + } + + if (CanBeSimplified) { + clearAssumptionsOfUsers(BO, DB); + BO->replaceAllUsesWith(BO->getOperand(0)); + Worklist.push_back(BO); + ++NumSimplified; + Changed = true; + continue; + } + } + } + } + for (Use &U : I.operands()) { // DemandedBits only detects dead integer uses. if (!U->getType()->isIntOrIntVectorTy()) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp index 47f663fa0cf0..b8571ba07489 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -403,7 +403,7 @@ static void splitCallSite(CallBase &CB, NewPN->insertBefore(*TailBB, TailBB->begin()); CurrentI->replaceAllUsesWith(NewPN); } - CurrentI->dropDbgValues(); + CurrentI->dropDbgRecords(); CurrentI->eraseFromParent(); // We are done once we handled the first original instruction in TailBB. if (CurrentI == OriginalBeginInst) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp index 49f8761a1392..4a6dedc93d30 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -43,6 +43,7 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" @@ -162,27 +163,27 @@ bool ConstantHoistingLegacyPass::runOnFunction(Function &Fn) { void ConstantHoistingPass::collectMatInsertPts( const RebasedConstantListType &RebasedConstants, - SmallVectorImpl<Instruction *> &MatInsertPts) const { + SmallVectorImpl<BasicBlock::iterator> &MatInsertPts) const { for (const RebasedConstantInfo &RCI : RebasedConstants) for (const ConstantUser &U : RCI.Uses) MatInsertPts.emplace_back(findMatInsertPt(U.Inst, U.OpndIdx)); } /// Find the constant materialization insertion point. -Instruction *ConstantHoistingPass::findMatInsertPt(Instruction *Inst, - unsigned Idx) const { +BasicBlock::iterator ConstantHoistingPass::findMatInsertPt(Instruction *Inst, + unsigned Idx) const { // If the operand is a cast instruction, then we have to materialize the // constant before the cast instruction. if (Idx != ~0U) { Value *Opnd = Inst->getOperand(Idx); if (auto CastInst = dyn_cast<Instruction>(Opnd)) if (CastInst->isCast()) - return CastInst; + return CastInst->getIterator(); } // The simple and common case. This also includes constant expressions. if (!isa<PHINode>(Inst) && !Inst->isEHPad()) - return Inst; + return Inst->getIterator(); // We can't insert directly before a phi node or an eh pad. Insert before // the terminator of the incoming or dominating block. @@ -191,7 +192,7 @@ Instruction *ConstantHoistingPass::findMatInsertPt(Instruction *Inst, if (Idx != ~0U && isa<PHINode>(Inst)) { InsertionBlock = cast<PHINode>(Inst)->getIncomingBlock(Idx); if (!InsertionBlock->isEHPad()) { - return InsertionBlock->getTerminator(); + return InsertionBlock->getTerminator()->getIterator(); } } else { InsertionBlock = Inst->getParent(); @@ -206,7 +207,7 @@ Instruction *ConstantHoistingPass::findMatInsertPt(Instruction *Inst, IDom = IDom->getIDom(); } - return IDom->getBlock()->getTerminator(); + return IDom->getBlock()->getTerminator()->getIterator(); } /// Given \p BBs as input, find another set of BBs which collectively @@ -314,26 +315,27 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI, } /// Find an insertion point that dominates all uses. -SetVector<Instruction *> ConstantHoistingPass::findConstantInsertionPoint( +SetVector<BasicBlock::iterator> +ConstantHoistingPass::findConstantInsertionPoint( const ConstantInfo &ConstInfo, - const ArrayRef<Instruction *> MatInsertPts) const { + const ArrayRef<BasicBlock::iterator> MatInsertPts) const { assert(!ConstInfo.RebasedConstants.empty() && "Invalid constant info entry."); // Collect all basic blocks. SetVector<BasicBlock *> BBs; - SetVector<Instruction *> InsertPts; + SetVector<BasicBlock::iterator> InsertPts; - for (Instruction *MatInsertPt : MatInsertPts) + for (BasicBlock::iterator MatInsertPt : MatInsertPts) BBs.insert(MatInsertPt->getParent()); if (BBs.count(Entry)) { - InsertPts.insert(&Entry->front()); + InsertPts.insert(Entry->begin()); return InsertPts; } if (BFI) { findBestInsertionSet(*DT, *BFI, Entry, BBs); for (BasicBlock *BB : BBs) - InsertPts.insert(&*BB->getFirstInsertionPt()); + InsertPts.insert(BB->getFirstInsertionPt()); return InsertPts; } @@ -343,7 +345,7 @@ SetVector<Instruction *> ConstantHoistingPass::findConstantInsertionPoint( BB2 = BBs.pop_back_val(); BB = DT->findNearestCommonDominator(BB1, BB2); if (BB == Entry) { - InsertPts.insert(&Entry->front()); + InsertPts.insert(Entry->begin()); return InsertPts; } BBs.insert(BB); @@ -363,6 +365,9 @@ SetVector<Instruction *> ConstantHoistingPass::findConstantInsertionPoint( void ConstantHoistingPass::collectConstantCandidates( ConstCandMapType &ConstCandMap, Instruction *Inst, unsigned Idx, ConstantInt *ConstInt) { + if (ConstInt->getType()->isVectorTy()) + return; + InstructionCost Cost; // Ask the target about the cost of materializing the constant for the given // instruction and operand index. @@ -761,11 +766,13 @@ void ConstantHoistingPass::emitBaseConstants(Instruction *Base, Mat = GetElementPtrInst::Create(Type::getInt8Ty(*Ctx), Base, Adj->Offset, "mat_gep", Adj->MatInsertPt); // Hide it behind a bitcast. - Mat = new BitCastInst(Mat, Adj->Ty, "mat_bitcast", Adj->MatInsertPt); + Mat = new BitCastInst(Mat, Adj->Ty, "mat_bitcast", + Adj->MatInsertPt->getIterator()); } else // Constant being rebased is a ConstantInt. - Mat = BinaryOperator::Create(Instruction::Add, Base, Adj->Offset, - "const_mat", Adj->MatInsertPt); + Mat = + BinaryOperator::Create(Instruction::Add, Base, Adj->Offset, + "const_mat", Adj->MatInsertPt->getIterator()); LLVM_DEBUG(dbgs() << "Materialize constant (" << *Base->getOperand(0) << " + " << *Adj->Offset << ") in BB " @@ -816,7 +823,8 @@ void ConstantHoistingPass::emitBaseConstants(Instruction *Base, // Aside from constant GEPs, only constant cast expressions are collected. assert(ConstExpr->isCast() && "ConstExpr should be a cast"); - Instruction *ConstExprInst = ConstExpr->getAsInstruction(Adj->MatInsertPt); + Instruction *ConstExprInst = ConstExpr->getAsInstruction(); + ConstExprInst->insertBefore(Adj->MatInsertPt); ConstExprInst->setOperand(0, Mat); // Use the same debug location as the instruction we are about to update. @@ -842,9 +850,9 @@ bool ConstantHoistingPass::emitBaseConstants(GlobalVariable *BaseGV) { SmallVectorImpl<consthoist::ConstantInfo> &ConstInfoVec = BaseGV ? ConstGEPInfoMap[BaseGV] : ConstIntInfoVec; for (const consthoist::ConstantInfo &ConstInfo : ConstInfoVec) { - SmallVector<Instruction *, 4> MatInsertPts; + SmallVector<BasicBlock::iterator, 4> MatInsertPts; collectMatInsertPts(ConstInfo.RebasedConstants, MatInsertPts); - SetVector<Instruction *> IPSet = + SetVector<BasicBlock::iterator> IPSet = findConstantInsertionPoint(ConstInfo, MatInsertPts); // We can have an empty set if the function contains unreachable blocks. if (IPSet.empty()) @@ -853,7 +861,7 @@ bool ConstantHoistingPass::emitBaseConstants(GlobalVariable *BaseGV) { unsigned UsesNum = 0; unsigned ReBasesNum = 0; unsigned NotRebasedNum = 0; - for (Instruction *IP : IPSet) { + for (const BasicBlock::iterator &IP : IPSet) { // First, collect constants depending on this IP of the base. UsesNum = 0; SmallVector<UserAdjustment, 4> ToBeRebased; @@ -861,7 +869,7 @@ bool ConstantHoistingPass::emitBaseConstants(GlobalVariable *BaseGV) { for (auto const &RCI : ConstInfo.RebasedConstants) { UsesNum += RCI.Uses.size(); for (auto const &U : RCI.Uses) { - Instruction *MatInsertPt = MatInsertPts[MatCtr++]; + const BasicBlock::iterator &MatInsertPt = MatInsertPts[MatCtr++]; BasicBlock *OrigMatInsertBB = MatInsertPt->getParent(); // If Base constant is to be inserted in multiple places, // generate rebase for U using the Base dominating U. @@ -941,7 +949,7 @@ bool ConstantHoistingPass::runImpl(Function &Fn, TargetTransformInfo &TTI, this->TTI = &TTI; this->DT = &DT; this->BFI = BFI; - this->DL = &Fn.getParent()->getDataLayout(); + this->DL = &Fn.getDataLayout(); this->Ctx = &Fn.getContext(); this->Entry = &Entry; this->PSI = PSI; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp index 7b672e89b67a..d1c80aa67124 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp @@ -29,6 +29,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Verifier.h" #include "llvm/Pass.h" @@ -231,8 +232,8 @@ struct ConstraintTy { ConstraintTy(SmallVector<int64_t, 8> Coefficients, bool IsSigned, bool IsEq, bool IsNe) - : Coefficients(Coefficients), IsSigned(IsSigned), IsEq(IsEq), IsNe(IsNe) { - } + : Coefficients(std::move(Coefficients)), IsSigned(IsSigned), IsEq(IsEq), + IsNe(IsNe) {} unsigned size() const { return Coefficients.size(); } @@ -461,7 +462,7 @@ static Decomposition decomposeGEP(GEPOperator &GEP, // If Op0 is signed non-negative, the GEP is increasing monotonically and // can be de-composed. - if (!isKnownNonNegative(Index, DL, /*Depth=*/MaxAnalysisRecursionDepth - 1)) + if (!isKnownNonNegative(Index, DL)) Preconditions.emplace_back(CmpInst::ICMP_SGE, Index, ConstantInt::get(Index->getType(), 0)); } @@ -499,6 +500,8 @@ static Decomposition decompose(Value *V, if (!Ty->isIntegerTy() || Ty->getIntegerBitWidth() > 64) return V; + bool IsKnownNonNegative = false; + // Decompose \p V used with a signed predicate. if (IsSigned) { if (auto *CI = dyn_cast<ConstantInt>(V)) { @@ -507,6 +510,14 @@ static Decomposition decompose(Value *V, } Value *Op0; Value *Op1; + + if (match(V, m_SExt(m_Value(Op0)))) + V = Op0; + else if (match(V, m_NNegZExt(m_Value(Op0)))) { + V = Op0; + IsKnownNonNegative = true; + } + if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1)))) return MergeResults(Op0, Op1, IsSigned); @@ -529,7 +540,7 @@ static Decomposition decompose(Value *V, } } - return V; + return {V, IsKnownNonNegative}; } if (auto *CI = dyn_cast<ConstantInt>(V)) { @@ -539,22 +550,27 @@ static Decomposition decompose(Value *V, } Value *Op0; - bool IsKnownNonNegative = false; if (match(V, m_ZExt(m_Value(Op0)))) { IsKnownNonNegative = true; V = Op0; } + if (match(V, m_SExt(m_Value(Op0)))) { + V = Op0; + Preconditions.emplace_back(CmpInst::ICMP_SGE, Op0, + ConstantInt::get(Op0->getType(), 0)); + } + Value *Op1; ConstantInt *CI; if (match(V, m_NUWAdd(m_Value(Op0), m_Value(Op1)))) { return MergeResults(Op0, Op1, IsSigned); } if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1)))) { - if (!isKnownNonNegative(Op0, DL, /*Depth=*/MaxAnalysisRecursionDepth - 1)) + if (!isKnownNonNegative(Op0, DL)) Preconditions.emplace_back(CmpInst::ICMP_SGE, Op0, ConstantInt::get(Op0->getType(), 0)); - if (!isKnownNonNegative(Op1, DL, /*Depth=*/MaxAnalysisRecursionDepth - 1)) + if (!isKnownNonNegative(Op1, DL)) Preconditions.emplace_back(CmpInst::ICMP_SGE, Op1, ConstantInt::get(Op1->getType(), 0)); @@ -1016,6 +1032,23 @@ void State::addInfoForInductions(BasicBlock &BB) { WorkList.push_back(FactOrCheck::getConditionFact( DTN, CmpInst::ICMP_SLT, PN, B, ConditionTy(CmpInst::ICMP_SLE, StartValue, B))); + + // Try to add condition from header to the dedicated exit blocks. When exiting + // either with EQ or NE in the header, we know that the induction value must + // be u<= B, as other exits may only exit earlier. + assert(!StepOffset.isNegative() && "induction must be increasing"); + assert((Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE) && + "unsupported predicate"); + ConditionTy Precond = {CmpInst::ICMP_ULE, StartValue, B}; + SmallVector<BasicBlock *> ExitBBs; + L->getExitBlocks(ExitBBs); + for (BasicBlock *EB : ExitBBs) { + // Bail out on non-dedicated exits. + if (DT.dominates(&BB, EB)) { + WorkList.emplace_back(FactOrCheck::getConditionFact( + DT.getNode(EB), CmpInst::ICMP_ULE, A, B, Precond)); + } + } } void State::addInfoFor(BasicBlock &BB) { @@ -1057,6 +1090,8 @@ void State::addInfoFor(BasicBlock &BB) { } // Enqueue ssub_with_overflow for simplification. case Intrinsic::ssub_with_overflow: + case Intrinsic::ucmp: + case Intrinsic::scmp: WorkList.push_back( FactOrCheck::getCheck(DT.getNode(&BB), cast<CallInst>(&I))); break; @@ -1065,6 +1100,9 @@ void State::addInfoFor(BasicBlock &BB) { case Intrinsic::umax: case Intrinsic::smin: case Intrinsic::smax: + // TODO: handle llvm.abs as well + WorkList.push_back( + FactOrCheck::getCheck(DT.getNode(&BB), cast<CallInst>(&I))); // TODO: Check if it is possible to instead only added the min/max facts // when simplifying uses of the min/max intrinsics. if (!isGuaranteedNotToBePoison(&I)) @@ -1395,6 +1433,48 @@ static bool checkAndReplaceCondition( return false; } +static bool checkAndReplaceMinMax(MinMaxIntrinsic *MinMax, ConstraintInfo &Info, + SmallVectorImpl<Instruction *> &ToRemove) { + auto ReplaceMinMaxWithOperand = [&](MinMaxIntrinsic *MinMax, bool UseLHS) { + // TODO: generate reproducer for min/max. + MinMax->replaceAllUsesWith(MinMax->getOperand(UseLHS ? 0 : 1)); + ToRemove.push_back(MinMax); + return true; + }; + + ICmpInst::Predicate Pred = + ICmpInst::getNonStrictPredicate(MinMax->getPredicate()); + if (auto ImpliedCondition = checkCondition( + Pred, MinMax->getOperand(0), MinMax->getOperand(1), MinMax, Info)) + return ReplaceMinMaxWithOperand(MinMax, *ImpliedCondition); + if (auto ImpliedCondition = checkCondition( + Pred, MinMax->getOperand(1), MinMax->getOperand(0), MinMax, Info)) + return ReplaceMinMaxWithOperand(MinMax, !*ImpliedCondition); + return false; +} + +static bool checkAndReplaceCmp(CmpIntrinsic *I, ConstraintInfo &Info, + SmallVectorImpl<Instruction *> &ToRemove) { + Value *LHS = I->getOperand(0); + Value *RHS = I->getOperand(1); + if (checkCondition(I->getGTPredicate(), LHS, RHS, I, Info).value_or(false)) { + I->replaceAllUsesWith(ConstantInt::get(I->getType(), 1)); + ToRemove.push_back(I); + return true; + } + if (checkCondition(I->getLTPredicate(), LHS, RHS, I, Info).value_or(false)) { + I->replaceAllUsesWith(ConstantInt::getSigned(I->getType(), -1)); + ToRemove.push_back(I); + return true; + } + if (checkCondition(ICmpInst::ICMP_EQ, LHS, RHS, I, Info).value_or(false)) { + I->replaceAllUsesWith(ConstantInt::get(I->getType(), 0)); + ToRemove.push_back(I); + return true; + } + return false; +} + static void removeEntryFromStack(const StackEntry &E, ConstraintInfo &Info, Module *ReproducerModule, @@ -1602,7 +1682,7 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI, SmallVector<Value *> FunctionArgs; for (Value &Arg : F.args()) FunctionArgs.push_back(&Arg); - ConstraintInfo Info(F.getParent()->getDataLayout(), FunctionArgs); + ConstraintInfo Info(F.getDataLayout(), FunctionArgs); State S(DT, LI, SE); std::unique_ptr<Module> ReproducerModule( DumpReproducers ? new Module(F.getName(), F.getContext()) : nullptr); @@ -1695,6 +1775,10 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI, ReproducerCondStack, DFSInStack); } Changed |= Simplified; + } else if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(Inst)) { + Changed |= checkAndReplaceMinMax(MinMax, Info, ToRemove); + } else if (auto *CmpIntr = dyn_cast<CmpIntrinsic>(Inst)) { + Changed |= checkAndReplaceCmp(CmpIntr, Info, ToRemove); } continue; } @@ -1730,7 +1814,10 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI, if (!CB.isConditionFact()) { Value *X; if (match(CB.Inst, m_Intrinsic<Intrinsic::abs>(m_Value(X)))) { - // TODO: Add CB.Inst >= 0 fact. + // If is_int_min_poison is true then we may assume llvm.abs >= 0. + if (cast<ConstantInt>(CB.Inst->getOperand(1))->isOne()) + AddFact(CmpInst::ICMP_SGE, CB.Inst, + ConstantInt::get(CB.Inst->getType(), 0)); AddFact(CmpInst::ICMP_SGE, CB.Inst, X); continue; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index 9235850de92f..95de8eceb6be 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -33,6 +33,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" @@ -47,11 +48,6 @@ using namespace llvm; #define DEBUG_TYPE "correlated-value-propagation" -static cl::opt<bool> CanonicalizeICmpPredicatesToUnsigned( - "canonicalize-icmp-predicates-to-unsigned", cl::init(true), cl::Hidden, - cl::desc("Enables canonicalization of signed relational predicates to " - "unsigned (e.g. sgt => ugt)")); - STATISTIC(NumPhis, "Number of phis propagated"); STATISTIC(NumPhiCommon, "Number of phis deleted via common incoming value"); STATISTIC(NumSelects, "Number of selects propagated"); @@ -67,6 +63,7 @@ STATISTIC(NumAShrsConverted, "Number of ashr converted to lshr"); STATISTIC(NumAShrsRemoved, "Number of ashr removed"); STATISTIC(NumSRems, "Number of srem converted to urem"); STATISTIC(NumSExt, "Number of sext converted to zext"); +STATISTIC(NumSIToFP, "Number of sitofp converted to uitofp"); STATISTIC(NumSICmps, "Number of signed icmp preds simplified to unsigned"); STATISTIC(NumAnd, "Number of ands removed"); STATISTIC(NumNW, "Number of no-wrap deductions"); @@ -89,10 +86,13 @@ STATISTIC(NumOverflows, "Number of overflow checks removed"); STATISTIC(NumSaturating, "Number of saturating arithmetics converted to normal arithmetics"); STATISTIC(NumNonNull, "Number of function pointer arguments marked non-null"); +STATISTIC(NumCmpIntr, "Number of llvm.[us]cmp intrinsics removed"); STATISTIC(NumMinMax, "Number of llvm.[us]{min,max} intrinsics removed"); +STATISTIC(NumSMinMax, + "Number of llvm.s{min,max} intrinsics simplified to unsigned"); STATISTIC(NumUDivURemsNarrowedExpanded, "Number of bound udiv's/urem's expanded"); -STATISTIC(NumZExt, "Number of non-negative deductions"); +STATISTIC(NumNNeg, "Number of zext/uitofp non-negative deductions"); static Constant *getConstantAt(Value *V, Instruction *At, LazyValueInfo *LVI) { if (Constant *C = LVI->getConstant(V, At)) @@ -109,14 +109,8 @@ static Constant *getConstantAt(Value *V, Instruction *At, LazyValueInfo *LVI) { if (!Op1) return nullptr; - LazyValueInfo::Tristate Result = LVI->getPredicateAt( - C->getPredicate(), Op0, Op1, At, /*UseBlockValue=*/false); - if (Result == LazyValueInfo::Unknown) - return nullptr; - - return (Result == LazyValueInfo::True) - ? ConstantInt::getTrue(C->getContext()) - : ConstantInt::getFalse(C->getContext()); + return LVI->getPredicateAt(C->getPredicate(), Op0, Op1, At, + /*UseBlockValue=*/false); } static bool processSelect(SelectInst *S, LazyValueInfo *LVI) { @@ -243,15 +237,17 @@ static Value *getValueOnEdge(LazyValueInfo *LVI, Value *Incoming, // The "false" case if (auto *C = dyn_cast<Constant>(SI->getFalseValue())) - if (LVI->getPredicateOnEdge(ICmpInst::ICMP_EQ, SI, C, From, To, CxtI) == - LazyValueInfo::False) + if (auto *Res = dyn_cast_or_null<ConstantInt>( + LVI->getPredicateOnEdge(ICmpInst::ICMP_EQ, SI, C, From, To, CxtI)); + Res && Res->isZero()) return SI->getTrueValue(); // The "true" case, // similar to the select "false" case, but try the select "true" value if (auto *C = dyn_cast<Constant>(SI->getTrueValue())) - if (LVI->getPredicateOnEdge(ICmpInst::ICMP_EQ, SI, C, From, To, CxtI) == - LazyValueInfo::False) + if (auto *Res = dyn_cast_or_null<ConstantInt>( + LVI->getPredicateOnEdge(ICmpInst::ICMP_EQ, SI, C, From, To, CxtI)); + Res && Res->isZero()) return SI->getFalseValue(); return nullptr; @@ -289,12 +285,8 @@ static bool processPHI(PHINode *P, LazyValueInfo *LVI, DominatorTree *DT, } static bool processICmp(ICmpInst *Cmp, LazyValueInfo *LVI) { - if (!CanonicalizeICmpPredicatesToUnsigned) - return false; - - // Only for signed relational comparisons of scalar integers. - if (Cmp->getType()->isVectorTy() || - !Cmp->getOperand(0)->getType()->isIntegerTy()) + // Only for signed relational comparisons of integers. + if (!Cmp->getOperand(0)->getType()->isIntOrIntVectorTy()) return false; if (!Cmp->isSigned()) @@ -324,16 +316,13 @@ static bool processICmp(ICmpInst *Cmp, LazyValueInfo *LVI) { static bool constantFoldCmp(CmpInst *Cmp, LazyValueInfo *LVI) { Value *Op0 = Cmp->getOperand(0); Value *Op1 = Cmp->getOperand(1); - LazyValueInfo::Tristate Result = - LVI->getPredicateAt(Cmp->getPredicate(), Op0, Op1, Cmp, - /*UseBlockValue=*/true); - if (Result == LazyValueInfo::Unknown) + Constant *Res = LVI->getPredicateAt(Cmp->getPredicate(), Op0, Op1, Cmp, + /*UseBlockValue=*/true); + if (!Res) return false; ++NumCmps; - Constant *TorF = - ConstantInt::get(CmpInst::makeCmpResultType(Op0->getType()), Result); - Cmp->replaceAllUsesWith(TorF); + Cmp->replaceAllUsesWith(Res); Cmp->eraseFromParent(); return true; } @@ -371,14 +360,15 @@ static bool processSwitch(SwitchInst *I, LazyValueInfo *LVI, { // Scope for SwitchInstProfUpdateWrapper. It must not live during // ConstantFoldTerminator() as the underlying SwitchInst can be changed. SwitchInstProfUpdateWrapper SI(*I); + unsigned ReachableCaseCount = 0; for (auto CI = SI->case_begin(), CE = SI->case_end(); CI != CE;) { ConstantInt *Case = CI->getCaseValue(); - LazyValueInfo::Tristate State = + auto *Res = dyn_cast_or_null<ConstantInt>( LVI->getPredicateAt(CmpInst::ICMP_EQ, Cond, Case, I, - /* UseBlockValue */ true); + /* UseBlockValue */ true)); - if (State == LazyValueInfo::False) { + if (Res && Res->isZero()) { // This case never fires - remove it. BasicBlock *Succ = CI->getCaseSuccessor(); Succ->removePredecessor(BB); @@ -395,7 +385,7 @@ static bool processSwitch(SwitchInst *I, LazyValueInfo *LVI, DTU.applyUpdatesPermissive({{DominatorTree::Delete, BB, Succ}}); continue; } - if (State == LazyValueInfo::True) { + if (Res && Res->isOne()) { // This case always fires. Arrange for the switch to be turned into an // unconditional branch by replacing the switch condition with the case // value. @@ -407,6 +397,31 @@ static bool processSwitch(SwitchInst *I, LazyValueInfo *LVI, // Increment the case iterator since we didn't delete it. ++CI; + ++ReachableCaseCount; + } + + BasicBlock *DefaultDest = SI->getDefaultDest(); + if (ReachableCaseCount > 1 && + !isa<UnreachableInst>(DefaultDest->getFirstNonPHIOrDbg())) { + ConstantRange CR = LVI->getConstantRangeAtUse(I->getOperandUse(0), + /*UndefAllowed*/ false); + // The default dest is unreachable if all cases are covered. + if (!CR.isSizeLargerThan(ReachableCaseCount)) { + BasicBlock *NewUnreachableBB = + BasicBlock::Create(BB->getContext(), "default.unreachable", + BB->getParent(), DefaultDest); + new UnreachableInst(BB->getContext(), NewUnreachableBB); + + DefaultDest->removePredecessor(BB); + SI->setDefaultDest(NewUnreachableBB); + + if (SuccessorsCount[DefaultDest] == 1) + DTU.applyUpdates({{DominatorTree::Delete, BB, DefaultDest}}); + DTU.applyUpdates({{DominatorTree::Insert, BB, NewUnreachableBB}}); + + ++NumDeadCases; + Changed = true; + } } } @@ -483,12 +498,8 @@ static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI); // because it is negation-invariant. static bool processAbsIntrinsic(IntrinsicInst *II, LazyValueInfo *LVI) { Value *X = II->getArgOperand(0); - Type *Ty = X->getType(); - if (!Ty->isIntegerTy()) - return false; - bool IsIntMinPoison = cast<ConstantInt>(II->getArgOperand(1))->isOne(); - APInt IntMin = APInt::getSignedMinValue(Ty->getScalarSizeInBits()); + APInt IntMin = APInt::getSignedMinValue(X->getType()->getScalarSizeInBits()); ConstantRange Range = LVI->getConstantRangeAtUse( II->getOperandUse(0), /*UndefAllowed*/ IsIntMinPoison); @@ -503,7 +514,7 @@ static bool processAbsIntrinsic(IntrinsicInst *II, LazyValueInfo *LVI) { // Is X in [IntMin, 0]? NOTE: INT_MIN is fine! if (Range.getSignedMax().isNonPositive()) { IRBuilder<> B(II); - Value *NegX = B.CreateNeg(X, II->getName(), /*HasNUW=*/false, + Value *NegX = B.CreateNeg(X, II->getName(), /*HasNSW=*/IsIntMinPoison); ++NumAbs; II->replaceAllUsesWith(NegX); @@ -527,18 +538,69 @@ static bool processAbsIntrinsic(IntrinsicInst *II, LazyValueInfo *LVI) { return false; } +static bool processCmpIntrinsic(CmpIntrinsic *CI, LazyValueInfo *LVI) { + ConstantRange LHS_CR = + LVI->getConstantRangeAtUse(CI->getOperandUse(0), /*UndefAllowed*/ false); + ConstantRange RHS_CR = + LVI->getConstantRangeAtUse(CI->getOperandUse(1), /*UndefAllowed*/ false); + + if (LHS_CR.icmp(CI->getGTPredicate(), RHS_CR)) { + ++NumCmpIntr; + CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 1)); + CI->eraseFromParent(); + return true; + } + if (LHS_CR.icmp(CI->getLTPredicate(), RHS_CR)) { + ++NumCmpIntr; + CI->replaceAllUsesWith(ConstantInt::getSigned(CI->getType(), -1)); + CI->eraseFromParent(); + return true; + } + if (LHS_CR.icmp(ICmpInst::ICMP_EQ, RHS_CR)) { + ++NumCmpIntr; + CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0)); + CI->eraseFromParent(); + return true; + } + + return false; +} + // See if this min/max intrinsic always picks it's one specific operand. +// If not, check whether we can canonicalize signed minmax into unsigned version static bool processMinMaxIntrinsic(MinMaxIntrinsic *MM, LazyValueInfo *LVI) { CmpInst::Predicate Pred = CmpInst::getNonStrictPredicate(MM->getPredicate()); - LazyValueInfo::Tristate Result = LVI->getPredicateAt( - Pred, MM->getLHS(), MM->getRHS(), MM, /*UseBlockValue=*/true); - if (Result == LazyValueInfo::Unknown) - return false; + ConstantRange LHS_CR = LVI->getConstantRangeAtUse(MM->getOperandUse(0), + /*UndefAllowed*/ false); + ConstantRange RHS_CR = LVI->getConstantRangeAtUse(MM->getOperandUse(1), + /*UndefAllowed*/ false); + if (LHS_CR.icmp(Pred, RHS_CR)) { + ++NumMinMax; + MM->replaceAllUsesWith(MM->getLHS()); + MM->eraseFromParent(); + return true; + } + if (RHS_CR.icmp(Pred, LHS_CR)) { + ++NumMinMax; + MM->replaceAllUsesWith(MM->getRHS()); + MM->eraseFromParent(); + return true; + } - ++NumMinMax; - MM->replaceAllUsesWith(MM->getOperand(!Result)); - MM->eraseFromParent(); - return true; + if (MM->isSigned() && + ConstantRange::areInsensitiveToSignednessOfICmpPredicate(LHS_CR, + RHS_CR)) { + ++NumSMinMax; + IRBuilder<> B(MM); + MM->replaceAllUsesWith(B.CreateBinaryIntrinsic( + MM->getIntrinsicID() == Intrinsic::smin ? Intrinsic::umin + : Intrinsic::umax, + MM->getLHS(), MM->getRHS())); + MM->eraseFromParent(); + return true; + } + + return false; } // Rewrite this with.overflow intrinsic as non-overflowing. @@ -573,7 +635,7 @@ static bool processSaturatingInst(SaturatingInst *SI, LazyValueInfo *LVI) { bool NSW = SI->isSigned(); bool NUW = !SI->isSigned(); BinaryOperator *BinOp = BinaryOperator::Create( - Opcode, SI->getLHS(), SI->getRHS(), SI->getName(), SI); + Opcode, SI->getLHS(), SI->getRHS(), SI->getName(), SI->getIterator()); BinOp->setDebugLoc(SI->getDebugLoc()); setDeducedOverflowingFlags(BinOp, Opcode, NSW, NUW); @@ -595,20 +657,22 @@ static bool processCallSite(CallBase &CB, LazyValueInfo *LVI) { return processAbsIntrinsic(&cast<IntrinsicInst>(CB), LVI); } + if (auto *CI = dyn_cast<CmpIntrinsic>(&CB)) { + return processCmpIntrinsic(CI, LVI); + } + if (auto *MM = dyn_cast<MinMaxIntrinsic>(&CB)) { return processMinMaxIntrinsic(MM, LVI); } if (auto *WO = dyn_cast<WithOverflowInst>(&CB)) { - if (WO->getLHS()->getType()->isIntegerTy() && willNotOverflow(WO, LVI)) { + if (willNotOverflow(WO, LVI)) return processOverflowIntrinsic(WO, LVI); - } } if (auto *SI = dyn_cast<SaturatingInst>(&CB)) { - if (SI->getType()->isIntegerTy() && willNotOverflow(SI, LVI)) { + if (willNotOverflow(SI, LVI)) return processSaturatingInst(SI, LVI); - } } bool Changed = false; @@ -643,11 +707,12 @@ static bool processCallSite(CallBase &CB, LazyValueInfo *LVI) { // relatively expensive analysis for constants which are obviously either // null or non-null to start with. if (Type && !CB.paramHasAttr(ArgNo, Attribute::NonNull) && - !isa<Constant>(V) && - LVI->getPredicateAt(ICmpInst::ICMP_EQ, V, - ConstantPointerNull::get(Type), &CB, - /*UseBlockValue=*/false) == LazyValueInfo::False) - ArgNos.push_back(ArgNo); + !isa<Constant>(V)) + if (auto *Res = dyn_cast_or_null<ConstantInt>(LVI->getPredicateAt( + ICmpInst::ICMP_EQ, V, ConstantPointerNull::get(Type), &CB, + /*UseBlockValue=*/false)); + Res && Res->isZero()) + ArgNos.push_back(ArgNo); ArgNo++; } @@ -682,11 +747,10 @@ static bool narrowSDivOrSRem(BinaryOperator *Instr, const ConstantRange &LCR, const ConstantRange &RCR) { assert(Instr->getOpcode() == Instruction::SDiv || Instr->getOpcode() == Instruction::SRem); - assert(!Instr->getType()->isVectorTy()); // Find the smallest power of two bitwidth that's sufficient to hold Instr's // operands. - unsigned OrigWidth = Instr->getType()->getIntegerBitWidth(); + unsigned OrigWidth = Instr->getType()->getScalarSizeInBits(); // What is the smallest bit width that can accommodate the entire value ranges // of both of the operands? @@ -709,7 +773,7 @@ static bool narrowSDivOrSRem(BinaryOperator *Instr, const ConstantRange &LCR, ++NumSDivSRemsNarrowed; IRBuilder<> B{Instr}; - auto *TruncTy = Type::getIntNTy(Instr->getContext(), NewWidth); + auto *TruncTy = Instr->getType()->getWithNewBitWidth(NewWidth); auto *LHS = B.CreateTruncOrBitCast(Instr->getOperand(0), TruncTy, Instr->getName() + ".lhs.trunc"); auto *RHS = B.CreateTruncOrBitCast(Instr->getOperand(1), TruncTy, @@ -730,7 +794,6 @@ static bool expandUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR, Type *Ty = Instr->getType(); assert(Instr->getOpcode() == Instruction::UDiv || Instr->getOpcode() == Instruction::URem); - assert(!Ty->isVectorTy()); bool IsRem = Instr->getOpcode() == Instruction::URem; Value *X = Instr->getOperand(0); @@ -788,9 +851,12 @@ static bool expandUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR, Value *FrozenX = X; if (!isGuaranteedNotToBeUndef(X)) FrozenX = B.CreateFreeze(X, X->getName() + ".frozen"); - auto *AdjX = B.CreateNUWSub(FrozenX, Y, Instr->getName() + ".urem"); - auto *Cmp = - B.CreateICmp(ICmpInst::ICMP_ULT, FrozenX, Y, Instr->getName() + ".cmp"); + Value *FrozenY = Y; + if (!isGuaranteedNotToBeUndef(Y)) + FrozenY = B.CreateFreeze(Y, Y->getName() + ".frozen"); + auto *AdjX = B.CreateNUWSub(FrozenX, FrozenY, Instr->getName() + ".urem"); + auto *Cmp = B.CreateICmp(ICmpInst::ICMP_ULT, FrozenX, FrozenY, + Instr->getName() + ".cmp"); ExpandedOp = B.CreateSelect(Cmp, FrozenX, AdjX); } else { auto *Cmp = @@ -810,7 +876,6 @@ static bool narrowUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR, const ConstantRange &YCR) { assert(Instr->getOpcode() == Instruction::UDiv || Instr->getOpcode() == Instruction::URem); - assert(!Instr->getType()->isVectorTy()); // Find the smallest power of two bitwidth that's sufficient to hold Instr's // operands. @@ -823,12 +888,12 @@ static bool narrowUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR, // NewWidth might be greater than OrigWidth if OrigWidth is not a power of // two. - if (NewWidth >= Instr->getType()->getIntegerBitWidth()) + if (NewWidth >= Instr->getType()->getScalarSizeInBits()) return false; ++NumUDivURemsNarrowed; IRBuilder<> B{Instr}; - auto *TruncTy = Type::getIntNTy(Instr->getContext(), NewWidth); + auto *TruncTy = Instr->getType()->getWithNewBitWidth(NewWidth); auto *LHS = B.CreateTruncOrBitCast(Instr->getOperand(0), TruncTy, Instr->getName() + ".lhs.trunc"); auto *RHS = B.CreateTruncOrBitCast(Instr->getOperand(1), TruncTy, @@ -847,9 +912,6 @@ static bool narrowUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR, static bool processUDivOrURem(BinaryOperator *Instr, LazyValueInfo *LVI) { assert(Instr->getOpcode() == Instruction::UDiv || Instr->getOpcode() == Instruction::URem); - if (Instr->getType()->isVectorTy()) - return false; - ConstantRange XCR = LVI->getConstantRangeAtUse(Instr->getOperandUse(0), /*UndefAllowed*/ false); // Allow undef for RHS, as we can assume it is division by zero UB. @@ -864,7 +926,6 @@ static bool processUDivOrURem(BinaryOperator *Instr, LazyValueInfo *LVI) { static bool processSRem(BinaryOperator *SDI, const ConstantRange &LCR, const ConstantRange &RCR, LazyValueInfo *LVI) { assert(SDI->getOpcode() == Instruction::SRem); - assert(!SDI->getType()->isVectorTy()); if (LCR.abs().icmp(CmpInst::ICMP_ULT, RCR.abs())) { SDI->replaceAllUsesWith(SDI->getOperand(0)); @@ -888,21 +949,22 @@ static bool processSRem(BinaryOperator *SDI, const ConstantRange &LCR, for (Operand &Op : Ops) { if (Op.D == Domain::NonNegative) continue; - auto *BO = - BinaryOperator::CreateNeg(Op.V, Op.V->getName() + ".nonneg", SDI); + auto *BO = BinaryOperator::CreateNeg(Op.V, Op.V->getName() + ".nonneg", + SDI->getIterator()); BO->setDebugLoc(SDI->getDebugLoc()); Op.V = BO; } - auto *URem = - BinaryOperator::CreateURem(Ops[0].V, Ops[1].V, SDI->getName(), SDI); + auto *URem = BinaryOperator::CreateURem(Ops[0].V, Ops[1].V, SDI->getName(), + SDI->getIterator()); URem->setDebugLoc(SDI->getDebugLoc()); auto *Res = URem; // If the divident was non-positive, we need to negate the result. if (Ops[0].D == Domain::NonPositive) { - Res = BinaryOperator::CreateNeg(Res, Res->getName() + ".neg", SDI); + Res = BinaryOperator::CreateNeg(Res, Res->getName() + ".neg", + SDI->getIterator()); Res->setDebugLoc(SDI->getDebugLoc()); } @@ -923,7 +985,6 @@ static bool processSRem(BinaryOperator *SDI, const ConstantRange &LCR, static bool processSDiv(BinaryOperator *SDI, const ConstantRange &LCR, const ConstantRange &RCR, LazyValueInfo *LVI) { assert(SDI->getOpcode() == Instruction::SDiv); - assert(!SDI->getType()->isVectorTy()); // Check whether the division folds to a constant. ConstantRange DivCR = LCR.sdiv(RCR); @@ -949,14 +1010,14 @@ static bool processSDiv(BinaryOperator *SDI, const ConstantRange &LCR, for (Operand &Op : Ops) { if (Op.D == Domain::NonNegative) continue; - auto *BO = - BinaryOperator::CreateNeg(Op.V, Op.V->getName() + ".nonneg", SDI); + auto *BO = BinaryOperator::CreateNeg(Op.V, Op.V->getName() + ".nonneg", + SDI->getIterator()); BO->setDebugLoc(SDI->getDebugLoc()); Op.V = BO; } - auto *UDiv = - BinaryOperator::CreateUDiv(Ops[0].V, Ops[1].V, SDI->getName(), SDI); + auto *UDiv = BinaryOperator::CreateUDiv(Ops[0].V, Ops[1].V, SDI->getName(), + SDI->getIterator()); UDiv->setDebugLoc(SDI->getDebugLoc()); UDiv->setIsExact(SDI->isExact()); @@ -964,7 +1025,8 @@ static bool processSDiv(BinaryOperator *SDI, const ConstantRange &LCR, // If the operands had two different domains, we need to negate the result. if (Ops[0].D != Ops[1].D) { - Res = BinaryOperator::CreateNeg(Res, Res->getName() + ".neg", SDI); + Res = BinaryOperator::CreateNeg(Res, Res->getName() + ".neg", + SDI->getIterator()); Res->setDebugLoc(SDI->getDebugLoc()); } @@ -980,9 +1042,6 @@ static bool processSDiv(BinaryOperator *SDI, const ConstantRange &LCR, static bool processSDivOrSRem(BinaryOperator *Instr, LazyValueInfo *LVI) { assert(Instr->getOpcode() == Instruction::SDiv || Instr->getOpcode() == Instruction::SRem); - if (Instr->getType()->isVectorTy()) - return false; - ConstantRange LCR = LVI->getConstantRangeAtUse(Instr->getOperandUse(0), /*AllowUndef*/ false); // Allow undef for RHS, as we can assume it is division by zero UB. @@ -1001,12 +1060,9 @@ static bool processSDivOrSRem(BinaryOperator *Instr, LazyValueInfo *LVI) { } static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) { - if (SDI->getType()->isVectorTy()) - return false; - ConstantRange LRange = LVI->getConstantRangeAtUse(SDI->getOperandUse(0), /*UndefAllowed*/ false); - unsigned OrigWidth = SDI->getType()->getIntegerBitWidth(); + unsigned OrigWidth = SDI->getType()->getScalarSizeInBits(); ConstantRange NegOneOrZero = ConstantRange(APInt(OrigWidth, (uint64_t)-1, true), APInt(OrigWidth, 1)); if (NegOneOrZero.contains(LRange)) { @@ -1022,7 +1078,7 @@ static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) { ++NumAShrsConverted; auto *BO = BinaryOperator::CreateLShr(SDI->getOperand(0), SDI->getOperand(1), - "", SDI); + "", SDI->getIterator()); BO->takeName(SDI); BO->setDebugLoc(SDI->getDebugLoc()); BO->setIsExact(SDI->isExact()); @@ -1033,16 +1089,14 @@ static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) { } static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) { - if (SDI->getType()->isVectorTy()) - return false; - const Use &Base = SDI->getOperandUse(0); if (!LVI->getConstantRangeAtUse(Base, /*UndefAllowed*/ false) .isAllNonNegative()) return false; ++NumSExt; - auto *ZExt = CastInst::CreateZExtOrBitCast(Base, SDI->getType(), "", SDI); + auto *ZExt = CastInst::CreateZExtOrBitCast(Base, SDI->getType(), "", + SDI->getIterator()); ZExt->takeName(SDI); ZExt->setDebugLoc(SDI->getDebugLoc()); ZExt->setNonNeg(); @@ -1052,20 +1106,43 @@ static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) { return true; } -static bool processZExt(ZExtInst *ZExt, LazyValueInfo *LVI) { - if (ZExt->getType()->isVectorTy()) +static bool processPossibleNonNeg(PossiblyNonNegInst *I, LazyValueInfo *LVI) { + if (I->hasNonNeg()) return false; - if (ZExt->hasNonNeg()) + const Use &Base = I->getOperandUse(0); + if (!LVI->getConstantRangeAtUse(Base, /*UndefAllowed*/ false) + .isAllNonNegative()) return false; - const Use &Base = ZExt->getOperandUse(0); + ++NumNNeg; + I->setNonNeg(); + + return true; +} + +static bool processZExt(ZExtInst *ZExt, LazyValueInfo *LVI) { + return processPossibleNonNeg(cast<PossiblyNonNegInst>(ZExt), LVI); +} + +static bool processUIToFP(UIToFPInst *UIToFP, LazyValueInfo *LVI) { + return processPossibleNonNeg(cast<PossiblyNonNegInst>(UIToFP), LVI); +} + +static bool processSIToFP(SIToFPInst *SIToFP, LazyValueInfo *LVI) { + const Use &Base = SIToFP->getOperandUse(0); if (!LVI->getConstantRangeAtUse(Base, /*UndefAllowed*/ false) .isAllNonNegative()) return false; - ++NumZExt; - ZExt->setNonNeg(); + ++NumSIToFP; + auto *UIToFP = CastInst::Create(Instruction::UIToFP, Base, SIToFP->getType(), + "", SIToFP->getIterator()); + UIToFP->takeName(SIToFP); + UIToFP->setDebugLoc(SIToFP->getDebugLoc()); + UIToFP->setNonNeg(); + SIToFP->replaceAllUsesWith(UIToFP); + SIToFP->eraseFromParent(); return true; } @@ -1073,22 +1150,16 @@ static bool processZExt(ZExtInst *ZExt, LazyValueInfo *LVI) { static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI) { using OBO = OverflowingBinaryOperator; - if (BinOp->getType()->isVectorTy()) - return false; - bool NSW = BinOp->hasNoSignedWrap(); bool NUW = BinOp->hasNoUnsignedWrap(); if (NSW && NUW) return false; Instruction::BinaryOps Opcode = BinOp->getOpcode(); - Value *LHS = BinOp->getOperand(0); - Value *RHS = BinOp->getOperand(1); - - ConstantRange LRange = - LVI->getConstantRange(LHS, BinOp, /*UndefAllowed*/ false); - ConstantRange RRange = - LVI->getConstantRange(RHS, BinOp, /*UndefAllowed*/ false); + ConstantRange LRange = LVI->getConstantRangeAtUse(BinOp->getOperandUse(0), + /*UndefAllowed=*/false); + ConstantRange RRange = LVI->getConstantRangeAtUse(BinOp->getOperandUse(1), + /*UndefAllowed=*/false); bool Changed = false; bool NewNUW = false, NewNSW = false; @@ -1111,21 +1182,20 @@ static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI) { } static bool processAnd(BinaryOperator *BinOp, LazyValueInfo *LVI) { - if (BinOp->getType()->isVectorTy()) - return false; + using namespace llvm::PatternMatch; // Pattern match (and lhs, C) where C includes a superset of bits which might // be set in lhs. This is a common truncation idiom created by instcombine. const Use &LHS = BinOp->getOperandUse(0); - ConstantInt *RHS = dyn_cast<ConstantInt>(BinOp->getOperand(1)); - if (!RHS || !RHS->getValue().isMask()) + const APInt *RHS; + if (!match(BinOp->getOperand(1), m_LowBitMask(RHS))) return false; // We can only replace the AND with LHS based on range info if the range does // not include undef. ConstantRange LRange = LVI->getConstantRangeAtUse(LHS, /*UndefAllowed=*/false); - if (!LRange.getUnsignedMax().ule(RHS->getValue())) + if (!LRange.getUnsignedMax().ule(*RHS)) return false; BinOp->replaceAllUsesWith(LHS); @@ -1177,6 +1247,12 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT, case Instruction::ZExt: BBChanged |= processZExt(cast<ZExtInst>(&II), LVI); break; + case Instruction::UIToFP: + BBChanged |= processUIToFP(cast<UIToFPInst>(&II), LVI); + break; + case Instruction::SIToFP: + BBChanged |= processSIToFP(cast<SIToFPInst>(&II), LVI); + break; case Instruction::Add: case Instruction::Sub: case Instruction::Mul: @@ -1227,6 +1303,12 @@ CorrelatedValuePropagationPass::run(Function &F, FunctionAnalysisManager &AM) { if (!Changed) { PA = PreservedAnalyses::all(); } else { +#if defined(EXPENSIVE_CHECKS) + assert(DT->verify(DominatorTree::VerificationLevel::Full)); +#else + assert(DT->verify(DominatorTree::VerificationLevel::Fast)); +#endif // EXPENSIVE_CHECKS + PA.preserve<DominatorTreeAnalysis>(); PA.preserve<LazyValueAnalysis>(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp index 85d4065286e4..4371b821eae6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp @@ -65,6 +65,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/CFG.h" @@ -95,6 +96,11 @@ static cl::opt<bool> cl::desc("View the CFG before DFA Jump Threading"), cl::Hidden, cl::init(false)); +static cl::opt<bool> EarlyExitHeuristic( + "dfa-early-exit-heuristic", + cl::desc("Exit early if an unpredictable value come from the same loop"), + cl::Hidden, cl::init(true)); + static cl::opt<unsigned> MaxPathLength( "dfa-max-path-length", cl::desc("Max number of blocks searched to find a threading path"), @@ -125,17 +131,18 @@ public: explicit operator bool() const { return SI && SIUse; } }; -void unfold(DomTreeUpdater *DTU, SelectInstToUnfold SIToUnfold, +void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold, std::vector<SelectInstToUnfold> *NewSIsToUnfold, std::vector<BasicBlock *> *NewBBs); class DFAJumpThreading { public: - DFAJumpThreading(AssumptionCache *AC, DominatorTree *DT, + DFAJumpThreading(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI, TargetTransformInfo *TTI, OptimizationRemarkEmitter *ORE) - : AC(AC), DT(DT), TTI(TTI), ORE(ORE) {} + : AC(AC), DT(DT), LI(LI), TTI(TTI), ORE(ORE) {} bool run(Function &F); + bool LoopInfoBroken; private: void @@ -151,7 +158,7 @@ private: std::vector<SelectInstToUnfold> NewSIsToUnfold; std::vector<BasicBlock *> NewBBs; - unfold(&DTU, SIToUnfold, &NewSIsToUnfold, &NewBBs); + unfold(&DTU, LI, SIToUnfold, &NewSIsToUnfold, &NewBBs); // Put newly discovered select instructions into the work list. for (const SelectInstToUnfold &NewSIToUnfold : NewSIsToUnfold) @@ -161,6 +168,7 @@ private: AssumptionCache *AC; DominatorTree *DT; + LoopInfo *LI; TargetTransformInfo *TTI; OptimizationRemarkEmitter *ORE; }; @@ -194,7 +202,7 @@ void createBasicBlockAndSinkSelectInst( /// created basic blocks into \p NewBBs. /// /// TODO: merge it with CodeGenPrepare::optimizeSelectInst() if possible. -void unfold(DomTreeUpdater *DTU, SelectInstToUnfold SIToUnfold, +void unfold(DomTreeUpdater *DTU, LoopInfo *LI, SelectInstToUnfold SIToUnfold, std::vector<SelectInstToUnfold> *NewSIsToUnfold, std::vector<BasicBlock *> *NewBBs) { SelectInst *SI = SIToUnfold.getInst(); @@ -300,6 +308,12 @@ void unfold(DomTreeUpdater *DTU, SelectInstToUnfold SIToUnfold, DTU->applyUpdates({{DominatorTree::Insert, StartBlock, TT}, {DominatorTree::Insert, StartBlock, FT}}); + // Preserve loop info + if (Loop *L = LI->getLoopFor(SI->getParent())) { + for (BasicBlock *NewBB : *NewBBs) + L->addBasicBlockToLoop(NewBB, *LI); + } + // The select is now dead. assert(SI->use_empty() && "Select must be dead now"); SI->eraseFromParent(); @@ -378,7 +392,8 @@ inline raw_ostream &operator<<(raw_ostream &OS, const ThreadingPath &TPath) { #endif struct MainSwitch { - MainSwitch(SwitchInst *SI, OptimizationRemarkEmitter *ORE) { + MainSwitch(SwitchInst *SI, LoopInfo *LI, OptimizationRemarkEmitter *ORE) + : LI(LI) { if (isCandidate(SI)) { Instr = SI; } else { @@ -402,7 +417,7 @@ private: /// /// Also, collect select instructions to unfold. bool isCandidate(const SwitchInst *SI) { - std::deque<Value *> Q; + std::deque<std::pair<Value *, BasicBlock *>> Q; SmallSet<Value *, 16> SeenValues; SelectInsts.clear(); @@ -411,22 +426,29 @@ private: if (!isa<PHINode>(SICond)) return false; - addToQueue(SICond, Q, SeenValues); + // The switch must be in a loop. + const Loop *L = LI->getLoopFor(SI->getParent()); + if (!L) + return false; + + addToQueue(SICond, nullptr, Q, SeenValues); while (!Q.empty()) { - Value *Current = Q.front(); + Value *Current = Q.front().first; + BasicBlock *CurrentIncomingBB = Q.front().second; Q.pop_front(); if (auto *Phi = dyn_cast<PHINode>(Current)) { - for (Value *Incoming : Phi->incoming_values()) { - addToQueue(Incoming, Q, SeenValues); + for (BasicBlock *IncomingBB : Phi->blocks()) { + Value *Incoming = Phi->getIncomingValueForBlock(IncomingBB); + addToQueue(Incoming, IncomingBB, Q, SeenValues); } LLVM_DEBUG(dbgs() << "\tphi: " << *Phi << "\n"); } else if (SelectInst *SelI = dyn_cast<SelectInst>(Current)) { if (!isValidSelectInst(SelI)) return false; - addToQueue(SelI->getTrueValue(), Q, SeenValues); - addToQueue(SelI->getFalseValue(), Q, SeenValues); + addToQueue(SelI->getTrueValue(), CurrentIncomingBB, Q, SeenValues); + addToQueue(SelI->getFalseValue(), CurrentIncomingBB, Q, SeenValues); LLVM_DEBUG(dbgs() << "\tselect: " << *SelI << "\n"); if (auto *SelIUse = dyn_cast<PHINode>(SelI->user_back())) SelectInsts.push_back(SelectInstToUnfold(SelI, SelIUse)); @@ -439,6 +461,18 @@ private: // initial switch values that can be ignored (they will hit the // unthreaded switch) but this assumption will get checked later after // paths have been enumerated (in function getStateDefMap). + + // If the unpredictable value comes from the same inner loop it is + // likely that it will also be on the enumerated paths, causing us to + // exit after we have enumerated all the paths. This heuristic save + // compile time because a search for all the paths can become expensive. + if (EarlyExitHeuristic && + L->contains(LI->getLoopFor(CurrentIncomingBB))) { + LLVM_DEBUG(dbgs() + << "\tExiting early due to unpredictability heuristic.\n"); + return false; + } + continue; } } @@ -446,11 +480,12 @@ private: return true; } - void addToQueue(Value *Val, std::deque<Value *> &Q, + void addToQueue(Value *Val, BasicBlock *BB, + std::deque<std::pair<Value *, BasicBlock *>> &Q, SmallSet<Value *, 16> &SeenValues) { if (SeenValues.contains(Val)) return; - Q.push_back(Val); + Q.push_back({Val, BB}); SeenValues.insert(Val); } @@ -488,14 +523,16 @@ private: return true; } + LoopInfo *LI; SwitchInst *Instr = nullptr; SmallVector<SelectInstToUnfold, 4> SelectInsts; }; struct AllSwitchPaths { - AllSwitchPaths(const MainSwitch *MSwitch, OptimizationRemarkEmitter *ORE) - : Switch(MSwitch->getInstr()), SwitchBlock(Switch->getParent()), - ORE(ORE) {} + AllSwitchPaths(const MainSwitch *MSwitch, OptimizationRemarkEmitter *ORE, + LoopInfo *LI) + : Switch(MSwitch->getInstr()), SwitchBlock(Switch->getParent()), ORE(ORE), + LI(LI) {} std::vector<ThreadingPath> &getThreadingPaths() { return TPaths; } unsigned getNumThreadingPaths() { return TPaths.size(); } @@ -516,7 +553,7 @@ struct AllSwitchPaths { return; } - for (PathType Path : LoopPaths) { + for (const PathType &Path : LoopPaths) { ThreadingPath TPath; const BasicBlock *PrevBB = Path.back(); @@ -567,6 +604,12 @@ private: Visited.insert(BB); + // Stop if we have reached the BB out of loop, since its successors have no + // impact on the DFA. + // TODO: Do we need to stop exploring if BB is the outer loop of the switch? + if (!LI->getLoopFor(BB)) + return Res; + // Some blocks have multiple edges to the same successor, and this set // is used to prevent a duplicate path from being generated SmallSet<BasicBlock *, 4> Successors; @@ -708,6 +751,7 @@ private: BasicBlock *SwitchBlock; OptimizationRemarkEmitter *ORE; std::vector<ThreadingPath> TPaths; + LoopInfo *LI; }; struct TransformDFA { @@ -783,7 +827,8 @@ private: return false; } - if (Metrics.convergent) { + // FIXME: Allow jump threading with controlled convergence. + if (Metrics.Convergence != ConvergenceKind::None) { LLVM_DEBUG(dbgs() << "DFA Jump Threading: Not jump threading, contains " << "convergent instructions.\n"); ORE->emit([&]() { @@ -1254,6 +1299,7 @@ bool DFAJumpThreading::run(Function &F) { SmallVector<AllSwitchPaths, 2> ThreadableLoops; bool MadeChanges = false; + LoopInfoBroken = false; for (BasicBlock &BB : F) { auto *SI = dyn_cast<SwitchInst>(BB.getTerminator()); @@ -1262,7 +1308,7 @@ bool DFAJumpThreading::run(Function &F) { LLVM_DEBUG(dbgs() << "\nCheck if SwitchInst in BB " << BB.getName() << " is a candidate\n"); - MainSwitch Switch(SI, ORE); + MainSwitch Switch(SI, LI, ORE); if (!Switch.getInstr()) continue; @@ -1275,7 +1321,7 @@ bool DFAJumpThreading::run(Function &F) { if (!Switch.getSelectInsts().empty()) MadeChanges = true; - AllSwitchPaths SwitchPaths(&Switch, ORE); + AllSwitchPaths SwitchPaths(&Switch, ORE, LI); SwitchPaths.run(); if (SwitchPaths.getNumThreadingPaths() > 0) { @@ -1286,10 +1332,15 @@ bool DFAJumpThreading::run(Function &F) { // strict requirement but it can cause buggy behavior if there is an // overlap of blocks in different opportunities. There is a lot of room to // experiment with catching more opportunities here. + // NOTE: To release this contraint, we must handle LoopInfo invalidation break; } } +#ifdef NDEBUG + LI->verify(*DT); +#endif + SmallPtrSet<const Value *, 32> EphValues; if (ThreadableLoops.size() > 0) CodeMetrics::collectEphemeralValues(&F, AC, EphValues); @@ -1298,6 +1349,7 @@ bool DFAJumpThreading::run(Function &F) { TransformDFA Transform(&SwitchPaths, DT, AC, TTI, ORE, EphValues); Transform.run(); MadeChanges = true; + LoopInfoBroken = true; } #ifdef EXPENSIVE_CHECKS @@ -1315,13 +1367,16 @@ PreservedAnalyses DFAJumpThreadingPass::run(Function &F, FunctionAnalysisManager &AM) { AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F); DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); + LoopInfo &LI = AM.getResult<LoopAnalysis>(F); TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F); OptimizationRemarkEmitter ORE(&F); - - if (!DFAJumpThreading(&AC, &DT, &TTI, &ORE).run(F)) + DFAJumpThreading ThreadImpl(&AC, &DT, &LI, &TTI, &ORE); + if (!ThreadImpl.run(F)) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserve<DominatorTreeAnalysis>(); + if (!ThreadImpl.LoopInfoBroken) + PA.preserve<LoopAnalysis>(); return PA; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp index 380d65836553..931606c6f8fe 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -484,7 +484,7 @@ memoryIsNotModifiedBetween(Instruction *FirstI, Instruction *SecondI, static void shortenAssignment(Instruction *Inst, Value *OriginalDest, uint64_t OldOffsetInBits, uint64_t OldSizeInBits, uint64_t NewSizeInBits, bool IsOverwriteEnd) { - const DataLayout &DL = Inst->getModule()->getDataLayout(); + const DataLayout &DL = Inst->getDataLayout(); uint64_t DeadSliceSizeInBits = OldSizeInBits - NewSizeInBits; uint64_t DeadSliceOffsetInBits = OldOffsetInBits + (IsOverwriteEnd ? NewSizeInBits : 0); @@ -526,7 +526,8 @@ static void shortenAssignment(Instruction *Inst, Value *OriginalDest, // returned by getAssignmentMarkers so save a copy of the markers to iterate // over. auto LinkedRange = at::getAssignmentMarkers(Inst); - SmallVector<DPValue *> LinkedDPVAssigns = at::getDPVAssignmentMarkers(Inst); + SmallVector<DbgVariableRecord *> LinkedDVRAssigns = + at::getDVRAssignmentMarkers(Inst); SmallVector<DbgAssignIntrinsic *> Linked(LinkedRange.begin(), LinkedRange.end()); auto InsertAssignForOverlap = [&](auto *Assign) { @@ -554,7 +555,7 @@ static void shortenAssignment(Instruction *Inst, Value *OriginalDest, NewAssign->setKillAddress(); }; for_each(Linked, InsertAssignForOverlap); - for_each(LinkedDPVAssigns, InsertAssignForOverlap); + for_each(LinkedDVRAssigns, InsertAssignForOverlap); } static bool tryToShorten(Instruction *DeadI, int64_t &DeadStart, @@ -634,7 +635,8 @@ static bool tryToShorten(Instruction *DeadI, int64_t &DeadStart, Value *Indices[1] = { ConstantInt::get(DeadWriteLength->getType(), ToRemoveSize)}; Instruction *NewDestGEP = GetElementPtrInst::CreateInBounds( - Type::getInt8Ty(DeadIntrinsic->getContext()), OrigDest, Indices, "", DeadI); + Type::getInt8Ty(DeadIntrinsic->getContext()), OrigDest, Indices, "", + DeadI->getIterator()); NewDestGEP->setDebugLoc(DeadIntrinsic->getDebugLoc()); DeadIntrinsic->setDest(NewDestGEP); } @@ -868,7 +870,7 @@ struct DSEState { PostDominatorTree &PDT, const TargetLibraryInfo &TLI, const LoopInfo &LI) : F(F), AA(AA), EI(DT, &LI), BatchAA(AA, &EI), MSSA(MSSA), DT(DT), - PDT(PDT), TLI(TLI), DL(F.getParent()->getDataLayout()), LI(LI) { + PDT(PDT), TLI(TLI), DL(F.getDataLayout()), LI(LI) { // Collect blocks with throwing instructions not modeled in MemorySSA and // alloc-like objects. unsigned PO = 0; @@ -900,6 +902,16 @@ struct DSEState { }); } + static void pushMemUses(MemoryAccess *Acc, + SmallVectorImpl<MemoryAccess *> &WorkList, + SmallPtrSetImpl<MemoryAccess *> &Visited) { + for (Use &U : Acc->uses()) { + auto *MA = cast<MemoryAccess>(U.getUser()); + if (Visited.insert(MA).second) + WorkList.push_back(MA); + } + }; + LocationSize strengthenLocationSize(const Instruction *I, LocationSize Size) const { if (auto *CB = dyn_cast<CallBase>(I)) { @@ -1155,26 +1167,14 @@ struct DSEState { } /// Returns true if \p Def is not read before returning from the function. - bool isWriteAtEndOfFunction(MemoryDef *Def) { + bool isWriteAtEndOfFunction(MemoryDef *Def, const MemoryLocation &DefLoc) { LLVM_DEBUG(dbgs() << " Check if def " << *Def << " (" << *Def->getMemoryInst() << ") is at the end the function \n"); - - auto MaybeLoc = getLocForWrite(Def->getMemoryInst()); - if (!MaybeLoc) { - LLVM_DEBUG(dbgs() << " ... could not get location for write.\n"); - return false; - } - SmallVector<MemoryAccess *, 4> WorkList; SmallPtrSet<MemoryAccess *, 8> Visited; - auto PushMemUses = [&WorkList, &Visited](MemoryAccess *Acc) { - if (!Visited.insert(Acc).second) - return; - for (Use &U : Acc->uses()) - WorkList.push_back(cast<MemoryAccess>(U.getUser())); - }; - PushMemUses(Def); + + pushMemUses(Def, WorkList, Visited); for (unsigned I = 0; I < WorkList.size(); I++) { if (WorkList.size() >= MemorySSAScanLimit) { LLVM_DEBUG(dbgs() << " ... hit exploration limit.\n"); @@ -1186,22 +1186,22 @@ struct DSEState { // AliasAnalysis does not account for loops. Limit elimination to // candidates for which we can guarantee they always store to the same // memory location. - if (!isGuaranteedLoopInvariant(MaybeLoc->Ptr)) + if (!isGuaranteedLoopInvariant(DefLoc.Ptr)) return false; - PushMemUses(cast<MemoryPhi>(UseAccess)); + pushMemUses(cast<MemoryPhi>(UseAccess), WorkList, Visited); continue; } // TODO: Checking for aliasing is expensive. Consider reducing the amount // of times this is called and/or caching it. Instruction *UseInst = cast<MemoryUseOrDef>(UseAccess)->getMemoryInst(); - if (isReadClobber(*MaybeLoc, UseInst)) { + if (isReadClobber(DefLoc, UseInst)) { LLVM_DEBUG(dbgs() << " ... hit read clobber " << *UseInst << ".\n"); return false; } if (MemoryDef *UseDef = dyn_cast<MemoryDef>(UseAccess)) - PushMemUses(UseDef); + pushMemUses(UseDef, WorkList, Visited); } return true; } @@ -1503,12 +1503,9 @@ struct DSEState { LLVM_DEBUG(dbgs() << " Checking for reads of " << *MaybeDeadAccess << " (" << *MaybeDeadI << ")\n"); - SmallSetVector<MemoryAccess *, 32> WorkList; - auto PushMemUses = [&WorkList](MemoryAccess *Acc) { - for (Use &U : Acc->uses()) - WorkList.insert(cast<MemoryAccess>(U.getUser())); - }; - PushMemUses(MaybeDeadAccess); + SmallVector<MemoryAccess *, 32> WorkList; + SmallPtrSet<MemoryAccess *, 32> Visited; + pushMemUses(MaybeDeadAccess, WorkList, Visited); // Check if DeadDef may be read. for (unsigned I = 0; I < WorkList.size(); I++) { @@ -1532,7 +1529,7 @@ struct DSEState { continue; } LLVM_DEBUG(dbgs() << "\n ... adding PHI uses\n"); - PushMemUses(UseAccess); + pushMemUses(UseAccess, WorkList, Visited); continue; } @@ -1557,7 +1554,7 @@ struct DSEState { if (isNoopIntrinsic(cast<MemoryUseOrDef>(UseAccess)->getMemoryInst())) { LLVM_DEBUG(dbgs() << " ... adding uses of intrinsic\n"); - PushMemUses(UseAccess); + pushMemUses(UseAccess, WorkList, Visited); continue; } @@ -1616,7 +1613,7 @@ struct DSEState { return std::nullopt; } } else - PushMemUses(UseDef); + pushMemUses(UseDef, WorkList, Visited); } } @@ -1697,7 +1694,9 @@ struct DSEState { /// Delete dead memory defs and recursively add their operands to ToRemove if /// they became dead. - void deleteDeadInstruction(Instruction *SI) { + void + deleteDeadInstruction(Instruction *SI, + SmallPtrSetImpl<MemoryAccess *> *Deleted = nullptr) { MemorySSAUpdater Updater(&MSSA); SmallVector<Instruction *, 32> NowDeadInsts; NowDeadInsts.push_back(SI); @@ -1718,6 +1717,8 @@ struct DSEState { if (IsMemDef) { auto *MD = cast<MemoryDef>(MA); SkipStores.insert(MD); + if (Deleted) + Deleted->insert(MD); if (auto *SI = dyn_cast<StoreInst>(MD->getMemoryInst())) { if (SI->getValueOperand()->getType()->isPointerTy()) { const Value *UO = getUnderlyingObject(SI->getValueOperand()); @@ -1815,8 +1816,11 @@ struct DSEState { Instruction *DefI = Def->getMemoryInst(); auto DefLoc = getLocForWrite(DefI); - if (!DefLoc || !isRemovable(DefI)) + if (!DefLoc || !isRemovable(DefI)) { + LLVM_DEBUG(dbgs() << " ... could not get location for write or " + "instruction not removable.\n"); continue; + } // NOTE: Currently eliminating writes at the end of a function is // limited to MemoryDefs with a single underlying object, to save @@ -1827,7 +1831,7 @@ struct DSEState { if (!isInvisibleToCallerAfterRet(UO)) continue; - if (isWriteAtEndOfFunction(Def)) { + if (isWriteAtEndOfFunction(Def, *DefLoc)) { // See through pointer-to-pointer bitcasts LLVM_DEBUG(dbgs() << " ... MemoryDef is not accessed until the end " "of the function\n"); @@ -1919,6 +1923,57 @@ struct DSEState { return true; } + // Check if there is a dominating condition, that implies that the value + // being stored in a ptr is already present in the ptr. + bool dominatingConditionImpliesValue(MemoryDef *Def) { + auto *StoreI = cast<StoreInst>(Def->getMemoryInst()); + BasicBlock *StoreBB = StoreI->getParent(); + Value *StorePtr = StoreI->getPointerOperand(); + Value *StoreVal = StoreI->getValueOperand(); + + DomTreeNode *IDom = DT.getNode(StoreBB)->getIDom(); + if (!IDom) + return false; + + auto *BI = dyn_cast<BranchInst>(IDom->getBlock()->getTerminator()); + if (!BI || !BI->isConditional()) + return false; + + // In case both blocks are the same, it is not possible to determine + // if optimization is possible. (We would not want to optimize a store + // in the FalseBB if condition is true and vice versa.) + if (BI->getSuccessor(0) == BI->getSuccessor(1)) + return false; + + Instruction *ICmpL; + ICmpInst::Predicate Pred; + if (!match(BI->getCondition(), + m_c_ICmp(Pred, + m_CombineAnd(m_Load(m_Specific(StorePtr)), + m_Instruction(ICmpL)), + m_Specific(StoreVal))) || + !ICmpInst::isEquality(Pred)) + return false; + + // In case the else blocks also branches to the if block or the other way + // around it is not possible to determine if the optimization is possible. + if (Pred == ICmpInst::ICMP_EQ && + !DT.dominates(BasicBlockEdge(BI->getParent(), BI->getSuccessor(0)), + StoreBB)) + return false; + + if (Pred == ICmpInst::ICMP_NE && + !DT.dominates(BasicBlockEdge(BI->getParent(), BI->getSuccessor(1)), + StoreBB)) + return false; + + MemoryAccess *LoadAcc = MSSA.getMemoryAccess(ICmpL); + MemoryAccess *ClobAcc = + MSSA.getSkipSelfWalker()->getClobberingMemoryAccess(Def, BatchAA); + + return MSSA.dominates(ClobAcc, LoadAcc); + } + /// \returns true if \p Def is a no-op store, either because it /// directly stores back a loaded value or stores zero to a calloced object. bool storeIsNoop(MemoryDef *Def, const Value *DefUO) { @@ -1949,6 +2004,9 @@ struct DSEState { if (!Store) return false; + if (dominatingConditionImpliesValue(Def)) + return true; + if (auto *LoadI = dyn_cast<LoadInst>(Store->getOperand(0))) { if (LoadI->getPointerOperand() == Store->getOperand(1)) { // Get the defining access for the load. @@ -2049,10 +2107,12 @@ struct DSEState { if (auto *MemSetI = dyn_cast<MemSetInst>(UpperInst)) { if (auto *SI = dyn_cast<StoreInst>(DefInst)) { // MemSetInst must have a write location. - MemoryLocation UpperLoc = *getLocForWrite(UpperInst); + auto UpperLoc = getLocForWrite(UpperInst); + if (!UpperLoc) + return false; int64_t InstWriteOffset = 0; int64_t DepWriteOffset = 0; - auto OR = isOverwrite(UpperInst, DefInst, UpperLoc, *MaybeDefLoc, + auto OR = isOverwrite(UpperInst, DefInst, *UpperLoc, *MaybeDefLoc, InstWriteOffset, DepWriteOffset); Value *StoredByte = isBytewiseValue(SI->getValueOperand(), DL); return StoredByte && StoredByte == MemSetI->getOperand(1) && @@ -2111,7 +2171,12 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit; unsigned PartialLimit = MemorySSAPartialStoreLimit; // Worklist of MemoryAccesses that may be killed by KillingDef. - SetVector<MemoryAccess *> ToCheck; + SmallSetVector<MemoryAccess *, 8> ToCheck; + // Track MemoryAccesses that have been deleted in the loop below, so we can + // skip them. Don't use SkipStores for this, which may contain reused + // MemoryAccess addresses. + SmallPtrSet<MemoryAccess *, 8> Deleted; + [[maybe_unused]] unsigned OrigNumSkipStores = State.SkipStores.size(); ToCheck.insert(KillingDef->getDefiningAccess()); bool Shortend = false; @@ -2119,7 +2184,7 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, // Check if MemoryAccesses in the worklist are killed by KillingDef. for (unsigned I = 0; I < ToCheck.size(); I++) { MemoryAccess *Current = ToCheck[I]; - if (State.SkipStores.count(Current)) + if (Deleted.contains(Current)) continue; std::optional<MemoryAccess *> MaybeDeadAccess = State.getDomMemoryDef( @@ -2166,7 +2231,7 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, continue; LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *DeadI << "\n KILLER: " << *KillingI << '\n'); - State.deleteDeadInstruction(DeadI); + State.deleteDeadInstruction(DeadI, &Deleted); ++NumFastStores; MadeChange = true; } else { @@ -2203,7 +2268,7 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, Shortend = true; // Remove killing store and remove any outstanding overlap // intervals for the updated store. - State.deleteDeadInstruction(KillingSI); + State.deleteDeadInstruction(KillingSI, &Deleted); auto I = State.IOLs.find(DeadSI->getParent()); if (I != State.IOLs.end()) I->second.erase(DeadSI); @@ -2215,13 +2280,16 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, if (OR == OW_Complete) { LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *DeadI << "\n KILLER: " << *KillingI << '\n'); - State.deleteDeadInstruction(DeadI); + State.deleteDeadInstruction(DeadI, &Deleted); ++NumFastStores; MadeChange = true; } } } + assert(State.SkipStores.size() - OrigNumSkipStores == Deleted.size() && + "SkipStores and Deleted out of sync?"); + // Check if the store is a no-op. if (!Shortend && State.storeIsNoop(KillingDef, KillingUndObj)) { LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n DEAD: " << *KillingI diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/DivRemPairs.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/DivRemPairs.cpp index 57d3f312186e..d8aea1e810e9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/DivRemPairs.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/DivRemPairs.cpp @@ -215,6 +215,7 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, RemInst = RealRem; // And replace the original instruction with the new one. OrigRemInst->replaceAllUsesWith(RealRem); + RealRem->setDebugLoc(OrigRemInst->getDebugLoc()); OrigRemInst->eraseFromParent(); NumRecomposed++; // Note that we have left ((X / Y) * Y) around. @@ -366,7 +367,9 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, if (!DivDominates) DivInst->moveBefore(RemInst); Mul->insertAfter(RemInst); + Mul->setDebugLoc(RemInst->getDebugLoc()); Sub->insertAfter(Mul); + Sub->setDebugLoc(RemInst->getDebugLoc()); // If DivInst has the exact flag, remove it. Otherwise this optimization // may replace a well-defined value 'X % Y' with poison. @@ -381,16 +384,19 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, // %mul = mul %div, 1 // %mul = undef // %rem = sub %x, %mul // %rem = undef - undef = undef // If X is not frozen, %rem becomes undef after transformation. - // TODO: We need a undef-specific checking function in ValueTracking - if (!isGuaranteedNotToBeUndefOrPoison(X, nullptr, DivInst, &DT)) { - auto *FrX = new FreezeInst(X, X->getName() + ".frozen", DivInst); + if (!isGuaranteedNotToBeUndef(X, nullptr, DivInst, &DT)) { + auto *FrX = + new FreezeInst(X, X->getName() + ".frozen", DivInst->getIterator()); + FrX->setDebugLoc(DivInst->getDebugLoc()); DivInst->setOperand(0, FrX); Sub->setOperand(0, FrX); } // Same for Y. If X = 1 and Y = (undef | 1), %rem in src is either 1 or 0, // but %rem in tgt can be one of many integer values. - if (!isGuaranteedNotToBeUndefOrPoison(Y, nullptr, DivInst, &DT)) { - auto *FrY = new FreezeInst(Y, Y->getName() + ".frozen", DivInst); + if (!isGuaranteedNotToBeUndef(Y, nullptr, DivInst, &DT)) { + auto *FrY = + new FreezeInst(Y, Y->getName() + ".frozen", DivInst->getIterator()); + FrY->setDebugLoc(DivInst->getDebugLoc()); DivInst->setOperand(1, FrY); Mul->setOperand(1, FrY); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/EarlyCSE.cpp index f736d429cb63..cf11f5bc885a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -1833,7 +1833,7 @@ PreservedAnalyses EarlyCSEPass::run(Function &F, auto *MSSA = UseMemorySSA ? &AM.getResult<MemorySSAAnalysis>(F).getMSSA() : nullptr; - EarlyCSE CSE(F.getParent()->getDataLayout(), TLI, TTI, DT, AC, MSSA); + EarlyCSE CSE(F.getDataLayout(), TLI, TTI, DT, AC, MSSA); if (!CSE.run()) return PreservedAnalyses::all(); @@ -1887,7 +1887,7 @@ public: auto *MSSA = UseMemorySSA ? &getAnalysis<MemorySSAWrapperPass>().getMSSA() : nullptr; - EarlyCSE CSE(F.getParent()->getDataLayout(), TLI, TTI, DT, AC, MSSA); + EarlyCSE CSE(F.getDataLayout(), TLI, TTI, DT, AC, MSSA); return CSE.run(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/FlattenCFGPass.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/FlattenCFGPass.cpp index ad2041cd4253..213d0f389c2e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/FlattenCFGPass.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/FlattenCFGPass.cpp @@ -21,7 +21,7 @@ using namespace llvm; -#define DEBUG_TYPE "flattencfg" +#define DEBUG_TYPE "flatten-cfg" namespace { struct FlattenCFGLegacyPass : public FunctionPass { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Float2Int.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Float2Int.cpp index ccca8bcc1a56..a4a1438dbe41 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Float2Int.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Float2Int.cpp @@ -311,7 +311,7 @@ void Float2IntPass::walkForwards() { } // If there is a valid transform to be done, do it. -bool Float2IntPass::validateAndTransform() { +bool Float2IntPass::validateAndTransform(const DataLayout &DL) { bool MadeChange = false; // Iterate over every disjoint partition of the def-use graph. @@ -359,9 +359,7 @@ bool Float2IntPass::validateAndTransform() { // The number of bits required is the maximum of the upper and // lower limits, plus one so it can be signed. - unsigned MinBW = std::max(R.getLower().getSignificantBits(), - R.getUpper().getSignificantBits()) + - 1; + unsigned MinBW = R.getMinSignedBits() + 1; LLVM_DEBUG(dbgs() << "F2I: MinBitwidth=" << MinBW << ", R: " << R << "\n"); // If we've run off the realms of the exactly representable integers, @@ -376,15 +374,23 @@ bool Float2IntPass::validateAndTransform() { LLVM_DEBUG(dbgs() << "F2I: Value not guaranteed to be representable!\n"); continue; } - if (MinBW > 64) { - LLVM_DEBUG( - dbgs() << "F2I: Value requires more than 64 bits to represent!\n"); - continue; - } - // OK, R is known to be representable. Now pick a type for it. - // FIXME: Pick the smallest legal type that will fit. - Type *Ty = (MinBW > 32) ? Type::getInt64Ty(*Ctx) : Type::getInt32Ty(*Ctx); + // OK, R is known to be representable. + // Pick the smallest legal type that will fit. + Type *Ty = DL.getSmallestLegalIntType(*Ctx, MinBW); + if (!Ty) { + // Every supported target supports 64-bit and 32-bit integers, + // so fallback to a 32 or 64-bit integer if the value fits. + if (MinBW <= 32) { + Ty = Type::getInt32Ty(*Ctx); + } else if (MinBW <= 64) { + Ty = Type::getInt64Ty(*Ctx); + } else { + LLVM_DEBUG(dbgs() << "F2I: Value requires more bits to represent than " + "the target supports!\n"); + continue; + } + } for (auto MI = ECs.member_begin(It), ME = ECs.member_end(); MI != ME; ++MI) @@ -491,7 +497,8 @@ bool Float2IntPass::runImpl(Function &F, const DominatorTree &DT) { walkBackwards(); walkForwards(); - bool Modified = validateAndTransform(); + const DataLayout &DL = F.getDataLayout(); + bool Modified = validateAndTransform(DL); if (Modified) cleanup(); return Modified; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVN.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVN.cpp index e36578f3de7a..db39d8621d07 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVN.cpp @@ -33,6 +33,7 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionPrecedenceTracking.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" @@ -419,7 +420,7 @@ GVNPass::ValueTable::createExtractvalueExpr(ExtractValueInst *EI) { GVNPass::Expression GVNPass::ValueTable::createGEPExpr(GetElementPtrInst *GEP) { Expression E; Type *PtrTy = GEP->getType()->getScalarType(); - const DataLayout &DL = GEP->getModule()->getDataLayout(); + const DataLayout &DL = GEP->getDataLayout(); unsigned BitWidth = DL.getIndexTypeSizeInBits(PtrTy); MapVector<Value *, APInt> VariableOffsets; APInt ConstantOffset(BitWidth, 0); @@ -725,6 +726,69 @@ void GVNPass::ValueTable::verifyRemoved(const Value *V) const { } //===----------------------------------------------------------------------===// +// LeaderMap External Functions +//===----------------------------------------------------------------------===// + +/// Push a new Value to the LeaderTable onto the list for its value number. +void GVNPass::LeaderMap::insert(uint32_t N, Value *V, const BasicBlock *BB) { + LeaderListNode &Curr = NumToLeaders[N]; + if (!Curr.Entry.Val) { + Curr.Entry.Val = V; + Curr.Entry.BB = BB; + return; + } + + LeaderListNode *Node = TableAllocator.Allocate<LeaderListNode>(); + Node->Entry.Val = V; + Node->Entry.BB = BB; + Node->Next = Curr.Next; + Curr.Next = Node; +} + +/// Scan the list of values corresponding to a given +/// value number, and remove the given instruction if encountered. +void GVNPass::LeaderMap::erase(uint32_t N, Instruction *I, + const BasicBlock *BB) { + LeaderListNode *Prev = nullptr; + LeaderListNode *Curr = &NumToLeaders[N]; + + while (Curr && (Curr->Entry.Val != I || Curr->Entry.BB != BB)) { + Prev = Curr; + Curr = Curr->Next; + } + + if (!Curr) + return; + + if (Prev) { + Prev->Next = Curr->Next; + } else { + if (!Curr->Next) { + Curr->Entry.Val = nullptr; + Curr->Entry.BB = nullptr; + } else { + LeaderListNode *Next = Curr->Next; + Curr->Entry.Val = Next->Entry.Val; + Curr->Entry.BB = Next->Entry.BB; + Curr->Next = Next->Next; + } + } +} + +void GVNPass::LeaderMap::verifyRemoved(const Value *V) const { + // Walk through the value number scope to make sure the instruction isn't + // ferreted away in it. + for (const auto &I : NumToLeaders) { + (void)I; + assert(I.second.Entry.Val != V && "Inst still in value numbering scope!"); + assert( + std::none_of(leader_iterator(&I.second), leader_iterator(nullptr), + [=](const LeaderTableEntry &E) { return E.Val == V; }) && + "Inst still in value numbering scope!"); + } +} + +//===----------------------------------------------------------------------===// // GVN Pass //===----------------------------------------------------------------------===// @@ -1008,7 +1072,7 @@ Value *AvailableValue::MaterializeAdjustedValue(LoadInst *Load, GVNPass &gvn) const { Value *Res; Type *LoadTy = Load->getType(); - const DataLayout &DL = Load->getModule()->getDataLayout(); + const DataLayout &DL = Load->getDataLayout(); if (isSimpleValue()) { Res = getSimpleValue(); if (Res->getType() != LoadTy) { @@ -1056,7 +1120,8 @@ Value *AvailableValue::MaterializeAdjustedValue(LoadInst *Load, // Introduce a new value select for a load from an eligible pointer select. SelectInst *Sel = getSelectValue(); assert(V1 && V2 && "both value operands of the select must be present"); - Res = SelectInst::Create(Sel->getCondition(), V1, V2, "", Sel); + Res = + SelectInst::Create(Sel->getCondition(), V1, V2, "", Sel->getIterator()); } else { llvm_unreachable("Should not materialize value from dead block"); } @@ -1173,7 +1238,7 @@ GVNPass::AnalyzeLoadAvailability(LoadInst *Load, MemDepResult DepInfo, Instruction *DepInst = DepInfo.getInst(); - const DataLayout &DL = Load->getModule()->getDataLayout(); + const DataLayout &DL = Load->getDataLayout(); if (DepInfo.isClobber()) { // If the dependence is to a store that writes to a superset of the bits // read by the load, we can extract the bits we need for the load from the @@ -1412,10 +1477,10 @@ void GVNPass::eliminatePartiallyRedundantLoad( BasicBlock *UnavailableBlock = AvailableLoad.first; Value *LoadPtr = AvailableLoad.second; - auto *NewLoad = - new LoadInst(Load->getType(), LoadPtr, Load->getName() + ".pre", - Load->isVolatile(), Load->getAlign(), Load->getOrdering(), - Load->getSyncScopeID(), UnavailableBlock->getTerminator()); + auto *NewLoad = new LoadInst( + Load->getType(), LoadPtr, Load->getName() + ".pre", Load->isVolatile(), + Load->getAlign(), Load->getOrdering(), Load->getSyncScopeID(), + UnavailableBlock->getTerminator()->getIterator()); NewLoad->setDebugLoc(Load->getDebugLoc()); if (MSSAU) { auto *NewAccess = MSSAU->createMemoryAccessInBB( @@ -1465,7 +1530,7 @@ void GVNPass::eliminatePartiallyRedundantLoad( OldLoad->replaceAllUsesWith(NewLoad); replaceValuesPerBlockEntry(ValuesPerBlock, OldLoad, NewLoad); if (uint32_t ValNo = VN.lookup(OldLoad, false)) - removeFromLeaderTable(ValNo, OldLoad, OldLoad->getParent()); + LeaderTable.erase(ValNo, OldLoad, OldLoad->getParent()); VN.erase(OldLoad); removeInstruction(OldLoad); } @@ -1658,7 +1723,7 @@ bool GVNPass::PerformLoadPRE(LoadInst *Load, AvailValInBlkVect &ValuesPerBlock, // Check if the load can safely be moved to all the unavailable predecessors. bool CanDoPRE = true; - const DataLayout &DL = Load->getModule()->getDataLayout(); + const DataLayout &DL = Load->getDataLayout(); SmallVector<Instruction*, 8> NewInsts; for (auto &PredLoad : PredLoads) { BasicBlock *UnavailablePred = PredLoad.first; @@ -1994,8 +2059,9 @@ bool GVNPass::processAssumeIntrinsic(AssumeInst *IntrinsicI) { // Insert a new store to null instruction before the load to indicate that // this code is not reachable. FIXME: We could insert unreachable // instruction directly because we can modify the CFG. - auto *NewS = new StoreInst(PoisonValue::get(Int8Ty), - Constant::getNullValue(PtrTy), IntrinsicI); + auto *NewS = + new StoreInst(PoisonValue::get(Int8Ty), Constant::getNullValue(PtrTy), + IntrinsicI->getIterator()); if (MSSAU) { const MemoryUseOrDef *FirstNonDom = nullptr; const auto *AL = @@ -2201,10 +2267,9 @@ GVNPass::ValueTable::assignExpNewValueNum(Expression &Exp) { /// defined in \p BB. bool GVNPass::ValueTable::areAllValsInBB(uint32_t Num, const BasicBlock *BB, GVNPass &Gvn) { - LeaderTableEntry *Vals = &Gvn.LeaderTable[Num]; - while (Vals && Vals->BB == BB) - Vals = Vals->Next; - return !Vals; + return all_of( + Gvn.LeaderTable.getLeaders(Num), + [=](const LeaderMap::LeaderTableEntry &L) { return L.BB == BB; }); } /// Wrap phiTranslateImpl to provide caching functionality. @@ -2226,12 +2291,11 @@ bool GVNPass::ValueTable::areCallValsEqual(uint32_t Num, uint32_t NewNum, const BasicBlock *PhiBlock, GVNPass &Gvn) { CallInst *Call = nullptr; - LeaderTableEntry *Vals = &Gvn.LeaderTable[Num]; - while (Vals) { - Call = dyn_cast<CallInst>(Vals->Val); + auto Leaders = Gvn.LeaderTable.getLeaders(Num); + for (const auto &Entry : Leaders) { + Call = dyn_cast<CallInst>(Entry.Val); if (Call && Call->getParent() == PhiBlock) break; - Vals = Vals->Next; } if (AA->doesNotAccessMemory(Call)) @@ -2324,23 +2388,17 @@ void GVNPass::ValueTable::eraseTranslateCacheEntry( // question. This is fast because dominator tree queries consist of only // a few comparisons of DFS numbers. Value *GVNPass::findLeader(const BasicBlock *BB, uint32_t num) { - LeaderTableEntry Vals = LeaderTable[num]; - if (!Vals.Val) return nullptr; + auto Leaders = LeaderTable.getLeaders(num); + if (Leaders.empty()) + return nullptr; Value *Val = nullptr; - if (DT->dominates(Vals.BB, BB)) { - Val = Vals.Val; - if (isa<Constant>(Val)) return Val; - } - - LeaderTableEntry* Next = Vals.Next; - while (Next) { - if (DT->dominates(Next->BB, BB)) { - if (isa<Constant>(Next->Val)) return Next->Val; - if (!Val) Val = Next->Val; + for (const auto &Entry : Leaders) { + if (DT->dominates(Entry.BB, BB)) { + Val = Entry.Val; + if (isa<Constant>(Val)) + return Val; } - - Next = Next->Next; } return Val; @@ -2417,6 +2475,10 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS, if (isa<Constant>(LHS) || (isa<Argument>(LHS) && !isa<Constant>(RHS))) std::swap(LHS, RHS); assert((isa<Argument>(LHS) || isa<Instruction>(LHS)) && "Unexpected value!"); + const DataLayout &DL = + isa<Argument>(LHS) + ? cast<Argument>(LHS)->getParent()->getDataLayout() + : cast<Instruction>(LHS)->getDataLayout(); // If there is no obvious reason to prefer the left-hand side over the // right-hand side, ensure the longest lived term is on the right-hand side, @@ -2443,23 +2505,32 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS, // using the leader table is about compiling faster, not optimizing better). // The leader table only tracks basic blocks, not edges. Only add to if we // have the simple case where the edge dominates the end. - if (RootDominatesEnd && !isa<Instruction>(RHS)) - addToLeaderTable(LVN, RHS, Root.getEnd()); + if (RootDominatesEnd && !isa<Instruction>(RHS) && + canReplacePointersIfEqual(LHS, RHS, DL)) + LeaderTable.insert(LVN, RHS, Root.getEnd()); // Replace all occurrences of 'LHS' with 'RHS' everywhere in the scope. As // LHS always has at least one use that is not dominated by Root, this will // never do anything if LHS has only one use. if (!LHS->hasOneUse()) { + // Create a callback that captures the DL. + auto canReplacePointersCallBack = [&DL](const Use &U, const Value *To) { + return canReplacePointersInUseIfEqual(U, To, DL); + }; unsigned NumReplacements = DominatesByEdge - ? replaceDominatedUsesWith(LHS, RHS, *DT, Root) - : replaceDominatedUsesWith(LHS, RHS, *DT, Root.getStart()); - - Changed |= NumReplacements > 0; - NumGVNEqProp += NumReplacements; - // Cached information for anything that uses LHS will be invalid. - if (MD) - MD->invalidateCachedPointerInfo(LHS); + ? replaceDominatedUsesWithIf(LHS, RHS, *DT, Root, + canReplacePointersCallBack) + : replaceDominatedUsesWithIf(LHS, RHS, *DT, Root.getStart(), + canReplacePointersCallBack); + + if (NumReplacements > 0) { + Changed = true; + NumGVNEqProp += NumReplacements; + // Cached information for anything that uses LHS will be invalid. + if (MD) + MD->invalidateCachedPointerInfo(LHS); + } } // Now try to deduce additional equalities from this one. For example, if @@ -2530,7 +2601,7 @@ bool GVNPass::propagateEquality(Value *LHS, Value *RHS, // The leader table only tracks basic blocks, not edges. Only add to if we // have the simple case where the edge dominates the end. if (RootDominatesEnd) - addToLeaderTable(Num, NotVal, Root.getEnd()); + LeaderTable.insert(Num, NotVal, Root.getEnd()); continue; } @@ -2550,7 +2621,7 @@ bool GVNPass::processInstruction(Instruction *I) { // to value numbering it. Value numbering often exposes redundancies, for // example if it determines that %y is equal to %x then the instruction // "%z = and i32 %x, %y" becomes "%z = and i32 %x, %x" which we now simplify. - const DataLayout &DL = I->getModule()->getDataLayout(); + const DataLayout &DL = I->getDataLayout(); if (Value *V = simplifyInstruction(I, {DL, TLI, DT, AC})) { bool Changed = false; if (!I->use_empty()) { @@ -2580,7 +2651,7 @@ bool GVNPass::processInstruction(Instruction *I) { return true; unsigned Num = VN.lookupOrAdd(Load); - addToLeaderTable(Num, Load, Load->getParent()); + LeaderTable.insert(Num, Load, Load->getParent()); return false; } @@ -2622,8 +2693,8 @@ bool GVNPass::processInstruction(Instruction *I) { // Remember how many outgoing edges there are to every successor. SmallDenseMap<BasicBlock *, unsigned, 16> SwitchEdges; - for (unsigned i = 0, n = SI->getNumSuccessors(); i != n; ++i) - ++SwitchEdges[SI->getSuccessor(i)]; + for (BasicBlock *Succ : successors(Parent)) + ++SwitchEdges[Succ]; for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i) { @@ -2648,7 +2719,7 @@ bool GVNPass::processInstruction(Instruction *I) { // Allocations are always uniquely numbered, so we can save time and memory // by fast failing them. if (isa<AllocaInst>(I) || I->isTerminator() || isa<PHINode>(I)) { - addToLeaderTable(Num, I, I->getParent()); + LeaderTable.insert(Num, I, I->getParent()); return false; } @@ -2656,7 +2727,7 @@ bool GVNPass::processInstruction(Instruction *I) { // need to do a lookup to see if the number already exists // somewhere in the domtree: it can't! if (Num >= NextNum) { - addToLeaderTable(Num, I, I->getParent()); + LeaderTable.insert(Num, I, I->getParent()); return false; } @@ -2665,7 +2736,7 @@ bool GVNPass::processInstruction(Instruction *I) { Value *Repl = findLeader(I->getParent(), Num); if (!Repl) { // Failure, just remember this instance for future use. - addToLeaderTable(Num, I, I->getParent()); + LeaderTable.insert(Num, I, I->getParent()); return false; } @@ -2706,7 +2777,7 @@ bool GVNPass::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, bool Changed = false; bool ShouldContinue = true; - DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); // Merge unconditional branches, allowing PRE to catch more // optimization opportunities. for (BasicBlock &BB : llvm::make_early_inc_range(F)) { @@ -2716,6 +2787,7 @@ bool GVNPass::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, Changed |= removedBlock; } + DTU.flush(); unsigned Iteration = 0; while (ShouldContinue) { @@ -2859,7 +2931,7 @@ bool GVNPass::performScalarPREInsertion(Instruction *Instr, BasicBlock *Pred, VN.add(Instr, Num); // Update the availability map to include the new instruction. - addToLeaderTable(Num, Instr, Pred); + LeaderTable.insert(Num, Instr, Pred); return true; } @@ -3010,13 +3082,13 @@ bool GVNPass::performScalarPRE(Instruction *CurInst) { // After creating a new PHI for ValNo, the phi translate result for ValNo will // be changed, so erase the related stale entries in phi translate cache. VN.eraseTranslateCacheEntry(ValNo, *CurrentBlock); - addToLeaderTable(ValNo, Phi, CurrentBlock); + LeaderTable.insert(ValNo, Phi, CurrentBlock); Phi->setDebugLoc(CurInst->getDebugLoc()); CurInst->replaceAllUsesWith(Phi); if (MD && Phi->getType()->isPtrOrPtrVectorTy()) MD->invalidateCachedPointerInfo(Phi); VN.erase(CurInst); - removeFromLeaderTable(ValNo, CurInst, CurrentBlock); + LeaderTable.erase(ValNo, CurInst, CurrentBlock); LLVM_DEBUG(dbgs() << "GVN PRE removed: " << *CurInst << '\n'); removeInstruction(CurInst); @@ -3110,7 +3182,6 @@ void GVNPass::cleanupGlobalSets() { VN.clear(); LeaderTable.clear(); BlockRPONumber.clear(); - TableAllocator.Reset(); ICF->clear(); InvalidBlockRPONumbers = true; } @@ -3130,18 +3201,7 @@ void GVNPass::removeInstruction(Instruction *I) { /// internal data structures. void GVNPass::verifyRemoved(const Instruction *Inst) const { VN.verifyRemoved(Inst); - - // Walk through the value number scope to make sure the instruction isn't - // ferreted away in it. - for (const auto &I : LeaderTable) { - const LeaderTableEntry *Node = &I.second; - assert(Node->Val != Inst && "Inst still in value numbering scope!"); - - while (Node->Next) { - Node = Node->Next; - assert(Node->Val != Inst && "Inst still in value numbering scope!"); - } - } + LeaderTable.verifyRemoved(Inst); } /// BB is declared dead, which implied other blocks become dead as well. This @@ -3268,7 +3328,7 @@ void GVNPass::assignValNumForDeadCode() { for (BasicBlock *BB : DeadBlocks) { for (Instruction &Inst : *BB) { unsigned ValNum = VN.lookupOrAdd(&Inst); - addToLeaderTable(ValNum, &Inst, BB); + LeaderTable.insert(ValNum, &Inst, BB); } } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNHoist.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNHoist.cpp index b564f00eb9d1..b5333c532280 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNHoist.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNHoist.cpp @@ -238,18 +238,6 @@ public: const VNtoInsns &getStoreVNTable() const { return VNtoCallsStores; } }; -static void combineKnownMetadata(Instruction *ReplInst, Instruction *I) { - static const unsigned KnownIDs[] = {LLVMContext::MD_tbaa, - LLVMContext::MD_alias_scope, - LLVMContext::MD_noalias, - LLVMContext::MD_range, - LLVMContext::MD_fpmath, - LLVMContext::MD_invariant_load, - LLVMContext::MD_invariant_group, - LLVMContext::MD_access_group}; - combineMetadata(ReplInst, I, KnownIDs, true); -} - // This pass hoists common computations across branches sharing common // dominator. The primary goal is to reduce the code size, and in some // cases reduce critical path (by exposing more ILP). @@ -951,6 +939,14 @@ void GVNHoist::makeGepsAvailable(Instruction *Repl, BasicBlock *HoistPt, OtherGep = cast<GetElementPtrInst>( cast<StoreInst>(OtherInst)->getPointerOperand()); ClonedGep->andIRFlags(OtherGep); + + // Merge debug locations of GEPs, because the hoisted GEP replaces those + // in branches. When cloning, ClonedGep preserves the debug location of + // Gepd, so Gep is skipped to avoid merging it twice. + if (OtherGep != Gep) { + ClonedGep->applyMergedLocation(ClonedGep->getDebugLoc(), + OtherGep->getDebugLoc()); + } } // Replace uses of Gep with ClonedGep in Repl. @@ -988,8 +984,8 @@ unsigned GVNHoist::rauw(const SmallVecInsn &Candidates, Instruction *Repl, MSSAUpdater->removeMemoryAccess(OldMA); } + combineMetadataForCSE(Repl, I, true); Repl->andIRFlags(I); - combineKnownMetadata(Repl, I); I->replaceAllUsesWith(Repl); // Also invalidate the Alias Analysis cache. MD->removeInstruction(I); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNSink.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNSink.cpp index 2b38831139a5..3dfa2dd9df27 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNSink.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GVNSink.cpp @@ -132,7 +132,7 @@ public: ActiveBlocks.remove(BB); continue; } - Insts.push_back(BB->getTerminator()->getPrevNode()); + Insts.push_back(BB->getTerminator()->getPrevNonDebugInstruction()); } if (Insts.empty()) Fail = true; @@ -168,7 +168,7 @@ public: if (Inst == &Inst->getParent()->front()) ActiveBlocks.remove(Inst->getParent()); else - NewInsts.push_back(Inst->getPrevNode()); + NewInsts.push_back(Inst->getPrevNonDebugInstruction()); } if (NewInsts.empty()) { Fail = true; @@ -226,12 +226,22 @@ class ModelledPHI { public: ModelledPHI() = default; - ModelledPHI(const PHINode *PN) { - // BasicBlock comes first so we sort by basic block pointer order, then by value pointer order. - SmallVector<std::pair<BasicBlock *, Value *>, 4> Ops; + ModelledPHI(const PHINode *PN, + const DenseMap<const BasicBlock *, unsigned> &BlockOrder) { + // BasicBlock comes first so we sort by basic block pointer order, + // then by value pointer order. No need to call `verifyModelledPHI` + // As the Values and Blocks are populated in a deterministic order. + using OpsType = std::pair<BasicBlock *, Value *>; + SmallVector<OpsType, 4> Ops; for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) Ops.push_back({PN->getIncomingBlock(I), PN->getIncomingValue(I)}); - llvm::sort(Ops); + + auto ComesBefore = [BlockOrder](OpsType O1, OpsType O2) { + return BlockOrder.lookup(O1.first) < BlockOrder.lookup(O2.first); + }; + // Sort in a deterministic order. + llvm::sort(Ops, ComesBefore); + for (auto &P : Ops) { Blocks.push_back(P.first); Values.push_back(P.second); @@ -247,16 +257,38 @@ public: return M; } + void + verifyModelledPHI(const DenseMap<const BasicBlock *, unsigned> &BlockOrder) { + assert(Values.size() > 1 && Blocks.size() > 1 && + "Modelling PHI with less than 2 values"); + auto ComesBefore = [BlockOrder](const BasicBlock *BB1, + const BasicBlock *BB2) { + return BlockOrder.lookup(BB1) < BlockOrder.lookup(BB2); + }; + assert(llvm::is_sorted(Blocks, ComesBefore)); + int C = 0; + for (const Value *V : Values) { + if (!isa<UndefValue>(V)) { + assert(cast<Instruction>(V)->getParent() == Blocks[C]); + (void)C; + } + C++; + } + } /// Create a PHI from an array of incoming values and incoming blocks. - template <typename VArray, typename BArray> - ModelledPHI(const VArray &V, const BArray &B) { + ModelledPHI(SmallVectorImpl<Instruction *> &V, + SmallSetVector<BasicBlock *, 4> &B, + const DenseMap<const BasicBlock *, unsigned> &BlockOrder) { + // The order of Values and Blocks are already ordered by the caller. llvm::copy(V, std::back_inserter(Values)); llvm::copy(B, std::back_inserter(Blocks)); + verifyModelledPHI(BlockOrder); } /// Create a PHI from [I[OpNum] for I in Insts]. - template <typename BArray> - ModelledPHI(ArrayRef<Instruction *> Insts, unsigned OpNum, const BArray &B) { + /// TODO: Figure out a way to verifyModelledPHI in this constructor. + ModelledPHI(ArrayRef<Instruction *> Insts, unsigned OpNum, + SmallSetVector<BasicBlock *, 4> &B) { llvm::copy(B, std::back_inserter(Blocks)); for (auto *I : Insts) Values.push_back(I->getOperand(OpNum)); @@ -297,7 +329,8 @@ public: // Hash functor unsigned hash() const { - return (unsigned)hash_combine_range(Values.begin(), Values.end()); + // Is deterministic because Values are saved in a specific order. + return (unsigned)hash_combine_range(Values.begin(), Values.end()); } bool operator==(const ModelledPHI &Other) const { @@ -566,7 +599,7 @@ public: class GVNSink { public: - GVNSink() = default; + GVNSink() {} bool run(Function &F) { LLVM_DEBUG(dbgs() << "GVNSink: running on function @" << F.getName() @@ -575,6 +608,16 @@ public: unsigned NumSunk = 0; ReversePostOrderTraversal<Function*> RPOT(&F); VN.setReachableBBs(BasicBlocksSet(RPOT.begin(), RPOT.end())); + // Populate reverse post-order to order basic blocks in deterministic + // order. Any arbitrary ordering will work in this case as long as they are + // deterministic. The node ordering of newly created basic blocks + // are irrelevant because RPOT(for computing sinkable candidates) is also + // obtained ahead of time and only their order are relevant for this pass. + unsigned NodeOrdering = 0; + RPOTOrder[*RPOT.begin()] = ++NodeOrdering; + for (auto *BB : RPOT) + if (!pred_empty(BB)) + RPOTOrder[BB] = ++NodeOrdering; for (auto *N : RPOT) NumSunk += sinkBB(N); @@ -583,6 +626,7 @@ public: private: ValueTable VN; + DenseMap<const BasicBlock *, unsigned> RPOTOrder; bool shouldAvoidSinkingInstruction(Instruction *I) { // These instructions may change or break semantics if moved. @@ -603,7 +647,7 @@ private: void analyzeInitialPHIs(BasicBlock *BB, ModelledPHISet &PHIs, SmallPtrSetImpl<Value *> &PHIContents) { for (PHINode &PN : BB->phis()) { - auto MPHI = ModelledPHI(&PN); + auto MPHI = ModelledPHI(&PN, RPOTOrder); PHIs.insert(MPHI); for (auto *V : MPHI.getValues()) PHIContents.insert(V); @@ -655,8 +699,7 @@ GVNSink::analyzeInstructionForSinking(LockstepReverseIterator &LRI, return std::nullopt; VNums[N]++; } - unsigned VNumToSink = - std::max_element(VNums.begin(), VNums.end(), llvm::less_second())->first; + unsigned VNumToSink = llvm::max_element(VNums, llvm::less_second())->first; if (VNums[VNumToSink] == 1) // Can't sink anything! @@ -692,7 +735,7 @@ GVNSink::analyzeInstructionForSinking(LockstepReverseIterator &LRI, } // The sunk instruction's results. - ModelledPHI NewPHI(NewInsts, ActivePreds); + ModelledPHI NewPHI(NewInsts, ActivePreds, RPOTOrder); // Does sinking this instruction render previous PHIs redundant? if (NeededPHIs.erase(NewPHI)) @@ -720,12 +763,11 @@ GVNSink::analyzeInstructionForSinking(LockstepReverseIterator &LRI, // try and continue making progress. Instruction *I0 = NewInsts[0]; - // If all instructions that are going to participate don't have the same - // number of operands, we can't do any useful PHI analysis for all operands. - auto hasDifferentNumOperands = [&I0](Instruction *I) { - return I->getNumOperands() != I0->getNumOperands(); + auto isNotSameOperation = [&I0](Instruction *I) { + return !I0->isSameOperationAs(I); }; - if (any_of(NewInsts, hasDifferentNumOperands)) + + if (any_of(NewInsts, isNotSameOperation)) return std::nullopt; for (unsigned OpNum = 0, E = I0->getNumOperands(); OpNum != E; ++OpNum) { @@ -767,6 +809,9 @@ unsigned GVNSink::sinkBB(BasicBlock *BBEnd) { BBEnd->printAsOperand(dbgs()); dbgs() << "\n"); SmallVector<BasicBlock *, 4> Preds; for (auto *B : predecessors(BBEnd)) { + // Bailout on basic blocks without predecessor(PR42346). + if (!RPOTOrder.count(B)) + return 0; auto *T = B->getTerminator(); if (isa<BranchInst>(T) || isa<SwitchInst>(T)) Preds.push_back(B); @@ -775,7 +820,11 @@ unsigned GVNSink::sinkBB(BasicBlock *BBEnd) { } if (Preds.size() < 2) return 0; - llvm::sort(Preds); + auto ComesBefore = [this](const BasicBlock *BB1, const BasicBlock *BB2) { + return RPOTOrder.lookup(BB1) < RPOTOrder.lookup(BB2); + }; + // Sort in a deterministic order. + llvm::sort(Preds, ComesBefore); unsigned NumOrigPreds = Preds.size(); // We can only sink instructions through unconditional branches. @@ -834,7 +883,7 @@ void GVNSink::sinkLastInstruction(ArrayRef<BasicBlock *> Blocks, BasicBlock *BBEnd) { SmallVector<Instruction *, 4> Insts; for (BasicBlock *BB : Blocks) - Insts.push_back(BB->getTerminator()->getPrevNode()); + Insts.push_back(BB->getTerminator()->getPrevNonDebugInstruction()); Instruction *I0 = Insts.front(); SmallVector<Value *, 4> NewOperands; @@ -872,8 +921,10 @@ void GVNSink::sinkLastInstruction(ArrayRef<BasicBlock *> Blocks, } for (auto *I : Insts) - if (I != I0) + if (I != I0) { I->replaceAllUsesWith(I0); + I0->applyMergedLocation(I0->getDebugLoc(), I->getDebugLoc()); + } foldPointlessPHINodes(BBEnd); // Finally nuke all instructions apart from the common instruction. @@ -890,5 +941,6 @@ PreservedAnalyses GVNSinkPass::run(Function &F, FunctionAnalysisManager &AM) { GVNSink G; if (!G.run(F)) return PreservedAnalyses::all(); + return PreservedAnalyses::none(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GuardWidening.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GuardWidening.cpp index 3bbf6642a90c..e7ff2a14469c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/GuardWidening.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/GuardWidening.cpp @@ -52,6 +52,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -121,12 +122,13 @@ static void eliminateGuard(Instruction *GuardInst, MemorySSAUpdater *MSSAU) { /// condition should stay invariant. Otherwise there can be a miscompile, like /// the one described at https://github.com/llvm/llvm-project/issues/60234. The /// safest way to do it is to expand the new condition at WC's block. -static Instruction *findInsertionPointForWideCondition(Instruction *WCOrGuard) { +static std::optional<BasicBlock::iterator> +findInsertionPointForWideCondition(Instruction *WCOrGuard) { if (isGuard(WCOrGuard)) - return WCOrGuard; + return WCOrGuard->getIterator(); if (auto WC = extractWidenableCondition(WCOrGuard)) - return cast<Instruction>(WC); - return nullptr; + return cast<Instruction>(WC)->getIterator(); + return std::nullopt; } class GuardWideningImpl { @@ -182,30 +184,30 @@ class GuardWideningImpl { /// into \p WideningPoint. WideningScore computeWideningScore(Instruction *DominatedInstr, Instruction *ToWiden, - Instruction *WideningPoint, + BasicBlock::iterator WideningPoint, SmallVectorImpl<Value *> &ChecksToHoist, SmallVectorImpl<Value *> &ChecksToWiden); /// Helper to check if \p V can be hoisted to \p InsertPos. - bool canBeHoistedTo(const Value *V, const Instruction *InsertPos) const { + bool canBeHoistedTo(const Value *V, BasicBlock::iterator InsertPos) const { SmallPtrSet<const Instruction *, 8> Visited; return canBeHoistedTo(V, InsertPos, Visited); } - bool canBeHoistedTo(const Value *V, const Instruction *InsertPos, + bool canBeHoistedTo(const Value *V, BasicBlock::iterator InsertPos, SmallPtrSetImpl<const Instruction *> &Visited) const; bool canBeHoistedTo(const SmallVectorImpl<Value *> &Checks, - const Instruction *InsertPos) const { + BasicBlock::iterator InsertPos) const { return all_of(Checks, [&](const Value *V) { return canBeHoistedTo(V, InsertPos); }); } /// Helper to hoist \p V to \p InsertPos. Guaranteed to succeed if \c /// canBeHoistedTo returned true. - void makeAvailableAt(Value *V, Instruction *InsertPos) const; + void makeAvailableAt(Value *V, BasicBlock::iterator InsertPos) const; void makeAvailableAt(const SmallVectorImpl<Value *> &Checks, - Instruction *InsertPos) const { + BasicBlock::iterator InsertPos) const { for (Value *V : Checks) makeAvailableAt(V, InsertPos); } @@ -217,18 +219,19 @@ class GuardWideningImpl { /// InsertPt is true then actually generate the resulting expression, make it /// available at \p InsertPt and return it in \p Result (else no change to the /// IR is made). - std::optional<Value *> mergeChecks(SmallVectorImpl<Value *> &ChecksToHoist, - SmallVectorImpl<Value *> &ChecksToWiden, - Instruction *InsertPt); + std::optional<Value *> + mergeChecks(SmallVectorImpl<Value *> &ChecksToHoist, + SmallVectorImpl<Value *> &ChecksToWiden, + std::optional<BasicBlock::iterator> InsertPt); /// Generate the logical AND of \p ChecksToHoist and \p OldCondition and make /// it available at InsertPt Value *hoistChecks(SmallVectorImpl<Value *> &ChecksToHoist, - Value *OldCondition, Instruction *InsertPt); + Value *OldCondition, BasicBlock::iterator InsertPt); /// Adds freeze to Orig and push it as far as possible very aggressively. /// Also replaces all uses of frozen instruction with frozen version. - Value *freezeAndPush(Value *Orig, Instruction *InsertPt); + Value *freezeAndPush(Value *Orig, BasicBlock::iterator InsertPt); /// Represents a range check of the form \c Base + \c Offset u< \c Length, /// with the constraint that \c Length is not negative. \c CheckInst is the @@ -294,7 +297,7 @@ class GuardWideningImpl { /// for the price of computing only one of the set of expressions? bool isWideningCondProfitable(SmallVectorImpl<Value *> &ChecksToHoist, SmallVectorImpl<Value *> &ChecksToWiden) { - return mergeChecks(ChecksToHoist, ChecksToWiden, /*InsertPt=*/nullptr) + return mergeChecks(ChecksToHoist, ChecksToWiden, /*InsertPt=*/std::nullopt) .has_value(); } @@ -302,11 +305,11 @@ class GuardWideningImpl { void widenGuard(SmallVectorImpl<Value *> &ChecksToHoist, SmallVectorImpl<Value *> &ChecksToWiden, Instruction *ToWiden) { - Instruction *InsertPt = findInsertionPointForWideCondition(ToWiden); + auto InsertPt = findInsertionPointForWideCondition(ToWiden); auto MergedCheck = mergeChecks(ChecksToHoist, ChecksToWiden, InsertPt); Value *Result = MergedCheck ? *MergedCheck : hoistChecks(ChecksToHoist, - getCondition(ToWiden), InsertPt); + getCondition(ToWiden), *InsertPt); if (isGuardAsWidenableBranch(ToWiden)) { setWidenableBranchCond(cast<BranchInst>(ToWiden), Result); @@ -417,12 +420,12 @@ bool GuardWideningImpl::eliminateInstrViaWidening( assert((i == (e - 1)) == (Instr->getParent() == CurBB) && "Bad DFS?"); for (auto *Candidate : make_range(I, E)) { - auto *WideningPoint = findInsertionPointForWideCondition(Candidate); + auto WideningPoint = findInsertionPointForWideCondition(Candidate); if (!WideningPoint) continue; SmallVector<Value *> CandidateChecks; parseWidenableGuard(Candidate, CandidateChecks); - auto Score = computeWideningScore(Instr, Candidate, WideningPoint, + auto Score = computeWideningScore(Instr, Candidate, *WideningPoint, ChecksToHoist, CandidateChecks); LLVM_DEBUG(dbgs() << "Score between " << *Instr << " and " << *Candidate << " is " << scoreTypeToString(Score) << "\n"); @@ -456,7 +459,7 @@ bool GuardWideningImpl::eliminateInstrViaWidening( GuardWideningImpl::WideningScore GuardWideningImpl::computeWideningScore( Instruction *DominatedInstr, Instruction *ToWiden, - Instruction *WideningPoint, SmallVectorImpl<Value *> &ChecksToHoist, + BasicBlock::iterator WideningPoint, SmallVectorImpl<Value *> &ChecksToHoist, SmallVectorImpl<Value *> &ChecksToWiden) { Loop *DominatedInstrLoop = LI.getLoopFor(DominatedInstr->getParent()); Loop *DominatingGuardLoop = LI.getLoopFor(WideningPoint->getParent()); @@ -559,7 +562,7 @@ GuardWideningImpl::WideningScore GuardWideningImpl::computeWideningScore( } bool GuardWideningImpl::canBeHoistedTo( - const Value *V, const Instruction *Loc, + const Value *V, BasicBlock::iterator Loc, SmallPtrSetImpl<const Instruction *> &Visited) const { auto *Inst = dyn_cast<Instruction>(V); if (!Inst || DT.dominates(Inst, Loc) || Visited.count(Inst)) @@ -580,7 +583,8 @@ bool GuardWideningImpl::canBeHoistedTo( [&](Value *Op) { return canBeHoistedTo(Op, Loc, Visited); }); } -void GuardWideningImpl::makeAvailableAt(Value *V, Instruction *Loc) const { +void GuardWideningImpl::makeAvailableAt(Value *V, + BasicBlock::iterator Loc) const { auto *Inst = dyn_cast<Instruction>(V); if (!Inst || DT.dominates(Inst, Loc)) return; @@ -592,7 +596,7 @@ void GuardWideningImpl::makeAvailableAt(Value *V, Instruction *Loc) const { for (Value *Op : Inst->operands()) makeAvailableAt(Op, Loc); - Inst->moveBefore(Loc); + Inst->moveBefore(*Loc->getParent(), Loc); } // Return Instruction before which we can insert freeze for the value V as close @@ -621,14 +625,15 @@ getFreezeInsertPt(Value *V, const DominatorTree &DT) { return Res; } -Value *GuardWideningImpl::freezeAndPush(Value *Orig, Instruction *InsertPt) { +Value *GuardWideningImpl::freezeAndPush(Value *Orig, + BasicBlock::iterator InsertPt) { if (isGuaranteedNotToBePoison(Orig, nullptr, InsertPt, &DT)) return Orig; std::optional<BasicBlock::iterator> InsertPtAtDef = getFreezeInsertPt(Orig, DT); if (!InsertPtAtDef) { FreezeInst *FI = new FreezeInst(Orig, "gw.freeze"); - FI->insertBefore(InsertPt); + FI->insertBefore(*InsertPt->getParent(), InsertPt); return FI; } if (isa<Constant>(Orig) || isa<GlobalValue>(Orig)) { @@ -695,7 +700,7 @@ Value *GuardWideningImpl::freezeAndPush(Value *Orig, Instruction *InsertPt) { Worklist.push_back(U.get()); } for (Instruction *I : DropPoisonFlags) - I->dropPoisonGeneratingFlagsAndMetadata(); + I->dropPoisonGeneratingAnnotations(); Value *Result = Orig; for (Value *V : NeedFreeze) { @@ -715,7 +720,7 @@ Value *GuardWideningImpl::freezeAndPush(Value *Orig, Instruction *InsertPt) { std::optional<Value *> GuardWideningImpl::mergeChecks(SmallVectorImpl<Value *> &ChecksToHoist, SmallVectorImpl<Value *> &ChecksToWiden, - Instruction *InsertPt) { + std::optional<BasicBlock::iterator> InsertPt) { using namespace llvm::PatternMatch; Value *Result = nullptr; @@ -747,10 +752,10 @@ GuardWideningImpl::mergeChecks(SmallVectorImpl<Value *> &ChecksToHoist, if (Intersect->getEquivalentICmp(Pred, NewRHSAP)) { if (InsertPt) { ConstantInt *NewRHS = - ConstantInt::get(InsertPt->getContext(), NewRHSAP); - assert(canBeHoistedTo(LHS, InsertPt) && "must be"); - makeAvailableAt(LHS, InsertPt); - Result = new ICmpInst(InsertPt, Pred, LHS, NewRHS, "wide.chk"); + ConstantInt::get((*InsertPt)->getContext(), NewRHSAP); + assert(canBeHoistedTo(LHS, *InsertPt) && "must be"); + makeAvailableAt(LHS, *InsertPt); + Result = new ICmpInst(*InsertPt, Pred, LHS, NewRHS, "wide.chk"); } return Result; } @@ -765,16 +770,16 @@ GuardWideningImpl::mergeChecks(SmallVectorImpl<Value *> &ChecksToHoist, combineRangeChecks(Checks, CombinedChecks)) { if (InsertPt) { for (auto &RC : CombinedChecks) { - makeAvailableAt(RC.getCheckInst(), InsertPt); + makeAvailableAt(RC.getCheckInst(), *InsertPt); if (Result) Result = BinaryOperator::CreateAnd(RC.getCheckInst(), Result, "", - InsertPt); + *InsertPt); else Result = RC.getCheckInst(); } assert(Result && "Failed to find result value"); Result->setName("wide.chk"); - Result = freezeAndPush(Result, InsertPt); + Result = freezeAndPush(Result, *InsertPt); } return Result; } @@ -786,9 +791,9 @@ GuardWideningImpl::mergeChecks(SmallVectorImpl<Value *> &ChecksToHoist, Value *GuardWideningImpl::hoistChecks(SmallVectorImpl<Value *> &ChecksToHoist, Value *OldCondition, - Instruction *InsertPt) { + BasicBlock::iterator InsertPt) { assert(!ChecksToHoist.empty()); - IRBuilder<> Builder(InsertPt); + IRBuilder<> Builder(InsertPt->getParent(), InsertPt); makeAvailableAt(ChecksToHoist, InsertPt); makeAvailableAt(OldCondition, InsertPt); Value *Result = Builder.CreateAnd(ChecksToHoist); @@ -812,7 +817,7 @@ bool GuardWideningImpl::parseRangeChecks( if (IC->getPredicate() == ICmpInst::ICMP_UGT) std::swap(CmpLHS, CmpRHS); - auto &DL = IC->getModule()->getDataLayout(); + auto &DL = IC->getDataLayout(); GuardWideningImpl::RangeCheck Check( CmpLHS, cast<ConstantInt>(ConstantInt::getNullValue(CmpRHS->getType())), diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp index 41c4d6236173..5e2131b0b180 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -70,6 +70,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -137,6 +138,8 @@ class IndVarSimplify { SmallVector<WeakTrackingVH, 16> DeadInsts; bool WidenIndVars; + bool RunUnswitching = false; + bool handleFloatingPointIV(Loop *L, PHINode *PH); bool rewriteNonIntegerIVs(Loop *L); @@ -170,6 +173,8 @@ public: } bool run(Loop *L); + + bool runUnswitching() const { return RunUnswitching; } }; } // end anonymous namespace @@ -350,18 +355,22 @@ bool IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { IntegerType *Int32Ty = Type::getInt32Ty(PN->getContext()); // Insert new integer induction variable. - PHINode *NewPHI = PHINode::Create(Int32Ty, 2, PN->getName()+".int", PN); + PHINode *NewPHI = + PHINode::Create(Int32Ty, 2, PN->getName() + ".int", PN->getIterator()); NewPHI->addIncoming(ConstantInt::get(Int32Ty, InitValue), PN->getIncomingBlock(IncomingEdge)); + NewPHI->setDebugLoc(PN->getDebugLoc()); - Value *NewAdd = - BinaryOperator::CreateAdd(NewPHI, ConstantInt::get(Int32Ty, IncValue), - Incr->getName()+".int", Incr); + Instruction *NewAdd = + BinaryOperator::CreateAdd(NewPHI, ConstantInt::get(Int32Ty, IncValue), + Incr->getName() + ".int", Incr->getIterator()); + NewAdd->setDebugLoc(Incr->getDebugLoc()); NewPHI->addIncoming(NewAdd, PN->getIncomingBlock(BackEdge)); - ICmpInst *NewCompare = new ICmpInst(TheBr, NewPred, NewAdd, - ConstantInt::get(Int32Ty, ExitValue), - Compare->getName()); + ICmpInst *NewCompare = + new ICmpInst(TheBr->getIterator(), NewPred, NewAdd, + ConstantInt::get(Int32Ty, ExitValue), Compare->getName()); + NewCompare->setDebugLoc(Compare->getDebugLoc()); // In the following deletions, PN may become dead and may be deleted. // Use a WeakTrackingVH to observe whether this happens. @@ -385,8 +394,9 @@ bool IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { // We give preference to sitofp over uitofp because it is faster on most // platforms. if (WeakPH) { - Value *Conv = new SIToFPInst(NewPHI, PN->getType(), "indvar.conv", - &*PN->getParent()->getFirstInsertionPt()); + Instruction *Conv = new SIToFPInst(NewPHI, PN->getType(), "indvar.conv", + PN->getParent()->getFirstInsertionPt()); + Conv->setDebugLoc(PN->getDebugLoc()); PN->replaceAllUsesWith(Conv); RecursivelyDeleteTriviallyDeadInstructions(PN, TLI, MSSAU.get()); } @@ -508,7 +518,7 @@ static void visitIVCast(CastInst *Cast, WideIVInfo &WI, Type *Ty = Cast->getType(); uint64_t Width = SE->getTypeSizeInBits(Ty); - if (!Cast->getModule()->getDataLayout().isLegalInteger(Width)) + if (!Cast->getDataLayout().isLegalInteger(Width)) return; // Check that `Cast` actually extends the induction variable (we rely on this @@ -614,9 +624,11 @@ bool IndVarSimplify::simplifyAndExtend(Loop *L, // Information about sign/zero extensions of CurrIV. IndVarSimplifyVisitor Visitor(CurrIV, SE, TTI, DT); - Changed |= simplifyUsersOfIV(CurrIV, SE, DT, LI, TTI, DeadInsts, Rewriter, - &Visitor); + const auto &[C, U] = simplifyUsersOfIV(CurrIV, SE, DT, LI, TTI, DeadInsts, + Rewriter, &Visitor); + Changed |= C; + RunUnswitching |= U; if (Visitor.WI.WidestNativeType) { WideIVs.push_back(Visitor.WI); } @@ -833,7 +845,7 @@ static PHINode *FindLoopCounter(Loop *L, BasicBlock *ExitingBB, const SCEV *BestInit = nullptr; BasicBlock *LatchBlock = L->getLoopLatch(); assert(LatchBlock && "Must be in simplified form"); - const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + const DataLayout &DL = L->getHeader()->getDataLayout(); for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) { PHINode *Phi = cast<PHINode>(I); @@ -1220,7 +1232,7 @@ static void replaceLoopPHINodesWithPreheaderValues( if (!L->contains(I)) continue; - Value *Res = simplifyInstruction(I, I->getModule()->getDataLayout()); + Value *Res = simplifyInstruction(I, I->getDataLayout()); if (Res && LI->replacementPreservesLCSSAForm(I, Res)) { for (User *U : I->users()) Worklist.push_back(cast<Instruction>(U)); @@ -1451,7 +1463,7 @@ bool IndVarSimplify::canonicalizeExitCondition(Loop *L) { if (!match(LHS, m_ZExt(m_Value(LHSOp))) || !ICmp->isSigned()) continue; - const DataLayout &DL = ExitingBB->getModule()->getDataLayout(); + const DataLayout &DL = ExitingBB->getDataLayout(); const unsigned InnerBitWidth = DL.getTypeSizeInBits(LHSOp->getType()); const unsigned OuterBitWidth = DL.getTypeSizeInBits(RHS->getType()); auto FullCR = ConstantRange::getFull(InnerBitWidth); @@ -1516,9 +1528,9 @@ bool IndVarSimplify::canonicalizeExitCondition(Loop *L) { // loop varying work to loop-invariant work. auto doRotateTransform = [&]() { assert(ICmp->isUnsigned() && "must have proven unsigned already"); - auto *NewRHS = - CastInst::Create(Instruction::Trunc, RHS, LHSOp->getType(), "", - L->getLoopPreheader()->getTerminator()); + auto *NewRHS = CastInst::Create( + Instruction::Trunc, RHS, LHSOp->getType(), "", + L->getLoopPreheader()->getTerminator()->getIterator()); ICmp->setOperand(Swapped ? 1 : 0, LHSOp); ICmp->setOperand(Swapped ? 0 : 1, NewRHS); if (LHS->use_empty()) @@ -1526,7 +1538,7 @@ bool IndVarSimplify::canonicalizeExitCondition(Loop *L) { }; - const DataLayout &DL = ExitingBB->getModule()->getDataLayout(); + const DataLayout &DL = ExitingBB->getDataLayout(); const unsigned InnerBitWidth = DL.getTypeSizeInBits(LHSOp->getType()); const unsigned OuterBitWidth = DL.getTypeSizeInBits(RHS->getType()); auto FullCR = ConstantRange::getFull(InnerBitWidth); @@ -1873,6 +1885,7 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { if (OldCond->use_empty()) DeadInsts.emplace_back(OldCond); Changed = true; + RunUnswitching = true; } return Changed; @@ -2049,7 +2062,7 @@ PreservedAnalyses IndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { Function *F = L.getHeader()->getParent(); - const DataLayout &DL = F->getParent()->getDataLayout(); + const DataLayout &DL = F->getDataLayout(); IndVarSimplify IVS(&AR.LI, &AR.SE, &AR.DT, DL, &AR.TLI, &AR.TTI, AR.MSSA, WidenIndVars && AllowIVWidening); @@ -2058,6 +2071,11 @@ PreservedAnalyses IndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, auto PA = getLoopPassPreservedAnalyses(); PA.preserveSet<CFGAnalyses>(); + if (IVS.runUnswitching()) { + AM.getResult<ShouldRunExtraSimpleLoopUnswitch>(L, AR); + PA.preserve<ShouldRunExtraSimpleLoopUnswitch>(); + } + if (AR.MSSA) PA.preserve<MemorySSAAnalysis>(); return PA; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index 9df28747570c..104e8ceb7967 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -279,6 +279,9 @@ bool InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, Value *LHS = ICI->getOperand(0); Value *RHS = ICI->getOperand(1); + if (!LHS->getType()->isIntegerTy()) + return false; + // Canonicalize to the `Index Pred Invariant` comparison if (IsLoopInvariant(LHS)) { std::swap(LHS, RHS); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp index 1bf50d79e533..6b9566f1ae46 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -642,6 +642,7 @@ Value *InferAddressSpacesImpl::cloneInstructionWithNewAddressSpace( Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(I->getType(), AS); auto *NewI = new AddrSpaceCastInst(I, NewPtrTy); NewI->insertAfter(I); + NewI->setDebugLoc(I->getDebugLoc()); return NewI; } @@ -821,7 +822,7 @@ unsigned InferAddressSpacesImpl::joinAddressSpaces(unsigned AS1, } bool InferAddressSpacesImpl::run(Function &F) { - DL = &F.getParent()->getDataLayout(); + DL = &F.getDataLayout(); if (AssumeDefaultIsFlatAddressSpace) FlatAddrSpace = 0; @@ -1221,6 +1222,7 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( Value::use_iterator I, E, Next; for (I = V->use_begin(), E = V->use_end(); I != E;) { Use &U = *I; + User *CurUser = U.getUser(); // Some users may see the same pointer operand in multiple operands. Skip // to the next instruction. @@ -1235,7 +1237,6 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( continue; } - User *CurUser = U.getUser(); // Skip if the current user is the new value itself. if (CurUser == NewV) continue; @@ -1311,10 +1312,13 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( while (isa<PHINode>(InsertPos)) ++InsertPos; - U.set(new AddrSpaceCastInst(NewV, V->getType(), "", &*InsertPos)); + // This instruction may contain multiple uses of V, update them all. + CurUser->replaceUsesOfWith( + V, new AddrSpaceCastInst(NewV, V->getType(), "", InsertPos)); } else { - U.set(ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV), - V->getType())); + CurUser->replaceUsesOfWith( + V, ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV), + V->getType())); } } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/InferAlignment.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/InferAlignment.cpp index b75b8d486fbb..6e0c206bd198 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/InferAlignment.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/InferAlignment.cpp @@ -48,7 +48,7 @@ static bool tryToImproveAlign( } bool inferAlignment(Function &F, AssumptionCache &AC, DominatorTree &DT) { - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); bool Changed = false; // Enforce preferred type alignment if possible. We do this as a separate diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp index ee9452ce1c7d..326849a4eb39 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp @@ -99,7 +99,7 @@ struct InstSimplifyLegacyPass : public FunctionPass { &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); AssumptionCache *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); const SimplifyQuery SQ(DL, TLI, DT, AC); return runImpl(F, SQ); } @@ -125,7 +125,7 @@ PreservedAnalyses InstSimplifyPass::run(Function &F, auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); const SimplifyQuery SQ(DL, &TLI, &DT, &AC); bool Changed = runImpl(F, SQ); if (!Changed) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp new file mode 100644 index 000000000000..2a4f68e12525 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp @@ -0,0 +1,190 @@ +//===- JumpTableToSwitch.cpp ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/JumpTableToSwitch.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace llvm; + +static cl::opt<unsigned> + JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden, + cl::desc("Only split jump tables with size less or " + "equal than JumpTableSizeThreshold."), + cl::init(10)); + +// TODO: Consider adding a cost model for profitability analysis of this +// transformation. Currently we replace a jump table with a switch if all the +// functions in the jump table are smaller than the provided threshold. +static cl::opt<unsigned> FunctionSizeThreshold( + "jump-table-to-switch-function-size-threshold", cl::Hidden, + cl::desc("Only split jump tables containing functions whose sizes are less " + "or equal than this threshold."), + cl::init(50)); + +#define DEBUG_TYPE "jump-table-to-switch" + +namespace { +struct JumpTableTy { + Value *Index; + SmallVector<Function *, 10> Funcs; +}; +} // anonymous namespace + +static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP, + PointerType *PtrTy) { + Constant *Ptr = dyn_cast<Constant>(GEP->getPointerOperand()); + if (!Ptr) + return std::nullopt; + + GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr); + if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) + return std::nullopt; + + Function &F = *GEP->getParent()->getParent(); + const DataLayout &DL = F.getDataLayout(); + const unsigned BitWidth = + DL.getIndexSizeInBits(GEP->getPointerAddressSpace()); + MapVector<Value *, APInt> VariableOffsets; + APInt ConstantOffset(BitWidth, 0); + if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset)) + return std::nullopt; + if (VariableOffsets.size() != 1) + return std::nullopt; + // TODO: consider supporting more general patterns + if (!ConstantOffset.isZero()) + return std::nullopt; + APInt StrideBytes = VariableOffsets.front().second; + const uint64_t JumpTableSizeBytes = DL.getTypeAllocSize(GV->getValueType()); + if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0) + return std::nullopt; + const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue(); + if (N > JumpTableSizeThreshold) + return std::nullopt; + + JumpTableTy JumpTable; + JumpTable.Index = VariableOffsets.front().first; + JumpTable.Funcs.reserve(N); + for (uint64_t Index = 0; Index < N; ++Index) { + // ConstantOffset is zero. + APInt Offset = Index * StrideBytes; + Constant *C = + ConstantFoldLoadFromConst(GV->getInitializer(), PtrTy, Offset, DL); + auto *Func = dyn_cast_or_null<Function>(C); + if (!Func || Func->isDeclaration() || + Func->getInstructionCount() > FunctionSizeThreshold) + return std::nullopt; + JumpTable.Funcs.push_back(Func); + } + return JumpTable; +} + +static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT, + DomTreeUpdater &DTU, + OptimizationRemarkEmitter &ORE) { + const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext()); + + SmallVector<DominatorTree::UpdateType, 8> DTUpdates; + BasicBlock *BB = CB->getParent(); + BasicBlock *Tail = SplitBlock(BB, CB, &DTU, nullptr, nullptr, + BB->getName() + Twine(".tail")); + DTUpdates.push_back({DominatorTree::Delete, BB, Tail}); + BB->getTerminator()->eraseFromParent(); + + Function &F = *BB->getParent(); + BasicBlock *BBUnreachable = BasicBlock::Create( + F.getContext(), "default.switch.case.unreachable", &F, Tail); + IRBuilder<> BuilderUnreachable(BBUnreachable); + BuilderUnreachable.CreateUnreachable(); + + IRBuilder<> Builder(BB); + SwitchInst *Switch = Builder.CreateSwitch(JT.Index, BBUnreachable); + DTUpdates.push_back({DominatorTree::Insert, BB, BBUnreachable}); + + IRBuilder<> BuilderTail(CB); + PHINode *PHI = + IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size()); + + for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) { + BasicBlock *B = BasicBlock::Create(Func->getContext(), + "call." + Twine(Index), &F, Tail); + DTUpdates.push_back({DominatorTree::Insert, BB, B}); + DTUpdates.push_back({DominatorTree::Insert, B, Tail}); + + CallBase *Call = cast<CallBase>(CB->clone()); + Call->setCalledFunction(Func); + Call->insertInto(B, B->end()); + Switch->addCase( + cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B); + BranchInst::Create(Tail, B); + if (PHI) + PHI->addIncoming(Call, B); + } + DTU.applyUpdates(DTUpdates); + ORE.emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB) + << "expanded indirect call into switch"; + }); + if (PHI) + CB->replaceAllUsesWith(PHI); + CB->eraseFromParent(); + return Tail; +} + +PreservedAnalyses JumpTableToSwitchPass::run(Function &F, + FunctionAnalysisManager &AM) { + OptimizationRemarkEmitter &ORE = + AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + DominatorTree *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); + PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F); + DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy); + bool Changed = false; + for (BasicBlock &BB : make_early_inc_range(F)) { + BasicBlock *CurrentBB = &BB; + while (CurrentBB) { + BasicBlock *SplittedOutTail = nullptr; + for (Instruction &I : make_early_inc_range(*CurrentBB)) { + auto *Call = dyn_cast<CallInst>(&I); + if (!Call || Call->getCalledFunction() || Call->isMustTailCall()) + continue; + auto *L = dyn_cast<LoadInst>(Call->getCalledOperand()); + // Skip atomic or volatile loads. + if (!L || !L->isSimple()) + continue; + auto *GEP = dyn_cast<GetElementPtrInst>(L->getPointerOperand()); + if (!GEP) + continue; + auto *PtrTy = dyn_cast<PointerType>(L->getType()); + assert(PtrTy && "call operand must be a pointer"); + std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy); + if (!JumpTable) + continue; + SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE); + Changed = true; + break; + } + CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr; + } + } + + if (!Changed) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + if (DT) + PA.preserve<DominatorTreeAnalysis>(); + if (PDT) + PA.preserve<PostDominatorTreeAnalysis>(); + return PA; +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpThreading.cpp index 87c01ead634f..7a0b661a0779 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -231,7 +231,7 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) { Weights[0] = BP.getCompl().getNumerator(); Weights[1] = BP.getNumerator(); } - setBranchWeights(*PredBr, Weights); + setBranchWeights(*PredBr, Weights, hasBranchWeightOrigin(*PredBr)); } } @@ -401,8 +401,8 @@ static bool replaceFoldableUses(Instruction *Cond, Value *ToVal, Changed |= replaceNonLocalUsesWith(Cond, ToVal); for (Instruction &I : reverse(*KnownAtEndOfBB)) { // Replace any debug-info record users of Cond with ToVal. - for (DPValue &DPV : I.getDbgValueRange()) - DPV.replaceVariableLocationOp(Cond, ToVal, true); + for (DbgVariableRecord &DVR : filterDbgVars(I.getDbgRecordRange())) + DVR.replaceVariableLocationOp(Cond, ToVal, true); // Reached the Cond whose uses we are trying to replace, so there are no // more uses. @@ -558,9 +558,9 @@ static Constant *getKnownConstant(Value *Val, ConstantPreference Preference) { /// This returns true if there were any known values. bool JumpThreadingPass::computeValueKnownInPredecessorsImpl( Value *V, BasicBlock *BB, PredValueInfo &Result, - ConstantPreference Preference, DenseSet<Value *> &RecursionSet, + ConstantPreference Preference, SmallPtrSet<Value *, 4> &RecursionSet, Instruction *CxtI) { - const DataLayout &DL = BB->getModule()->getDataLayout(); + const DataLayout &DL = BB->getDataLayout(); // This method walks up use-def chains recursively. Because of this, we could // get into an infinite loop going around loops in the use-def chain. To @@ -596,11 +596,8 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl( CmpInst::Predicate Pred; Value *Val; Constant *Cst; - if (!PredCst && match(V, m_Cmp(Pred, m_Value(Val), m_Constant(Cst)))) { - auto Res = LVI->getPredicateOnEdge(Pred, Val, Cst, P, BB, CxtI); - if (Res != LazyValueInfo::Unknown) - PredCst = ConstantInt::getBool(V->getContext(), Res); - } + if (!PredCst && match(V, m_Cmp(Pred, m_Value(Val), m_Constant(Cst)))) + PredCst = LVI->getPredicateOnEdge(Pred, Val, Cst, P, BB, CxtI); if (Constant *KC = getKnownConstant(PredCst, Preference)) Result.emplace_back(KC, P); } @@ -757,7 +754,7 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl( // may result in comparison of values from two different loop iterations. // FIXME: This check is broken if LoopHeaders is not populated. if (PN && PN->getParent() == BB && !LoopHeaders.contains(BB)) { - const DataLayout &DL = PN->getModule()->getDataLayout(); + const DataLayout &DL = PN->getDataLayout(); // We can do this simplification if any comparisons fold to true or false. // See if any do. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { @@ -780,13 +777,8 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl( if (LHSInst && LHSInst->getParent() == BB) continue; - LazyValueInfo::Tristate - ResT = LVI->getPredicateOnEdge(Pred, LHS, - cast<Constant>(RHS), PredBB, BB, - CxtI ? CxtI : Cmp); - if (ResT == LazyValueInfo::Unknown) - continue; - Res = ConstantInt::get(Type::getInt1Ty(LHS->getContext()), ResT); + Res = LVI->getPredicateOnEdge(Pred, LHS, cast<Constant>(RHS), PredBB, + BB, CxtI ? CxtI : Cmp); } if (Constant *KC = getKnownConstant(Res, WantInteger)) @@ -806,14 +798,10 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl( for (BasicBlock *P : predecessors(BB)) { // If the value is known by LazyValueInfo to be a constant in a // predecessor, use that information to try to thread this block. - LazyValueInfo::Tristate Res = - LVI->getPredicateOnEdge(Pred, CmpLHS, - CmpConst, P, BB, CxtI ? CxtI : Cmp); - if (Res == LazyValueInfo::Unknown) - continue; - - Constant *ResC = ConstantInt::get(CmpType, Res); - Result.emplace_back(ResC, P); + Constant *Res = LVI->getPredicateOnEdge(Pred, CmpLHS, CmpConst, P, BB, + CxtI ? CxtI : Cmp); + if (Constant *KC = getKnownConstant(Res, WantInteger)) + Result.emplace_back(KC, P); } return !Result.empty(); @@ -868,7 +856,8 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl( for (const auto &LHSVal : LHSVals) { Constant *V = LHSVal.first; - Constant *Folded = ConstantExpr::getCompare(Pred, V, CmpConst); + Constant *Folded = + ConstantFoldCompareInstOperands(Pred, V, CmpConst, DL); if (Constant *KC = getKnownConstant(Folded, WantInteger)) Result.emplace_back(KC, LHSVal.second); } @@ -1007,7 +996,7 @@ bool JumpThreadingPass::processBlock(BasicBlock *BB) { // constant. if (Instruction *I = dyn_cast<Instruction>(Condition)) { Value *SimpleVal = - ConstantFoldInstruction(I, BB->getModule()->getDataLayout(), TLI); + ConstantFoldInstruction(I, BB->getDataLayout(), TLI); if (SimpleVal) { I->replaceAllUsesWith(SimpleVal); if (isInstructionTriviallyDead(I, TLI)) @@ -1037,7 +1026,8 @@ bool JumpThreadingPass::processBlock(BasicBlock *BB) { LLVM_DEBUG(dbgs() << " In block '" << BB->getName() << "' folding undef terminator: " << *BBTerm << '\n'); - BranchInst::Create(BBTerm->getSuccessor(BestSucc), BBTerm); + Instruction *NewBI = BranchInst::Create(BBTerm->getSuccessor(BestSucc), BBTerm->getIterator()); + NewBI->setDebugLoc(BBTerm->getDebugLoc()); ++NumFolds; BBTerm->eraseFromParent(); DTU->applyUpdatesPermissive(Updates); @@ -1080,11 +1070,11 @@ bool JumpThreadingPass::processBlock(BasicBlock *BB) { // it's value at the branch instruction. We only handle comparisons // against a constant at this time. if (Constant *CondConst = dyn_cast<Constant>(CondCmp->getOperand(1))) { - LazyValueInfo::Tristate Ret = + Constant *Res = LVI->getPredicateAt(CondCmp->getPredicate(), CondCmp->getOperand(0), CondConst, BB->getTerminator(), /*UseBlockValue=*/false); - if (Ret != LazyValueInfo::Unknown) { + if (Res) { // We can safely replace *some* uses of the CondInst if it has // exactly one value as returned by LVI. RAUW is incorrect in the // presence of guards and assumes, that have the `Cond` as the use. This @@ -1092,10 +1082,7 @@ bool JumpThreadingPass::processBlock(BasicBlock *BB) { // at the end of block, but RAUW unconditionally replaces all uses // including the guards/assumes themselves and the uses before the // guard/assume. - auto *CI = Ret == LazyValueInfo::True ? - ConstantInt::getTrue(CondCmp->getType()) : - ConstantInt::getFalse(CondCmp->getType()); - if (replaceFoldableUses(CondCmp, CI, BB)) + if (replaceFoldableUses(CondCmp, Res, BB)) return true; } @@ -1177,7 +1164,7 @@ bool JumpThreadingPass::processImpliedCondition(BasicBlock *BB) { BasicBlock *CurrentPred = BB->getSinglePredecessor(); unsigned Iter = 0; - auto &DL = BB->getModule()->getDataLayout(); + auto &DL = BB->getDataLayout(); while (CurrentPred && Iter++ < ImplicationSearchThreshold) { auto *PBI = dyn_cast<BranchInst>(CurrentPred->getTerminator()); @@ -1202,7 +1189,7 @@ bool JumpThreadingPass::processImpliedCondition(BasicBlock *BB) { BasicBlock *KeepSucc = BI->getSuccessor(*Implication ? 0 : 1); BasicBlock *RemoveSucc = BI->getSuccessor(*Implication ? 1 : 0); RemoveSucc->removePredecessor(BB); - BranchInst *UncondBI = BranchInst::Create(KeepSucc, BI); + BranchInst *UncondBI = BranchInst::Create(KeepSucc, BI->getIterator()); UncondBI->setDebugLoc(BI->getDebugLoc()); ++NumFolds; BI->eraseFromParent(); @@ -1278,9 +1265,11 @@ bool JumpThreadingPass::simplifyPartiallyRedundantLoad(LoadInst *LoadI) { // only happen in dead loops. if (AvailableVal == LoadI) AvailableVal = PoisonValue::get(LoadI->getType()); - if (AvailableVal->getType() != LoadI->getType()) + if (AvailableVal->getType() != LoadI->getType()) { AvailableVal = CastInst::CreateBitOrPointerCast( - AvailableVal, LoadI->getType(), "", LoadI); + AvailableVal, LoadI->getType(), "", LoadI->getIterator()); + cast<Instruction>(AvailableVal)->setDebugLoc(LoadI->getDebugLoc()); + } LoadI->replaceAllUsesWith(AvailableVal); LoadI->eraseFromParent(); return true; @@ -1321,7 +1310,7 @@ bool JumpThreadingPass::simplifyPartiallyRedundantLoad(LoadInst *LoadI) { // If this is a load on a phi pointer, phi-translate it and search // for available load/store to the pointer in predecessors. Type *AccessTy = LoadI->getType(); - const auto &DL = LoadI->getModule()->getDataLayout(); + const auto &DL = LoadI->getDataLayout(); MemoryLocation Loc(LoadedPtr->DoPHITranslation(LoadBB, PredBB), LocationSize::precise(DL.getTypeStoreSize(AccessTy)), AATags); @@ -1421,7 +1410,7 @@ bool JumpThreadingPass::simplifyPartiallyRedundantLoad(LoadInst *LoadI) { LoadI->getType(), LoadedPtr->DoPHITranslation(LoadBB, UnavailablePred), LoadI->getName() + ".pr", false, LoadI->getAlign(), LoadI->getOrdering(), LoadI->getSyncScopeID(), - UnavailablePred->getTerminator()); + UnavailablePred->getTerminator()->getIterator()); NewVal->setDebugLoc(LoadI->getDebugLoc()); if (AATags) NewVal->setAAMetadata(AATags); @@ -1434,16 +1423,14 @@ bool JumpThreadingPass::simplifyPartiallyRedundantLoad(LoadInst *LoadI) { array_pod_sort(AvailablePreds.begin(), AvailablePreds.end()); // Create a PHI node at the start of the block for the PRE'd load value. - pred_iterator PB = pred_begin(LoadBB), PE = pred_end(LoadBB); - PHINode *PN = PHINode::Create(LoadI->getType(), std::distance(PB, PE), ""); + PHINode *PN = PHINode::Create(LoadI->getType(), pred_size(LoadBB), ""); PN->insertBefore(LoadBB->begin()); PN->takeName(LoadI); PN->setDebugLoc(LoadI->getDebugLoc()); // Insert new entries into the PHI for each predecessor. A single block may // have multiple entries here. - for (pred_iterator PI = PB; PI != PE; ++PI) { - BasicBlock *P = *PI; + for (BasicBlock *P : predecessors(LoadBB)) { AvailablePredsTy::iterator I = llvm::lower_bound(AvailablePreds, std::make_pair(P, (Value *)nullptr)); @@ -1456,8 +1443,8 @@ bool JumpThreadingPass::simplifyPartiallyRedundantLoad(LoadInst *LoadI) { // predecessor use the same bitcast. Value *&PredV = I->second; if (PredV->getType() != LoadI->getType()) - PredV = CastInst::CreateBitOrPointerCast(PredV, LoadI->getType(), "", - P->getTerminator()); + PredV = CastInst::CreateBitOrPointerCast( + PredV, LoadI->getType(), "", P->getTerminator()->getIterator()); PN->addIncoming(PredV, I->first); } @@ -1490,7 +1477,7 @@ findMostPopularDest(BasicBlock *BB, // Populate DestPopularity with the successors in the order they appear in the // successor list. This way, we ensure determinism by iterating it in the - // same order in std::max_element below. We map nullptr to 0 so that we can + // same order in llvm::max_element below. We map nullptr to 0 so that we can // return nullptr when PredToDestList contains nullptr only. DestPopularity[nullptr] = 0; for (auto *SuccBB : successors(BB)) @@ -1501,8 +1488,7 @@ findMostPopularDest(BasicBlock *BB, DestPopularity[PredToDest.second]++; // Find the most popular dest. - auto MostPopular = std::max_element( - DestPopularity.begin(), DestPopularity.end(), llvm::less_second()); + auto MostPopular = llvm::max_element(DestPopularity, llvm::less_second()); // Okay, we have finally picked the most popular destination. return MostPopular->first; @@ -1512,7 +1498,8 @@ findMostPopularDest(BasicBlock *BB, // BB->getSinglePredecessor() and then on to BB. Constant *JumpThreadingPass::evaluateOnPredecessorEdge(BasicBlock *BB, BasicBlock *PredPredBB, - Value *V) { + Value *V, + const DataLayout &DL) { BasicBlock *PredBB = BB->getSinglePredecessor(); assert(PredBB && "Expected a single predecessor"); @@ -1537,11 +1524,12 @@ Constant *JumpThreadingPass::evaluateOnPredecessorEdge(BasicBlock *BB, if (CmpInst *CondCmp = dyn_cast<CmpInst>(V)) { if (CondCmp->getParent() == BB) { Constant *Op0 = - evaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(0)); + evaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(0), DL); Constant *Op1 = - evaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(1)); + evaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(1), DL); if (Op0 && Op1) { - return ConstantExpr::getCompare(CondCmp->getPredicate(), Op0, Op1); + return ConstantFoldCompareInstOperands(CondCmp->getPredicate(), Op0, + Op1, DL); } } return nullptr; @@ -1655,7 +1643,8 @@ bool JumpThreadingPass::processThreadableEdges(Value *Cond, BasicBlock *BB, // Finally update the terminator. Instruction *Term = BB->getTerminator(); - BranchInst::Create(OnlyDest, Term); + Instruction *NewBI = BranchInst::Create(OnlyDest, Term->getIterator()); + NewBI->setDebugLoc(Term->getDebugLoc()); ++NumFolds; Term->eraseFromParent(); DTU->applyUpdatesPermissive(Updates); @@ -1879,7 +1868,7 @@ bool JumpThreadingPass::processBranchOnXOR(BinaryOperator *BO) { static void addPHINodeEntriesForMappedBlock(BasicBlock *PHIBB, BasicBlock *OldPred, BasicBlock *NewPred, - DenseMap<Instruction*, Value*> &ValueMap) { + ValueToValueMapTy &ValueMap) { for (PHINode &PN : PHIBB->phis()) { // Ok, we have a PHI node. Figure out what the incoming value was for the // DestBlock. @@ -1887,7 +1876,7 @@ static void addPHINodeEntriesForMappedBlock(BasicBlock *PHIBB, // Remap the value if necessary. if (Instruction *Inst = dyn_cast<Instruction>(IV)) { - DenseMap<Instruction*, Value*>::iterator I = ValueMap.find(Inst); + ValueToValueMapTy::iterator I = ValueMap.find(Inst); if (I != ValueMap.end()) IV = I->second; } @@ -1948,9 +1937,8 @@ bool JumpThreadingPass::maybeMergeBasicBlockIntoOnlyPred(BasicBlock *BB) { /// Update the SSA form. NewBB contains instructions that are copied from BB. /// ValueMapping maps old values in BB to new ones in NewBB. -void JumpThreadingPass::updateSSA( - BasicBlock *BB, BasicBlock *NewBB, - DenseMap<Instruction *, Value *> &ValueMapping) { +void JumpThreadingPass::updateSSA(BasicBlock *BB, BasicBlock *NewBB, + ValueToValueMapTy &ValueMapping) { // If there were values defined in BB that are used outside the block, then we // now have to update all uses of the value to use either the original value, // the cloned value, or some PHI derived value. This can require arbitrary @@ -1958,7 +1946,7 @@ void JumpThreadingPass::updateSSA( SSAUpdater SSAUpdate; SmallVector<Use *, 16> UsesToRename; SmallVector<DbgValueInst *, 4> DbgValues; - SmallVector<DPValue *, 4> DPValues; + SmallVector<DbgVariableRecord *, 4> DbgVariableRecords; for (Instruction &I : *BB) { // Scan all uses of this instruction to see if it is used outside of its @@ -1975,16 +1963,16 @@ void JumpThreadingPass::updateSSA( } // Find debug values outside of the block - findDbgValues(DbgValues, &I, &DPValues); + findDbgValues(DbgValues, &I, &DbgVariableRecords); llvm::erase_if(DbgValues, [&](const DbgValueInst *DbgVal) { return DbgVal->getParent() == BB; }); - llvm::erase_if(DPValues, [&](const DPValue *DPVal) { - return DPVal->getParent() == BB; + llvm::erase_if(DbgVariableRecords, [&](const DbgVariableRecord *DbgVarRec) { + return DbgVarRec->getParent() == BB; }); // If there are no uses outside the block, we're done with this instruction. - if (UsesToRename.empty() && DbgValues.empty() && DPValues.empty()) + if (UsesToRename.empty() && DbgValues.empty() && DbgVariableRecords.empty()) continue; LLVM_DEBUG(dbgs() << "JT: Renaming non-local uses of: " << I << "\n"); @@ -1997,11 +1985,11 @@ void JumpThreadingPass::updateSSA( while (!UsesToRename.empty()) SSAUpdate.RewriteUse(*UsesToRename.pop_back_val()); - if (!DbgValues.empty() || !DPValues.empty()) { + if (!DbgValues.empty() || !DbgVariableRecords.empty()) { SSAUpdate.UpdateDebugValues(&I, DbgValues); - SSAUpdate.UpdateDebugValues(&I, DPValues); + SSAUpdate.UpdateDebugValues(&I, DbgVariableRecords); DbgValues.clear(); - DPValues.clear(); + DbgVariableRecords.clear(); } LLVM_DEBUG(dbgs() << "\n"); @@ -2011,14 +1999,15 @@ void JumpThreadingPass::updateSSA( /// Clone instructions in range [BI, BE) to NewBB. For PHI nodes, we only clone /// arguments that come from PredBB. Return the map from the variables in the /// source basic block to the variables in the newly created basic block. -DenseMap<Instruction *, Value *> -JumpThreadingPass::cloneInstructions(BasicBlock::iterator BI, - BasicBlock::iterator BE, BasicBlock *NewBB, - BasicBlock *PredBB) { + +void JumpThreadingPass::cloneInstructions(ValueToValueMapTy &ValueMapping, + BasicBlock::iterator BI, + BasicBlock::iterator BE, + BasicBlock *NewBB, + BasicBlock *PredBB) { // We are going to have to map operands from the source basic block to the new // copy of the block 'NewBB'. If there are PHI nodes in the source basic // block, evaluate them to account for entry from PredBB. - DenseMap<Instruction *, Value *> ValueMapping; // Retargets llvm.dbg.value to any renamed variables. auto RetargetDbgValueIfPossible = [&](Instruction *NewInst) -> bool { @@ -2044,11 +2033,11 @@ JumpThreadingPass::cloneInstructions(BasicBlock::iterator BI, return true; }; - // Duplicate implementation of the above dbg.value code, using DPValues - // instead. - auto RetargetDPValueIfPossible = [&](DPValue *DPV) { + // Duplicate implementation of the above dbg.value code, using + // DbgVariableRecords instead. + auto RetargetDbgVariableRecordIfPossible = [&](DbgVariableRecord *DVR) { SmallSet<std::pair<Value *, Value *>, 16> OperandsToRemap; - for (auto *Op : DPV->location_ops()) { + for (auto *Op : DVR->location_ops()) { Instruction *OpInst = dyn_cast<Instruction>(Op); if (!OpInst) continue; @@ -2059,7 +2048,7 @@ JumpThreadingPass::cloneInstructions(BasicBlock::iterator BI, } for (auto &[OldOp, MappedOp] : OperandsToRemap) - DPV->replaceVariableLocationOp(OldOp, MappedOp); + DVR->replaceVariableLocationOp(OldOp, MappedOp); }; BasicBlock *RangeBB = BI->getParent(); @@ -2083,9 +2072,9 @@ JumpThreadingPass::cloneInstructions(BasicBlock::iterator BI, cloneNoAliasScopes(NoAliasScopes, ClonedScopes, "thread", Context); auto CloneAndRemapDbgInfo = [&](Instruction *NewInst, Instruction *From) { - auto DPVRange = NewInst->cloneDebugInfoFrom(From); - for (DPValue &DPV : DPVRange) - RetargetDPValueIfPossible(&DPV); + auto DVRRange = NewInst->cloneDebugInfoFrom(From); + for (DbgVariableRecord &DVR : filterDbgVars(DVRRange)) + RetargetDbgVariableRecordIfPossible(&DVR); }; // Clone the non-phi instructions of the source basic block into NewBB, @@ -2106,24 +2095,24 @@ JumpThreadingPass::cloneInstructions(BasicBlock::iterator BI, // Remap operands to patch up intra-block references. for (unsigned i = 0, e = New->getNumOperands(); i != e; ++i) if (Instruction *Inst = dyn_cast<Instruction>(New->getOperand(i))) { - DenseMap<Instruction *, Value *>::iterator I = ValueMapping.find(Inst); + ValueToValueMapTy::iterator I = ValueMapping.find(Inst); if (I != ValueMapping.end()) New->setOperand(i, I->second); } } - // There may be DPValues on the terminator, clone directly from marker - // to marker as there isn't an instruction there. - if (BE != RangeBB->end() && BE->hasDbgValues()) { + // There may be DbgVariableRecords on the terminator, clone directly from + // marker to marker as there isn't an instruction there. + if (BE != RangeBB->end() && BE->hasDbgRecords()) { // Dump them at the end. - DPMarker *Marker = RangeBB->getMarker(BE); - DPMarker *EndMarker = NewBB->createMarker(NewBB->end()); - auto DPVRange = EndMarker->cloneDebugInfoFrom(Marker, std::nullopt); - for (DPValue &DPV : DPVRange) - RetargetDPValueIfPossible(&DPV); + DbgMarker *Marker = RangeBB->getMarker(BE); + DbgMarker *EndMarker = NewBB->createMarker(NewBB->end()); + auto DVRRange = EndMarker->cloneDebugInfoFrom(Marker, std::nullopt); + for (DbgVariableRecord &DVR : filterDbgVars(DVRRange)) + RetargetDbgVariableRecordIfPossible(&DVR); } - return ValueMapping; + return; } /// Attempt to thread through two successive basic blocks. @@ -2194,12 +2183,13 @@ bool JumpThreadingPass::maybethreadThroughTwoBasicBlocks(BasicBlock *BB, unsigned OneCount = 0; BasicBlock *ZeroPred = nullptr; BasicBlock *OnePred = nullptr; + const DataLayout &DL = BB->getDataLayout(); for (BasicBlock *P : predecessors(PredBB)) { // If PredPred ends with IndirectBrInst, we can't handle it. if (isa<IndirectBrInst>(P->getTerminator())) continue; if (ConstantInt *CI = dyn_cast_or_null<ConstantInt>( - evaluateOnPredecessorEdge(BB, P, Cond))) { + evaluateOnPredecessorEdge(BB, P, Cond, DL))) { if (CI->isZero()) { ZeroCount++; ZeroPred = P; @@ -2298,8 +2288,9 @@ void JumpThreadingPass::threadThroughTwoBasicBlocks(BasicBlock *PredPredBB, // We are going to have to map operands from the original BB block to the new // copy of the block 'NewBB'. If there are PHI nodes in PredBB, evaluate them // to account for entry from PredPredBB. - DenseMap<Instruction *, Value *> ValueMapping = - cloneInstructions(PredBB->begin(), PredBB->end(), NewBB, PredPredBB); + ValueToValueMapTy ValueMapping; + cloneInstructions(ValueMapping, PredBB->begin(), PredBB->end(), NewBB, + PredPredBB); // Copy the edge probabilities from PredBB to NewBB. if (BPI) @@ -2422,8 +2413,9 @@ void JumpThreadingPass::threadEdge(BasicBlock *BB, } // Copy all the instructions from BB to NewBB except the terminator. - DenseMap<Instruction *, Value *> ValueMapping = - cloneInstructions(BB->begin(), std::prev(BB->end()), NewBB, PredBB); + ValueToValueMapTy ValueMapping; + cloneInstructions(ValueMapping, BB->begin(), std::prev(BB->end()), NewBB, + PredBB); // We didn't copy the terminator from BB over to NewBB, because there is now // an unconditional jump to SuccBB. Insert the unconditional jump. @@ -2555,8 +2547,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB, BBSuccFreq.push_back(SuccFreq.getFrequency()); } - uint64_t MaxBBSuccFreq = - *std::max_element(BBSuccFreq.begin(), BBSuccFreq.end()); + uint64_t MaxBBSuccFreq = *llvm::max_element(BBSuccFreq); SmallVector<BranchProbability, 4> BBSuccProbs; if (MaxBBSuccFreq == 0) @@ -2614,7 +2605,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB, Weights.push_back(Prob.getNumerator()); auto TI = BB->getTerminator(); - setBranchWeights(*TI, Weights); + setBranchWeights(*TI, Weights, hasBranchWeightOrigin(*TI)); } } @@ -2679,7 +2670,7 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred( // We are going to have to map operands from the original BB block into the // PredBB block. Evaluate PHI nodes in BB. - DenseMap<Instruction*, Value*> ValueMapping; + ValueToValueMapTy ValueMapping; BasicBlock::iterator BI = BB->begin(); for (; PHINode *PN = dyn_cast<PHINode>(BI); ++BI) @@ -2693,17 +2684,20 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred( // Remap operands to patch up intra-block references. for (unsigned i = 0, e = New->getNumOperands(); i != e; ++i) if (Instruction *Inst = dyn_cast<Instruction>(New->getOperand(i))) { - DenseMap<Instruction*, Value*>::iterator I = ValueMapping.find(Inst); + ValueToValueMapTy::iterator I = ValueMapping.find(Inst); if (I != ValueMapping.end()) New->setOperand(i, I->second); } + // Remap debug variable operands. + remapDebugVariable(ValueMapping, New); + // If this instruction can be simplified after the operands are updated, // just use the simplified value instead. This frequently happens due to // phi translation. if (Value *IV = simplifyInstruction( New, - {BB->getModule()->getDataLayout(), TLI, nullptr, nullptr, New})) { + {BB->getDataLayout(), TLI, nullptr, nullptr, New})) { ValueMapping[&*BI] = IV; if (!New->mayHaveSideEffects()) { New->eraseFromParent(); @@ -2882,15 +2876,13 @@ bool JumpThreadingPass::tryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) { // Now check if one of the select values would allow us to constant fold the // terminator in BB. We don't do the transform if both sides fold, those // cases will be threaded in any case. - LazyValueInfo::Tristate LHSFolds = + Constant *LHSRes = LVI->getPredicateOnEdge(CondCmp->getPredicate(), SI->getOperand(1), CondRHS, Pred, BB, CondCmp); - LazyValueInfo::Tristate RHSFolds = + Constant *RHSRes = LVI->getPredicateOnEdge(CondCmp->getPredicate(), SI->getOperand(2), CondRHS, Pred, BB, CondCmp); - if ((LHSFolds != LazyValueInfo::Unknown || - RHSFolds != LazyValueInfo::Unknown) && - LHSFolds != RHSFolds) { + if ((LHSRes || RHSRes) && LHSRes != RHSRes) { unfoldSelectInstr(Pred, BB, SI, CondLHS, I); return true; } @@ -2973,15 +2965,16 @@ bool JumpThreadingPass::tryToUnfoldSelectInCurrBB(BasicBlock *BB) { // Expand the select. Value *Cond = SI->getCondition(); if (!isGuaranteedNotToBeUndefOrPoison(Cond, nullptr, SI)) - Cond = new FreezeInst(Cond, "cond.fr", SI); + Cond = new FreezeInst(Cond, "cond.fr", SI->getIterator()); MDNode *BranchWeights = getBranchWeightMDNode(*SI); Instruction *Term = SplitBlockAndInsertIfThen(Cond, SI, false, BranchWeights); BasicBlock *SplitBB = SI->getParent(); BasicBlock *NewBB = Term->getParent(); - PHINode *NewPN = PHINode::Create(SI->getType(), 2, "", SI); + PHINode *NewPN = PHINode::Create(SI->getType(), 2, "", SI->getIterator()); NewPN->addIncoming(SI->getTrueValue(), Term->getParent()); NewPN->addIncoming(SI->getFalseValue(), BB); + NewPN->setDebugLoc(SI->getDebugLoc()); SI->replaceAllUsesWith(NewPN); SI->eraseFromParent(); // NewBB and SplitBB are newly created blocks which require insertion. @@ -3063,7 +3056,7 @@ bool JumpThreadingPass::threadGuard(BasicBlock *BB, IntrinsicInst *Guard, BasicBlock *TrueDest = BI->getSuccessor(0); BasicBlock *FalseDest = BI->getSuccessor(1); - auto &DL = BB->getModule()->getDataLayout(); + auto &DL = BB->getDataLayout(); bool TrueDestIsSafe = false; bool FalseDestIsSafe = false; @@ -3119,10 +3112,11 @@ bool JumpThreadingPass::threadGuard(BasicBlock *BB, IntrinsicInst *Guard, PHINode *NewPN = PHINode::Create(Inst->getType(), 2); NewPN->addIncoming(UnguardedMapping[Inst], UnguardedBlock); NewPN->addIncoming(GuardedMapping[Inst], GuardedBlock); + NewPN->setDebugLoc(Inst->getDebugLoc()); NewPN->insertBefore(InsertionPoint); Inst->replaceAllUsesWith(NewPN); } - Inst->dropDbgValues(); + Inst->dropDbgRecords(); Inst->eraseFromParent(); } return true; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LICM.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LICM.cpp index f3e40a5cb809..ca03eff7a4e2 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LICM.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LICM.cpp @@ -110,6 +110,9 @@ STATISTIC(NumAddSubHoisted, "Number of add/subtract expressions reassociated " "and hoisted out of the loop"); STATISTIC(NumFPAssociationsHoisted, "Number of invariant FP expressions " "reassociated and hoisted out of the loop"); +STATISTIC(NumIntAssociationsHoisted, + "Number of invariant int expressions " + "reassociated and hoisted out of the loop"); /// Memory promotion is enabled by default. static cl::opt<bool> @@ -135,6 +138,12 @@ static cl::opt<unsigned> FPAssociationUpperLimit( "Set upper limit for the number of transformations performed " "during a single round of hoisting the reassociated expressions.")); +cl::opt<unsigned> IntAssociationUpperLimit( + "licm-max-num-int-reassociations", cl::init(5U), cl::Hidden, + cl::desc( + "Set upper limit for the number of transformations performed " + "during a single round of hoisting the reassociated expressions.")); + // Experimental option to allow imprecision in LICM in pathological cases, in // exchange for faster compile. This is to be removed if MemorySSA starts to // address the same issue. LICM calls MemorySSAWalker's @@ -924,12 +933,14 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, ReciprocalDivisor->setFastMathFlags(I.getFastMathFlags()); SafetyInfo->insertInstructionTo(ReciprocalDivisor, I.getParent()); ReciprocalDivisor->insertBefore(&I); + ReciprocalDivisor->setDebugLoc(I.getDebugLoc()); auto Product = BinaryOperator::CreateFMul(I.getOperand(0), ReciprocalDivisor); Product->setFastMathFlags(I.getFastMathFlags()); SafetyInfo->insertInstructionTo(Product, I.getParent()); Product->insertAfter(&I); + Product->setDebugLoc(I.getDebugLoc()); I.replaceAllUsesWith(Product); eraseInstruction(I, *SafetyInfo, MSSAU); @@ -1041,7 +1052,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI, static bool isLoadInvariantInLoop(LoadInst *LI, DominatorTree *DT, Loop *CurLoop) { Value *Addr = LI->getPointerOperand(); - const DataLayout &DL = LI->getModule()->getDataLayout(); + const DataLayout &DL = LI->getDataLayout(); const TypeSize LocSizeInBits = DL.getTypeSizeInBits(LI->getType()); // It is not currently possible for clang to generate an invariant.start @@ -1208,6 +1219,14 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, if (CI->isConvergent()) return false; + // FIXME: Current LLVM IR semantics don't work well with coroutines and + // thread local globals. We currently treat getting the address of a thread + // local global as not accessing memory, even though it may not be a + // constant throughout a function with coroutines. Remove this check after + // we better model semantics of thread local globals. + if (CI->getFunction()->isPresplitCoroutine()) + return false; + using namespace PatternMatch; if (match(CI, m_Intrinsic<Intrinsic::assume>())) // Assumes don't actually alias anything or throw @@ -1216,14 +1235,6 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, // Handle simple cases by querying alias analysis. MemoryEffects Behavior = AA->getMemoryEffects(CI); - // FIXME: we don't handle the semantics of thread local well. So that the - // address of thread locals are fake constants in coroutines. So We forbid - // to treat onlyReadsMemory call in coroutines as constants now. Note that - // it is possible to hide a thread local access in a onlyReadsMemory call. - // Remove this check after we handle the semantics of thread locals well. - if (Behavior.onlyReadsMemory() && CI->getFunction()->isPresplitCoroutine()) - return false; - if (Behavior.doesNotAccessMemory()) return true; if (Behavior.onlyReadsMemory()) { @@ -1442,6 +1453,7 @@ static Instruction *cloneInstructionInExitBlock( } New = CallInst::Create(CI, OpBundles); + New->copyMetadata(*CI); } else { New = I.clone(); } @@ -1452,8 +1464,11 @@ static Instruction *cloneInstructionInExitBlock( if (MSSAU.getMemorySSA()->getMemoryAccess(&I)) { // Create a new MemoryAccess and let MemorySSA set its defining access. + // After running some passes, MemorySSA might be outdated, and the + // instruction `I` may have become a non-memory touching instruction. MemoryAccess *NewMemAcc = MSSAU.createMemoryAccessInBB( - New, nullptr, New->getParent(), MemorySSA::Beginning); + New, nullptr, New->getParent(), MemorySSA::Beginning, + /*CreationMustSucceed=*/false); if (NewMemAcc) { if (auto *MemDef = dyn_cast<MemoryDef>(NewMemAcc)) MSSAU.insertDef(MemDef, /*RenameUses=*/true); @@ -2031,7 +2046,7 @@ bool llvm::promoteLoopAccessesToScalars( bool SawNotAtomic = false; AAMDNodes AATags; - const DataLayout &MDL = Preheader->getModule()->getDataLayout(); + const DataLayout &MDL = Preheader->getDataLayout(); // If there are reads outside the promoted set, then promoting stores is // definitely not safe. @@ -2225,7 +2240,7 @@ bool llvm::promoteLoopAccessesToScalars( if (FoundLoadToPromote || !StoreIsGuanteedToExecute) { PreheaderLoad = new LoadInst(AccessTy, SomePtr, SomePtr->getName() + ".promoted", - Preheader->getTerminator()); + Preheader->getTerminator()->getIterator()); if (SawUnorderedAtomic) PreheaderLoad->setOrdering(AtomicOrdering::Unordered); PreheaderLoad->setAlignment(Alignment); @@ -2494,7 +2509,7 @@ static bool hoistGEP(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo, // The swapped GEPs are inbounds if both original GEPs are inbounds // and the sign of the offsets is the same. For simplicity, only // handle both offsets being non-negative. - const DataLayout &DL = GEP->getModule()->getDataLayout(); + const DataLayout &DL = GEP->getDataLayout(); auto NonNegative = [&](Value *V) { return isKnownNonNegative(V, SimplifyQuery(DL, DT, AC, GEP)); }; @@ -2544,7 +2559,7 @@ static bool hoistAdd(ICmpInst::Predicate Pred, Value *VariantLHS, // freely move values from left side of inequality to right side (just as in // normal linear arithmetics). Overflows make things much more complicated, so // we want to avoid this. - auto &DL = L.getHeader()->getModule()->getDataLayout(); + auto &DL = L.getHeader()->getDataLayout(); bool ProvedNoOverflowAfterReassociate = computeOverflowForSignedSub(InvariantRHS, InvariantOp, SimplifyQuery(DL, DT, AC, &ICmp)) == @@ -2597,7 +2612,7 @@ static bool hoistSub(ICmpInst::Predicate Pred, Value *VariantLHS, // normal linear arithmetics). Overflows make things much more complicated, so // we want to avoid this. Likewise, for "C1 - LV < C2" we need to prove that // "C1 - C2" does not overflow. - auto &DL = L.getHeader()->getModule()->getDataLayout(); + auto &DL = L.getHeader()->getDataLayout(); SimplifyQuery SQ(DL, DT, AC, &ICmp); if (VariantSubtracted) { // C1 - LV < C2 --> LV > C1 - C2 @@ -2661,21 +2676,29 @@ static bool hoistAddSub(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo, return false; } +static bool isReassociableOp(Instruction *I, unsigned IntOpcode, + unsigned FPOpcode) { + if (I->getOpcode() == IntOpcode) + return true; + if (I->getOpcode() == FPOpcode && I->hasAllowReassoc() && + I->hasNoSignedZeros()) + return true; + return false; +} + /// Try to reassociate expressions like ((A1 * B1) + (A2 * B2) + ...) * C where /// A1, A2, ... and C are loop invariants into expressions like /// ((A1 * C * B1) + (A2 * C * B2) + ...) and hoist the (A1 * C), (A2 * C), ... /// invariant expressions. This functions returns true only if any hoisting has /// actually occured. -static bool hoistFPAssociation(Instruction &I, Loop &L, - ICFLoopSafetyInfo &SafetyInfo, - MemorySSAUpdater &MSSAU, AssumptionCache *AC, - DominatorTree *DT) { - using namespace PatternMatch; - Value *VariantOp = nullptr, *InvariantOp = nullptr; - - if (!match(&I, m_FMul(m_Value(VariantOp), m_Value(InvariantOp))) || - !I.hasAllowReassoc() || !I.hasNoSignedZeros()) +static bool hoistMulAddAssociation(Instruction &I, Loop &L, + ICFLoopSafetyInfo &SafetyInfo, + MemorySSAUpdater &MSSAU, AssumptionCache *AC, + DominatorTree *DT) { + if (!isReassociableOp(&I, Instruction::Mul, Instruction::FMul)) return false; + Value *VariantOp = I.getOperand(0); + Value *InvariantOp = I.getOperand(1); if (L.isLoopInvariant(VariantOp)) std::swap(VariantOp, InvariantOp); if (L.isLoopInvariant(VariantOp) || !L.isLoopInvariant(InvariantOp)) @@ -2684,20 +2707,24 @@ static bool hoistFPAssociation(Instruction &I, Loop &L, // First, we need to make sure we should do the transformation. SmallVector<Use *> Changes; + SmallVector<BinaryOperator *> Adds; SmallVector<BinaryOperator *> Worklist; if (BinaryOperator *VariantBinOp = dyn_cast<BinaryOperator>(VariantOp)) Worklist.push_back(VariantBinOp); while (!Worklist.empty()) { BinaryOperator *BO = Worklist.pop_back_val(); - if (!BO->hasOneUse() || !BO->hasAllowReassoc() || !BO->hasNoSignedZeros()) + if (!BO->hasOneUse()) return false; - BinaryOperator *Op0, *Op1; - if (match(BO, m_FAdd(m_BinOp(Op0), m_BinOp(Op1)))) { - Worklist.push_back(Op0); - Worklist.push_back(Op1); + if (isReassociableOp(BO, Instruction::Add, Instruction::FAdd) && + isa<BinaryOperator>(BO->getOperand(0)) && + isa<BinaryOperator>(BO->getOperand(1))) { + Worklist.push_back(cast<BinaryOperator>(BO->getOperand(0))); + Worklist.push_back(cast<BinaryOperator>(BO->getOperand(1))); + Adds.push_back(BO); continue; } - if (BO->getOpcode() != Instruction::FMul || L.isLoopInvariant(BO)) + if (!isReassociableOp(BO, Instruction::Mul, Instruction::FMul) || + L.isLoopInvariant(BO)) return false; Use &U0 = BO->getOperandUse(0); Use &U1 = BO->getOperandUse(1); @@ -2707,21 +2734,49 @@ static bool hoistFPAssociation(Instruction &I, Loop &L, Changes.push_back(&U1); else return false; - if (Changes.size() > FPAssociationUpperLimit) + unsigned Limit = I.getType()->isIntOrIntVectorTy() + ? IntAssociationUpperLimit + : FPAssociationUpperLimit; + if (Changes.size() > Limit) return false; } if (Changes.empty()) return false; + // Drop the poison flags for any adds we looked through. + if (I.getType()->isIntOrIntVectorTy()) { + for (auto *Add : Adds) + Add->dropPoisonGeneratingFlags(); + } + // We know we should do it so let's do the transformation. auto *Preheader = L.getLoopPreheader(); assert(Preheader && "Loop is not in simplify form?"); IRBuilder<> Builder(Preheader->getTerminator()); for (auto *U : Changes) { assert(L.isLoopInvariant(U->get())); - Instruction *Ins = cast<Instruction>(U->getUser()); - U->set(Builder.CreateFMulFMF(U->get(), Factor, Ins, "factor.op.fmul")); + auto *Ins = cast<BinaryOperator>(U->getUser()); + Value *Mul; + if (I.getType()->isIntOrIntVectorTy()) { + Mul = Builder.CreateMul(U->get(), Factor, "factor.op.mul"); + // Drop the poison flags on the original multiply. + Ins->dropPoisonGeneratingFlags(); + } else + Mul = Builder.CreateFMulFMF(U->get(), Factor, Ins, "factor.op.fmul"); + + // Rewrite the reassociable instruction. + unsigned OpIdx = U->getOperandNo(); + auto *LHS = OpIdx == 0 ? Mul : Ins->getOperand(0); + auto *RHS = OpIdx == 1 ? Mul : Ins->getOperand(1); + auto *NewBO = BinaryOperator::Create(Ins->getOpcode(), LHS, RHS, + Ins->getName() + ".reass", Ins); + NewBO->copyIRFlags(Ins); + if (VariantOp == Ins) + VariantOp = NewBO; + Ins->replaceAllUsesWith(NewBO); + eraseInstruction(*Ins, SafetyInfo, MSSAU); } + I.replaceAllUsesWith(VariantOp); eraseInstruction(I, SafetyInfo, MSSAU); return true; @@ -2754,9 +2809,13 @@ static bool hoistArithmetics(Instruction &I, Loop &L, return true; } - if (hoistFPAssociation(I, L, SafetyInfo, MSSAU, AC, DT)) { + bool IsInt = I.getType()->isIntOrIntVectorTy(); + if (hoistMulAddAssociation(I, L, SafetyInfo, MSSAU, AC, DT)) { ++NumHoisted; - ++NumFPAssociationsHoisted; + if (IsInt) + ++NumIntAssociationsHoisted; + else + ++NumFPAssociationsHoisted; return true; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp index 9a27a08c86eb..6092cd1bc08b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp @@ -405,7 +405,7 @@ static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI, : SE.getUMinExpr(NewBoundSCEV, SplitBoundSCEV); SCEVExpander Expander( - SE, L.getHeader()->getParent()->getParent()->getDataLayout(), "split"); + SE, L.getHeader()->getDataLayout(), "split"); Instruction *InsertPt = SplitLoopPH->getTerminator(); Value *NewBoundValue = Expander.expandCodeFor(NewBoundSCEV, NewBoundSCEV->getType(), InsertPt); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp index cc1f56014eee..d85166e518f1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp @@ -391,7 +391,7 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { continue; BasicBlock *BB = P.InsertPt->getParent(); - SCEVExpander SCEVE(*SE, BB->getModule()->getDataLayout(), "prefaddr"); + SCEVExpander SCEVE(*SE, BB->getDataLayout(), "prefaddr"); const SCEV *NextLSCEV = SE->getAddExpr(P.LSCEVAddRec, SE->getMulExpr( SE->getConstant(P.LSCEVAddRec->getType(), ItersAhead), P.LSCEVAddRec->getStepRecurrence(*SE))); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDeletion.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDeletion.cpp index bfe9374cf2f8..b0b7ae60da98 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDeletion.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDeletion.cpp @@ -273,9 +273,9 @@ static bool canProveExitOnFirstIteration(Loop *L, DominatorTree &DT, if (LiveEdges.count({ Pred, BB })) { HasLivePreds = true; Value *Incoming = PN.getIncomingValueForBlock(Pred); - // Skip undefs. If they are present, we can assume they are equal to - // the non-undef input. - if (isa<UndefValue>(Incoming)) + // Skip poison. If they are present, we can assume they are equal to + // the non-poison input. + if (isa<PoisonValue>(Incoming)) continue; // Two inputs. if (OnlyInput && OnlyInput != Incoming) @@ -284,8 +284,8 @@ static bool canProveExitOnFirstIteration(Loop *L, DominatorTree &DT, } assert(HasLivePreds && "No live predecessors?"); - // If all incoming live value were undefs, return undef. - return OnlyInput ? OnlyInput : UndefValue::get(PN.getType()); + // If all incoming live value were poison, return poison. + return OnlyInput ? OnlyInput : PoisonValue::get(PN.getType()); }; DenseMap<Value *, Value *> FirstIterValue; @@ -299,7 +299,7 @@ static bool canProveExitOnFirstIteration(Loop *L, DominatorTree &DT, // iteration, mark this successor live. // 3b. If we cannot prove it, conservatively assume that all successors are // live. - auto &DL = Header->getModule()->getDataLayout(); + auto &DL = Header->getDataLayout(); const SimplifyQuery SQ(DL); for (auto *BB : RPOT) { Visited.insert(BB); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDistribute.cpp index 626888c74bad..c84e419c2a24 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -26,7 +26,7 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" @@ -120,7 +120,7 @@ namespace { /// Maintains the set of instructions of the loop for a partition before /// cloning. After cloning, it hosts the new loop. class InstPartition { - using InstructionSet = SmallPtrSet<Instruction *, 8>; + using InstructionSet = SmallSetVector<Instruction *, 8>; public: InstPartition(Instruction *I, Loop *L, bool DepCycle = false) @@ -166,7 +166,7 @@ public: // Insert instructions from the loop that we depend on. for (Value *V : I->operand_values()) { auto *I = dyn_cast<Instruction>(V); - if (I && OrigLoop->contains(I->getParent()) && Set.insert(I).second) + if (I && OrigLoop->contains(I->getParent()) && Set.insert(I)) Worklist.push_back(I); } } @@ -231,17 +231,16 @@ public: } } - void print() const { - if (DepCycle) - dbgs() << " (cycle)\n"; + void print(raw_ostream &OS) const { + OS << (DepCycle ? " (cycle)\n" : "\n"); for (auto *I : Set) // Prefix with the block name. - dbgs() << " " << I->getParent()->getName() << ":" << *I << "\n"; + OS << " " << I->getParent()->getName() << ":" << *I << "\n"; } - void printBlocks() const { + void printBlocks(raw_ostream &OS) const { for (auto *BB : getDistributedLoop()->getBlocks()) - dbgs() << *BB; + OS << *BB; } private: @@ -368,11 +367,11 @@ public: std::tie(LoadToPart, NewElt) = LoadToPartition.insert(std::make_pair(Inst, PartI)); if (!NewElt) { - LLVM_DEBUG(dbgs() - << "Merging partitions due to this load in multiple " - << "partitions: " << PartI << ", " << LoadToPart->second - << "\n" - << *Inst << "\n"); + LLVM_DEBUG( + dbgs() + << "LDist: Merging partitions due to this load in multiple " + << "partitions: " << PartI << ", " << LoadToPart->second << "\n" + << *Inst << "\n"); auto PartJ = I; do { @@ -530,8 +529,8 @@ public: void print(raw_ostream &OS) const { unsigned Index = 0; for (const auto &P : PartitionContainer) { - OS << "Partition " << Index++ << " (" << &P << "):\n"; - P.print(); + OS << "LDist: Partition " << Index++ << ":"; + P.print(OS); } } @@ -545,11 +544,11 @@ public: } #endif - void printBlocks() const { + void printBlocks(raw_ostream &OS) const { unsigned Index = 0; for (const auto &P : PartitionContainer) { - dbgs() << "\nPartition " << Index++ << " (" << &P << "):\n"; - P.printBlocks(); + OS << "LDist: Partition " << Index++ << ":"; + P.printBlocks(OS); } } @@ -628,7 +627,7 @@ public: const SmallVectorImpl<Dependence> &Dependences) { Accesses.append(Instructions.begin(), Instructions.end()); - LLVM_DEBUG(dbgs() << "Backward dependences:\n"); + LLVM_DEBUG(dbgs() << "LDist: Backward dependences:\n"); for (const auto &Dep : Dependences) if (Dep.isPossiblyBackward()) { // Note that the designations source and destination follow the program @@ -659,9 +658,9 @@ public: bool processLoop() { assert(L->isInnermost() && "Only process inner loops."); - LLVM_DEBUG(dbgs() << "\nLDist: In \"" - << L->getHeader()->getParent()->getName() - << "\" checking " << *L << "\n"); + LLVM_DEBUG(dbgs() << "\nLDist: Checking a loop in '" + << L->getHeader()->getParent()->getName() << "' from " + << L->getLocStr() << "\n"); // Having a single exit block implies there's also one exiting block. if (!L->getExitBlock()) @@ -686,6 +685,9 @@ public: if (!Dependences || Dependences->empty()) return fail("NoUnsafeDeps", "no unsafe dependences to isolate"); + LLVM_DEBUG(dbgs() << "LDist: Found a candidate loop: " + << L->getHeader()->getName() << "\n"); + InstPartitionContainer Partitions(L, LI, DT); // First, go through each memory operation and assign them to consecutive @@ -735,7 +737,7 @@ public: for (auto *Inst : DefsUsedOutside) Partitions.addToNewNonCyclicPartition(Inst); - LLVM_DEBUG(dbgs() << "Seeded partitions:\n" << Partitions); + LLVM_DEBUG(dbgs() << "LDist: Seeded partitions:\n" << Partitions); if (Partitions.getSize() < 2) return fail("CantIsolateUnsafeDeps", "cannot isolate unsafe dependencies"); @@ -743,19 +745,19 @@ public: // Run the merge heuristics: Merge non-cyclic adjacent partitions since we // should be able to vectorize these together. Partitions.mergeBeforePopulating(); - LLVM_DEBUG(dbgs() << "\nMerged partitions:\n" << Partitions); + LLVM_DEBUG(dbgs() << "LDist: Merged partitions:\n" << Partitions); if (Partitions.getSize() < 2) return fail("CantIsolateUnsafeDeps", "cannot isolate unsafe dependencies"); // Now, populate the partitions with non-memory operations. Partitions.populateUsedSet(); - LLVM_DEBUG(dbgs() << "\nPopulated partitions:\n" << Partitions); + LLVM_DEBUG(dbgs() << "LDist: Populated partitions:\n" << Partitions); // In order to preserve original lexical order for loads, keep them in the // partition that we set up in the MemoryInstructionDependences loop. if (Partitions.mergeToAvoidDuplicatedLoads()) { - LLVM_DEBUG(dbgs() << "\nPartitions merged to ensure unique loads:\n" + LLVM_DEBUG(dbgs() << "LDist: Partitions merged to ensure unique loads:\n" << Partitions); if (Partitions.getSize() < 2) return fail("CantIsolateUnsafeDeps", @@ -779,7 +781,8 @@ public: if (!IsForced.value_or(false) && hasDisableAllTransformsHint(L)) return fail("HeuristicDisabled", "distribution heuristic disabled"); - LLVM_DEBUG(dbgs() << "\nDistributing loop: " << *L << "\n"); + LLVM_DEBUG(dbgs() << "LDist: Distributing loop: " + << L->getHeader()->getName() << "\n"); // We're done forming the partitions set up the reverse mapping from // instructions to partitions. Partitions.setupPartitionIdOnInstructions(); @@ -807,7 +810,7 @@ public: MDNode *OrigLoopID = L->getLoopID(); - LLVM_DEBUG(dbgs() << "\nPointers:\n"); + LLVM_DEBUG(dbgs() << "LDist: Pointers:\n"); LLVM_DEBUG(LAI->getRuntimePointerChecking()->printChecks(dbgs(), Checks)); LoopVersioning LVer(*LAI, Checks, L, LI, DT, SE); LVer.versionLoop(DefsUsedOutside); @@ -830,8 +833,8 @@ public: // Now, we remove the instruction from each loop that don't belong to that // partition. Partitions.removeUnusedInsts(); - LLVM_DEBUG(dbgs() << "\nAfter removing unused Instrs:\n"); - LLVM_DEBUG(Partitions.printBlocks()); + LLVM_DEBUG(dbgs() << "LDist: After removing unused Instrs:\n"); + LLVM_DEBUG(Partitions.printBlocks(dbgs())); if (LDistVerify) { LI->verify(*DT); @@ -853,7 +856,7 @@ public: LLVMContext &Ctx = F->getContext(); bool Forced = isForced().value_or(false); - LLVM_DEBUG(dbgs() << "Skipping; " << Message << "\n"); + LLVM_DEBUG(dbgs() << "LDist: Skipping; " << Message << "\n"); // With Rpass-missed report that distribution failed. ORE->emit([&]() { @@ -962,11 +965,10 @@ private: } // end anonymous namespace -/// Shared implementation between new and old PMs. static bool runImpl(Function &F, LoopInfo *LI, DominatorTree *DT, ScalarEvolution *SE, OptimizationRemarkEmitter *ORE, LoopAccessInfoManager &LAIs) { - // Build up a worklist of inner-loops to vectorize. This is necessary as the + // Build up a worklist of inner-loops to distribute. This is necessary as the // act of distributing a loop creates new loops and can invalidate iterators // across the loops. SmallVector<Loop *, 8> Worklist; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFlatten.cpp index 533cefaf1061..d5e91d3c1dec 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -70,6 +70,7 @@ #include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/LoopVersioning.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include "llvm/Transforms/Utils/SimplifyIndVar.h" #include <optional> @@ -97,6 +98,10 @@ static cl::opt<bool> cl::desc("Widen the loop induction variables, if possible, so " "overflow checks won't reject flattening")); +static cl::opt<bool> + VersionLoops("loop-flatten-version-loops", cl::Hidden, cl::init(true), + cl::desc("Version loops if flattened loop could overflow")); + namespace { // We require all uses of both induction variables to match this pattern: // @@ -141,6 +146,8 @@ struct FlattenInfo { // has been applied. Used to skip // checks on phi nodes. + Value *NewTripCount = nullptr; // The tripcount of the flattened loop. + FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL){}; bool isNarrowInductionPhi(PHINode *Phi) { @@ -637,7 +644,7 @@ static bool checkIVUsers(FlattenInfo &FI) { static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT, AssumptionCache *AC) { Function *F = FI.OuterLoop->getHeader()->getParent(); - const DataLayout &DL = F->getParent()->getDataLayout(); + const DataLayout &DL = F->getDataLayout(); // For debugging/testing. if (AssumeNoOverflow) @@ -752,11 +759,13 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, ORE.emit(Remark); } - Value *NewTripCount = BinaryOperator::CreateMul( - FI.InnerTripCount, FI.OuterTripCount, "flatten.tripcount", - FI.OuterLoop->getLoopPreheader()->getTerminator()); - LLVM_DEBUG(dbgs() << "Created new trip count in preheader: "; - NewTripCount->dump()); + if (!FI.NewTripCount) { + FI.NewTripCount = BinaryOperator::CreateMul( + FI.InnerTripCount, FI.OuterTripCount, "flatten.tripcount", + FI.OuterLoop->getLoopPreheader()->getTerminator()->getIterator()); + LLVM_DEBUG(dbgs() << "Created new trip count in preheader: "; + FI.NewTripCount->dump()); + } // Fix up PHI nodes that take values from the inner loop back-edge, which // we are about to remove. @@ -769,13 +778,15 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, // Modify the trip count of the outer loop to be the product of the two // trip counts. - cast<User>(FI.OuterBranch->getCondition())->setOperand(1, NewTripCount); + cast<User>(FI.OuterBranch->getCondition())->setOperand(1, FI.NewTripCount); // Replace the inner loop backedge with an unconditional branch to the exit. BasicBlock *InnerExitBlock = FI.InnerLoop->getExitBlock(); BasicBlock *InnerExitingBlock = FI.InnerLoop->getExitingBlock(); - InnerExitingBlock->getTerminator()->eraseFromParent(); - BranchInst::Create(InnerExitBlock, InnerExitingBlock); + Instruction *Term = InnerExitingBlock->getTerminator(); + Instruction *BI = BranchInst::Create(InnerExitBlock, InnerExitingBlock); + BI->setDebugLoc(Term->getDebugLoc()); + Term->eraseFromParent(); // Update the DomTree and MemorySSA. DT->deleteEdge(InnerExitingBlock, FI.InnerLoop->getHeader()); @@ -799,8 +810,10 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, // we need to insert the new GEP where the old GEP was. if (!DT->dominates(Base, &*Builder.GetInsertPoint())) Builder.SetInsertPoint(cast<Instruction>(V)); - OuterValue = Builder.CreateGEP(GEP->getSourceElementType(), Base, - OuterValue, "flatten." + V->getName()); + OuterValue = + Builder.CreateGEP(GEP->getSourceElementType(), Base, OuterValue, + "flatten." + V->getName(), + GEP->isInBounds() && InnerGEP->isInBounds()); } LLVM_DEBUG(dbgs() << "Replacing: "; V->dump(); dbgs() << "with: "; @@ -891,7 +904,8 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, const TargetTransformInfo *TTI, LPMUpdater *U, - MemorySSAUpdater *MSSAU) { + MemorySSAUpdater *MSSAU, + const LoopAccessInfo &LAI) { LLVM_DEBUG( dbgs() << "Loop flattening running on outer loop " << FI.OuterLoop->getHeader()->getName() << " and inner loop " @@ -926,18 +940,55 @@ static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, // variable might overflow. In this case, we need to version the loop, and // select the original version at runtime if the iteration space is too // large. - // TODO: We currently don't version the loop. OverflowResult OR = checkOverflow(FI, DT, AC); if (OR == OverflowResult::AlwaysOverflowsHigh || OR == OverflowResult::AlwaysOverflowsLow) { LLVM_DEBUG(dbgs() << "Multiply would always overflow, so not profitable\n"); return false; } else if (OR == OverflowResult::MayOverflow) { - LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n"); - return false; + Module *M = FI.OuterLoop->getHeader()->getParent()->getParent(); + const DataLayout &DL = M->getDataLayout(); + if (!VersionLoops) { + LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n"); + return false; + } else if (!DL.isLegalInteger( + FI.OuterTripCount->getType()->getScalarSizeInBits())) { + // If the trip count type isn't legal then it won't be possible to check + // for overflow using only a single multiply instruction, so don't + // flatten. + LLVM_DEBUG( + dbgs() << "Can't check overflow efficiently, not flattening\n"); + return false; + } + LLVM_DEBUG(dbgs() << "Multiply might overflow, versioning loop\n"); + + // Version the loop. The overflow check isn't a runtime pointer check, so we + // pass an empty list of runtime pointer checks, causing LoopVersioning to + // emit 'false' as the branch condition, and add our own check afterwards. + BasicBlock *CheckBlock = FI.OuterLoop->getLoopPreheader(); + ArrayRef<RuntimePointerCheck> Checks(nullptr, nullptr); + LoopVersioning LVer(LAI, Checks, FI.OuterLoop, LI, DT, SE); + LVer.versionLoop(); + + // Check for overflow by calculating the new tripcount using + // umul_with_overflow and then checking if it overflowed. + BranchInst *Br = cast<BranchInst>(CheckBlock->getTerminator()); + assert(Br->isConditional() && + "Expected LoopVersioning to generate a conditional branch"); + assert(match(Br->getCondition(), m_Zero()) && + "Expected branch condition to be false"); + IRBuilder<> Builder(Br); + Function *F = Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow, + FI.OuterTripCount->getType()); + Value *Call = Builder.CreateCall(F, {FI.OuterTripCount, FI.InnerTripCount}, + "flatten.mul"); + FI.NewTripCount = Builder.CreateExtractValue(Call, 0, "flatten.tripcount"); + Value *Overflow = Builder.CreateExtractValue(Call, 1, "flatten.overflow"); + Br->setCondition(Overflow); + } else { + LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n"); } - LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n"); return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU); } @@ -958,13 +1009,15 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, // in simplified form, and also needs LCSSA. Running // this pass will simplify all loops that contain inner loops, // regardless of whether anything ends up being flattened. + LoopAccessInfoManager LAIM(AR.SE, AR.AA, AR.DT, AR.LI, &AR.TTI, nullptr); for (Loop *InnerLoop : LN.getLoops()) { auto *OuterLoop = InnerLoop->getParentLoop(); if (!OuterLoop) continue; FlattenInfo FI(OuterLoop, InnerLoop); - Changed |= FlattenLoopPair(FI, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U, - MSSAU ? &*MSSAU : nullptr); + Changed |= + FlattenLoopPair(FI, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U, + MSSAU ? &*MSSAU : nullptr, LAIM.getInfo(*OuterLoop)); } if (!Changed) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFuse.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFuse.cpp index e0b224d5ef73..8512b2accbe7 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFuse.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopFuse.cpp @@ -1684,7 +1684,7 @@ private: PHINode::Create(LCV->getType(), 2, LCPHI->getName() + ".afterFC0"); L1HeaderPHI->insertBefore(L1HeaderIP); L1HeaderPHI->addIncoming(LCV, FC0.Latch); - L1HeaderPHI->addIncoming(UndefValue::get(LCV->getType()), + L1HeaderPHI->addIncoming(PoisonValue::get(LCV->getType()), FC0.ExitingBlock); LCPHI->setIncomingValue(L1LatchBBIdx, L1HeaderPHI); @@ -2072,7 +2072,7 @@ PreservedAnalyses LoopFusePass::run(Function &F, FunctionAnalysisManager &AM) { auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); const TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F); - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); // Ensure loops are in simplifed form which is a pre-requisite for loop fusion // pass. Added only for new PM since the legacy PM has already added diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 3721564890dd..0ee1afa76a82 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -22,8 +22,6 @@ // // Future loop memory idioms to recognize: // memcmp, strlen, etc. -// Future floating point idioms to recognize in -ffast-math mode: -// fpowi // // This could recognize common matrix multiplies and dot product idioms and // replace them with calls to BLAS (if linked in??). @@ -233,12 +231,19 @@ private: bool recognizePopcount(); void transformLoopToPopcount(BasicBlock *PreCondBB, Instruction *CntInst, PHINode *CntPhi, Value *Var); + bool isProfitableToInsertFFS(Intrinsic::ID IntrinID, Value *InitX, + bool ZeroCheck, size_t CanonicalSize); + bool insertFFSIfProfitable(Intrinsic::ID IntrinID, Value *InitX, + Instruction *DefX, PHINode *CntPhi, + Instruction *CntInst); bool recognizeAndInsertFFS(); /// Find First Set: ctlz or cttz + bool recognizeShiftUntilLessThan(); void transformLoopToCountable(Intrinsic::ID IntrinID, BasicBlock *PreCondBB, Instruction *CntInst, PHINode *CntPhi, Value *Var, Instruction *DefX, const DebugLoc &DL, bool ZeroCheck, - bool IsCntPhiUsedOutsideLoop); + bool IsCntPhiUsedOutsideLoop, + bool InsertSub = false); bool recognizeShiftUntilBitTest(); bool recognizeShiftUntilZero(); @@ -253,7 +258,7 @@ PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM, if (DisableLIRP::All) return PreservedAnalyses::all(); - const auto *DL = &L.getHeader()->getModule()->getDataLayout(); + const auto *DL = &L.getHeader()->getDataLayout(); // For the new PM, we also can't use OptimizationRemarkEmitter as an analysis // pass. Function analyses need to be preserved across loop transformations @@ -1107,7 +1112,7 @@ bool LoopIdiomRecognize::processLoopStridedStore( GV->setAlignment(Align(16)); Value *PatternPtr = GV; NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, NumBytes}); - + // Set the TBAA info if present. if (AATags.TBAA) NewCall->setMetadata(LLVMContext::MD_tbaa, AATags.TBAA); @@ -1117,7 +1122,7 @@ bool LoopIdiomRecognize::processLoopStridedStore( if (AATags.NoAlias) NewCall->setMetadata(LLVMContext::MD_noalias, AATags.NoAlias); - } + } NewCall->setDebugLoc(TheStore->getDebugLoc()); @@ -1484,7 +1489,8 @@ bool LoopIdiomRecognize::runOnNoncountableLoop() { << CurLoop->getHeader()->getName() << "\n"); return recognizePopcount() || recognizeAndInsertFFS() || - recognizeShiftUntilBitTest() || recognizeShiftUntilZero(); + recognizeShiftUntilBitTest() || recognizeShiftUntilZero() || + recognizeShiftUntilLessThan(); } /// Check if the given conditional branch is based on the comparison between @@ -1519,6 +1525,34 @@ static Value *matchCondition(BranchInst *BI, BasicBlock *LoopEntry, return nullptr; } +/// Check if the given conditional branch is based on an unsigned less-than +/// comparison between a variable and a constant, and if the comparison is false +/// the control yields to the loop entry. If the branch matches the behaviour, +/// the variable involved in the comparison is returned. +static Value *matchShiftULTCondition(BranchInst *BI, BasicBlock *LoopEntry, + APInt &Threshold) { + if (!BI || !BI->isConditional()) + return nullptr; + + ICmpInst *Cond = dyn_cast<ICmpInst>(BI->getCondition()); + if (!Cond) + return nullptr; + + ConstantInt *CmpConst = dyn_cast<ConstantInt>(Cond->getOperand(1)); + if (!CmpConst) + return nullptr; + + BasicBlock *FalseSucc = BI->getSuccessor(1); + ICmpInst::Predicate Pred = Cond->getPredicate(); + + if (Pred == ICmpInst::ICMP_ULT && FalseSucc == LoopEntry) { + Threshold = CmpConst->getValue(); + return Cond->getOperand(0); + } + + return nullptr; +} + // Check if the recurrence variable `VarX` is in the right form to create // the idiom. Returns the value coerced to a PHINode if so. static PHINode *getRecurrenceVar(Value *VarX, Instruction *DefX, @@ -1530,6 +1564,107 @@ static PHINode *getRecurrenceVar(Value *VarX, Instruction *DefX, return nullptr; } +/// Return true if the idiom is detected in the loop. +/// +/// Additionally: +/// 1) \p CntInst is set to the instruction Counting Leading Zeros (CTLZ) +/// or nullptr if there is no such. +/// 2) \p CntPhi is set to the corresponding phi node +/// or nullptr if there is no such. +/// 3) \p InitX is set to the value whose CTLZ could be used. +/// 4) \p DefX is set to the instruction calculating Loop exit condition. +/// 5) \p Threshold is set to the constant involved in the unsigned less-than +/// comparison. +/// +/// The core idiom we are trying to detect is: +/// \code +/// if (x0 < 2) +/// goto loop-exit // the precondition of the loop +/// cnt0 = init-val +/// do { +/// x = phi (x0, x.next); //PhiX +/// cnt = phi (cnt0, cnt.next) +/// +/// cnt.next = cnt + 1; +/// ... +/// x.next = x >> 1; // DefX +/// } while (x >= 4) +/// loop-exit: +/// \endcode +static bool detectShiftUntilLessThanIdiom(Loop *CurLoop, const DataLayout &DL, + Intrinsic::ID &IntrinID, + Value *&InitX, Instruction *&CntInst, + PHINode *&CntPhi, Instruction *&DefX, + APInt &Threshold) { + BasicBlock *LoopEntry; + + DefX = nullptr; + CntInst = nullptr; + CntPhi = nullptr; + LoopEntry = *(CurLoop->block_begin()); + + // step 1: Check if the loop-back branch is in desirable form. + if (Value *T = matchShiftULTCondition( + dyn_cast<BranchInst>(LoopEntry->getTerminator()), LoopEntry, + Threshold)) + DefX = dyn_cast<Instruction>(T); + else + return false; + + // step 2: Check the recurrence of variable X + if (!DefX || !isa<PHINode>(DefX)) + return false; + + PHINode *VarPhi = cast<PHINode>(DefX); + int Idx = VarPhi->getBasicBlockIndex(LoopEntry); + if (Idx == -1) + return false; + + DefX = dyn_cast<Instruction>(VarPhi->getIncomingValue(Idx)); + if (!DefX || DefX->getNumOperands() == 0 || DefX->getOperand(0) != VarPhi) + return false; + + // step 3: detect instructions corresponding to "x.next = x >> 1" + if (DefX->getOpcode() != Instruction::LShr) + return false; + + IntrinID = Intrinsic::ctlz; + ConstantInt *Shft = dyn_cast<ConstantInt>(DefX->getOperand(1)); + if (!Shft || !Shft->isOne()) + return false; + + InitX = VarPhi->getIncomingValueForBlock(CurLoop->getLoopPreheader()); + + // step 4: Find the instruction which count the CTLZ: cnt.next = cnt + 1 + // or cnt.next = cnt + -1. + // TODO: We can skip the step. If loop trip count is known (CTLZ), + // then all uses of "cnt.next" could be optimized to the trip count + // plus "cnt0". Currently it is not optimized. + // This step could be used to detect POPCNT instruction: + // cnt.next = cnt + (x.next & 1) + for (Instruction &Inst : llvm::make_range( + LoopEntry->getFirstNonPHI()->getIterator(), LoopEntry->end())) { + if (Inst.getOpcode() != Instruction::Add) + continue; + + ConstantInt *Inc = dyn_cast<ConstantInt>(Inst.getOperand(1)); + if (!Inc || (!Inc->isOne() && !Inc->isMinusOne())) + continue; + + PHINode *Phi = getRecurrenceVar(Inst.getOperand(0), &Inst, LoopEntry); + if (!Phi) + continue; + + CntInst = &Inst; + CntPhi = Phi; + break; + } + if (!CntInst) + return false; + + return true; +} + /// Return true iff the idiom is detected in the loop. /// /// Additionally: @@ -1758,27 +1893,35 @@ static bool detectShiftUntilZeroIdiom(Loop *CurLoop, const DataLayout &DL, return true; } -/// Recognize CTLZ or CTTZ idiom in a non-countable loop and convert the loop -/// to countable (with CTLZ / CTTZ trip count). If CTLZ / CTTZ inserted as a new -/// trip count returns true; otherwise, returns false. -bool LoopIdiomRecognize::recognizeAndInsertFFS() { - // Give up if the loop has multiple blocks or multiple backedges. - if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1) - return false; +// Check if CTLZ / CTTZ intrinsic is profitable. Assume it is always +// profitable if we delete the loop. +bool LoopIdiomRecognize::isProfitableToInsertFFS(Intrinsic::ID IntrinID, + Value *InitX, bool ZeroCheck, + size_t CanonicalSize) { + const Value *Args[] = {InitX, + ConstantInt::getBool(InitX->getContext(), ZeroCheck)}; - Intrinsic::ID IntrinID; - Value *InitX; - Instruction *DefX = nullptr; - PHINode *CntPhi = nullptr; - Instruction *CntInst = nullptr; - // Help decide if transformation is profitable. For ShiftUntilZero idiom, - // this is always 6. - size_t IdiomCanonicalSize = 6; + // @llvm.dbg doesn't count as they have no semantic effect. + auto InstWithoutDebugIt = CurLoop->getHeader()->instructionsWithoutDebug(); + uint32_t HeaderSize = + std::distance(InstWithoutDebugIt.begin(), InstWithoutDebugIt.end()); - if (!detectShiftUntilZeroIdiom(CurLoop, *DL, IntrinID, InitX, - CntInst, CntPhi, DefX)) + IntrinsicCostAttributes Attrs(IntrinID, InitX->getType(), Args); + InstructionCost Cost = TTI->getIntrinsicInstrCost( + Attrs, TargetTransformInfo::TCK_SizeAndLatency); + if (HeaderSize != CanonicalSize && Cost > TargetTransformInfo::TCC_Basic) return false; + return true; +} + +/// Convert CTLZ / CTTZ idiom loop into countable loop. +/// If CTLZ / CTTZ inserted as a new trip count returns true; otherwise, +/// returns false. +bool LoopIdiomRecognize::insertFFSIfProfitable(Intrinsic::ID IntrinID, + Value *InitX, Instruction *DefX, + PHINode *CntPhi, + Instruction *CntInst) { bool IsCntPhiUsedOutsideLoop = false; for (User *U : CntPhi->users()) if (!CurLoop->contains(cast<Instruction>(U))) { @@ -1820,35 +1963,107 @@ bool LoopIdiomRecognize::recognizeAndInsertFFS() { ZeroCheck = true; } - // Check if CTLZ / CTTZ intrinsic is profitable. Assume it is always - // profitable if we delete the loop. - - // the loop has only 6 instructions: + // FFS idiom loop has only 6 instructions: // %n.addr.0 = phi [ %n, %entry ], [ %shr, %while.cond ] // %i.0 = phi [ %i0, %entry ], [ %inc, %while.cond ] // %shr = ashr %n.addr.0, 1 // %tobool = icmp eq %shr, 0 // %inc = add nsw %i.0, 1 // br i1 %tobool + size_t IdiomCanonicalSize = 6; + if (!isProfitableToInsertFFS(IntrinID, InitX, ZeroCheck, IdiomCanonicalSize)) + return false; - const Value *Args[] = {InitX, - ConstantInt::getBool(InitX->getContext(), ZeroCheck)}; + transformLoopToCountable(IntrinID, PH, CntInst, CntPhi, InitX, DefX, + DefX->getDebugLoc(), ZeroCheck, + IsCntPhiUsedOutsideLoop); + return true; +} - // @llvm.dbg doesn't count as they have no semantic effect. - auto InstWithoutDebugIt = CurLoop->getHeader()->instructionsWithoutDebug(); - uint32_t HeaderSize = - std::distance(InstWithoutDebugIt.begin(), InstWithoutDebugIt.end()); +/// Recognize CTLZ or CTTZ idiom in a non-countable loop and convert the loop +/// to countable (with CTLZ / CTTZ trip count). If CTLZ / CTTZ inserted as a new +/// trip count returns true; otherwise, returns false. +bool LoopIdiomRecognize::recognizeAndInsertFFS() { + // Give up if the loop has multiple blocks or multiple backedges. + if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1) + return false; - IntrinsicCostAttributes Attrs(IntrinID, InitX->getType(), Args); - InstructionCost Cost = - TTI->getIntrinsicInstrCost(Attrs, TargetTransformInfo::TCK_SizeAndLatency); - if (HeaderSize != IdiomCanonicalSize && - Cost > TargetTransformInfo::TCC_Basic) + Intrinsic::ID IntrinID; + Value *InitX; + Instruction *DefX = nullptr; + PHINode *CntPhi = nullptr; + Instruction *CntInst = nullptr; + + if (!detectShiftUntilZeroIdiom(CurLoop, *DL, IntrinID, InitX, CntInst, CntPhi, + DefX)) + return false; + + return insertFFSIfProfitable(IntrinID, InitX, DefX, CntPhi, CntInst); +} + +bool LoopIdiomRecognize::recognizeShiftUntilLessThan() { + // Give up if the loop has multiple blocks or multiple backedges. + if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1) + return false; + + Intrinsic::ID IntrinID; + Value *InitX; + Instruction *DefX = nullptr; + PHINode *CntPhi = nullptr; + Instruction *CntInst = nullptr; + + APInt LoopThreshold; + if (!detectShiftUntilLessThanIdiom(CurLoop, *DL, IntrinID, InitX, CntInst, + CntPhi, DefX, LoopThreshold)) + return false; + + if (LoopThreshold == 2) { + // Treat as regular FFS. + return insertFFSIfProfitable(IntrinID, InitX, DefX, CntPhi, CntInst); + } + + // Look for Floor Log2 Idiom. + if (LoopThreshold != 4) + return false; + + // Abort if CntPhi is used outside of the loop. + for (User *U : CntPhi->users()) + if (!CurLoop->contains(cast<Instruction>(U))) + return false; + + // It is safe to assume Preheader exist as it was checked in + // parent function RunOnLoop. + BasicBlock *PH = CurLoop->getLoopPreheader(); + auto *PreCondBB = PH->getSinglePredecessor(); + if (!PreCondBB) + return false; + auto *PreCondBI = dyn_cast<BranchInst>(PreCondBB->getTerminator()); + if (!PreCondBI) + return false; + + APInt PreLoopThreshold; + if (matchShiftULTCondition(PreCondBI, PH, PreLoopThreshold) != InitX || + PreLoopThreshold != 2) return false; + bool ZeroCheck = true; + + // the loop has only 6 instructions: + // %n.addr.0 = phi [ %n, %entry ], [ %shr, %while.cond ] + // %i.0 = phi [ %i0, %entry ], [ %inc, %while.cond ] + // %shr = ashr %n.addr.0, 1 + // %tobool = icmp ult %n.addr.0, C + // %inc = add nsw %i.0, 1 + // br i1 %tobool + size_t IdiomCanonicalSize = 6; + if (!isProfitableToInsertFFS(IntrinID, InitX, ZeroCheck, IdiomCanonicalSize)) + return false; + + // log2(x) = w − 1 − clz(x) transformLoopToCountable(IntrinID, PH, CntInst, CntPhi, InitX, DefX, DefX->getDebugLoc(), ZeroCheck, - IsCntPhiUsedOutsideLoop); + /*IsCntPhiUsedOutsideLoop=*/false, + /*InsertSub=*/true); return true; } @@ -1963,7 +2178,7 @@ static CallInst *createFFSIntrinsic(IRBuilder<> &IRBuilder, Value *Val, void LoopIdiomRecognize::transformLoopToCountable( Intrinsic::ID IntrinID, BasicBlock *Preheader, Instruction *CntInst, PHINode *CntPhi, Value *InitX, Instruction *DefX, const DebugLoc &DL, - bool ZeroCheck, bool IsCntPhiUsedOutsideLoop) { + bool ZeroCheck, bool IsCntPhiUsedOutsideLoop, bool InsertSub) { BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator()); // Step 1: Insert the CTLZ/CTTZ instruction at the end of the preheader block @@ -1993,6 +2208,8 @@ void LoopIdiomRecognize::transformLoopToCountable( Type *CountTy = Count->getType(); Count = Builder.CreateSub( ConstantInt::get(CountTy, CountTy->getIntegerBitWidth()), Count); + if (InsertSub) + Count = Builder.CreateSub(Count, ConstantInt::get(CountTy, 1)); Value *NewCount = Count; if (IsCntPhiUsedOutsideLoop) Count = Builder.CreateAdd(Count, ConstantInt::get(CountTy, 1)); @@ -2409,15 +2626,15 @@ bool LoopIdiomRecognize::recognizeShiftUntilBitTest() { if (!isGuaranteedNotToBeUndefOrPoison(BitPos)) { // BitMask may be computed from BitPos, Freeze BitPos so we can increase // it's use count. - Instruction *InsertPt = nullptr; + std::optional<BasicBlock::iterator> InsertPt = std::nullopt; if (auto *BitPosI = dyn_cast<Instruction>(BitPos)) - InsertPt = &**BitPosI->getInsertionPointAfterDef(); + InsertPt = BitPosI->getInsertionPointAfterDef(); else - InsertPt = &*DT->getRoot()->getFirstNonPHIOrDbgOrAlloca(); + InsertPt = DT->getRoot()->getFirstNonPHIOrDbgOrAlloca(); if (!InsertPt) return false; FreezeInst *BitPosFrozen = - new FreezeInst(BitPos, BitPos->getName() + ".fr", InsertPt); + new FreezeInst(BitPos, BitPos->getName() + ".fr", *InsertPt); BitPos->replaceUsesWithIf(BitPosFrozen, [BitPosFrozen](Use &U) { return U.getUser() != BitPosFrozen; }); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp index cfe069d00bce..270c2120365c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -45,7 +45,7 @@ STATISTIC(NumSimplified, "Number of redundant instructions simplified"); static bool simplifyLoopInst(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, const TargetLibraryInfo &TLI, MemorySSAUpdater *MSSAU) { - const DataLayout &DL = L.getHeader()->getModule()->getDataLayout(); + const DataLayout &DL = L.getHeader()->getDataLayout(); SimplifyQuery SQ(DL, &TLI, &DT, &AC); // On the first pass over the loop body we try to simplify every instruction. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInterchange.cpp index 277f530ee25f..400973fd9fc9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -976,7 +976,7 @@ bool LoopInterchangeLegality::canInterchangeLoops(unsigned InnerLoopId, } if (!findInductions(InnerLoop, InnerLoopInductions)) { - LLVM_DEBUG(dbgs() << "Cound not find inner loop induction variables.\n"); + LLVM_DEBUG(dbgs() << "Could not find inner loop induction variables.\n"); return false; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp index 5ec387300aac..489f12e689d3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -96,7 +96,7 @@ struct StoreToLoadForwardingCandidate { Value *LoadPtr = Load->getPointerOperand(); Value *StorePtr = Store->getPointerOperand(); Type *LoadType = getLoadStoreType(Load); - auto &DL = Load->getParent()->getModule()->getDataLayout(); + auto &DL = Load->getDataLayout(); assert(LoadPtr->getType()->getPointerAddressSpace() == StorePtr->getType()->getPointerAddressSpace() && @@ -126,8 +126,10 @@ struct StoreToLoadForwardingCandidate { // We don't need to check non-wrapping here because forward/backward // dependence wouldn't be valid if these weren't monotonic accesses. - auto *Dist = cast<SCEVConstant>( + auto *Dist = dyn_cast<SCEVConstant>( PSE.getSE()->getMinusSCEV(StorePtrSCEV, LoadPtrSCEV)); + if (!Dist) + return false; const APInt &Val = Dist->getAPInt(); return Val == TypeByteSize * StrideLoad; } @@ -181,7 +183,8 @@ public: findStoreToLoadDependences(const LoopAccessInfo &LAI) { std::forward_list<StoreToLoadForwardingCandidate> Candidates; - const auto *Deps = LAI.getDepChecker().getDependences(); + const auto &DepChecker = LAI.getDepChecker(); + const auto *Deps = DepChecker.getDependences(); if (!Deps) return Candidates; @@ -192,8 +195,8 @@ public: SmallPtrSet<Instruction *, 4> LoadsWithUnknownDepedence; for (const auto &Dep : *Deps) { - Instruction *Source = Dep.getSource(LAI); - Instruction *Destination = Dep.getDestination(LAI); + Instruction *Source = Dep.getSource(DepChecker); + Instruction *Destination = Dep.getDestination(DepChecker); if (Dep.Type == MemoryDepChecker::Dependence::Unknown || Dep.Type == MemoryDepChecker::Dependence::IndirectUnsafe) { @@ -222,7 +225,7 @@ public: // Only propagate if the stored values are bit/pointer castable. if (!CastInst::isBitOrNoopPointerCastable( getLoadStoreType(Store), getLoadStoreType(Load), - Store->getParent()->getModule()->getDataLayout())) + Store->getDataLayout())) continue; Candidates.emplace_front(Load, Store); @@ -349,19 +352,20 @@ public: // ld0. LoadInst *LastLoad = - std::max_element(Candidates.begin(), Candidates.end(), - [&](const StoreToLoadForwardingCandidate &A, - const StoreToLoadForwardingCandidate &B) { - return getInstrIndex(A.Load) < getInstrIndex(B.Load); - }) + llvm::max_element(Candidates, + [&](const StoreToLoadForwardingCandidate &A, + const StoreToLoadForwardingCandidate &B) { + return getInstrIndex(A.Load) < + getInstrIndex(B.Load); + }) ->Load; StoreInst *FirstStore = - std::min_element(Candidates.begin(), Candidates.end(), - [&](const StoreToLoadForwardingCandidate &A, - const StoreToLoadForwardingCandidate &B) { - return getInstrIndex(A.Store) < - getInstrIndex(B.Store); - }) + llvm::min_element(Candidates, + [&](const StoreToLoadForwardingCandidate &A, + const StoreToLoadForwardingCandidate &B) { + return getInstrIndex(A.Store) < + getInstrIndex(B.Store); + }) ->Store; // We're looking for stores after the first forwarding store until the end @@ -440,9 +444,14 @@ public: assert(PH && "Preheader should exist!"); Value *InitialPtr = SEE.expandCodeFor(PtrSCEV->getStart(), Ptr->getType(), PH->getTerminator()); - Value *Initial = new LoadInst( - Cand.Load->getType(), InitialPtr, "load_initial", - /* isVolatile */ false, Cand.Load->getAlign(), PH->getTerminator()); + Value *Initial = + new LoadInst(Cand.Load->getType(), InitialPtr, "load_initial", + /* isVolatile */ false, Cand.Load->getAlign(), + PH->getTerminator()->getIterator()); + // We don't give any debug location to Initial, because it is inserted + // into the loop's preheader. A debug location inside the loop will cause + // a misleading stepping when debugging. The test update-debugloc-store + // -forwarded.ll checks this. PHINode *PHI = PHINode::Create(Initial->getType(), 2, "store_forwarded"); PHI->insertBefore(L->getHeader()->begin()); @@ -450,20 +459,27 @@ public: Type *LoadType = Initial->getType(); Type *StoreType = Cand.Store->getValueOperand()->getType(); - auto &DL = Cand.Load->getParent()->getModule()->getDataLayout(); + auto &DL = Cand.Load->getDataLayout(); (void)DL; assert(DL.getTypeSizeInBits(LoadType) == DL.getTypeSizeInBits(StoreType) && "The type sizes should match!"); Value *StoreValue = Cand.Store->getValueOperand(); - if (LoadType != StoreType) - StoreValue = CastInst::CreateBitOrPointerCast( - StoreValue, LoadType, "store_forward_cast", Cand.Store); + if (LoadType != StoreType) { + StoreValue = CastInst::CreateBitOrPointerCast(StoreValue, LoadType, + "store_forward_cast", + Cand.Store->getIterator()); + // Because it casts the old `load` value and is used by the new `phi` + // which replaces the old `load`, we give the `load`'s debug location + // to it. + cast<Instruction>(StoreValue)->setDebugLoc(Cand.Load->getDebugLoc()); + } PHI->addIncoming(StoreValue, L->getLoopLatch()); Cand.Load->replaceAllUsesWith(PHI); + PHI->setDebugLoc(Cand.Load->getDebugLoc()); } /// Top-level driver for each loop: find store->load forwarding @@ -601,7 +617,7 @@ public: // Next, propagate the value stored by the store to the users of the load. // Also for the first iteration, generate the initial value of the load. - SCEVExpander SEE(*PSE.getSE(), L->getHeader()->getModule()->getDataLayout(), + SCEVExpander SEE(*PSE.getSE(), L->getHeader()->getDataLayout(), "storeforward"); for (const auto &Cand : Candidates) propagateStoredValueToLoadUsers(Cand, SEE); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp deleted file mode 100644 index 7f62526a4f6d..000000000000 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp +++ /dev/null @@ -1,1679 +0,0 @@ -//===- LoopReroll.cpp - Loop rerolling pass -------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This pass implements a simple loop reroller. -// -//===----------------------------------------------------------------------===// - -#include "llvm/ADT/APInt.h" -#include "llvm/ADT/BitVector.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/MapVector.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/AliasSetTracker.h" -#include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpressions.h" -#include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/Dominators.h" -#include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instruction.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Use.h" -#include "llvm/IR/User.h" -#include "llvm/IR/Value.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar/LoopReroll.h" -#include "llvm/Transforms/Utils.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" -#include "llvm/Transforms/Utils/LoopUtils.h" -#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" -#include <cassert> -#include <cstddef> -#include <cstdint> -#include <iterator> -#include <map> -#include <utility> - -using namespace llvm; - -#define DEBUG_TYPE "loop-reroll" - -STATISTIC(NumRerolledLoops, "Number of rerolled loops"); - -static cl::opt<unsigned> -NumToleratedFailedMatches("reroll-num-tolerated-failed-matches", cl::init(400), - cl::Hidden, - cl::desc("The maximum number of failures to tolerate" - " during fuzzy matching. (default: 400)")); - -// This loop re-rolling transformation aims to transform loops like this: -// -// int foo(int a); -// void bar(int *x) { -// for (int i = 0; i < 500; i += 3) { -// foo(i); -// foo(i+1); -// foo(i+2); -// } -// } -// -// into a loop like this: -// -// void bar(int *x) { -// for (int i = 0; i < 500; ++i) -// foo(i); -// } -// -// It does this by looking for loops that, besides the latch code, are composed -// of isomorphic DAGs of instructions, with each DAG rooted at some increment -// to the induction variable, and where each DAG is isomorphic to the DAG -// rooted at the induction variable (excepting the sub-DAGs which root the -// other induction-variable increments). In other words, we're looking for loop -// bodies of the form: -// -// %iv = phi [ (preheader, ...), (body, %iv.next) ] -// f(%iv) -// %iv.1 = add %iv, 1 <-- a root increment -// f(%iv.1) -// %iv.2 = add %iv, 2 <-- a root increment -// f(%iv.2) -// %iv.scale_m_1 = add %iv, scale-1 <-- a root increment -// f(%iv.scale_m_1) -// ... -// %iv.next = add %iv, scale -// %cmp = icmp(%iv, ...) -// br %cmp, header, exit -// -// where each f(i) is a set of instructions that, collectively, are a function -// only of i (and other loop-invariant values). -// -// As a special case, we can also reroll loops like this: -// -// int foo(int); -// void bar(int *x) { -// for (int i = 0; i < 500; ++i) { -// x[3*i] = foo(0); -// x[3*i+1] = foo(0); -// x[3*i+2] = foo(0); -// } -// } -// -// into this: -// -// void bar(int *x) { -// for (int i = 0; i < 1500; ++i) -// x[i] = foo(0); -// } -// -// in which case, we're looking for inputs like this: -// -// %iv = phi [ (preheader, ...), (body, %iv.next) ] -// %scaled.iv = mul %iv, scale -// f(%scaled.iv) -// %scaled.iv.1 = add %scaled.iv, 1 -// f(%scaled.iv.1) -// %scaled.iv.2 = add %scaled.iv, 2 -// f(%scaled.iv.2) -// %scaled.iv.scale_m_1 = add %scaled.iv, scale-1 -// f(%scaled.iv.scale_m_1) -// ... -// %iv.next = add %iv, 1 -// %cmp = icmp(%iv, ...) -// br %cmp, header, exit - -namespace { - - enum IterationLimits { - /// The maximum number of iterations that we'll try and reroll. - IL_MaxRerollIterations = 32, - /// The bitvector index used by loop induction variables and other - /// instructions that belong to all iterations. - IL_All, - IL_End - }; - - class LoopReroll { - public: - LoopReroll(AliasAnalysis *AA, LoopInfo *LI, ScalarEvolution *SE, - TargetLibraryInfo *TLI, DominatorTree *DT, bool PreserveLCSSA) - : AA(AA), LI(LI), SE(SE), TLI(TLI), DT(DT), - PreserveLCSSA(PreserveLCSSA) {} - bool runOnLoop(Loop *L); - - protected: - AliasAnalysis *AA; - LoopInfo *LI; - ScalarEvolution *SE; - TargetLibraryInfo *TLI; - DominatorTree *DT; - bool PreserveLCSSA; - - using SmallInstructionVector = SmallVector<Instruction *, 16>; - using SmallInstructionSet = SmallPtrSet<Instruction *, 16>; - using TinyInstructionVector = SmallVector<Instruction *, 1>; - - // Map between induction variable and its increment - DenseMap<Instruction *, int64_t> IVToIncMap; - - // For loop with multiple induction variables, remember the ones used only to - // control the loop. - TinyInstructionVector LoopControlIVs; - - // A chain of isomorphic instructions, identified by a single-use PHI - // representing a reduction. Only the last value may be used outside the - // loop. - struct SimpleLoopReduction { - SimpleLoopReduction(Instruction *P, Loop *L) : Instructions(1, P) { - assert(isa<PHINode>(P) && "First reduction instruction must be a PHI"); - add(L); - } - - bool valid() const { - return Valid; - } - - Instruction *getPHI() const { - assert(Valid && "Using invalid reduction"); - return Instructions.front(); - } - - Instruction *getReducedValue() const { - assert(Valid && "Using invalid reduction"); - return Instructions.back(); - } - - Instruction *get(size_t i) const { - assert(Valid && "Using invalid reduction"); - return Instructions[i+1]; - } - - Instruction *operator [] (size_t i) const { return get(i); } - - // The size, ignoring the initial PHI. - size_t size() const { - assert(Valid && "Using invalid reduction"); - return Instructions.size()-1; - } - - using iterator = SmallInstructionVector::iterator; - using const_iterator = SmallInstructionVector::const_iterator; - - iterator begin() { - assert(Valid && "Using invalid reduction"); - return std::next(Instructions.begin()); - } - - const_iterator begin() const { - assert(Valid && "Using invalid reduction"); - return std::next(Instructions.begin()); - } - - iterator end() { return Instructions.end(); } - const_iterator end() const { return Instructions.end(); } - - protected: - bool Valid = false; - SmallInstructionVector Instructions; - - void add(Loop *L); - }; - - // The set of all reductions, and state tracking of possible reductions - // during loop instruction processing. - struct ReductionTracker { - using SmallReductionVector = SmallVector<SimpleLoopReduction, 16>; - - // Add a new possible reduction. - void addSLR(SimpleLoopReduction &SLR) { PossibleReds.push_back(SLR); } - - // Setup to track possible reductions corresponding to the provided - // rerolling scale. Only reductions with a number of non-PHI instructions - // that is divisible by the scale are considered. Three instructions sets - // are filled in: - // - A set of all possible instructions in eligible reductions. - // - A set of all PHIs in eligible reductions - // - A set of all reduced values (last instructions) in eligible - // reductions. - void restrictToScale(uint64_t Scale, - SmallInstructionSet &PossibleRedSet, - SmallInstructionSet &PossibleRedPHISet, - SmallInstructionSet &PossibleRedLastSet) { - PossibleRedIdx.clear(); - PossibleRedIter.clear(); - Reds.clear(); - - for (unsigned i = 0, e = PossibleReds.size(); i != e; ++i) - if (PossibleReds[i].size() % Scale == 0) { - PossibleRedLastSet.insert(PossibleReds[i].getReducedValue()); - PossibleRedPHISet.insert(PossibleReds[i].getPHI()); - - PossibleRedSet.insert(PossibleReds[i].getPHI()); - PossibleRedIdx[PossibleReds[i].getPHI()] = i; - for (Instruction *J : PossibleReds[i]) { - PossibleRedSet.insert(J); - PossibleRedIdx[J] = i; - } - } - } - - // The functions below are used while processing the loop instructions. - - // Are the two instructions both from reductions, and furthermore, from - // the same reduction? - bool isPairInSame(Instruction *J1, Instruction *J2) { - DenseMap<Instruction *, int>::iterator J1I = PossibleRedIdx.find(J1); - if (J1I != PossibleRedIdx.end()) { - DenseMap<Instruction *, int>::iterator J2I = PossibleRedIdx.find(J2); - if (J2I != PossibleRedIdx.end() && J1I->second == J2I->second) - return true; - } - - return false; - } - - // The two provided instructions, the first from the base iteration, and - // the second from iteration i, form a matched pair. If these are part of - // a reduction, record that fact. - void recordPair(Instruction *J1, Instruction *J2, unsigned i) { - if (PossibleRedIdx.count(J1)) { - assert(PossibleRedIdx.count(J2) && - "Recording reduction vs. non-reduction instruction?"); - - PossibleRedIter[J1] = 0; - PossibleRedIter[J2] = i; - - int Idx = PossibleRedIdx[J1]; - assert(Idx == PossibleRedIdx[J2] && - "Recording pair from different reductions?"); - Reds.insert(Idx); - } - } - - // The functions below can be called after we've finished processing all - // instructions in the loop, and we know which reductions were selected. - - bool validateSelected(); - void replaceSelected(); - - protected: - // The vector of all possible reductions (for any scale). - SmallReductionVector PossibleReds; - - DenseMap<Instruction *, int> PossibleRedIdx; - DenseMap<Instruction *, int> PossibleRedIter; - DenseSet<int> Reds; - }; - - // A DAGRootSet models an induction variable being used in a rerollable - // loop. For example, - // - // x[i*3+0] = y1 - // x[i*3+1] = y2 - // x[i*3+2] = y3 - // - // Base instruction -> i*3 - // +---+----+ - // / | \ - // ST[y1] +1 +2 <-- Roots - // | | - // ST[y2] ST[y3] - // - // There may be multiple DAGRoots, for example: - // - // x[i*2+0] = ... (1) - // x[i*2+1] = ... (1) - // x[i*2+4] = ... (2) - // x[i*2+5] = ... (2) - // x[(i+1234)*2+5678] = ... (3) - // x[(i+1234)*2+5679] = ... (3) - // - // The loop will be rerolled by adding a new loop induction variable, - // one for the Base instruction in each DAGRootSet. - // - struct DAGRootSet { - Instruction *BaseInst; - SmallInstructionVector Roots; - - // The instructions between IV and BaseInst (but not including BaseInst). - SmallInstructionSet SubsumedInsts; - }; - - // The set of all DAG roots, and state tracking of all roots - // for a particular induction variable. - struct DAGRootTracker { - DAGRootTracker(LoopReroll *Parent, Loop *L, Instruction *IV, - ScalarEvolution *SE, AliasAnalysis *AA, - TargetLibraryInfo *TLI, DominatorTree *DT, LoopInfo *LI, - bool PreserveLCSSA, - DenseMap<Instruction *, int64_t> &IncrMap, - TinyInstructionVector LoopCtrlIVs) - : Parent(Parent), L(L), SE(SE), AA(AA), TLI(TLI), DT(DT), LI(LI), - PreserveLCSSA(PreserveLCSSA), IV(IV), IVToIncMap(IncrMap), - LoopControlIVs(LoopCtrlIVs) {} - - /// Stage 1: Find all the DAG roots for the induction variable. - bool findRoots(); - - /// Stage 2: Validate if the found roots are valid. - bool validate(ReductionTracker &Reductions); - - /// Stage 3: Assuming validate() returned true, perform the - /// replacement. - /// @param BackedgeTakenCount The backedge-taken count of L. - void replace(const SCEV *BackedgeTakenCount); - - protected: - using UsesTy = MapVector<Instruction *, BitVector>; - - void findRootsRecursive(Instruction *IVU, - SmallInstructionSet SubsumedInsts); - bool findRootsBase(Instruction *IVU, SmallInstructionSet SubsumedInsts); - bool collectPossibleRoots(Instruction *Base, - std::map<int64_t,Instruction*> &Roots); - bool validateRootSet(DAGRootSet &DRS); - - bool collectUsedInstructions(SmallInstructionSet &PossibleRedSet); - void collectInLoopUserSet(const SmallInstructionVector &Roots, - const SmallInstructionSet &Exclude, - const SmallInstructionSet &Final, - DenseSet<Instruction *> &Users); - void collectInLoopUserSet(Instruction *Root, - const SmallInstructionSet &Exclude, - const SmallInstructionSet &Final, - DenseSet<Instruction *> &Users); - - UsesTy::iterator nextInstr(int Val, UsesTy &In, - const SmallInstructionSet &Exclude, - UsesTy::iterator *StartI=nullptr); - bool isBaseInst(Instruction *I); - bool isRootInst(Instruction *I); - bool instrDependsOn(Instruction *I, - UsesTy::iterator Start, - UsesTy::iterator End); - void replaceIV(DAGRootSet &DRS, const SCEV *Start, const SCEV *IncrExpr); - - LoopReroll *Parent; - - // Members of Parent, replicated here for brevity. - Loop *L; - ScalarEvolution *SE; - AliasAnalysis *AA; - TargetLibraryInfo *TLI; - DominatorTree *DT; - LoopInfo *LI; - bool PreserveLCSSA; - - // The loop induction variable. - Instruction *IV; - - // Loop step amount. - int64_t Inc; - - // Loop reroll count; if Inc == 1, this records the scaling applied - // to the indvar: a[i*2+0] = ...; a[i*2+1] = ... ; - // If Inc is not 1, Scale = Inc. - uint64_t Scale; - - // The roots themselves. - SmallVector<DAGRootSet,16> RootSets; - - // All increment instructions for IV. - SmallInstructionVector LoopIncs; - - // Map of all instructions in the loop (in order) to the iterations - // they are used in (or specially, IL_All for instructions - // used in the loop increment mechanism). - UsesTy Uses; - - // Map between induction variable and its increment - DenseMap<Instruction *, int64_t> &IVToIncMap; - - TinyInstructionVector LoopControlIVs; - }; - - // Check if it is a compare-like instruction whose user is a branch - bool isCompareUsedByBranch(Instruction *I) { - auto *TI = I->getParent()->getTerminator(); - if (!isa<BranchInst>(TI) || !isa<CmpInst>(I)) - return false; - return I->hasOneUse() && TI->getOperand(0) == I; - }; - - bool isLoopControlIV(Loop *L, Instruction *IV); - void collectPossibleIVs(Loop *L, SmallInstructionVector &PossibleIVs); - void collectPossibleReductions(Loop *L, - ReductionTracker &Reductions); - bool reroll(Instruction *IV, Loop *L, BasicBlock *Header, - const SCEV *BackedgeTakenCount, ReductionTracker &Reductions); - }; - -} // end anonymous namespace - -// Returns true if the provided instruction is used outside the given loop. -// This operates like Instruction::isUsedOutsideOfBlock, but considers PHIs in -// non-loop blocks to be outside the loop. -static bool hasUsesOutsideLoop(Instruction *I, Loop *L) { - for (User *U : I->users()) { - if (!L->contains(cast<Instruction>(U))) - return true; - } - return false; -} - -// Check if an IV is only used to control the loop. There are two cases: -// 1. It only has one use which is loop increment, and the increment is only -// used by comparison and the PHI (could has sext with nsw in between), and the -// comparison is only used by branch. -// 2. It is used by loop increment and the comparison, the loop increment is -// only used by the PHI, and the comparison is used only by the branch. -bool LoopReroll::isLoopControlIV(Loop *L, Instruction *IV) { - unsigned IVUses = IV->getNumUses(); - if (IVUses != 2 && IVUses != 1) - return false; - - for (auto *User : IV->users()) { - int32_t IncOrCmpUses = User->getNumUses(); - bool IsCompInst = isCompareUsedByBranch(cast<Instruction>(User)); - - // User can only have one or two uses. - if (IncOrCmpUses != 2 && IncOrCmpUses != 1) - return false; - - // Case 1 - if (IVUses == 1) { - // The only user must be the loop increment. - // The loop increment must have two uses. - if (IsCompInst || IncOrCmpUses != 2) - return false; - } - - // Case 2 - if (IVUses == 2 && IncOrCmpUses != 1) - return false; - - // The users of the IV must be a binary operation or a comparison - if (auto *BO = dyn_cast<BinaryOperator>(User)) { - if (BO->getOpcode() == Instruction::Add) { - // Loop Increment - // User of Loop Increment should be either PHI or CMP - for (auto *UU : User->users()) { - if (PHINode *PN = dyn_cast<PHINode>(UU)) { - if (PN != IV) - return false; - } - // Must be a CMP or an ext (of a value with nsw) then CMP - else { - auto *UUser = cast<Instruction>(UU); - // Skip SExt if we are extending an nsw value - // TODO: Allow ZExt too - if (BO->hasNoSignedWrap() && UUser->hasOneUse() && - isa<SExtInst>(UUser)) - UUser = cast<Instruction>(*(UUser->user_begin())); - if (!isCompareUsedByBranch(UUser)) - return false; - } - } - } else - return false; - // Compare : can only have one use, and must be branch - } else if (!IsCompInst) - return false; - } - return true; -} - -// Collect the list of loop induction variables with respect to which it might -// be possible to reroll the loop. -void LoopReroll::collectPossibleIVs(Loop *L, - SmallInstructionVector &PossibleIVs) { - for (Instruction &IV : L->getHeader()->phis()) { - if (!IV.getType()->isIntegerTy() && !IV.getType()->isPointerTy()) - continue; - - if (const SCEVAddRecExpr *PHISCEV = - dyn_cast<SCEVAddRecExpr>(SE->getSCEV(&IV))) { - if (PHISCEV->getLoop() != L) - continue; - if (!PHISCEV->isAffine()) - continue; - const auto *IncSCEV = dyn_cast<SCEVConstant>(PHISCEV->getStepRecurrence(*SE)); - if (IncSCEV) { - IVToIncMap[&IV] = IncSCEV->getValue()->getSExtValue(); - LLVM_DEBUG(dbgs() << "LRR: Possible IV: " << IV << " = " << *PHISCEV - << "\n"); - - if (isLoopControlIV(L, &IV)) { - LoopControlIVs.push_back(&IV); - LLVM_DEBUG(dbgs() << "LRR: Loop control only IV: " << IV - << " = " << *PHISCEV << "\n"); - } else - PossibleIVs.push_back(&IV); - } - } - } -} - -// Add the remainder of the reduction-variable chain to the instruction vector -// (the initial PHINode has already been added). If successful, the object is -// marked as valid. -void LoopReroll::SimpleLoopReduction::add(Loop *L) { - assert(!Valid && "Cannot add to an already-valid chain"); - - // The reduction variable must be a chain of single-use instructions - // (including the PHI), except for the last value (which is used by the PHI - // and also outside the loop). - Instruction *C = Instructions.front(); - if (C->user_empty()) - return; - - do { - C = cast<Instruction>(*C->user_begin()); - if (C->hasOneUse()) { - if (!C->isBinaryOp()) - return; - - if (!(isa<PHINode>(Instructions.back()) || - C->isSameOperationAs(Instructions.back()))) - return; - - Instructions.push_back(C); - } - } while (C->hasOneUse()); - - if (Instructions.size() < 2 || - !C->isSameOperationAs(Instructions.back()) || - C->use_empty()) - return; - - // C is now the (potential) last instruction in the reduction chain. - for (User *U : C->users()) { - // The only in-loop user can be the initial PHI. - if (L->contains(cast<Instruction>(U))) - if (cast<Instruction>(U) != Instructions.front()) - return; - } - - Instructions.push_back(C); - Valid = true; -} - -// Collect the vector of possible reduction variables. -void LoopReroll::collectPossibleReductions(Loop *L, - ReductionTracker &Reductions) { - BasicBlock *Header = L->getHeader(); - for (BasicBlock::iterator I = Header->begin(), - IE = Header->getFirstInsertionPt(); I != IE; ++I) { - if (!isa<PHINode>(I)) - continue; - if (!I->getType()->isSingleValueType()) - continue; - - SimpleLoopReduction SLR(&*I, L); - if (!SLR.valid()) - continue; - - LLVM_DEBUG(dbgs() << "LRR: Possible reduction: " << *I << " (with " - << SLR.size() << " chained instructions)\n"); - Reductions.addSLR(SLR); - } -} - -// Collect the set of all users of the provided root instruction. This set of -// users contains not only the direct users of the root instruction, but also -// all users of those users, and so on. There are two exceptions: -// -// 1. Instructions in the set of excluded instructions are never added to the -// use set (even if they are users). This is used, for example, to exclude -// including root increments in the use set of the primary IV. -// -// 2. Instructions in the set of final instructions are added to the use set -// if they are users, but their users are not added. This is used, for -// example, to prevent a reduction update from forcing all later reduction -// updates into the use set. -void LoopReroll::DAGRootTracker::collectInLoopUserSet( - Instruction *Root, const SmallInstructionSet &Exclude, - const SmallInstructionSet &Final, - DenseSet<Instruction *> &Users) { - SmallInstructionVector Queue(1, Root); - while (!Queue.empty()) { - Instruction *I = Queue.pop_back_val(); - if (!Users.insert(I).second) - continue; - - if (!Final.count(I)) - for (Use &U : I->uses()) { - Instruction *User = cast<Instruction>(U.getUser()); - if (PHINode *PN = dyn_cast<PHINode>(User)) { - // Ignore "wrap-around" uses to PHIs of this loop's header. - if (PN->getIncomingBlock(U) == L->getHeader()) - continue; - } - - if (L->contains(User) && !Exclude.count(User)) { - Queue.push_back(User); - } - } - - // We also want to collect single-user "feeder" values. - for (Use &U : I->operands()) { - if (Instruction *Op = dyn_cast<Instruction>(U)) - if (Op->hasOneUse() && L->contains(Op) && !Exclude.count(Op) && - !Final.count(Op)) - Queue.push_back(Op); - } - } -} - -// Collect all of the users of all of the provided root instructions (combined -// into a single set). -void LoopReroll::DAGRootTracker::collectInLoopUserSet( - const SmallInstructionVector &Roots, - const SmallInstructionSet &Exclude, - const SmallInstructionSet &Final, - DenseSet<Instruction *> &Users) { - for (Instruction *Root : Roots) - collectInLoopUserSet(Root, Exclude, Final, Users); -} - -static bool isUnorderedLoadStore(Instruction *I) { - if (LoadInst *LI = dyn_cast<LoadInst>(I)) - return LI->isUnordered(); - if (StoreInst *SI = dyn_cast<StoreInst>(I)) - return SI->isUnordered(); - if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) - return !MI->isVolatile(); - return false; -} - -/// Return true if IVU is a "simple" arithmetic operation. -/// This is used for narrowing the search space for DAGRoots; only arithmetic -/// and GEPs can be part of a DAGRoot. -static bool isSimpleArithmeticOp(User *IVU) { - if (Instruction *I = dyn_cast<Instruction>(IVU)) { - switch (I->getOpcode()) { - default: return false; - case Instruction::Add: - case Instruction::Sub: - case Instruction::Mul: - case Instruction::Shl: - case Instruction::AShr: - case Instruction::LShr: - case Instruction::GetElementPtr: - case Instruction::Trunc: - case Instruction::ZExt: - case Instruction::SExt: - return true; - } - } - return false; -} - -static bool isLoopIncrement(User *U, Instruction *IV) { - BinaryOperator *BO = dyn_cast<BinaryOperator>(U); - - if ((BO && BO->getOpcode() != Instruction::Add) || - (!BO && !isa<GetElementPtrInst>(U))) - return false; - - for (auto *UU : U->users()) { - PHINode *PN = dyn_cast<PHINode>(UU); - if (PN && PN == IV) - return true; - } - return false; -} - -bool LoopReroll::DAGRootTracker:: -collectPossibleRoots(Instruction *Base, std::map<int64_t,Instruction*> &Roots) { - SmallInstructionVector BaseUsers; - - for (auto *I : Base->users()) { - ConstantInt *CI = nullptr; - - if (isLoopIncrement(I, IV)) { - LoopIncs.push_back(cast<Instruction>(I)); - continue; - } - - // The root nodes must be either GEPs, ORs or ADDs. - if (auto *BO = dyn_cast<BinaryOperator>(I)) { - if (BO->getOpcode() == Instruction::Add || - BO->getOpcode() == Instruction::Or) - CI = dyn_cast<ConstantInt>(BO->getOperand(1)); - } else if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { - Value *LastOperand = GEP->getOperand(GEP->getNumOperands()-1); - CI = dyn_cast<ConstantInt>(LastOperand); - } - - if (!CI) { - if (Instruction *II = dyn_cast<Instruction>(I)) { - BaseUsers.push_back(II); - continue; - } else { - LLVM_DEBUG(dbgs() << "LRR: Aborting due to non-instruction: " << *I - << "\n"); - return false; - } - } - - int64_t V = std::abs(CI->getValue().getSExtValue()); - if (Roots.find(V) != Roots.end()) - // No duplicates, please. - return false; - - Roots[V] = cast<Instruction>(I); - } - - // Make sure we have at least two roots. - if (Roots.empty() || (Roots.size() == 1 && BaseUsers.empty())) - return false; - - // If we found non-loop-inc, non-root users of Base, assume they are - // for the zeroth root index. This is because "add %a, 0" gets optimized - // away. - if (BaseUsers.size()) { - if (Roots.find(0) != Roots.end()) { - LLVM_DEBUG(dbgs() << "LRR: Multiple roots found for base - aborting!\n"); - return false; - } - Roots[0] = Base; - } - - // Calculate the number of users of the base, or lowest indexed, iteration. - unsigned NumBaseUses = BaseUsers.size(); - if (NumBaseUses == 0) - NumBaseUses = Roots.begin()->second->getNumUses(); - - // Check that every node has the same number of users. - for (auto &KV : Roots) { - if (KV.first == 0) - continue; - if (!KV.second->hasNUses(NumBaseUses)) { - LLVM_DEBUG(dbgs() << "LRR: Aborting - Root and Base #users not the same: " - << "#Base=" << NumBaseUses - << ", #Root=" << KV.second->getNumUses() << "\n"); - return false; - } - } - - return true; -} - -void LoopReroll::DAGRootTracker:: -findRootsRecursive(Instruction *I, SmallInstructionSet SubsumedInsts) { - // Does the user look like it could be part of a root set? - // All its users must be simple arithmetic ops. - if (I->hasNUsesOrMore(IL_MaxRerollIterations + 1)) - return; - - if (I != IV && findRootsBase(I, SubsumedInsts)) - return; - - SubsumedInsts.insert(I); - - for (User *V : I->users()) { - Instruction *I = cast<Instruction>(V); - if (is_contained(LoopIncs, I)) - continue; - - if (!isSimpleArithmeticOp(I)) - continue; - - // The recursive call makes a copy of SubsumedInsts. - findRootsRecursive(I, SubsumedInsts); - } -} - -bool LoopReroll::DAGRootTracker::validateRootSet(DAGRootSet &DRS) { - if (DRS.Roots.empty()) - return false; - - // If the value of the base instruction is used outside the loop, we cannot - // reroll the loop. Check for other root instructions is unnecessary because - // they don't match any base instructions if their values are used outside. - if (hasUsesOutsideLoop(DRS.BaseInst, L)) - return false; - - // Consider a DAGRootSet with N-1 roots (so N different values including - // BaseInst). - // Define d = Roots[0] - BaseInst, which should be the same as - // Roots[I] - Roots[I-1] for all I in [1..N). - // Define D = BaseInst@J - BaseInst@J-1, where "@J" means the value at the - // loop iteration J. - // - // Now, For the loop iterations to be consecutive: - // D = d * N - const auto *ADR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(DRS.BaseInst)); - if (!ADR) - return false; - - // Check that the first root is evenly spaced. - unsigned N = DRS.Roots.size() + 1; - const SCEV *StepSCEV = SE->getMinusSCEV(SE->getSCEV(DRS.Roots[0]), ADR); - if (isa<SCEVCouldNotCompute>(StepSCEV) || StepSCEV->getType()->isPointerTy()) - return false; - const SCEV *ScaleSCEV = SE->getConstant(StepSCEV->getType(), N); - if (ADR->getStepRecurrence(*SE) != SE->getMulExpr(StepSCEV, ScaleSCEV)) - return false; - - // Check that the remainling roots are evenly spaced. - for (unsigned i = 1; i < N - 1; ++i) { - const SCEV *NewStepSCEV = SE->getMinusSCEV(SE->getSCEV(DRS.Roots[i]), - SE->getSCEV(DRS.Roots[i-1])); - if (NewStepSCEV != StepSCEV) - return false; - } - - return true; -} - -bool LoopReroll::DAGRootTracker:: -findRootsBase(Instruction *IVU, SmallInstructionSet SubsumedInsts) { - // The base of a RootSet must be an AddRec, so it can be erased. - const auto *IVU_ADR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(IVU)); - if (!IVU_ADR || IVU_ADR->getLoop() != L) - return false; - - std::map<int64_t, Instruction*> V; - if (!collectPossibleRoots(IVU, V)) - return false; - - // If we didn't get a root for index zero, then IVU must be - // subsumed. - if (V.find(0) == V.end()) - SubsumedInsts.insert(IVU); - - // Partition the vector into monotonically increasing indexes. - DAGRootSet DRS; - DRS.BaseInst = nullptr; - - SmallVector<DAGRootSet, 16> PotentialRootSets; - - for (auto &KV : V) { - if (!DRS.BaseInst) { - DRS.BaseInst = KV.second; - DRS.SubsumedInsts = SubsumedInsts; - } else if (DRS.Roots.empty()) { - DRS.Roots.push_back(KV.second); - } else if (V.find(KV.first - 1) != V.end()) { - DRS.Roots.push_back(KV.second); - } else { - // Linear sequence terminated. - if (!validateRootSet(DRS)) - return false; - - // Construct a new DAGRootSet with the next sequence. - PotentialRootSets.push_back(DRS); - DRS.BaseInst = KV.second; - DRS.Roots.clear(); - } - } - - if (!validateRootSet(DRS)) - return false; - - PotentialRootSets.push_back(DRS); - - RootSets.append(PotentialRootSets.begin(), PotentialRootSets.end()); - - return true; -} - -bool LoopReroll::DAGRootTracker::findRoots() { - Inc = IVToIncMap[IV]; - - assert(RootSets.empty() && "Unclean state!"); - if (std::abs(Inc) == 1) { - for (auto *IVU : IV->users()) { - if (isLoopIncrement(IVU, IV)) - LoopIncs.push_back(cast<Instruction>(IVU)); - } - findRootsRecursive(IV, SmallInstructionSet()); - LoopIncs.push_back(IV); - } else { - if (!findRootsBase(IV, SmallInstructionSet())) - return false; - } - - // Ensure all sets have the same size. - if (RootSets.empty()) { - LLVM_DEBUG(dbgs() << "LRR: Aborting because no root sets found!\n"); - return false; - } - for (auto &V : RootSets) { - if (V.Roots.empty() || V.Roots.size() != RootSets[0].Roots.size()) { - LLVM_DEBUG( - dbgs() - << "LRR: Aborting because not all root sets have the same size\n"); - return false; - } - } - - Scale = RootSets[0].Roots.size() + 1; - - if (Scale > IL_MaxRerollIterations) { - LLVM_DEBUG(dbgs() << "LRR: Aborting - too many iterations found. " - << "#Found=" << Scale - << ", #Max=" << IL_MaxRerollIterations << "\n"); - return false; - } - - LLVM_DEBUG(dbgs() << "LRR: Successfully found roots: Scale=" << Scale - << "\n"); - - return true; -} - -bool LoopReroll::DAGRootTracker::collectUsedInstructions(SmallInstructionSet &PossibleRedSet) { - // Populate the MapVector with all instructions in the block, in order first, - // so we can iterate over the contents later in perfect order. - for (auto &I : *L->getHeader()) { - Uses[&I].resize(IL_End); - } - - SmallInstructionSet Exclude; - for (auto &DRS : RootSets) { - Exclude.insert(DRS.Roots.begin(), DRS.Roots.end()); - Exclude.insert(DRS.SubsumedInsts.begin(), DRS.SubsumedInsts.end()); - Exclude.insert(DRS.BaseInst); - } - Exclude.insert(LoopIncs.begin(), LoopIncs.end()); - - for (auto &DRS : RootSets) { - DenseSet<Instruction*> VBase; - collectInLoopUserSet(DRS.BaseInst, Exclude, PossibleRedSet, VBase); - for (auto *I : VBase) { - Uses[I].set(0); - } - - unsigned Idx = 1; - for (auto *Root : DRS.Roots) { - DenseSet<Instruction*> V; - collectInLoopUserSet(Root, Exclude, PossibleRedSet, V); - - // While we're here, check the use sets are the same size. - if (V.size() != VBase.size()) { - LLVM_DEBUG(dbgs() << "LRR: Aborting - use sets are different sizes\n"); - return false; - } - - for (auto *I : V) { - Uses[I].set(Idx); - } - ++Idx; - } - - // Make sure our subsumed instructions are remembered too. - for (auto *I : DRS.SubsumedInsts) { - Uses[I].set(IL_All); - } - } - - // Make sure the loop increments are also accounted for. - - Exclude.clear(); - for (auto &DRS : RootSets) { - Exclude.insert(DRS.Roots.begin(), DRS.Roots.end()); - Exclude.insert(DRS.SubsumedInsts.begin(), DRS.SubsumedInsts.end()); - Exclude.insert(DRS.BaseInst); - } - - DenseSet<Instruction*> V; - collectInLoopUserSet(LoopIncs, Exclude, PossibleRedSet, V); - for (auto *I : V) { - if (I->mayHaveSideEffects()) { - LLVM_DEBUG(dbgs() << "LRR: Aborting - " - << "An instruction which does not belong to any root " - << "sets must not have side effects: " << *I); - return false; - } - Uses[I].set(IL_All); - } - - return true; -} - -/// Get the next instruction in "In" that is a member of set Val. -/// Start searching from StartI, and do not return anything in Exclude. -/// If StartI is not given, start from In.begin(). -LoopReroll::DAGRootTracker::UsesTy::iterator -LoopReroll::DAGRootTracker::nextInstr(int Val, UsesTy &In, - const SmallInstructionSet &Exclude, - UsesTy::iterator *StartI) { - UsesTy::iterator I = StartI ? *StartI : In.begin(); - while (I != In.end() && (I->second.test(Val) == 0 || - Exclude.contains(I->first))) - ++I; - return I; -} - -bool LoopReroll::DAGRootTracker::isBaseInst(Instruction *I) { - for (auto &DRS : RootSets) { - if (DRS.BaseInst == I) - return true; - } - return false; -} - -bool LoopReroll::DAGRootTracker::isRootInst(Instruction *I) { - for (auto &DRS : RootSets) { - if (is_contained(DRS.Roots, I)) - return true; - } - return false; -} - -/// Return true if instruction I depends on any instruction between -/// Start and End. -bool LoopReroll::DAGRootTracker::instrDependsOn(Instruction *I, - UsesTy::iterator Start, - UsesTy::iterator End) { - for (auto *U : I->users()) { - for (auto It = Start; It != End; ++It) - if (U == It->first) - return true; - } - return false; -} - -static bool isIgnorableInst(const Instruction *I) { - if (isa<DbgInfoIntrinsic>(I)) - return true; - const IntrinsicInst* II = dyn_cast<IntrinsicInst>(I); - if (!II) - return false; - switch (II->getIntrinsicID()) { - default: - return false; - case Intrinsic::annotation: - case Intrinsic::ptr_annotation: - case Intrinsic::var_annotation: - // TODO: the following intrinsics may also be allowed: - // lifetime_start, lifetime_end, invariant_start, invariant_end - return true; - } - return false; -} - -bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) { - // We now need to check for equivalence of the use graph of each root with - // that of the primary induction variable (excluding the roots). Our goal - // here is not to solve the full graph isomorphism problem, but rather to - // catch common cases without a lot of work. As a result, we will assume - // that the relative order of the instructions in each unrolled iteration - // is the same (although we will not make an assumption about how the - // different iterations are intermixed). Note that while the order must be - // the same, the instructions may not be in the same basic block. - - // An array of just the possible reductions for this scale factor. When we - // collect the set of all users of some root instructions, these reduction - // instructions are treated as 'final' (their uses are not considered). - // This is important because we don't want the root use set to search down - // the reduction chain. - SmallInstructionSet PossibleRedSet; - SmallInstructionSet PossibleRedLastSet; - SmallInstructionSet PossibleRedPHISet; - Reductions.restrictToScale(Scale, PossibleRedSet, - PossibleRedPHISet, PossibleRedLastSet); - - // Populate "Uses" with where each instruction is used. - if (!collectUsedInstructions(PossibleRedSet)) - return false; - - // Make sure we mark the reduction PHIs as used in all iterations. - for (auto *I : PossibleRedPHISet) { - Uses[I].set(IL_All); - } - - // Make sure we mark loop-control-only PHIs as used in all iterations. See - // comment above LoopReroll::isLoopControlIV for more information. - BasicBlock *Header = L->getHeader(); - for (Instruction *LoopControlIV : LoopControlIVs) { - for (auto *U : LoopControlIV->users()) { - Instruction *IVUser = dyn_cast<Instruction>(U); - // IVUser could be loop increment or compare - Uses[IVUser].set(IL_All); - for (auto *UU : IVUser->users()) { - Instruction *UUser = dyn_cast<Instruction>(UU); - // UUser could be compare, PHI or branch - Uses[UUser].set(IL_All); - // Skip SExt - if (isa<SExtInst>(UUser)) { - UUser = dyn_cast<Instruction>(*(UUser->user_begin())); - Uses[UUser].set(IL_All); - } - // Is UUser a compare instruction? - if (UU->hasOneUse()) { - Instruction *BI = dyn_cast<BranchInst>(*UUser->user_begin()); - if (BI == cast<BranchInst>(Header->getTerminator())) - Uses[BI].set(IL_All); - } - } - } - } - - // Make sure all instructions in the loop are in one and only one - // set. - for (auto &KV : Uses) { - if (KV.second.count() != 1 && !isIgnorableInst(KV.first)) { - LLVM_DEBUG( - dbgs() << "LRR: Aborting - instruction is not used in 1 iteration: " - << *KV.first << " (#uses=" << KV.second.count() << ")\n"); - return false; - } - } - - LLVM_DEBUG(for (auto &KV - : Uses) { - dbgs() << "LRR: " << KV.second.find_first() << "\t" << *KV.first << "\n"; - }); - - BatchAAResults BatchAA(*AA); - for (unsigned Iter = 1; Iter < Scale; ++Iter) { - // In addition to regular aliasing information, we need to look for - // instructions from later (future) iterations that have side effects - // preventing us from reordering them past other instructions with side - // effects. - bool FutureSideEffects = false; - AliasSetTracker AST(BatchAA); - // The map between instructions in f(%iv.(i+1)) and f(%iv). - DenseMap<Value *, Value *> BaseMap; - - // Compare iteration Iter to the base. - SmallInstructionSet Visited; - auto BaseIt = nextInstr(0, Uses, Visited); - auto RootIt = nextInstr(Iter, Uses, Visited); - auto LastRootIt = Uses.begin(); - - while (BaseIt != Uses.end() && RootIt != Uses.end()) { - Instruction *BaseInst = BaseIt->first; - Instruction *RootInst = RootIt->first; - - // Skip over the IV or root instructions; only match their users. - bool Continue = false; - if (isBaseInst(BaseInst)) { - Visited.insert(BaseInst); - BaseIt = nextInstr(0, Uses, Visited); - Continue = true; - } - if (isRootInst(RootInst)) { - LastRootIt = RootIt; - Visited.insert(RootInst); - RootIt = nextInstr(Iter, Uses, Visited); - Continue = true; - } - if (Continue) continue; - - if (!BaseInst->isSameOperationAs(RootInst)) { - // Last chance saloon. We don't try and solve the full isomorphism - // problem, but try and at least catch the case where two instructions - // *of different types* are round the wrong way. We won't be able to - // efficiently tell, given two ADD instructions, which way around we - // should match them, but given an ADD and a SUB, we can at least infer - // which one is which. - // - // This should allow us to deal with a greater subset of the isomorphism - // problem. It does however change a linear algorithm into a quadratic - // one, so limit the number of probes we do. - auto TryIt = RootIt; - unsigned N = NumToleratedFailedMatches; - while (TryIt != Uses.end() && - !BaseInst->isSameOperationAs(TryIt->first) && - N--) { - ++TryIt; - TryIt = nextInstr(Iter, Uses, Visited, &TryIt); - } - - if (TryIt == Uses.end() || TryIt == RootIt || - instrDependsOn(TryIt->first, RootIt, TryIt)) { - LLVM_DEBUG(dbgs() << "LRR: iteration root match failed at " - << *BaseInst << " vs. " << *RootInst << "\n"); - return false; - } - - RootIt = TryIt; - RootInst = TryIt->first; - } - - // All instructions between the last root and this root - // may belong to some other iteration. If they belong to a - // future iteration, then they're dangerous to alias with. - // - // Note that because we allow a limited amount of flexibility in the order - // that we visit nodes, LastRootIt might be *before* RootIt, in which - // case we've already checked this set of instructions so we shouldn't - // do anything. - for (; LastRootIt < RootIt; ++LastRootIt) { - Instruction *I = LastRootIt->first; - if (LastRootIt->second.find_first() < (int)Iter) - continue; - if (I->mayWriteToMemory()) - AST.add(I); - // Note: This is specifically guarded by a check on isa<PHINode>, - // which while a valid (somewhat arbitrary) micro-optimization, is - // needed because otherwise isSafeToSpeculativelyExecute returns - // false on PHI nodes. - if (!isa<PHINode>(I) && !isUnorderedLoadStore(I) && - !isSafeToSpeculativelyExecute(I)) - // Intervening instructions cause side effects. - FutureSideEffects = true; - } - - // Make sure that this instruction, which is in the use set of this - // root instruction, does not also belong to the base set or the set of - // some other root instruction. - if (RootIt->second.count() > 1) { - LLVM_DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst - << " vs. " << *RootInst << " (prev. case overlap)\n"); - return false; - } - - // Make sure that we don't alias with any instruction in the alias set - // tracker. If we do, then we depend on a future iteration, and we - // can't reroll. - if (RootInst->mayReadFromMemory()) { - for (auto &K : AST) { - if (isModOrRefSet(K.aliasesUnknownInst(RootInst, BatchAA))) { - LLVM_DEBUG(dbgs() << "LRR: iteration root match failed at " - << *BaseInst << " vs. " << *RootInst - << " (depends on future store)\n"); - return false; - } - } - } - - // If we've past an instruction from a future iteration that may have - // side effects, and this instruction might also, then we can't reorder - // them, and this matching fails. As an exception, we allow the alias - // set tracker to handle regular (unordered) load/store dependencies. - if (FutureSideEffects && ((!isUnorderedLoadStore(BaseInst) && - !isSafeToSpeculativelyExecute(BaseInst)) || - (!isUnorderedLoadStore(RootInst) && - !isSafeToSpeculativelyExecute(RootInst)))) { - LLVM_DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst - << " vs. " << *RootInst - << " (side effects prevent reordering)\n"); - return false; - } - - // For instructions that are part of a reduction, if the operation is - // associative, then don't bother matching the operands (because we - // already know that the instructions are isomorphic, and the order - // within the iteration does not matter). For non-associative reductions, - // we do need to match the operands, because we need to reject - // out-of-order instructions within an iteration! - // For example (assume floating-point addition), we need to reject this: - // x += a[i]; x += b[i]; - // x += a[i+1]; x += b[i+1]; - // x += b[i+2]; x += a[i+2]; - bool InReduction = Reductions.isPairInSame(BaseInst, RootInst); - - if (!(InReduction && BaseInst->isAssociative())) { - bool Swapped = false, SomeOpMatched = false; - for (unsigned j = 0; j < BaseInst->getNumOperands(); ++j) { - Value *Op2 = RootInst->getOperand(j); - - // If this is part of a reduction (and the operation is not - // associatve), then we match all operands, but not those that are - // part of the reduction. - if (InReduction) - if (Instruction *Op2I = dyn_cast<Instruction>(Op2)) - if (Reductions.isPairInSame(RootInst, Op2I)) - continue; - - DenseMap<Value *, Value *>::iterator BMI = BaseMap.find(Op2); - if (BMI != BaseMap.end()) { - Op2 = BMI->second; - } else { - for (auto &DRS : RootSets) { - if (DRS.Roots[Iter-1] == (Instruction*) Op2) { - Op2 = DRS.BaseInst; - break; - } - } - } - - if (BaseInst->getOperand(Swapped ? unsigned(!j) : j) != Op2) { - // If we've not already decided to swap the matched operands, and - // we've not already matched our first operand (note that we could - // have skipped matching the first operand because it is part of a - // reduction above), and the instruction is commutative, then try - // the swapped match. - if (!Swapped && BaseInst->isCommutative() && !SomeOpMatched && - BaseInst->getOperand(!j) == Op2) { - Swapped = true; - } else { - LLVM_DEBUG(dbgs() - << "LRR: iteration root match failed at " << *BaseInst - << " vs. " << *RootInst << " (operand " << j << ")\n"); - return false; - } - } - - SomeOpMatched = true; - } - } - - if ((!PossibleRedLastSet.count(BaseInst) && - hasUsesOutsideLoop(BaseInst, L)) || - (!PossibleRedLastSet.count(RootInst) && - hasUsesOutsideLoop(RootInst, L))) { - LLVM_DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst - << " vs. " << *RootInst << " (uses outside loop)\n"); - return false; - } - - Reductions.recordPair(BaseInst, RootInst, Iter); - BaseMap.insert(std::make_pair(RootInst, BaseInst)); - - LastRootIt = RootIt; - Visited.insert(BaseInst); - Visited.insert(RootInst); - BaseIt = nextInstr(0, Uses, Visited); - RootIt = nextInstr(Iter, Uses, Visited); - } - assert(BaseIt == Uses.end() && RootIt == Uses.end() && - "Mismatched set sizes!"); - } - - LLVM_DEBUG(dbgs() << "LRR: Matched all iteration increments for " << *IV - << "\n"); - - return true; -} - -void LoopReroll::DAGRootTracker::replace(const SCEV *BackedgeTakenCount) { - BasicBlock *Header = L->getHeader(); - - // Compute the start and increment for each BaseInst before we start erasing - // instructions. - SmallVector<const SCEV *, 8> StartExprs; - SmallVector<const SCEV *, 8> IncrExprs; - for (auto &DRS : RootSets) { - const SCEVAddRecExpr *IVSCEV = - cast<SCEVAddRecExpr>(SE->getSCEV(DRS.BaseInst)); - StartExprs.push_back(IVSCEV->getStart()); - IncrExprs.push_back(SE->getMinusSCEV(SE->getSCEV(DRS.Roots[0]), IVSCEV)); - } - - // Remove instructions associated with non-base iterations. - for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*Header))) { - unsigned I = Uses[&Inst].find_first(); - if (I > 0 && I < IL_All) { - LLVM_DEBUG(dbgs() << "LRR: removing: " << Inst << "\n"); - Inst.eraseFromParent(); - } - } - - // Rewrite each BaseInst using SCEV. - for (size_t i = 0, e = RootSets.size(); i != e; ++i) - // Insert the new induction variable. - replaceIV(RootSets[i], StartExprs[i], IncrExprs[i]); - - { // Limit the lifetime of SCEVExpander. - BranchInst *BI = cast<BranchInst>(Header->getTerminator()); - const DataLayout &DL = Header->getModule()->getDataLayout(); - SCEVExpander Expander(*SE, DL, "reroll"); - auto Zero = SE->getZero(BackedgeTakenCount->getType()); - auto One = SE->getOne(BackedgeTakenCount->getType()); - auto NewIVSCEV = SE->getAddRecExpr(Zero, One, L, SCEV::FlagAnyWrap); - Value *NewIV = - Expander.expandCodeFor(NewIVSCEV, BackedgeTakenCount->getType(), - Header->getFirstNonPHIOrDbg()); - // FIXME: This arithmetic can overflow. - auto TripCount = SE->getAddExpr(BackedgeTakenCount, One); - auto ScaledTripCount = SE->getMulExpr( - TripCount, SE->getConstant(BackedgeTakenCount->getType(), Scale)); - auto ScaledBECount = SE->getMinusSCEV(ScaledTripCount, One); - Value *TakenCount = - Expander.expandCodeFor(ScaledBECount, BackedgeTakenCount->getType(), - Header->getFirstNonPHIOrDbg()); - Value *Cond = - new ICmpInst(BI, CmpInst::ICMP_EQ, NewIV, TakenCount, "exitcond"); - BI->setCondition(Cond); - - if (BI->getSuccessor(1) != Header) - BI->swapSuccessors(); - } - - SimplifyInstructionsInBlock(Header, TLI); - DeleteDeadPHIs(Header, TLI); -} - -void LoopReroll::DAGRootTracker::replaceIV(DAGRootSet &DRS, - const SCEV *Start, - const SCEV *IncrExpr) { - BasicBlock *Header = L->getHeader(); - Instruction *Inst = DRS.BaseInst; - - const SCEV *NewIVSCEV = - SE->getAddRecExpr(Start, IncrExpr, L, SCEV::FlagAnyWrap); - - { // Limit the lifetime of SCEVExpander. - const DataLayout &DL = Header->getModule()->getDataLayout(); - SCEVExpander Expander(*SE, DL, "reroll"); - Value *NewIV = Expander.expandCodeFor(NewIVSCEV, Inst->getType(), - Header->getFirstNonPHIOrDbg()); - - for (auto &KV : Uses) - if (KV.second.find_first() == 0) - KV.first->replaceUsesOfWith(Inst, NewIV); - } -} - -// Validate the selected reductions. All iterations must have an isomorphic -// part of the reduction chain and, for non-associative reductions, the chain -// entries must appear in order. -bool LoopReroll::ReductionTracker::validateSelected() { - // For a non-associative reduction, the chain entries must appear in order. - for (int i : Reds) { - int PrevIter = 0, BaseCount = 0, Count = 0; - for (Instruction *J : PossibleReds[i]) { - // Note that all instructions in the chain must have been found because - // all instructions in the function must have been assigned to some - // iteration. - int Iter = PossibleRedIter[J]; - if (Iter != PrevIter && Iter != PrevIter + 1 && - !PossibleReds[i].getReducedValue()->isAssociative()) { - LLVM_DEBUG(dbgs() << "LRR: Out-of-order non-associative reduction: " - << J << "\n"); - return false; - } - - if (Iter != PrevIter) { - if (Count != BaseCount) { - LLVM_DEBUG(dbgs() - << "LRR: Iteration " << PrevIter << " reduction use count " - << Count << " is not equal to the base use count " - << BaseCount << "\n"); - return false; - } - - Count = 0; - } - - ++Count; - if (Iter == 0) - ++BaseCount; - - PrevIter = Iter; - } - } - - return true; -} - -// For all selected reductions, remove all parts except those in the first -// iteration (and the PHI). Replace outside uses of the reduced value with uses -// of the first-iteration reduced value (in other words, reroll the selected -// reductions). -void LoopReroll::ReductionTracker::replaceSelected() { - // Fixup reductions to refer to the last instruction associated with the - // first iteration (not the last). - for (int i : Reds) { - int j = 0; - for (int e = PossibleReds[i].size(); j != e; ++j) - if (PossibleRedIter[PossibleReds[i][j]] != 0) { - --j; - break; - } - - // Replace users with the new end-of-chain value. - SmallInstructionVector Users; - for (User *U : PossibleReds[i].getReducedValue()->users()) { - Users.push_back(cast<Instruction>(U)); - } - - for (Instruction *User : Users) - User->replaceUsesOfWith(PossibleReds[i].getReducedValue(), - PossibleReds[i][j]); - } -} - -// Reroll the provided loop with respect to the provided induction variable. -// Generally, we're looking for a loop like this: -// -// %iv = phi [ (preheader, ...), (body, %iv.next) ] -// f(%iv) -// %iv.1 = add %iv, 1 <-- a root increment -// f(%iv.1) -// %iv.2 = add %iv, 2 <-- a root increment -// f(%iv.2) -// %iv.scale_m_1 = add %iv, scale-1 <-- a root increment -// f(%iv.scale_m_1) -// ... -// %iv.next = add %iv, scale -// %cmp = icmp(%iv, ...) -// br %cmp, header, exit -// -// Notably, we do not require that f(%iv), f(%iv.1), etc. be isolated groups of -// instructions. In other words, the instructions in f(%iv), f(%iv.1), etc. can -// be intermixed with eachother. The restriction imposed by this algorithm is -// that the relative order of the isomorphic instructions in f(%iv), f(%iv.1), -// etc. be the same. -// -// First, we collect the use set of %iv, excluding the other increment roots. -// This gives us f(%iv). Then we iterate over the loop instructions (scale-1) -// times, having collected the use set of f(%iv.(i+1)), during which we: -// - Ensure that the next unmatched instruction in f(%iv) is isomorphic to -// the next unmatched instruction in f(%iv.(i+1)). -// - Ensure that both matched instructions don't have any external users -// (with the exception of last-in-chain reduction instructions). -// - Track the (aliasing) write set, and other side effects, of all -// instructions that belong to future iterations that come before the matched -// instructions. If the matched instructions read from that write set, then -// f(%iv) or f(%iv.(i+1)) has some dependency on instructions in -// f(%iv.(j+1)) for some j > i, and we cannot reroll the loop. Similarly, -// if any of these future instructions had side effects (could not be -// speculatively executed), and so do the matched instructions, when we -// cannot reorder those side-effect-producing instructions, and rerolling -// fails. -// -// Finally, we make sure that all loop instructions are either loop increment -// roots, belong to simple latch code, parts of validated reductions, part of -// f(%iv) or part of some f(%iv.i). If all of that is true (and all reductions -// have been validated), then we reroll the loop. -bool LoopReroll::reroll(Instruction *IV, Loop *L, BasicBlock *Header, - const SCEV *BackedgeTakenCount, - ReductionTracker &Reductions) { - DAGRootTracker DAGRoots(this, L, IV, SE, AA, TLI, DT, LI, PreserveLCSSA, - IVToIncMap, LoopControlIVs); - - if (!DAGRoots.findRoots()) - return false; - LLVM_DEBUG(dbgs() << "LRR: Found all root induction increments for: " << *IV - << "\n"); - - if (!DAGRoots.validate(Reductions)) - return false; - if (!Reductions.validateSelected()) - return false; - // At this point, we've validated the rerolling, and we're committed to - // making changes! - - Reductions.replaceSelected(); - DAGRoots.replace(BackedgeTakenCount); - - ++NumRerolledLoops; - return true; -} - -bool LoopReroll::runOnLoop(Loop *L) { - BasicBlock *Header = L->getHeader(); - LLVM_DEBUG(dbgs() << "LRR: F[" << Header->getParent()->getName() << "] Loop %" - << Header->getName() << " (" << L->getNumBlocks() - << " block(s))\n"); - - // For now, we'll handle only single BB loops. - if (L->getNumBlocks() > 1) - return false; - - if (!SE->hasLoopInvariantBackedgeTakenCount(L)) - return false; - - const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); - LLVM_DEBUG(dbgs() << "\n Before Reroll:\n" << *(L->getHeader()) << "\n"); - LLVM_DEBUG(dbgs() << "LRR: backedge-taken count = " << *BackedgeTakenCount - << "\n"); - - // First, we need to find the induction variable with respect to which we can - // reroll (there may be several possible options). - SmallInstructionVector PossibleIVs; - IVToIncMap.clear(); - LoopControlIVs.clear(); - collectPossibleIVs(L, PossibleIVs); - - if (PossibleIVs.empty()) { - LLVM_DEBUG(dbgs() << "LRR: No possible IVs found\n"); - return false; - } - - ReductionTracker Reductions; - collectPossibleReductions(L, Reductions); - bool Changed = false; - - // For each possible IV, collect the associated possible set of 'root' nodes - // (i+1, i+2, etc.). - for (Instruction *PossibleIV : PossibleIVs) - if (reroll(PossibleIV, L, Header, BackedgeTakenCount, Reductions)) { - Changed = true; - break; - } - LLVM_DEBUG(dbgs() << "\n After Reroll:\n" << *(L->getHeader()) << "\n"); - - // Trip count of L has changed so SE must be re-evaluated. - if (Changed) - SE->forgetLoop(L); - - return Changed; -} - -PreservedAnalyses LoopRerollPass::run(Loop &L, LoopAnalysisManager &AM, - LoopStandardAnalysisResults &AR, - LPMUpdater &U) { - return LoopReroll(&AR.AA, &AR.LI, &AR.SE, &AR.TLI, &AR.DT, true).runOnLoop(&L) - ? getLoopPassPreservedAnalyses() - : PreservedAnalyses::all(); -} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRotation.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRotation.cpp index eee855058706..acb79e94d087 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRotation.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopRotation.cpp @@ -64,11 +64,12 @@ PreservedAnalyses LoopRotatePass::run(Loop &L, LoopAnalysisManager &AM, // Vectorization requires loop-rotation. Use default threshold for loops the // user explicitly marked for vectorization, even when header duplication is // disabled. - int Threshold = EnableHeaderDuplication || - hasVectorizeTransformation(&L) == TM_ForcedByUser - ? DefaultRotationThreshold - : 0; - const DataLayout &DL = L.getHeader()->getModule()->getDataLayout(); + int Threshold = + (EnableHeaderDuplication && !L.getHeader()->getParent()->hasMinSize()) || + hasVectorizeTransformation(&L) == TM_ForcedByUser + ? DefaultRotationThreshold + : 0; + const DataLayout &DL = L.getHeader()->getDataLayout(); const SimplifyQuery SQ = getBestSimplifyQuery(AR, DL); std::optional<MemorySSAUpdater> MSSAU; @@ -89,79 +90,3 @@ PreservedAnalyses LoopRotatePass::run(Loop &L, LoopAnalysisManager &AM, PA.preserve<MemorySSAAnalysis>(); return PA; } - -namespace { - -class LoopRotateLegacyPass : public LoopPass { - unsigned MaxHeaderSize; - bool PrepareForLTO; - -public: - static char ID; // Pass ID, replacement for typeid - LoopRotateLegacyPass(int SpecifiedMaxHeaderSize = -1, - bool PrepareForLTO = false) - : LoopPass(ID), PrepareForLTO(PrepareForLTO) { - initializeLoopRotateLegacyPassPass(*PassRegistry::getPassRegistry()); - if (SpecifiedMaxHeaderSize == -1) - MaxHeaderSize = DefaultRotationThreshold; - else - MaxHeaderSize = unsigned(SpecifiedMaxHeaderSize); - } - - // LCSSA form makes instruction renaming easier. - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - getLoopAnalysisUsage(AU); - - // Lazy BFI and BPI are marked as preserved here so LoopRotate - // can remain part of the same loop pass manager as LICM. - AU.addPreserved<LazyBlockFrequencyInfoPass>(); - AU.addPreserved<LazyBranchProbabilityInfoPass>(); - } - - bool runOnLoop(Loop *L, LPPassManager &LPM) override { - if (skipLoop(L)) - return false; - Function &F = *L->getHeader()->getParent(); - - auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - const SimplifyQuery SQ = getBestSimplifyQuery(*this, F); - std::optional<MemorySSAUpdater> MSSAU; - // Not requiring MemorySSA and getting it only if available will split - // the loop pass pipeline when LoopRotate is being run first. - auto *MSSAA = getAnalysisIfAvailable<MemorySSAWrapperPass>(); - if (MSSAA) - MSSAU = MemorySSAUpdater(&MSSAA->getMSSA()); - // Vectorization requires loop-rotation. Use default threshold for loops the - // user explicitly marked for vectorization, even when header duplication is - // disabled. - int Threshold = hasVectorizeTransformation(L) == TM_ForcedByUser - ? DefaultRotationThreshold - : MaxHeaderSize; - - return LoopRotation(L, LI, TTI, AC, &DT, &SE, MSSAU ? &*MSSAU : nullptr, SQ, - false, Threshold, false, - PrepareForLTO || PrepareForLTOOption); - } -}; -} // end namespace - -char LoopRotateLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(LoopRotateLegacyPass, "loop-rotate", "Rotate Loops", - false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(LoopPass) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) -INITIALIZE_PASS_END(LoopRotateLegacyPass, "loop-rotate", "Rotate Loops", false, - false) - -Pass *llvm::createLoopRotatePass(int MaxHeaderSize, bool PrepareForLTO) { - return new LoopRotateLegacyPass(MaxHeaderSize, PrepareForLTO); -} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp index 028a487ecdbc..ae9103d0608a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp @@ -16,7 +16,6 @@ #include "llvm/Transforms/Scalar/LoopSimplifyCFG.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/DependenceAnalysis.h" #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 7ebc5da8b25a..91461d1ed275 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -193,10 +193,18 @@ static cl::opt<cl::boolOrDefault> AllowTerminatingConditionFoldingAfterLSR( "lsr-term-fold", cl::Hidden, cl::desc("Attempt to replace primary IV with other IV.")); -static cl::opt<bool> AllowDropSolutionIfLessProfitable( - "lsr-drop-solution", cl::Hidden, cl::init(false), +static cl::opt<cl::boolOrDefault> AllowDropSolutionIfLessProfitable( + "lsr-drop-solution", cl::Hidden, cl::desc("Attempt to drop solution if it is less profitable")); +static cl::opt<bool> EnableVScaleImmediates( + "lsr-enable-vscale-immediates", cl::Hidden, cl::init(true), + cl::desc("Enable analysis of vscale-relative immediates in LSR")); + +static cl::opt<bool> DropScaledForVScale( + "lsr-drop-scaled-reg-for-vscale", cl::Hidden, cl::init(true), + cl::desc("Avoid using scaled registers with vscale-relative addressing")); + STATISTIC(NumTermFold, "Number of terminating condition fold recognized and performed"); @@ -247,6 +255,126 @@ public: void dump() const; }; +// An offset from an address that is either scalable or fixed. Used for +// per-target optimizations of addressing modes. +class Immediate : public details::FixedOrScalableQuantity<Immediate, int64_t> { + constexpr Immediate(ScalarTy MinVal, bool Scalable) + : FixedOrScalableQuantity(MinVal, Scalable) {} + + constexpr Immediate(const FixedOrScalableQuantity<Immediate, int64_t> &V) + : FixedOrScalableQuantity(V) {} + +public: + constexpr Immediate() = delete; + + static constexpr Immediate getFixed(ScalarTy MinVal) { + return {MinVal, false}; + } + static constexpr Immediate getScalable(ScalarTy MinVal) { + return {MinVal, true}; + } + static constexpr Immediate get(ScalarTy MinVal, bool Scalable) { + return {MinVal, Scalable}; + } + static constexpr Immediate getZero() { return {0, false}; } + static constexpr Immediate getFixedMin() { + return {std::numeric_limits<int64_t>::min(), false}; + } + static constexpr Immediate getFixedMax() { + return {std::numeric_limits<int64_t>::max(), false}; + } + static constexpr Immediate getScalableMin() { + return {std::numeric_limits<int64_t>::min(), true}; + } + static constexpr Immediate getScalableMax() { + return {std::numeric_limits<int64_t>::max(), true}; + } + + constexpr bool isLessThanZero() const { return Quantity < 0; } + + constexpr bool isGreaterThanZero() const { return Quantity > 0; } + + constexpr bool isCompatibleImmediate(const Immediate &Imm) const { + return isZero() || Imm.isZero() || Imm.Scalable == Scalable; + } + + constexpr bool isMin() const { + return Quantity == std::numeric_limits<ScalarTy>::min(); + } + + constexpr bool isMax() const { + return Quantity == std::numeric_limits<ScalarTy>::max(); + } + + // Arithmetic 'operators' that cast to unsigned types first. + constexpr Immediate addUnsigned(const Immediate &RHS) const { + assert(isCompatibleImmediate(RHS) && "Incompatible Immediates"); + ScalarTy Value = (uint64_t)Quantity + RHS.getKnownMinValue(); + return {Value, Scalable || RHS.isScalable()}; + } + + constexpr Immediate subUnsigned(const Immediate &RHS) const { + assert(isCompatibleImmediate(RHS) && "Incompatible Immediates"); + ScalarTy Value = (uint64_t)Quantity - RHS.getKnownMinValue(); + return {Value, Scalable || RHS.isScalable()}; + } + + // Scale the quantity by a constant without caring about runtime scalability. + constexpr Immediate mulUnsigned(const ScalarTy RHS) const { + ScalarTy Value = (uint64_t)Quantity * RHS; + return {Value, Scalable}; + } + + // Helpers for generating SCEVs with vscale terms where needed. + const SCEV *getSCEV(ScalarEvolution &SE, Type *Ty) const { + const SCEV *S = SE.getConstant(Ty, Quantity); + if (Scalable) + S = SE.getMulExpr(S, SE.getVScale(S->getType())); + return S; + } + + const SCEV *getNegativeSCEV(ScalarEvolution &SE, Type *Ty) const { + const SCEV *NegS = SE.getConstant(Ty, -(uint64_t)Quantity); + if (Scalable) + NegS = SE.getMulExpr(NegS, SE.getVScale(NegS->getType())); + return NegS; + } + + const SCEV *getUnknownSCEV(ScalarEvolution &SE, Type *Ty) const { + const SCEV *SU = SE.getUnknown(ConstantInt::getSigned(Ty, Quantity)); + if (Scalable) + SU = SE.getMulExpr(SU, SE.getVScale(SU->getType())); + return SU; + } +}; + +// This is needed for the Compare type of std::map when Immediate is used +// as a key. We don't need it to be fully correct against any value of vscale, +// just to make sure that vscale-related terms in the map are considered against +// each other rather than being mixed up and potentially missing opportunities. +struct KeyOrderTargetImmediate { + bool operator()(const Immediate &LHS, const Immediate &RHS) const { + if (LHS.isScalable() && !RHS.isScalable()) + return false; + if (!LHS.isScalable() && RHS.isScalable()) + return true; + return LHS.getKnownMinValue() < RHS.getKnownMinValue(); + } +}; + +// This would be nicer if we could be generic instead of directly using size_t, +// but there doesn't seem to be a type trait for is_orderable or +// is_lessthan_comparable or similar. +struct KeyOrderSizeTAndImmediate { + bool operator()(const std::pair<size_t, Immediate> &LHS, + const std::pair<size_t, Immediate> &RHS) const { + size_t LSize = LHS.first; + size_t RSize = RHS.first; + if (LSize != RSize) + return LSize < RSize; + return KeyOrderTargetImmediate()(LHS.second, RHS.second); + } +}; } // end anonymous namespace #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -357,7 +485,7 @@ struct Formula { GlobalValue *BaseGV = nullptr; /// Base offset for complex addressing. - int64_t BaseOffset = 0; + Immediate BaseOffset = Immediate::getZero(); /// Whether any complex addressing has a base register. bool HasBaseReg = false; @@ -388,7 +516,7 @@ struct Formula { /// An additional constant offset which added near the use. This requires a /// temporary register, but the offset itself can live in an add immediate /// field rather than a register. - int64_t UnfoldedOffset = 0; + Immediate UnfoldedOffset = Immediate::getZero(); Formula() = default; @@ -628,7 +756,7 @@ void Formula::print(raw_ostream &OS) const { if (!First) OS << " + "; else First = false; BaseGV->printAsOperand(OS, /*PrintType=*/false); } - if (BaseOffset != 0) { + if (BaseOffset.isNonZero()) { if (!First) OS << " + "; else First = false; OS << BaseOffset; } @@ -652,7 +780,7 @@ void Formula::print(raw_ostream &OS) const { OS << "<unknown>"; OS << ')'; } - if (UnfoldedOffset != 0) { + if (UnfoldedOffset.isNonZero()) { if (!First) OS << " + "; OS << "imm(" << UnfoldedOffset << ')'; } @@ -798,28 +926,36 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS, /// If S involves the addition of a constant integer value, return that integer /// value, and mutate S to point to a new SCEV with that value excluded. -static int64_t ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) { +static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) { if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) { if (C->getAPInt().getSignificantBits() <= 64) { S = SE.getConstant(C->getType(), 0); - return C->getValue()->getSExtValue(); + return Immediate::getFixed(C->getValue()->getSExtValue()); } } else if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { SmallVector<const SCEV *, 8> NewOps(Add->operands()); - int64_t Result = ExtractImmediate(NewOps.front(), SE); - if (Result != 0) + Immediate Result = ExtractImmediate(NewOps.front(), SE); + if (Result.isNonZero()) S = SE.getAddExpr(NewOps); return Result; } else if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) { SmallVector<const SCEV *, 8> NewOps(AR->operands()); - int64_t Result = ExtractImmediate(NewOps.front(), SE); - if (Result != 0) + Immediate Result = ExtractImmediate(NewOps.front(), SE); + if (Result.isNonZero()) S = SE.getAddRecExpr(NewOps, AR->getLoop(), // FIXME: AR->getNoWrapFlags(SCEV::FlagNW) SCEV::FlagAnyWrap); return Result; + } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) { + if (EnableVScaleImmediates && M->getNumOperands() == 2) { + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(M->getOperand(0))) + if (isa<SCEVVScale>(M->getOperand(1))) { + S = SE.getConstant(M->getType(), 0); + return Immediate::getScalable(C->getValue()->getSExtValue()); + } + } } - return 0; + return Immediate::getZero(); } /// If S involves the addition of a GlobalValue address, return that symbol, and @@ -1134,7 +1270,7 @@ struct LSRFixup { /// A constant offset to be added to the LSRUse expression. This allows /// multiple fixups to share the same LSRUse with different offsets, for /// example in an unrolled loop. - int64_t Offset = 0; + Immediate Offset = Immediate::getZero(); LSRFixup() = default; @@ -1197,8 +1333,8 @@ public: SmallVector<LSRFixup, 8> Fixups; /// Keep track of the min and max offsets of the fixups. - int64_t MinOffset = std::numeric_limits<int64_t>::max(); - int64_t MaxOffset = std::numeric_limits<int64_t>::min(); + Immediate MinOffset = Immediate::getFixedMax(); + Immediate MaxOffset = Immediate::getFixedMin(); /// This records whether all of the fixups using this LSRUse are outside of /// the loop, in which case some special-case heuristics may be used. @@ -1234,9 +1370,9 @@ public: void pushFixup(LSRFixup &f) { Fixups.push_back(f); - if (f.Offset > MaxOffset) + if (Immediate::isKnownGT(f.Offset, MaxOffset)) MaxOffset = f.Offset; - if (f.Offset < MinOffset) + if (Immediate::isKnownLT(f.Offset, MinOffset)) MinOffset = f.Offset; } @@ -1254,7 +1390,7 @@ public: static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, LSRUse::KindType Kind, MemAccessTy AccessTy, - GlobalValue *BaseGV, int64_t BaseOffset, + GlobalValue *BaseGV, Immediate BaseOffset, bool HasBaseReg, int64_t Scale, Instruction *Fixup = nullptr); @@ -1308,9 +1444,9 @@ void Cost::RateRegister(const Formula &F, const SCEV *Reg, // If the step size matches the base offset, we could use pre-indexed // addressing. - if (AMK == TTI::AMK_PreIndexed) { + if (AMK == TTI::AMK_PreIndexed && F.BaseOffset.isFixed()) { if (auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE))) - if (Step->getAPInt() == F.BaseOffset) + if (Step->getAPInt() == F.BaseOffset.getFixedValue()) LoopCost = 0; } else if (AMK == TTI::AMK_PostIndexed) { const SCEV *LoopStep = AR->getStepRecurrence(*SE); @@ -1401,27 +1537,32 @@ void Cost::RateFormula(const Formula &F, // allows to fold 2 registers. C.NumBaseAdds += NumBaseParts - (1 + (F.Scale && isAMCompletelyFolded(*TTI, LU, F))); - C.NumBaseAdds += (F.UnfoldedOffset != 0); + C.NumBaseAdds += (F.UnfoldedOffset.isNonZero()); // Accumulate non-free scaling amounts. C.ScaleCost += *getScalingFactorCost(*TTI, LU, F, *L).getValue(); // Tally up the non-zero immediates. for (const LSRFixup &Fixup : LU.Fixups) { - int64_t O = Fixup.Offset; - int64_t Offset = (uint64_t)O + F.BaseOffset; - if (F.BaseGV) - C.ImmCost += 64; // Handle symbolic values conservatively. - // TODO: This should probably be the pointer size. - else if (Offset != 0) - C.ImmCost += APInt(64, Offset, true).getSignificantBits(); - - // Check with target if this offset with this instruction is - // specifically not supported. - if (LU.Kind == LSRUse::Address && Offset != 0 && - !isAMCompletelyFolded(*TTI, LSRUse::Address, LU.AccessTy, F.BaseGV, - Offset, F.HasBaseReg, F.Scale, Fixup.UserInst)) - C.NumBaseAdds++; + if (Fixup.Offset.isCompatibleImmediate(F.BaseOffset)) { + Immediate Offset = Fixup.Offset.addUnsigned(F.BaseOffset); + if (F.BaseGV) + C.ImmCost += 64; // Handle symbolic values conservatively. + // TODO: This should probably be the pointer size. + else if (Offset.isNonZero()) + C.ImmCost += + APInt(64, Offset.getKnownMinValue(), true).getSignificantBits(); + + // Check with target if this offset with this instruction is + // specifically not supported. + if (LU.Kind == LSRUse::Address && Offset.isNonZero() && + !isAMCompletelyFolded(*TTI, LSRUse::Address, LU.AccessTy, F.BaseGV, + Offset, F.HasBaseReg, F.Scale, Fixup.UserInst)) + C.NumBaseAdds++; + } else { + // Incompatible immediate type, increase cost to avoid using + C.ImmCost += 2048; + } } // If we don't count instruction cost exit here. @@ -1546,7 +1687,7 @@ void LSRFixup::print(raw_ostream &OS) const { PIL->getHeader()->printAsOperand(OS, /*PrintType=*/false); } - if (Offset != 0) + if (Offset.isNonZero()) OS << ", Offset=" << Offset; } @@ -1673,14 +1814,19 @@ LLVM_DUMP_METHOD void LSRUse::dump() const { static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, LSRUse::KindType Kind, MemAccessTy AccessTy, - GlobalValue *BaseGV, int64_t BaseOffset, + GlobalValue *BaseGV, Immediate BaseOffset, bool HasBaseReg, int64_t Scale, - Instruction *Fixup/*= nullptr*/) { + Instruction *Fixup /* = nullptr */) { switch (Kind) { - case LSRUse::Address: - return TTI.isLegalAddressingMode(AccessTy.MemTy, BaseGV, BaseOffset, - HasBaseReg, Scale, AccessTy.AddrSpace, Fixup); - + case LSRUse::Address: { + int64_t FixedOffset = + BaseOffset.isScalable() ? 0 : BaseOffset.getFixedValue(); + int64_t ScalableOffset = + BaseOffset.isScalable() ? BaseOffset.getKnownMinValue() : 0; + return TTI.isLegalAddressingMode(AccessTy.MemTy, BaseGV, FixedOffset, + HasBaseReg, Scale, AccessTy.AddrSpace, + Fixup, ScalableOffset); + } case LSRUse::ICmpZero: // There's not even a target hook for querying whether it would be legal to // fold a GV into an ICmp. @@ -1688,7 +1834,7 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, return false; // ICmp only has two operands; don't allow more than two non-trivial parts. - if (Scale != 0 && HasBaseReg && BaseOffset != 0) + if (Scale != 0 && HasBaseReg && BaseOffset.isNonZero()) return false; // ICmp only supports no scale or a -1 scale, as we can "fold" a -1 scale by @@ -1698,7 +1844,12 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, // If we have low-level target information, ask the target if it can fold an // integer immediate on an icmp. - if (BaseOffset != 0) { + if (BaseOffset.isNonZero()) { + // We don't have an interface to query whether the target supports + // icmpzero against scalable quantities yet. + if (BaseOffset.isScalable()) + return false; + // We have one of: // ICmpZero BaseReg + BaseOffset => ICmp BaseReg, -BaseOffset // ICmpZero -1*ScaleReg + BaseOffset => ICmp ScaleReg, BaseOffset @@ -1706,8 +1857,8 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, if (Scale == 0) // The cast does the right thing with // std::numeric_limits<int64_t>::min(). - BaseOffset = -(uint64_t)BaseOffset; - return TTI.isLegalICmpImmediate(BaseOffset); + BaseOffset = BaseOffset.getFixed(-(uint64_t)BaseOffset.getFixedValue()); + return TTI.isLegalICmpImmediate(BaseOffset.getFixedValue()); } // ICmpZero BaseReg + -1*ScaleReg => ICmp BaseReg, ScaleReg @@ -1715,30 +1866,35 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, case LSRUse::Basic: // Only handle single-register values. - return !BaseGV && Scale == 0 && BaseOffset == 0; + return !BaseGV && Scale == 0 && BaseOffset.isZero(); case LSRUse::Special: // Special case Basic to handle -1 scales. - return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0; + return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset.isZero(); } llvm_unreachable("Invalid LSRUse Kind!"); } static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, - int64_t MinOffset, int64_t MaxOffset, + Immediate MinOffset, Immediate MaxOffset, LSRUse::KindType Kind, MemAccessTy AccessTy, - GlobalValue *BaseGV, int64_t BaseOffset, + GlobalValue *BaseGV, Immediate BaseOffset, bool HasBaseReg, int64_t Scale) { + if (BaseOffset.isNonZero() && + (BaseOffset.isScalable() != MinOffset.isScalable() || + BaseOffset.isScalable() != MaxOffset.isScalable())) + return false; // Check for overflow. - if (((int64_t)((uint64_t)BaseOffset + MinOffset) > BaseOffset) != - (MinOffset > 0)) + int64_t Base = BaseOffset.getKnownMinValue(); + int64_t Min = MinOffset.getKnownMinValue(); + int64_t Max = MaxOffset.getKnownMinValue(); + if (((int64_t)((uint64_t)Base + Min) > Base) != (Min > 0)) return false; - MinOffset = (uint64_t)BaseOffset + MinOffset; - if (((int64_t)((uint64_t)BaseOffset + MaxOffset) > BaseOffset) != - (MaxOffset > 0)) + MinOffset = Immediate::get((uint64_t)Base + Min, MinOffset.isScalable()); + if (((int64_t)((uint64_t)Base + Max) > Base) != (Max > 0)) return false; - MaxOffset = (uint64_t)BaseOffset + MaxOffset; + MaxOffset = Immediate::get((uint64_t)Base + Max, MaxOffset.isScalable()); return isAMCompletelyFolded(TTI, Kind, AccessTy, BaseGV, MinOffset, HasBaseReg, Scale) && @@ -1747,7 +1903,7 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, } static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, - int64_t MinOffset, int64_t MaxOffset, + Immediate MinOffset, Immediate MaxOffset, LSRUse::KindType Kind, MemAccessTy AccessTy, const Formula &F, const Loop &L) { // For the purpose of isAMCompletelyFolded either having a canonical formula @@ -1763,10 +1919,10 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, } /// Test whether we know how to expand the current formula. -static bool isLegalUse(const TargetTransformInfo &TTI, int64_t MinOffset, - int64_t MaxOffset, LSRUse::KindType Kind, +static bool isLegalUse(const TargetTransformInfo &TTI, Immediate MinOffset, + Immediate MaxOffset, LSRUse::KindType Kind, MemAccessTy AccessTy, GlobalValue *BaseGV, - int64_t BaseOffset, bool HasBaseReg, int64_t Scale) { + Immediate BaseOffset, bool HasBaseReg, int64_t Scale) { // We know how to expand completely foldable formulae. return isAMCompletelyFolded(TTI, MinOffset, MaxOffset, Kind, AccessTy, BaseGV, BaseOffset, HasBaseReg, Scale) || @@ -1777,13 +1933,21 @@ static bool isLegalUse(const TargetTransformInfo &TTI, int64_t MinOffset, BaseGV, BaseOffset, true, 0)); } -static bool isLegalUse(const TargetTransformInfo &TTI, int64_t MinOffset, - int64_t MaxOffset, LSRUse::KindType Kind, +static bool isLegalUse(const TargetTransformInfo &TTI, Immediate MinOffset, + Immediate MaxOffset, LSRUse::KindType Kind, MemAccessTy AccessTy, const Formula &F) { return isLegalUse(TTI, MinOffset, MaxOffset, Kind, AccessTy, F.BaseGV, F.BaseOffset, F.HasBaseReg, F.Scale); } +static bool isLegalAddImmediate(const TargetTransformInfo &TTI, + Immediate Offset) { + if (Offset.isScalable()) + return TTI.isLegalAddScalableImmediate(Offset.getKnownMinValue()); + + return TTI.isLegalAddImmediate(Offset.getFixedValue()); +} + static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, const LSRUse &LU, const Formula &F) { // Target may want to look at the user instructions. @@ -1816,12 +1980,20 @@ static InstructionCost getScalingFactorCost(const TargetTransformInfo &TTI, switch (LU.Kind) { case LSRUse::Address: { // Check the scaling factor cost with both the min and max offsets. + int64_t ScalableMin = 0, ScalableMax = 0, FixedMin = 0, FixedMax = 0; + if (F.BaseOffset.isScalable()) { + ScalableMin = (F.BaseOffset + LU.MinOffset).getKnownMinValue(); + ScalableMax = (F.BaseOffset + LU.MaxOffset).getKnownMinValue(); + } else { + FixedMin = (F.BaseOffset + LU.MinOffset).getFixedValue(); + FixedMax = (F.BaseOffset + LU.MaxOffset).getFixedValue(); + } InstructionCost ScaleCostMinOffset = TTI.getScalingFactorCost( - LU.AccessTy.MemTy, F.BaseGV, F.BaseOffset + LU.MinOffset, F.HasBaseReg, - F.Scale, LU.AccessTy.AddrSpace); + LU.AccessTy.MemTy, F.BaseGV, StackOffset::get(FixedMin, ScalableMin), + F.HasBaseReg, F.Scale, LU.AccessTy.AddrSpace); InstructionCost ScaleCostMaxOffset = TTI.getScalingFactorCost( - LU.AccessTy.MemTy, F.BaseGV, F.BaseOffset + LU.MaxOffset, F.HasBaseReg, - F.Scale, LU.AccessTy.AddrSpace); + LU.AccessTy.MemTy, F.BaseGV, StackOffset::get(FixedMax, ScalableMax), + F.HasBaseReg, F.Scale, LU.AccessTy.AddrSpace); assert(ScaleCostMinOffset.isValid() && ScaleCostMaxOffset.isValid() && "Legal addressing mode has an illegal cost!"); @@ -1840,10 +2012,11 @@ static InstructionCost getScalingFactorCost(const TargetTransformInfo &TTI, static bool isAlwaysFoldable(const TargetTransformInfo &TTI, LSRUse::KindType Kind, MemAccessTy AccessTy, - GlobalValue *BaseGV, int64_t BaseOffset, + GlobalValue *BaseGV, Immediate BaseOffset, bool HasBaseReg) { // Fast-path: zero is always foldable. - if (BaseOffset == 0 && !BaseGV) return true; + if (BaseOffset.isZero() && !BaseGV) + return true; // Conservatively, create an address with an immediate and a // base and a scale. @@ -1856,13 +2029,22 @@ static bool isAlwaysFoldable(const TargetTransformInfo &TTI, HasBaseReg = true; } + // FIXME: Try with + without a scale? Maybe based on TTI? + // I think basereg + scaledreg + immediateoffset isn't a good 'conservative' + // default for many architectures, not just AArch64 SVE. More investigation + // needed later to determine if this should be used more widely than just + // on scalable types. + if (HasBaseReg && BaseOffset.isNonZero() && Kind != LSRUse::ICmpZero && + AccessTy.MemTy && AccessTy.MemTy->isScalableTy() && DropScaledForVScale) + Scale = 0; + return isAMCompletelyFolded(TTI, Kind, AccessTy, BaseGV, BaseOffset, HasBaseReg, Scale); } static bool isAlwaysFoldable(const TargetTransformInfo &TTI, - ScalarEvolution &SE, int64_t MinOffset, - int64_t MaxOffset, LSRUse::KindType Kind, + ScalarEvolution &SE, Immediate MinOffset, + Immediate MaxOffset, LSRUse::KindType Kind, MemAccessTy AccessTy, const SCEV *S, bool HasBaseReg) { // Fast-path: zero is always foldable. @@ -1870,14 +2052,18 @@ static bool isAlwaysFoldable(const TargetTransformInfo &TTI, // Conservatively, create an address with an immediate and a // base and a scale. - int64_t BaseOffset = ExtractImmediate(S, SE); + Immediate BaseOffset = ExtractImmediate(S, SE); GlobalValue *BaseGV = ExtractSymbol(S, SE); // If there's anything else involved, it's not foldable. if (!S->isZero()) return false; // Fast-path: zero is always foldable. - if (BaseOffset == 0 && !BaseGV) return true; + if (BaseOffset.isZero() && !BaseGV) + return true; + + if (BaseOffset.isScalable()) + return false; // Conservatively, create an address with an immediate and a // base and a scale. @@ -2026,11 +2212,11 @@ class LSRInstance { using UseMapTy = DenseMap<LSRUse::SCEVUseKindPair, size_t>; UseMapTy UseMap; - bool reconcileNewOffset(LSRUse &LU, int64_t NewOffset, bool HasBaseReg, + bool reconcileNewOffset(LSRUse &LU, Immediate NewOffset, bool HasBaseReg, LSRUse::KindType Kind, MemAccessTy AccessTy); - std::pair<size_t, int64_t> getUse(const SCEV *&Expr, LSRUse::KindType Kind, - MemAccessTy AccessTy); + std::pair<size_t, Immediate> getUse(const SCEV *&Expr, LSRUse::KindType Kind, + MemAccessTy AccessTy); void DeleteUse(LSRUse &LU, size_t LUIdx); @@ -2056,7 +2242,7 @@ class LSRInstance { void GenerateSymbolicOffsets(LSRUse &LU, unsigned LUIdx, Formula Base); void GenerateConstantOffsetsImpl(LSRUse &LU, unsigned LUIdx, const Formula &Base, - const SmallVectorImpl<int64_t> &Worklist, + const SmallVectorImpl<Immediate> &Worklist, size_t Idx, bool IsScaledReg = false); void GenerateConstantOffsets(LSRUse &LU, unsigned LUIdx, Formula Base); void GenerateICmpZeroScales(LSRUse &LU, unsigned LUIdx, Formula Base); @@ -2215,17 +2401,20 @@ void LSRInstance::OptimizeShadowIV() { // Ignore negative constants, as the code below doesn't handle them // correctly. TODO: Remove this restriction. - if (!C->getValue().isStrictlyPositive()) continue; + if (!C->getValue().isStrictlyPositive()) + continue; /* Add new PHINode. */ - PHINode *NewPH = PHINode::Create(DestTy, 2, "IV.S.", PH); + PHINode *NewPH = PHINode::Create(DestTy, 2, "IV.S.", PH->getIterator()); + NewPH->setDebugLoc(PH->getDebugLoc()); /* create new increment. '++d' in above example. */ Constant *CFP = ConstantFP::get(DestTy, C->getZExtValue()); - BinaryOperator *NewIncr = - BinaryOperator::Create(Incr->getOpcode() == Instruction::Add ? - Instruction::FAdd : Instruction::FSub, - NewPH, CFP, "IV.S.next.", Incr); + BinaryOperator *NewIncr = BinaryOperator::Create( + Incr->getOpcode() == Instruction::Add ? Instruction::FAdd + : Instruction::FSub, + NewPH, CFP, "IV.S.next.", Incr->getIterator()); + NewIncr->setDebugLoc(Incr->getDebugLoc()); NewPH->addIncoming(NewInit, PH->getIncomingBlock(Entry)); NewPH->addIncoming(NewIncr, PH->getIncomingBlock(Latch)); @@ -2395,8 +2584,8 @@ ICmpInst *LSRInstance::OptimizeMax(ICmpInst *Cond, IVStrideUse* &CondUse) { // Ok, everything looks ok to change the condition into an SLT or SGE and // delete the max calculation. - ICmpInst *NewCond = - new ICmpInst(Cond, Pred, Cond->getOperand(0), NewRHS, "scmp"); + ICmpInst *NewCond = new ICmpInst(Cond->getIterator(), Pred, + Cond->getOperand(0), NewRHS, "scmp"); // Delete the max calculation instructions. NewCond->setDebugLoc(Cond->getDebugLoc()); @@ -2563,11 +2752,11 @@ LSRInstance::OptimizeLoopTermCond() { /// Determine if the given use can accommodate a fixup at the given offset and /// other details. If so, update the use and return true. -bool LSRInstance::reconcileNewOffset(LSRUse &LU, int64_t NewOffset, +bool LSRInstance::reconcileNewOffset(LSRUse &LU, Immediate NewOffset, bool HasBaseReg, LSRUse::KindType Kind, MemAccessTy AccessTy) { - int64_t NewMinOffset = LU.MinOffset; - int64_t NewMaxOffset = LU.MaxOffset; + Immediate NewMinOffset = LU.MinOffset; + Immediate NewMaxOffset = LU.MaxOffset; MemAccessTy NewAccessTy = AccessTy; // Check for a mismatched kind. It's tempting to collapse mismatched kinds to @@ -2587,18 +2776,25 @@ bool LSRInstance::reconcileNewOffset(LSRUse &LU, int64_t NewOffset, } // Conservatively assume HasBaseReg is true for now. - if (NewOffset < LU.MinOffset) { + if (Immediate::isKnownLT(NewOffset, LU.MinOffset)) { if (!isAlwaysFoldable(TTI, Kind, NewAccessTy, /*BaseGV=*/nullptr, LU.MaxOffset - NewOffset, HasBaseReg)) return false; NewMinOffset = NewOffset; - } else if (NewOffset > LU.MaxOffset) { + } else if (Immediate::isKnownGT(NewOffset, LU.MaxOffset)) { if (!isAlwaysFoldable(TTI, Kind, NewAccessTy, /*BaseGV=*/nullptr, NewOffset - LU.MinOffset, HasBaseReg)) return false; NewMaxOffset = NewOffset; } + // FIXME: We should be able to handle some level of scalable offset support + // for 'void', but in order to get basic support up and running this is + // being left out. + if (NewAccessTy.MemTy && NewAccessTy.MemTy->isVoidTy() && + (NewMinOffset.isScalable() || NewMaxOffset.isScalable())) + return false; + // Update the use. LU.MinOffset = NewMinOffset; LU.MaxOffset = NewMaxOffset; @@ -2609,17 +2805,17 @@ bool LSRInstance::reconcileNewOffset(LSRUse &LU, int64_t NewOffset, /// Return an LSRUse index and an offset value for a fixup which needs the given /// expression, with the given kind and optional access type. Either reuse an /// existing use or create a new one, as needed. -std::pair<size_t, int64_t> LSRInstance::getUse(const SCEV *&Expr, - LSRUse::KindType Kind, - MemAccessTy AccessTy) { +std::pair<size_t, Immediate> LSRInstance::getUse(const SCEV *&Expr, + LSRUse::KindType Kind, + MemAccessTy AccessTy) { const SCEV *Copy = Expr; - int64_t Offset = ExtractImmediate(Expr, SE); + Immediate Offset = ExtractImmediate(Expr, SE); // Basic uses can't accept any offset, for example. if (!isAlwaysFoldable(TTI, Kind, AccessTy, /*BaseGV=*/ nullptr, Offset, /*HasBaseReg=*/ true)) { Expr = Copy; - Offset = 0; + Offset = Immediate::getFixed(0); } std::pair<UseMapTy::iterator, bool> P = @@ -2680,7 +2876,7 @@ LSRInstance::FindUseWithSimilarFormula(const Formula &OrigF, F.BaseGV == OrigF.BaseGV && F.Scale == OrigF.Scale && F.UnfoldedOffset == OrigF.UnfoldedOffset) { - if (F.BaseOffset == 0) + if (F.BaseOffset.isZero()) return &LU; // This is the formula where all the registers and symbols matched; // there aren't going to be any others. Since we declined it, we @@ -3162,14 +3358,27 @@ void LSRInstance::FinalizeChain(IVChain &Chain) { static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst, Value *Operand, const TargetTransformInfo &TTI) { const SCEVConstant *IncConst = dyn_cast<SCEVConstant>(IncExpr); - if (!IncConst || !isAddressUse(TTI, UserInst, Operand)) - return false; + Immediate IncOffset = Immediate::getZero(); + if (IncConst) { + if (IncConst && IncConst->getAPInt().getSignificantBits() > 64) + return false; + IncOffset = Immediate::getFixed(IncConst->getValue()->getSExtValue()); + } else { + // Look for mul(vscale, constant), to detect a scalable offset. + auto *IncVScale = dyn_cast<SCEVMulExpr>(IncExpr); + if (!IncVScale || IncVScale->getNumOperands() != 2 || + !isa<SCEVVScale>(IncVScale->getOperand(1))) + return false; + auto *Scale = dyn_cast<SCEVConstant>(IncVScale->getOperand(0)); + if (!Scale || Scale->getType()->getScalarSizeInBits() > 64) + return false; + IncOffset = Immediate::getScalable(Scale->getValue()->getSExtValue()); + } - if (IncConst->getAPInt().getSignificantBits() > 64) + if (!isAddressUse(TTI, UserInst, Operand)) return false; MemAccessTy AccessTy = getAccessType(TTI, UserInst, Operand); - int64_t IncOffset = IncConst->getValue()->getSExtValue(); if (!isAlwaysFoldable(TTI, LSRUse::Address, AccessTy, /*BaseGV=*/nullptr, IncOffset, /*HasBaseReg=*/false)) return false; @@ -3217,6 +3426,10 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, Type *IVTy = IVSrc->getType(); Type *IntTy = SE.getEffectiveSCEVType(IVTy); const SCEV *LeftOverExpr = nullptr; + const SCEV *Accum = SE.getZero(IntTy); + SmallVector<std::pair<const SCEV *, Value *>> Bases; + Bases.emplace_back(Accum, IVSrc); + for (const IVInc &Inc : Chain) { Instruction *InsertPt = Inc.UserInst; if (isa<PHINode>(InsertPt)) @@ -3229,10 +3442,31 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, // IncExpr was the result of subtraction of two narrow values, so must // be signed. const SCEV *IncExpr = SE.getNoopOrSignExtend(Inc.IncExpr, IntTy); + Accum = SE.getAddExpr(Accum, IncExpr); LeftOverExpr = LeftOverExpr ? SE.getAddExpr(LeftOverExpr, IncExpr) : IncExpr; } - if (LeftOverExpr && !LeftOverExpr->isZero()) { + + // Look through each base to see if any can produce a nice addressing mode. + bool FoundBase = false; + for (auto [MapScev, MapIVOper] : reverse(Bases)) { + const SCEV *Remainder = SE.getMinusSCEV(Accum, MapScev); + if (canFoldIVIncExpr(Remainder, Inc.UserInst, Inc.IVOperand, TTI)) { + if (!Remainder->isZero()) { + Rewriter.clearPostInc(); + Value *IncV = Rewriter.expandCodeFor(Remainder, IntTy, InsertPt); + const SCEV *IVOperExpr = + SE.getAddExpr(SE.getUnknown(MapIVOper), SE.getUnknown(IncV)); + IVOper = Rewriter.expandCodeFor(IVOperExpr, IVTy, InsertPt); + } else { + IVOper = MapIVOper; + } + + FoundBase = true; + break; + } + } + if (!FoundBase && LeftOverExpr && !LeftOverExpr->isZero()) { // Expand the IV increment. Rewriter.clearPostInc(); Value *IncV = Rewriter.expandCodeFor(LeftOverExpr, IntTy, InsertPt); @@ -3243,6 +3477,7 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, // If an IV increment can't be folded, use it as the next IV value. if (!canFoldIVIncExpr(LeftOverExpr, Inc.UserInst, Inc.IVOperand, TTI)) { assert(IVTy == IVOper->getType() && "inconsistent IV increment type"); + Bases.emplace_back(Accum, IVOper); IVSrc = IVOper; LeftOverExpr = nullptr; } @@ -3377,9 +3612,9 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { } // Get or create an LSRUse. - std::pair<size_t, int64_t> P = getUse(S, Kind, AccessTy); + std::pair<size_t, Immediate> P = getUse(S, Kind, AccessTy); size_t LUIdx = P.first; - int64_t Offset = P.second; + Immediate Offset = P.second; LSRUse &LU = Uses[LUIdx]; // Record the fixup. @@ -3569,10 +3804,10 @@ LSRInstance::CollectLoopInvariantFixupsAndFormulae() { continue; } - std::pair<size_t, int64_t> P = getUse( - S, LSRUse::Basic, MemAccessTy()); + std::pair<size_t, Immediate> P = + getUse(S, LSRUse::Basic, MemAccessTy()); size_t LUIdx = P.first; - int64_t Offset = P.second; + Immediate Offset = P.second; LSRUse &LU = Uses[LUIdx]; LSRFixup &LF = LU.getNewFixup(); LF.UserInst = const_cast<Instruction *>(UserInst); @@ -3728,13 +3963,17 @@ void LSRInstance::GenerateReassociationsImpl(LSRUse &LU, unsigned LUIdx, continue; Formula F = Base; + if (F.UnfoldedOffset.isNonZero() && F.UnfoldedOffset.isScalable()) + continue; + // Add the remaining pieces of the add back into the new formula. const SCEVConstant *InnerSumSC = dyn_cast<SCEVConstant>(InnerSum); if (InnerSumSC && SE.getTypeSizeInBits(InnerSumSC->getType()) <= 64 && - TTI.isLegalAddImmediate((uint64_t)F.UnfoldedOffset + + TTI.isLegalAddImmediate((uint64_t)F.UnfoldedOffset.getFixedValue() + InnerSumSC->getValue()->getZExtValue())) { F.UnfoldedOffset = - (uint64_t)F.UnfoldedOffset + InnerSumSC->getValue()->getZExtValue(); + Immediate::getFixed((uint64_t)F.UnfoldedOffset.getFixedValue() + + InnerSumSC->getValue()->getZExtValue()); if (IsScaledReg) F.ScaledReg = nullptr; else @@ -3747,10 +3986,11 @@ void LSRInstance::GenerateReassociationsImpl(LSRUse &LU, unsigned LUIdx, // Add J as its own register, or an unfolded immediate. const SCEVConstant *SC = dyn_cast<SCEVConstant>(*J); if (SC && SE.getTypeSizeInBits(SC->getType()) <= 64 && - TTI.isLegalAddImmediate((uint64_t)F.UnfoldedOffset + + TTI.isLegalAddImmediate((uint64_t)F.UnfoldedOffset.getFixedValue() + SC->getValue()->getZExtValue())) F.UnfoldedOffset = - (uint64_t)F.UnfoldedOffset + SC->getValue()->getZExtValue(); + Immediate::getFixed((uint64_t)F.UnfoldedOffset.getFixedValue() + + SC->getValue()->getZExtValue()); else F.BaseRegs.push_back(*J); // We may have changed the number of register in base regs, adjust the @@ -3791,7 +4031,8 @@ void LSRInstance::GenerateCombinations(LSRUse &LU, unsigned LUIdx, Formula Base) { // This method is only interesting on a plurality of registers. if (Base.BaseRegs.size() + (Base.Scale == 1) + - (Base.UnfoldedOffset != 0) <= 1) + (Base.UnfoldedOffset.isNonZero()) <= + 1) return; // Flatten the representation, i.e., reg1 + 1*reg2 => reg1 + reg2, before @@ -3840,11 +4081,11 @@ void LSRInstance::GenerateCombinations(LSRUse &LU, unsigned LUIdx, // If we have an unfolded offset, generate a formula combining it with the // registers collected. - if (NewBase.UnfoldedOffset) { + if (NewBase.UnfoldedOffset.isNonZero() && NewBase.UnfoldedOffset.isFixed()) { assert(CombinedIntegerType && "Missing a type for the unfolded offset"); - Ops.push_back(SE.getConstant(CombinedIntegerType, NewBase.UnfoldedOffset, - true)); - NewBase.UnfoldedOffset = 0; + Ops.push_back(SE.getConstant(CombinedIntegerType, + NewBase.UnfoldedOffset.getFixedValue(), true)); + NewBase.UnfoldedOffset = Immediate::getFixed(0); GenerateFormula(SE.getAddExpr(Ops)); } } @@ -3884,15 +4125,18 @@ void LSRInstance::GenerateSymbolicOffsets(LSRUse &LU, unsigned LUIdx, /// Helper function for LSRInstance::GenerateConstantOffsets. void LSRInstance::GenerateConstantOffsetsImpl( LSRUse &LU, unsigned LUIdx, const Formula &Base, - const SmallVectorImpl<int64_t> &Worklist, size_t Idx, bool IsScaledReg) { + const SmallVectorImpl<Immediate> &Worklist, size_t Idx, bool IsScaledReg) { - auto GenerateOffset = [&](const SCEV *G, int64_t Offset) { + auto GenerateOffset = [&](const SCEV *G, Immediate Offset) { Formula F = Base; - F.BaseOffset = (uint64_t)Base.BaseOffset - Offset; + if (!Base.BaseOffset.isCompatibleImmediate(Offset)) + return; + F.BaseOffset = Base.BaseOffset.subUnsigned(Offset); if (isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, F)) { // Add the offset to the base register. - const SCEV *NewG = SE.getAddExpr(SE.getConstant(G->getType(), Offset), G); + const SCEV *NewOffset = Offset.getSCEV(SE, G->getType()); + const SCEV *NewG = SE.getAddExpr(NewOffset, G); // If it cancelled out, drop the base register, otherwise update it. if (NewG->isZero()) { if (IsScaledReg) { @@ -3928,21 +4172,24 @@ void LSRInstance::GenerateConstantOffsetsImpl( int64_t Step = StepInt.isNegative() ? StepInt.getSExtValue() : StepInt.getZExtValue(); - for (int64_t Offset : Worklist) { - Offset -= Step; - GenerateOffset(G, Offset); + for (Immediate Offset : Worklist) { + if (Offset.isFixed()) { + Offset = Immediate::getFixed(Offset.getFixedValue() - Step); + GenerateOffset(G, Offset); + } } } } } - for (int64_t Offset : Worklist) + for (Immediate Offset : Worklist) GenerateOffset(G, Offset); - int64_t Imm = ExtractImmediate(G, SE); - if (G->isZero() || Imm == 0) + Immediate Imm = ExtractImmediate(G, SE); + if (G->isZero() || Imm.isZero() || + !Base.BaseOffset.isCompatibleImmediate(Imm)) return; Formula F = Base; - F.BaseOffset = (uint64_t)F.BaseOffset + Imm; + F.BaseOffset = F.BaseOffset.addUnsigned(Imm); if (!isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, F)) return; if (IsScaledReg) { @@ -3961,7 +4208,7 @@ void LSRInstance::GenerateConstantOffsets(LSRUse &LU, unsigned LUIdx, Formula Base) { // TODO: For now, just add the min and max offset, because it usually isn't // worthwhile looking at everything inbetween. - SmallVector<int64_t, 2> Worklist; + SmallVector<Immediate, 2> Worklist; Worklist.push_back(LU.MinOffset); if (LU.MaxOffset != LU.MinOffset) Worklist.push_back(LU.MaxOffset); @@ -4001,27 +4248,31 @@ void LSRInstance::GenerateICmpZeroScales(LSRUse &LU, unsigned LUIdx, if (!ConstantInt::isValueValidForType(IntTy, Factor)) continue; // Check that the multiplication doesn't overflow. - if (Base.BaseOffset == std::numeric_limits<int64_t>::min() && Factor == -1) + if (Base.BaseOffset.isMin() && Factor == -1) + continue; + // Not supporting scalable immediates. + if (Base.BaseOffset.isNonZero() && Base.BaseOffset.isScalable()) continue; - int64_t NewBaseOffset = (uint64_t)Base.BaseOffset * Factor; + Immediate NewBaseOffset = Base.BaseOffset.mulUnsigned(Factor); assert(Factor != 0 && "Zero factor not expected!"); - if (NewBaseOffset / Factor != Base.BaseOffset) + if (NewBaseOffset.getFixedValue() / Factor != + Base.BaseOffset.getFixedValue()) continue; // If the offset will be truncated at this use, check that it is in bounds. if (!IntTy->isPointerTy() && - !ConstantInt::isValueValidForType(IntTy, NewBaseOffset)) + !ConstantInt::isValueValidForType(IntTy, NewBaseOffset.getFixedValue())) continue; // Check that multiplying with the use offset doesn't overflow. - int64_t Offset = LU.MinOffset; - if (Offset == std::numeric_limits<int64_t>::min() && Factor == -1) + Immediate Offset = LU.MinOffset; + if (Offset.isMin() && Factor == -1) continue; - Offset = (uint64_t)Offset * Factor; - if (Offset / Factor != LU.MinOffset) + Offset = Offset.mulUnsigned(Factor); + if (Offset.getFixedValue() / Factor != LU.MinOffset.getFixedValue()) continue; // If the offset will be truncated at this use, check that it is in bounds. if (!IntTy->isPointerTy() && - !ConstantInt::isValueValidForType(IntTy, Offset)) + !ConstantInt::isValueValidForType(IntTy, Offset.getFixedValue())) continue; Formula F = Base; @@ -4032,7 +4283,7 @@ void LSRInstance::GenerateICmpZeroScales(LSRUse &LU, unsigned LUIdx, continue; // Compensate for the use having MinOffset built into it. - F.BaseOffset = (uint64_t)F.BaseOffset + Offset - LU.MinOffset; + F.BaseOffset = F.BaseOffset.addUnsigned(Offset).subUnsigned(LU.MinOffset); const SCEV *FactorS = SE.getConstant(IntTy, Factor); @@ -4051,16 +4302,16 @@ void LSRInstance::GenerateICmpZeroScales(LSRUse &LU, unsigned LUIdx, } // Check that multiplying with the unfolded offset doesn't overflow. - if (F.UnfoldedOffset != 0) { - if (F.UnfoldedOffset == std::numeric_limits<int64_t>::min() && - Factor == -1) + if (F.UnfoldedOffset.isNonZero()) { + if (F.UnfoldedOffset.isMin() && Factor == -1) continue; - F.UnfoldedOffset = (uint64_t)F.UnfoldedOffset * Factor; - if (F.UnfoldedOffset / Factor != Base.UnfoldedOffset) + F.UnfoldedOffset = F.UnfoldedOffset.mulUnsigned(Factor); + if (F.UnfoldedOffset.getFixedValue() / Factor != + Base.UnfoldedOffset.getFixedValue()) continue; // If the offset will be truncated, check that it is in bounds. - if (!IntTy->isPointerTy() && - !ConstantInt::isValueValidForType(IntTy, F.UnfoldedOffset)) + if (!IntTy->isPointerTy() && !ConstantInt::isValueValidForType( + IntTy, F.UnfoldedOffset.getFixedValue())) continue; } @@ -4103,8 +4354,8 @@ void LSRInstance::GenerateScales(LSRUse &LU, unsigned LUIdx, Formula Base) { } // For an ICmpZero, negating a solitary base register won't lead to // new solutions. - if (LU.Kind == LSRUse::ICmpZero && - !Base.HasBaseReg && Base.BaseOffset == 0 && !Base.BaseGV) + if (LU.Kind == LSRUse::ICmpZero && !Base.HasBaseReg && + Base.BaseOffset.isZero() && !Base.BaseGV) continue; // For each addrec base reg, if its loop is current loop, apply the scale. for (size_t i = 0, e = Base.BaseRegs.size(); i != e; ++i) { @@ -4230,10 +4481,10 @@ namespace { /// structures moving underneath it. struct WorkItem { size_t LUIdx; - int64_t Imm; + Immediate Imm; const SCEV *OrigReg; - WorkItem(size_t LI, int64_t I, const SCEV *R) + WorkItem(size_t LI, Immediate I, const SCEV *R) : LUIdx(LI), Imm(I), OrigReg(R) {} void print(raw_ostream &OS) const; @@ -4257,14 +4508,14 @@ LLVM_DUMP_METHOD void WorkItem::dump() const { /// opportunities between them. void LSRInstance::GenerateCrossUseConstantOffsets() { // Group the registers by their value without any added constant offset. - using ImmMapTy = std::map<int64_t, const SCEV *>; + using ImmMapTy = std::map<Immediate, const SCEV *, KeyOrderTargetImmediate>; DenseMap<const SCEV *, ImmMapTy> Map; DenseMap<const SCEV *, SmallBitVector> UsedByIndicesMap; SmallVector<const SCEV *, 8> Sequence; for (const SCEV *Use : RegUses) { const SCEV *Reg = Use; // Make a copy for ExtractImmediate to modify. - int64_t Imm = ExtractImmediate(Reg, SE); + Immediate Imm = ExtractImmediate(Reg, SE); auto Pair = Map.insert(std::make_pair(Reg, ImmMapTy())); if (Pair.second) Sequence.push_back(Reg); @@ -4276,7 +4527,8 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { // a list of work to do and do the work in a separate step so that we're // not adding formulae and register counts while we're searching. SmallVector<WorkItem, 32> WorkItems; - SmallSet<std::pair<size_t, int64_t>, 32> UniqueItems; + SmallSet<std::pair<size_t, Immediate>, 32, KeyOrderSizeTAndImmediate> + UniqueItems; for (const SCEV *Reg : Sequence) { const ImmMapTy &Imms = Map.find(Reg)->second; @@ -4295,7 +4547,7 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { J != JE; ++J) { const SCEV *OrigReg = J->second; - int64_t JImm = J->first; + Immediate JImm = J->first; const SmallBitVector &UsedByIndices = RegUses.getUsedByIndices(OrigReg); if (!isa<SCEVConstant>(OrigReg) && @@ -4307,22 +4559,34 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { // Conservatively examine offsets between this orig reg a few selected // other orig regs. - int64_t First = Imms.begin()->first; - int64_t Last = std::prev(Imms.end())->first; + Immediate First = Imms.begin()->first; + Immediate Last = std::prev(Imms.end())->first; + if (!First.isCompatibleImmediate(Last)) { + LLVM_DEBUG(dbgs() << "Skipping cross-use reuse for " << *OrigReg + << "\n"); + continue; + } + // Only scalable if both terms are scalable, or if one is scalable and + // the other is 0. + bool Scalable = First.isScalable() || Last.isScalable(); + int64_t FI = First.getKnownMinValue(); + int64_t LI = Last.getKnownMinValue(); // Compute (First + Last) / 2 without overflow using the fact that // First + Last = 2 * (First + Last) + (First ^ Last). - int64_t Avg = (First & Last) + ((First ^ Last) >> 1); - // If the result is negative and First is odd and Last even (or vice versa), + int64_t Avg = (FI & LI) + ((FI ^ LI) >> 1); + // If the result is negative and FI is odd and LI even (or vice versa), // we rounded towards -inf. Add 1 in that case, to round towards 0. - Avg = Avg + ((First ^ Last) & ((uint64_t)Avg >> 63)); + Avg = Avg + ((FI ^ LI) & ((uint64_t)Avg >> 63)); ImmMapTy::const_iterator OtherImms[] = { Imms.begin(), std::prev(Imms.end()), - Imms.lower_bound(Avg)}; + Imms.lower_bound(Immediate::get(Avg, Scalable))}; for (const auto &M : OtherImms) { if (M == J || M == JE) continue; + if (!JImm.isCompatibleImmediate(M->first)) + continue; // Compute the difference between the two. - int64_t Imm = (uint64_t)JImm - M->first; + Immediate Imm = JImm.subUnsigned(M->first); for (unsigned LUIdx : UsedByIndices.set_bits()) // Make a memo of this use, offset, and register tuple. if (UniqueItems.insert(std::make_pair(LUIdx, Imm)).second) @@ -4340,11 +4604,11 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { for (const WorkItem &WI : WorkItems) { size_t LUIdx = WI.LUIdx; LSRUse &LU = Uses[LUIdx]; - int64_t Imm = WI.Imm; + Immediate Imm = WI.Imm; const SCEV *OrigReg = WI.OrigReg; Type *IntTy = SE.getEffectiveSCEVType(OrigReg->getType()); - const SCEV *NegImmS = SE.getSCEV(ConstantInt::get(IntTy, -(uint64_t)Imm)); + const SCEV *NegImmS = Imm.getNegativeSCEV(SE, IntTy); unsigned BitWidth = SE.getTypeSizeInBits(IntTy); // TODO: Use a more targeted data structure. @@ -4357,10 +4621,12 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { F.unscale(); // Use the immediate in the scaled register. if (F.ScaledReg == OrigReg) { - int64_t Offset = (uint64_t)F.BaseOffset + Imm * (uint64_t)F.Scale; + if (!F.BaseOffset.isCompatibleImmediate(Imm)) + continue; + Immediate Offset = F.BaseOffset.addUnsigned(Imm.mulUnsigned(F.Scale)); // Don't create 50 + reg(-50). - if (F.referencesReg(SE.getSCEV( - ConstantInt::get(IntTy, -(uint64_t)Offset)))) + const SCEV *S = Offset.getNegativeSCEV(SE, IntTy); + if (F.referencesReg(S)) continue; Formula NewF = F; NewF.BaseOffset = Offset; @@ -4372,11 +4638,18 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { // If the new scale is a constant in a register, and adding the constant // value to the immediate would produce a value closer to zero than the // immediate itself, then the formula isn't worthwhile. - if (const SCEVConstant *C = dyn_cast<SCEVConstant>(NewF.ScaledReg)) - if (C->getValue()->isNegative() != (NewF.BaseOffset < 0) && + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(NewF.ScaledReg)) { + // FIXME: Do we need to do something for scalable immediates here? + // A scalable SCEV won't be constant, but we might still have + // something in the offset? Bail out for now to be safe. + if (NewF.BaseOffset.isNonZero() && NewF.BaseOffset.isScalable()) + continue; + if (C->getValue()->isNegative() != + (NewF.BaseOffset.isLessThanZero()) && (C->getAPInt().abs() * APInt(BitWidth, F.Scale)) - .ule(std::abs(NewF.BaseOffset))) + .ule(std::abs(NewF.BaseOffset.getFixedValue()))) continue; + } // OK, looks good. NewF.canonicalize(*this->L); @@ -4388,16 +4661,21 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { if (BaseReg != OrigReg) continue; Formula NewF = F; - NewF.BaseOffset = (uint64_t)NewF.BaseOffset + Imm; + if (!NewF.BaseOffset.isCompatibleImmediate(Imm) || + !NewF.UnfoldedOffset.isCompatibleImmediate(Imm) || + !NewF.BaseOffset.isCompatibleImmediate(NewF.UnfoldedOffset)) + continue; + NewF.BaseOffset = NewF.BaseOffset.addUnsigned(Imm); if (!isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, NewF)) { if (AMK == TTI::AMK_PostIndexed && mayUsePostIncMode(TTI, LU, OrigReg, this->L, SE)) continue; - if (!TTI.isLegalAddImmediate((uint64_t)NewF.UnfoldedOffset + Imm)) + Immediate NewUnfoldedOffset = NewF.UnfoldedOffset.addUnsigned(Imm); + if (!isLegalAddImmediate(TTI, NewUnfoldedOffset)) continue; NewF = F; - NewF.UnfoldedOffset = (uint64_t)NewF.UnfoldedOffset + Imm; + NewF.UnfoldedOffset = NewUnfoldedOffset; } NewF.BaseRegs[N] = SE.getAddExpr(NegImmS, BaseReg); @@ -4405,13 +4683,18 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { // constant value to the immediate would produce a value closer to // zero than the immediate itself, then the formula isn't worthwhile. for (const SCEV *NewReg : NewF.BaseRegs) - if (const SCEVConstant *C = dyn_cast<SCEVConstant>(NewReg)) - if ((C->getAPInt() + NewF.BaseOffset) + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(NewReg)) { + if (NewF.BaseOffset.isNonZero() && NewF.BaseOffset.isScalable()) + goto skip_formula; + if ((C->getAPInt() + NewF.BaseOffset.getFixedValue()) .abs() - .slt(std::abs(NewF.BaseOffset)) && - (C->getAPInt() + NewF.BaseOffset).countr_zero() >= - (unsigned)llvm::countr_zero<uint64_t>(NewF.BaseOffset)) + .slt(std::abs(NewF.BaseOffset.getFixedValue())) && + (C->getAPInt() + NewF.BaseOffset.getFixedValue()) + .countr_zero() >= + (unsigned)llvm::countr_zero<uint64_t>( + NewF.BaseOffset.getFixedValue())) goto skip_formula; + } // Ok, looks good. NewF.canonicalize(*this->L); @@ -4595,6 +4878,8 @@ void LSRInstance::NarrowSearchSpaceByDetectingSupersets() { bool Any = false; for (size_t i = 0, e = LU.Formulae.size(); i != e; ++i) { Formula &F = LU.Formulae[i]; + if (F.BaseOffset.isNonZero() && F.BaseOffset.isScalable()) + continue; // Look for a formula with a constant or GV in a register. If the use // also has a formula with that same value in an immediate field, // delete the one that uses a register. @@ -4604,7 +4889,9 @@ void LSRInstance::NarrowSearchSpaceByDetectingSupersets() { Formula NewF = F; //FIXME: Formulas should store bitwidth to do wrapping properly. // See PR41034. - NewF.BaseOffset += (uint64_t)C->getValue()->getSExtValue(); + NewF.BaseOffset = + Immediate::getFixed(NewF.BaseOffset.getFixedValue() + + (uint64_t)C->getValue()->getSExtValue()); NewF.BaseRegs.erase(NewF.BaseRegs.begin() + (I - F.BaseRegs.begin())); if (LU.HasFormulaWithSameRegs(NewF)) { @@ -4660,7 +4947,7 @@ void LSRInstance::NarrowSearchSpaceByCollapsingUnrolledCode() { for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { LSRUse &LU = Uses[LUIdx]; for (const Formula &F : LU.Formulae) { - if (F.BaseOffset == 0 || (F.Scale != 0 && F.Scale != 1)) + if (F.BaseOffset.isZero() || (F.Scale != 0 && F.Scale != 1)) continue; LSRUse *LUThatHas = FindUseWithSimilarFormula(F, LU); @@ -5247,10 +5534,20 @@ void LSRInstance::Solve(SmallVectorImpl<const Formula *> &Solution) const { assert(Solution.size() == Uses.size() && "Malformed solution!"); + const bool EnableDropUnprofitableSolution = [&] { + switch (AllowDropSolutionIfLessProfitable) { + case cl::BOU_TRUE: + return true; + case cl::BOU_FALSE: + return false; + case cl::BOU_UNSET: + return TTI.shouldDropLSRSolutionIfLessProfitable(); + } + llvm_unreachable("Unhandled cl::boolOrDefault enum"); + }(); + if (BaselineCost.isLess(SolutionCost)) { - LLVM_DEBUG(dbgs() << "The baseline solution requires "; - BaselineCost.print(dbgs()); dbgs() << "\n"); - if (!AllowDropSolutionIfLessProfitable) + if (!EnableDropUnprofitableSolution) LLVM_DEBUG( dbgs() << "Baseline is more profitable than chosen solution, " "add option 'lsr-drop-solution' to drop LSR solution.\n"); @@ -5485,31 +5782,36 @@ Value *LSRInstance::Expand(const LSRUse &LU, const LSRFixup &LF, Ops.push_back(SE.getUnknown(FullV)); } + // FIXME: Are we sure we won't get a mismatch here? Is there a way to bail + // out at this point, or should we generate a SCEV adding together mixed + // offsets? + assert(F.BaseOffset.isCompatibleImmediate(LF.Offset) && + "Expanding mismatched offsets\n"); // Expand the immediate portion. - int64_t Offset = (uint64_t)F.BaseOffset + LF.Offset; - if (Offset != 0) { + Immediate Offset = F.BaseOffset.addUnsigned(LF.Offset); + if (Offset.isNonZero()) { if (LU.Kind == LSRUse::ICmpZero) { // The other interesting way of "folding" with an ICmpZero is to use a // negated immediate. if (!ICmpScaledV) - ICmpScaledV = ConstantInt::get(IntTy, -(uint64_t)Offset); + ICmpScaledV = + ConstantInt::get(IntTy, -(uint64_t)Offset.getFixedValue()); else { Ops.push_back(SE.getUnknown(ICmpScaledV)); - ICmpScaledV = ConstantInt::get(IntTy, Offset); + ICmpScaledV = ConstantInt::get(IntTy, Offset.getFixedValue()); } } else { // Just add the immediate values. These again are expected to be matched // as part of the address. - Ops.push_back(SE.getUnknown(ConstantInt::getSigned(IntTy, Offset))); + Ops.push_back(Offset.getUnknownSCEV(SE, IntTy)); } } // Expand the unfolded offset portion. - int64_t UnfoldedOffset = F.UnfoldedOffset; - if (UnfoldedOffset != 0) { + Immediate UnfoldedOffset = F.UnfoldedOffset; + if (UnfoldedOffset.isNonZero()) { // Just add the immediate values. - Ops.push_back(SE.getUnknown(ConstantInt::getSigned(IntTy, - UnfoldedOffset))); + Ops.push_back(UnfoldedOffset.getUnknownSCEV(SE, IntTy)); } // Emit instructions summing all the operands. @@ -5532,10 +5834,9 @@ Value *LSRInstance::Expand(const LSRUse &LU, const LSRFixup &LF, "a scale at the same time!"); if (F.Scale == -1) { if (ICmpScaledV->getType() != OpTy) { - Instruction *Cast = - CastInst::Create(CastInst::getCastOpcode(ICmpScaledV, false, - OpTy, false), - ICmpScaledV, OpTy, "tmp", CI); + Instruction *Cast = CastInst::Create( + CastInst::getCastOpcode(ICmpScaledV, false, OpTy, false), + ICmpScaledV, OpTy, "tmp", CI->getIterator()); ICmpScaledV = Cast; } CI->setOperand(1, ICmpScaledV); @@ -5546,11 +5847,11 @@ Value *LSRInstance::Expand(const LSRUse &LU, const LSRFixup &LF, "ICmp does not support folding a global value and " "a scale at the same time!"); Constant *C = ConstantInt::getSigned(SE.getEffectiveSCEVType(OpTy), - -(uint64_t)Offset); + -(uint64_t)Offset.getFixedValue()); if (C->getType() != OpTy) { C = ConstantFoldCastOperand( CastInst::getCastOpcode(C, false, OpTy, false), C, OpTy, - CI->getModule()->getDataLayout()); + CI->getDataLayout()); assert(C && "Cast of ConstantInt should have folded"); } @@ -5635,11 +5936,10 @@ void LSRInstance::RewriteForPHI( // If this is reuse-by-noop-cast, insert the noop cast. Type *OpTy = LF.OperandValToReplace->getType(); if (FullV->getType() != OpTy) - FullV = - CastInst::Create(CastInst::getCastOpcode(FullV, false, - OpTy, false), - FullV, LF.OperandValToReplace->getType(), - "tmp", BB->getTerminator()); + FullV = CastInst::Create( + CastInst::getCastOpcode(FullV, false, OpTy, false), FullV, + LF.OperandValToReplace->getType(), "tmp", + BB->getTerminator()->getIterator()); // If the incoming block for this value is not in the loop, it means the // current PHI is not in a loop exit, so we must create a LCSSA PHI for @@ -5657,8 +5957,8 @@ void LSRInstance::RewriteForPHI( // formulae will not be implemented completely and some instructions // will not be eliminated. if (needUpdateFixups) { - for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) - for (LSRFixup &Fixup : Uses[LUIdx].Fixups) + for (LSRUse &LU : Uses) + for (LSRFixup &Fixup : LU.Fixups) // If fixup is supposed to rewrite some operand in the phi // that was just updated, it may be already moved to // another phi node. Such fixup requires update. @@ -5711,8 +6011,8 @@ void LSRInstance::Rewrite(const LSRUse &LU, const LSRFixup &LF, Type *OpTy = LF.OperandValToReplace->getType(); if (FullV->getType() != OpTy) { Instruction *Cast = - CastInst::Create(CastInst::getCastOpcode(FullV, false, OpTy, false), - FullV, OpTy, "tmp", LF.UserInst); + CastInst::Create(CastInst::getCastOpcode(FullV, false, OpTy, false), + FullV, OpTy, "tmp", LF.UserInst->getIterator()); FullV = Cast; } @@ -5856,7 +6156,7 @@ LSRInstance::LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE, MSSAU(MSSAU), AMK(PreferredAddresingMode.getNumOccurrences() > 0 ? PreferredAddresingMode : TTI.getPreferredAddressingMode(L, &SE)), - Rewriter(SE, L->getHeader()->getModule()->getDataLayout(), "lsr", false), + Rewriter(SE, L->getHeader()->getDataLayout(), "lsr", false), BaselineCost(L, SE, TTI, AMK) { // If LoopSimplify form is not available, stay out of trouble. if (!L->isLoopSimplifyForm()) @@ -5930,6 +6230,8 @@ LSRInstance::LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE, LLVM_DEBUG(dbgs() << "LSR found " << Uses.size() << " uses:\n"; print_uses(dbgs())); + LLVM_DEBUG(dbgs() << "The baseline solution requires "; + BaselineCost.print(dbgs()); dbgs() << "\n"); // Now use the reuse data to generate a bunch of interesting ways // to formulate the values needed for the uses. @@ -6368,10 +6670,10 @@ struct DVIRecoveryRec { DVIRecoveryRec(DbgValueInst *DbgValue) : DbgRef(DbgValue), Expr(DbgValue->getExpression()), HadLocationArgList(false) {} - DVIRecoveryRec(DPValue *DPV) - : DbgRef(DPV), Expr(DPV->getExpression()), HadLocationArgList(false) {} + DVIRecoveryRec(DbgVariableRecord *DVR) + : DbgRef(DVR), Expr(DVR->getExpression()), HadLocationArgList(false) {} - PointerUnion<DbgValueInst *, DPValue *> DbgRef; + PointerUnion<DbgValueInst *, DbgVariableRecord *> DbgRef; DIExpression *Expr; bool HadLocationArgList; SmallVector<WeakVH, 2> LocationOps; @@ -6467,7 +6769,7 @@ static void UpdateDbgValueInst(DVIRecoveryRec &DVIRec, if (isa<DbgValueInst *>(DVIRec.DbgRef)) UpdateDbgValueInstImpl(cast<DbgValueInst *>(DVIRec.DbgRef)); else - UpdateDbgValueInstImpl(cast<DPValue *>(DVIRec.DbgRef)); + UpdateDbgValueInstImpl(cast<DbgVariableRecord *>(DVIRec.DbgRef)); } /// Cached location ops may be erased during LSR, in which case a poison is @@ -6513,7 +6815,7 @@ static void restorePreTransformState(DVIRecoveryRec &DVIRec) { if (isa<DbgValueInst *>(DVIRec.DbgRef)) RestorePreTransformStateImpl(cast<DbgValueInst *>(DVIRec.DbgRef)); else - RestorePreTransformStateImpl(cast<DPValue *>(DVIRec.DbgRef)); + RestorePreTransformStateImpl(cast<DbgVariableRecord *>(DVIRec.DbgRef)); } static bool SalvageDVI(llvm::Loop *L, ScalarEvolution &SE, @@ -6523,7 +6825,7 @@ static bool SalvageDVI(llvm::Loop *L, ScalarEvolution &SE, if (isa<DbgValueInst *>(DVIRec.DbgRef) ? !cast<DbgValueInst *>(DVIRec.DbgRef)->isKillLocation() - : !cast<DPValue *>(DVIRec.DbgRef)->isKillLocation()) + : !cast<DbgVariableRecord *>(DVIRec.DbgRef)->isKillLocation()) return false; // LSR may have caused several changes to the dbg.value in the failed salvage @@ -6621,7 +6923,7 @@ static bool SalvageDVI(llvm::Loop *L, ScalarEvolution &SE, << *cast<DbgValueInst *>(DVIRec.DbgRef) << "\n"); else LLVM_DEBUG(dbgs() << "scev-salvage: Updated DVI: " - << *cast<DPValue *>(DVIRec.DbgRef) << "\n"); + << *cast<DbgVariableRecord *>(DVIRec.DbgRef) << "\n"); return true; } @@ -6712,9 +7014,9 @@ static void DbgGatherSalvagableDVI( SalvageableDVISCEVs.push_back(std::move(NewRec)); return true; }; - for (auto &DPV : I.getDbgValueRange()) { - if (DPV.isDbgValue() || DPV.isDbgAssign()) - ProcessDbgValue(&DPV); + for (DbgVariableRecord &DVR : filterDbgVars(I.getDbgRecordRange())) { + if (DVR.isDbgValue() || DVR.isDbgAssign()) + ProcessDbgValue(&DVR); } auto DVI = dyn_cast<DbgValueInst>(&I); if (!DVI) @@ -6762,7 +7064,7 @@ static llvm::PHINode *GetInductionVariable(const Loop &L, ScalarEvolution &SE, static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>> canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, - const LoopInfo &LI) { + const LoopInfo &LI, const TargetTransformInfo &TTI) { if (!L->isInnermost()) { LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n"); return std::nullopt; @@ -6808,18 +7110,35 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep)) return std::nullopt; + // Ensure the simple recurrence is a part of the current loop. + if (ToFold->getParent() != L->getHeader()) + return std::nullopt; + // If that IV isn't dead after we rewrite the exit condition in terms of // another IV, there's no point in doing the transform. if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond)) return std::nullopt; + // Inserting instructions in the preheader has a runtime cost, scale + // the allowed cost with the loops trip count as best we can. + const unsigned ExpansionBudget = [&]() { + unsigned Budget = 2 * SCEVCheapExpansionBudget; + if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L)) + return std::min(Budget, SmallTC); + if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L)) + return std::min(Budget, *SmallTC); + // Unknown trip count, assume long running by default. + return Budget; + }(); + const SCEV *BECount = SE.getBackedgeTakenCount(L); - const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + const DataLayout &DL = L->getHeader()->getDataLayout(); SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); PHINode *ToHelpFold = nullptr; const SCEV *TermValueS = nullptr; bool MustDropPoison = false; + auto InsertPt = L->getLoopPreheader()->getTerminator(); for (PHINode &PN : L->getHeader()->phis()) { if (ToFold == &PN) continue; @@ -6861,6 +7180,14 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, continue; } + if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget, + &TTI, InsertPt)) { + LLVM_DEBUG( + dbgs() << "Is too expensive to expand terminating value for phi node" + << PN << "\n"); + continue; + } + // The candidate IV may have been otherwise dead and poison from the // very first iteration. If we can't disprove that, we can't use the IV. if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) { @@ -6941,12 +7268,13 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, Changed |= DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); if (EnablePhiElim && L->isLoopSimplifyForm()) { SmallVector<WeakTrackingVH, 16> DeadInsts; - const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + const DataLayout &DL = L->getHeader()->getDataLayout(); SCEVExpander Rewriter(SE, DL, "lsr", false); #ifndef NDEBUG Rewriter.setDebugType(DEBUG_TYPE); #endif unsigned numFolded = Rewriter.replaceCongruentIVs(L, &DT, DeadInsts, &TTI); + Rewriter.clear(); if (numFolded) { Changed = true; RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts, &TLI, @@ -6961,10 +7289,11 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, // skip the updates in each loop iteration. if (L->isRecursivelyLCSSAForm(DT, LI) && L->getExitBlock()) { SmallVector<WeakTrackingVH, 16> DeadInsts; - const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + const DataLayout &DL = L->getHeader()->getDataLayout(); SCEVExpander Rewriter(SE, DL, "lsr", true); int Rewrites = rewriteLoopExitValues(L, &LI, &TLI, &SE, &TTI, Rewriter, &DT, UnusedIndVarInLoop, DeadInsts); + Rewriter.clear(); if (Rewrites) { Changed = true; RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts, &TLI, @@ -6986,7 +7315,7 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, }(); if (EnableFormTerm) { - if (auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI)) { + if (auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI)) { auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt; Changed = true; @@ -7010,9 +7339,8 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags(); // SCEVExpander for both use in preheader and latch - const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + const DataLayout &DL = L->getHeader()->getDataLayout(); SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); - SCEVExpanderCleaner ExpCleaner(Expander); assert(Expander.isSafeToExpand(TermValueS) && "Terminating value was checked safe in canFoldTerminatingCondition"); @@ -7043,10 +7371,9 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, BI->setCondition(NewTermCond); + Expander.clear(); OldTermCond->eraseFromParent(); DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); - - ExpCleaner.markResultUsed(); } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp index 7b4c54370e48..f8e2f1f28088 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp @@ -327,8 +327,7 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, UnrollCostEstimator OuterUCE(L, TTI, EphValues, UP.BEInsns); if (!InnerUCE.canUnroll() || !OuterUCE.canUnroll()) { - LLVM_DEBUG(dbgs() << " Not unrolling loop which contains instructions" - << " which cannot be duplicated or have invalid cost.\n"); + LLVM_DEBUG(dbgs() << " Loop not considered unrollable\n"); return LoopUnrollResult::Unmodified; } @@ -341,7 +340,10 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, LLVM_DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n"); return LoopUnrollResult::Unmodified; } - if (InnerUCE.Convergent || OuterUCE.Convergent) { + // FIXME: The call to canUnroll() allows some controlled convergent + // operations, but we block them here for future changes. + if (InnerUCE.Convergence != ConvergenceKind::None || + OuterUCE.Convergence != ConvergenceKind::None) { LLVM_DEBUG( dbgs() << " Not unrolling loop with convergent instructions.\n"); return LoopUnrollResult::Unmodified; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp index 7cfeb019af97..cbc35b6dd429 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" @@ -27,6 +28,7 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/LoopUnrollAnalyzer.h" +#include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -173,6 +175,10 @@ static cl::opt<unsigned> cl::desc("Default threshold (max size of unrolled " "loop), used in all but O3 optimizations")); +static cl::opt<unsigned> PragmaUnrollFullMaxIterations( + "pragma-unroll-full-max-iterations", cl::init(1'000'000), cl::Hidden, + cl::desc("Maximum allowed iterations to unroll under pragma unroll full.")); + /// A magic value for use with the Threshold parameter to indicate /// that the loop unroll should be performed regardless of how much /// code expansion would result. @@ -446,7 +452,15 @@ static std::optional<EstimatedUnrollCost> analyzeLoopUnrollCost( // First accumulate the cost of this instruction. if (!Cost.IsFree) { - UnrolledCost += TTI.getInstructionCost(I, CostKind); + // Consider simplified operands in instruction cost. + SmallVector<Value *, 4> Operands; + transform(I->operands(), std::back_inserter(Operands), + [&](Value *Op) { + if (auto Res = SimplifiedValues.lookup(Op)) + return Res; + return Op; + }); + UnrolledCost += TTI.getInstructionCost(I, Operands, CostKind); LLVM_DEBUG(dbgs() << "Adding cost of instruction (iteration " << Iteration << "): "); LLVM_DEBUG(I->dump()); @@ -670,11 +684,15 @@ UnrollCostEstimator::UnrollCostEstimator( const SmallPtrSetImpl<const Value *> &EphValues, unsigned BEInsns) { CodeMetrics Metrics; for (BasicBlock *BB : L->blocks()) - Metrics.analyzeBasicBlock(BB, TTI, EphValues); + Metrics.analyzeBasicBlock(BB, TTI, EphValues, /* PrepareForLTO= */ false, + L); NumInlineCandidates = Metrics.NumInlineCandidates; NotDuplicatable = Metrics.notDuplicatable; - Convergent = Metrics.convergent; + Convergence = Metrics.Convergence; LoopSize = Metrics.NumInsts; + ConvergenceAllowsRuntime = + Metrics.Convergence != ConvergenceKind::Uncontrolled && + !getLoopConvergenceHeart(L); // Don't allow an estimate of size zero. This would allows unrolling of loops // with huge iteration counts, which is a compile time problem even if it's @@ -687,6 +705,25 @@ UnrollCostEstimator::UnrollCostEstimator( LoopSize = BEInsns + 1; } +bool UnrollCostEstimator::canUnroll() const { + switch (Convergence) { + case ConvergenceKind::ExtendedLoop: + LLVM_DEBUG(dbgs() << " Convergence prevents unrolling.\n"); + return false; + default: + break; + } + if (!LoopSize.isValid()) { + LLVM_DEBUG(dbgs() << " Invalid loop size prevents unrolling.\n"); + return false; + } + if (NotDuplicatable) { + LLVM_DEBUG(dbgs() << " Non-duplicatable blocks prevent unrolling.\n"); + return false; + } + return true; +} + uint64_t UnrollCostEstimator::getUnrolledLoopSize( const TargetTransformInfo::UnrollingPreferences &UP, unsigned CountOverwrite) const { @@ -776,8 +813,17 @@ shouldPragmaUnroll(Loop *L, const PragmaInfo &PInfo, return PInfo.PragmaCount; } - if (PInfo.PragmaFullUnroll && TripCount != 0) + if (PInfo.PragmaFullUnroll && TripCount != 0) { + // Certain cases with UBSAN can cause trip count to be calculated as + // INT_MAX, Block full unrolling at a reasonable limit so that the compiler + // doesn't hang trying to unroll the loop. See PR77842 + if (TripCount > PragmaUnrollFullMaxIterations) { + LLVM_DEBUG(dbgs() << "Won't unroll; trip count is too large\n"); + return std::nullopt; + } + return TripCount; + } if (PInfo.PragmaEnableUnroll && !TripCount && MaxTripCount && MaxTripCount <= UP.MaxUpperBound) @@ -1119,7 +1165,8 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, std::optional<bool> ProvidedUpperBound, std::optional<bool> ProvidedAllowPeeling, std::optional<bool> ProvidedAllowProfileBasedPeeling, - std::optional<unsigned> ProvidedFullUnrollMaxCount) { + std::optional<unsigned> ProvidedFullUnrollMaxCount, + AAResults *AA = nullptr) { LLVM_DEBUG(dbgs() << "Loop Unroll: F[" << L->getHeader()->getParent()->getName() << "] Loop %" @@ -1182,8 +1229,7 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns); if (!UCE.canUnroll()) { - LLVM_DEBUG(dbgs() << " Not unrolling loop which contains instructions" - << " which cannot be duplicated or have invalid cost.\n"); + LLVM_DEBUG(dbgs() << " Loop not considered unrollable.\n"); return LoopUnrollResult::Unmodified; } @@ -1230,15 +1276,9 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, // is unsafe -- it adds a control-flow dependency to the convergent // operation. Therefore restrict remainder loop (try unrolling without). // - // TODO: This is quite conservative. In practice, convergent_op() - // is likely to be called unconditionally in the loop. In this - // case, the program would be ill-formed (on most architectures) - // unless n were the same on all threads in a thread group. - // Assuming n is the same on all threads, any kind of unrolling is - // safe. But currently llvm's notion of convergence isn't powerful - // enough to express this. - if (UCE.Convergent) - UP.AllowRemainder = false; + // TODO: This is somewhat conservative; we could allow the remainder if the + // trip count is uniform. + UP.AllowRemainder &= UCE.ConvergenceAllowsRuntime; // Try to find the trip count upper bound if we cannot find the exact trip // count. @@ -1258,6 +1298,8 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, if (!UP.Count) return LoopUnrollResult::Unmodified; + UP.Runtime &= UCE.ConvergenceAllowsRuntime; + if (PP.PeelCount) { assert(UP.Count == 1 && "Cannot perform peel and unroll in the same step"); LLVM_DEBUG(dbgs() << "PEELING loop %" << L->getHeader()->getName() @@ -1271,7 +1313,7 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, ValueToValueMapTy VMap; if (peelLoop(L, PP.PeelCount, LI, &SE, DT, &AC, PreserveLCSSA, VMap)) { - simplifyLoopAfterUnroll(L, true, LI, &SE, &DT, &AC, &TTI); + simplifyLoopAfterUnroll(L, true, LI, &SE, &DT, &AC, &TTI, nullptr); // If the loop was peeled, we already "used up" the profile information // we had, so we don't want to unroll or peel again. if (PP.PeelProfiledIterations) @@ -1282,7 +1324,7 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, } // Do not attempt partial/runtime unrolling in FullLoopUnrolling - if (OnlyFullUnroll && !(UP.Count >= MaxTripCount)) { + if (OnlyFullUnroll && (UP.Count < TripCount || UP.Count < MaxTripCount)) { LLVM_DEBUG( dbgs() << "Not attempting partial/runtime unroll in FullLoopUnroll.\n"); return LoopUnrollResult::Unmodified; @@ -1300,11 +1342,16 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, // Unroll the loop. Loop *RemainderLoop = nullptr; + UnrollLoopOptions ULO; + ULO.Count = UP.Count; + ULO.Force = UP.Force; + ULO.AllowExpensiveTripCount = UP.AllowExpensiveTripCount; + ULO.UnrollRemainder = UP.UnrollRemainder; + ULO.Runtime = UP.Runtime; + ULO.ForgetAllSCEV = ForgetAllSCEV; + ULO.Heart = getLoopConvergenceHeart(L); LoopUnrollResult UnrollResult = UnrollLoop( - L, - {UP.Count, UP.Force, UP.Runtime, UP.AllowExpensiveTripCount, - UP.UnrollRemainder, ForgetAllSCEV}, - LI, &SE, &DT, &AC, &TTI, &ORE, PreserveLCSSA, &RemainderLoop); + L, ULO, LI, &SE, &DT, &AC, &TTI, &ORE, PreserveLCSSA, &RemainderLoop, AA); if (UnrollResult == LoopUnrollResult::Unmodified) return LoopUnrollResult::Unmodified; @@ -1551,6 +1598,7 @@ PreservedAnalyses LoopUnrollPass::run(Function &F, auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + AAResults &AA = AM.getResult<AAManager>(F); LoopAnalysisManager *LAM = nullptr; if (auto *LAMProxy = AM.getCachedResult<LoopAnalysisManagerFunctionProxy>(F)) @@ -1606,7 +1654,8 @@ PreservedAnalyses LoopUnrollPass::run(Function &F, /*Count*/ std::nullopt, /*Threshold*/ std::nullopt, UnrollOpts.AllowPartial, UnrollOpts.AllowRuntime, UnrollOpts.AllowUpperBound, LocalAllowPeeling, - UnrollOpts.AllowProfileBasedPeeling, UnrollOpts.FullUnrollMaxCount); + UnrollOpts.AllowProfileBasedPeeling, UnrollOpts.FullUnrollMaxCount, + &AA); Changed |= Result != LoopUnrollResult::Unmodified; // The parent must not be damaged by unrolling! diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp index f39c24484840..663715948241 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -582,7 +582,7 @@ PreservedAnalyses LoopVersioningLICMPass::run(Loop &L, LoopAnalysisManager &AM, const Function *F = L.getHeader()->getParent(); OptimizationRemarkEmitter ORE(F); - LoopAccessInfoManager LAIs(*SE, *AA, *DT, LAR.LI, nullptr); + LoopAccessInfoManager LAIs(*SE, *AA, *DT, LAR.LI, nullptr, nullptr); if (!LoopVersioningLICM(AA, SE, &ORE, LAIs, LAR.LI, &L).run(DT)) return PreservedAnalyses::all(); return getLoopPassPreservedAnalyses(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerAtomicPass.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerAtomicPass.cpp index 6aba913005d0..b42d3b2bc09a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerAtomicPass.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerAtomicPass.cpp @@ -20,7 +20,7 @@ #include "llvm/Transforms/Utils/LowerAtomic.h" using namespace llvm; -#define DEBUG_TYPE "loweratomic" +#define DEBUG_TYPE "lower-atomic" static bool LowerFenceInst(FenceInst *FI) { FI->eraseFromParent(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp index b167120a906d..bd7895feb64a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp @@ -85,8 +85,11 @@ static bool replaceConditionalBranchesOnConstant(Instruction *II, if (Target && Target != Other) { BasicBlock *Source = BI->getParent(); Other->removePredecessor(Source); + + Instruction *NewBI = BranchInst::Create(Target, Source); + NewBI->setDebugLoc(BI->getDebugLoc()); BI->eraseFromParent(); - BranchInst::Create(Target, Source); + if (DTU) DTU->applyUpdates({{DominatorTree::Delete, Source, Other}}); if (pred_empty(Other)) @@ -103,7 +106,7 @@ static bool lowerConstantIntrinsics(Function &F, const TargetLibraryInfo &TLI, DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy); bool HasDeadBlocks = false; - const auto &DL = F.getParent()->getDataLayout(); + const auto &DL = F.getDataLayout(); SmallVector<WeakTrackingVH, 8> Worklist; ReversePostOrderTraversal<Function *> RPOT(&F); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp index 6f87e4d91d2c..17c5a4ee1fd0 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp @@ -102,7 +102,7 @@ static bool handleSwitchExpect(SwitchInst &SI) { misexpect::checkExpectAnnotations(SI, Weights, /*IsFrontend=*/true); SI.setCondition(ArgValue); - setBranchWeights(SI, Weights); + setBranchWeights(SI, Weights, /*IsExpected=*/true); return true; } @@ -262,11 +262,13 @@ static void handlePhiDef(CallInst *Expect) { if (IsOpndComingFromSuccessor(BI->getSuccessor(1))) BI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(LikelyBranchWeightVal, - UnlikelyBranchWeightVal)); + UnlikelyBranchWeightVal, + /*IsExpected=*/true)); else if (IsOpndComingFromSuccessor(BI->getSuccessor(0))) BI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(UnlikelyBranchWeightVal, - LikelyBranchWeightVal)); + LikelyBranchWeightVal, + /*IsExpected=*/true)); } } @@ -331,12 +333,12 @@ template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) { SmallVector<uint32_t, 4> ExpectedWeights; if ((ExpectedValue->getZExtValue() == ValueComparedTo) == (Predicate == CmpInst::ICMP_EQ)) { - Node = - MDB.createBranchWeights(LikelyBranchWeightVal, UnlikelyBranchWeightVal); + Node = MDB.createBranchWeights( + LikelyBranchWeightVal, UnlikelyBranchWeightVal, /*IsExpected=*/true); ExpectedWeights = {LikelyBranchWeightVal, UnlikelyBranchWeightVal}; } else { - Node = - MDB.createBranchWeights(UnlikelyBranchWeightVal, LikelyBranchWeightVal); + Node = MDB.createBranchWeights(UnlikelyBranchWeightVal, + LikelyBranchWeightVal, /*IsExpected=*/true); ExpectedWeights = {UnlikelyBranchWeightVal, LikelyBranchWeightVal}; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 72b9db1e73d7..6a681fd93397 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -19,6 +19,7 @@ #include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -192,6 +193,109 @@ Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, return VecStart; } +namespace { +struct ShapeInfo { + unsigned NumRows; + unsigned NumColumns; + + bool IsColumnMajor; + + ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) + : NumRows(NumRows), NumColumns(NumColumns), + IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} + + ShapeInfo(Value *NumRows, Value *NumColumns) + : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), + cast<ConstantInt>(NumColumns)->getZExtValue()) {} + + bool operator==(const ShapeInfo &other) { + return NumRows == other.NumRows && NumColumns == other.NumColumns; + } + bool operator!=(const ShapeInfo &other) { return !(*this == other); } + + /// Returns true if shape-information is defined, meaning both dimensions + /// are != 0. + operator bool() const { + assert(NumRows == 0 || NumColumns != 0); + return NumRows != 0; + } + + unsigned getStride() const { + if (IsColumnMajor) + return NumRows; + return NumColumns; + } + + unsigned getNumVectors() const { + if (IsColumnMajor) + return NumColumns; + return NumRows; + } + + /// Returns the transposed shape. + ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); } +}; +} // namespace + +static bool isUniformShape(Value *V) { + Instruction *I = dyn_cast<Instruction>(V); + if (!I) + return true; + + switch (I->getOpcode()) { + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: // Scalar multiply. + case Instruction::FNeg: + case Instruction::Add: + case Instruction::Mul: + case Instruction::Sub: + return true; + default: + return false; + } +} + +/// Return the ShapeInfo for the result of \p I, it it can be determined. +static std::optional<ShapeInfo> +computeShapeInfoForInst(Instruction *I, + const ValueMap<Value *, ShapeInfo> &ShapeMap) { + Value *M; + Value *N; + Value *K; + if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>( + m_Value(), m_Value(), m_Value(M), m_Value(N), m_Value(K)))) + return ShapeInfo(M, K); + if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(), m_Value(M), + m_Value(N)))) { + // Flip dimensions. + return ShapeInfo(N, M); + } + if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>( + m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M), + m_Value(N)))) + return ShapeInfo(N, M); + if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>( + m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N)))) + return ShapeInfo(M, N); + Value *MatrixA; + if (match(I, m_Store(m_Value(MatrixA), m_Value()))) { + auto OpShape = ShapeMap.find(MatrixA); + if (OpShape != ShapeMap.end()) + return OpShape->second; + } + + if (isUniformShape(I)) { + // Find the first operand that has a known shape and use that. + for (auto &Op : I->operands()) { + auto OpShape = ShapeMap.find(Op.get()); + if (OpShape != ShapeMap.end()) + return OpShape->second; + } + } + return std::nullopt; +} + /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. /// /// Currently, the lowering for each matrix intrinsic is done as follows: @@ -383,48 +487,6 @@ class LowerMatrixIntrinsics { } }; - struct ShapeInfo { - unsigned NumRows; - unsigned NumColumns; - - bool IsColumnMajor; - - ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) - : NumRows(NumRows), NumColumns(NumColumns), - IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} - - ShapeInfo(Value *NumRows, Value *NumColumns) - : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), - cast<ConstantInt>(NumColumns)->getZExtValue()) {} - - bool operator==(const ShapeInfo &other) { - return NumRows == other.NumRows && NumColumns == other.NumColumns; - } - bool operator!=(const ShapeInfo &other) { return !(*this == other); } - - /// Returns true if shape-information is defined, meaning both dimensions - /// are != 0. - operator bool() const { - assert(NumRows == 0 || NumColumns != 0); - return NumRows != 0; - } - - unsigned getStride() const { - if (IsColumnMajor) - return NumRows; - return NumColumns; - } - - unsigned getNumVectors() const { - if (IsColumnMajor) - return NumColumns; - return NumRows; - } - - /// Returns the transposed shape. - ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); } - }; - /// Maps instructions to their shape information. The shape information /// describes the shape to be used while lowering. This matches the shape of /// the result value of the instruction, with the only exceptions being store @@ -459,7 +521,7 @@ public: LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI, OptimizationRemarkEmitter *ORE) - : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), + : Func(F), DL(F.getDataLayout()), TTI(TTI), AA(AA), DT(DT), LI(LI), ORE(ORE) {} unsigned getNumOps(Type *VT) { @@ -554,25 +616,6 @@ public: return true; } - bool isUniformShape(Value *V) { - Instruction *I = dyn_cast<Instruction>(V); - if (!I) - return true; - - switch (I->getOpcode()) { - case Instruction::FAdd: - case Instruction::FSub: - case Instruction::FMul: // Scalar multiply. - case Instruction::FNeg: - case Instruction::Add: - case Instruction::Mul: - case Instruction::Sub: - return true; - default: - return false; - } - } - /// Returns true if shape information can be used for \p V. The supported /// instructions must match the instructions that can be lowered by this pass. bool supportsShapeInfo(Value *V) { @@ -610,43 +653,8 @@ public: // New entry, set the value and insert operands bool Propagate = false; - - Value *MatrixA; - Value *MatrixB; - Value *M; - Value *N; - Value *K; - if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>( - m_Value(MatrixA), m_Value(MatrixB), m_Value(M), - m_Value(N), m_Value(K)))) { - Propagate = setShapeInfo(Inst, {M, K}); - } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>( - m_Value(MatrixA), m_Value(M), m_Value(N)))) { - // Flip dimensions. - Propagate = setShapeInfo(Inst, {N, M}); - } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>( - m_Value(MatrixA), m_Value(), m_Value(), - m_Value(), m_Value(M), m_Value(N)))) { - Propagate = setShapeInfo(Inst, {N, M}); - } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( - m_Value(), m_Value(), m_Value(), m_Value(M), - m_Value(N)))) { - Propagate = setShapeInfo(Inst, {M, N}); - } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) { - auto OpShape = ShapeMap.find(MatrixA); - if (OpShape != ShapeMap.end()) - setShapeInfo(Inst, OpShape->second); - continue; - } else if (isUniformShape(Inst)) { - // Find the first operand that has a known shape and use that. - for (auto &Op : Inst->operands()) { - auto OpShape = ShapeMap.find(Op.get()); - if (OpShape != ShapeMap.end()) { - Propagate |= setShapeInfo(Inst, OpShape->second); - break; - } - } - } + if (auto SI = computeShapeInfoForInst(Inst, ShapeMap)) + Propagate = setShapeInfo(Inst, *SI); if (Propagate) { NewWorkList.push_back(Inst); @@ -891,20 +899,28 @@ public: updateShapeAndReplaceAllUsesWith(I, NewInst); CleanupBinOp(I, A, B); } - // A^t + B ^t -> (A + B)^t + // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If + // the shape of the second transpose is different, there's a shape conflict + // which gets resolved by picking the shape of the first operand. else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) && match(A, m_Intrinsic<Intrinsic::matrix_transpose>( m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) && match(B, m_Intrinsic<Intrinsic::matrix_transpose>( - m_Value(BT), m_ConstantInt(R), m_ConstantInt(C)))) { + m_Value(BT), m_ConstantInt(), m_ConstantInt()))) { IRBuilder<> Builder(&I); - Value *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd")); - setShapeInfo(Add, {C, R}); + auto *Add = cast<Instruction>(Builder.CreateFAdd(AT, BT, "mfadd")); + setShapeInfo(Add, {R, C}); MatrixBuilder MBuilder(Builder); Instruction *NewInst = MBuilder.CreateMatrixTranspose( - Add, C->getZExtValue(), R->getZExtValue(), "mfadd_t"); + Add, R->getZExtValue(), C->getZExtValue(), "mfadd_t"); updateShapeAndReplaceAllUsesWith(I, NewInst); + assert(computeShapeInfoForInst(NewInst, ShapeMap) == + computeShapeInfoForInst(&I, ShapeMap) && + "Shape of new instruction doesn't match original shape."); CleanupBinOp(I, A, B); + assert(computeShapeInfoForInst(Add, ShapeMap).value_or(ShapeMap[Add]) == + ShapeMap[Add] && + "Shape of updated addition doesn't match cached shape."); } } @@ -975,12 +991,15 @@ public: bool Changed = false; SmallVector<CallInst *, 16> MaybeFusableInsts; SmallVector<Instruction *, 16> MatrixInsts; + SmallVector<IntrinsicInst *, 16> LifetimeEnds; // First, collect all instructions with shape information and candidates for // fusion (currently only matrix multiplies). ReversePostOrderTraversal<Function *> RPOT(&Func); for (auto *BB : RPOT) for (Instruction &I : *BB) { + if (match(&I, m_Intrinsic<Intrinsic::lifetime_end>())) + LifetimeEnds.push_back(cast<IntrinsicInst>(&I)); if (ShapeMap.find(&I) == ShapeMap.end()) continue; if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) @@ -995,7 +1014,7 @@ public: // Third, try to fuse candidates. for (CallInst *CI : MaybeFusableInsts) - LowerMatrixMultiplyFused(CI, FusedInsts); + LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds); Changed = !FusedInsts.empty(); @@ -1332,8 +1351,8 @@ public: if (!IsIntVec && !FMF.allowReassoc()) return; - auto CanBeFlattened = [this](Value *Op) { - if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) + auto CanBeFlattened = [](Value *Op) { + if (match(Op, m_BinOp())) return true; return match( Op, m_OneUse(m_CombineOr( @@ -1346,6 +1365,9 @@ public: // the returned cost is < 0, the argument is cheaper to use in the // dot-product lowering. auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) { + if (ShapeMap.find(Op) == ShapeMap.end()) + return InstructionCost::getInvalid(); + if (!isa<Instruction>(Op)) return InstructionCost(0); @@ -1356,7 +1378,7 @@ public: InstructionCost EmbedCost(0); // Roughly estimate the cost for embedding the columns into a vector. for (unsigned I = 1; I < N; ++I) - EmbedCost -= + EmbedCost += TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), std::nullopt, TTI::TCK_RecipThroughput); return EmbedCost; @@ -1378,7 +1400,7 @@ public: // vector. InstructionCost EmbedCost(0); for (unsigned I = 1; I < N; ++I) - EmbedCost += + EmbedCost -= TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), std::nullopt, TTI::TCK_RecipThroughput); return EmbedCost; @@ -1391,7 +1413,29 @@ public: return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) - N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0); }; - auto LHSCost = GetCostForArg(LHS, LShape.NumColumns); + + // Iterate over LHS and operations feeding LHS and check if it is profitable + // to flatten the visited ops. For each op, we compute the difference + // between the flattened and matrix versions. + SmallPtrSet<Value *, 4> Seen; + SmallVector<Value *> WorkList; + SmallVector<Value *> ToFlatten; + WorkList.push_back(LHS); + InstructionCost LHSCost(0); + while (!WorkList.empty()) { + Value *Op = WorkList.pop_back_val(); + if (!Seen.insert(Op).second) + continue; + + InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns); + if (OpCost + LHSCost >= LHSCost) + continue; + + LHSCost += OpCost; + ToFlatten.push_back(Op); + if (auto *I = dyn_cast<Instruction>(Op)) + WorkList.append(I->op_begin(), I->op_end()); + } // We compare the costs of a vector.reduce.add to sequential add. int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd; @@ -1412,16 +1456,16 @@ public: FusedInsts.insert(MatMul); IRBuilder<> Builder(MatMul); auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened, - this](Value *Op) -> Value * { + this](Value *Op) { // Matmul must be the only user of loads because we don't use LowerLoad // for row vectors (LowerLoad results in scalar loads and shufflevectors // instead of single vector load). if (!CanBeFlattened(Op)) - return Op; + return; if (match(Op, m_BinOp()) && ShapeMap.find(Op) != ShapeMap.end()) { ShapeMap[Op] = ShapeMap[Op].t(); - return Op; + return; } FusedInsts.insert(cast<Instruction>(Op)); @@ -1432,16 +1476,19 @@ public: auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg); Op->replaceAllUsesWith(NewLoad); cast<Instruction>(Op)->eraseFromParent(); - return NewLoad; + return; } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>( m_Value(Arg)))) { ToRemove.push_back(cast<Instruction>(Op)); - return Arg; + Op->replaceAllUsesWith(Arg); + return; } - - return Op; }; - LHS = FlattenArg(LHS); + + for (auto *V : ToFlatten) + FlattenArg(V); + + LHS = MatMul->getArgOperand(0); // Insert mul/fmul and llvm.vector.reduce.fadd Value *Mul = @@ -1594,7 +1641,7 @@ public: IRBuilder<> Builder(MatMul); Check0->getTerminator()->eraseFromParent(); Builder.SetInsertPoint(Check0); - Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); + Type *IntPtrTy = Builder.getIntPtrTy(Load->getDataLayout()); Value *StoreBegin = Builder.CreatePtrToInt( const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); Value *StoreEnd = Builder.CreateAdd( @@ -1813,8 +1860,10 @@ public: /// /// Call finalizeLowering on lowered instructions. Instructions that are /// completely eliminated by fusion are added to \p FusedInsts. - void LowerMatrixMultiplyFused(CallInst *MatMul, - SmallPtrSetImpl<Instruction *> &FusedInsts) { + void + LowerMatrixMultiplyFused(CallInst *MatMul, + SmallPtrSetImpl<Instruction *> &FusedInsts, + SmallVector<IntrinsicInst *, 16> &LifetimeEnds) { if (!FuseMatrix || !DT) return; @@ -1903,6 +1952,55 @@ public: for (Instruction *I : ToHoist) I->moveBefore(MatMul); + // Deal with lifetime.end calls that might be between Load0/Load1 and the + // store. To avoid introducing loads to dead objects (i.e. after the + // lifetime has been termined by @llvm.lifetime.end), either sink them + // after the store if in the same block, or remove the lifetime.end marker + // otherwise. This might pessimize further optimizations, by extending the + // lifetime of the object until the function returns, but should be + // conservatively correct. + MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0); + MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1); + BasicBlock *StoreParent = Store->getParent(); + bool FusableOpsInSameBlock = LoadOp0->getParent() == StoreParent && + LoadOp1->getParent() == StoreParent; + for (unsigned Idx = 0; Idx != LifetimeEnds.size();) { + IntrinsicInst *End = LifetimeEnds[Idx]; + auto Inc = make_scope_exit([&Idx]() { Idx++; }); + // If the lifetime.end is guaranteed to be before the loads or after the + // store, it won't interfere with fusion. + if (DT->dominates(End, LoadOp0) && DT->dominates(End, LoadOp1)) + continue; + if (DT->dominates(Store, End)) + continue; + // If all fusable ops are in the same block and the lifetime.end is in a + // different block, it won't interfere with fusion. + if (FusableOpsInSameBlock && End->getParent() != StoreParent) + continue; + + // If the loads don't alias the lifetime.end, it won't interfere with + // fusion. + MemoryLocation EndLoc = MemoryLocation::getForArgument(End, 1, nullptr); + if (!EndLoc.Ptr) + continue; + if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc)) + continue; + + // If both lifetime.end and the store are in the same block, extend the + // lifetime until after the store, so the new lifetime covers the loads + // we introduce later. + if (End->getParent() == StoreParent) { + End->moveAfter(Store); + continue; + } + + // Otherwise remove the conflicting lifetime.end marker. + ToRemove.push_back(End); + std::swap(LifetimeEnds[Idx], LifetimeEnds.back()); + LifetimeEnds.pop_back(); + Inc.release(); + } + emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); return; } @@ -2364,7 +2462,7 @@ public: RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, OptimizationRemarkEmitter &ORE, Function &Func) : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), - DL(Func.getParent()->getDataLayout()) {} + DL(Func.getDataLayout()) {} /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are /// instructions in Inst2Matrix returning void or without any users in diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MakeGuardsExplicit.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MakeGuardsExplicit.cpp index 78e474f925b5..aea17aa82a88 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MakeGuardsExplicit.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MakeGuardsExplicit.cpp @@ -36,6 +36,7 @@ #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Transforms/Utils/GuardUtils.h" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp index 805bbe40bd7c..cee34f0a6da1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/Scalar/MemCpyOptimizer.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/iterator_range.h" @@ -99,7 +100,7 @@ struct MemsetRange { MaybeAlign Alignment; /// TheStores - The actual stores that make up this range. - SmallVector<Instruction*, 16> TheStores; + SmallVector<Instruction *, 16> TheStores; bool isProfitableToUseMemset(const DataLayout &DL) const; }; @@ -108,10 +109,12 @@ struct MemsetRange { bool MemsetRange::isProfitableToUseMemset(const DataLayout &DL) const { // If we found more than 4 stores to merge or 16 bytes, use memset. - if (TheStores.size() >= 4 || End-Start >= 16) return true; + if (TheStores.size() >= 4 || End - Start >= 16) + return true; // If there is nothing to merge, don't do anything. - if (TheStores.size() < 2) return false; + if (TheStores.size() < 2) + return false; // If any of the stores are a memset, then it is always good to extend the // memset. @@ -121,7 +124,8 @@ bool MemsetRange::isProfitableToUseMemset(const DataLayout &DL) const { // Assume that the code generator is capable of merging pairs of stores // together if it wants to. - if (TheStores.size() == 2) return false; + if (TheStores.size() == 2) + return false; // If we have fewer than 8 stores, it can still be worthwhile to do this. // For example, merging 4 i8 stores into an i32 store is useful almost always. @@ -133,7 +137,7 @@ bool MemsetRange::isProfitableToUseMemset(const DataLayout &DL) const { // the maximum GPR width is the same size as the largest legal integer // size. If so, check to see whether we will end up actually reducing the // number of stores used. - unsigned Bytes = unsigned(End-Start); + unsigned Bytes = unsigned(End - Start); unsigned MaxIntSize = DL.getLargestLegalIntTypeSizeInBits() / 8; if (MaxIntSize == 0) MaxIntSize = 1; @@ -145,7 +149,7 @@ bool MemsetRange::isProfitableToUseMemset(const DataLayout &DL) const { // If we will reduce the # stores (according to this heuristic), do the // transformation. This encourages merging 4 x i8 -> i32 and 2 x i16 -> i32 // etc. - return TheStores.size() > NumPointerStores+NumByteStores; + return TheStores.size() > NumPointerStores + NumByteStores; } namespace { @@ -197,7 +201,7 @@ public: /// existing ranges as appropriate. void MemsetRanges::addRange(int64_t Start, int64_t Size, Value *Ptr, MaybeAlign Alignment, Instruction *Inst) { - int64_t End = Start+Size; + int64_t End = Start + Size; range_iterator I = partition_point( Ranges, [=](const MemsetRange &O) { return O.End < Start; }); @@ -207,10 +211,10 @@ void MemsetRanges::addRange(int64_t Start, int64_t Size, Value *Ptr, // to insert a new range. Handle this now. if (I == Ranges.end() || End < I->Start) { MemsetRange &R = *Ranges.insert(I, MemsetRange()); - R.Start = Start; - R.End = End; - R.StartPtr = Ptr; - R.Alignment = Alignment; + R.Start = Start; + R.End = End; + R.StartPtr = Ptr; + R.Alignment = Alignment; R.TheStores.push_back(Inst); return; } @@ -354,7 +358,7 @@ static void combineAAMetadata(Instruction *ReplInst, Instruction *I) { Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, Value *StartPtr, Value *ByteVal) { - const DataLayout &DL = StartInst->getModule()->getDataLayout(); + const DataLayout &DL = StartInst->getDataLayout(); // We can't track scalable types if (auto *SI = dyn_cast<StoreInst>(StartInst)) @@ -397,7 +401,8 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, if (auto *NextStore = dyn_cast<StoreInst>(BI)) { // If this is a store, see if we can merge it in. - if (!NextStore->isSimple()) break; + if (!NextStore->isSimple()) + break; Value *StoredVal = NextStore->getValueOperand(); @@ -460,7 +465,8 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, // emit memset's for anything big enough to be worthwhile. Instruction *AMemSet = nullptr; for (const MemsetRange &Range : Ranges) { - if (Range.TheStores.size() == 1) continue; + if (Range.TheStores.size() == 1) + continue; // If it is profitable to lower this range to memset, do so now. if (!Range.isProfitableToUseMemset(DL)) @@ -481,12 +487,10 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, if (!Range.TheStores.empty()) AMemSet->setDebugLoc(Range.TheStores[0]->getDebugLoc()); - auto *NewDef = - cast<MemoryDef>(MemInsertPoint->getMemoryInst() == &*BI - ? MSSAU->createMemoryAccessBefore( - AMemSet, nullptr, MemInsertPoint) - : MSSAU->createMemoryAccessAfter( - AMemSet, nullptr, MemInsertPoint)); + auto *NewDef = cast<MemoryDef>( + MemInsertPoint->getMemoryInst() == &*BI + ? MSSAU->createMemoryAccessBefore(AMemSet, nullptr, MemInsertPoint) + : MSSAU->createMemoryAccessAfter(AMemSet, nullptr, MemInsertPoint)); MSSAU->insertDef(NewDef, /*RenameUses=*/true); MemInsertPoint = NewDef; @@ -512,12 +516,13 @@ bool MemCpyOptPass::moveUp(StoreInst *SI, Instruction *P, const LoadInst *LI) { // Keep track of the arguments of all instruction we plan to lift // so we can make sure to lift them as well if appropriate. - DenseSet<Instruction*> Args; + DenseSet<Instruction *> Args; auto AddArg = [&](Value *Arg) { auto *I = dyn_cast<Instruction>(Arg); if (I && I->getParent() == SI->getParent()) { // Cannot hoist user of P above P - if (I == P) return false; + if (I == P) + return false; Args.insert(I); } return true; @@ -630,8 +635,7 @@ bool MemCpyOptPass::moveUp(StoreInst *SI, Instruction *P, const LoadInst *LI) { bool MemCpyOptPass::processStoreOfLoad(StoreInst *SI, LoadInst *LI, const DataLayout &DL, BasicBlock::iterator &BBI) { - if (!LI->isSimple() || !LI->hasOneUse() || - LI->getParent() != SI->getParent()) + if (!LI->isSimple() || !LI->hasOneUse() || LI->getParent() != SI->getParent()) return false; auto *T = LI->getType(); @@ -677,22 +681,21 @@ bool MemCpyOptPass::processStoreOfLoad(StoreInst *SI, LoadInst *LI, if (isModSet(AA->getModRefInfo(SI, LoadLoc))) UseMemMove = true; - uint64_t Size = DL.getTypeStoreSize(T); - IRBuilder<> Builder(P); + Value *Size = + Builder.CreateTypeSize(Builder.getInt64Ty(), DL.getTypeStoreSize(T)); Instruction *M; if (UseMemMove) - M = Builder.CreateMemMove( - SI->getPointerOperand(), SI->getAlign(), - LI->getPointerOperand(), LI->getAlign(), Size); + M = Builder.CreateMemMove(SI->getPointerOperand(), SI->getAlign(), + LI->getPointerOperand(), LI->getAlign(), + Size); else - M = Builder.CreateMemCpy( - SI->getPointerOperand(), SI->getAlign(), - LI->getPointerOperand(), LI->getAlign(), Size); + M = Builder.CreateMemCpy(SI->getPointerOperand(), SI->getAlign(), + LI->getPointerOperand(), LI->getAlign(), Size); M->copyMetadata(*SI, LLVMContext::MD_DIAssignID); - LLVM_DEBUG(dbgs() << "Promoting " << *LI << " to " << *SI << " => " - << *M << "\n"); + LLVM_DEBUG(dbgs() << "Promoting " << *LI << " to " << *SI << " => " << *M + << "\n"); auto *LastDef = cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(SI)); @@ -755,7 +758,8 @@ bool MemCpyOptPass::processStoreOfLoad(StoreInst *SI, LoadInst *LI, } bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { - if (!SI->isSimple()) return false; + if (!SI->isSimple()) + return false; // Avoid merging nontemporal stores since the resulting // memcpy/memset would not be able to preserve the nontemporal hint. @@ -766,7 +770,7 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { if (SI->getMetadata(LLVMContext::MD_nontemporal)) return false; - const DataLayout &DL = SI->getModule()->getDataLayout(); + const DataLayout &DL = SI->getDataLayout(); Value *StoredVal = SI->getValueOperand(); @@ -794,8 +798,8 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { // 0xA0A0A0A0 and 0.0. auto *V = SI->getOperand(0); if (Value *ByteVal = isBytewiseValue(V, DL)) { - if (Instruction *I = tryMergingIntoMemset(SI, SI->getPointerOperand(), - ByteVal)) { + if (Instruction *I = + tryMergingIntoMemset(SI, SI->getPointerOperand(), ByteVal)) { BBI = I->getIterator(); // Don't invalidate iterator. return true; } @@ -816,8 +820,7 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { // The newly inserted memset is immediately overwritten by the original // store, so we do not need to rename uses. auto *StoreDef = cast<MemoryDef>(MSSA->getMemoryAccess(SI)); - auto *NewAccess = MSSAU->createMemoryAccessBefore( - M, nullptr, StoreDef); + auto *NewAccess = MSSAU->createMemoryAccessBefore(M, nullptr, StoreDef); MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/false); eraseInstruction(SI); @@ -836,8 +839,8 @@ bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) { // See if there is another memset or store neighboring this memset which // allows us to widen out the memset to do a single larger store. if (isa<ConstantInt>(MSI->getLength()) && !MSI->isVolatile()) - if (Instruction *I = tryMergingIntoMemset(MSI, MSI->getDest(), - MSI->getValue())) { + if (Instruction *I = + tryMergingIntoMemset(MSI, MSI->getDest(), MSI->getValue())) { BBI = I->getIterator(); // Don't invalidate iterator. return true; } @@ -850,7 +853,8 @@ bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) { bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, Instruction *cpyStore, Value *cpyDest, Value *cpySrc, TypeSize cpySize, - Align cpyDestAlign, BatchAAResults &BAA, + Align cpyDestAlign, + BatchAAResults &BAA, std::function<CallInst *()> GetC) { // The general transformation to keep in mind is // @@ -879,7 +883,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, if (!srcArraySize) return false; - const DataLayout &DL = cpyLoad->getModule()->getDataLayout(); + const DataLayout &DL = cpyLoad->getDataLayout(); TypeSize SrcAllocaSize = DL.getTypeAllocSize(srcAlloca->getAllocatedType()); // We can't optimize scalable types. if (SrcAllocaSize.isScalable()) @@ -898,15 +902,15 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, if (F->isIntrinsic() && F->getIntrinsicID() == Intrinsic::lifetime_start) return false; - if (C->getParent() != cpyStore->getParent()) { LLVM_DEBUG(dbgs() << "Call Slot: block local restriction\n"); return false; } - MemoryLocation DestLoc = isa<StoreInst>(cpyStore) ? - MemoryLocation::get(cpyStore) : - MemoryLocation::getForDest(cast<MemCpyInst>(cpyStore)); + MemoryLocation DestLoc = + isa<StoreInst>(cpyStore) + ? MemoryLocation::get(cpyStore) + : MemoryLocation::getForDest(cast<MemCpyInst>(cpyStore)); // Check that nothing touches the dest of the copy between // the call and the store/memcpy. @@ -980,10 +984,8 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, append_range(srcUseList, U->users()); continue; } - if (const auto *G = dyn_cast<GetElementPtrInst>(U)) { - if (!G->hasAllZeroIndices()) - return false; - + if (const auto *G = dyn_cast<GetElementPtrInst>(U); + G && G->hasAllZeroIndices()) { append_range(srcUseList, U->users()); continue; } @@ -991,8 +993,10 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, if (IT->isLifetimeStartOrEnd()) continue; - if (U != C && U != cpyLoad) + if (U != C && U != cpyLoad) { + LLVM_DEBUG(dbgs() << "Call slot: Source accessed by " << *U << "\n"); return false; + } } // Check whether src is captured by the called function, in which case there @@ -1121,28 +1125,79 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpyLoad, bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep, BatchAAResults &BAA) { - // We can only transforms memcpy's where the dest of one is the source of the - // other. - if (M->getSource() != MDep->getDest() || MDep->isVolatile()) - return false; - // If dep instruction is reading from our current input, then it is a noop - // transfer and substituting the input won't change this instruction. Just - // ignore the input and let someone else zap MDep. This handles cases like: + // transfer and substituting the input won't change this instruction. Just + // ignore the input and let someone else zap MDep. This handles cases like: // memcpy(a <- a) // memcpy(b <- a) if (M->getSource() == MDep->getSource()) return false; - // Second, the length of the memcpy's must be the same, or the preceding one + // We can only optimize non-volatile memcpy's. + if (MDep->isVolatile()) + return false; + + int64_t MForwardOffset = 0; + const DataLayout &DL = M->getModule()->getDataLayout(); + // We can only transforms memcpy's where the dest of one is the source of the + // other, or they have an offset in a range. + if (M->getSource() != MDep->getDest()) { + std::optional<int64_t> Offset = + M->getSource()->getPointerOffsetFrom(MDep->getDest(), DL); + if (!Offset || *Offset < 0) + return false; + MForwardOffset = *Offset; + } + + // The length of the memcpy's must be the same, or the preceding one // must be larger than the following one. - if (MDep->getLength() != M->getLength()) { + if (MForwardOffset != 0 || MDep->getLength() != M->getLength()) { auto *MDepLen = dyn_cast<ConstantInt>(MDep->getLength()); auto *MLen = dyn_cast<ConstantInt>(M->getLength()); - if (!MDepLen || !MLen || MDepLen->getZExtValue() < MLen->getZExtValue()) + if (!MDepLen || !MLen || + MDepLen->getZExtValue() < MLen->getZExtValue() + MForwardOffset) return false; } + IRBuilder<> Builder(M); + auto *CopySource = MDep->getSource(); + Instruction *NewCopySource = nullptr; + auto CleanupOnRet = llvm::make_scope_exit([&NewCopySource] { + if (NewCopySource && NewCopySource->use_empty()) + // Safety: It's safe here because we will only allocate more instructions + // after finishing all BatchAA queries, but we have to be careful if we + // want to do something like this in another place. Then we'd probably + // have to delay instruction removal until all transforms on an + // instruction finished. + NewCopySource->eraseFromParent(); + }); + MaybeAlign CopySourceAlign = MDep->getSourceAlign(); + // We just need to calculate the actual size of the copy. + auto MCopyLoc = MemoryLocation::getForSource(MDep).getWithNewSize( + MemoryLocation::getForSource(M).Size); + + // When the forwarding offset is greater than 0, we transform + // memcpy(d1 <- s1) + // memcpy(d2 <- d1+o) + // to + // memcpy(d2 <- s1+o) + if (MForwardOffset > 0) { + // The copy destination of `M` maybe can serve as the source of copying. + std::optional<int64_t> MDestOffset = + M->getRawDest()->getPointerOffsetFrom(MDep->getRawSource(), DL); + if (MDestOffset == MForwardOffset) + CopySource = M->getDest(); + else { + CopySource = Builder.CreateInBoundsPtrAdd( + CopySource, Builder.getInt64(MForwardOffset)); + NewCopySource = dyn_cast<Instruction>(CopySource); + } + // We need to update `MCopyLoc` if an offset exists. + MCopyLoc = MCopyLoc.getWithNewPtr(CopySource); + if (CopySourceAlign) + CopySourceAlign = commonAlignment(*CopySourceAlign, MForwardOffset); + } + // Verify that the copied-from memory doesn't change in between the two // transfers. For example, in: // memcpy(a <- b) @@ -1152,12 +1207,18 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M, // // TODO: If the code between M and MDep is transparent to the destination "c", // then we could still perform the xform by moving M up to the first memcpy. - // TODO: It would be sufficient to check the MDep source up to the memcpy - // size of M, rather than MDep. - if (writtenBetween(MSSA, BAA, MemoryLocation::getForSource(MDep), - MSSA->getMemoryAccess(MDep), MSSA->getMemoryAccess(M))) + if (writtenBetween(MSSA, BAA, MCopyLoc, MSSA->getMemoryAccess(MDep), + MSSA->getMemoryAccess(M))) return false; + // No need to create `memcpy(a <- a)`. + if (BAA.isMustAlias(M->getDest(), CopySource)) { + // Remove the instruction we're replacing. + eraseInstruction(M); + ++NumMemCpyInstr; + return true; + } + // If the dest of the second might alias the source of the first, then the // source and dest might overlap. In addition, if the source of the first // points to constant memory, they won't overlap by definition. Otherwise, we @@ -1175,27 +1236,27 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M, // If all checks passed, then we can transform M. LLVM_DEBUG(dbgs() << "MemCpyOptPass: Forwarding memcpy->memcpy src:\n" - << *MDep << '\n' << *M << '\n'); + << *MDep << '\n' + << *M << '\n'); // TODO: Is this worth it if we're creating a less aligned memcpy? For // example we could be moving from movaps -> movq on x86. - IRBuilder<> Builder(M); Instruction *NewM; if (UseMemMove) - NewM = Builder.CreateMemMove(M->getRawDest(), M->getDestAlign(), - MDep->getRawSource(), MDep->getSourceAlign(), - M->getLength(), M->isVolatile()); + NewM = + Builder.CreateMemMove(M->getDest(), M->getDestAlign(), CopySource, + CopySourceAlign, M->getLength(), M->isVolatile()); else if (isa<MemCpyInlineInst>(M)) { // llvm.memcpy may be promoted to llvm.memcpy.inline, but the converse is // never allowed since that would allow the latter to be lowered as a call // to an external function. - NewM = Builder.CreateMemCpyInline( - M->getRawDest(), M->getDestAlign(), MDep->getRawSource(), - MDep->getSourceAlign(), M->getLength(), M->isVolatile()); + NewM = Builder.CreateMemCpyInline(M->getDest(), M->getDestAlign(), + CopySource, CopySourceAlign, + M->getLength(), M->isVolatile()); } else - NewM = Builder.CreateMemCpy(M->getRawDest(), M->getDestAlign(), - MDep->getRawSource(), MDep->getSourceAlign(), - M->getLength(), M->isVolatile()); + NewM = + Builder.CreateMemCpy(M->getDest(), M->getDestAlign(), CopySource, + CopySourceAlign, M->getLength(), M->isVolatile()); NewM->copyMetadata(*M, LLVMContext::MD_DIAssignID); assert(isa<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(M))); @@ -1235,6 +1296,15 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, if (!BAA.isMustAlias(MemSet->getDest(), MemCpy->getDest())) return false; + // Don't perform the transform if src_size may be zero. In that case, the + // transform is essentially a complex no-op and may lead to an infinite + // loop if BasicAA is smart enough to understand that dst and dst + src_size + // are still MustAlias after the transform. + Value *SrcSize = MemCpy->getLength(); + if (!isKnownNonZero(SrcSize, + SimplifyQuery(MemCpy->getDataLayout(), DT, AC, MemCpy))) + return false; + // Check that src and dst of the memcpy aren't the same. While memcpy // operands cannot partially overlap, exact equality is allowed. if (isModSet(BAA.getModRefInfo(MemCpy, MemoryLocation::getForSource(MemCpy)))) @@ -1251,7 +1321,6 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, // Use the same i8* dest as the memcpy, killing the memset dest if different. Value *Dest = MemCpy->getRawDest(); Value *DestSize = MemSet->getLength(); - Value *SrcSize = MemCpy->getLength(); if (mayBeVisibleThroughUnwinding(Dest, MemSet, MemCpy)) return false; @@ -1307,8 +1376,8 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, // memcpy's defining access is the memset about to be removed. auto *LastDef = cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(MemCpy)); - auto *NewAccess = MSSAU->createMemoryAccessBefore( - NewMemSet, nullptr, LastDef); + auto *NewAccess = + MSSAU->createMemoryAccessBefore(NewMemSet, nullptr, LastDef); MSSAU->insertDef(cast<MemoryDef>(NewAccess), /*RenameUses=*/true); eraseInstruction(MemSet); @@ -1338,7 +1407,7 @@ static bool hasUndefContents(MemorySSA *MSSA, BatchAAResults &AA, Value *V, // The size also doesn't matter, as an out-of-bounds access would be UB. if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(V))) { if (getUnderlyingObject(II->getArgOperand(1)) == Alloca) { - const DataLayout &DL = Alloca->getModule()->getDataLayout(); + const DataLayout &DL = Alloca->getDataLayout(); if (std::optional<TypeSize> AllocaSize = Alloca->getAllocationSize(DL)) if (*AllocaSize == LTSize->getValue()) @@ -1384,7 +1453,7 @@ bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, return false; // A known memcpy size is also required. - auto *CCopySize = dyn_cast<ConstantInt>(CopySize); + auto *CCopySize = dyn_cast<ConstantInt>(CopySize); if (!CCopySize) return false; if (CCopySize->getZExtValue() > CMemSetSize->getZExtValue()) { @@ -1445,7 +1514,7 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store, } // Check that copy is full with static size. - const DataLayout &DL = DestAlloca->getModule()->getDataLayout(); + const DataLayout &DL = DestAlloca->getDataLayout(); std::optional<TypeSize> SrcSize = SrcAlloca->getAllocationSize(DL); if (!SrcSize || Size != *SrcSize) { LLVM_DEBUG(dbgs() << "Stack Move: Source alloca size mismatch\n"); @@ -1640,7 +1709,7 @@ bool MemCpyOptPass::performStackMoveOptzn(Instruction *Load, Instruction *Store, static bool isZeroSize(Value *Size) { if (auto *I = dyn_cast<Instruction>(Size)) - if (auto *Res = simplifyInstruction(I, I->getModule()->getDataLayout())) + if (auto *Res = simplifyInstruction(I, I->getDataLayout())) Size = Res; // Treat undef/poison size like zero. if (auto *C = dyn_cast<Constant>(Size)) @@ -1655,7 +1724,8 @@ static bool isZeroSize(Value *Size) { /// altogether. bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { // We can only optimize non-volatile memcpy's. - if (M->isVolatile()) return false; + if (M->isVolatile()) + return false; // If the source and destination of the memcpy are the same, then zap it. if (M->getSource() == M->getDest()) { @@ -1664,8 +1734,7 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { return true; } - // If the size is zero, remove the memcpy. This also prevents infinite loops - // in processMemSetMemCpyDependence, which is a no-op for zero-length memcpys. + // If the size is zero, remove the memcpy. if (isZeroSize(M->getLength())) { ++BBI; eraseInstruction(M); @@ -1681,7 +1750,7 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) { if (auto *GV = dyn_cast<GlobalVariable>(M->getSource())) if (GV->isConstant() && GV->hasDefinitiveInitializer()) if (Value *ByteVal = isBytewiseValue(GV->getInitializer(), - M->getModule()->getDataLayout())) { + M->getDataLayout())) { IRBuilder<> Builder(M); Instruction *NewM = Builder.CreateMemSet( M->getRawDest(), ByteVal, M->getLength(), M->getDestAlign(), false); @@ -1796,11 +1865,10 @@ bool MemCpyOptPass::processMemMove(MemMoveInst *M) { << "\n"); // If not, then we know we can transform this. - Type *ArgTys[3] = { M->getRawDest()->getType(), - M->getRawSource()->getType(), - M->getLength()->getType() }; - M->setCalledFunction(Intrinsic::getDeclaration(M->getModule(), - Intrinsic::memcpy, ArgTys)); + Type *ArgTys[3] = {M->getRawDest()->getType(), M->getRawSource()->getType(), + M->getLength()->getType()}; + M->setCalledFunction( + Intrinsic::getDeclaration(M->getModule(), Intrinsic::memcpy, ArgTys)); // For MemorySSA nothing really changes (except that memcpy may imply stricter // aliasing guarantees). @@ -1811,7 +1879,7 @@ bool MemCpyOptPass::processMemMove(MemMoveInst *M) { /// This is called on every byval argument in call sites. bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) { - const DataLayout &DL = CB.getCaller()->getParent()->getDataLayout(); + const DataLayout &DL = CB.getDataLayout(); // Find out what feeds this byval argument. Value *ByValArg = CB.getArgOperand(ArgNo); Type *ByValTy = CB.getParamByValType(ArgNo); @@ -1843,7 +1911,8 @@ bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) { // Get the alignment of the byval. If the call doesn't specify the alignment, // then it is some target specific value that we can't know. MaybeAlign ByValAlign = CB.getParamAlign(ArgNo); - if (!ByValAlign) return false; + if (!ByValAlign) + return false; // If it is greater than the memcpy, then we check to see if we can force the // source of the memcpy to the alignment we need. If we fail, we bail out. @@ -1897,7 +1966,7 @@ bool MemCpyOptPass::processImmutArgument(CallBase &CB, unsigned ArgNo) { if (!(CB.paramHasAttr(ArgNo, Attribute::NoAlias) && CB.paramHasAttr(ArgNo, Attribute::NoCapture))) return false; - const DataLayout &DL = CB.getCaller()->getParent()->getDataLayout(); + const DataLayout &DL = CB.getDataLayout(); Value *ImmutArg = CB.getArgOperand(ArgNo); // 2. Check that arg is alloca @@ -1987,7 +2056,7 @@ bool MemCpyOptPass::iterateOnFunction(Function &F) { continue; for (BasicBlock::iterator BI = BB.begin(), BE = BB.end(); BI != BE;) { - // Avoid invalidating the iterator. + // Avoid invalidating the iterator. Instruction *I = &*BI++; bool RepeatInstruction = false; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergeICmps.cpp index 1e0906717549..4291f3aee0cd 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergeICmps.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergeICmps.cpp @@ -74,7 +74,7 @@ namespace { struct BCEAtom { BCEAtom() = default; BCEAtom(GetElementPtrInst *GEP, LoadInst *LoadI, int BaseId, APInt Offset) - : GEP(GEP), LoadI(LoadI), BaseId(BaseId), Offset(Offset) {} + : GEP(GEP), LoadI(LoadI), BaseId(BaseId), Offset(std::move(Offset)) {} BCEAtom(const BCEAtom &) = delete; BCEAtom &operator=(const BCEAtom &) = delete; @@ -151,7 +151,7 @@ BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) { LLVM_DEBUG(dbgs() << "from non-zero AddressSpace\n"); return {}; } - const auto &DL = LoadI->getModule()->getDataLayout(); + const auto &DL = LoadI->getDataLayout(); if (!isDereferenceablePointer(Addr, LoadI->getType(), DL)) { LLVM_DEBUG(dbgs() << "not dereferenceable\n"); // We need to make sure that we can do comparison in any order, so we @@ -325,7 +325,7 @@ std::optional<BCECmp> visitICmp(const ICmpInst *const CmpI, auto Rhs = visitICmpLoadOperand(CmpI->getOperand(1), BaseId); if (!Rhs.BaseId) return std::nullopt; - const auto &DL = CmpI->getModule()->getDataLayout(); + const auto &DL = CmpI->getDataLayout(); return BCECmp(std::move(Lhs), std::move(Rhs), DL.getTypeSizeInBits(CmpI->getOperand(0)->getType()), CmpI); } @@ -658,7 +658,7 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, unsigned IntBits = TLI.getIntSize(); // Create memcmp() == 0. - const auto &DL = Phi.getModule()->getDataLayout(); + const auto &DL = Phi.getDataLayout(); Value *const MemCmpCall = emitMemCmp( Lhs, Rhs, ConstantInt::get(Builder.getIntNTy(SizeTBits), TotalSizeBits / 8), diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp index d65054a6ff9d..299239fb7020 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -199,7 +199,7 @@ StoreInst *MergedLoadStoreMotion::canSinkFromBlock(BasicBlock *BB1, CastInst::isBitOrNoopPointerCastable( Store0->getValueOperand()->getType(), Store1->getValueOperand()->getType(), - Store0->getModule()->getDataLayout())) + Store0->getDataLayout())) return Store1; } return nullptr; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/NaryReassociate.cpp index 7fe1a222021e..c00c71fcb0b4 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/NaryReassociate.cpp @@ -205,7 +205,7 @@ bool NaryReassociatePass::runImpl(Function &F, AssumptionCache *AC_, SE = SE_; TLI = TLI_; TTI = TTI_; - DL = &F.getParent()->getDataLayout(); + DL = &F.getDataLayout(); bool Changed = false, ChangedInThisIteration; do { @@ -511,14 +511,15 @@ Instruction *NaryReassociatePass::tryReassociatedBinaryOp(const SCEV *LHSExpr, Instruction *NewI = nullptr; switch (I->getOpcode()) { case Instruction::Add: - NewI = BinaryOperator::CreateAdd(LHS, RHS, "", I); + NewI = BinaryOperator::CreateAdd(LHS, RHS, "", I->getIterator()); break; case Instruction::Mul: - NewI = BinaryOperator::CreateMul(LHS, RHS, "", I); + NewI = BinaryOperator::CreateMul(LHS, RHS, "", I->getIterator()); break; default: llvm_unreachable("Unexpected instruction."); } + NewI->setDebugLoc(I->getDebugLoc()); NewI->takeName(I); return NewI; } @@ -564,14 +565,24 @@ NaryReassociatePass::findClosestMatchingDominator(const SCEV *CandidateExpr, // optimization makes the algorithm O(n). while (!Candidates.empty()) { // Candidates stores WeakTrackingVHs, so a candidate can be nullptr if it's - // removed - // during rewriting. - if (Value *Candidate = Candidates.back()) { + // removed during rewriting. + if (Value *Candidate = Candidates.pop_back_val()) { Instruction *CandidateInstruction = cast<Instruction>(Candidate); - if (DT->dominates(CandidateInstruction, Dominatee)) - return CandidateInstruction; + if (!DT->dominates(CandidateInstruction, Dominatee)) + continue; + + // Make sure that the instruction is safe to reuse without introducing + // poison. + SmallVector<Instruction *> DropPoisonGeneratingInsts; + if (!SE->canReuseInstruction(CandidateExpr, CandidateInstruction, + DropPoisonGeneratingInsts)) + continue; + + for (Instruction *I : DropPoisonGeneratingInsts) + I->dropPoisonGeneratingAnnotations(); + + return CandidateInstruction; } - Candidates.pop_back(); } return nullptr; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/NewGVN.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/NewGVN.cpp index 19ac9526b5f8..fc0b31c43396 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/NewGVN.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/NewGVN.cpp @@ -529,7 +529,11 @@ class NewGVN { // IR. SmallPtrSet<const Instruction *, 8> PHINodeUses; - DenseMap<const Value *, bool> OpSafeForPHIOfOps; + // The cached results, in general, are only valid for the specific block where + // they were computed. The unsigned part of the key is a unique block + // identifier + DenseMap<std::pair<const Value *, unsigned>, bool> OpSafeForPHIOfOps; + unsigned CacheIdx; // Map a temporary instruction we created to a parent block. DenseMap<const Value *, BasicBlock *> TempToBlock; @@ -892,7 +896,7 @@ private: // Debug counter info. When verifying, we have to reset the value numbering // debug counter to the same state it started in to get the same results. - int64_t StartingVNCounter = 0; + DebugCounter::CounterState StartingVNCounter; }; } // end anonymous namespace @@ -1199,7 +1203,7 @@ NewGVN::ExprResult NewGVN::createExpression(Instruction *I) const { } else if (auto *GEPI = dyn_cast<GetElementPtrInst>(I)) { Value *V = simplifyGEPInst(GEPI->getSourceElementType(), *E->op_begin(), ArrayRef(std::next(E->op_begin()), E->op_end()), - GEPI->isInBounds(), Q); + GEPI->getNoWrapFlags(), Q); if (auto Simplified = checkExprResults(E, I, V)) return Simplified; } else if (AllConstant) { @@ -2525,18 +2529,14 @@ void NewGVN::processOutgoingEdges(Instruction *TI, BasicBlock *B) { BasicBlock *TargetBlock = Case.getCaseSuccessor(); updateReachableEdge(B, TargetBlock); } else { - for (unsigned i = 0, e = SI->getNumSuccessors(); i != e; ++i) { - BasicBlock *TargetBlock = SI->getSuccessor(i); + for (BasicBlock *TargetBlock : successors(SI->getParent())) updateReachableEdge(B, TargetBlock); - } } } else { // Otherwise this is either unconditional, or a type we have no // idea about. Just mark successors as reachable. - for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { - BasicBlock *TargetBlock = TI->getSuccessor(i); + for (BasicBlock *TargetBlock : successors(TI->getParent())) updateReachableEdge(B, TargetBlock); - } // This also may be a memory defining terminator, in which case, set it // equivalent only to itself. @@ -2600,19 +2600,19 @@ bool NewGVN::OpIsSafeForPHIOfOps(Value *V, const BasicBlock *PHIBlock, if (!isa<Instruction>(I)) continue; - auto OISIt = OpSafeForPHIOfOps.find(I); + auto OISIt = OpSafeForPHIOfOps.find({I, CacheIdx}); if (OISIt != OpSafeForPHIOfOps.end()) return OISIt->second; // Keep walking until we either dominate the phi block, or hit a phi, or run // out of things to check. if (DT->properlyDominates(getBlockForValue(I), PHIBlock)) { - OpSafeForPHIOfOps.insert({I, true}); + OpSafeForPHIOfOps.insert({{I, CacheIdx}, true}); continue; } // PHI in the same block. if (isa<PHINode>(I) && getBlockForValue(I) == PHIBlock) { - OpSafeForPHIOfOps.insert({I, false}); + OpSafeForPHIOfOps.insert({{I, CacheIdx}, false}); return false; } @@ -2631,10 +2631,10 @@ bool NewGVN::OpIsSafeForPHIOfOps(Value *V, const BasicBlock *PHIBlock, if (!isa<Instruction>(Op)) continue; // Stop now if we find an unsafe operand. - auto OISIt = OpSafeForPHIOfOps.find(OrigI); + auto OISIt = OpSafeForPHIOfOps.find({OrigI, CacheIdx}); if (OISIt != OpSafeForPHIOfOps.end()) { if (!OISIt->second) { - OpSafeForPHIOfOps.insert({I, false}); + OpSafeForPHIOfOps.insert({{I, CacheIdx}, false}); return false; } continue; @@ -2644,7 +2644,7 @@ bool NewGVN::OpIsSafeForPHIOfOps(Value *V, const BasicBlock *PHIBlock, Worklist.push_back(cast<Instruction>(Op)); } } - OpSafeForPHIOfOps.insert({V, true}); + OpSafeForPHIOfOps.insert({{V, CacheIdx}, true}); return true; } @@ -3278,7 +3278,7 @@ void NewGVN::verifyIterationSettled(Function &F) { #ifndef NDEBUG LLVM_DEBUG(dbgs() << "Beginning iteration verification\n"); if (DebugCounter::isCounterSet(VNCounter)) - DebugCounter::setCounterValue(VNCounter, StartingVNCounter); + DebugCounter::setCounterState(VNCounter, StartingVNCounter); // Note that we have to store the actual classes, as we may change existing // classes during iteration. This is because our memory iteration propagation @@ -3297,6 +3297,7 @@ void NewGVN::verifyIterationSettled(Function &F) { TouchedInstructions.set(); TouchedInstructions.reset(0); OpSafeForPHIOfOps.clear(); + CacheIdx = 0; iterateTouchedInstructions(); DenseSet<std::pair<const CongruenceClass *, const CongruenceClass *>> EqualClasses; @@ -3400,6 +3401,8 @@ void NewGVN::iterateTouchedInstructions() { << " because it is unreachable\n"); continue; } + // Use the appropriate cache for "OpIsSafeForPHIOfOps". + CacheIdx = RPOOrdering.lookup(DT->getNode(CurrBlock)) - 1; updateProcessedCount(CurrBlock); } // Reset after processing (because we may mark ourselves as touched when @@ -3423,7 +3426,7 @@ void NewGVN::iterateTouchedInstructions() { // This is the main transformation entry point. bool NewGVN::runGVN() { if (DebugCounter::isCounterSet(VNCounter)) - StartingVNCounter = DebugCounter::getCounterValue(VNCounter); + StartingVNCounter = DebugCounter::getCounterState(VNCounter); bool Changed = false; NumFuncArgs = F.arg_size(); MSSAWalker = MSSA->getWalker(); @@ -3479,6 +3482,8 @@ bool NewGVN::runGVN() { LLVM_DEBUG(dbgs() << "Block " << getBlockName(&F.getEntryBlock()) << " marked reachable\n"); ReachableBlocks.insert(&F.getEntryBlock()); + // Use index corresponding to entry block. + CacheIdx = 0; iterateTouchedInstructions(); verifyMemoryCongruency(); @@ -3721,7 +3726,7 @@ void NewGVN::deleteInstructionsInBlock(BasicBlock *BB) { new StoreInst( PoisonValue::get(Int8Ty), Constant::getNullValue(PointerType::getUnqual(BB->getContext())), - BB->getTerminator()); + BB->getTerminator()->getIterator()); } void NewGVN::markInstructionForDeletion(Instruction *I) { @@ -4019,7 +4024,7 @@ bool NewGVN::eliminateInstructions(Function &F) { // dominated defs as dead. if (Def) { // For anything in this case, what and how we value number - // guarantees that any side-effets that would have occurred (ie + // guarantees that any side-effects that would have occurred (ie // throwing, etc) can be proven to either still occur (because it's // dominated by something that has the same side-effects), or never // occur. Otherwise, we would not have been able to prove it value @@ -4237,7 +4242,7 @@ PreservedAnalyses NewGVNPass::run(Function &F, AnalysisManager<Function> &AM) { auto &AA = AM.getResult<AAManager>(F); auto &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); bool Changed = - NewGVN(F, &DT, &AC, &TLI, &AA, &MSSA, F.getParent()->getDataLayout()) + NewGVN(F, &DT, &AC, &TLI, &AA, &MSSA, F.getDataLayout()) .runGVN(); if (!Changed) return PreservedAnalyses::all(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp index 0266eb1a9f50..77d67a2ce0f3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp @@ -60,6 +60,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" #include "llvm/IR/Statepoint.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -190,7 +191,7 @@ static bool enableBackedgeSafepoints(Function &F); static bool enableCallSafepoints(Function &F); static void -InsertSafepointPoll(Instruction *InsertBefore, +InsertSafepointPoll(BasicBlock::iterator InsertBefore, std::vector<CallBase *> &ParsePointsNeeded /*rval*/, const TargetLibraryInfo &TLI); @@ -288,6 +289,8 @@ bool PlaceSafepointsPass::runImpl(Function &F, const TargetLibraryInfo &TLI) { // with for the moment. legacy::FunctionPassManager FPM(F.getParent()); bool CanAssumeCallSafepoints = enableCallSafepoints(F); + + FPM.add(new TargetLibraryInfoWrapperPass(TLI)); auto *PBS = new PlaceBackedgeSafepointsLegacyPass(CanAssumeCallSafepoints); FPM.add(PBS); FPM.run(F); @@ -308,8 +311,7 @@ bool PlaceSafepointsPass::runImpl(Function &F, const TargetLibraryInfo &TLI) { // We can sometimes end up with duplicate poll locations. This happens if // a single loop is visited more than once. The fact this happens seems // wrong, but it does happen for the split-backedge.ll test case. - PollLocations.erase(std::unique(PollLocations.begin(), PollLocations.end()), - PollLocations.end()); + PollLocations.erase(llvm::unique(PollLocations), PollLocations.end()); // Insert a poll at each point the analysis pass identified // The poll location must be the terminator of a loop latch block. @@ -368,7 +370,7 @@ bool PlaceSafepointsPass::runImpl(Function &F, const TargetLibraryInfo &TLI) { // safepoint polls themselves. for (Instruction *PollLocation : PollsNeeded) { std::vector<CallBase *> RuntimeCalls; - InsertSafepointPoll(PollLocation, RuntimeCalls, TLI); + InsertSafepointPoll(PollLocation->getIterator(), RuntimeCalls, TLI); llvm::append_range(ParsePointNeeded, RuntimeCalls); } @@ -517,7 +519,7 @@ static bool doesNotRequireEntrySafepointBefore(CallBase *Call) { switch (II->getIntrinsicID()) { case Intrinsic::experimental_gc_statepoint: case Intrinsic::experimental_patchpoint_void: - case Intrinsic::experimental_patchpoint_i64: + case Intrinsic::experimental_patchpoint: // The can wrap an actual call which may grow the stack by an unbounded // amount or run forever. return false; @@ -591,7 +593,7 @@ static Instruction *findLocationForEntrySafepoint(Function &F, const char GCSafepointPollName[] = "gc.safepoint_poll"; static bool isGCSafepointPoll(Function &F) { - return F.getName().equals(GCSafepointPollName); + return F.getName() == GCSafepointPollName; } /// Returns true if this function should be rewritten to include safepoint @@ -619,7 +621,7 @@ static bool enableCallSafepoints(Function &F) { return !NoCall; } // not handle the parsability of state at the runtime call, that's the // callers job. static void -InsertSafepointPoll(Instruction *InsertBefore, +InsertSafepointPoll(BasicBlock::iterator InsertBefore, std::vector<CallBase *> &ParsePointsNeeded /*rval*/, const TargetLibraryInfo &TLI) { BasicBlock *OrigBB = InsertBefore->getParent(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reassociate.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reassociate.cpp index 818c7b40d489..e742d2ed12af 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -246,7 +246,8 @@ void ReassociatePass::canonicalizeOperands(Instruction *I) { } static BinaryOperator *CreateAdd(Value *S1, Value *S2, const Twine &Name, - Instruction *InsertBefore, Value *FlagsOp) { + BasicBlock::iterator InsertBefore, + Value *FlagsOp) { if (S1->getType()->isIntOrIntVectorTy()) return BinaryOperator::CreateAdd(S1, S2, Name, InsertBefore); else { @@ -258,7 +259,8 @@ static BinaryOperator *CreateAdd(Value *S1, Value *S2, const Twine &Name, } static BinaryOperator *CreateMul(Value *S1, Value *S2, const Twine &Name, - Instruction *InsertBefore, Value *FlagsOp) { + BasicBlock::iterator InsertBefore, + Value *FlagsOp) { if (S1->getType()->isIntOrIntVectorTy()) return BinaryOperator::CreateMul(S1, S2, Name, InsertBefore); else { @@ -270,7 +272,8 @@ static BinaryOperator *CreateMul(Value *S1, Value *S2, const Twine &Name, } static Instruction *CreateNeg(Value *S1, const Twine &Name, - Instruction *InsertBefore, Value *FlagsOp) { + BasicBlock::iterator InsertBefore, + Value *FlagsOp) { if (S1->getType()->isIntOrIntVectorTy()) return BinaryOperator::CreateNeg(S1, Name, InsertBefore); @@ -290,7 +293,8 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) { Constant *NegOne = Ty->isIntOrIntVectorTy() ? ConstantInt::getAllOnesValue(Ty) : ConstantFP::get(Ty, -1.0); - BinaryOperator *Res = CreateMul(Neg->getOperand(OpNo), NegOne, "", Neg, Neg); + BinaryOperator *Res = + CreateMul(Neg->getOperand(OpNo), NegOne, "", Neg->getIterator(), Neg); Neg->setOperand(OpNo, Constant::getNullValue(Ty)); // Drop use of op. Res->takeName(Neg); Neg->replaceAllUsesWith(Res); @@ -298,98 +302,7 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) { return Res; } -/// Returns k such that lambda(2^Bitwidth) = 2^k, where lambda is the Carmichael -/// function. This means that x^(2^k) === 1 mod 2^Bitwidth for -/// every odd x, i.e. x^(2^k) = 1 for every odd x in Bitwidth-bit arithmetic. -/// Note that 0 <= k < Bitwidth, and if Bitwidth > 3 then x^(2^k) = 0 for every -/// even x in Bitwidth-bit arithmetic. -static unsigned CarmichaelShift(unsigned Bitwidth) { - if (Bitwidth < 3) - return Bitwidth - 1; - return Bitwidth - 2; -} - -/// Add the extra weight 'RHS' to the existing weight 'LHS', -/// reducing the combined weight using any special properties of the operation. -/// The existing weight LHS represents the computation X op X op ... op X where -/// X occurs LHS times. The combined weight represents X op X op ... op X with -/// X occurring LHS + RHS times. If op is "Xor" for example then the combined -/// operation is equivalent to X if LHS + RHS is odd, or 0 if LHS + RHS is even; -/// the routine returns 1 in LHS in the first case, and 0 in LHS in the second. -static void IncorporateWeight(APInt &LHS, const APInt &RHS, unsigned Opcode) { - // If we were working with infinite precision arithmetic then the combined - // weight would be LHS + RHS. But we are using finite precision arithmetic, - // and the APInt sum LHS + RHS may not be correct if it wraps (it is correct - // for nilpotent operations and addition, but not for idempotent operations - // and multiplication), so it is important to correctly reduce the combined - // weight back into range if wrapping would be wrong. - - // If RHS is zero then the weight didn't change. - if (RHS.isMinValue()) - return; - // If LHS is zero then the combined weight is RHS. - if (LHS.isMinValue()) { - LHS = RHS; - return; - } - // From this point on we know that neither LHS nor RHS is zero. - - if (Instruction::isIdempotent(Opcode)) { - // Idempotent means X op X === X, so any non-zero weight is equivalent to a - // weight of 1. Keeping weights at zero or one also means that wrapping is - // not a problem. - assert(LHS == 1 && RHS == 1 && "Weights not reduced!"); - return; // Return a weight of 1. - } - if (Instruction::isNilpotent(Opcode)) { - // Nilpotent means X op X === 0, so reduce weights modulo 2. - assert(LHS == 1 && RHS == 1 && "Weights not reduced!"); - LHS = 0; // 1 + 1 === 0 modulo 2. - return; - } - if (Opcode == Instruction::Add || Opcode == Instruction::FAdd) { - // TODO: Reduce the weight by exploiting nsw/nuw? - LHS += RHS; - return; - } - - assert((Opcode == Instruction::Mul || Opcode == Instruction::FMul) && - "Unknown associative operation!"); - unsigned Bitwidth = LHS.getBitWidth(); - // If CM is the Carmichael number then a weight W satisfying W >= CM+Bitwidth - // can be replaced with W-CM. That's because x^W=x^(W-CM) for every Bitwidth - // bit number x, since either x is odd in which case x^CM = 1, or x is even in - // which case both x^W and x^(W - CM) are zero. By subtracting off multiples - // of CM like this weights can always be reduced to the range [0, CM+Bitwidth) - // which by a happy accident means that they can always be represented using - // Bitwidth bits. - // TODO: Reduce the weight by exploiting nsw/nuw? (Could do much better than - // the Carmichael number). - if (Bitwidth > 3) { - /// CM - The value of Carmichael's lambda function. - APInt CM = APInt::getOneBitSet(Bitwidth, CarmichaelShift(Bitwidth)); - // Any weight W >= Threshold can be replaced with W - CM. - APInt Threshold = CM + Bitwidth; - assert(LHS.ult(Threshold) && RHS.ult(Threshold) && "Weights not reduced!"); - // For Bitwidth 4 or more the following sum does not overflow. - LHS += RHS; - while (LHS.uge(Threshold)) - LHS -= CM; - } else { - // To avoid problems with overflow do everything the same as above but using - // a larger type. - unsigned CM = 1U << CarmichaelShift(Bitwidth); - unsigned Threshold = CM + Bitwidth; - assert(LHS.getZExtValue() < Threshold && RHS.getZExtValue() < Threshold && - "Weights not reduced!"); - unsigned Total = LHS.getZExtValue() + RHS.getZExtValue(); - while (Total >= Threshold) - Total -= CM; - LHS = Total; - } -} - -using RepeatedValue = std::pair<Value*, APInt>; +using RepeatedValue = std::pair<Value *, uint64_t>; /// Given an associative binary expression, return the leaf /// nodes in Ops along with their weights (how many times the leaf occurs). The @@ -467,11 +380,10 @@ using RepeatedValue = std::pair<Value*, APInt>; static bool LinearizeExprTree(Instruction *I, SmallVectorImpl<RepeatedValue> &Ops, ReassociatePass::OrderedSet &ToRedo, - bool &HasNUW) { + reassociate::OverflowTracking &Flags) { assert((isa<UnaryOperator>(I) || isa<BinaryOperator>(I)) && "Expected a UnaryOperator or BinaryOperator!"); LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n'); - unsigned Bitwidth = I->getType()->getScalarType()->getPrimitiveSizeInBits(); unsigned Opcode = I->getOpcode(); assert(I->isAssociative() && I->isCommutative() && "Expected an associative and commutative operation!"); @@ -486,8 +398,8 @@ static bool LinearizeExprTree(Instruction *I, // with their weights, representing a certain number of paths to the operator. // If an operator occurs in the worklist multiple times then we found multiple // ways to get to it. - SmallVector<std::pair<Instruction*, APInt>, 8> Worklist; // (Op, Weight) - Worklist.push_back(std::make_pair(I, APInt(Bitwidth, 1))); + SmallVector<std::pair<Instruction *, uint64_t>, 8> Worklist; // (Op, Weight) + Worklist.push_back(std::make_pair(I, 1)); bool Changed = false; // Leaves of the expression are values that either aren't the right kind of @@ -505,23 +417,25 @@ static bool LinearizeExprTree(Instruction *I, // Leaves - Keeps track of the set of putative leaves as well as the number of // paths to each leaf seen so far. - using LeafMap = DenseMap<Value *, APInt>; + using LeafMap = DenseMap<Value *, uint64_t>; LeafMap Leaves; // Leaf -> Total weight so far. SmallVector<Value *, 8> LeafOrder; // Ensure deterministic leaf output order. + const DataLayout DL = I->getDataLayout(); #ifndef NDEBUG SmallPtrSet<Value *, 8> Visited; // For checking the iteration scheme. #endif while (!Worklist.empty()) { - std::pair<Instruction*, APInt> P = Worklist.pop_back_val(); - I = P.first; // We examine the operands of this binary operator. + // We examine the operands of this binary operator. + auto [I, Weight] = Worklist.pop_back_val(); - if (isa<OverflowingBinaryOperator>(I)) - HasNUW &= I->hasNoUnsignedWrap(); + if (isa<OverflowingBinaryOperator>(I)) { + Flags.HasNUW &= I->hasNoUnsignedWrap(); + Flags.HasNSW &= I->hasNoSignedWrap(); + } for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands. Value *Op = I->getOperand(OpIdx); - APInt Weight = P.second; // Number of paths to this operand. LLVM_DEBUG(dbgs() << "OPERAND: " << *Op << " (" << Weight << ")\n"); assert(!Op->use_empty() && "No uses, so how did we get to it?!"); @@ -555,26 +469,8 @@ static bool LinearizeExprTree(Instruction *I, "In leaf map but not visited!"); // Update the number of paths to the leaf. - IncorporateWeight(It->second, Weight, Opcode); - -#if 0 // TODO: Re-enable once PR13021 is fixed. - // The leaf already has one use from inside the expression. As we want - // exactly one such use, drop this new use of the leaf. - assert(!Op->hasOneUse() && "Only one use, but we got here twice!"); - I->setOperand(OpIdx, UndefValue::get(I->getType())); - Changed = true; - - // If the leaf is a binary operation of the right kind and we now see - // that its multiple original uses were in fact all by nodes belonging - // to the expression, then no longer consider it to be a leaf and add - // its operands to the expression. - if (BinaryOperator *BO = isReassociableOp(Op, Opcode)) { - LLVM_DEBUG(dbgs() << "UNLEAF: " << *Op << " (" << It->second << ")\n"); - Worklist.push_back(std::make_pair(BO, It->second)); - Leaves.erase(It); - continue; - } -#endif + It->second += Weight; + assert(It->second >= Weight && "Weight overflows"); // If we still have uses that are not accounted for by the expression // then it is not safe to modify the value. @@ -637,13 +533,22 @@ static bool LinearizeExprTree(Instruction *I, // Node initially thought to be a leaf wasn't. continue; assert(!isReassociableOp(V, Opcode) && "Shouldn't be a leaf!"); - APInt Weight = It->second; - if (Weight.isMinValue()) - // Leaf already output or weight reduction eliminated it. - continue; + uint64_t Weight = It->second; // Ensure the leaf is only output once. It->second = 0; Ops.push_back(std::make_pair(V, Weight)); + if (Opcode == Instruction::Add && Flags.AllKnownNonNegative && Flags.HasNSW) + Flags.AllKnownNonNegative &= isKnownNonNegative(V, SimplifyQuery(DL)); + else if (Opcode == Instruction::Mul) { + // To preserve NUW we need all inputs non-zero. + // To preserve NSW we need all inputs strictly positive. + if (Flags.AllKnownNonZero && + (Flags.HasNUW || (Flags.HasNSW && Flags.AllKnownNonNegative))) { + Flags.AllKnownNonZero &= isKnownNonZero(V, SimplifyQuery(DL)); + if (Flags.HasNSW && Flags.AllKnownNonNegative) + Flags.AllKnownNonNegative &= isKnownNonNegative(V, SimplifyQuery(DL)); + } + } } // For nilpotent operations or addition there may be no operands, for example @@ -652,7 +557,7 @@ static bool LinearizeExprTree(Instruction *I, if (Ops.empty()) { Constant *Identity = ConstantExpr::getBinOpIdentity(Opcode, I->getType()); assert(Identity && "Associative operation without identity!"); - Ops.emplace_back(Identity, APInt(Bitwidth, 1)); + Ops.emplace_back(Identity, 1); } return Changed; @@ -662,7 +567,7 @@ static bool LinearizeExprTree(Instruction *I, /// linearized and optimized, emit them in-order. void ReassociatePass::RewriteExprTree(BinaryOperator *I, SmallVectorImpl<ValueEntry> &Ops, - bool HasNUW) { + OverflowTracking Flags) { assert(Ops.size() > 1 && "Single values should be used directly!"); // Since our optimizations should never increase the number of operations, the @@ -691,8 +596,8 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, /// of leaf nodes as inner nodes cannot occur by remembering all of the future /// leaves and refusing to reuse any of them as inner nodes. SmallPtrSet<Value*, 8> NotRewritable; - for (unsigned i = 0, e = Ops.size(); i != e; ++i) - NotRewritable.insert(Ops[i].Op); + for (const ValueEntry &Op : Ops) + NotRewritable.insert(Op.Op); // ExpressionChangedStart - Non-null if the rewritten expression differs from // the original in some non-trivial way, requiring the clearing of optional @@ -792,9 +697,9 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, // stupid, create a new node if there are none left. BinaryOperator *NewOp; if (NodesToRewrite.empty()) { - Constant *Undef = UndefValue::get(I->getType()); - NewOp = BinaryOperator::Create(Instruction::BinaryOps(Opcode), - Undef, Undef, "", I); + Constant *Poison = PoisonValue::get(I->getType()); + NewOp = BinaryOperator::Create(Instruction::BinaryOps(Opcode), Poison, + Poison, "", I->getIterator()); if (isa<FPMathOperator>(NewOp)) NewOp->setFastMathFlags(I->getFastMathFlags()); } else { @@ -827,11 +732,14 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, ExpressionChangedStart->setFastMathFlags(Flags); } else { ExpressionChangedStart->clearSubclassOptionalData(); - // Note that it doesn't hold for mul if one of the operands is zero. - // TODO: We can preserve NUW flag if we prove that all mul operands - // are non-zero. - if (HasNUW && ExpressionChangedStart->getOpcode() == Instruction::Add) - ExpressionChangedStart->setHasNoUnsignedWrap(); + if (ExpressionChangedStart->getOpcode() == Instruction::Add || + (ExpressionChangedStart->getOpcode() == Instruction::Mul && + Flags.AllKnownNonZero)) { + if (Flags.HasNUW) + ExpressionChangedStart->setHasNoUnsignedWrap(); + if (Flags.HasNSW && (Flags.AllKnownNonNegative || Flags.HasNUW)) + ExpressionChangedStart->setHasNoSignedWrap(); + } } } @@ -854,8 +762,8 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, } // Throw away any left over nodes from the original expression. - for (unsigned i = 0, e = NodesToRewrite.size(); i != e; ++i) - RedoInsts.insert(NodesToRewrite[i]); + for (BinaryOperator *BO : NodesToRewrite) + RedoInsts.insert(BO); } /// Insert instructions before the instruction pointed to by BI, @@ -868,7 +776,7 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, static Value *NegateValue(Value *V, Instruction *BI, ReassociatePass::OrderedSet &ToRedo) { if (auto *C = dyn_cast<Constant>(V)) { - const DataLayout &DL = BI->getModule()->getDataLayout(); + const DataLayout &DL = BI->getDataLayout(); Constant *Res = C->getType()->isFPOrFPVectorTy() ? ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL) : ConstantExpr::getNeg(C); @@ -945,7 +853,13 @@ static Value *NegateValue(Value *V, Instruction *BI, ->getIterator(); } + // Check that if TheNeg is moved out of its parent block, we drop its + // debug location to avoid extra coverage. + // See test dropping_debugloc_the_neg.ll for a detailed example. + if (TheNeg->getParent() != InsertPt->getParent()) + TheNeg->dropLocation(); TheNeg->moveBefore(*InsertPt->getParent(), InsertPt); + if (TheNeg->getOpcode() == Instruction::Sub) { TheNeg->setHasNoUnsignedWrap(false); TheNeg->setHasNoSignedWrap(false); @@ -958,7 +872,8 @@ static Value *NegateValue(Value *V, Instruction *BI, // Insert a 'neg' instruction that subtracts the value from zero to get the // negation. - Instruction *NewNeg = CreateNeg(V, V->getName() + ".neg", BI, BI); + Instruction *NewNeg = + CreateNeg(V, V->getName() + ".neg", BI->getIterator(), BI); ToRedo.insert(NewNeg); return NewNeg; } @@ -1044,8 +959,8 @@ static bool shouldConvertOrWithNoCommonBitsToAdd(Instruction *Or) { /// transform this into (X+Y) to allow arithmetics reassociation. static BinaryOperator *convertOrWithNoCommonBitsToAdd(Instruction *Or) { // Convert an or into an add. - BinaryOperator *New = - CreateAdd(Or->getOperand(0), Or->getOperand(1), "", Or, Or); + BinaryOperator *New = CreateAdd(Or->getOperand(0), Or->getOperand(1), "", + Or->getIterator(), Or); New->setHasNoSignedWrap(); New->setHasNoUnsignedWrap(); New->takeName(Or); @@ -1097,7 +1012,8 @@ static BinaryOperator *BreakUpSubtract(Instruction *Sub, // Calculate the negative value of Operand 1 of the sub instruction, // and set it as the RHS of the add instruction we just made. Value *NegVal = NegateValue(Sub->getOperand(1), Sub, ToRedo); - BinaryOperator *New = CreateAdd(Sub->getOperand(0), NegVal, "", Sub, Sub); + BinaryOperator *New = + CreateAdd(Sub->getOperand(0), NegVal, "", Sub->getIterator(), Sub); Sub->setOperand(0, Constant::getNullValue(Sub->getType())); // Drop use of op. Sub->setOperand(1, Constant::getNullValue(Sub->getType())); // Drop use of op. New->takeName(Sub); @@ -1115,10 +1031,11 @@ static BinaryOperator *BreakUpSubtract(Instruction *Sub, static BinaryOperator *ConvertShiftToMul(Instruction *Shl) { Constant *MulCst = ConstantInt::get(Shl->getType(), 1); auto *SA = cast<ConstantInt>(Shl->getOperand(1)); - MulCst = ConstantExpr::getShl(MulCst, SA); + MulCst = ConstantFoldBinaryInstruction(Instruction::Shl, MulCst, SA); + assert(MulCst && "Constant folding of immediate constants failed"); - BinaryOperator *Mul = - BinaryOperator::CreateMul(Shl->getOperand(0), MulCst, "", Shl); + BinaryOperator *Mul = BinaryOperator::CreateMul(Shl->getOperand(0), MulCst, + "", Shl->getIterator()); Shl->setOperand(0, PoisonValue::get(Shl->getType())); // Drop use of op. Mul->takeName(Shl); @@ -1168,13 +1085,13 @@ static unsigned FindInOperandList(const SmallVectorImpl<ValueEntry> &Ops, /// Emit a tree of add instructions, summing Ops together /// and returning the result. Insert the tree before I. -static Value *EmitAddTreeOfValues(Instruction *I, +static Value *EmitAddTreeOfValues(BasicBlock::iterator It, SmallVectorImpl<WeakTrackingVH> &Ops) { if (Ops.size() == 1) return Ops.back(); Value *V1 = Ops.pop_back_val(); - Value *V2 = EmitAddTreeOfValues(I, Ops); - return CreateAdd(V2, V1, "reass.add", I, I); + Value *V2 = EmitAddTreeOfValues(It, Ops); + return CreateAdd(V2, V1, "reass.add", It, &*It); } /// If V is an expression tree that is a multiplication sequence, @@ -1186,14 +1103,13 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { return nullptr; SmallVector<RepeatedValue, 8> Tree; - bool HasNUW = true; - MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts, HasNUW); + OverflowTracking Flags; + MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts, Flags); SmallVector<ValueEntry, 8> Factors; Factors.reserve(Tree.size()); for (unsigned i = 0, e = Tree.size(); i != e; ++i) { RepeatedValue E = Tree[i]; - Factors.append(E.second.getZExtValue(), - ValueEntry(getRank(E.first), E.first)); + Factors.append(E.second, ValueEntry(getRank(E.first), E.first)); } bool FoundFactor = false; @@ -1229,7 +1145,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { if (!FoundFactor) { // Make sure to restore the operands to the expression tree. - RewriteExprTree(BO, Factors, HasNUW); + RewriteExprTree(BO, Factors, Flags); return nullptr; } @@ -1241,12 +1157,12 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { RedoInsts.insert(BO); V = Factors[0].Op; } else { - RewriteExprTree(BO, Factors, HasNUW); + RewriteExprTree(BO, Factors, Flags); V = BO; } if (NeedsNegate) - V = CreateNeg(V, "neg", &*InsertPt, BO); + V = CreateNeg(V, "neg", InsertPt, BO); return V; } @@ -1321,7 +1237,7 @@ static Value *OptimizeAndOrXor(unsigned Opcode, /// instruction. There are two special cases: 1) if the constant operand is 0, /// it will return NULL. 2) if the constant is ~0, the symbolic operand will /// be returned. -static Value *createAndInstr(Instruction *InsertBefore, Value *Opnd, +static Value *createAndInstr(BasicBlock::iterator InsertBefore, Value *Opnd, const APInt &ConstOpnd) { if (ConstOpnd.isZero()) return nullptr; @@ -1342,7 +1258,7 @@ static Value *createAndInstr(Instruction *InsertBefore, Value *Opnd, // If it was successful, true is returned, and the "R" and "C" is returned // via "Res" and "ConstOpnd", respectively; otherwise, false is returned, // and both "Res" and "ConstOpnd" remain unchanged. -bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, +bool ReassociatePass::CombineXorOpnd(BasicBlock::iterator It, XorOpnd *Opnd1, APInt &ConstOpnd, Value *&Res) { // Xor-Rule 1: (x | c1) ^ c2 = (x | c1) ^ (c1 ^ c1) ^ c2 // = ((x | c1) ^ c1) ^ (c1 ^ c2) @@ -1359,7 +1275,7 @@ bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, return false; Value *X = Opnd1->getSymbolicPart(); - Res = createAndInstr(I, X, ~C1); + Res = createAndInstr(It, X, ~C1); // ConstOpnd was C2, now C1 ^ C2. ConstOpnd ^= C1; @@ -1376,7 +1292,7 @@ bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, // via "Res" and "ConstOpnd", respectively (If the entire expression is // evaluated to a constant, the Res is set to NULL); otherwise, false is // returned, and both "Res" and "ConstOpnd" remain unchanged. -bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, +bool ReassociatePass::CombineXorOpnd(BasicBlock::iterator It, XorOpnd *Opnd1, XorOpnd *Opnd2, APInt &ConstOpnd, Value *&Res) { Value *X = Opnd1->getSymbolicPart(); @@ -1411,7 +1327,7 @@ bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, return false; } - Res = createAndInstr(I, X, C3); + Res = createAndInstr(It, X, C3); ConstOpnd ^= C1; } else if (Opnd1->isOrExpr()) { // Xor-Rule 3: (x | c1) ^ (x | c2) = (x & c3) ^ c3 where c3 = c1 ^ c2 @@ -1427,7 +1343,7 @@ bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, return false; } - Res = createAndInstr(I, X, C3); + Res = createAndInstr(It, X, C3); ConstOpnd ^= C3; } else { // Xor-Rule 4: (x & c1) ^ (x & c2) = (x & (c1^c2)) @@ -1435,7 +1351,7 @@ bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, const APInt &C1 = Opnd1->getConstPart(); const APInt &C2 = Opnd2->getConstPart(); APInt C3 = C1 ^ C2; - Res = createAndInstr(I, X, C3); + Res = createAndInstr(It, X, C3); } // Put the original operands in the Redo list; hope they will be deleted @@ -1483,8 +1399,8 @@ Value *ReassociatePass::OptimizeXor(Instruction *I, // the "OpndPtrs" as well. For the similar reason, do not fuse this loop // with the previous loop --- the iterator of the "Opnds" may be invalidated // when new elements are added to the vector. - for (unsigned i = 0, e = Opnds.size(); i != e; ++i) - OpndPtrs.push_back(&Opnds[i]); + for (XorOpnd &Op : Opnds) + OpndPtrs.push_back(&Op); // Step 2: Sort the Xor-Operands in a way such that the operands containing // the same symbolic value cluster together. For instance, the input operand @@ -1512,7 +1428,8 @@ Value *ReassociatePass::OptimizeXor(Instruction *I, Value *CV; // Step 3.1: Try simplifying "CurrOpnd ^ ConstOpnd" - if (!ConstOpnd.isZero() && CombineXorOpnd(I, CurrOpnd, ConstOpnd, CV)) { + if (!ConstOpnd.isZero() && + CombineXorOpnd(I->getIterator(), CurrOpnd, ConstOpnd, CV)) { Changed = true; if (CV) *CurrOpnd = XorOpnd(CV); @@ -1529,7 +1446,7 @@ Value *ReassociatePass::OptimizeXor(Instruction *I, // step 3.2: When previous and current operands share the same symbolic // value, try to simplify "PrevOpnd ^ CurrOpnd ^ ConstOpnd" - if (CombineXorOpnd(I, CurrOpnd, PrevOpnd, ConstOpnd, CV)) { + if (CombineXorOpnd(I->getIterator(), CurrOpnd, PrevOpnd, ConstOpnd, CV)) { // Remove previous operand PrevOpnd->Invalidate(); if (CV) { @@ -1600,7 +1517,7 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I, Type *Ty = TheOp->getType(); Constant *C = Ty->isIntOrIntVectorTy() ? ConstantInt::get(Ty, NumFound) : ConstantFP::get(Ty, NumFound); - Instruction *Mul = CreateMul(TheOp, C, "factor", I, I); + Instruction *Mul = CreateMul(TheOp, C, "factor", I->getIterator(), I); // Now that we have inserted a multiply, optimize it. This allows us to // handle cases that require multiple factoring steps, such as this: @@ -1764,7 +1681,7 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I, DummyInst->deleteValue(); unsigned NumAddedValues = NewMulOps.size(); - Value *V = EmitAddTreeOfValues(I, NewMulOps); + Value *V = EmitAddTreeOfValues(I->getIterator(), NewMulOps); // Now that we have inserted the add tree, optimize it. This allows us to // handle cases that require multiple factoring steps, such as this: @@ -1775,7 +1692,7 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I, RedoInsts.insert(VI); // Create the multiply. - Instruction *V2 = CreateMul(V, MaxOccVal, "reass.mul", I, I); + Instruction *V2 = CreateMul(V, MaxOccVal, "reass.mul", I->getIterator(), I); // Rerun associate on the multiply in case the inner expression turned into // a multiply. We want to make sure that we keep things in canonical form. @@ -1914,10 +1831,10 @@ ReassociatePass::buildMinimalMultiplyDAG(IRBuilderBase &Builder, } // Unique factors with equal powers -- we've folded them into the first one's // base. - Factors.erase(std::unique(Factors.begin(), Factors.end(), - [](const Factor &LHS, const Factor &RHS) { - return LHS.Power == RHS.Power; - }), + Factors.erase(llvm::unique(Factors, + [](const Factor &LHS, const Factor &RHS) { + return LHS.Power == RHS.Power; + }), Factors.end()); // Iteratively collect the base of each factor with an add power into the @@ -1974,7 +1891,7 @@ Value *ReassociatePass::OptimizeExpression(BinaryOperator *I, SmallVectorImpl<ValueEntry> &Ops) { // Now that we have the linearized expression tree, try to optimize it. // Start by folding any constants that we found. - const DataLayout &DL = I->getModule()->getDataLayout(); + const DataLayout &DL = I->getDataLayout(); Constant *Cst = nullptr; unsigned Opcode = I->getOpcode(); while (!Ops.empty()) { @@ -2071,8 +1988,8 @@ void ReassociatePass::EraseInst(Instruction *I) { I->eraseFromParent(); // Optimize its operands. SmallPtrSet<Instruction *, 8> Visited; // Detect self-referential nodes. - for (unsigned i = 0, e = Ops.size(); i != e; ++i) - if (Instruction *Op = dyn_cast<Instruction>(Ops[i])) { + for (Value *V : Ops) + if (Instruction *Op = dyn_cast<Instruction>(V)) { // If this is a node in an expression tree, climb to the expression root // and add that since that's where optimization actually happens. unsigned Opcode = Op->getOpcode(); @@ -2270,7 +2187,7 @@ void ReassociatePass::OptimizeInst(Instruction *I) { shouldConvertOrWithNoCommonBitsToAdd(I) && !isLoadCombineCandidate(I) && (cast<PossiblyDisjointInst>(I)->isDisjoint() || haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), - SimplifyQuery(I->getModule()->getDataLayout(), + SimplifyQuery(I->getDataLayout(), /*DT=*/nullptr, /*AC=*/nullptr, I)))) { Instruction *NI = convertOrWithNoCommonBitsToAdd(I); RedoInsts.insert(I); @@ -2366,12 +2283,12 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { // First, walk the expression tree, linearizing the tree, collecting the // operand information. SmallVector<RepeatedValue, 8> Tree; - bool HasNUW = true; - MadeChange |= LinearizeExprTree(I, Tree, RedoInsts, HasNUW); + OverflowTracking Flags; + MadeChange |= LinearizeExprTree(I, Tree, RedoInsts, Flags); SmallVector<ValueEntry, 8> Ops; Ops.reserve(Tree.size()); for (const RepeatedValue &E : Tree) - Ops.append(E.second.getZExtValue(), ValueEntry(getRank(E.first), E.first)); + Ops.append(E.second, ValueEntry(getRank(E.first), E.first)); LLVM_DEBUG(dbgs() << "RAIn:\t"; PrintOps(I, Ops); dbgs() << '\n'); @@ -2560,7 +2477,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { dbgs() << '\n'); // Now that we ordered and optimized the expressions, splat them back into // the expression tree, removing any unneeded nodes. - RewriteExprTree(I, Ops, HasNUW); + RewriteExprTree(I, Ops, Flags); } void diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reg2Mem.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reg2Mem.cpp index 6c2b3e9bd4a7..ebc5075aa36f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reg2Mem.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Reg2Mem.cpp @@ -64,7 +64,7 @@ static bool runPass(Function &F) { CastInst *AllocaInsertionPoint = new BitCastInst( Constant::getNullValue(Type::getInt32Ty(F.getContext())), - Type::getInt32Ty(F.getContext()), "reg2mem alloca point", &*I); + Type::getInt32Ty(F.getContext()), "reg2mem alloca point", I); // Find the escaped instructions. But don't create stack slots for // allocas in entry block. @@ -76,7 +76,7 @@ static bool runPass(Function &F) { // Demote escaped instructions NumRegsDemoted += WorkList.size(); for (Instruction *I : WorkList) - DemoteRegToStack(*I, false, AllocaInsertionPoint); + DemoteRegToStack(*I, false, AllocaInsertionPoint->getIterator()); WorkList.clear(); @@ -88,7 +88,7 @@ static bool runPass(Function &F) { // Demote phi nodes NumPhisDemoted += WorkList.size(); for (Instruction *I : WorkList) - DemotePHIToStack(cast<PHINode>(I), AllocaInsertionPoint); + DemotePHIToStack(cast<PHINode>(I), AllocaInsertionPoint->getIterator()); return true; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index 45ce3bf3ceae..2b99e28acb4e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -1143,7 +1143,8 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache, assert(Base && "Can't be null"); // The cast is needed since base traversal may strip away bitcasts if (Base->getType() != Input->getType() && InsertPt) - Base = new BitCastInst(Base, Input->getType(), "cast", InsertPt); + Base = new BitCastInst(Base, Input->getType(), "cast", + InsertPt->getIterator()); return Base; }; @@ -1251,7 +1252,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache, // get the data layout to compare the sizes of base/derived pointer values [[maybe_unused]] auto &DL = - cast<llvm::Instruction>(Def)->getModule()->getDataLayout(); + cast<llvm::Instruction>(Def)->getDataLayout(); // Cache all of our results so we can cheaply reuse them // NOTE: This is actually two caches: one of the base defining value // relation and one of the base pointer relation! FIXME @@ -1322,7 +1323,7 @@ static void findBasePointers(DominatorTree &DT, DefiningValueMapTy &DVCache, IsKnownBaseMapTy &KnownBases) { StatepointLiveSetTy PotentiallyDerivedPointers = result.LiveSet; // We assume that all pointers passed to deopt are base pointers; as an - // optimization, we can use this to avoid seperately materializing the base + // optimization, we can use this to avoid separately materializing the base // pointer graph. This is only relevant since we're very conservative about // generating new conflict nodes during base pointer insertion. If we were // smarter there, this would be irrelevant. @@ -1612,7 +1613,7 @@ public: // Note: we've inserted instructions, so the call to llvm.deoptimize may // not necessarily be followed by the matching return. auto *RI = cast<ReturnInst>(OldI->getParent()->getTerminator()); - new UnreachableInst(RI->getContext(), RI); + new UnreachableInst(RI->getContext(), RI->getIterator()); RI->eraseFromParent(); } @@ -1684,10 +1685,10 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ // Pass through the requested lowering if any. The default is live-through. StringRef DeoptLowering = getDeoptLowering(Call); - if (DeoptLowering.equals("live-in")) + if (DeoptLowering == "live-in") Flags |= uint32_t(StatepointFlags::DeoptLiveIn); else { - assert(DeoptLowering.equals("live-through") && "Unsupported value!"); + assert(DeoptLowering == "live-through" && "Unsupported value!"); } FunctionCallee CallTarget(Call->getFunctionType(), Call->getCalledOperand()); @@ -1733,7 +1734,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */ // memcpy(dest_derived, source_derived, ...) => // memcpy(dest_base, dest_offset, source_base, source_offset, ...) auto &Context = Call->getContext(); - auto &DL = Call->getModule()->getDataLayout(); + auto &DL = Call->getDataLayout(); auto GetBaseAndOffset = [&](Value *Derived) { Value *Base = nullptr; // Optimizations in unreachable code might substitute the real pointer @@ -1976,7 +1977,7 @@ insertRelocationStores(iterator_range<Value::user_iterator> GCRelocs, // Emit store into the related alloca. assert(Relocate->getNextNode() && "Should always have one since it's not a terminator"); - new StoreInst(Relocate, Alloca, Relocate->getNextNode()); + new StoreInst(Relocate, Alloca, std::next(Relocate->getIterator())); #ifndef NDEBUG VisitedLiveValues.insert(OriginalValue); @@ -1999,7 +2000,7 @@ static void insertRematerializationStores( Value *Alloca = AllocaMap[OriginalValue]; new StoreInst(RematerializedValue, Alloca, - RematerializedValue->getNextNode()); + std::next(RematerializedValue->getIterator())); #ifndef NDEBUG VisitedLiveValues.insert(OriginalValue); @@ -2029,11 +2030,11 @@ static void relocationViaAlloca( // Emit alloca for "LiveValue" and record it in "allocaMap" and // "PromotableAllocas" - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); auto emitAllocaFor = [&](Value *LiveValue) { - AllocaInst *Alloca = new AllocaInst(LiveValue->getType(), - DL.getAllocaAddrSpace(), "", - F.getEntryBlock().getFirstNonPHI()); + AllocaInst *Alloca = + new AllocaInst(LiveValue->getType(), DL.getAllocaAddrSpace(), "", + F.getEntryBlock().getFirstNonPHIIt()); AllocaMap[LiveValue] = Alloca; PromotableAllocas.push_back(Alloca); }; @@ -2100,7 +2101,7 @@ static void relocationViaAlloca( ToClobber.push_back(Alloca); } - auto InsertClobbersAt = [&](Instruction *IP) { + auto InsertClobbersAt = [&](BasicBlock::iterator IP) { for (auto *AI : ToClobber) { auto AT = AI->getAllocatedType(); Constant *CPN; @@ -2115,10 +2116,11 @@ static void relocationViaAlloca( // Insert the clobbering stores. These may get intermixed with the // gc.results and gc.relocates, but that's fine. if (auto II = dyn_cast<InvokeInst>(Statepoint)) { - InsertClobbersAt(&*II->getNormalDest()->getFirstInsertionPt()); - InsertClobbersAt(&*II->getUnwindDest()->getFirstInsertionPt()); + InsertClobbersAt(II->getNormalDest()->getFirstInsertionPt()); + InsertClobbersAt(II->getUnwindDest()->getFirstInsertionPt()); } else { - InsertClobbersAt(cast<Instruction>(Statepoint)->getNextNode()); + InsertClobbersAt( + std::next(cast<Instruction>(Statepoint)->getIterator())); } } } @@ -2146,7 +2148,7 @@ static void relocationViaAlloca( } llvm::sort(Uses); - auto Last = std::unique(Uses.begin(), Uses.end()); + auto Last = llvm::unique(Uses); Uses.erase(Last, Uses.end()); for (Instruction *Use : Uses) { @@ -2154,15 +2156,15 @@ static void relocationViaAlloca( PHINode *Phi = cast<PHINode>(Use); for (unsigned i = 0; i < Phi->getNumIncomingValues(); i++) { if (Def == Phi->getIncomingValue(i)) { - LoadInst *Load = - new LoadInst(Alloca->getAllocatedType(), Alloca, "", - Phi->getIncomingBlock(i)->getTerminator()); + LoadInst *Load = new LoadInst( + Alloca->getAllocatedType(), Alloca, "", + Phi->getIncomingBlock(i)->getTerminator()->getIterator()); Phi->setIncomingValue(i, Load); } } } else { - LoadInst *Load = - new LoadInst(Alloca->getAllocatedType(), Alloca, "", Use); + LoadInst *Load = new LoadInst(Alloca->getAllocatedType(), Alloca, "", + Use->getIterator()); Use->replaceUsesOfWith(Def, Load); } } @@ -2229,16 +2231,16 @@ static void insertUseHolderAfter(CallBase *Call, const ArrayRef<Value *> Values, if (isa<CallInst>(Call)) { // For call safepoints insert dummy calls right after safepoint Holders.push_back( - CallInst::Create(Func, Values, "", &*++Call->getIterator())); + CallInst::Create(Func, Values, "", std::next(Call->getIterator()))); return; } // For invoke safepooints insert dummy calls both in normal and // exceptional destination blocks auto *II = cast<InvokeInst>(Call); Holders.push_back(CallInst::Create( - Func, Values, "", &*II->getNormalDest()->getFirstInsertionPt())); + Func, Values, "", II->getNormalDest()->getFirstInsertionPt())); Holders.push_back(CallInst::Create( - Func, Values, "", &*II->getUnwindDest()->getFirstInsertionPt())); + Func, Values, "", II->getUnwindDest()->getFirstInsertionPt())); } static void findLiveReferences( @@ -2269,7 +2271,7 @@ static Value* findRematerializableChainToBasePointer( } if (CastInst *CI = dyn_cast<CastInst>(CurrentValue)) { - if (!CI->isNoopCast(CI->getModule()->getDataLayout())) + if (!CI->isNoopCast(CI->getDataLayout())) return CI; ChainToBase.push_back(CI); @@ -2291,7 +2293,7 @@ chainToBasePointerCost(SmallVectorImpl<Instruction *> &Chain, for (Instruction *Instr : Chain) { if (CastInst *CI = dyn_cast<CastInst>(Instr)) { - assert(CI->isNoopCast(CI->getModule()->getDataLayout()) && + assert(CI->isNoopCast(CI->getDataLayout()) && "non noop cast is found during rematerialization"); Type *SrcTy = CI->getOperand(0)->getType(); @@ -2599,7 +2601,7 @@ static bool inlineGetBaseAndOffset(Function &F, DefiningValueMapTy &DVCache, IsKnownBaseMapTy &KnownBases) { auto &Context = F.getContext(); - auto &DL = F.getParent()->getDataLayout(); + auto &DL = F.getDataLayout(); bool Changed = false; for (auto *Callsite : Intrinsics) @@ -3044,8 +3046,7 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT, // which doesn't know how to produce a proper deopt state. So if we see a // non-leaf memcpy/memmove without deopt state just treat it as a leaf // copy and don't produce a statepoint. - if (!AllowStatepointWithNoDeoptInfo && - !Call->getOperandBundle(LLVMContext::OB_deopt)) { + if (!AllowStatepointWithNoDeoptInfo && !Call->hasDeoptState()) { assert((isa<AtomicMemCpyInst>(Call) || isa<AtomicMemMoveInst>(Call)) && "Don't expect any other calls here!"); return false; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SCCP.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SCCP.cpp index 8a491e74b91c..ce45c58e624e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SCCP.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SCCP.cpp @@ -119,7 +119,7 @@ static bool runSCCP(Function &F, const DataLayout &DL, } PreservedAnalyses SCCPPass::run(Function &F, FunctionAnalysisManager &AM) { - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SROA.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SROA.cpp index 17a94f9381bf..c738a2a6f39a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SROA.cpp @@ -116,10 +116,6 @@ STATISTIC( STATISTIC(NumDeleted, "Number of instructions deleted"); STATISTIC(NumVectorized, "Number of vectorized aggregates"); -/// Hidden option to experiment with completely strict handling of inbounds -/// GEPs. -static cl::opt<bool> SROAStrictInbounds("sroa-strict-inbounds", cl::init(false), - cl::Hidden); /// Disable running mem2reg during SROA in order to test or debug SROA. static cl::opt<bool> SROASkipMem2Reg("sroa-skip-mem2reg", cl::init(false), cl::Hidden); @@ -293,7 +289,7 @@ calculateFragment(DILocalVariable *Variable, if (!CurrentFragment) { if (auto Size = Variable->getSizeInBits()) { // Treat the current fragment as covering the whole variable. - CurrentFragment = DIExpression::FragmentInfo(*Size, 0); + CurrentFragment = DIExpression::FragmentInfo(*Size, 0); if (Target == CurrentFragment) return UseNoFrag; } @@ -319,28 +315,21 @@ static DebugVariable getAggregateVariable(DbgVariableIntrinsic *DVI) { return DebugVariable(DVI->getVariable(), std::nullopt, DVI->getDebugLoc().getInlinedAt()); } -static DebugVariable getAggregateVariable(DPValue *DPV) { - return DebugVariable(DPV->getVariable(), std::nullopt, - DPV->getDebugLoc().getInlinedAt()); +static DebugVariable getAggregateVariable(DbgVariableRecord *DVR) { + return DebugVariable(DVR->getVariable(), std::nullopt, + DVR->getDebugLoc().getInlinedAt()); } -static DPValue *createLinkedAssign(DPValue *, DIBuilder &DIB, - Instruction *LinkedInstr, Value *NewValue, - DILocalVariable *Variable, - DIExpression *Expression, Value *Address, - DIExpression *AddressExpression, - const DILocation *DI) { - (void)DIB; - return DPValue::createLinkedDPVAssign(LinkedInstr, NewValue, Variable, - Expression, Address, AddressExpression, - DI); +/// Helpers for handling new and old debug info modes in migrateDebugInfo. +/// These overloads unwrap a DbgInstPtr {Instruction* | DbgRecord*} union based +/// on the \p Unused parameter type. +DbgVariableRecord *UnwrapDbgInstPtr(DbgInstPtr P, DbgVariableRecord *Unused) { + (void)Unused; + return static_cast<DbgVariableRecord *>(cast<DbgRecord *>(P)); } -static DbgAssignIntrinsic *createLinkedAssign( - DbgAssignIntrinsic *, DIBuilder &DIB, Instruction *LinkedInstr, - Value *NewValue, DILocalVariable *Variable, DIExpression *Expression, - Value *Address, DIExpression *AddressExpression, const DILocation *DI) { - return DIB.insertDbgAssign(LinkedInstr, NewValue, Variable, Expression, - Address, AddressExpression, DI); +DbgAssignIntrinsic *UnwrapDbgInstPtr(DbgInstPtr P, DbgAssignIntrinsic *Unused) { + (void)Unused; + return static_cast<DbgAssignIntrinsic *>(cast<Instruction *>(P)); } /// Find linked dbg.assign and generate a new one with the correct @@ -363,9 +352,9 @@ static void migrateDebugInfo(AllocaInst *OldAlloca, bool IsSplit, Instruction *Inst, Value *Dest, Value *Value, const DataLayout &DL) { auto MarkerRange = at::getAssignmentMarkers(OldInst); - auto DPVAssignMarkerRange = at::getDPVAssignmentMarkers(OldInst); + auto DVRAssignMarkerRange = at::getDVRAssignmentMarkers(OldInst); // Nothing to do if OldInst has no linked dbg.assign intrinsics. - if (MarkerRange.empty() && DPVAssignMarkerRange.empty()) + if (MarkerRange.empty() && DVRAssignMarkerRange.empty()) return; LLVM_DEBUG(dbgs() << " migrateDebugInfo\n"); @@ -386,9 +375,9 @@ static void migrateDebugInfo(AllocaInst *OldAlloca, bool IsSplit, for (auto *DAI : at::getAssignmentMarkers(OldAlloca)) BaseFragments[getAggregateVariable(DAI)] = DAI->getExpression()->getFragmentInfo(); - for (auto *DPV : at::getDPVAssignmentMarkers(OldAlloca)) - BaseFragments[getAggregateVariable(DPV)] = - DPV->getExpression()->getFragmentInfo(); + for (auto *DVR : at::getDVRAssignmentMarkers(OldAlloca)) + BaseFragments[getAggregateVariable(DVR)] = + DVR->getExpression()->getFragmentInfo(); // The new inst needs a DIAssignID unique metadata tag (if OldInst has // one). It shouldn't already have one: assert this assumption. @@ -398,7 +387,7 @@ static void migrateDebugInfo(AllocaInst *OldAlloca, bool IsSplit, DIBuilder DIB(*OldInst->getModule(), /*AllowUnresolved*/ false); assert(OldAlloca->isStaticAlloca()); - auto MigrateDbgAssign = [&](auto DbgAssign) { + auto MigrateDbgAssign = [&](auto *DbgAssign) { LLVM_DEBUG(dbgs() << " existing dbg.assign is: " << *DbgAssign << "\n"); auto *Expr = DbgAssign->getExpression(); @@ -452,10 +441,12 @@ static void migrateDebugInfo(AllocaInst *OldAlloca, bool IsSplit, } ::Value *NewValue = Value ? Value : DbgAssign->getValue(); - auto *NewAssign = createLinkedAssign( - DbgAssign, DIB, Inst, NewValue, DbgAssign->getVariable(), Expr, Dest, - DIExpression::get(Expr->getContext(), std::nullopt), - DbgAssign->getDebugLoc()); + auto *NewAssign = UnwrapDbgInstPtr( + DIB.insertDbgAssign(Inst, NewValue, DbgAssign->getVariable(), Expr, + Dest, + DIExpression::get(Expr->getContext(), std::nullopt), + DbgAssign->getDebugLoc()), + DbgAssign); // If we've updated the value but the original dbg.assign has an arglist // then kill it now - we can't use the requested new value. @@ -493,7 +484,7 @@ static void migrateDebugInfo(AllocaInst *OldAlloca, bool IsSplit, }; for_each(MarkerRange, MigrateDbgAssign); - for_each(DPVAssignMarkerRange, MigrateDbgAssign); + for_each(DVRAssignMarkerRange, MigrateDbgAssign); } namespace { @@ -510,9 +501,9 @@ class IRBuilderPrefixedInserter final : public IRBuilderDefaultInserter { public: void SetNamePrefix(const Twine &P) { Prefix = P.str(); } - void InsertHelper(Instruction *I, const Twine &Name, BasicBlock *BB, + void InsertHelper(Instruction *I, const Twine &Name, BasicBlock::iterator InsertPt) const override { - IRBuilderDefaultInserter::InsertHelper(I, getNameWithPrefix(Name), BB, + IRBuilderDefaultInserter::InsertHelper(I, getNameWithPrefix(Name), InsertPt); } }; @@ -635,7 +626,7 @@ public: int OldSize = Slices.size(); Slices.append(NewSlices.begin(), NewSlices.end()); auto SliceI = Slices.begin() + OldSize; - llvm::sort(SliceI, Slices.end()); + std::stable_sort(SliceI, Slices.end()); std::inplace_merge(Slices.begin(), SliceI, Slices.end()); } @@ -1100,45 +1091,6 @@ private: if (GEPI.use_empty()) return markAsDead(GEPI); - if (SROAStrictInbounds && GEPI.isInBounds()) { - // FIXME: This is a manually un-factored variant of the basic code inside - // of GEPs with checking of the inbounds invariant specified in the - // langref in a very strict sense. If we ever want to enable - // SROAStrictInbounds, this code should be factored cleanly into - // PtrUseVisitor, but it is easier to experiment with SROAStrictInbounds - // by writing out the code here where we have the underlying allocation - // size readily available. - APInt GEPOffset = Offset; - const DataLayout &DL = GEPI.getModule()->getDataLayout(); - for (gep_type_iterator GTI = gep_type_begin(GEPI), - GTE = gep_type_end(GEPI); - GTI != GTE; ++GTI) { - ConstantInt *OpC = dyn_cast<ConstantInt>(GTI.getOperand()); - if (!OpC) - break; - - // Handle a struct index, which adds its field offset to the pointer. - if (StructType *STy = GTI.getStructTypeOrNull()) { - unsigned ElementIdx = OpC->getZExtValue(); - const StructLayout *SL = DL.getStructLayout(STy); - GEPOffset += - APInt(Offset.getBitWidth(), SL->getElementOffset(ElementIdx)); - } else { - // For array or vector indices, scale the index by the size of the - // type. - APInt Index = OpC->getValue().sextOrTrunc(Offset.getBitWidth()); - GEPOffset += Index * APInt(Offset.getBitWidth(), - GTI.getSequentialElementStride(DL)); - } - - // If this index has computed an intermediate pointer which is not - // inbounds, then the result of the GEP is a poison value and we can - // delete it and all uses. - if (GEPOffset.ugt(AllocSize)) - return markAsDead(GEPI); - } - } - return Base::visitGetElementPtrInst(GEPI); } @@ -1213,8 +1165,9 @@ private: if (!IsOffsetKnown) return PI.setAborted(&II); - insertUse(II, Offset, Length ? Length->getLimitedValue() - : AllocSize - Offset.getLimitedValue(), + insertUse(II, Offset, + Length ? Length->getLimitedValue() + : AllocSize - Offset.getLimitedValue(), (bool)Length); } @@ -1327,7 +1280,7 @@ private: SmallVector<std::pair<Instruction *, Instruction *>, 4> Uses; Visited.insert(Root); Uses.push_back(std::make_pair(cast<Instruction>(*U), Root)); - const DataLayout &DL = Root->getModule()->getDataLayout(); + const DataLayout &DL = Root->getDataLayout(); // If there are no loads or stores, the access is dead. We mark that as // a size zero access. Size = 0; @@ -1574,7 +1527,7 @@ findCommonType(AllocaSlices::const_iterator B, AllocaSlices::const_iterator E, /// FIXME: This should be hoisted into a generic utility, likely in /// Transforms/Util/Local.h static bool isSafePHIToSpeculate(PHINode &PN) { - const DataLayout &DL = PN.getModule()->getDataLayout(); + const DataLayout &DL = PN.getDataLayout(); // For now, we can only do this promotion if the load is in the same block // as the PHI, and if there are no stores between the phi and load. @@ -1669,7 +1622,7 @@ static void speculatePHINodeLoads(IRBuilderTy &IRB, PHINode &PN) { } // Inject loads into all of the pred blocks. - DenseMap<BasicBlock*, Value*> InjectedLoads; + DenseMap<BasicBlock *, Value *> InjectedLoads; for (unsigned Idx = 0, Num = PN.getNumIncomingValues(); Idx != Num; ++Idx) { BasicBlock *Pred = PN.getIncomingBlock(Idx); Value *InVal = PN.getIncomingValue(Idx); @@ -1678,7 +1631,7 @@ static void speculatePHINodeLoads(IRBuilderTy &IRB, PHINode &PN) { // basic block, as long as the value is the same. So if we already injected // a load in the predecessor, then we should reuse the same load for all // duplicated entries. - if (Value* V = InjectedLoads.lookup(Pred)) { + if (Value *V = InjectedLoads.lookup(Pred)) { NewPN->addIncoming(V, Pred); continue; } @@ -1732,7 +1685,7 @@ isSafeLoadOfSelectToSpeculate(LoadInst &LI, SelectInst &SI, bool PreserveCFG) { assert(LI.isSimple() && "Only for simple loads"); SelectHandSpeculativity Spec; - const DataLayout &DL = SI.getModule()->getDataLayout(); + const DataLayout &DL = SI.getDataLayout(); for (Value *Value : {SI.getTrueValue(), SI.getFalseValue()}) if (isSafeToLoadUnconditionally(Value, LI.getType(), LI.getAlign(), DL, &LI)) @@ -1852,7 +1805,7 @@ static void rewriteMemOpOfSelect(SelectInst &SI, T &I, Tail->setName(Head->getName() + ".cont"); PHINode *PN; if (isa<LoadInst>(I)) - PN = PHINode::Create(I.getType(), 2, "", &I); + PN = PHINode::Create(I.getType(), 2, "", I.getIterator()); for (BasicBlock *SuccBB : successors(Head)) { bool IsThen = SuccBB == HeadBI->getSuccessor(0); int SuccIdx = IsThen ? 0 : 1; @@ -2077,8 +2030,7 @@ static bool isVectorPromotionViableForSlice(Partition &P, const Slice &S, if (BeginIndex * ElementSize != BeginOffset || BeginIndex >= cast<FixedVectorType>(Ty)->getNumElements()) return false; - uint64_t EndOffset = - std::min(S.endOffset(), P.endOffset()) - P.beginOffset(); + uint64_t EndOffset = std::min(S.endOffset(), P.endOffset()) - P.beginOffset(); uint64_t EndIndex = EndOffset / ElementSize; if (EndIndex * ElementSize != EndOffset || EndIndex > cast<FixedVectorType>(Ty)->getNumElements()) @@ -2226,8 +2178,7 @@ checkVectorTypesForPromotion(Partition &P, const DataLayout &DL, cast<FixedVectorType>(LHSTy)->getNumElements(); }; llvm::sort(CandidateTys, RankVectorTypesComp); - CandidateTys.erase(std::unique(CandidateTys.begin(), CandidateTys.end(), - RankVectorTypesEq), + CandidateTys.erase(llvm::unique(CandidateTys, RankVectorTypesEq), CandidateTys.end()); } else { // The only way to have the same element type in every vector type is to @@ -2780,8 +2731,8 @@ public: Instruction *OldUserI = cast<Instruction>(OldUse->getUser()); IRB.SetInsertPoint(OldUserI); IRB.SetCurrentDebugLocation(OldUserI->getDebugLoc()); - IRB.getInserter().SetNamePrefix( - Twine(NewAI.getName()) + "." + Twine(BeginOffset) + "."); + IRB.getInserter().SetNamePrefix(Twine(NewAI.getName()) + "." + + Twine(BeginOffset) + "."); CanSROA &= visit(cast<Instruction>(OldUse->getUser())); if (VecTy || IntTy) @@ -2834,7 +2785,7 @@ private: #else Twine() #endif - ); + ); } /// Compute suitable alignment to access this slice of the *new* @@ -2940,7 +2891,8 @@ private: // Do this after copyMetadataForLoad() to preserve the TBAA shift. if (AATags) - NewLI->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); + NewLI->setAAMetadata(AATags.adjustForAccess( + NewBeginOffset - BeginOffset, NewLI->getType(), DL)); // Try to preserve nonnull metadata V = NewLI; @@ -2961,8 +2913,11 @@ private: LoadInst *NewLI = IRB.CreateAlignedLoad(TargetTy, getNewAllocaSlicePtr(IRB, LTy), getSliceAlign(), LI.isVolatile(), LI.getName()); + if (AATags) - NewLI->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); + NewLI->setAAMetadata(AATags.adjustForAccess( + NewBeginOffset - BeginOffset, NewLI->getType(), DL)); + if (LI.isVolatile()) NewLI->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); NewLI->copyMetadata(LI, {LLVMContext::MD_mem_parallel_loop_access, @@ -2982,7 +2937,12 @@ private: assert(DL.typeSizeEqualsStoreSize(LI.getType()) && "Non-byte-multiple bit width"); // Move the insertion point just past the load so that we can refer to it. - IRB.SetInsertPoint(&*std::next(BasicBlock::iterator(&LI))); + BasicBlock::iterator LIIt = std::next(LI.getIterator()); + // Ensure the insertion point comes before any debug-info immediately + // after the load, so that variable values referring to the load are + // dominated by it. + LIIt.setHeadBit(true); + IRB.SetInsertPoint(LI.getParent(), LIIt); // Create a placeholder value with the same type as LI to use as the // basis for the new value. This allows us to replace the uses of LI with // the computed value, and then replace the placeholder with LI, leaving @@ -3032,7 +2992,8 @@ private: Store->copyMetadata(SI, {LLVMContext::MD_mem_parallel_loop_access, LLVMContext::MD_access_group}); if (AATags) - Store->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); + Store->setAAMetadata(AATags.adjustForAccess(NewBeginOffset - BeginOffset, + V->getType(), DL)); Pass.DeadInsts.push_back(&SI); // NOTE: Careful to use OrigV rather than V. @@ -3059,7 +3020,8 @@ private: Store->copyMetadata(SI, {LLVMContext::MD_mem_parallel_loop_access, LLVMContext::MD_access_group}); if (AATags) - Store->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); + Store->setAAMetadata(AATags.adjustForAccess(NewBeginOffset - BeginOffset, + V->getType(), DL)); migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &SI, Store, Store->getPointerOperand(), @@ -3119,7 +3081,8 @@ private: NewSI->copyMetadata(SI, {LLVMContext::MD_mem_parallel_loop_access, LLVMContext::MD_access_group}); if (AATags) - NewSI->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); + NewSI->setAAMetadata(AATags.adjustForAccess(NewBeginOffset - BeginOffset, + V->getType(), DL)); if (SI.isVolatile()) NewSI->setAtomic(SI.getOrdering(), SI.getSyncScopeID()); if (NewSI->isAtomic()) @@ -3188,7 +3151,7 @@ private: // emit dbg.assign intrinsics for mem intrinsics storing through non- // constant geps, or storing a variable number of bytes. assert(at::getAssignmentMarkers(&II).empty() && - at::getDPVAssignmentMarkers(&II).empty() && + at::getDVRAssignmentMarkers(&II).empty() && "AT: Unexpected link to non-const GEP"); deleteIfTriviallyDead(OldPtr); return false; @@ -3203,8 +3166,7 @@ private: const bool CanContinue = [&]() { if (VecTy || IntTy) return true; - if (BeginOffset > NewAllocaBeginOffset || - EndOffset < NewAllocaEndOffset) + if (BeginOffset > NewAllocaBeginOffset || EndOffset < NewAllocaEndOffset) return false; // Length must be in range for FixedVectorType. auto *C = cast<ConstantInt>(II.getLength()); @@ -3221,12 +3183,14 @@ private: // a single value type, just emit a memset. if (!CanContinue) { Type *SizeTy = II.getLength()->getType(); - Constant *Size = ConstantInt::get(SizeTy, NewEndOffset - NewBeginOffset); + unsigned Sz = NewEndOffset - NewBeginOffset; + Constant *Size = ConstantInt::get(SizeTy, Sz); MemIntrinsic *New = cast<MemIntrinsic>(IRB.CreateMemSet( getNewAllocaSlicePtr(IRB, OldPtr->getType()), II.getValue(), Size, MaybeAlign(getSliceAlign()), II.isVolatile())); if (AATags) - New->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); + New->setAAMetadata( + AATags.adjustForAccess(NewBeginOffset - BeginOffset, Sz)); migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &II, New, New->getRawDest(), nullptr, DL); @@ -3302,7 +3266,8 @@ private: New->copyMetadata(II, {LLVMContext::MD_mem_parallel_loop_access, LLVMContext::MD_access_group}); if (AATags) - New->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); + New->setAAMetadata(AATags.adjustForAccess(NewBeginOffset - BeginOffset, + V->getType(), DL)); migrateDebugInfo(&OldAI, IsSplit, NewBeginOffset * 8, SliceSize * 8, &II, New, New->getPointerOperand(), V, DL); @@ -3341,7 +3306,7 @@ private: DbgAssign->replaceVariableLocationOp(II.getDest(), AdjustedPtr); }; for_each(at::getAssignmentMarkers(&II), UpdateAssignAddress); - for_each(at::getDPVAssignmentMarkers(&II), UpdateAssignAddress); + for_each(at::getDVRAssignmentMarkers(&II), UpdateAssignAddress); II.setDest(AdjustedPtr); II.setDestAlignment(SliceAlign); } else { @@ -3507,7 +3472,8 @@ private: Load->copyMetadata(II, {LLVMContext::MD_mem_parallel_loop_access, LLVMContext::MD_access_group}); if (AATags) - Load->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); + Load->setAAMetadata(AATags.adjustForAccess(NewBeginOffset - BeginOffset, + Load->getType(), DL)); Src = Load; } @@ -3529,7 +3495,8 @@ private: Store->copyMetadata(II, {LLVMContext::MD_mem_parallel_loop_access, LLVMContext::MD_access_group}); if (AATags) - Store->setAAMetadata(AATags.shift(NewBeginOffset - BeginOffset)); + Store->setAAMetadata(AATags.adjustForAccess(NewBeginOffset - BeginOffset, + Src->getType(), DL)); APInt Offset(DL.getIndexTypeSizeInBits(DstPtr->getType()), 0); if (IsDest) { @@ -3857,7 +3824,8 @@ private: DL.getIndexSizeInBits(Ptr->getType()->getPointerAddressSpace()), 0); if (AATags && GEPOperator::accumulateConstantOffset(BaseTy, GEPIndices, DL, Offset)) - Load->setAAMetadata(AATags.shift(Offset.getZExtValue())); + Load->setAAMetadata( + AATags.adjustForAccess(Offset.getZExtValue(), Load->getType(), DL)); Agg = IRB.CreateInsertValue(Agg, Load, Indices, Name + ".insert"); LLVM_DEBUG(dbgs() << " to: " << *Load << "\n"); @@ -3908,8 +3876,10 @@ private: APInt Offset( DL.getIndexSizeInBits(Ptr->getType()->getPointerAddressSpace()), 0); GEPOperator::accumulateConstantOffset(BaseTy, GEPIndices, DL, Offset); - if (AATags) - Store->setAAMetadata(AATags.shift(Offset.getZExtValue())); + if (AATags) { + Store->setAAMetadata(AATags.adjustForAccess( + Offset.getZExtValue(), ExtractValue->getType(), DL)); + } // migrateDebugInfo requires the base Alloca. Walk to it from this gep. // If we cannot (because there's an intervening non-const or unbounded @@ -3925,7 +3895,7 @@ private: DL); } else { assert(at::getAssignmentMarkers(Store).empty() && - at::getDPVAssignmentMarkers(Store).empty() && + at::getDVRAssignmentMarkers(Store).empty() && "AT: unexpected debug.assign linked to store through " "unbounded GEP"); } @@ -3963,30 +3933,62 @@ private: return false; } - // Fold gep (select cond, ptr1, ptr2) => select cond, gep(ptr1), gep(ptr2) - bool foldGEPSelect(GetElementPtrInst &GEPI) { - if (!GEPI.hasAllConstantIndices()) + // Unfold gep (select cond, ptr1, ptr2), idx + // => select cond, gep(ptr1, idx), gep(ptr2, idx) + // and gep ptr, (select cond, idx1, idx2) + // => select cond, gep(ptr, idx1), gep(ptr, idx2) + bool unfoldGEPSelect(GetElementPtrInst &GEPI) { + // Check whether the GEP has exactly one select operand and all indices + // will become constant after the transform. + SelectInst *Sel = dyn_cast<SelectInst>(GEPI.getPointerOperand()); + for (Value *Op : GEPI.indices()) { + if (auto *SI = dyn_cast<SelectInst>(Op)) { + if (Sel) + return false; + + Sel = SI; + if (!isa<ConstantInt>(Sel->getTrueValue()) || + !isa<ConstantInt>(Sel->getFalseValue())) + return false; + continue; + } + + if (!isa<ConstantInt>(Op)) + return false; + } + + if (!Sel) return false; - SelectInst *Sel = cast<SelectInst>(GEPI.getPointerOperand()); + LLVM_DEBUG(dbgs() << " Rewriting gep(select) -> select(gep):\n"; + dbgs() << " original: " << *Sel << "\n"; + dbgs() << " " << GEPI << "\n";); + + auto GetNewOps = [&](Value *SelOp) { + SmallVector<Value *> NewOps; + for (Value *Op : GEPI.operands()) + if (Op == Sel) + NewOps.push_back(SelOp); + else + NewOps.push_back(Op); + return NewOps; + }; - LLVM_DEBUG(dbgs() << " Rewriting gep(select) -> select(gep):" - << "\n original: " << *Sel - << "\n " << GEPI); + Value *True = Sel->getTrueValue(); + Value *False = Sel->getFalseValue(); + SmallVector<Value *> TrueOps = GetNewOps(True); + SmallVector<Value *> FalseOps = GetNewOps(False); IRB.SetInsertPoint(&GEPI); - SmallVector<Value *, 4> Index(GEPI.indices()); - bool IsInBounds = GEPI.isInBounds(); + GEPNoWrapFlags NW = GEPI.getNoWrapFlags(); Type *Ty = GEPI.getSourceElementType(); - Value *True = Sel->getTrueValue(); - Value *NTrue = IRB.CreateGEP(Ty, True, Index, True->getName() + ".sroa.gep", - IsInBounds); - - Value *False = Sel->getFalseValue(); + Value *NTrue = IRB.CreateGEP(Ty, TrueOps[0], ArrayRef(TrueOps).drop_front(), + True->getName() + ".sroa.gep", NW); - Value *NFalse = IRB.CreateGEP(Ty, False, Index, - False->getName() + ".sroa.gep", IsInBounds); + Value *NFalse = + IRB.CreateGEP(Ty, FalseOps[0], ArrayRef(FalseOps).drop_front(), + False->getName() + ".sroa.gep", NW); Value *NSel = IRB.CreateSelect(Sel->getCondition(), NTrue, NFalse, Sel->getName() + ".sroa.sel"); @@ -3997,75 +3999,114 @@ private: Visited.insert(NSelI); enqueueUsers(*NSelI); - LLVM_DEBUG(dbgs() << "\n to: " << *NTrue - << "\n " << *NFalse - << "\n " << *NSel << '\n'); + LLVM_DEBUG(dbgs() << " to: " << *NTrue << "\n"; + dbgs() << " " << *NFalse << "\n"; + dbgs() << " " << *NSel << "\n";); return true; } - // Fold gep (phi ptr1, ptr2) => phi gep(ptr1), gep(ptr2) - bool foldGEPPhi(GetElementPtrInst &GEPI) { - if (!GEPI.hasAllConstantIndices()) - return false; + // Unfold gep (phi ptr1, ptr2), idx + // => phi ((gep ptr1, idx), (gep ptr2, idx)) + // and gep ptr, (phi idx1, idx2) + // => phi ((gep ptr, idx1), (gep ptr, idx2)) + bool unfoldGEPPhi(GetElementPtrInst &GEPI) { + // To prevent infinitely expanding recursive phis, bail if the GEP pointer + // operand (looking through the phi if it is the phi we want to unfold) is + // an instruction besides a static alloca. + PHINode *Phi = dyn_cast<PHINode>(GEPI.getPointerOperand()); + auto IsInvalidPointerOperand = [](Value *V) { + if (!isa<Instruction>(V)) + return false; + if (auto *AI = dyn_cast<AllocaInst>(V)) + return !AI->isStaticAlloca(); + return true; + }; + if (Phi) { + if (any_of(Phi->operands(), IsInvalidPointerOperand)) + return false; + } else { + if (IsInvalidPointerOperand(GEPI.getPointerOperand())) + return false; + } + // Check whether the GEP has exactly one phi operand (including the pointer + // operand) and all indices will become constant after the transform. + for (Value *Op : GEPI.indices()) { + if (auto *SI = dyn_cast<PHINode>(Op)) { + if (Phi) + return false; + + Phi = SI; + if (!all_of(Phi->incoming_values(), + [](Value *V) { return isa<ConstantInt>(V); })) + return false; + continue; + } + + if (!isa<ConstantInt>(Op)) + return false; + } - PHINode *PHI = cast<PHINode>(GEPI.getPointerOperand()); - if (GEPI.getParent() != PHI->getParent() || - llvm::any_of(PHI->incoming_values(), [](Value *In) - { Instruction *I = dyn_cast<Instruction>(In); - return !I || isa<GetElementPtrInst>(I) || isa<PHINode>(I) || - succ_empty(I->getParent()) || - !I->getParent()->isLegalToHoistInto(); - })) + if (!Phi) return false; - LLVM_DEBUG(dbgs() << " Rewriting gep(phi) -> phi(gep):" - << "\n original: " << *PHI - << "\n " << GEPI - << "\n to: "); - - SmallVector<Value *, 4> Index(GEPI.indices()); - bool IsInBounds = GEPI.isInBounds(); - IRB.SetInsertPoint(GEPI.getParent(), GEPI.getParent()->getFirstNonPHIIt()); - PHINode *NewPN = IRB.CreatePHI(GEPI.getType(), PHI->getNumIncomingValues(), - PHI->getName() + ".sroa.phi"); - for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { - BasicBlock *B = PHI->getIncomingBlock(I); - Value *NewVal = nullptr; - int Idx = NewPN->getBasicBlockIndex(B); - if (Idx >= 0) { - NewVal = NewPN->getIncomingValue(Idx); - } else { - Instruction *In = cast<Instruction>(PHI->getIncomingValue(I)); + LLVM_DEBUG(dbgs() << " Rewriting gep(phi) -> phi(gep):\n"; + dbgs() << " original: " << *Phi << "\n"; + dbgs() << " " << GEPI << "\n";); + + auto GetNewOps = [&](Value *PhiOp) { + SmallVector<Value *> NewOps; + for (Value *Op : GEPI.operands()) + if (Op == Phi) + NewOps.push_back(PhiOp); + else + NewOps.push_back(Op); + return NewOps; + }; - IRB.SetInsertPoint(In->getParent(), std::next(In->getIterator())); - Type *Ty = GEPI.getSourceElementType(); - NewVal = IRB.CreateGEP(Ty, In, Index, In->getName() + ".sroa.gep", - IsInBounds); + IRB.SetInsertPoint(Phi); + PHINode *NewPhi = IRB.CreatePHI(GEPI.getType(), Phi->getNumIncomingValues(), + Phi->getName() + ".sroa.phi"); + + Type *SourceTy = GEPI.getSourceElementType(); + // We only handle arguments, constants, and static allocas here, so we can + // insert GEPs at the end of the entry block. + IRB.SetInsertPoint(GEPI.getFunction()->getEntryBlock().getTerminator()); + for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I) { + Value *Op = Phi->getIncomingValue(I); + BasicBlock *BB = Phi->getIncomingBlock(I); + Value *NewGEP; + if (int NI = NewPhi->getBasicBlockIndex(BB); NI >= 0) { + NewGEP = NewPhi->getIncomingValue(NI); + } else { + SmallVector<Value *> NewOps = GetNewOps(Op); + NewGEP = + IRB.CreateGEP(SourceTy, NewOps[0], ArrayRef(NewOps).drop_front(), + Phi->getName() + ".sroa.gep", GEPI.getNoWrapFlags()); } - NewPN->addIncoming(NewVal, B); + NewPhi->addIncoming(NewGEP, BB); } Visited.erase(&GEPI); - GEPI.replaceAllUsesWith(NewPN); + GEPI.replaceAllUsesWith(NewPhi); GEPI.eraseFromParent(); - Visited.insert(NewPN); - enqueueUsers(*NewPN); + Visited.insert(NewPhi); + enqueueUsers(*NewPhi); - LLVM_DEBUG(for (Value *In : NewPN->incoming_values()) - dbgs() << "\n " << *In; - dbgs() << "\n " << *NewPN << '\n'); + LLVM_DEBUG(dbgs() << " to: "; + for (Value *In + : NewPhi->incoming_values()) dbgs() + << "\n " << *In; + dbgs() << "\n " << *NewPhi << '\n'); return true; } bool visitGetElementPtrInst(GetElementPtrInst &GEPI) { - if (isa<SelectInst>(GEPI.getPointerOperand()) && - foldGEPSelect(GEPI)) + if (unfoldGEPSelect(GEPI)) return true; - if (isa<PHINode>(GEPI.getPointerOperand()) && - foldGEPPhi(GEPI)) + if (unfoldGEPPhi(GEPI)) return true; enqueueUsers(GEPI); @@ -4137,17 +4178,17 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset, return nullptr; if (isa<ArrayType>(Ty) || isa<VectorType>(Ty)) { - Type *ElementTy; - uint64_t TyNumElements; - if (auto *AT = dyn_cast<ArrayType>(Ty)) { - ElementTy = AT->getElementType(); - TyNumElements = AT->getNumElements(); - } else { - // FIXME: This isn't right for vectors with non-byte-sized or - // non-power-of-two sized elements. - auto *VT = cast<FixedVectorType>(Ty); - ElementTy = VT->getElementType(); - TyNumElements = VT->getNumElements(); + Type *ElementTy; + uint64_t TyNumElements; + if (auto *AT = dyn_cast<ArrayType>(Ty)) { + ElementTy = AT->getElementType(); + TyNumElements = AT->getNumElements(); + } else { + // FIXME: This isn't right for vectors with non-byte-sized or + // non-power-of-two sized elements. + auto *VT = cast<FixedVectorType>(Ty); + ElementTy = VT->getElementType(); + TyNumElements = VT->getNumElements(); } uint64_t ElementSize = DL.getTypeAllocSize(ElementTy).getFixedValue(); uint64_t NumSkippedElements = Offset / ElementSize; @@ -4458,7 +4499,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { // them to the alloca slices. SmallDenseMap<LoadInst *, std::vector<LoadInst *>, 1> SplitLoadsMap; std::vector<LoadInst *> SplitLoads; - const DataLayout &DL = AI.getModule()->getDataLayout(); + const DataLayout &DL = AI.getDataLayout(); for (LoadInst *LI : Loads) { SplitLoads.clear(); @@ -4532,6 +4573,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { Value *StoreBasePtr = SI->getPointerOperand(); IRB.SetInsertPoint(SI); + AAMDNodes AATags = SI->getAAMetadata(); LLVM_DEBUG(dbgs() << " Splitting store of load: " << *SI << "\n"); @@ -4551,6 +4593,10 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { PStore->copyMetadata(*SI, {LLVMContext::MD_mem_parallel_loop_access, LLVMContext::MD_access_group, LLVMContext::MD_DIAssignID}); + + if (AATags) + PStore->setAAMetadata( + AATags.adjustForAccess(PartOffset, PLoad->getType(), DL)); LLVM_DEBUG(dbgs() << " +" << PartOffset << ":" << *PStore << "\n"); } @@ -4747,7 +4793,7 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, // or an i8 array of an appropriate size. Type *SliceTy = nullptr; VectorType *SliceVecTy = nullptr; - const DataLayout &DL = AI.getModule()->getDataLayout(); + const DataLayout &DL = AI.getDataLayout(); std::pair<Type *, IntegerType *> CommonUseTy = findCommonType(P.begin(), P.end(), P.endOffset()); // Do all uses operate on the same type? @@ -4817,15 +4863,15 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, NewAI = new AllocaInst( SliceTy, AI.getAddressSpace(), nullptr, IsUnconstrained ? DL.getPrefTypeAlign(SliceTy) : Alignment, - AI.getName() + ".sroa." + Twine(P.begin() - AS.begin()), &AI); + AI.getName() + ".sroa." + Twine(P.begin() - AS.begin()), + AI.getIterator()); // Copy the old AI debug location over to the new one. NewAI->setDebugLoc(AI.getDebugLoc()); ++NumNewAllocas; } - LLVM_DEBUG(dbgs() << "Rewriting alloca partition " - << "[" << P.beginOffset() << "," << P.endOffset() - << ") to: " << *NewAI << "\n"); + LLVM_DEBUG(dbgs() << "Rewriting alloca partition " << "[" << P.beginOffset() + << "," << P.endOffset() << ") to: " << *NewAI << "\n"); // Track the high watermark on the worklist as it is only relevant for // promoted allocas. We will reset it to this point if the alloca is not in @@ -4921,45 +4967,236 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, return NewAI; } -static void insertNewDbgInst(DIBuilder &DIB, DbgDeclareInst *Orig, - AllocaInst *NewAddr, DIExpression *NewFragmentExpr, - Instruction *BeforeInst) { - DIB.insertDeclare(NewAddr, Orig->getVariable(), NewFragmentExpr, +// There isn't a shared interface to get the "address" parts out of a +// dbg.declare and dbg.assign, so provide some wrappers now for +// both debug intrinsics and records. +const Value *getAddress(const DbgVariableIntrinsic *DVI) { + if (const auto *DAI = dyn_cast<DbgAssignIntrinsic>(DVI)) + return DAI->getAddress(); + return cast<DbgDeclareInst>(DVI)->getAddress(); +} + +const Value *getAddress(const DbgVariableRecord *DVR) { + assert(DVR->getType() == DbgVariableRecord::LocationType::Declare || + DVR->getType() == DbgVariableRecord::LocationType::Assign); + return DVR->getAddress(); +} + +bool isKillAddress(const DbgVariableIntrinsic *DVI) { + if (const auto *DAI = dyn_cast<DbgAssignIntrinsic>(DVI)) + return DAI->isKillAddress(); + return cast<DbgDeclareInst>(DVI)->isKillLocation(); +} + +bool isKillAddress(const DbgVariableRecord *DVR) { + assert(DVR->getType() == DbgVariableRecord::LocationType::Declare || + DVR->getType() == DbgVariableRecord::LocationType::Assign); + if (DVR->getType() == DbgVariableRecord::LocationType::Assign) + return DVR->isKillAddress(); + return DVR->isKillLocation(); +} + +const DIExpression *getAddressExpression(const DbgVariableIntrinsic *DVI) { + if (const auto *DAI = dyn_cast<DbgAssignIntrinsic>(DVI)) + return DAI->getAddressExpression(); + return cast<DbgDeclareInst>(DVI)->getExpression(); +} + +const DIExpression *getAddressExpression(const DbgVariableRecord *DVR) { + assert(DVR->getType() == DbgVariableRecord::LocationType::Declare || + DVR->getType() == DbgVariableRecord::LocationType::Assign); + if (DVR->getType() == DbgVariableRecord::LocationType::Assign) + return DVR->getAddressExpression(); + return DVR->getExpression(); +} + +/// Create or replace an existing fragment in a DIExpression with \p Frag. +/// If the expression already contains a DW_OP_LLVM_extract_bits_[sz]ext +/// operation, add \p BitExtractOffset to the offset part. +/// +/// Returns the new expression, or nullptr if this fails (see details below). +/// +/// This function is similar to DIExpression::createFragmentExpression except +/// for 3 important distinctions: +/// 1. The new fragment isn't relative to an existing fragment. +/// 2. It assumes the computed location is a memory location. This means we +/// don't need to perform checks that creating the fragment preserves the +/// expression semantics. +/// 3. Existing extract_bits are modified independently of fragment changes +/// using \p BitExtractOffset. A change to the fragment offset or size +/// may affect a bit extract. But a bit extract offset can change +/// independently of the fragment dimensions. +/// +/// Returns the new expression, or nullptr if one couldn't be created. +/// Ideally this is only used to signal that a bit-extract has become +/// zero-sized (and thus the new debug record has no size and can be +/// dropped), however, it fails for other reasons too - see the FIXME below. +/// +/// FIXME: To keep the change that introduces this function NFC it bails +/// in some situations unecessarily, e.g. when fragment and bit extract +/// sizes differ. +static DIExpression *createOrReplaceFragment(const DIExpression *Expr, + DIExpression::FragmentInfo Frag, + int64_t BitExtractOffset) { + SmallVector<uint64_t, 8> Ops; + bool HasFragment = false; + bool HasBitExtract = false; + + for (auto &Op : Expr->expr_ops()) { + if (Op.getOp() == dwarf::DW_OP_LLVM_fragment) { + HasFragment = true; + continue; + } + if (Op.getOp() == dwarf::DW_OP_LLVM_extract_bits_zext || + Op.getOp() == dwarf::DW_OP_LLVM_extract_bits_sext) { + HasBitExtract = true; + int64_t ExtractOffsetInBits = Op.getArg(0); + int64_t ExtractSizeInBits = Op.getArg(1); + + // DIExpression::createFragmentExpression doesn't know how to handle + // a fragment that is smaller than the extract. Copy the behaviour + // (bail) to avoid non-NFC changes. + // FIXME: Don't do this. + if (Frag.SizeInBits < uint64_t(ExtractSizeInBits)) + return nullptr; + + assert(BitExtractOffset <= 0); + int64_t AdjustedOffset = ExtractOffsetInBits + BitExtractOffset; + + // DIExpression::createFragmentExpression doesn't know what to do + // if the new extract starts "outside" the existing one. Copy the + // behaviour (bail) to avoid non-NFC changes. + // FIXME: Don't do this. + if (AdjustedOffset < 0) + return nullptr; + + Ops.push_back(Op.getOp()); + Ops.push_back(std::max<int64_t>(0, AdjustedOffset)); + Ops.push_back(ExtractSizeInBits); + continue; + } + Op.appendToVector(Ops); + } + + // Unsupported by createFragmentExpression, so don't support it here yet to + // preserve NFC-ness. + if (HasFragment && HasBitExtract) + return nullptr; + + if (!HasBitExtract) { + Ops.push_back(dwarf::DW_OP_LLVM_fragment); + Ops.push_back(Frag.OffsetInBits); + Ops.push_back(Frag.SizeInBits); + } + return DIExpression::get(Expr->getContext(), Ops); +} + +/// Insert a new dbg.declare. +/// \p Orig Original to copy debug loc and variable from. +/// \p NewAddr Location's new base address. +/// \p NewAddrExpr New expression to apply to address. +/// \p BeforeInst Insert position. +/// \p NewFragment New fragment (absolute, non-relative). +/// \p BitExtractAdjustment Offset to apply to any extract_bits op. +static void +insertNewDbgInst(DIBuilder &DIB, DbgDeclareInst *Orig, AllocaInst *NewAddr, + DIExpression *NewAddrExpr, Instruction *BeforeInst, + std::optional<DIExpression::FragmentInfo> NewFragment, + int64_t BitExtractAdjustment) { + if (NewFragment) + NewAddrExpr = createOrReplaceFragment(NewAddrExpr, *NewFragment, + BitExtractAdjustment); + if (!NewAddrExpr) + return; + + DIB.insertDeclare(NewAddr, Orig->getVariable(), NewAddrExpr, Orig->getDebugLoc(), BeforeInst); } -static void insertNewDbgInst(DIBuilder &DIB, DbgAssignIntrinsic *Orig, - AllocaInst *NewAddr, DIExpression *NewFragmentExpr, - Instruction *BeforeInst) { + +/// Insert a new dbg.assign. +/// \p Orig Original to copy debug loc, variable, value and value expression +/// from. +/// \p NewAddr Location's new base address. +/// \p NewAddrExpr New expression to apply to address. +/// \p BeforeInst Insert position. +/// \p NewFragment New fragment (absolute, non-relative). +/// \p BitExtractAdjustment Offset to apply to any extract_bits op. +static void +insertNewDbgInst(DIBuilder &DIB, DbgAssignIntrinsic *Orig, AllocaInst *NewAddr, + DIExpression *NewAddrExpr, Instruction *BeforeInst, + std::optional<DIExpression::FragmentInfo> NewFragment, + int64_t BitExtractAdjustment) { + // DIBuilder::insertDbgAssign will insert the #dbg_assign after NewAddr. (void)BeforeInst; + + // A dbg.assign puts fragment info in the value expression only. The address + // expression has already been built: NewAddrExpr. + DIExpression *NewFragmentExpr = Orig->getExpression(); + if (NewFragment) + NewFragmentExpr = createOrReplaceFragment(NewFragmentExpr, *NewFragment, + BitExtractAdjustment); + if (!NewFragmentExpr) + return; + + // Apply a DIAssignID to the store if it doesn't already have it. if (!NewAddr->hasMetadata(LLVMContext::MD_DIAssignID)) { NewAddr->setMetadata(LLVMContext::MD_DIAssignID, DIAssignID::getDistinct(NewAddr->getContext())); } - auto *NewAssign = DIB.insertDbgAssign( - NewAddr, Orig->getValue(), Orig->getVariable(), NewFragmentExpr, NewAddr, - Orig->getAddressExpression(), Orig->getDebugLoc()); + + Instruction *NewAssign = + DIB.insertDbgAssign(NewAddr, Orig->getValue(), Orig->getVariable(), + NewFragmentExpr, NewAddr, NewAddrExpr, + Orig->getDebugLoc()) + .get<Instruction *>(); LLVM_DEBUG(dbgs() << "Created new assign intrinsic: " << *NewAssign << "\n"); (void)NewAssign; } -static void insertNewDbgInst(DIBuilder &DIB, DPValue *Orig, AllocaInst *NewAddr, - DIExpression *NewFragmentExpr, - Instruction *BeforeInst) { + +/// Insert a new DbgRecord. +/// \p Orig Original to copy record type, debug loc and variable from, and +/// additionally value and value expression for dbg_assign records. +/// \p NewAddr Location's new base address. +/// \p NewAddrExpr New expression to apply to address. +/// \p BeforeInst Insert position. +/// \p NewFragment New fragment (absolute, non-relative). +/// \p BitExtractAdjustment Offset to apply to any extract_bits op. +static void +insertNewDbgInst(DIBuilder &DIB, DbgVariableRecord *Orig, AllocaInst *NewAddr, + DIExpression *NewAddrExpr, Instruction *BeforeInst, + std::optional<DIExpression::FragmentInfo> NewFragment, + int64_t BitExtractAdjustment) { (void)DIB; + + // A dbg_assign puts fragment info in the value expression only. The address + // expression has already been built: NewAddrExpr. A dbg_declare puts the + // new fragment info into NewAddrExpr (as it only has one expression). + DIExpression *NewFragmentExpr = + Orig->isDbgAssign() ? Orig->getExpression() : NewAddrExpr; + if (NewFragment) + NewFragmentExpr = createOrReplaceFragment(NewFragmentExpr, *NewFragment, + BitExtractAdjustment); + if (!NewFragmentExpr) + return; + if (Orig->isDbgDeclare()) { - DPValue *DPV = DPValue::createDPVDeclare( + DbgVariableRecord *DVR = DbgVariableRecord::createDVRDeclare( NewAddr, Orig->getVariable(), NewFragmentExpr, Orig->getDebugLoc()); - BeforeInst->getParent()->insertDPValueBefore(DPV, - BeforeInst->getIterator()); + BeforeInst->getParent()->insertDbgRecordBefore(DVR, + BeforeInst->getIterator()); return; } + + // Apply a DIAssignID to the store if it doesn't already have it. if (!NewAddr->hasMetadata(LLVMContext::MD_DIAssignID)) { NewAddr->setMetadata(LLVMContext::MD_DIAssignID, DIAssignID::getDistinct(NewAddr->getContext())); } - auto *NewAssign = DPValue::createLinkedDPVAssign( + + DbgVariableRecord *NewAssign = DbgVariableRecord::createLinkedDVRAssign( NewAddr, Orig->getValue(), Orig->getVariable(), NewFragmentExpr, NewAddr, - Orig->getAddressExpression(), Orig->getDebugLoc()); - LLVM_DEBUG(dbgs() << "Created new DPVAssign: " << *NewAssign << "\n"); + NewAddrExpr, Orig->getDebugLoc()); + LLVM_DEBUG(dbgs() << "Created new DVRAssign: " << *NewAssign << "\n"); (void)NewAssign; } @@ -5010,8 +5247,7 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { IsSorted = false; } } - } - else { + } else { // We only allow whole-alloca splittable loads and stores // for a large alloca to avoid creating too large BitVector. for (Slice &S : AS) { @@ -5030,7 +5266,7 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { } if (!IsSorted) - llvm::sort(AS); + llvm::stable_sort(AS); /// Describes the allocas introduced by rewritePartition in order to migrate /// the debug info. @@ -5039,7 +5275,7 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { uint64_t Offset; uint64_t Size; Fragment(AllocaInst *AI, uint64_t O, uint64_t S) - : Alloca(AI), Offset(O), Size(S) {} + : Alloca(AI), Offset(O), Size(S) {} }; SmallVector<Fragment, 4> Fragments; @@ -5053,7 +5289,8 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { DL.getTypeSizeInBits(NewAI->getAllocatedType()).getFixedValue(); // Don't include any padding. uint64_t Size = std::min(AllocaSize, P.size() * SizeOfByte); - Fragments.push_back(Fragment(NewAI, P.beginOffset() * SizeOfByte, Size)); + Fragments.push_back( + Fragment(NewAI, P.beginOffset() * SizeOfByte, Size)); } } ++NumPartitions; @@ -5065,54 +5302,78 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { // Migrate debug information from the old alloca to the new alloca(s) // and the individual partitions. auto MigrateOne = [&](auto *DbgVariable) { - auto *Expr = DbgVariable->getExpression(); - DIBuilder DIB(*AI.getModule(), /*AllowUnresolved*/ false); - uint64_t AllocaSize = - DL.getTypeSizeInBits(AI.getAllocatedType()).getFixedValue(); - for (auto Fragment : Fragments) { - // Create a fragment expression describing the new partition or reuse AI's - // expression if there is only one partition. - auto *FragmentExpr = Expr; - if (Fragment.Size < AllocaSize || Expr->isFragment()) { - // If this alloca is already a scalar replacement of a larger aggregate, - // Fragment.Offset describes the offset inside the scalar. - auto ExprFragment = Expr->getFragmentInfo(); - uint64_t Offset = ExprFragment ? ExprFragment->OffsetInBits : 0; - uint64_t Start = Offset + Fragment.Offset; - uint64_t Size = Fragment.Size; - if (ExprFragment) { - uint64_t AbsEnd = - ExprFragment->OffsetInBits + ExprFragment->SizeInBits; - if (Start >= AbsEnd) { - // No need to describe a SROAed padding. - continue; - } - Size = std::min(Size, AbsEnd - Start); - } - // The new, smaller fragment is stenciled out from the old fragment. - if (auto OrigFragment = FragmentExpr->getFragmentInfo()) { - assert(Start >= OrigFragment->OffsetInBits && - "new fragment is outside of original fragment"); - Start -= OrigFragment->OffsetInBits; - } + // Can't overlap with undef memory. + if (isKillAddress(DbgVariable)) + return; - // The alloca may be larger than the variable. - auto VarSize = DbgVariable->getVariable()->getSizeInBits(); - if (VarSize) { - if (Size > *VarSize) - Size = *VarSize; - if (Size == 0 || Start + Size > *VarSize) - continue; - } + const Value *DbgPtr = getAddress(DbgVariable); + DIExpression::FragmentInfo VarFrag = + DbgVariable->getFragmentOrEntireVariable(); + // Get the address expression constant offset if one exists and the ops + // that come after it. + int64_t CurrentExprOffsetInBytes = 0; + SmallVector<uint64_t> PostOffsetOps; + if (!getAddressExpression(DbgVariable) + ->extractLeadingOffset(CurrentExprOffsetInBytes, PostOffsetOps)) + return; // Couldn't interpret this DIExpression - drop the var. + + // Offset defined by a DW_OP_LLVM_extract_bits_[sz]ext. + int64_t ExtractOffsetInBits = 0; + for (auto Op : getAddressExpression(DbgVariable)->expr_ops()) { + if (Op.getOp() == dwarf::DW_OP_LLVM_extract_bits_zext || + Op.getOp() == dwarf::DW_OP_LLVM_extract_bits_sext) { + ExtractOffsetInBits = Op.getArg(0); + break; + } + } - // Avoid creating a fragment expression that covers the entire variable. - if (!VarSize || *VarSize != Size) { - if (auto E = - DIExpression::createFragmentExpression(Expr, Start, Size)) - FragmentExpr = *E; - else - continue; - } + DIBuilder DIB(*AI.getModule(), /*AllowUnresolved*/ false); + for (auto Fragment : Fragments) { + int64_t OffsetFromLocationInBits; + std::optional<DIExpression::FragmentInfo> NewDbgFragment; + // Find the variable fragment that the new alloca slice covers. + // Drop debug info for this variable fragment if we can't compute an + // intersect between it and the alloca slice. + if (!DIExpression::calculateFragmentIntersect( + DL, &AI, Fragment.Offset, Fragment.Size, DbgPtr, + CurrentExprOffsetInBytes * 8, ExtractOffsetInBits, VarFrag, + NewDbgFragment, OffsetFromLocationInBits)) + continue; // Do not migrate this fragment to this slice. + + // Zero sized fragment indicates there's no intersect between the variable + // fragment and the alloca slice. Skip this slice for this variable + // fragment. + if (NewDbgFragment && !NewDbgFragment->SizeInBits) + continue; // Do not migrate this fragment to this slice. + + // No fragment indicates DbgVariable's variable or fragment exactly + // overlaps the slice; copy its fragment (or nullopt if there isn't one). + if (!NewDbgFragment) + NewDbgFragment = DbgVariable->getFragment(); + + // Reduce the new expression offset by the bit-extract offset since + // we'll be keeping that. + int64_t OffestFromNewAllocaInBits = + OffsetFromLocationInBits - ExtractOffsetInBits; + // We need to adjust an existing bit extract if the offset expression + // can't eat the slack (i.e., if the new offset would be negative). + int64_t BitExtractOffset = + std::min<int64_t>(0, OffestFromNewAllocaInBits); + // The magnitude of a negative value indicates the number of bits into + // the existing variable fragment that the memory region begins. The new + // variable fragment already excludes those bits - the new DbgPtr offset + // only needs to be applied if it's positive. + OffestFromNewAllocaInBits = + std::max(int64_t(0), OffestFromNewAllocaInBits); + + // Rebuild the expression: + // {Offset(OffestFromNewAllocaInBits), PostOffsetOps, NewDbgFragment} + // Add NewDbgFragment later, because dbg.assigns don't want it in the + // address expression but the value expression instead. + DIExpression *NewExpr = DIExpression::get(AI.getContext(), PostOffsetOps); + if (OffestFromNewAllocaInBits > 0) { + int64_t OffsetInBytes = (OffestFromNewAllocaInBits + 7) / 8; + NewExpr = DIExpression::prepend(NewExpr, /*flags=*/0, OffsetInBytes); } // Remove any existing intrinsics on the new alloca describing @@ -5127,18 +5388,19 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { OldDII->eraseFromParent(); }; for_each(findDbgDeclares(Fragment.Alloca), RemoveOne); - for_each(findDPVDeclares(Fragment.Alloca), RemoveOne); + for_each(findDVRDeclares(Fragment.Alloca), RemoveOne); - insertNewDbgInst(DIB, DbgVariable, Fragment.Alloca, FragmentExpr, &AI); + insertNewDbgInst(DIB, DbgVariable, Fragment.Alloca, NewExpr, &AI, + NewDbgFragment, BitExtractOffset); } }; // Migrate debug information from the old alloca to the new alloca(s) // and the individual partitions. for_each(findDbgDeclares(&AI), MigrateOne); - for_each(findDPVDeclares(&AI), MigrateOne); + for_each(findDVRDeclares(&AI), MigrateOne); for_each(at::getAssignmentMarkers(&AI), MigrateOne); - for_each(at::getDPVAssignmentMarkers(&AI), MigrateOne); + for_each(at::getDVRAssignmentMarkers(&AI), MigrateOne); return Changed; } @@ -5177,7 +5439,7 @@ SROA::runOnAlloca(AllocaInst &AI) { Changed = true; return {Changed, CFGChanged}; } - const DataLayout &DL = AI.getModule()->getDataLayout(); + const DataLayout &DL = AI.getDataLayout(); // Skip alloca forms that this analysis can't handle. auto *AT = AI.getAllocatedType(); @@ -5262,7 +5524,7 @@ bool SROA::deleteDeadInstructions( DeletedAllocas.insert(AI); for (DbgDeclareInst *OldDII : findDbgDeclares(AI)) OldDII->eraseFromParent(); - for (DPValue *OldDII : findDPVDeclares(AI)) + for (DbgVariableRecord *OldDII : findDVRDeclares(AI)) OldDII->eraseFromParent(); } @@ -5309,7 +5571,7 @@ bool SROA::promoteAllocas(Function &F) { std::pair<bool /*Changed*/, bool /*CFGChanged*/> SROA::runSROA(Function &F) { LLVM_DEBUG(dbgs() << "SROA function: " << F.getName() << "\n"); - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); BasicBlock &EntryBB = F.getEntryBlock(); for (BasicBlock::iterator I = EntryBB.begin(), E = std::prev(EntryBB.end()); I != E; ++I) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalar.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalar.cpp index 4ce6ce93be33..cb1456b14632 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalar.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalar.cpp @@ -29,7 +29,6 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeInstSimplifyLegacyPassPass(Registry); initializeLegacyLICMPassPass(Registry); initializeLoopDataPrefetchLegacyPassPass(Registry); - initializeLoopRotateLegacyPassPass(Registry); initializeLoopStrengthReducePass(Registry); initializeLoopUnrollPass(Registry); initializeLowerAtomicLegacyPassPass(Registry); @@ -49,4 +48,5 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeSpeculativeExecutionLegacyPassPass(Registry); initializeStraightLineStrengthReduceLegacyPassPass(Registry); initializePlaceBackedgeSafepointsLegacyPassPass(Registry); + initializePostInlineEntryExitInstrumenterPass(Registry); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp index c01d03f64472..8eadf8900020 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp @@ -627,6 +627,7 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI, Value *Ptr = CI->getArgOperand(0); Value *Mask = CI->getArgOperand(1); Value *PassThru = CI->getArgOperand(2); + Align Alignment = CI->getParamAlign(0).valueOrOne(); auto *VecType = cast<FixedVectorType>(CI->getType()); @@ -644,6 +645,10 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI, // The result vector Value *VResult = PassThru; + // Adjust alignment for the scalar instruction. + const Align AdjustedAlignment = + commonAlignment(Alignment, EltTy->getPrimitiveSizeInBits() / 8); + // Shorten the way if the mask is a vector of constants. // Create a build_vector pattern, with loads/poisons as necessary and then // shuffle blend with the pass through value. @@ -659,7 +664,7 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI, } else { Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); - InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1), + InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, AdjustedAlignment, "Load" + Twine(Idx)); ShuffleMask[Idx] = Idx; ++MemIndex; @@ -713,7 +718,7 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI, CondBlock->setName("cond.load"); Builder.SetInsertPoint(CondBlock->getTerminator()); - LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1)); + LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, AdjustedAlignment); Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); // Move the pointer if there are more blocks to come. @@ -755,6 +760,7 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI, Value *Src = CI->getArgOperand(0); Value *Ptr = CI->getArgOperand(1); Value *Mask = CI->getArgOperand(2); + Align Alignment = CI->getParamAlign(1).valueOrOne(); auto *VecType = cast<FixedVectorType>(Src->getType()); @@ -767,6 +773,10 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI, Type *EltTy = VecType->getElementType(); + // Adjust alignment for the scalar instruction. + const Align AdjustedAlignment = + commonAlignment(Alignment, EltTy->getPrimitiveSizeInBits() / 8); + unsigned VectorWidth = VecType->getNumElements(); // Shorten the way if the mask is a vector of constants. @@ -778,7 +788,7 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI, Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); - Builder.CreateAlignedStore(OneElt, NewPtr, Align(1)); + Builder.CreateAlignedStore(OneElt, NewPtr, AdjustedAlignment); ++MemIndex; } CI->eraseFromParent(); @@ -824,7 +834,7 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI, Builder.SetInsertPoint(CondBlock->getTerminator()); Value *OneElt = Builder.CreateExtractElement(Src, Idx); - Builder.CreateAlignedStore(OneElt, Ptr, Align(1)); + Builder.CreateAlignedStore(OneElt, Ptr, AdjustedAlignment); // Move the pointer if there are more blocks to come. Value *NewPtr; @@ -852,6 +862,69 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI, ModifiedDT = true; } +static void scalarizeMaskedVectorHistogram(const DataLayout &DL, CallInst *CI, + DomTreeUpdater *DTU, + bool &ModifiedDT) { + // If we extend histogram to return a result someday (like the updated vector) + // then we'll need to support it here. + assert(CI->getType()->isVoidTy() && "Histogram with non-void return."); + Value *Ptrs = CI->getArgOperand(0); + Value *Inc = CI->getArgOperand(1); + Value *Mask = CI->getArgOperand(2); + + auto *AddrType = cast<FixedVectorType>(Ptrs->getType()); + Type *EltTy = Inc->getType(); + + IRBuilder<> Builder(CI->getContext()); + Instruction *InsertPt = CI; + Builder.SetInsertPoint(InsertPt); + + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + + // FIXME: Do we need to add an alignment parameter to the intrinsic? + unsigned VectorWidth = AddrType->getNumElements(); + + // Shorten the way if the mask is a vector of constants. + if (isConstantIntVector(Mask)) { + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) + continue; + Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); + LoadInst *Load = Builder.CreateLoad(EltTy, Ptr, "Load" + Twine(Idx)); + Value *Add = Builder.CreateAdd(Load, Inc); + Builder.CreateStore(Add, Ptr); + } + CI->eraseFromParent(); + return; + } + + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + Value *Predicate = + Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); + + Instruction *ThenTerm = + SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, + /*BranchWeights=*/nullptr, DTU); + + BasicBlock *CondBlock = ThenTerm->getParent(); + CondBlock->setName("cond.histogram.update"); + + Builder.SetInsertPoint(CondBlock->getTerminator()); + Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); + LoadInst *Load = Builder.CreateLoad(EltTy, Ptr, "Load" + Twine(Idx)); + Value *Add = Builder.CreateAdd(Load, Inc); + Builder.CreateStore(Add, Ptr); + + // Create "else" block, fill it in the next iteration + BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); + NewIfBlock->setName("else"); + Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); + } + + CI->eraseFromParent(); + ModifiedDT = true; +} + static bool runImpl(Function &F, const TargetTransformInfo &TTI, DominatorTree *DT) { std::optional<DomTreeUpdater> DTU; @@ -860,7 +933,7 @@ static bool runImpl(Function &F, const TargetTransformInfo &TTI, bool EverMadeChange = false; bool MadeChange = true; - auto &DL = F.getParent()->getDataLayout(); + auto &DL = F.getDataLayout(); while (MadeChange) { MadeChange = false; for (BasicBlock &BB : llvm::make_early_inc_range(F)) { @@ -928,6 +1001,12 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, switch (II->getIntrinsicID()) { default: break; + case Intrinsic::experimental_vector_histogram_add: + if (TTI.isLegalMaskedVectorHistogram(CI->getArgOperand(0)->getType(), + CI->getArgOperand(1)->getType())) + return false; + scalarizeMaskedVectorHistogram(DL, CI, DTU, ModifiedDT); + return true; case Intrinsic::masked_load: // Scalarize unsupported vector masked load if (TTI.isLegalMaskedLoad( @@ -969,12 +1048,16 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, return true; } case Intrinsic::masked_expandload: - if (TTI.isLegalMaskedExpandLoad(CI->getType())) + if (TTI.isLegalMaskedExpandLoad( + CI->getType(), + CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne())) return false; scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT); return true; case Intrinsic::masked_compressstore: - if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType())) + if (TTI.isLegalMaskedCompressStore( + CI->getArgOperand(0)->getType(), + CI->getAttributes().getParamAttrs(1).getAlignment().valueOrOne())) return false; scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT); return true; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalarizer.cpp index 3eca9ac7c267..2bed3480da1c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalarizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/Scalarizer.cpp @@ -523,8 +523,8 @@ void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV) { SmallVector<std::pair<unsigned, MDNode *>, 4> MDs; Op->getAllMetadataOtherThanDebugLoc(MDs); - for (unsigned I = 0, E = CV.size(); I != E; ++I) { - if (Instruction *New = dyn_cast<Instruction>(CV[I])) { + for (Value *V : CV) { + if (Instruction *New = dyn_cast<Instruction>(V)) { for (const auto &MD : MDs) if (canTransferMetadata(MD.first)) New->setMetadata(MD.first, MD.second); @@ -1107,7 +1107,7 @@ bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) { return false; std::optional<VectorLayout> Layout = getVectorLayout( - LI.getType(), LI.getAlign(), LI.getModule()->getDataLayout()); + LI.getType(), LI.getAlign(), LI.getDataLayout()); if (!Layout) return false; @@ -1133,7 +1133,7 @@ bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) { Value *FullValue = SI.getValueOperand(); std::optional<VectorLayout> Layout = getVectorLayout( - FullValue->getType(), SI.getAlign(), SI.getModule()->getDataLayout()); + FullValue->getType(), SI.getAlign(), SI.getDataLayout()); if (!Layout) return false; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index 17c466f38c9c..73e3ff296cf1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -57,7 +57,7 @@ // // base = gep a, 0, x, y // load base -// laod base + 1 * sizeof(float) +// load base + 1 * sizeof(float) // load base + 32 * sizeof(float) // load base + 33 * sizeof(float) // @@ -174,6 +174,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -235,18 +236,16 @@ public: /// \p UserChainTail Outputs the tail of UserChain so that we can /// garbage-collect unused instructions in UserChain. static Value *Extract(Value *Idx, GetElementPtrInst *GEP, - User *&UserChainTail, const DominatorTree *DT); + User *&UserChainTail); /// Looks for a constant offset from the given GEP index without extracting /// it. It returns the numeric value of the extracted constant offset (0 if /// failed). The meaning of the arguments are the same as Extract. - static int64_t Find(Value *Idx, GetElementPtrInst *GEP, - const DominatorTree *DT); + static int64_t Find(Value *Idx, GetElementPtrInst *GEP); private: - ConstantOffsetExtractor(Instruction *InsertionPt, const DominatorTree *DT) - : IP(InsertionPt), DL(InsertionPt->getModule()->getDataLayout()), DT(DT) { - } + ConstantOffsetExtractor(BasicBlock::iterator InsertionPt) + : IP(InsertionPt), DL(InsertionPt->getDataLayout()) {} /// Searches the expression that computes V for a non-zero constant C s.t. /// V can be reassociated into the form V' + C. If the searching is @@ -333,10 +332,9 @@ private: SmallVector<CastInst *, 16> ExtInsts; /// Insertion position of cloned instructions. - Instruction *IP; + BasicBlock::iterator IP; const DataLayout &DL; - const DominatorTree *DT; }; /// A pass that tries to split every GEP in the function into a variadic @@ -393,6 +391,11 @@ private: /// and returns true if the splitting succeeds. bool splitGEP(GetElementPtrInst *GEP); + /// Tries to reorder the given GEP with the GEP that produces the base if + /// doing so results in producing a constant offset as the outermost + /// index. + bool reorderGEP(GetElementPtrInst *GEP, TargetTransformInfo &TTI); + /// Lower a GEP with multiple indices into multiple GEPs with a single index. /// Function splitGEP already split the original GEP into a variadic part and /// a constant offset (i.e., AccumulativeByteOffset). This function lowers the @@ -519,12 +522,10 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended, } Value *LHS = BO->getOperand(0), *RHS = BO->getOperand(1); - // Do not trace into "or" unless it is equivalent to "add". If LHS and RHS - // don't have common bits, (LHS | RHS) is equivalent to (LHS + RHS). - // FIXME: this does not appear to be covered by any tests - // (with x86/aarch64 backends at least) + // Do not trace into "or" unless it is equivalent to "add". + // This is the case if the or's disjoint flag is set. if (BO->getOpcode() == Instruction::Or && - !haveNoCommonBitsSet(LHS, RHS, SimplifyQuery(DL, DT, /*AC*/ nullptr, BO))) + !cast<PossiblyDisjointInst>(BO)->isDisjoint()) return false; // FIXME: We don't currently support constants from the RHS of subs, @@ -669,7 +670,7 @@ Value *ConstantOffsetExtractor::applyExts(Value *V) { Instruction *Ext = I->clone(); Ext->setOperand(0, Current); - Ext->insertBefore(IP); + Ext->insertBefore(*IP->getParent(), IP); Current = Ext; } return Current; @@ -778,9 +779,8 @@ Value *ConstantOffsetExtractor::removeConstOffset(unsigned ChainIndex) { } Value *ConstantOffsetExtractor::Extract(Value *Idx, GetElementPtrInst *GEP, - User *&UserChainTail, - const DominatorTree *DT) { - ConstantOffsetExtractor Extractor(GEP, DT); + User *&UserChainTail) { + ConstantOffsetExtractor Extractor(GEP->getIterator()); // Find a non-zero constant offset first. APInt ConstantOffset = Extractor.find(Idx, /* SignExtended */ false, /* ZeroExtended */ false, @@ -795,10 +795,9 @@ Value *ConstantOffsetExtractor::Extract(Value *Idx, GetElementPtrInst *GEP, return IdxWithoutConstOffset; } -int64_t ConstantOffsetExtractor::Find(Value *Idx, GetElementPtrInst *GEP, - const DominatorTree *DT) { +int64_t ConstantOffsetExtractor::Find(Value *Idx, GetElementPtrInst *GEP) { // If Idx is an index of an inbound GEP, Idx is guaranteed to be non-negative. - return ConstantOffsetExtractor(GEP, DT) + return ConstantOffsetExtractor(GEP->getIterator()) .find(Idx, /* SignExtended */ false, /* ZeroExtended */ false, GEP->isInBounds()) .getSExtValue(); @@ -814,7 +813,8 @@ bool SeparateConstOffsetFromGEP::canonicalizeArrayIndicesToIndexSize( // Skip struct member indices which must be i32. if (GTI.isSequential()) { if ((*I)->getType() != PtrIdxTy) { - *I = CastInst::CreateIntegerCast(*I, PtrIdxTy, true, "idxprom", GEP); + *I = CastInst::CreateIntegerCast(*I, PtrIdxTy, true, "idxprom", + GEP->getIterator()); Changed = true; } } @@ -836,7 +836,7 @@ SeparateConstOffsetFromGEP::accumulateByteOffset(GetElementPtrInst *GEP, // Tries to extract a constant offset from this GEP index. int64_t ConstantOffset = - ConstantOffsetExtractor::Find(GEP->getOperand(I), GEP, DT); + ConstantOffsetExtractor::Find(GEP->getOperand(I), GEP); if (ConstantOffset != 0) { NeedsExtraction = true; // A GEP may have multiple indices. We accumulate the extracted @@ -970,6 +970,49 @@ SeparateConstOffsetFromGEP::lowerToArithmetics(GetElementPtrInst *Variadic, Variadic->eraseFromParent(); } +bool SeparateConstOffsetFromGEP::reorderGEP(GetElementPtrInst *GEP, + TargetTransformInfo &TTI) { + auto PtrGEP = dyn_cast<GetElementPtrInst>(GEP->getPointerOperand()); + if (!PtrGEP) + return false; + + bool NestedNeedsExtraction; + int64_t NestedByteOffset = + accumulateByteOffset(PtrGEP, NestedNeedsExtraction); + if (!NestedNeedsExtraction) + return false; + + unsigned AddrSpace = PtrGEP->getPointerAddressSpace(); + if (!TTI.isLegalAddressingMode(GEP->getResultElementType(), + /*BaseGV=*/nullptr, NestedByteOffset, + /*HasBaseReg=*/true, /*Scale=*/0, AddrSpace)) + return false; + + bool GEPInBounds = GEP->isInBounds(); + bool PtrGEPInBounds = PtrGEP->isInBounds(); + bool IsChainInBounds = GEPInBounds && PtrGEPInBounds; + if (IsChainInBounds) { + auto IsKnownNonNegative = [this](Value *V) { + return isKnownNonNegative(V, *DL); + }; + IsChainInBounds &= all_of(GEP->indices(), IsKnownNonNegative); + if (IsChainInBounds) + IsChainInBounds &= all_of(PtrGEP->indices(), IsKnownNonNegative); + } + + IRBuilder<> Builder(GEP); + // For trivial GEP chains, we can swap the indices. + Value *NewSrc = Builder.CreateGEP( + GEP->getSourceElementType(), PtrGEP->getPointerOperand(), + SmallVector<Value *, 4>(GEP->indices()), "", IsChainInBounds); + Value *NewGEP = Builder.CreateGEP(PtrGEP->getSourceElementType(), NewSrc, + SmallVector<Value *, 4>(PtrGEP->indices()), + "", IsChainInBounds); + GEP->replaceAllUsesWith(NewGEP); + RecursivelyDeleteTriviallyDeadInstructions(GEP); + return true; +} + bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { // Skip vector GEPs. if (GEP->getType()->isVectorTy()) @@ -985,11 +1028,13 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { bool NeedsExtraction; int64_t AccumulativeByteOffset = accumulateByteOffset(GEP, NeedsExtraction); - if (!NeedsExtraction) - return Changed; - TargetTransformInfo &TTI = GetTTI(*GEP->getFunction()); + if (!NeedsExtraction) { + Changed |= reorderGEP(GEP, TTI); + return Changed; + } + // If LowerGEP is disabled, before really splitting the GEP, check whether the // backend supports the addressing mode we are about to produce. If no, this // splitting probably won't be beneficial. @@ -1026,7 +1071,7 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { Value *OldIdx = GEP->getOperand(I); User *UserChainTail; Value *NewIdx = - ConstantOffsetExtractor::Extract(OldIdx, GEP, UserChainTail, DT); + ConstantOffsetExtractor::Extract(OldIdx, GEP, UserChainTail); if (NewIdx != nullptr) { // Switches to the index with the constant offset removed. GEP->setOperand(I, NewIdx); @@ -1057,8 +1102,9 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { // // TODO(jingyue): do some range analysis to keep as many inbounds as // possible. GEPs with inbounds are more friendly to alias analysis. + // TODO(gep_nowrap): Preserve nuw at least. bool GEPWasInBounds = GEP->isInBounds(); - GEP->setIsInBounds(false); + GEP->setNoWrapFlags(GEPNoWrapFlags::none()); // Lowers a GEP to either GEPs with a single index or arithmetic operations. if (LowerGEP) { @@ -1133,7 +1179,7 @@ bool SeparateConstOffsetFromGEP::run(Function &F) { if (DisableSeparateConstOffsetFromGEP) return false; - DL = &F.getParent()->getDataLayout(); + DL = &F.getDataLayout(); bool Changed = false; for (BasicBlock &B : F) { if (!DT->isReachableFromEntry(&B)) @@ -1188,9 +1234,11 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) { if (LHS->getType() == RHS->getType()) { ExprKey Key = createNormalizedCommutablePair(LHS, RHS); if (auto *Dom = findClosestMatchingDominator(Key, I, DominatingAdds)) { - Instruction *NewSExt = new SExtInst(Dom, I->getType(), "", I); + Instruction *NewSExt = + new SExtInst(Dom, I->getType(), "", I->getIterator()); NewSExt->takeName(I); I->replaceAllUsesWith(NewSExt); + NewSExt->setDebugLoc(I->getDebugLoc()); RecursivelyDeleteTriviallyDeadInstructions(I); return true; } @@ -1199,9 +1247,11 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) { if (LHS->getType() == RHS->getType()) { if (auto *Dom = findClosestMatchingDominator({LHS, RHS}, I, DominatingSubs)) { - Instruction *NewSExt = new SExtInst(Dom, I->getType(), "", I); + Instruction *NewSExt = + new SExtInst(Dom, I->getType(), "", I->getIterator()); NewSExt->takeName(I); I->replaceAllUsesWith(NewSExt); + NewSExt->setDebugLoc(I->getDebugLoc()); RecursivelyDeleteTriviallyDeadInstructions(I); return true; } @@ -1321,7 +1371,7 @@ void SeparateConstOffsetFromGEP::swapGEPOperand(GetElementPtrInst *First, Second->setOperand(1, Offset1); // We changed p+o+c to p+c+o, p+c may not be inbound anymore. - const DataLayout &DAL = First->getModule()->getDataLayout(); + const DataLayout &DAL = First->getDataLayout(); APInt Offset(DAL.getIndexSizeInBits( cast<PointerType>(First->getType())->getAddressSpace()), 0); @@ -1330,8 +1380,9 @@ void SeparateConstOffsetFromGEP::swapGEPOperand(GetElementPtrInst *First, uint64_t ObjectSize; if (!getObjectSize(NewBase, ObjectSize, DAL, TLI) || Offset.ugt(ObjectSize)) { - First->setIsInBounds(false); - Second->setIsInBounds(false); + // TODO(gep_nowrap): Make flag preservation more precise. + First->setNoWrapFlags(GEPNoWrapFlags::none()); + Second->setNoWrapFlags(GEPNoWrapFlags::none()); } else First->setIsInBounds(true); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index 7eb0ba1c2c17..f99f4487c554 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -41,6 +41,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Use.h" @@ -133,6 +134,7 @@ static cl::opt<unsigned> InjectInvariantConditionHotnesThreshold( "not-taken 1/<this option> times or less."), cl::init(16)); +AnalysisKey ShouldRunExtraSimpleLoopUnswitch::Key; namespace { struct CompareDesc { BranchInst *Term; @@ -630,7 +632,8 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, } else { // Create a new unconditional branch that will continue the loop as a new // terminator. - BranchInst::Create(ContinueBB, ParentBB); + Instruction *NewBI = BranchInst::Create(ContinueBB, ParentBB); + NewBI->setDebugLoc(BI.getDebugLoc()); } BI.setSuccessor(LoopExitSuccIdx, UnswitchedBB); BI.setSuccessor(1 - LoopExitSuccIdx, NewPH); @@ -664,10 +667,12 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, // Finish updating dominator tree and memory ssa for full unswitch. if (FullUnswitch) { if (MSSAU) { - // Remove the cloned branch instruction. - ParentBB->getTerminator()->eraseFromParent(); - // Create unconditional branch now. - BranchInst::Create(ContinueBB, ParentBB); + Instruction *Term = ParentBB->getTerminator(); + // Remove the cloned branch instruction and create unconditional branch + // now. + Instruction *NewBI = BranchInst::Create(ContinueBB, ParentBB); + NewBI->setDebugLoc(Term->getDebugLoc()); + Term->eraseFromParent(); MSSAU->removeEdge(ParentBB, LoopExitBB); } DT.deleteEdge(ParentBB, LoopExitBB); @@ -859,8 +864,11 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, BasicBlock *NewPH = SplitEdge(OldPH, L.getHeader(), &DT, &LI, MSSAU); OldPH->getTerminator()->eraseFromParent(); - // Now add the unswitched switch. + // Now add the unswitched switch. This new switch instruction inherits the + // debug location of the old switch, because it semantically replace the old + // one. auto *NewSI = SwitchInst::Create(LoopCond, NewPH, ExitCases.size(), OldPH); + NewSI->setDebugLoc(SIW->getDebugLoc()); SwitchInstProfUpdateWrapper NewSIW(*NewSI); // Rewrite the IR for the unswitched basic blocks. This requires two steps. @@ -970,8 +978,9 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, /*KeepOneInputPHIs*/ true); } // Now nuke the switch and replace it with a direct branch. + Instruction *NewBI = BranchInst::Create(CommonSuccBB, BB); + NewBI->setDebugLoc(SIW->getDebugLoc()); SIW.eraseFromParent(); - BranchInst::Create(CommonSuccBB, BB); } else if (DefaultExitBB) { assert(SI.getNumCases() > 0 && "If we had no cases we'd have a common successor!"); @@ -1240,12 +1249,16 @@ static BasicBlock *buildClonedLoopBlocks( assert(VMap.lookup(&I) == &ClonedI && "Mismatch in the value map!"); // Forget SCEVs based on exit phis in case SCEV looked through the phi. - if (SE && isa<PHINode>(I)) - SE->forgetValue(&I); + if (SE) + if (auto *PN = dyn_cast<PHINode>(&I)) + SE->forgetLcssaPhiWithNewPredecessor(&L, PN); + + BasicBlock::iterator InsertPt = MergeBB->getFirstInsertionPt(); auto *MergePN = PHINode::Create(I.getType(), /*NumReservedValues*/ 2, ".us-phi"); - MergePN->insertBefore(MergeBB->getFirstInsertionPt()); + MergePN->insertBefore(InsertPt); + MergePN->setDebugLoc(InsertPt->getDebugLoc()); I.replaceAllUsesWith(MergePN); MergePN->addIncoming(&I, ExitBB); MergePN->addIncoming(&ClonedI, ClonedExitBB); @@ -1260,8 +1273,8 @@ static BasicBlock *buildClonedLoopBlocks( Module *M = ClonedPH->getParent()->getParent(); for (auto *ClonedBB : NewBlocks) for (Instruction &I : *ClonedBB) { - RemapDPValueRange(M, I.getDbgValueRange(), VMap, - RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + RemapDbgRecordRange(M, I.getDbgRecordRange(), VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); RemapInstruction(&I, VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); if (auto *II = dyn_cast<AssumeInst>(&I)) @@ -1304,8 +1317,9 @@ static BasicBlock *buildClonedLoopBlocks( else if (auto *SI = dyn_cast<SwitchInst>(ClonedTerminator)) ClonedConditionToErase = SI->getCondition(); + Instruction *BI = BranchInst::Create(ClonedSuccBB, ClonedParentBB); + BI->setDebugLoc(ClonedTerminator->getDebugLoc()); ClonedTerminator->eraseFromParent(); - BranchInst::Create(ClonedSuccBB, ClonedParentBB); if (ClonedConditionToErase) RecursivelyDeleteTriviallyDeadInstructions(ClonedConditionToErase, nullptr, @@ -2332,23 +2346,27 @@ static void unswitchNontrivialInvariants( // nuke the initial terminator placed in the split block. SplitBB->getTerminator()->eraseFromParent(); if (FullUnswitch) { - // Splice the terminator from the original loop and rewrite its - // successors. - TI.moveBefore(*SplitBB, SplitBB->end()); - // Keep a clone of the terminator for MSSA updates. Instruction *NewTI = TI.clone(); NewTI->insertInto(ParentBB, ParentBB->end()); + // Splice the terminator from the original loop and rewrite its + // successors. + TI.moveBefore(*SplitBB, SplitBB->end()); + TI.dropLocation(); + // First wire up the moved terminator to the preheaders. if (BI) { BasicBlock *ClonedPH = ClonedPHs.begin()->second; BI->setSuccessor(ClonedSucc, ClonedPH); BI->setSuccessor(1 - ClonedSucc, LoopPH); Value *Cond = skipTrivialSelect(BI->getCondition()); - if (InsertFreeze) - Cond = new FreezeInst( - Cond, Cond->getName() + ".fr", BI); + if (InsertFreeze) { + // We don't give any debug location to the new freeze, because the + // BI (`dyn_cast<BranchInst>(TI)`) is an in-loop instruction hoisted + // out of the loop. + Cond = new FreezeInst(Cond, Cond->getName() + ".fr", BI->getIterator()); + } BI->setCondition(Cond); DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH}); } else { @@ -2365,8 +2383,9 @@ static void unswitchNontrivialInvariants( Case.setSuccessor(ClonedPHs.find(Case.getCaseSuccessor())->second); if (InsertFreeze) - SI->setCondition(new FreezeInst( - SI->getCondition(), SI->getCondition()->getName() + ".fr", SI)); + SI->setCondition(new FreezeInst(SI->getCondition(), + SI->getCondition()->getName() + ".fr", + SI->getIterator())); // We need to use the set to populate domtree updates as even when there // are multiple cases pointing at the same successor we only want to @@ -2430,12 +2449,13 @@ static void unswitchNontrivialInvariants( DTUpdates.push_back({DominatorTree::Delete, ParentBB, SuccBB}); } - // After MSSAU update, remove the cloned terminator instruction NewTI. - ParentBB->getTerminator()->eraseFromParent(); - // Create a new unconditional branch to the continuing block (as opposed to // the one cloned). - BranchInst::Create(RetainedSuccBB, ParentBB); + Instruction *NewBI = BranchInst::Create(RetainedSuccBB, ParentBB); + NewBI->setDebugLoc(NewTI->getDebugLoc()); + + // After MSSAU update, remove the cloned terminator instruction NewTI. + NewTI->eraseFromParent(); } else { assert(BI && "Only branches have partial unswitching."); assert(UnswitchedSuccBBs.size() == 1 && @@ -2704,9 +2724,11 @@ static BranchInst *turnSelectIntoBranch(SelectInst *SI, DominatorTree &DT, if (MSSAU) MSSAU->moveAllAfterSpliceBlocks(HeadBB, TailBB, SI); - PHINode *Phi = PHINode::Create(SI->getType(), 2, "unswitched.select", SI); + PHINode *Phi = + PHINode::Create(SI->getType(), 2, "unswitched.select", SI->getIterator()); Phi->addIncoming(SI->getTrueValue(), ThenBB); Phi->addIncoming(SI->getFalseValue(), HeadBB); + Phi->setDebugLoc(SI->getDebugLoc()); SI->replaceAllUsesWith(Phi); SI->eraseFromParent(); @@ -3092,7 +3114,7 @@ injectPendingInvariantConditions(NonTrivialUnswitchCandidate Candidate, Loop &L, // unswitching will break. Better optimize it away later. auto *InjectedCond = ICmpInst::Create(Instruction::ICmp, Pred, LHS, RHS, "injected.cond", - Preheader->getTerminator()); + Preheader->getTerminator()->getIterator()); BasicBlock *CheckBlock = BasicBlock::Create(Ctx, BB->getName() + ".check", BB->getParent(), InLoopSucc); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp index 7017f6adf3a2..11de37f7a7c1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -77,6 +77,9 @@ static cl::opt<bool> UserSinkCommonInsts( "sink-common-insts", cl::Hidden, cl::init(false), cl::desc("Sink common instructions (default = false)")); +static cl::opt<bool> UserSpeculateUnpredictables( + "speculate-unpredictables", cl::Hidden, cl::init(false), + cl::desc("Speculate unpredictable branches (default = false)")); STATISTIC(NumSimpl, "Number of blocks simplified"); @@ -142,8 +145,10 @@ performBlockTailMerging(Function &F, ArrayRef<BasicBlock *> BBs, // And turn BB into a block that just unconditionally branches // to the canonical block. + Instruction *BI = BranchInst::Create(CanonicalBB, BB); + BI->setDebugLoc(Term->getDebugLoc()); Term->eraseFromParent(); - BranchInst::Create(CanonicalBB, BB); + if (Updates) Updates->push_back({DominatorTree::Insert, BB, CanonicalBB}); } @@ -323,6 +328,8 @@ static void applyCommandLineOverridesToOptions(SimplifyCFGOptions &Options) { Options.HoistCommonInsts = UserHoistCommonInsts; if (UserSinkCommonInsts.getNumOccurrences()) Options.SinkCommonInsts = UserSinkCommonInsts; + if (UserSpeculateUnpredictables.getNumOccurrences()) + Options.SpeculateUnpredictables = UserSpeculateUnpredictables; } SimplifyCFGPass::SimplifyCFGPass() { @@ -349,7 +356,9 @@ void SimplifyCFGPass::printPipeline( OS << (Options.HoistCommonInsts ? "" : "no-") << "hoist-common-insts;"; OS << (Options.SinkCommonInsts ? "" : "no-") << "sink-common-insts;"; OS << (Options.SpeculateBlocks ? "" : "no-") << "speculate-blocks;"; - OS << (Options.SimplifyCondBranch ? "" : "no-") << "simplify-cond-branch"; + OS << (Options.SimplifyCondBranch ? "" : "no-") << "simplify-cond-branch;"; + OS << (Options.SpeculateUnpredictables ? "" : "no-") + << "speculate-unpredictables"; OS << '>'; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp index 7a5318d4404c..ed9c1828ce06 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp @@ -260,36 +260,47 @@ static InstructionCost ComputeSpeculationCost(const Instruction *I, } } +// Do not hoist any debug info intrinsics. +// ... +// if (cond) { +// x = y * z; +// foo(); +// } +// ... +// -------- Which then becomes: +// ... +// if.then: +// %x = mul i32 %y, %z +// call void @llvm.dbg.value(%x, !"x", !DIExpression()) +// call void foo() +// +// SpeculativeExecution might decide to hoist the 'y * z' calculation +// out of the 'if' block, because it is more efficient that way, so the +// '%x = mul i32 %y, %z' moves to the block above. But it might also +// decide to hoist the 'llvm.dbg.value' call. +// This is incorrect, because even if we've moved the calculation of +// 'y * z', we should not see the value of 'x' change unless we +// actually go inside the 'if' block. + bool SpeculativeExecutionPass::considerHoistingFromTo( BasicBlock &FromBlock, BasicBlock &ToBlock) { SmallPtrSet<const Instruction *, 8> NotHoisted; - const auto AllPrecedingUsesFromBlockHoisted = [&NotHoisted](const User *U) { - // Debug variable has special operand to check it's not hoisted. - if (const auto *DVI = dyn_cast<DbgVariableIntrinsic>(U)) { - return all_of(DVI->location_ops(), [&NotHoisted](Value *V) { - if (const auto *I = dyn_cast_or_null<Instruction>(V)) { - if (!NotHoisted.contains(I)) - return true; - } - return false; - }); - } - - // Usially debug label intrinsic corresponds to label in LLVM IR. In these - // cases we should not move it here. - // TODO: Possible special processing needed to detect it is related to a - // hoisted instruction. - if (isa<DbgLabelInst>(U)) - return false; - - for (const Value *V : U->operand_values()) { - if (const Instruction *I = dyn_cast<Instruction>(V)) { + auto HasNoUnhoistedInstr = [&NotHoisted](auto Values) { + for (const Value *V : Values) { + if (const auto *I = dyn_cast_or_null<Instruction>(V)) if (NotHoisted.contains(I)) return false; - } } return true; }; + auto AllPrecedingUsesFromBlockHoisted = + [&HasNoUnhoistedInstr](const User *U) { + // Do not hoist any debug info intrinsics. + if (isa<DbgInfoIntrinsic>(U)) + return false; + + return HasNoUnhoistedInstr(U->operand_values()); + }; InstructionCost TotalSpeculationCost = 0; unsigned NotHoistedInstCount = 0; @@ -316,7 +327,8 @@ bool SpeculativeExecutionPass::considerHoistingFromTo( auto Current = I; ++I; if (!NotHoisted.count(&*Current)) { - Current->moveBeforePreserving(ToBlock.getTerminator()); + Current->moveBefore(ToBlock.getTerminator()); + Current->dropLocation(); } } return true; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp index 75910d7b698a..75585fcc8026 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -425,14 +425,12 @@ void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd( // Returns true if A matches B + C where C is constant. static bool matchesAdd(Value *A, Value *&B, ConstantInt *&C) { - return (match(A, m_Add(m_Value(B), m_ConstantInt(C))) || - match(A, m_Add(m_ConstantInt(C), m_Value(B)))); + return match(A, m_c_Add(m_Value(B), m_ConstantInt(C))); } // Returns true if A matches B | C where C is constant. static bool matchesOr(Value *A, Value *&B, ConstantInt *&C) { - return (match(A, m_Or(m_Value(B), m_ConstantInt(C))) || - match(A, m_Or(m_ConstantInt(C), m_Value(B)))); + return match(A, m_c_Or(m_Value(B), m_ConstantInt(C))); } void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul( @@ -715,7 +713,7 @@ namespace llvm { PreservedAnalyses StraightLineStrengthReducePass::run(Function &F, FunctionAnalysisManager &AM) { - const DataLayout *DL = &F.getParent()->getDataLayout(); + const DataLayout *DL = &F.getDataLayout(); auto *DT = &AM.getResult<DominatorTreeAnalysis>(F); auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(F); auto *TTI = &AM.getResult<TargetIRAnalysis>(F); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp index 7d96a3478858..9c711ec18382 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -772,7 +772,7 @@ void StructurizeCFG::simplifyAffectedPhis() { bool Changed; do { Changed = false; - SimplifyQuery Q(Func->getParent()->getDataLayout()); + SimplifyQuery Q(Func->getDataLayout()); Q.DT = DT; // Setting CanUseUndef to true might extend value liveness, set it to false // to achieve better register pressure. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp index c6e8505d5ab4..1b3e6d9549b8 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -349,7 +349,7 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) { // does not write to memory and the load provably won't trap. // Writes to memory only matter if they may alias the pointer // being loaded from. - const DataLayout &DL = L->getModule()->getDataLayout(); + const DataLayout &DL = L->getDataLayout(); if (isModSet(AA->getModRefInfo(CI, MemoryLocation::get(L))) || !isSafeToLoadUnconditionally(L->getPointerOperand(), L->getType(), L->getAlign(), DL, L)) @@ -509,8 +509,10 @@ void TailRecursionEliminator::createTailRecurseLoopHeader(CallInst *CI) { BasicBlock *NewEntry = BasicBlock::Create(F.getContext(), "", &F, HeaderBB); NewEntry->takeName(HeaderBB); HeaderBB->setName("tailrecurse"); - BranchInst *BI = BranchInst::Create(HeaderBB, NewEntry); - BI->setDebugLoc(CI->getDebugLoc()); + BranchInst::Create(HeaderBB, NewEntry); + // If the new branch preserves the debug location of CI, it could result in + // misleading stepping, if CI is located in a conditional branch. + // So, here we don't give any debug location to the new branch. // Move all fixed sized allocas from HeaderBB to NewEntry. for (BasicBlock::iterator OEBI = HeaderBB->begin(), E = HeaderBB->end(), @@ -592,7 +594,7 @@ void TailRecursionEliminator::copyByValueOperandIntoLocalTemp(CallInst *CI, int OpndIdx) { Type *AggTy = CI->getParamByValType(OpndIdx); assert(AggTy); - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); // Get alignment of byVal operand. Align Alignment(CI->getParamAlign(OpndIdx).valueOrOne()); @@ -601,7 +603,7 @@ void TailRecursionEliminator::copyByValueOperandIntoLocalTemp(CallInst *CI, // Put alloca into the entry block. Value *NewAlloca = new AllocaInst( AggTy, DL.getAllocaAddrSpace(), nullptr, Alignment, - CI->getArgOperand(OpndIdx)->getName(), &*F.getEntryBlock().begin()); + CI->getArgOperand(OpndIdx)->getName(), F.getEntryBlock().begin()); IRBuilder<> Builder(CI); Value *Size = Builder.getInt64(DL.getTypeAllocSize(AggTy)); @@ -619,7 +621,7 @@ void TailRecursionEliminator::copyLocalTempOfByValueOperandIntoArguments( CallInst *CI, int OpndIdx) { Type *AggTy = CI->getParamByValType(OpndIdx); assert(AggTy); - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = F.getDataLayout(); // Get alignment of byVal operand. Align Alignment(CI->getParamAlign(OpndIdx).valueOrOne()); @@ -714,8 +716,9 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) { // We found a return value we want to use, insert a select instruction to // select it if we don't already know what our return value will be and // store the result in our return value PHI node. - SelectInst *SI = SelectInst::Create( - RetKnownPN, RetPN, Ret->getReturnValue(), "current.ret.tr", Ret); + SelectInst *SI = + SelectInst::Create(RetKnownPN, RetPN, Ret->getReturnValue(), + "current.ret.tr", Ret->getIterator()); RetSelects.push_back(SI); RetPN->addIncoming(SI, BB); @@ -728,7 +731,7 @@ bool TailRecursionEliminator::eliminateCall(CallInst *CI) { // Now that all of the PHI nodes are in place, remove the call and // ret instructions, replacing them with an unconditional branch. - BranchInst *NewBI = BranchInst::Create(HeaderBB, Ret); + BranchInst *NewBI = BranchInst::Create(HeaderBB, Ret->getIterator()); NewBI->setDebugLoc(CI->getDebugLoc()); Ret->eraseFromParent(); // Remove return. @@ -746,7 +749,7 @@ void TailRecursionEliminator::cleanupAndFinalize() { // call. for (PHINode *PN : ArgumentPHIs) { // If the PHI Node is a dynamic constant, replace it with the value it is. - if (Value *PNV = simplifyInstruction(PN, F.getParent()->getDataLayout())) { + if (Value *PNV = simplifyInstruction(PN, F.getDataLayout())) { PN->replaceAllUsesWith(PNV); PN->eraseFromParent(); } @@ -776,6 +779,7 @@ void TailRecursionEliminator::cleanupAndFinalize() { AccRecInstrNew->setOperand(AccRecInstr->getOperand(0) == AccPN, RI->getOperand(0)); AccRecInstrNew->insertBefore(RI); + AccRecInstrNew->dropLocation(); RI->setOperand(0, AccRecInstrNew); } } @@ -787,8 +791,9 @@ void TailRecursionEliminator::cleanupAndFinalize() { if (!RI) continue; - SelectInst *SI = SelectInst::Create( - RetKnownPN, RetPN, RI->getOperand(0), "current.ret.tr", RI); + SelectInst *SI = + SelectInst::Create(RetKnownPN, RetPN, RI->getOperand(0), + "current.ret.tr", RI->getIterator()); RetSelects.push_back(SI); RI->setOperand(0, SI); } @@ -803,6 +808,7 @@ void TailRecursionEliminator::cleanupAndFinalize() { AccRecInstrNew->setOperand(AccRecInstr->getOperand(0) == AccPN, SI->getFalseValue()); AccRecInstrNew->insertBefore(SI); + AccRecInstrNew->dropLocation(); SI->setFalseValue(AccRecInstrNew); } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp index 6ca737df49b9..a25632acbfcc 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/SparseBitVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Module.h" #include "llvm/Support/DataExtractor.h" #include "llvm/Support/MD5.h" #include "llvm/Support/MathExtras.h" @@ -153,12 +154,11 @@ static Value *getStrlenWithNull(IRBuilder<> &Builder, Value *Str) { static Value *callAppendStringN(IRBuilder<> &Builder, Value *Desc, Value *Str, Value *Length, bool isLast) { auto Int64Ty = Builder.getInt64Ty(); - auto PtrTy = Builder.getPtrTy(); - auto Int32Ty = Builder.getInt32Ty(); + auto IsLastInt32 = Builder.getInt32(isLast); auto M = Builder.GetInsertBlock()->getModule(); auto Fn = M->getOrInsertFunction("__ockl_printf_append_string_n", Int64Ty, - Int64Ty, PtrTy, Int64Ty, Int32Ty); - auto IsLastInt32 = Builder.getInt32(isLast); + Desc->getType(), Str->getType(), + Length->getType(), IsLastInt32->getType()); return Builder.CreateCall(Fn, {Desc, Str, Length, IsLastInt32}); } @@ -351,7 +351,7 @@ static void processConstantStringArg(StringData *SD, IRBuilder<> &Builder, } static Value *processNonStringArg(Value *Arg, IRBuilder<> &Builder) { - const DataLayout &DL = Builder.GetInsertBlock()->getModule()->getDataLayout(); + const DataLayout &DL = Builder.GetInsertBlock()->getDataLayout(); auto Ty = Arg->getType(); if (auto IntTy = dyn_cast<IntegerType>(Ty)) { @@ -408,9 +408,7 @@ callBufferedPrintfArgPush(IRBuilder<> &Builder, ArrayRef<Value *> Args, WhatToStore.push_back(processNonStringArg(Args[i], Builder)); } - for (unsigned I = 0, E = WhatToStore.size(); I != E; ++I) { - Value *toStore = WhatToStore[I]; - + for (Value *toStore : WhatToStore) { StoreInst *StBuff = Builder.CreateStore(toStore, PtrToStore); LLVM_DEBUG(dbgs() << "inserting store to printf buffer:" << *StBuff << '\n'); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp index efa8e874b955..3cf68e07da5b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp @@ -320,7 +320,7 @@ RetainedKnowledge llvm::simplifyRetainedKnowledge(AssumeInst *Assume, AssumptionCache *AC, DominatorTree *DT) { AssumeBuilderState Builder(Assume->getModule(), Assume, AC, DT); - RK = canonicalizedKnowledge(RK, Assume->getModule()->getDataLayout()); + RK = canonicalizedKnowledge(RK, Assume->getDataLayout()); if (!Builder.isKnowledgeWorthPreserving(RK)) return RetainedKnowledge::none(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp index ec0482ac2cde..79911bf563ea 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -382,14 +382,23 @@ bool llvm::MergeBlockSuccessorsIntoGivenBlocks( /// - Check fully overlapping fragments and not only identical fragments. /// - Support dbg.declare. dbg.label, and possibly other meta instructions being /// part of the sequence of consecutive instructions. -static bool DPValuesRemoveRedundantDbgInstrsUsingBackwardScan(BasicBlock *BB) { - SmallVector<DPValue *, 8> ToBeRemoved; +static bool +DbgVariableRecordsRemoveRedundantDbgInstrsUsingBackwardScan(BasicBlock *BB) { + SmallVector<DbgVariableRecord *, 8> ToBeRemoved; SmallDenseSet<DebugVariable> VariableSet; for (auto &I : reverse(*BB)) { - for (DPValue &DPV : reverse(I.getDbgValueRange())) { + for (DbgRecord &DR : reverse(I.getDbgRecordRange())) { + if (isa<DbgLabelRecord>(DR)) { + // Emulate existing behaviour (see comment below for dbg.declares). + // FIXME: Don't do this. + VariableSet.clear(); + continue; + } + + DbgVariableRecord &DVR = cast<DbgVariableRecord>(DR); // Skip declare-type records, as the debug intrinsic method only works // on dbg.value intrinsics. - if (DPV.getType() == DPValue::LocationType::Declare) { + if (DVR.getType() == DbgVariableRecord::LocationType::Declare) { // The debug intrinsic method treats dbg.declares are "non-debug" // instructions (i.e., a break in a consecutive range of debug // intrinsics). Emulate that to create identical outputs. See @@ -399,8 +408,8 @@ static bool DPValuesRemoveRedundantDbgInstrsUsingBackwardScan(BasicBlock *BB) { continue; } - DebugVariable Key(DPV.getVariable(), DPV.getExpression(), - DPV.getDebugLoc()->getInlinedAt()); + DebugVariable Key(DVR.getVariable(), DVR.getExpression(), + DVR.getDebugLoc()->getInlinedAt()); auto R = VariableSet.insert(Key); // If the same variable fragment is described more than once it is enough // to keep the last one (i.e. the first found since we for reverse @@ -408,14 +417,14 @@ static bool DPValuesRemoveRedundantDbgInstrsUsingBackwardScan(BasicBlock *BB) { if (R.second) continue; - if (DPV.isDbgAssign()) { + if (DVR.isDbgAssign()) { // Don't delete dbg.assign intrinsics that are linked to instructions. - if (!at::getAssignmentInsts(&DPV).empty()) + if (!at::getAssignmentInsts(&DVR).empty()) continue; // Unlinked dbg.assign intrinsics can be treated like dbg.values. } - ToBeRemoved.push_back(&DPV); + ToBeRemoved.push_back(&DVR); continue; } // Sequence with consecutive dbg.value instrs ended. Clear the map to @@ -424,15 +433,15 @@ static bool DPValuesRemoveRedundantDbgInstrsUsingBackwardScan(BasicBlock *BB) { VariableSet.clear(); } - for (auto &DPV : ToBeRemoved) - DPV->eraseFromParent(); + for (auto &DVR : ToBeRemoved) + DVR->eraseFromParent(); return !ToBeRemoved.empty(); } static bool removeRedundantDbgInstrsUsingBackwardScan(BasicBlock *BB) { if (BB->IsNewDbgInfoFormat) - return DPValuesRemoveRedundantDbgInstrsUsingBackwardScan(BB); + return DbgVariableRecordsRemoveRedundantDbgInstrsUsingBackwardScan(BB); SmallVector<DbgValueInst *, 8> ToBeRemoved; SmallDenseSet<DebugVariable> VariableSet; @@ -491,29 +500,30 @@ static bool removeRedundantDbgInstrsUsingBackwardScan(BasicBlock *BB) { /// /// Possible improvements: /// - Keep track of non-overlapping fragments. -static bool DPValuesRemoveRedundantDbgInstrsUsingForwardScan(BasicBlock *BB) { - SmallVector<DPValue *, 8> ToBeRemoved; +static bool +DbgVariableRecordsRemoveRedundantDbgInstrsUsingForwardScan(BasicBlock *BB) { + SmallVector<DbgVariableRecord *, 8> ToBeRemoved; DenseMap<DebugVariable, std::pair<SmallVector<Value *, 4>, DIExpression *>> VariableMap; for (auto &I : *BB) { - for (DPValue &DPV : I.getDbgValueRange()) { - if (DPV.getType() == DPValue::LocationType::Declare) + for (DbgVariableRecord &DVR : filterDbgVars(I.getDbgRecordRange())) { + if (DVR.getType() == DbgVariableRecord::LocationType::Declare) continue; - DebugVariable Key(DPV.getVariable(), std::nullopt, - DPV.getDebugLoc()->getInlinedAt()); + DebugVariable Key(DVR.getVariable(), std::nullopt, + DVR.getDebugLoc()->getInlinedAt()); auto VMI = VariableMap.find(Key); // A dbg.assign with no linked instructions can be treated like a // dbg.value (i.e. can be deleted). bool IsDbgValueKind = - (!DPV.isDbgAssign() || at::getAssignmentInsts(&DPV).empty()); + (!DVR.isDbgAssign() || at::getAssignmentInsts(&DVR).empty()); // Update the map if we found a new value/expression describing the // variable, or if the variable wasn't mapped already. - SmallVector<Value *, 4> Values(DPV.location_ops()); + SmallVector<Value *, 4> Values(DVR.location_ops()); if (VMI == VariableMap.end() || VMI->second.first != Values || - VMI->second.second != DPV.getExpression()) { + VMI->second.second != DVR.getExpression()) { if (IsDbgValueKind) - VariableMap[Key] = {Values, DPV.getExpression()}; + VariableMap[Key] = {Values, DVR.getExpression()}; else VariableMap[Key] = {Values, nullptr}; continue; @@ -522,55 +532,56 @@ static bool DPValuesRemoveRedundantDbgInstrsUsingForwardScan(BasicBlock *BB) { if (!IsDbgValueKind) continue; // Found an identical mapping. Remember the instruction for later removal. - ToBeRemoved.push_back(&DPV); + ToBeRemoved.push_back(&DVR); } } - for (auto *DPV : ToBeRemoved) - DPV->eraseFromParent(); + for (auto *DVR : ToBeRemoved) + DVR->eraseFromParent(); return !ToBeRemoved.empty(); } -static bool DPValuesRemoveUndefDbgAssignsFromEntryBlock(BasicBlock *BB) { +static bool +DbgVariableRecordsRemoveUndefDbgAssignsFromEntryBlock(BasicBlock *BB) { assert(BB->isEntryBlock() && "expected entry block"); - SmallVector<DPValue *, 8> ToBeRemoved; + SmallVector<DbgVariableRecord *, 8> ToBeRemoved; DenseSet<DebugVariable> SeenDefForAggregate; // Returns the DebugVariable for DVI with no fragment info. - auto GetAggregateVariable = [](const DPValue &DPV) { - return DebugVariable(DPV.getVariable(), std::nullopt, - DPV.getDebugLoc().getInlinedAt()); + auto GetAggregateVariable = [](const DbgVariableRecord &DVR) { + return DebugVariable(DVR.getVariable(), std::nullopt, + DVR.getDebugLoc().getInlinedAt()); }; // Remove undef dbg.assign intrinsics that are encountered before // any non-undef intrinsics from the entry block. for (auto &I : *BB) { - for (DPValue &DPV : I.getDbgValueRange()) { - if (!DPV.isDbgValue() && !DPV.isDbgAssign()) + for (DbgVariableRecord &DVR : filterDbgVars(I.getDbgRecordRange())) { + if (!DVR.isDbgValue() && !DVR.isDbgAssign()) continue; bool IsDbgValueKind = - (DPV.isDbgValue() || at::getAssignmentInsts(&DPV).empty()); - DebugVariable Aggregate = GetAggregateVariable(DPV); + (DVR.isDbgValue() || at::getAssignmentInsts(&DVR).empty()); + DebugVariable Aggregate = GetAggregateVariable(DVR); if (!SeenDefForAggregate.contains(Aggregate)) { - bool IsKill = DPV.isKillLocation() && IsDbgValueKind; + bool IsKill = DVR.isKillLocation() && IsDbgValueKind; if (!IsKill) { SeenDefForAggregate.insert(Aggregate); - } else if (DPV.isDbgAssign()) { - ToBeRemoved.push_back(&DPV); + } else if (DVR.isDbgAssign()) { + ToBeRemoved.push_back(&DVR); } } } } - for (DPValue *DPV : ToBeRemoved) - DPV->eraseFromParent(); + for (DbgVariableRecord *DVR : ToBeRemoved) + DVR->eraseFromParent(); return !ToBeRemoved.empty(); } static bool removeRedundantDbgInstrsUsingForwardScan(BasicBlock *BB) { if (BB->IsNewDbgInfoFormat) - return DPValuesRemoveRedundantDbgInstrsUsingForwardScan(BB); + return DbgVariableRecordsRemoveRedundantDbgInstrsUsingForwardScan(BB); SmallVector<DbgValueInst *, 8> ToBeRemoved; DenseMap<DebugVariable, std::pair<SmallVector<Value *, 4>, DIExpression *>> @@ -634,7 +645,7 @@ static bool removeRedundantDbgInstrsUsingForwardScan(BasicBlock *BB) { /// - Keep track of non-overlapping fragments. static bool removeUndefDbgAssignsFromEntryBlock(BasicBlock *BB) { if (BB->IsNewDbgInfoFormat) - return DPValuesRemoveUndefDbgAssignsFromEntryBlock(BB); + return DbgVariableRecordsRemoveUndefDbgAssignsFromEntryBlock(BB); assert(BB->isEntryBlock() && "expected entry block"); SmallVector<DbgAssignIntrinsic *, 8> ToBeRemoved; @@ -773,7 +784,7 @@ BasicBlock *llvm::SplitEdge(BasicBlock *BB, BasicBlock *Succ, DominatorTree *DT, // If the successor only has a single pred, split the top of the successor // block. assert(SP == BB && "CFG broken"); - SP = nullptr; + (void)SP; return SplitBlock(Succ, &Succ->front(), DT, LI, MSSAU, BBName, /*Before=*/true); } @@ -1130,6 +1141,7 @@ BasicBlock *llvm::splitBlockBefore(BasicBlock *Old, BasicBlock::iterator SplitPt } /// Update DominatorTree, LoopInfo, and LCCSA analysis information. +/// Invalidates DFS Numbering when DTU or DT is provided. static void UpdateAnalysisInformation(BasicBlock *OldBB, BasicBlock *NewBB, ArrayRef<BasicBlock *> Preds, DomTreeUpdater *DTU, DominatorTree *DT, @@ -1289,7 +1301,7 @@ static void UpdatePHINodes(BasicBlock *OrigBB, BasicBlock *NewBB, // PHI. // Create the new PHI node, insert it into NewBB at the end of the block PHINode *NewPHI = - PHINode::Create(PN->getType(), Preds.size(), PN->getName() + ".ph", BI); + PHINode::Create(PN->getType(), Preds.size(), PN->getName() + ".ph", BI->getIterator()); // NOTE! This loop walks backwards for a reason! First off, this minimizes // the cost of removal if we end up removing a large number of values, and @@ -1390,13 +1402,13 @@ SplitBlockPredecessorsImpl(BasicBlock *BB, ArrayRef<BasicBlock *> Preds, if (OldLatch) { BasicBlock *NewLatch = L->getLoopLatch(); if (NewLatch != OldLatch) { - MDNode *MD = OldLatch->getTerminator()->getMetadata("llvm.loop"); - NewLatch->getTerminator()->setMetadata("llvm.loop", MD); + MDNode *MD = OldLatch->getTerminator()->getMetadata(LLVMContext::MD_loop); + NewLatch->getTerminator()->setMetadata(LLVMContext::MD_loop, MD); // It's still possible that OldLatch is the latch of another inner loop, // in which case we do not remove the metadata. Loop *IL = LI->getLoopFor(OldLatch); if (IL && IL->getLoopLatch() != OldLatch) - OldLatch->getTerminator()->setMetadata("llvm.loop", nullptr); + OldLatch->getTerminator()->setMetadata(LLVMContext::MD_loop, nullptr); } } @@ -1509,7 +1521,7 @@ static void SplitLandingPadPredecessorsImpl( assert(!LPad->getType()->isTokenTy() && "Split cannot be applied if LPad is token type. Otherwise an " "invalid PHINode of token type would be created."); - PHINode *PN = PHINode::Create(LPad->getType(), 2, "lpad.phi", LPad); + PHINode *PN = PHINode::Create(LPad->getType(), 2, "lpad.phi", LPad->getIterator()); PN->addIncoming(Clone1, NewBB1); PN->addIncoming(Clone2, NewBB2); LPad->replaceAllUsesWith(PN); @@ -1722,7 +1734,7 @@ llvm::SplitBlockAndInsertSimpleForLoop(Value *End, Instruction *SplitBefore) { BasicBlock *LoopExit = SplitBlock(SplitBefore->getParent(), SplitBefore); auto *Ty = End->getType(); - auto &DL = SplitBefore->getModule()->getDataLayout(); + auto &DL = SplitBefore->getDataLayout(); const unsigned Bitwidth = DL.getTypeSizeInBits(Ty); IRBuilder<> Builder(LoopBody->getTerminator()); @@ -1896,7 +1908,7 @@ static void reconnectPhis(BasicBlock *Out, BasicBlock *GuardBlock, auto Phi = cast<PHINode>(I); auto NewPhi = PHINode::Create(Phi->getType(), Incoming.size(), - Phi->getName() + ".moved", &FirstGuardBlock->front()); + Phi->getName() + ".moved", FirstGuardBlock->begin()); for (auto *In : Incoming) { Value *V = UndefValue::get(Phi->getType()); if (In == Out) { @@ -2015,7 +2027,7 @@ static void calcPredicateUsingInteger( Value *Id1 = ConstantInt::get(Type::getInt32Ty(Context), std::distance(Outgoing.begin(), Succ1Iter)); IncomingId = SelectInst::Create(Condition, Id0, Id1, "target.bb.idx", - In->getTerminator()); + In->getTerminator()->getIterator()); } else { // Get the index of the non-null successor. auto SuccIter = Succ0 ? find(Outgoing, Succ0) : find(Outgoing, Succ1); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp index 5fb796cc3db6..4606514cbc71 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp @@ -344,12 +344,9 @@ bool llvm::SplitIndirectBrCriticalEdges(Function &F, // this lowers the common case's overhead to O(Blocks) instead of O(Edges). SmallSetVector<BasicBlock *, 16> Targets; for (auto &BB : F) { - auto *IBI = dyn_cast<IndirectBrInst>(BB.getTerminator()); - if (!IBI) - continue; - - for (unsigned Succ = 0, E = IBI->getNumSuccessors(); Succ != E; ++Succ) - Targets.insert(IBI->getSuccessor(Succ)); + if (isa<IndirectBrInst>(BB.getTerminator())) + for (BasicBlock *Succ : successors(&BB)) + Targets.insert(Succ); } if (Targets.empty()) @@ -423,7 +420,7 @@ bool llvm::SplitIndirectBrCriticalEdges(Function &F, // (b) Leave that as the only edge in the "Indirect" PHI. // (c) Merge the two in the body block. BasicBlock::iterator Indirect = Target->begin(), - End = Target->getFirstNonPHI()->getIterator(); + End = Target->getFirstNonPHIIt(); BasicBlock::iterator Direct = DirectSucc->begin(); BasicBlock::iterator MergeInsert = BodyBlock->getFirstInsertionPt(); @@ -433,6 +430,7 @@ bool llvm::SplitIndirectBrCriticalEdges(Function &F, while (Indirect != End) { PHINode *DirPHI = cast<PHINode>(Direct); PHINode *IndPHI = cast<PHINode>(Indirect); + BasicBlock::iterator InsertPt = Indirect; // Now, clean up - the direct block shouldn't get the indirect value, // and vice versa. @@ -443,7 +441,7 @@ bool llvm::SplitIndirectBrCriticalEdges(Function &F, // PHI is erased. Indirect++; - PHINode *NewIndPHI = PHINode::Create(IndPHI->getType(), 1, "ind", IndPHI); + PHINode *NewIndPHI = PHINode::Create(IndPHI->getType(), 1, "ind", InsertPt); NewIndPHI->addIncoming(IndPHI->getIncomingValueForBlock(IBRPred), IBRPred); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/BuildLibCalls.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/BuildLibCalls.cpp index 12741dc5af5a..0c45bd886af9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/BuildLibCalls.cpp @@ -1098,6 +1098,11 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, case LibFunc_ldexpl: Changed |= setWillReturn(F); break; + case LibFunc_remquo: + case LibFunc_remquof: + case LibFunc_remquol: + Changed |= setDoesNotCapture(F, 2); + [[fallthrough]]; case LibFunc_abs: case LibFunc_acos: case LibFunc_acosf: @@ -1137,6 +1142,9 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, case LibFunc_cosl: case LibFunc_cospi: case LibFunc_cospif: + case LibFunc_erf: + case LibFunc_erff: + case LibFunc_erfl: case LibFunc_exp: case LibFunc_expf: case LibFunc_expl: @@ -1192,6 +1200,9 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, case LibFunc_pow: case LibFunc_powf: case LibFunc_powl: + case LibFunc_remainder: + case LibFunc_remainderf: + case LibFunc_remainderl: case LibFunc_rint: case LibFunc_rintf: case LibFunc_rintl: @@ -1252,7 +1263,7 @@ static void setRetExtAttr(Function &F, } // Modeled after X86TargetLowering::markLibCallAttributes. -static void markRegisterParameterAttributes(Function *F) { +void llvm::markRegisterParameterAttributes(Function *F) { if (!F->arg_size() || F->isVarArg()) return; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp index 73a50b793e6d..41031ae69c40 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp @@ -233,7 +233,7 @@ ValueRange FastDivInsertionTask::getValueRange(Value *V, assert(LongLen > ShortLen && "Value type must be wider than BypassType"); unsigned HiBits = LongLen - ShortLen; - const DataLayout &DL = SlowDivOrRem->getModule()->getDataLayout(); + const DataLayout &DL = SlowDivOrRem->getDataLayout(); KnownBits Known(LongLen); computeKnownBits(V, Known, DL); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp index d0b9884aa909..3b6fce578ffc 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CallGraphUpdater.cpp @@ -13,9 +13,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/CallGraphUpdater.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Analysis/CallGraph.h" -#include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/IR/Constants.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -28,53 +25,31 @@ bool CallGraphUpdater::finalize() { DeadFunctionsInComdats.end()); } - if (CG) { - // First remove all references, e.g., outgoing via called functions. This is - // necessary as we can delete functions that have circular references. - for (Function *DeadFn : DeadFunctions) { - DeadFn->removeDeadConstantUsers(); - CallGraphNode *DeadCGN = (*CG)[DeadFn]; - DeadCGN->removeAllCalledFunctions(); - CG->getExternalCallingNode()->removeAnyCallEdgeTo(DeadCGN); - DeadFn->replaceAllUsesWith(PoisonValue::get(DeadFn->getType())); - } - - // Then remove the node and function from the module. - for (Function *DeadFn : DeadFunctions) { - CallGraphNode *DeadCGN = CG->getOrInsertFunction(DeadFn); - assert(DeadCGN->getNumReferences() == 0 && - "References should have been handled by now"); - delete CG->removeFunctionFromModule(DeadCGN); - } - } else { - // This is the code path for the new lazy call graph and for the case were - // no call graph was provided. - for (Function *DeadFn : DeadFunctions) { - DeadFn->removeDeadConstantUsers(); - DeadFn->replaceAllUsesWith(PoisonValue::get(DeadFn->getType())); - - if (LCG && !ReplacedFunctions.count(DeadFn)) { - // Taken mostly from the inliner: - LazyCallGraph::Node &N = LCG->get(*DeadFn); - auto *DeadSCC = LCG->lookupSCC(N); - assert(DeadSCC && DeadSCC->size() == 1 && - &DeadSCC->begin()->getFunction() == DeadFn); - auto &DeadRC = DeadSCC->getOuterRefSCC(); - - FunctionAnalysisManager &FAM = - AM->getResult<FunctionAnalysisManagerCGSCCProxy>(*DeadSCC, *LCG) - .getManager(); - - FAM.clear(*DeadFn, DeadFn->getName()); - AM->clear(*DeadSCC, DeadSCC->getName()); - LCG->removeDeadFunction(*DeadFn); - - // Mark the relevant parts of the call graph as invalid so we don't - // visit them. - UR->InvalidatedSCCs.insert(DeadSCC); - UR->InvalidatedRefSCCs.insert(&DeadRC); - } - + // This is the code path for the new lazy call graph and for the case were + // no call graph was provided. + for (Function *DeadFn : DeadFunctions) { + DeadFn->removeDeadConstantUsers(); + DeadFn->replaceAllUsesWith(PoisonValue::get(DeadFn->getType())); + + if (LCG && !ReplacedFunctions.count(DeadFn)) { + // Taken mostly from the inliner: + LazyCallGraph::Node &N = LCG->get(*DeadFn); + auto *DeadSCC = LCG->lookupSCC(N); + assert(DeadSCC && DeadSCC->size() == 1 && + &DeadSCC->begin()->getFunction() == DeadFn); + + FAM->clear(*DeadFn, DeadFn->getName()); + AM->clear(*DeadSCC, DeadSCC->getName()); + LCG->markDeadFunction(*DeadFn); + + // Mark the relevant parts of the call graph as invalid so we don't + // visit them. + UR->InvalidatedSCCs.insert(LCG->lookupSCC(N)); + UR->DeadFunctions.push_back(DeadFn); + } else { + // The CGSCC infrastructure batch deletes functions at the end of the + // call graph walk, so only erase the function if we're not using that + // infrastructure. // The function is now really dead and de-attached from everything. DeadFn->eraseFromParent(); } @@ -87,11 +62,7 @@ bool CallGraphUpdater::finalize() { } void CallGraphUpdater::reanalyzeFunction(Function &Fn) { - if (CG) { - CallGraphNode *OldCGN = CG->getOrInsertFunction(&Fn); - OldCGN->removeAllCalledFunctions(); - CG->populateCallGraphNode(OldCGN); - } else if (LCG) { + if (LCG) { LazyCallGraph::Node &N = LCG->get(Fn); LazyCallGraph::SCC *C = LCG->lookupSCC(N); updateCGAndAnalysisManagerForCGSCCPass(*LCG, *C, N, *AM, *UR, *FAM); @@ -100,9 +71,7 @@ void CallGraphUpdater::reanalyzeFunction(Function &Fn) { void CallGraphUpdater::registerOutlinedFunction(Function &OriginalFn, Function &NewFn) { - if (CG) - CG->addToCallGraph(&NewFn); - else if (LCG) + if (LCG) LCG->addSplitFunction(OriginalFn, NewFn); } @@ -114,12 +83,6 @@ void CallGraphUpdater::removeFunction(Function &DeadFn) { else DeadFunctions.push_back(&DeadFn); - // For the old call graph we remove the function from the SCC right away. - if (CG && !ReplacedFunctions.count(&DeadFn)) { - CallGraphNode *DeadCGN = (*CG)[&DeadFn]; - DeadCGN->removeAllCalledFunctions(); - CGSCC->DeleteNode(DeadCGN); - } if (FAM) FAM->clear(DeadFn, DeadFn.getName()); } @@ -127,46 +90,10 @@ void CallGraphUpdater::removeFunction(Function &DeadFn) { void CallGraphUpdater::replaceFunctionWith(Function &OldFn, Function &NewFn) { OldFn.removeDeadConstantUsers(); ReplacedFunctions.insert(&OldFn); - if (CG) { - // Update the call graph for the newly promoted function. - CallGraphNode *OldCGN = (*CG)[&OldFn]; - CallGraphNode *NewCGN = CG->getOrInsertFunction(&NewFn); - NewCGN->stealCalledFunctionsFrom(OldCGN); - CG->ReplaceExternalCallEdge(OldCGN, NewCGN); - - // And update the SCC we're iterating as well. - CGSCC->ReplaceNode(OldCGN, NewCGN); - } else if (LCG) { + if (LCG) { // Directly substitute the functions in the call graph. LazyCallGraph::Node &OldLCGN = LCG->get(OldFn); SCC->getOuterRefSCC().replaceNodeFunction(OldLCGN, NewFn); } removeFunction(OldFn); } - -bool CallGraphUpdater::replaceCallSite(CallBase &OldCS, CallBase &NewCS) { - // This is only necessary in the (old) CG. - if (!CG) - return true; - - Function *Caller = OldCS.getCaller(); - CallGraphNode *NewCalleeNode = - CG->getOrInsertFunction(NewCS.getCalledFunction()); - CallGraphNode *CallerNode = (*CG)[Caller]; - if (llvm::none_of(*CallerNode, [&OldCS](const CallGraphNode::CallRecord &CR) { - return CR.first && *CR.first == &OldCS; - })) - return false; - CallerNode->replaceCallEdge(OldCS, NewCS, NewCalleeNode); - return true; -} - -void CallGraphUpdater::removeCallSite(CallBase &CS) { - // This is only necessary in the (old) CG. - if (!CG) - return; - - Function *Caller = CS.getCaller(); - CallGraphNode *CallerNode = (*CG)[Caller]; - CallerNode->removeCallEdgeFor(CS); -} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp index e42cdab64446..90dc727cde16 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -12,11 +12,14 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/CallPromotionUtils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/AttributeMask.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; @@ -168,12 +171,12 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) { // Determine an appropriate location to create the bitcast for the return // value. The location depends on if we have a call or invoke instruction. - Instruction *InsertBefore = nullptr; + BasicBlock::iterator InsertBefore; if (auto *Invoke = dyn_cast<InvokeInst>(&CB)) InsertBefore = - &SplitEdge(Invoke->getParent(), Invoke->getNormalDest())->front(); + SplitEdge(Invoke->getParent(), Invoke->getNormalDest())->begin(); else - InsertBefore = &*std::next(CB.getIterator()); + InsertBefore = std::next(CB.getIterator()); // Bitcast the return value to the correct type. auto *Cast = CastInst::CreateBitOrPointerCast(&CB, RetTy, "", InsertBefore); @@ -188,10 +191,9 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) { /// Predicate and clone the given call site. /// /// This function creates an if-then-else structure at the location of the call -/// site. The "if" condition compares the call site's called value to the given -/// callee. The original call site is moved into the "else" block, and a clone -/// of the call site is placed in the "then" block. The cloned instruction is -/// returned. +/// site. The "if" condition is specified by `Cond`. +/// The original call site is moved into the "else" block, and a clone of the +/// call site is placed in the "then" block. The cloned instruction is returned. /// /// For example, the call instruction below: /// @@ -202,7 +204,7 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) { /// Is replace by the following: /// /// orig_bb: -/// %cond = icmp eq i32 ()* %ptr, @func +/// %cond = Cond /// br i1 %cond, %then_bb, %else_bb /// /// then_bb: @@ -232,7 +234,7 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) { /// Is replace by the following: /// /// orig_bb: -/// %cond = icmp eq i32 ()* %ptr, @func +/// %cond = Cond /// br i1 %cond, %then_bb, %else_bb /// /// then_bb: @@ -267,7 +269,7 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) { /// Is replaced by the following: /// /// cond_bb: -/// %cond = icmp eq i32 ()* %ptr, @func +/// %cond = Cond /// br i1 %cond, %then_bb, %orig_bb /// /// then_bb: @@ -280,19 +282,13 @@ static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) { /// ; The original call instruction stays in its original block. /// %t0 = musttail call i32 %ptr() /// ret %t0 -CallBase &llvm::versionCallSite(CallBase &CB, Value *Callee, - MDNode *BranchWeights) { +static CallBase &versionCallSiteWithCond(CallBase &CB, Value *Cond, + MDNode *BranchWeights) { IRBuilder<> Builder(&CB); CallBase *OrigInst = &CB; BasicBlock *OrigBlock = OrigInst->getParent(); - // Create the compare. The called value and callee must have the same type to - // be compared. - if (CB.getCalledOperand()->getType() != Callee->getType()) - Callee = Builder.CreateBitCast(Callee, CB.getCalledOperand()->getType()); - auto *Cond = Builder.CreateICmpEQ(CB.getCalledOperand(), Callee); - if (OrigInst->isMustTailCall()) { // Create an if-then structure. The original instruction stays in its block, // and a clone of the original instruction is placed in the "then" block. @@ -380,11 +376,27 @@ CallBase &llvm::versionCallSite(CallBase &CB, Value *Callee, return *NewInst; } +// Predicate and clone the given call site using condition `CB.callee == +// Callee`. See the comment `versionCallSiteWithCond` for the transformation. +CallBase &llvm::versionCallSite(CallBase &CB, Value *Callee, + MDNode *BranchWeights) { + + IRBuilder<> Builder(&CB); + + // Create the compare. The called value and callee must have the same type to + // be compared. + if (CB.getCalledOperand()->getType() != Callee->getType()) + Callee = Builder.CreateBitCast(Callee, CB.getCalledOperand()->getType()); + auto *Cond = Builder.CreateICmpEQ(CB.getCalledOperand(), Callee); + + return versionCallSiteWithCond(CB, Cond, BranchWeights); +} + bool llvm::isLegalToPromote(const CallBase &CB, Function *Callee, const char **FailureReason) { assert(!CB.getCalledFunction() && "Only indirect call sites can be promoted"); - auto &DL = Callee->getParent()->getDataLayout(); + auto &DL = Callee->getDataLayout(); // Check the return type. The callee's return value type must be bitcast // compatible with the call site's type. @@ -509,7 +521,8 @@ CallBase &llvm::promoteCall(CallBase &CB, Function *Callee, Type *FormalTy = CalleeType->getParamType(ArgNo); Type *ActualTy = Arg->getType(); if (FormalTy != ActualTy) { - auto *Cast = CastInst::CreateBitOrPointerCast(Arg, FormalTy, "", &CB); + auto *Cast = + CastInst::CreateBitOrPointerCast(Arg, FormalTy, "", CB.getIterator()); CB.setArgOperand(ArgNo, Cast); // Remove any incompatible attributes for the argument. @@ -559,6 +572,27 @@ CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee, return promoteCall(NewInst, Callee); } +CallBase &llvm::promoteCallWithVTableCmp(CallBase &CB, Instruction *VPtr, + Function *Callee, + ArrayRef<Constant *> AddressPoints, + MDNode *BranchWeights) { + assert(!AddressPoints.empty() && "Caller should guarantee"); + IRBuilder<> Builder(&CB); + SmallVector<Value *, 2> ICmps; + for (auto &AddressPoint : AddressPoints) + ICmps.push_back(Builder.CreateICmpEQ(VPtr, AddressPoint)); + + // TODO: Perform tree height reduction if the number of ICmps is high. + Value *Cond = Builder.CreateOr(ICmps); + + // Version the indirect call site. If Cond is true, 'NewInst' will be + // executed, otherwise the original call site will be executed. + CallBase &NewInst = versionCallSiteWithCond(CB, Cond, BranchWeights); + + // Promote 'NewInst' so that it directly calls the desired function. + return promoteCall(NewInst, Callee); +} + bool llvm::tryPromoteCall(CallBase &CB) { assert(!CB.getCalledFunction()); Module *M = CB.getCaller()->getParent(); @@ -597,16 +631,13 @@ bool llvm::tryPromoteCall(CallBase &CB) { // Not in the form of a global constant variable with an initializer. return false; - Constant *VTableGVInitializer = GV->getInitializer(); APInt VTableGVOffset = VTableOffsetGVBase + VTableOffset; if (!(VTableGVOffset.getActiveBits() <= 64)) return false; // Out of range. - Constant *Ptr = getPointerAtOffset(VTableGVInitializer, - VTableGVOffset.getZExtValue(), - *M); - if (!Ptr) - return false; // No constant (function) pointer found. - Function *DirectCallee = dyn_cast<Function>(Ptr->stripPointerCasts()); + + Function *DirectCallee = nullptr; + std::tie(DirectCallee, std::ignore) = + getFunctionAtVTableOffset(GV, VTableGVOffset.getZExtValue(), *M); if (!DirectCallee) return false; // No function pointer found. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp index c24b6ed70405..7330e59d25ee 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeAliases.cpp @@ -31,6 +31,7 @@ #include "llvm/Transforms/Utils/CanonicalizeAliases.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/Module.h" using namespace llvm; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp index 282c44563466..40010aee9c11 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CanonicalizeFreezeInLoops.cpp @@ -144,7 +144,7 @@ void CanonicalizeFreezeInLoopsImpl::InsertFreezeAndForgetFromSCEV(Use &U) { LLVM_DEBUG(dbgs() << "\tOperand: " << *U.get() << "\n"); U.set(new FreezeInst(ValueToFr, ValueToFr->getName() + ".frozen", - PH->getTerminator())); + PH->getTerminator()->getIterator())); SE.forgetValue(UserI); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneFunction.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneFunction.cpp index c0f333364fa5..47e3c03288d9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneFunction.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneFunction.cpp @@ -14,9 +14,11 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/AttributeMask.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfo.h" @@ -276,8 +278,8 @@ void llvm::CloneFunctionInto(Function *NewFunc, const Function *OldFunc, // attached debug-info records. for (Instruction &II : *BB) { RemapInstruction(&II, VMap, RemapFlag, TypeMapper, Materializer); - RemapDPValueRange(II.getModule(), II.getDbgValueRange(), VMap, RemapFlag, - TypeMapper, Materializer); + RemapDbgRecordRange(II.getModule(), II.getDbgRecordRange(), VMap, + RemapFlag, TypeMapper, Materializer); } // Only update !llvm.dbg.cu for DifferentModule (not CloneModule). In the @@ -384,18 +386,6 @@ public: }; } // namespace -static bool hasRoundingModeOperand(Intrinsic::ID CIID) { - switch (CIID) { -#define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC) \ - case Intrinsic::INTRINSIC: \ - return ROUND_MODE == 1; -#define FUNCTION INSTRUCTION -#include "llvm/IR/ConstrainedOps.def" - default: - llvm_unreachable("Unexpected constrained intrinsic id"); - } -} - Instruction * PruningFunctionCloner::cloneInstruction(BasicBlock::const_iterator II) { const Instruction &OldInst = *II; @@ -453,7 +443,7 @@ PruningFunctionCloner::cloneInstruction(BasicBlock::const_iterator II) { // The last arguments of a constrained intrinsic are metadata that // represent rounding mode (absents in some intrinsics) and exception // behavior. The inlined function uses default settings. - if (hasRoundingModeOperand(CIID)) + if (Intrinsic::hasConstrainedFPRoundingModeOperand(CIID)) Args.push_back( MetadataAsValue::get(Ctx, MDString::get(Ctx, "round.tonearest"))); Args.push_back( @@ -540,18 +530,13 @@ void PruningFunctionCloner::CloneBlock( RemapInstruction(NewInst, VMap, ModuleLevelChanges ? RF_None : RF_NoModuleLevelChanges); - // If we can simplify this instruction to some other value, simply add - // a mapping to that value rather than inserting a new instruction into - // the basic block. - if (Value *V = - simplifyInstruction(NewInst, BB->getModule()->getDataLayout())) { - // On the off-chance that this simplifies to an instruction in the old - // function, map it back into the new function. - if (NewFunc != OldFunc) - if (Value *MappedV = VMap.lookup(V)) - V = MappedV; - - if (!NewInst->mayHaveSideEffects()) { + // Eagerly constant fold the newly cloned instruction. If successful, add + // a mapping to the new value. Non-constant operands may be incomplete at + // this stage, thus instruction simplification is performed after + // processing phi-nodes. + if (Value *V = ConstantFoldInstruction( + NewInst, BB->getDataLayout())) { + if (isInstructionTriviallyDead(NewInst)) { VMap[&*II] = V; NewInst->eraseFromParent(); continue; @@ -641,8 +626,8 @@ void PruningFunctionCloner::CloneBlock( // Recursively clone any reachable successor blocks. append_range(ToClone, successors(BB->getTerminator())); } else { - // If we didn't create a new terminator, clone DPValues from the old - // terminator onto the new terminator. + // If we didn't create a new terminator, clone DbgVariableRecords from the + // old terminator onto the new terminator. Instruction *NewInst = NewBB->getTerminator(); assert(NewInst); @@ -823,54 +808,40 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, } } - // Make a second pass over the PHINodes now that all of them have been - // remapped into the new function, simplifying the PHINode and performing any - // recursive simplifications exposed. This will transparently update the - // WeakTrackingVH in the VMap. Notably, we rely on that so that if we coalesce - // two PHINodes, the iteration over the old PHIs remains valid, and the - // mapping will just map us to the new node (which may not even be a PHI - // node). - const DataLayout &DL = NewFunc->getParent()->getDataLayout(); - SmallSetVector<const Value *, 8> Worklist; - for (unsigned Idx = 0, Size = PHIToResolve.size(); Idx != Size; ++Idx) - if (isa<PHINode>(VMap[PHIToResolve[Idx]])) - Worklist.insert(PHIToResolve[Idx]); - - // Note that we must test the size on each iteration, the worklist can grow. - for (unsigned Idx = 0; Idx != Worklist.size(); ++Idx) { - const Value *OrigV = Worklist[Idx]; - auto *I = dyn_cast_or_null<Instruction>(VMap.lookup(OrigV)); - if (!I) - continue; - - // Skip over non-intrinsic callsites, we don't want to remove any nodes from - // the CGSCC. - CallBase *CB = dyn_cast<CallBase>(I); - if (CB && CB->getCalledFunction() && - !CB->getCalledFunction()->isIntrinsic()) - continue; - - // See if this instruction simplifies. - Value *SimpleV = simplifyInstruction(I, DL); - if (!SimpleV) - continue; - - // Stash away all the uses of the old instruction so we can check them for - // recursive simplifications after a RAUW. This is cheaper than checking all - // uses of To on the recursive step in most cases. - for (const User *U : OrigV->users()) - Worklist.insert(cast<Instruction>(U)); + // Drop all incompatible return attributes that cannot be applied to NewFunc + // during cloning, so as to allow instruction simplification to reason on the + // old state of the function. The original attributes are restored later. + AttributeMask IncompatibleAttrs = + AttributeFuncs::typeIncompatible(OldFunc->getReturnType()); + AttributeList Attrs = NewFunc->getAttributes(); + NewFunc->removeRetAttrs(IncompatibleAttrs); + + // As phi-nodes have been now remapped, allow incremental simplification of + // newly-cloned instructions. + const DataLayout &DL = NewFunc->getDataLayout(); + for (const auto &BB : *OldFunc) { + for (const auto &I : BB) { + auto *NewI = dyn_cast_or_null<Instruction>(VMap.lookup(&I)); + if (!NewI) + continue; - // Replace the instruction with its simplified value. - I->replaceAllUsesWith(SimpleV); + if (Value *V = simplifyInstruction(NewI, DL)) { + NewI->replaceAllUsesWith(V); - // If the original instruction had no side effects, remove it. - if (isInstructionTriviallyDead(I)) - I->eraseFromParent(); - else - VMap[OrigV] = I; + if (isInstructionTriviallyDead(NewI)) { + NewI->eraseFromParent(); + } else { + // Did not erase it? Restore the new instruction into VMap previously + // dropped by `ValueIsRAUWd`. + VMap[&I] = NewI; + } + } + } } + // Restore attributes. + NewFunc->setAttributes(Attrs); + // Remap debug intrinsic operands now that all values have been mapped. // Doing this now (late) preserves use-before-defs in debug intrinsics. If // we didn't do this, ValueAsMetadata(use-before-def) operands would be @@ -884,14 +855,15 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, TypeMapper, Materializer); } - // Do the same for DPValues, touching all the instructions in the cloned - // range of blocks. + // Do the same for DbgVariableRecords, touching all the instructions in the + // cloned range of blocks. Function::iterator Begin = cast<BasicBlock>(VMap[StartingBB])->getIterator(); for (BasicBlock &BB : make_range(Begin, NewFunc->end())) { for (Instruction &I : BB) { - RemapDPValueRange(I.getModule(), I.getDbgValueRange(), VMap, - ModuleLevelChanges ? RF_None : RF_NoModuleLevelChanges, - TypeMapper, Materializer); + RemapDbgRecordRange(I.getModule(), I.getDbgRecordRange(), VMap, + ModuleLevelChanges ? RF_None + : RF_NoModuleLevelChanges, + TypeMapper, Materializer); } } @@ -990,8 +962,8 @@ void llvm::remapInstructionsInBlocks(ArrayRef<BasicBlock *> Blocks, // Rewrite the code to refer to itself. for (auto *BB : Blocks) { for (auto &Inst : *BB) { - RemapDPValueRange(Inst.getModule(), Inst.getDbgValueRange(), VMap, - RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + RemapDbgRecordRange(Inst.getModule(), Inst.getDbgRecordRange(), VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); RemapInstruction(&Inst, VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); } @@ -1129,6 +1101,9 @@ BasicBlock *llvm::DuplicateInstructionsInSplitBetween( if (I != ValueMapping.end()) New->setOperand(i, I->second); } + + // Remap debug variable operands. + remapDebugVariable(ValueMapping, New); } return NewBB; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneModule.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneModule.cpp index 00e40fe73d90..cabc2ab7933a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneModule.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CloneModule.cpp @@ -208,8 +208,8 @@ std::unique_ptr<Module> llvm::CloneModule( // And named metadata.... for (const NamedMDNode &NMD : M.named_metadata()) { NamedMDNode *NewNMD = New->getOrInsertNamedMetadata(NMD.getName()); - for (unsigned i = 0, e = NMD.getNumOperands(); i != e; ++i) - NewNMD->addOperand(MapMetadata(NMD.getOperand(i), VMap)); + for (const MDNode *N : NMD.operands()) + NewNMD->addOperand(MapMetadata(N, VMap)); } return New; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp index 278111883459..5bca5cf8ff91 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -570,7 +570,7 @@ void CodeExtractor::findAllocas(const CodeExtractorAnalysisCache &CEAC, LLVMContext &Ctx = M->getContext(); auto *Int8PtrTy = PointerType::getUnqual(Ctx); CastInst *CastI = - CastInst::CreatePointerCast(AI, Int8PtrTy, "lt.cast", I); + CastInst::CreatePointerCast(AI, Int8PtrTy, "lt.cast", I->getIterator()); I->replaceUsesOfWith(I->getOperand(1), CastI); } @@ -745,7 +745,7 @@ void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock *&Header) { /// and other with remaining incoming blocks; then first PHIs are placed in /// outlined region. void CodeExtractor::severSplitPHINodesOfExits( - const SmallPtrSetImpl<BasicBlock *> &Exits) { + const SetVector<BasicBlock *> &Exits) { for (BasicBlock *ExitBB : Exits) { BasicBlock *NewBB = nullptr; @@ -932,6 +932,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::DisableSanitizerInstrumentation: case Attribute::FnRetThunkExtern: case Attribute::Hot: + case Attribute::HybridPatchable: case Attribute::NoRecurse: case Attribute::InlineHint: case Attribute::MinSize: @@ -954,6 +955,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::ShadowCallStack: case Attribute::SanitizeAddress: case Attribute::SanitizeMemory: + case Attribute::SanitizeNumericalStability: case Attribute::SanitizeThread: case Attribute::SanitizeHWAddress: case Attribute::SanitizeMemTag: @@ -999,6 +1001,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::WriteOnly: case Attribute::Writable: case Attribute::DeadOnUnwind: + case Attribute::Range: + case Attribute::Initializes: // These are not really attributes. case Attribute::None: case Attribute::EndAttrKinds: @@ -1009,6 +1013,18 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, newFunction->addFnAttr(Attr); } + + if (NumExitBlocks == 0) { + // Mark the new function `noreturn` if applicable. Terminators which resume + // exception propagation are treated as returning instructions. This is to + // avoid inserting traps after calls to outlined functions which unwind. + if (none_of(Blocks, [](const BasicBlock *BB) { + const Instruction *Term = BB->getTerminator(); + return isa<ReturnInst>(Term) || isa<ResumeInst>(Term); + })) + newFunction->setDoesNotReturn(); + } + newFunction->insert(newFunction->end(), newRootNode); // Create scalar and aggregate iterators to name all of the arguments we @@ -1024,7 +1040,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, Value *Idx[2]; Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext())); Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx); - Instruction *TI = newFunction->begin()->getTerminator(); + BasicBlock::iterator TI = newFunction->begin()->getTerminator()->getIterator(); GetElementPtrInst *GEP = GetElementPtrInst::Create( StructTy, &*AggAI, Idx, "gep_" + inputs[i]->getName(), TI); RewriteVal = new LoadInst(StructTy->getElementType(aggIdx), GEP, @@ -1173,7 +1189,7 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, AllocaInst *alloca = new AllocaInst(output->getType(), DL.getAllocaAddrSpace(), nullptr, output->getName() + ".loc", - &codeReplacer->getParent()->front().front()); + codeReplacer->getParent()->front().begin()); ReloadOutputs.push_back(alloca); params.push_back(alloca); ++ScalarOutputArgNo; @@ -1192,8 +1208,8 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, StructArgTy = StructType::get(newFunction->getContext(), ArgTypes); Struct = new AllocaInst( StructArgTy, DL.getAllocaAddrSpace(), nullptr, "structArg", - AllocationBlock ? &*AllocationBlock->getFirstInsertionPt() - : &codeReplacer->getParent()->front().front()); + AllocationBlock ? AllocationBlock->getFirstInsertionPt() + : codeReplacer->getParent()->front().begin()); if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) { auto *StructSpaceCast = new AddrSpaceCastInst( @@ -1358,9 +1374,8 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, else InsertPt = std::next(OutI->getIterator()); - Instruction *InsertBefore = &*InsertPt; - assert((InsertBefore->getFunction() == newFunction || - Blocks.count(InsertBefore->getParent())) && + assert((InsertPt->getFunction() == newFunction || + Blocks.count(InsertPt->getParent())) && "InsertPt should be in new function"); if (AggregateArgs && StructValues.contains(outputs[i])) { assert(AggOutputArgBegin != newFunction->arg_end() && @@ -1371,8 +1386,8 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx); GetElementPtrInst *GEP = GetElementPtrInst::Create( StructArgTy, &*AggOutputArgBegin, Idx, "gep_" + outputs[i]->getName(), - InsertBefore); - new StoreInst(outputs[i], GEP, InsertBefore); + InsertPt); + new StoreInst(outputs[i], GEP, InsertPt); ++aggIdx; // Since there should be only one struct argument aggregating // all the output values, we shouldn't increment AggOutputArgBegin, which @@ -1381,7 +1396,7 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, assert(ScalarOutputArgBegin != newFunction->arg_end() && "Number of scalar output arguments should match " "the number of defined values"); - new StoreInst(outputs[i], &*ScalarOutputArgBegin, InsertBefore); + new StoreInst(outputs[i], &*ScalarOutputArgBegin, InsertPt); ++ScalarOutputArgBegin; } } @@ -1392,19 +1407,23 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, case 0: // There are no successors (the block containing the switch itself), which // means that previously this was the last part of the function, and hence - // this should be rewritten as a `ret' - - // Check if the function should return a value - if (OldFnRetTy->isVoidTy()) { - ReturnInst::Create(Context, nullptr, TheSwitch); // Return void + // this should be rewritten as a `ret` or `unreachable`. + if (newFunction->doesNotReturn()) { + // If fn is no return, end with an unreachable terminator. + (void)new UnreachableInst(Context, TheSwitch->getIterator()); + } else if (OldFnRetTy->isVoidTy()) { + // We have no return value. + ReturnInst::Create(Context, nullptr, + TheSwitch->getIterator()); // Return void } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) { // return what we have - ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch); + ReturnInst::Create(Context, TheSwitch->getCondition(), + TheSwitch->getIterator()); } else { // Otherwise we must have code extracted an unwind or something, just // return whatever we want. - ReturnInst::Create(Context, - Constant::getNullValue(OldFnRetTy), TheSwitch); + ReturnInst::Create(Context, Constant::getNullValue(OldFnRetTy), + TheSwitch->getIterator()); } TheSwitch->eraseFromParent(); @@ -1412,12 +1431,12 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, case 1: // Only a single destination, change the switch into an unconditional // branch. - BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch); + BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getIterator()); TheSwitch->eraseFromParent(); break; case 2: BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2), - call, TheSwitch); + call, TheSwitch->getIterator()); TheSwitch->eraseFromParent(); break; default: @@ -1508,14 +1527,14 @@ void CodeExtractor::calculateNewCallTerminatorWeights( static void eraseDebugIntrinsicsWithNonLocalRefs(Function &F) { for (Instruction &I : instructions(F)) { SmallVector<DbgVariableIntrinsic *, 4> DbgUsers; - SmallVector<DPValue *, 4> DPValues; - findDbgUsers(DbgUsers, &I, &DPValues); + SmallVector<DbgVariableRecord *, 4> DbgVariableRecords; + findDbgUsers(DbgUsers, &I, &DbgVariableRecords); for (DbgVariableIntrinsic *DVI : DbgUsers) if (DVI->getFunction() != &F) DVI->eraseFromParent(); - for (DPValue *DPV : DPValues) - if (DPV->getFunction() != &F) - DPV->eraseFromParent(); + for (DbgVariableRecord *DVR : DbgVariableRecords) + if (DVR->getFunction() != &F) + DVR->eraseFromParent(); } } @@ -1569,7 +1588,7 @@ static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc, // point to a variable in the wrong scope. SmallDenseMap<DINode *, DINode *> RemappedMetadata; SmallVector<Instruction *, 4> DebugIntrinsicsToDelete; - SmallVector<DPValue *, 4> DPVsToDelete; + SmallVector<DbgVariableRecord *, 4> DVRsToDelete; DenseMap<const MDNode *, MDNode *> Cache; auto GetUpdatedDIVariable = [&](DILocalVariable *OldVar) { @@ -1585,27 +1604,47 @@ static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc, return cast<DILocalVariable>(NewVar); }; - auto UpdateDPValuesOnInst = [&](Instruction &I) -> void { - for (auto &DPV : I.getDbgValueRange()) { + auto UpdateDbgLabel = [&](auto *LabelRecord) { + // Point the label record to a fresh label within the new function if + // the record was not inlined from some other function. + if (LabelRecord->getDebugLoc().getInlinedAt()) + return; + DILabel *OldLabel = LabelRecord->getLabel(); + DINode *&NewLabel = RemappedMetadata[OldLabel]; + if (!NewLabel) { + DILocalScope *NewScope = DILocalScope::cloneScopeForSubprogram( + *OldLabel->getScope(), *NewSP, Ctx, Cache); + NewLabel = DILabel::get(Ctx, NewScope, OldLabel->getName(), + OldLabel->getFile(), OldLabel->getLine()); + } + LabelRecord->setLabel(cast<DILabel>(NewLabel)); + }; + + auto UpdateDbgRecordsOnInst = [&](Instruction &I) -> void { + for (DbgRecord &DR : I.getDbgRecordRange()) { + if (DbgLabelRecord *DLR = dyn_cast<DbgLabelRecord>(&DR)) { + UpdateDbgLabel(DLR); + continue; + } + + DbgVariableRecord &DVR = cast<DbgVariableRecord>(DR); // Apply the two updates that dbg.values get: invalid operands, and // variable metadata fixup. - if (any_of(DPV.location_ops(), IsInvalidLocation)) { - DPVsToDelete.push_back(&DPV); + if (any_of(DVR.location_ops(), IsInvalidLocation)) { + DVRsToDelete.push_back(&DVR); continue; } - if (DPV.isDbgAssign() && IsInvalidLocation(DPV.getAddress())) { - DPVsToDelete.push_back(&DPV); + if (DVR.isDbgAssign() && IsInvalidLocation(DVR.getAddress())) { + DVRsToDelete.push_back(&DVR); continue; } - if (!DPV.getDebugLoc().getInlinedAt()) - DPV.setVariable(GetUpdatedDIVariable(DPV.getVariable())); - DPV.setDebugLoc(DebugLoc::replaceInlinedAtSubprogram(DPV.getDebugLoc(), - *NewSP, Ctx, Cache)); + if (!DVR.getDebugLoc().getInlinedAt()) + DVR.setVariable(GetUpdatedDIVariable(DVR.getVariable())); } }; for (Instruction &I : instructions(NewFunc)) { - UpdateDPValuesOnInst(I); + UpdateDbgRecordsOnInst(I); auto *DII = dyn_cast<DbgInfoIntrinsic>(&I); if (!DII) @@ -1614,17 +1653,7 @@ static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc, // Point the intrinsic to a fresh label within the new function if the // intrinsic was not inlined from some other function. if (auto *DLI = dyn_cast<DbgLabelInst>(&I)) { - if (DLI->getDebugLoc().getInlinedAt()) - continue; - DILabel *OldLabel = DLI->getLabel(); - DINode *&NewLabel = RemappedMetadata[OldLabel]; - if (!NewLabel) { - DILocalScope *NewScope = DILocalScope::cloneScopeForSubprogram( - *OldLabel->getScope(), *NewSP, Ctx, Cache); - NewLabel = DILabel::get(Ctx, NewScope, OldLabel->getName(), - OldLabel->getFile(), OldLabel->getLine()); - } - DLI->setArgOperand(0, MetadataAsValue::get(Ctx, NewLabel)); + UpdateDbgLabel(DLI); continue; } @@ -1648,16 +1677,20 @@ static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc, for (auto *DII : DebugIntrinsicsToDelete) DII->eraseFromParent(); - for (auto *DPV : DPVsToDelete) - DPV->getMarker()->MarkedInstr->dropOneDbgValue(DPV); + for (auto *DVR : DVRsToDelete) + DVR->getMarker()->MarkedInstr->dropOneDbgRecord(DVR); DIB.finalizeSubprogram(NewSP); - // Fix up the scope information attached to the line locations in the new - // function. + // Fix up the scope information attached to the line locations and the + // debug assignment metadata in the new function. + DenseMap<DIAssignID *, DIAssignID *> AssignmentIDMap; for (Instruction &I : instructions(NewFunc)) { if (const DebugLoc &DL = I.getDebugLoc()) I.setDebugLoc( DebugLoc::replaceInlinedAtSubprogram(DL, *NewSP, Ctx, Cache)); + for (DbgRecord &DR : I.getDbgRecordRange()) + DR.setDebugLoc(DebugLoc::replaceInlinedAtSubprogram(DR.getDebugLoc(), + *NewSP, Ctx, Cache)); // Loop info metadata may contain line locations. Fix them up. auto updateLoopInfoLoc = [&Ctx, &Cache, NewSP](Metadata *MD) -> Metadata * { @@ -1666,6 +1699,7 @@ static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc, return MD; }; updateLoopMetadataDebugLocations(I, updateLoopInfoLoc); + at::remapAssignID(AssignmentIDMap, I); } if (!TheCall.getDebugLoc()) TheCall.setDebugLoc(DILocation::get(Ctx, 0, 0, OldSP)); @@ -1722,7 +1756,7 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, // Calculate the exit blocks for the extracted region and the total exit // weights for each of those blocks. DenseMap<BasicBlock *, BlockFrequency> ExitWeights; - SmallPtrSet<BasicBlock *, 1> ExitBlocks; + SetVector<BasicBlock *> ExitBlocks; for (BasicBlock *Block : Blocks) { for (BasicBlock *Succ : successors(Block)) { if (!Blocks.count(Succ)) { @@ -1769,6 +1803,10 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, return any_of(*BB, [&BranchI](const Instruction &I) { if (!I.getDebugLoc()) return false; + // Don't use source locations attached to debug-intrinsics: they could + // be from completely unrelated scopes. + if (isa<DbgInfoIntrinsic>(I)) + return false; BranchI->setDebugLoc(I.getDebugLoc()); return true; }); @@ -1878,16 +1916,6 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, fixupDebugInfoPostExtraction(*oldFunction, *newFunction, *TheCall); - // Mark the new function `noreturn` if applicable. Terminators which resume - // exception propagation are treated as returning instructions. This is to - // avoid inserting traps after calls to outlined functions which unwind. - bool doesNotReturn = none_of(*newFunction, [](const BasicBlock &BB) { - const Instruction *Term = BB.getTerminator(); - return isa<ReturnInst>(Term) || isa<ResumeInst>(Term); - }); - if (doesNotReturn) - newFunction->setDoesNotReturn(); - LLVM_DEBUG(if (verifyFunction(*newFunction, &errs())) { newFunction->dump(); report_fatal_error("verification of newFunction failed!"); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp index 6a2dae5bab68..ac106e4aa2a3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CodeMoverUtils.cpp @@ -336,9 +336,22 @@ bool llvm::isSafeToMoveBefore(Instruction &I, Instruction &InsertPoint, if (isReachedBefore(&I, &InsertPoint, &DT, PDT)) for (const Use &U : I.uses()) - if (auto *UserInst = dyn_cast<Instruction>(U.getUser())) - if (UserInst != &InsertPoint && !DT.dominates(&InsertPoint, U)) + if (auto *UserInst = dyn_cast<Instruction>(U.getUser())) { + // If InsertPoint is in a BB that comes after I, then we cannot move if + // I is used in the terminator of the current BB. + if (I.getParent() == InsertPoint.getParent() && + UserInst == I.getParent()->getTerminator()) return false; + if (UserInst != &InsertPoint && !DT.dominates(&InsertPoint, U)) { + // If UserInst is an instruction that appears later in the same BB as + // I, then it is okay to move since I will still be available when + // UserInst is executed. + if (CheckForEntireBlock && I.getParent() == UserInst->getParent() && + DT.dominates(&I, UserInst)) + continue; + return false; + } + } if (isReachedBefore(&InsertPoint, &I, &DT, PDT)) for (const Value *Op : I.operands()) if (auto *OpInst = dyn_cast<Instruction>(Op)) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/CountVisits.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/CountVisits.cpp index 4faded8fc656..f22880bc9a66 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/CountVisits.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/CountVisits.cpp @@ -8,6 +8,7 @@ #include "llvm/Transforms/Utils/CountVisits.h" #include "llvm/ADT/Statistic.h" +#include "llvm/IR/Function.h" #include "llvm/IR/PassManager.h" using namespace llvm; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/DXILResource.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/DXILResource.cpp new file mode 100644 index 000000000000..de2b6512a6d1 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/DXILResource.cpp @@ -0,0 +1,370 @@ +//===- DXILResource.cpp - Tools to translate DXIL resources ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/DXILResource.h" +#include "llvm/ADT/APInt.h" +#include "llvm/IR/DerivedTypes.h" + +using namespace llvm; +using namespace dxil; + +bool ResourceInfo::isUAV() const { return RC == ResourceClass::UAV; } + +bool ResourceInfo::isCBuffer() const { return RC == ResourceClass::CBuffer; } + +bool ResourceInfo::isSampler() const { return RC == ResourceClass::Sampler; } + +bool ResourceInfo::isStruct() const { + return Kind == ResourceKind::StructuredBuffer; +} + +bool ResourceInfo::isTyped() const { + switch (Kind) { + case ResourceKind::Texture1D: + case ResourceKind::Texture2D: + case ResourceKind::Texture2DMS: + case ResourceKind::Texture3D: + case ResourceKind::TextureCube: + case ResourceKind::Texture1DArray: + case ResourceKind::Texture2DArray: + case ResourceKind::Texture2DMSArray: + case ResourceKind::TextureCubeArray: + case ResourceKind::TypedBuffer: + return true; + case ResourceKind::RawBuffer: + case ResourceKind::StructuredBuffer: + case ResourceKind::FeedbackTexture2D: + case ResourceKind::FeedbackTexture2DArray: + case ResourceKind::CBuffer: + case ResourceKind::Sampler: + case ResourceKind::TBuffer: + case ResourceKind::RTAccelerationStructure: + return false; + case ResourceKind::Invalid: + case ResourceKind::NumEntries: + llvm_unreachable("Invalid resource kind"); + } + llvm_unreachable("Unhandled ResourceKind enum"); +} + +bool ResourceInfo::isFeedback() const { + return Kind == ResourceKind::FeedbackTexture2D || + Kind == ResourceKind::FeedbackTexture2DArray; +} + +bool ResourceInfo::isMultiSample() const { + return Kind == ResourceKind::Texture2DMS || + Kind == ResourceKind::Texture2DMSArray; +} + +ResourceInfo ResourceInfo::SRV(Value *Symbol, StringRef Name, + ResourceBinding Binding, uint32_t UniqueID, + ElementType ElementTy, uint32_t ElementCount, + ResourceKind Kind) { + ResourceInfo RI(ResourceClass::SRV, Kind, Symbol, Name, Binding, UniqueID); + assert(RI.isTyped() && !(RI.isStruct() || RI.isMultiSample()) && + "Invalid ResourceKind for SRV constructor."); + RI.Typed.ElementTy = ElementTy; + RI.Typed.ElementCount = ElementCount; + return RI; +} + +ResourceInfo ResourceInfo::RawBuffer(Value *Symbol, StringRef Name, + ResourceBinding Binding, + uint32_t UniqueID) { + ResourceInfo RI(ResourceClass::SRV, ResourceKind::RawBuffer, Symbol, Name, + Binding, UniqueID); + return RI; +} + +ResourceInfo ResourceInfo::StructuredBuffer(Value *Symbol, StringRef Name, + ResourceBinding Binding, + uint32_t UniqueID, uint32_t Stride, + Align Alignment) { + ResourceInfo RI(ResourceClass::SRV, ResourceKind::StructuredBuffer, Symbol, + Name, Binding, UniqueID); + RI.Struct.Stride = Stride; + RI.Struct.Alignment = Alignment; + return RI; +} + +ResourceInfo ResourceInfo::Texture2DMS(Value *Symbol, StringRef Name, + ResourceBinding Binding, + uint32_t UniqueID, ElementType ElementTy, + uint32_t ElementCount, + uint32_t SampleCount) { + ResourceInfo RI(ResourceClass::SRV, ResourceKind::Texture2DMS, Symbol, Name, + Binding, UniqueID); + RI.Typed.ElementTy = ElementTy; + RI.Typed.ElementCount = ElementCount; + RI.MultiSample.Count = SampleCount; + return RI; +} + +ResourceInfo ResourceInfo::Texture2DMSArray( + Value *Symbol, StringRef Name, ResourceBinding Binding, uint32_t UniqueID, + ElementType ElementTy, uint32_t ElementCount, uint32_t SampleCount) { + ResourceInfo RI(ResourceClass::SRV, ResourceKind::Texture2DMSArray, Symbol, + Name, Binding, UniqueID); + RI.Typed.ElementTy = ElementTy; + RI.Typed.ElementCount = ElementCount; + RI.MultiSample.Count = SampleCount; + return RI; +} + +ResourceInfo ResourceInfo::UAV(Value *Symbol, StringRef Name, + ResourceBinding Binding, uint32_t UniqueID, + ElementType ElementTy, uint32_t ElementCount, + bool GloballyCoherent, bool IsROV, + ResourceKind Kind) { + ResourceInfo RI(ResourceClass::UAV, Kind, Symbol, Name, Binding, UniqueID); + assert(RI.isTyped() && !(RI.isStruct() || RI.isMultiSample()) && + "Invalid ResourceKind for UAV constructor."); + RI.Typed.ElementTy = ElementTy; + RI.Typed.ElementCount = ElementCount; + RI.UAVFlags.GloballyCoherent = GloballyCoherent; + RI.UAVFlags.IsROV = IsROV; + RI.UAVFlags.HasCounter = false; + return RI; +} + +ResourceInfo ResourceInfo::RWRawBuffer(Value *Symbol, StringRef Name, + ResourceBinding Binding, + uint32_t UniqueID, bool GloballyCoherent, + bool IsROV) { + ResourceInfo RI(ResourceClass::UAV, ResourceKind::RawBuffer, Symbol, Name, + Binding, UniqueID); + RI.UAVFlags.GloballyCoherent = GloballyCoherent; + RI.UAVFlags.IsROV = IsROV; + RI.UAVFlags.HasCounter = false; + return RI; +} + +ResourceInfo ResourceInfo::RWStructuredBuffer(Value *Symbol, StringRef Name, + ResourceBinding Binding, + uint32_t UniqueID, + uint32_t Stride, Align Alignment, + bool GloballyCoherent, bool IsROV, + bool HasCounter) { + ResourceInfo RI(ResourceClass::UAV, ResourceKind::StructuredBuffer, Symbol, + Name, Binding, UniqueID); + RI.Struct.Stride = Stride; + RI.Struct.Alignment = Alignment; + RI.UAVFlags.GloballyCoherent = GloballyCoherent; + RI.UAVFlags.IsROV = IsROV; + RI.UAVFlags.HasCounter = HasCounter; + return RI; +} + +ResourceInfo +ResourceInfo::RWTexture2DMS(Value *Symbol, StringRef Name, + ResourceBinding Binding, uint32_t UniqueID, + ElementType ElementTy, uint32_t ElementCount, + uint32_t SampleCount, bool GloballyCoherent) { + ResourceInfo RI(ResourceClass::UAV, ResourceKind::Texture2DMS, Symbol, Name, + Binding, UniqueID); + RI.Typed.ElementTy = ElementTy; + RI.Typed.ElementCount = ElementCount; + RI.UAVFlags.GloballyCoherent = GloballyCoherent; + RI.UAVFlags.IsROV = false; + RI.UAVFlags.HasCounter = false; + RI.MultiSample.Count = SampleCount; + return RI; +} + +ResourceInfo +ResourceInfo::RWTexture2DMSArray(Value *Symbol, StringRef Name, + ResourceBinding Binding, uint32_t UniqueID, + ElementType ElementTy, uint32_t ElementCount, + uint32_t SampleCount, bool GloballyCoherent) { + ResourceInfo RI(ResourceClass::UAV, ResourceKind::Texture2DMSArray, Symbol, + Name, Binding, UniqueID); + RI.Typed.ElementTy = ElementTy; + RI.Typed.ElementCount = ElementCount; + RI.UAVFlags.GloballyCoherent = GloballyCoherent; + RI.UAVFlags.IsROV = false; + RI.UAVFlags.HasCounter = false; + RI.MultiSample.Count = SampleCount; + return RI; +} + +ResourceInfo ResourceInfo::FeedbackTexture2D(Value *Symbol, StringRef Name, + ResourceBinding Binding, + uint32_t UniqueID, + SamplerFeedbackType FeedbackTy) { + ResourceInfo RI(ResourceClass::UAV, ResourceKind::FeedbackTexture2D, Symbol, + Name, Binding, UniqueID); + RI.UAVFlags.GloballyCoherent = false; + RI.UAVFlags.IsROV = false; + RI.UAVFlags.HasCounter = false; + RI.Feedback.Type = FeedbackTy; + return RI; +} + +ResourceInfo +ResourceInfo::FeedbackTexture2DArray(Value *Symbol, StringRef Name, + ResourceBinding Binding, uint32_t UniqueID, + SamplerFeedbackType FeedbackTy) { + ResourceInfo RI(ResourceClass::UAV, ResourceKind::FeedbackTexture2DArray, + Symbol, Name, Binding, UniqueID); + RI.UAVFlags.GloballyCoherent = false; + RI.UAVFlags.IsROV = false; + RI.UAVFlags.HasCounter = false; + RI.Feedback.Type = FeedbackTy; + return RI; +} + +ResourceInfo ResourceInfo::CBuffer(Value *Symbol, StringRef Name, + ResourceBinding Binding, uint32_t UniqueID, + uint32_t Size) { + ResourceInfo RI(ResourceClass::CBuffer, ResourceKind::CBuffer, Symbol, Name, + Binding, UniqueID); + RI.CBufferSize = Size; + return RI; +} + +ResourceInfo ResourceInfo::Sampler(Value *Symbol, StringRef Name, + ResourceBinding Binding, uint32_t UniqueID, + SamplerType SamplerTy) { + ResourceInfo RI(ResourceClass::Sampler, ResourceKind::Sampler, Symbol, Name, + Binding, UniqueID); + RI.SamplerTy = SamplerTy; + return RI; +} + +bool ResourceInfo::operator==(const ResourceInfo &RHS) const { + if (std::tie(Symbol, Name, Binding, UniqueID, RC, Kind) != + std::tie(RHS.Symbol, RHS.Name, RHS.Binding, RHS.UniqueID, RHS.RC, + RHS.Kind)) + return false; + if (isCBuffer()) + return CBufferSize == RHS.CBufferSize; + if (isSampler()) + return SamplerTy == RHS.SamplerTy; + if (isUAV() && UAVFlags != RHS.UAVFlags) + return false; + + if (isStruct()) + return Struct == RHS.Struct; + if (isFeedback()) + return Feedback == RHS.Feedback; + if (isTyped() && Typed != RHS.Typed) + return false; + + if (isMultiSample()) + return MultiSample == RHS.MultiSample; + + assert((Kind == ResourceKind::RawBuffer) && "Unhandled resource kind"); + return true; +} + +MDTuple *ResourceInfo::getAsMetadata(LLVMContext &Ctx) const { + SmallVector<Metadata *, 11> MDVals; + + Type *I32Ty = Type::getInt32Ty(Ctx); + Type *I1Ty = Type::getInt1Ty(Ctx); + auto getIntMD = [&I32Ty](uint32_t V) { + return ConstantAsMetadata::get( + Constant::getIntegerValue(I32Ty, APInt(32, V))); + }; + auto getBoolMD = [&I1Ty](uint32_t V) { + return ConstantAsMetadata::get( + Constant::getIntegerValue(I1Ty, APInt(1, V))); + }; + + MDVals.push_back(getIntMD(UniqueID)); + MDVals.push_back(ValueAsMetadata::get(Symbol)); + MDVals.push_back(MDString::get(Ctx, Name)); + MDVals.push_back(getIntMD(Binding.Space)); + MDVals.push_back(getIntMD(Binding.LowerBound)); + MDVals.push_back(getIntMD(Binding.Size)); + + if (isCBuffer()) { + MDVals.push_back(getIntMD(CBufferSize)); + MDVals.push_back(nullptr); + } else if (isSampler()) { + MDVals.push_back(getIntMD(llvm::to_underlying(SamplerTy))); + MDVals.push_back(nullptr); + } else { + MDVals.push_back(getIntMD(llvm::to_underlying(Kind))); + + if (isUAV()) { + MDVals.push_back(getBoolMD(UAVFlags.GloballyCoherent)); + MDVals.push_back(getBoolMD(UAVFlags.HasCounter)); + MDVals.push_back(getBoolMD(UAVFlags.IsROV)); + } else { + // All SRVs include sample count in the metadata, but it's only meaningful + // for multi-sampled textured. Also, UAVs can be multisampled in SM6.7+, + // but this just isn't reflected in the metadata at all. + uint32_t SampleCount = isMultiSample() ? MultiSample.Count : 0; + MDVals.push_back(getIntMD(SampleCount)); + } + + // Further properties are attached to a metadata list of tag-value pairs. + SmallVector<Metadata *> Tags; + if (isStruct()) { + Tags.push_back( + getIntMD(llvm::to_underlying(ExtPropTags::StructuredBufferStride))); + Tags.push_back(getIntMD(Struct.Stride)); + } else if (isTyped()) { + Tags.push_back(getIntMD(llvm::to_underlying(ExtPropTags::ElementType))); + Tags.push_back(getIntMD(llvm::to_underlying(Typed.ElementTy))); + } else if (isFeedback()) { + Tags.push_back( + getIntMD(llvm::to_underlying(ExtPropTags::SamplerFeedbackKind))); + Tags.push_back(getIntMD(llvm::to_underlying(Feedback.Type))); + } + MDVals.push_back(Tags.empty() ? nullptr : MDNode::get(Ctx, Tags)); + } + + return MDNode::get(Ctx, MDVals); +} + +std::pair<uint32_t, uint32_t> ResourceInfo::getAnnotateProps() const { + uint32_t ResourceKind = llvm::to_underlying(Kind); + uint32_t AlignLog2 = isStruct() ? Log2(Struct.Alignment) : 0; + bool IsUAV = isUAV(); + bool IsROV = IsUAV && UAVFlags.IsROV; + bool IsGloballyCoherent = IsUAV && UAVFlags.GloballyCoherent; + uint8_t SamplerCmpOrHasCounter = 0; + if (IsUAV) + SamplerCmpOrHasCounter = UAVFlags.HasCounter; + else if (isSampler()) + SamplerCmpOrHasCounter = SamplerTy == SamplerType::Comparison; + + // TODO: Document this format. Currently the only reference is the + // implementation of dxc's DxilResourceProperties struct. + uint32_t Word0 = 0; + Word0 |= ResourceKind & 0xFF; + Word0 |= (AlignLog2 & 0xF) << 8; + Word0 |= (IsUAV & 1) << 12; + Word0 |= (IsROV & 1) << 13; + Word0 |= (IsGloballyCoherent & 1) << 14; + Word0 |= (SamplerCmpOrHasCounter & 1) << 15; + + uint32_t Word1 = 0; + if (isStruct()) + Word1 = Struct.Stride; + else if (isCBuffer()) + Word1 = CBufferSize; + else if (isFeedback()) + Word1 = llvm::to_underlying(Feedback.Type); + else if (isTyped()) { + uint32_t CompType = llvm::to_underlying(Typed.ElementTy); + uint32_t CompCount = Typed.ElementCount; + uint32_t SampleCount = isMultiSample() ? MultiSample.Count : 0; + + Word1 |= (CompType & 0xFF) << 0; + Word1 |= (CompCount & 0xFF) << 8; + Word1 |= (SampleCount & 0xFF) << 16; + } + + return {Word0, Word1}; +} + +#define DEBUG_TYPE "dxil-resource" diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/Debugify.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/Debugify.cpp index d0cc603426d2..fcc82eadac36 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/Debugify.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/Debugify.cpp @@ -338,20 +338,23 @@ bool llvm::collectDebugInfoMetadata(Module &M, // Cllect dbg.values and dbg.declare. if (DebugifyLevel > Level::Locations) { - if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I)) { + auto HandleDbgVariable = [&](auto *DbgVar) { if (!SP) - continue; + return; // Skip inlined variables. - if (I.getDebugLoc().getInlinedAt()) - continue; + if (DbgVar->getDebugLoc().getInlinedAt()) + return; // Skip undef values. - if (DVI->isKillLocation()) - continue; + if (DbgVar->isKillLocation()) + return; - auto *Var = DVI->getVariable(); + auto *Var = DbgVar->getVariable(); DebugInfoBeforePass.DIVariables[Var]++; - continue; - } + }; + for (DbgVariableRecord &DVR : filterDbgVars(I.getDbgRecordRange())) + HandleDbgVariable(&DVR); + if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I)) + HandleDbgVariable(DVI); } // Skip debug instructions other than dbg.value and dbg.declare. @@ -581,20 +584,23 @@ bool llvm::checkDebugInfoMetadata(Module &M, // Collect dbg.values and dbg.declares. if (DebugifyLevel > Level::Locations) { - if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I)) { + auto HandleDbgVariable = [&](auto *DbgVar) { if (!SP) - continue; + return; // Skip inlined variables. - if (I.getDebugLoc().getInlinedAt()) - continue; + if (DbgVar->getDebugLoc().getInlinedAt()) + return; // Skip undef values. - if (DVI->isKillLocation()) - continue; + if (DbgVar->isKillLocation()) + return; - auto *Var = DVI->getVariable(); + auto *Var = DbgVar->getVariable(); DebugInfoAfterPass.DIVariables[Var]++; - continue; - } + }; + for (DbgVariableRecord &DVR : filterDbgVars(I.getDbgRecordRange())) + HandleDbgVariable(&DVR); + if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I)) + HandleDbgVariable(DVI); } // Skip debug instructions other than dbg.value and dbg.declare. @@ -662,8 +668,9 @@ bool llvm::checkDebugInfoMetadata(Module &M, } namespace { -/// Return true if a mis-sized diagnostic is issued for \p DVI. -bool diagnoseMisSizedDbgValue(Module &M, DbgValueInst *DVI) { +/// Return true if a mis-sized diagnostic is issued for \p DbgVal. +template <typename DbgValTy> +bool diagnoseMisSizedDbgValue(Module &M, DbgValTy *DbgVal) { // The size of a dbg.value's value operand should match the size of the // variable it corresponds to. // @@ -672,22 +679,22 @@ bool diagnoseMisSizedDbgValue(Module &M, DbgValueInst *DVI) { // For now, don't try to interpret anything more complicated than an empty // DIExpression. Eventually we should try to handle OP_deref and fragments. - if (DVI->getExpression()->getNumElements()) + if (DbgVal->getExpression()->getNumElements()) return false; - Value *V = DVI->getVariableLocationOp(0); + Value *V = DbgVal->getVariableLocationOp(0); if (!V) return false; Type *Ty = V->getType(); uint64_t ValueOperandSize = getAllocSizeInBits(M, Ty); - std::optional<uint64_t> DbgVarSize = DVI->getFragmentSizeInBits(); + std::optional<uint64_t> DbgVarSize = DbgVal->getFragmentSizeInBits(); if (!ValueOperandSize || !DbgVarSize) return false; bool HasBadSize = false; if (Ty->isIntegerTy()) { - auto Signedness = DVI->getVariable()->getSignedness(); + auto Signedness = DbgVal->getVariable()->getSignedness(); if (Signedness && *Signedness == DIBasicType::Signedness::Signed) HasBadSize = ValueOperandSize < *DbgVarSize; } else { @@ -697,7 +704,7 @@ bool diagnoseMisSizedDbgValue(Module &M, DbgValueInst *DVI) { if (HasBadSize) { dbg() << "ERROR: dbg.value operand has size " << ValueOperandSize << ", but its variable has size " << *DbgVarSize << ": "; - DVI->print(dbg()); + DbgVal->print(dbg()); dbg() << "\n"; } return HasBadSize; @@ -755,18 +762,23 @@ bool checkDebugifyMetadata(Module &M, } // Find missing variables and mis-sized debug values. - for (Instruction &I : instructions(F)) { - auto *DVI = dyn_cast<DbgValueInst>(&I); - if (!DVI) - continue; - + auto CheckForMisSized = [&](auto *DbgVal) { unsigned Var = ~0U; - (void)to_integer(DVI->getVariable()->getName(), Var, 10); + (void)to_integer(DbgVal->getVariable()->getName(), Var, 10); assert(Var <= OriginalNumVars && "Unexpected name for DILocalVariable"); - bool HasBadSize = diagnoseMisSizedDbgValue(M, DVI); + bool HasBadSize = diagnoseMisSizedDbgValue(M, DbgVal); if (!HasBadSize) MissingVars.reset(Var - 1); HasErrors |= HasBadSize; + }; + for (Instruction &I : instructions(F)) { + for (DbgVariableRecord &DVR : filterDbgVars(I.getDbgRecordRange())) + if (DVR.isDbgValue() || DVR.isDbgAssign()) + CheckForMisSized(&DVR); + auto *DVI = dyn_cast<DbgValueInst>(&I); + if (!DVI) + continue; + CheckForMisSized(DVI); } } @@ -791,24 +803,19 @@ bool checkDebugifyMetadata(Module &M, dbg() << ": " << (HasErrors ? "FAIL" : "PASS") << '\n'; // Strip debugify metadata if required. + bool Ret = false; if (Strip) - return stripDebugifyMetadata(M); + Ret = stripDebugifyMetadata(M); - return false; + return Ret; } /// ModulePass for attaching synthetic debug info to everything, used with the /// legacy module pass manager. struct DebugifyModulePass : public ModulePass { bool runOnModule(Module &M) override { - bool NewDebugMode = M.IsNewDbgInfoFormat; - if (NewDebugMode) - M.convertFromNewDbgValues(); - - bool Result = applyDebugify(M, Mode, DebugInfoBeforePass, NameOfWrappedPass); - - if (NewDebugMode) - M.convertToNewDbgValues(); + bool Result = + applyDebugify(M, Mode, DebugInfoBeforePass, NameOfWrappedPass); return Result; } @@ -834,14 +841,8 @@ private: /// single function, used with the legacy module pass manager. struct DebugifyFunctionPass : public FunctionPass { bool runOnFunction(Function &F) override { - bool NewDebugMode = F.IsNewDbgInfoFormat; - if (NewDebugMode) - F.convertFromNewDbgValues(); - - bool Result = applyDebugify(F, Mode, DebugInfoBeforePass, NameOfWrappedPass); - - if (NewDebugMode) - F.convertToNewDbgValues(); + bool Result = + applyDebugify(F, Mode, DebugInfoBeforePass, NameOfWrappedPass); return Result; } @@ -868,10 +869,6 @@ private: /// legacy module pass manager. struct CheckDebugifyModulePass : public ModulePass { bool runOnModule(Module &M) override { - bool NewDebugMode = M.IsNewDbgInfoFormat; - if (NewDebugMode) - M.convertFromNewDbgValues(); - bool Result; if (Mode == DebugifyMode::SyntheticDebugInfo) Result = checkDebugifyMetadata(M, M.functions(), NameOfWrappedPass, @@ -882,9 +879,6 @@ struct CheckDebugifyModulePass : public ModulePass { "CheckModuleDebugify (original debuginfo)", NameOfWrappedPass, OrigDIVerifyBugsReportFilePath); - if (NewDebugMode) - M.convertToNewDbgValues(); - return Result; } @@ -918,10 +912,6 @@ private: /// with the legacy module pass manager. struct CheckDebugifyFunctionPass : public FunctionPass { bool runOnFunction(Function &F) override { - bool NewDebugMode = F.IsNewDbgInfoFormat; - if (NewDebugMode) - F.convertFromNewDbgValues(); - Module &M = *F.getParent(); auto FuncIt = F.getIterator(); bool Result; @@ -935,8 +925,6 @@ struct CheckDebugifyFunctionPass : public FunctionPass { "CheckFunctionDebugify (original debuginfo)", NameOfWrappedPass, OrigDIVerifyBugsReportFilePath); - if (NewDebugMode) - F.convertToNewDbgValues(); return Result; } @@ -1009,10 +997,6 @@ createDebugifyFunctionPass(enum DebugifyMode Mode, } PreservedAnalyses NewPMDebugifyPass::run(Module &M, ModuleAnalysisManager &) { - bool NewDebugMode = M.IsNewDbgInfoFormat; - if (NewDebugMode) - M.convertFromNewDbgValues(); - if (Mode == DebugifyMode::SyntheticDebugInfo) applyDebugifyMetadata(M, M.functions(), "ModuleDebugify: ", /*ApplyToMF*/ nullptr); @@ -1021,9 +1005,6 @@ PreservedAnalyses NewPMDebugifyPass::run(Module &M, ModuleAnalysisManager &) { "ModuleDebugify (original debuginfo)", NameOfWrappedPass); - if (NewDebugMode) - M.convertToNewDbgValues(); - PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); return PA; @@ -1055,10 +1036,6 @@ FunctionPass *createCheckDebugifyFunctionPass( PreservedAnalyses NewPMCheckDebugifyPass::run(Module &M, ModuleAnalysisManager &) { - bool NewDebugMode = M.IsNewDbgInfoFormat; - if (NewDebugMode) - M.convertFromNewDbgValues(); - if (Mode == DebugifyMode::SyntheticDebugInfo) checkDebugifyMetadata(M, M.functions(), NameOfWrappedPass, "CheckModuleDebugify", Strip, StatsMap); @@ -1068,9 +1045,6 @@ PreservedAnalyses NewPMCheckDebugifyPass::run(Module &M, "CheckModuleDebugify (original debuginfo)", NameOfWrappedPass, OrigDIVerifyBugsReportFilePath); - if (NewDebugMode) - M.convertToNewDbgValues(); - return PreservedAnalyses::all(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp index c894afee68a2..3a33b591d355 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/DemoteRegToStack.cpp @@ -8,6 +8,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/Analysis/CFG.h" +#include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -20,23 +21,23 @@ using namespace llvm; /// invalidating the SSA information for the value. It returns the pointer to /// the alloca inserted to create a stack slot for I. AllocaInst *llvm::DemoteRegToStack(Instruction &I, bool VolatileLoads, - Instruction *AllocaPoint) { + std::optional<BasicBlock::iterator> AllocaPoint) { if (I.use_empty()) { I.eraseFromParent(); return nullptr; } Function *F = I.getParent()->getParent(); - const DataLayout &DL = F->getParent()->getDataLayout(); + const DataLayout &DL = F->getDataLayout(); // Create a stack slot to hold the value. AllocaInst *Slot; if (AllocaPoint) { Slot = new AllocaInst(I.getType(), DL.getAllocaAddrSpace(), nullptr, - I.getName()+".reg2mem", AllocaPoint); + I.getName()+".reg2mem", *AllocaPoint); } else { Slot = new AllocaInst(I.getType(), DL.getAllocaAddrSpace(), nullptr, - I.getName() + ".reg2mem", &F->getEntryBlock().front()); + I.getName() + ".reg2mem", F->getEntryBlock().begin()); } // We cannot demote invoke instructions to the stack if their normal edge @@ -50,6 +51,15 @@ AllocaInst *llvm::DemoteRegToStack(Instruction &I, bool VolatileLoads, assert(BB && "Unable to split critical edge."); (void)BB; } + } else if (CallBrInst *CBI = dyn_cast<CallBrInst>(&I)) { + for (unsigned i = 0; i < CBI->getNumSuccessors(); i++) { + auto *Succ = CBI->getSuccessor(i); + if (!Succ->getSinglePredecessor()) { + assert(isCriticalEdge(II, i) && "Expected a critical edge!"); + [[maybe_unused]] BasicBlock *BB = SplitCriticalEdge(II, i); + assert(BB && "Unable to split critical edge."); + } + } } // Change all of the users of the instruction to read from the stack slot. @@ -73,7 +83,7 @@ AllocaInst *llvm::DemoteRegToStack(Instruction &I, bool VolatileLoads, // Insert the load into the predecessor block V = new LoadInst(I.getType(), Slot, I.getName() + ".reload", VolatileLoads, - PN->getIncomingBlock(i)->getTerminator()); + PN->getIncomingBlock(i)->getTerminator()->getIterator()); Loads[PN->getIncomingBlock(i)] = V; } PN->setIncomingValue(i, V); @@ -82,7 +92,7 @@ AllocaInst *llvm::DemoteRegToStack(Instruction &I, bool VolatileLoads, } else { // If this is a normal instruction, just insert a load. Value *V = new LoadInst(I.getType(), Slot, I.getName() + ".reload", - VolatileLoads, U); + VolatileLoads, U->getIterator()); U->replaceUsesOfWith(&I, V); } } @@ -99,39 +109,44 @@ AllocaInst *llvm::DemoteRegToStack(Instruction &I, bool VolatileLoads, break; if (isa<CatchSwitchInst>(InsertPt)) { for (BasicBlock *Handler : successors(&*InsertPt)) - new StoreInst(&I, Slot, &*Handler->getFirstInsertionPt()); + new StoreInst(&I, Slot, Handler->getFirstInsertionPt()); return Slot; } + } else if (InvokeInst *II = dyn_cast<InvokeInst>(&I)) { + InsertPt = II->getNormalDest()->getFirstInsertionPt(); + } else if (CallBrInst *CBI = dyn_cast<CallBrInst>(&I)) { + for (BasicBlock *Succ : successors(CBI)) + new StoreInst(CBI, Slot, Succ->getFirstInsertionPt()); + return Slot; } else { - InvokeInst &II = cast<InvokeInst>(I); - InsertPt = II.getNormalDest()->getFirstInsertionPt(); + llvm_unreachable("Unsupported terminator for Reg2Mem"); } - new StoreInst(&I, Slot, &*InsertPt); + new StoreInst(&I, Slot, InsertPt); return Slot; } /// DemotePHIToStack - This function takes a virtual register computed by a PHI /// node and replaces it with a slot in the stack frame allocated via alloca. /// The PHI node is deleted. It returns the pointer to the alloca inserted. -AllocaInst *llvm::DemotePHIToStack(PHINode *P, Instruction *AllocaPoint) { +AllocaInst *llvm::DemotePHIToStack(PHINode *P, std::optional<BasicBlock::iterator> AllocaPoint) { if (P->use_empty()) { P->eraseFromParent(); return nullptr; } - const DataLayout &DL = P->getModule()->getDataLayout(); + const DataLayout &DL = P->getDataLayout(); // Create a stack slot to hold the value. AllocaInst *Slot; if (AllocaPoint) { Slot = new AllocaInst(P->getType(), DL.getAllocaAddrSpace(), nullptr, - P->getName()+".reg2mem", AllocaPoint); + P->getName()+".reg2mem", *AllocaPoint); } else { Function *F = P->getParent()->getParent(); Slot = new AllocaInst(P->getType(), DL.getAllocaAddrSpace(), nullptr, P->getName() + ".reg2mem", - &F->getEntryBlock().front()); + F->getEntryBlock().begin()); } // Iterate over each operand inserting a store in each predecessor. @@ -141,7 +156,7 @@ AllocaInst *llvm::DemotePHIToStack(PHINode *P, Instruction *AllocaPoint) { "Invoke edge not supported yet"); (void)II; } new StoreInst(P->getIncomingValue(i), Slot, - P->getIncomingBlock(i)->getTerminator()); + P->getIncomingBlock(i)->getTerminator()->getIterator()); } // Insert a load in place of the PHI and replace all uses. @@ -159,12 +174,12 @@ AllocaInst *llvm::DemotePHIToStack(PHINode *P, Instruction *AllocaPoint) { } for (Instruction *User : Users) { Value *V = - new LoadInst(P->getType(), Slot, P->getName() + ".reload", User); + new LoadInst(P->getType(), Slot, P->getName() + ".reload", User->getIterator()); User->replaceUsesOfWith(P, V); } } else { Value *V = - new LoadInst(P->getType(), Slot, P->getName() + ".reload", &*InsertPt); + new LoadInst(P->getType(), Slot, P->getName() + ".reload", InsertPt); P->replaceAllUsesWith(V); } // Delete PHI. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp index 092f1799755d..d12c540f9a4d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/EntryExitInstrumenter.cpp @@ -15,12 +15,15 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" +#include "llvm/InitializePasses.h" #include "llvm/TargetParser/Triple.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Utils.h" using namespace llvm; static void insertCall(Function &CurFn, StringRef Func, - Instruction *InsertionPt, DebugLoc DL) { + BasicBlock::iterator InsertionPt, DebugLoc DL) { Module &M = *InsertionPt->getParent()->getParent()->getParent(); LLVMContext &C = InsertionPt->getParent()->getContext(); @@ -105,7 +108,7 @@ static bool runOnFunction(Function &F, bool PostInlining) { if (auto SP = F.getSubprogram()) DL = DILocation::get(SP->getContext(), SP->getScopeLine(), 0, SP); - insertCall(F, EntryFunc, &*F.begin()->getFirstInsertionPt(), DL); + insertCall(F, EntryFunc, F.begin()->getFirstInsertionPt(), DL); Changed = true; F.removeFnAttr(EntryAttr); } @@ -126,7 +129,7 @@ static bool runOnFunction(Function &F, bool PostInlining) { else if (auto SP = F.getSubprogram()) DL = DILocation::get(SP->getContext(), 0, 0, SP); - insertCall(F, ExitFunc, T, DL); + insertCall(F, ExitFunc, T->getIterator(), DL); Changed = true; } F.removeFnAttr(ExitAttr); @@ -135,9 +138,42 @@ static bool runOnFunction(Function &F, bool PostInlining) { return Changed; } +namespace { +struct PostInlineEntryExitInstrumenter : public FunctionPass { + static char ID; + PostInlineEntryExitInstrumenter() : FunctionPass(ID) { + initializePostInlineEntryExitInstrumenterPass( + *PassRegistry::getPassRegistry()); + } + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + } + bool runOnFunction(Function &F) override { return ::runOnFunction(F, true); } +}; +char PostInlineEntryExitInstrumenter::ID = 0; +} + +INITIALIZE_PASS_BEGIN( + PostInlineEntryExitInstrumenter, "post-inline-ee-instrument", + "Instrument function entry/exit with calls to e.g. mcount() " + "(post inlining)", + false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END( + PostInlineEntryExitInstrumenter, "post-inline-ee-instrument", + "Instrument function entry/exit with calls to e.g. mcount() " + "(post inlining)", + false, false) + +FunctionPass *llvm::createPostInlineEntryExitInstrumenterPass() { + return new PostInlineEntryExitInstrumenter(); +} + PreservedAnalyses llvm::EntryExitInstrumenterPass::run(Function &F, FunctionAnalysisManager &AM) { - runOnFunction(F, PostInlining); + if (!runOnFunction(F, PostInlining)) + return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); return PA; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/FlattenCFG.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/FlattenCFG.cpp index c5cb3748a52f..16b4bb1981d8 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/FlattenCFG.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/FlattenCFG.cpp @@ -28,7 +28,7 @@ using namespace llvm; -#define DEBUG_TYPE "flattencfg" +#define DEBUG_TYPE "flatten-cfg" namespace { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/FunctionComparator.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/FunctionComparator.cpp index 09e19be0d293..47d4e167b1c8 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/FunctionComparator.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/FunctionComparator.cpp @@ -143,6 +143,18 @@ int FunctionComparator::cmpAttrs(const AttributeList L, if (int Res = cmpNumbers((uint64_t)TyL, (uint64_t)TyR)) return Res; continue; + } else if (LA.isConstantRangeAttribute() && + RA.isConstantRangeAttribute()) { + if (LA.getKindAsEnum() != RA.getKindAsEnum()) + return cmpNumbers(LA.getKindAsEnum(), RA.getKindAsEnum()); + + const ConstantRange &LCR = LA.getRange(); + const ConstantRange &RCR = RA.getRange(); + if (int Res = cmpAPInts(LCR.getLower(), RCR.getLower())) + return Res; + if (int Res = cmpAPInts(LCR.getUpper(), RCR.getUpper())) + return Res; + continue; } if (LA < RA) return -1; @@ -416,19 +428,27 @@ int FunctionComparator::cmpConstants(const Constant *L, cast<Constant>(RE->getOperand(i)))) return Res; } - if (LE->isCompare()) - if (int Res = cmpNumbers(LE->getPredicate(), RE->getPredicate())) - return Res; if (auto *GEPL = dyn_cast<GEPOperator>(LE)) { auto *GEPR = cast<GEPOperator>(RE); if (int Res = cmpTypes(GEPL->getSourceElementType(), GEPR->getSourceElementType())) return Res; - if (int Res = cmpNumbers(GEPL->isInBounds(), GEPR->isInBounds())) - return Res; - if (int Res = cmpNumbers(GEPL->getInRangeIndex().value_or(unsigned(-1)), - GEPR->getInRangeIndex().value_or(unsigned(-1)))) + if (int Res = cmpNumbers(GEPL->getNoWrapFlags().getRaw(), + GEPR->getNoWrapFlags().getRaw())) return Res; + + std::optional<ConstantRange> InRangeL = GEPL->getInRange(); + std::optional<ConstantRange> InRangeR = GEPR->getInRange(); + if (InRangeL) { + if (!InRangeR) + return 1; + if (int Res = cmpAPInts(InRangeL->getLower(), InRangeR->getLower())) + return Res; + if (int Res = cmpAPInts(InRangeL->getUpper(), InRangeR->getUpper())) + return Res; + } else if (InRangeR) { + return -1; + } } if (auto *OBOL = dyn_cast<OverflowingBinaryOperator>(LE)) { auto *OBOR = cast<OverflowingBinaryOperator>(RE); @@ -504,7 +524,7 @@ int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const { PointerType *PTyL = dyn_cast<PointerType>(TyL); PointerType *PTyR = dyn_cast<PointerType>(TyR); - const DataLayout &DL = FnL->getParent()->getDataLayout(); + const DataLayout &DL = FnL->getDataLayout(); if (PTyL && PTyL->getAddressSpace() == 0) TyL = DL.getIntPtrType(TyL); if (PTyR && PTyR->getAddressSpace() == 0) @@ -785,7 +805,7 @@ int FunctionComparator::cmpGEPs(const GEPOperator *GEPL, // When we have target data, we can reduce the GEP down to the value in bytes // added to the address. - const DataLayout &DL = FnL->getParent()->getDataLayout(); + const DataLayout &DL = FnL->getDataLayout(); unsigned OffsetBitWidth = DL.getIndexSizeInBits(ASL); APInt OffsetL(OffsetBitWidth, 0), OffsetR(OffsetBitWidth, 0); if (GEPL->accumulateConstantOffset(DL, OffsetL) && diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/GlobalStatus.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/GlobalStatus.cpp index c5aded3c45f4..b177e048faae 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/GlobalStatus.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/GlobalStatus.cpp @@ -172,9 +172,14 @@ static bool analyzeGlobalAux(const Value *V, GlobalStatus &GS, return true; GS.StoredType = GlobalStatus::Stored; } else if (const auto *CB = dyn_cast<CallBase>(I)) { - if (!CB->isCallee(&U)) - return true; - GS.IsLoaded = true; + if (CB->getIntrinsicID() == Intrinsic::threadlocal_address) { + if (analyzeGlobalAux(I, GS, VisitedUsers)) + return true; + } else { + if (!CB->isCallee(&U)) + return true; + GS.IsLoaded = true; + } } else { return true; // Any other non-load instruction might take address! } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/HelloWorld.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/HelloWorld.cpp index 7019e9e4451b..1098281fa0e1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/HelloWorld.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/HelloWorld.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/HelloWorld.h" +#include "llvm/IR/Function.h" using namespace llvm; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/InlineFunction.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/InlineFunction.cpp index d4d4bf5ebdf3..fda1c22cc1fb 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -23,6 +23,7 @@ #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/IndirectCallVisitor.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryProfileInfo.h" #include "llvm/Analysis/ObjCARCAnalysisUtils.h" @@ -30,11 +31,12 @@ #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" -#include "llvm/IR/AttributeMask.h" #include "llvm/IR/Argument.h" +#include "llvm/IR/AttributeMask.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constant.h" +#include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfo.h" @@ -55,6 +57,7 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" @@ -689,7 +692,7 @@ static void HandleInlinedEHPad(InvokeInst *II, BasicBlock *FirstNewBlock, if (auto *CRI = dyn_cast<CleanupReturnInst>(BB->getTerminator())) { if (CRI->unwindsToCaller()) { auto *CleanupPad = CRI->getCleanupPad(); - CleanupReturnInst::Create(CleanupPad, UnwindDest, CRI); + CleanupReturnInst::Create(CleanupPad, UnwindDest, CRI->getIterator()); CRI->eraseFromParent(); UpdatePHINodes(&*BB); // Finding a cleanupret with an unwind destination would confuse @@ -737,7 +740,7 @@ static void HandleInlinedEHPad(InvokeInst *II, BasicBlock *FirstNewBlock, auto *NewCatchSwitch = CatchSwitchInst::Create( CatchSwitch->getParentPad(), UnwindDest, CatchSwitch->getNumHandlers(), CatchSwitch->getName(), - CatchSwitch); + CatchSwitch->getIterator()); for (BasicBlock *PadBB : CatchSwitch->handlers()) NewCatchSwitch->addHandler(PadBB); // Propagate info for the old catchswitch over to the new one in @@ -972,7 +975,7 @@ static void PropagateOperandBundles(Function::iterator InlinedBB, I->getOperandBundlesAsDefs(OpBundles); OpBundles.emplace_back("funclet", CallSiteEHPad); - Instruction *NewInst = CallBase::Create(I, OpBundles, I); + Instruction *NewInst = CallBase::Create(I, OpBundles, I->getIterator()); NewInst->takeName(I); I->replaceAllUsesWith(NewInst); I->eraseFromParent(); @@ -1220,7 +1223,6 @@ static void AddAliasScopeMetadata(CallBase &CB, ValueToValueMapTy &VMap, SmallPtrSet<const Value *, 4> ObjSet; SmallVector<Metadata *, 4> Scopes, NoAliases; - SmallSetVector<const Argument *, 4> NAPtrArgs; for (const Value *V : PtrArgs) { SmallVector<const Value *, 4> Objects; getUnderlyingObjects(V, Objects, /* LI = */ nullptr); @@ -1344,6 +1346,89 @@ static bool MayContainThrowingOrExitingCallAfterCB(CallBase *Begin, ++BeginIt, End->getIterator(), InlinerAttributeWindow + 1); } +// Add attributes from CB params and Fn attributes that can always be propagated +// to the corresponding argument / inner callbases. +static void AddParamAndFnBasicAttributes(const CallBase &CB, + ValueToValueMapTy &VMap, + ClonedCodeInfo &InlinedFunctionInfo) { + auto *CalledFunction = CB.getCalledFunction(); + auto &Context = CalledFunction->getContext(); + + // Collect valid attributes for all params. + SmallVector<AttrBuilder> ValidParamAttrs; + bool HasAttrToPropagate = false; + + for (unsigned I = 0, E = CB.arg_size(); I < E; ++I) { + ValidParamAttrs.emplace_back(AttrBuilder{CB.getContext()}); + // Access attributes can be propagated to any param with the same underlying + // object as the argument. + if (CB.paramHasAttr(I, Attribute::ReadNone)) + ValidParamAttrs.back().addAttribute(Attribute::ReadNone); + if (CB.paramHasAttr(I, Attribute::ReadOnly)) + ValidParamAttrs.back().addAttribute(Attribute::ReadOnly); + HasAttrToPropagate |= ValidParamAttrs.back().hasAttributes(); + } + + // Won't be able to propagate anything. + if (!HasAttrToPropagate) + return; + + for (BasicBlock &BB : *CalledFunction) { + for (Instruction &Ins : BB) { + const auto *InnerCB = dyn_cast<CallBase>(&Ins); + if (!InnerCB) + continue; + auto *NewInnerCB = dyn_cast_or_null<CallBase>(VMap.lookup(InnerCB)); + if (!NewInnerCB) + continue; + // The InnerCB might have be simplified during the inlining + // process which can make propagation incorrect. + if (InlinedFunctionInfo.isSimplified(InnerCB, NewInnerCB)) + continue; + + AttributeList AL = NewInnerCB->getAttributes(); + for (unsigned I = 0, E = InnerCB->arg_size(); I < E; ++I) { + // Check if the underlying value for the parameter is an argument. + const Value *UnderlyingV = + getUnderlyingObject(InnerCB->getArgOperand(I)); + const Argument *Arg = dyn_cast<Argument>(UnderlyingV); + if (!Arg) + continue; + + if (NewInnerCB->paramHasAttr(I, Attribute::ByVal)) + // It's unsound to propagate memory attributes to byval arguments. + // Even if CalledFunction doesn't e.g. write to the argument, + // the call to NewInnerCB may write to its by-value copy. + continue; + + unsigned ArgNo = Arg->getArgNo(); + // If so, propagate its access attributes. + AL = AL.addParamAttributes(Context, I, ValidParamAttrs[ArgNo]); + // We can have conflicting attributes from the inner callsite and + // to-be-inlined callsite. In that case, choose the most + // restrictive. + + // readonly + writeonly means we can never deref so make readnone. + if (AL.hasParamAttr(I, Attribute::ReadOnly) && + AL.hasParamAttr(I, Attribute::WriteOnly)) + AL = AL.addParamAttribute(Context, I, Attribute::ReadNone); + + // If have readnone, need to clear readonly/writeonly + if (AL.hasParamAttr(I, Attribute::ReadNone)) { + AL = AL.removeParamAttribute(Context, I, Attribute::ReadOnly); + AL = AL.removeParamAttribute(Context, I, Attribute::WriteOnly); + } + + // Writable cannot exist in conjunction w/ readonly/readnone + if (AL.hasParamAttr(I, Attribute::ReadOnly) || + AL.hasParamAttr(I, Attribute::ReadNone)) + AL = AL.removeParamAttribute(Context, I, Attribute::Writable); + } + NewInnerCB->setAttributes(AL); + } + } +} + // Only allow these white listed attributes to be propagated back to the // callee. This is because other attributes may only be valid on the call // itself, i.e. attributes such as signext and zeroext. @@ -1371,10 +1456,13 @@ static AttrBuilder IdentifyValidPoisonGeneratingAttributes(CallBase &CB) { Valid.addAttribute(Attribute::NonNull); if (CB.hasRetAttr(Attribute::Alignment)) Valid.addAlignmentAttr(CB.getRetAlign()); + if (std::optional<ConstantRange> Range = CB.getRange()) + Valid.addRangeAttr(*Range); return Valid; } -static void AddReturnAttributes(CallBase &CB, ValueToValueMapTy &VMap) { +static void AddReturnAttributes(CallBase &CB, ValueToValueMapTy &VMap, + ClonedCodeInfo &InlinedFunctionInfo) { AttrBuilder ValidUB = IdentifyValidUBGeneratingAttributes(CB); AttrBuilder ValidPG = IdentifyValidPoisonGeneratingAttributes(CB); if (!ValidUB.hasAttributes() && !ValidPG.hasAttributes()) @@ -1393,6 +1481,11 @@ static void AddReturnAttributes(CallBase &CB, ValueToValueMapTy &VMap) { auto *NewRetVal = dyn_cast_or_null<CallBase>(VMap.lookup(RetVal)); if (!NewRetVal) continue; + + // The RetVal might have be simplified during the inlining + // process which can make propagation incorrect. + if (InlinedFunctionInfo.isSimplified(RetVal, NewRetVal)) + continue; // Backward propagation of attributes to the returned value may be incorrect // if it is control flow dependent. // Consider: @@ -1462,6 +1555,14 @@ static void AddReturnAttributes(CallBase &CB, ValueToValueMapTy &VMap) { if (ValidPG.getAlignment().valueOrOne() < AL.getRetAlignment().valueOrOne()) ValidPG.removeAttribute(Attribute::Alignment); if (ValidPG.hasAttributes()) { + Attribute CBRange = ValidPG.getAttribute(Attribute::Range); + if (CBRange.isValid()) { + Attribute NewRange = AL.getRetAttr(Attribute::Range); + if (NewRange.isValid()) { + ValidPG.addRangeAttr( + CBRange.getRange().intersectWith(NewRange.getRange())); + } + } // Three checks. // If the callsite has `noundef`, then a poison due to violating the // return attribute will create UB anyways so we can always propagate. @@ -1488,7 +1589,7 @@ static void AddAlignmentAssumptions(CallBase &CB, InlineFunctionInfo &IFI) { return; AssumptionCache *AC = &IFI.GetAssumptionCache(*CB.getCaller()); - auto &DL = CB.getCaller()->getParent()->getDataLayout(); + auto &DL = CB.getDataLayout(); // To avoid inserting redundant assumptions, we should check for assumptions // already in the caller. To do this, we might need a DT of the caller. @@ -1551,7 +1652,7 @@ static Value *HandleByValArgument(Type *ByValType, Value *Arg, InlineFunctionInfo &IFI, MaybeAlign ByValAlignment) { Function *Caller = TheCall->getFunction(); - const DataLayout &DL = Caller->getParent()->getDataLayout(); + const DataLayout &DL = Caller->getDataLayout(); // If the called function is readonly, then it could not mutate the caller's // copy of the byval'd memory. In this case, it is safe to elide the copy and @@ -1585,8 +1686,9 @@ static Value *HandleByValArgument(Type *ByValType, Value *Arg, if (ByValAlignment) Alignment = std::max(Alignment, *ByValAlignment); - AllocaInst *NewAlloca = new AllocaInst(ByValType, DL.getAllocaAddrSpace(), - nullptr, Alignment, Arg->getName()); + AllocaInst *NewAlloca = + new AllocaInst(ByValType, Arg->getType()->getPointerAddressSpace(), + nullptr, Alignment, Arg->getName()); NewAlloca->insertBefore(Caller->begin()->begin()); IFI.StaticAllocas.push_back(NewAlloca); @@ -1710,26 +1812,25 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI, }; // Helper-util for updating debug-info records attached to instructions. - auto UpdateDPV = [&](DPValue *DPV) { - assert(DPV->getDebugLoc() && "Debug Value must have debug loc"); + auto UpdateDVR = [&](DbgRecord *DVR) { + assert(DVR->getDebugLoc() && "Debug Value must have debug loc"); if (NoInlineLineTables) { - DPV->setDebugLoc(TheCallDL); + DVR->setDebugLoc(TheCallDL); return; } - DebugLoc DL = DPV->getDebugLoc(); + DebugLoc DL = DVR->getDebugLoc(); DebugLoc IDL = inlineDebugLoc(DL, InlinedAtNode, - DPV->getMarker()->getParent()->getContext(), IANodes); - DPV->setDebugLoc(IDL); + DVR->getMarker()->getParent()->getContext(), IANodes); + DVR->setDebugLoc(IDL); }; // Iterate over all instructions, updating metadata and debug-info records. for (; FI != Fn->end(); ++FI) { - for (BasicBlock::iterator BI = FI->begin(), BE = FI->end(); BI != BE; - ++BI) { - UpdateInst(*BI); - for (DPValue &DPV : BI->getDbgValueRange()) { - UpdateDPV(&DPV); + for (Instruction &I : *FI) { + UpdateInst(I); + for (DbgRecord &DVR : I.getDbgRecordRange()) { + UpdateDVR(&DVR); } } @@ -1741,7 +1842,7 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI, BI = BI->eraseFromParent(); continue; } else { - BI->dropDbgValues(); + BI->dropDbgRecords(); } ++BI; } @@ -1797,7 +1898,7 @@ static at::StorageToVarsMap collectEscapedLocals(const DataLayout &DL, EscapedLocals[Base].insert(at::VarRecord(DbgAssign)); }; for_each(at::getAssignmentMarkers(Base), CollectAssignsForStorage); - for_each(at::getDPVAssignmentMarkers(Base), CollectAssignsForStorage); + for_each(at::getDVRAssignmentMarkers(Base), CollectAssignsForStorage); } return EscapedLocals; } @@ -1815,29 +1916,12 @@ static void trackInlinedStores(Function::iterator Start, Function::iterator End, /// otherwise a function inlined more than once into the same function /// will cause DIAssignID to be shared by many instructions. static void fixupAssignments(Function::iterator Start, Function::iterator End) { - // Map {Old, New} metadata. Not used directly - use GetNewID. DenseMap<DIAssignID *, DIAssignID *> Map; - auto GetNewID = [&Map](Metadata *Old) { - DIAssignID *OldID = cast<DIAssignID>(Old); - if (DIAssignID *NewID = Map.lookup(OldID)) - return NewID; - DIAssignID *NewID = DIAssignID::getDistinct(OldID->getContext()); - Map[OldID] = NewID; - return NewID; - }; // Loop over all the inlined instructions. If we find a DIAssignID // attachment or use, replace it with a new version. for (auto BBI = Start; BBI != End; ++BBI) { - for (Instruction &I : *BBI) { - for (DPValue &DPV : I.getDbgValueRange()) { - if (DPV.isDbgAssign()) - DPV.setAssignId(GetNewID(DPV.getAssignID())); - } - if (auto *ID = I.getMetadata(LLVMContext::MD_DIAssignID)) - I.setMetadata(LLVMContext::MD_DIAssignID, GetNewID(ID)); - else if (auto *DAI = dyn_cast<DbgAssignIntrinsic>(&I)) - DAI->setAssignId(GetNewID(DAI->getAssignID())); - } + for (Instruction &I : *BBI) + at::remapAssignID(Map, I); } } #undef DEBUG_TYPE @@ -1906,13 +1990,29 @@ void llvm::updateProfileCallee( ? 0 : PriorEntryCount + EntryDelta; + auto updateVTableProfWeight = [](CallBase *CB, const uint64_t NewEntryCount, + const uint64_t PriorEntryCount) { + Instruction *VPtr = PGOIndirectCallVisitor::tryGetVTableInstruction(CB); + if (VPtr) + scaleProfData(*VPtr, NewEntryCount, PriorEntryCount); + }; + // During inlining ? if (VMap) { uint64_t CloneEntryCount = PriorEntryCount - NewEntryCount; - for (auto Entry : *VMap) + for (auto Entry : *VMap) { if (isa<CallInst>(Entry.first)) - if (auto *CI = dyn_cast_or_null<CallInst>(Entry.second)) + if (auto *CI = dyn_cast_or_null<CallInst>(Entry.second)) { CI->updateProfWeight(CloneEntryCount, PriorEntryCount); + updateVTableProfWeight(CI, CloneEntryCount, PriorEntryCount); + } + + if (isa<InvokeInst>(Entry.first)) + if (auto *II = dyn_cast_or_null<InvokeInst>(Entry.second)) { + II->updateProfWeight(CloneEntryCount, PriorEntryCount); + updateVTableProfWeight(II, CloneEntryCount, PriorEntryCount); + } + } } if (EntryDelta) { @@ -1921,9 +2021,16 @@ void llvm::updateProfileCallee( for (BasicBlock &BB : *Callee) // No need to update the callsite if it is pruned during inlining. if (!VMap || VMap->count(&BB)) - for (Instruction &I : BB) - if (CallInst *CI = dyn_cast<CallInst>(&I)) + for (Instruction &I : BB) { + if (CallInst *CI = dyn_cast<CallInst>(&I)) { CI->updateProfWeight(NewEntryCount, PriorEntryCount); + updateVTableProfWeight(CI, NewEntryCount, PriorEntryCount); + } + if (InvokeInst *II = dyn_cast<InvokeInst>(&I)) { + II->updateProfWeight(NewEntryCount, PriorEntryCount); + updateVTableProfWeight(II, NewEntryCount, PriorEntryCount); + } + } } } @@ -2002,7 +2109,7 @@ inlineRetainOrClaimRVCalls(CallBase &CB, objcarc::ARCInstKind RVCallKind, Value *BundleArgs[] = {*objcarc::getAttachedARCFunction(&CB)}; OperandBundleDef OB("clang.arc.attachedcall", BundleArgs); auto *NewCall = CallBase::addOperandBundle( - CI, LLVMContext::OB_clang_arc_attachedcall, OB, CI); + CI, LLVMContext::OB_clang_arc_attachedcall, OB, CI->getIterator()); NewCall->copyMetadata(*CI); CI->replaceAllUsesWith(NewCall); CI->eraseFromParent(); @@ -2103,13 +2210,6 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, BasicBlock *OrigBB = CB.getParent(); Function *Caller = OrigBB->getParent(); - // Do not inline strictfp function into non-strictfp one. It would require - // conversion of all FP operations in host function to constrained intrinsics. - if (CalledFunc->getAttributes().hasFnAttr(Attribute::StrictFP) && - !Caller->getAttributes().hasFnAttr(Attribute::StrictFP)) { - return InlineResult::failure("incompatible strictfp attributes"); - } - // GC poses two hazards to inlining, which only occur when the callee has GC: // 1. If the caller has no GC, then the callee's GC must be propagated to the // caller. @@ -2223,7 +2323,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // callee. ScopedAliasMetadataDeepCloner SAMetadataCloner(CB.getCalledFunction()); - auto &DL = Caller->getParent()->getDataLayout(); + auto &DL = Caller->getDataLayout(); // Calculate the vector of arguments to pass into the function cloner, which // matches up the formal to the actual argument values. @@ -2333,7 +2433,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, OpDefs.emplace_back("deopt", std::move(MergedDeoptArgs)); } - Instruction *NewI = CallBase::Create(ICS, OpDefs, ICS); + Instruction *NewI = CallBase::Create(ICS, OpDefs, ICS->getIterator()); // Note: the RAUW does the appropriate fixup in VMap, so we need to do // this even if the call returns void. @@ -2368,7 +2468,11 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // Clone return attributes on the callsite into the calls within the inlined // function which feed into its return value. - AddReturnAttributes(CB, VMap); + AddReturnAttributes(CB, VMap, InlinedFunctionInfo); + + // Clone attributes on the params of the callsite to calls within the + // inlined function which use the same param. + AddParamAndFnBasicAttributes(CB, VMap, InlinedFunctionInfo); propagateMemProfMetadata(CalledFunc, CB, InlinedFunctionInfo.ContainsMemProfMetadata, VMap); @@ -2486,7 +2590,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, SmallVector<Value *, 6> Params(CI->args()); Params.append(VarArgsToForward.begin(), VarArgsToForward.end()); CallInst *NewCI = CallInst::Create( - CI->getFunctionType(), CI->getCalledOperand(), Params, "", CI); + CI->getFunctionType(), CI->getCalledOperand(), Params, "", CI->getIterator()); NewCI->setDebugLoc(CI->getDebugLoc()); NewCI->setAttributes(Attrs); NewCI->setCallingConv(CI->getCallingConv()); @@ -2538,8 +2642,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, if ((InsertLifetime || Caller->isPresplitCoroutine()) && !IFI.StaticAllocas.empty()) { IRBuilder<> builder(&*FirstNewBlock, FirstNewBlock->begin()); - for (unsigned ai = 0, ae = IFI.StaticAllocas.size(); ai != ae; ++ai) { - AllocaInst *AI = IFI.StaticAllocas[ai]; + for (AllocaInst *AI : IFI.StaticAllocas) { // Don't mark swifterror allocas. They can't have bitcast uses. if (AI->isSwiftError()) continue; @@ -2553,7 +2656,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, ConstantInt *AllocaSize = nullptr; if (ConstantInt *AIArraySize = dyn_cast<ConstantInt>(AI->getArraySize())) { - auto &DL = Caller->getParent()->getDataLayout(); + auto &DL = Caller->getDataLayout(); Type *AllocaType = AI->getAllocatedType(); TypeSize AllocaTypeSize = DL.getTypeAllocSize(AllocaType); uint64_t AllocaArraySize = AIArraySize->getLimitedValue(); @@ -2783,7 +2886,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // If the call site was an invoke instruction, add a branch to the normal // destination. if (InvokeInst *II = dyn_cast<InvokeInst>(&CB)) { - BranchInst *NewBr = BranchInst::Create(II->getNormalDest(), &CB); + BranchInst *NewBr = BranchInst::Create(II->getNormalDest(), CB.getIterator()); NewBr->setDebugLoc(Returns[0]->getDebugLoc()); } @@ -2820,7 +2923,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, if (InvokeInst *II = dyn_cast<InvokeInst>(&CB)) { // Add an unconditional branch to make this look like the CallInst case... - CreatedBranchToNormalDest = BranchInst::Create(II->getNormalDest(), &CB); + CreatedBranchToNormalDest = BranchInst::Create(II->getNormalDest(), CB.getIterator()); // Split the basic block. This guarantees that no PHI nodes will have to be // updated due to new incoming edges, and make the invoke case more @@ -2876,8 +2979,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // Loop over all of the return instructions adding entries to the PHI node // as appropriate. if (PHI) { - for (unsigned i = 0, e = Returns.size(); i != e; ++i) { - ReturnInst *RI = Returns[i]; + for (ReturnInst *RI : Returns) { assert(RI->getReturnValue()->getType() == PHI->getType() && "Ret value not consistent in function!"); PHI->addIncoming(RI->getReturnValue(), RI->getParent()); @@ -2886,9 +2988,8 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, // Add a branch to the merge points and remove return instructions. DebugLoc Loc; - for (unsigned i = 0, e = Returns.size(); i != e; ++i) { - ReturnInst *RI = Returns[i]; - BranchInst* BI = BranchInst::Create(AfterCallBB, RI); + for (ReturnInst *RI : Returns) { + BranchInst *BI = BranchInst::Create(AfterCallBB, RI->getIterator()); Loc = RI->getDebugLoc(); BI->setDebugLoc(Loc); RI->eraseFromParent(); @@ -2959,7 +3060,7 @@ llvm::InlineResult llvm::InlineFunction(CallBase &CB, InlineFunctionInfo &IFI, if (PHI) { AssumptionCache *AC = IFI.GetAssumptionCache ? &IFI.GetAssumptionCache(*Caller) : nullptr; - auto &DL = Caller->getParent()->getDataLayout(); + auto &DL = Caller->getDataLayout(); if (Value *V = simplifyInstruction(PHI, {DL, nullptr, nullptr, AC})) { PHI->replaceAllUsesWith(V); PHI->eraseFromParent(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LCSSA.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LCSSA.cpp index 5e0c312fe149..ab1edf47d8db 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LCSSA.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LCSSA.cpp @@ -242,8 +242,8 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, } SmallVector<DbgValueInst *, 4> DbgValues; - SmallVector<DPValue *, 4> DPValues; - llvm::findDbgValues(DbgValues, I, &DPValues); + SmallVector<DbgVariableRecord *, 4> DbgVariableRecords; + llvm::findDbgValues(DbgValues, I, &DbgVariableRecords); // Update pre-existing debug value uses that reside outside the loop. for (auto *DVI : DbgValues) { @@ -261,8 +261,8 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, // RemoveDIs: copy-paste of block above, using non-instruction debug-info // records. - for (DPValue *DPV : DPValues) { - BasicBlock *UserBB = DPV->getMarker()->getParent(); + for (DbgVariableRecord *DVR : DbgVariableRecords) { + BasicBlock *UserBB = DVR->getMarker()->getParent(); if (InstBB == UserBB || L->contains(UserBB)) continue; // We currently only handle debug values residing in blocks that were @@ -271,7 +271,7 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, Value *V = AddedPHIs.size() == 1 ? AddedPHIs[0] : SSAUpdate.FindValueForBlock(UserBB); if (V) - DPV->replaceVariableLocationOp(I, V); + DVR->replaceVariableLocationOp(I, V); } // SSAUpdater might have inserted phi-nodes inside other loops. We'll need diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp index 6220f8509309..9fe655e548c2 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp @@ -467,7 +467,7 @@ Value *LibCallsShrinkWrap::generateCondForPow(CallInst *CI, void LibCallsShrinkWrap::shrinkWrapCI(CallInst *CI, Value *Cond) { assert(Cond != nullptr && "ShrinkWrapCI is not expecting an empty call inst"); MDNode *BranchWeights = - MDBuilder(CI->getContext()).createBranchWeights(1, 2000); + MDBuilder(CI->getContext()).createUnlikelyBranchWeights(); Instruction *NewInst = SplitBlockAndInsertIfThen(Cond, CI, false, BranchWeights, &DTU); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/Local.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/Local.cpp index a1c6bbc52fd0..f68cbf62b982 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/Local.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/Local.cpp @@ -59,6 +59,7 @@ #include "llvm/IR/IntrinsicsWebAssembly.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/MemoryModelRelaxationAnnotations.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" @@ -230,7 +231,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, // Remove weight for this case. std::swap(Weights[Idx + 1], Weights.back()); Weights.pop_back(); - setBranchWeights(*SI, Weights); + setBranchWeights(*SI, Weights, hasBranchWeightOrigin(MD)); } // Remove this entry. BasicBlock *ParentBB = SI->getParent(); @@ -475,6 +476,12 @@ bool llvm::wouldInstructionBeTriviallyDead(const Instruction *I, II->getIntrinsicID() == Intrinsic::launder_invariant_group) return true; + // Intrinsics declare sideeffects to prevent them from moving, but they are + // nops without users. + if (II->getIntrinsicID() == Intrinsic::allow_runtime_check || + II->getIntrinsicID() == Intrinsic::allow_ubsan_check) + return true; + if (II->isLifetimeStartOrEnd()) { auto *Arg = II->getArgOperand(1); // Lifetime intrinsics are dead when their right-hand is undef. @@ -609,12 +616,12 @@ void llvm::RecursivelyDeleteTriviallyDeadInstructions( bool llvm::replaceDbgUsesWithUndef(Instruction *I) { SmallVector<DbgVariableIntrinsic *, 1> DbgUsers; - SmallVector<DPValue *, 1> DPUsers; + SmallVector<DbgVariableRecord *, 1> DPUsers; findDbgUsers(DbgUsers, I, &DPUsers); for (auto *DII : DbgUsers) DII->setKillLocation(); - for (auto *DPV : DPUsers) - DPV->setKillLocation(); + for (auto *DVR : DPUsers) + DVR->setKillLocation(); return !DbgUsers.empty() || !DPUsers.empty(); } @@ -724,7 +731,7 @@ simplifyAndDCEInstruction(Instruction *I, bool llvm::SimplifyInstructionsInBlock(BasicBlock *BB, const TargetLibraryInfo *TLI) { bool MadeChange = false; - const DataLayout &DL = BB->getModule()->getDataLayout(); + const DataLayout &DL = BB->getDataLayout(); #ifndef NDEBUG // In debug builds, ensure that the terminator of the block is never replaced @@ -1021,7 +1028,13 @@ CanRedirectPredsOfEmptyBBToSucc(BasicBlock *BB, BasicBlock *Succ, if (!BB->hasNPredecessorsOrMore(2)) return false; - // Get single common predecessors of both BB and Succ + if (any_of(BBPreds, [](const BasicBlock *Pred) { + return isa<IndirectBrInst>(Pred->getTerminator()); + })) + return false; + + // Get the single common predecessor of both BB and Succ. Return false + // when there are more than one common predecessors. for (BasicBlock *SuccPred : SuccPreds) { if (BBPreds.count(SuccPred)) { if (CommonPred) @@ -1088,11 +1101,9 @@ static void redirectValuesFromPredecessorsToPhi(BasicBlock *BB, PN->addIncoming(OldValPN->getIncomingValueForBlock(CommonPred), BB); } else { - for (unsigned i = 0, e = BBPreds.size(); i != e; ++i) { + for (BasicBlock *PredBB : BBPreds) { // Update existing incoming values in PN for this // predecessor of BB. - BasicBlock *PredBB = BBPreds[i]; - if (PredBB == CommonPred) continue; @@ -1128,7 +1139,7 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, bool BBKillable = CanPropagatePredecessorsForPHIs(BB, Succ, BBPreds); - // Even if we can not fold bB into Succ, we may be able to redirect the + // Even if we can not fold BB into Succ, we may be able to redirect the // predecessors of BB to Succ. bool BBPhisMergeable = BBKillable || @@ -1570,16 +1581,16 @@ static bool PhiHasDebugValue(DILocalVariable *DIVar, // is removed by LowerDbgDeclare(), we need to make sure that we are // not inserting the same dbg.value intrinsic over and over. SmallVector<DbgValueInst *, 1> DbgValues; - SmallVector<DPValue *, 1> DPValues; - findDbgValues(DbgValues, APN, &DPValues); + SmallVector<DbgVariableRecord *, 1> DbgVariableRecords; + findDbgValues(DbgValues, APN, &DbgVariableRecords); for (auto *DVI : DbgValues) { assert(is_contained(DVI->getValues(), APN)); if ((DVI->getVariable() == DIVar) && (DVI->getExpression() == DIExpr)) return true; } - for (auto *DPV : DPValues) { - assert(is_contained(DPV->location_ops(), APN)); - if ((DPV->getVariable() == DIVar) && (DPV->getExpression() == DIExpr)) + for (auto *DVR : DbgVariableRecords) { + assert(is_contained(DVR->location_ops(), APN)); + if ((DVR->getVariable() == DIVar) && (DVR->getExpression() == DIExpr)) return true; } return false; @@ -1594,9 +1605,10 @@ static bool PhiHasDebugValue(DILocalVariable *DIVar, /// value when doing the comparison. E.g. an i1 value will be identified as /// covering an n-bit fragment, if the store size of i1 is at least n bits. static bool valueCoversEntireFragment(Type *ValTy, DbgVariableIntrinsic *DII) { - const DataLayout &DL = DII->getModule()->getDataLayout(); + const DataLayout &DL = DII->getDataLayout(); TypeSize ValueSize = DL.getTypeAllocSizeInBits(ValTy); - if (std::optional<uint64_t> FragmentSize = DII->getFragmentSizeInBits()) + if (std::optional<uint64_t> FragmentSize = + DII->getExpression()->getActiveBits(DII->getVariable())) return TypeSize::isKnownGE(ValueSize, TypeSize::getFixed(*FragmentSize)); // We can't always calculate the size of the DI variable (e.g. if it is a @@ -1617,23 +1629,24 @@ static bool valueCoversEntireFragment(Type *ValTy, DbgVariableIntrinsic *DII) { // Could not determine size of variable. Conservatively return false. return false; } -// RemoveDIs: duplicate implementation of the above, using DPValues, the -// replacement for dbg.values. -static bool valueCoversEntireFragment(Type *ValTy, DPValue *DPV) { - const DataLayout &DL = DPV->getModule()->getDataLayout(); +// RemoveDIs: duplicate implementation of the above, using DbgVariableRecords, +// the replacement for dbg.values. +static bool valueCoversEntireFragment(Type *ValTy, DbgVariableRecord *DVR) { + const DataLayout &DL = DVR->getModule()->getDataLayout(); TypeSize ValueSize = DL.getTypeAllocSizeInBits(ValTy); - if (std::optional<uint64_t> FragmentSize = DPV->getFragmentSizeInBits()) + if (std::optional<uint64_t> FragmentSize = + DVR->getExpression()->getActiveBits(DVR->getVariable())) return TypeSize::isKnownGE(ValueSize, TypeSize::getFixed(*FragmentSize)); // We can't always calculate the size of the DI variable (e.g. if it is a // VLA). Try to use the size of the alloca that the dbg intrinsic describes // intead. - if (DPV->isAddressOfVariable()) { - // DPV should have exactly 1 location when it is an address. - assert(DPV->getNumVariableLocationOps() == 1 && + if (DVR->isAddressOfVariable()) { + // DVR should have exactly 1 location when it is an address. + assert(DVR->getNumVariableLocationOps() == 1 && "address of variable must have exactly 1 location operand."); if (auto *AI = - dyn_cast_or_null<AllocaInst>(DPV->getVariableLocationOp(0))) { + dyn_cast_or_null<AllocaInst>(DVR->getVariableLocationOp(0))) { if (std::optional<TypeSize> FragmentSize = AI->getAllocationSizeInBits(DL)) { return TypeSize::isKnownGE(ValueSize, *FragmentSize); } @@ -1643,39 +1656,39 @@ static bool valueCoversEntireFragment(Type *ValTy, DPValue *DPV) { return false; } -static void insertDbgValueOrDPValue(DIBuilder &Builder, Value *DV, - DILocalVariable *DIVar, - DIExpression *DIExpr, - const DebugLoc &NewLoc, - BasicBlock::iterator Instr) { +static void insertDbgValueOrDbgVariableRecord(DIBuilder &Builder, Value *DV, + DILocalVariable *DIVar, + DIExpression *DIExpr, + const DebugLoc &NewLoc, + BasicBlock::iterator Instr) { if (!UseNewDbgInfoFormat) { - auto *DbgVal = Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc, - (Instruction *)nullptr); - DbgVal->insertBefore(Instr); + auto DbgVal = Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc, + (Instruction *)nullptr); + DbgVal.get<Instruction *>()->insertBefore(Instr); } else { // RemoveDIs: if we're using the new debug-info format, allocate a - // DPValue directly instead of a dbg.value intrinsic. + // DbgVariableRecord directly instead of a dbg.value intrinsic. ValueAsMetadata *DVAM = ValueAsMetadata::get(DV); - DPValue *DV = new DPValue(DVAM, DIVar, DIExpr, NewLoc.get()); - Instr->getParent()->insertDPValueBefore(DV, Instr); + DbgVariableRecord *DV = + new DbgVariableRecord(DVAM, DIVar, DIExpr, NewLoc.get()); + Instr->getParent()->insertDbgRecordBefore(DV, Instr); } } -static void insertDbgValueOrDPValueAfter(DIBuilder &Builder, Value *DV, - DILocalVariable *DIVar, - DIExpression *DIExpr, - const DebugLoc &NewLoc, - BasicBlock::iterator Instr) { +static void insertDbgValueOrDbgVariableRecordAfter( + DIBuilder &Builder, Value *DV, DILocalVariable *DIVar, DIExpression *DIExpr, + const DebugLoc &NewLoc, BasicBlock::iterator Instr) { if (!UseNewDbgInfoFormat) { - auto *DbgVal = Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc, - (Instruction *)nullptr); - DbgVal->insertAfter(&*Instr); + auto DbgVal = Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, NewLoc, + (Instruction *)nullptr); + DbgVal.get<Instruction *>()->insertAfter(&*Instr); } else { // RemoveDIs: if we're using the new debug-info format, allocate a - // DPValue directly instead of a dbg.value intrinsic. + // DbgVariableRecord directly instead of a dbg.value intrinsic. ValueAsMetadata *DVAM = ValueAsMetadata::get(DV); - DPValue *DV = new DPValue(DVAM, DIVar, DIExpr, NewLoc.get()); - Instr->getParent()->insertDPValueAfter(DV, &*Instr); + DbgVariableRecord *DV = + new DbgVariableRecord(DVAM, DIVar, DIExpr, NewLoc.get()); + Instr->getParent()->insertDbgRecordAfter(DV, &*Instr); } } @@ -1707,8 +1720,8 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, DIExpr->isDeref() || (!DIExpr->startsWithDeref() && valueCoversEntireFragment(DV->getType(), DII)); if (CanConvert) { - insertDbgValueOrDPValue(Builder, DV, DIVar, DIExpr, NewLoc, - SI->getIterator()); + insertDbgValueOrDbgVariableRecord(Builder, DV, DIVar, DIExpr, NewLoc, + SI->getIterator()); return; } @@ -1720,8 +1733,8 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, // know which part) we insert an dbg.value intrinsic to indicate that we // know nothing about the variable's content. DV = UndefValue::get(DV->getType()); - insertDbgValueOrDPValue(Builder, DV, DIVar, DIExpr, NewLoc, - SI->getIterator()); + insertDbgValueOrDbgVariableRecord(Builder, DV, DIVar, DIExpr, NewLoc, + SI->getIterator()); } /// Inserts a llvm.dbg.value intrinsic before a load of an alloca'd value @@ -1747,19 +1760,19 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, // future if multi-location support is added to the IR, it might be // preferable to keep tracking both the loaded value and the original // address in case the alloca can not be elided. - insertDbgValueOrDPValueAfter(Builder, LI, DIVar, DIExpr, NewLoc, - LI->getIterator()); + insertDbgValueOrDbgVariableRecordAfter(Builder, LI, DIVar, DIExpr, NewLoc, + LI->getIterator()); } -void llvm::ConvertDebugDeclareToDebugValue(DPValue *DPV, StoreInst *SI, - DIBuilder &Builder) { - assert(DPV->isAddressOfVariable() || DPV->isDbgAssign()); - auto *DIVar = DPV->getVariable(); +void llvm::ConvertDebugDeclareToDebugValue(DbgVariableRecord *DVR, + StoreInst *SI, DIBuilder &Builder) { + assert(DVR->isAddressOfVariable() || DVR->isDbgAssign()); + auto *DIVar = DVR->getVariable(); assert(DIVar && "Missing variable"); - auto *DIExpr = DPV->getExpression(); + auto *DIExpr = DVR->getExpression(); Value *DV = SI->getValueOperand(); - DebugLoc NewLoc = getDebugValueLoc(DPV); + DebugLoc NewLoc = getDebugValueLoc(DVR); // If the alloca describes the variable itself, i.e. the expression in the // dbg.declare doesn't start with a dereference, we can perform the @@ -1775,16 +1788,16 @@ void llvm::ConvertDebugDeclareToDebugValue(DPValue *DPV, StoreInst *SI, // deref expression. bool CanConvert = DIExpr->isDeref() || (!DIExpr->startsWithDeref() && - valueCoversEntireFragment(DV->getType(), DPV)); + valueCoversEntireFragment(DV->getType(), DVR)); if (CanConvert) { - insertDbgValueOrDPValue(Builder, DV, DIVar, DIExpr, NewLoc, - SI->getIterator()); + insertDbgValueOrDbgVariableRecord(Builder, DV, DIVar, DIExpr, NewLoc, + SI->getIterator()); return; } // FIXME: If storing to a part of the variable described by the dbg.declare, // then we want to insert a dbg.value for the corresponding fragment. - LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to dbg.value: " << *DPV + LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to dbg.value: " << *DVR << '\n'); assert(UseNewDbgInfoFormat); @@ -1793,8 +1806,9 @@ void llvm::ConvertDebugDeclareToDebugValue(DPValue *DPV, StoreInst *SI, // know nothing about the variable's content. DV = UndefValue::get(DV->getType()); ValueAsMetadata *DVAM = ValueAsMetadata::get(DV); - DPValue *NewDPV = new DPValue(DVAM, DIVar, DIExpr, NewLoc.get()); - SI->getParent()->insertDPValueBefore(NewDPV, SI->getIterator()); + DbgVariableRecord *NewDVR = + new DbgVariableRecord(DVAM, DIVar, DIExpr, NewLoc.get()); + SI->getParent()->insertDbgRecordBefore(NewDVR, SI->getIterator()); } /// Inserts a llvm.dbg.value intrinsic after a phi that has an associated @@ -1826,26 +1840,27 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, // insertion point. // FIXME: Insert dbg.value markers in the successors when appropriate. if (InsertionPt != BB->end()) { - insertDbgValueOrDPValue(Builder, APN, DIVar, DIExpr, NewLoc, InsertionPt); + insertDbgValueOrDbgVariableRecord(Builder, APN, DIVar, DIExpr, NewLoc, + InsertionPt); } } -void llvm::ConvertDebugDeclareToDebugValue(DPValue *DPV, LoadInst *LI, +void llvm::ConvertDebugDeclareToDebugValue(DbgVariableRecord *DVR, LoadInst *LI, DIBuilder &Builder) { - auto *DIVar = DPV->getVariable(); - auto *DIExpr = DPV->getExpression(); + auto *DIVar = DVR->getVariable(); + auto *DIExpr = DVR->getExpression(); assert(DIVar && "Missing variable"); - if (!valueCoversEntireFragment(LI->getType(), DPV)) { + if (!valueCoversEntireFragment(LI->getType(), DVR)) { // FIXME: If only referring to a part of the variable described by the - // dbg.declare, then we want to insert a DPValue for the corresponding - // fragment. - LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to DPValue: " << *DPV - << '\n'); + // dbg.declare, then we want to insert a DbgVariableRecord for the + // corresponding fragment. + LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to DbgVariableRecord: " + << *DVR << '\n'); return; } - DebugLoc NewLoc = getDebugValueLoc(DPV); + DebugLoc NewLoc = getDebugValueLoc(DVR); // We are now tracking the loaded value instead of the address. In the // future if multi-location support is added to the IR, it might be @@ -1853,10 +1868,11 @@ void llvm::ConvertDebugDeclareToDebugValue(DPValue *DPV, LoadInst *LI, // address in case the alloca can not be elided. assert(UseNewDbgInfoFormat); - // Create a DPValue directly and insert. + // Create a DbgVariableRecord directly and insert. ValueAsMetadata *LIVAM = ValueAsMetadata::get(LI); - DPValue *DV = new DPValue(LIVAM, DIVar, DIExpr, NewLoc.get()); - LI->getParent()->insertDPValueAfter(DV, LI); + DbgVariableRecord *DV = + new DbgVariableRecord(LIVAM, DIVar, DIExpr, NewLoc.get()); + LI->getParent()->insertDbgRecordAfter(DV, LI); } /// Determine whether this alloca is either a VLA or an array. @@ -1869,34 +1885,35 @@ static bool isArray(AllocaInst *AI) { static bool isStructure(AllocaInst *AI) { return AI->getAllocatedType() && AI->getAllocatedType()->isStructTy(); } -void llvm::ConvertDebugDeclareToDebugValue(DPValue *DPV, PHINode *APN, +void llvm::ConvertDebugDeclareToDebugValue(DbgVariableRecord *DVR, PHINode *APN, DIBuilder &Builder) { - auto *DIVar = DPV->getVariable(); - auto *DIExpr = DPV->getExpression(); + auto *DIVar = DVR->getVariable(); + auto *DIExpr = DVR->getExpression(); assert(DIVar && "Missing variable"); if (PhiHasDebugValue(DIVar, DIExpr, APN)) return; - if (!valueCoversEntireFragment(APN->getType(), DPV)) { + if (!valueCoversEntireFragment(APN->getType(), DVR)) { // FIXME: If only referring to a part of the variable described by the - // dbg.declare, then we want to insert a DPValue for the corresponding - // fragment. - LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to DPValue: " << *DPV - << '\n'); + // dbg.declare, then we want to insert a DbgVariableRecord for the + // corresponding fragment. + LLVM_DEBUG(dbgs() << "Failed to convert dbg.declare to DbgVariableRecord: " + << *DVR << '\n'); return; } BasicBlock *BB = APN->getParent(); auto InsertionPt = BB->getFirstInsertionPt(); - DebugLoc NewLoc = getDebugValueLoc(DPV); + DebugLoc NewLoc = getDebugValueLoc(DVR); // The block may be a catchswitch block, which does not have a valid // insertion point. - // FIXME: Insert DPValue markers in the successors when appropriate. + // FIXME: Insert DbgVariableRecord markers in the successors when appropriate. if (InsertionPt != BB->end()) { - insertDbgValueOrDPValue(Builder, APN, DIVar, DIExpr, NewLoc, InsertionPt); + insertDbgValueOrDbgVariableRecord(Builder, APN, DIVar, DIExpr, NewLoc, + InsertionPt); } } @@ -1906,19 +1923,19 @@ bool llvm::LowerDbgDeclare(Function &F) { bool Changed = false; DIBuilder DIB(*F.getParent(), /*AllowUnresolved*/ false); SmallVector<DbgDeclareInst *, 4> Dbgs; - SmallVector<DPValue *> DPVs; + SmallVector<DbgVariableRecord *> DVRs; for (auto &FI : F) { for (Instruction &BI : FI) { if (auto *DDI = dyn_cast<DbgDeclareInst>(&BI)) Dbgs.push_back(DDI); - for (DPValue &DPV : BI.getDbgValueRange()) { - if (DPV.getType() == DPValue::LocationType::Declare) - DPVs.push_back(&DPV); + for (DbgVariableRecord &DVR : filterDbgVars(BI.getDbgRecordRange())) { + if (DVR.getType() == DbgVariableRecord::LocationType::Declare) + DVRs.push_back(&DVR); } } } - if (Dbgs.empty() && DPVs.empty()) + if (Dbgs.empty() && DVRs.empty()) return Changed; auto LowerOne = [&](auto *DDI) { @@ -1962,8 +1979,9 @@ bool llvm::LowerDbgDeclare(Function &F) { DebugLoc NewLoc = getDebugValueLoc(DDI); auto *DerefExpr = DIExpression::append(DDI->getExpression(), dwarf::DW_OP_deref); - insertDbgValueOrDPValue(DIB, AI, DDI->getVariable(), DerefExpr, - NewLoc, CI->getIterator()); + insertDbgValueOrDbgVariableRecord(DIB, AI, DDI->getVariable(), + DerefExpr, NewLoc, + CI->getIterator()); } } else if (BitCastInst *BI = dyn_cast<BitCastInst>(U)) { if (BI->getType()->isPointerTy()) @@ -1976,7 +1994,7 @@ bool llvm::LowerDbgDeclare(Function &F) { }; for_each(Dbgs, LowerOne); - for_each(DPVs, LowerOne); + for_each(DVRs, LowerOne); if (Changed) for (BasicBlock &BB : F) @@ -1986,35 +2004,39 @@ bool llvm::LowerDbgDeclare(Function &F) { } // RemoveDIs: re-implementation of insertDebugValuesForPHIs, but which pulls the -// debug-info out of the block's DPValues rather than dbg.value intrinsics. -static void insertDPValuesForPHIs(BasicBlock *BB, - SmallVectorImpl<PHINode *> &InsertedPHIs) { - assert(BB && "No BasicBlock to clone DPValue(s) from."); +// debug-info out of the block's DbgVariableRecords rather than dbg.value +// intrinsics. +static void +insertDbgVariableRecordsForPHIs(BasicBlock *BB, + SmallVectorImpl<PHINode *> &InsertedPHIs) { + assert(BB && "No BasicBlock to clone DbgVariableRecord(s) from."); if (InsertedPHIs.size() == 0) return; - // Map existing PHI nodes to their DPValues. - DenseMap<Value *, DPValue *> DbgValueMap; + // Map existing PHI nodes to their DbgVariableRecords. + DenseMap<Value *, DbgVariableRecord *> DbgValueMap; for (auto &I : *BB) { - for (auto &DPV : I.getDbgValueRange()) { - for (Value *V : DPV.location_ops()) + for (DbgVariableRecord &DVR : filterDbgVars(I.getDbgRecordRange())) { + for (Value *V : DVR.location_ops()) if (auto *Loc = dyn_cast_or_null<PHINode>(V)) - DbgValueMap.insert({Loc, &DPV}); + DbgValueMap.insert({Loc, &DVR}); } } if (DbgValueMap.size() == 0) return; - // Map a pair of the destination BB and old DPValue to the new DPValue, - // so that if a DPValue is being rewritten to use more than one of the - // inserted PHIs in the same destination BB, we can update the same DPValue - // with all the new PHIs instead of creating one copy for each. - MapVector<std::pair<BasicBlock *, DPValue *>, DPValue *> NewDbgValueMap; + // Map a pair of the destination BB and old DbgVariableRecord to the new + // DbgVariableRecord, so that if a DbgVariableRecord is being rewritten to use + // more than one of the inserted PHIs in the same destination BB, we can + // update the same DbgVariableRecord with all the new PHIs instead of creating + // one copy for each. + MapVector<std::pair<BasicBlock *, DbgVariableRecord *>, DbgVariableRecord *> + NewDbgValueMap; // Then iterate through the new PHIs and look to see if they use one of the - // previously mapped PHIs. If so, create a new DPValue that will propagate - // the info through the new PHI. If we use more than one new PHI in a single - // destination BB with the same old dbg.value, merge the updates so that we - // get a single new DPValue with all the new PHIs. + // previously mapped PHIs. If so, create a new DbgVariableRecord that will + // propagate the info through the new PHI. If we use more than one new PHI in + // a single destination BB with the same old dbg.value, merge the updates so + // that we get a single new DbgVariableRecord with all the new PHIs. for (auto PHI : InsertedPHIs) { BasicBlock *Parent = PHI->getParent(); // Avoid inserting a debug-info record into an EH block. @@ -2023,13 +2045,13 @@ static void insertDPValuesForPHIs(BasicBlock *BB, for (auto VI : PHI->operand_values()) { auto V = DbgValueMap.find(VI); if (V != DbgValueMap.end()) { - DPValue *DbgII = cast<DPValue>(V->second); + DbgVariableRecord *DbgII = cast<DbgVariableRecord>(V->second); auto NewDI = NewDbgValueMap.find({Parent, DbgII}); if (NewDI == NewDbgValueMap.end()) { - DPValue *NewDbgII = DbgII->clone(); + DbgVariableRecord *NewDbgII = DbgII->clone(); NewDI = NewDbgValueMap.insert({{Parent, DbgII}, NewDbgII}).first; } - DPValue *NewDbgII = NewDI->second; + DbgVariableRecord *NewDbgII = NewDI->second; // If PHI contains VI as an operand more than once, we may // replaced it in NewDbgII; confirm that it is present. if (is_contained(NewDbgII->location_ops(), VI)) @@ -2037,14 +2059,14 @@ static void insertDPValuesForPHIs(BasicBlock *BB, } } } - // Insert the new DPValues into their destination blocks. + // Insert the new DbgVariableRecords into their destination blocks. for (auto DI : NewDbgValueMap) { BasicBlock *Parent = DI.first.first; - DPValue *NewDbgII = DI.second; + DbgVariableRecord *NewDbgII = DI.second; auto InsertionPt = Parent->getFirstInsertionPt(); assert(InsertionPt != Parent->end() && "Ill-formed basic block"); - InsertionPt->DbgMarker->insertDPValue(NewDbgII, true); + Parent->insertDbgRecordBefore(NewDbgII, InsertionPt); } } @@ -2055,7 +2077,7 @@ void llvm::insertDebugValuesForPHIs(BasicBlock *BB, if (InsertedPHIs.size() == 0) return; - insertDPValuesForPHIs(BB, InsertedPHIs); + insertDbgVariableRecordsForPHIs(BB, InsertedPHIs); // Map existing PHI nodes to their dbg.values. ValueToValueMapTy DbgValueMap; @@ -2117,7 +2139,7 @@ bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress, DIBuilder &Builder, uint8_t DIExprFlags, int Offset) { TinyPtrVector<DbgDeclareInst *> DbgDeclares = findDbgDeclares(Address); - TinyPtrVector<DPValue *> DPVDeclares = findDPVDeclares(Address); + TinyPtrVector<DbgVariableRecord *> DVRDeclares = findDVRDeclares(Address); auto ReplaceOne = [&](auto *DII) { assert(DII->getVariable() && "Missing variable"); @@ -2128,21 +2150,22 @@ bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress, }; for_each(DbgDeclares, ReplaceOne); - for_each(DPVDeclares, ReplaceOne); + for_each(DVRDeclares, ReplaceOne); - return !DbgDeclares.empty() || !DPVDeclares.empty(); + return !DbgDeclares.empty() || !DVRDeclares.empty(); } static void updateOneDbgValueForAlloca(const DebugLoc &Loc, DILocalVariable *DIVar, DIExpression *DIExpr, Value *NewAddress, - DbgValueInst *DVI, DPValue *DPV, + DbgValueInst *DVI, + DbgVariableRecord *DVR, DIBuilder &Builder, int Offset) { assert(DIVar && "Missing variable"); - // This is an alloca-based dbg.value/DPValue. The first thing it should do - // with the alloca pointer is dereference it. Otherwise we don't know how to - // handle it and give up. + // This is an alloca-based dbg.value/DbgVariableRecord. The first thing it + // should do with the alloca pointer is dereference it. Otherwise we don't + // know how to handle it and give up. if (!DIExpr || DIExpr->getNumElements() < 1 || DIExpr->getElement(0) != dwarf::DW_OP_deref) return; @@ -2155,16 +2178,16 @@ static void updateOneDbgValueForAlloca(const DebugLoc &Loc, DVI->setExpression(DIExpr); DVI->replaceVariableLocationOp(0u, NewAddress); } else { - assert(DPV); - DPV->setExpression(DIExpr); - DPV->replaceVariableLocationOp(0u, NewAddress); + assert(DVR); + DVR->setExpression(DIExpr); + DVR->replaceVariableLocationOp(0u, NewAddress); } } void llvm::replaceDbgValueForAlloca(AllocaInst *AI, Value *NewAllocaAddress, DIBuilder &Builder, int Offset) { SmallVector<DbgValueInst *, 1> DbgUsers; - SmallVector<DPValue *, 1> DPUsers; + SmallVector<DbgVariableRecord *, 1> DPUsers; findDbgValues(DbgUsers, AI, &DPUsers); // Attempt to replace dbg.values that use this alloca. @@ -2173,18 +2196,18 @@ void llvm::replaceDbgValueForAlloca(AllocaInst *AI, Value *NewAllocaAddress, DVI->getExpression(), NewAllocaAddress, DVI, nullptr, Builder, Offset); - // Replace any DPValues that use this alloca. - for (DPValue *DPV : DPUsers) - updateOneDbgValueForAlloca(DPV->getDebugLoc(), DPV->getVariable(), - DPV->getExpression(), NewAllocaAddress, nullptr, - DPV, Builder, Offset); + // Replace any DbgVariableRecords that use this alloca. + for (DbgVariableRecord *DVR : DPUsers) + updateOneDbgValueForAlloca(DVR->getDebugLoc(), DVR->getVariable(), + DVR->getExpression(), NewAllocaAddress, nullptr, + DVR, Builder, Offset); } /// Where possible to salvage debug information for \p I do so. /// If not possible mark undef. void llvm::salvageDebugInfo(Instruction &I) { SmallVector<DbgVariableIntrinsic *, 1> DbgUsers; - SmallVector<DPValue *, 1> DPUsers; + SmallVector<DbgVariableRecord *, 1> DPUsers; findDbgUsers(DbgUsers, &I, &DPUsers); salvageDebugInfoForDbgValues(I, DbgUsers, DPUsers); } @@ -2213,6 +2236,8 @@ template <typename T> static void salvageDbgAssignAddress(T *Assign) { assert(!SalvagedExpr->getFragmentInfo().has_value() && "address-expression shouldn't have fragment info"); + SalvagedExpr = SalvagedExpr->foldConstantMath(); + // Salvage succeeds if no additional values are required. if (AdditionalValues.empty()) { Assign->setAddress(NewV); @@ -2224,7 +2249,7 @@ template <typename T> static void salvageDbgAssignAddress(T *Assign) { void llvm::salvageDebugInfoForDbgValues( Instruction &I, ArrayRef<DbgVariableIntrinsic *> DbgUsers, - ArrayRef<DPValue *> DPUsers) { + ArrayRef<DbgVariableRecord *> DPUsers) { // These are arbitrary chosen limits on the maximum number of values and the // maximum size of a debug expression we can salvage up to, used for // performance reasons. @@ -2273,6 +2298,7 @@ void llvm::salvageDebugInfoForDbgValues( if (!Op0) break; + SalvagedExpr = SalvagedExpr->foldConstantMath(); DII->replaceVariableLocationOp(&I, Op0); bool IsValidSalvageExpr = SalvagedExpr->getNumElements() <= MaxExpressionSize; if (AdditionalValues.empty() && IsValidSalvageExpr) { @@ -2290,67 +2316,69 @@ void llvm::salvageDebugInfoForDbgValues( LLVM_DEBUG(dbgs() << "SALVAGE: " << *DII << '\n'); Salvaged = true; } - // Duplicate of above block for DPValues. - for (auto *DPV : DPUsers) { - if (DPV->isDbgAssign()) { - if (DPV->getAddress() == &I) { - salvageDbgAssignAddress(DPV); + // Duplicate of above block for DbgVariableRecords. + for (auto *DVR : DPUsers) { + if (DVR->isDbgAssign()) { + if (DVR->getAddress() == &I) { + salvageDbgAssignAddress(DVR); Salvaged = true; } - if (DPV->getValue() != &I) + if (DVR->getValue() != &I) continue; } // Do not add DW_OP_stack_value for DbgDeclare and DbgAddr, because they // are implicitly pointing out the value as a DWARF memory location // description. - bool StackValue = DPV->getType() != DPValue::LocationType::Declare; - auto DPVLocation = DPV->location_ops(); + bool StackValue = + DVR->getType() != DbgVariableRecord::LocationType::Declare; + auto DVRLocation = DVR->location_ops(); assert( - is_contained(DPVLocation, &I) && + is_contained(DVRLocation, &I) && "DbgVariableIntrinsic must use salvaged instruction as its location"); SmallVector<Value *, 4> AdditionalValues; - // 'I' may appear more than once in DPV's location ops, and each use of 'I' + // 'I' may appear more than once in DVR's location ops, and each use of 'I' // must be updated in the DIExpression and potentially have additional // values added; thus we call salvageDebugInfoImpl for each 'I' instance in - // DPVLocation. + // DVRLocation. Value *Op0 = nullptr; - DIExpression *SalvagedExpr = DPV->getExpression(); - auto LocItr = find(DPVLocation, &I); - while (SalvagedExpr && LocItr != DPVLocation.end()) { + DIExpression *SalvagedExpr = DVR->getExpression(); + auto LocItr = find(DVRLocation, &I); + while (SalvagedExpr && LocItr != DVRLocation.end()) { SmallVector<uint64_t, 16> Ops; - unsigned LocNo = std::distance(DPVLocation.begin(), LocItr); + unsigned LocNo = std::distance(DVRLocation.begin(), LocItr); uint64_t CurrentLocOps = SalvagedExpr->getNumLocationOperands(); Op0 = salvageDebugInfoImpl(I, CurrentLocOps, Ops, AdditionalValues); if (!Op0) break; SalvagedExpr = DIExpression::appendOpsToArg(SalvagedExpr, Ops, LocNo, StackValue); - LocItr = std::find(++LocItr, DPVLocation.end(), &I); + LocItr = std::find(++LocItr, DVRLocation.end(), &I); } // salvageDebugInfoImpl should fail on examining the first element of // DbgUsers, or none of them. if (!Op0) break; - DPV->replaceVariableLocationOp(&I, Op0); + SalvagedExpr = SalvagedExpr->foldConstantMath(); + DVR->replaceVariableLocationOp(&I, Op0); bool IsValidSalvageExpr = SalvagedExpr->getNumElements() <= MaxExpressionSize; if (AdditionalValues.empty() && IsValidSalvageExpr) { - DPV->setExpression(SalvagedExpr); - } else if (DPV->getType() != DPValue::LocationType::Declare && + DVR->setExpression(SalvagedExpr); + } else if (DVR->getType() != DbgVariableRecord::LocationType::Declare && IsValidSalvageExpr && - DPV->getNumVariableLocationOps() + AdditionalValues.size() <= + DVR->getNumVariableLocationOps() + AdditionalValues.size() <= MaxDebugArgs) { - DPV->addVariableLocationOps(AdditionalValues, SalvagedExpr); + DVR->addVariableLocationOps(AdditionalValues, SalvagedExpr); } else { // Do not salvage using DIArgList for dbg.addr/dbg.declare, as it is // currently only valid for stack value expressions. // Also do not salvage if the resulting DIArgList would contain an // unreasonably large number of values. - DPV->setKillLocation(); + DVR->setKillLocation(); } - LLVM_DEBUG(dbgs() << "SALVAGE: " << DPV << '\n'); + LLVM_DEBUG(dbgs() << "SALVAGE: " << DVR << '\n'); Salvaged = true; } @@ -2360,8 +2388,8 @@ void llvm::salvageDebugInfoForDbgValues( for (auto *DII : DbgUsers) DII->setKillLocation(); - for (auto *DPV : DPUsers) - DPV->setKillLocation(); + for (auto *DVR : DPUsers) + DVR->setKillLocation(); } Value *getSalvageOpsForGEP(GetElementPtrInst *GEP, const DataLayout &DL, @@ -2577,10 +2605,10 @@ using DbgValReplacement = std::optional<DIExpression *>; static bool rewriteDebugUsers( Instruction &From, Value &To, Instruction &DomPoint, DominatorTree &DT, function_ref<DbgValReplacement(DbgVariableIntrinsic &DII)> RewriteExpr, - function_ref<DbgValReplacement(DPValue &DPV)> RewriteDPVExpr) { + function_ref<DbgValReplacement(DbgVariableRecord &DVR)> RewriteDVRExpr) { // Find debug users of From. SmallVector<DbgVariableIntrinsic *, 1> Users; - SmallVector<DPValue *, 1> DPUsers; + SmallVector<DbgVariableRecord *, 1> DPUsers; findDbgUsers(Users, &From, &DPUsers); if (Users.empty() && DPUsers.empty()) return false; @@ -2589,7 +2617,7 @@ static bool rewriteDebugUsers( bool Changed = false; SmallPtrSet<DbgVariableIntrinsic *, 1> UndefOrSalvage; - SmallPtrSet<DPValue *, 1> UndefOrSalvageDPV; + SmallPtrSet<DbgVariableRecord *, 1> UndefOrSalvageDVR; if (isa<Instruction>(&To)) { bool DomPointAfterFrom = From.getNextNonDebugInstruction() == &DomPoint; @@ -2608,22 +2636,22 @@ static bool rewriteDebugUsers( } } - // DPValue implementation of the above. - for (auto *DPV : DPUsers) { - Instruction *MarkedInstr = DPV->getMarker()->MarkedInstr; + // DbgVariableRecord implementation of the above. + for (auto *DVR : DPUsers) { + Instruction *MarkedInstr = DVR->getMarker()->MarkedInstr; Instruction *NextNonDebug = MarkedInstr; // The next instruction might still be a dbg.declare, skip over it. if (isa<DbgVariableIntrinsic>(NextNonDebug)) NextNonDebug = NextNonDebug->getNextNonDebugInstruction(); if (DomPointAfterFrom && NextNonDebug == &DomPoint) { - LLVM_DEBUG(dbgs() << "MOVE: " << *DPV << '\n'); - DPV->removeFromParent(); + LLVM_DEBUG(dbgs() << "MOVE: " << *DVR << '\n'); + DVR->removeFromParent(); // Ensure there's a marker. - DomPoint.getParent()->insertDPValueAfter(DPV, &DomPoint); + DomPoint.getParent()->insertDbgRecordAfter(DVR, &DomPoint); Changed = true; } else if (!DT.dominates(&DomPoint, MarkedInstr)) { - UndefOrSalvageDPV.insert(DPV); + UndefOrSalvageDVR.insert(DVR); } } } @@ -2633,30 +2661,30 @@ static bool rewriteDebugUsers( if (UndefOrSalvage.count(DII)) continue; - DbgValReplacement DVR = RewriteExpr(*DII); - if (!DVR) + DbgValReplacement DVRepl = RewriteExpr(*DII); + if (!DVRepl) continue; DII->replaceVariableLocationOp(&From, &To); - DII->setExpression(*DVR); + DII->setExpression(*DVRepl); LLVM_DEBUG(dbgs() << "REWRITE: " << *DII << '\n'); Changed = true; } - for (auto *DPV : DPUsers) { - if (UndefOrSalvageDPV.count(DPV)) + for (auto *DVR : DPUsers) { + if (UndefOrSalvageDVR.count(DVR)) continue; - DbgValReplacement DVR = RewriteDPVExpr(*DPV); - if (!DVR) + DbgValReplacement DVRepl = RewriteDVRExpr(*DVR); + if (!DVRepl) continue; - DPV->replaceVariableLocationOp(&From, &To); - DPV->setExpression(*DVR); - LLVM_DEBUG(dbgs() << "REWRITE: " << DPV << '\n'); + DVR->replaceVariableLocationOp(&From, &To); + DVR->setExpression(*DVRepl); + LLVM_DEBUG(dbgs() << "REWRITE: " << DVR << '\n'); Changed = true; } - if (!UndefOrSalvage.empty() || !UndefOrSalvageDPV.empty()) { + if (!UndefOrSalvage.empty() || !UndefOrSalvageDVR.empty()) { // Try to salvage the remaining debug users. salvageDebugInfo(From); Changed = true; @@ -2704,15 +2732,15 @@ bool llvm::replaceAllDbgUsesWith(Instruction &From, Value &To, auto Identity = [&](DbgVariableIntrinsic &DII) -> DbgValReplacement { return DII.getExpression(); }; - auto IdentityDPV = [&](DPValue &DPV) -> DbgValReplacement { - return DPV.getExpression(); + auto IdentityDVR = [&](DbgVariableRecord &DVR) -> DbgValReplacement { + return DVR.getExpression(); }; // Handle no-op conversions. Module &M = *From.getModule(); const DataLayout &DL = M.getDataLayout(); if (isBitCastSemanticsPreserving(DL, FromTy, ToTy)) - return rewriteDebugUsers(From, To, DomPoint, DT, Identity, IdentityDPV); + return rewriteDebugUsers(From, To, DomPoint, DT, Identity, IdentityDVR); // Handle integer-to-integer widening and narrowing. // FIXME: Use DW_OP_convert when it's available everywhere. @@ -2724,7 +2752,7 @@ bool llvm::replaceAllDbgUsesWith(Instruction &From, Value &To, // When the width of the result grows, assume that a debugger will only // access the low `FromBits` bits when inspecting the source variable. if (FromBits < ToBits) - return rewriteDebugUsers(From, To, DomPoint, DT, Identity, IdentityDPV); + return rewriteDebugUsers(From, To, DomPoint, DT, Identity, IdentityDVR); // The width of the result has shrunk. Use sign/zero extension to describe // the source variable's high bits. @@ -2740,10 +2768,10 @@ bool llvm::replaceAllDbgUsesWith(Instruction &From, Value &To, return DIExpression::appendExt(DII.getExpression(), ToBits, FromBits, Signed); }; - // RemoveDIs: duplicate implementation working on DPValues rather than on - // dbg.value intrinsics. - auto SignOrZeroExtDPV = [&](DPValue &DPV) -> DbgValReplacement { - DILocalVariable *Var = DPV.getVariable(); + // RemoveDIs: duplicate implementation working on DbgVariableRecords rather + // than on dbg.value intrinsics. + auto SignOrZeroExtDVR = [&](DbgVariableRecord &DVR) -> DbgValReplacement { + DILocalVariable *Var = DVR.getVariable(); // Without knowing signedness, sign/zero extension isn't possible. auto Signedness = Var->getSignedness(); @@ -2751,17 +2779,34 @@ bool llvm::replaceAllDbgUsesWith(Instruction &From, Value &To, return std::nullopt; bool Signed = *Signedness == DIBasicType::Signedness::Signed; - return DIExpression::appendExt(DPV.getExpression(), ToBits, FromBits, + return DIExpression::appendExt(DVR.getExpression(), ToBits, FromBits, Signed); }; return rewriteDebugUsers(From, To, DomPoint, DT, SignOrZeroExt, - SignOrZeroExtDPV); + SignOrZeroExtDVR); } // TODO: Floating-point conversions, vectors. return false; } +bool llvm::handleUnreachableTerminator( + Instruction *I, SmallVectorImpl<Value *> &PoisonedValues) { + bool Changed = false; + // RemoveDIs: erase debug-info on this instruction manually. + I->dropDbgRecords(); + for (Use &U : I->operands()) { + Value *Op = U.get(); + if (isa<Instruction>(Op) && !Op->getType()->isTokenTy()) { + U.set(PoisonValue::get(Op->getType())); + PoisonedValues.push_back(Op); + Changed = true; + } + } + + return Changed; +} + std::pair<unsigned, unsigned> llvm::removeAllNonTerminatorAndEHPadInstructions(BasicBlock *BB) { unsigned NumDeadInst = 0; @@ -2769,17 +2814,18 @@ llvm::removeAllNonTerminatorAndEHPadInstructions(BasicBlock *BB) { // Delete the instructions backwards, as it has a reduced likelihood of // having to update as many def-use and use-def chains. Instruction *EndInst = BB->getTerminator(); // Last not to be deleted. - // RemoveDIs: erasing debug-info must be done manually. - EndInst->dropDbgValues(); + SmallVector<Value *> Uses; + handleUnreachableTerminator(EndInst, Uses); + while (EndInst != &BB->front()) { // Delete the next to last instruction. Instruction *Inst = &*--EndInst->getIterator(); if (!Inst->use_empty() && !Inst->getType()->isTokenTy()) Inst->replaceAllUsesWith(PoisonValue::get(Inst->getType())); if (Inst->isEHPad() || Inst->getType()->isTokenTy()) { - // EHPads can't have DPValues attached to them, but it might be possible - // for things with token type. - Inst->dropDbgValues(); + // EHPads can't have DbgVariableRecords attached to them, but it might be + // possible for things with token type. + Inst->dropDbgRecords(); EndInst = Inst; continue; } @@ -2788,7 +2834,7 @@ llvm::removeAllNonTerminatorAndEHPadInstructions(BasicBlock *BB) { else ++NumDeadInst; // RemoveDIs: erasing debug-info must be done manually. - Inst->dropDbgValues(); + Inst->dropDbgRecords(); Inst->eraseFromParent(); } return {NumDeadInst, NumDeadDbgInst}; @@ -2811,7 +2857,7 @@ unsigned llvm::changeToUnreachable(Instruction *I, bool PreserveLCSSA, if (DTU) UniqueSuccessors.insert(Successor); } - auto *UI = new UnreachableInst(I->getContext(), I); + auto *UI = new UnreachableInst(I->getContext(), I->getIterator()); UI->setDebugLoc(I->getDebugLoc()); // All instructions after this are dead. @@ -2830,7 +2876,7 @@ unsigned llvm::changeToUnreachable(Instruction *I, bool PreserveLCSSA, Updates.push_back({DominatorTree::Delete, BB, UniqueSuccessor}); DTU->applyUpdates(Updates); } - BB->flushTerminatorDbgValues(); + BB->flushTerminatorDbgRecords(); return NumInstrsRemoved; } @@ -2868,7 +2914,7 @@ CallInst *llvm::changeToCall(InvokeInst *II, DomTreeUpdater *DTU) { // Follow the call by a branch to the normal destination. BasicBlock *NormalDestBB = II->getNormalDest(); - BranchInst::Create(NormalDestBB, II); + BranchInst::Create(NormalDestBB, II->getIterator()); // Update PHI nodes in the unwind destination BasicBlock *BB = II->getParent(); @@ -3048,7 +3094,7 @@ static bool markAliveBlocks(Function &F, // jump to the normal destination branch. BasicBlock *NormalDestBB = II->getNormalDest(); BasicBlock *UnwindDestBB = II->getUnwindDest(); - BranchInst::Create(NormalDestBB, II); + BranchInst::Create(NormalDestBB, II->getIterator()); UnwindDestBB->removePredecessor(II->getParent()); II->eraseFromParent(); if (DTU) @@ -3131,12 +3177,12 @@ Instruction *llvm::removeUnwindEdge(BasicBlock *BB, DomTreeUpdater *DTU) { BasicBlock *UnwindDest; if (auto *CRI = dyn_cast<CleanupReturnInst>(TI)) { - NewTI = CleanupReturnInst::Create(CRI->getCleanupPad(), nullptr, CRI); + NewTI = CleanupReturnInst::Create(CRI->getCleanupPad(), nullptr, CRI->getIterator()); UnwindDest = CRI->getUnwindDest(); } else if (auto *CatchSwitch = dyn_cast<CatchSwitchInst>(TI)) { auto *NewCatchSwitch = CatchSwitchInst::Create( CatchSwitch->getParentPad(), nullptr, CatchSwitch->getNumHandlers(), - CatchSwitch->getName(), CatchSwitch); + CatchSwitch->getName(), CatchSwitch->getIterator()); for (BasicBlock *PadBB : CatchSwitch->handlers()) NewCatchSwitch->addHandler(PadBB); @@ -3249,6 +3295,9 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J, case LLVMContext::MD_invariant_group: // Preserve !invariant.group in K. break; + case LLVMContext::MD_mmra: + // Combine MMRAs + break; case LLVMContext::MD_align: if (DoesKMove || !K->hasMetadata(LLVMContext::MD_noundef)) K->setMetadata( @@ -3287,6 +3336,16 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J, if (auto *JMD = J->getMetadata(LLVMContext::MD_invariant_group)) if (isa<LoadInst>(K) || isa<StoreInst>(K)) K->setMetadata(LLVMContext::MD_invariant_group, JMD); + + // Merge MMRAs. + // This is handled separately because we also want to handle cases where K + // doesn't have tags but J does. + auto JMMRA = J->getMetadata(LLVMContext::MD_mmra); + auto KMMRA = K->getMetadata(LLVMContext::MD_mmra); + if (JMMRA || KMMRA) { + K->setMetadata(LLVMContext::MD_mmra, + MMRAMetadata::combine(K->getContext(), JMMRA, KMMRA)); + } } void llvm::combineMetadataForCSE(Instruction *K, const Instruction *J, @@ -3306,7 +3365,8 @@ void llvm::combineMetadataForCSE(Instruction *K, const Instruction *J, LLVMContext::MD_preserve_access_index, LLVMContext::MD_prof, LLVMContext::MD_nontemporal, - LLVMContext::MD_noundef}; + LLVMContext::MD_noundef, + LLVMContext::MD_mmra}; combineMetadata(K, J, KnownIDs, KDominatesJ); } @@ -3315,7 +3375,7 @@ void llvm::copyMetadataForLoad(LoadInst &Dest, const LoadInst &Source) { Source.getAllMetadata(MD); MDBuilder MDB(Dest.getContext()); Type *NewType = Dest.getType(); - const DataLayout &DL = Source.getModule()->getDataLayout(); + const DataLayout &DL = Source.getDataLayout(); for (const auto &MDPair : MD) { unsigned ID = MDPair.first; MDNode *N = MDPair.second; @@ -3394,15 +3454,15 @@ void llvm::patchReplacementInstruction(Instruction *I, Value *Repl) { combineMetadataForCSE(ReplInst, I, false); } -template <typename RootType, typename DominatesFn> +template <typename RootType, typename ShouldReplaceFn> static unsigned replaceDominatedUsesWith(Value *From, Value *To, const RootType &Root, - const DominatesFn &Dominates) { + const ShouldReplaceFn &ShouldReplace) { assert(From->getType() == To->getType()); unsigned Count = 0; for (Use &U : llvm::make_early_inc_range(From->uses())) { - if (!Dominates(Root, U)) + if (!ShouldReplace(Root, U)) continue; LLVM_DEBUG(dbgs() << "Replace dominated use of '"; From->printAsOperand(dbgs()); @@ -3446,6 +3506,26 @@ unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To, return ::replaceDominatedUsesWith(From, To, BB, Dominates); } +unsigned llvm::replaceDominatedUsesWithIf( + Value *From, Value *To, DominatorTree &DT, const BasicBlockEdge &Root, + function_ref<bool(const Use &U, const Value *To)> ShouldReplace) { + auto DominatesAndShouldReplace = + [&DT, &ShouldReplace, To](const BasicBlockEdge &Root, const Use &U) { + return DT.dominates(Root, U) && ShouldReplace(U, To); + }; + return ::replaceDominatedUsesWith(From, To, Root, DominatesAndShouldReplace); +} + +unsigned llvm::replaceDominatedUsesWithIf( + Value *From, Value *To, DominatorTree &DT, const BasicBlock *BB, + function_ref<bool(const Use &U, const Value *To)> ShouldReplace) { + auto DominatesAndShouldReplace = [&DT, &ShouldReplace, + To](const BasicBlock *BB, const Use &U) { + return DT.dominates(BB, U) && ShouldReplace(U, To); + }; + return ::replaceDominatedUsesWith(From, To, BB, DominatesAndShouldReplace); +} + bool llvm::callsGCLeafFunction(const CallBase *Call, const TargetLibraryInfo &TLI) { // Check if the function is specifically marked as a gc leaf function. @@ -3526,12 +3606,12 @@ void llvm::copyRangeMetadata(const DataLayout &DL, const LoadInst &OldLI, void llvm::dropDebugUsers(Instruction &I) { SmallVector<DbgVariableIntrinsic *, 1> DbgUsers; - SmallVector<DPValue *, 1> DPUsers; + SmallVector<DbgVariableRecord *, 1> DPUsers; findDbgUsers(DbgUsers, &I, &DPUsers); for (auto *DII : DbgUsers) DII->eraseFromParent(); - for (auto *DPV : DPUsers) - DPV->eraseFromParent(); + for (auto *DVR : DPUsers) + DVR->eraseFromParent(); } void llvm::hoistAllInstructionsInto(BasicBlock *DomBlock, Instruction *InsertPt, @@ -3564,7 +3644,7 @@ void llvm::hoistAllInstructionsInto(BasicBlock *DomBlock, Instruction *InsertPt, if (I->isUsedByMetadata()) dropDebugUsers(*I); // RemoveDIs: drop debug-info too as the following code does. - I->dropDbgValues(); + I->dropDbgRecords(); if (I->isDebugOrPseudoInst()) { // Remove DbgInfo and pseudo probe Intrinsics. II = I->eraseFromParent(); @@ -3592,10 +3672,12 @@ DIExpression *llvm::getExpressionForConstant(DIBuilder &DIB, const Constant &C, return createIntegerExpression(C); auto *FP = dyn_cast<ConstantFP>(&C); - if (FP && (Ty.isFloatTy() || Ty.isDoubleTy())) { + if (FP && Ty.isFloatingPointTy() && Ty.getScalarSizeInBits() <= 64) { const APFloat &APF = FP->getValueAPF(); - return DIB.createConstantValueExpression( - APF.bitcastToAPInt().getZExtValue()); + APInt const &API = APF.bitcastToAPInt(); + if (auto Temp = API.getZExtValue()) + return DIB.createConstantValueExpression(static_cast<uint64_t>(Temp)); + return DIB.createConstantValueExpression(*API.getRawData()); } if (!Ty.isPointerTy()) @@ -3613,6 +3695,30 @@ DIExpression *llvm::getExpressionForConstant(DIBuilder &DIB, const Constant &C, return nullptr; } +void llvm::remapDebugVariable(ValueToValueMapTy &Mapping, Instruction *Inst) { + auto RemapDebugOperands = [&Mapping](auto *DV, auto Set) { + for (auto *Op : Set) { + auto I = Mapping.find(Op); + if (I != Mapping.end()) + DV->replaceVariableLocationOp(Op, I->second, /*AllowEmpty=*/true); + } + }; + auto RemapAssignAddress = [&Mapping](auto *DA) { + auto I = Mapping.find(DA->getAddress()); + if (I != Mapping.end()) + DA->setAddress(I->second); + }; + if (auto DVI = dyn_cast<DbgVariableIntrinsic>(Inst)) + RemapDebugOperands(DVI, DVI->location_ops()); + if (auto DAI = dyn_cast<DbgAssignIntrinsic>(Inst)) + RemapAssignAddress(DAI); + for (DbgVariableRecord &DVR : filterDbgVars(Inst->getDbgRecordRange())) { + RemapDebugOperands(&DVR, DVR.location_ops()); + if (DVR.isDbgAssign()) + RemapAssignAddress(&DVR); + } +} + namespace { /// A potential constituent of a bitreverse or bswap expression. See @@ -3972,23 +4078,23 @@ bool llvm::recognizeBSwapOrBitReverseIdiom( // We may need to truncate the provider. if (DemandedTy != Provider->getType()) { auto *Trunc = - CastInst::CreateIntegerCast(Provider, DemandedTy, false, "trunc", I); + CastInst::CreateIntegerCast(Provider, DemandedTy, false, "trunc", I->getIterator()); InsertedInsts.push_back(Trunc); Provider = Trunc; } - Instruction *Result = CallInst::Create(F, Provider, "rev", I); + Instruction *Result = CallInst::Create(F, Provider, "rev", I->getIterator()); InsertedInsts.push_back(Result); if (!DemandedMask.isAllOnes()) { auto *Mask = ConstantInt::get(DemandedTy, DemandedMask); - Result = BinaryOperator::Create(Instruction::And, Result, Mask, "mask", I); + Result = BinaryOperator::Create(Instruction::And, Result, Mask, "mask", I->getIterator()); InsertedInsts.push_back(Result); } // We may need to zeroextend back to the result type. if (ITy != Result->getType()) { - auto *ExtInst = CastInst::CreateIntegerCast(Result, ITy, false, "zext", I); + auto *ExtInst = CastInst::CreateIntegerCast(Result, ITy, false, "zext", I->getIterator()); InsertedInsts.push_back(ExtInst); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp index ea6d952cfa7d..4ae2baca327c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp @@ -42,8 +42,11 @@ static bool isSafeDecreasingBound(const SCEV *Start, const SCEV *BoundSCEV, ICmpInst::Predicate BoundPred = IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT; + auto StartLG = SE.applyLoopGuards(Start, L); + auto BoundLG = SE.applyLoopGuards(BoundSCEV, L); + if (LatchBrExitIdx == 1) - return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV); + return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, BoundLG); assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be either 0 or 1"); @@ -54,10 +57,10 @@ static bool isSafeDecreasingBound(const SCEV *Start, const SCEV *BoundSCEV, const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne); const SCEV *MinusOne = - SE.getMinusSCEV(BoundSCEV, SE.getOne(BoundSCEV->getType())); + SE.getMinusSCEV(BoundLG, SE.getOne(BoundLG->getType())); - return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, MinusOne) && - SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit); + return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, MinusOne) && + SE.isLoopEntryGuardedByCond(L, BoundPred, BoundLG, Limit); } /// Given a loop with an increasing induction variable, is it possible to @@ -86,8 +89,11 @@ static bool isSafeIncreasingBound(const SCEV *Start, const SCEV *BoundSCEV, ICmpInst::Predicate BoundPred = IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT; + auto StartLG = SE.applyLoopGuards(Start, L); + auto BoundLG = SE.applyLoopGuards(BoundSCEV, L); + if (LatchBrExitIdx == 1) - return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV); + return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, BoundLG); assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1"); @@ -97,9 +103,9 @@ static bool isSafeIncreasingBound(const SCEV *Start, const SCEV *BoundSCEV, : APInt::getMaxValue(BitWidth); const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne); - return (SE.isLoopEntryGuardedByCond(L, BoundPred, Start, - SE.getAddExpr(BoundSCEV, Step)) && - SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit)); + return (SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, + SE.getAddExpr(BoundLG, Step)) && + SE.isLoopEntryGuardedByCond(L, BoundPred, BoundLG, Limit)); } /// Returns estimate for max latch taken count of the loop of the narrowest @@ -391,7 +397,7 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx); assert(!L.contains(LatchExit) && "expected an exit block!"); - const DataLayout &DL = Preheader->getModule()->getDataLayout(); + const DataLayout &DL = Preheader->getDataLayout(); SCEVExpander Expander(SE, DL, "loop-constrainer"); Instruction *Ins = Preheader->getTerminator(); @@ -644,7 +650,7 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( // value of the same PHI nodes if/when we continue execution. for (PHINode &PN : LS.Header->phis()) { PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy", - BranchToContinuation); + BranchToContinuation->getIterator()); NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader); NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch), @@ -653,7 +659,7 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( } RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end", - BranchToContinuation); + BranchToContinuation->getIterator()); RRI.IndVarEnd->addIncoming(IndVarStart, Preheader); RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector); @@ -727,7 +733,7 @@ bool LoopConstrainer::run() { bool Increasing = MainLoopStructure.IndVarIncreasing; IntegerType *IVTy = cast<IntegerType>(RangeTy); - SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "loop-constrainer"); + SCEVExpander Expander(SE, F.getDataLayout(), "loop-constrainer"); Instruction *InsertPt = OriginalPreheader->getTerminator(); // It would have been better to make `PreLoop' and `PostLoop' diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopPeel.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopPeel.cpp index f76fa3bb6c61..760f1619e030 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopPeel.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopPeel.cpp @@ -298,7 +298,7 @@ static unsigned peelToTurnInvariantLoadsDerefencebale(Loop &L, BasicBlock *Header = L.getHeader(); BasicBlock *Latch = L.getLoopLatch(); SmallPtrSet<Value *, 8> LoadUsers; - const DataLayout &DL = L.getHeader()->getModule()->getDataLayout(); + const DataLayout &DL = L.getHeader()->getDataLayout(); for (BasicBlock *BB : L.blocks()) { for (Instruction &I : *BB) { if (I.mayWriteToMemory()) @@ -351,6 +351,21 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount, MaxPeelCount = std::min((unsigned)SC->getAPInt().getLimitedValue() - 1, MaxPeelCount); + // Increase PeelCount while (IterVal Pred BoundSCEV) condition is satisfied; + // return true if inversed condition become known before reaching the + // MaxPeelCount limit. + auto PeelWhilePredicateIsKnown = + [&](unsigned &PeelCount, const SCEV *&IterVal, const SCEV *BoundSCEV, + const SCEV *Step, ICmpInst::Predicate Pred) { + while (PeelCount < MaxPeelCount && + SE.isKnownPredicate(Pred, IterVal, BoundSCEV)) { + IterVal = SE.getAddExpr(IterVal, Step); + ++PeelCount; + } + return SE.isKnownPredicate(ICmpInst::getInversePredicate(Pred), IterVal, + BoundSCEV); + }; + const unsigned MaxDepth = 4; std::function<void(Value *, unsigned)> ComputePeelCount = [&](Value *Condition, unsigned Depth) -> void { @@ -411,48 +426,73 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount, Pred = ICmpInst::getInversePredicate(Pred); const SCEV *Step = LeftAR->getStepRecurrence(SE); - const SCEV *NextIterVal = SE.getAddExpr(IterVal, Step); - auto PeelOneMoreIteration = [&IterVal, &NextIterVal, &SE, Step, - &NewPeelCount]() { - IterVal = NextIterVal; - NextIterVal = SE.getAddExpr(IterVal, Step); - NewPeelCount++; - }; - - auto CanPeelOneMoreIteration = [&NewPeelCount, &MaxPeelCount]() { - return NewPeelCount < MaxPeelCount; - }; - - while (CanPeelOneMoreIteration() && - SE.isKnownPredicate(Pred, IterVal, RightSCEV)) - PeelOneMoreIteration(); - - // With *that* peel count, does the predicate !Pred become known in the - // first iteration of the loop body after peeling? - if (!SE.isKnownPredicate(ICmpInst::getInversePredicate(Pred), IterVal, - RightSCEV)) - return; // If not, give up. + if (!PeelWhilePredicateIsKnown(NewPeelCount, IterVal, RightSCEV, Step, + Pred)) + return; // However, for equality comparisons, that isn't always sufficient to // eliminate the comparsion in loop body, we may need to peel one more // iteration. See if that makes !Pred become unknown again. + const SCEV *NextIterVal = SE.getAddExpr(IterVal, Step); if (ICmpInst::isEquality(Pred) && !SE.isKnownPredicate(ICmpInst::getInversePredicate(Pred), NextIterVal, RightSCEV) && !SE.isKnownPredicate(Pred, IterVal, RightSCEV) && SE.isKnownPredicate(Pred, NextIterVal, RightSCEV)) { - if (!CanPeelOneMoreIteration()) + if (NewPeelCount >= MaxPeelCount) return; // Need to peel one more iteration, but can't. Give up. - PeelOneMoreIteration(); // Great! + ++NewPeelCount; // Great! } DesiredPeelCount = std::max(DesiredPeelCount, NewPeelCount); }; + auto ComputePeelCountMinMax = [&](MinMaxIntrinsic *MinMax) { + if (!MinMax->getType()->isIntegerTy()) + return; + Value *LHS = MinMax->getLHS(), *RHS = MinMax->getRHS(); + const SCEV *BoundSCEV, *IterSCEV; + if (L.isLoopInvariant(LHS)) { + BoundSCEV = SE.getSCEV(LHS); + IterSCEV = SE.getSCEV(RHS); + } else if (L.isLoopInvariant(RHS)) { + BoundSCEV = SE.getSCEV(RHS); + IterSCEV = SE.getSCEV(LHS); + } else + return; + const auto *AddRec = dyn_cast<SCEVAddRecExpr>(IterSCEV); + // For simplicity, we support only affine recurrences. + if (!AddRec || !AddRec->isAffine() || AddRec->getLoop() != &L) + return; + const SCEV *Step = AddRec->getStepRecurrence(SE); + bool IsSigned = MinMax->isSigned(); + // To minimize number of peeled iterations, we use strict relational + // predicates here. + ICmpInst::Predicate Pred; + if (SE.isKnownPositive(Step)) + Pred = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; + else if (SE.isKnownNegative(Step)) + Pred = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; + else + return; + // Check that AddRec is not wrapping. + if (!(IsSigned ? AddRec->hasNoSignedWrap() : AddRec->hasNoUnsignedWrap())) + return; + unsigned NewPeelCount = DesiredPeelCount; + const SCEV *IterVal = AddRec->evaluateAtIteration( + SE.getConstant(AddRec->getType(), NewPeelCount), SE); + if (!PeelWhilePredicateIsKnown(NewPeelCount, IterVal, BoundSCEV, Step, + Pred)) + return; + DesiredPeelCount = NewPeelCount; + }; + for (BasicBlock *BB : L.blocks()) { for (Instruction &I : *BB) { if (SelectInst *SI = dyn_cast<SelectInst>(&I)) ComputePeelCount(SI->getCondition(), 0); + if (MinMaxIntrinsic *MinMax = dyn_cast<MinMaxIntrinsic>(&I)) + ComputePeelCountMinMax(MinMax); } auto *BI = dyn_cast<BranchInst>(BB->getTerminator()); @@ -640,7 +680,7 @@ struct WeightInfo { /// To avoid dealing with division rounding we can just multiple both part /// of weights to E and use weight as (F - I * E, E). static void updateBranchWeights(Instruction *Term, WeightInfo &Info) { - setBranchWeights(*Term, Info.Weights); + setBranchWeights(*Term, Info.Weights, /*IsExpected=*/false); for (auto [Idx, SubWeight] : enumerate(Info.SubWeights)) if (SubWeight != 0) // Don't set the probability of taking the edge from latch to loop header @@ -819,7 +859,7 @@ static void cloneLoopBlocks( if (LatchInst && L->contains(LatchInst)) LatchVal = VMap[LatchVal]; PHI.addIncoming(LatchVal, cast<BasicBlock>(VMap[Edge.first])); - SE.forgetValue(&PHI); + SE.forgetLcssaPhiWithNewPredecessor(L, &PHI); } // LastValueMap is updated with the values for the current loop @@ -1033,7 +1073,7 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, } for (const auto &[Term, Info] : Weights) { - setBranchWeights(*Term, Info.Weights); + setBranchWeights(*Term, Info.Weights, /*IsExpected=*/false); } // Update Metadata for count of peeled off iterations. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp index 504f4430dc2c..04042e71a2b8 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp @@ -159,8 +159,8 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, // Replace MetadataAsValue(ValueAsMetadata(OrigHeaderVal)) uses in debug // intrinsics. SmallVector<DbgValueInst *, 1> DbgValues; - SmallVector<DPValue *, 1> DPValues; - llvm::findDbgValues(DbgValues, OrigHeaderVal, &DPValues); + SmallVector<DbgVariableRecord *, 1> DbgVariableRecords; + llvm::findDbgValues(DbgValues, OrigHeaderVal, &DbgVariableRecords); for (auto &DbgValue : DbgValues) { // The original users in the OrigHeader are already using the original // definitions. @@ -183,11 +183,11 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, } // RemoveDIs: duplicate implementation for non-instruction debug-info - // storage in DPValues. - for (DPValue *DPV : DPValues) { + // storage in DbgVariableRecords. + for (DbgVariableRecord *DVR : DbgVariableRecords) { // The original users in the OrigHeader are already using the original // definitions. - BasicBlock *UserBB = DPV->getMarker()->getParent(); + BasicBlock *UserBB = DVR->getMarker()->getParent(); if (UserBB == OrigHeader) continue; @@ -202,7 +202,7 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, NewVal = SSA.GetValueInMiddleOfBlock(UserBB); else NewVal = UndefValue::get(OrigHeaderVal->getType()); - DPV->replaceVariableLocationOp(OrigHeaderVal, NewVal); + DVR->replaceVariableLocationOp(OrigHeaderVal, NewVal); } } } @@ -287,7 +287,7 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI, return; SmallVector<uint32_t, 2> Weights; - extractFromBranchWeightMD(WeightMD, Weights); + extractFromBranchWeightMD32(WeightMD, Weights); if (Weights.size() != 2) return; uint32_t OrigLoopExitWeight = Weights[0]; @@ -347,9 +347,19 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI, // probabilities as if there are only 0-trip and 1-trip cases. ExitWeight0 = OrigLoopExitWeight - OrigLoopBackedgeWeight; } + } else { + // Theoretically, if the loop body must be executed at least once, the + // backedge count must be not less than exit count. However the branch + // weight collected by sampling-based PGO may be not very accurate due to + // sampling. Therefore this workaround is required here to avoid underflow + // of unsigned in following update of branch weight. + if (OrigLoopExitWeight > OrigLoopBackedgeWeight) + OrigLoopBackedgeWeight = OrigLoopExitWeight; } + assert(OrigLoopExitWeight >= ExitWeight0 && "Bad branch weight"); ExitWeight1 = OrigLoopExitWeight - ExitWeight0; EnterWeight = ExitWeight1; + assert(OrigLoopBackedgeWeight >= EnterWeight && "Bad branch weight"); LoopBackWeight = OrigLoopBackedgeWeight - EnterWeight; } else if (OrigLoopExitWeight == 0) { if (OrigLoopBackedgeWeight == 0) { @@ -380,13 +390,13 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI, SuccsSwapped ? LoopBackWeight : ExitWeight1, SuccsSwapped ? ExitWeight1 : LoopBackWeight, }; - setBranchWeights(LoopBI, LoopBIWeights); + setBranchWeights(LoopBI, LoopBIWeights, /*IsExpected=*/false); if (HasConditionalPreHeader) { const uint32_t PreHeaderBIWeights[] = { SuccsSwapped ? EnterWeight : ExitWeight0, SuccsSwapped ? ExitWeight0 : EnterWeight, }; - setBranchWeights(PreHeaderBI, PreHeaderBIWeights); + setBranchWeights(PreHeaderBI, PreHeaderBIWeights, /*IsExpected=*/false); } } @@ -450,7 +460,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { L->dump()); return Rotated; } - if (Metrics.convergent) { + if (Metrics.Convergence != ConvergenceKind::None) { LLVM_DEBUG(dbgs() << "LoopRotation: NOT rotating - contains convergent " "instructions: "; L->dump()); @@ -552,20 +562,22 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { for (Instruction &I : llvm::drop_begin(llvm::reverse(*OrigPreheader))) { if (auto *DII = dyn_cast<DbgVariableIntrinsic>(&I)) { DbgIntrinsics.insert(makeHash(DII)); - // Until RemoveDIs supports dbg.declares in DPValue format, we'll need - // to collect DPValues attached to any other debug intrinsics. - for (const DPValue &DPV : DII->getDbgValueRange()) - DbgIntrinsics.insert(makeHash(&DPV)); + // Until RemoveDIs supports dbg.declares in DbgVariableRecord format, + // we'll need to collect DbgVariableRecords attached to any other debug + // intrinsics. + for (const DbgVariableRecord &DVR : + filterDbgVars(DII->getDbgRecordRange())) + DbgIntrinsics.insert(makeHash(&DVR)); } else { break; } } - // Build DPValue hashes for DPValues attached to the terminator, which isn't - // considered in the loop above. - for (const DPValue &DPV : - OrigPreheader->getTerminator()->getDbgValueRange()) - DbgIntrinsics.insert(makeHash(&DPV)); + // Build DbgVariableRecord hashes for DbgVariableRecords attached to the + // terminator, which isn't considered in the loop above. + for (const DbgVariableRecord &DVR : + filterDbgVars(OrigPreheader->getTerminator()->getDbgRecordRange())) + DbgIntrinsics.insert(makeHash(&DVR)); // Remember the local noalias scope declarations in the header. After the // rotation, they must be duplicated and the scope must be cloned. This @@ -577,26 +589,29 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { Module *M = OrigHeader->getModule(); - // Track the next DPValue to clone. If we have a sequence where an + // Track the next DbgRecord to clone. If we have a sequence where an // instruction is hoisted instead of being cloned: - // DPValue blah + // DbgRecord blah // %foo = add i32 0, 0 - // DPValue xyzzy + // DbgRecord xyzzy // %bar = call i32 @foobar() - // where %foo is hoisted, then the DPValue "blah" will be seen twice, once + // where %foo is hoisted, then the DbgRecord "blah" will be seen twice, once // attached to %foo, then when %foo his hoisted it will "fall down" onto the // function call: - // DPValue blah - // DPValue xyzzy + // DbgRecord blah + // DbgRecord xyzzy // %bar = call i32 @foobar() // causing it to appear attached to the call too. // // To avoid this, cloneDebugInfoFrom takes an optional "start cloning from - // here" position to account for this behaviour. We point it at any DPValues - // on the next instruction, here labelled xyzzy, before we hoist %foo. - // Later, we only only clone DPValues from that position (xyzzy) onwards, - // which avoids cloning DPValue "blah" multiple times. - std::optional<DPValue::self_iterator> NextDbgInst = std::nullopt; + // here" position to account for this behaviour. We point it at any + // DbgRecords on the next instruction, here labelled xyzzy, before we hoist + // %foo. Later, we only only clone DbgRecords from that position (xyzzy) + // onwards, which avoids cloning DbgRecord "blah" multiple times. (Stored as + // a range because it gives us a natural way of testing whether + // there were DbgRecords on the next instruction before we hoisted things). + iterator_range<DbgRecord::self_iterator> NextDbgInsts = + (I != E) ? I->getDbgRecordRange() : DbgMarker::getEmptyDbgRecordRange(); while (I != E) { Instruction *Inst = &*I++; @@ -609,20 +624,32 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // memory (without proving that the loop doesn't write). if (L->hasLoopInvariantOperands(Inst) && !Inst->mayReadFromMemory() && !Inst->mayWriteToMemory() && !Inst->isTerminator() && - !isa<DbgInfoIntrinsic>(Inst) && !isa<AllocaInst>(Inst)) { - - if (LoopEntryBranch->getParent()->IsNewDbgInfoFormat) { + !isa<DbgInfoIntrinsic>(Inst) && !isa<AllocaInst>(Inst) && + // It is not safe to hoist the value of these instructions in + // coroutines, as the addresses of otherwise eligible variables (e.g. + // thread-local variables and errno) may change if the coroutine is + // resumed in a different thread.Therefore, we disable this + // optimization for correctness. However, this may block other correct + // optimizations. + // FIXME: This should be reverted once we have a better model for + // memory access in coroutines. + !Inst->getFunction()->isPresplitCoroutine()) { + + if (LoopEntryBranch->getParent()->IsNewDbgInfoFormat && + !NextDbgInsts.empty()) { auto DbgValueRange = - LoopEntryBranch->cloneDebugInfoFrom(Inst, NextDbgInst); - RemapDPValueRange(M, DbgValueRange, ValueMap, - RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + LoopEntryBranch->cloneDebugInfoFrom(Inst, NextDbgInsts.begin()); + RemapDbgRecordRange(M, DbgValueRange, ValueMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); // Erase anything we've seen before. - for (DPValue &DPV : make_early_inc_range(DbgValueRange)) - if (DbgIntrinsics.count(makeHash(&DPV))) - DPV.eraseFromParent(); + for (DbgVariableRecord &DVR : + make_early_inc_range(filterDbgVars(DbgValueRange))) + if (DbgIntrinsics.count(makeHash(&DVR))) + DVR.eraseFromParent(); } - NextDbgInst = I->getDbgValueRange().begin(); + NextDbgInsts = I->getDbgRecordRange(); + Inst->moveBefore(LoopEntryBranch); ++NumInstrsHoisted; @@ -635,15 +662,17 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { ++NumInstrsDuplicated; - if (LoopEntryBranch->getParent()->IsNewDbgInfoFormat) { - auto Range = C->cloneDebugInfoFrom(Inst, NextDbgInst); - RemapDPValueRange(M, Range, ValueMap, - RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); - NextDbgInst = std::nullopt; + if (LoopEntryBranch->getParent()->IsNewDbgInfoFormat && + !NextDbgInsts.empty()) { + auto Range = C->cloneDebugInfoFrom(Inst, NextDbgInsts.begin()); + RemapDbgRecordRange(M, Range, ValueMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + NextDbgInsts = DbgMarker::getEmptyDbgRecordRange(); // Erase anything we've seen before. - for (DPValue &DPV : make_early_inc_range(Range)) - if (DbgIntrinsics.count(makeHash(&DPV))) - DPV.eraseFromParent(); + for (DbgVariableRecord &DVR : + make_early_inc_range(filterDbgVars(Range))) + if (DbgIntrinsics.count(makeHash(&DVR))) + DVR.eraseFromParent(); } // Eagerly remap the operands of the instruction. @@ -761,7 +790,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // OrigPreHeader's old terminator (the original branch into the loop), and // remove the corresponding incoming values from the PHI nodes in OrigHeader. LoopEntryBranch->eraseFromParent(); - OrigPreheader->flushTerminatorDbgValues(); + OrigPreheader->flushTerminatorDbgRecords(); // Update MemorySSA before the rewrite call below changes the 1:1 // instruction:cloned_instruction_or_value mapping. @@ -858,7 +887,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // We can fold the conditional branch in the preheader, this makes things // simpler. The first step is to remove the extra edge to the Exit block. Exit->removePredecessor(OrigPreheader, true /*preserve LCSSA*/); - BranchInst *NewBI = BranchInst::Create(NewHeader, PHBI); + BranchInst *NewBI = BranchInst::Create(NewHeader, PHBI->getIterator()); NewBI->setDebugLoc(PHBI->getDebugLoc()); PHBI->eraseFromParent(); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopSimplify.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopSimplify.cpp index 07e622b1577f..a764fef57491 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopSimplify.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopSimplify.cpp @@ -83,8 +83,8 @@ static void placeSplitBlockCarefully(BasicBlock *NewBB, Loop *L) { // Check to see if NewBB is already well placed. Function::iterator BBI = --NewBB->getIterator(); - for (unsigned i = 0, e = SplitPreds.size(); i != e; ++i) { - if (&*BBI == SplitPreds[i]) + for (BasicBlock *Pred : SplitPreds) { + if (&*BBI == Pred) return; } @@ -95,10 +95,10 @@ static void placeSplitBlockCarefully(BasicBlock *NewBB, // Figure out *which* outside block to put this after. Prefer an outside // block that neighbors a BB actually in the loop. BasicBlock *FoundBB = nullptr; - for (unsigned i = 0, e = SplitPreds.size(); i != e; ++i) { - Function::iterator BBI = SplitPreds[i]->getIterator(); + for (BasicBlock *Pred : SplitPreds) { + Function::iterator BBI = Pred->getIterator(); if (++BBI != NewBB->getParent()->end() && L->contains(&*BBI)) { - FoundBB = SplitPreds[i]; + FoundBB = Pred; break; } } @@ -172,7 +172,7 @@ static void addBlockAndPredsToSet(BasicBlock *InputBB, BasicBlock *StopBlock, /// us how to partition the loops. static PHINode *findPHIToPartitionLoops(Loop *L, DominatorTree *DT, AssumptionCache *AC) { - const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + const DataLayout &DL = L->getHeader()->getDataLayout(); for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ) { PHINode *PN = cast<PHINode>(I); ++I; @@ -399,7 +399,7 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) { PHINode *PN = cast<PHINode>(I); PHINode *NewPN = PHINode::Create(PN->getType(), BackedgeBlocks.size(), - PN->getName()+".be", BETerminator); + PN->getName()+".be", BETerminator->getIterator()); // Loop over the PHI node, moving all entries except the one for the // preheader over to the new PHI node. @@ -588,7 +588,7 @@ ReprocessLoop: if (MSSAU && VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); - const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + const DataLayout &DL = L->getHeader()->getDataLayout(); // Scan over the PHI nodes in the loop header. Since they now have only two // incoming values (the loop is canonicalized), we may have simplified the PHI @@ -630,8 +630,7 @@ ReprocessLoop: return true; }; if (HasUniqueExitBlock()) { - for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { - BasicBlock *ExitingBlock = ExitingBlocks[i]; + for (BasicBlock *ExitingBlock : ExitingBlocks) { if (!ExitingBlock->getSinglePredecessor()) continue; BranchInst *BI = dyn_cast<BranchInst>(ExitingBlock->getTerminator()); if (!BI || !BI->isConditional()) continue; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnroll.cpp index ee6f7b35750a..a0406111ecbf 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnroll.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnroll.cpp @@ -18,17 +18,20 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/ilist_iterator.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" +#include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/BasicBlock.h" @@ -209,13 +212,140 @@ static bool isEpilogProfitable(Loop *L) { return false; } +struct LoadValue { + Instruction *DefI = nullptr; + unsigned Generation = 0; + LoadValue() = default; + LoadValue(Instruction *Inst, unsigned Generation) + : DefI(Inst), Generation(Generation) {} +}; + +class StackNode { + ScopedHashTable<const SCEV *, LoadValue>::ScopeTy LoadScope; + unsigned CurrentGeneration; + unsigned ChildGeneration; + DomTreeNode *Node; + DomTreeNode::const_iterator ChildIter; + DomTreeNode::const_iterator EndIter; + bool Processed = false; + +public: + StackNode(ScopedHashTable<const SCEV *, LoadValue> &AvailableLoads, + unsigned cg, DomTreeNode *N, DomTreeNode::const_iterator Child, + DomTreeNode::const_iterator End) + : LoadScope(AvailableLoads), CurrentGeneration(cg), ChildGeneration(cg), + Node(N), ChildIter(Child), EndIter(End) {} + // Accessors. + unsigned currentGeneration() const { return CurrentGeneration; } + unsigned childGeneration() const { return ChildGeneration; } + void childGeneration(unsigned generation) { ChildGeneration = generation; } + DomTreeNode *node() { return Node; } + DomTreeNode::const_iterator childIter() const { return ChildIter; } + + DomTreeNode *nextChild() { + DomTreeNode *Child = *ChildIter; + ++ChildIter; + return Child; + } + + DomTreeNode::const_iterator end() const { return EndIter; } + bool isProcessed() const { return Processed; } + void process() { Processed = true; } +}; + +Value *getMatchingValue(LoadValue LV, LoadInst *LI, unsigned CurrentGeneration, + BatchAAResults &BAA, + function_ref<MemorySSA *()> GetMSSA) { + if (!LV.DefI) + return nullptr; + if (LV.DefI->getType() != LI->getType()) + return nullptr; + if (LV.Generation != CurrentGeneration) { + MemorySSA *MSSA = GetMSSA(); + if (!MSSA) + return nullptr; + auto *EarlierMA = MSSA->getMemoryAccess(LV.DefI); + MemoryAccess *LaterDef = + MSSA->getWalker()->getClobberingMemoryAccess(LI, BAA); + if (!MSSA->dominates(LaterDef, EarlierMA)) + return nullptr; + } + return LV.DefI; +} + +void loadCSE(Loop *L, DominatorTree &DT, ScalarEvolution &SE, LoopInfo &LI, + BatchAAResults &BAA, function_ref<MemorySSA *()> GetMSSA) { + ScopedHashTable<const SCEV *, LoadValue> AvailableLoads; + SmallVector<std::unique_ptr<StackNode>> NodesToProcess; + DomTreeNode *HeaderD = DT.getNode(L->getHeader()); + NodesToProcess.emplace_back(new StackNode(AvailableLoads, 0, HeaderD, + HeaderD->begin(), HeaderD->end())); + + unsigned CurrentGeneration = 0; + while (!NodesToProcess.empty()) { + StackNode *NodeToProcess = &*NodesToProcess.back(); + + CurrentGeneration = NodeToProcess->currentGeneration(); + + if (!NodeToProcess->isProcessed()) { + // Process the node. + + // If this block has a single predecessor, then the predecessor is the + // parent + // of the domtree node and all of the live out memory values are still + // current in this block. If this block has multiple predecessors, then + // they could have invalidated the live-out memory values of our parent + // value. For now, just be conservative and invalidate memory if this + // block has multiple predecessors. + if (!NodeToProcess->node()->getBlock()->getSinglePredecessor()) + ++CurrentGeneration; + for (auto &I : make_early_inc_range(*NodeToProcess->node()->getBlock())) { + + auto *Load = dyn_cast<LoadInst>(&I); + if (!Load || !Load->isSimple()) { + if (I.mayWriteToMemory()) + CurrentGeneration++; + continue; + } + + const SCEV *PtrSCEV = SE.getSCEV(Load->getPointerOperand()); + LoadValue LV = AvailableLoads.lookup(PtrSCEV); + if (Value *M = + getMatchingValue(LV, Load, CurrentGeneration, BAA, GetMSSA)) { + if (LI.replacementPreservesLCSSAForm(Load, M)) { + Load->replaceAllUsesWith(M); + Load->eraseFromParent(); + } + } else { + AvailableLoads.insert(PtrSCEV, LoadValue(Load, CurrentGeneration)); + } + } + NodeToProcess->childGeneration(CurrentGeneration); + NodeToProcess->process(); + } else if (NodeToProcess->childIter() != NodeToProcess->end()) { + // Push the next child onto the stack. + DomTreeNode *Child = NodeToProcess->nextChild(); + if (!L->contains(Child->getBlock())) + continue; + NodesToProcess.emplace_back( + new StackNode(AvailableLoads, NodeToProcess->childGeneration(), Child, + Child->begin(), Child->end())); + } else { + // It has been processed, and there are no more children to process, + // so delete it and pop it off the stack. + NodesToProcess.pop_back(); + } + } +} + /// Perform some cleanup and simplifications on loops after unrolling. It is /// useful to simplify the IV's in the new loop, as well as do a quick /// simplify/dce pass of the instructions. void llvm::simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC, - const TargetTransformInfo *TTI) { + const TargetTransformInfo *TTI, + AAResults *AA) { using namespace llvm::PatternMatch; // Simplify any new induction variables in the partially unrolled loop. @@ -230,13 +360,27 @@ void llvm::simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI, if (Instruction *Inst = dyn_cast_or_null<Instruction>(V)) RecursivelyDeleteTriviallyDeadInstructions(Inst); } + + if (AA) { + std::unique_ptr<MemorySSA> MSSA = nullptr; + BatchAAResults BAA(*AA); + loadCSE(L, *DT, *SE, *LI, BAA, [L, AA, DT, &MSSA]() -> MemorySSA * { + if (!MSSA) + MSSA.reset(new MemorySSA(*L, AA, DT)); + return &*MSSA; + }); + } } // At this point, the code is well formed. Perform constprop, instsimplify, // and dce. - const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + const DataLayout &DL = L->getHeader()->getDataLayout(); SmallVector<WeakTrackingVH, 16> DeadInsts; for (BasicBlock *BB : L->getBlocks()) { + // Remove repeated debug instructions after loop unrolling. + if (BB->getParent()->getSubprogram()) + RemoveRedundantDbgInstrs(BB); + for (Instruction &Inst : llvm::make_early_inc_range(*BB)) { if (Value *V = simplifyInstruction(&Inst, {DL, nullptr, DT, AC})) if (LI->replacementPreservesLCSSAForm(&Inst, V)) @@ -275,6 +419,26 @@ void llvm::simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI, } } +// Loops containing convergent instructions that are uncontrolled or controlled +// from outside the loop must have a count that divides their TripMultiple. +LLVM_ATTRIBUTE_USED +static bool canHaveUnrollRemainder(const Loop *L) { + if (getLoopConvergenceHeart(L)) + return false; + + // Check for uncontrolled convergent operations. + for (auto &BB : L->blocks()) { + for (auto &I : *BB) { + if (isa<ConvergenceControlInst>(I)) + return true; + if (auto *CB = dyn_cast<CallBase>(&I)) + if (CB->isConvergent()) + return CB->getConvergenceControlToken(); + } + } + return true; +} + /// Unroll the given loop by Count. The loop must be in LCSSA form. Unrolling /// can only fail when the loop's latch block is not terminated by a conditional /// branch instruction. However, if the trip count (and multiple) are not known, @@ -292,12 +456,11 @@ void llvm::simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI, /// /// If RemainderLoop is non-null, it will receive the remainder loop (if /// required and not fully unrolled). -LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, - ScalarEvolution *SE, DominatorTree *DT, - AssumptionCache *AC, - const TargetTransformInfo *TTI, - OptimizationRemarkEmitter *ORE, - bool PreserveLCSSA, Loop **RemainderLoop) { +LoopUnrollResult +llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, + ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC, + const TargetTransformInfo *TTI, OptimizationRemarkEmitter *ORE, + bool PreserveLCSSA, Loop **RemainderLoop, AAResults *AA) { assert(DT && "DomTree is required"); if (!L->getLoopPreheader()) { @@ -421,19 +584,8 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, return LoopUnrollResult::Unmodified; } - // Loops containing convergent instructions cannot use runtime unrolling, - // as the prologue/epilogue may add additional control-dependencies to - // convergent operations. - LLVM_DEBUG( - { - bool HasConvergent = false; - for (auto &BB : L->blocks()) - for (auto &I : *BB) - if (auto *CB = dyn_cast<CallBase>(&I)) - HasConvergent |= CB->isConvergent(); - assert((!HasConvergent || !ULO.Runtime) && - "Can't runtime unroll if loop contains a convergent operation."); - }); + assert((!ULO.Runtime || canHaveUnrollRemainder(L)) && + "Can't runtime unroll if loop contains a convergent operation."); bool EpilogProfitability = UnrollRuntimeEpilog.getNumOccurrences() ? UnrollRuntimeEpilog @@ -579,7 +731,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, if (OldLoop) LoopsToSimplify.insert(NewLoops[OldLoop]); - if (*BB == Header) + if (*BB == Header) { // Loop over all of the PHI nodes in the block, changing them to use // the incoming values from the previous block. for (PHINode *OrigPHI : OrigPHINode) { @@ -592,6 +744,16 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, NewPHI->eraseFromParent(); } + // Eliminate copies of the loop heart intrinsic, if any. + if (ULO.Heart) { + auto it = VMap.find(ULO.Heart); + assert(it != VMap.end()); + Instruction *heartCopy = cast<Instruction>(it->second); + heartCopy->eraseFromParent(); + VMap.erase(it); + } + } + // Update our running map of newest clones LastValueMap[*BB] = New; for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end(); @@ -721,7 +883,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, DeadSucc->removePredecessor(Src, /* KeepOneInputPHIs */ true); // Replace the conditional branch with an unconditional one. - BranchInst::Create(Dest, Term); + BranchInst::Create(Dest, Term->getIterator()); Term->eraseFromParent(); DTUpdates.emplace_back(DominatorTree::Delete, Src, DeadSucc); @@ -852,7 +1014,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, // At this point, the code is well formed. We now simplify the unrolled loop, // doing constant propagation and dead code elimination as we go. simplifyLoopAfterUnroll(L, !CompletelyUnroll && ULO.Count > 1, LI, SE, DT, AC, - TTI); + TTI, AA); NumCompletelyUnrolled += CompletelyUnroll; ++NumUnrolled; @@ -929,8 +1091,8 @@ MDNode *llvm::GetUnrollMetadata(MDNode *LoopID, StringRef Name) { assert(LoopID->getNumOperands() > 0 && "requires at least one operand"); assert(LoopID->getOperand(0) == LoopID && "invalid loop id"); - for (unsigned i = 1, e = LoopID->getNumOperands(); i < e; ++i) { - MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); + for (const MDOperand &MDO : llvm::drop_begin(LoopID->operands())) { + MDNode *MD = dyn_cast<MDNode>(MDO); if (!MD) continue; @@ -938,7 +1100,7 @@ MDNode *llvm::GetUnrollMetadata(MDNode *LoopID, StringRef Name) { if (!S) continue; - if (Name.equals(S->getString())) + if (Name == S->getString()) return MD; } return nullptr; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp index 3c06a6e47a30..c7b88d3c48a6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp @@ -473,9 +473,9 @@ llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount, }; // Move all the phis from Src into Dest auto movePHIs = [](BasicBlock *Src, BasicBlock *Dest) { - Instruction *insertPoint = Dest->getFirstNonPHI(); + BasicBlock::iterator insertPoint = Dest->getFirstNonPHIIt(); while (PHINode *Phi = dyn_cast<PHINode>(Src->begin())) - Phi->moveBefore(insertPoint); + Phi->moveBefore(*Dest, insertPoint); }; // Update the PHI values outside the loop to point to the last block @@ -522,7 +522,7 @@ llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount, // unconditional one to this one BranchInst *SubTerm = cast<BranchInst>(SubLoopBlocksLast[It - 1]->getTerminator()); - BranchInst::Create(SubLoopBlocksFirst[It], SubTerm); + BranchInst::Create(SubLoopBlocksFirst[It], SubTerm->getIterator()); SubTerm->eraseFromParent(); SubLoopBlocksFirst[It]->replacePhiUsesWith(ForeBlocksLast[It], @@ -535,7 +535,7 @@ llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount, // Aft blocks successors and phis BranchInst *AftTerm = cast<BranchInst>(AftBlocksLast.back()->getTerminator()); if (CompletelyUnroll) { - BranchInst::Create(LoopExit, AftTerm); + BranchInst::Create(LoopExit, AftTerm->getIterator()); AftTerm->eraseFromParent(); } else { AftTerm->setSuccessor(!ContinueOnTrue, ForeBlocksFirst[0]); @@ -550,7 +550,7 @@ llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount, // unconditional one to this one BranchInst *AftTerm = cast<BranchInst>(AftBlocksLast[It - 1]->getTerminator()); - BranchInst::Create(AftBlocksFirst[It], AftTerm); + BranchInst::Create(AftBlocksFirst[It], AftTerm->getIterator()); AftTerm->eraseFromParent(); AftBlocksFirst[It]->replacePhiUsesWith(SubLoopBlocksLast[It], diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp index 612f69970881..56aa96e550d9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -126,7 +126,7 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, PreHeader); } else { // Succ is LatchExit. - NewPN->addIncoming(UndefValue::get(PN.getType()), PreHeader); + NewPN->addIncoming(PoisonValue::get(PN.getType()), PreHeader); } Value *V = PN.getIncomingValueForBlock(Latch); @@ -253,7 +253,7 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, assert(EpilogPN->getParent() == Exit && "EpilogPN should be in Exit block"); // Add incoming PreHeader from branch around the Loop - PN.addIncoming(UndefValue::get(PN.getType()), PreHeader); + PN.addIncoming(PoisonValue::get(PN.getType()), PreHeader); SE.forgetValue(&PN); Value *V = PN.getIncomingValueForBlock(Latch); @@ -272,7 +272,7 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, NewExit); // Now PHIs should look like: // NewExit: - // PN = PHI [I, Latch], [undef, PreHeader] + // PN = PHI [I, Latch], [poison, PreHeader] // ... // Exit: // EpilogPN = PHI [PN, NewExit], [VMap[I], EpilogLatch] @@ -670,7 +670,7 @@ bool llvm::UnrollRuntimeLoopRemainder( BasicBlock *PreHeader = L->getLoopPreheader(); BranchInst *PreHeaderBR = cast<BranchInst>(PreHeader->getTerminator()); - const DataLayout &DL = Header->getModule()->getDataLayout(); + const DataLayout &DL = Header->getDataLayout(); SCEVExpander Expander(*SE, DL, "loop-unroll"); if (!AllowExpensiveTripCount && Expander.isHighCostExpansion(TripCountSC, L, SCEVCheapExpansionBudget, @@ -776,7 +776,7 @@ bool llvm::UnrollRuntimeLoopRemainder( !isGuaranteedNotToBeUndefOrPoison(TripCount, AC, PreHeaderBR, DT)) { TripCount = B.CreateFreeze(TripCount); BECount = - B.CreateAdd(TripCount, ConstantInt::get(TripCount->getType(), -1)); + B.CreateAdd(TripCount, Constant::getAllOnesValue(TripCount->getType())); } else { // If we don't need to freeze, use SCEVExpander for BECount as well, to // allow slightly better value reuse. @@ -849,7 +849,7 @@ bool llvm::UnrollRuntimeLoopRemainder( for (unsigned i = 0; i < oldNumOperands; i++){ auto *PredBB =PN.getIncomingBlock(i); if (PredBB == Latch) - // The latch exit is handled seperately, see connectX + // The latch exit is handled separately, see connectX continue; if (!L->contains(PredBB)) // Even if we had dedicated exits, the code above inserted an @@ -917,8 +917,8 @@ bool llvm::UnrollRuntimeLoopRemainder( for (Instruction &I : *BB) { RemapInstruction(&I, VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); - RemapDPValueRange(M, I.getDbgValueRange(), VMap, - RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + RemapDbgRecordRange(M, I.getDbgRecordRange(), VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); } } @@ -977,7 +977,7 @@ bool llvm::UnrollRuntimeLoopRemainder( remainderLoop = nullptr; // Simplify loop values after breaking the backedge - const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + const DataLayout &DL = L->getHeader()->getDataLayout(); SmallVector<WeakTrackingVH, 16> DeadInsts; for (BasicBlock *BB : RemainderBlocks) { for (Instruction &Inst : llvm::make_early_inc_range(*BB)) { @@ -1016,12 +1016,17 @@ bool llvm::UnrollRuntimeLoopRemainder( auto UnrollResult = LoopUnrollResult::Unmodified; if (remainderLoop && UnrollRemainder) { LLVM_DEBUG(dbgs() << "Unrolling remainder loop\n"); - UnrollResult = - UnrollLoop(remainderLoop, - {/*Count*/ Count - 1, /*Force*/ false, /*Runtime*/ false, - /*AllowExpensiveTripCount*/ false, - /*UnrollRemainder*/ false, ForgetAllSCEV}, - LI, SE, DT, AC, TTI, /*ORE*/ nullptr, PreserveLCSSA); + UnrollLoopOptions ULO; + ULO.Count = Count - 1; + ULO.Force = false; + ULO.Runtime = false; + ULO.AllowExpensiveTripCount = false; + ULO.UnrollRemainder = false; + ULO.ForgetAllSCEV = ForgetAllSCEV; + assert(!getLoopConvergenceHeart(L) && + "A loop with a convergence heart does not allow runtime unrolling."); + UnrollResult = UnrollLoop(remainderLoop, ULO, LI, SE, DT, AC, TTI, + /*ORE*/ nullptr, PreserveLCSSA); } if (ResultLoop && UnrollResult != LoopUnrollResult::FullyUnrolled) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp index 59485126b280..0abf6d77496d 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -222,7 +222,7 @@ void llvm::addStringMetadataToLoop(Loop *TheLoop, const char *StringMD, // If it is of form key = value, try to parse it. if (Node->getNumOperands() == 2) { MDString *S = dyn_cast<MDString>(Node->getOperand(0)); - if (S && S->getString().equals(StringMD)) { + if (S && S->getString() == StringMD) { ConstantInt *IntMD = mdconst::extract_or_null<ConstantInt>(Node->getOperand(1)); if (IntMD && IntMD->getSExtValue() == V) @@ -468,6 +468,7 @@ llvm::collectChildrenInLoop(DomTreeNode *N, const Loop *CurLoop) { bool llvm::isAlmostDeadIV(PHINode *PN, BasicBlock *LatchBlock, Value *Cond) { int LatchIdx = PN->getBasicBlockIndex(LatchBlock); + assert(LatchIdx != -1 && "LatchBlock is not a case in this PHINode"); Value *IncV = PN->getIncomingValue(LatchIdx); for (User *U : PN->users()) @@ -604,7 +605,7 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, // Use a map to unique and a vector to guarantee deterministic ordering. llvm::SmallDenseSet<DebugVariable, 4> DeadDebugSet; llvm::SmallVector<DbgVariableIntrinsic *, 4> DeadDebugInst; - llvm::SmallVector<DPValue *, 4> DeadDPValues; + llvm::SmallVector<DbgVariableRecord *, 4> DeadDbgVariableRecords; if (ExitBlock) { // Given LCSSA form is satisfied, we should not have users of instructions @@ -630,17 +631,17 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, U.set(Poison); } - // RemoveDIs: do the same as below for DPValues. + // RemoveDIs: do the same as below for DbgVariableRecords. if (Block->IsNewDbgInfoFormat) { - for (DPValue &DPV : - llvm::make_early_inc_range(I.getDbgValueRange())) { - DebugVariable Key(DPV.getVariable(), DPV.getExpression(), - DPV.getDebugLoc().get()); + for (DbgVariableRecord &DVR : llvm::make_early_inc_range( + filterDbgVars(I.getDbgRecordRange()))) { + DebugVariable Key(DVR.getVariable(), DVR.getExpression(), + DVR.getDebugLoc().get()); if (!DeadDebugSet.insert(Key).second) continue; - // Unlinks the DPV from it's container, for later insertion. - DPV.removeFromParent(); - DeadDPValues.push_back(&DPV); + // Unlinks the DVR from it's container, for later insertion. + DVR.removeFromParent(); + DeadDbgVariableRecords.push_back(&DVR); } } @@ -672,11 +673,11 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, DVI->moveBefore(*ExitBlock, InsertDbgValueBefore); // Due to the "head" bit in BasicBlock::iterator, we're going to insert - // each DPValue right at the start of the block, wheras dbg.values would be - // repeatedly inserted before the first instruction. To replicate this - // behaviour, do it backwards. - for (DPValue *DPV : llvm::reverse(DeadDPValues)) - ExitBlock->insertDPValueBefore(DPV, InsertDbgValueBefore); + // each DbgVariableRecord right at the start of the block, wheras dbg.values + // would be repeatedly inserted before the first instruction. To replicate + // this behaviour, do it backwards. + for (DbgVariableRecord *DVR : llvm::reverse(DeadDbgVariableRecords)) + ExitBlock->insertDbgRecordBefore(DVR, InsertDbgValueBefore); } // Remove the block from the reference counting scheme, so that we can @@ -917,6 +918,96 @@ bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop, return true; } +constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) { + switch (RK) { + default: + llvm_unreachable("Unexpected recurrence kind"); + case RecurKind::Add: + return Intrinsic::vector_reduce_add; + case RecurKind::Mul: + return Intrinsic::vector_reduce_mul; + case RecurKind::And: + return Intrinsic::vector_reduce_and; + case RecurKind::Or: + return Intrinsic::vector_reduce_or; + case RecurKind::Xor: + return Intrinsic::vector_reduce_xor; + case RecurKind::FMulAdd: + case RecurKind::FAdd: + return Intrinsic::vector_reduce_fadd; + case RecurKind::FMul: + return Intrinsic::vector_reduce_fmul; + case RecurKind::SMax: + return Intrinsic::vector_reduce_smax; + case RecurKind::SMin: + return Intrinsic::vector_reduce_smin; + case RecurKind::UMax: + return Intrinsic::vector_reduce_umax; + case RecurKind::UMin: + return Intrinsic::vector_reduce_umin; + case RecurKind::FMax: + return Intrinsic::vector_reduce_fmax; + case RecurKind::FMin: + return Intrinsic::vector_reduce_fmin; + case RecurKind::FMaximum: + return Intrinsic::vector_reduce_fmaximum; + case RecurKind::FMinimum: + return Intrinsic::vector_reduce_fminimum; + } +} + +unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) { + switch (RdxID) { + case Intrinsic::vector_reduce_fadd: + return Instruction::FAdd; + case Intrinsic::vector_reduce_fmul: + return Instruction::FMul; + case Intrinsic::vector_reduce_add: + return Instruction::Add; + case Intrinsic::vector_reduce_mul: + return Instruction::Mul; + case Intrinsic::vector_reduce_and: + return Instruction::And; + case Intrinsic::vector_reduce_or: + return Instruction::Or; + case Intrinsic::vector_reduce_xor: + return Instruction::Xor; + case Intrinsic::vector_reduce_smax: + case Intrinsic::vector_reduce_smin: + case Intrinsic::vector_reduce_umax: + case Intrinsic::vector_reduce_umin: + return Instruction::ICmp; + case Intrinsic::vector_reduce_fmax: + case Intrinsic::vector_reduce_fmin: + return Instruction::FCmp; + default: + llvm_unreachable("Unexpected ID"); + } +} + +Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID) { + switch (RdxID) { + default: + llvm_unreachable("Unknown min/max recurrence kind"); + case Intrinsic::vector_reduce_umin: + return Intrinsic::umin; + case Intrinsic::vector_reduce_umax: + return Intrinsic::umax; + case Intrinsic::vector_reduce_smin: + return Intrinsic::smin; + case Intrinsic::vector_reduce_smax: + return Intrinsic::smax; + case Intrinsic::vector_reduce_fmin: + return Intrinsic::minnum; + case Intrinsic::vector_reduce_fmax: + return Intrinsic::maxnum; + case Intrinsic::vector_reduce_fminimum: + return Intrinsic::minimum; + case Intrinsic::vector_reduce_fmaximum: + return Intrinsic::maximum; + } +} + Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(RecurKind RK) { switch (RK) { default: @@ -940,6 +1031,25 @@ Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(RecurKind RK) { } } +RecurKind llvm::getMinMaxReductionRecurKind(Intrinsic::ID RdxID) { + switch (RdxID) { + case Intrinsic::vector_reduce_smax: + return RecurKind::SMax; + case Intrinsic::vector_reduce_smin: + return RecurKind::SMin; + case Intrinsic::vector_reduce_umax: + return RecurKind::UMax; + case Intrinsic::vector_reduce_umin: + return RecurKind::UMin; + case Intrinsic::vector_reduce_fmax: + return RecurKind::FMax; + case Intrinsic::vector_reduce_fmin: + return RecurKind::FMin; + default: + return RecurKind::None; + } +} + CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) { switch (RK) { default: @@ -962,15 +1072,6 @@ CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) { } } -Value *llvm::createAnyOfOp(IRBuilderBase &Builder, Value *StartVal, - RecurKind RK, Value *Left, Value *Right) { - if (auto VTy = dyn_cast<VectorType>(Left->getType())) - StartVal = Builder.CreateVectorSplat(VTy->getElementCount(), StartVal); - Value *Cmp = - Builder.CreateCmp(CmpInst::ICMP_NE, Left, StartVal, "rdx.select.cmp"); - return Builder.CreateSelect(Cmp, Left, Right, "rdx.select"); -} - Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left, Value *Right) { Type *Ty = Left->getType(); @@ -1014,7 +1115,9 @@ Value *llvm::getOrderedReduction(IRBuilderBase &Builder, Value *Acc, Value *Src, // Helper to generate a log2 shuffle reduction. Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src, - unsigned Op, RecurKind RdxKind) { + unsigned Op, + TargetTransformInfo::ReductionShuffle RS, + RecurKind RdxKind) { unsigned VF = cast<FixedVectorType>(Src->getType())->getNumElements(); // VF is a power of 2 so we can emit the reduction using log2(VF) shuffles // and vector ops, reducing the set of values being computed by half each @@ -1028,18 +1131,10 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src, // will never be relevant here. Note that it would be generally unsound to // propagate these from an intrinsic call to the expansion anyways as we/ // change the order of operations. - Value *TmpVec = Src; - SmallVector<int, 32> ShuffleMask(VF); - for (unsigned i = VF; i != 1; i >>= 1) { - // Move the upper half of the vector to the lower half. - for (unsigned j = 0; j != i / 2; ++j) - ShuffleMask[j] = i / 2 + j; - - // Fill the rest of the mask with undef. - std::fill(&ShuffleMask[i / 2], ShuffleMask.end(), -1); - + auto BuildShuffledOp = [&Builder, &Op, + &RdxKind](SmallVectorImpl<int> &ShuffleMask, + Value *&TmpVec) -> void { Value *Shuf = Builder.CreateShuffleVector(TmpVec, ShuffleMask, "rdx.shuf"); - if (Op != Instruction::ICmp && Op != Instruction::FCmp) { TmpVec = Builder.CreateBinOp((Instruction::BinaryOps)Op, TmpVec, Shuf, "bin.rdx"); @@ -1048,6 +1143,30 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src, "Invalid min/max"); TmpVec = createMinMaxOp(Builder, RdxKind, TmpVec, Shuf); } + }; + + Value *TmpVec = Src; + if (TargetTransformInfo::ReductionShuffle::Pairwise == RS) { + SmallVector<int, 32> ShuffleMask(VF); + for (unsigned stride = 1; stride < VF; stride <<= 1) { + // Initialise the mask with undef. + std::fill(ShuffleMask.begin(), ShuffleMask.end(), -1); + for (unsigned j = 0; j < VF; j += stride << 1) { + ShuffleMask[j] = j + stride; + } + BuildShuffledOp(ShuffleMask, TmpVec); + } + } else { + SmallVector<int, 32> ShuffleMask(VF); + for (unsigned i = VF; i != 1; i >>= 1) { + // Move the upper half of the vector to the lower half. + for (unsigned j = 0; j != i / 2; ++j) + ShuffleMask[j] = i / 2 + j; + + // Fill the rest of the mask with undef. + std::fill(&ShuffleMask[i / 2], ShuffleMask.end(), -1); + BuildShuffledOp(ShuffleMask, TmpVec); + } } // The result is in the first element of the vector. return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0)); @@ -1079,16 +1198,13 @@ Value *llvm::createAnyOfTargetReduction(IRBuilderBase &Builder, Value *Src, NewVal = SI->getTrueValue(); } - // Create a splat vector with the new value and compare this to the vector - // we want to reduce. - ElementCount EC = cast<VectorType>(Src->getType())->getElementCount(); - Value *Right = Builder.CreateVectorSplat(EC, InitVal); - Value *Cmp = - Builder.CreateCmp(CmpInst::ICMP_NE, Src, Right, "rdx.select.cmp"); - // If any predicate is true it means that we want to select the new value. - Cmp = Builder.CreateOrReduce(Cmp); - return Builder.CreateSelect(Cmp, NewVal, InitVal, "rdx.select"); + Value *AnyOf = + Src->getType()->isVectorTy() ? Builder.CreateOrReduce(Src) : Src; + // The compares in the loop may yield poison, which propagates through the + // bitwise ORs. Freeze it here before the condition is used. + AnyOf = Builder.CreateFreeze(AnyOf); + return Builder.CreateSelect(AnyOf, NewVal, InitVal, "rdx.select"); } Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, Value *Src, @@ -1132,6 +1248,20 @@ Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, Value *Src, } } +Value *llvm::createSimpleTargetReduction(VectorBuilder &VBuilder, Value *Src, + const RecurrenceDescriptor &Desc) { + RecurKind Kind = Desc.getRecurrenceKind(); + assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) && + "AnyOf reduction is not supported."); + Intrinsic::ID Id = getReductionIntrinsicID(Kind); + auto *SrcTy = cast<VectorType>(Src->getType()); + Type *SrcEltTy = SrcTy->getElementType(); + Value *Iden = + Desc.getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags()); + Value *Ops[] = {Iden, Src}; + return VBuilder.createSimpleTargetReduction(Id, SrcTy, Ops); +} + Value *llvm::createTargetReduction(IRBuilderBase &B, const RecurrenceDescriptor &Desc, Value *Src, PHINode *OrigPhi) { @@ -1160,6 +1290,21 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B, return B.CreateFAddReduce(Start, Src); } +Value *llvm::createOrderedReduction(VectorBuilder &VBuilder, + const RecurrenceDescriptor &Desc, + Value *Src, Value *Start) { + assert((Desc.getRecurrenceKind() == RecurKind::FAdd || + Desc.getRecurrenceKind() == RecurKind::FMulAdd) && + "Unexpected reduction kind"); + assert(Src->getType()->isVectorTy() && "Expected a vector type"); + assert(!Start->getType()->isVectorTy() && "Expected a scalar type"); + + Intrinsic::ID Id = getReductionIntrinsicID(RecurKind::FAdd); + auto *SrcTy = cast<VectorType>(Src->getType()); + Value *Ops[] = {Start, Src}; + return VBuilder.createSimpleTargetReduction(Id, SrcTy, Ops); +} + void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue, bool IncludeWrapFlags) { auto *VecOp = dyn_cast<Instruction>(I); @@ -1683,16 +1828,16 @@ static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG, auto *HighAR = cast<SCEVAddRecExpr>(High); auto *LowAR = cast<SCEVAddRecExpr>(Low); const Loop *OuterLoop = TheLoop->getParentLoop(); - const SCEV *Recur = LowAR->getStepRecurrence(*Exp.getSE()); - if (Recur == HighAR->getStepRecurrence(*Exp.getSE()) && + ScalarEvolution &SE = *Exp.getSE(); + const SCEV *Recur = LowAR->getStepRecurrence(SE); + if (Recur == HighAR->getStepRecurrence(SE) && HighAR->getLoop() == OuterLoop && LowAR->getLoop() == OuterLoop) { BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch(); - const SCEV *OuterExitCount = - Exp.getSE()->getExitCount(OuterLoop, OuterLoopLatch); + const SCEV *OuterExitCount = SE.getExitCount(OuterLoop, OuterLoopLatch); if (!isa<SCEVCouldNotCompute>(OuterExitCount) && OuterExitCount->getType()->isIntegerTy()) { - const SCEV *NewHigh = cast<SCEVAddRecExpr>(High)->evaluateAtIteration( - OuterExitCount, *Exp.getSE()); + const SCEV *NewHigh = + cast<SCEVAddRecExpr>(High)->evaluateAtIteration(OuterExitCount, SE); if (!isa<SCEVCouldNotCompute>(NewHigh)) { LLVM_DEBUG(dbgs() << "LAA: Expanded RT check for range to include " "outer loop in order to permit hoisting\n"); @@ -1700,7 +1845,8 @@ static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG, Low = cast<SCEVAddRecExpr>(Low)->getStart(); // If there is a possibility that the stride is negative then we have // to generate extra checks to ensure the stride is positive. - if (!Exp.getSE()->isKnownNonNegative(Recur)) { + if (!SE.isKnownNonNegative( + SE.applyLoopGuards(Recur, HighAR->getLoop()))) { Stride = Recur; LLVM_DEBUG(dbgs() << "LAA: ... but need to check stride is " "positive: " @@ -1756,13 +1902,12 @@ Value *llvm::addRuntimeChecks( LLVMContext &Ctx = Loc->getContext(); IRBuilder<InstSimplifyFolder> ChkBuilder(Ctx, - Loc->getModule()->getDataLayout()); + Loc->getDataLayout()); ChkBuilder.SetInsertPoint(Loc); // Our instructions might fold to a constant. Value *MemoryRuntimeCheck = nullptr; - for (const auto &Check : ExpandedChecks) { - const PointerBounds &A = Check.first, &B = Check.second; + for (const auto &[A, B] : ExpandedChecks) { // Check if two pointers (A and B) conflict where conflict is computed as: // start(A) <= end(B) && start(B) <= end(A) @@ -1811,7 +1956,7 @@ Value *llvm::addDiffRuntimeChecks( LLVMContext &Ctx = Loc->getContext(); IRBuilder<InstSimplifyFolder> ChkBuilder(Ctx, - Loc->getModule()->getDataLayout()); + Loc->getDataLayout()); ChkBuilder.SetInsertPoint(Loc); // Our instructions might fold to a constant. Value *MemoryRuntimeCheck = nullptr; @@ -1820,14 +1965,14 @@ Value *llvm::addDiffRuntimeChecks( // Map to keep track of created compares, The key is the pair of operands for // the compare, to allow detecting and re-using redundant compares. DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares; - for (const auto &C : Checks) { - Type *Ty = C.SinkStart->getType(); + for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze] : Checks) { + Type *Ty = SinkStart->getType(); // Compute VF * IC * AccessSize. auto *VFTimesUFTimesSize = ChkBuilder.CreateMul(GetVF(ChkBuilder, Ty->getScalarSizeInBits()), - ConstantInt::get(Ty, IC * C.AccessSize)); - Value *Diff = Expander.expandCodeFor( - SE.getMinusSCEV(C.SinkStart, C.SrcStart), Ty, Loc); + ConstantInt::get(Ty, IC * AccessSize)); + Value *Diff = + Expander.expandCodeFor(SE.getMinusSCEV(SinkStart, SrcStart), Ty, Loc); // Check if the same compare has already been created earlier. In that case, // there is no need to check it again. @@ -1838,7 +1983,7 @@ Value *llvm::addDiffRuntimeChecks( IsConflict = ChkBuilder.CreateICmpULT(Diff, VFTimesUFTimesSize, "diff.check"); SeenCompares.insert({{Diff, VFTimesUFTimesSize}, IsConflict}); - if (C.NeedsFreeze) + if (NeedsFreeze) IsConflict = ChkBuilder.CreateFreeze(IsConflict, IsConflict->getName() + ".fr"); if (MemoryRuntimeCheck) { @@ -1858,10 +2003,12 @@ llvm::hasPartialIVCondition(const Loop &L, unsigned MSSAThreshold, if (!TI || !TI->isConditional()) return {}; - auto *CondI = dyn_cast<CmpInst>(TI->getCondition()); + auto *CondI = dyn_cast<Instruction>(TI->getCondition()); // The case with the condition outside the loop should already be handled // earlier. - if (!CondI || !L.contains(CondI)) + // Allow CmpInst and TruncInsts as they may be users of load instructions + // and have potential for partial unswitching + if (!CondI || !isa<CmpInst, TruncInst>(CondI) || !L.contains(CondI)) return {}; SmallVector<Instruction *> InstToDuplicate; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopVersioning.cpp index 548b0f3c55f0..c43c92a6b4d5 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -62,19 +62,19 @@ void LoopVersioning::versionLoop( const auto &RtPtrChecking = *LAI.getRuntimePointerChecking(); SCEVExpander Exp2(*RtPtrChecking.getSE(), - VersionedLoop->getHeader()->getModule()->getDataLayout(), + VersionedLoop->getHeader()->getDataLayout(), "induction"); MemRuntimeCheck = addRuntimeChecks(RuntimeCheckBB->getTerminator(), VersionedLoop, AliasChecks, Exp2); - SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(), + SCEVExpander Exp(*SE, RuntimeCheckBB->getDataLayout(), "scev.check"); SCEVRuntimeCheck = Exp.expandCodeForPredicate(&Preds, RuntimeCheckBB->getTerminator()); IRBuilder<InstSimplifyFolder> Builder( RuntimeCheckBB->getContext(), - InstSimplifyFolder(RuntimeCheckBB->getModule()->getDataLayout())); + InstSimplifyFolder(RuntimeCheckBB->getDataLayout())); if (MemRuntimeCheck && SCEVRuntimeCheck) { Builder.SetInsertPoint(RuntimeCheckBB->getTerminator()); RuntimeCheck = diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp index 4908535cba54..55f9400d93d7 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerGlobalDtors.cpp @@ -20,6 +20,7 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Transforms/Utils.h" @@ -207,7 +208,7 @@ static bool runImpl(Module &M) { Value *Null = ConstantPointerNull::get(VoidStar); Value *Args[] = {CallDtors, Null, DsoHandle}; Value *Res = CallInst::Create(AtExit, Args, "call", EntryBB); - Value *Cmp = new ICmpInst(*EntryBB, ICmpInst::ICMP_NE, Res, + Value *Cmp = new ICmpInst(EntryBB, ICmpInst::ICMP_NE, Res, Constant::getNullValue(Res->getType())); BranchInst::Create(FailBB, RetBB, Cmp, EntryBB); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerInvoke.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerInvoke.cpp index 6d788857c1ea..ff2ab3c6dce9 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerInvoke.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerInvoke.cpp @@ -22,7 +22,7 @@ #include "llvm/Transforms/Utils.h" using namespace llvm; -#define DEBUG_TYPE "lowerinvoke" +#define DEBUG_TYPE "lower-invoke" STATISTIC(NumInvokes, "Number of invokes replaced"); @@ -52,7 +52,7 @@ static bool runImpl(Function &F) { // Insert a normal call instruction... CallInst *NewCall = CallInst::Create(II->getFunctionType(), II->getCalledOperand(), - CallArgs, OpBundles, "", II); + CallArgs, OpBundles, "", II->getIterator()); NewCall->takeName(II); NewCall->setCallingConv(II->getCallingConv()); NewCall->setAttributes(II->getAttributes()); @@ -60,7 +60,7 @@ static bool runImpl(Function &F) { II->replaceAllUsesWith(NewCall); // Insert an unconditional branch to the normal destination. - BranchInst::Create(II->getNormalDest(), II); + BranchInst::Create(II->getNormalDest(), II->getIterator()); // Remove any PHI node entries from the exception destination. II->getUnwindDest()->removePredecessor(&BB); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp index c75de8687879..b38db412f786 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp @@ -13,6 +13,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MDBuilder.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include <optional> @@ -33,7 +34,7 @@ void llvm::createMemCpyLoopKnownSize( BasicBlock *PostLoopBB = nullptr; Function *ParentFunc = PreLoopBB->getParent(); LLVMContext &Ctx = PreLoopBB->getContext(); - const DataLayout &DL = ParentFunc->getParent()->getDataLayout(); + const DataLayout &DL = ParentFunc->getDataLayout(); MDBuilder MDB(Ctx); MDNode *NewDomain = MDB.createAnonymousAliasScopeDomain("MemCopyDomain"); StringRef Name = "MemCopyAliasScope"; @@ -155,6 +156,26 @@ void llvm::createMemCpyLoopKnownSize( "Bytes copied should match size in the call!"); } +// \returns \p Len udiv \p OpSize, checking for optimization opportunities. +static Value *getRuntimeLoopCount(const DataLayout &DL, IRBuilderBase &B, + Value *Len, Value *OpSize, + unsigned OpSizeVal) { + // For powers of 2, we can lshr by log2 instead of using udiv. + if (isPowerOf2_32(OpSizeVal)) + return B.CreateLShr(Len, Log2_32(OpSizeVal)); + return B.CreateUDiv(Len, OpSize); +} + +// \returns \p Len urem \p OpSize, checking for optimization opportunities. +static Value *getRuntimeLoopRemainder(const DataLayout &DL, IRBuilderBase &B, + Value *Len, Value *OpSize, + unsigned OpSizeVal) { + // For powers of 2, we can and by (OpSizeVal - 1) instead of using urem. + if (isPowerOf2_32(OpSizeVal)) + return B.CreateAnd(Len, OpSizeVal - 1); + return B.CreateURem(Len, OpSize); +} + void llvm::createMemCpyLoopUnknownSize( Instruction *InsertBefore, Value *SrcAddr, Value *DstAddr, Value *CopyLen, Align SrcAlign, Align DstAlign, bool SrcIsVolatile, bool DstIsVolatile, @@ -165,7 +186,7 @@ void llvm::createMemCpyLoopUnknownSize( PreLoopBB->splitBasicBlock(InsertBefore, "post-loop-memcpy-expansion"); Function *ParentFunc = PreLoopBB->getParent(); - const DataLayout &DL = ParentFunc->getParent()->getDataLayout(); + const DataLayout &DL = ParentFunc->getDataLayout(); LLVMContext &Ctx = PreLoopBB->getContext(); MDBuilder MDB(Ctx); MDNode *NewDomain = MDB.createAnonymousAliasScopeDomain("MemCopyDomain"); @@ -194,9 +215,11 @@ void llvm::createMemCpyLoopUnknownSize( Type *Int8Type = Type::getInt8Ty(Ctx); bool LoopOpIsInt8 = LoopOpType == Int8Type; ConstantInt *CILoopOpSize = ConstantInt::get(ILengthType, LoopOpSize); - Value *RuntimeLoopCount = LoopOpIsInt8 ? - CopyLen : - PLBuilder.CreateUDiv(CopyLen, CILoopOpSize); + Value *RuntimeLoopCount = LoopOpIsInt8 + ? CopyLen + : getRuntimeLoopCount(DL, PLBuilder, CopyLen, + CILoopOpSize, LoopOpSize); + BasicBlock *LoopBB = BasicBlock::Create(Ctx, "loop-memcpy-expansion", ParentFunc, PostLoopBB); IRBuilder<> LoopBuilder(LoopBB); @@ -239,8 +262,11 @@ void llvm::createMemCpyLoopUnknownSize( assert((ResLoopOpSize == AtomicElementSize ? *AtomicElementSize : 1) && "Store size is expected to match type size"); - // Add in the - Value *RuntimeResidual = PLBuilder.CreateURem(CopyLen, CILoopOpSize); + Align ResSrcAlign(commonAlignment(PartSrcAlign, ResLoopOpSize)); + Align ResDstAlign(commonAlignment(PartDstAlign, ResLoopOpSize)); + + Value *RuntimeResidual = getRuntimeLoopRemainder(DL, PLBuilder, CopyLen, + CILoopOpSize, LoopOpSize); Value *RuntimeBytesCopied = PLBuilder.CreateSub(CopyLen, RuntimeResidual); // Loop body for the residual copy. @@ -280,7 +306,7 @@ void llvm::createMemCpyLoopUnknownSize( Value *SrcGEP = ResBuilder.CreateInBoundsGEP(ResLoopOpType, SrcAddr, FullOffset); LoadInst *Load = ResBuilder.CreateAlignedLoad(ResLoopOpType, SrcGEP, - PartSrcAlign, SrcIsVolatile); + ResSrcAlign, SrcIsVolatile); if (!CanOverlap) { // Set alias scope for loads. Load->setMetadata(LLVMContext::MD_alias_scope, @@ -288,8 +314,8 @@ void llvm::createMemCpyLoopUnknownSize( } Value *DstGEP = ResBuilder.CreateInBoundsGEP(ResLoopOpType, DstAddr, FullOffset); - StoreInst *Store = ResBuilder.CreateAlignedStore(Load, DstGEP, PartDstAlign, - DstIsVolatile); + StoreInst *Store = + ResBuilder.CreateAlignedStore(Load, DstGEP, ResDstAlign, DstIsVolatile); if (!CanOverlap) { // Indicate that stores don't overlap loads. Store->setMetadata(LLVMContext::MD_noalias, MDNode::get(Ctx, NewScope)); @@ -351,7 +377,7 @@ static void createMemMoveLoop(Instruction *InsertBefore, Value *SrcAddr, Type *TypeOfCopyLen = CopyLen->getType(); BasicBlock *OrigBB = InsertBefore->getParent(); Function *F = OrigBB->getParent(); - const DataLayout &DL = F->getParent()->getDataLayout(); + const DataLayout &DL = F->getDataLayout(); // TODO: Use different element type if possible? Type *EltTy = Type::getInt8Ty(F->getContext()); @@ -361,10 +387,10 @@ static void createMemMoveLoop(Instruction *InsertBefore, Value *SrcAddr, // SplitBlockAndInsertIfThenElse conveniently creates the basic if-then-else // structure. Its block terminators (unconditional branches) are replaced by // the appropriate conditional branches when the loop is built. - ICmpInst *PtrCompare = new ICmpInst(InsertBefore, ICmpInst::ICMP_ULT, + ICmpInst *PtrCompare = new ICmpInst(InsertBefore->getIterator(), ICmpInst::ICMP_ULT, SrcAddr, DstAddr, "compare_src_dst"); Instruction *ThenTerm, *ElseTerm; - SplitBlockAndInsertIfThenElse(PtrCompare, InsertBefore, &ThenTerm, + SplitBlockAndInsertIfThenElse(PtrCompare, InsertBefore->getIterator(), &ThenTerm, &ElseTerm); // Each part of the function consists of two blocks: @@ -386,7 +412,7 @@ static void createMemMoveLoop(Instruction *InsertBefore, Value *SrcAddr, // Initial comparison of n == 0 that lets us skip the loops altogether. Shared // between both backwards and forward copy clauses. ICmpInst *CompareN = - new ICmpInst(OrigBB->getTerminator(), ICmpInst::ICMP_EQ, CopyLen, + new ICmpInst(OrigBB->getTerminator()->getIterator(), ICmpInst::ICMP_EQ, CopyLen, ConstantInt::get(TypeOfCopyLen, 0), "compare_n_to_0"); // Copying backwards. @@ -399,16 +425,16 @@ static void createMemMoveLoop(Instruction *InsertBefore, Value *SrcAddr, LoopPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_ptr"); Value *Element = LoopBuilder.CreateAlignedLoad( EltTy, LoopBuilder.CreateInBoundsGEP(EltTy, SrcAddr, IndexPtr), - PartSrcAlign, "element"); + PartSrcAlign, SrcIsVolatile, "element"); LoopBuilder.CreateAlignedStore( Element, LoopBuilder.CreateInBoundsGEP(EltTy, DstAddr, IndexPtr), - PartDstAlign); + PartDstAlign, DstIsVolatile); LoopBuilder.CreateCondBr( LoopBuilder.CreateICmpEQ(IndexPtr, ConstantInt::get(TypeOfCopyLen, 0)), ExitBB, LoopBB); LoopPhi->addIncoming(IndexPtr, LoopBB); LoopPhi->addIncoming(CopyLen, CopyBackwardsBB); - BranchInst::Create(ExitBB, LoopBB, CompareN, ThenTerm); + BranchInst::Create(ExitBB, LoopBB, CompareN, ThenTerm->getIterator()); ThenTerm->eraseFromParent(); // Copying forward. @@ -417,10 +443,11 @@ static void createMemMoveLoop(Instruction *InsertBefore, Value *SrcAddr, IRBuilder<> FwdLoopBuilder(FwdLoopBB); PHINode *FwdCopyPhi = FwdLoopBuilder.CreatePHI(TypeOfCopyLen, 0, "index_ptr"); Value *SrcGEP = FwdLoopBuilder.CreateInBoundsGEP(EltTy, SrcAddr, FwdCopyPhi); - Value *FwdElement = - FwdLoopBuilder.CreateAlignedLoad(EltTy, SrcGEP, PartSrcAlign, "element"); + Value *FwdElement = FwdLoopBuilder.CreateAlignedLoad( + EltTy, SrcGEP, PartSrcAlign, SrcIsVolatile, "element"); Value *DstGEP = FwdLoopBuilder.CreateInBoundsGEP(EltTy, DstAddr, FwdCopyPhi); - FwdLoopBuilder.CreateAlignedStore(FwdElement, DstGEP, PartDstAlign); + FwdLoopBuilder.CreateAlignedStore(FwdElement, DstGEP, PartDstAlign, + DstIsVolatile); Value *FwdIndexPtr = FwdLoopBuilder.CreateAdd( FwdCopyPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_increment"); FwdLoopBuilder.CreateCondBr(FwdLoopBuilder.CreateICmpEQ(FwdIndexPtr, CopyLen), @@ -428,7 +455,7 @@ static void createMemMoveLoop(Instruction *InsertBefore, Value *SrcAddr, FwdCopyPhi->addIncoming(FwdIndexPtr, FwdLoopBB); FwdCopyPhi->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), CopyForwardBB); - BranchInst::Create(ExitBB, FwdLoopBB, CompareN, ElseTerm); + BranchInst::Create(ExitBB, FwdLoopBB, CompareN, ElseTerm->getIterator()); ElseTerm->eraseFromParent(); } @@ -438,7 +465,7 @@ static void createMemSetLoop(Instruction *InsertBefore, Value *DstAddr, Type *TypeOfCopyLen = CopyLen->getType(); BasicBlock *OrigBB = InsertBefore->getParent(); Function *F = OrigBB->getParent(); - const DataLayout &DL = F->getParent()->getDataLayout(); + const DataLayout &DL = F->getDataLayout(); BasicBlock *NewBB = OrigBB->splitBasicBlock(InsertBefore, "split"); BasicBlock *LoopBB diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp index 4131d36b572d..b5c4e93be574 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LowerSwitch.cpp @@ -165,20 +165,20 @@ BasicBlock *NewLeafBlock(CaseRange &Leaf, Value *Val, ConstantInt *LowerBound, if (Leaf.Low == Leaf.High) { // Make the seteq instruction... Comp = - new ICmpInst(*NewLeaf, ICmpInst::ICMP_EQ, Val, Leaf.Low, "SwitchLeaf"); + new ICmpInst(NewLeaf, ICmpInst::ICMP_EQ, Val, Leaf.Low, "SwitchLeaf"); } else { // Make range comparison if (Leaf.Low == LowerBound) { // Val >= Min && Val <= Hi --> Val <= Hi - Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High, + Comp = new ICmpInst(NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High, "SwitchLeaf"); } else if (Leaf.High == UpperBound) { // Val <= Max && Val >= Lo --> Val >= Lo - Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SGE, Val, Leaf.Low, + Comp = new ICmpInst(NewLeaf, ICmpInst::ICMP_SGE, Val, Leaf.Low, "SwitchLeaf"); } else if (Leaf.Low->isZero()) { // Val >= 0 && Val <= Hi --> Val <=u Hi - Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, + Comp = new ICmpInst(NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, "SwitchLeaf"); } else { // Emit V-Lo <=u Hi-Lo @@ -186,7 +186,7 @@ BasicBlock *NewLeafBlock(CaseRange &Leaf, Value *Val, ConstantInt *LowerBound, Instruction *Add = BinaryOperator::CreateAdd( Val, NegLo, Val->getName() + ".off", NewLeaf); Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High); - Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Add, UpperBound, + Comp = new ICmpInst(NewLeaf, ICmpInst::ICMP_ULE, Add, UpperBound, "SwitchLeaf"); } } @@ -208,7 +208,7 @@ BasicBlock *NewLeafBlock(CaseRange &Leaf, Value *Val, ConstantInt *LowerBound, PHINode *PN = cast<PHINode>(I); // Remove all but one incoming entries from the cluster APInt Range = Leaf.High->getValue() - Leaf.Low->getValue(); - for (APInt j(Range.getBitWidth(), 0, true); j.slt(Range); ++j) { + for (APInt j(Range.getBitWidth(), 0, false); j.ult(Range); ++j) { PN->removeIncomingValue(OrigBlock); } @@ -369,7 +369,7 @@ void ProcessSwitchInst(SwitchInst *SI, const unsigned NumSimpleCases = Clusterify(Cases, SI); IntegerType *IT = cast<IntegerType>(SI->getCondition()->getType()); const unsigned BitWidth = IT->getBitWidth(); - // Explictly use higher precision to prevent unsigned overflow where + // Explicitly use higher precision to prevent unsigned overflow where // `UnsignedMax - 0 + 1 == 0` APInt UnsignedZero(BitWidth + 1, 0); APInt UnsignedMax = APInt::getMaxValue(BitWidth); @@ -407,7 +407,7 @@ void ProcessSwitchInst(SwitchInst *SI, // 2. even if limited to icmp instructions only, it will have to process // roughly C icmp's per switch, where C is the number of cases in the // switch, while LowerSwitch only needs to call LVI once per switch. - const DataLayout &DL = F->getParent()->getDataLayout(); + const DataLayout &DL = F->getDataLayout(); KnownBits Known = computeKnownBits(Val, DL, /*Depth=*/0, AC, SI); // TODO Shouldn't this create a signed range? ConstantRange KnownBitsRange = diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/MatrixUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/MatrixUtils.cpp index e218773cf5da..7866d6434c11 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/MatrixUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/MatrixUtils.cpp @@ -36,7 +36,7 @@ BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit, BranchInst::Create(Body, Header); BranchInst::Create(Latch, Body); PHINode *IV = - PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator()); + PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator()->getIterator()); IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader); B.SetInsertPoint(Latch); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/Mem2Reg.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/Mem2Reg.cpp index fbc6dd7613de..5ad7aeb463ec 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/Mem2Reg.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/Mem2Reg.cpp @@ -74,19 +74,15 @@ namespace { struct PromoteLegacyPass : public FunctionPass { // Pass identification, replacement for typeid static char ID; - bool ForcePass; /// If true, forces pass to execute, instead of skipping. - PromoteLegacyPass() : FunctionPass(ID), ForcePass(false) { - initializePromoteLegacyPassPass(*PassRegistry::getPassRegistry()); - } - PromoteLegacyPass(bool IsForced) : FunctionPass(ID), ForcePass(IsForced) { + PromoteLegacyPass() : FunctionPass(ID) { initializePromoteLegacyPassPass(*PassRegistry::getPassRegistry()); } // runOnFunction - To run this pass, first we calculate the alloca // instructions that are safe for promotion, then we promote each one. bool runOnFunction(Function &F) override { - if (!ForcePass && skipFunction(F)) + if (skipFunction(F)) return false; DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); @@ -115,6 +111,6 @@ INITIALIZE_PASS_END(PromoteLegacyPass, "mem2reg", "Promote Memory to Register", false, false) // createPromoteMemoryToRegister - Provide an entry point to create this pass. -FunctionPass *llvm::createPromoteMemoryToRegisterPass(bool IsForced) { - return new PromoteLegacyPass(IsForced); +FunctionPass *llvm::createPromoteMemoryToRegisterPass() { + return new PromoteLegacyPass(); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/MemoryOpRemark.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/MemoryOpRemark.cpp index d671a9373bf0..8f55d7bbd318 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/MemoryOpRemark.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/MemoryOpRemark.cpp @@ -332,7 +332,7 @@ void MemoryOpRemark::visitVariable(const Value *V, } }; for_each(findDbgDeclares(const_cast<Value *>(V)), FindDI); - for_each(findDPVDeclares(const_cast<Value *>(V)), FindDI); + for_each(findDVRDeclares(const_cast<Value *>(V)), FindDI); if (FoundDI) { assert(!Result.empty()); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp index f94047633022..1472302b6ca3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp @@ -12,12 +12,16 @@ #include "llvm/Transforms/Utils/MemoryTaggingSupport.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/StackSafetyAnalysis.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/BinaryFormat/Dwarf.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" namespace llvm { @@ -69,14 +73,12 @@ bool forAllReachableExits(const DominatorTree &DT, const PostDominatorTree &PDT, ++NumCoveredExits; } } - // If there's a mix of covered and non-covered exits, just put the untag - // on exits, so we avoid the redundancy of untagging twice. if (NumCoveredExits == ReachableRetVec.size()) { - for (auto *End : Ends) - Callback(End); + for_each(Ends, Callback); } else { - for (auto *RI : ReachableRetVec) - Callback(RI); + // If there's a mix of covered and non-covered exits, just put the untag + // on exits, so we avoid the redundancy of untagging twice. + for_each(ReachableRetVec, Callback); // We may have inserted untag outside of the lifetime interval. // Signal the caller to remove the lifetime end call for this alloca. return false; @@ -110,6 +112,24 @@ Instruction *getUntagLocationIfFunctionExit(Instruction &Inst) { } void StackInfoBuilder::visit(Instruction &Inst) { + // Visit non-intrinsic debug-info records attached to Inst. + for (DbgVariableRecord &DVR : filterDbgVars(Inst.getDbgRecordRange())) { + auto AddIfInteresting = [&](Value *V) { + if (auto *AI = dyn_cast_or_null<AllocaInst>(V)) { + if (!isInterestingAlloca(*AI)) + return; + AllocaInfo &AInfo = Info.AllocasToInstrument[AI]; + auto &DVRVec = AInfo.DbgVariableRecords; + if (DVRVec.empty() || DVRVec.back() != &DVR) + DVRVec.push_back(&DVR); + } + }; + + for_each(DVR.location_ops(), AddIfInteresting); + if (DVR.isDbgAssign()) + AddIfInteresting(DVR.getAddress()); + } + if (CallInst *CI = dyn_cast<CallInst>(&Inst)) { if (CI->canReturnTwice()) { Info.CallsReturnTwice = true; @@ -138,17 +158,21 @@ void StackInfoBuilder::visit(Instruction &Inst) { return; } if (auto *DVI = dyn_cast<DbgVariableIntrinsic>(&Inst)) { - for (Value *V : DVI->location_ops()) { + auto AddIfInteresting = [&](Value *V) { if (auto *AI = dyn_cast_or_null<AllocaInst>(V)) { if (!isInterestingAlloca(*AI)) - continue; + return; AllocaInfo &AInfo = Info.AllocasToInstrument[AI]; auto &DVIVec = AInfo.DbgVariableIntrinsics; if (DVIVec.empty() || DVIVec.back() != DVI) DVIVec.push_back(DVI); } - } + }; + for_each(DVI->location_ops(), AddIfInteresting); + if (auto *DAI = dyn_cast<DbgAssignIntrinsic>(DVI)) + AddIfInteresting(DAI->getAddress()); } + Instruction *ExitUntag = getUntagLocationIfFunctionExit(Inst); if (ExitUntag) Info.RetVec.push_back(ExitUntag); @@ -156,6 +180,8 @@ void StackInfoBuilder::visit(Instruction &Inst) { bool StackInfoBuilder::isInterestingAlloca(const AllocaInst &AI) { return (AI.getAllocatedType()->isSized() && + // FIXME: support vscale. + !AI.getAllocatedType()->isScalableTy() && // FIXME: instrument dynamic allocas, too AI.isStaticAlloca() && // alloca() may be called with 0 size, ignore it. @@ -173,7 +199,7 @@ bool StackInfoBuilder::isInterestingAlloca(const AllocaInst &AI) { } uint64_t getAllocaSizeInBytes(const AllocaInst &AI) { - auto DL = AI.getModule()->getDataLayout(); + auto DL = AI.getDataLayout(); return *AI.getAllocationSize(DL); } @@ -197,7 +223,7 @@ void alignAndPadAlloca(memtag::AllocaInfo &Info, llvm::Align Alignment) { Type *PaddingType = ArrayType::get(Type::getInt8Ty(Ctx), AlignedSize - Size); Type *TypeWithPadding = StructType::get(AllocatedType, PaddingType); auto *NewAI = new AllocaInst(TypeWithPadding, Info.AI->getAddressSpace(), - nullptr, "", Info.AI); + nullptr, "", Info.AI->getIterator()); NewAI->takeName(Info.AI); NewAI->setAlignment(Info.AI->getAlign()); NewAI->setUsedWithInAlloca(Info.AI->isUsedWithInAlloca()); @@ -208,12 +234,89 @@ void alignAndPadAlloca(memtag::AllocaInfo &Info, llvm::Align Alignment) { // TODO: Remove when typed pointers dropped if (Info.AI->getType() != NewAI->getType()) - NewPtr = new BitCastInst(NewAI, Info.AI->getType(), "", Info.AI); + NewPtr = new BitCastInst(NewAI, Info.AI->getType(), "", Info.AI->getIterator()); Info.AI->replaceAllUsesWith(NewPtr); Info.AI->eraseFromParent(); Info.AI = NewAI; } +bool isLifetimeIntrinsic(Value *V) { + auto *II = dyn_cast<IntrinsicInst>(V); + return II && II->isLifetimeStartOrEnd(); +} + +Value *readRegister(IRBuilder<> &IRB, StringRef Name) { + Module *M = IRB.GetInsertBlock()->getParent()->getParent(); + Function *ReadRegister = Intrinsic::getDeclaration( + M, Intrinsic::read_register, IRB.getIntPtrTy(M->getDataLayout())); + MDNode *MD = + MDNode::get(M->getContext(), {MDString::get(M->getContext(), Name)}); + Value *Args[] = {MetadataAsValue::get(M->getContext(), MD)}; + return IRB.CreateCall(ReadRegister, Args); +} + +Value *getPC(const Triple &TargetTriple, IRBuilder<> &IRB) { + Module *M = IRB.GetInsertBlock()->getParent()->getParent(); + if (TargetTriple.getArch() == Triple::aarch64) + return memtag::readRegister(IRB, "pc"); + return IRB.CreatePtrToInt(IRB.GetInsertBlock()->getParent(), + IRB.getIntPtrTy(M->getDataLayout())); +} + +Value *getFP(IRBuilder<> &IRB) { + Function *F = IRB.GetInsertBlock()->getParent(); + Module *M = F->getParent(); + auto *GetStackPointerFn = Intrinsic::getDeclaration( + M, Intrinsic::frameaddress, + IRB.getPtrTy(M->getDataLayout().getAllocaAddrSpace())); + return IRB.CreatePtrToInt( + IRB.CreateCall(GetStackPointerFn, + {Constant::getNullValue(IRB.getInt32Ty())}), + IRB.getIntPtrTy(M->getDataLayout())); +} + +Value *getAndroidSlotPtr(IRBuilder<> &IRB, int Slot) { + Module *M = IRB.GetInsertBlock()->getParent()->getParent(); + // Android provides a fixed TLS slot for sanitizers. See TLS_SLOT_SANITIZER + // in Bionic's libc/private/bionic_tls.h. + Function *ThreadPointerFunc = + Intrinsic::getDeclaration(M, Intrinsic::thread_pointer); + return IRB.CreateConstGEP1_32(IRB.getInt8Ty(), + IRB.CreateCall(ThreadPointerFunc), 8 * Slot); +} + +static DbgAssignIntrinsic *DynCastToDbgAssign(DbgVariableIntrinsic *DVI) { + return dyn_cast<DbgAssignIntrinsic>(DVI); +} + +static DbgVariableRecord *DynCastToDbgAssign(DbgVariableRecord *DVR) { + return DVR->isDbgAssign() ? DVR : nullptr; +} + +void annotateDebugRecords(AllocaInfo &Info, unsigned int Tag) { + // Helper utility for adding DW_OP_LLVM_tag_offset to debug-info records, + // abstracted over whether they're intrinsic-stored or DbgVariableRecord + // stored. + auto AnnotateDbgRecord = [&](auto *DPtr) { + // Prepend "tag_offset, N" to the dwarf expression. + // Tag offset logically applies to the alloca pointer, and it makes sense + // to put it at the beginning of the expression. + SmallVector<uint64_t, 8> NewOps = {dwarf::DW_OP_LLVM_tag_offset, Tag}; + for (size_t LocNo = 0; LocNo < DPtr->getNumVariableLocationOps(); ++LocNo) + if (DPtr->getVariableLocationOp(LocNo) == Info.AI) + DPtr->setExpression( + DIExpression::appendOpsToArg(DPtr->getExpression(), NewOps, LocNo)); + if (auto *DAI = DynCastToDbgAssign(DPtr)) { + if (DAI->getAddress() == Info.AI) + DAI->setAddressExpression( + DIExpression::prependOpcodes(DAI->getAddressExpression(), NewOps)); + } + }; + + llvm::for_each(Info.DbgVariableIntrinsics, AnnotateDbgRecord); + llvm::for_each(Info.DbgVariableRecords, AnnotateDbgRecord); +} + } // namespace memtag } // namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/MisExpect.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/MisExpect.cpp index 6f5a25a26821..aef9d82db042 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/MisExpect.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/MisExpect.cpp @@ -59,9 +59,10 @@ static cl::opt<bool> PGOWarnMisExpect( cl::desc("Use this option to turn on/off " "warnings about incorrect usage of llvm.expect intrinsics.")); +// Command line option for setting the diagnostic tolerance threshold static cl::opt<uint32_t> MisExpectTolerance( "misexpect-tolerance", cl::init(0), - cl::desc("Prevents emiting diagnostics when profile counts are " + cl::desc("Prevents emitting diagnostics when profile counts are " "within N% of the threshold..")); } // namespace llvm @@ -150,15 +151,9 @@ void verifyMisExpect(Instruction &I, ArrayRef<uint32_t> RealWeights, uint64_t TotalBranchWeight = LikelyBranchWeight + (UnlikelyBranchWeight * NumUnlikelyTargets); - // FIXME: When we've addressed sample profiling, restore the assertion - // - // We cannot calculate branch probability if either of these invariants aren't - // met. However, MisExpect diagnostics should not prevent code from compiling, - // so we simply forgo emitting diagnostics here, and return early. - // assert((TotalBranchWeight >= LikelyBranchWeight) && (TotalBranchWeight > 0) - // && "TotalBranchWeight is less than the Likely branch weight"); - if ((TotalBranchWeight == 0) || (TotalBranchWeight <= LikelyBranchWeight)) - return; + // Failing this assert means that we have corrupted metadata. + assert((TotalBranchWeight >= LikelyBranchWeight) && (TotalBranchWeight > 0) && + "TotalBranchWeight is less than the Likely branch weight"); // To determine our threshold value we need to obtain the branch probability // for the weights added by llvm.expect and use that proportion to calculate @@ -185,6 +180,13 @@ void verifyMisExpect(Instruction &I, ArrayRef<uint32_t> RealWeights, void checkBackendInstrumentation(Instruction &I, const ArrayRef<uint32_t> RealWeights) { + // Backend checking assumes any existing weight comes from an `llvm.expect` + // intrinsic. However, SampleProfiling + ThinLTO add branch weights multiple + // times, leading to an invalid assumption in our checking. Backend checks + // should only operate on branch weights that carry the "!expected" field, + // since they are guaranteed to be added by the LowerExpectIntrinsic pass. + if (!hasBranchWeightOrigin(I)) + return; SmallVector<uint32_t> ExpectedWeights; if (!extractBranchWeights(I, ExpectedWeights)) return; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/ModuleUtils.cpp index 209a6a34a3c9..95bf9f06bc33 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/ModuleUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/ModuleUtils.cpp @@ -18,6 +18,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/Support/MD5.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Support/xxhash.h" @@ -160,11 +161,13 @@ void llvm::setKCFIType(Module &M, Function &F, StringRef MangledType) { // Matches CodeGenModule::CreateKCFITypeId in Clang. LLVMContext &Ctx = M.getContext(); MDBuilder MDB(Ctx); - F.setMetadata( - LLVMContext::MD_kcfi_type, - MDNode::get(Ctx, MDB.createConstant(ConstantInt::get( - Type::getInt32Ty(Ctx), - static_cast<uint32_t>(xxHash64(MangledType)))))); + std::string Type = MangledType.str(); + if (M.getModuleFlag("cfi-normalize-integers")) + Type += ".normalized"; + F.setMetadata(LLVMContext::MD_kcfi_type, + MDNode::get(Ctx, MDB.createConstant(ConstantInt::get( + Type::getInt32Ty(Ctx), + static_cast<uint32_t>(xxHash64(Type)))))); // If the module was compiled with -fpatchable-function-entry, ensure // we use the same patchable-function-prefix. if (auto *MD = mdconst::extract_or_null<ConstantInt>( diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/PredicateInfo.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/PredicateInfo.cpp index 902977b08d15..186e17e166ba 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/PredicateInfo.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/PredicateInfo.cpp @@ -478,10 +478,8 @@ void PredicateInfoBuilder::processSwitch( // Remember how many outgoing edges there are to every successor. SmallDenseMap<BasicBlock *, unsigned, 16> SwitchEdges; - for (unsigned i = 0, e = SI->getNumSuccessors(); i != e; ++i) { - BasicBlock *TargetBlock = SI->getSuccessor(i); + for (BasicBlock *TargetBlock : successors(BranchBB)) ++SwitchEdges[TargetBlock]; - } // Now propagate info for each case value for (auto C : SI->cases()) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp index 88b05aab8db4..546a6cd56b25 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -41,6 +41,7 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/Type.h" #include "llvm/IR/User.h" #include "llvm/Support/Casting.h" @@ -101,21 +102,22 @@ bool llvm::isAllocaPromotable(const AllocaInst *AI) { namespace { -static DPValue *createDebugValue(DIBuilder &DIB, Value *NewValue, - DILocalVariable *Variable, - DIExpression *Expression, const DILocation *DI, - DPValue *InsertBefore) { +static void createDebugValue(DIBuilder &DIB, Value *NewValue, + DILocalVariable *Variable, + DIExpression *Expression, const DILocation *DI, + DbgVariableRecord *InsertBefore) { + // FIXME: Merge these two functions now that DIBuilder supports + // DbgVariableRecords. We neeed the API to accept DbgVariableRecords as an + // insert point for that to work. (void)DIB; - return DPValue::createDPValue(NewValue, Variable, Expression, DI, - *InsertBefore); + DbgVariableRecord::createDbgVariableRecord(NewValue, Variable, Expression, DI, + *InsertBefore); } -static DbgValueInst *createDebugValue(DIBuilder &DIB, Value *NewValue, - DILocalVariable *Variable, - DIExpression *Expression, - const DILocation *DI, - Instruction *InsertBefore) { - return static_cast<DbgValueInst *>(DIB.insertDbgValueIntrinsic( - NewValue, Variable, Expression, DI, InsertBefore)); +static void createDebugValue(DIBuilder &DIB, Value *NewValue, + DILocalVariable *Variable, + DIExpression *Expression, const DILocation *DI, + Instruction *InsertBefore) { + DIB.insertDbgValueIntrinsic(NewValue, Variable, Expression, DI, InsertBefore); } /// Helper for updating assignment tracking debug info when promoting allocas. @@ -124,7 +126,7 @@ class AssignmentTrackingInfo { /// fragment. (i.e. not be a comprehensive set if there are multiple /// dbg.assigns for one variable fragment). SmallVector<DbgVariableIntrinsic *> DbgAssigns; - SmallVector<DPValue *> DPVAssigns; + SmallVector<DbgVariableRecord *> DVRAssigns; public: void init(AllocaInst *AI) { @@ -133,21 +135,21 @@ public: if (Vars.insert(DebugVariable(DAI)).second) DbgAssigns.push_back(DAI); } - for (DPValue *DPV : at::getDPVAssignmentMarkers(AI)) { - if (Vars.insert(DebugVariable(DPV)).second) - DPVAssigns.push_back(DPV); + for (DbgVariableRecord *DVR : at::getDVRAssignmentMarkers(AI)) { + if (Vars.insert(DebugVariable(DVR)).second) + DVRAssigns.push_back(DVR); } } /// Update assignment tracking debug info given for the to-be-deleted store /// \p ToDelete that stores to this alloca. - void - updateForDeletedStore(StoreInst *ToDelete, DIBuilder &DIB, - SmallSet<DbgAssignIntrinsic *, 8> *DbgAssignsToDelete, - SmallSet<DPValue *, 8> *DPVAssignsToDelete) const { + void updateForDeletedStore( + StoreInst *ToDelete, DIBuilder &DIB, + SmallSet<DbgAssignIntrinsic *, 8> *DbgAssignsToDelete, + SmallSet<DbgVariableRecord *, 8> *DVRAssignsToDelete) const { // There's nothing to do if the alloca doesn't have any variables using // assignment tracking. - if (DbgAssigns.empty() && DPVAssigns.empty()) + if (DbgAssigns.empty() && DVRAssigns.empty()) return; // Insert a dbg.value where the linked dbg.assign is and remember to delete @@ -166,8 +168,8 @@ public: }; for (auto *Assign : at::getAssignmentMarkers(ToDelete)) InsertValueForAssign(Assign, DbgAssignsToDelete); - for (auto *Assign : at::getDPVAssignmentMarkers(ToDelete)) - InsertValueForAssign(Assign, DPVAssignsToDelete); + for (auto *Assign : at::getDVRAssignmentMarkers(ToDelete)) + InsertValueForAssign(Assign, DVRAssignsToDelete); // It's possible for variables using assignment tracking to have no // dbg.assign linked to this store. These are variables in DbgAssigns that @@ -183,7 +185,7 @@ public: ConvertDebugDeclareToDebugValue(Assign, ToDelete, DIB); }; for_each(DbgAssigns, ConvertUnlinkedAssignToValue); - for_each(DPVAssigns, ConvertUnlinkedAssignToValue); + for_each(DVRAssigns, ConvertUnlinkedAssignToValue); } /// Update assignment tracking debug info given for the newly inserted PHI \p @@ -194,20 +196,20 @@ public: // debug-phi. for (auto *DAI : DbgAssigns) ConvertDebugDeclareToDebugValue(DAI, NewPhi, DIB); - for (auto *DPV : DPVAssigns) - ConvertDebugDeclareToDebugValue(DPV, NewPhi, DIB); + for (auto *DVR : DVRAssigns) + ConvertDebugDeclareToDebugValue(DVR, NewPhi, DIB); } void clear() { DbgAssigns.clear(); - DPVAssigns.clear(); + DVRAssigns.clear(); } - bool empty() { return DbgAssigns.empty() && DPVAssigns.empty(); } + bool empty() { return DbgAssigns.empty() && DVRAssigns.empty(); } }; struct AllocaInfo { using DbgUserVec = SmallVector<DbgVariableIntrinsic *, 1>; - using DPUserVec = SmallVector<DPValue *, 1>; + using DPUserVec = SmallVector<DbgVariableRecord *, 1>; SmallVector<BasicBlock *, 32> DefiningBlocks; SmallVector<BasicBlock *, 32> UsingBlocks; @@ -263,7 +265,7 @@ struct AllocaInfo { } } DbgUserVec AllDbgUsers; - SmallVector<DPValue *> AllDPUsers; + SmallVector<DbgVariableRecord *> AllDPUsers; findDbgUsers(AllDbgUsers, AI, &AllDPUsers); std::copy_if(AllDbgUsers.begin(), AllDbgUsers.end(), std::back_inserter(DbgUsers), [](DbgVariableIntrinsic *DII) { @@ -271,7 +273,7 @@ struct AllocaInfo { }); std::copy_if(AllDPUsers.begin(), AllDPUsers.end(), std::back_inserter(DPUsers), - [](DPValue *DPV) { return !DPV->isDbgAssign(); }); + [](DbgVariableRecord *DVR) { return !DVR->isDbgAssign(); }); AssignmentTracking.init(AI); } }; @@ -379,7 +381,7 @@ struct PromoteMem2Reg { /// A set of dbg.assigns to delete because they've been demoted to /// dbg.values. Call cleanUpDbgAssigns to delete them. SmallSet<DbgAssignIntrinsic *, 8> DbgAssignsToDelete; - SmallSet<DPValue *, 8> DPVAssignsToDelete; + SmallSet<DbgVariableRecord *, 8> DVRAssignsToDelete; /// The set of basic blocks the renamer has already visited. SmallPtrSet<BasicBlock *, 16> Visited; @@ -391,12 +393,15 @@ struct PromoteMem2Reg { /// Lazily compute the number of predecessors a block has. DenseMap<const BasicBlock *, unsigned> BBNumPreds; + /// Whether the function has the no-signed-zeros-fp-math attribute set. + bool NoSignedZeros = false; + public: PromoteMem2Reg(ArrayRef<AllocaInst *> Allocas, DominatorTree &DT, AssumptionCache *AC) : Allocas(Allocas.begin(), Allocas.end()), DT(DT), DIB(*DT.getRoot()->getParent()->getParent(), /*AllowUnresolved*/ false), - AC(AC), SQ(DT.getRoot()->getParent()->getParent()->getDataLayout(), + AC(AC), SQ(DT.getRoot()->getDataLayout(), nullptr, &DT, AC) {} void run(); @@ -429,9 +434,9 @@ private: for (auto *DAI : DbgAssignsToDelete) DAI->eraseFromParent(); DbgAssignsToDelete.clear(); - for (auto *DPV : DPVAssignsToDelete) - DPV->eraseFromParent(); - DPVAssignsToDelete.clear(); + for (auto *DVR : DVRAssignsToDelete) + DVR->eraseFromParent(); + DVRAssignsToDelete.clear(); } }; @@ -452,13 +457,22 @@ static void addAssumeNonNull(AssumptionCache *AC, LoadInst *LI) { static void convertMetadataToAssumes(LoadInst *LI, Value *Val, const DataLayout &DL, AssumptionCache *AC, const DominatorTree *DT) { + if (isa<UndefValue>(Val) && LI->hasMetadata(LLVMContext::MD_noundef)) { + // Insert non-terminator unreachable. + LLVMContext &Ctx = LI->getContext(); + new StoreInst(ConstantInt::getTrue(Ctx), + PoisonValue::get(PointerType::getUnqual(Ctx)), + /*isVolatile=*/false, Align(1), LI); + return; + } + // If the load was marked as nonnull we don't want to lose that information // when we erase this Load. So we preserve it with an assume. As !nonnull // returns poison while assume violations are immediate undefined behavior, // we can only do this if the value is known non-poison. if (AC && LI->getMetadata(LLVMContext::MD_nonnull) && LI->getMetadata(LLVMContext::MD_noundef) && - !isKnownNonZero(Val, DL, 0, AC, LI, DT)) + !isKnownNonZero(Val, SimplifyQuery(DL, DT, AC, LI))) addAssumeNonNull(AC, LI); } @@ -509,9 +523,15 @@ rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, LargeBlockInfo &LBI, const DataLayout &DL, DominatorTree &DT, AssumptionCache *AC, SmallSet<DbgAssignIntrinsic *, 8> *DbgAssignsToDelete, - SmallSet<DPValue *, 8> *DPVAssignsToDelete) { + SmallSet<DbgVariableRecord *, 8> *DVRAssignsToDelete) { StoreInst *OnlyStore = Info.OnlyStore; - bool StoringGlobalVal = !isa<Instruction>(OnlyStore->getOperand(0)); + Value *ReplVal = OnlyStore->getOperand(0); + // Loads may either load the stored value or uninitialized memory (undef). + // If the stored value may be poison, then replacing an uninitialized memory + // load with it would be incorrect. If the store dominates the load, we know + // it is always initialized. + bool RequireDominatingStore = + isa<Instruction>(ReplVal) || !isGuaranteedNotToBePoison(ReplVal); BasicBlock *StoreBB = OnlyStore->getParent(); int StoreIndex = -1; @@ -528,7 +548,7 @@ rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, LargeBlockInfo &LBI, // only value stored to the alloca. We can do this if the value is // dominated by the store. If not, we use the rest of the mem2reg machinery // to insert the phi nodes as needed. - if (!StoringGlobalVal) { // Non-instructions are always dominated. + if (RequireDominatingStore) { if (LI->getParent() == StoreBB) { // If we have a use that is in the same block as the store, compare the // indices of the two instructions to see which one came first. If the @@ -551,7 +571,6 @@ rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, LargeBlockInfo &LBI, } // Otherwise, we *can* safely rewrite this load. - Value *ReplVal = OnlyStore->getOperand(0); // If the replacement value is the load, this must occur in unreachable // code. if (ReplVal == LI) @@ -570,7 +589,7 @@ rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, LargeBlockInfo &LBI, DIBuilder DIB(*AI->getModule(), /*AllowUnresolved*/ false); // Update assignment tracking info for the store we're going to delete. Info.AssignmentTracking.updateForDeletedStore( - Info.OnlyStore, DIB, DbgAssignsToDelete, DPVAssignsToDelete); + Info.OnlyStore, DIB, DbgAssignsToDelete, DVRAssignsToDelete); // Record debuginfo for the store and remove the declaration's // debuginfo. @@ -619,7 +638,7 @@ promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, LargeBlockInfo &LBI, const DataLayout &DL, DominatorTree &DT, AssumptionCache *AC, SmallSet<DbgAssignIntrinsic *, 8> *DbgAssignsToDelete, - SmallSet<DPValue *, 8> *DPVAssignsToDelete) { + SmallSet<DbgVariableRecord *, 8> *DVRAssignsToDelete) { // The trickiest case to handle is when we have large blocks. Because of this, // this code is optimized assuming that large blocks happen. This does not // significantly pessimize the small block case. This uses LargeBlockInfo to @@ -684,7 +703,7 @@ promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, StoreInst *SI = cast<StoreInst>(AI->user_back()); // Update assignment tracking info for the store we're going to delete. Info.AssignmentTracking.updateForDeletedStore(SI, DIB, DbgAssignsToDelete, - DPVAssignsToDelete); + DVRAssignsToDelete); // Record debuginfo for the store before removing it. auto DbgUpdateForStore = [&](auto &Container) { for (auto *DbgItem : Container) { @@ -729,6 +748,8 @@ void PromoteMem2Reg::run() { LargeBlockInfo LBI; ForwardIDFCalculator IDF(DT); + NoSignedZeros = F.getFnAttribute("no-signed-zeros-fp-math").getValueAsBool(); + for (unsigned AllocaNum = 0; AllocaNum != Allocas.size(); ++AllocaNum) { AllocaInst *AI = Allocas[AllocaNum]; @@ -756,7 +777,7 @@ void PromoteMem2Reg::run() { // it that are directly dominated by the definition with the value stored. if (Info.DefiningBlocks.size() == 1) { if (rewriteSingleStoreAlloca(AI, Info, LBI, SQ.DL, DT, AC, - &DbgAssignsToDelete, &DPVAssignsToDelete)) { + &DbgAssignsToDelete, &DVRAssignsToDelete)) { // The alloca has been processed, move on. RemoveFromAllocasList(AllocaNum); ++NumSingleStore; @@ -768,7 +789,7 @@ void PromoteMem2Reg::run() { // linear sweep over the block to eliminate it. if (Info.OnlyUsedInOneBlock && promoteSingleBlockAlloca(AI, Info, LBI, SQ.DL, DT, AC, - &DbgAssignsToDelete, &DPVAssignsToDelete)) { + &DbgAssignsToDelete, &DVRAssignsToDelete)) { // The alloca has been processed, move on. RemoveFromAllocasList(AllocaNum); continue; @@ -1112,6 +1133,14 @@ NextIteration: for (unsigned i = 0; i != NumEdges; ++i) APN->addIncoming(IncomingVals[AllocaNo], Pred); + // For the sequence `return X > 0.0 ? X : -X`, it is expected that this + // results in fabs intrinsic. However, without no-signed-zeros(nsz) flag + // on the phi node generated at this stage, fabs folding does not + // happen. So, we try to infer nsz flag from the function attributes to + // enable this fabs folding. + if (isa<FPMathOperator>(APN) && NoSignedZeros) + APN->setHasNoSignedZeros(true); + // The currently active variable for this block is now the PHI. IncomingVals[AllocaNo] = APN; AllocaATInfo[AllocaNo].updateForNewPhi(APN, DIB); @@ -1175,7 +1204,7 @@ NextIteration: // Record debuginfo for the store before removing it. IncomingLocs[AllocaNo] = SI->getDebugLoc(); AllocaATInfo[AllocaNo].updateForDeletedStore(SI, DIB, &DbgAssignsToDelete, - &DPVAssignsToDelete); + &DVRAssignsToDelete); auto ConvertDbgDeclares = [&](auto &Container) { for (auto *DbgItem : Container) if (DbgItem->isAddressOfVariable()) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp index ea628d7c3d7d..6e84965370b2 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp @@ -100,10 +100,10 @@ static GlobalVariable *createRelLookupTable(Function &Func, ArrayType::get(Type::getInt32Ty(M.getContext()), NumElts); GlobalVariable *RelLookupTable = new GlobalVariable( - M, IntArrayTy, LookupTable.isConstant(), LookupTable.getLinkage(), - nullptr, "reltable." + Func.getName(), &LookupTable, - LookupTable.getThreadLocalMode(), LookupTable.getAddressSpace(), - LookupTable.isExternallyInitialized()); + M, IntArrayTy, LookupTable.isConstant(), LookupTable.getLinkage(), + nullptr, LookupTable.getName() + ".rel", &LookupTable, + LookupTable.getThreadLocalMode(), LookupTable.getAddressSpace(), + LookupTable.isExternallyInitialized()); uint64_t Idx = 0; SmallVector<Constant *, 64> RelLookupTableContents(NumElts); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SCCPSolver.cpp index 3dc6016a0a37..2336466a25a1 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -42,14 +42,6 @@ static ValueLatticeElement::MergeOptions getMaxWidenStepsOpts() { MaxNumRangeExtensions); } -static ConstantRange getConstantRange(const ValueLatticeElement &LV, Type *Ty, - bool UndefAllowed = true) { - assert(Ty->isIntOrIntVectorTy() && "Should be int or int vector"); - if (LV.isConstantRange(UndefAllowed)) - return LV.getConstantRange(); - return ConstantRange::getFull(Ty->getScalarSizeInBits()); -} - namespace llvm { bool SCCPSolver::isConstant(const ValueLatticeElement &LV) { @@ -109,17 +101,20 @@ static bool refineInstruction(SCCPSolver &Solver, Instruction &Inst) { bool Changed = false; auto GetRange = [&Solver, &InsertedValues](Value *Op) { - if (auto *Const = dyn_cast<ConstantInt>(Op)) - return ConstantRange(Const->getValue()); - if (isa<Constant>(Op) || InsertedValues.contains(Op)) { + if (auto *Const = dyn_cast<Constant>(Op)) + return Const->toConstantRange(); + if (InsertedValues.contains(Op)) { unsigned Bitwidth = Op->getType()->getScalarSizeInBits(); return ConstantRange::getFull(Bitwidth); } - return getConstantRange(Solver.getLatticeValueFor(Op), Op->getType(), - /*UndefAllowed=*/false); + return Solver.getLatticeValueFor(Op).asConstantRange( + Op->getType(), /*UndefAllowed=*/false); }; if (isa<OverflowingBinaryOperator>(Inst)) { + if (Inst.hasNoSignedWrap() && Inst.hasNoUnsignedWrap()) + return false; + auto RangeA = GetRange(Inst.getOperand(0)); auto RangeB = GetRange(Inst.getOperand(1)); if (!Inst.hasNoUnsignedWrap()) { @@ -140,12 +135,30 @@ static bool refineInstruction(SCCPSolver &Solver, Changed = true; } } - } else if (isa<ZExtInst>(Inst) && !Inst.hasNonNeg()) { + } else if (isa<PossiblyNonNegInst>(Inst) && !Inst.hasNonNeg()) { auto Range = GetRange(Inst.getOperand(0)); if (Range.isAllNonNegative()) { Inst.setNonNeg(); Changed = true; } + } else if (TruncInst *TI = dyn_cast<TruncInst>(&Inst)) { + if (TI->hasNoSignedWrap() && TI->hasNoUnsignedWrap()) + return false; + + auto Range = GetRange(Inst.getOperand(0)); + uint64_t DestWidth = TI->getDestTy()->getScalarSizeInBits(); + if (!TI->hasNoUnsignedWrap()) { + if (Range.getActiveBits() <= DestWidth) { + TI->setHasNoUnsignedWrap(true); + Changed = true; + } + } + if (!TI->hasNoSignedWrap()) { + if (Range.getMinSignedBits() <= DestWidth) { + TI->setHasNoSignedWrap(true); + Changed = true; + } + } } return Changed; @@ -170,14 +183,16 @@ static bool replaceSignedInst(SCCPSolver &Solver, Instruction *NewInst = nullptr; switch (Inst.getOpcode()) { - // Note: We do not fold sitofp -> uitofp here because that could be more - // expensive in codegen and may not be reversible in the backend. + case Instruction::SIToFP: case Instruction::SExt: { - // If the source value is not negative, this is a zext. + // If the source value is not negative, this is a zext/uitofp. Value *Op0 = Inst.getOperand(0); if (InsertedValues.count(Op0) || !isNonNegative(Op0)) return false; - NewInst = new ZExtInst(Op0, Inst.getType(), "", &Inst); + NewInst = CastInst::Create(Inst.getOpcode() == Instruction::SExt + ? Instruction::ZExt + : Instruction::UIToFP, + Op0, Inst.getType(), "", Inst.getIterator()); NewInst->setNonNeg(); break; } @@ -186,7 +201,7 @@ static bool replaceSignedInst(SCCPSolver &Solver, Value *Op0 = Inst.getOperand(0); if (InsertedValues.count(Op0) || !isNonNegative(Op0)) return false; - NewInst = BinaryOperator::CreateLShr(Op0, Inst.getOperand(1), "", &Inst); + NewInst = BinaryOperator::CreateLShr(Op0, Inst.getOperand(1), "", Inst.getIterator()); NewInst->setIsExact(Inst.isExact()); break; } @@ -199,7 +214,7 @@ static bool replaceSignedInst(SCCPSolver &Solver, return false; auto NewOpcode = Inst.getOpcode() == Instruction::SDiv ? Instruction::UDiv : Instruction::URem; - NewInst = BinaryOperator::Create(NewOpcode, Op0, Op1, "", &Inst); + NewInst = BinaryOperator::Create(NewOpcode, Op0, Op1, "", Inst.getIterator()); if (Inst.getOpcode() == Instruction::SDiv) NewInst->setIsExact(Inst.isExact()); break; @@ -213,6 +228,7 @@ static bool replaceSignedInst(SCCPSolver &Solver, NewInst->takeName(&Inst); InsertedValues.insert(NewInst); Inst.replaceAllUsesWith(NewInst); + NewInst->setDebugLoc(Inst.getDebugLoc()); Solver.removeLatticeValueFor(&Inst); Inst.eraseFromParent(); return true; @@ -292,7 +308,8 @@ bool SCCPSolver::removeNonFeasibleEdges(BasicBlock *BB, DomTreeUpdater &DTU, Updates.push_back({DominatorTree::Delete, BB, Succ}); } - BranchInst::Create(OnlyFeasibleSuccessor, BB); + Instruction *BI = BranchInst::Create(OnlyFeasibleSuccessor, BB); + BI->setDebugLoc(TI->getDebugLoc()); TI->eraseFromParent(); DTU.applyUpdatesPermissive(Updates); } else if (FeasibleSuccessors.size() > 1) { @@ -428,6 +445,13 @@ private: return markConstant(ValueState[V], V, C); } + /// markConstantRange - Mark the object as constant range with \p CR. If the + /// object is not a constant range with the range \p CR, add it to the + /// instruction work list so that the users of the instruction are updated + /// later. + bool markConstantRange(ValueLatticeElement &IV, Value *V, + const ConstantRange &CR); + // markOverdefined - Make a value be marked as "overdefined". If the // value is not already overdefined, add it to the overdefined instruction // work list so that the users of the instruction are updated later. @@ -788,6 +812,17 @@ public: markOverdefined(ValueState[V], V); } + void trackValueOfArgument(Argument *A) { + if (A->getType()->isIntOrIntVectorTy()) { + if (std::optional<ConstantRange> Range = A->getRange()) { + markConstantRange(ValueState[A], A, *Range); + return; + } + } + // Assume nothing about the incoming arguments without range. + markOverdefined(A); + } + bool isStructLatticeConstant(Function *F, StructType *STy); Constant *getConstant(const ValueLatticeElement &LV, Type *Ty) const; @@ -873,6 +908,15 @@ bool SCCPInstVisitor::markConstant(ValueLatticeElement &IV, Value *V, return true; } +bool SCCPInstVisitor::markConstantRange(ValueLatticeElement &IV, Value *V, + const ConstantRange &CR) { + if (!IV.markConstantRange(CR)) + return false; + LLVM_DEBUG(dbgs() << "markConstantRange: " << CR << ": " << *V << '\n'); + pushToWorkList(IV, V); + return true; +} + bool SCCPInstVisitor::markOverdefined(ValueLatticeElement &IV, Value *V) { if (!IV.markOverdefined()) return false; @@ -1245,23 +1289,17 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) { return (void)markConstant(&I, C); } - if (I.getDestTy()->isIntegerTy() && I.getSrcTy()->isIntOrIntVectorTy()) { + // Ignore bitcasts, as they may change the number of vector elements. + if (I.getDestTy()->isIntOrIntVectorTy() && + I.getSrcTy()->isIntOrIntVectorTy() && + I.getOpcode() != Instruction::BitCast) { auto &LV = getValueState(&I); - ConstantRange OpRange = getConstantRange(OpSt, I.getSrcTy()); + ConstantRange OpRange = + OpSt.asConstantRange(I.getSrcTy(), /*UndefAllowed=*/false); Type *DestTy = I.getDestTy(); - // Vectors where all elements have the same known constant range are treated - // as a single constant range in the lattice. When bitcasting such vectors, - // there is a mis-match between the width of the lattice value (single - // constant range) and the original operands (vector). Go to overdefined in - // that case. - if (I.getOpcode() == Instruction::BitCast && - I.getOperand(0)->getType()->isVectorTy() && - OpRange.getBitWidth() < DL.getTypeSizeInBits(DestTy)) - return (void)markOverdefined(&I); - ConstantRange Res = - OpRange.castOp(I.getOpcode(), DL.getTypeSizeInBits(DestTy)); + OpRange.castOp(I.getOpcode(), DestTy->getScalarSizeInBits()); mergeInValue(LV, &I, ValueLatticeElement::getRange(Res)); } else markOverdefined(&I); @@ -1279,8 +1317,8 @@ void SCCPInstVisitor::handleExtractOfWithOverflow(ExtractValueInst &EVI, return; // Wait to resolve. Type *Ty = LHS->getType(); - ConstantRange LR = getConstantRange(L, Ty); - ConstantRange RR = getConstantRange(R, Ty); + ConstantRange LR = L.asConstantRange(Ty, /*UndefAllowed=*/false); + ConstantRange RR = R.asConstantRange(Ty, /*UndefAllowed=*/false); if (Idx == 0) { ConstantRange Res = LR.binaryOp(WO->getBinaryOp(), RR); mergeInValue(&EVI, ValueLatticeElement::getRange(Res)); @@ -1480,13 +1518,21 @@ void SCCPInstVisitor::visitBinaryOperator(Instruction &I) { } // Only use ranges for binary operators on integers. - if (!I.getType()->isIntegerTy()) + if (!I.getType()->isIntOrIntVectorTy()) return markOverdefined(&I); // Try to simplify to a constant range. - ConstantRange A = getConstantRange(V1State, I.getType()); - ConstantRange B = getConstantRange(V2State, I.getType()); - ConstantRange R = A.binaryOp(cast<BinaryOperator>(&I)->getOpcode(), B); + ConstantRange A = + V1State.asConstantRange(I.getType(), /*UndefAllowed=*/false); + ConstantRange B = + V2State.asConstantRange(I.getType(), /*UndefAllowed=*/false); + + auto *BO = cast<BinaryOperator>(&I); + ConstantRange R = ConstantRange::getEmpty(I.getType()->getScalarSizeInBits()); + if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(BO)) + R = A.overflowingBinaryOp(BO->getOpcode(), B, OBO->getNoWrapKind()); + else + R = A.binaryOp(BO->getOpcode(), B); mergeInValue(&I, ValueLatticeElement::getRange(R)); // TODO: Currently we do not exploit special values that produce something @@ -1575,10 +1621,15 @@ void SCCPInstVisitor::visitStoreInst(StoreInst &SI) { } static ValueLatticeElement getValueFromMetadata(const Instruction *I) { - if (MDNode *Ranges = I->getMetadata(LLVMContext::MD_range)) - if (I->getType()->isIntegerTy()) + if (I->getType()->isIntOrIntVectorTy()) { + if (MDNode *Ranges = I->getMetadata(LLVMContext::MD_range)) return ValueLatticeElement::getRange( getConstantRangeFromMetadata(*Ranges)); + + if (const auto *CB = dyn_cast<CallBase>(I)) + if (std::optional<ConstantRange> Range = CB->getRange()) + return ValueLatticeElement::getRange(*Range); + } if (I->hasMetadata(LLVMContext::MD_nonnull)) return ValueLatticeElement::getNot( ConstantPointerNull::get(cast<PointerType>(I->getType()))); @@ -1757,7 +1808,11 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) { // Combine range info for the original value with the new range from the // condition. - auto CopyOfCR = getConstantRange(CopyOfVal, CopyOf->getType()); + auto CopyOfCR = CopyOfVal.asConstantRange(CopyOf->getType(), + /*UndefAllowed=*/true); + // Treat an unresolved input like a full range. + if (CopyOfCR.isEmptySet()) + CopyOfCR = ConstantRange::getFull(CopyOfCR.getBitWidth()); auto NewCR = ImposedCR.intersectWith(CopyOfCR); // If the existing information is != x, do not use the information from // a chained predicate, as the != x information is more likely to be @@ -1802,7 +1857,8 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) { const ValueLatticeElement &State = getValueState(Op); if (State.isUnknownOrUndef()) return; - OpRanges.push_back(getConstantRange(State, Op->getType())); + OpRanges.push_back( + State.asConstantRange(Op->getType(), /*UndefAllowed=*/false)); } ConstantRange Result = @@ -2084,6 +2140,10 @@ const SmallPtrSet<Function *, 16> SCCPSolver::getMRVFunctionsTracked() { void SCCPSolver::markOverdefined(Value *V) { Visitor->markOverdefined(V); } +void SCCPSolver::trackValueOfArgument(Argument *V) { + Visitor->trackValueOfArgument(V); +} + bool SCCPSolver::isStructLatticeConstant(Function *F, StructType *STy) { return Visitor->isStructLatticeConstant(F, STy); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SSAUpdater.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SSAUpdater.cpp index fc21fb552137..7fd3e51e141f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SSAUpdater.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SSAUpdater.cpp @@ -136,9 +136,9 @@ Value *SSAUpdater::GetValueInMiddleOfBlock(BasicBlock *BB) { } } - // If there are no predecessors, just return undef. + // If there are no predecessors, just return poison. if (PredValues.empty()) - return UndefValue::get(ProtoType); + return PoisonValue::get(ProtoType); // Otherwise, if all the merged values are the same, just use it. if (SingularValue) @@ -167,7 +167,7 @@ Value *SSAUpdater::GetValueInMiddleOfBlock(BasicBlock *BB) { // See if the PHI node can be merged to a single value. This can happen in // loop cases when we get a PHI of itself and one other value. if (Value *V = - simplifyInstruction(InsertedPHI, BB->getModule()->getDataLayout())) { + simplifyInstruction(InsertedPHI, BB->getDataLayout())) { InsertedPHI->eraseFromParent(); return V; } @@ -199,17 +199,17 @@ void SSAUpdater::RewriteUse(Use &U) { void SSAUpdater::UpdateDebugValues(Instruction *I) { SmallVector<DbgValueInst *, 4> DbgValues; - SmallVector<DPValue *, 4> DPValues; - llvm::findDbgValues(DbgValues, I, &DPValues); + SmallVector<DbgVariableRecord *, 4> DbgVariableRecords; + llvm::findDbgValues(DbgValues, I, &DbgVariableRecords); for (auto &DbgValue : DbgValues) { if (DbgValue->getParent() == I->getParent()) continue; UpdateDebugValue(I, DbgValue); } - for (auto &DPV : DPValues) { - if (DPV->getParent() == I->getParent()) + for (auto &DVR : DbgVariableRecords) { + if (DVR->getParent() == I->getParent()) continue; - UpdateDebugValue(I, DPV); + UpdateDebugValue(I, DVR); } } @@ -220,10 +220,10 @@ void SSAUpdater::UpdateDebugValues(Instruction *I, } } -void SSAUpdater::UpdateDebugValues(Instruction *I, - SmallVectorImpl<DPValue *> &DPValues) { - for (auto &DPV : DPValues) { - UpdateDebugValue(I, DPV); +void SSAUpdater::UpdateDebugValues( + Instruction *I, SmallVectorImpl<DbgVariableRecord *> &DbgVariableRecords) { + for (auto &DVR : DbgVariableRecords) { + UpdateDebugValue(I, DVR); } } @@ -236,13 +236,13 @@ void SSAUpdater::UpdateDebugValue(Instruction *I, DbgValueInst *DbgValue) { DbgValue->setKillLocation(); } -void SSAUpdater::UpdateDebugValue(Instruction *I, DPValue *DPV) { - BasicBlock *UserBB = DPV->getParent(); +void SSAUpdater::UpdateDebugValue(Instruction *I, DbgVariableRecord *DVR) { + BasicBlock *UserBB = DVR->getParent(); if (HasValueForBlock(UserBB)) { Value *NewVal = GetValueAtEndOfBlock(UserBB); - DPV->replaceVariableLocationOp(I, NewVal); + DVR->replaceVariableLocationOp(I, NewVal); } else - DPV->setKillLocation(); + DVR->setKillLocation(); } void SSAUpdater::RewriteUseAfterInsertions(Use &U) { @@ -307,10 +307,10 @@ public: append_range(*Preds, predecessors(BB)); } - /// GetUndefVal - Get an undefined value of the same type as the value + /// GetPoisonVal - Get a poison value of the same type as the value /// being handled. - static Value *GetUndefVal(BasicBlock *BB, SSAUpdater *Updater) { - return UndefValue::get(Updater->ProtoType); + static Value *GetPoisonVal(BasicBlock *BB, SSAUpdater *Updater) { + return PoisonValue::get(Updater->ProtoType); } /// CreateEmptyPHI - Create a new PHI instruction in the specified block. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SampleProfileInference.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SampleProfileInference.cpp index 101b70d8def4..54d46117729c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SampleProfileInference.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SampleProfileInference.cpp @@ -1061,7 +1061,7 @@ void initializeNetwork(const ProfiParams &Params, MinCostMaxFlow &Network, assert(NumJumps > 0 && "Too few jumps in a function"); // Introducing dummy source/sink pairs to allow flow circulation. - // The nodes corresponding to blocks of the function have indicies in + // The nodes corresponding to blocks of the function have indices in // the range [0 .. 2 * NumBlocks); the dummy sources/sinks are indexed by the // next four values. uint64_t S = 2 * NumBlocks; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp index a3951fdf8a15..c7d758aa575e 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -43,6 +43,45 @@ cl::opt<unsigned> llvm::SCEVCheapExpansionBudget( using namespace PatternMatch; +PoisonFlags::PoisonFlags(const Instruction *I) { + NUW = false; + NSW = false; + Exact = false; + Disjoint = false; + NNeg = false; + if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(I)) { + NUW = OBO->hasNoUnsignedWrap(); + NSW = OBO->hasNoSignedWrap(); + } + if (auto *PEO = dyn_cast<PossiblyExactOperator>(I)) + Exact = PEO->isExact(); + if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I)) + Disjoint = PDI->isDisjoint(); + if (auto *PNI = dyn_cast<PossiblyNonNegInst>(I)) + NNeg = PNI->hasNonNeg(); + if (auto *TI = dyn_cast<TruncInst>(I)) { + NUW = TI->hasNoUnsignedWrap(); + NSW = TI->hasNoSignedWrap(); + } +} + +void PoisonFlags::apply(Instruction *I) { + if (isa<OverflowingBinaryOperator>(I)) { + I->setHasNoUnsignedWrap(NUW); + I->setHasNoSignedWrap(NSW); + } + if (isa<PossiblyExactOperator>(I)) + I->setIsExact(Exact); + if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I)) + PDI->setIsDisjoint(Disjoint); + if (auto *PNI = dyn_cast<PossiblyNonNegInst>(I)) + PNI->setNonNeg(NNeg); + if (isa<TruncInst>(I)) { + I->setHasNoUnsignedWrap(NUW); + I->setHasNoSignedWrap(NSW); + } +} + /// ReuseOrCreateCast - Arrange for there to be a cast of V to Ty at IP, /// reusing an existing cast if a suitable one (= dominating IP) exists, or /// creating a new one. @@ -452,6 +491,16 @@ public: } Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) { + // Recognize the canonical representation of an unsimplifed urem. + const SCEV *URemLHS = nullptr; + const SCEV *URemRHS = nullptr; + if (SE.matchURem(S, URemLHS, URemRHS)) { + Value *LHS = expand(URemLHS); + Value *RHS = expand(URemRHS); + return InsertBinop(Instruction::URem, LHS, RHS, SCEV::FlagAnyWrap, + /*IsSafeToHoist*/ false); + } + // Collect all the add operands in a loop, along with their associated loops. // Iterate in reverse so that constants are emitted last, all else equal, and // so that pointer operands are inserted first, which the code below relies on @@ -724,6 +773,7 @@ bool SCEVExpander::hoistIVInc(Instruction *IncV, Instruction *InsertPos, auto FixupPoisonFlags = [this](Instruction *I) { // Drop flags that are potentially inferred from old context and infer flags // in new context. + rememberFlags(I); I->dropPoisonGeneratingFlags(); if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(I)) if (auto Flags = SE.getStrengthenedNoWrapFlagsFromBinOp(OBO)) { @@ -771,6 +821,15 @@ bool SCEVExpander::hoistIVInc(Instruction *IncV, Instruction *InsertPos, return true; } +bool SCEVExpander::canReuseFlagsFromOriginalIVInc(PHINode *OrigPhi, + PHINode *WidePhi, + Instruction *OrigInc, + Instruction *WideInc) { + return match(OrigInc, m_c_BinOp(m_Specific(OrigPhi), m_Value())) && + match(WideInc, m_c_BinOp(m_Specific(WidePhi), m_Value())) && + OrigInc->getOpcode() == WideInc->getOpcode(); +} + /// Determine if this cyclic phi is in a form that would have been generated by /// LSR. We don't care if the phi was actually expanded in this pass, as long /// as it is in a low-cost form, for example, no implied multiplication. This @@ -795,7 +854,8 @@ Value *SCEVExpander::expandIVInc(PHINode *PN, Value *StepV, const Loop *L, Value *IncV; // If the PHI is a pointer, use a GEP, otherwise use an add or sub. if (PN->getType()->isPointerTy()) { - IncV = expandAddToGEP(SE.getSCEV(StepV), PN); + // TODO: Change name to IVName.iv.next. + IncV = Builder.CreatePtrAdd(PN, StepV, "scevgep"); } else { IncV = useSubtract ? Builder.CreateSub(PN, StepV, Twine(IVName) + ".iv.next") : @@ -1010,14 +1070,11 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized, // Create the PHI. BasicBlock *Header = L->getHeader(); Builder.SetInsertPoint(Header, Header->begin()); - pred_iterator HPB = pred_begin(Header), HPE = pred_end(Header); - PHINode *PN = Builder.CreatePHI(ExpandTy, std::distance(HPB, HPE), - Twine(IVName) + ".iv"); + PHINode *PN = + Builder.CreatePHI(ExpandTy, pred_size(Header), Twine(IVName) + ".iv"); // Create the step instructions and populate the PHI. - for (pred_iterator HPI = HPB; HPI != HPE; ++HPI) { - BasicBlock *Pred = *HPI; - + for (BasicBlock *Pred : predecessors(Header)) { // Add a start value. if (!L->contains(Pred)) { PN->addIncoming(StartV, Pred); @@ -1228,7 +1285,7 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { // corresponding to the back-edge. Instruction *Add = BinaryOperator::CreateAdd(CanonicalIV, One, "indvar.next", - HP->getTerminator()); + HP->getTerminator()->getIterator()); Add->setDebugLoc(HP->getTerminator()->getDebugLoc()); rememberInstruction(Add); CanonicalIV->addIncoming(Add, HP); @@ -1474,7 +1531,8 @@ Value *SCEVExpander::expand(const SCEV *S) { V = fixupLCSSAFormFor(V); } else { for (Instruction *I : DropPoisonGeneratingInsts) { - I->dropPoisonGeneratingFlagsAndMetadata(); + rememberFlags(I); + I->dropPoisonGeneratingAnnotations(); // See if we can re-infer from first principles any of the flags we just // dropped. if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(I)) @@ -1514,6 +1572,99 @@ void SCEVExpander::rememberInstruction(Value *I) { DoInsert(I); } +void SCEVExpander::rememberFlags(Instruction *I) { + // If we already have flags for the instruction, keep the existing ones. + OrigFlags.try_emplace(I, PoisonFlags(I)); +} + +void SCEVExpander::replaceCongruentIVInc( + PHINode *&Phi, PHINode *&OrigPhi, Loop *L, const DominatorTree *DT, + SmallVectorImpl<WeakTrackingVH> &DeadInsts) { + BasicBlock *LatchBlock = L->getLoopLatch(); + if (!LatchBlock) + return; + + Instruction *OrigInc = + dyn_cast<Instruction>(OrigPhi->getIncomingValueForBlock(LatchBlock)); + Instruction *IsomorphicInc = + dyn_cast<Instruction>(Phi->getIncomingValueForBlock(LatchBlock)); + if (!OrigInc || !IsomorphicInc) + return; + + // If this phi has the same width but is more canonical, replace the + // original with it. As part of the "more canonical" determination, + // respect a prior decision to use an IV chain. + if (OrigPhi->getType() == Phi->getType() && + !(ChainedPhis.count(Phi) || + isExpandedAddRecExprPHI(OrigPhi, OrigInc, L)) && + (ChainedPhis.count(Phi) || + isExpandedAddRecExprPHI(Phi, IsomorphicInc, L))) { + std::swap(OrigPhi, Phi); + std::swap(OrigInc, IsomorphicInc); + } + + // Replacing the congruent phi is sufficient because acyclic + // redundancy elimination, CSE/GVN, should handle the + // rest. However, once SCEV proves that a phi is congruent, + // it's often the head of an IV user cycle that is isomorphic + // with the original phi. It's worth eagerly cleaning up the + // common case of a single IV increment so that DeleteDeadPHIs + // can remove cycles that had postinc uses. + // Because we may potentially introduce a new use of OrigIV that didn't + // exist before at this point, its poison flags need readjustment. + const SCEV *TruncExpr = + SE.getTruncateOrNoop(SE.getSCEV(OrigInc), IsomorphicInc->getType()); + if (OrigInc == IsomorphicInc || TruncExpr != SE.getSCEV(IsomorphicInc) || + !SE.LI.replacementPreservesLCSSAForm(IsomorphicInc, OrigInc)) + return; + + bool BothHaveNUW = false; + bool BothHaveNSW = false; + auto *OBOIncV = dyn_cast<OverflowingBinaryOperator>(OrigInc); + auto *OBOIsomorphic = dyn_cast<OverflowingBinaryOperator>(IsomorphicInc); + if (OBOIncV && OBOIsomorphic) { + BothHaveNUW = + OBOIncV->hasNoUnsignedWrap() && OBOIsomorphic->hasNoUnsignedWrap(); + BothHaveNSW = + OBOIncV->hasNoSignedWrap() && OBOIsomorphic->hasNoSignedWrap(); + } + + if (!hoistIVInc(OrigInc, IsomorphicInc, + /*RecomputePoisonFlags*/ true)) + return; + + // We are replacing with a wider increment. If both OrigInc and IsomorphicInc + // are NUW/NSW, then we can preserve them on the wider increment; the narrower + // IsomorphicInc would wrap before the wider OrigInc, so the replacement won't + // make IsomorphicInc's uses more poisonous. + assert(OrigInc->getType()->getScalarSizeInBits() >= + IsomorphicInc->getType()->getScalarSizeInBits() && + "Should only replace an increment with a wider one."); + if (BothHaveNUW || BothHaveNSW) { + OrigInc->setHasNoUnsignedWrap(OBOIncV->hasNoUnsignedWrap() || BothHaveNUW); + OrigInc->setHasNoSignedWrap(OBOIncV->hasNoSignedWrap() || BothHaveNSW); + } + + SCEV_DEBUG_WITH_TYPE(DebugType, + dbgs() << "INDVARS: Eliminated congruent iv.inc: " + << *IsomorphicInc << '\n'); + Value *NewInc = OrigInc; + if (OrigInc->getType() != IsomorphicInc->getType()) { + BasicBlock::iterator IP; + if (PHINode *PN = dyn_cast<PHINode>(OrigInc)) + IP = PN->getParent()->getFirstInsertionPt(); + else + IP = OrigInc->getNextNonDebugInstruction()->getIterator(); + + IRBuilder<> Builder(IP->getParent(), IP); + Builder.SetCurrentDebugLocation(IsomorphicInc->getDebugLoc()); + NewInc = + Builder.CreateTruncOrBitCast(OrigInc, IsomorphicInc->getType(), IVName); + } + IsomorphicInc->replaceAllUsesWith(NewInc); + DeadInsts.emplace_back(IsomorphicInc); +} + /// replaceCongruentIVs - Check for congruent phis in this loop header and /// replace them with their most canonical representative. Return the number of /// phis eliminated. @@ -1599,60 +1750,7 @@ SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, if (OrigPhiRef->getType()->isPointerTy() != Phi->getType()->isPointerTy()) continue; - if (BasicBlock *LatchBlock = L->getLoopLatch()) { - Instruction *OrigInc = dyn_cast<Instruction>( - OrigPhiRef->getIncomingValueForBlock(LatchBlock)); - Instruction *IsomorphicInc = - dyn_cast<Instruction>(Phi->getIncomingValueForBlock(LatchBlock)); - - if (OrigInc && IsomorphicInc) { - // If this phi has the same width but is more canonical, replace the - // original with it. As part of the "more canonical" determination, - // respect a prior decision to use an IV chain. - if (OrigPhiRef->getType() == Phi->getType() && - !(ChainedPhis.count(Phi) || - isExpandedAddRecExprPHI(OrigPhiRef, OrigInc, L)) && - (ChainedPhis.count(Phi) || - isExpandedAddRecExprPHI(Phi, IsomorphicInc, L))) { - std::swap(OrigPhiRef, Phi); - std::swap(OrigInc, IsomorphicInc); - } - // Replacing the congruent phi is sufficient because acyclic - // redundancy elimination, CSE/GVN, should handle the - // rest. However, once SCEV proves that a phi is congruent, - // it's often the head of an IV user cycle that is isomorphic - // with the original phi. It's worth eagerly cleaning up the - // common case of a single IV increment so that DeleteDeadPHIs - // can remove cycles that had postinc uses. - // Because we may potentially introduce a new use of OrigIV that didn't - // exist before at this point, its poison flags need readjustment. - const SCEV *TruncExpr = - SE.getTruncateOrNoop(SE.getSCEV(OrigInc), IsomorphicInc->getType()); - if (OrigInc != IsomorphicInc && - TruncExpr == SE.getSCEV(IsomorphicInc) && - SE.LI.replacementPreservesLCSSAForm(IsomorphicInc, OrigInc) && - hoistIVInc(OrigInc, IsomorphicInc, /*RecomputePoisonFlags*/ true)) { - SCEV_DEBUG_WITH_TYPE( - DebugType, dbgs() << "INDVARS: Eliminated congruent iv.inc: " - << *IsomorphicInc << '\n'); - Value *NewInc = OrigInc; - if (OrigInc->getType() != IsomorphicInc->getType()) { - BasicBlock::iterator IP; - if (PHINode *PN = dyn_cast<PHINode>(OrigInc)) - IP = PN->getParent()->getFirstInsertionPt(); - else - IP = OrigInc->getNextNonDebugInstruction()->getIterator(); - - IRBuilder<> Builder(IP->getParent(), IP); - Builder.SetCurrentDebugLocation(IsomorphicInc->getDebugLoc()); - NewInc = Builder.CreateTruncOrBitCast( - OrigInc, IsomorphicInc->getType(), IVName); - } - IsomorphicInc->replaceAllUsesWith(NewInc); - DeadInsts.emplace_back(IsomorphicInc); - } - } - } + replaceCongruentIVInc(Phi, OrigPhiRef, L, DT, DeadInsts); SCEV_DEBUG_WITH_TYPE(DebugType, dbgs() << "INDVARS: Eliminated congruent iv: " << *Phi << '\n'); @@ -1991,7 +2089,7 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, // FIXME: It is highly suspicious that we're ignoring the predicates here. SmallVector<const SCEVPredicate *, 4> Pred; const SCEV *ExitCount = - SE.getPredicatedBackedgeTakenCount(AR->getLoop(), Pred); + SE.getPredicatedSymbolicMaxBackedgeTakenCount(AR->getLoop(), Pred); assert(!isa<SCEVCouldNotCompute>(ExitCount) && "Invalid loop count"); @@ -2152,7 +2250,7 @@ Value *SCEVExpander::fixupLCSSAFormFor(Value *V) { if (!PreserveLCSSA || !DefI) return V; - Instruction *InsertPt = &*Builder.GetInsertPoint(); + BasicBlock::iterator InsertPt = Builder.GetInsertPoint(); Loop *DefLoop = SE.LI.getLoopFor(DefI->getParent()); Loop *UseLoop = SE.LI.getLoopFor(InsertPt->getParent()); if (!DefLoop || UseLoop == DefLoop || DefLoop->contains(UseLoop)) @@ -2276,6 +2374,10 @@ void SCEVExpanderCleaner::cleanup() { if (ResultUsed) return; + // Restore original poison flags. + for (auto [I, Flags] : Expander.OrigFlags) + Flags.apply(I); + auto InsertedInstructions = Expander.getAllInsertedInstructions(); #ifndef NDEBUG SmallPtrSet<Instruction *, 8> InsertedSet(InsertedInstructions.begin(), diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index f95dae1842fe..f23e28888931 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -51,6 +51,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/MemoryModelRelaxationAnnotations.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/NoFolder.h" @@ -860,26 +861,28 @@ static bool ValuesOverlap(std::vector<ValueEqualityComparisonCase> &C1, // Set branch weights on SwitchInst. This sets the metadata if there is at // least one non-zero weight. -static void setBranchWeights(SwitchInst *SI, ArrayRef<uint32_t> Weights) { +static void setBranchWeights(SwitchInst *SI, ArrayRef<uint32_t> Weights, + bool IsExpected) { // Check that there is at least one non-zero weight. Otherwise, pass // nullptr to setMetadata which will erase the existing metadata. MDNode *N = nullptr; if (llvm::any_of(Weights, [](uint32_t W) { return W != 0; })) - N = MDBuilder(SI->getParent()->getContext()).createBranchWeights(Weights); + N = MDBuilder(SI->getParent()->getContext()) + .createBranchWeights(Weights, IsExpected); SI->setMetadata(LLVMContext::MD_prof, N); } // Similar to the above, but for branch and select instructions that take // exactly 2 weights. static void setBranchWeights(Instruction *I, uint32_t TrueWeight, - uint32_t FalseWeight) { + uint32_t FalseWeight, bool IsExpected) { assert(isa<BranchInst>(I) || isa<SelectInst>(I)); // Check that there is at least one non-zero weight. Otherwise, pass // nullptr to setMetadata which will erase the existing metadata. MDNode *N = nullptr; if (TrueWeight || FalseWeight) N = MDBuilder(I->getParent()->getContext()) - .createBranchWeights(TrueWeight, FalseWeight); + .createBranchWeights(TrueWeight, FalseWeight, IsExpected); I->setMetadata(LLVMContext::MD_prof, N); } @@ -1065,11 +1068,8 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1, static void GetBranchWeights(Instruction *TI, SmallVectorImpl<uint64_t> &Weights) { MDNode *MD = TI->getMetadata(LLVMContext::MD_prof); - assert(MD); - for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) { - ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(i)); - Weights.push_back(CI->getValue().getZExtValue()); - } + assert(MD && "Invalid branch-weight metadata"); + extractFromBranchWeightMD64(MD, Weights); // If TI is a conditional eq, the default case is the false case, // and the corresponding branch-weight data is at index 2. We swap the @@ -1084,7 +1084,7 @@ static void GetBranchWeights(Instruction *TI, /// Keep halving the weights until all can fit in uint32_t. static void FitWeights(MutableArrayRef<uint64_t> Weights) { - uint64_t Max = *std::max_element(Weights.begin(), Weights.end()); + uint64_t Max = *llvm::max_element(Weights); if (Max > UINT_MAX) { unsigned Offset = 32 - llvm::countl_zero(Max); for (uint64_t &I : Weights) @@ -1126,8 +1126,8 @@ static void CloneInstructionsIntoPredecessorBlockAndUpdateSSAUses( NewBonusInst->insertInto(PredBlock, PTI->getIterator()); auto Range = NewBonusInst->cloneDebugInfoFrom(&BonusInst); - RemapDPValueRange(NewBonusInst->getModule(), Range, VMap, - RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + RemapDbgRecordRange(NewBonusInst->getModule(), Range, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); if (isa<DbgInfoIntrinsic>(BonusInst)) continue; @@ -1340,7 +1340,7 @@ bool SimplifyCFGOpt::PerformValueComparisonIntoPredecessorFolding( SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end()); - setBranchWeights(NewSI, MDWeights); + setBranchWeights(NewSI, MDWeights, /*IsExpected=*/false); } EraseTerminatorAndDCECond(PTI); @@ -1526,6 +1526,63 @@ static bool shouldHoistCommonInstructions(Instruction *I1, Instruction *I2, return true; } +/// Hoists DbgVariableRecords from \p I1 and \p OtherInstrs that are identical +/// in lock-step to \p TI. This matches how dbg.* intrinsics are hoisting in +/// hoistCommonCodeFromSuccessors. e.g. The input: +/// I1 DVRs: { x, z }, +/// OtherInsts: { I2 DVRs: { x, y, z } } +/// would result in hoisting only DbgVariableRecord x. +static void hoistLockstepIdenticalDbgVariableRecords( + Instruction *TI, Instruction *I1, + SmallVectorImpl<Instruction *> &OtherInsts) { + if (!I1->hasDbgRecords()) + return; + using CurrentAndEndIt = + std::pair<DbgRecord::self_iterator, DbgRecord::self_iterator>; + // Vector of {Current, End} iterators. + SmallVector<CurrentAndEndIt> Itrs; + Itrs.reserve(OtherInsts.size() + 1); + // Helper lambdas for lock-step checks: + // Return true if this Current == End. + auto atEnd = [](const CurrentAndEndIt &Pair) { + return Pair.first == Pair.second; + }; + // Return true if all Current are identical. + auto allIdentical = [](const SmallVector<CurrentAndEndIt> &Itrs) { + return all_of(make_first_range(ArrayRef(Itrs).drop_front()), + [&](DbgRecord::self_iterator I) { + return Itrs[0].first->isIdenticalToWhenDefined(*I); + }); + }; + + // Collect the iterators. + Itrs.push_back( + {I1->getDbgRecordRange().begin(), I1->getDbgRecordRange().end()}); + for (Instruction *Other : OtherInsts) { + if (!Other->hasDbgRecords()) + return; + Itrs.push_back( + {Other->getDbgRecordRange().begin(), Other->getDbgRecordRange().end()}); + } + + // Iterate in lock-step until any of the DbgRecord lists are exausted. If + // the lock-step DbgRecord are identical, hoist all of them to TI. + // This replicates the dbg.* intrinsic behaviour in + // hoistCommonCodeFromSuccessors. + while (none_of(Itrs, atEnd)) { + bool HoistDVRs = allIdentical(Itrs); + for (CurrentAndEndIt &Pair : Itrs) { + // Increment Current iterator now as we may be about to move the + // DbgRecord. + DbgRecord &DR = *Pair.first++; + if (HoistDVRs) { + DR.removeFromParent(); + TI->getParent()->insertDbgRecordBefore(&DR, TI->getIterator()); + } + } + } +} + /// Hoist any common code in the successor blocks up into the block. This /// function guarantees that BB dominates all successors. If EqTermsOnly is /// given, only perform hoisting in case both blocks only contain a terminator. @@ -1598,7 +1655,6 @@ bool SimplifyCFGOpt::hoistCommonCodeFromSuccessors(BasicBlock *BB, auto OtherSuccIterRange = make_first_range(OtherSuccIterPairRange); Instruction *I1 = &*BB1ItrPair.first; - auto *BB1 = I1->getParent(); // Skip debug info if it is not identical. bool AllDbgInstsAreIdentical = all_of(OtherSuccIterRange, [I1](auto &Iter) { @@ -1620,22 +1676,28 @@ bool SimplifyCFGOpt::hoistCommonCodeFromSuccessors(BasicBlock *BB, for (auto &SuccIter : OtherSuccIterRange) { Instruction *I2 = &*SuccIter; HasTerminator |= I2->isTerminator(); - if (AllInstsAreIdentical && !I1->isIdenticalToWhenDefined(I2)) + if (AllInstsAreIdentical && (!I1->isIdenticalToWhenDefined(I2) || + MMRAMetadata(*I1) != MMRAMetadata(*I2))) AllInstsAreIdentical = false; } + SmallVector<Instruction *, 8> OtherInsts; + for (auto &SuccIter : OtherSuccIterRange) + OtherInsts.push_back(&*SuccIter); + // If we are hoisting the terminator instruction, don't move one (making a // broken BB), instead clone it, and remove BI. if (HasTerminator) { // Even if BB, which contains only one unreachable instruction, is ignored // at the beginning of the loop, we can hoist the terminator instruction. // If any instructions remain in the block, we cannot hoist terminators. - if (NumSkipped || !AllInstsAreIdentical) + if (NumSkipped || !AllInstsAreIdentical) { + hoistLockstepIdenticalDbgVariableRecords(TI, I1, OtherInsts); return Changed; - SmallVector<Instruction *, 8> Insts; - for (auto &SuccIter : OtherSuccIterRange) - Insts.push_back(&*SuccIter); - return hoistSuccIdenticalTerminatorToSwitchOrIf(TI, I1, Insts) || Changed; + } + + return hoistSuccIdenticalTerminatorToSwitchOrIf(TI, I1, OtherInsts) || + Changed; } if (AllInstsAreIdentical) { @@ -1660,18 +1722,25 @@ bool SimplifyCFGOpt::hoistCommonCodeFromSuccessors(BasicBlock *BB, // The debug location is an integral part of a debug info intrinsic // and can't be separated from it or replaced. Instead of attempting // to merge locations, simply hoist both copies of the intrinsic. - I1->moveBeforePreserving(TI); + hoistLockstepIdenticalDbgVariableRecords(TI, I1, OtherInsts); + // We've just hoisted DbgVariableRecords; move I1 after them (before TI) + // and leave any that were not hoisted behind (by calling moveBefore + // rather than moveBeforePreserving). + I1->moveBefore(TI); for (auto &SuccIter : OtherSuccIterRange) { auto *I2 = &*SuccIter++; assert(isa<DbgInfoIntrinsic>(I2)); - I2->moveBeforePreserving(TI); + I2->moveBefore(TI); } } else { // For a normal instruction, we just move one to right before the // branch, then replace all uses of the other with the first. Finally, // we remove the now redundant second instruction. - I1->moveBeforePreserving(TI); - BB->splice(TI->getIterator(), BB1, I1->getIterator()); + hoistLockstepIdenticalDbgVariableRecords(TI, I1, OtherInsts); + // We've just hoisted DbgVariableRecords; move I1 after them (before TI) + // and leave any that were not hoisted behind (by calling moveBefore + // rather than moveBeforePreserving). + I1->moveBefore(TI); for (auto &SuccIter : OtherSuccIterRange) { Instruction *I2 = &*SuccIter++; assert(I2 != I1); @@ -1690,8 +1759,10 @@ bool SimplifyCFGOpt::hoistCommonCodeFromSuccessors(BasicBlock *BB, Changed = true; NumHoistCommonInstrs += SuccIterPairs.size(); } else { - if (NumSkipped >= HoistCommonSkipLimit) + if (NumSkipped >= HoistCommonSkipLimit) { + hoistLockstepIdenticalDbgVariableRecords(TI, I1, OtherInsts); return Changed; + } // We are about to skip over a pair of non-identical instructions. Record // if any have characteristics that would prevent reordering instructions // across them. @@ -1752,7 +1823,10 @@ bool SimplifyCFGOpt::hoistSuccIdenticalTerminatorToSwitchOrIf( } } - // Okay, it is safe to hoist the terminator. + // Hoist DbgVariableRecords attached to the terminator to match dbg.* + // intrinsic hoisting behaviour in hoistCommonCodeFromSuccessors. + hoistLockstepIdenticalDbgVariableRecords(TI, I1, OtherSuccTIs); + // Clone the terminator and hoist it into the pred, without any debug info. Instruction *NT = I1->clone(); NT->insertInto(TIParent, TI->getIterator()); if (!NT->getType()->isVoidTy()) { @@ -1770,11 +1844,6 @@ bool SimplifyCFGOpt::hoistSuccIdenticalTerminatorToSwitchOrIf( Locs.push_back(I1->getDebugLoc()); for (auto *OtherSuccTI : OtherSuccTIs) Locs.push_back(OtherSuccTI->getDebugLoc()); - // Also clone DPValues from the existing terminator, and all others (to - // duplicate existing hoisting behaviour). - NT->cloneDebugInfoFrom(I1); - for (Instruction *OtherSuccTI : OtherSuccTIs) - NT->cloneDebugInfoFrom(OtherSuccTI); NT->setDebugLoc(DILocation::getMergedLocations(Locs)); // PHIs created below will adopt NT's merged DebugLoc. @@ -1863,11 +1932,10 @@ static bool replacingOperandWithVariableIsCheap(const Instruction *I, // PHI node (because an operand varies in each input block), add to PHIOperands. static bool canSinkInstructions( ArrayRef<Instruction *> Insts, - DenseMap<Instruction *, SmallVector<Value *, 4>> &PHIOperands) { + DenseMap<const Use *, SmallVector<Value *, 4>> &PHIOperands) { // Prune out obviously bad instructions to move. Each instruction must have - // exactly zero or one use, and we check later that use is by a single, common - // PHI instruction in the successor. - bool HasUse = !Insts.front()->user_empty(); + // the same number of uses, and we check later that the uses are consistent. + std::optional<unsigned> NumUses; for (auto *I : Insts) { // These instructions may change or break semantics if moved. if (isa<PHINode>(I) || I->isEHPad() || isa<AllocaInst>(I) || @@ -1887,14 +1955,14 @@ static bool canSinkInstructions( if (C->isInlineAsm() || C->cannotMerge() || C->isConvergent()) return false; - // Each instruction must have zero or one use. - if (HasUse && !I->hasOneUse()) - return false; - if (!HasUse && !I->user_empty()) + if (!NumUses) + NumUses = I->getNumUses(); + else if (NumUses != I->getNumUses()) return false; } const Instruction *I0 = Insts.front(); + const auto I0MMRA = MMRAMetadata(*I0); for (auto *I : Insts) { if (!I->isSameOperationAs(I0)) return false; @@ -1906,22 +1974,23 @@ static bool canSinkInstructions( return false; if (isa<LoadInst>(I) && I->getOperand(0)->isSwiftError()) return false; + + // Treat MMRAs conservatively. This pass can be quite aggressive and + // could drop a lot of MMRAs otherwise. + if (MMRAMetadata(*I) != I0MMRA) + return false; } - // All instructions in Insts are known to be the same opcode. If they have a - // use, check that the only user is a PHI or in the same block as the - // instruction, because if a user is in the same block as an instruction we're - // contemplating sinking, it must already be determined to be sinkable. - if (HasUse) { - auto *PNUse = dyn_cast<PHINode>(*I0->user_begin()); - auto *Succ = I0->getParent()->getTerminator()->getSuccessor(0); - if (!all_of(Insts, [&PNUse,&Succ](const Instruction *I) -> bool { - auto *U = cast<Instruction>(*I->user_begin()); - return (PNUse && - PNUse->getParent() == Succ && - PNUse->getIncomingValueForBlock(I->getParent()) == I) || - U->getParent() == I->getParent(); - })) + // Uses must be consistent: If I0 is used in a phi node in the sink target, + // then the other phi operands must match the instructions from Insts. This + // also has to hold true for any phi nodes that would be created as a result + // of sinking. Both of these cases are represented by PhiOperands. + for (const Use &U : I0->uses()) { + auto It = PHIOperands.find(&U); + if (It == PHIOperands.end()) + // There may be uses in other blocks when sinking into a loop header. + return false; + if (!equal(Insts, It->second)) return false; } @@ -1988,8 +2057,9 @@ static bool canSinkInstructions( !canReplaceOperandWithVariable(I0, OI)) // We can't create a PHI from this GEP. return false; + auto &Ops = PHIOperands[&I0->getOperandUse(OI)]; for (auto *I : Insts) - PHIOperands[I].push_back(I->getOperand(OI)); + Ops.push_back(I->getOperand(OI)); } } return true; @@ -1998,7 +2068,7 @@ static bool canSinkInstructions( // Assuming canSinkInstructions(Blocks) has returned true, sink the last // instruction of every block in Blocks to their common successor, commoning // into one instruction. -static bool sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) { +static void sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) { auto *BBEnd = Blocks[0]->getTerminator()->getSuccessor(0); // canSinkInstructions returning true guarantees that every block has at @@ -2013,23 +2083,10 @@ static bool sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) { Insts.push_back(I); } - // The only checking we need to do now is that all users of all instructions - // are the same PHI node. canSinkInstructions should have checked this but - // it is slightly over-aggressive - it gets confused by commutative - // instructions so double-check it here. - Instruction *I0 = Insts.front(); - if (!I0->user_empty()) { - auto *PNUse = dyn_cast<PHINode>(*I0->user_begin()); - if (!all_of(Insts, [&PNUse](const Instruction *I) -> bool { - auto *U = cast<Instruction>(*I->user_begin()); - return U == PNUse; - })) - return false; - } - // We don't need to do any more checking here; canSinkInstructions should // have done it all for us. SmallVector<Value*, 4> NewOperands; + Instruction *I0 = Insts.front(); for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O) { // This check is different to that in canSinkInstructions. There, we // cared about the global view once simplifycfg (and instcombine) have @@ -2078,11 +2135,11 @@ static bool sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) { I0->andIRFlags(I); } - if (!I0->user_empty()) { - // canSinkLastInstruction checked that all instructions were used by - // one and only one PHI node. Find that now, RAUW it to our common - // instruction and nuke it. - auto *PN = cast<PHINode>(*I0->user_begin()); + for (User *U : make_early_inc_range(I0->users())) { + // canSinkLastInstruction checked that all instructions are only used by + // phi nodes in a way that allows replacing the phi node with the common + // instruction. + auto *PN = cast<PHINode>(U); PN->replaceAllUsesWith(I0); PN->eraseFromParent(); } @@ -2097,8 +2154,6 @@ static bool sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) { I->replaceAllUsesWith(I0); I->eraseFromParent(); } - - return true; } namespace { @@ -2239,9 +2294,19 @@ static bool SinkCommonCodeFromPredecessors(BasicBlock *BB, // carry on. If we can sink an instruction but need to PHI-merge some operands // (because they're not identical in each instruction) we add these to // PHIOperands. + // We prepopulate PHIOperands with the phis that already exist in BB. + DenseMap<const Use *, SmallVector<Value *, 4>> PHIOperands; + for (PHINode &PN : BB->phis()) { + SmallDenseMap<BasicBlock *, const Use *, 4> IncomingVals; + for (const Use &U : PN.incoming_values()) + IncomingVals.insert({PN.getIncomingBlock(U), &U}); + auto &Ops = PHIOperands[IncomingVals[UnconditionalPreds[0]]]; + for (BasicBlock *Pred : UnconditionalPreds) + Ops.push_back(*IncomingVals[Pred]); + } + int ScanIdx = 0; SmallPtrSet<Value*,4> InstructionsToSink; - DenseMap<Instruction*, SmallVector<Value*,4>> PHIOperands; LockstepReverseIterator LRI(UnconditionalPreds); while (LRI.isValid() && canSinkInstructions(*LRI, PHIOperands)) { @@ -2263,20 +2328,19 @@ static bool SinkCommonCodeFromPredecessors(BasicBlock *BB, // actually sink before encountering instruction that is unprofitable to // sink? auto ProfitableToSinkInstruction = [&](LockstepReverseIterator &LRI) { - unsigned NumPHIdValues = 0; - for (auto *I : *LRI) - for (auto *V : PHIOperands[I]) { - if (!InstructionsToSink.contains(V)) - ++NumPHIdValues; + unsigned NumPHIInsts = 0; + for (Use &U : (*LRI)[0]->operands()) { + auto It = PHIOperands.find(&U); + if (It != PHIOperands.end() && !all_of(It->second, [&](Value *V) { + return InstructionsToSink.contains(V); + })) { + ++NumPHIInsts; // FIXME: this check is overly optimistic. We may end up not sinking // said instruction, due to the very same profitability check. // See @creating_too_many_phis in sink-common-code.ll. } - LLVM_DEBUG(dbgs() << "SINK: #phid values: " << NumPHIdValues << "\n"); - unsigned NumPHIInsts = NumPHIdValues / UnconditionalPreds.size(); - if ((NumPHIdValues % UnconditionalPreds.size()) != 0) - NumPHIInsts++; - + } + LLVM_DEBUG(dbgs() << "SINK: #phi insts: " << NumPHIInsts << "\n"); return NumPHIInsts <= 1; }; @@ -2401,13 +2465,7 @@ static bool SinkCommonCodeFromPredecessors(BasicBlock *BB, // sink is always at index 0. LRI.reset(); - if (!sinkLastInstruction(UnconditionalPreds)) { - LLVM_DEBUG( - dbgs() - << "SINK: stopping here, failed to actually sink instruction!\n"); - break; - } - + sinkLastInstruction(UnconditionalPreds); NumSinkCommonInstrs++; Changed = true; } @@ -2643,7 +2701,7 @@ static void MergeCompatibleInvokesImpl(ArrayRef<InvokeInst *> Invokes, // Form a PHI out of all the data ops under this index. PHINode *PN = PHINode::Create( - U->getType(), /*NumReservedValues=*/Invokes.size(), "", MergedInvoke); + U->getType(), /*NumReservedValues=*/Invokes.size(), "", MergedInvoke->getIterator()); for (InvokeInst *II : Invokes) PN->addIncoming(II->getOperand(U.getOperandNo()), II->getParent()); @@ -2819,7 +2877,8 @@ static Value *isSafeToSpeculateStore(Instruction *I, BasicBlock *BrBB, // simple, to avoid introducing a spurious non-atomic write after an // atomic write. if (SI->getPointerOperand() == StorePtr && - SI->getValueOperand()->getType() == StoreTy && SI->isSimple()) + SI->getValueOperand()->getType() == StoreTy && SI->isSimple() && + SI->getAlign() >= StoreToHoist->getAlign()) // Found the previous store, return its value operand. return SI->getValueOperand(); return nullptr; // Unknown store. @@ -2827,7 +2886,7 @@ static Value *isSafeToSpeculateStore(Instruction *I, BasicBlock *BrBB, if (auto *LI = dyn_cast<LoadInst>(&CurI)) { if (LI->getPointerOperand() == StorePtr && LI->getType() == StoreTy && - LI->isSimple()) { + LI->isSimple() && LI->getAlign() >= StoreToHoist->getAlign()) { // Local objects (created by an `alloca` instruction) are always // writable, so once we are past a read from a location it is valid to // also write to that same location. @@ -3111,7 +3170,7 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, DbgAssign->replaceVariableLocationOp(OrigV, S); }; for_each(at::getAssignmentMarkers(SpeculatedStore), replaceVariable); - for_each(at::getDPVAssignmentMarkers(SpeculatedStore), replaceVariable); + for_each(at::getDVRAssignmentMarkers(SpeculatedStore), replaceVariable); } // Metadata can be dependent on the condition we are hoisting above. @@ -3136,13 +3195,16 @@ bool SimplifyCFGOpt::SpeculativelyExecuteBB(BranchInst *BI, } // Hoist the instructions. - // In "RemoveDIs" non-instr debug-info mode, drop DPValues attached to these - // instructions, in the same way that dbg.value intrinsics are dropped at the - // end of this block. + // In "RemoveDIs" non-instr debug-info mode, drop DbgVariableRecords attached + // to these instructions, in the same way that dbg.value intrinsics are + // dropped at the end of this block. for (auto &It : make_range(ThenBB->begin(), ThenBB->end())) - for (DPValue &DPV : make_early_inc_range(It.getDbgValueRange())) - if (!DPV.isDbgAssign()) - It.dropOneDbgValue(&DPV); + for (DbgRecord &DR : make_early_inc_range(It.getDbgRecordRange())) + // Drop all records except assign-kind DbgVariableRecords (dbg.assign + // equivalent). + if (DbgVariableRecord *DVR = dyn_cast<DbgVariableRecord>(&DR); + !DVR || !DVR->isDbgAssign()) + It.dropOneDbgRecord(&DR); BB->splice(BI->getIterator(), ThenBB, ThenBB->begin(), std::prev(ThenBB->end())); @@ -3316,7 +3378,7 @@ FoldCondBranchOnValueKnownInPredecessorImpl(BranchInst *BI, DomTreeUpdater *DTU, TranslateMap[Cond] = CB; // RemoveDIs: track instructions that we optimise away while folding, so - // that we can copy DPValues from them later. + // that we can copy DbgVariableRecords from them later. BasicBlock::iterator SrcDbgCursor = BB->begin(); for (BasicBlock::iterator BBI = BB->begin(); &*BBI != BI; ++BBI) { if (PHINode *PN = dyn_cast<PHINode>(BBI)) { @@ -3414,7 +3476,8 @@ static bool FoldCondBranchOnValueKnownInPredecessor(BranchInst *BI, /// Given a BB that starts with the specified two-entry PHI node, /// see if we can eliminate it. static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, - DomTreeUpdater *DTU, const DataLayout &DL) { + DomTreeUpdater *DTU, const DataLayout &DL, + bool SpeculateUnpredictables) { // Ok, this is a two entry PHI node. Check to see if this is a simple "if // statement", which has a very simple dominance structure. Basically, we // are trying to find the condition that is being branched on, which @@ -3446,7 +3509,8 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, // jump to one specific 'then' block (if we have two of them). // It isn't beneficial to speculatively execute the code // from the block that we know is predictably not entered. - if (!DomBI->getMetadata(LLVMContext::MD_unpredictable)) { + bool IsUnpredictable = DomBI->getMetadata(LLVMContext::MD_unpredictable); + if (!IsUnpredictable) { uint64_t TWeight, FWeight; if (extractBranchWeights(*DomBI, TWeight, FWeight) && (TWeight + FWeight) != 0) { @@ -3489,6 +3553,8 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, InstructionCost Cost = 0; InstructionCost Budget = TwoEntryPHINodeFoldingThreshold * TargetTransformInfo::TCC_Basic; + if (SpeculateUnpredictables && IsUnpredictable) + Budget += TTI.getBranchMispredictPenalty(); bool Changed = false; for (BasicBlock::iterator II = BB->begin(); isa<PHINode>(II);) { @@ -3558,8 +3624,9 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, [](BasicBlock *IfBlock) { return IfBlock->hasAddressTaken(); })) return Changed; - LLVM_DEBUG(dbgs() << "FOUND IF CONDITION! " << *IfCond - << " T: " << IfTrue->getName() + LLVM_DEBUG(dbgs() << "FOUND IF CONDITION! " << *IfCond; + if (IsUnpredictable) dbgs() << " (unpredictable)"; + dbgs() << " T: " << IfTrue->getName() << " F: " << IfFalse->getName() << "\n"); // If we can still promote the PHI nodes after this gauntlet of tests, @@ -3754,7 +3821,7 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI, FitWeights(NewWeights); SmallVector<uint32_t, 8> MDWeights(NewWeights.begin(), NewWeights.end()); - setBranchWeights(PBI, MDWeights[0], MDWeights[1]); + setBranchWeights(PBI, MDWeights[0], MDWeights[1], /*IsExpected=*/false); // TODO: If BB is reachable from all paths through PredBlock, then we // could replace PBI's branch probabilities with BI's. @@ -3780,9 +3847,10 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI, if (PredBlock->IsNewDbgInfoFormat) { PredBlock->getTerminator()->cloneDebugInfoFrom(BB->getTerminator()); - for (DPValue &DPV : PredBlock->getTerminator()->getDbgValueRange()) { - RemapDPValue(M, &DPV, VMap, - RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + for (DbgVariableRecord &DVR : + filterDbgVars(PredBlock->getTerminator()->getDbgRecordRange())) { + RemapDbgRecord(M, &DVR, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); } } @@ -4490,7 +4558,7 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, // Halve the weights if any of them cannot fit in an uint32_t FitWeights(NewWeights); - setBranchWeights(PBI, NewWeights[0], NewWeights[1]); + setBranchWeights(PBI, NewWeights[0], NewWeights[1], /*IsExpected=*/false); } // OtherDest may have phi nodes. If so, add an entry from PBI's @@ -4526,7 +4594,8 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, FitWeights(NewWeights); - setBranchWeights(NV, NewWeights[0], NewWeights[1]); + setBranchWeights(NV, NewWeights[0], NewWeights[1], + /*IsExpected=*/false); } } } @@ -4589,12 +4658,12 @@ bool SimplifyCFGOpt::SimplifyTerminatorOnSelect(Instruction *OldTerm, // Create a conditional branch sharing the condition of the select. BranchInst *NewBI = Builder.CreateCondBr(Cond, TrueBB, FalseBB); if (TrueWeight != FalseWeight) - setBranchWeights(NewBI, TrueWeight, FalseWeight); + setBranchWeights(NewBI, TrueWeight, FalseWeight, /*IsExpected=*/false); } } else if (KeepEdge1 && (KeepEdge2 || TrueBB == FalseBB)) { // Neither of the selected blocks were successors, so this // terminator must be unreachable. - new UnreachableInst(OldTerm->getContext(), OldTerm); + new UnreachableInst(OldTerm->getContext(), OldTerm->getIterator()); } else { // One of the selected values was a successor, but the other wasn't. // Insert an unconditional branch to the one that was found; @@ -4838,7 +4907,7 @@ bool SimplifyCFGOpt::SimplifyBranchOnICmpChain(BranchInst *BI, // There might be duplicate constants in the list, which the switch // instruction can't handle, remove them now. array_pod_sort(Values.begin(), Values.end(), ConstantIntSortPredicate); - Values.erase(std::unique(Values.begin(), Values.end()), Values.end()); + Values.erase(llvm::unique(Values), Values.end()); // If Extra was used, we require at least two switch values to do the // transformation. A switch with one value is just a conditional branch. @@ -5235,11 +5304,11 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) { // Ensure that any debug-info records that used to occur after the Unreachable // are moved to in front of it -- otherwise they'll "dangle" at the end of // the block. - BB->flushTerminatorDbgValues(); + BB->flushTerminatorDbgRecords(); // Debug-info records on the unreachable inst itself should be deleted, as // below we delete everything past the final executable instruction. - UI->dropDbgValues(); + UI->dropDbgRecords(); // If there are any instructions immediately before the unreachable that can // be removed, do so. @@ -5257,9 +5326,9 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) { // block will be the unwind edges of Invoke/CatchSwitch/CleanupReturn, // and we can therefore guarantee this block will be erased. - // If we're deleting this, we're deleting any subsequent dbg.values, so - // delete DPValue records of variable information. - BBI->dropDbgValues(); + // If we're deleting this, we're deleting any subsequent debug info, so + // delete DbgRecords. + BBI->dropDbgRecords(); // Delete this instruction (any uses are guaranteed to be dead) BBI->replaceAllUsesWith(PoisonValue::get(BBI->getType())); @@ -5275,8 +5344,7 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) { std::vector<DominatorTree::UpdateType> Updates; SmallSetVector<BasicBlock *, 8> Preds(pred_begin(BB), pred_end(BB)); - for (unsigned i = 0, e = Preds.size(); i != e; ++i) { - auto *Predecessor = Preds[i]; + for (BasicBlock *Predecessor : Preds) { Instruction *TI = Predecessor->getTerminator(); IRBuilder<> Builder(TI); if (auto *BI = dyn_cast<BranchInst>(TI)) { @@ -5284,7 +5352,7 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) { // or a degenerate conditional branch with matching destinations. if (all_of(BI->successors(), [BB](auto *Successor) { return Successor == BB; })) { - new UnreachableInst(TI->getContext(), TI); + new UnreachableInst(TI->getContext(), TI->getIterator()); TI->eraseFromParent(); Changed = true; } else { @@ -5383,7 +5451,7 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) { removeUnwindEdge(EHPred, DTU); } // The catchswitch is no longer reachable. - new UnreachableInst(CSI->getContext(), CSI); + new UnreachableInst(CSI->getContext(), CSI->getIterator()); CSI->eraseFromParent(); Changed = true; } @@ -5393,7 +5461,7 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) { "Expected to always have an unwind to BB."); if (DTU) Updates.push_back({DominatorTree::Delete, Predecessor, BB}); - new UnreachableInst(TI->getContext(), TI); + new UnreachableInst(TI->getContext(), TI->getIterator()); TI->eraseFromParent(); Changed = true; } @@ -5423,11 +5491,13 @@ static bool CasesAreContiguous(SmallVectorImpl<ConstantInt *> &Cases) { } static void createUnreachableSwitchDefault(SwitchInst *Switch, - DomTreeUpdater *DTU) { + DomTreeUpdater *DTU, + bool RemoveOrigDefaultBlock = true) { LLVM_DEBUG(dbgs() << "SimplifyCFG: switch default is dead.\n"); auto *BB = Switch->getParent(); auto *OrigDefaultBlock = Switch->getDefaultDest(); - OrigDefaultBlock->removePredecessor(BB); + if (RemoveOrigDefaultBlock) + OrigDefaultBlock->removePredecessor(BB); BasicBlock *NewDefaultBlock = BasicBlock::Create( BB->getContext(), BB->getName() + ".unreachabledefault", BB->getParent(), OrigDefaultBlock); @@ -5436,7 +5506,8 @@ static void createUnreachableSwitchDefault(SwitchInst *Switch, if (DTU) { SmallVector<DominatorTree::UpdateType, 2> Updates; Updates.push_back({DominatorTree::Insert, BB, &*NewDefaultBlock}); - if (!is_contained(successors(BB), OrigDefaultBlock)) + if (RemoveOrigDefaultBlock && + !is_contained(successors(BB), OrigDefaultBlock)) Updates.push_back({DominatorTree::Delete, BB, &*OrigDefaultBlock}); DTU->applyUpdates(Updates); } @@ -5536,7 +5607,7 @@ bool SimplifyCFGOpt::TurnSwitchRangeIntoICmp(SwitchInst *SI, TrueWeight /= 2; FalseWeight /= 2; } - setBranchWeights(NewBI, TrueWeight, FalseWeight); + setBranchWeights(NewBI, TrueWeight, FalseWeight, /*IsExpected=*/false); } } @@ -5618,10 +5689,33 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU, Known.getBitWidth() - (Known.Zero | Known.One).popcount(); assert(NumUnknownBits <= Known.getBitWidth()); if (HasDefault && DeadCases.empty() && - NumUnknownBits < 64 /* avoid overflow */ && - SI->getNumCases() == (1ULL << NumUnknownBits)) { - createUnreachableSwitchDefault(SI, DTU); - return true; + NumUnknownBits < 64 /* avoid overflow */) { + uint64_t AllNumCases = 1ULL << NumUnknownBits; + if (SI->getNumCases() == AllNumCases) { + createUnreachableSwitchDefault(SI, DTU); + return true; + } + // When only one case value is missing, replace default with that case. + // Eliminating the default branch will provide more opportunities for + // optimization, such as lookup tables. + if (SI->getNumCases() == AllNumCases - 1) { + assert(NumUnknownBits > 1 && "Should be canonicalized to a branch"); + IntegerType *CondTy = cast<IntegerType>(Cond->getType()); + if (CondTy->getIntegerBitWidth() > 64 || + !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth())) + return false; + + uint64_t MissingCaseVal = 0; + for (const auto &Case : SI->cases()) + MissingCaseVal ^= Case.getCaseValue()->getValue().getLimitedValue(); + auto *MissingCase = + cast<ConstantInt>(ConstantInt::get(Cond->getType(), MissingCaseVal)); + SwitchInstProfUpdateWrapper SIW(*SI); + SIW.addCase(MissingCase, SI->getDefaultDest(), SIW.getSuccessorWeight(0)); + createUnreachableSwitchDefault(SI, DTU, /*RemoveOrigDefaultBlock*/ false); + SIW.setSuccessorWeight(0, 0); + return true; + } } if (DeadCases.empty()) @@ -5728,7 +5822,8 @@ static bool ForwardSwitchConditionToPHI(SwitchInst *SI) { for (auto &ForwardingNode : ForwardingNodes) { PHINode *Phi = ForwardingNode.first; SmallVectorImpl<int> &Indexes = ForwardingNode.second; - if (Indexes.size() < 2) + // Check if it helps to fold PHI. + if (Indexes.size() < 2 && !llvm::is_contained(Phi->incoming_values(), SI->getCondition())) continue; for (int Index : Indexes) @@ -6502,16 +6597,17 @@ static void reuseTableCompare( Constant *FalseConst = ConstantInt::getFalse(RangeCmp->getType()); // Check if the compare with the default value is constant true or false. - Constant *DefaultConst = ConstantExpr::getICmp(CmpInst->getPredicate(), - DefaultValue, CmpOp1, true); + const DataLayout &DL = PhiBlock->getDataLayout(); + Constant *DefaultConst = ConstantFoldCompareInstOperands( + CmpInst->getPredicate(), DefaultValue, CmpOp1, DL); if (DefaultConst != TrueConst && DefaultConst != FalseConst) return; // Check if the compare with the case values is distinct from the default // compare result. for (auto ValuePair : Values) { - Constant *CaseConst = ConstantExpr::getICmp(CmpInst->getPredicate(), - ValuePair.second, CmpOp1, true); + Constant *CaseConst = ConstantFoldCompareInstOperands( + CmpInst->getPredicate(), ValuePair.second, CmpOp1, DL); if (!CaseConst || CaseConst == DefaultConst || (CaseConst != TrueConst && CaseConst != FalseConst)) return; @@ -6535,7 +6631,7 @@ static void reuseTableCompare( // The compare yields the same result, just inverted. We can replace it. Value *InvertedTableCmp = BinaryOperator::CreateXor( RangeCmp, ConstantInt::get(RangeCmp->getType(), 1), "inverted.cmp", - RangeCheckBranch); + RangeCheckBranch->getIterator()); CmpInst->replaceAllUsesWith(InvertedTableCmp); ++NumTableCmpReuses; } @@ -6638,8 +6734,25 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, TableSize = (MaxCaseVal->getValue() - MinCaseVal->getValue()).getLimitedValue() + 1; + // If the default destination is unreachable, or if the lookup table covers + // all values of the conditional variable, branch directly to the lookup table + // BB. Otherwise, check that the condition is within the case range. + bool DefaultIsReachable = !SI->defaultDestUndefined(); + bool TableHasHoles = (NumResults < TableSize); - bool NeedMask = (TableHasHoles && !HasDefaultResults); + + // If the table has holes but the default destination doesn't produce any + // constant results, the lookup table entries corresponding to the holes will + // contain undefined values. + bool AllHolesAreUndefined = TableHasHoles && !HasDefaultResults; + + // If the default destination doesn't produce a constant result but is still + // reachable, and the lookup table has holes, we need to use a mask to + // determine if the current index should load from the lookup table or jump + // to the default case. + // The mask is unnecessary if the table has holes but the default destination + // is unreachable, as in that case the holes must also be unreachable. + bool NeedMask = AllHolesAreUndefined && DefaultIsReachable; if (NeedMask) { // As an extra penalty for the validity test we require more cases. if (SI->getNumCases() < 4) // FIXME: Find best threshold value (benchmark). @@ -6661,12 +6774,6 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, "It is impossible for a switch to have more entries than the max " "representable value of its input integer type's size."); - // If the default destination is unreachable, or if the lookup table covers - // all values of the conditional variable, branch directly to the lookup table - // BB. Otherwise, check that the condition is within the case range. - bool DefaultIsReachable = - !isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg()); - // Create the BB that does the lookups. Module &Mod = *CommonDest->getParent()->getParent(); BasicBlock *LookupBB = BasicBlock::Create( @@ -6790,8 +6897,9 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, for (PHINode *PHI : PHIs) { const ResultListTy &ResultList = ResultLists[PHI]; - // If using a bitmask, use any value to fill the lookup table holes. - Constant *DV = NeedMask ? ResultLists[PHI][0].second : DefaultResults[PHI]; + // Use any value to fill the lookup table holes. + Constant *DV = + AllHolesAreUndefined ? ResultLists[PHI][0].second : DefaultResults[PHI]; StringRef FuncName = Fn->getName(); SwitchLookupTable Table(Mod, TableSize, TableIndexOffset, ResultList, DV, DL, FuncName); @@ -7088,14 +7196,14 @@ bool SimplifyCFGOpt::simplifyIndirectBr(IndirectBrInst *IBI) { if (IBI->getNumDestinations() == 0) { // If the indirectbr has no successors, change it to unreachable. - new UnreachableInst(IBI->getContext(), IBI); + new UnreachableInst(IBI->getContext(), IBI->getIterator()); EraseTerminatorAndDCECond(IBI); return true; } if (IBI->getNumDestinations() == 1) { // If the indirectbr has one successor, change it to a direct branch. - BranchInst::Create(IBI->getDestination(0), IBI); + BranchInst::Create(IBI->getDestination(0), IBI->getIterator()); EraseTerminatorAndDCECond(IBI); return true; } @@ -7257,6 +7365,95 @@ static BasicBlock *allPredecessorsComeFromSameSource(BasicBlock *BB) { return PredPred; } +/// Fold the following pattern: +/// bb0: +/// br i1 %cond1, label %bb1, label %bb2 +/// bb1: +/// br i1 %cond2, label %bb3, label %bb4 +/// bb2: +/// br i1 %cond2, label %bb4, label %bb3 +/// bb3: +/// ... +/// bb4: +/// ... +/// into +/// bb0: +/// %cond = xor i1 %cond1, %cond2 +/// br i1 %cond, label %bb4, label %bb3 +/// bb3: +/// ... +/// bb4: +/// ... +/// NOTE: %cond2 always dominates the terminator of bb0. +static bool mergeNestedCondBranch(BranchInst *BI, DomTreeUpdater *DTU) { + BasicBlock *BB = BI->getParent(); + BasicBlock *BB1 = BI->getSuccessor(0); + BasicBlock *BB2 = BI->getSuccessor(1); + auto IsSimpleSuccessor = [BB](BasicBlock *Succ, BranchInst *&SuccBI) { + if (Succ == BB) + return false; + if (&Succ->front() != Succ->getTerminator()) + return false; + SuccBI = dyn_cast<BranchInst>(Succ->getTerminator()); + if (!SuccBI || !SuccBI->isConditional()) + return false; + BasicBlock *Succ1 = SuccBI->getSuccessor(0); + BasicBlock *Succ2 = SuccBI->getSuccessor(1); + return Succ1 != Succ && Succ2 != Succ && Succ1 != BB && Succ2 != BB && + !isa<PHINode>(Succ1->front()) && !isa<PHINode>(Succ2->front()); + }; + BranchInst *BB1BI, *BB2BI; + if (!IsSimpleSuccessor(BB1, BB1BI) || !IsSimpleSuccessor(BB2, BB2BI)) + return false; + + if (BB1BI->getCondition() != BB2BI->getCondition() || + BB1BI->getSuccessor(0) != BB2BI->getSuccessor(1) || + BB1BI->getSuccessor(1) != BB2BI->getSuccessor(0)) + return false; + + BasicBlock *BB3 = BB1BI->getSuccessor(0); + BasicBlock *BB4 = BB1BI->getSuccessor(1); + IRBuilder<> Builder(BI); + BI->setCondition( + Builder.CreateXor(BI->getCondition(), BB1BI->getCondition())); + BB1->removePredecessor(BB); + BI->setSuccessor(0, BB4); + BB2->removePredecessor(BB); + BI->setSuccessor(1, BB3); + if (DTU) { + SmallVector<DominatorTree::UpdateType, 4> Updates; + Updates.push_back({DominatorTree::Delete, BB, BB1}); + Updates.push_back({DominatorTree::Insert, BB, BB4}); + Updates.push_back({DominatorTree::Delete, BB, BB2}); + Updates.push_back({DominatorTree::Insert, BB, BB3}); + + DTU->applyUpdates(Updates); + } + bool HasWeight = false; + uint64_t BBTWeight, BBFWeight; + if (extractBranchWeights(*BI, BBTWeight, BBFWeight)) + HasWeight = true; + else + BBTWeight = BBFWeight = 1; + uint64_t BB1TWeight, BB1FWeight; + if (extractBranchWeights(*BB1BI, BB1TWeight, BB1FWeight)) + HasWeight = true; + else + BB1TWeight = BB1FWeight = 1; + uint64_t BB2TWeight, BB2FWeight; + if (extractBranchWeights(*BB2BI, BB2TWeight, BB2FWeight)) + HasWeight = true; + else + BB2TWeight = BB2FWeight = 1; + if (HasWeight) { + uint64_t Weights[2] = {BBTWeight * BB1FWeight + BBFWeight * BB2TWeight, + BBTWeight * BB1TWeight + BBFWeight * BB2FWeight}; + FitWeights(Weights); + setBranchWeights(BI, Weights[0], Weights[1], /*IsExpected=*/false); + } + return true; +} + bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { assert( !isa<ConstantInt>(BI->getCondition()) && @@ -7364,6 +7561,10 @@ bool SimplifyCFGOpt::simplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (mergeConditionalStores(PBI, BI, DTU, DL, TTI)) return requestResimplify(); + // Look for nested conditional branches. + if (mergeNestedCondBranch(BI, DTU)) + return requestResimplify(); + return false; } @@ -7377,11 +7578,31 @@ static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I, bool PtrValu return false; if (C->isNullValue() || isa<UndefValue>(C)) { - // Only look at the first use, avoid hurting compile time with long uselists - auto *Use = cast<Instruction>(*I->user_begin()); + // Only look at the first use we can handle, avoid hurting compile time with + // long uselists + auto FindUse = llvm::find_if(I->users(), [](auto *U) { + auto *Use = cast<Instruction>(U); + // Change this list when we want to add new instructions. + switch (Use->getOpcode()) { + default: + return false; + case Instruction::GetElementPtr: + case Instruction::Ret: + case Instruction::BitCast: + case Instruction::Load: + case Instruction::Store: + case Instruction::Call: + case Instruction::CallBr: + case Instruction::Invoke: + return true; + } + }); + if (FindUse == I->user_end()) + return false; + auto *Use = cast<Instruction>(*FindUse); // Bail out if Use is not in the same BB as I or Use == I or Use comes - // before I in the block. The latter two can be the case if Use is a PHI - // node. + // before I in the block. The latter two can be the case if Use is a + // PHI node. if (Use->getParent() != I->getParent() || Use == I || Use->comesBefore(I)) return false; @@ -7397,11 +7618,34 @@ static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I, bool PtrValu // Look through GEPs. A load from a GEP derived from NULL is still undefined if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Use)) if (GEP->getPointerOperand() == I) { - if (!GEP->isInBounds() || !GEP->hasAllZeroIndices()) + // The current base address is null, there are four cases to consider: + // getelementptr (TY, null, 0) -> null + // getelementptr (TY, null, not zero) -> may be modified + // getelementptr inbounds (TY, null, 0) -> null + // getelementptr inbounds (TY, null, not zero) -> poison iff null is + // undefined? + if (!GEP->hasAllZeroIndices() && + (!GEP->isInBounds() || + NullPointerIsDefined(GEP->getFunction(), + GEP->getPointerAddressSpace()))) PtrValueMayBeModified = true; return passingValueIsAlwaysUndefined(V, GEP, PtrValueMayBeModified); } + // Look through return. + if (ReturnInst *Ret = dyn_cast<ReturnInst>(Use)) { + bool HasNoUndefAttr = + Ret->getFunction()->hasRetAttribute(Attribute::NoUndef); + // Return undefined to a noundef return value is undefined. + if (isa<UndefValue>(C) && HasNoUndefAttr) + return true; + // Return null to a nonnull+noundef return value is undefined. + if (C->isNullValue() && HasNoUndefAttr && + Ret->getFunction()->hasRetAttribute(Attribute::NonNull)) { + return !PtrValueMayBeModified; + } + } + // Look through bitcasts. if (BitCastInst *BC = dyn_cast<BitCastInst>(Use)) return passingValueIsAlwaysUndefined(V, BC, PtrValueMayBeModified); @@ -7419,6 +7663,13 @@ static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I, bool PtrValu SI->getPointerAddressSpace())) && SI->getPointerOperand() == I; + // llvm.assume(false/undef) always triggers immediate UB. + if (auto *Assume = dyn_cast<AssumeInst>(Use)) { + // Ignore assume operand bundles. + if (I == Assume->getArgOperand(0)) + return true; + } + if (auto *CB = dyn_cast<CallBase>(Use)) { if (C->isNullValue() && NullPointerIsDefined(CB->getFunction())) return false; @@ -7568,7 +7819,8 @@ bool SimplifyCFGOpt::simplifyOnce(BasicBlock *BB) { // eliminate it, do so now. if (auto *PN = dyn_cast<PHINode>(BB->begin())) if (PN->getNumIncomingValues() == 2) - if (FoldTwoEntryPHINode(PN, TTI, DTU, DL)) + if (FoldTwoEntryPHINode(PN, TTI, DTU, DL, + Options.SpeculateUnpredictables)) return true; } @@ -7616,7 +7868,7 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) { bool llvm::simplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI, DomTreeUpdater *DTU, const SimplifyCFGOptions &Options, ArrayRef<WeakVH> LoopHeaders) { - return SimplifyCFGOpt(TTI, DTU, BB->getModule()->getDataLayout(), LoopHeaders, + return SimplifyCFGOpt(TTI, DTU, BB->getDataLayout(), LoopHeaders, Options) .run(BB); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp index 1b142f14d811..0b4a75e0bc52 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -25,6 +25,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" using namespace llvm; @@ -59,6 +60,7 @@ namespace { SmallVectorImpl<WeakTrackingVH> &DeadInsts; bool Changed = false; + bool RunUnswitching = false; public: SimplifyIndvar(Loop *Loop, ScalarEvolution *SE, DominatorTree *DT, @@ -71,12 +73,18 @@ namespace { } bool hasChanged() const { return Changed; } + bool runUnswitching() const { return RunUnswitching; } /// Iteratively perform simplification on a worklist of users of the /// specified induction variable. This is the top-level driver that applies /// all simplifications to users of an IV. void simplifyUsers(PHINode *CurrIV, IVVisitor *V = nullptr); + void pushIVUsers(Instruction *Def, + SmallPtrSet<Instruction *, 16> &Simplified, + SmallVectorImpl<std::pair<Instruction *, Instruction *>> + &SimpleIVUsers); + Value *foldIVUser(Instruction *UseInst, Instruction *IVOperand); bool eliminateIdentitySCEV(Instruction *UseInst, Instruction *IVOperand); @@ -232,6 +240,7 @@ bool SimplifyIndvar::makeIVComparisonInvariant(ICmpInst *ICmp, ICmp->setPredicate(InvariantPredicate); ICmp->setOperand(0, NewLHS); ICmp->setOperand(1, NewRHS); + RunUnswitching = true; return true; } @@ -300,9 +309,10 @@ bool SimplifyIndvar::eliminateSDiv(BinaryOperator *SDiv) { if (SE->isKnownNonNegative(N) && SE->isKnownNonNegative(D)) { auto *UDiv = BinaryOperator::Create( BinaryOperator::UDiv, SDiv->getOperand(0), SDiv->getOperand(1), - SDiv->getName() + ".udiv", SDiv); + SDiv->getName() + ".udiv", SDiv->getIterator()); UDiv->setIsExact(SDiv->isExact()); SDiv->replaceAllUsesWith(UDiv); + UDiv->setDebugLoc(SDiv->getDebugLoc()); LLVM_DEBUG(dbgs() << "INDVARS: Simplified sdiv: " << *SDiv << '\n'); ++NumSimplifiedSDiv; Changed = true; @@ -317,8 +327,9 @@ bool SimplifyIndvar::eliminateSDiv(BinaryOperator *SDiv) { void SimplifyIndvar::replaceSRemWithURem(BinaryOperator *Rem) { auto *N = Rem->getOperand(0), *D = Rem->getOperand(1); auto *URem = BinaryOperator::Create(BinaryOperator::URem, N, D, - Rem->getName() + ".urem", Rem); + Rem->getName() + ".urem", Rem->getIterator()); Rem->replaceAllUsesWith(URem); + URem->setDebugLoc(Rem->getDebugLoc()); LLVM_DEBUG(dbgs() << "INDVARS: Simplified srem: " << *Rem << '\n'); ++NumSimplifiedSRem; Changed = true; @@ -338,10 +349,11 @@ void SimplifyIndvar::replaceRemWithNumerator(BinaryOperator *Rem) { void SimplifyIndvar::replaceRemWithNumeratorOrZero(BinaryOperator *Rem) { auto *T = Rem->getType(); auto *N = Rem->getOperand(0), *D = Rem->getOperand(1); - ICmpInst *ICmp = new ICmpInst(Rem, ICmpInst::ICMP_EQ, N, D); + ICmpInst *ICmp = new ICmpInst(Rem->getIterator(), ICmpInst::ICMP_EQ, N, D); SelectInst *Sel = - SelectInst::Create(ICmp, ConstantInt::get(T, 0), N, "iv.rem", Rem); + SelectInst::Create(ICmp, ConstantInt::get(T, 0), N, "iv.rem", Rem->getIterator()); Rem->replaceAllUsesWith(Sel); + Sel->setDebugLoc(Rem->getDebugLoc()); LLVM_DEBUG(dbgs() << "INDVARS: Simplified rem: " << *Rem << '\n'); ++NumElimRem; Changed = true; @@ -410,7 +422,7 @@ bool SimplifyIndvar::eliminateOverflowIntrinsic(WithOverflowInst *WO) { // intrinsic as well. BinaryOperator *NewResult = BinaryOperator::Create( - WO->getBinaryOp(), WO->getLHS(), WO->getRHS(), "", WO); + WO->getBinaryOp(), WO->getLHS(), WO->getRHS(), "", WO->getIterator()); if (WO->isSigned()) NewResult->setHasNoSignedWrap(true); @@ -426,6 +438,7 @@ bool SimplifyIndvar::eliminateOverflowIntrinsic(WithOverflowInst *WO) { else { assert(EVI->getIndices()[0] == 0 && "Only two possibilities!"); EVI->replaceAllUsesWith(NewResult); + NewResult->setDebugLoc(EVI->getDebugLoc()); } ToDelete.push_back(EVI); } @@ -448,13 +461,14 @@ bool SimplifyIndvar::eliminateSaturatingIntrinsic(SaturatingInst *SI) { return false; BinaryOperator *BO = BinaryOperator::Create( - SI->getBinaryOp(), SI->getLHS(), SI->getRHS(), SI->getName(), SI); + SI->getBinaryOp(), SI->getLHS(), SI->getRHS(), SI->getName(), SI->getIterator()); if (SI->isSigned()) BO->setHasNoSignedWrap(); else BO->setHasNoUnsignedWrap(); SI->replaceAllUsesWith(BO); + BO->setDebugLoc(SI->getDebugLoc()); DeadInsts.emplace_back(SI); Changed = true; return true; @@ -643,10 +657,21 @@ bool SimplifyIndvar::replaceIVUserWithLoopInvariant(Instruction *I) { } auto *Invariant = Rewriter.expandCodeFor(S, I->getType(), IP); + bool NeedToEmitLCSSAPhis = false; + if (!LI->replacementPreservesLCSSAForm(I, Invariant)) + NeedToEmitLCSSAPhis = true; I->replaceAllUsesWith(Invariant); LLVM_DEBUG(dbgs() << "INDVARS: Replace IV user: " << *I << " with loop invariant: " << *S << '\n'); + + if (NeedToEmitLCSSAPhis) { + SmallVector<Instruction *, 1> NeedsLCSSAPhis; + NeedsLCSSAPhis.push_back(cast<Instruction>(Invariant)); + formLCSSAForInstructions(NeedsLCSSAPhis, *DT, *LI, SE); + LLVM_DEBUG(dbgs() << " INDVARS: Replacement breaks LCSSA form" + << " inserting LCSSA Phis" << '\n'); + } ++NumFoldedUser; Changed = true; DeadInsts.emplace_back(I); @@ -753,7 +778,7 @@ bool SimplifyIndvar::eliminateIdentitySCEV(Instruction *UseInst, return false; for (Instruction *I : DropPoisonGeneratingInsts) - I->dropPoisonGeneratingFlagsAndMetadata(); + I->dropPoisonGeneratingAnnotations(); } LLVM_DEBUG(dbgs() << "INDVARS: Eliminated identity: " << *UseInst << '\n'); @@ -824,11 +849,9 @@ bool SimplifyIndvar::strengthenRightShift(BinaryOperator *BO, } /// Add all uses of Def to the current IV's worklist. -static void pushIVUsers( - Instruction *Def, Loop *L, - SmallPtrSet<Instruction*,16> &Simplified, - SmallVectorImpl< std::pair<Instruction*,Instruction*> > &SimpleIVUsers) { - +void SimplifyIndvar::pushIVUsers( + Instruction *Def, SmallPtrSet<Instruction *, 16> &Simplified, + SmallVectorImpl<std::pair<Instruction *, Instruction *>> &SimpleIVUsers) { for (User *U : Def->users()) { Instruction *UI = cast<Instruction>(U); @@ -898,7 +921,7 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { // Push users of the current LoopPhi. In rare cases, pushIVUsers may be // called multiple times for the same LoopPhi. This is the proper thing to // do for loop header phis that use each other. - pushIVUsers(CurrIV, L, Simplified, SimpleIVUsers); + pushIVUsers(CurrIV, Simplified, SimpleIVUsers); while (!SimpleIVUsers.empty()) { std::pair<Instruction*, Instruction*> UseOper = @@ -945,7 +968,7 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { continue; if (eliminateIVUser(UseInst, IVOperand)) { - pushIVUsers(IVOperand, L, Simplified, SimpleIVUsers); + pushIVUsers(IVOperand, Simplified, SimpleIVUsers); continue; } @@ -953,14 +976,14 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { if (strengthenBinaryOp(BO, IVOperand)) { // re-queue uses of the now modified binary operator and fall // through to the checks that remain. - pushIVUsers(IVOperand, L, Simplified, SimpleIVUsers); + pushIVUsers(IVOperand, Simplified, SimpleIVUsers); } } // Try to use integer induction for FPToSI of float induction directly. if (replaceFloatIVWithIntegerIV(UseInst)) { // Re-queue the potentially new direct uses of IVOperand. - pushIVUsers(IVOperand, L, Simplified, SimpleIVUsers); + pushIVUsers(IVOperand, Simplified, SimpleIVUsers); continue; } @@ -970,7 +993,7 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { continue; } if (isSimpleIVUser(UseInst, L, SE)) { - pushIVUsers(UseInst, L, Simplified, SimpleIVUsers); + pushIVUsers(UseInst, Simplified, SimpleIVUsers); } } } @@ -981,14 +1004,18 @@ void IVVisitor::anchor() { } /// Simplify instructions that use this induction variable /// by using ScalarEvolution to analyze the IV's recurrence. -bool simplifyUsersOfIV(PHINode *CurrIV, ScalarEvolution *SE, DominatorTree *DT, - LoopInfo *LI, const TargetTransformInfo *TTI, - SmallVectorImpl<WeakTrackingVH> &Dead, - SCEVExpander &Rewriter, IVVisitor *V) { +/// Returns a pair where the first entry indicates that the function makes +/// changes and the second entry indicates that it introduced new opportunities +/// for loop unswitching. +std::pair<bool, bool> simplifyUsersOfIV(PHINode *CurrIV, ScalarEvolution *SE, + DominatorTree *DT, LoopInfo *LI, + const TargetTransformInfo *TTI, + SmallVectorImpl<WeakTrackingVH> &Dead, + SCEVExpander &Rewriter, IVVisitor *V) { SimplifyIndvar SIV(LI->getLoopFor(CurrIV->getParent()), SE, DT, LI, TTI, Rewriter, Dead); SIV.simplifyUsers(CurrIV, V); - return SIV.hasChanged(); + return {SIV.hasChanged(), SIV.runUnswitching()}; } /// Simplify users of induction variables within this @@ -1002,8 +1029,9 @@ bool simplifyLoopIVs(Loop *L, ScalarEvolution *SE, DominatorTree *DT, #endif bool Changed = false; for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) { - Changed |= + const auto &[C, _] = simplifyUsersOfIV(cast<PHINode>(I), SE, DT, LI, TTI, Dead, Rewriter); + Changed |= C; } return Changed; } @@ -1131,7 +1159,9 @@ protected: const SCEV *getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS, unsigned OpCode) const; - Instruction *widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter); + Instruction *widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter, + PHINode *OrigPhi, PHINode *WidePhi); + void truncateIVUse(NarrowIVDefUse DU); bool widenLoopCompare(NarrowIVDefUse DU); bool widenWithVariantUse(NarrowIVDefUse DU); @@ -1368,6 +1398,77 @@ const SCEV *WidenIV::getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS, }; } +namespace { + +// Represents a interesting integer binary operation for +// getExtendedOperandRecurrence. This may be a shl that is being treated as a +// multiply or a 'or disjoint' that is being treated as 'add nsw nuw'. +struct BinaryOp { + unsigned Opcode; + std::array<Value *, 2> Operands; + bool IsNSW = false; + bool IsNUW = false; + + explicit BinaryOp(Instruction *Op) + : Opcode(Op->getOpcode()), + Operands({Op->getOperand(0), Op->getOperand(1)}) { + if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) { + IsNSW = OBO->hasNoSignedWrap(); + IsNUW = OBO->hasNoUnsignedWrap(); + } + } + + explicit BinaryOp(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS, + bool IsNSW = false, bool IsNUW = false) + : Opcode(Opcode), Operands({LHS, RHS}), IsNSW(IsNSW), IsNUW(IsNUW) {} +}; + +} // end anonymous namespace + +static std::optional<BinaryOp> matchBinaryOp(Instruction *Op) { + switch (Op->getOpcode()) { + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + return BinaryOp(Op); + case Instruction::Or: { + // Convert or disjoint into add nuw nsw. + if (cast<PossiblyDisjointInst>(Op)->isDisjoint()) + return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1), + /*IsNSW=*/true, /*IsNUW=*/true); + break; + } + case Instruction::Shl: { + if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) { + unsigned BitWidth = cast<IntegerType>(SA->getType())->getBitWidth(); + + // If the shift count is not less than the bitwidth, the result of + // the shift is undefined. Don't try to analyze it, because the + // resolution chosen here may differ from the resolution chosen in + // other parts of the compiler. + if (SA->getValue().ult(BitWidth)) { + // We can safely preserve the nuw flag in all cases. It's also safe to + // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation + // requires special handling. It can be preserved as long as we're not + // left shifting by bitwidth - 1. + bool IsNUW = Op->hasNoUnsignedWrap(); + bool IsNSW = Op->hasNoSignedWrap() && + (IsNUW || SA->getValue().ult(BitWidth - 1)); + + ConstantInt *X = + ConstantInt::get(Op->getContext(), + APInt::getOneBitSet(BitWidth, SA->getZExtValue())); + return BinaryOp(Instruction::Mul, Op->getOperand(0), X, IsNSW, IsNUW); + } + } + + break; + } + } + + return std::nullopt; +} + /// No-wrap operations can transfer sign extension of their result to their /// operands. Generate the SCEV value for the widened operation without /// actually modifying the IR yet. If the expression after extending the @@ -1375,24 +1476,22 @@ const SCEV *WidenIV::getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS, /// extension used. WidenIV::WidenedRecTy WidenIV::getExtendedOperandRecurrence(WidenIV::NarrowIVDefUse DU) { - // Handle the common case of add<nsw/nuw> - const unsigned OpCode = DU.NarrowUse->getOpcode(); - // Only Add/Sub/Mul instructions supported yet. - if (OpCode != Instruction::Add && OpCode != Instruction::Sub && - OpCode != Instruction::Mul) + auto Op = matchBinaryOp(DU.NarrowUse); + if (!Op) return {nullptr, ExtendKind::Unknown}; + assert((Op->Opcode == Instruction::Add || Op->Opcode == Instruction::Sub || + Op->Opcode == Instruction::Mul) && + "Unexpected opcode"); + // One operand (NarrowDef) has already been extended to WideDef. Now determine // if extending the other will lead to a recurrence. - const unsigned ExtendOperIdx = - DU.NarrowUse->getOperand(0) == DU.NarrowDef ? 1 : 0; - assert(DU.NarrowUse->getOperand(1-ExtendOperIdx) == DU.NarrowDef && "bad DU"); + const unsigned ExtendOperIdx = Op->Operands[0] == DU.NarrowDef ? 1 : 0; + assert(Op->Operands[1 - ExtendOperIdx] == DU.NarrowDef && "bad DU"); - const OverflowingBinaryOperator *OBO = - cast<OverflowingBinaryOperator>(DU.NarrowUse); ExtendKind ExtKind = getExtendKind(DU.NarrowDef); - if (!(ExtKind == ExtendKind::Sign && OBO->hasNoSignedWrap()) && - !(ExtKind == ExtendKind::Zero && OBO->hasNoUnsignedWrap())) { + if (!(ExtKind == ExtendKind::Sign && Op->IsNSW) && + !(ExtKind == ExtendKind::Zero && Op->IsNUW)) { ExtKind = ExtendKind::Unknown; // For a non-negative NarrowDef, we can choose either type of @@ -1400,16 +1499,15 @@ WidenIV::getExtendedOperandRecurrence(WidenIV::NarrowIVDefUse DU) { // (see above), and we only hit this code if we need to check // the opposite case. if (DU.NeverNegative) { - if (OBO->hasNoSignedWrap()) { + if (Op->IsNSW) { ExtKind = ExtendKind::Sign; - } else if (OBO->hasNoUnsignedWrap()) { + } else if (Op->IsNUW) { ExtKind = ExtendKind::Zero; } } } - const SCEV *ExtendOperExpr = - SE->getSCEV(DU.NarrowUse->getOperand(ExtendOperIdx)); + const SCEV *ExtendOperExpr = SE->getSCEV(Op->Operands[ExtendOperIdx]); if (ExtKind == ExtendKind::Sign) ExtendOperExpr = SE->getSignExtendExpr(ExtendOperExpr, WideType); else if (ExtKind == ExtendKind::Zero) @@ -1430,7 +1528,7 @@ WidenIV::getExtendedOperandRecurrence(WidenIV::NarrowIVDefUse DU) { if (ExtendOperIdx == 0) std::swap(lhs, rhs); const SCEVAddRecExpr *AddRec = - dyn_cast<SCEVAddRecExpr>(getSCEVByOpCode(lhs, rhs, OpCode)); + dyn_cast<SCEVAddRecExpr>(getSCEVByOpCode(lhs, rhs, Op->Opcode)); if (!AddRec || AddRec->getLoop() != L) return {nullptr, ExtendKind::Unknown}; @@ -1480,15 +1578,18 @@ WidenIV::WidenedRecTy WidenIV::getWideRecurrence(WidenIV::NarrowIVDefUse DU) { /// This IV user cannot be widened. Replace this use of the original narrow IV /// with a truncation of the new wide IV to isolate and eliminate the narrow IV. -static void truncateIVUse(WidenIV::NarrowIVDefUse DU, DominatorTree *DT, - LoopInfo *LI) { +void WidenIV::truncateIVUse(NarrowIVDefUse DU) { auto *InsertPt = getInsertPointForUses(DU.NarrowUse, DU.NarrowDef, DT, LI); if (!InsertPt) return; LLVM_DEBUG(dbgs() << "INDVARS: Truncate IV " << *DU.WideDef << " for user " << *DU.NarrowUse << "\n"); + ExtendKind ExtKind = getExtendKind(DU.NarrowDef); IRBuilder<> Builder(InsertPt); - Value *Trunc = Builder.CreateTrunc(DU.WideDef, DU.NarrowDef->getType()); + Value *Trunc = + Builder.CreateTrunc(DU.WideDef, DU.NarrowDef->getType(), "", + DU.NeverNegative || ExtKind == ExtendKind::Zero, + DU.NeverNegative || ExtKind == ExtendKind::Sign); DU.NarrowUse->replaceUsesOfWith(DU.NarrowDef, Trunc); } @@ -1731,10 +1832,19 @@ bool WidenIV::widenWithVariantUse(WidenIV::NarrowIVDefUse DU) { /// Determine whether an individual user of the narrow IV can be widened. If so, /// return the wide clone of the user. -Instruction *WidenIV::widenIVUse(WidenIV::NarrowIVDefUse DU, SCEVExpander &Rewriter) { +Instruction *WidenIV::widenIVUse(WidenIV::NarrowIVDefUse DU, + SCEVExpander &Rewriter, PHINode *OrigPhi, + PHINode *WidePhi) { assert(ExtendKindMap.count(DU.NarrowDef) && "Should already know the kind of extension used to widen NarrowDef"); + // This narrow use can be widened by a sext if it's non-negative or its narrow + // def was widened by a sext. Same for zext. + bool CanWidenBySExt = + DU.NeverNegative || getExtendKind(DU.NarrowDef) == ExtendKind::Sign; + bool CanWidenByZExt = + DU.NeverNegative || getExtendKind(DU.NarrowDef) == ExtendKind::Zero; + // Stop traversing the def-use chain at inner-loop phis or post-loop phis. if (PHINode *UsePhi = dyn_cast<PHINode>(DU.NarrowUse)) { if (LI->getLoopFor(UsePhi->getParent()) != L) { @@ -1742,7 +1852,7 @@ Instruction *WidenIV::widenIVUse(WidenIV::NarrowIVDefUse DU, SCEVExpander &Rewri // After SimplifyCFG most loop exit targets have a single predecessor. // Otherwise fall back to a truncate within the loop. if (UsePhi->getNumOperands() != 1) - truncateIVUse(DU, DT, LI); + truncateIVUse(DU); else { // Widening the PHI requires us to insert a trunc. The logical place // for this trunc is in the same BB as the PHI. This is not possible if @@ -1752,11 +1862,12 @@ Instruction *WidenIV::widenIVUse(WidenIV::NarrowIVDefUse DU, SCEVExpander &Rewri PHINode *WidePhi = PHINode::Create(DU.WideDef->getType(), 1, UsePhi->getName() + ".wide", - UsePhi); + UsePhi->getIterator()); WidePhi->addIncoming(DU.WideDef, UsePhi->getIncomingBlock(0)); BasicBlock *WidePhiBB = WidePhi->getParent(); IRBuilder<> Builder(WidePhiBB, WidePhiBB->getFirstInsertionPt()); - Value *Trunc = Builder.CreateTrunc(WidePhi, DU.NarrowDef->getType()); + Value *Trunc = Builder.CreateTrunc(WidePhi, DU.NarrowDef->getType(), "", + CanWidenByZExt, CanWidenBySExt); UsePhi->replaceAllUsesWith(Trunc); DeadInsts.emplace_back(UsePhi); LLVM_DEBUG(dbgs() << "INDVARS: Widen lcssa phi " << *UsePhi << " to " @@ -1766,18 +1877,9 @@ Instruction *WidenIV::widenIVUse(WidenIV::NarrowIVDefUse DU, SCEVExpander &Rewri } } - // This narrow use can be widened by a sext if it's non-negative or its narrow - // def was widened by a sext. Same for zext. - auto canWidenBySExt = [&]() { - return DU.NeverNegative || getExtendKind(DU.NarrowDef) == ExtendKind::Sign; - }; - auto canWidenByZExt = [&]() { - return DU.NeverNegative || getExtendKind(DU.NarrowDef) == ExtendKind::Zero; - }; - // Our raison d'etre! Eliminate sign and zero extension. - if ((match(DU.NarrowUse, m_SExtLike(m_Value())) && canWidenBySExt()) || - (isa<ZExtInst>(DU.NarrowUse) && canWidenByZExt())) { + if ((match(DU.NarrowUse, m_SExtLike(m_Value())) && CanWidenBySExt) || + (isa<ZExtInst>(DU.NarrowUse) && CanWidenByZExt)) { Value *NewDef = DU.WideDef; if (DU.NarrowUse->getType() != WideType) { unsigned CastWidth = SE->getTypeSizeInBits(DU.NarrowUse->getType()); @@ -1785,7 +1887,8 @@ Instruction *WidenIV::widenIVUse(WidenIV::NarrowIVDefUse DU, SCEVExpander &Rewri if (CastWidth < IVWidth) { // The cast isn't as wide as the IV, so insert a Trunc. IRBuilder<> Builder(DU.NarrowUse); - NewDef = Builder.CreateTrunc(DU.WideDef, DU.NarrowUse->getType()); + NewDef = Builder.CreateTrunc(DU.WideDef, DU.NarrowUse->getType(), "", + CanWidenByZExt, CanWidenBySExt); } else { // A wider extend was hidden behind a narrower one. This may induce @@ -1825,11 +1928,24 @@ Instruction *WidenIV::widenIVUse(WidenIV::NarrowIVDefUse DU, SCEVExpander &Rewri if (!WideAddRec.first) return nullptr; - // Reuse the IV increment that SCEVExpander created as long as it dominates - // NarrowUse. + auto CanUseWideInc = [&]() { + if (!WideInc) + return false; + // Reuse the IV increment that SCEVExpander created. Recompute flags, + // unless the flags for both increments agree and it is safe to use the + // ones from the original inc. In that case, the new use of the wide + // increment won't be more poisonous. + bool NeedToRecomputeFlags = + !SCEVExpander::canReuseFlagsFromOriginalIVInc( + OrigPhi, WidePhi, DU.NarrowUse, WideInc) || + DU.NarrowUse->hasNoUnsignedWrap() != WideInc->hasNoUnsignedWrap() || + DU.NarrowUse->hasNoSignedWrap() != WideInc->hasNoSignedWrap(); + return WideAddRec.first == WideIncExpr && + Rewriter.hoistIVInc(WideInc, DU.NarrowUse, NeedToRecomputeFlags); + }; + Instruction *WideUse = nullptr; - if (WideAddRec.first == WideIncExpr && - Rewriter.hoistIVInc(WideInc, DU.NarrowUse)) + if (CanUseWideInc()) WideUse = WideInc; else { WideUse = cloneIVUser(DU, WideAddRec.first); @@ -1877,7 +1993,7 @@ Instruction *WidenIV::widenIVUse(WidenIV::NarrowIVDefUse DU, SCEVExpander &Rewri // This user does not evaluate to a recurrence after widening, so don't // follow it. Instead insert a Trunc to kill off the original use, // eventually isolating the original narrow IV so it can be removed. - truncateIVUse(DU, DT, LI); + truncateIVUse(DU); return nullptr; } @@ -1985,7 +2101,26 @@ PHINode *WidenIV::createWideIV(SCEVExpander &Rewriter) { // increment to the new (widened) increment. auto *OrigInc = cast<Instruction>(OrigPhi->getIncomingValueForBlock(LatchBlock)); + WideInc->setDebugLoc(OrigInc->getDebugLoc()); + // We are replacing a narrow IV increment with a wider IV increment. If + // the original (narrow) increment did not wrap, the wider increment one + // should not wrap either. Set the flags to be the union of both wide + // increment and original increment; this ensures we preserve flags SCEV + // could infer for the wider increment. Limit this only to cases where + // both increments directly increment the corresponding PHI nodes and have + // the same opcode. It is not safe to re-use the flags from the original + // increment, if it is more complex and SCEV expansion may have yielded a + // more simplified wider increment. + if (SCEVExpander::canReuseFlagsFromOriginalIVInc(OrigPhi, WidePhi, + OrigInc, WideInc) && + isa<OverflowingBinaryOperator>(OrigInc) && + isa<OverflowingBinaryOperator>(WideInc)) { + WideInc->setHasNoUnsignedWrap(WideInc->hasNoUnsignedWrap() || + OrigInc->hasNoUnsignedWrap()); + WideInc->setHasNoSignedWrap(WideInc->hasNoSignedWrap() || + OrigInc->hasNoSignedWrap()); + } } } @@ -2003,7 +2138,7 @@ PHINode *WidenIV::createWideIV(SCEVExpander &Rewriter) { // Process a def-use edge. This may replace the use, so don't hold a // use_iterator across it. - Instruction *WideUse = widenIVUse(DU, Rewriter); + Instruction *WideUse = widenIVUse(DU, Rewriter, OrigPhi, WidePhi); // Follow all def-use edges from the previous narrow use. if (WideUse) diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index 52eef9ab58a4..89c8c5bf0895 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -27,6 +27,7 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" @@ -52,6 +53,10 @@ static cl::opt<bool> static cl::opt<bool> OptimizeHotColdNew("optimize-hot-cold-new", cl::Hidden, cl::init(false), cl::desc("Enable hot/cold operator new library calls")); +static cl::opt<bool> OptimizeExistingHotColdNew( + "optimize-existing-hot-cold-new", cl::Hidden, cl::init(false), + cl::desc( + "Enable optimization of existing hot/cold operator new library calls")); namespace { @@ -81,6 +86,10 @@ struct HotColdHintParser : public cl::parser<unsigned> { static cl::opt<unsigned, false, HotColdHintParser> ColdNewHintValue( "cold-new-hint-value", cl::Hidden, cl::init(1), cl::desc("Value to pass to hot/cold operator new for cold allocation")); +static cl::opt<unsigned, false, HotColdHintParser> + NotColdNewHintValue("notcold-new-hint-value", cl::Hidden, cl::init(128), + cl::desc("Value to pass to hot/cold operator new for " + "notcold (warm) allocation")); static cl::opt<unsigned, false, HotColdHintParser> HotNewHintValue( "hot-new-hint-value", cl::Hidden, cl::init(254), cl::desc("Value to pass to hot/cold operator new for hot allocation")); @@ -1722,45 +1731,122 @@ Value *LibCallSimplifier::optimizeNew(CallInst *CI, IRBuilderBase &B, uint8_t HotCold; if (CI->getAttributes().getFnAttr("memprof").getValueAsString() == "cold") HotCold = ColdNewHintValue; + else if (CI->getAttributes().getFnAttr("memprof").getValueAsString() == + "notcold") + HotCold = NotColdNewHintValue; else if (CI->getAttributes().getFnAttr("memprof").getValueAsString() == "hot") HotCold = HotNewHintValue; else return nullptr; + // For calls that already pass a hot/cold hint, only update the hint if + // directed by OptimizeExistingHotColdNew. For other calls to new, add a hint + // if cold or hot, and leave as-is for default handling if "notcold" aka warm. + // Note that in cases where we decide it is "notcold", it might be slightly + // better to replace the hinted call with a non hinted call, to avoid the + // extra paramter and the if condition check of the hint value in the + // allocator. This can be considered in the future. switch (Func) { + case LibFunc_Znwm12__hot_cold_t: + if (OptimizeExistingHotColdNew) + return emitHotColdNew(CI->getArgOperand(0), B, TLI, + LibFunc_Znwm12__hot_cold_t, HotCold); + break; case LibFunc_Znwm: - return emitHotColdNew(CI->getArgOperand(0), B, TLI, - LibFunc_Znwm12__hot_cold_t, HotCold); + if (HotCold != NotColdNewHintValue) + return emitHotColdNew(CI->getArgOperand(0), B, TLI, + LibFunc_Znwm12__hot_cold_t, HotCold); + break; + case LibFunc_Znam12__hot_cold_t: + if (OptimizeExistingHotColdNew) + return emitHotColdNew(CI->getArgOperand(0), B, TLI, + LibFunc_Znam12__hot_cold_t, HotCold); + break; case LibFunc_Znam: - return emitHotColdNew(CI->getArgOperand(0), B, TLI, - LibFunc_Znam12__hot_cold_t, HotCold); + if (HotCold != NotColdNewHintValue) + return emitHotColdNew(CI->getArgOperand(0), B, TLI, + LibFunc_Znam12__hot_cold_t, HotCold); + break; + case LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t: + if (OptimizeExistingHotColdNew) + return emitHotColdNewNoThrow( + CI->getArgOperand(0), CI->getArgOperand(1), B, TLI, + LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t, HotCold); + break; case LibFunc_ZnwmRKSt9nothrow_t: - return emitHotColdNewNoThrow(CI->getArgOperand(0), CI->getArgOperand(1), B, - TLI, LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t, - HotCold); + if (HotCold != NotColdNewHintValue) + return emitHotColdNewNoThrow( + CI->getArgOperand(0), CI->getArgOperand(1), B, TLI, + LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t, HotCold); + break; + case LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t: + if (OptimizeExistingHotColdNew) + return emitHotColdNewNoThrow( + CI->getArgOperand(0), CI->getArgOperand(1), B, TLI, + LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t, HotCold); + break; case LibFunc_ZnamRKSt9nothrow_t: - return emitHotColdNewNoThrow(CI->getArgOperand(0), CI->getArgOperand(1), B, - TLI, LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t, - HotCold); + if (HotCold != NotColdNewHintValue) + return emitHotColdNewNoThrow( + CI->getArgOperand(0), CI->getArgOperand(1), B, TLI, + LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t, HotCold); + break; + case LibFunc_ZnwmSt11align_val_t12__hot_cold_t: + if (OptimizeExistingHotColdNew) + return emitHotColdNewAligned( + CI->getArgOperand(0), CI->getArgOperand(1), B, TLI, + LibFunc_ZnwmSt11align_val_t12__hot_cold_t, HotCold); + break; case LibFunc_ZnwmSt11align_val_t: - return emitHotColdNewAligned(CI->getArgOperand(0), CI->getArgOperand(1), B, - TLI, LibFunc_ZnwmSt11align_val_t12__hot_cold_t, - HotCold); + if (HotCold != NotColdNewHintValue) + return emitHotColdNewAligned( + CI->getArgOperand(0), CI->getArgOperand(1), B, TLI, + LibFunc_ZnwmSt11align_val_t12__hot_cold_t, HotCold); + break; + case LibFunc_ZnamSt11align_val_t12__hot_cold_t: + if (OptimizeExistingHotColdNew) + return emitHotColdNewAligned( + CI->getArgOperand(0), CI->getArgOperand(1), B, TLI, + LibFunc_ZnamSt11align_val_t12__hot_cold_t, HotCold); + break; case LibFunc_ZnamSt11align_val_t: - return emitHotColdNewAligned(CI->getArgOperand(0), CI->getArgOperand(1), B, - TLI, LibFunc_ZnamSt11align_val_t12__hot_cold_t, - HotCold); + if (HotCold != NotColdNewHintValue) + return emitHotColdNewAligned( + CI->getArgOperand(0), CI->getArgOperand(1), B, TLI, + LibFunc_ZnamSt11align_val_t12__hot_cold_t, HotCold); + break; + case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t: + if (OptimizeExistingHotColdNew) + return emitHotColdNewAlignedNoThrow( + CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B, + TLI, LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t, + HotCold); + break; case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t: - return emitHotColdNewAlignedNoThrow( - CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B, - TLI, LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t, HotCold); + if (HotCold != NotColdNewHintValue) + return emitHotColdNewAlignedNoThrow( + CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B, + TLI, LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t, + HotCold); + break; + case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t: + if (OptimizeExistingHotColdNew) + return emitHotColdNewAlignedNoThrow( + CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B, + TLI, LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t, + HotCold); + break; case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t: - return emitHotColdNewAlignedNoThrow( - CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B, - TLI, LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t, HotCold); + if (HotCold != NotColdNewHintValue) + return emitHotColdNewAlignedNoThrow( + CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B, + TLI, LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t, + HotCold); + break; default: return nullptr; } + return nullptr; } //===----------------------------------------------------------------------===// @@ -1770,14 +1856,7 @@ Value *LibCallSimplifier::optimizeNew(CallInst *CI, IRBuilderBase &B, // Replace a libcall \p CI with a call to intrinsic \p IID static Value *replaceUnaryCall(CallInst *CI, IRBuilderBase &B, Intrinsic::ID IID) { - // Propagate fast-math flags from the existing call to the new call. - IRBuilderBase::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - - Module *M = CI->getModule(); - Value *V = CI->getArgOperand(0); - Function *F = Intrinsic::getDeclaration(M, IID, CI->getType()); - CallInst *NewCall = B.CreateCall(F, V); + CallInst *NewCall = B.CreateUnaryIntrinsic(IID, CI->getArgOperand(0), CI); NewCall->takeName(CI); return copyFlags(*CI, NewCall); } @@ -1880,69 +1959,59 @@ static Value *optimizeBinaryDoubleFP(CallInst *CI, IRBuilderBase &B, // cabs(z) -> sqrt((creal(z)*creal(z)) + (cimag(z)*cimag(z))) Value *LibCallSimplifier::optimizeCAbs(CallInst *CI, IRBuilderBase &B) { - if (!CI->isFast()) - return nullptr; - - // Propagate fast-math flags from the existing call to new instructions. - IRBuilderBase::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - Value *Real, *Imag; + if (CI->arg_size() == 1) { + + if (!CI->isFast()) + return nullptr; + Value *Op = CI->getArgOperand(0); assert(Op->getType()->isArrayTy() && "Unexpected signature for cabs!"); + Real = B.CreateExtractValue(Op, 0, "real"); Imag = B.CreateExtractValue(Op, 1, "imag"); + } else { assert(CI->arg_size() == 2 && "Unexpected signature for cabs!"); + Real = CI->getArgOperand(0); Imag = CI->getArgOperand(1); - } - Value *RealReal = B.CreateFMul(Real, Real); - Value *ImagImag = B.CreateFMul(Imag, Imag); + // if real or imaginary part is zero, simplify to abs(cimag(z)) + // or abs(creal(z)) + Value *AbsOp = nullptr; + if (ConstantFP *ConstReal = dyn_cast<ConstantFP>(Real)) { + if (ConstReal->isZero()) + AbsOp = Imag; - Function *FSqrt = Intrinsic::getDeclaration(CI->getModule(), Intrinsic::sqrt, - CI->getType()); - return copyFlags( - *CI, B.CreateCall(FSqrt, B.CreateFAdd(RealReal, ImagImag), "cabs")); -} + } else if (ConstantFP *ConstImag = dyn_cast<ConstantFP>(Imag)) { + if (ConstImag->isZero()) + AbsOp = Real; + } -static Value *optimizeTrigReflections(CallInst *Call, LibFunc Func, - IRBuilderBase &B) { - if (!isa<FPMathOperator>(Call)) - return nullptr; + if (AbsOp) { + IRBuilderBase::FastMathFlagGuard Guard(B); + B.setFastMathFlags(CI->getFastMathFlags()); - IRBuilderBase::FastMathFlagGuard Guard(B); - B.setFastMathFlags(Call->getFastMathFlags()); + return copyFlags( + *CI, B.CreateUnaryIntrinsic(Intrinsic::fabs, AbsOp, nullptr, "cabs")); + } - // TODO: Can this be shared to also handle LLVM intrinsics? - Value *X; - switch (Func) { - case LibFunc_sin: - case LibFunc_sinf: - case LibFunc_sinl: - case LibFunc_tan: - case LibFunc_tanf: - case LibFunc_tanl: - // sin(-X) --> -sin(X) - // tan(-X) --> -tan(X) - if (match(Call->getArgOperand(0), m_OneUse(m_FNeg(m_Value(X))))) - return B.CreateFNeg( - copyFlags(*Call, B.CreateCall(Call->getCalledFunction(), X))); - break; - case LibFunc_cos: - case LibFunc_cosf: - case LibFunc_cosl: - // cos(-X) --> cos(X) - if (match(Call->getArgOperand(0), m_FNeg(m_Value(X)))) - return copyFlags(*Call, - B.CreateCall(Call->getCalledFunction(), X, "cos")); - break; - default: - break; + if (!CI->isFast()) + return nullptr; } - return nullptr; + + // Propagate fast-math flags from the existing call to new instructions. + IRBuilderBase::FastMathFlagGuard Guard(B); + B.setFastMathFlags(CI->getFastMathFlags()); + + Value *RealReal = B.CreateFMul(Real, Real); + Value *ImagImag = B.CreateFMul(Imag, Imag); + + return copyFlags(*CI, B.CreateUnaryIntrinsic(Intrinsic::sqrt, + B.CreateFAdd(RealReal, ImagImag), + nullptr, "cabs")); } // Return a properly extended integer (DstWidth bits wide) if the operation is @@ -1952,11 +2021,12 @@ static Value *getIntToFPVal(Value *I2F, IRBuilderBase &B, unsigned DstWidth) { Value *Op = cast<Instruction>(I2F)->getOperand(0); // Make sure that the exponent fits inside an "int" of size DstWidth, // thus avoiding any range issues that FP has not. - unsigned BitWidth = Op->getType()->getPrimitiveSizeInBits(); - if (BitWidth < DstWidth || - (BitWidth == DstWidth && isa<SIToFPInst>(I2F))) - return isa<SIToFPInst>(I2F) ? B.CreateSExt(Op, B.getIntNTy(DstWidth)) - : B.CreateZExt(Op, B.getIntNTy(DstWidth)); + unsigned BitWidth = Op->getType()->getScalarSizeInBits(); + if (BitWidth < DstWidth || (BitWidth == DstWidth && isa<SIToFPInst>(I2F))) { + Type *IntTy = Op->getType()->getWithNewBitWidth(DstWidth); + return isa<SIToFPInst>(I2F) ? B.CreateSExt(Op, IntTy) + : B.CreateZExt(Op, IntTy); + } } return nullptr; @@ -1968,7 +2038,6 @@ static Value *getIntToFPVal(Value *I2F, IRBuilderBase &B, unsigned DstWidth) { Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { Module *M = Pow->getModule(); Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); - Module *Mod = Pow->getModule(); Type *Ty = Pow->getType(); bool Ignored; @@ -2025,11 +2094,10 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { // Create new exp{,2}() with the product as its argument. Value *FMul = B.CreateFMul(BaseFn->getArgOperand(0), Expo, "mul"); ExpFn = BaseFn->doesNotAccessMemory() - ? B.CreateCall(Intrinsic::getDeclaration(Mod, ID, Ty), - FMul, ExpName) - : emitUnaryFloatFnCall(FMul, TLI, LibFnDouble, LibFnFloat, - LibFnLongDouble, B, - BaseFn->getAttributes()); + ? B.CreateUnaryIntrinsic(ID, FMul, nullptr, ExpName) + : emitUnaryFloatFnCall(FMul, TLI, LibFnDouble, LibFnFloat, + LibFnLongDouble, B, + BaseFn->getAttributes()); // Since the new exp{,2}() is different from the original one, dead code // elimination cannot be trusted to remove it, since it may have side @@ -2043,21 +2111,34 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { // Evaluate special cases related to a constant base. const APFloat *BaseF; - if (!match(Pow->getArgOperand(0), m_APFloat(BaseF))) + if (!match(Base, m_APFloat(BaseF))) return nullptr; AttributeList NoAttrs; // Attributes are only meaningful on the original call + const bool UseIntrinsic = Pow->doesNotAccessMemory(); + // pow(2.0, itofp(x)) -> ldexp(1.0, x) - // TODO: This does not work for vectors because there is no ldexp intrinsic. - if (!Ty->isVectorTy() && match(Base, m_SpecificFP(2.0)) && + if ((UseIntrinsic || !Ty->isVectorTy()) && BaseF->isExactlyValue(2.0) && (isa<SIToFPInst>(Expo) || isa<UIToFPInst>(Expo)) && - hasFloatFn(M, TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) { - if (Value *ExpoI = getIntToFPVal(Expo, B, TLI->getIntSize())) - return copyFlags(*Pow, - emitBinaryFloatFnCall(ConstantFP::get(Ty, 1.0), ExpoI, - TLI, LibFunc_ldexp, LibFunc_ldexpf, - LibFunc_ldexpl, B, NoAttrs)); + (UseIntrinsic || + hasFloatFn(M, TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl))) { + + // TODO: Shouldn't really need to depend on getIntToFPVal for intrinsic. Can + // just directly use the original integer type. + if (Value *ExpoI = getIntToFPVal(Expo, B, TLI->getIntSize())) { + Constant *One = ConstantFP::get(Ty, 1.0); + + if (UseIntrinsic) { + return copyFlags(*Pow, B.CreateIntrinsic(Intrinsic::ldexp, + {Ty, ExpoI->getType()}, + {One, ExpoI}, Pow, "exp2")); + } + + return copyFlags(*Pow, emitBinaryFloatFnCall( + One, ExpoI, TLI, LibFunc_ldexp, LibFunc_ldexpf, + LibFunc_ldexpl, B, NoAttrs)); + } } // pow(2.0 ** n, x) -> exp2(n * x) @@ -2075,9 +2156,8 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { double N = NI.logBase2() * (IsReciprocal ? -1.0 : 1.0); Value *FMul = B.CreateFMul(Expo, ConstantFP::get(Ty, N), "mul"); if (Pow->doesNotAccessMemory()) - return copyFlags(*Pow, B.CreateCall(Intrinsic::getDeclaration( - Mod, Intrinsic::exp2, Ty), - FMul, "exp2")); + return copyFlags(*Pow, B.CreateUnaryIntrinsic(Intrinsic::exp2, FMul, + nullptr, "exp2")); else return copyFlags(*Pow, emitUnaryFloatFnCall(FMul, TLI, LibFunc_exp2, LibFunc_exp2f, @@ -2086,12 +2166,19 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { } // pow(10.0, x) -> exp10(x) - // TODO: There is no exp10() intrinsic yet, but some day there shall be one. - if (match(Base, m_SpecificFP(10.0)) && - hasFloatFn(M, TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) + if (BaseF->isExactlyValue(10.0) && + hasFloatFn(M, TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) { + + if (Pow->doesNotAccessMemory()) { + CallInst *NewExp10 = + B.CreateIntrinsic(Intrinsic::exp10, {Ty}, {Expo}, Pow, "exp10"); + return copyFlags(*Pow, NewExp10); + } + return copyFlags(*Pow, emitUnaryFloatFnCall(Expo, TLI, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l, B, NoAttrs)); + } // pow(x, y) -> exp2(log2(x) * y) if (Pow->hasApproxFunc() && Pow->hasNoNaNs() && BaseF->isFiniteNonZero() && @@ -2110,9 +2197,8 @@ Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) { if (Log) { Value *FMul = B.CreateFMul(Log, Expo, "mul"); if (Pow->doesNotAccessMemory()) - return copyFlags(*Pow, B.CreateCall(Intrinsic::getDeclaration( - Mod, Intrinsic::exp2, Ty), - FMul, "exp2")); + return copyFlags(*Pow, B.CreateUnaryIntrinsic(Intrinsic::exp2, FMul, + nullptr, "exp2")); else if (hasFloatFn(M, TLI, Ty, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l)) return copyFlags(*Pow, emitUnaryFloatFnCall(FMul, TLI, LibFunc_exp2, @@ -2128,11 +2214,8 @@ static Value *getSqrtCall(Value *V, AttributeList Attrs, bool NoErrno, Module *M, IRBuilderBase &B, const TargetLibraryInfo *TLI) { // If errno is never set, then use the intrinsic for sqrt(). - if (NoErrno) { - Function *SqrtFn = - Intrinsic::getDeclaration(M, Intrinsic::sqrt, V->getType()); - return B.CreateCall(SqrtFn, V, "sqrt"); - } + if (NoErrno) + return B.CreateUnaryIntrinsic(Intrinsic::sqrt, V, nullptr, "sqrt"); // Otherwise, use the libcall for sqrt(). if (hasFloatFn(M, TLI, V->getType(), LibFunc_sqrt, LibFunc_sqrtf, @@ -2167,7 +2250,8 @@ Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilderBase &B) { // pow(-Inf, 0.5) is optionally required to have a result of +Inf (not setting // errno), but sqrt(-Inf) is required by various standards to set errno. if (!Pow->doesNotAccessMemory() && !Pow->hasNoInfs() && - !isKnownNeverInfinity(Base, DL, TLI, 0, AC, Pow)) + !isKnownNeverInfinity(Base, 0, + SimplifyQuery(DL, TLI, /*DT=*/nullptr, AC, Pow))) return nullptr; Sqrt = getSqrtCall(Base, AttributeList(), Pow->doesNotAccessMemory(), Mod, B, @@ -2176,10 +2260,8 @@ Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilderBase &B) { return nullptr; // Handle signed zero base by expanding to fabs(sqrt(x)). - if (!Pow->hasNoSignedZeros()) { - Function *FAbsFn = Intrinsic::getDeclaration(Mod, Intrinsic::fabs, Ty); - Sqrt = B.CreateCall(FAbsFn, Sqrt, "abs"); - } + if (!Pow->hasNoSignedZeros()) + Sqrt = B.CreateUnaryIntrinsic(Intrinsic::fabs, Sqrt, nullptr, "abs"); Sqrt = copyFlags(*Pow, Sqrt); @@ -2203,8 +2285,7 @@ static Value *createPowWithIntegerExponent(Value *Base, Value *Expo, Module *M, IRBuilderBase &B) { Value *Args[] = {Base, Expo}; Type *Types[] = {Base->getType(), Expo->getType()}; - Function *F = Intrinsic::getDeclaration(M, Intrinsic::powi, Types); - return B.CreateCall(F, Args); + return B.CreateIntrinsic(Intrinsic::powi, Types, Args); } Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilderBase &B) { @@ -2328,24 +2409,38 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilderBase &B) { hasFloatVersion(M, Name)) Ret = optimizeUnaryDoubleFP(CI, B, TLI, true); + // If we have an llvm.exp2 intrinsic, emit the llvm.ldexp intrinsic. If we + // have the libcall, emit the libcall. + // + // TODO: In principle we should be able to just always use the intrinsic for + // any doesNotAccessMemory callsite. + + const bool UseIntrinsic = Callee->isIntrinsic(); // Bail out for vectors because the code below only expects scalars. - // TODO: This could be allowed if we had a ldexp intrinsic (D14327). Type *Ty = CI->getType(); - if (Ty->isVectorTy()) + if (!UseIntrinsic && Ty->isVectorTy()) return Ret; // exp2(sitofp(x)) -> ldexp(1.0, sext(x)) if sizeof(x) <= IntSize // exp2(uitofp(x)) -> ldexp(1.0, zext(x)) if sizeof(x) < IntSize Value *Op = CI->getArgOperand(0); if ((isa<SIToFPInst>(Op) || isa<UIToFPInst>(Op)) && - hasFloatFn(M, TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl)) { + (UseIntrinsic || + hasFloatFn(M, TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl))) { if (Value *Exp = getIntToFPVal(Op, B, TLI->getIntSize())) { + Constant *One = ConstantFP::get(Ty, 1.0); + + if (UseIntrinsic) { + return copyFlags(*CI, B.CreateIntrinsic(Intrinsic::ldexp, + {Ty, Exp->getType()}, + {One, Exp}, CI)); + } + IRBuilderBase::FastMathFlagGuard Guard(B); B.setFastMathFlags(CI->getFastMathFlags()); - return copyFlags( - *CI, emitBinaryFloatFnCall(ConstantFP::get(Ty, 1.0), Exp, TLI, - LibFunc_ldexp, LibFunc_ldexpf, - LibFunc_ldexpl, B, AttributeList())); + return copyFlags(*CI, emitBinaryFloatFnCall( + One, Exp, TLI, LibFunc_ldexp, LibFunc_ldexpf, + LibFunc_ldexpl, B, AttributeList())); } } @@ -2377,9 +2472,8 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilderBase &B) { Intrinsic::ID IID = Callee->getName().starts_with("fmin") ? Intrinsic::minnum : Intrinsic::maxnum; - Function *F = Intrinsic::getDeclaration(CI->getModule(), IID, CI->getType()); - return copyFlags( - *CI, B.CreateCall(F, {CI->getArgOperand(0), CI->getArgOperand(1)})); + return copyFlags(*CI, B.CreateBinaryIntrinsic(IID, CI->getArgOperand(0), + CI->getArgOperand(1))); } Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { @@ -2498,8 +2592,7 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { if (ArgLb == PowLb || ArgID == Intrinsic::pow || ArgID == Intrinsic::powi) { Value *LogX = Log->doesNotAccessMemory() - ? B.CreateCall(Intrinsic::getDeclaration(Mod, LogID, Ty), - Arg->getOperand(0), "log") + ? B.CreateUnaryIntrinsic(LogID, Arg->getOperand(0), nullptr, "log") : emitUnaryFloatFnCall(Arg->getOperand(0), TLI, LogNm, B, NoAttrs); Value *Y = Arg->getArgOperand(1); // Cast exponent to FP if integer. @@ -2525,8 +2618,7 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { else Eul = ConstantFP::get(Log->getType(), 10.0); Value *LogE = Log->doesNotAccessMemory() - ? B.CreateCall(Intrinsic::getDeclaration(Mod, LogID, Ty), - Eul, "log") + ? B.CreateUnaryIntrinsic(LogID, Eul, nullptr, "log") : emitUnaryFloatFnCall(Eul, TLI, LogNm, B, NoAttrs); Value *MulY = B.CreateFMul(Arg->getArgOperand(0), LogE, "mul"); // Since exp() may have side effects, e.g. errno, @@ -2538,6 +2630,70 @@ Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { return Ret; } +// sqrt(exp(X)) -> exp(X * 0.5) +Value *LibCallSimplifier::mergeSqrtToExp(CallInst *CI, IRBuilderBase &B) { + if (!CI->hasAllowReassoc()) + return nullptr; + + Function *SqrtFn = CI->getCalledFunction(); + CallInst *Arg = dyn_cast<CallInst>(CI->getArgOperand(0)); + if (!Arg || !Arg->hasAllowReassoc() || !Arg->hasOneUse()) + return nullptr; + Intrinsic::ID ArgID = Arg->getIntrinsicID(); + LibFunc ArgLb = NotLibFunc; + TLI->getLibFunc(*Arg, ArgLb); + + LibFunc SqrtLb, ExpLb, Exp2Lb, Exp10Lb; + + if (TLI->getLibFunc(SqrtFn->getName(), SqrtLb)) + switch (SqrtLb) { + case LibFunc_sqrtf: + ExpLb = LibFunc_expf; + Exp2Lb = LibFunc_exp2f; + Exp10Lb = LibFunc_exp10f; + break; + case LibFunc_sqrt: + ExpLb = LibFunc_exp; + Exp2Lb = LibFunc_exp2; + Exp10Lb = LibFunc_exp10; + break; + case LibFunc_sqrtl: + ExpLb = LibFunc_expl; + Exp2Lb = LibFunc_exp2l; + Exp10Lb = LibFunc_exp10l; + break; + default: + return nullptr; + } + else if (SqrtFn->getIntrinsicID() == Intrinsic::sqrt) { + if (CI->getType()->getScalarType()->isFloatTy()) { + ExpLb = LibFunc_expf; + Exp2Lb = LibFunc_exp2f; + Exp10Lb = LibFunc_exp10f; + } else if (CI->getType()->getScalarType()->isDoubleTy()) { + ExpLb = LibFunc_exp; + Exp2Lb = LibFunc_exp2; + Exp10Lb = LibFunc_exp10; + } else + return nullptr; + } else + return nullptr; + + if (ArgLb != ExpLb && ArgLb != Exp2Lb && ArgLb != Exp10Lb && + ArgID != Intrinsic::exp && ArgID != Intrinsic::exp2) + return nullptr; + + IRBuilderBase::InsertPointGuard Guard(B); + B.SetInsertPoint(Arg); + auto *ExpOperand = Arg->getOperand(0); + auto *FMul = + B.CreateFMulFMF(ExpOperand, ConstantFP::get(ExpOperand->getType(), 0.5), + CI, "merged.sqrt"); + + Arg->setOperand(0, FMul); + return Arg; +} + Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) { Module *M = CI->getModule(); Function *Callee = CI->getCalledFunction(); @@ -2550,6 +2706,9 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) { Callee->getIntrinsicID() == Intrinsic::sqrt)) Ret = optimizeUnaryDoubleFP(CI, B, TLI, true); + if (Value *Opt = mergeSqrtToExp(CI, B)) + return Opt; + if (!CI->isFast()) return Ret; @@ -2593,27 +2752,29 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) { // If we found a repeated factor, hoist it out of the square root and // replace it with the fabs of that factor. - Type *ArgType = I->getType(); - Function *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType); - Value *FabsCall = B.CreateCall(Fabs, RepeatOp, "fabs"); + Value *FabsCall = + B.CreateUnaryIntrinsic(Intrinsic::fabs, RepeatOp, nullptr, "fabs"); if (OtherOp) { // If we found a non-repeated factor, we still need to get its square // root. We then multiply that by the value that was simplified out // of the square root calculation. - Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, ArgType); - Value *SqrtCall = B.CreateCall(Sqrt, OtherOp, "sqrt"); + Value *SqrtCall = + B.CreateUnaryIntrinsic(Intrinsic::sqrt, OtherOp, nullptr, "sqrt"); return copyFlags(*CI, B.CreateFMul(FabsCall, SqrtCall)); } return copyFlags(*CI, FabsCall); } -// TODO: Generalize to handle any trig function and its inverse. -Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) { +Value *LibCallSimplifier::optimizeTrigInversionPairs(CallInst *CI, + IRBuilderBase &B) { Module *M = CI->getModule(); Function *Callee = CI->getCalledFunction(); Value *Ret = nullptr; StringRef Name = Callee->getName(); - if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(M, Name)) + if (UnsafeFPShrink && + (Name == "tan" || Name == "atanh" || Name == "sinh" || Name == "cosh" || + Name == "asinh") && + hasFloatVersion(M, Name)) Ret = optimizeUnaryDoubleFP(CI, B, TLI, true); Value *Op1 = CI->getArgOperand(0); @@ -2626,16 +2787,34 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilderBase &B) { return Ret; // tan(atan(x)) -> x - // tanf(atanf(x)) -> x - // tanl(atanl(x)) -> x + // atanh(tanh(x)) -> x + // sinh(asinh(x)) -> x + // asinh(sinh(x)) -> x + // cosh(acosh(x)) -> x LibFunc Func; Function *F = OpC->getCalledFunction(); if (F && TLI->getLibFunc(F->getName(), Func) && - isLibFuncEmittable(M, TLI, Func) && - ((Func == LibFunc_atan && Callee->getName() == "tan") || - (Func == LibFunc_atanf && Callee->getName() == "tanf") || - (Func == LibFunc_atanl && Callee->getName() == "tanl"))) - Ret = OpC->getArgOperand(0); + isLibFuncEmittable(M, TLI, Func)) { + LibFunc inverseFunc = llvm::StringSwitch<LibFunc>(Callee->getName()) + .Case("tan", LibFunc_atan) + .Case("atanh", LibFunc_tanh) + .Case("sinh", LibFunc_asinh) + .Case("cosh", LibFunc_acosh) + .Case("tanf", LibFunc_atanf) + .Case("atanhf", LibFunc_tanhf) + .Case("sinhf", LibFunc_asinhf) + .Case("coshf", LibFunc_acoshf) + .Case("tanl", LibFunc_atanl) + .Case("atanhl", LibFunc_tanhl) + .Case("sinhl", LibFunc_asinhl) + .Case("coshl", LibFunc_acoshl) + .Case("asinh", LibFunc_sinh) + .Case("asinhf", LibFunc_sinhf) + .Case("asinhl", LibFunc_sinhl) + .Default(NumLibFuncs); // Used as error value + if (Func == inverseFunc) + Ret = OpC->getArgOperand(0); + } return Ret; } @@ -2702,6 +2881,63 @@ static bool insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg, return true; } +static Value *optimizeSymmetricCall(CallInst *CI, bool IsEven, + IRBuilderBase &B) { + Value *X; + Value *Src = CI->getArgOperand(0); + + if (match(Src, m_OneUse(m_FNeg(m_Value(X))))) { + IRBuilderBase::FastMathFlagGuard Guard(B); + B.setFastMathFlags(CI->getFastMathFlags()); + + auto *CallInst = copyFlags(*CI, B.CreateCall(CI->getCalledFunction(), {X})); + if (IsEven) { + // Even function: f(-x) = f(x) + return CallInst; + } + // Odd function: f(-x) = -f(x) + return B.CreateFNeg(CallInst); + } + + // Even function: f(abs(x)) = f(x), f(copysign(x, y)) = f(x) + if (IsEven && (match(Src, m_FAbs(m_Value(X))) || + match(Src, m_CopySign(m_Value(X), m_Value())))) { + IRBuilderBase::FastMathFlagGuard Guard(B); + B.setFastMathFlags(CI->getFastMathFlags()); + + auto *CallInst = copyFlags(*CI, B.CreateCall(CI->getCalledFunction(), {X})); + return CallInst; + } + + return nullptr; +} + +Value *LibCallSimplifier::optimizeSymmetric(CallInst *CI, LibFunc Func, + IRBuilderBase &B) { + switch (Func) { + case LibFunc_cos: + case LibFunc_cosf: + case LibFunc_cosl: + return optimizeSymmetricCall(CI, /*IsEven*/ true, B); + + case LibFunc_sin: + case LibFunc_sinf: + case LibFunc_sinl: + + case LibFunc_tan: + case LibFunc_tanf: + case LibFunc_tanl: + + case LibFunc_erf: + case LibFunc_erff: + case LibFunc_erfl: + return optimizeSymmetricCall(CI, /*IsEven*/ false, B); + + default: + return nullptr; + } +} + Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B) { // Make sure the prototype is as expected, otherwise the rest of the // function is probably invalid and likely to abort. @@ -2792,9 +3028,8 @@ Value *LibCallSimplifier::optimizeFFS(CallInst *CI, IRBuilderBase &B) { Type *RetType = CI->getType(); Value *Op = CI->getArgOperand(0); Type *ArgType = Op->getType(); - Function *F = Intrinsic::getDeclaration(CI->getCalledFunction()->getParent(), - Intrinsic::cttz, ArgType); - Value *V = B.CreateCall(F, {Op, B.getTrue()}, "cttz"); + Value *V = B.CreateIntrinsic(Intrinsic::cttz, {ArgType}, {Op, B.getTrue()}, + nullptr, "cttz"); V = B.CreateAdd(V, ConstantInt::get(V->getType(), 1)); V = B.CreateIntCast(V, RetType, false); @@ -2807,9 +3042,8 @@ Value *LibCallSimplifier::optimizeFls(CallInst *CI, IRBuilderBase &B) { // fls{,l,ll}(x) -> (int)(sizeInBits(x) - llvm.ctlz(x, false)) Value *Op = CI->getArgOperand(0); Type *ArgType = Op->getType(); - Function *F = Intrinsic::getDeclaration(CI->getCalledFunction()->getParent(), - Intrinsic::ctlz, ArgType); - Value *V = B.CreateCall(F, {Op, B.getFalse()}, "ctlz"); + Value *V = B.CreateIntrinsic(Intrinsic::ctlz, {ArgType}, {Op, B.getFalse()}, + nullptr, "ctlz"); V = B.CreateSub(ConstantInt::get(V->getType(), ArgType->getIntegerBitWidth()), V); return B.CreateIntCast(V, CI->getType(), false); @@ -3566,6 +3800,14 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, case LibFunc_ZnamRKSt9nothrow_t: case LibFunc_ZnamSt11align_val_t: case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t: + case LibFunc_Znwm12__hot_cold_t: + case LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t: + case LibFunc_ZnwmSt11align_val_t12__hot_cold_t: + case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t: + case LibFunc_Znam12__hot_cold_t: + case LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t: + case LibFunc_ZnamSt11align_val_t12__hot_cold_t: + case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t: return optimizeNew(CI, Builder, Func); default: break; @@ -3583,7 +3825,7 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, if (CI->isStrictFP()) return nullptr; - if (Value *V = optimizeTrigReflections(CI, Func, Builder)) + if (Value *V = optimizeSymmetric(CI, Func, Builder)) return V; switch (Func) { @@ -3628,7 +3870,19 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_tan: case LibFunc_tanf: case LibFunc_tanl: - return optimizeTan(CI, Builder); + case LibFunc_sinh: + case LibFunc_sinhf: + case LibFunc_sinhl: + case LibFunc_asinh: + case LibFunc_asinhf: + case LibFunc_asinhl: + case LibFunc_cosh: + case LibFunc_coshf: + case LibFunc_coshl: + case LibFunc_atanh: + case LibFunc_atanhf: + case LibFunc_atanhl: + return optimizeTrigInversionPairs(CI, Builder); case LibFunc_ceil: return replaceUnaryCall(CI, Builder, Intrinsic::ceil); case LibFunc_floor: @@ -3646,17 +3900,13 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_acos: case LibFunc_acosh: case LibFunc_asin: - case LibFunc_asinh: case LibFunc_atan: - case LibFunc_atanh: case LibFunc_cbrt: - case LibFunc_cosh: case LibFunc_exp: case LibFunc_exp10: case LibFunc_expm1: case LibFunc_cos: case LibFunc_sin: - case LibFunc_sinh: case LibFunc_tanh: if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName())) return optimizeUnaryDoubleFP(CI, Builder, TLI, true); @@ -3938,7 +4188,7 @@ Value *FortifiedLibCallSimplifier::optimizeMemSetChk(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeMemPCpyChk(CallInst *CI, IRBuilderBase &B) { - const DataLayout &DL = CI->getModule()->getDataLayout(); + const DataLayout &DL = CI->getDataLayout(); if (isFortifiedCallFoldable(CI, 3, 2)) if (Value *Call = emitMemPCpy(CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B, DL, TLI)) { @@ -3950,7 +4200,7 @@ Value *FortifiedLibCallSimplifier::optimizeMemPCpyChk(CallInst *CI, Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI, IRBuilderBase &B, LibFunc Func) { - const DataLayout &DL = CI->getModule()->getDataLayout(); + const DataLayout &DL = CI->getDataLayout(); Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1), *ObjSize = CI->getArgOperand(2); @@ -3998,7 +4248,7 @@ Value *FortifiedLibCallSimplifier::optimizeStrLenChk(CallInst *CI, IRBuilderBase &B) { if (isFortifiedCallFoldable(CI, 1, std::nullopt, 0)) return copyFlags(*CI, emitStrLen(CI->getArgOperand(0), B, - CI->getModule()->getDataLayout(), TLI)); + CI->getDataLayout(), TLI)); return nullptr; } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SplitModule.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SplitModule.cpp index 9c39c26d8b7a..a30afadf0365 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SplitModule.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SplitModule.cpp @@ -55,6 +55,18 @@ using ClusterMapType = EquivalenceClasses<const GlobalValue *>; using ComdatMembersType = DenseMap<const Comdat *, const GlobalValue *>; using ClusterIDMapType = DenseMap<const GlobalValue *, unsigned>; +bool compareClusters(const std::pair<unsigned, unsigned> &A, + const std::pair<unsigned, unsigned> &B) { + if (A.second || B.second) + return A.second > B.second; + return A.first > B.first; +} + +using BalancingQueueType = + std::priority_queue<std::pair<unsigned, unsigned>, + std::vector<std::pair<unsigned, unsigned>>, + decltype(compareClusters) *>; + } // end anonymous namespace static void addNonConstUser(ClusterMapType &GVtoClusterMap, @@ -105,7 +117,8 @@ static void findPartitions(Module &M, ClusterIDMapType &ClusterIDMap, // At this point module should have the proper mix of globals and locals. // As we attempt to partition this module, we must not change any // locals to globals. - LLVM_DEBUG(dbgs() << "Partition module with (" << M.size() << ")functions\n"); + LLVM_DEBUG(dbgs() << "Partition module with (" << M.size() + << ") functions\n"); ClusterMapType GVtoClusterMap; ComdatMembersType ComdatMembers; @@ -153,21 +166,10 @@ static void findPartitions(Module &M, ClusterIDMapType &ClusterIDMap, // Assigned all GVs to merged clusters while balancing number of objects in // each. - auto CompareClusters = [](const std::pair<unsigned, unsigned> &a, - const std::pair<unsigned, unsigned> &b) { - if (a.second || b.second) - return a.second > b.second; - else - return a.first > b.first; - }; - - std::priority_queue<std::pair<unsigned, unsigned>, - std::vector<std::pair<unsigned, unsigned>>, - decltype(CompareClusters)> - BalancinQueue(CompareClusters); + BalancingQueueType BalancingQueue(compareClusters); // Pre-populate priority queue with N slot blanks. for (unsigned i = 0; i < N; ++i) - BalancinQueue.push(std::make_pair(i, 0)); + BalancingQueue.push(std::make_pair(i, 0)); using SortType = std::pair<unsigned, ClusterMapType::iterator>; @@ -177,11 +179,13 @@ static void findPartitions(Module &M, ClusterIDMapType &ClusterIDMap, // To guarantee determinism, we have to sort SCC according to size. // When size is the same, use leader's name. for (ClusterMapType::iterator I = GVtoClusterMap.begin(), - E = GVtoClusterMap.end(); I != E; ++I) + E = GVtoClusterMap.end(); + I != E; ++I) if (I->isLeader()) Sets.push_back( std::make_pair(std::distance(GVtoClusterMap.member_begin(I), - GVtoClusterMap.member_end()), I)); + GVtoClusterMap.member_end()), + I)); llvm::sort(Sets, [](const SortType &a, const SortType &b) { if (a.first == b.first) @@ -191,9 +195,9 @@ static void findPartitions(Module &M, ClusterIDMapType &ClusterIDMap, }); for (auto &I : Sets) { - unsigned CurrentClusterID = BalancinQueue.top().first; - unsigned CurrentClusterSize = BalancinQueue.top().second; - BalancinQueue.pop(); + unsigned CurrentClusterID = BalancingQueue.top().first; + unsigned CurrentClusterSize = BalancingQueue.top().second; + BalancingQueue.pop(); LLVM_DEBUG(dbgs() << "Root[" << CurrentClusterID << "] cluster_size(" << I.first << ") ----> " << I.second->getData()->getName() @@ -211,7 +215,7 @@ static void findPartitions(Module &M, ClusterIDMapType &ClusterIDMap, CurrentClusterSize++; } // Add this set size to the number of entries in this cluster. - BalancinQueue.push(std::make_pair(CurrentClusterID, CurrentClusterSize)); + BalancingQueue.push(std::make_pair(CurrentClusterID, CurrentClusterSize)); } } @@ -251,7 +255,7 @@ static bool isInPartition(const GlobalValue *GV, unsigned I, unsigned N) { void llvm::SplitModule( Module &M, unsigned N, function_ref<void(std::unique_ptr<Module> MPart)> ModuleCallback, - bool PreserveLocals) { + bool PreserveLocals, bool RoundRobin) { if (!PreserveLocals) { for (Function &F : M) externalize(&F); @@ -268,6 +272,41 @@ void llvm::SplitModule( ClusterIDMapType ClusterIDMap; findPartitions(M, ClusterIDMap, N); + // Find functions not mapped to modules in ClusterIDMap and count functions + // per module. Map unmapped functions using round-robin so that they skip + // being distributed by isInPartition() based on function name hashes below. + // This provides better uniformity of distribution of functions to modules + // in some cases - for example when the number of functions equals to N. + if (RoundRobin) { + DenseMap<unsigned, unsigned> ModuleFunctionCount; + SmallVector<const GlobalValue *> UnmappedFunctions; + for (const auto &F : M.functions()) { + if (F.isDeclaration() || + F.getLinkage() != GlobalValue::LinkageTypes::ExternalLinkage) + continue; + auto It = ClusterIDMap.find(&F); + if (It == ClusterIDMap.end()) + UnmappedFunctions.push_back(&F); + else + ++ModuleFunctionCount[It->second]; + } + BalancingQueueType BalancingQueue(compareClusters); + for (unsigned I = 0; I < N; ++I) { + if (auto It = ModuleFunctionCount.find(I); + It != ModuleFunctionCount.end()) + BalancingQueue.push(*It); + else + BalancingQueue.push({I, 0}); + } + for (const auto *const F : UnmappedFunctions) { + const unsigned I = BalancingQueue.top().first; + const unsigned Count = BalancingQueue.top().second; + BalancingQueue.pop(); + ClusterIDMap.insert({F, I}); + BalancingQueue.push({I, Count + 1}); + } + } + // FIXME: We should be able to reuse M as the last partition instead of // cloning it. Note that the callers at the moment expect the module to // be preserved, so will need some adjustments as well. @@ -275,8 +314,8 @@ void llvm::SplitModule( ValueToValueMapTy VMap; std::unique_ptr<Module> MPart( CloneModule(M, VMap, [&](const GlobalValue *GV) { - if (ClusterIDMap.count(GV)) - return (ClusterIDMap[GV] == I); + if (auto It = ClusterIDMap.find(GV); It != ClusterIDMap.end()) + return It->second == I; else return isInPartition(GV, I, N); })); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/StripGCRelocates.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/StripGCRelocates.cpp index 6094f36a77f4..3ae76ffd5eca 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/StripGCRelocates.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/StripGCRelocates.cpp @@ -42,7 +42,7 @@ static bool stripGCRelocates(Function &F) { // All gc_relocates are i8 addrspace(1)* typed, we need a bitcast from i8 // addrspace(1)* to the type of the OrigPtr, if the are not the same. if (GCRel->getType() != OrigPtr->getType()) - ReplaceGCRel = new BitCastInst(OrigPtr, GCRel->getType(), "cast", GCRel); + ReplaceGCRel = new BitCastInst(OrigPtr, GCRel->getType(), "cast", GCRel->getIterator()); // Replace all uses of gc.relocate and delete the gc.relocate // There maybe unncessary bitcasts back to the OrigPtr type, an instcombine diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/SymbolRewriter.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/SymbolRewriter.cpp index 8b4f34209e85..d52d52a9b7d3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/SymbolRewriter.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/SymbolRewriter.cpp @@ -308,11 +308,11 @@ bool RewriteMapParser::parseEntry(yaml::Stream &YS, yaml::KeyValueNode &Entry, } RewriteType = Key->getValue(KeyStorage); - if (RewriteType.equals("function")) + if (RewriteType == "function") return parseRewriteFunctionDescriptor(YS, Key, Value, DL); - else if (RewriteType.equals("global variable")) + else if (RewriteType == "global variable") return parseRewriteGlobalVariableDescriptor(YS, Key, Value, DL); - else if (RewriteType.equals("global alias")) + else if (RewriteType == "global alias") return parseRewriteGlobalAliasDescriptor(YS, Key, Value, DL); YS.printError(Entry.getKey(), "unknown rewrite type"); @@ -348,7 +348,7 @@ parseRewriteFunctionDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, } KeyValue = Key->getValue(KeyStorage); - if (KeyValue.equals("source")) { + if (KeyValue == "source") { std::string Error; Source = std::string(Value->getValue(ValueStorage)); @@ -356,11 +356,11 @@ parseRewriteFunctionDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, YS.printError(Field.getKey(), "invalid regex: " + Error); return false; } - } else if (KeyValue.equals("target")) { + } else if (KeyValue == "target") { Target = std::string(Value->getValue(ValueStorage)); - } else if (KeyValue.equals("transform")) { + } else if (KeyValue == "transform") { Transform = std::string(Value->getValue(ValueStorage)); - } else if (KeyValue.equals("naked")) { + } else if (KeyValue == "naked") { std::string Undecorated; Undecorated = std::string(Value->getValue(ValueStorage)); @@ -417,7 +417,7 @@ parseRewriteGlobalVariableDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, } KeyValue = Key->getValue(KeyStorage); - if (KeyValue.equals("source")) { + if (KeyValue == "source") { std::string Error; Source = std::string(Value->getValue(ValueStorage)); @@ -425,9 +425,9 @@ parseRewriteGlobalVariableDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, YS.printError(Field.getKey(), "invalid regex: " + Error); return false; } - } else if (KeyValue.equals("target")) { + } else if (KeyValue == "target") { Target = std::string(Value->getValue(ValueStorage)); - } else if (KeyValue.equals("transform")) { + } else if (KeyValue == "transform") { Transform = std::string(Value->getValue(ValueStorage)); } else { YS.printError(Field.getKey(), "unknown Key for Global Variable"); @@ -480,7 +480,7 @@ parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, } KeyValue = Key->getValue(KeyStorage); - if (KeyValue.equals("source")) { + if (KeyValue == "source") { std::string Error; Source = std::string(Value->getValue(ValueStorage)); @@ -488,9 +488,9 @@ parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, YS.printError(Field.getKey(), "invalid regex: " + Error); return false; } - } else if (KeyValue.equals("target")) { + } else if (KeyValue == "target") { Target = std::string(Value->getValue(ValueStorage)); - } else if (KeyValue.equals("transform")) { + } else if (KeyValue == "transform") { Transform = std::string(Value->getValue(ValueStorage)); } else { YS.printError(Field.getKey(), "unknown key for Global Alias"); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp index 2f37f7f972cb..1d51f61351fe 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp @@ -119,7 +119,7 @@ static void restoreSSA(const DominatorTree &DT, const Loop *L, LLVM_DEBUG(dbgs() << "externally used: " << Def->getName() << "\n"); auto NewPhi = PHINode::Create(Def->getType(), Incoming.size(), - Def->getName() + ".moved", &LoopExitBlock->front()); + Def->getName() + ".moved", LoopExitBlock->begin()); for (auto *In : Incoming) { LLVM_DEBUG(dbgs() << "predecessor " << In->getName() << ": "); if (Def->getParent() == In || DT.dominates(Def, In)) { diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/ValueMapper.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/ValueMapper.cpp index 380541ffdd49..1696e9c72673 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/ValueMapper.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/ValueMapper.cpp @@ -146,7 +146,7 @@ public: Value *mapValue(const Value *V); void remapInstruction(Instruction *I); void remapFunction(Function &F); - void remapDPValue(DPValue &DPV); + void remapDbgRecord(DbgRecord &DVR); Constant *mapConstant(const Constant *C) { return cast_or_null<Constant>(mapValue(C)); @@ -537,12 +537,21 @@ Value *Mapper::mapValue(const Value *V) { return getVM()[V] = ConstantPointerNull::get(cast<PointerType>(NewTy)); } -void Mapper::remapDPValue(DPValue &V) { - // Remap variables and DILocations. +void Mapper::remapDbgRecord(DbgRecord &DR) { + // Remap DILocations. + auto *MappedDILoc = mapMetadata(DR.getDebugLoc()); + DR.setDebugLoc(DebugLoc(cast<DILocation>(MappedDILoc))); + + if (DbgLabelRecord *DLR = dyn_cast<DbgLabelRecord>(&DR)) { + // Remap labels. + DLR->setLabel(cast<DILabel>(mapMetadata(DLR->getLabel()))); + return; + } + + DbgVariableRecord &V = cast<DbgVariableRecord>(DR); + // Remap variables. auto *MappedVar = mapMetadata(V.getVariable()); - auto *MappedDILoc = mapMetadata(V.getDebugLoc()); V.setVariable(cast<DILocalVariable>(MappedVar)); - V.setDebugLoc(DebugLoc(cast<DILocation>(MappedDILoc))); bool IgnoreMissingLocals = Flags & RF_IgnoreMissingLocals; @@ -552,6 +561,7 @@ void Mapper::remapDPValue(DPValue &V) { V.setKillAddress(); else if (NewAddr) V.setAddress(NewAddr); + V.setAssignId(cast<DIAssignID>(mapMetadata(V.getAssignID()))); } // Find Value operands and remap those. @@ -1056,9 +1066,13 @@ void Mapper::remapFunction(Function &F) { A.mutateType(TypeMapper->remapType(A.getType())); // Remap the instructions. - for (BasicBlock &BB : F) - for (Instruction &I : BB) + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { remapInstruction(&I); + for (DbgRecord &DR : I.getDbgRecordRange()) + remapDbgRecord(DR); + } + } } void Mapper::mapAppendingVariable(GlobalVariable &GV, Constant *InitPrefix, @@ -1222,14 +1236,14 @@ void ValueMapper::remapInstruction(Instruction &I) { FlushingMapper(pImpl)->remapInstruction(&I); } -void ValueMapper::remapDPValue(Module *M, DPValue &V) { - FlushingMapper(pImpl)->remapDPValue(V); +void ValueMapper::remapDbgRecord(Module *M, DbgRecord &DR) { + FlushingMapper(pImpl)->remapDbgRecord(DR); } -void ValueMapper::remapDPValueRange( - Module *M, iterator_range<DPValue::self_iterator> Range) { - for (DPValue &DPV : Range) { - remapDPValue(M, DPV); +void ValueMapper::remapDbgRecordRange( + Module *M, iterator_range<DbgRecord::self_iterator> Range) { + for (DbgRecord &DR : Range) { + remapDbgRecord(M, DR); } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index 1f11d4894f77..c91911ecad74 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -252,7 +252,7 @@ public: Vectorizer(Function &F, AliasAnalysis &AA, AssumptionCache &AC, DominatorTree &DT, ScalarEvolution &SE, TargetTransformInfo &TTI) : F(F), AA(AA), AC(AC), DT(DT), SE(SE), TTI(TTI), - DL(F.getParent()->getDataLayout()), Builder(SE.getContext()) {} + DL(F.getDataLayout()), Builder(SE.getContext()) {} bool run(); @@ -892,7 +892,7 @@ bool Vectorizer::vectorizeChain(Chain &C) { // Loads get hoisted to the location of the first load in the chain. We may // also need to hoist the (transitive) operands of the loads. Builder.SetInsertPoint( - std::min_element(C.begin(), C.end(), [](const auto &A, const auto &B) { + llvm::min_element(C, [](const auto &A, const auto &B) { return A.Inst->comesBefore(B.Inst); })->Inst); @@ -944,10 +944,9 @@ bool Vectorizer::vectorizeChain(Chain &C) { reorder(VecInst); } else { // Stores get sunk to the location of the last store in the chain. - Builder.SetInsertPoint( - std::max_element(C.begin(), C.end(), [](auto &A, auto &B) { - return A.Inst->comesBefore(B.Inst); - })->Inst); + Builder.SetInsertPoint(llvm::max_element(C, [](auto &A, auto &B) { + return A.Inst->comesBefore(B.Inst); + })->Inst); // Build the vector to store. Value *Vec = PoisonValue::get(VecTy); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp new file mode 100644 index 000000000000..64e04cae2773 --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp @@ -0,0 +1,943 @@ +//===-------- LoopIdiomVectorize.cpp - Loop idiom vectorization -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements a pass that recognizes certain loop idioms and +// transforms them into more optimized versions of the same loop. In cases +// where this happens, it can be a significant performance win. +// +// We currently only recognize one loop that finds the first mismatched byte +// in an array and returns the index, i.e. something like: +// +// while (++i != n) { +// if (a[i] != b[i]) +// break; +// } +// +// In this example we can actually vectorize the loop despite the early exit, +// although the loop vectorizer does not support it. It requires some extra +// checks to deal with the possibility of faulting loads when crossing page +// boundaries. However, even with these checks it is still profitable to do the +// transformation. +// +//===----------------------------------------------------------------------===// +// +// NOTE: This Pass matches a really specific loop pattern because it's only +// supposed to be a temporary solution until our LoopVectorizer is powerful +// enought to vectorize it automatically. +// +// TODO List: +// +// * Add support for the inverse case where we scan for a matching element. +// * Permit 64-bit induction variable types. +// * Recognize loops that increment the IV *after* comparing bytes. +// * Allow 32-bit sign-extends of the IV used by the GEP. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h" +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "loop-idiom-vectorize" + +static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden, + cl::init(false), + cl::desc("Disable Loop Idiom Vectorize Pass.")); + +static cl::opt<LoopIdiomVectorizeStyle> + LITVecStyle("loop-idiom-vectorize-style", cl::Hidden, + cl::desc("The vectorization style for loop idiom transform."), + cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, "masked", + "Use masked vector intrinsics"), + clEnumValN(LoopIdiomVectorizeStyle::Predicated, + "predicated", "Use VP intrinsics")), + cl::init(LoopIdiomVectorizeStyle::Masked)); + +static cl::opt<bool> + DisableByteCmp("disable-loop-idiom-vectorize-bytecmp", cl::Hidden, + cl::init(false), + cl::desc("Proceed with Loop Idiom Vectorize Pass, but do " + "not convert byte-compare loop(s).")); + +static cl::opt<unsigned> + ByteCmpVF("loop-idiom-vectorize-bytecmp-vf", cl::Hidden, + cl::desc("The vectorization factor for byte-compare patterns."), + cl::init(16)); + +static cl::opt<bool> + VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(false), + cl::desc("Verify loops generated Loop Idiom Vectorize Pass.")); + +namespace { +class LoopIdiomVectorize { + LoopIdiomVectorizeStyle VectorizeStyle; + unsigned ByteCompareVF; + Loop *CurLoop = nullptr; + DominatorTree *DT; + LoopInfo *LI; + const TargetTransformInfo *TTI; + const DataLayout *DL; + + // Blocks that will be used for inserting vectorized code. + BasicBlock *EndBlock = nullptr; + BasicBlock *VectorLoopPreheaderBlock = nullptr; + BasicBlock *VectorLoopStartBlock = nullptr; + BasicBlock *VectorLoopMismatchBlock = nullptr; + BasicBlock *VectorLoopIncBlock = nullptr; + +public: + LoopIdiomVectorize(LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT, + LoopInfo *LI, const TargetTransformInfo *TTI, + const DataLayout *DL) + : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) { + } + + bool run(Loop *L); + +private: + /// \name Countable Loop Idiom Handling + /// @{ + + bool runOnCountableLoop(); + bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount, + SmallVectorImpl<BasicBlock *> &ExitBlocks); + + bool recognizeByteCompare(); + + Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, + GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, + Instruction *Index, Value *Start, Value *MaxLen); + + Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, + GetElementPtrInst *GEPA, + GetElementPtrInst *GEPB, Value *ExtStart, + Value *ExtEnd); + Value *createPredicatedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, + GetElementPtrInst *GEPA, + GetElementPtrInst *GEPB, Value *ExtStart, + Value *ExtEnd); + + void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, + PHINode *IndPhi, Value *MaxLen, Instruction *Index, + Value *Start, bool IncIdx, BasicBlock *FoundBB, + BasicBlock *EndBB); + /// @} +}; +} // anonymous namespace + +PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { + if (DisableAll) + return PreservedAnalyses::all(); + + const auto *DL = &L.getHeader()->getDataLayout(); + + LoopIdiomVectorizeStyle VecStyle = VectorizeStyle; + if (LITVecStyle.getNumOccurrences()) + VecStyle = LITVecStyle; + + unsigned BCVF = ByteCompareVF; + if (ByteCmpVF.getNumOccurrences()) + BCVF = ByteCmpVF; + + LoopIdiomVectorize LIV(VecStyle, BCVF, &AR.DT, &AR.LI, &AR.TTI, DL); + if (!LIV.run(&L)) + return PreservedAnalyses::all(); + + return PreservedAnalyses::none(); +} + +//===----------------------------------------------------------------------===// +// +// Implementation of LoopIdiomVectorize +// +//===----------------------------------------------------------------------===// + +bool LoopIdiomVectorize::run(Loop *L) { + CurLoop = L; + + Function &F = *L->getHeader()->getParent(); + if (DisableAll || F.hasOptSize()) + return false; + + if (F.hasFnAttribute(Attribute::NoImplicitFloat)) { + LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on " << F.getName() + << " due to its NoImplicitFloat attribute"); + return false; + } + + // If the loop could not be converted to canonical form, it must have an + // indirectbr in it, just give up. + if (!L->getLoopPreheader()) + return false; + + LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %" + << CurLoop->getHeader()->getName() << "\n"); + + return recognizeByteCompare(); +} + +bool LoopIdiomVectorize::recognizeByteCompare() { + // Currently the transformation only works on scalable vector types, although + // there is no fundamental reason why it cannot be made to work for fixed + // width too. + + // We also need to know the minimum page size for the target in order to + // generate runtime memory checks to ensure the vector version won't fault. + if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() || + DisableByteCmp) + return false; + + BasicBlock *Header = CurLoop->getHeader(); + + // In LoopIdiomVectorize::run we have already checked that the loop + // has a preheader so we can assume it's in a canonical form. + if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2) + return false; + + PHINode *PN = dyn_cast<PHINode>(&Header->front()); + if (!PN || PN->getNumIncomingValues() != 2) + return false; + + auto LoopBlocks = CurLoop->getBlocks(); + // The first block in the loop should contain only 4 instructions, e.g. + // + // while.cond: + // %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ] + // %inc = add i32 %res.phi, 1 + // %cmp.not = icmp eq i32 %inc, %n + // br i1 %cmp.not, label %while.end, label %while.body + // + if (LoopBlocks[0]->sizeWithoutDebug() > 4) + return false; + + // The second block should contain 7 instructions, e.g. + // + // while.body: + // %idx = zext i32 %inc to i64 + // %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx + // %load.a = load i8, ptr %idx.a + // %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx + // %load.b = load i8, ptr %idx.b + // %cmp.not.ld = icmp eq i8 %load.a, %load.b + // br i1 %cmp.not.ld, label %while.cond, label %while.end + // + if (LoopBlocks[1]->sizeWithoutDebug() > 7) + return false; + + // The incoming value to the PHI node from the loop should be an add of 1. + Value *StartIdx = nullptr; + Instruction *Index = nullptr; + if (!CurLoop->contains(PN->getIncomingBlock(0))) { + StartIdx = PN->getIncomingValue(0); + Index = dyn_cast<Instruction>(PN->getIncomingValue(1)); + } else { + StartIdx = PN->getIncomingValue(1); + Index = dyn_cast<Instruction>(PN->getIncomingValue(0)); + } + + // Limit to 32-bit types for now + if (!Index || !Index->getType()->isIntegerTy(32) || + !match(Index, m_c_Add(m_Specific(PN), m_One()))) + return false; + + // If we match the pattern, PN and Index will be replaced with the result of + // the cttz.elts intrinsic. If any other instructions are used outside of + // the loop, we cannot replace it. + for (BasicBlock *BB : LoopBlocks) + for (Instruction &I : *BB) + if (&I != PN && &I != Index) + for (User *U : I.users()) + if (!CurLoop->contains(cast<Instruction>(U))) + return false; + + // Match the branch instruction for the header + ICmpInst::Predicate Pred; + Value *MaxLen; + BasicBlock *EndBB, *WhileBB; + if (!match(Header->getTerminator(), + m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)), + m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) || + Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(WhileBB)) + return false; + + // WhileBB should contain the pattern of load & compare instructions. Match + // the pattern and find the GEP instructions used by the loads. + ICmpInst::Predicate WhilePred; + BasicBlock *FoundBB; + BasicBlock *TrueBB; + Value *LoadA, *LoadB; + if (!match(WhileBB->getTerminator(), + m_Br(m_ICmp(WhilePred, m_Value(LoadA), m_Value(LoadB)), + m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) || + WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(TrueBB)) + return false; + + Value *A, *B; + if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B)))) + return false; + + LoadInst *LoadAI = cast<LoadInst>(LoadA); + LoadInst *LoadBI = cast<LoadInst>(LoadB); + if (!LoadAI->isSimple() || !LoadBI->isSimple()) + return false; + + GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A); + GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B); + + if (!GEPA || !GEPB) + return false; + + Value *PtrA = GEPA->getPointerOperand(); + Value *PtrB = GEPB->getPointerOperand(); + + // Check we are loading i8 values from two loop invariant pointers + if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) || + !GEPA->getResultElementType()->isIntegerTy(8) || + !GEPB->getResultElementType()->isIntegerTy(8) || + !LoadAI->getType()->isIntegerTy(8) || + !LoadBI->getType()->isIntegerTy(8) || PtrA == PtrB) + return false; + + // Check that the index to the GEPs is the index we found earlier + if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1) + return false; + + Value *IdxA = GEPA->getOperand(GEPA->getNumIndices()); + Value *IdxB = GEPB->getOperand(GEPB->getNumIndices()); + if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(Index)))) + return false; + + // We only ever expect the pre-incremented index value to be used inside the + // loop. + if (!PN->hasOneUse()) + return false; + + // Ensure that when the Found and End blocks are identical the PHIs have the + // supported format. We don't currently allow cases like this: + // while.cond: + // ... + // br i1 %cmp.not, label %while.end, label %while.body + // + // while.body: + // ... + // br i1 %cmp.not2, label %while.cond, label %while.end + // + // while.end: + // %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ] + // + // Where the incoming values for %final_ptr are unique and from each of the + // loop blocks, but not actually defined in the loop. This requires extra + // work setting up the byte.compare block, i.e. by introducing a select to + // choose the correct value. + // TODO: We could add support for this in future. + if (FoundBB == EndBB) { + for (PHINode &EndPN : EndBB->phis()) { + Value *WhileCondVal = EndPN.getIncomingValueForBlock(Header); + Value *WhileBodyVal = EndPN.getIncomingValueForBlock(WhileBB); + + // The value of the index when leaving the while.cond block is always the + // same as the end value (MaxLen) so we permit either. The value when + // leaving the while.body block should only be the index. Otherwise for + // any other values we only allow ones that are same for both blocks. + if (WhileCondVal != WhileBodyVal && + ((WhileCondVal != Index && WhileCondVal != MaxLen) || + (WhileBodyVal != Index))) + return false; + } + } + + LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n" + << *(EndBB->getParent()) << "\n\n"); + + // The index is incremented before the GEP/Load pair so we need to + // add 1 to the start value. + transformByteCompare(GEPA, GEPB, PN, MaxLen, Index, StartIdx, /*IncIdx=*/true, + FoundBB, EndBB); + return true; +} + +Value *LoopIdiomVectorize::createMaskedFindMismatch( + IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, + GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) { + Type *I64Type = Builder.getInt64Ty(); + Type *ResType = Builder.getInt32Ty(); + Type *LoadType = Builder.getInt8Ty(); + Value *PtrA = GEPA->getPointerOperand(); + Value *PtrB = GEPB->getPointerOperand(); + + ScalableVectorType *PredVTy = + ScalableVectorType::get(Builder.getInt1Ty(), ByteCompareVF); + + Value *InitialPred = Builder.CreateIntrinsic( + Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd}); + + Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {}); + VecLen = + Builder.CreateMul(VecLen, ConstantInt::get(I64Type, ByteCompareVF), "", + /*HasNUW=*/true, /*HasNSW=*/true); + + Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(), + Builder.getInt1(false)); + + BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock); + Builder.Insert(JumpToVectorLoop); + + DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock, + VectorLoopStartBlock}}); + + // Set up the first vector loop block by creating the PHIs, doing the vector + // loads and comparing the vectors. + Builder.SetInsertPoint(VectorLoopStartBlock); + PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred"); + LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock); + PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index"); + VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock); + Type *VectorLoadType = + ScalableVectorType::get(Builder.getInt8Ty(), ByteCompareVF); + Value *Passthru = ConstantInt::getNullValue(VectorLoadType); + + Value *VectorLhsGep = + Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds()); + Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep, + Align(1), LoopPred, Passthru); + + Value *VectorRhsGep = + Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds()); + Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep, + Align(1), LoopPred, Passthru); + + Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad); + VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse); + Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp); + BranchInst *VectorEarlyExit = BranchInst::Create( + VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes); + Builder.Insert(VectorEarlyExit); + + DTU.applyUpdates( + {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock}, + {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}}); + + // Increment the index counter and calculate the predicate for the next + // iteration of the loop. We branch back to the start of the loop if there + // is at least one active lane. + Builder.SetInsertPoint(VectorLoopIncBlock); + Value *NewVectorIndexPhi = + Builder.CreateAdd(VectorIndexPhi, VecLen, "", + /*HasNUW=*/true, /*HasNSW=*/true); + VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock); + Value *NewPred = + Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, + {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd}); + LoopPred->addIncoming(NewPred, VectorLoopIncBlock); + + Value *PredHasActiveLanes = + Builder.CreateExtractElement(NewPred, uint64_t(0)); + BranchInst *VectorLoopBranchBack = + BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes); + Builder.Insert(VectorLoopBranchBack); + + DTU.applyUpdates( + {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock}, + {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}}); + + // If we found a mismatch then we need to calculate which lane in the vector + // had a mismatch and add that on to the current loop index. + Builder.SetInsertPoint(VectorLoopMismatchBlock); + PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred"); + FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock); + PHINode *LastLoopPred = + Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred"); + LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock); + PHINode *VectorFoundIndex = + Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index"); + VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock); + + Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred); + Value *Ctz = Builder.CreateIntrinsic( + Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()}, + {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)}); + Ctz = Builder.CreateZExt(Ctz, I64Type); + Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "", + /*HasNUW=*/true, /*HasNSW=*/true); + return Builder.CreateTrunc(VectorLoopRes64, ResType); +} + +Value *LoopIdiomVectorize::createPredicatedFindMismatch( + IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, + GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) { + Type *I64Type = Builder.getInt64Ty(); + Type *I32Type = Builder.getInt32Ty(); + Type *ResType = I32Type; + Type *LoadType = Builder.getInt8Ty(); + Value *PtrA = GEPA->getPointerOperand(); + Value *PtrB = GEPB->getPointerOperand(); + + auto *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock); + Builder.Insert(JumpToVectorLoop); + + DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock, + VectorLoopStartBlock}}); + + // Set up the first Vector loop block by creating the PHIs, doing the vector + // loads and comparing the vectors. + Builder.SetInsertPoint(VectorLoopStartBlock); + auto *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vector_index"); + VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock); + + // Calculate AVL by subtracting the vector loop index from the trip count + Value *AVL = Builder.CreateSub(ExtEnd, VectorIndexPhi, "avl", /*HasNUW=*/true, + /*HasNSW=*/true); + + auto *VectorLoadType = ScalableVectorType::get(LoadType, ByteCompareVF); + auto *VF = ConstantInt::get(I32Type, ByteCompareVF); + + Value *VL = Builder.CreateIntrinsic(Intrinsic::experimental_get_vector_length, + {I64Type}, {AVL, VF, Builder.getTrue()}); + Value *GepOffset = VectorIndexPhi; + + Value *VectorLhsGep = + Builder.CreateGEP(LoadType, PtrA, GepOffset, "", GEPA->isInBounds()); + VectorType *TrueMaskTy = + VectorType::get(Builder.getInt1Ty(), VectorLoadType->getElementCount()); + Value *AllTrueMask = Constant::getAllOnesValue(TrueMaskTy); + Value *VectorLhsLoad = Builder.CreateIntrinsic( + Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()}, + {VectorLhsGep, AllTrueMask, VL}, nullptr, "lhs.load"); + + Value *VectorRhsGep = + Builder.CreateGEP(LoadType, PtrB, GepOffset, "", GEPB->isInBounds()); + Value *VectorRhsLoad = Builder.CreateIntrinsic( + Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()}, + {VectorRhsGep, AllTrueMask, VL}, nullptr, "rhs.load"); + + StringRef PredicateStr = CmpInst::getPredicateName(CmpInst::ICMP_NE); + auto *PredicateMDS = MDString::get(VectorLhsLoad->getContext(), PredicateStr); + Value *Pred = MetadataAsValue::get(VectorLhsLoad->getContext(), PredicateMDS); + Value *VectorMatchCmp = Builder.CreateIntrinsic( + Intrinsic::vp_icmp, {VectorLhsLoad->getType()}, + {VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, nullptr, + "mismatch.cmp"); + Value *CTZ = Builder.CreateIntrinsic( + Intrinsic::vp_cttz_elts, {ResType, VectorMatchCmp->getType()}, + {VectorMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(false), AllTrueMask, + VL}); + Value *MismatchFound = Builder.CreateICmpNE(CTZ, VL); + auto *VectorEarlyExit = BranchInst::Create(VectorLoopMismatchBlock, + VectorLoopIncBlock, MismatchFound); + Builder.Insert(VectorEarlyExit); + + DTU.applyUpdates( + {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock}, + {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}}); + + // Increment the index counter and calculate the predicate for the next + // iteration of the loop. We branch back to the start of the loop if there + // is at least one active lane. + Builder.SetInsertPoint(VectorLoopIncBlock); + Value *VL64 = Builder.CreateZExt(VL, I64Type); + Value *NewVectorIndexPhi = + Builder.CreateAdd(VectorIndexPhi, VL64, "", + /*HasNUW=*/true, /*HasNSW=*/true); + VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock); + Value *ExitCond = Builder.CreateICmpNE(NewVectorIndexPhi, ExtEnd); + auto *VectorLoopBranchBack = + BranchInst::Create(VectorLoopStartBlock, EndBlock, ExitCond); + Builder.Insert(VectorLoopBranchBack); + + DTU.applyUpdates( + {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock}, + {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}}); + + // If we found a mismatch then we need to calculate which lane in the vector + // had a mismatch and add that on to the current loop index. + Builder.SetInsertPoint(VectorLoopMismatchBlock); + + // Add LCSSA phis for CTZ and VectorIndexPhi. + auto *CTZLCSSAPhi = Builder.CreatePHI(CTZ->getType(), 1, "ctz"); + CTZLCSSAPhi->addIncoming(CTZ, VectorLoopStartBlock); + auto *VectorIndexLCSSAPhi = + Builder.CreatePHI(VectorIndexPhi->getType(), 1, "mismatch_vector_index"); + VectorIndexLCSSAPhi->addIncoming(VectorIndexPhi, VectorLoopStartBlock); + + Value *CTZI64 = Builder.CreateZExt(CTZLCSSAPhi, I64Type); + Value *VectorLoopRes64 = Builder.CreateAdd(VectorIndexLCSSAPhi, CTZI64, "", + /*HasNUW=*/true, /*HasNSW=*/true); + return Builder.CreateTrunc(VectorLoopRes64, ResType); +} + +Value *LoopIdiomVectorize::expandFindMismatch( + IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, + GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) { + Value *PtrA = GEPA->getPointerOperand(); + Value *PtrB = GEPB->getPointerOperand(); + + // Get the arguments and types for the intrinsic. + BasicBlock *Preheader = CurLoop->getLoopPreheader(); + BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator()); + LLVMContext &Ctx = PHBranch->getContext(); + Type *LoadType = Type::getInt8Ty(Ctx); + Type *ResType = Builder.getInt32Ty(); + + // Split block in the original loop preheader. + EndBlock = SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end"); + + // Create the blocks that we're going to need: + // 1. A block for checking the zero-extended length exceeds 0 + // 2. A block to check that the start and end addresses of a given array + // lie on the same page. + // 3. The vector loop preheader. + // 4. The first vector loop block. + // 5. The vector loop increment block. + // 6. A block we can jump to from the vector loop when a mismatch is found. + // 7. The first block of the scalar loop itself, containing PHIs , loads + // and cmp. + // 8. A scalar loop increment block to increment the PHIs and go back + // around the loop. + + BasicBlock *MinItCheckBlock = BasicBlock::Create( + Ctx, "mismatch_min_it_check", EndBlock->getParent(), EndBlock); + + // Update the terminator added by SplitBlock to branch to the first block + Preheader->getTerminator()->setSuccessor(0, MinItCheckBlock); + + BasicBlock *MemCheckBlock = BasicBlock::Create( + Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock); + + VectorLoopPreheaderBlock = BasicBlock::Create( + Ctx, "mismatch_vec_loop_preheader", EndBlock->getParent(), EndBlock); + + VectorLoopStartBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop", + EndBlock->getParent(), EndBlock); + + VectorLoopIncBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_inc", + EndBlock->getParent(), EndBlock); + + VectorLoopMismatchBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_found", + EndBlock->getParent(), EndBlock); + + BasicBlock *LoopPreHeaderBlock = BasicBlock::Create( + Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock); + + BasicBlock *LoopStartBlock = + BasicBlock::Create(Ctx, "mismatch_loop", EndBlock->getParent(), EndBlock); + + BasicBlock *LoopIncBlock = BasicBlock::Create( + Ctx, "mismatch_loop_inc", EndBlock->getParent(), EndBlock); + + DTU.applyUpdates({{DominatorTree::Insert, Preheader, MinItCheckBlock}, + {DominatorTree::Delete, Preheader, EndBlock}}); + + // Update LoopInfo with the new vector & scalar loops. + auto VectorLoop = LI->AllocateLoop(); + auto ScalarLoop = LI->AllocateLoop(); + + if (CurLoop->getParentLoop()) { + CurLoop->getParentLoop()->addBasicBlockToLoop(MinItCheckBlock, *LI); + CurLoop->getParentLoop()->addBasicBlockToLoop(MemCheckBlock, *LI); + CurLoop->getParentLoop()->addBasicBlockToLoop(VectorLoopPreheaderBlock, + *LI); + CurLoop->getParentLoop()->addChildLoop(VectorLoop); + CurLoop->getParentLoop()->addBasicBlockToLoop(VectorLoopMismatchBlock, *LI); + CurLoop->getParentLoop()->addBasicBlockToLoop(LoopPreHeaderBlock, *LI); + CurLoop->getParentLoop()->addChildLoop(ScalarLoop); + } else { + LI->addTopLevelLoop(VectorLoop); + LI->addTopLevelLoop(ScalarLoop); + } + + // Add the new basic blocks to their associated loops. + VectorLoop->addBasicBlockToLoop(VectorLoopStartBlock, *LI); + VectorLoop->addBasicBlockToLoop(VectorLoopIncBlock, *LI); + + ScalarLoop->addBasicBlockToLoop(LoopStartBlock, *LI); + ScalarLoop->addBasicBlockToLoop(LoopIncBlock, *LI); + + // Set up some types and constants that we intend to reuse. + Type *I64Type = Builder.getInt64Ty(); + + // Check the zero-extended iteration count > 0 + Builder.SetInsertPoint(MinItCheckBlock); + Value *ExtStart = Builder.CreateZExt(Start, I64Type); + Value *ExtEnd = Builder.CreateZExt(MaxLen, I64Type); + // This check doesn't really cost us very much. + + Value *LimitCheck = Builder.CreateICmpULE(Start, MaxLen); + BranchInst *MinItCheckBr = + BranchInst::Create(MemCheckBlock, LoopPreHeaderBlock, LimitCheck); + MinItCheckBr->setMetadata( + LLVMContext::MD_prof, + MDBuilder(MinItCheckBr->getContext()).createBranchWeights(99, 1)); + Builder.Insert(MinItCheckBr); + + DTU.applyUpdates( + {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock}, + {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}}); + + // For each of the arrays, check the start/end addresses are on the same + // page. + Builder.SetInsertPoint(MemCheckBlock); + + // The early exit in the original loop means that when performing vector + // loads we are potentially reading ahead of the early exit. So we could + // fault if crossing a page boundary. Therefore, we create runtime memory + // checks based on the minimum page size as follows: + // 1. Calculate the addresses of the first memory accesses in the loop, + // i.e. LhsStart and RhsStart. + // 2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd. + // 3. Determine which pages correspond to all the memory accesses, i.e + // LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage. + // 4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then + // we know we won't cross any page boundaries in the loop so we can + // enter the vector loop! Otherwise we fall back on the scalar loop. + Value *LhsStartGEP = Builder.CreateGEP(LoadType, PtrA, ExtStart); + Value *RhsStartGEP = Builder.CreateGEP(LoadType, PtrB, ExtStart); + Value *RhsStart = Builder.CreatePtrToInt(RhsStartGEP, I64Type); + Value *LhsStart = Builder.CreatePtrToInt(LhsStartGEP, I64Type); + Value *LhsEndGEP = Builder.CreateGEP(LoadType, PtrA, ExtEnd); + Value *RhsEndGEP = Builder.CreateGEP(LoadType, PtrB, ExtEnd); + Value *LhsEnd = Builder.CreatePtrToInt(LhsEndGEP, I64Type); + Value *RhsEnd = Builder.CreatePtrToInt(RhsEndGEP, I64Type); + + const uint64_t MinPageSize = TTI->getMinPageSize().value(); + const uint64_t AddrShiftAmt = llvm::Log2_64(MinPageSize); + Value *LhsStartPage = Builder.CreateLShr(LhsStart, AddrShiftAmt); + Value *LhsEndPage = Builder.CreateLShr(LhsEnd, AddrShiftAmt); + Value *RhsStartPage = Builder.CreateLShr(RhsStart, AddrShiftAmt); + Value *RhsEndPage = Builder.CreateLShr(RhsEnd, AddrShiftAmt); + Value *LhsPageCmp = Builder.CreateICmpNE(LhsStartPage, LhsEndPage); + Value *RhsPageCmp = Builder.CreateICmpNE(RhsStartPage, RhsEndPage); + + Value *CombinedPageCmp = Builder.CreateOr(LhsPageCmp, RhsPageCmp); + BranchInst *CombinedPageCmpCmpBr = BranchInst::Create( + LoopPreHeaderBlock, VectorLoopPreheaderBlock, CombinedPageCmp); + CombinedPageCmpCmpBr->setMetadata( + LLVMContext::MD_prof, MDBuilder(CombinedPageCmpCmpBr->getContext()) + .createBranchWeights(10, 90)); + Builder.Insert(CombinedPageCmpCmpBr); + + DTU.applyUpdates( + {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock}, + {DominatorTree::Insert, MemCheckBlock, VectorLoopPreheaderBlock}}); + + // Set up the vector loop preheader, i.e. calculate initial loop predicate, + // zero-extend MaxLen to 64-bits, determine the number of vector elements + // processed in each iteration, etc. + Builder.SetInsertPoint(VectorLoopPreheaderBlock); + + // At this point we know two things must be true: + // 1. Start <= End + // 2. ExtMaxLen <= MinPageSize due to the page checks. + // Therefore, we know that we can use a 64-bit induction variable that + // starts from 0 -> ExtMaxLen and it will not overflow. + Value *VectorLoopRes = nullptr; + switch (VectorizeStyle) { + case LoopIdiomVectorizeStyle::Masked: + VectorLoopRes = + createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd); + break; + case LoopIdiomVectorizeStyle::Predicated: + VectorLoopRes = createPredicatedFindMismatch(Builder, DTU, GEPA, GEPB, + ExtStart, ExtEnd); + break; + } + + Builder.Insert(BranchInst::Create(EndBlock)); + + DTU.applyUpdates( + {{DominatorTree::Insert, VectorLoopMismatchBlock, EndBlock}}); + + // Generate code for scalar loop. + Builder.SetInsertPoint(LoopPreHeaderBlock); + Builder.Insert(BranchInst::Create(LoopStartBlock)); + + DTU.applyUpdates( + {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}}); + + Builder.SetInsertPoint(LoopStartBlock); + PHINode *IndexPhi = Builder.CreatePHI(ResType, 2, "mismatch_index"); + IndexPhi->addIncoming(Start, LoopPreHeaderBlock); + + // Otherwise compare the values + // Load bytes from each array and compare them. + Value *GepOffset = Builder.CreateZExt(IndexPhi, I64Type); + + Value *LhsGep = + Builder.CreateGEP(LoadType, PtrA, GepOffset, "", GEPA->isInBounds()); + Value *LhsLoad = Builder.CreateLoad(LoadType, LhsGep); + + Value *RhsGep = + Builder.CreateGEP(LoadType, PtrB, GepOffset, "", GEPB->isInBounds()); + Value *RhsLoad = Builder.CreateLoad(LoadType, RhsGep); + + Value *MatchCmp = Builder.CreateICmpEQ(LhsLoad, RhsLoad); + // If we have a mismatch then exit the loop ... + BranchInst *MatchCmpBr = BranchInst::Create(LoopIncBlock, EndBlock, MatchCmp); + Builder.Insert(MatchCmpBr); + + DTU.applyUpdates({{DominatorTree::Insert, LoopStartBlock, LoopIncBlock}, + {DominatorTree::Insert, LoopStartBlock, EndBlock}}); + + // Have we reached the maximum permitted length for the loop? + Builder.SetInsertPoint(LoopIncBlock); + Value *PhiInc = Builder.CreateAdd(IndexPhi, ConstantInt::get(ResType, 1), "", + /*HasNUW=*/Index->hasNoUnsignedWrap(), + /*HasNSW=*/Index->hasNoSignedWrap()); + IndexPhi->addIncoming(PhiInc, LoopIncBlock); + Value *IVCmp = Builder.CreateICmpEQ(PhiInc, MaxLen); + BranchInst *IVCmpBr = BranchInst::Create(EndBlock, LoopStartBlock, IVCmp); + Builder.Insert(IVCmpBr); + + DTU.applyUpdates({{DominatorTree::Insert, LoopIncBlock, EndBlock}, + {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}}); + + // In the end block we need to insert a PHI node to deal with three cases: + // 1. We didn't find a mismatch in the scalar loop, so we return MaxLen. + // 2. We exitted the scalar loop early due to a mismatch and need to return + // the index that we found. + // 3. We didn't find a mismatch in the vector loop, so we return MaxLen. + // 4. We exitted the vector loop early due to a mismatch and need to return + // the index that we found. + Builder.SetInsertPoint(EndBlock, EndBlock->getFirstInsertionPt()); + PHINode *ResPhi = Builder.CreatePHI(ResType, 4, "mismatch_result"); + ResPhi->addIncoming(MaxLen, LoopIncBlock); + ResPhi->addIncoming(IndexPhi, LoopStartBlock); + ResPhi->addIncoming(MaxLen, VectorLoopIncBlock); + ResPhi->addIncoming(VectorLoopRes, VectorLoopMismatchBlock); + + Value *FinalRes = Builder.CreateTrunc(ResPhi, ResType); + + if (VerifyLoops) { + ScalarLoop->verifyLoop(); + VectorLoop->verifyLoop(); + if (!VectorLoop->isRecursivelyLCSSAForm(*DT, *LI)) + report_fatal_error("Loops must remain in LCSSA form!"); + if (!ScalarLoop->isRecursivelyLCSSAForm(*DT, *LI)) + report_fatal_error("Loops must remain in LCSSA form!"); + } + + return FinalRes; +} + +void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA, + GetElementPtrInst *GEPB, + PHINode *IndPhi, Value *MaxLen, + Instruction *Index, Value *Start, + bool IncIdx, BasicBlock *FoundBB, + BasicBlock *EndBB) { + + // Insert the byte compare code at the end of the preheader block + BasicBlock *Preheader = CurLoop->getLoopPreheader(); + BasicBlock *Header = CurLoop->getHeader(); + BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator()); + IRBuilder<> Builder(PHBranch); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); + Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc()); + + // Increment the pointer if this was done before the loads in the loop. + if (IncIdx) + Start = Builder.CreateAdd(Start, ConstantInt::get(Start->getType(), 1)); + + Value *ByteCmpRes = + expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen); + + // Replaces uses of index & induction Phi with intrinsic (we already + // checked that the the first instruction of Header is the Phi above). + assert(IndPhi->hasOneUse() && "Index phi node has more than one use!"); + Index->replaceAllUsesWith(ByteCmpRes); + + assert(PHBranch->isUnconditional() && + "Expected preheader to terminate with an unconditional branch."); + + // If no mismatch was found, we can jump to the end block. Create a + // new basic block for the compare instruction. + auto *CmpBB = BasicBlock::Create(Preheader->getContext(), "byte.compare", + Preheader->getParent()); + CmpBB->moveBefore(EndBB); + + // Replace the branch in the preheader with an always-true conditional branch. + // This ensures there is still a reference to the original loop. + Builder.CreateCondBr(Builder.getTrue(), CmpBB, Header); + PHBranch->eraseFromParent(); + + BasicBlock *MismatchEnd = cast<Instruction>(ByteCmpRes)->getParent(); + DTU.applyUpdates({{DominatorTree::Insert, MismatchEnd, CmpBB}}); + + // Create the branch to either the end or found block depending on the value + // returned by the intrinsic. + Builder.SetInsertPoint(CmpBB); + if (FoundBB != EndBB) { + Value *FoundCmp = Builder.CreateICmpEQ(ByteCmpRes, MaxLen); + Builder.CreateCondBr(FoundCmp, EndBB, FoundBB); + DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB}, + {DominatorTree::Insert, CmpBB, EndBB}}); + + } else { + Builder.CreateBr(FoundBB); + DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB}}); + } + + auto fixSuccessorPhis = [&](BasicBlock *SuccBB) { + for (PHINode &PN : SuccBB->phis()) { + // At this point we've already replaced all uses of the result from the + // loop with ByteCmp. Look through the incoming values to find ByteCmp, + // meaning this is a Phi collecting the results of the byte compare. + bool ResPhi = false; + for (Value *Op : PN.incoming_values()) + if (Op == ByteCmpRes) { + ResPhi = true; + break; + } + + // Any PHI that depended upon the result of the byte compare needs a new + // incoming value from CmpBB. This is because the original loop will get + // deleted. + if (ResPhi) + PN.addIncoming(ByteCmpRes, CmpBB); + else { + // There should be no other outside uses of other values in the + // original loop. Any incoming values should either: + // 1. Be for blocks outside the loop, which aren't interesting. Or .. + // 2. These are from blocks in the loop with values defined outside + // the loop. We should a similar incoming value from CmpBB. + for (BasicBlock *BB : PN.blocks()) + if (CurLoop->contains(BB)) { + PN.addIncoming(PN.getIncomingValueForBlock(BB), CmpBB); + break; + } + } + } + }; + + // Ensure all Phis in the successors of CmpBB have an incoming value from it. + fixSuccessorPhis(EndBB); + if (EndBB != FoundBB) + fixSuccessorPhis(FoundBB); + + // The new CmpBB block isn't part of the loop, but will need to be added to + // the outer loop if there is one. + if (!CurLoop->isOutermost()) + CurLoop->getParentLoop()->addBasicBlockToLoop(CmpBB, *LI); + + if (VerifyLoops && CurLoop->getParentLoop()) { + CurLoop->getParentLoop()->verifyLoop(); + if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(*DT, *LI)) + report_fatal_error("Loops must remain in LCSSA form!"); + } +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp index 37a356c43e29..cafec165f6d6 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -261,20 +261,20 @@ void LoopVectorizeHints::getHintsFromMetadata() { assert(LoopID->getNumOperands() > 0 && "requires at least one operand"); assert(LoopID->getOperand(0) == LoopID && "invalid loop id"); - for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + for (const MDOperand &MDO : llvm::drop_begin(LoopID->operands())) { const MDString *S = nullptr; SmallVector<Metadata *, 4> Args; // The expected hint is either a MDString or a MDNode with the first // operand a MDString. - if (const MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i))) { + if (const MDNode *MD = dyn_cast<MDNode>(MDO)) { if (!MD || MD->getNumOperands() == 0) continue; S = dyn_cast<MDString>(MD->getOperand(0)); for (unsigned i = 1, ie = MD->getNumOperands(); i < ie; ++i) Args.push_back(MD->getOperand(i)); } else { - S = dyn_cast<MDString>(LoopID->getOperand(i)); + S = dyn_cast<MDString>(MDO); assert(Args.size() == 0 && "too many arguments for MDString"); } @@ -692,7 +692,7 @@ void LoopVectorizationLegality::addInductionPhi( InductionCastsToIgnore.insert(*Casts.begin()); Type *PhiTy = Phi->getType(); - const DataLayout &DL = Phi->getModule()->getDataLayout(); + const DataLayout &DL = Phi->getDataLayout(); // Get the widest type. if (!PhiTy->isFloatingPointTy()) { @@ -1067,6 +1067,15 @@ bool LoopVectorizationLegality::canVectorizeMemory() { if (!LAI->canVectorizeMemory()) return false; + if (LAI->hasLoadStoreDependenceInvolvingLoopInvariantAddress()) { + reportVectorizationFailure("We don't allow storing to uniform addresses", + "write to a loop invariant address could not " + "be vectorized", + "CantVectorizeStoreToLoopInvariantAddress", ORE, + TheLoop); + return false; + } + // We can vectorize stores to invariant address when final reduction value is // guaranteed to be stored at the end of the loop. Also, if decision to // vectorize loop is made, runtime checks are added so as to make sure that @@ -1102,13 +1111,12 @@ bool LoopVectorizationLegality::canVectorizeMemory() { } } - if (LAI->hasDependenceInvolvingLoopInvariantAddress()) { + if (LAI->hasStoreStoreDependenceInvolvingLoopInvariantAddress()) { // For each invariant address, check its last stored value is the result // of one of our reductions. // - // We do not check if dependence with loads exists because they are - // currently rejected earlier in LoopAccessInfo::analyzeLoop. In case this - // behaviour changes we have to modify this code. + // We do not check if dependence with loads exists because that is already + // checked via hasLoadStoreDependenceInvolvingLoopInvariantAddress. ScalarEvolution *SE = PSE.getSE(); SmallVector<StoreInst *, 4> UnhandledStores; for (StoreInst *SI : LAI->getStoresToInvariantAddresses()) { @@ -1498,6 +1506,16 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) { return false; } + if (isa<SCEVCouldNotCompute>(PSE.getBackedgeTakenCount())) { + reportVectorizationFailure("could not determine number of loop iterations", + "could not determine number of loop iterations", + "CantComputeNumberOfIterations", ORE, TheLoop); + if (DoExtraAnalysis) + Result = false; + else + return false; + } + LLVM_DEBUG(dbgs() << "LV: We can vectorize this loop" << (LAI->getRuntimePointerChecking()->Need ? " (with a runtime bound check)" @@ -1525,7 +1543,7 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) { return Result; } -bool LoopVectorizationLegality::prepareToFoldTailByMasking() { +bool LoopVectorizationLegality::canFoldTailByMasking() const { LLVM_DEBUG(dbgs() << "LV: checking if tail can be folded by masking.\n"); @@ -1552,26 +1570,47 @@ bool LoopVectorizationLegality::prepareToFoldTailByMasking() { } } + for (const auto &Entry : getInductionVars()) { + PHINode *OrigPhi = Entry.first; + for (User *U : OrigPhi->users()) { + auto *UI = cast<Instruction>(U); + if (!TheLoop->contains(UI)) { + LLVM_DEBUG(dbgs() << "LV: Cannot fold tail by masking, loop IV has an " + "outside user for " + << *UI << "\n"); + return false; + } + } + } + // The list of pointers that we can safely read and write to remains empty. SmallPtrSet<Value *, 8> SafePointers; - // Collect masked ops in temporary set first to avoid partially populating - // MaskedOp if a block cannot be predicated. + // Check all blocks for predication, including those that ordinarily do not + // need predication such as the header block. SmallPtrSet<const Instruction *, 8> TmpMaskedOp; - - // Check and mark all blocks for predication, including those that ordinarily - // do not need predication such as the header block. for (BasicBlock *BB : TheLoop->blocks()) { if (!blockCanBePredicated(BB, SafePointers, TmpMaskedOp)) { - LLVM_DEBUG(dbgs() << "LV: Cannot fold tail by masking as requested.\n"); + LLVM_DEBUG(dbgs() << "LV: Cannot fold tail by masking.\n"); return false; } } LLVM_DEBUG(dbgs() << "LV: can fold tail by masking.\n"); - MaskedOp.insert(TmpMaskedOp.begin(), TmpMaskedOp.end()); return true; } +void LoopVectorizationLegality::prepareToFoldTailByMasking() { + // The list of pointers that we can safely read and write to remains empty. + SmallPtrSet<Value *, 8> SafePointers; + + // Mark all blocks for predication, including those that ordinarily do not + // need predication such as the header block. + for (BasicBlock *BB : TheLoop->blocks()) { + [[maybe_unused]] bool R = blockCanBePredicated(BB, SafePointers, MaskedOp); + assert(R && "Must be able to predicate block when tail-folding."); + } +} + } // namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h index a7ebf78e54ce..c63cf0c37f2f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -68,6 +68,7 @@ class VPBuilder { public: VPBuilder() = default; VPBuilder(VPBasicBlock *InsertBB) { setInsertPoint(InsertBB); } + VPBuilder(VPRecipeBase *InsertPt) { setInsertPoint(InsertPt); } /// Clear the insertion point: created instructions will not be inserted into /// a block. @@ -79,6 +80,13 @@ public: VPBasicBlock *getInsertBlock() const { return BB; } VPBasicBlock::iterator getInsertPoint() const { return InsertPt; } + /// Create a VPBuilder to insert after \p R. + static VPBuilder getToInsertAfter(VPRecipeBase *R) { + VPBuilder B; + B.setInsertPoint(R->getParent(), std::next(R->getIterator())); + return B; + } + /// InsertPoint - A saved insertion point. class VPInsertPoint { VPBasicBlock *Block = nullptr; @@ -131,8 +139,9 @@ public: /// Create an N-ary operation with \p Opcode, \p Operands and set \p Inst as /// its underlying Instruction. - VPValue *createNaryOp(unsigned Opcode, ArrayRef<VPValue *> Operands, - Instruction *Inst = nullptr, const Twine &Name = "") { + VPInstruction *createNaryOp(unsigned Opcode, ArrayRef<VPValue *> Operands, + Instruction *Inst = nullptr, + const Twine &Name = "") { DebugLoc DL; if (Inst) DL = Inst->getDebugLoc(); @@ -140,8 +149,8 @@ public: NewVPInst->setUnderlyingValue(Inst); return NewVPInst; } - VPValue *createNaryOp(unsigned Opcode, ArrayRef<VPValue *> Operands, - DebugLoc DL, const Twine &Name = "") { + VPInstruction *createNaryOp(unsigned Opcode, ArrayRef<VPValue *> Operands, + DebugLoc DL, const Twine &Name = "") { return createInstruction(Opcode, Operands, DL, Name); } @@ -164,7 +173,16 @@ public: VPValue *createOr(VPValue *LHS, VPValue *RHS, DebugLoc DL = {}, const Twine &Name = "") { - return createInstruction(Instruction::BinaryOps::Or, {LHS, RHS}, DL, Name); + + return tryInsertInstruction(new VPInstruction( + Instruction::BinaryOps::Or, {LHS, RHS}, + VPRecipeWithIRFlags::DisjointFlagsTy(false), DL, Name)); + } + + VPValue *createLogicalAnd(VPValue *LHS, VPValue *RHS, DebugLoc DL = {}, + const Twine &Name = "") { + return tryInsertInstruction( + new VPInstruction(VPInstruction::LogicalAnd, {LHS, RHS}, DL, Name)); } VPValue *createSelect(VPValue *Cond, VPValue *TrueVal, VPValue *FalseVal, @@ -208,7 +226,7 @@ public: /// TODO: The following VectorizationFactor was pulled out of /// LoopVectorizationCostModel class. LV also deals with -/// VectorizerParams::VectorizationFactor and VectorizationCostTy. +/// VectorizerParams::VectorizationFactor. /// We need to streamline them. /// Information about vectorization costs. @@ -244,16 +262,6 @@ struct VectorizationFactor { } }; -/// ElementCountComparator creates a total ordering for ElementCount -/// for the purposes of using it in a set structure. -struct ElementCountComparator { - bool operator()(const ElementCount &LHS, const ElementCount &RHS) const { - return std::make_tuple(LHS.isScalable(), LHS.getKnownMinValue()) < - std::make_tuple(RHS.isScalable(), RHS.getKnownMinValue()); - } -}; -using ElementCountSet = SmallSet<ElementCount, 16, ElementCountComparator>; - /// A class that represents two vectorization factors (initialized with 0 by /// default). One for fixed-width vectorization and one for scalable /// vectorization. This can be used by the vectorizer to choose from a range of @@ -326,6 +334,16 @@ class LoopVectorizationPlanner { /// A builder used to construct the current plan. VPBuilder Builder; + /// Computes the cost of \p Plan for vectorization factor \p VF. + /// + /// The current implementation requires access to the + /// LoopVectorizationLegality to handle inductions and reductions, which is + /// why it is kept separate from the VPlan-only cost infrastructure. + /// + /// TODO: Move to VPlan::cost once the use of LoopVectorizationLegality has + /// been retired. + InstructionCost cost(VPlan &Plan, ElementCount VF) const; + public: LoopVectorizationPlanner( Loop *L, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, @@ -347,6 +365,9 @@ public: /// Return the best VPlan for \p VF. VPlan &getBestPlanFor(ElementCount VF) const; + /// Return the most profitable plan and fix its VF to the most profitable one. + VPlan &getBestPlan() const; + /// Generate the IR code for the vectorized loop captured in VPlan \p BestPlan /// according to the best selected \p VF and \p UF. /// @@ -420,14 +441,16 @@ private: // converted to reductions, with one operand being vector and the other being // the scalar reduction chain. For other reductions, a select is introduced // between the phi and live-out recipes when folding the tail. - void adjustRecipesForReductions(VPBasicBlock *LatchVPBB, VPlanPtr &Plan, + void adjustRecipesForReductions(VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF); - /// \return The most profitable vectorization factor and the cost of that VF. - /// This method checks every VF in \p CandidateVFs. - VectorizationFactor - selectVectorizationFactor(const ElementCountSet &CandidateVFs); + /// \return The most profitable vectorization factor for the available VPlans + /// and the cost of that VF. + /// This is now only used to verify the decisions by the new VPlan-based + /// cost-model and will be retired once the VPlan-based cost-model is + /// stabilized. + VectorizationFactor selectVectorizationFactor(); /// Returns true if the per-lane cost of VectorizationFactor A is lower than /// that of B. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index dd596c567cd4..68363abdb817 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -59,7 +59,9 @@ #include "VPlan.h" #include "VPlanAnalysis.h" #include "VPlanHCFGBuilder.h" +#include "VPlanPatternMatch.h" #include "VPlanTransforms.h" +#include "VPlanVerifier.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -123,6 +125,7 @@ #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" +#include "llvm/IR/VectorBuilder.h" #include "llvm/IR/Verifier.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" @@ -202,6 +205,11 @@ static cl::opt<unsigned> VectorizeMemoryCheckThreshold( "vectorize-memory-check-threshold", cl::init(128), cl::Hidden, cl::desc("The maximum allowed number of runtime memory checks")); +static cl::opt<bool> UseLegacyCostModel( + "vectorize-use-legacy-cost-model", cl::init(true), cl::Hidden, + cl::desc("Use the legacy cost model instead of the VPlan-based cost model. " + "This option will be removed in the future.")); + // Option prefer-predicate-over-epilogue indicates that an epilogue is undesired, // that predication is preferred, and this lists all options. I.e., the // vectorizer will try to fold the tail-loop (epilogue) into the vector body @@ -247,10 +255,12 @@ static cl::opt<TailFoldingStyle> ForceTailFoldingStyle( clEnumValN(TailFoldingStyle::DataAndControlFlow, "data-and-control", "Create lane mask using active.lane.mask intrinsic, and use " "it for both data and control flow"), - clEnumValN( - TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck, - "data-and-control-without-rt-check", - "Similar to data-and-control, but remove the runtime check"))); + clEnumValN(TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck, + "data-and-control-without-rt-check", + "Similar to data-and-control, but remove the runtime check"), + clEnumValN(TailFoldingStyle::DataWithEVL, "data-with-evl", + "Use predicated EVL instructions for tail folding. If EVL " + "is unsupported, fallback to data-without-lane-mask."))); static cl::opt<bool> MaximizeBandwidth( "vectorizer-maximize-bandwidth", cl::init(false), cl::Hidden, @@ -267,11 +277,6 @@ static cl::opt<bool> EnableMaskedInterleavedMemAccesses( "enable-masked-interleaved-mem-accesses", cl::init(false), cl::Hidden, cl::desc("Enable vectorization on masked interleaved memory accesses in a loop")); -static cl::opt<unsigned> TinyTripCountInterleaveThreshold( - "tiny-trip-count-interleave-threshold", cl::init(128), cl::Hidden, - cl::desc("We don't interleave loops with a estimated constant trip count " - "below this number")); - static cl::opt<unsigned> ForceTargetNumScalarRegs( "force-target-num-scalar-regs", cl::init(0), cl::Hidden, cl::desc("A flag that overrides the target's number of scalar registers.")); @@ -290,7 +295,7 @@ static cl::opt<unsigned> ForceTargetMaxVectorInterleaveFactor( cl::desc("A flag that overrides the target's max interleave factor for " "vectorized loops.")); -static cl::opt<unsigned> ForceTargetInstructionCost( +cl::opt<unsigned> ForceTargetInstructionCost( "force-target-instruction-cost", cl::init(0), cl::Hidden, cl::desc("A flag that overrides the target's expected cost for " "an instruction to a single constant value. Mostly " @@ -319,12 +324,6 @@ static cl::opt<bool> EnableLoadStoreRuntimeInterleave( cl::desc( "Enable runtime interleaving until load/store ports are saturated")); -/// Interleave small loops with scalar reductions. -static cl::opt<bool> InterleaveSmallLoopScalarReduction( - "interleave-small-loop-scalar-reduction", cl::init(false), cl::Hidden, - cl::desc("Enable interleaving for loops with small iteration counts that " - "contain scalar reductions to expose ILP.")); - /// The number of stores in a loop that are allowed to need predication. static cl::opt<unsigned> NumberOfStoresToPredicate( "vectorize-num-stores-pred", cl::init(1), cl::Hidden, @@ -418,14 +417,6 @@ static bool hasIrregularType(Type *Ty, const DataLayout &DL) { return DL.getTypeAllocSizeInBits(Ty) != DL.getTypeSizeInBits(Ty); } -/// A helper function that returns the reciprocal of the block probability of -/// predicated blocks. If we return X, we are assuming the predicated block -/// will execute once for every X iterations of the loop header. -/// -/// TODO: We should use actual block probability here, if available. Currently, -/// we always assume predicated blocks have a 50% chance of executing. -static unsigned getReciprocalPredBlockProb() { return 2; } - /// Returns "best known" trip count for the specified loop \p L as defined by /// the following procedure: /// 1) Returns exact trip count if it is known. @@ -450,37 +441,6 @@ static std::optional<unsigned> getSmallBestKnownTC(ScalarEvolution &SE, return std::nullopt; } -/// Return a vector containing interleaved elements from multiple -/// smaller input vectors. -static Value *interleaveVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vals, - const Twine &Name) { - unsigned Factor = Vals.size(); - assert(Factor > 1 && "Tried to interleave invalid number of vectors"); - - VectorType *VecTy = cast<VectorType>(Vals[0]->getType()); -#ifndef NDEBUG - for (Value *Val : Vals) - assert(Val->getType() == VecTy && "Tried to interleave mismatched types"); -#endif - - // Scalable vectors cannot use arbitrary shufflevectors (only splats), so - // must use intrinsics to interleave. - if (VecTy->isScalableTy()) { - VectorType *WideVecTy = VectorType::getDoubleElementsVectorType(VecTy); - return Builder.CreateIntrinsic( - WideVecTy, Intrinsic::experimental_vector_interleave2, Vals, - /*FMFSource=*/nullptr, Name); - } - - // Fixed length. Start by concatenating all vectors into a wide vector. - Value *WideVec = concatenateVectors(Builder, Vals); - - // Interleave the elements into the wide vector. - const unsigned NumElts = VecTy->getElementCount().getFixedValue(); - return Builder.CreateShuffleVector( - WideVec, createInterleaveMask(NumElts, Factor), Name); -} - namespace { // Forward declare GeneratedRTChecks. class GeneratedRTChecks; @@ -552,11 +512,6 @@ public: // Return true if any runtime check is added. bool areSafetyChecksAdded() { return AddedSafetyChecks; } - /// A type for vectorized values in the new loop. Each value from the - /// original loop, when vectorized, is represented by UF vector values in the - /// new unrolled loop, where UF is the unroll factor. - using VectorParts = SmallVector<Value *, 2>; - /// A helper function to scalarize a single Instruction in the innermost loop. /// Generates a sequence of scalar instances for each lane between \p MinLane /// and \p MaxLane, times each part between \p MinPart and \p MaxPart, @@ -567,23 +522,9 @@ public: const VPIteration &Instance, VPTransformState &State); - /// Try to vectorize interleaved access group \p Group with the base address - /// given in \p Addr, optionally masking the vector operations if \p - /// BlockInMask is non-null. Use \p State to translate given VPValues to IR - /// values in the vectorized loop. - void vectorizeInterleaveGroup(const InterleaveGroup<Instruction> *Group, - ArrayRef<VPValue *> VPDefs, - VPTransformState &State, VPValue *Addr, - ArrayRef<VPValue *> StoredValues, - VPValue *BlockInMask, bool NeedsMaskForGaps); - /// Fix the non-induction PHIs in \p Plan. void fixNonInductionPHIs(VPlan &Plan, VPTransformState &State); - /// Returns true if the reordering of FP operations is not allowed, but we are - /// able to vectorize with strict in-order reductions for the given RdxDesc. - bool useOrderedReductions(const RecurrenceDescriptor &RdxDesc); - /// Create a new phi node for the induction variable \p OrigPhi to resume /// iteration count in the scalar epilogue, from where the vectorized loop /// left off. \p Step is the SCEV-expanded induction step to use. In cases @@ -622,14 +563,6 @@ protected: BasicBlock *MiddleBlock, BasicBlock *VectorHeader, VPlan &Plan, VPTransformState &State); - /// Create the exit value of first order recurrences in the middle block and - /// update their users. - void fixFixedOrderRecurrence(VPFirstOrderRecurrencePHIRecipe *PhiR, - VPTransformState &State); - - /// Create code for the loop exit value of the reduction. - void fixReduction(VPReductionPHIRecipe *Phi, VPTransformState &State); - /// Iteratively sink the scalarized operands of a predicated instruction into /// the block that was created for it. void sinkScalarOperands(Instruction *PredInst); @@ -637,11 +570,6 @@ protected: /// Returns (and creates if needed) the trip count of the widened loop. Value *getOrCreateVectorTripCount(BasicBlock *InsertBlock); - /// Returns a bitcasted value to the requested vector type. - /// Also handles bitcasts of vector<float> <-> vector<pointer> types. - Value *createBitOrPointerCast(Value *V, VectorType *DstVTy, - const DataLayout &DL); - /// Emit a bypass check to see if the vector trip count is zero, including if /// it overflows. void emitIterationCountCheck(BasicBlock *Bypass); @@ -675,17 +603,6 @@ protected: /// running the verifier. Return the preheader of the completed vector loop. BasicBlock *completeLoopSkeleton(); - /// Collect poison-generating recipes that may generate a poison value that is - /// used after vectorization, even when their operands are not poison. Those - /// recipes meet the following conditions: - /// * Contribute to the address computation of a recipe generating a widen - /// memory load/store (VPWidenMemoryInstructionRecipe or - /// VPInterleaveRecipe). - /// * Such a widen memory load/store has at least one underlying Instruction - /// that is in a basic block that needs predication and after vectorization - /// the generated instruction won't be predicated. - void collectPoisonGeneratingRecipes(VPTransformState &State); - /// Allow subclasses to override and print debug traces before/after vplan /// execution, when trace information is requested. virtual void printDebugTracesAtStart(){}; @@ -1028,9 +945,12 @@ void reportVectorizationFailure(const StringRef DebugMsg, << "loop not vectorized: " << OREMsg); } -void reportVectorizationInfo(const StringRef Msg, const StringRef ORETag, +/// Reports an informative message: print \p Msg for debugging purposes as well +/// as an optimization remark. Uses either \p I as location of the remark, or +/// otherwise \p TheLoop. +static void reportVectorizationInfo(const StringRef Msg, const StringRef ORETag, OptimizationRemarkEmitter *ORE, Loop *TheLoop, - Instruction *I) { + Instruction *I = nullptr) { LLVM_DEBUG(debugVectorizationMessage("", Msg, I)); LoopVectorizeHints Hints(TheLoop, true /* doesn't matter */, *ORE); ORE->emit( @@ -1057,108 +977,6 @@ static void reportVectorization(OptimizationRemarkEmitter *ORE, Loop *TheLoop, } // end namespace llvm -#ifndef NDEBUG -/// \return string containing a file name and a line # for the given loop. -static std::string getDebugLocString(const Loop *L) { - std::string Result; - if (L) { - raw_string_ostream OS(Result); - if (const DebugLoc LoopDbgLoc = L->getStartLoc()) - LoopDbgLoc.print(OS); - else - // Just print the module name. - OS << L->getHeader()->getParent()->getParent()->getModuleIdentifier(); - OS.flush(); - } - return Result; -} -#endif - -void InnerLoopVectorizer::collectPoisonGeneratingRecipes( - VPTransformState &State) { - - // Collect recipes in the backward slice of `Root` that may generate a poison - // value that is used after vectorization. - SmallPtrSet<VPRecipeBase *, 16> Visited; - auto collectPoisonGeneratingInstrsInBackwardSlice([&](VPRecipeBase *Root) { - SmallVector<VPRecipeBase *, 16> Worklist; - Worklist.push_back(Root); - - // Traverse the backward slice of Root through its use-def chain. - while (!Worklist.empty()) { - VPRecipeBase *CurRec = Worklist.back(); - Worklist.pop_back(); - - if (!Visited.insert(CurRec).second) - continue; - - // Prune search if we find another recipe generating a widen memory - // instruction. Widen memory instructions involved in address computation - // will lead to gather/scatter instructions, which don't need to be - // handled. - if (isa<VPWidenMemoryInstructionRecipe>(CurRec) || - isa<VPInterleaveRecipe>(CurRec) || - isa<VPScalarIVStepsRecipe>(CurRec) || - isa<VPCanonicalIVPHIRecipe>(CurRec) || - isa<VPActiveLaneMaskPHIRecipe>(CurRec)) - continue; - - // This recipe contributes to the address computation of a widen - // load/store. If the underlying instruction has poison-generating flags, - // drop them directly. - if (auto *RecWithFlags = dyn_cast<VPRecipeWithIRFlags>(CurRec)) { - RecWithFlags->dropPoisonGeneratingFlags(); - } else { - Instruction *Instr = dyn_cast_or_null<Instruction>( - CurRec->getVPSingleValue()->getUnderlyingValue()); - (void)Instr; - assert((!Instr || !Instr->hasPoisonGeneratingFlags()) && - "found instruction with poison generating flags not covered by " - "VPRecipeWithIRFlags"); - } - - // Add new definitions to the worklist. - for (VPValue *operand : CurRec->operands()) - if (VPRecipeBase *OpDef = operand->getDefiningRecipe()) - Worklist.push_back(OpDef); - } - }); - - // Traverse all the recipes in the VPlan and collect the poison-generating - // recipes in the backward slice starting at the address of a VPWidenRecipe or - // VPInterleaveRecipe. - auto Iter = vp_depth_first_deep(State.Plan->getEntry()); - for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(Iter)) { - for (VPRecipeBase &Recipe : *VPBB) { - if (auto *WidenRec = dyn_cast<VPWidenMemoryInstructionRecipe>(&Recipe)) { - Instruction &UnderlyingInstr = WidenRec->getIngredient(); - VPRecipeBase *AddrDef = WidenRec->getAddr()->getDefiningRecipe(); - if (AddrDef && WidenRec->isConsecutive() && - Legal->blockNeedsPredication(UnderlyingInstr.getParent())) - collectPoisonGeneratingInstrsInBackwardSlice(AddrDef); - } else if (auto *InterleaveRec = dyn_cast<VPInterleaveRecipe>(&Recipe)) { - VPRecipeBase *AddrDef = InterleaveRec->getAddr()->getDefiningRecipe(); - if (AddrDef) { - // Check if any member of the interleave group needs predication. - const InterleaveGroup<Instruction> *InterGroup = - InterleaveRec->getInterleaveGroup(); - bool NeedPredication = false; - for (int I = 0, NumMembers = InterGroup->getNumMembers(); - I < NumMembers; ++I) { - Instruction *Member = InterGroup->getMember(I); - if (Member) - NeedPredication |= - Legal->blockNeedsPredication(Member->getParent()); - } - - if (NeedPredication) - collectPoisonGeneratingInstrsInBackwardSlice(AddrDef); - } - } - } - } -} - namespace llvm { // Loop vectorization cost-model hints how the scalar epilogue loop should be @@ -1222,7 +1040,7 @@ public: bool selectUserVectorizationFactor(ElementCount UserVF) { collectUniformsAndScalars(UserVF); collectInstsToScalarize(UserVF); - return expectedCost(UserVF).first.isValid(); + return expectedCost(UserVF).isValid(); } /// \return The size (in bits) of the smallest and widest types in the code @@ -1298,11 +1116,9 @@ public: bool isProfitableToScalarize(Instruction *I, ElementCount VF) const { assert(VF.isVector() && "Profitable to scalarize relevant only for VF > 1."); - - // Cost model is not run in the VPlan-native path - return conservative - // result until this changes. - if (EnableVPlanNativePath) - return false; + assert( + TheLoop->isInnermost() && + "cost-model should not be used for outer loops (in VPlan-native path)"); auto Scalars = InstsToScalarize.find(VF); assert(Scalars != InstsToScalarize.end() && @@ -1312,6 +1128,9 @@ public: /// Returns true if \p I is known to be uniform after vectorization. bool isUniformAfterVectorization(Instruction *I, ElementCount VF) const { + assert( + TheLoop->isInnermost() && + "cost-model should not be used for outer loops (in VPlan-native path)"); // Pseudo probe needs to be duplicated for each unrolled iteration and // vector lane so that profiled loop trip count can be accurately // accumulated instead of being under counted. @@ -1321,11 +1140,6 @@ public: if (VF.isScalar()) return true; - // Cost model is not run in the VPlan-native path - return conservative - // result until this changes. - if (EnableVPlanNativePath) - return false; - auto UniformsPerVF = Uniforms.find(VF); assert(UniformsPerVF != Uniforms.end() && "VF not yet analyzed for uniformity"); @@ -1334,14 +1148,12 @@ public: /// Returns true if \p I is known to be scalar after vectorization. bool isScalarAfterVectorization(Instruction *I, ElementCount VF) const { + assert( + TheLoop->isInnermost() && + "cost-model should not be used for outer loops (in VPlan-native path)"); if (VF.isScalar()) return true; - // Cost model is not run in the VPlan-native path - return conservative - // result until this changes. - if (EnableVPlanNativePath) - return false; - auto ScalarsPerVF = Scalars.find(VF); assert(ScalarsPerVF != Scalars.end() && "Scalar values are not calculated for VF"); @@ -1399,10 +1211,9 @@ public: /// through the cost modeling. InstWidening getWideningDecision(Instruction *I, ElementCount VF) const { assert(VF.isVector() && "Expected VF to be a vector VF"); - // Cost model is not run in the VPlan-native path - return conservative - // result until this changes. - if (EnableVPlanNativePath) - return CM_GatherScatter; + assert( + TheLoop->isInnermost() && + "cost-model should not be used for outer loops (in VPlan-native path)"); std::pair<Instruction *, ElementCount> InstOnVF = std::make_pair(I, VF); auto Itr = WideningDecisions.find(InstOnVF); @@ -1570,29 +1381,40 @@ public: /// Returns true if \p I is a memory instruction in an interleaved-group /// of memory accesses that can be vectorized with wide vector loads/stores /// and shuffles. - bool interleavedAccessCanBeWidened(Instruction *I, ElementCount VF); + bool interleavedAccessCanBeWidened(Instruction *I, ElementCount VF) const; /// Check if \p Instr belongs to any interleaved access group. - bool isAccessInterleaved(Instruction *Instr) { + bool isAccessInterleaved(Instruction *Instr) const { return InterleaveInfo.isInterleaved(Instr); } /// Get the interleaved access group that \p Instr belongs to. const InterleaveGroup<Instruction> * - getInterleavedAccessGroup(Instruction *Instr) { + getInterleavedAccessGroup(Instruction *Instr) const { return InterleaveInfo.getInterleaveGroup(Instr); } /// Returns true if we're required to use a scalar epilogue for at least /// the final iteration of the original loop. bool requiresScalarEpilogue(bool IsVectorizing) const { - if (!isScalarEpilogueAllowed()) + if (!isScalarEpilogueAllowed()) { + LLVM_DEBUG(dbgs() << "LV: Loop does not require scalar epilogue\n"); return false; + } // If we might exit from anywhere but the latch, must run the exiting // iteration in scalar form. - if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) + if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) { + LLVM_DEBUG( + dbgs() << "LV: Loop requires scalar epilogue: multiple exits\n"); + return true; + } + if (IsVectorizing && InterleaveInfo.requiresScalarEpilogue()) { + LLVM_DEBUG(dbgs() << "LV: Loop requires scalar epilogue: " + "interleaved group requires scalar epilogue\n"); return true; - return IsVectorizing && InterleaveInfo.requiresScalarEpilogue(); + } + LLVM_DEBUG(dbgs() << "LV: Loop does not require scalar epilogue\n"); + return false; } /// Returns true if we're required to use a scalar epilogue for at least @@ -1617,19 +1439,67 @@ public: } /// Returns the TailFoldingStyle that is best for the current loop. - TailFoldingStyle - getTailFoldingStyle(bool IVUpdateMayOverflow = true) const { - if (!CanFoldTailByMasking) + TailFoldingStyle getTailFoldingStyle(bool IVUpdateMayOverflow = true) const { + if (!ChosenTailFoldingStyle) return TailFoldingStyle::None; + return IVUpdateMayOverflow ? ChosenTailFoldingStyle->first + : ChosenTailFoldingStyle->second; + } + + /// Selects and saves TailFoldingStyle for 2 options - if IV update may + /// overflow or not. + /// \param IsScalableVF true if scalable vector factors enabled. + /// \param UserIC User specific interleave count. + void setTailFoldingStyles(bool IsScalableVF, unsigned UserIC) { + assert(!ChosenTailFoldingStyle && "Tail folding must not be selected yet."); + if (!Legal->canFoldTailByMasking()) { + ChosenTailFoldingStyle = + std::make_pair(TailFoldingStyle::None, TailFoldingStyle::None); + return; + } - if (ForceTailFoldingStyle.getNumOccurrences()) - return ForceTailFoldingStyle; + if (!ForceTailFoldingStyle.getNumOccurrences()) { + ChosenTailFoldingStyle = std::make_pair( + TTI.getPreferredTailFoldingStyle(/*IVUpdateMayOverflow=*/true), + TTI.getPreferredTailFoldingStyle(/*IVUpdateMayOverflow=*/false)); + return; + } - return TTI.getPreferredTailFoldingStyle(IVUpdateMayOverflow); + // Set styles when forced. + ChosenTailFoldingStyle = std::make_pair(ForceTailFoldingStyle.getValue(), + ForceTailFoldingStyle.getValue()); + if (ForceTailFoldingStyle != TailFoldingStyle::DataWithEVL) + return; + // Override forced styles if needed. + // FIXME: use actual opcode/data type for analysis here. + // FIXME: Investigate opportunity for fixed vector factor. + bool EVLIsLegal = + IsScalableVF && UserIC <= 1 && + TTI.hasActiveVectorLength(0, nullptr, Align()) && + !EnableVPlanNativePath && + // FIXME: implement support for max safe dependency distance. + Legal->isSafeForAnyVectorWidth(); + if (!EVLIsLegal) { + // If for some reason EVL mode is unsupported, fallback to + // DataWithoutLaneMask to try to vectorize the loop with folded tail + // in a generic way. + ChosenTailFoldingStyle = + std::make_pair(TailFoldingStyle::DataWithoutLaneMask, + TailFoldingStyle::DataWithoutLaneMask); + LLVM_DEBUG( + dbgs() + << "LV: Preference for VP intrinsics indicated. Will " + "not try to generate VP Intrinsics " + << (UserIC > 1 + ? "since interleave count specified is greater than 1.\n" + : "due to non-interleaving reasons.\n")); + } } /// Returns true if all loop blocks should be masked to fold tail loop. bool foldTailByMasking() const { + // TODO: check if it is possible to check for None style independent of + // IVUpdateMayOverflow flag in getTailFoldingStyle. return getTailFoldingStyle() != TailFoldingStyle::None; } @@ -1640,6 +1510,12 @@ public: return foldTailByMasking() || Legal->blockNeedsPredication(BB); } + /// Returns true if VP intrinsics with explicit vector length support should + /// be generated in the tail folded loop. + bool foldTailWithEVL() const { + return getTailFoldingStyle() == TailFoldingStyle::DataWithEVL; + } + /// Returns true if the Phi is part of an inloop reduction. bool isInLoopReduction(PHINode *Phi) const { return InLoopReductions.contains(Phi); @@ -1663,20 +1539,13 @@ public: Scalars.clear(); } - /// The vectorization cost is a combination of the cost itself and a boolean - /// indicating whether any of the contributing operations will actually - /// operate on vector values after type legalization in the backend. If this - /// latter value is false, then all operations will be scalarized (i.e. no - /// vectorization has actually taken place). - using VectorizationCostTy = std::pair<InstructionCost, bool>; - /// Returns the expected execution cost. The unit of the cost does /// not matter because we use the 'cost' units to compare different /// vector widths. The cost that is returned is *not* normalized by /// the factor width. If \p Invalid is not nullptr, this function /// will add a pair(Instruction*, ElementCount) to \p Invalid for /// each instruction that has an Invalid cost for the given VF. - VectorizationCostTy + InstructionCost expectedCost(ElementCount VF, SmallVectorImpl<InstructionVFPair> *Invalid = nullptr); @@ -1687,6 +1556,16 @@ public: /// \p VF is the vectorization factor chosen for the original loop. bool isEpilogueVectorizationProfitable(const ElementCount VF) const; + /// Returns the execution time cost of an instruction for a given vector + /// width. Vector width of one means scalar. + InstructionCost getInstructionCost(Instruction *I, ElementCount VF); + + /// Return the cost of instructions in an inloop reduction pattern, if I is + /// part of that pattern. + std::optional<InstructionCost> + getReductionPatternCost(Instruction *I, ElementCount VF, Type *VectorTy, + TTI::TargetCostKind CostKind) const; + private: unsigned NumPredStores = 0; @@ -1708,25 +1587,14 @@ private: ElementCount MaxSafeVF, bool FoldTailByMasking); + /// Checks if scalable vectorization is supported and enabled. Caches the + /// result to avoid repeated debug dumps for repeated queries. + bool isScalableVectorizationAllowed(); + /// \return the maximum legal scalable VF, based on the safe max number /// of elements. ElementCount getMaxLegalScalableVF(unsigned MaxSafeElements); - /// Returns the execution time cost of an instruction for a given vector - /// width. Vector width of one means scalar. - VectorizationCostTy getInstructionCost(Instruction *I, ElementCount VF); - - /// The cost-computation logic from getInstructionCost which provides - /// the vector type as an output parameter. - InstructionCost getInstructionCost(Instruction *I, ElementCount VF, - Type *&VectorTy); - - /// Return the cost of instructions in an inloop reduction pattern, if I is - /// part of that pattern. - std::optional<InstructionCost> - getReductionPatternCost(Instruction *I, ElementCount VF, Type *VectorTy, - TTI::TargetCostKind CostKind) const; - /// Calculate vectorization cost of memory instruction \p I. InstructionCost getMemoryInstructionCost(Instruction *I, ElementCount VF); @@ -1782,8 +1650,13 @@ private: /// iterations to execute in the scalar loop. ScalarEpilogueLowering ScalarEpilogueStatus = CM_ScalarEpilogueAllowed; - /// All blocks of loop are to be masked to fold tail of scalar iterations. - bool CanFoldTailByMasking = false; + /// Control finally chosen tail folding style. The first element is used if + /// the IV update may overflow, the second element - if it does not. + std::optional<std::pair<TailFoldingStyle, TailFoldingStyle>> + ChosenTailFoldingStyle; + + /// true if scalable vectorization is supported and enabled. + std::optional<bool> IsScalableVectorizationAllowed; /// A map holding scalar costs for different vectorization factors. The /// presence of a cost for an instruction in the mapping indicates that the @@ -2118,16 +1991,18 @@ public: BestTripCount = *EstimatedTC; } + BestTripCount = std::max(BestTripCount, 1U); InstructionCost NewMemCheckCost = MemCheckCost / BestTripCount; // Let's ensure the cost is always at least 1. NewMemCheckCost = std::max(*NewMemCheckCost.getValue(), (InstructionCost::CostType)1); - LLVM_DEBUG(dbgs() - << "We expect runtime memory checks to be hoisted " - << "out of the outer loop. Cost reduced from " - << MemCheckCost << " to " << NewMemCheckCost << '\n'); + if (BestTripCount > 1) + LLVM_DEBUG(dbgs() + << "We expect runtime memory checks to be hoisted " + << "out of the outer loop. Cost reduced from " + << MemCheckCost << " to " << NewMemCheckCost << '\n'); MemCheckCost = NewMemCheckCost; } @@ -2207,7 +2082,7 @@ public: BranchInst &BI = *BranchInst::Create(Bypass, LoopVectorPreHeader, Cond); if (AddBranchWeights) - setBranchWeights(BI, SCEVCheckBypassWeights); + setBranchWeights(BI, SCEVCheckBypassWeights, /*IsExpected=*/false); ReplaceInstWithInst(SCEVCheckBlock->getTerminator(), &BI); return SCEVCheckBlock; } @@ -2235,7 +2110,7 @@ public: BranchInst &BI = *BranchInst::Create(Bypass, LoopVectorPreHeader, MemRuntimeCheckCond); if (AddBranchWeights) { - setBranchWeights(BI, MemCheckBypassWeights); + setBranchWeights(BI, MemCheckBypassWeights, /*IsExpected=*/false); } ReplaceInstWithInst(MemCheckBlock->getTerminator(), &BI); MemCheckBlock->getTerminator()->setDebugLoc( @@ -2472,276 +2347,6 @@ static bool useMaskedInterleavedAccesses(const TargetTransformInfo &TTI) { return TTI.enableMaskedInterleavedAccessVectorization(); } -// Try to vectorize the interleave group that \p Instr belongs to. -// -// E.g. Translate following interleaved load group (factor = 3): -// for (i = 0; i < N; i+=3) { -// R = Pic[i]; // Member of index 0 -// G = Pic[i+1]; // Member of index 1 -// B = Pic[i+2]; // Member of index 2 -// ... // do something to R, G, B -// } -// To: -// %wide.vec = load <12 x i32> ; Read 4 tuples of R,G,B -// %R.vec = shuffle %wide.vec, poison, <0, 3, 6, 9> ; R elements -// %G.vec = shuffle %wide.vec, poison, <1, 4, 7, 10> ; G elements -// %B.vec = shuffle %wide.vec, poison, <2, 5, 8, 11> ; B elements -// -// Or translate following interleaved store group (factor = 3): -// for (i = 0; i < N; i+=3) { -// ... do something to R, G, B -// Pic[i] = R; // Member of index 0 -// Pic[i+1] = G; // Member of index 1 -// Pic[i+2] = B; // Member of index 2 -// } -// To: -// %R_G.vec = shuffle %R.vec, %G.vec, <0, 1, 2, ..., 7> -// %B_U.vec = shuffle %B.vec, poison, <0, 1, 2, 3, u, u, u, u> -// %interleaved.vec = shuffle %R_G.vec, %B_U.vec, -// <0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11> ; Interleave R,G,B elements -// store <12 x i32> %interleaved.vec ; Write 4 tuples of R,G,B -void InnerLoopVectorizer::vectorizeInterleaveGroup( - const InterleaveGroup<Instruction> *Group, ArrayRef<VPValue *> VPDefs, - VPTransformState &State, VPValue *Addr, ArrayRef<VPValue *> StoredValues, - VPValue *BlockInMask, bool NeedsMaskForGaps) { - Instruction *Instr = Group->getInsertPos(); - const DataLayout &DL = Instr->getModule()->getDataLayout(); - - // Prepare for the vector type of the interleaved load/store. - Type *ScalarTy = getLoadStoreType(Instr); - unsigned InterleaveFactor = Group->getFactor(); - auto *VecTy = VectorType::get(ScalarTy, VF * InterleaveFactor); - - // Prepare for the new pointers. - SmallVector<Value *, 2> AddrParts; - unsigned Index = Group->getIndex(Instr); - - // TODO: extend the masked interleaved-group support to reversed access. - assert((!BlockInMask || !Group->isReverse()) && - "Reversed masked interleave-group not supported."); - - Value *Idx; - // If the group is reverse, adjust the index to refer to the last vector lane - // instead of the first. We adjust the index from the first vector lane, - // rather than directly getting the pointer for lane VF - 1, because the - // pointer operand of the interleaved access is supposed to be uniform. For - // uniform instructions, we're only required to generate a value for the - // first vector lane in each unroll iteration. - if (Group->isReverse()) { - Value *RuntimeVF = getRuntimeVF(Builder, Builder.getInt32Ty(), VF); - Idx = Builder.CreateSub(RuntimeVF, Builder.getInt32(1)); - Idx = Builder.CreateMul(Idx, Builder.getInt32(Group->getFactor())); - Idx = Builder.CreateAdd(Idx, Builder.getInt32(Index)); - Idx = Builder.CreateNeg(Idx); - } else - Idx = Builder.getInt32(-Index); - - for (unsigned Part = 0; Part < UF; Part++) { - Value *AddrPart = State.get(Addr, VPIteration(Part, 0)); - if (auto *I = dyn_cast<Instruction>(AddrPart)) - State.setDebugLocFrom(I->getDebugLoc()); - - // Notice current instruction could be any index. Need to adjust the address - // to the member of index 0. - // - // E.g. a = A[i+1]; // Member of index 1 (Current instruction) - // b = A[i]; // Member of index 0 - // Current pointer is pointed to A[i+1], adjust it to A[i]. - // - // E.g. A[i+1] = a; // Member of index 1 - // A[i] = b; // Member of index 0 - // A[i+2] = c; // Member of index 2 (Current instruction) - // Current pointer is pointed to A[i+2], adjust it to A[i]. - - bool InBounds = false; - if (auto *gep = dyn_cast<GetElementPtrInst>(AddrPart->stripPointerCasts())) - InBounds = gep->isInBounds(); - AddrPart = Builder.CreateGEP(ScalarTy, AddrPart, Idx, "", InBounds); - AddrParts.push_back(AddrPart); - } - - State.setDebugLocFrom(Instr->getDebugLoc()); - Value *PoisonVec = PoisonValue::get(VecTy); - - auto CreateGroupMask = [this, &BlockInMask, &State, &InterleaveFactor]( - unsigned Part, Value *MaskForGaps) -> Value * { - if (VF.isScalable()) { - assert(!MaskForGaps && "Interleaved groups with gaps are not supported."); - assert(InterleaveFactor == 2 && - "Unsupported deinterleave factor for scalable vectors"); - auto *BlockInMaskPart = State.get(BlockInMask, Part); - SmallVector<Value *, 2> Ops = {BlockInMaskPart, BlockInMaskPart}; - auto *MaskTy = - VectorType::get(Builder.getInt1Ty(), VF.getKnownMinValue() * 2, true); - return Builder.CreateIntrinsic( - MaskTy, Intrinsic::experimental_vector_interleave2, Ops, - /*FMFSource=*/nullptr, "interleaved.mask"); - } - - if (!BlockInMask) - return MaskForGaps; - - Value *BlockInMaskPart = State.get(BlockInMask, Part); - Value *ShuffledMask = Builder.CreateShuffleVector( - BlockInMaskPart, - createReplicatedMask(InterleaveFactor, VF.getKnownMinValue()), - "interleaved.mask"); - return MaskForGaps ? Builder.CreateBinOp(Instruction::And, ShuffledMask, - MaskForGaps) - : ShuffledMask; - }; - - // Vectorize the interleaved load group. - if (isa<LoadInst>(Instr)) { - Value *MaskForGaps = nullptr; - if (NeedsMaskForGaps) { - MaskForGaps = - createBitMaskForGaps(Builder, VF.getKnownMinValue(), *Group); - assert(MaskForGaps && "Mask for Gaps is required but it is null"); - } - - // For each unroll part, create a wide load for the group. - SmallVector<Value *, 2> NewLoads; - for (unsigned Part = 0; Part < UF; Part++) { - Instruction *NewLoad; - if (BlockInMask || MaskForGaps) { - assert(useMaskedInterleavedAccesses(*TTI) && - "masked interleaved groups are not allowed."); - Value *GroupMask = CreateGroupMask(Part, MaskForGaps); - NewLoad = - Builder.CreateMaskedLoad(VecTy, AddrParts[Part], Group->getAlign(), - GroupMask, PoisonVec, "wide.masked.vec"); - } - else - NewLoad = Builder.CreateAlignedLoad(VecTy, AddrParts[Part], - Group->getAlign(), "wide.vec"); - Group->addMetadata(NewLoad); - NewLoads.push_back(NewLoad); - } - - if (VecTy->isScalableTy()) { - assert(InterleaveFactor == 2 && - "Unsupported deinterleave factor for scalable vectors"); - - for (unsigned Part = 0; Part < UF; ++Part) { - // Scalable vectors cannot use arbitrary shufflevectors (only splats), - // so must use intrinsics to deinterleave. - Value *DI = Builder.CreateIntrinsic( - Intrinsic::experimental_vector_deinterleave2, VecTy, NewLoads[Part], - /*FMFSource=*/nullptr, "strided.vec"); - unsigned J = 0; - for (unsigned I = 0; I < InterleaveFactor; ++I) { - Instruction *Member = Group->getMember(I); - - if (!Member) - continue; - - Value *StridedVec = Builder.CreateExtractValue(DI, I); - // If this member has different type, cast the result type. - if (Member->getType() != ScalarTy) { - VectorType *OtherVTy = VectorType::get(Member->getType(), VF); - StridedVec = createBitOrPointerCast(StridedVec, OtherVTy, DL); - } - - if (Group->isReverse()) - StridedVec = Builder.CreateVectorReverse(StridedVec, "reverse"); - - State.set(VPDefs[J], StridedVec, Part); - ++J; - } - } - - return; - } - - // For each member in the group, shuffle out the appropriate data from the - // wide loads. - unsigned J = 0; - for (unsigned I = 0; I < InterleaveFactor; ++I) { - Instruction *Member = Group->getMember(I); - - // Skip the gaps in the group. - if (!Member) - continue; - - auto StrideMask = - createStrideMask(I, InterleaveFactor, VF.getKnownMinValue()); - for (unsigned Part = 0; Part < UF; Part++) { - Value *StridedVec = Builder.CreateShuffleVector( - NewLoads[Part], StrideMask, "strided.vec"); - - // If this member has different type, cast the result type. - if (Member->getType() != ScalarTy) { - assert(!VF.isScalable() && "VF is assumed to be non scalable."); - VectorType *OtherVTy = VectorType::get(Member->getType(), VF); - StridedVec = createBitOrPointerCast(StridedVec, OtherVTy, DL); - } - - if (Group->isReverse()) - StridedVec = Builder.CreateVectorReverse(StridedVec, "reverse"); - - State.set(VPDefs[J], StridedVec, Part); - } - ++J; - } - return; - } - - // The sub vector type for current instruction. - auto *SubVT = VectorType::get(ScalarTy, VF); - - // Vectorize the interleaved store group. - Value *MaskForGaps = - createBitMaskForGaps(Builder, VF.getKnownMinValue(), *Group); - assert((!MaskForGaps || useMaskedInterleavedAccesses(*TTI)) && - "masked interleaved groups are not allowed."); - assert((!MaskForGaps || !VF.isScalable()) && - "masking gaps for scalable vectors is not yet supported."); - for (unsigned Part = 0; Part < UF; Part++) { - // Collect the stored vector from each member. - SmallVector<Value *, 4> StoredVecs; - unsigned StoredIdx = 0; - for (unsigned i = 0; i < InterleaveFactor; i++) { - assert((Group->getMember(i) || MaskForGaps) && - "Fail to get a member from an interleaved store group"); - Instruction *Member = Group->getMember(i); - - // Skip the gaps in the group. - if (!Member) { - Value *Undef = PoisonValue::get(SubVT); - StoredVecs.push_back(Undef); - continue; - } - - Value *StoredVec = State.get(StoredValues[StoredIdx], Part); - ++StoredIdx; - - if (Group->isReverse()) - StoredVec = Builder.CreateVectorReverse(StoredVec, "reverse"); - - // If this member has different type, cast it to a unified type. - - if (StoredVec->getType() != SubVT) - StoredVec = createBitOrPointerCast(StoredVec, SubVT, DL); - - StoredVecs.push_back(StoredVec); - } - - // Interleave all the smaller vectors into one wider vector. - Value *IVec = interleaveVectors(Builder, StoredVecs, "interleaved.vec"); - Instruction *NewStoreInstr; - if (BlockInMask || MaskForGaps) { - Value *GroupMask = CreateGroupMask(Part, MaskForGaps); - NewStoreInstr = Builder.CreateMaskedStore(IVec, AddrParts[Part], - Group->getAlign(), GroupMask); - } else - NewStoreInstr = - Builder.CreateAlignedStore(IVec, AddrParts[Part], Group->getAlign()); - - Group->addMetadata(NewStoreInstr); - } -} - void InnerLoopVectorizer::scalarizeInstruction(const Instruction *Instr, VPReplicateRecipe *RepRecipe, const VPIteration &Instance, @@ -2822,9 +2427,8 @@ InnerLoopVectorizer::getOrCreateVectorTripCount(BasicBlock *InsertBlock) { if (Cost->foldTailByMasking()) { assert(isPowerOf2_32(VF.getKnownMinValue() * UF) && "VF*UF must be a power of 2 when folding tail by masking"); - Value *NumLanes = getRuntimeVF(Builder, Ty, VF * UF); - TC = Builder.CreateAdd( - TC, Builder.CreateSub(NumLanes, ConstantInt::get(Ty, 1)), "n.rnd.up"); + TC = Builder.CreateAdd(TC, Builder.CreateSub(Step, ConstantInt::get(Ty, 1)), + "n.rnd.up"); } // Now we need to generate the expression for the part of the loop that the @@ -2850,37 +2454,6 @@ InnerLoopVectorizer::getOrCreateVectorTripCount(BasicBlock *InsertBlock) { return VectorTripCount; } -Value *InnerLoopVectorizer::createBitOrPointerCast(Value *V, VectorType *DstVTy, - const DataLayout &DL) { - // Verify that V is a vector type with same number of elements as DstVTy. - auto *DstFVTy = cast<VectorType>(DstVTy); - auto VF = DstFVTy->getElementCount(); - auto *SrcVecTy = cast<VectorType>(V->getType()); - assert(VF == SrcVecTy->getElementCount() && "Vector dimensions do not match"); - Type *SrcElemTy = SrcVecTy->getElementType(); - Type *DstElemTy = DstFVTy->getElementType(); - assert((DL.getTypeSizeInBits(SrcElemTy) == DL.getTypeSizeInBits(DstElemTy)) && - "Vector elements must have same size"); - - // Do a direct cast if element types are castable. - if (CastInst::isBitOrNoopPointerCastable(SrcElemTy, DstElemTy, DL)) { - return Builder.CreateBitOrPointerCast(V, DstFVTy); - } - // V cannot be directly casted to desired vector type. - // May happen when V is a floating point vector but DstVTy is a vector of - // pointers or vice-versa. Handle this using a two-step bitcast using an - // intermediate Integer type for the bitcast i.e. Ptr <-> Int <-> Float. - assert((DstElemTy->isPointerTy() != SrcElemTy->isPointerTy()) && - "Only one type should be a pointer type"); - assert((DstElemTy->isFloatingPointTy() != SrcElemTy->isFloatingPointTy()) && - "Only one type should be a floating point type"); - Type *IntTy = - IntegerType::getIntNTy(V->getContext(), DL.getTypeSizeInBits(SrcElemTy)); - auto *VecIntTy = VectorType::get(IntTy, VF); - Value *CastVal = Builder.CreateBitOrPointerCast(V, VecIntTy); - return Builder.CreateBitOrPointerCast(CastVal, DstFVTy); -} - void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) { Value *Count = getTripCount(); // Reuse existing vector loop preheader for TC checks. @@ -2943,16 +2516,10 @@ void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) { // Update dominator for Bypass & LoopExit (if needed). DT->changeImmediateDominator(Bypass, TCCheckBlock); - if (!Cost->requiresScalarEpilogue(VF.isVector())) - // If there is an epilogue which must run, there's no edge from the - // middle block to exit blocks and thus no need to update the immediate - // dominator of the exit blocks. - DT->changeImmediateDominator(LoopExitBlock, TCCheckBlock); - BranchInst &BI = *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters); if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) - setBranchWeights(BI, MinItersBypassWeights); + setBranchWeights(BI, MinItersBypassWeights, /*IsExpected=*/false); ReplaceInstWithInst(TCCheckBlock->getTerminator(), &BI); LoopBypassBlocks.push_back(TCCheckBlock); } @@ -3034,33 +2601,6 @@ void InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) { LoopScalarPreHeader = SplitBlock(LoopMiddleBlock, LoopMiddleBlock->getTerminator(), DT, LI, nullptr, Twine(Prefix) + "scalar.ph"); - - auto *ScalarLatchTerm = OrigLoop->getLoopLatch()->getTerminator(); - - // Set up the middle block terminator. Two cases: - // 1) If we know that we must execute the scalar epilogue, emit an - // unconditional branch. - // 2) Otherwise, we must have a single unique exit block (due to how we - // implement the multiple exit case). In this case, set up a conditional - // branch from the middle block to the loop scalar preheader, and the - // exit block. completeLoopSkeleton will update the condition to use an - // iteration check, if required to decide whether to execute the remainder. - BranchInst *BrInst = - Cost->requiresScalarEpilogue(VF.isVector()) - ? BranchInst::Create(LoopScalarPreHeader) - : BranchInst::Create(LoopExitBlock, LoopScalarPreHeader, - Builder.getTrue()); - BrInst->setDebugLoc(ScalarLatchTerm->getDebugLoc()); - ReplaceInstWithInst(LoopMiddleBlock->getTerminator(), BrInst); - - // Update dominator for loop exit. During skeleton creation, only the vector - // pre-header and the middle block are created. The vector loop is entirely - // created during VPlan exection. - if (!Cost->requiresScalarEpilogue(VF.isVector())) - // If there is an epilogue which must run, there's no edge from the - // middle block to exit blocks and thus no need to update the immediate - // dominator of the exit blocks. - DT->changeImmediateDominator(LoopExitBlock, LoopMiddleBlock); } PHINode *InnerLoopVectorizer::createInductionResumeValue( @@ -3100,7 +2640,7 @@ PHINode *InnerLoopVectorizer::createInductionResumeValue( // Create phi nodes to merge from the backedge-taken check block. PHINode *BCResumeVal = PHINode::Create(OrigPhi->getType(), 3, "bc.resume.val", - LoopScalarPreHeader->getTerminator()); + LoopScalarPreHeader->getFirstNonPHI()); // Copy original phi DL over to the new one. BCResumeVal->setDebugLoc(OrigPhi->getDebugLoc()); @@ -3157,51 +2697,6 @@ void InnerLoopVectorizer::createInductionResumeValues( } } -BasicBlock *InnerLoopVectorizer::completeLoopSkeleton() { - // The trip counts should be cached by now. - Value *Count = getTripCount(); - Value *VectorTripCount = getOrCreateVectorTripCount(LoopVectorPreHeader); - - auto *ScalarLatchTerm = OrigLoop->getLoopLatch()->getTerminator(); - - // Add a check in the middle block to see if we have completed - // all of the iterations in the first vector loop. Three cases: - // 1) If we require a scalar epilogue, there is no conditional branch as - // we unconditionally branch to the scalar preheader. Do nothing. - // 2) If (N - N%VF) == N, then we *don't* need to run the remainder. - // Thus if tail is to be folded, we know we don't need to run the - // remainder and we can use the previous value for the condition (true). - // 3) Otherwise, construct a runtime check. - if (!Cost->requiresScalarEpilogue(VF.isVector()) && - !Cost->foldTailByMasking()) { - // Here we use the same DebugLoc as the scalar loop latch terminator instead - // of the corresponding compare because they may have ended up with - // different line numbers and we want to avoid awkward line stepping while - // debugging. Eg. if the compare has got a line number inside the loop. - // TODO: At the moment, CreateICmpEQ will simplify conditions with constant - // operands. Perform simplification directly on VPlan once the branch is - // modeled there. - IRBuilder<> B(LoopMiddleBlock->getTerminator()); - B.SetCurrentDebugLocation(ScalarLatchTerm->getDebugLoc()); - Value *CmpN = B.CreateICmpEQ(Count, VectorTripCount, "cmp.n"); - BranchInst &BI = *cast<BranchInst>(LoopMiddleBlock->getTerminator()); - BI.setCondition(CmpN); - if (hasBranchWeightMD(*ScalarLatchTerm)) { - // Assume that `Count % VectorTripCount` is equally distributed. - unsigned TripCount = UF * VF.getKnownMinValue(); - assert(TripCount > 0 && "trip count should not be zero"); - const uint32_t Weights[] = {1, TripCount - 1}; - setBranchWeights(BI, Weights); - } - } - -#ifdef EXPENSIVE_CHECKS - assert(DT->verify(DominatorTree::VerificationLevel::Fast)); -#endif - - return LoopVectorPreHeader; -} - std::pair<BasicBlock *, Value *> InnerLoopVectorizer::createVectorizedLoopSkeleton( const SCEV2ValueTy &ExpandedSCEVs) { @@ -3224,17 +2719,18 @@ InnerLoopVectorizer::createVectorizedLoopSkeleton( | [ ]_| <-- vector loop (created during VPlan execution). | | | v - \ -[ ] <--- middle-block. + \ -[ ] <--- middle-block (wrapped in VPIRBasicBlock with the branch to + | | successors created during VPlan execution) \/ | /\ v - | ->[ ] <--- new preheader. + | ->[ ] <--- new preheader (wrapped in VPIRBasicBlock). | | (opt) v <-- edge from middle to exit iff epilogue is not required. | [ ] \ | [ ]_| <-- old scalar loop to handle remainder (scalar epilogue). \ | \ v - >[ ] <-- exit block(s). + >[ ] <-- exit block(s). (wrapped in VPIRBasicBlock) ... */ @@ -3261,7 +2757,7 @@ InnerLoopVectorizer::createVectorizedLoopSkeleton( // Emit phis for the new starting index of the scalar loop. createInductionResumeValues(ExpandedSCEVs); - return {completeLoopSkeleton(), nullptr}; + return {LoopVectorPreHeader, nullptr}; } // Fix up external users of the induction variable. At this point, we are @@ -3447,37 +2943,12 @@ LoopVectorizationCostModel::getVectorIntrinsicCost(CallInst *CI, TargetTransformInfo::TCK_RecipThroughput); } -static Type *smallestIntegerVectorType(Type *T1, Type *T2) { - auto *I1 = cast<IntegerType>(cast<VectorType>(T1)->getElementType()); - auto *I2 = cast<IntegerType>(cast<VectorType>(T2)->getElementType()); - return I1->getBitWidth() < I2->getBitWidth() ? T1 : T2; -} - -static Type *largestIntegerVectorType(Type *T1, Type *T2) { - auto *I1 = cast<IntegerType>(cast<VectorType>(T1)->getElementType()); - auto *I2 = cast<IntegerType>(cast<VectorType>(T2)->getElementType()); - return I1->getBitWidth() > I2->getBitWidth() ? T1 : T2; -} - void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State, VPlan &Plan) { // Fix widened non-induction PHIs by setting up the PHI operands. if (EnableVPlanNativePath) fixNonInductionPHIs(Plan, State); - // At this point every instruction in the original loop is widened to a - // vector form. Now we need to fix the recurrences in the loop. These PHI - // nodes are currently empty because we did not want to introduce cycles. - // This is the second stage of vectorizing recurrences. Note that fixing - // reduction phis are already modeled in VPlan. - // TODO: Also model fixing fixed-order recurrence phis in VPlan. - VPRegionBlock *VectorRegion = State.Plan->getVectorLoopRegion(); - VPBasicBlock *HeaderVPBB = VectorRegion->getEntryBasicBlock(); - for (VPRecipeBase &R : HeaderVPBB->phis()) { - if (auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R)) - fixFixedOrderRecurrence(FOR, State); - } - // Forget the original basic block. PSE.getSE()->forgetLoop(OrigLoop); PSE.getSE()->forgetBlockAndLoopDispositions(); @@ -3491,6 +2962,7 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State, for (PHINode &PN : Exit->phis()) PSE.getSE()->forgetLcssaPhiWithNewPredecessor(OrigLoop, &PN); + VPRegionBlock *VectorRegion = State.Plan->getVectorLoopRegion(); VPBasicBlock *LatchVPBB = VectorRegion->getExitingBasicBlock(); Loop *VectorLoop = LI->getLoopFor(State.CFG.VPBB2IRBB[LatchVPBB]); if (Cost->requiresScalarEpilogue(VF.isVector())) { @@ -3513,10 +2985,7 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State, VectorLoop->getHeader(), Plan, State); } - // Fix LCSSA phis not already fixed earlier. Extracts may need to be generated - // in the exit block, so update the builder. - State.Builder.SetInsertPoint(State.CFG.ExitBB, - State.CFG.ExitBB->getFirstNonPHIIt()); + // Fix live-out phis not already fixed earlier. for (const auto &KV : Plan.getLiveOuts()) KV.second->fixPhi(Plan, State); @@ -3544,125 +3013,6 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State, VF.getKnownMinValue() * UF); } -void InnerLoopVectorizer::fixFixedOrderRecurrence( - VPFirstOrderRecurrencePHIRecipe *PhiR, VPTransformState &State) { - // This is the second phase of vectorizing first-order recurrences. An - // overview of the transformation is described below. Suppose we have the - // following loop. - // - // for (int i = 0; i < n; ++i) - // b[i] = a[i] - a[i - 1]; - // - // There is a first-order recurrence on "a". For this loop, the shorthand - // scalar IR looks like: - // - // scalar.ph: - // s_init = a[-1] - // br scalar.body - // - // scalar.body: - // i = phi [0, scalar.ph], [i+1, scalar.body] - // s1 = phi [s_init, scalar.ph], [s2, scalar.body] - // s2 = a[i] - // b[i] = s2 - s1 - // br cond, scalar.body, ... - // - // In this example, s1 is a recurrence because it's value depends on the - // previous iteration. In the first phase of vectorization, we created a - // vector phi v1 for s1. We now complete the vectorization and produce the - // shorthand vector IR shown below (for VF = 4, UF = 1). - // - // vector.ph: - // v_init = vector(..., ..., ..., a[-1]) - // br vector.body - // - // vector.body - // i = phi [0, vector.ph], [i+4, vector.body] - // v1 = phi [v_init, vector.ph], [v2, vector.body] - // v2 = a[i, i+1, i+2, i+3]; - // v3 = vector(v1(3), v2(0, 1, 2)) - // b[i, i+1, i+2, i+3] = v2 - v3 - // br cond, vector.body, middle.block - // - // middle.block: - // x = v2(3) - // br scalar.ph - // - // scalar.ph: - // s_init = phi [x, middle.block], [a[-1], otherwise] - // br scalar.body - // - // After execution completes the vector loop, we extract the next value of - // the recurrence (x) to use as the initial value in the scalar loop. - - // Extract the last vector element in the middle block. This will be the - // initial value for the recurrence when jumping to the scalar loop. - VPValue *PreviousDef = PhiR->getBackedgeValue(); - Value *Incoming = State.get(PreviousDef, UF - 1); - auto *ExtractForScalar = Incoming; - auto *IdxTy = Builder.getInt32Ty(); - Value *RuntimeVF = nullptr; - if (VF.isVector()) { - auto *One = ConstantInt::get(IdxTy, 1); - Builder.SetInsertPoint(LoopMiddleBlock->getTerminator()); - RuntimeVF = getRuntimeVF(Builder, IdxTy, VF); - auto *LastIdx = Builder.CreateSub(RuntimeVF, One); - ExtractForScalar = - Builder.CreateExtractElement(Incoming, LastIdx, "vector.recur.extract"); - } - - auto RecurSplice = cast<VPInstruction>(*PhiR->user_begin()); - assert(PhiR->getNumUsers() == 1 && - RecurSplice->getOpcode() == - VPInstruction::FirstOrderRecurrenceSplice && - "recurrence phi must have a single user: FirstOrderRecurrenceSplice"); - SmallVector<VPLiveOut *> LiveOuts; - for (VPUser *U : RecurSplice->users()) - if (auto *LiveOut = dyn_cast<VPLiveOut>(U)) - LiveOuts.push_back(LiveOut); - - if (!LiveOuts.empty()) { - // Extract the second last element in the middle block if the - // Phi is used outside the loop. We need to extract the phi itself - // and not the last element (the phi update in the current iteration). This - // will be the value when jumping to the exit block from the - // LoopMiddleBlock, when the scalar loop is not run at all. - Value *ExtractForPhiUsedOutsideLoop = nullptr; - if (VF.isVector()) { - auto *Idx = Builder.CreateSub(RuntimeVF, ConstantInt::get(IdxTy, 2)); - ExtractForPhiUsedOutsideLoop = Builder.CreateExtractElement( - Incoming, Idx, "vector.recur.extract.for.phi"); - } else { - assert(UF > 1 && "VF and UF cannot both be 1"); - // When loop is unrolled without vectorizing, initialize - // ExtractForPhiUsedOutsideLoop with the value just prior to unrolled - // value of `Incoming`. This is analogous to the vectorized case above: - // extracting the second last element when VF > 1. - ExtractForPhiUsedOutsideLoop = State.get(PreviousDef, UF - 2); - } - - for (VPLiveOut *LiveOut : LiveOuts) { - assert(!Cost->requiresScalarEpilogue(VF.isVector())); - PHINode *LCSSAPhi = LiveOut->getPhi(); - LCSSAPhi->addIncoming(ExtractForPhiUsedOutsideLoop, LoopMiddleBlock); - State.Plan->removeLiveOut(LCSSAPhi); - } - } - - // Fix the initial value of the original recurrence in the scalar loop. - Builder.SetInsertPoint(LoopScalarPreHeader, LoopScalarPreHeader->begin()); - PHINode *Phi = cast<PHINode>(PhiR->getUnderlyingValue()); - auto *Start = Builder.CreatePHI(Phi->getType(), 2, "scalar.recur.init"); - auto *ScalarInit = PhiR->getStartValue()->getLiveInIRValue(); - for (auto *BB : predecessors(LoopScalarPreHeader)) { - auto *Incoming = BB == LoopMiddleBlock ? ExtractForScalar : ScalarInit; - Start->addIncoming(Incoming, BB); - } - - Phi->setIncomingValueForBlock(LoopScalarPreHeader, Start); - Phi->setName("scalar.recur"); -} - void InnerLoopVectorizer::sinkScalarOperands(Instruction *PredInst) { // The basic block and loop containing the predicated instruction. auto *PredBB = PredInst->getParent(); @@ -3758,11 +3108,6 @@ void InnerLoopVectorizer::fixNonInductionPHIs(VPlan &Plan, } } -bool InnerLoopVectorizer::useOrderedReductions( - const RecurrenceDescriptor &RdxDesc) { - return Cost->useOrderedReductions(RdxDesc); -} - void LoopVectorizationCostModel::collectLoopScalars(ElementCount VF) { // We should not collect Scalars more than once per VF. Right now, this // function is called from collectUniformsAndScalars(), which already does @@ -3928,6 +3273,13 @@ void LoopVectorizationCostModel::collectLoopScalars(ElementCount VF) { if (!ScalarInd) continue; + // If the induction variable update is a fixed-order recurrence, neither the + // induction variable or its update should be marked scalar after + // vectorization. + auto *IndUpdatePhi = dyn_cast<PHINode>(IndUpdate); + if (IndUpdatePhi && Legal->isFixedOrderRecurrence(IndUpdatePhi)) + continue; + // Determine if all users of the induction variable update instruction are // scalar after vectorization. auto ScalarIndUpdate = @@ -4100,7 +3452,7 @@ LoopVectorizationCostModel::getDivRemSpeculationCost(Instruction *I, } bool LoopVectorizationCostModel::interleavedAccessCanBeWidened( - Instruction *I, ElementCount VF) { + Instruction *I, ElementCount VF) const { assert(isAccessInterleaved(I) && "Expecting interleaved access."); assert(getWideningDecision(I, VF) == CM_Unknown && "Decision should not be set yet."); @@ -4109,7 +3461,7 @@ bool LoopVectorizationCostModel::interleavedAccessCanBeWidened( // If the instruction's allocated size doesn't equal it's type size, it // requires padding and will be scalarized. - auto &DL = I->getModule()->getDataLayout(); + auto &DL = I->getDataLayout(); auto *ScalarTy = getLoadStoreType(I); if (hasIrregularType(ScalarTy, DL)) return false; @@ -4187,7 +3539,7 @@ bool LoopVectorizationCostModel::memoryInstructionCanBeWidened( // If the instruction's allocated size doesn't equal it's type size, it // requires padding and will be scalarized. - auto &DL = I->getModule()->getDataLayout(); + auto &DL = I->getDataLayout(); if (hasIrregularType(ScalarTy, DL)) return false; @@ -4220,34 +3572,37 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) { // Worklist containing uniform instructions demanding lane 0. SetVector<Instruction *> Worklist; - BasicBlock *Latch = TheLoop->getLoopLatch(); // Add uniform instructions demanding lane 0 to the worklist. Instructions - // that are scalar with predication must not be considered uniform after + // that require predication must not be considered uniform after // vectorization, because that would create an erroneous replicating region // where only a single instance out of VF should be formed. - // TODO: optimize such seldom cases if found important, see PR40816. auto addToWorklistIfAllowed = [&](Instruction *I) -> void { if (isOutOfScope(I)) { LLVM_DEBUG(dbgs() << "LV: Found not uniform due to scope: " << *I << "\n"); return; } - if (isScalarWithPredication(I, VF)) { - LLVM_DEBUG(dbgs() << "LV: Found not uniform being ScalarWithPredication: " - << *I << "\n"); + if (isPredicatedInst(I)) { + LLVM_DEBUG( + dbgs() << "LV: Found not uniform due to requiring predication: " << *I + << "\n"); return; } LLVM_DEBUG(dbgs() << "LV: Found uniform instruction: " << *I << "\n"); Worklist.insert(I); }; - // Start with the conditional branch. If the branch condition is an - // instruction contained in the loop that is only used by the branch, it is - // uniform. - auto *Cmp = dyn_cast<Instruction>(Latch->getTerminator()->getOperand(0)); - if (Cmp && TheLoop->contains(Cmp) && Cmp->hasOneUse()) - addToWorklistIfAllowed(Cmp); + // Start with the conditional branches exiting the loop. If the branch + // condition is an instruction contained in the loop that is only used by the + // branch, it is uniform. + SmallVector<BasicBlock *> Exiting; + TheLoop->getExitingBlocks(Exiting); + for (BasicBlock *E : Exiting) { + auto *Cmp = dyn_cast<Instruction>(E->getTerminator()->getOperand(0)); + if (Cmp && TheLoop->contains(Cmp) && Cmp->hasOneUse()) + addToWorklistIfAllowed(Cmp); + } auto PrevVF = VF.divideCoefficientBy(2); // Return true if all lanes perform the same memory operation, and we can @@ -4388,6 +3743,7 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) { // nodes separately. An induction variable will remain uniform if all users // of the induction variable and induction variable update remain uniform. // The code below handles both pointer and non-pointer induction variables. + BasicBlock *Latch = TheLoop->getLoopLatch(); for (const auto &Induction : Legal->getInductionVars()) { auto *Ind = Induction.first; auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch)); @@ -4454,15 +3810,18 @@ bool LoopVectorizationCostModel::runtimeChecksRequired() { return false; } -ElementCount -LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) { +bool LoopVectorizationCostModel::isScalableVectorizationAllowed() { + if (IsScalableVectorizationAllowed) + return *IsScalableVectorizationAllowed; + + IsScalableVectorizationAllowed = false; if (!TTI.supportsScalableVectors() && !ForceTargetSupportsScalableVectors) - return ElementCount::getScalable(0); + return false; if (Hints->isScalableVectorizationDisabled()) { reportVectorizationInfo("Scalable vectorization is explicitly disabled", "ScalableVectorizationDisabled", ORE, TheLoop); - return ElementCount::getScalable(0); + return false; } LLVM_DEBUG(dbgs() << "LV: Scalable vectorization is available\n"); @@ -4482,7 +3841,7 @@ LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) { "Scalable vectorization not supported for the reduction " "operations found in this loop.", "ScalableVFUnfeasible", ORE, TheLoop); - return ElementCount::getScalable(0); + return false; } // Disable scalable vectorization if the loop contains any instructions @@ -4494,17 +3853,33 @@ LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) { reportVectorizationInfo("Scalable vectorization is not supported " "for all element types found in this loop.", "ScalableVFUnfeasible", ORE, TheLoop); - return ElementCount::getScalable(0); + return false; + } + + if (!Legal->isSafeForAnyVectorWidth() && !getMaxVScale(*TheFunction, TTI)) { + reportVectorizationInfo("The target does not provide maximum vscale value " + "for safe distance analysis.", + "ScalableVFUnfeasible", ORE, TheLoop); + return false; } + IsScalableVectorizationAllowed = true; + return true; +} + +ElementCount +LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) { + if (!isScalableVectorizationAllowed()) + return ElementCount::getScalable(0); + + auto MaxScalableVF = ElementCount::getScalable( + std::numeric_limits<ElementCount::ScalarTy>::max()); if (Legal->isSafeForAnyVectorWidth()) return MaxScalableVF; + std::optional<unsigned> MaxVScale = getMaxVScale(*TheFunction, TTI); // Limit MaxScalableVF by the maximum safe dependence distance. - if (std::optional<unsigned> MaxVScale = getMaxVScale(*TheFunction, TTI)) - MaxScalableVF = ElementCount::getScalable(MaxSafeElements / *MaxVScale); - else - MaxScalableVF = ElementCount::getScalable(0); + MaxScalableVF = ElementCount::getScalable(MaxSafeElements / *MaxVScale); if (!MaxScalableVF) reportVectorizationInfo( @@ -4738,8 +4113,22 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) { // found modulo the vectorization factor is not zero, try to fold the tail // by masking. // FIXME: look for a smaller MaxVF that does divide TC rather than masking. - if (Legal->prepareToFoldTailByMasking()) { - CanFoldTailByMasking = true; + setTailFoldingStyles(MaxFactors.ScalableVF.isScalable(), UserIC); + if (foldTailByMasking()) { + if (getTailFoldingStyle() == TailFoldingStyle::DataWithEVL) { + LLVM_DEBUG( + dbgs() + << "LV: tail is folded with EVL, forcing unroll factor to be 1. Will " + "try to generate VP Intrinsics with scalable vector " + "factors only.\n"); + // Tail folded loop using VP intrinsics restricts the VF to be scalable + // for now. + // TODO: extend it for fixed vectors, if required. + assert(MaxFactors.ScalableVF.isScalable() && + "Expected scalable vector factor."); + + MaxFactors.FixedVF = ElementCount::getFixed(1); + } return MaxFactors; } @@ -4860,15 +4249,12 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget( // Select the largest VF which doesn't require more registers than existing // ones. - for (int i = RUs.size() - 1; i >= 0; --i) { - bool Selected = true; - for (auto &pair : RUs[i].MaxLocalUsers) { - unsigned TargetNumRegisters = TTI.getNumberOfRegisters(pair.first); - if (pair.second > TargetNumRegisters) - Selected = false; - } - if (Selected) { - MaxVF = VFs[i]; + for (int I = RUs.size() - 1; I >= 0; --I) { + const auto &MLU = RUs[I].MaxLocalUsers; + if (all_of(MLU, [&](decltype(MLU.front()) &LU) { + return LU.second <= TTI.getNumberOfRegisters(LU.first); + })) { + MaxVF = VFs[I]; break; } } @@ -4913,28 +4299,6 @@ bool LoopVectorizationPlanner::isMoreProfitable( unsigned MaxTripCount = PSE.getSE()->getSmallConstantMaxTripCount(OrigLoop); - if (!A.Width.isScalable() && !B.Width.isScalable() && MaxTripCount) { - // If the trip count is a known (possibly small) constant, the trip count - // will be rounded up to an integer number of iterations under - // FoldTailByMasking. The total cost in that case will be - // VecCost*ceil(TripCount/VF). When not folding the tail, the total - // cost will be VecCost*floor(TC/VF) + ScalarCost*(TC%VF). There will be - // some extra overheads, but for the purpose of comparing the costs of - // different VFs we can use this to compare the total loop-body cost - // expected after vectorization. - auto GetCostForTC = [MaxTripCount, this](unsigned VF, - InstructionCost VectorCost, - InstructionCost ScalarCost) { - return CM.foldTailByMasking() ? VectorCost * divideCeil(MaxTripCount, VF) - : VectorCost * (MaxTripCount / VF) + - ScalarCost * (MaxTripCount % VF); - }; - auto RTCostA = GetCostForTC(A.Width.getFixedValue(), CostA, A.ScalarCost); - auto RTCostB = GetCostForTC(B.Width.getFixedValue(), CostB, B.ScalarCost); - - return RTCostA < RTCostB; - } - // Improve estimate for the vector width if it is scalable. unsigned EstimatedWidthA = A.Width.getKnownMinValue(); unsigned EstimatedWidthB = B.Width.getKnownMinValue(); @@ -4948,13 +4312,39 @@ bool LoopVectorizationPlanner::isMoreProfitable( // Assume vscale may be larger than 1 (or the value being tuned for), // so that scalable vectorization is slightly favorable over fixed-width // vectorization. - if (A.Width.isScalable() && !B.Width.isScalable()) - return (CostA * B.Width.getFixedValue()) <= (CostB * EstimatedWidthA); + bool PreferScalable = !TTI.preferFixedOverScalableIfEqualCost() && + A.Width.isScalable() && !B.Width.isScalable(); + + auto CmpFn = [PreferScalable](const InstructionCost &LHS, + const InstructionCost &RHS) { + return PreferScalable ? LHS <= RHS : LHS < RHS; + }; // To avoid the need for FP division: - // (CostA / A.Width) < (CostB / B.Width) - // <=> (CostA * B.Width) < (CostB * A.Width) - return (CostA * EstimatedWidthB) < (CostB * EstimatedWidthA); + // (CostA / EstimatedWidthA) < (CostB / EstimatedWidthB) + // <=> (CostA * EstimatedWidthB) < (CostB * EstimatedWidthA) + if (!MaxTripCount) + return CmpFn(CostA * EstimatedWidthB, CostB * EstimatedWidthA); + + auto GetCostForTC = [MaxTripCount, this](unsigned VF, + InstructionCost VectorCost, + InstructionCost ScalarCost) { + // If the trip count is a known (possibly small) constant, the trip count + // will be rounded up to an integer number of iterations under + // FoldTailByMasking. The total cost in that case will be + // VecCost*ceil(TripCount/VF). When not folding the tail, the total + // cost will be VecCost*floor(TC/VF) + ScalarCost*(TC%VF). There will be + // some extra overheads, but for the purpose of comparing the costs of + // different VFs we can use this to compare the total loop-body cost + // expected after vectorization. + if (CM.foldTailByMasking()) + return VectorCost * divideCeil(MaxTripCount, VF); + return VectorCost * (MaxTripCount / VF) + ScalarCost * (MaxTripCount % VF); + }; + + auto RTCostA = GetCostForTC(EstimatedWidthA, CostA, A.ScalarCost); + auto RTCostB = GetCostForTC(EstimatedWidthB, CostB, B.ScalarCost); + return CmpFn(RTCostA, RTCostB); } static void emitInvalidCostRemarks(SmallVector<InstructionVFPair> InvalidCosts, @@ -4977,8 +4367,10 @@ static void emitInvalidCostRemarks(SmallVector<InstructionVFPair> InvalidCosts, sort(InvalidCosts, [&Numbering](InstructionVFPair &A, InstructionVFPair &B) { if (Numbering[A.first] != Numbering[B.first]) return Numbering[A.first] < Numbering[B.first]; - ElementCountComparator ECC; - return ECC(A.second, B.second); + const auto &LHS = A.second; + const auto &RHS = B.second; + return std::make_tuple(LHS.isScalable(), LHS.getKnownMinValue()) < + std::make_tuple(RHS.isScalable(), RHS.getKnownMinValue()); }); // For a list of ordered instruction-vf pairs: @@ -5021,13 +4413,111 @@ static void emitInvalidCostRemarks(SmallVector<InstructionVFPair> InvalidCosts, } while (!Tail.empty()); } -VectorizationFactor LoopVectorizationPlanner::selectVectorizationFactor( - const ElementCountSet &VFCandidates) { - InstructionCost ExpectedCost = - CM.expectedCost(ElementCount::getFixed(1)).first; +/// Check if any recipe of \p Plan will generate a vector value, which will be +/// assigned a vector register. +static bool willGenerateVectors(VPlan &Plan, ElementCount VF, + const TargetTransformInfo &TTI) { + assert(VF.isVector() && "Checking a scalar VF?"); + VPTypeAnalysis TypeInfo(Plan.getCanonicalIV()->getScalarType(), + Plan.getCanonicalIV()->getScalarType()->getContext()); + DenseSet<VPRecipeBase *> EphemeralRecipes; + collectEphemeralRecipesForVPlan(Plan, EphemeralRecipes); + // Set of already visited types. + DenseSet<Type *> Visited; + for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( + vp_depth_first_shallow(Plan.getVectorLoopRegion()->getEntry()))) { + for (VPRecipeBase &R : *VPBB) { + if (EphemeralRecipes.contains(&R)) + continue; + // Continue early if the recipe is considered to not produce a vector + // result. Note that this includes VPInstruction where some opcodes may + // produce a vector, to preserve existing behavior as VPInstructions model + // aspects not directly mapped to existing IR instructions. + switch (R.getVPDefID()) { + case VPDef::VPDerivedIVSC: + case VPDef::VPScalarIVStepsSC: + case VPDef::VPScalarCastSC: + case VPDef::VPReplicateSC: + case VPDef::VPInstructionSC: + case VPDef::VPCanonicalIVPHISC: + case VPDef::VPVectorPointerSC: + case VPDef::VPExpandSCEVSC: + case VPDef::VPEVLBasedIVPHISC: + case VPDef::VPPredInstPHISC: + case VPDef::VPBranchOnMaskSC: + continue; + case VPDef::VPReductionSC: + case VPDef::VPActiveLaneMaskPHISC: + case VPDef::VPWidenCallSC: + case VPDef::VPWidenCanonicalIVSC: + case VPDef::VPWidenCastSC: + case VPDef::VPWidenGEPSC: + case VPDef::VPWidenSC: + case VPDef::VPWidenSelectSC: + case VPDef::VPBlendSC: + case VPDef::VPFirstOrderRecurrencePHISC: + case VPDef::VPWidenPHISC: + case VPDef::VPWidenIntOrFpInductionSC: + case VPDef::VPWidenPointerInductionSC: + case VPDef::VPReductionPHISC: + case VPDef::VPInterleaveSC: + case VPDef::VPWidenLoadEVLSC: + case VPDef::VPWidenLoadSC: + case VPDef::VPWidenStoreEVLSC: + case VPDef::VPWidenStoreSC: + break; + default: + llvm_unreachable("unhandled recipe"); + } + + auto WillWiden = [&TTI, VF](Type *ScalarTy) { + Type *VectorTy = ToVectorTy(ScalarTy, VF); + unsigned NumLegalParts = TTI.getNumberOfParts(VectorTy); + if (!NumLegalParts) + return false; + if (VF.isScalable()) { + // <vscale x 1 x iN> is assumed to be profitable over iN because + // scalable registers are a distinct register class from scalar + // ones. If we ever find a target which wants to lower scalable + // vectors back to scalars, we'll need to update this code to + // explicitly ask TTI about the register class uses for each part. + return NumLegalParts <= VF.getKnownMinValue(); + } + // Two or more parts that share a register - are vectorized. + return NumLegalParts < VF.getKnownMinValue(); + }; + + // If no def nor is a store, e.g., branches, continue - no value to check. + if (R.getNumDefinedValues() == 0 && + !isa<VPWidenStoreRecipe, VPWidenStoreEVLRecipe, VPInterleaveRecipe>( + &R)) + continue; + // For multi-def recipes, currently only interleaved loads, suffice to + // check first def only. + // For stores check their stored value; for interleaved stores suffice + // the check first stored value only. In all cases this is the second + // operand. + VPValue *ToCheck = + R.getNumDefinedValues() >= 1 ? R.getVPValue(0) : R.getOperand(1); + Type *ScalarTy = TypeInfo.inferScalarType(ToCheck); + if (!Visited.insert({ScalarTy}).second) + continue; + if (WillWiden(ScalarTy)) + return true; + } + } + + return false; +} + +VectorizationFactor LoopVectorizationPlanner::selectVectorizationFactor() { + InstructionCost ExpectedCost = CM.expectedCost(ElementCount::getFixed(1)); LLVM_DEBUG(dbgs() << "LV: Scalar loop costs: " << ExpectedCost << ".\n"); assert(ExpectedCost.isValid() && "Unexpected invalid cost for scalar loop"); - assert(VFCandidates.count(ElementCount::getFixed(1)) && + assert(any_of(VPlans, + [](std::unique_ptr<VPlan> &P) { + return P->hasVF(ElementCount::getFixed(1)); + }) && "Expected Scalar VF to be a candidate"); const VectorizationFactor ScalarCost(ElementCount::getFixed(1), ExpectedCost, @@ -5035,7 +4525,8 @@ VectorizationFactor LoopVectorizationPlanner::selectVectorizationFactor( VectorizationFactor ChosenFactor = ScalarCost; bool ForceVectorization = Hints.getForce() == LoopVectorizeHints::FK_Enabled; - if (ForceVectorization && VFCandidates.size() > 1) { + if (ForceVectorization && + (VPlans.size() > 1 || !VPlans[0]->hasScalarVFOnly())) { // Ignore scalar width, because the user explicitly wants vectorization. // Initialize cost to max so that VF = 2 is, at least, chosen during cost // evaluation. @@ -5043,43 +4534,45 @@ VectorizationFactor LoopVectorizationPlanner::selectVectorizationFactor( } SmallVector<InstructionVFPair> InvalidCosts; - for (const auto &i : VFCandidates) { - // The cost for scalar VF=1 is already calculated, so ignore it. - if (i.isScalar()) - continue; + for (auto &P : VPlans) { + for (ElementCount VF : P->vectorFactors()) { + // The cost for scalar VF=1 is already calculated, so ignore it. + if (VF.isScalar()) + continue; - LoopVectorizationCostModel::VectorizationCostTy C = - CM.expectedCost(i, &InvalidCosts); - VectorizationFactor Candidate(i, C.first, ScalarCost.ScalarCost); + InstructionCost C = CM.expectedCost(VF, &InvalidCosts); + VectorizationFactor Candidate(VF, C, ScalarCost.ScalarCost); #ifndef NDEBUG - unsigned AssumedMinimumVscale = - getVScaleForTuning(OrigLoop, TTI).value_or(1); - unsigned Width = - Candidate.Width.isScalable() - ? Candidate.Width.getKnownMinValue() * AssumedMinimumVscale - : Candidate.Width.getFixedValue(); - LLVM_DEBUG(dbgs() << "LV: Vector loop of width " << i - << " costs: " << (Candidate.Cost / Width)); - if (i.isScalable()) - LLVM_DEBUG(dbgs() << " (assuming a minimum vscale of " - << AssumedMinimumVscale << ")"); - LLVM_DEBUG(dbgs() << ".\n"); + unsigned AssumedMinimumVscale = + getVScaleForTuning(OrigLoop, TTI).value_or(1); + unsigned Width = + Candidate.Width.isScalable() + ? Candidate.Width.getKnownMinValue() * AssumedMinimumVscale + : Candidate.Width.getFixedValue(); + LLVM_DEBUG(dbgs() << "LV: Vector loop of width " << VF + << " costs: " << (Candidate.Cost / Width)); + if (VF.isScalable()) + LLVM_DEBUG(dbgs() << " (assuming a minimum vscale of " + << AssumedMinimumVscale << ")"); + LLVM_DEBUG(dbgs() << ".\n"); #endif - if (!C.second && !ForceVectorization) { - LLVM_DEBUG( - dbgs() << "LV: Not considering vector loop of width " << i - << " because it will not generate any vector instructions.\n"); - continue; - } + if (!ForceVectorization && !willGenerateVectors(*P, VF, TTI)) { + LLVM_DEBUG( + dbgs() + << "LV: Not considering vector loop of width " << VF + << " because it will not generate any vector instructions.\n"); + continue; + } - // If profitable add it to ProfitableVF list. - if (isMoreProfitable(Candidate, ScalarCost)) - ProfitableVFs.push_back(Candidate); + // If profitable add it to ProfitableVF list. + if (isMoreProfitable(Candidate, ScalarCost)) + ProfitableVFs.push_back(Candidate); - if (isMoreProfitable(Candidate, ChosenFactor)) - ChosenFactor = Candidate; + if (isMoreProfitable(Candidate, ChosenFactor)) + ChosenFactor = Candidate; + } } emitInvalidCostRemarks(InvalidCosts, ORE, OrigLoop); @@ -5258,7 +4751,7 @@ std::pair<unsigned, unsigned> LoopVectorizationCostModel::getSmallestAndWidestTypes() { unsigned MinWidth = -1U; unsigned MaxWidth = 8; - const DataLayout &DL = TheFunction->getParent()->getDataLayout(); + const DataLayout &DL = TheFunction->getDataLayout(); // For in-loop reductions, no element types are added to ElementTypesInLoop // if there are no loads/stores in the loop. In this case, check through the // reduction variables to determine the maximum width. @@ -5349,25 +4842,24 @@ LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF, if (!isScalarEpilogueAllowed()) return 1; + // Do not interleave if EVL is preferred and no User IC is specified. + if (foldTailWithEVL()) { + LLVM_DEBUG(dbgs() << "LV: Preference for VP intrinsics indicated. " + "Unroll factor forced to be 1.\n"); + return 1; + } + // We used the distance for the interleave count. if (!Legal->isSafeForAnyVectorWidth()) return 1; auto BestKnownTC = getSmallBestKnownTC(*PSE.getSE(), TheLoop); const bool HasReductions = !Legal->getReductionVars().empty(); - // Do not interleave loops with a relatively small known or estimated trip - // count. But we will interleave when InterleaveSmallLoopScalarReduction is - // enabled, and the code has scalar reductions(HasReductions && VF = 1), - // because with the above conditions interleaving can expose ILP and break - // cross iteration dependences for reductions. - if (BestKnownTC && (*BestKnownTC < TinyTripCountInterleaveThreshold) && - !(InterleaveSmallLoopScalarReduction && HasReductions && VF.isScalar())) - return 1; // If we did not calculate the cost for VF (because the user selected the VF) // then we calculate the cost of VF here. if (LoopCost == 0) { - LoopCost = expectedCost(VF).first; + LoopCost = expectedCost(VF); assert(LoopCost.isValid() && "Expected to have chosen a VF with valid cost"); // Loop body is free and there is no need for interleaving. @@ -5443,7 +4935,12 @@ LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF, assert(EstimatedVF >= 1 && "Estimated VF shouldn't be less than 1"); unsigned KnownTC = PSE.getSE()->getSmallConstantTripCount(TheLoop); - if (KnownTC) { + if (KnownTC > 0) { + // At least one iteration must be scalar when this constraint holds. So the + // maximum available iterations for interleaving is one less. + unsigned AvailableTC = + requiresScalarEpilogue(VF.isVector()) ? KnownTC - 1 : KnownTC; + // If trip count is known we select between two prospective ICs, where // 1) the aggressive IC is capped by the trip count divided by VF // 2) the conservative IC is capped by the trip count divided by (VF * 2) @@ -5453,27 +4950,35 @@ LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF, // we run the vector loop at least twice. unsigned InterleaveCountUB = bit_floor( - std::max(1u, std::min(KnownTC / EstimatedVF, MaxInterleaveCount))); + std::max(1u, std::min(AvailableTC / EstimatedVF, MaxInterleaveCount))); unsigned InterleaveCountLB = bit_floor(std::max( - 1u, std::min(KnownTC / (EstimatedVF * 2), MaxInterleaveCount))); + 1u, std::min(AvailableTC / (EstimatedVF * 2), MaxInterleaveCount))); MaxInterleaveCount = InterleaveCountLB; if (InterleaveCountUB != InterleaveCountLB) { - unsigned TailTripCountUB = (KnownTC % (EstimatedVF * InterleaveCountUB)); - unsigned TailTripCountLB = (KnownTC % (EstimatedVF * InterleaveCountLB)); + unsigned TailTripCountUB = + (AvailableTC % (EstimatedVF * InterleaveCountUB)); + unsigned TailTripCountLB = + (AvailableTC % (EstimatedVF * InterleaveCountLB)); // If both produce same scalar tail, maximize the IC to do the same work // in fewer vector loop iterations if (TailTripCountUB == TailTripCountLB) MaxInterleaveCount = InterleaveCountUB; } - } else if (BestKnownTC) { + } else if (BestKnownTC && *BestKnownTC > 0) { + // At least one iteration must be scalar when this constraint holds. So the + // maximum available iterations for interleaving is one less. + unsigned AvailableTC = requiresScalarEpilogue(VF.isVector()) + ? (*BestKnownTC) - 1 + : *BestKnownTC; + // If trip count is an estimated compile time constant, limit the // IC to be capped by the trip count divided by VF * 2, such that the vector // loop runs at least twice to make interleaving seem profitable when there // is an epilogue loop present. Since exact Trip count is not known we // choose to be conservative in our IC estimate. MaxInterleaveCount = bit_floor(std::max( - 1u, std::min(*BestKnownTC / (EstimatedVF * 2), MaxInterleaveCount))); + 1u, std::min(AvailableTC / (EstimatedVF * 2), MaxInterleaveCount))); } assert(MaxInterleaveCount > 0 && @@ -5577,8 +5082,7 @@ LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF, // If there are scalar reductions and TTI has enabled aggressive // interleaving for reductions, we will interleave to expose ILP. - if (InterleaveSmallLoopScalarReduction && VF.isScalar() && - AggressivelyInterleaveReductions) { + if (VF.isScalar() && AggressivelyInterleaveReductions) { LLVM_DEBUG(dbgs() << "LV: Interleaving to expose ILP.\n"); // Interleave no less than SmallIC but not as aggressive as the normal IC // to satisfy the rare situation when resources are too limited. @@ -5845,15 +5349,21 @@ void LoopVectorizationCostModel::collectInstsToScalarize(ElementCount VF) { for (Instruction &I : *BB) if (isScalarWithPredication(&I, VF)) { ScalarCostsTy ScalarCosts; - // Do not apply discount if scalable, because that would lead to - // invalid scalarization costs. - // Do not apply discount logic if hacked cost is needed - // for emulated masked memrefs. - if (!VF.isScalable() && !useEmulatedMaskMemRefHack(&I, VF) && + // Do not apply discount logic for: + // 1. Scalars after vectorization, as there will only be a single copy + // of the instruction. + // 2. Scalable VF, as that would lead to invalid scalarization costs. + // 3. Emulated masked memrefs, if a hacked cost is needed. + if (!isScalarAfterVectorization(&I, VF) && !VF.isScalable() && + !useEmulatedMaskMemRefHack(&I, VF) && computePredInstDiscount(&I, ScalarCosts, VF) >= 0) ScalarCostsVF.insert(ScalarCosts.begin(), ScalarCosts.end()); // Remember that BB will remain after vectorization. PredicatedBBsAfterVectorization[VF].insert(BB); + for (auto *Pred : predecessors(BB)) { + if (Pred->getSingleSuccessor() == BB) + PredicatedBBsAfterVectorization[VF].insert(Pred); + } } } } @@ -5920,15 +5430,14 @@ InstructionCost LoopVectorizationCostModel::computePredInstDiscount( // Compute the cost of the vector instruction. Note that this cost already // includes the scalarization overhead of the predicated instruction. - InstructionCost VectorCost = getInstructionCost(I, VF).first; + InstructionCost VectorCost = getInstructionCost(I, VF); // Compute the cost of the scalarized instruction. This cost is the cost of // the instruction as if it wasn't if-converted and instead remained in the // predicated block. We will scale this cost by block probability after // computing the scalarization overhead. InstructionCost ScalarCost = - VF.getFixedValue() * - getInstructionCost(I, ElementCount::getFixed(1)).first; + VF.getFixedValue() * getInstructionCost(I, ElementCount::getFixed(1)); // Compute the scalarization overhead of needed insertelement instructions // and phi nodes. @@ -5972,14 +5481,13 @@ InstructionCost LoopVectorizationCostModel::computePredInstDiscount( return Discount; } -LoopVectorizationCostModel::VectorizationCostTy -LoopVectorizationCostModel::expectedCost( +InstructionCost LoopVectorizationCostModel::expectedCost( ElementCount VF, SmallVectorImpl<InstructionVFPair> *Invalid) { - VectorizationCostTy Cost; + InstructionCost Cost; // For each block. for (BasicBlock *BB : TheLoop->blocks()) { - VectorizationCostTy BlockCost; + InstructionCost BlockCost; // For each instruction in the old loop. for (Instruction &I : BB->instructionsWithoutDebug()) { @@ -5988,22 +5496,19 @@ LoopVectorizationCostModel::expectedCost( (VF.isVector() && VecValuesToIgnore.count(&I))) continue; - VectorizationCostTy C = getInstructionCost(&I, VF); + InstructionCost C = getInstructionCost(&I, VF); // Check if we should override the cost. - if (C.first.isValid() && - ForceTargetInstructionCost.getNumOccurrences() > 0) - C.first = InstructionCost(ForceTargetInstructionCost); + if (C.isValid() && ForceTargetInstructionCost.getNumOccurrences() > 0) + C = InstructionCost(ForceTargetInstructionCost); // Keep a list of instructions with invalid costs. - if (Invalid && !C.first.isValid()) + if (Invalid && !C.isValid()) Invalid->emplace_back(&I, VF); - BlockCost.first += C.first; - BlockCost.second |= C.second; - LLVM_DEBUG(dbgs() << "LV: Found an estimated cost of " << C.first - << " for VF " << VF << " For instruction: " << I - << '\n'); + BlockCost += C; + LLVM_DEBUG(dbgs() << "LV: Found an estimated cost of " << C << " for VF " + << VF << " For instruction: " << I << '\n'); } // If we are vectorizing a predicated block, it will have been @@ -6014,10 +5519,9 @@ LoopVectorizationCostModel::expectedCost( // cost by the probability of executing it. blockNeedsPredication from // Legal is used so as to not include all blocks in tail folded loops. if (VF.isScalar() && Legal->blockNeedsPredication(BB)) - BlockCost.first /= getReciprocalPredBlockProb(); + BlockCost /= getReciprocalPredBlockProb(); - Cost.first += BlockCost.first; - Cost.second |= BlockCost.second; + Cost += BlockCost; } return Cost; @@ -6273,12 +5777,20 @@ LoopVectorizationCostModel::getReductionPatternCost( const RecurrenceDescriptor &RdxDesc = Legal->getReductionVars().find(cast<PHINode>(ReductionPhi))->second; - InstructionCost BaseCost = TTI.getArithmeticReductionCost( - RdxDesc.getOpcode(), VectorTy, RdxDesc.getFastMathFlags(), CostKind); + InstructionCost BaseCost; + RecurKind RK = RdxDesc.getRecurrenceKind(); + if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RK)) { + Intrinsic::ID MinMaxID = getMinMaxReductionIntrinsicOp(RK); + BaseCost = TTI.getMinMaxReductionCost(MinMaxID, VectorTy, + RdxDesc.getFastMathFlags(), CostKind); + } else { + BaseCost = TTI.getArithmeticReductionCost( + RdxDesc.getOpcode(), VectorTy, RdxDesc.getFastMathFlags(), CostKind); + } // For a call to the llvm.fmuladd intrinsic we need to add the cost of a // normal fmul instruction to the cost of the fadd reduction. - if (RdxDesc.getRecurrenceKind() == RecurKind::FMulAdd) + if (RK == RecurKind::FMulAdd) BaseCost += TTI.getArithmeticInstrCost(Instruction::FMul, VectorTy, CostKind); @@ -6416,49 +5928,6 @@ LoopVectorizationCostModel::getMemoryInstructionCost(Instruction *I, return getWideningCost(I, VF); } -LoopVectorizationCostModel::VectorizationCostTy -LoopVectorizationCostModel::getInstructionCost(Instruction *I, - ElementCount VF) { - // If we know that this instruction will remain uniform, check the cost of - // the scalar version. - if (isUniformAfterVectorization(I, VF)) - VF = ElementCount::getFixed(1); - - if (VF.isVector() && isProfitableToScalarize(I, VF)) - return VectorizationCostTy(InstsToScalarize[VF][I], false); - - // Forced scalars do not have any scalarization overhead. - auto ForcedScalar = ForcedScalars.find(VF); - if (VF.isVector() && ForcedScalar != ForcedScalars.end()) { - auto InstSet = ForcedScalar->second; - if (InstSet.count(I)) - return VectorizationCostTy( - (getInstructionCost(I, ElementCount::getFixed(1)).first * - VF.getKnownMinValue()), - false); - } - - Type *VectorTy; - InstructionCost C = getInstructionCost(I, VF, VectorTy); - - bool TypeNotScalarized = false; - if (VF.isVector() && VectorTy->isVectorTy()) { - if (unsigned NumParts = TTI.getNumberOfParts(VectorTy)) { - if (VF.isScalable()) - // <vscale x 1 x iN> is assumed to be profitable over iN because - // scalable registers are a distinct register class from scalar ones. - // If we ever find a target which wants to lower scalable vectors - // back to scalars, we'll need to update this code to explicitly - // ask TTI about the register class uses for each part. - TypeNotScalarized = NumParts <= VF.getKnownMinValue(); - else - TypeNotScalarized = NumParts < VF.getKnownMinValue(); - } else - C = InstructionCost::getInvalid(); - } - return VectorizationCostTy(C, TypeNotScalarized); -} - InstructionCost LoopVectorizationCostModel::getScalarizationOverhead( Instruction *I, ElementCount VF, TTI::TargetCostKind CostKind) const { @@ -6849,8 +6318,25 @@ void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) { } InstructionCost -LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, - Type *&VectorTy) { +LoopVectorizationCostModel::getInstructionCost(Instruction *I, + ElementCount VF) { + // If we know that this instruction will remain uniform, check the cost of + // the scalar version. + if (isUniformAfterVectorization(I, VF)) + VF = ElementCount::getFixed(1); + + if (VF.isVector() && isProfitableToScalarize(I, VF)) + return InstsToScalarize[VF][I]; + + // Forced scalars do not have any scalarization overhead. + auto ForcedScalar = ForcedScalars.find(VF); + if (VF.isVector() && ForcedScalar != ForcedScalars.end()) { + auto InstSet = ForcedScalar->second; + if (InstSet.count(I)) + return getInstructionCost(I, ElementCount::getFixed(1)) * + VF.getKnownMinValue(); + } + Type *RetTy = I->getType(); if (canTruncateToMinimalBitwidth(I, VF)) RetTy = IntegerType::get(RetTy->getContext(), MinBWs[I]); @@ -6873,6 +6359,7 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, }; (void) hasSingleCopyAfterVectorization; + Type *VectorTy; if (isScalarAfterVectorization(I, VF)) { // With the exception of GEPs and PHIs, after scalarization there should // only be one copy of the instruction generated in the loop. This is @@ -6888,6 +6375,10 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, } else VectorTy = ToVectorTy(RetTy, VF); + if (VF.isVector() && VectorTy->isVectorTy() && + !TTI.getNumberOfParts(VectorTy)) + return InstructionCost::getInvalid(); + // TODO: We need to estimate the cost of intrinsic calls. switch (I->getOpcode()) { case Instruction::GetElementPtr: @@ -6900,11 +6391,15 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, // In cases of scalarized and predicated instructions, there will be VF // predicated blocks in the vectorized loop. Each branch around these // blocks requires also an extract of its vector compare i1 element. + // Note that the conditional branch from the loop latch will be replaced by + // a single branch controlling the loop, so there is no extra overhead from + // scalarization. bool ScalarPredicatedBB = false; BranchInst *BI = cast<BranchInst>(I); if (VF.isVector() && BI->isConditional() && (PredicatedBBsAfterVectorization[VF].count(BI->getSuccessor(0)) || - PredicatedBBsAfterVectorization[VF].count(BI->getSuccessor(1)))) + PredicatedBBsAfterVectorization[VF].count(BI->getSuccessor(1))) && + BI->getParent() != TheLoop->getLoopLatch()) ScalarPredicatedBB = true; if (ScalarPredicatedBB) { @@ -6934,6 +6429,11 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, // First-order recurrences are replaced by vector shuffles inside the loop. if (VF.isVector() && Legal->isFixedOrderRecurrence(Phi)) { + // For <vscale x 1 x i64>, if vscale = 1 we are unable to extract the + // penultimate value of the recurrence. + // TODO: Consider vscale_range info. + if (VF.isScalable() && VF.getKnownMinValue() == 1) + return InstructionCost::getInvalid(); SmallVector<int> Mask(VF.getKnownMinValue()); std::iota(Mask.begin(), Mask.end(), VF.getKnownMinValue() - 1); return TTI.getShuffleCost(TargetTransformInfo::SK_Splice, @@ -6999,25 +6499,10 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, Op2Info.Kind = TargetTransformInfo::OK_UniformValue; SmallVector<const Value *, 4> Operands(I->operand_values()); - auto InstrCost = TTI.getArithmeticInstrCost( + return TTI.getArithmeticInstrCost( I->getOpcode(), VectorTy, CostKind, {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None}, - Op2Info, Operands, I); - - // Some targets can replace frem with vector library calls. - InstructionCost VecCallCost = InstructionCost::getInvalid(); - if (I->getOpcode() == Instruction::FRem) { - LibFunc Func; - if (TLI->getLibFunc(I->getOpcode(), I->getType(), Func) && - TLI->isFunctionVectorizable(TLI->getName(Func), VF)) { - SmallVector<Type *, 4> OpTypes; - for (auto &Op : I->operands()) - OpTypes.push_back(Op->getType()); - VecCallCost = - TTI.getCallInstrCost(nullptr, VectorTy, OpTypes, CostKind); - } - } - return std::min(InstrCost, VecCallCost); + Op2Info, Operands, I, TLI); } case Instruction::FNeg: { return TTI.getArithmeticInstrCost( @@ -7157,25 +6642,20 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF, return *RedCost; Type *SrcScalarTy = I->getOperand(0)->getType(); + Instruction *Op0AsInstruction = dyn_cast<Instruction>(I->getOperand(0)); + if (canTruncateToMinimalBitwidth(Op0AsInstruction, VF)) + SrcScalarTy = + IntegerType::get(SrcScalarTy->getContext(), MinBWs[Op0AsInstruction]); Type *SrcVecTy = VectorTy->isVectorTy() ? ToVectorTy(SrcScalarTy, VF) : SrcScalarTy; + if (canTruncateToMinimalBitwidth(I, VF)) { - // This cast is going to be shrunk. This may remove the cast or it might - // turn it into slightly different cast. For example, if MinBW == 16, - // "zext i8 %1 to i32" becomes "zext i8 %1 to i16". - // - // Calculate the modified src and dest types. - Type *MinVecTy = VectorTy; - if (Opcode == Instruction::Trunc) { - SrcVecTy = smallestIntegerVectorType(SrcVecTy, MinVecTy); - VectorTy = - largestIntegerVectorType(ToVectorTy(I->getType(), VF), MinVecTy); - } else if (Opcode == Instruction::ZExt || Opcode == Instruction::SExt) { - // Leave SrcVecTy unchanged - we only shrink the destination element - // type. - VectorTy = - smallestIntegerVectorType(ToVectorTy(I->getType(), VF), MinVecTy); - } + // If the result type is <= the source type, there will be no extend + // after truncating the users to the minimal required bitwidth. + if (VectorTy->getScalarSizeInBits() <= SrcVecTy->getScalarSizeInBits() && + (I->getOpcode() == Instruction::ZExt || + I->getOpcode() == Instruction::SExt)) + return 0; } return TTI.getCastInstrCost(Opcode, VectorTy, SrcVecTy, CCH, CostKind, I); @@ -7200,16 +6680,43 @@ void LoopVectorizationCostModel::collectValuesToIgnore() { // Ignore ephemeral values. CodeMetrics::collectEphemeralValues(TheLoop, AC, ValuesToIgnore); - // Find all stores to invariant variables. Since they are going to sink - // outside the loop we do not need calculate cost for them. + SmallVector<Value *, 4> DeadInterleavePointerOps; for (BasicBlock *BB : TheLoop->blocks()) for (Instruction &I : *BB) { + // Find all stores to invariant variables. Since they are going to sink + // outside the loop we do not need calculate cost for them. StoreInst *SI; if ((SI = dyn_cast<StoreInst>(&I)) && Legal->isInvariantAddressOfReduction(SI->getPointerOperand())) ValuesToIgnore.insert(&I); + + // For interleave groups, we only create a pointer for the start of the + // interleave group. Queue up addresses of group members except the insert + // position for further processing. + if (isAccessInterleaved(&I)) { + auto *Group = getInterleavedAccessGroup(&I); + if (Group->getInsertPos() == &I) + continue; + Value *PointerOp = getLoadStorePointerOperand(&I); + DeadInterleavePointerOps.push_back(PointerOp); + } } + // Mark ops feeding interleave group members as free, if they are only used + // by other dead computations. + for (unsigned I = 0; I != DeadInterleavePointerOps.size(); ++I) { + auto *Op = dyn_cast<Instruction>(DeadInterleavePointerOps[I]); + if (!Op || !TheLoop->contains(Op) || any_of(Op->users(), [this](User *U) { + Instruction *UI = cast<Instruction>(U); + return !VecValuesToIgnore.contains(U) && + (!isAccessInterleaved(UI) || + getInterleavedAccessGroup(UI)->getInsertPos() == UI); + })) + continue; + VecValuesToIgnore.insert(Op); + DeadInterleavePointerOps.append(Op->op_begin(), Op->op_end()); + } + // Ignore type-promoting instructions we identified during reduction // detection. for (const auto &Reduction : Legal->getReductionVars()) { @@ -7370,6 +6877,9 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) { CM.invalidateCostModelingDecisions(); } + if (CM.foldTailByMasking()) + Legal->prepareToFoldTailByMasking(); + ElementCount MaxUserVF = UserVF.isScalable() ? MaxFactors.ScalableVF : MaxFactors.FixedVF; bool UserVFIsLegal = ElementCount::isKnownLE(UserVF, MaxUserVF); @@ -7395,14 +6905,14 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) { "InvalidCost", ORE, OrigLoop); } - // Populate the set of Vectorization Factor Candidates. - ElementCountSet VFCandidates; + // Collect the Vectorization Factor Candidates. + SmallVector<ElementCount> VFCandidates; for (auto VF = ElementCount::getFixed(1); ElementCount::isKnownLE(VF, MaxFactors.FixedVF); VF *= 2) - VFCandidates.insert(VF); + VFCandidates.push_back(VF); for (auto VF = ElementCount::getScalable(1); ElementCount::isKnownLE(VF, MaxFactors.ScalableVF); VF *= 2) - VFCandidates.insert(VF); + VFCandidates.push_back(VF); CM.collectInLoopReductions(); for (const auto &VF : VFCandidates) { @@ -7419,11 +6929,17 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) { buildVPlansWithVPRecipes(ElementCount::getScalable(1), MaxFactors.ScalableVF); LLVM_DEBUG(printPlans(dbgs())); - if (!MaxFactors.hasVector()) + if (VPlans.empty()) + return std::nullopt; + if (all_of(VPlans, + [](std::unique_ptr<VPlan> &P) { return P->hasScalarVFOnly(); })) return VectorizationFactor::Disabled(); - // Select the optimal vectorization factor. - VectorizationFactor VF = selectVectorizationFactor(VFCandidates); + // Select the optimal vectorization factor according to the legacy cost-model. + // This is now only used to verify the decisions by the new VPlan-based + // cost-model and will be retired once the VPlan-based cost-model is + // stabilized. + VectorizationFactor VF = selectVectorizationFactor(); assert((VF.Width.isScalar() || VF.ScalarCost > 0) && "when vectorizing, the scalar cost must be non-zero."); if (!hasPlanWithVF(VF.Width)) { LLVM_DEBUG(dbgs() << "LV: No VPlan could be built for " << VF.Width @@ -7433,6 +6949,212 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) { return VF; } +InstructionCost VPCostContext::getLegacyCost(Instruction *UI, + ElementCount VF) const { + return CM.getInstructionCost(UI, VF); +} + +bool VPCostContext::skipCostComputation(Instruction *UI, bool IsVector) const { + return CM.ValuesToIgnore.contains(UI) || + (IsVector && CM.VecValuesToIgnore.contains(UI)) || + SkipCostComputation.contains(UI); +} + +InstructionCost LoopVectorizationPlanner::cost(VPlan &Plan, + ElementCount VF) const { + InstructionCost Cost = 0; + LLVMContext &LLVMCtx = OrigLoop->getHeader()->getContext(); + VPCostContext CostCtx(CM.TTI, Legal->getWidestInductionType(), LLVMCtx, CM); + + // Cost modeling for inductions is inaccurate in the legacy cost model + // compared to the recipes that are generated. To match here initially during + // VPlan cost model bring up directly use the induction costs from the legacy + // cost model. Note that we do this as pre-processing; the VPlan may not have + // any recipes associated with the original induction increment instruction + // and may replace truncates with VPWidenIntOrFpInductionRecipe. We precompute + // the cost of induction phis and increments (both that are represented by + // recipes and those that are not), to avoid distinguishing between them here, + // and skip all recipes that represent induction phis and increments (the + // former case) later on, if they exist, to avoid counting them twice. + // Similarly we pre-compute the cost of any optimized truncates. + // TODO: Switch to more accurate costing based on VPlan. + for (const auto &[IV, IndDesc] : Legal->getInductionVars()) { + Instruction *IVInc = cast<Instruction>( + IV->getIncomingValueForBlock(OrigLoop->getLoopLatch())); + SmallVector<Instruction *> IVInsts = {IV, IVInc}; + for (User *U : IV->users()) { + auto *CI = cast<Instruction>(U); + if (!CostCtx.CM.isOptimizableIVTruncate(CI, VF)) + continue; + IVInsts.push_back(CI); + } + for (Instruction *IVInst : IVInsts) { + if (!CostCtx.SkipCostComputation.insert(IVInst).second) + continue; + InstructionCost InductionCost = CostCtx.getLegacyCost(IVInst, VF); + LLVM_DEBUG({ + dbgs() << "Cost of " << InductionCost << " for VF " << VF + << ": induction instruction " << *IVInst << "\n"; + }); + Cost += InductionCost; + } + } + + /// Compute the cost of all exiting conditions of the loop using the legacy + /// cost model. This is to match the legacy behavior, which adds the cost of + /// all exit conditions. Note that this over-estimates the cost, as there will + /// be a single condition to control the vector loop. + SmallVector<BasicBlock *> Exiting; + CM.TheLoop->getExitingBlocks(Exiting); + SetVector<Instruction *> ExitInstrs; + // Collect all exit conditions. + for (BasicBlock *EB : Exiting) { + auto *Term = dyn_cast<BranchInst>(EB->getTerminator()); + if (!Term) + continue; + if (auto *CondI = dyn_cast<Instruction>(Term->getOperand(0))) { + ExitInstrs.insert(CondI); + } + } + // Compute the cost of all instructions only feeding the exit conditions. + for (unsigned I = 0; I != ExitInstrs.size(); ++I) { + Instruction *CondI = ExitInstrs[I]; + if (!OrigLoop->contains(CondI) || + !CostCtx.SkipCostComputation.insert(CondI).second) + continue; + Cost += CostCtx.getLegacyCost(CondI, VF); + for (Value *Op : CondI->operands()) { + auto *OpI = dyn_cast<Instruction>(Op); + if (!OpI || any_of(OpI->users(), [&ExitInstrs, this](User *U) { + return OrigLoop->contains(cast<Instruction>(U)->getParent()) && + !ExitInstrs.contains(cast<Instruction>(U)); + })) + continue; + ExitInstrs.insert(OpI); + } + } + + // The legacy cost model has special logic to compute the cost of in-loop + // reductions, which may be smaller than the sum of all instructions involved + // in the reduction. For AnyOf reductions, VPlan codegen may remove the select + // which the legacy cost model uses to assign cost. Pre-compute their costs + // for now. + // TODO: Switch to costing based on VPlan once the logic has been ported. + for (const auto &[RedPhi, RdxDesc] : Legal->getReductionVars()) { + if (!CM.isInLoopReduction(RedPhi) && + !RecurrenceDescriptor::isAnyOfRecurrenceKind( + RdxDesc.getRecurrenceKind())) + continue; + + // AnyOf reduction codegen may remove the select. To match the legacy cost + // model, pre-compute the cost for AnyOf reductions here. + if (RecurrenceDescriptor::isAnyOfRecurrenceKind( + RdxDesc.getRecurrenceKind())) { + auto *Select = cast<SelectInst>(*find_if( + RedPhi->users(), [](User *U) { return isa<SelectInst>(U); })); + assert(!CostCtx.SkipCostComputation.contains(Select) && + "reduction op visited multiple times"); + CostCtx.SkipCostComputation.insert(Select); + auto ReductionCost = CostCtx.getLegacyCost(Select, VF); + LLVM_DEBUG(dbgs() << "Cost of " << ReductionCost << " for VF " << VF + << ":\n any-of reduction " << *Select << "\n"); + Cost += ReductionCost; + continue; + } + + const auto &ChainOps = RdxDesc.getReductionOpChain(RedPhi, OrigLoop); + SetVector<Instruction *> ChainOpsAndOperands(ChainOps.begin(), + ChainOps.end()); + // Also include the operands of instructions in the chain, as the cost-model + // may mark extends as free. + for (auto *ChainOp : ChainOps) { + for (Value *Op : ChainOp->operands()) { + if (auto *I = dyn_cast<Instruction>(Op)) + ChainOpsAndOperands.insert(I); + } + } + + // Pre-compute the cost for I, if it has a reduction pattern cost. + for (Instruction *I : ChainOpsAndOperands) { + auto ReductionCost = CM.getReductionPatternCost( + I, VF, ToVectorTy(I->getType(), VF), TTI::TCK_RecipThroughput); + if (!ReductionCost) + continue; + + assert(!CostCtx.SkipCostComputation.contains(I) && + "reduction op visited multiple times"); + CostCtx.SkipCostComputation.insert(I); + LLVM_DEBUG(dbgs() << "Cost of " << ReductionCost << " for VF " << VF + << ":\n in-loop reduction " << *I << "\n"); + Cost += *ReductionCost; + } + } + + // Pre-compute the costs for branches except for the backedge, as the number + // of replicate regions in a VPlan may not directly match the number of + // branches, which would lead to different decisions. + // TODO: Compute cost of branches for each replicate region in the VPlan, + // which is more accurate than the legacy cost model. + for (BasicBlock *BB : OrigLoop->blocks()) { + if (BB == OrigLoop->getLoopLatch()) + continue; + CostCtx.SkipCostComputation.insert(BB->getTerminator()); + auto BranchCost = CostCtx.getLegacyCost(BB->getTerminator(), VF); + Cost += BranchCost; + } + // Now compute and add the VPlan-based cost. + Cost += Plan.cost(VF, CostCtx); + LLVM_DEBUG(dbgs() << "Cost for VF " << VF << ": " << Cost << "\n"); + return Cost; +} + +VPlan &LoopVectorizationPlanner::getBestPlan() const { + // If there is a single VPlan with a single VF, return it directly. + VPlan &FirstPlan = *VPlans[0]; + if (VPlans.size() == 1 && size(FirstPlan.vectorFactors()) == 1) + return FirstPlan; + + VPlan *BestPlan = &FirstPlan; + ElementCount ScalarVF = ElementCount::getFixed(1); + assert(hasPlanWithVF(ScalarVF) && + "More than a single plan/VF w/o any plan having scalar VF"); + + // TODO: Compute scalar cost using VPlan-based cost model. + InstructionCost ScalarCost = CM.expectedCost(ScalarVF); + VectorizationFactor BestFactor(ScalarVF, ScalarCost, ScalarCost); + + bool ForceVectorization = Hints.getForce() == LoopVectorizeHints::FK_Enabled; + if (ForceVectorization) { + // Ignore scalar width, because the user explicitly wants vectorization. + // Initialize cost to max so that VF = 2 is, at least, chosen during cost + // evaluation. + BestFactor.Cost = InstructionCost::getMax(); + } + + for (auto &P : VPlans) { + for (ElementCount VF : P->vectorFactors()) { + if (VF.isScalar()) + continue; + if (!ForceVectorization && !willGenerateVectors(*P, VF, TTI)) { + LLVM_DEBUG( + dbgs() + << "LV: Not considering vector loop of width " << VF + << " because it will not generate any vector instructions.\n"); + continue; + } + + InstructionCost Cost = cost(*P, VF); + VectorizationFactor CurrentFactor(VF, Cost, ScalarCost); + if (isMoreProfitable(CurrentFactor, BestFactor)) { + BestFactor = CurrentFactor; + BestPlan = &*P; + } + } + } + BestPlan->setVF(BestFactor.Width); + return *BestPlan; +} + VPlan &LoopVectorizationPlanner::getBestPlanFor(ElementCount VF) const { assert(count_if(VPlans, [VF](const VPlanPtr &Plan) { return Plan->hasVF(VF); }) == @@ -7485,7 +7207,8 @@ static void AddRuntimeUnrollDisableMetaData(Loop *L) { static void createAndCollectMergePhiForReduction( VPInstruction *RedResult, DenseMap<const RecurrenceDescriptor *, Value *> &ReductionResumeValues, - VPTransformState &State, Loop *OrigLoop, BasicBlock *LoopMiddleBlock) { + VPTransformState &State, Loop *OrigLoop, BasicBlock *LoopMiddleBlock, + bool VectorizingEpilogue) { if (!RedResult || RedResult->getOpcode() != VPInstruction::ComputeReductionResult) return; @@ -7493,19 +7216,29 @@ static void createAndCollectMergePhiForReduction( auto *PhiR = cast<VPReductionPHIRecipe>(RedResult->getOperand(0)); const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor(); - TrackingVH<Value> ReductionStartValue = RdxDesc.getRecurrenceStartValue(); Value *FinalValue = State.get(RedResult, VPIteration(State.UF - 1, VPLane::getFirstLane())); auto *ResumePhi = dyn_cast<PHINode>(PhiR->getStartValue()->getUnderlyingValue()); + if (VectorizingEpilogue && RecurrenceDescriptor::isAnyOfRecurrenceKind( + RdxDesc.getRecurrenceKind())) { + auto *Cmp = cast<ICmpInst>(PhiR->getStartValue()->getUnderlyingValue()); + assert(Cmp->getPredicate() == CmpInst::ICMP_NE); + assert(Cmp->getOperand(1) == RdxDesc.getRecurrenceStartValue()); + ResumePhi = cast<PHINode>(Cmp->getOperand(0)); + } + assert((!VectorizingEpilogue || ResumePhi) && + "when vectorizing the epilogue loop, we need a resume phi from main " + "vector loop"); // TODO: bc.merge.rdx should not be created here, instead it should be // modeled in VPlan. BasicBlock *LoopScalarPreHeader = OrigLoop->getLoopPreheader(); // Create a phi node that merges control-flow from the backedge-taken check // block and the middle block. - auto *BCBlockPhi = PHINode::Create(FinalValue->getType(), 2, "bc.merge.rdx", - LoopScalarPreHeader->getTerminator()); + auto *BCBlockPhi = + PHINode::Create(FinalValue->getType(), 2, "bc.merge.rdx", + LoopScalarPreHeader->getTerminator()->getIterator()); // If we are fixing reductions in the epilogue loop then we should already // have created a bc.merge.rdx Phi after the main vector body. Ensure that @@ -7517,7 +7250,7 @@ static void createAndCollectMergePhiForReduction( BCBlockPhi->addIncoming(ResumePhi->getIncomingValueForBlock(Incoming), Incoming); else - BCBlockPhi->addIncoming(ReductionStartValue, Incoming); + BCBlockPhi->addIncoming(RdxDesc.getRecurrenceStartValue(), Incoming); } auto *OrigPhi = cast<PHINode>(PhiR->getUnderlyingValue()); @@ -7549,12 +7282,14 @@ LoopVectorizationPlanner::executePlan( assert( (IsEpilogueVectorization || !ExpandedSCEVs) && "expanded SCEVs to reuse can only be used during epilogue vectorization"); + (void)IsEpilogueVectorization; - LLVM_DEBUG(dbgs() << "Executing best plan with VF=" << BestVF << ", UF=" << BestUF - << '\n'); + VPlanTransforms::optimizeForVFAndUF(BestVPlan, BestVF, BestUF, PSE); - if (!IsEpilogueVectorization) - VPlanTransforms::optimizeForVFAndUF(BestVPlan, BestVF, BestUF, PSE); + LLVM_DEBUG(dbgs() << "Executing best plan with VF=" << BestVF + << ", UF=" << BestUF << '\n'); + BestVPlan.setName("Final VPlan"); + LLVM_DEBUG(BestVPlan.dump()); // Perform the actual loop transformation. VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan, @@ -7579,6 +7314,9 @@ LoopVectorizationPlanner::executePlan( std::tie(State.CFG.PrevBB, CanonicalIVStartValue) = ILV.createVectorizedLoopSkeleton(ExpandedSCEVs ? *ExpandedSCEVs : State.ExpandedSCEVs); +#ifdef EXPENSIVE_CHECKS + assert(DT->verify(DominatorTree::VerificationLevel::Fast)); +#endif // Only use noalias metadata when using memory checks guaranteeing no overlap // across all iterations. @@ -7598,8 +7336,6 @@ LoopVectorizationPlanner::executePlan( State.LVer->prepareNoAliasMetadata(); } - ILV.collectPoisonGeneratingRecipes(State); - ILV.printDebugTracesAtStart(); //===------------------------------------------------===// @@ -7622,9 +7358,9 @@ LoopVectorizationPlanner::executePlan( auto *ExitVPBB = cast<VPBasicBlock>(BestVPlan.getVectorLoopRegion()->getSingleSuccessor()); for (VPRecipeBase &R : *ExitVPBB) { - createAndCollectMergePhiForReduction(dyn_cast<VPInstruction>(&R), - ReductionResumeValues, State, OrigLoop, - State.CFG.VPBB2IRBB[ExitVPBB]); + createAndCollectMergePhiForReduction( + dyn_cast<VPInstruction>(&R), ReductionResumeValues, State, OrigLoop, + State.CFG.VPBB2IRBB[ExitVPBB], ExpandedSCEVs); } // 2.6. Maintain Loop Hints @@ -7661,6 +7397,18 @@ LoopVectorizationPlanner::executePlan( ILV.printDebugTracesAtEnd(); + // 4. Adjust branch weight of the branch in the middle block. + auto *MiddleTerm = + cast<BranchInst>(State.CFG.VPBB2IRBB[ExitVPBB]->getTerminator()); + if (MiddleTerm->isConditional() && + hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) { + // Assume that `Count % VectorTripCount` is equally distributed. + unsigned TripCount = State.UF * State.VF.getKnownMinValue(); + assert(TripCount > 0 && "trip count should not be zero"); + const uint32_t Weights[] = {1, TripCount - 1}; + setBranchWeights(*MiddleTerm, Weights, /*IsExpected=*/false); + } + return {State.ExpandedSCEVs, ReductionResumeValues}; } @@ -7717,7 +7465,7 @@ EpilogueVectorizerMainLoop::createEpilogueVectorizedLoopSkeleton( // inductions in the epilogue loop are created before executing the plan for // the epilogue loop. - return {completeLoopSkeleton(), nullptr}; + return {LoopVectorPreHeader, nullptr}; } void EpilogueVectorizerMainLoop::printDebugTracesAtStart() { @@ -7772,14 +7520,8 @@ EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass, DT->getNode(Bypass)->getIDom()) && "TC check is expected to dominate Bypass"); - // Update dominator for Bypass & LoopExit. + // Update dominator for Bypass. DT->changeImmediateDominator(Bypass, TCCheckBlock); - if (!Cost->requiresScalarEpilogue(EPI.EpilogueVF.isVector())) - // For loops with multiple exits, there's no edge from the middle block - // to exit blocks (as the epilogue must run) and thus no need to update - // the immediate dominator of the exit blocks. - DT->changeImmediateDominator(LoopExitBlock, TCCheckBlock); - LoopBypassBlocks.push_back(TCCheckBlock); // Save the trip count so we don't have to regenerate it in the @@ -7791,7 +7533,7 @@ EpilogueVectorizerMainLoop::emitIterationCountCheck(BasicBlock *Bypass, BranchInst &BI = *BranchInst::Create(Bypass, LoopVectorPreHeader, CheckMinIters); if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) - setBranchWeights(BI, MinItersBypassWeights); + setBranchWeights(BI, MinItersBypassWeights, /*IsExpected=*/false); ReplaceInstWithInst(TCCheckBlock->getTerminator(), &BI); return TCCheckBlock; @@ -7810,11 +7552,10 @@ EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton( // Now, compare the remaining count and if there aren't enough iterations to // execute the vectorized epilogue skip to the scalar part. - BasicBlock *VecEpilogueIterationCountCheck = LoopVectorPreHeader; - VecEpilogueIterationCountCheck->setName("vec.epilog.iter.check"); - LoopVectorPreHeader = - SplitBlock(LoopVectorPreHeader, LoopVectorPreHeader->getTerminator(), DT, - LI, nullptr, "vec.epilog.ph"); + LoopVectorPreHeader->setName("vec.epilog.ph"); + BasicBlock *VecEpilogueIterationCountCheck = + SplitBlock(LoopVectorPreHeader, LoopVectorPreHeader->begin(), DT, LI, + nullptr, "vec.epilog.iter.check", true); emitMinimumVectorEpilogueIterCountCheck(LoopScalarPreHeader, VecEpilogueIterationCountCheck); @@ -7907,7 +7648,7 @@ EpilogueVectorizerEpilogueLoop::createEpilogueVectorizedLoopSkeleton( {VecEpilogueIterationCountCheck, EPI.VectorTripCount} /* AdditionalBypass */); - return {completeLoopSkeleton(), EPResumeVal}; + return {LoopVectorPreHeader, EPResumeVal}; } BasicBlock * @@ -7949,10 +7690,9 @@ EpilogueVectorizerEpilogueLoop::emitMinimumVectorEpilogueIterCountCheck( unsigned EstimatedSkipCount = std::min(MainLoopStep, EpilogueLoopStep); const uint32_t Weights[] = {EstimatedSkipCount, MainLoopStep - EstimatedSkipCount}; - setBranchWeights(BI, Weights); + setBranchWeights(BI, Weights, /*IsExpected=*/false); } ReplaceInstWithInst(Insert->getTerminator(), &BI); - LoopBypassBlocks.push_back(Insert); return Insert; } @@ -8000,8 +7740,19 @@ void LoopVectorizationPlanner::buildVPlans(ElementCount MinVF, } } -VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst, - VPlan &Plan) { +iterator_range<mapped_iterator<Use *, std::function<VPValue *(Value *)>>> +VPRecipeBuilder::mapToVPValues(User::op_range Operands) { + std::function<VPValue *(Value *)> Fn = [this](Value *Op) { + if (auto *I = dyn_cast<Instruction>(Op)) { + if (auto *R = Ingredient2Recipe.lookup(I)) + return R->getVPSingleValue(); + } + return Plan.getOrAddLiveIn(Op); + }; + return map_range(Operands, Fn); +} + +VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) { assert(is_contained(predecessors(Dst), Src) && "Invalid edge"); // Look for cached value. @@ -8025,27 +7776,34 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst, if (OrigLoop->isLoopExiting(Src)) return EdgeMaskCache[Edge] = SrcMask; - VPValue *EdgeMask = Plan.getVPValueOrAddLiveIn(BI->getCondition()); + VPValue *EdgeMask = getVPValueOrAddLiveIn(BI->getCondition(), Plan); assert(EdgeMask && "No Edge Mask found for condition"); if (BI->getSuccessor(0) != Dst) EdgeMask = Builder.createNot(EdgeMask, BI->getDebugLoc()); if (SrcMask) { // Otherwise block in-mask is all-one, no need to AND. - // The condition is 'SrcMask && EdgeMask', which is equivalent to - // 'select i1 SrcMask, i1 EdgeMask, i1 false'. - // The select version does not introduce new UB if SrcMask is false and - // EdgeMask is poison. Using 'and' here introduces undefined behavior. - VPValue *False = Plan.getVPValueOrAddLiveIn( - ConstantInt::getFalse(BI->getCondition()->getType())); - EdgeMask = - Builder.createSelect(SrcMask, EdgeMask, False, BI->getDebugLoc()); + // The bitwise 'And' of SrcMask and EdgeMask introduces new UB if SrcMask + // is false and EdgeMask is poison. Avoid that by using 'LogicalAnd' + // instead which generates 'select i1 SrcMask, i1 EdgeMask, i1 false'. + EdgeMask = Builder.createLogicalAnd(SrcMask, EdgeMask, BI->getDebugLoc()); } return EdgeMaskCache[Edge] = EdgeMask; } -void VPRecipeBuilder::createHeaderMask(VPlan &Plan) { +VPValue *VPRecipeBuilder::getEdgeMask(BasicBlock *Src, BasicBlock *Dst) const { + assert(is_contained(predecessors(Dst), Src) && "Invalid edge"); + + // Look for cached value. + std::pair<BasicBlock *, BasicBlock *> Edge(Src, Dst); + EdgeMaskCacheTy::const_iterator ECEntryIt = EdgeMaskCache.find(Edge); + assert(ECEntryIt != EdgeMaskCache.end() && + "looking up mask for edge which has not been created"); + return ECEntryIt->second; +} + +void VPRecipeBuilder::createHeaderMask() { BasicBlock *Header = OrigLoop->getHeader(); // When not folding the tail, use nullptr to model all-true mask. @@ -8080,7 +7838,7 @@ VPValue *VPRecipeBuilder::getBlockInMask(BasicBlock *BB) const { return BCEntryIt->second; } -void VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlan &Plan) { +void VPRecipeBuilder::createBlockInMask(BasicBlock *BB) { assert(OrigLoop->contains(BB) && "Block is not a part of a loop"); assert(BlockMaskCache.count(BB) == 0 && "Mask for block already computed"); assert(OrigLoop->getHeader() != BB && @@ -8091,7 +7849,7 @@ void VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlan &Plan) { VPValue *BlockMask = nullptr; // This is the block mask. We OR all incoming edges. for (auto *Predecessor : predecessors(BB)) { - VPValue *EdgeMask = createEdgeMask(Predecessor, BB, Plan); + VPValue *EdgeMask = createEdgeMask(Predecessor, BB); if (!EdgeMask) { // Mask of predecessor is all-one so mask of block is too. BlockMaskCache[BB] = EdgeMask; return; @@ -8108,10 +7866,9 @@ void VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlan &Plan) { BlockMaskCache[BB] = BlockMask; } -VPRecipeBase *VPRecipeBuilder::tryToWidenMemory(Instruction *I, - ArrayRef<VPValue *> Operands, - VFRange &Range, - VPlanPtr &Plan) { +VPWidenMemoryRecipe * +VPRecipeBuilder::tryToWidenMemory(Instruction *I, ArrayRef<VPValue *> Operands, + VFRange &Range) { assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && "Must be called with either a load or store"); @@ -8154,12 +7911,12 @@ VPRecipeBase *VPRecipeBuilder::tryToWidenMemory(Instruction *I, Ptr = VectorPtr; } if (LoadInst *Load = dyn_cast<LoadInst>(I)) - return new VPWidenMemoryInstructionRecipe(*Load, Ptr, Mask, Consecutive, - Reverse); + return new VPWidenLoadRecipe(*Load, Ptr, Mask, Consecutive, Reverse, + I->getDebugLoc()); StoreInst *Store = cast<StoreInst>(I); - return new VPWidenMemoryInstructionRecipe(*Store, Ptr, Operands[0], Mask, - Consecutive, Reverse); + return new VPWidenStoreRecipe(*Store, Ptr, Operands[0], Mask, Consecutive, + Reverse, I->getDebugLoc()); } /// Creates a VPWidenIntOrFpInductionRecpipe for \p Phi. If needed, it will also @@ -8167,8 +7924,7 @@ VPRecipeBase *VPRecipeBuilder::tryToWidenMemory(Instruction *I, static VPWidenIntOrFpInductionRecipe * createWidenInductionRecipes(PHINode *Phi, Instruction *PhiOrTrunc, VPValue *Start, const InductionDescriptor &IndDesc, - VPlan &Plan, ScalarEvolution &SE, Loop &OrigLoop, - VFRange &Range) { + VPlan &Plan, ScalarEvolution &SE, Loop &OrigLoop) { assert(IndDesc.getStartValue() == Phi->getIncomingValueForBlock(OrigLoop.getLoopPreheader())); assert(SE.isLoopInvariant(IndDesc.getStep(), &OrigLoop) && @@ -8183,14 +7939,14 @@ createWidenInductionRecipes(PHINode *Phi, Instruction *PhiOrTrunc, return new VPWidenIntOrFpInductionRecipe(Phi, Start, Step, IndDesc); } -VPRecipeBase *VPRecipeBuilder::tryToOptimizeInductionPHI( - PHINode *Phi, ArrayRef<VPValue *> Operands, VPlan &Plan, VFRange &Range) { +VPHeaderPHIRecipe *VPRecipeBuilder::tryToOptimizeInductionPHI( + PHINode *Phi, ArrayRef<VPValue *> Operands, VFRange &Range) { // Check if this is an integer or fp induction. If so, build the recipe that // produces its scalar and vector values. if (auto *II = Legal->getIntOrFpInductionDescriptor(Phi)) return createWidenInductionRecipes(Phi, Phi, Operands[0], *II, Plan, - *PSE.getSE(), *OrigLoop, Range); + *PSE.getSE(), *OrigLoop); // Check if this is pointer induction. If so, build the recipe for it. if (auto *II = Legal->getPointerInductionDescriptor(Phi)) { @@ -8208,7 +7964,7 @@ VPRecipeBase *VPRecipeBuilder::tryToOptimizeInductionPHI( } VPWidenIntOrFpInductionRecipe *VPRecipeBuilder::tryToOptimizeInductionTruncate( - TruncInst *I, ArrayRef<VPValue *> Operands, VFRange &Range, VPlan &Plan) { + TruncInst *I, ArrayRef<VPValue *> Operands, VFRange &Range) { // Optimize the special case where the source is a constant integer // induction variable. Notice that we can only optimize the 'trunc' case // because (a) FP conversions lose precision, (b) sext/zext may wrap, and @@ -8228,62 +7984,46 @@ VPWidenIntOrFpInductionRecipe *VPRecipeBuilder::tryToOptimizeInductionTruncate( auto *Phi = cast<PHINode>(I->getOperand(0)); const InductionDescriptor &II = *Legal->getIntOrFpInductionDescriptor(Phi); - VPValue *Start = Plan.getVPValueOrAddLiveIn(II.getStartValue()); + VPValue *Start = Plan.getOrAddLiveIn(II.getStartValue()); return createWidenInductionRecipes(Phi, I, Start, II, Plan, *PSE.getSE(), - *OrigLoop, Range); + *OrigLoop); } return nullptr; } -VPRecipeOrVPValueTy VPRecipeBuilder::tryToBlend(PHINode *Phi, - ArrayRef<VPValue *> Operands, - VPlanPtr &Plan) { - // If all incoming values are equal, the incoming VPValue can be used directly - // instead of creating a new VPBlendRecipe. - if (llvm::all_equal(Operands)) - return Operands[0]; - +VPBlendRecipe *VPRecipeBuilder::tryToBlend(PHINode *Phi, + ArrayRef<VPValue *> Operands) { unsigned NumIncoming = Phi->getNumIncomingValues(); - // For in-loop reductions, we do not need to create an additional select. - VPValue *InLoopVal = nullptr; - for (unsigned In = 0; In < NumIncoming; In++) { - PHINode *PhiOp = - dyn_cast_or_null<PHINode>(Operands[In]->getUnderlyingValue()); - if (PhiOp && CM.isInLoopReduction(PhiOp)) { - assert(!InLoopVal && "Found more than one in-loop reduction!"); - InLoopVal = Operands[In]; - } - } - - assert((!InLoopVal || NumIncoming == 2) && - "Found an in-loop reduction for PHI with unexpected number of " - "incoming values"); - if (InLoopVal) - return Operands[Operands[0] == InLoopVal ? 1 : 0]; // We know that all PHIs in non-header blocks are converted into selects, so // we don't have to worry about the insertion order and we can just use the // builder. At this point we generate the predication tree. There may be // duplications since this is a simple recursive scan, but future // optimizations will clean it up. + // TODO: At the moment the first mask is always skipped, but it would be + // better to skip the most expensive mask. SmallVector<VPValue *, 2> OperandsWithMask; for (unsigned In = 0; In < NumIncoming; In++) { - VPValue *EdgeMask = - createEdgeMask(Phi->getIncomingBlock(In), Phi->getParent(), *Plan); - assert((EdgeMask || NumIncoming == 1) && - "Multiple predecessors with one having a full mask"); OperandsWithMask.push_back(Operands[In]); - if (EdgeMask) - OperandsWithMask.push_back(EdgeMask); + VPValue *EdgeMask = + getEdgeMask(Phi->getIncomingBlock(In), Phi->getParent()); + if (!EdgeMask) { + assert(In == 0 && "Both null and non-null edge masks found"); + assert(all_equal(Operands) && + "Distinct incoming values with one having a full mask"); + break; + } + if (In == 0) + continue; + OperandsWithMask.push_back(EdgeMask); } - return toVPRecipeResult(new VPBlendRecipe(Phi, OperandsWithMask)); + return new VPBlendRecipe(Phi, OperandsWithMask); } VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI, ArrayRef<VPValue *> Operands, - VFRange &Range, - VPlanPtr &Plan) { + VFRange &Range) { bool IsPredicated = LoopVectorizationPlanner::getDecisionAndClampRange( [this, CI](ElementCount VF) { return CM.isScalarWithPredication(CI, VF); @@ -8301,6 +8041,7 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI, return nullptr; SmallVector<VPValue *, 4> Ops(Operands.take_front(CI->arg_size())); + Ops.push_back(Operands.back()); // Is it beneficial to perform intrinsic call compared to lib call? bool ShouldUseVectorIntrinsic = @@ -8311,7 +8052,7 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI, }, Range); if (ShouldUseVectorIntrinsic) - return new VPWidenCallRecipe(*CI, make_range(Ops.begin(), Ops.end()), ID, + return new VPWidenCallRecipe(CI, make_range(Ops.begin(), Ops.end()), ID, CI->getDebugLoc()); Function *Variant = nullptr; @@ -8358,13 +8099,13 @@ VPWidenCallRecipe *VPRecipeBuilder::tryToWidenCall(CallInst *CI, if (Legal->isMaskRequired(CI)) Mask = getBlockInMask(CI->getParent()); else - Mask = Plan->getVPValueOrAddLiveIn(ConstantInt::getTrue( + Mask = Plan.getOrAddLiveIn(ConstantInt::getTrue( IntegerType::getInt1Ty(Variant->getFunctionType()->getContext()))); Ops.insert(Ops.begin() + *MaskPos, Mask); } - return new VPWidenCallRecipe(*CI, make_range(Ops.begin(), Ops.end()), + return new VPWidenCallRecipe(CI, make_range(Ops.begin(), Ops.end()), Intrinsic::not_intrinsic, CI->getDebugLoc(), Variant); } @@ -8386,9 +8127,9 @@ bool VPRecipeBuilder::shouldWiden(Instruction *I, VFRange &Range) const { Range); } -VPRecipeBase *VPRecipeBuilder::tryToWiden(Instruction *I, - ArrayRef<VPValue *> Operands, - VPBasicBlock *VPBB, VPlanPtr &Plan) { +VPWidenRecipe *VPRecipeBuilder::tryToWiden(Instruction *I, + ArrayRef<VPValue *> Operands, + VPBasicBlock *VPBB) { switch (I->getOpcode()) { default: return nullptr; @@ -8401,12 +8142,9 @@ VPRecipeBase *VPRecipeBuilder::tryToWiden(Instruction *I, if (CM.isPredicatedInst(I)) { SmallVector<VPValue *> Ops(Operands.begin(), Operands.end()); VPValue *Mask = getBlockInMask(I->getParent()); - VPValue *One = Plan->getVPValueOrAddLiveIn( - ConstantInt::get(I->getType(), 1u, false)); - auto *SafeRHS = - new VPInstruction(Instruction::Select, {Mask, Ops[1], One}, - I->getDebugLoc()); - VPBB->appendRecipe(SafeRHS); + VPValue *One = + Plan.getOrAddLiveIn(ConstantInt::get(I->getType(), 1u, false)); + auto *SafeRHS = Builder.createSelect(Mask, Ops[1], One, I->getDebugLoc()); Ops[1] = SafeRHS; return new VPWidenRecipe(*I, make_range(Ops.begin(), Ops.end())); } @@ -8445,9 +8183,8 @@ void VPRecipeBuilder::fixHeaderPhis() { } } -VPRecipeOrVPValueTy VPRecipeBuilder::handleReplication(Instruction *I, - VFRange &Range, - VPlan &Plan) { +VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I, + VFRange &Range) { bool IsUniform = LoopVectorizationPlanner::getDecisionAndClampRange( [&](ElementCount VF) { return CM.isUniformAfterVectorization(I, VF); }, Range); @@ -8497,29 +8234,30 @@ VPRecipeOrVPValueTy VPRecipeBuilder::handleReplication(Instruction *I, BlockInMask = getBlockInMask(I->getParent()); } - auto *Recipe = new VPReplicateRecipe(I, Plan.mapToVPValues(I->operands()), + // Note that there is some custom logic to mark some intrinsics as uniform + // manually above for scalable vectors, which this assert needs to account for + // as well. + assert((Range.Start.isScalar() || !IsUniform || !IsPredicated || + (Range.Start.isScalable() && isa<IntrinsicInst>(I))) && + "Should not predicate a uniform recipe"); + auto *Recipe = new VPReplicateRecipe(I, mapToVPValues(I->operands()), IsUniform, BlockInMask); - return toVPRecipeResult(Recipe); + return Recipe; } -VPRecipeOrVPValueTy +VPRecipeBase * VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr, ArrayRef<VPValue *> Operands, - VFRange &Range, VPBasicBlock *VPBB, - VPlanPtr &Plan) { + VFRange &Range, VPBasicBlock *VPBB) { // First, check for specific widening recipes that deal with inductions, Phi // nodes, calls and memory operations. VPRecipeBase *Recipe; if (auto Phi = dyn_cast<PHINode>(Instr)) { if (Phi->getParent() != OrigLoop->getHeader()) - return tryToBlend(Phi, Operands, Plan); + return tryToBlend(Phi, Operands); - // Always record recipes for header phis. Later first-order recurrence phis - // can have earlier phis as incoming values. - recordRecipeOf(Phi); - - if ((Recipe = tryToOptimizeInductionPHI(Phi, Operands, *Plan, Range))) - return toVPRecipeResult(Recipe); + if ((Recipe = tryToOptimizeInductionPHI(Phi, Operands, Range))) + return Recipe; VPHeaderPHIRecipe *PhiRecipe = nullptr; assert((Legal->isReductionVariable(Phi) || @@ -8542,22 +8280,13 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr, PhiRecipe = new VPFirstOrderRecurrencePHIRecipe(Phi, *StartV); } - // Record the incoming value from the backedge, so we can add the incoming - // value from the backedge after all recipes have been created. - auto *Inc = cast<Instruction>( - Phi->getIncomingValueForBlock(OrigLoop->getLoopLatch())); - auto RecipeIter = Ingredient2Recipe.find(Inc); - if (RecipeIter == Ingredient2Recipe.end()) - recordRecipeOf(Inc); - PhisToFix.push_back(PhiRecipe); - return toVPRecipeResult(PhiRecipe); + return PhiRecipe; } - if (isa<TruncInst>(Instr) && - (Recipe = tryToOptimizeInductionTruncate(cast<TruncInst>(Instr), Operands, - Range, *Plan))) - return toVPRecipeResult(Recipe); + if (isa<TruncInst>(Instr) && (Recipe = tryToOptimizeInductionTruncate( + cast<TruncInst>(Instr), Operands, Range))) + return Recipe; // All widen recipes below deal only with VF > 1. if (LoopVectorizationPlanner::getDecisionAndClampRange( @@ -8565,29 +8294,29 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr, return nullptr; if (auto *CI = dyn_cast<CallInst>(Instr)) - return toVPRecipeResult(tryToWidenCall(CI, Operands, Range, Plan)); + return tryToWidenCall(CI, Operands, Range); if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr)) - return toVPRecipeResult(tryToWidenMemory(Instr, Operands, Range, Plan)); + return tryToWidenMemory(Instr, Operands, Range); if (!shouldWiden(Instr, Range)) return nullptr; if (auto GEP = dyn_cast<GetElementPtrInst>(Instr)) - return toVPRecipeResult(new VPWidenGEPRecipe( - GEP, make_range(Operands.begin(), Operands.end()))); + return new VPWidenGEPRecipe(GEP, + make_range(Operands.begin(), Operands.end())); if (auto *SI = dyn_cast<SelectInst>(Instr)) { - return toVPRecipeResult(new VPWidenSelectRecipe( - *SI, make_range(Operands.begin(), Operands.end()))); + return new VPWidenSelectRecipe( + *SI, make_range(Operands.begin(), Operands.end())); } if (auto *CI = dyn_cast<CastInst>(Instr)) { - return toVPRecipeResult(new VPWidenCastRecipe(CI->getOpcode(), Operands[0], - CI->getType(), *CI)); + return new VPWidenCastRecipe(CI->getOpcode(), Operands[0], CI->getType(), + *CI); } - return toVPRecipeResult(tryToWiden(Instr, Operands, VPBB, Plan)); + return tryToWiden(Instr, Operands, VPBB); } void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF, @@ -8603,7 +8332,12 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF, VPlanTransforms::truncateToMinimalBitwidths( *Plan, CM.getMinimalBitwidths(), PSE.getSE()->getContext()); VPlanTransforms::optimize(*Plan, *PSE.getSE()); - assert(VPlanVerifier::verifyPlanIsValid(*Plan) && "VPlan is invalid"); + // TODO: try to put it close to addActiveLaneMask(). + // Discard the plan if it is not EVL-compatible + if (CM.foldTailWithEVL() && + !VPlanTransforms::tryAddExplicitVectorLength(*Plan)) + break; + assert(verifyVPlanIsValid(*Plan) && "VPlan is invalid"); VPlans.push_back(std::move(Plan)); } VF = SubRange.End; @@ -8615,7 +8349,7 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF, static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, bool HasNUW, DebugLoc DL) { Value *StartIdx = ConstantInt::get(IdxTy, 0); - auto *StartV = Plan.getVPValueOrAddLiveIn(StartIdx); + auto *StartV = Plan.getOrAddLiveIn(StartIdx); // Add a VPCanonicalIVPHIRecipe starting at 0 to the header. auto *CanonicalIVPHI = new VPCanonicalIVPHIRecipe(StartV, DL); @@ -8623,27 +8357,22 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, bool HasNUW, VPBasicBlock *Header = TopRegion->getEntryBasicBlock(); Header->insert(CanonicalIVPHI, Header->begin()); - // Add a CanonicalIVIncrement{NUW} VPInstruction to increment the scalar - // IV by VF * UF. - auto *CanonicalIVIncrement = - new VPInstruction(Instruction::Add, {CanonicalIVPHI, &Plan.getVFxUF()}, - {HasNUW, false}, DL, "index.next"); + VPBuilder Builder(TopRegion->getExitingBasicBlock()); + // Add a VPInstruction to increment the scalar canonical IV by VF * UF. + auto *CanonicalIVIncrement = Builder.createOverflowingOp( + Instruction::Add, {CanonicalIVPHI, &Plan.getVFxUF()}, {HasNUW, false}, DL, + "index.next"); CanonicalIVPHI->addOperand(CanonicalIVIncrement); - VPBasicBlock *EB = TopRegion->getExitingBasicBlock(); - EB->appendRecipe(CanonicalIVIncrement); - // Add the BranchOnCount VPInstruction to the latch. - VPInstruction *BranchBack = - new VPInstruction(VPInstruction::BranchOnCount, - {CanonicalIVIncrement, &Plan.getVectorTripCount()}, DL); - EB->appendRecipe(BranchBack); + Builder.createNaryOp(VPInstruction::BranchOnCount, + {CanonicalIVIncrement, &Plan.getVectorTripCount()}, DL); } // Add exit values to \p Plan. VPLiveOuts are added for each LCSSA phi in the // original exit block. static void addUsersInExitBlock(VPBasicBlock *HeaderVPBB, Loop *OrigLoop, - VPlan &Plan) { + VPRecipeBuilder &Builder, VPlan &Plan) { BasicBlock *ExitBB = OrigLoop->getUniqueExitBlock(); BasicBlock *ExitingBB = OrigLoop->getExitingBlock(); // Only handle single-exit loops with unique exit blocks for now. @@ -8654,17 +8383,115 @@ static void addUsersInExitBlock(VPBasicBlock *HeaderVPBB, Loop *OrigLoop, for (PHINode &ExitPhi : ExitBB->phis()) { Value *IncomingValue = ExitPhi.getIncomingValueForBlock(ExitingBB); - VPValue *V = Plan.getVPValueOrAddLiveIn(IncomingValue); + VPValue *V = Builder.getVPValueOrAddLiveIn(IncomingValue, Plan); + // Exit values for inductions are computed and updated outside of VPlan and + // independent of induction recipes. + // TODO: Compute induction exit values in VPlan, use VPLiveOuts to update + // live-outs. + if ((isa<VPWidenIntOrFpInductionRecipe>(V) && + !cast<VPWidenIntOrFpInductionRecipe>(V)->getTruncInst()) || + isa<VPWidenPointerInductionRecipe>(V)) + continue; Plan.addLiveOut(&ExitPhi, V); } } +/// Feed a resume value for every FOR from the vector loop to the scalar loop, +/// if middle block branches to scalar preheader, by introducing ExtractFromEnd +/// and ResumePhi recipes in each, respectively, and a VPLiveOut which uses the +/// latter and corresponds to the scalar header. +static void addLiveOutsForFirstOrderRecurrences(VPlan &Plan) { + VPRegionBlock *VectorRegion = Plan.getVectorLoopRegion(); + + // Start by finding out if middle block branches to scalar preheader, which is + // not a VPIRBasicBlock, unlike Exit block - the other possible successor of + // middle block. + // TODO: Should be replaced by + // Plan->getScalarLoopRegion()->getSinglePredecessor() in the future once the + // scalar region is modeled as well. + VPBasicBlock *ScalarPHVPBB = nullptr; + auto *MiddleVPBB = cast<VPBasicBlock>(VectorRegion->getSingleSuccessor()); + for (VPBlockBase *Succ : MiddleVPBB->getSuccessors()) { + if (isa<VPIRBasicBlock>(Succ)) + continue; + assert(!ScalarPHVPBB && "Two candidates for ScalarPHVPBB?"); + ScalarPHVPBB = cast<VPBasicBlock>(Succ); + } + if (!ScalarPHVPBB) + return; + + VPBuilder ScalarPHBuilder(ScalarPHVPBB); + VPBuilder MiddleBuilder(MiddleVPBB); + // Reset insert point so new recipes are inserted before terminator and + // condition, if there is either the former or both. + if (auto *Terminator = MiddleVPBB->getTerminator()) { + auto *Condition = dyn_cast<VPInstruction>(Terminator->getOperand(0)); + assert((!Condition || Condition->getParent() == MiddleVPBB) && + "Condition expected in MiddleVPBB"); + MiddleBuilder.setInsertPoint(Condition ? Condition : Terminator); + } + VPValue *OneVPV = Plan.getOrAddLiveIn( + ConstantInt::get(Plan.getCanonicalIV()->getScalarType(), 1)); + + for (auto &HeaderPhi : VectorRegion->getEntryBasicBlock()->phis()) { + auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&HeaderPhi); + if (!FOR) + continue; + + // Extract the resume value and create a new VPLiveOut for it. + auto *Resume = MiddleBuilder.createNaryOp(VPInstruction::ExtractFromEnd, + {FOR->getBackedgeValue(), OneVPV}, + {}, "vector.recur.extract"); + auto *ResumePhiRecipe = ScalarPHBuilder.createNaryOp( + VPInstruction::ResumePhi, {Resume, FOR->getStartValue()}, {}, + "scalar.recur.init"); + Plan.addLiveOut(cast<PHINode>(FOR->getUnderlyingInstr()), ResumePhiRecipe); + } +} + VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) { SmallPtrSet<const InterleaveGroup<Instruction> *, 1> InterleaveGroups; - VPRecipeBuilder RecipeBuilder(OrigLoop, TLI, Legal, CM, PSE, Builder); + // --------------------------------------------------------------------------- + // Build initial VPlan: Scan the body of the loop in a topological order to + // visit each basic block after having visited its predecessor basic blocks. + // --------------------------------------------------------------------------- + + // Create initial VPlan skeleton, having a basic block for the pre-header + // which contains SCEV expansions that need to happen before the CFG is + // modified; a basic block for the vector pre-header, followed by a region for + // the vector loop, followed by the middle basic block. The skeleton vector + // loop region contains a header and latch basic blocks. + + bool RequiresScalarEpilogueCheck = + LoopVectorizationPlanner::getDecisionAndClampRange( + [this](ElementCount VF) { + return !CM.requiresScalarEpilogue(VF.isVector()); + }, + Range); + VPlanPtr Plan = VPlan::createInitialVPlan( + createTripCountSCEV(Legal->getWidestInductionType(), PSE, OrigLoop), + *PSE.getSE(), RequiresScalarEpilogueCheck, CM.foldTailByMasking(), + OrigLoop); + + // Don't use getDecisionAndClampRange here, because we don't know the UF + // so this function is better to be conservative, rather than to split + // it up into different VPlans. + // TODO: Consider using getDecisionAndClampRange here to split up VPlans. + bool IVUpdateMayOverflow = false; + for (ElementCount VF : Range) + IVUpdateMayOverflow |= !isIndvarOverflowCheckKnownFalse(&CM, VF); + + DebugLoc DL = getDebugLocFromInstOrOperands(Legal->getPrimaryInduction()); + TailFoldingStyle Style = CM.getTailFoldingStyle(IVUpdateMayOverflow); + // When not folding the tail, we know that the induction increment will not + // overflow. + bool HasNUW = Style == TailFoldingStyle::None; + addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, DL); + + VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder); // --------------------------------------------------------------------------- // Pre-construction: record ingredients whose recipes we'll need to further @@ -8690,55 +8517,26 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) { if (!getDecisionAndClampRange(applyIG, Range)) continue; InterleaveGroups.insert(IG); - for (unsigned i = 0; i < IG->getFactor(); i++) - if (Instruction *Member = IG->getMember(i)) - RecipeBuilder.recordRecipeOf(Member); }; // --------------------------------------------------------------------------- - // Build initial VPlan: Scan the body of the loop in a topological order to - // visit each basic block after having visited its predecessor basic blocks. + // Construct recipes for the instructions in the loop // --------------------------------------------------------------------------- - // Create initial VPlan skeleton, having a basic block for the pre-header - // which contains SCEV expansions that need to happen before the CFG is - // modified; a basic block for the vector pre-header, followed by a region for - // the vector loop, followed by the middle basic block. The skeleton vector - // loop region contains a header and latch basic blocks. - VPlanPtr Plan = VPlan::createInitialVPlan( - createTripCountSCEV(Legal->getWidestInductionType(), PSE, OrigLoop), - *PSE.getSE()); - VPBasicBlock *HeaderVPBB = new VPBasicBlock("vector.body"); - VPBasicBlock *LatchVPBB = new VPBasicBlock("vector.latch"); - VPBlockUtils::insertBlockAfter(LatchVPBB, HeaderVPBB); - Plan->getVectorLoopRegion()->setEntry(HeaderVPBB); - Plan->getVectorLoopRegion()->setExiting(LatchVPBB); - - // Don't use getDecisionAndClampRange here, because we don't know the UF - // so this function is better to be conservative, rather than to split - // it up into different VPlans. - // TODO: Consider using getDecisionAndClampRange here to split up VPlans. - bool IVUpdateMayOverflow = false; - for (ElementCount VF : Range) - IVUpdateMayOverflow |= !isIndvarOverflowCheckKnownFalse(&CM, VF); - - DebugLoc DL = getDebugLocFromInstOrOperands(Legal->getPrimaryInduction()); - TailFoldingStyle Style = CM.getTailFoldingStyle(IVUpdateMayOverflow); - // When not folding the tail, we know that the induction increment will not - // overflow. - bool HasNUW = Style == TailFoldingStyle::None; - addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, DL); - // Scan the body of the loop in a topological order to visit each basic block // after having visited its predecessor basic blocks. LoopBlocksDFS DFS(OrigLoop); DFS.perform(LI); + VPBasicBlock *HeaderVPBB = Plan->getVectorLoopRegion()->getEntryBasicBlock(); VPBasicBlock *VPBB = HeaderVPBB; - bool NeedsMasks = CM.foldTailByMasking() || - any_of(OrigLoop->blocks(), [this](BasicBlock *BB) { - return Legal->blockNeedsPredication(BB); - }); + BasicBlock *HeaderBB = OrigLoop->getHeader(); + bool NeedsMasks = + CM.foldTailByMasking() || + any_of(OrigLoop->blocks(), [this, HeaderBB](BasicBlock *BB) { + bool NeedsBlends = BB != HeaderBB && !BB->phis().empty(); + return Legal->blockNeedsPredication(BB) || NeedsBlends; + }); for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) { // Relevant instructions from basic block BB will be grouped into VPRecipe // ingredients and fill a new VPBasicBlock. @@ -8747,9 +8545,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) { Builder.setInsertPoint(VPBB); if (VPBB == HeaderVPBB) - RecipeBuilder.createHeaderMask(*Plan); + RecipeBuilder.createHeaderMask(); else if (NeedsMasks) - RecipeBuilder.createBlockInMask(BB, *Plan); + RecipeBuilder.createBlockInMask(BB); // Introduce each ingredient into VPlan. // TODO: Model and preserve debug intrinsics in VPlan. @@ -8757,11 +8555,11 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) { Instruction *Instr = &I; SmallVector<VPValue *, 4> Operands; auto *Phi = dyn_cast<PHINode>(Instr); - if (Phi && Phi->getParent() == OrigLoop->getHeader()) { - Operands.push_back(Plan->getVPValueOrAddLiveIn( + if (Phi && Phi->getParent() == HeaderBB) { + Operands.push_back(Plan->getOrAddLiveIn( Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()))); } else { - auto OpRange = Plan->mapToVPValues(Instr->operands()); + auto OpRange = RecipeBuilder.mapToVPValues(Instr->operands()); Operands = {OpRange.begin(), OpRange.end()}; } @@ -8772,26 +8570,10 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) { Legal->isInvariantAddressOfReduction(SI->getPointerOperand())) continue; - auto RecipeOrValue = RecipeBuilder.tryToCreateWidenRecipe( - Instr, Operands, Range, VPBB, Plan); - if (!RecipeOrValue) - RecipeOrValue = RecipeBuilder.handleReplication(Instr, Range, *Plan); - // If Instr can be simplified to an existing VPValue, use it. - if (isa<VPValue *>(RecipeOrValue)) { - auto *VPV = cast<VPValue *>(RecipeOrValue); - Plan->addVPValue(Instr, VPV); - // If the re-used value is a recipe, register the recipe for the - // instruction, in case the recipe for Instr needs to be recorded. - if (VPRecipeBase *R = VPV->getDefiningRecipe()) - RecipeBuilder.setRecipe(Instr, R); - continue; - } - // Otherwise, add the new recipe. - VPRecipeBase *Recipe = cast<VPRecipeBase *>(RecipeOrValue); - for (auto *Def : Recipe->definedValues()) { - auto *UV = Def->getUnderlyingValue(); - Plan->addVPValue(UV, Def); - } + VPRecipeBase *Recipe = + RecipeBuilder.tryToCreateWidenRecipe(Instr, Operands, Range, VPBB); + if (!Recipe) + Recipe = RecipeBuilder.handleReplication(Instr, Range); RecipeBuilder.setRecipe(Instr, Recipe); if (isa<VPHeaderPHIRecipe>(Recipe)) { @@ -8823,7 +8605,7 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) { // and there is nothing to fix from vector loop; phis should have incoming // from scalar loop only. } else - addUsersInExitBlock(HeaderVPBB, OrigLoop, *Plan); + addUsersInExitBlock(HeaderVPBB, OrigLoop, RecipeBuilder, *Plan); assert(isa<VPRegionBlock>(Plan->getVectorLoopRegion()) && !Plan->getVectorLoopRegion()->getEntryBasicBlock()->empty() && @@ -8831,30 +8613,33 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) { "VPBasicBlock"); RecipeBuilder.fixHeaderPhis(); + addLiveOutsForFirstOrderRecurrences(*Plan); + // --------------------------------------------------------------------------- // Transform initial VPlan: Apply previously taken decisions, in order, to // bring the VPlan to its final state. // --------------------------------------------------------------------------- // Adjust the recipes for any inloop reductions. - adjustRecipesForReductions(LatchVPBB, Plan, RecipeBuilder, Range.Start); + adjustRecipesForReductions(Plan, RecipeBuilder, Range.Start); // Interleave memory: for each Interleave Group we marked earlier as relevant // for this VPlan, replace the Recipes widening its memory instructions with a // single VPInterleaveRecipe at its insertion point. for (const auto *IG : InterleaveGroups) { - auto *Recipe = cast<VPWidenMemoryInstructionRecipe>( - RecipeBuilder.getRecipe(IG->getInsertPos())); + auto *Recipe = + cast<VPWidenMemoryRecipe>(RecipeBuilder.getRecipe(IG->getInsertPos())); SmallVector<VPValue *, 4> StoredValues; for (unsigned i = 0; i < IG->getFactor(); ++i) if (auto *SI = dyn_cast_or_null<StoreInst>(IG->getMember(i))) { - auto *StoreR = - cast<VPWidenMemoryInstructionRecipe>(RecipeBuilder.getRecipe(SI)); + auto *StoreR = cast<VPWidenStoreRecipe>(RecipeBuilder.getRecipe(SI)); StoredValues.push_back(StoreR->getStoredValue()); } bool NeedsMaskForGaps = IG->requiresScalarEpilogue() && !CM.isScalarEpilogueAllowed(); + assert((!NeedsMaskForGaps || useMaskedInterleavedAccesses(CM.TTI)) && + "masked interleaved groups are not allowed."); auto *VPIG = new VPInterleaveRecipe(IG, Recipe->getAddr(), StoredValues, Recipe->getMask(), NeedsMaskForGaps); VPIG->insertBefore(Recipe); @@ -8883,17 +8668,31 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) { // Only handle constant strides for now. if (!ScevStride) continue; - Constant *CI = ConstantInt::get(Stride->getType(), ScevStride->getAPInt()); - auto *ConstVPV = Plan->getVPValueOrAddLiveIn(CI); - // The versioned value may not be used in the loop directly, so just add a - // new live-in in those cases. - Plan->getVPValueOrAddLiveIn(StrideV)->replaceAllUsesWith(ConstVPV); + auto *CI = Plan->getOrAddLiveIn( + ConstantInt::get(Stride->getType(), ScevStride->getAPInt())); + if (VPValue *StrideVPV = Plan->getLiveIn(StrideV)) + StrideVPV->replaceAllUsesWith(CI); + + // The versioned value may not be used in the loop directly but through a + // sext/zext. Add new live-ins in those cases. + for (Value *U : StrideV->users()) { + if (!isa<SExtInst, ZExtInst>(U)) + continue; + VPValue *StrideVPV = Plan->getLiveIn(U); + if (!StrideVPV) + continue; + unsigned BW = U->getType()->getScalarSizeInBits(); + APInt C = isa<SExtInst>(U) ? ScevStride->getAPInt().sext(BW) + : ScevStride->getAPInt().zext(BW); + VPValue *CI = Plan->getOrAddLiveIn(ConstantInt::get(U->getType(), C)); + StrideVPV->replaceAllUsesWith(CI); + } } - // From this point onwards, VPlan-to-VPlan transformations may change the plan - // in ways that accessing values using original IR values is incorrect. - Plan->disableValue2VPValue(); + VPlanTransforms::dropPoisonGeneratingRecipes(*Plan, [this](BasicBlock *BB) { + return Legal->blockNeedsPredication(BB); + }); // Sink users of fixed-order recurrence past the recipe defining the previous // value and introduce FirstOrderRecurrenceSplice VPInstructions. @@ -8923,7 +8722,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) { // Create new empty VPlan auto Plan = VPlan::createInitialVPlan( createTripCountSCEV(Legal->getWidestInductionType(), PSE, OrigLoop), - *PSE.getSE()); + *PSE.getSE(), true, false, OrigLoop); // Build hierarchical CFG VPlanHCFGBuilder HCFGBuilder(OrigLoop, LI, *Plan); @@ -8948,6 +8747,7 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) { bool HasNUW = true; addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, DebugLoc()); + assert(verifyVPlanIsValid(*Plan) && "VPlan is invalid"); return Plan; } @@ -8960,9 +8760,12 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) { // A ComputeReductionResult recipe is added to the middle block, also for // in-loop reductions which compute their result in-loop, because generating // the subsequent bc.merge.rdx phi is driven by ComputeReductionResult recipes. +// +// Adjust AnyOf reductions; replace the reduction phi for the selected value +// with a boolean reduction phi node to check if the condition is true in any +// iteration. The final value is selected by the final ComputeReductionResult. void LoopVectorizationPlanner::adjustRecipesForReductions( - VPBasicBlock *LatchVPBB, VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, - ElementCount MinVF) { + VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF) { VPRegionBlock *VectorLoopRegion = Plan->getVectorLoopRegion(); VPBasicBlock *Header = VectorLoopRegion->getEntryBasicBlock(); // Gather all VPReductionPHIRecipe and sort them so that Intermediate stores @@ -9034,7 +8837,9 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( // the phi until LoopExitValue. We keep track of the previous item // (PreviousLink) to tell which of the two operands of a Link will remain // scalar and which will be reduced. For minmax by select(cmp), Link will be - // the select instructions. + // the select instructions. Blend recipes of in-loop reduction phi's will + // get folded to their non-phi operand, as the reduction recipe handles the + // condition directly. VPSingleDefRecipe *PreviousLink = PhiR; // Aka Worklist[0]. for (VPSingleDefRecipe *CurrentLink : Worklist.getArrayRef().drop_front()) { Instruction *CurrentLinkI = CurrentLink->getUnderlyingInstr(); @@ -9065,6 +8870,20 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( LinkVPBB->insert(FMulRecipe, CurrentLink->getIterator()); VecOp = FMulRecipe; } else { + auto *Blend = dyn_cast<VPBlendRecipe>(CurrentLink); + if (PhiR->isInLoop() && Blend) { + assert(Blend->getNumIncomingValues() == 2 && + "Blend must have 2 incoming values"); + if (Blend->getIncomingValue(0) == PhiR) + Blend->replaceAllUsesWith(Blend->getIncomingValue(1)); + else { + assert(Blend->getIncomingValue(1) == PhiR && + "PhiR must be an operand of the blend"); + Blend->replaceAllUsesWith(Blend->getIncomingValue(0)); + } + continue; + } + if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) { if (isa<VPWidenRecipe>(CurrentLink)) { assert(isa<CmpInst>(CurrentLinkI) && @@ -9095,14 +8914,12 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( BasicBlock *BB = CurrentLinkI->getParent(); VPValue *CondOp = nullptr; - if (CM.blockNeedsPredicationForAnyReason(BB)) { - VPBuilder::InsertPointGuard Guard(Builder); - Builder.setInsertPoint(CurrentLink); + if (CM.blockNeedsPredicationForAnyReason(BB)) CondOp = RecipeBuilder.getBlockInMask(BB); - } - VPReductionRecipe *RedRecipe = new VPReductionRecipe( - RdxDesc, CurrentLinkI, PreviousLink, VecOp, CondOp); + VPReductionRecipe *RedRecipe = + new VPReductionRecipe(RdxDesc, CurrentLinkI, PreviousLink, VecOp, + CondOp, CM.useOrderedReductions(RdxDesc)); // Append the recipe to the end of the VPBasicBlock because we need to // ensure that it comes after all of it's inputs, including CondOp. // Note that this transformation may leave over dead recipes (including @@ -9112,7 +8929,11 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( PreviousLink = RedRecipe; } } + VPBasicBlock *LatchVPBB = VectorLoopRegion->getExitingBasicBlock(); Builder.setInsertPoint(&*LatchVPBB->begin()); + VPBasicBlock *MiddleVPBB = + cast<VPBasicBlock>(VectorLoopRegion->getSingleSuccessor()); + VPBasicBlock::iterator IP = MiddleVPBB->getFirstNonPhi(); for (VPRecipeBase &R : Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) { VPReductionPHIRecipe *PhiR = dyn_cast<VPReductionPHIRecipe>(&R); @@ -9120,6 +8941,41 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( continue; const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor(); + // Adjust AnyOf reductions; replace the reduction phi for the selected value + // with a boolean reduction phi node to check if the condition is true in + // any iteration. The final value is selected by the final + // ComputeReductionResult. + if (RecurrenceDescriptor::isAnyOfRecurrenceKind( + RdxDesc.getRecurrenceKind())) { + auto *Select = cast<VPRecipeBase>(*find_if(PhiR->users(), [](VPUser *U) { + return isa<VPWidenSelectRecipe>(U) || + (isa<VPReplicateRecipe>(U) && + cast<VPReplicateRecipe>(U)->getUnderlyingInstr()->getOpcode() == + Instruction::Select); + })); + VPValue *Cmp = Select->getOperand(0); + // If the compare is checking the reduction PHI node, adjust it to check + // the start value. + if (VPRecipeBase *CmpR = Cmp->getDefiningRecipe()) { + for (unsigned I = 0; I != CmpR->getNumOperands(); ++I) + if (CmpR->getOperand(I) == PhiR) + CmpR->setOperand(I, PhiR->getStartValue()); + } + VPBuilder::InsertPointGuard Guard(Builder); + Builder.setInsertPoint(Select); + + // If the true value of the select is the reduction phi, the new value is + // selected if the negated condition is true in any iteration. + if (Select->getOperand(1) == PhiR) + Cmp = Builder.createNot(Cmp); + VPValue *Or = Builder.createOr(PhiR, Cmp); + Select->getVPSingleValue()->replaceAllUsesWith(Or); + + // Convert the reduction phi to operate on bools. + PhiR->setOperand(0, Plan->getOrAddLiveIn(ConstantInt::getFalse( + OrigLoop->getHeader()->getContext()))); + } + // If tail is folded by masking, introduce selects between the phi // and the live-out instruction of each reduction, at the beginning of the // dedicated latch block. @@ -9152,7 +9008,9 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( // then extend the loop exit value to enable InstCombine to evaluate the // entire expression in the smaller type. Type *PhiTy = PhiR->getStartValue()->getLiveInIRValue()->getType(); - if (MinVF.isVector() && PhiTy != RdxDesc.getRecurrenceType()) { + if (MinVF.isVector() && PhiTy != RdxDesc.getRecurrenceType() && + !RecurrenceDescriptor::isAnyOfRecurrenceKind( + RdxDesc.getRecurrenceKind())) { assert(!PhiR->isInLoop() && "Unexpected truncated inloop reduction!"); Type *RdxTy = RdxDesc.getRecurrenceType(); auto *Trunc = @@ -9184,8 +9042,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( // also modeled in VPlan. auto *FinalReductionResult = new VPInstruction( VPInstruction::ComputeReductionResult, {PhiR, NewExitingVPV}, ExitDL); - cast<VPBasicBlock>(VectorLoopRegion->getSingleSuccessor()) - ->appendRecipe(FinalReductionResult); + FinalReductionResult->insertBefore(*MiddleVPBB, IP); OrigExitingVPV->replaceUsesWithIf( FinalReductionResult, [](VPUser &User, unsigned) { return isa<VPLiveOut>(&User); }); @@ -9194,91 +9051,29 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( VPlanTransforms::clearReductionWrapFlags(*Plan); } -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -void VPInterleaveRecipe::print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const { - O << Indent << "INTERLEAVE-GROUP with factor " << IG->getFactor() << " at "; - IG->getInsertPos()->printAsOperand(O, false); - O << ", "; - getAddr()->printAsOperand(O, SlotTracker); - VPValue *Mask = getMask(); - if (Mask) { - O << ", "; - Mask->printAsOperand(O, SlotTracker); - } - - unsigned OpIdx = 0; - for (unsigned i = 0; i < IG->getFactor(); ++i) { - if (!IG->getMember(i)) - continue; - if (getNumStoreOperands() > 0) { - O << "\n" << Indent << " store "; - getOperand(1 + OpIdx)->printAsOperand(O, SlotTracker); - O << " to index " << i; - } else { - O << "\n" << Indent << " "; - getVPValue(OpIdx)->printAsOperand(O, SlotTracker); - O << " = load from index " << i; - } - ++OpIdx; - } -} -#endif - void VPWidenPointerInductionRecipe::execute(VPTransformState &State) { assert(IndDesc.getKind() == InductionDescriptor::IK_PtrInduction && "Not a pointer induction according to InductionDescriptor!"); assert(cast<PHINode>(getUnderlyingInstr())->getType()->isPointerTy() && "Unexpected type."); + assert(!onlyScalarsGenerated(State.VF.isScalable()) && + "Recipe should have been replaced"); auto *IVR = getParent()->getPlan()->getCanonicalIV(); - PHINode *CanonicalIV = cast<PHINode>(State.get(IVR, 0)); - - if (onlyScalarsGenerated(State.VF)) { - // This is the normalized GEP that starts counting at zero. - Value *PtrInd = State.Builder.CreateSExtOrTrunc( - CanonicalIV, IndDesc.getStep()->getType()); - // Determine the number of scalars we need to generate for each unroll - // iteration. If the instruction is uniform, we only need to generate the - // first lane. Otherwise, we generate all VF values. - bool IsUniform = vputils::onlyFirstLaneUsed(this); - assert((IsUniform || !State.VF.isScalable()) && - "Cannot scalarize a scalable VF"); - unsigned Lanes = IsUniform ? 1 : State.VF.getFixedValue(); - - for (unsigned Part = 0; Part < State.UF; ++Part) { - Value *PartStart = - createStepForVF(State.Builder, PtrInd->getType(), State.VF, Part); - - for (unsigned Lane = 0; Lane < Lanes; ++Lane) { - Value *Idx = State.Builder.CreateAdd( - PartStart, ConstantInt::get(PtrInd->getType(), Lane)); - Value *GlobalIdx = State.Builder.CreateAdd(PtrInd, Idx); - - Value *Step = State.get(getOperand(1), VPIteration(Part, Lane)); - Value *SclrGep = emitTransformedIndex( - State.Builder, GlobalIdx, IndDesc.getStartValue(), Step, - IndDesc.getKind(), IndDesc.getInductionBinOp()); - SclrGep->setName("next.gep"); - State.set(this, SclrGep, VPIteration(Part, Lane)); - } - } - return; - } - + PHINode *CanonicalIV = cast<PHINode>(State.get(IVR, 0, /*IsScalar*/ true)); Type *PhiType = IndDesc.getStep()->getType(); // Build a pointer phi Value *ScalarStartValue = getStartValue()->getLiveInIRValue(); Type *ScStValueType = ScalarStartValue->getType(); - PHINode *NewPointerPhi = - PHINode::Create(ScStValueType, 2, "pointer.phi", CanonicalIV); + PHINode *NewPointerPhi = PHINode::Create(ScStValueType, 2, "pointer.phi", + CanonicalIV->getIterator()); BasicBlock *VectorPH = State.CFG.getPreheaderBBFor(this); NewPointerPhi->addIncoming(ScalarStartValue, VectorPH); // A pointer induction, performed by using a gep - Instruction *InductionLoc = &*State.Builder.GetInsertPoint(); + BasicBlock::iterator InductionLoc = State.Builder.GetInsertPoint(); Value *ScalarStepValue = State.get(getOperand(1), VPIteration(0, 0)); Value *RuntimeVF = getRuntimeVF(State.Builder, PhiType, State.VF); @@ -9329,84 +9124,21 @@ void VPDerivedIVRecipe::execute(VPTransformState &State) { State.Builder.setFastMathFlags(FPBinOp->getFastMathFlags()); Value *Step = State.get(getStepValue(), VPIteration(0, 0)); - Value *CanonicalIV = State.get(getCanonicalIV(), VPIteration(0, 0)); + Value *CanonicalIV = State.get(getOperand(1), VPIteration(0, 0)); Value *DerivedIV = emitTransformedIndex( State.Builder, CanonicalIV, getStartValue()->getLiveInIRValue(), Step, Kind, cast_if_present<BinaryOperator>(FPBinOp)); DerivedIV->setName("offset.idx"); - if (TruncResultTy) { - assert(TruncResultTy != DerivedIV->getType() && - Step->getType()->isIntegerTy() && - "Truncation requires an integer step"); - DerivedIV = State.Builder.CreateTrunc(DerivedIV, TruncResultTy); - } assert(DerivedIV != CanonicalIV && "IV didn't need transforming?"); State.set(this, DerivedIV, VPIteration(0, 0)); } -void VPInterleaveRecipe::execute(VPTransformState &State) { - assert(!State.Instance && "Interleave group being replicated."); - State.ILV->vectorizeInterleaveGroup(IG, definedValues(), State, getAddr(), - getStoredValues(), getMask(), - NeedsMaskForGaps); -} - -void VPReductionRecipe::execute(VPTransformState &State) { - assert(!State.Instance && "Reduction being replicated."); - Value *PrevInChain = State.get(getChainOp(), 0); - RecurKind Kind = RdxDesc.getRecurrenceKind(); - bool IsOrdered = State.ILV->useOrderedReductions(RdxDesc); - // Propagate the fast-math flags carried by the underlying instruction. - IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder); - State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags()); - for (unsigned Part = 0; Part < State.UF; ++Part) { - Value *NewVecOp = State.get(getVecOp(), Part); - if (VPValue *Cond = getCondOp()) { - Value *NewCond = State.VF.isVector() ? State.get(Cond, Part) - : State.get(Cond, {Part, 0}); - VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType()); - Type *ElementTy = VecTy ? VecTy->getElementType() : NewVecOp->getType(); - Value *Iden = RdxDesc.getRecurrenceIdentity(Kind, ElementTy, - RdxDesc.getFastMathFlags()); - if (State.VF.isVector()) { - Iden = - State.Builder.CreateVectorSplat(VecTy->getElementCount(), Iden); - } - - Value *Select = State.Builder.CreateSelect(NewCond, NewVecOp, Iden); - NewVecOp = Select; - } - Value *NewRed; - Value *NextInChain; - if (IsOrdered) { - if (State.VF.isVector()) - NewRed = createOrderedReduction(State.Builder, RdxDesc, NewVecOp, - PrevInChain); - else - NewRed = State.Builder.CreateBinOp( - (Instruction::BinaryOps)RdxDesc.getOpcode(Kind), PrevInChain, - NewVecOp); - PrevInChain = NewRed; - } else { - PrevInChain = State.get(getChainOp(), Part); - NewRed = createTargetReduction(State.Builder, RdxDesc, NewVecOp); - } - if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) { - NextInChain = createMinMaxOp(State.Builder, RdxDesc.getRecurrenceKind(), - NewRed, PrevInChain); - } else if (IsOrdered) - NextInChain = NewRed; - else - NextInChain = State.Builder.CreateBinOp( - (Instruction::BinaryOps)RdxDesc.getOpcode(Kind), NewRed, PrevInChain); - State.set(this, NextInChain, Part); - } -} - void VPReplicateRecipe::execute(VPTransformState &State) { Instruction *UI = getUnderlyingInstr(); if (State.Instance) { // Generate a single instance. + assert((State.VF.isScalar() || !isUniform()) && + "uniform recipe shouldn't be predicated"); assert(!State.VF.isScalable() && "Can't scalarize a scalable vector"); State.ILV->scalarizeInstruction(UI, this, *State.Instance, State); // Insert scalar instance packing it into a vector. @@ -9464,98 +9196,180 @@ void VPReplicateRecipe::execute(VPTransformState &State) { State.ILV->scalarizeInstruction(UI, this, VPIteration(Part, Lane), State); } -void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { - VPValue *StoredValue = isStore() ? getStoredValue() : nullptr; - - // Attempt to issue a wide load. - LoadInst *LI = dyn_cast<LoadInst>(&Ingredient); - StoreInst *SI = dyn_cast<StoreInst>(&Ingredient); - - assert((LI || SI) && "Invalid Load/Store instruction"); - assert((!SI || StoredValue) && "No stored value provided for widened store"); - assert((!LI || !StoredValue) && "Stored value provided for widened load"); +void VPWidenLoadRecipe::execute(VPTransformState &State) { + auto *LI = cast<LoadInst>(&Ingredient); Type *ScalarDataTy = getLoadStoreType(&Ingredient); - auto *DataTy = VectorType::get(ScalarDataTy, State.VF); const Align Alignment = getLoadStoreAlignment(&Ingredient); - bool CreateGatherScatter = !isConsecutive(); + bool CreateGather = !isConsecutive(); auto &Builder = State.Builder; - InnerLoopVectorizer::VectorParts BlockInMaskParts(State.UF); - bool isMaskRequired = getMask(); - if (isMaskRequired) { - // Mask reversal is only needed for non-all-one (null) masks, as reverse of - // a null all-one mask is a null mask. - for (unsigned Part = 0; Part < State.UF; ++Part) { - Value *Mask = State.get(getMask(), Part); + State.setDebugLocFrom(getDebugLoc()); + for (unsigned Part = 0; Part < State.UF; ++Part) { + Value *NewLI; + Value *Mask = nullptr; + if (auto *VPMask = getMask()) { + // Mask reversal is only needed for non-all-one (null) masks, as reverse + // of a null all-one mask is a null mask. + Mask = State.get(VPMask, Part); if (isReverse()) Mask = Builder.CreateVectorReverse(Mask, "reverse"); - BlockInMaskParts[Part] = Mask; } + + Value *Addr = State.get(getAddr(), Part, /*IsScalar*/ !CreateGather); + if (CreateGather) { + NewLI = Builder.CreateMaskedGather(DataTy, Addr, Alignment, Mask, nullptr, + "wide.masked.gather"); + } else if (Mask) { + NewLI = Builder.CreateMaskedLoad(DataTy, Addr, Alignment, Mask, + PoisonValue::get(DataTy), + "wide.masked.load"); + } else { + NewLI = Builder.CreateAlignedLoad(DataTy, Addr, Alignment, "wide.load"); + } + // Add metadata to the load, but setVectorValue to the reverse shuffle. + State.addMetadata(NewLI, LI); + if (Reverse) + NewLI = Builder.CreateVectorReverse(NewLI, "reverse"); + State.set(this, NewLI, Part); } +} - // Handle Stores: - if (SI) { - State.setDebugLocFrom(SI->getDebugLoc()); +/// Use all-true mask for reverse rather than actual mask, as it avoids a +/// dependence w/o affecting the result. +static Instruction *createReverseEVL(IRBuilderBase &Builder, Value *Operand, + Value *EVL, const Twine &Name) { + VectorType *ValTy = cast<VectorType>(Operand->getType()); + Value *AllTrueMask = + Builder.CreateVectorSplat(ValTy->getElementCount(), Builder.getTrue()); + return Builder.CreateIntrinsic(ValTy, Intrinsic::experimental_vp_reverse, + {Operand, AllTrueMask, EVL}, nullptr, Name); +} - for (unsigned Part = 0; Part < State.UF; ++Part) { - Instruction *NewSI = nullptr; - Value *StoredVal = State.get(StoredValue, Part); - if (CreateGatherScatter) { - Value *MaskPart = isMaskRequired ? BlockInMaskParts[Part] : nullptr; - Value *VectorGep = State.get(getAddr(), Part); - NewSI = Builder.CreateMaskedScatter(StoredVal, VectorGep, Alignment, - MaskPart); - } else { - if (isReverse()) { - // If we store to reverse consecutive memory locations, then we need - // to reverse the order of elements in the stored value. - StoredVal = Builder.CreateVectorReverse(StoredVal, "reverse"); - // We don't want to update the value in the map as it might be used in - // another expression. So don't call resetVectorValue(StoredVal). - } - auto *VecPtr = State.get(getAddr(), Part); - if (isMaskRequired) - NewSI = Builder.CreateMaskedStore(StoredVal, VecPtr, Alignment, - BlockInMaskParts[Part]); - else - NewSI = Builder.CreateAlignedStore(StoredVal, VecPtr, Alignment); - } - State.addMetadata(NewSI, SI); - } - return; +void VPWidenLoadEVLRecipe::execute(VPTransformState &State) { + assert(State.UF == 1 && "Expected only UF == 1 when vectorizing with " + "explicit vector length."); + auto *LI = cast<LoadInst>(&Ingredient); + + Type *ScalarDataTy = getLoadStoreType(&Ingredient); + auto *DataTy = VectorType::get(ScalarDataTy, State.VF); + const Align Alignment = getLoadStoreAlignment(&Ingredient); + bool CreateGather = !isConsecutive(); + + auto &Builder = State.Builder; + State.setDebugLocFrom(getDebugLoc()); + CallInst *NewLI; + Value *EVL = State.get(getEVL(), VPIteration(0, 0)); + Value *Addr = State.get(getAddr(), 0, !CreateGather); + Value *Mask = nullptr; + if (VPValue *VPMask = getMask()) { + Mask = State.get(VPMask, 0); + if (isReverse()) + Mask = createReverseEVL(Builder, Mask, EVL, "vp.reverse.mask"); + } else { + Mask = Builder.CreateVectorSplat(State.VF, Builder.getTrue()); + } + + if (CreateGather) { + NewLI = + Builder.CreateIntrinsic(DataTy, Intrinsic::vp_gather, {Addr, Mask, EVL}, + nullptr, "wide.masked.gather"); + } else { + VectorBuilder VBuilder(Builder); + VBuilder.setEVL(EVL).setMask(Mask); + NewLI = cast<CallInst>(VBuilder.createVectorInstruction( + Instruction::Load, DataTy, Addr, "vp.op.load")); } + NewLI->addParamAttr( + 0, Attribute::getWithAlignment(NewLI->getContext(), Alignment)); + State.addMetadata(NewLI, LI); + Instruction *Res = NewLI; + if (isReverse()) + Res = createReverseEVL(Builder, Res, EVL, "vp.reverse"); + State.set(this, Res, 0); +} + +void VPWidenStoreRecipe::execute(VPTransformState &State) { + auto *SI = cast<StoreInst>(&Ingredient); + + VPValue *StoredVPValue = getStoredValue(); + bool CreateScatter = !isConsecutive(); + const Align Alignment = getLoadStoreAlignment(&Ingredient); + + auto &Builder = State.Builder; + State.setDebugLocFrom(getDebugLoc()); - // Handle loads. - assert(LI && "Must have a load instruction"); - State.setDebugLocFrom(LI->getDebugLoc()); for (unsigned Part = 0; Part < State.UF; ++Part) { - Value *NewLI; - if (CreateGatherScatter) { - Value *MaskPart = isMaskRequired ? BlockInMaskParts[Part] : nullptr; - Value *VectorGep = State.get(getAddr(), Part); - NewLI = Builder.CreateMaskedGather(DataTy, VectorGep, Alignment, MaskPart, - nullptr, "wide.masked.gather"); - State.addMetadata(NewLI, LI); - } else { - auto *VecPtr = State.get(getAddr(), Part); - if (isMaskRequired) - NewLI = Builder.CreateMaskedLoad( - DataTy, VecPtr, Alignment, BlockInMaskParts[Part], - PoisonValue::get(DataTy), "wide.masked.load"); - else - NewLI = - Builder.CreateAlignedLoad(DataTy, VecPtr, Alignment, "wide.load"); + Instruction *NewSI = nullptr; + Value *Mask = nullptr; + if (auto *VPMask = getMask()) { + // Mask reversal is only needed for non-all-one (null) masks, as reverse + // of a null all-one mask is a null mask. + Mask = State.get(VPMask, Part); + if (isReverse()) + Mask = Builder.CreateVectorReverse(Mask, "reverse"); + } - // Add metadata to the load, but setVectorValue to the reverse shuffle. - State.addMetadata(NewLI, LI); - if (Reverse) - NewLI = Builder.CreateVectorReverse(NewLI, "reverse"); + Value *StoredVal = State.get(StoredVPValue, Part); + if (isReverse()) { + // If we store to reverse consecutive memory locations, then we need + // to reverse the order of elements in the stored value. + StoredVal = Builder.CreateVectorReverse(StoredVal, "reverse"); + // We don't want to update the value in the map as it might be used in + // another expression. So don't call resetVectorValue(StoredVal). } + Value *Addr = State.get(getAddr(), Part, /*IsScalar*/ !CreateScatter); + if (CreateScatter) + NewSI = Builder.CreateMaskedScatter(StoredVal, Addr, Alignment, Mask); + else if (Mask) + NewSI = Builder.CreateMaskedStore(StoredVal, Addr, Alignment, Mask); + else + NewSI = Builder.CreateAlignedStore(StoredVal, Addr, Alignment); + State.addMetadata(NewSI, SI); + } +} - State.set(getVPSingleValue(), NewLI, Part); +void VPWidenStoreEVLRecipe::execute(VPTransformState &State) { + assert(State.UF == 1 && "Expected only UF == 1 when vectorizing with " + "explicit vector length."); + auto *SI = cast<StoreInst>(&Ingredient); + + VPValue *StoredValue = getStoredValue(); + bool CreateScatter = !isConsecutive(); + const Align Alignment = getLoadStoreAlignment(&Ingredient); + + auto &Builder = State.Builder; + State.setDebugLocFrom(getDebugLoc()); + + CallInst *NewSI = nullptr; + Value *StoredVal = State.get(StoredValue, 0); + Value *EVL = State.get(getEVL(), VPIteration(0, 0)); + if (isReverse()) + StoredVal = createReverseEVL(Builder, StoredVal, EVL, "vp.reverse"); + Value *Mask = nullptr; + if (VPValue *VPMask = getMask()) { + Mask = State.get(VPMask, 0); + if (isReverse()) + Mask = createReverseEVL(Builder, Mask, EVL, "vp.reverse.mask"); + } else { + Mask = Builder.CreateVectorSplat(State.VF, Builder.getTrue()); } + Value *Addr = State.get(getAddr(), 0, !CreateScatter); + if (CreateScatter) { + NewSI = Builder.CreateIntrinsic(Type::getVoidTy(EVL->getContext()), + Intrinsic::vp_scatter, + {StoredVal, Addr, Mask, EVL}); + } else { + VectorBuilder VBuilder(Builder); + VBuilder.setEVL(EVL).setMask(Mask); + NewSI = cast<CallInst>(VBuilder.createVectorInstruction( + Instruction::Store, Type::getVoidTy(EVL->getContext()), + {StoredVal, Addr})); + } + NewSI->addParamAttr( + 1, Attribute::getWithAlignment(NewSI->getContext(), Alignment)); + State.addMetadata(NewSI, SI); } // Determine how to lower the scalar epilogue, which depends on 1) optimising @@ -9658,7 +9472,7 @@ static bool processLoopInVPlanNativePath( bool AddBranchWeights = hasBranchWeightMD(*L->getLoopLatch()->getTerminator()); GeneratedRTChecks Checks(*PSE.getSE(), DT, LI, TTI, - F->getParent()->getDataLayout(), AddBranchWeights); + F->getDataLayout(), AddBranchWeights); InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, AC, ORE, VF.Width, VF.Width, 1, LVL, &CM, BFI, PSI, Checks); LLVM_DEBUG(dbgs() << "Vectorizing outer loop in \"" @@ -9741,7 +9555,7 @@ static bool areRuntimeChecksProfitable(GeneratedRTChecks &Checks, } // The scalar cost should only be 0 when vectorizing with a user specified VF/IC. In those cases, runtime checks should always be generated. - double ScalarC = *VF.ScalarCost.getValue(); + uint64_t ScalarC = *VF.ScalarCost.getValue(); if (ScalarC == 0) return true; @@ -9768,7 +9582,7 @@ static bool areRuntimeChecksProfitable(GeneratedRTChecks &Checks, // RtC + VecC * (TC / VF) + EpiC < ScalarC * TC // // Now we can compute the minimum required trip count TC as - // (RtC + EpiC) / (ScalarC - (VecC / VF)) < TC + // VF * (RtC + EpiC) / (ScalarC * VF - VecC) < TC // // For now we assume the epilogue cost EpiC = 0 for simplicity. Note that // the computations are performed on doubles, not integers and the result @@ -9780,9 +9594,9 @@ static bool areRuntimeChecksProfitable(GeneratedRTChecks &Checks, AssumedMinimumVscale = *VScale; IntVF *= AssumedMinimumVscale; } - double VecCOverVF = double(*VF.Cost.getValue()) / IntVF; - double RtC = *CheckCost.getValue(); - double MinTC1 = RtC / (ScalarC - VecCOverVF); + uint64_t RtC = *CheckCost.getValue(); + uint64_t Div = ScalarC * IntVF - *VF.Cost.getValue(); + uint64_t MinTC1 = Div == 0 ? 0 : divideCeil(RtC * IntVF, Div); // Second, compute a minimum iteration count so that the cost of the // runtime checks is only a fraction of the total scalar loop cost. This @@ -9791,12 +9605,12 @@ static bool areRuntimeChecksProfitable(GeneratedRTChecks &Checks, // * TC. To bound the runtime check to be a fraction 1/X of the scalar // cost, compute // RtC < ScalarC * TC * (1 / X) ==> RtC * X / ScalarC < TC - double MinTC2 = RtC * 10 / ScalarC; + uint64_t MinTC2 = divideCeil(RtC * 10, ScalarC); // Now pick the larger minimum. If it is not a multiple of VF and a scalar // epilogue is allowed, choose the next closest multiple of VF. This should // partly compensate for ignoring the epilogue cost. - uint64_t MinTC = std::ceil(std::max(MinTC1, MinTC2)); + uint64_t MinTC = std::max(MinTC1, MinTC2); if (SEL == CM_ScalarEpilogueAllowed) MinTC = alignTo(MinTC, IntVF); VF.MinProfitableTripCount = ElementCount::getFixed(MinTC); @@ -9831,13 +9645,9 @@ bool LoopVectorizePass::processLoop(Loop *L) { assert((EnableVPlanNativePath || L->isInnermost()) && "VPlan-native path is not enabled. Only process inner loops."); -#ifndef NDEBUG - const std::string DebugLocStr = getDebugLocString(L); -#endif /* NDEBUG */ - LLVM_DEBUG(dbgs() << "\nLV: Checking a loop in '" << L->getHeader()->getParent()->getName() << "' from " - << DebugLocStr << "\n"); + << L->getLocStr() << "\n"); LoopVectorizeHints Hints(L, InterleaveOnlyWhenForced, *ORE, TTI); @@ -10006,7 +9816,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { bool AddBranchWeights = hasBranchWeightMD(*L->getLoopLatch()->getTerminator()); GeneratedRTChecks Checks(*PSE.getSE(), DT, LI, TTI, - F->getParent()->getDataLayout(), AddBranchWeights); + F->getDataLayout(), AddBranchWeights); if (MaybeVF) { VF = *MaybeVF; // Select the interleave count. @@ -10107,7 +9917,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { }); } else if (VectorizeLoop && !InterleaveLoop) { LLVM_DEBUG(dbgs() << "LV: Found a vectorizable loop (" << VF.Width - << ") in " << DebugLocStr << '\n'); + << ") in " << L->getLocStr() << '\n'); ORE->emit([&]() { return OptimizationRemarkAnalysis(LV_NAME, IntDiagMsg.first, L->getStartLoc(), L->getHeader()) @@ -10115,7 +9925,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { }); } else if (VectorizeLoop && InterleaveLoop) { LLVM_DEBUG(dbgs() << "LV: Found a vectorizable loop (" << VF.Width - << ") in " << DebugLocStr << '\n'); + << ") in " << L->getLocStr() << '\n'); LLVM_DEBUG(dbgs() << "LV: Interleave Count is " << IC << '\n'); } @@ -10130,7 +9940,10 @@ bool LoopVectorizePass::processLoop(Loop *L) { InnerLoopUnroller Unroller(L, PSE, LI, DT, TLI, TTI, AC, ORE, IC, &LVL, &CM, BFI, PSI, Checks); - VPlan &BestPlan = LVP.getBestPlanFor(VF.Width); + VPlan &BestPlan = + UseLegacyCostModel ? LVP.getBestPlanFor(VF.Width) : LVP.getBestPlan(); + assert((UseLegacyCostModel || BestPlan.hasScalarVFOnly()) && + "VPlan cost model and legacy cost model disagreed"); LVP.executePlan(VF.Width, IC, BestPlan, Unroller, DT, false); ORE->emit([&]() { @@ -10154,9 +9967,10 @@ bool LoopVectorizePass::processLoop(Loop *L) { EpilogueVectorizerMainLoop MainILV(L, PSE, LI, DT, TLI, TTI, AC, ORE, EPI, &LVL, &CM, BFI, PSI, Checks); - VPlan &BestMainPlan = LVP.getBestPlanFor(EPI.MainLoopVF); + std::unique_ptr<VPlan> BestMainPlan( + LVP.getBestPlanFor(EPI.MainLoopVF).duplicate()); const auto &[ExpandedSCEVs, ReductionResumeValues] = LVP.executePlan( - EPI.MainLoopVF, EPI.MainLoopUF, BestMainPlan, MainILV, DT, true); + EPI.MainLoopVF, EPI.MainLoopUF, *BestMainPlan, MainILV, DT, true); ++LoopsVectorized; // Second pass vectorizes the epilogue and adjusts the control flow @@ -10181,9 +9995,11 @@ bool LoopVectorizePass::processLoop(Loop *L) { EpilogILV.setTripCount(MainILV.getTripCount()); for (auto &R : make_early_inc_range(*BestEpiPlan.getPreheader())) { auto *ExpandR = cast<VPExpandSCEVRecipe>(&R); - auto *ExpandedVal = BestEpiPlan.getVPValueOrAddLiveIn( + auto *ExpandedVal = BestEpiPlan.getOrAddLiveIn( ExpandedSCEVs.find(ExpandR->getSCEV())->second); ExpandR->replaceAllUsesWith(ExpandedVal); + if (BestEpiPlan.getTripCount() == ExpandR) + BestEpiPlan.resetTripCount(ExpandedVal); ExpandR->eraseFromParent(); } @@ -10197,9 +10013,19 @@ bool LoopVectorizePass::processLoop(Loop *L) { Value *ResumeV = nullptr; // TODO: Move setting of resume values to prepareToExecute. if (auto *ReductionPhi = dyn_cast<VPReductionPHIRecipe>(&R)) { - ResumeV = ReductionResumeValues - .find(&ReductionPhi->getRecurrenceDescriptor()) - ->second; + const RecurrenceDescriptor &RdxDesc = + ReductionPhi->getRecurrenceDescriptor(); + RecurKind RK = RdxDesc.getRecurrenceKind(); + ResumeV = ReductionResumeValues.find(&RdxDesc)->second; + if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) { + // VPReductionPHIRecipes for AnyOf reductions expect a boolean as + // start value; compare the final value from the main vector loop + // to the start value. + IRBuilder<> Builder( + cast<Instruction>(ResumeV)->getParent()->getFirstNonPHI()); + ResumeV = Builder.CreateICmpNE(ResumeV, + RdxDesc.getRecurrenceStartValue()); + } } else { // Create induction resume values for both widened pointer and // integer/fp inductions and update the start value of the induction @@ -10220,10 +10046,12 @@ bool LoopVectorizePass::processLoop(Loop *L) { {EPI.MainLoopIterationCountCheck}); } assert(ResumeV && "Must have a resume value"); - VPValue *StartVal = BestEpiPlan.getVPValueOrAddLiveIn(ResumeV); + VPValue *StartVal = BestEpiPlan.getOrAddLiveIn(ResumeV); cast<VPHeaderPHIRecipe>(&R)->setStartValue(StartVal); } + assert(DT->verify(DominatorTree::VerificationLevel::Fast) && + "DT not preserved correctly"); LVP.executePlan(EPI.EpilogueVF, EPI.EpilogueUF, BestEpiPlan, EpilogILV, DT, true, &ExpandedSCEVs); ++LoopsEpilogueVectorized; @@ -10231,12 +10059,22 @@ bool LoopVectorizePass::processLoop(Loop *L) { if (!MainILV.areSafetyChecksAdded()) DisableRuntimeUnroll = true; } else { - InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, AC, ORE, VF.Width, + ElementCount Width = VF.Width; + VPlan &BestPlan = + UseLegacyCostModel ? LVP.getBestPlanFor(Width) : LVP.getBestPlan(); + if (!UseLegacyCostModel) { + assert(size(BestPlan.vectorFactors()) == 1 && + "Plan should have a single VF"); + Width = *BestPlan.vectorFactors().begin(); + LLVM_DEBUG(dbgs() + << "VF picked by VPlan cost model: " << Width << "\n"); + assert(VF.Width == Width && + "VPlan cost model and legacy cost model disagreed"); + } + InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, AC, ORE, Width, VF.MinProfitableTripCount, IC, &LVL, &CM, BFI, PSI, Checks); - - VPlan &BestPlan = LVP.getBestPlanFor(VF.Width); - LVP.executePlan(VF.Width, IC, BestPlan, LB, DT, false); + LVP.executePlan(Width, IC, BestPlan, LB, DT, false); ++LoopsVectorized; // Add metadata to disable runtime unrolling a scalar loop when there @@ -10376,15 +10214,10 @@ PreservedAnalyses LoopVectorizePass::run(Function &F, RemoveRedundantDbgInstrs(&BB); } - // We currently do not preserve loopinfo/dominator analyses with outer loop - // vectorization. Until this is addressed, mark these analyses as preserved - // only for non-VPlan-native path. - // TODO: Preserve Loop and Dominator analyses for VPlan-native path. - if (!EnableVPlanNativePath) { - PA.preserve<LoopAnalysis>(); - PA.preserve<DominatorTreeAnalysis>(); - PA.preserve<ScalarEvolutionAnalysis>(); - } + PA.preserve<LoopAnalysis>(); + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<ScalarEvolutionAnalysis>(); + PA.preserve<LoopAccessAnalysis>(); if (Result.MadeCFGChange) { // Making CFG changes likely means a loop got vectorized. Indicate that diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 1fbd69e38eae..fd08d5d9d755 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/PriorityQueue.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" @@ -87,6 +88,7 @@ #include "llvm/Transforms/Utils/InjectTLIMappings.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include <algorithm> #include <cassert> #include <cstdint> @@ -111,11 +113,20 @@ static cl::opt<bool> RunSLPVectorization("vectorize-slp", cl::init(true), cl::Hidden, cl::desc("Run the SLP vectorization passes")); +static cl::opt<bool> + SLPReVec("slp-revec", cl::init(false), cl::Hidden, + cl::desc("Enable vectorization for wider vector utilization")); + static cl::opt<int> SLPCostThreshold("slp-threshold", cl::init(0), cl::Hidden, cl::desc("Only vectorize if you gain more than this " "number ")); +static cl::opt<bool> SLPSkipEarlyProfitabilityCheck( + "slp-skip-early-profitability-check", cl::init(false), cl::Hidden, + cl::desc("When true, SLP vectorizer bypasses profitability checks based on " + "heuristics and makes vectorization decision via cost modeling.")); + static cl::opt<bool> ShouldVectorizeHor("slp-vectorize-hor", cl::init(true), cl::Hidden, cl::desc("Attempt to vectorize horizontal reductions")); @@ -175,14 +186,31 @@ static cl::opt<int> RootLookAheadMaxDepth( "slp-max-root-look-ahead-depth", cl::init(2), cl::Hidden, cl::desc("The maximum look-ahead depth for searching best rooting option")); +static cl::opt<unsigned> MinProfitableStridedLoads( + "slp-min-strided-loads", cl::init(2), cl::Hidden, + cl::desc("The minimum number of loads, which should be considered strided, " + "if the stride is > 1 or is runtime value")); + +static cl::opt<unsigned> MaxProfitableLoadStride( + "slp-max-stride", cl::init(8), cl::Hidden, + cl::desc("The maximum stride, considered to be profitable.")); + static cl::opt<bool> ViewSLPTree("view-slp-tree", cl::Hidden, cl::desc("Display the SLP trees with Graphviz")); +static cl::opt<bool> VectorizeNonPowerOf2( + "slp-vectorize-non-power-of-2", cl::init(false), cl::Hidden, + cl::desc("Try to vectorize with non-power-of-2 number of elements.")); + // Limit the number of alias checks. The limit is chosen so that // it has no negative effect on the llvm benchmarks. static const unsigned AliasedCheckLimit = 10; +// Limit of the number of uses for potentially transformed instructions/values, +// used in checks to avoid compile-time explode. +static constexpr int UsesLimit = 64; + // Another limit for the alias checks: The maximum distance between load/store // instructions where alias checks are done. // This limit is useful for very large basic blocks. @@ -192,6 +220,9 @@ static const unsigned MaxMemDepDistance = 160; /// regions to be handled. static const int MinScheduleRegionSize = 16; +/// Maximum allowed number of operands in the PHI nodes. +static const unsigned MaxPHINumOperands = 128; + /// Predicate for the element types that the SLP vectorizer supports. /// /// The most important thing to filter here are types which are invalid in LLVM @@ -200,10 +231,28 @@ static const int MinScheduleRegionSize = 16; /// avoids spending time checking the cost model and realizing that they will /// be inevitably scalarized. static bool isValidElementType(Type *Ty) { + // TODO: Support ScalableVectorType. + if (SLPReVec && isa<FixedVectorType>(Ty)) + Ty = Ty->getScalarType(); return VectorType::isValidElementType(Ty) && !Ty->isX86_FP80Ty() && !Ty->isPPC_FP128Ty(); } +/// \returns the number of elements for Ty. +static unsigned getNumElements(Type *Ty) { + assert(!isa<ScalableVectorType>(Ty) && + "ScalableVectorType is not supported."); + if (auto *VecTy = dyn_cast<FixedVectorType>(Ty)) + return VecTy->getNumElements(); + return 1; +} + +/// \returns the vector type of ScalarTy based on vectorization factor. +static FixedVectorType *getWidenedType(Type *ScalarTy, unsigned VF) { + return FixedVectorType::get(ScalarTy->getScalarType(), + VF * getNumElements(ScalarTy)); +} + /// \returns True if the value is a constant (but not globals/constant /// expressions). static bool isConstant(Value *V) { @@ -228,6 +277,21 @@ static bool isVectorLikeInstWithConstOps(Value *V) { return isConstant(I->getOperand(2)); } +/// Returns power-of-2 number of elements in a single register (part), given the +/// total number of elements \p Size and number of registers (parts) \p +/// NumParts. +static unsigned getPartNumElems(unsigned Size, unsigned NumParts) { + return PowerOf2Ceil(divideCeil(Size, NumParts)); +} + +/// Returns correct remaining number of elements, considering total amount \p +/// Size, (power-of-2 number) of elements in a single register \p PartNumElems +/// and current register (part) \p Part. +static unsigned getNumElems(unsigned Size, unsigned PartNumElems, + unsigned Part) { + return std::min<unsigned>(PartNumElems, Size - Part * PartNumElems); +} + #if !defined(NDEBUG) /// Print a short descriptor of the instruction bundle suitable for debug output. static std::string shortBundleName(ArrayRef<Value *> VL) { @@ -290,19 +354,43 @@ static bool isCommutative(Instruction *I) { if (auto *Cmp = dyn_cast<CmpInst>(I)) return Cmp->isCommutative(); if (auto *BO = dyn_cast<BinaryOperator>(I)) - return BO->isCommutative(); - // TODO: This should check for generic Instruction::isCommutative(), but - // we need to confirm that the caller code correctly handles Intrinsics - // for example (does not have 2 operands). - return false; + return BO->isCommutative() || + (BO->getOpcode() == Instruction::Sub && + !BO->hasNUsesOrMore(UsesLimit) && + all_of( + BO->uses(), + [](const Use &U) { + // Commutative, if icmp eq/ne sub, 0 + ICmpInst::Predicate Pred; + if (match(U.getUser(), + m_ICmp(Pred, m_Specific(U.get()), m_Zero())) && + (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)) + return true; + // Commutative, if abs(sub nsw, true) or abs(sub, false). + ConstantInt *Flag; + return match(U.getUser(), + m_Intrinsic<Intrinsic::abs>( + m_Specific(U.get()), m_ConstantInt(Flag))) && + (!cast<Instruction>(U.get())->hasNoSignedWrap() || + Flag->isOne()); + })) || + (BO->getOpcode() == Instruction::FSub && + !BO->hasNUsesOrMore(UsesLimit) && + all_of(BO->uses(), [](const Use &U) { + return match(U.getUser(), + m_Intrinsic<Intrinsic::fabs>(m_Specific(U.get()))); + })); + return I->isCommutative(); } -/// \returns inserting index of InsertElement or InsertValue instruction, -/// using Offset as base offset for index. -static std::optional<unsigned> getInsertIndex(const Value *InsertInst, - unsigned Offset = 0) { +template <typename T> +static std::optional<unsigned> getInsertExtractIndex(const Value *Inst, + unsigned Offset) { + static_assert(std::is_same_v<T, InsertElementInst> || + std::is_same_v<T, ExtractElementInst>, + "unsupported T"); int Index = Offset; - if (const auto *IE = dyn_cast<InsertElementInst>(InsertInst)) { + if (const auto *IE = dyn_cast<T>(Inst)) { const auto *VT = dyn_cast<FixedVectorType>(IE->getType()); if (!VT) return std::nullopt; @@ -315,8 +403,25 @@ static std::optional<unsigned> getInsertIndex(const Value *InsertInst, Index += CI->getZExtValue(); return Index; } + return std::nullopt; +} + +/// \returns inserting or extracting index of InsertElement, ExtractElement or +/// InsertValue instruction, using Offset as base offset for index. +/// \returns std::nullopt if the index is not an immediate. +static std::optional<unsigned> getElementIndex(const Value *Inst, + unsigned Offset = 0) { + if (auto Index = getInsertExtractIndex<InsertElementInst>(Inst, Offset)) + return Index; + if (auto Index = getInsertExtractIndex<ExtractElementInst>(Inst, Offset)) + return Index; + + int Index = Offset; + + const auto *IV = dyn_cast<InsertValueInst>(Inst); + if (!IV) + return std::nullopt; - const auto *IV = cast<InsertValueInst>(InsertInst); Type *CurrentType = IV->getType(); for (unsigned I : IV->indices()) { if (const auto *ST = dyn_cast<StructType>(CurrentType)) { @@ -390,7 +495,7 @@ static SmallBitVector isUndefVector(const Value *V, Base = II->getOperand(0); if (isa<T>(II->getOperand(1))) continue; - std::optional<unsigned> Idx = getInsertIndex(II); + std::optional<unsigned> Idx = getElementIndex(II); if (!Idx) { Res.reset(); return Res; @@ -443,17 +548,31 @@ static SmallBitVector isUndefVector(const Value *V, /// ShuffleVectorInst/getShuffleCost? static std::optional<TargetTransformInfo::ShuffleKind> isFixedVectorShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) { - const auto *It = - find_if(VL, [](Value *V) { return isa<ExtractElementInst>(V); }); + const auto *It = find_if(VL, IsaPred<ExtractElementInst>); if (It == VL.end()) return std::nullopt; - auto *EI0 = cast<ExtractElementInst>(*It); - if (isa<ScalableVectorType>(EI0->getVectorOperandType())) - return std::nullopt; unsigned Size = - cast<FixedVectorType>(EI0->getVectorOperandType())->getNumElements(); + std::accumulate(VL.begin(), VL.end(), 0u, [](unsigned S, Value *V) { + auto *EI = dyn_cast<ExtractElementInst>(V); + if (!EI) + return S; + auto *VTy = dyn_cast<FixedVectorType>(EI->getVectorOperandType()); + if (!VTy) + return S; + return std::max(S, VTy->getNumElements()); + }); + Value *Vec1 = nullptr; Value *Vec2 = nullptr; + bool HasNonUndefVec = any_of(VL, [](Value *V) { + auto *EE = dyn_cast<ExtractElementInst>(V); + if (!EE) + return false; + Value *Vec = EE->getVectorOperand(); + if (isa<UndefValue>(Vec)) + return false; + return isGuaranteedNotToBePoison(Vec); + }); enum ShuffleMode { Unknown, Select, Permute }; ShuffleMode CommonShuffleMode = Unknown; Mask.assign(VL.size(), PoisonMaskElem); @@ -466,21 +585,25 @@ isFixedVectorShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) { return std::nullopt; auto *Vec = EI->getVectorOperand(); // We can extractelement from undef or poison vector. - if (isUndefVector(Vec).all()) + if (isUndefVector</*isPoisonOnly=*/true>(Vec).all()) continue; // All vector operands must have the same number of vector elements. - if (cast<FixedVectorType>(Vec->getType())->getNumElements() != Size) - return std::nullopt; - if (isa<UndefValue>(EI->getIndexOperand())) - continue; - auto *Idx = dyn_cast<ConstantInt>(EI->getIndexOperand()); - if (!Idx) - return std::nullopt; - // Undefined behavior if Idx is negative or >= Size. - if (Idx->getValue().uge(Size)) + if (isa<UndefValue>(Vec)) { + Mask[I] = I; + } else { + if (isa<UndefValue>(EI->getIndexOperand())) + continue; + auto *Idx = dyn_cast<ConstantInt>(EI->getIndexOperand()); + if (!Idx) + return std::nullopt; + // Undefined behavior if Idx is negative or >= Size. + if (Idx->getValue().uge(Size)) + continue; + unsigned IntIdx = Idx->getValue().getZExtValue(); + Mask[I] = IntIdx; + } + if (isUndefVector(Vec).all() && HasNonUndefVec) continue; - unsigned IntIdx = Idx->getValue().getZExtValue(); - Mask[I] = IntIdx; // For correct shuffling we have to have at most 2 different vector operands // in all extractelement instructions. if (!Vec1 || Vec1 == Vec) { @@ -495,7 +618,7 @@ isFixedVectorShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) { continue; // If the extract index is not the same as the operation number, it is a // permutation. - if (IntIdx != I) { + if (Mask[I] % Size != I) { CommonShuffleMode = Permute; continue; } @@ -644,6 +767,29 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL, unsigned AltOpcode = Opcode; unsigned AltIndex = BaseIndex; + bool SwappedPredsCompatible = [&]() { + if (!IsCmpOp) + return false; + SetVector<unsigned> UniquePreds, UniqueNonSwappedPreds; + UniquePreds.insert(BasePred); + UniqueNonSwappedPreds.insert(BasePred); + for (Value *V : VL) { + auto *I = dyn_cast<CmpInst>(V); + if (!I) + return false; + CmpInst::Predicate CurrentPred = I->getPredicate(); + CmpInst::Predicate SwappedCurrentPred = + CmpInst::getSwappedPredicate(CurrentPred); + UniqueNonSwappedPreds.insert(CurrentPred); + if (!UniquePreds.contains(CurrentPred) && + !UniquePreds.contains(SwappedCurrentPred)) + UniquePreds.insert(CurrentPred); + } + // Total number of predicates > 2, but if consider swapped predicates + // compatible only 2, consider swappable predicates as compatible opcodes, + // not alternate. + return UniqueNonSwappedPreds.size() > 2 && UniquePreds.size() == 2; + }(); // Check for one alternate opcode from another BinaryOperator. // TODO - generalize to support all operators (types, calls etc.). auto *IBase = cast<Instruction>(VL[BaseIndex]); @@ -696,7 +842,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL, CmpInst::Predicate SwappedCurrentPred = CmpInst::getSwappedPredicate(CurrentPred); - if (E == 2 && + if ((E == 2 || SwappedPredsCompatible) && (BasePred == CurrentPred || BasePred == SwappedCurrentPred)) continue; @@ -734,11 +880,11 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL, auto *CallBase = cast<CallInst>(IBase); if (Call->getCalledFunction() != CallBase->getCalledFunction()) return InstructionsState(VL[BaseIndex], nullptr, nullptr); - if (Call->hasOperandBundles() && + if (Call->hasOperandBundles() && (!CallBase->hasOperandBundles() || !std::equal(Call->op_begin() + Call->getBundleOperandsStartIndex(), Call->op_begin() + Call->getBundleOperandsEndIndex(), CallBase->op_begin() + - CallBase->getBundleOperandsStartIndex())) + CallBase->getBundleOperandsStartIndex()))) return InstructionsState(VL[BaseIndex], nullptr, nullptr); Intrinsic::ID ID = getVectorIntrinsicIDForCall(Call, &TLI); if (ID != BaseID) @@ -858,7 +1004,7 @@ static void addMask(SmallVectorImpl<int> &Mask, ArrayRef<int> SubMask, /// values 3 and 7 respectively: /// before: 6 9 5 4 9 2 1 0 /// after: 6 3 5 4 7 2 1 0 -static void fixupOrderingIndices(SmallVectorImpl<unsigned> &Order) { +static void fixupOrderingIndices(MutableArrayRef<unsigned> Order) { const unsigned Sz = Order.size(); SmallBitVector UnusedIndices(Sz, /*t=*/true); SmallBitVector MaskedIndices(Sz); @@ -882,6 +1028,17 @@ static void fixupOrderingIndices(SmallVectorImpl<unsigned> &Order) { } } +/// \returns a bitset for selecting opcodes. false for Opcode0 and true for +/// Opcode1. +SmallBitVector getAltInstrMask(ArrayRef<Value *> VL, unsigned Opcode0, + unsigned Opcode1) { + SmallBitVector OpcodeMask(VL.size(), false); + for (unsigned Lane : seq<unsigned>(VL.size())) + if (cast<Instruction>(VL[Lane])->getOpcode() == Opcode1) + OpcodeMask.set(Lane); + return OpcodeMask; +} + namespace llvm { static void inversePermutation(ArrayRef<unsigned> Indices, @@ -898,7 +1055,7 @@ static void reorderScalars(SmallVectorImpl<Value *> &Scalars, ArrayRef<int> Mask) { assert(!Mask.empty() && "Expected non-empty mask."); SmallVector<Value *> Prev(Scalars.size(), - UndefValue::get(Scalars.front()->getType())); + PoisonValue::get(Scalars.front()->getType())); Prev.swap(Scalars); for (unsigned I = 0, E = Prev.size(); I < E; ++I) if (Mask[I] != PoisonMaskElem) @@ -931,7 +1088,6 @@ static bool isUsedOutsideBlock(Value *V) { if (!I) return true; // Limits the number of uses to save compile time. - constexpr int UsesLimit = 8; return !I->mayReadOrWriteMemory() && !I->hasNUsesOrMore(UsesLimit) && all_of(I->users(), [I](User *U) { auto *IU = dyn_cast<Instruction>(U); @@ -967,6 +1123,14 @@ class BoUpSLP { class ShuffleInstructionBuilder; public: + /// Tracks the state we can represent the loads in the given sequence. + enum class LoadsState { + Gather, + Vectorize, + ScatterVectorize, + StridedVectorize + }; + using ValueList = SmallVector<Value *, 8>; using InstrList = SmallVector<Instruction *, 16>; using ValueSet = SmallPtrSet<Value *, 16>; @@ -979,8 +1143,9 @@ public: TargetLibraryInfo *TLi, AAResults *Aa, LoopInfo *Li, DominatorTree *Dt, AssumptionCache *AC, DemandedBits *DB, const DataLayout *DL, OptimizationRemarkEmitter *ORE) - : BatchAA(*Aa), F(Func), SE(Se), TTI(Tti), TLI(TLi), LI(Li), - DT(Dt), AC(AC), DB(DB), DL(DL), ORE(ORE), Builder(Se->getContext()) { + : BatchAA(*Aa), F(Func), SE(Se), TTI(Tti), TLI(TLi), LI(Li), DT(Dt), + AC(AC), DB(DB), DL(DL), ORE(ORE), + Builder(Se->getContext(), TargetFolder(*DL)) { CodeMetrics::collectEphemeralValues(F, AC, EphValues); // Use the vector register size specified by the target unless overridden // by a command-line option. @@ -1043,6 +1208,12 @@ public: return VectorizableTree.front()->Scalars; } + /// Checks if the root graph node can be emitted with narrower bitwidth at + /// codegen and returns it signedness, if so. + bool isSignedMinBitwidthRootNode() const { + return MinBWs.at(VectorizableTree.front().get()).second; + } + /// Builds external uses of the vectorized scalars, i.e. the list of /// vectorized scalars to be extracted, their lanes and their scalar users. \p /// ExternallyUsedValues contains additional list of external uses to handle @@ -1050,19 +1221,27 @@ public: void buildExternalUses(const ExtraValueToDebugLocsMap &ExternallyUsedValues = {}); + /// Transforms graph nodes to target specific representations, if profitable. + void transformNodes(); + /// Clear the internal data structures that are created by 'buildTree'. void deleteTree() { VectorizableTree.clear(); ScalarToTreeEntry.clear(); MultiNodeScalars.clear(); MustGather.clear(); + NonScheduledFirst.clear(); EntryToLastInstruction.clear(); ExternalUses.clear(); + ExternalUsesAsGEPs.clear(); for (auto &Iter : BlocksSchedules) { BlockScheduling *BS = Iter.second.get(); BS->clear(); } MinBWs.clear(); + ReductionBitWidth = 0; + CastMaxMinBWSizes.reset(); + ExtraBitWidthNodes.clear(); InstrElementSize.clear(); UserIgnoreList = nullptr; PostponedGathers.clear(); @@ -1169,7 +1348,20 @@ public: /// effectively impossible for the backend to undo. /// TODO: If load combining is allowed in the IR optimizer, this analysis /// may not be necessary. - bool isLoadCombineCandidate() const; + bool isLoadCombineCandidate(ArrayRef<Value *> Stores) const; + + /// Checks if the given array of loads can be represented as a vectorized, + /// scatter or just simple gather. + /// \param VL list of loads. + /// \param VL0 main load value. + /// \param Order returned order of load instructions. + /// \param PointerOps returned list of pointer operands. + /// \param TryRecursiveCheck used to check if long masked gather can be + /// represented as a serie of loads/insert subvector, if profitable. + LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0, + SmallVectorImpl<unsigned> &Order, + SmallVectorImpl<Value *> &PointerOps, + bool TryRecursiveCheck = true) const; OptimizationRemarkEmitter *getORE() { return ORE; } @@ -1275,8 +1467,7 @@ public: // Retruns true if the users of V1 and V2 won't need to be extracted. auto AllUsersAreInternal = [U1, U2, this](Value *V1, Value *V2) { // Bail out if we have too many uses to save compilation time. - static constexpr unsigned Limit = 8; - if (V1->hasNUsesOrMore(Limit) || V2->hasNUsesOrMore(Limit)) + if (V1->hasNUsesOrMore(UsesLimit) || V2->hasNUsesOrMore(UsesLimit)) return false; auto AllUsersVectorized = [U1, U2, this](Value *V) { @@ -1296,12 +1487,19 @@ public: return LookAheadHeuristics::ScoreSplat; } + auto CheckSameEntryOrFail = [&]() { + if (const TreeEntry *TE1 = R.getTreeEntry(V1); + TE1 && TE1 == R.getTreeEntry(V2)) + return LookAheadHeuristics::ScoreSplatLoads; + return LookAheadHeuristics::ScoreFail; + }; + auto *LI1 = dyn_cast<LoadInst>(V1); auto *LI2 = dyn_cast<LoadInst>(V2); if (LI1 && LI2) { if (LI1->getParent() != LI2->getParent() || !LI1->isSimple() || !LI2->isSimple()) - return LookAheadHeuristics::ScoreFail; + return CheckSameEntryOrFail(); std::optional<int> Dist = getPointersDiff( LI1->getType(), LI1->getPointerOperand(), LI2->getType(), @@ -1310,10 +1508,9 @@ public: if (getUnderlyingObject(LI1->getPointerOperand()) == getUnderlyingObject(LI2->getPointerOperand()) && R.TTI->isLegalMaskedGather( - FixedVectorType::get(LI1->getType(), NumLanes), - LI1->getAlign())) + getWidenedType(LI1->getType(), NumLanes), LI1->getAlign())) return LookAheadHeuristics::ScoreMaskedGatherCandidate; - return LookAheadHeuristics::ScoreFail; + return CheckSameEntryOrFail(); } // The distance is too large - still may be profitable to use masked // loads/gathers. @@ -1370,14 +1567,14 @@ public: } return LookAheadHeuristics::ScoreAltOpcodes; } - return LookAheadHeuristics::ScoreFail; + return CheckSameEntryOrFail(); } auto *I1 = dyn_cast<Instruction>(V1); auto *I2 = dyn_cast<Instruction>(V2); if (I1 && I2) { if (I1->getParent() != I2->getParent()) - return LookAheadHeuristics::ScoreFail; + return CheckSameEntryOrFail(); SmallVector<Value *, 4> Ops(MainAltOps.begin(), MainAltOps.end()); Ops.push_back(I1); Ops.push_back(I2); @@ -1398,7 +1595,7 @@ public: if (isa<UndefValue>(V2)) return LookAheadHeuristics::ScoreUndef; - return LookAheadHeuristics::ScoreFail; + return CheckSameEntryOrFail(); } /// Go through the operands of \p LHS and \p RHS recursively until @@ -1561,6 +1758,7 @@ public: const DataLayout &DL; ScalarEvolution &SE; const BoUpSLP &R; + const Loop *L = nullptr; /// \returns the operand data at \p OpIdx and \p Lane. OperandData &getData(unsigned OpIdx, unsigned Lane) { @@ -1729,8 +1927,9 @@ public: // Track if the operand must be marked as used. If the operand is set to // Score 1 explicitly (because of non power-of-2 unique scalars, we may // want to reestimate the operands again on the following iterations). - bool IsUsed = - RMode == ReorderingMode::Splat || RMode == ReorderingMode::Constant; + bool IsUsed = RMode == ReorderingMode::Splat || + RMode == ReorderingMode::Constant || + RMode == ReorderingMode::Load; // Iterate through all unused operands and look for the best. for (unsigned Idx = 0; Idx != NumOperands; ++Idx) { // Get the operand at Idx and Lane. @@ -1751,23 +1950,44 @@ public: // Look for an operand that matches the current mode. switch (RMode) { case ReorderingMode::Load: - case ReorderingMode::Constant: case ReorderingMode::Opcode: { bool LeftToRight = Lane > LastLane; Value *OpLeft = (LeftToRight) ? OpLastLane : Op; Value *OpRight = (LeftToRight) ? Op : OpLastLane; int Score = getLookAheadScore(OpLeft, OpRight, MainAltOps, Lane, OpIdx, Idx, IsUsed); - if (Score > static_cast<int>(BestOp.Score)) { + if (Score > static_cast<int>(BestOp.Score) || + (Score > 0 && Score == static_cast<int>(BestOp.Score) && + Idx == OpIdx)) { BestOp.Idx = Idx; BestOp.Score = Score; BestScoresPerLanes[std::make_pair(OpIdx, Lane)] = Score; } break; } + case ReorderingMode::Constant: + if (isa<Constant>(Op) || + (!BestOp.Score && L && L->isLoopInvariant(Op))) { + BestOp.Idx = Idx; + if (isa<Constant>(Op)) { + BestOp.Score = LookAheadHeuristics::ScoreConstants; + BestScoresPerLanes[std::make_pair(OpIdx, Lane)] = + LookAheadHeuristics::ScoreConstants; + } + if (isa<UndefValue>(Op) || !isa<Constant>(Op)) + IsUsed = false; + } + break; case ReorderingMode::Splat: - if (Op == OpLastLane) + if (Op == OpLastLane || (!BestOp.Score && isa<Constant>(Op))) { + IsUsed = Op == OpLastLane; + if (Op == OpLastLane) { + BestOp.Score = LookAheadHeuristics::ScoreSplat; + BestScoresPerLanes[std::make_pair(OpIdx, Lane)] = + LookAheadHeuristics::ScoreSplat; + } BestOp.Idx = Idx; + } break; case ReorderingMode::Failed: llvm_unreachable("Not expected Failed reordering mode."); @@ -1915,6 +2135,9 @@ public: "Expected same number of lanes"); assert(isa<Instruction>(VL[0]) && "Expected instruction"); unsigned NumOperands = cast<Instruction>(VL[0])->getNumOperands(); + constexpr unsigned IntrinsicNumOperands = 2; + if (isa<IntrinsicInst>(VL[0])) + NumOperands = IntrinsicNumOperands; OpsVec.resize(NumOperands); unsigned NumLanes = VL.size(); for (unsigned OpIdx = 0; OpIdx != NumOperands; ++OpIdx) { @@ -1957,10 +2180,12 @@ public: void clear() { OpsVec.clear(); } /// \Returns true if there are enough operands identical to \p Op to fill - /// the whole vector. + /// the whole vector (it is mixed with constants or loop invariant values). /// Note: This modifies the 'IsUsed' flag, so a cleanUsed() must follow. bool shouldBroadcast(Value *Op, unsigned OpIdx, unsigned Lane) { bool OpAPO = getData(OpIdx, Lane).APO; + bool IsInvariant = L && L->isLoopInvariant(Op); + unsigned Cnt = 0; for (unsigned Ln = 0, Lns = getNumLanes(); Ln != Lns; ++Ln) { if (Ln == Lane) continue; @@ -1970,23 +2195,72 @@ public: OperandData &Data = getData(OpI, Ln); if (Data.APO != OpAPO || Data.IsUsed) continue; - if (Data.V == Op) { + Value *OpILane = getValue(OpI, Lane); + bool IsConstantOp = isa<Constant>(OpILane); + // Consider the broadcast candidate if: + // 1. Same value is found in one of the operands. + if (Data.V == Op || + // 2. The operand in the given lane is not constant but there is a + // constant operand in another lane (which can be moved to the + // given lane). In this case we can represent it as a simple + // permutation of constant and broadcast. + (!IsConstantOp && + ((Lns > 2 && isa<Constant>(Data.V)) || + // 2.1. If we have only 2 lanes, need to check that value in the + // next lane does not build same opcode sequence. + (Lns == 2 && + !getSameOpcode({Op, getValue((OpI + 1) % OpE, Ln)}, TLI) + .getOpcode() && + isa<Constant>(Data.V)))) || + // 3. The operand in the current lane is loop invariant (can be + // hoisted out) and another operand is also a loop invariant + // (though not a constant). In this case the whole vector can be + // hoisted out. + // FIXME: need to teach the cost model about this case for better + // estimation. + (IsInvariant && !isa<Constant>(Data.V) && + !getSameOpcode({Op, Data.V}, TLI).getOpcode() && + L->isLoopInvariant(Data.V))) { FoundCandidate = true; - Data.IsUsed = true; + Data.IsUsed = Data.V == Op; + if (Data.V == Op) + ++Cnt; break; } } if (!FoundCandidate) return false; } - return true; + return getNumLanes() == 2 || Cnt > 1; + } + + /// Checks if there is at least single compatible operand in lanes other + /// than \p Lane, compatible with the operand \p Op. + bool canBeVectorized(Instruction *Op, unsigned OpIdx, unsigned Lane) const { + bool OpAPO = getData(OpIdx, Lane).APO; + for (unsigned Ln = 0, Lns = getNumLanes(); Ln != Lns; ++Ln) { + if (Ln == Lane) + continue; + if (any_of(seq<unsigned>(getNumOperands()), [&](unsigned OpI) { + const OperandData &Data = getData(OpI, Ln); + if (Data.APO != OpAPO || Data.IsUsed) + return true; + Value *OpILn = getValue(OpI, Ln); + return (L && L->isLoopInvariant(OpILn)) || + (getSameOpcode({Op, OpILn}, TLI).getOpcode() && + Op->getParent() == cast<Instruction>(OpILn)->getParent()); + })) + return true; + } + return false; } public: /// Initialize with all the operands of the instruction vector \p RootVL. - VLOperands(ArrayRef<Value *> RootVL, const TargetLibraryInfo &TLI, - const DataLayout &DL, ScalarEvolution &SE, const BoUpSLP &R) - : TLI(TLI), DL(DL), SE(SE), R(R) { + VLOperands(ArrayRef<Value *> RootVL, const BoUpSLP &R) + : TLI(*R.TLI), DL(*R.DL), SE(*R.SE), R(R), + L(R.LI->getLoopFor( + (cast<Instruction>(RootVL.front())->getParent()))) { // Append all the operands of RootVL. appendOperandsOfVL(RootVL); } @@ -2036,14 +2310,14 @@ public: // side. if (isa<LoadInst>(OpLane0)) ReorderingModes[OpIdx] = ReorderingMode::Load; - else if (isa<Instruction>(OpLane0)) { + else if (auto *OpILane0 = dyn_cast<Instruction>(OpLane0)) { // Check if OpLane0 should be broadcast. - if (shouldBroadcast(OpLane0, OpIdx, FirstLane)) + if (shouldBroadcast(OpLane0, OpIdx, FirstLane) || + !canBeVectorized(OpILane0, OpIdx, FirstLane)) ReorderingModes[OpIdx] = ReorderingMode::Splat; else ReorderingModes[OpIdx] = ReorderingMode::Opcode; - } - else if (isa<Constant>(OpLane0)) + } else if (isa<Constant>(OpLane0)) ReorderingModes[OpIdx] = ReorderingMode::Constant; else if (isa<Argument>(OpLane0)) // Our best hope is a Splat. It may save some cost in some cases. @@ -2118,8 +2392,6 @@ public: // getBestOperand(). swap(OpIdx, *BestIdx, Lane); } else { - // We failed to find a best operand, set mode to 'Failed'. - ReorderingModes[OpIdx] = ReorderingMode::Failed; // Enable the second pass. StrategyFailed = true; } @@ -2201,7 +2473,7 @@ public: /// of the cost, considered to be good enough score. std::optional<int> findBestRootPair(ArrayRef<std::pair<Value *, Value *>> Candidates, - int Limit = LookAheadHeuristics::ScoreFail) { + int Limit = LookAheadHeuristics::ScoreFail) const { LookAheadHeuristics LookAhead(*TLI, *DL, *SE, *this, /*NumLanes=*/2, RootLookAheadMaxDepth); int BestScore = Limit; @@ -2229,6 +2501,91 @@ public: DeletedInstructions.insert(I); } + /// Remove instructions from the parent function and clear the operands of \p + /// DeadVals instructions, marking for deletion trivially dead operands. + template <typename T> + void removeInstructionsAndOperands(ArrayRef<T *> DeadVals) { + SmallVector<WeakTrackingVH> DeadInsts; + for (T *V : DeadVals) { + auto *I = cast<Instruction>(V); + DeletedInstructions.insert(I); + } + DenseSet<Value *> Processed; + for (T *V : DeadVals) { + if (!V || !Processed.insert(V).second) + continue; + auto *I = cast<Instruction>(V); + salvageDebugInfo(*I); + SmallVector<const TreeEntry *> Entries; + if (const TreeEntry *Entry = getTreeEntry(I)) { + Entries.push_back(Entry); + auto It = MultiNodeScalars.find(I); + if (It != MultiNodeScalars.end()) + Entries.append(It->second.begin(), It->second.end()); + } + for (Use &U : I->operands()) { + if (auto *OpI = dyn_cast_if_present<Instruction>(U.get()); + OpI && !DeletedInstructions.contains(OpI) && OpI->hasOneUser() && + wouldInstructionBeTriviallyDead(OpI, TLI) && + (Entries.empty() || none_of(Entries, [&](const TreeEntry *Entry) { + return Entry->VectorizedValue == OpI; + }))) + DeadInsts.push_back(OpI); + } + I->dropAllReferences(); + } + for (T *V : DeadVals) { + auto *I = cast<Instruction>(V); + if (!I->getParent()) + continue; + assert((I->use_empty() || all_of(I->uses(), + [&](Use &U) { + return isDeleted( + cast<Instruction>(U.getUser())); + })) && + "trying to erase instruction with users."); + I->removeFromParent(); + SE->forgetValue(I); + } + // Process the dead instruction list until empty. + while (!DeadInsts.empty()) { + Value *V = DeadInsts.pop_back_val(); + Instruction *VI = cast_or_null<Instruction>(V); + if (!VI || !VI->getParent()) + continue; + assert(isInstructionTriviallyDead(VI, TLI) && + "Live instruction found in dead worklist!"); + assert(VI->use_empty() && "Instructions with uses are not dead."); + + // Don't lose the debug info while deleting the instructions. + salvageDebugInfo(*VI); + + // Null out all of the instruction's operands to see if any operand + // becomes dead as we go. + for (Use &OpU : VI->operands()) { + Value *OpV = OpU.get(); + if (!OpV) + continue; + OpU.set(nullptr); + + if (!OpV->use_empty()) + continue; + + // If the operand is an instruction that became dead as we nulled out + // the operand, and if it is 'trivially' dead, delete it in a future + // loop iteration. + if (auto *OpI = dyn_cast<Instruction>(OpV)) + if (!DeletedInstructions.contains(OpI) && + isInstructionTriviallyDead(OpI, TLI)) + DeadInsts.push_back(OpI); + } + + VI->removeFromParent(); + DeletedInstructions.insert(VI); + SE->forgetValue(VI); + } + } + /// Checks if the instruction was already analyzed for being possible /// reduction root. bool isAnalyzedReductionRoot(Instruction *I) const { @@ -2253,11 +2610,20 @@ public: void clearReductionData() { AnalyzedReductionsRoots.clear(); AnalyzedReductionVals.clear(); + AnalyzedMinBWVals.clear(); } /// Checks if the given value is gathered in one of the nodes. bool isAnyGathered(const SmallDenseSet<Value *> &Vals) const { return any_of(MustGather, [&](Value *V) { return Vals.contains(V); }); } + /// Checks if the given value is gathered in one of the nodes. + bool isGathered(const Value *V) const { + return MustGather.contains(V); + } + /// Checks if the specified value was not schedule. + bool isNotScheduled(const Value *V) const { + return NonScheduledFirst.contains(V); + } /// Check if the value is vectorized in the tree. bool isVectorized(Value *V) const { return getTreeEntry(V); } @@ -2265,17 +2631,17 @@ public: ~BoUpSLP(); private: - /// Determine if a vectorized value \p V in can be demoted to - /// a smaller type with a truncation. We collect the values that will be - /// demoted in ToDemote and additional roots that require investigating in - /// Roots. - /// \param DemotedConsts list of Instruction/OperandIndex pairs that are - /// constant and to be demoted. Required to correctly identify constant nodes - /// to be demoted. - bool collectValuesToDemote( - Value *V, SmallVectorImpl<Value *> &ToDemote, - DenseMap<Instruction *, SmallVector<unsigned>> &DemotedConsts, - SmallVectorImpl<Value *> &Roots, DenseSet<Value *> &Visited) const; + /// Determine if a node \p E in can be demoted to a smaller type with a + /// truncation. We collect the entries that will be demoted in ToDemote. + /// \param E Node for analysis + /// \param ToDemote indices of the nodes to be demoted. + bool collectValuesToDemote(const TreeEntry &E, bool IsProfitableToDemoteRoot, + unsigned &BitWidth, + SmallVectorImpl<unsigned> &ToDemote, + DenseSet<const TreeEntry *> &Visited, + unsigned &MaxDepthLevel, + bool &IsProfitableToDemote, + bool IsTruncRoot) const; /// Check if the operands on the edges \p Edges of the \p UserTE allows /// reordering (i.e. the operands can be reordered because they have only one @@ -2341,6 +2707,10 @@ private: /// \ returns the graph entry for the \p Idx operand of the \p E entry. const TreeEntry *getOperandEntry(const TreeEntry *E, unsigned Idx) const; + /// \returns Cast context for the given graph node. + TargetTransformInfo::CastContextHint + getCastContextHint(const TreeEntry &TE) const; + /// \returns the cost of the vectorizable entry. InstructionCost getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, @@ -2376,12 +2746,12 @@ private: /// which exploits values reused across lanes, and arranges the inserts /// for ease of later optimization. template <typename BVTy, typename ResTy, typename... Args> - ResTy processBuildVector(const TreeEntry *E, Args &...Params); + ResTy processBuildVector(const TreeEntry *E, Type *ScalarTy, Args &...Params); /// Create a new vector from a list of scalar values. Produces a sequence /// which exploits values reused across lanes, and arranges the inserts /// for ease of later optimization. - Value *createBuildVector(const TreeEntry *E); + Value *createBuildVector(const TreeEntry *E, Type *ScalarTy); /// Returns the instruction in the bundle, which can be used as a base point /// for scheduling. Usually it is the last instruction in the bundle, except @@ -2413,18 +2783,25 @@ private: /// \param TE Tree entry checked for permutation. /// \param VL List of scalars (a subset of the TE scalar), checked for /// permutations. Must form single-register vector. + /// \param ForOrder Tries to fetch the best candidates for ordering info. Also + /// commands to build the mask using the original vector value, without + /// relying on the potential reordering. /// \returns ShuffleKind, if gathered values can be represented as shuffles of /// previous tree entries. \p Part of \p Mask is filled with the shuffle mask. std::optional<TargetTransformInfo::ShuffleKind> isGatherShuffledSingleRegisterEntry( const TreeEntry *TE, ArrayRef<Value *> VL, MutableArrayRef<int> Mask, - SmallVectorImpl<const TreeEntry *> &Entries, unsigned Part); + SmallVectorImpl<const TreeEntry *> &Entries, unsigned Part, + bool ForOrder); /// Checks if the gathered \p VL can be represented as multi-register /// shuffle(s) of previous tree entries. /// \param TE Tree entry checked for permutation. /// \param VL List of scalars (a subset of the TE scalar), checked for /// permutations. + /// \param ForOrder Tries to fetch the best candidates for ordering info. Also + /// commands to build the mask using the original vector value, without + /// relying on the potential reordering. /// \returns per-register series of ShuffleKind, if gathered values can be /// represented as shuffles of previous tree entries. \p Mask is filled with /// the shuffle mask (also on per-register base). @@ -2432,13 +2809,14 @@ private: isGatherShuffledEntry( const TreeEntry *TE, ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask, SmallVectorImpl<SmallVector<const TreeEntry *>> &Entries, - unsigned NumParts); + unsigned NumParts, bool ForOrder = false); /// \returns the scalarization cost for this list of values. Assuming that /// this subtree gets vectorized, we may need to extract the values from the /// roots. This method calculates the cost of extracting the values. /// \param ForPoisonSrc true if initial vector is poison, false otherwise. - InstructionCost getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc) const; + InstructionCost getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc, + Type *ScalarTy) const; /// Set the Builder insert point to one after the last instruction in /// the bundle @@ -2446,7 +2824,7 @@ private: /// \returns a vector from a collection of scalars in \p VL. if \p Root is not /// specified, the starting vector value is poison. - Value *gather(ArrayRef<Value *> VL, Value *Root); + Value *gather(ArrayRef<Value *> VL, Value *Root, Type *ScalarTy); /// \returns whether the VectorizableTree is fully vectorizable and will /// be beneficial even the tree height is tiny. @@ -2454,10 +2832,10 @@ private: /// Reorder commutative or alt operands to get better probability of /// generating vectorized code. - static void reorderInputsAccordingToOpcode( - ArrayRef<Value *> VL, SmallVectorImpl<Value *> &Left, - SmallVectorImpl<Value *> &Right, const TargetLibraryInfo &TLI, - const DataLayout &DL, ScalarEvolution &SE, const BoUpSLP &R); + static void reorderInputsAccordingToOpcode(ArrayRef<Value *> VL, + SmallVectorImpl<Value *> &Left, + SmallVectorImpl<Value *> &Right, + const BoUpSLP &R); /// Helper for `findExternalStoreUsersReorderIndices()`. It iterates over the /// users of \p TE and collects the stores. It returns the map from the store @@ -2524,8 +2902,7 @@ private: } bool isOperandGatherNode(const EdgeInfo &UserEI) const { - return State == TreeEntry::NeedToGather && - UserTreeIndices.front().EdgeIdx == UserEI.EdgeIdx && + return isGather() && UserTreeIndices.front().EdgeIdx == UserEI.EdgeIdx && UserTreeIndices.front().UserTE == UserEI.UserTE; } @@ -2560,6 +2937,9 @@ private: return Scalars.size(); }; + /// Checks if the current node is a gather node. + bool isGather() const {return State == NeedToGather; } + /// A vector of scalars. ValueList Scalars; @@ -2575,7 +2955,7 @@ private: enum EntryState { Vectorize, ScatterVectorize, - PossibleStridedVectorize, + StridedVectorize, NeedToGather }; EntryState State; @@ -2733,6 +3113,14 @@ private: SmallVectorImpl<Value *> *OpScalars = nullptr, SmallVectorImpl<Value *> *AltScalars = nullptr) const; + /// Return true if this is a non-power-of-2 node. + bool isNonPowOf2Vec() const { + bool IsNonPowerOf2 = !isPowerOf2_32(Scalars.size()); + assert((!IsNonPowerOf2 || ReuseShuffleIndices.empty()) && + "Reshuffling not supported with non-power-of-2 vectors yet."); + return IsNonPowerOf2; + } + #ifndef NDEBUG /// Debug printer. LLVM_DUMP_METHOD void dump() const { @@ -2753,8 +3141,8 @@ private: case ScatterVectorize: dbgs() << "ScatterVectorize\n"; break; - case PossibleStridedVectorize: - dbgs() << "PossibleStridedVectorize\n"; + case StridedVectorize: + dbgs() << "StridedVectorize\n"; break; case NeedToGather: dbgs() << "NeedToGather\n"; @@ -2854,7 +3242,7 @@ private: Last->setOperations(S); Last->ReorderIndices.append(ReorderIndices.begin(), ReorderIndices.end()); } - if (Last->State != TreeEntry::NeedToGather) { + if (!Last->isGather()) { for (Value *V : VL) { const TreeEntry *TE = getTreeEntry(V); assert((!TE || TE == Last || doesNotNeedToBeScheduled(V)) && @@ -2884,16 +3272,25 @@ private: } assert(!BundleMember && "Bundle and VL out of sync"); } else { - MustGather.insert(VL.begin(), VL.end()); // Build a map for gathered scalars to the nodes where they are used. + bool AllConstsOrCasts = true; for (Value *V : VL) - if (!isConstant(V)) + if (!isConstant(V)) { + auto *I = dyn_cast<CastInst>(V); + AllConstsOrCasts &= I && I->getType()->isIntegerTy(); ValueToGatherNodes.try_emplace(V).first->getSecond().insert(Last); + } + if (AllConstsOrCasts) + CastMaxMinBWSizes = + std::make_pair(std::numeric_limits<unsigned>::max(), 1); + MustGather.insert(VL.begin(), VL.end()); } - if (UserTreeIdx.UserTE) + if (UserTreeIdx.UserTE) { Last->UserTreeIndices.push_back(UserTreeIdx); - + assert((!Last->isNonPowOf2Vec() || Last->ReorderIndices.empty()) && + "Reordering isn't implemented for non-power-of-2 nodes yet"); + } return Last; } @@ -2917,6 +3314,15 @@ private: return ScalarToTreeEntry.lookup(V); } + /// Check that the operand node of alternate node does not generate + /// buildvector sequence. If it is, then probably not worth it to build + /// alternate shuffle, if number of buildvector operands + alternate + /// instruction > than the number of buildvector instructions. + /// \param S the instructions state of the analyzed values. + /// \param VL list of the instructions with alternate opcodes. + bool areAltOperandsProfitable(const InstructionsState &S, + ArrayRef<Value *> VL) const; + /// Checks if the specified list of the instructions/values can be vectorized /// and fills required data before actual scheduling of the instructions. TreeEntry::EntryState getScalarsVectorizationState( @@ -2936,6 +3342,9 @@ private: /// A list of scalars that we found that we need to keep as scalars. ValueSet MustGather; + /// A set of first non-schedulable values. + ValueSet NonScheduledFirst; + /// A map between the vectorized entries and the last instructions in the /// bundles. The bundles are built in use order, not in the def order of the /// instructions. So, we cannot rely directly on the last instruction in the @@ -3013,12 +3422,20 @@ private: /// Set of hashes for the list of reduction values already being analyzed. DenseSet<size_t> AnalyzedReductionVals; + /// Values, already been analyzed for mininmal bitwidth and found to be + /// non-profitable. + DenseSet<Value *> AnalyzedMinBWVals; + /// A list of values that need to extracted out of the tree. /// This list holds pairs of (Internal Scalar : External User). External User /// can be nullptr, it means that this Internal Scalar will be used later, /// after vectorization. UserList ExternalUses; + /// A list of GEPs which can be reaplced by scalar GEPs instead of + /// extractelement instructions. + SmallPtrSet<Value *, 4> ExternalUsesAsGEPs; + /// Values used only by @llvm.assume calls. SmallPtrSet<const Value *, 32> EphValues; @@ -3336,10 +3753,11 @@ private: // immediates do not affect scheduler behavior this is considered // okay. auto *In = BundleMember->Inst; - assert(In && - (isa<ExtractValueInst, ExtractElementInst>(In) || - In->getNumOperands() == TE->getNumOperands()) && - "Missed TreeEntry operands?"); + assert( + In && + (isa<ExtractValueInst, ExtractElementInst, IntrinsicInst>(In) || + In->getNumOperands() == TE->getNumOperands()) && + "Missed TreeEntry operands?"); (void)In; // fake use to avoid build failure when assertions disabled for (unsigned OpIdx = 0, NumOperands = TE->getNumOperands(); @@ -3580,7 +3998,7 @@ private: unsigned MinVecRegSize; // Set by cl::opt (default: 128). /// Instruction builder to construct the vectorized tree. - IRBuilder<> Builder; + IRBuilder<TargetFolder> Builder; /// A map of scalar integer values to the smallest bit width with which they /// can legally be represented. The values map to (width, signed) pairs, @@ -3588,6 +4006,19 @@ private: /// value must be signed-extended, rather than zero-extended, back to its /// original width. DenseMap<const TreeEntry *, std::pair<uint64_t, bool>> MinBWs; + + /// Final size of the reduced vector, if the current graph represents the + /// input for the reduction and it was possible to narrow the size of the + /// reduction. + unsigned ReductionBitWidth = 0; + + /// If the tree contains any zext/sext/trunc nodes, contains max-min pair of + /// type sizes, used in the tree. + std::optional<std::pair<unsigned, unsigned>> CastMaxMinBWSizes; + + /// Indices of the vectorized nodes, which supposed to be the roots of the new + /// bitwidth analysis attempt, like trunc, IToFP or ICmp. + DenseSet<unsigned> ExtraBitWidthNodes; }; } // end namespace slpvectorizer @@ -3677,10 +4108,10 @@ template <> struct DOTGraphTraits<BoUpSLP *> : public DefaultDOTGraphTraits { static std::string getNodeAttributes(const TreeEntry *Entry, const BoUpSLP *) { - if (Entry->State == TreeEntry::NeedToGather) + if (Entry->isGather()) return "color=red"; if (Entry->State == TreeEntry::ScatterVectorize || - Entry->State == TreeEntry::PossibleStridedVectorize) + Entry->State == TreeEntry::StridedVectorize) return "color=blue"; return ""; } @@ -3691,6 +4122,17 @@ template <> struct DOTGraphTraits<BoUpSLP *> : public DefaultDOTGraphTraits { BoUpSLP::~BoUpSLP() { SmallVector<WeakTrackingVH> DeadInsts; for (auto *I : DeletedInstructions) { + if (!I->getParent()) { + // Temporarily insert instruction back to erase them from parent and + // memory later. + if (isa<PHINode>(I)) + // Phi nodes must be the very first instructions in the block. + I->insertBefore(F->getEntryBlock(), + F->getEntryBlock().getFirstNonPHIIt()); + else + I->insertBefore(F->getEntryBlock().getTerminator()); + continue; + } for (Use &U : I->operands()) { auto *Op = dyn_cast<Instruction>(U.get()); if (Op && !DeletedInstructions.count(Op) && Op->hasOneUser() && @@ -3732,22 +4174,45 @@ static void reorderReuses(SmallVectorImpl<int> &Reuses, ArrayRef<int> Mask) { /// the original order of the scalars. Procedure transforms the provided order /// in accordance with the given \p Mask. If the resulting \p Order is just an /// identity order, \p Order is cleared. -static void reorderOrder(SmallVectorImpl<unsigned> &Order, ArrayRef<int> Mask) { +static void reorderOrder(SmallVectorImpl<unsigned> &Order, ArrayRef<int> Mask, + bool BottomOrder = false) { assert(!Mask.empty() && "Expected non-empty mask."); + unsigned Sz = Mask.size(); + if (BottomOrder) { + SmallVector<unsigned> PrevOrder; + if (Order.empty()) { + PrevOrder.resize(Sz); + std::iota(PrevOrder.begin(), PrevOrder.end(), 0); + } else { + PrevOrder.swap(Order); + } + Order.assign(Sz, Sz); + for (unsigned I = 0; I < Sz; ++I) + if (Mask[I] != PoisonMaskElem) + Order[I] = PrevOrder[Mask[I]]; + if (all_of(enumerate(Order), [&](const auto &Data) { + return Data.value() == Sz || Data.index() == Data.value(); + })) { + Order.clear(); + return; + } + fixupOrderingIndices(Order); + return; + } SmallVector<int> MaskOrder; if (Order.empty()) { - MaskOrder.resize(Mask.size()); + MaskOrder.resize(Sz); std::iota(MaskOrder.begin(), MaskOrder.end(), 0); } else { inversePermutation(Order, MaskOrder); } reorderReuses(MaskOrder, Mask); - if (ShuffleVectorInst::isIdentityMask(MaskOrder, MaskOrder.size())) { + if (ShuffleVectorInst::isIdentityMask(MaskOrder, Sz)) { Order.clear(); return; } - Order.assign(Mask.size(), Mask.size()); - for (unsigned I = 0, E = Mask.size(); I < E; ++I) + Order.assign(Sz, Sz); + for (unsigned I = 0; I < Sz; ++I) if (MaskOrder[I] != PoisonMaskElem) Order[MaskOrder[I]] = I; fixupOrderingIndices(Order); @@ -3755,78 +4220,168 @@ static void reorderOrder(SmallVectorImpl<unsigned> &Order, ArrayRef<int> Mask) { std::optional<BoUpSLP::OrdersType> BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) { - assert(TE.State == TreeEntry::NeedToGather && "Expected gather node only."); - unsigned NumScalars = TE.Scalars.size(); + assert(TE.isGather() && "Expected gather node only."); + // Try to find subvector extract/insert patterns and reorder only such + // patterns. + SmallVector<Value *> GatheredScalars(TE.Scalars.begin(), TE.Scalars.end()); + Type *ScalarTy = GatheredScalars.front()->getType(); + int NumScalars = GatheredScalars.size(); + if (!isValidElementType(ScalarTy)) + return std::nullopt; + auto *VecTy = getWidenedType(ScalarTy, NumScalars); + int NumParts = TTI->getNumberOfParts(VecTy); + if (NumParts == 0 || NumParts >= NumScalars) + NumParts = 1; + SmallVector<int> ExtractMask; + SmallVector<int> Mask; + SmallVector<SmallVector<const TreeEntry *>> Entries; + SmallVector<std::optional<TargetTransformInfo::ShuffleKind>> ExtractShuffles = + tryToGatherExtractElements(GatheredScalars, ExtractMask, NumParts); + SmallVector<std::optional<TargetTransformInfo::ShuffleKind>> GatherShuffles = + isGatherShuffledEntry(&TE, GatheredScalars, Mask, Entries, NumParts, + /*ForOrder=*/true); + // No shuffled operands - ignore. + if (GatherShuffles.empty() && ExtractShuffles.empty()) + return std::nullopt; OrdersType CurrentOrder(NumScalars, NumScalars); - SmallVector<int> Positions; - SmallBitVector UsedPositions(NumScalars); - const TreeEntry *STE = nullptr; - // Try to find all gathered scalars that are gets vectorized in other - // vectorize node. Here we can have only one single tree vector node to - // correctly identify order of the gathered scalars. - for (unsigned I = 0; I < NumScalars; ++I) { - Value *V = TE.Scalars[I]; - if (!isa<LoadInst, ExtractElementInst, ExtractValueInst>(V)) - continue; - if (const auto *LocalSTE = getTreeEntry(V)) { - if (!STE) - STE = LocalSTE; - else if (STE != LocalSTE) - // Take the order only from the single vector node. - return std::nullopt; - unsigned Lane = - std::distance(STE->Scalars.begin(), find(STE->Scalars, V)); - if (Lane >= NumScalars) - return std::nullopt; - if (CurrentOrder[Lane] != NumScalars) { - if (Lane != I) + if (GatherShuffles.size() == 1 && + *GatherShuffles.front() == TTI::SK_PermuteSingleSrc && + Entries.front().front()->isSame(TE.Scalars)) { + // Perfect match in the graph, will reuse the previously vectorized + // node. Cost is 0. + std::iota(CurrentOrder.begin(), CurrentOrder.end(), 0); + return CurrentOrder; + } + auto IsSplatMask = [](ArrayRef<int> Mask) { + int SingleElt = PoisonMaskElem; + return all_of(Mask, [&](int I) { + if (SingleElt == PoisonMaskElem && I != PoisonMaskElem) + SingleElt = I; + return I == PoisonMaskElem || I == SingleElt; + }); + }; + // Exclusive broadcast mask - ignore. + if ((ExtractShuffles.empty() && IsSplatMask(Mask) && + (Entries.size() != 1 || + Entries.front().front()->ReorderIndices.empty())) || + (GatherShuffles.empty() && IsSplatMask(ExtractMask))) + return std::nullopt; + SmallBitVector ShuffledSubMasks(NumParts); + auto TransformMaskToOrder = [&](MutableArrayRef<unsigned> CurrentOrder, + ArrayRef<int> Mask, int PartSz, int NumParts, + function_ref<unsigned(unsigned)> GetVF) { + for (int I : seq<int>(0, NumParts)) { + if (ShuffledSubMasks.test(I)) + continue; + const int VF = GetVF(I); + if (VF == 0) + continue; + unsigned Limit = getNumElems(CurrentOrder.size(), PartSz, I); + MutableArrayRef<unsigned> Slice = CurrentOrder.slice(I * PartSz, Limit); + // Shuffle of at least 2 vectors - ignore. + if (any_of(Slice, [&](int I) { return I != NumScalars; })) { + std::fill(Slice.begin(), Slice.end(), NumScalars); + ShuffledSubMasks.set(I); + continue; + } + // Try to include as much elements from the mask as possible. + int FirstMin = INT_MAX; + int SecondVecFound = false; + for (int K : seq<int>(Limit)) { + int Idx = Mask[I * PartSz + K]; + if (Idx == PoisonMaskElem) { + Value *V = GatheredScalars[I * PartSz + K]; + if (isConstant(V) && !isa<PoisonValue>(V)) { + SecondVecFound = true; + break; + } continue; - UsedPositions.reset(CurrentOrder[Lane]); + } + if (Idx < VF) { + if (FirstMin > Idx) + FirstMin = Idx; + } else { + SecondVecFound = true; + break; + } } - // The partial identity (where only some elements of the gather node are - // in the identity order) is good. - CurrentOrder[Lane] = I; - UsedPositions.set(I); - } - } - // Need to keep the order if we have a vector entry and at least 2 scalars or - // the vectorized entry has just 2 scalars. - if (STE && (UsedPositions.count() > 1 || STE->Scalars.size() == 2)) { - auto &&IsIdentityOrder = [NumScalars](ArrayRef<unsigned> CurrentOrder) { - for (unsigned I = 0; I < NumScalars; ++I) - if (CurrentOrder[I] != I && CurrentOrder[I] != NumScalars) - return false; - return true; - }; - if (IsIdentityOrder(CurrentOrder)) - return OrdersType(); - auto *It = CurrentOrder.begin(); - for (unsigned I = 0; I < NumScalars;) { - if (UsedPositions.test(I)) { - ++I; + FirstMin = (FirstMin / PartSz) * PartSz; + // Shuffle of at least 2 vectors - ignore. + if (SecondVecFound) { + std::fill(Slice.begin(), Slice.end(), NumScalars); + ShuffledSubMasks.set(I); continue; } - if (*It == NumScalars) { - *It = I; - ++I; + for (int K : seq<int>(Limit)) { + int Idx = Mask[I * PartSz + K]; + if (Idx == PoisonMaskElem) + continue; + Idx -= FirstMin; + if (Idx >= PartSz) { + SecondVecFound = true; + break; + } + if (CurrentOrder[I * PartSz + Idx] > + static_cast<unsigned>(I * PartSz + K) && + CurrentOrder[I * PartSz + Idx] != + static_cast<unsigned>(I * PartSz + Idx)) + CurrentOrder[I * PartSz + Idx] = I * PartSz + K; + } + // Shuffle of at least 2 vectors - ignore. + if (SecondVecFound) { + std::fill(Slice.begin(), Slice.end(), NumScalars); + ShuffledSubMasks.set(I); + continue; } - ++It; } - return std::move(CurrentOrder); + }; + int PartSz = getPartNumElems(NumScalars, NumParts); + if (!ExtractShuffles.empty()) + TransformMaskToOrder( + CurrentOrder, ExtractMask, PartSz, NumParts, [&](unsigned I) { + if (!ExtractShuffles[I]) + return 0U; + unsigned VF = 0; + unsigned Sz = getNumElems(TE.getVectorFactor(), PartSz, I); + for (unsigned Idx : seq<unsigned>(Sz)) { + int K = I * PartSz + Idx; + if (ExtractMask[K] == PoisonMaskElem) + continue; + if (!TE.ReuseShuffleIndices.empty()) + K = TE.ReuseShuffleIndices[K]; + if (!TE.ReorderIndices.empty()) + K = std::distance(TE.ReorderIndices.begin(), + find(TE.ReorderIndices, K)); + auto *EI = dyn_cast<ExtractElementInst>(TE.Scalars[K]); + if (!EI) + continue; + VF = std::max(VF, cast<VectorType>(EI->getVectorOperandType()) + ->getElementCount() + .getKnownMinValue()); + } + return VF; + }); + // Check special corner case - single shuffle of the same entry. + if (GatherShuffles.size() == 1 && NumParts != 1) { + if (ShuffledSubMasks.any()) + return std::nullopt; + PartSz = NumScalars; + NumParts = 1; } - return std::nullopt; + if (!Entries.empty()) + TransformMaskToOrder(CurrentOrder, Mask, PartSz, NumParts, [&](unsigned I) { + if (!GatherShuffles[I]) + return 0U; + return std::max(Entries[I].front()->getVectorFactor(), + Entries[I].back()->getVectorFactor()); + }); + int NumUndefs = + count_if(CurrentOrder, [&](int Idx) { return Idx == NumScalars; }); + if (ShuffledSubMasks.all() || (NumScalars > 2 && NumUndefs >= NumScalars / 2)) + return std::nullopt; + return std::move(CurrentOrder); } -namespace { -/// Tracks the state we can represent the loads in the given sequence. -enum class LoadsState { - Gather, - Vectorize, - ScatterVectorize, - PossibleStridedVectorize -}; -} // anonymous namespace - static bool arePointersCompatible(Value *Ptr1, Value *Ptr2, const TargetLibraryInfo &TLI, bool CompareOpcodes = true) { @@ -3846,14 +4401,151 @@ static bool arePointersCompatible(Value *Ptr1, Value *Ptr2, .getOpcode()); } -/// Checks if the given array of loads can be represented as a vectorized, -/// scatter or just simple gather. -static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0, - const TargetTransformInfo &TTI, - const DataLayout &DL, ScalarEvolution &SE, - LoopInfo &LI, const TargetLibraryInfo &TLI, - SmallVectorImpl<unsigned> &Order, - SmallVectorImpl<Value *> &PointerOps) { +/// Calculates minimal alignment as a common alignment. +template <typename T> +static Align computeCommonAlignment(ArrayRef<Value *> VL) { + Align CommonAlignment = cast<T>(VL.front())->getAlign(); + for (Value *V : VL.drop_front()) + CommonAlignment = std::min(CommonAlignment, cast<T>(V)->getAlign()); + return CommonAlignment; +} + +/// Check if \p Order represents reverse order. +static bool isReverseOrder(ArrayRef<unsigned> Order) { + unsigned Sz = Order.size(); + return !Order.empty() && all_of(enumerate(Order), [&](const auto &Pair) { + return Pair.value() == Sz || Sz - Pair.index() - 1 == Pair.value(); + }); +} + +/// Checks if the provided list of pointers \p Pointers represents the strided +/// pointers for type ElemTy. If they are not, std::nullopt is returned. +/// Otherwise, if \p Inst is not specified, just initialized optional value is +/// returned to show that the pointers represent strided pointers. If \p Inst +/// specified, the runtime stride is materialized before the given \p Inst. +/// \returns std::nullopt if the pointers are not pointers with the runtime +/// stride, nullptr or actual stride value, otherwise. +static std::optional<Value *> +calculateRtStride(ArrayRef<Value *> PointerOps, Type *ElemTy, + const DataLayout &DL, ScalarEvolution &SE, + SmallVectorImpl<unsigned> &SortedIndices, + Instruction *Inst = nullptr) { + SmallVector<const SCEV *> SCEVs; + const SCEV *PtrSCEVLowest = nullptr; + const SCEV *PtrSCEVHighest = nullptr; + // Find lower/upper pointers from the PointerOps (i.e. with lowest and highest + // addresses). + for (Value *Ptr : PointerOps) { + const SCEV *PtrSCEV = SE.getSCEV(Ptr); + if (!PtrSCEV) + return std::nullopt; + SCEVs.push_back(PtrSCEV); + if (!PtrSCEVLowest && !PtrSCEVHighest) { + PtrSCEVLowest = PtrSCEVHighest = PtrSCEV; + continue; + } + const SCEV *Diff = SE.getMinusSCEV(PtrSCEV, PtrSCEVLowest); + if (isa<SCEVCouldNotCompute>(Diff)) + return std::nullopt; + if (Diff->isNonConstantNegative()) { + PtrSCEVLowest = PtrSCEV; + continue; + } + const SCEV *Diff1 = SE.getMinusSCEV(PtrSCEVHighest, PtrSCEV); + if (isa<SCEVCouldNotCompute>(Diff1)) + return std::nullopt; + if (Diff1->isNonConstantNegative()) { + PtrSCEVHighest = PtrSCEV; + continue; + } + } + // Dist = PtrSCEVHighest - PtrSCEVLowest; + const SCEV *Dist = SE.getMinusSCEV(PtrSCEVHighest, PtrSCEVLowest); + if (isa<SCEVCouldNotCompute>(Dist)) + return std::nullopt; + int Size = DL.getTypeStoreSize(ElemTy); + auto TryGetStride = [&](const SCEV *Dist, + const SCEV *Multiplier) -> const SCEV * { + if (const auto *M = dyn_cast<SCEVMulExpr>(Dist)) { + if (M->getOperand(0) == Multiplier) + return M->getOperand(1); + if (M->getOperand(1) == Multiplier) + return M->getOperand(0); + return nullptr; + } + if (Multiplier == Dist) + return SE.getConstant(Dist->getType(), 1); + return SE.getUDivExactExpr(Dist, Multiplier); + }; + // Stride_in_elements = Dist / element_size * (num_elems - 1). + const SCEV *Stride = nullptr; + if (Size != 1 || SCEVs.size() > 2) { + const SCEV *Sz = SE.getConstant(Dist->getType(), Size * (SCEVs.size() - 1)); + Stride = TryGetStride(Dist, Sz); + if (!Stride) + return std::nullopt; + } + if (!Stride || isa<SCEVConstant>(Stride)) + return std::nullopt; + // Iterate through all pointers and check if all distances are + // unique multiple of Stride. + using DistOrdPair = std::pair<int64_t, int>; + auto Compare = llvm::less_first(); + std::set<DistOrdPair, decltype(Compare)> Offsets(Compare); + int Cnt = 0; + bool IsConsecutive = true; + for (const SCEV *PtrSCEV : SCEVs) { + unsigned Dist = 0; + if (PtrSCEV != PtrSCEVLowest) { + const SCEV *Diff = SE.getMinusSCEV(PtrSCEV, PtrSCEVLowest); + const SCEV *Coeff = TryGetStride(Diff, Stride); + if (!Coeff) + return std::nullopt; + const auto *SC = dyn_cast<SCEVConstant>(Coeff); + if (!SC || isa<SCEVCouldNotCompute>(SC)) + return std::nullopt; + if (!SE.getMinusSCEV(PtrSCEV, SE.getAddExpr(PtrSCEVLowest, + SE.getMulExpr(Stride, SC))) + ->isZero()) + return std::nullopt; + Dist = SC->getAPInt().getZExtValue(); + } + // If the strides are not the same or repeated, we can't vectorize. + if ((Dist / Size) * Size != Dist || (Dist / Size) >= SCEVs.size()) + return std::nullopt; + auto Res = Offsets.emplace(Dist, Cnt); + if (!Res.second) + return std::nullopt; + // Consecutive order if the inserted element is the last one. + IsConsecutive = IsConsecutive && std::next(Res.first) == Offsets.end(); + ++Cnt; + } + if (Offsets.size() != SCEVs.size()) + return std::nullopt; + SortedIndices.clear(); + if (!IsConsecutive) { + // Fill SortedIndices array only if it is non-consecutive. + SortedIndices.resize(PointerOps.size()); + Cnt = 0; + for (const std::pair<int64_t, int> &Pair : Offsets) { + SortedIndices[Cnt] = Pair.second; + ++Cnt; + } + } + if (!Inst) + return nullptr; + SCEVExpander Expander(SE, DL, "strided-load-vec"); + return Expander.expandCodeFor(Stride, Stride->getType(), Inst); +} + +static std::pair<InstructionCost, InstructionCost> +getGEPCosts(const TargetTransformInfo &TTI, ArrayRef<Value *> Ptrs, + Value *BasePtr, unsigned Opcode, TTI::TargetCostKind CostKind, + Type *ScalarTy, VectorType *VecTy); + +BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads( + ArrayRef<Value *> VL, const Value *VL0, SmallVectorImpl<unsigned> &Order, + SmallVectorImpl<Value *> &PointerOps, bool TryRecursiveCheck) const { // Check that a vectorized load would load the same memory as a scalar // load. For example, we don't want to vectorize loads that are smaller // than 8-bit. Even though we have a packed struct {<i2, i2, i2, i2>} LLVM @@ -3862,13 +4554,14 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0, // unvectorized version. Type *ScalarTy = VL0->getType(); - if (DL.getTypeSizeInBits(ScalarTy) != DL.getTypeAllocSizeInBits(ScalarTy)) + if (DL->getTypeSizeInBits(ScalarTy) != DL->getTypeAllocSizeInBits(ScalarTy)) return LoadsState::Gather; // Make sure all loads in the bundle are simple - we can't vectorize // atomic or volatile loads. PointerOps.clear(); - PointerOps.resize(VL.size()); + const unsigned Sz = VL.size(); + PointerOps.resize(Sz); auto *POIter = PointerOps.begin(); for (Value *V : VL) { auto *L = cast<LoadInst>(V); @@ -3879,12 +4572,24 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0, } Order.clear(); + auto *VecTy = getWidenedType(ScalarTy, Sz); // Check the order of pointer operands or that all pointers are the same. - bool IsSorted = sortPtrAccesses(PointerOps, ScalarTy, DL, SE, Order); + bool IsSorted = sortPtrAccesses(PointerOps, ScalarTy, *DL, *SE, Order); + // FIXME: Reordering isn't implemented for non-power-of-2 nodes yet. + if (!Order.empty() && !isPowerOf2_32(VL.size())) { + assert(VectorizeNonPowerOf2 && "non-power-of-2 number of loads only " + "supported with VectorizeNonPowerOf2"); + return LoadsState::Gather; + } + + Align CommonAlignment = computeCommonAlignment<LoadInst>(VL); + if (!IsSorted && Sz > MinProfitableStridedLoads && TTI->isTypeLegal(VecTy) && + TTI->isLegalStridedLoadStore(VecTy, CommonAlignment) && + calculateRtStride(PointerOps, ScalarTy, *DL, *SE, Order)) + return LoadsState::StridedVectorize; if (IsSorted || all_of(PointerOps, [&](Value *P) { - return arePointersCompatible(P, PointerOps.front(), TLI); + return arePointersCompatible(P, PointerOps.front(), *TLI); })) { - bool IsPossibleStrided = false; if (IsSorted) { Value *Ptr0; Value *PtrN; @@ -3896,35 +4601,184 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0, PtrN = PointerOps[Order.back()]; } std::optional<int> Diff = - getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, DL, SE); + getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, *DL, *SE); // Check that the sorted loads are consecutive. - if (static_cast<unsigned>(*Diff) == VL.size() - 1) + if (static_cast<unsigned>(*Diff) == Sz - 1) return LoadsState::Vectorize; // Simple check if not a strided access - clear order. - IsPossibleStrided = *Diff % (VL.size() - 1) == 0; + bool IsPossibleStrided = *Diff % (Sz - 1) == 0; + // Try to generate strided load node if: + // 1. Target with strided load support is detected. + // 2. The number of loads is greater than MinProfitableStridedLoads, + // or the potential stride <= MaxProfitableLoadStride and the + // potential stride is power-of-2 (to avoid perf regressions for the very + // small number of loads) and max distance > number of loads, or potential + // stride is -1. + // 3. The loads are ordered, or number of unordered loads <= + // MaxProfitableUnorderedLoads, or loads are in reversed order. + // (this check is to avoid extra costs for very expensive shuffles). + if (IsPossibleStrided && (((Sz > MinProfitableStridedLoads || + (static_cast<unsigned>(std::abs(*Diff)) <= + MaxProfitableLoadStride * Sz && + isPowerOf2_32(std::abs(*Diff)))) && + static_cast<unsigned>(std::abs(*Diff)) > Sz) || + *Diff == -(static_cast<int>(Sz) - 1))) { + int Stride = *Diff / static_cast<int>(Sz - 1); + if (*Diff == Stride * static_cast<int>(Sz - 1)) { + Align Alignment = + cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()]) + ->getAlign(); + if (TTI->isLegalStridedLoadStore(VecTy, Alignment)) { + // Iterate through all pointers and check if all distances are + // unique multiple of Dist. + SmallSet<int, 4> Dists; + for (Value *Ptr : PointerOps) { + int Dist = 0; + if (Ptr == PtrN) + Dist = *Diff; + else if (Ptr != Ptr0) + Dist = + *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE); + // If the strides are not the same or repeated, we can't + // vectorize. + if (((Dist / Stride) * Stride) != Dist || + !Dists.insert(Dist).second) + break; + } + if (Dists.size() == Sz) + return LoadsState::StridedVectorize; + } + } + } } + auto CheckForShuffledLoads = [&, &TTI = *TTI](Align CommonAlignment) { + unsigned Sz = DL->getTypeSizeInBits(ScalarTy); + unsigned MinVF = getMinVF(Sz); + unsigned MaxVF = std::max<unsigned>(bit_floor(VL.size() / 2), MinVF); + MaxVF = std::min(getMaximumVF(Sz, Instruction::Load), MaxVF); + for (unsigned VF = MaxVF; VF >= MinVF; VF /= 2) { + unsigned VectorizedCnt = 0; + SmallVector<LoadsState> States; + for (unsigned Cnt = 0, End = VL.size(); Cnt + VF <= End; + Cnt += VF, ++VectorizedCnt) { + ArrayRef<Value *> Slice = VL.slice(Cnt, VF); + SmallVector<unsigned> Order; + SmallVector<Value *> PointerOps; + LoadsState LS = + canVectorizeLoads(Slice, Slice.front(), Order, PointerOps, + /*TryRecursiveCheck=*/false); + // Check that the sorted loads are consecutive. + if (LS == LoadsState::Gather) + break; + // If need the reorder - consider as high-cost masked gather for now. + if ((LS == LoadsState::Vectorize || + LS == LoadsState::StridedVectorize) && + !Order.empty() && !isReverseOrder(Order)) + LS = LoadsState::ScatterVectorize; + States.push_back(LS); + } + // Can be vectorized later as a serie of loads/insertelements. + if (VectorizedCnt == VL.size() / VF) { + // Compare masked gather cost and loads + insersubvector costs. + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + auto [ScalarGEPCost, VectorGEPCost] = getGEPCosts( + TTI, PointerOps, PointerOps.front(), Instruction::GetElementPtr, + CostKind, ScalarTy, VecTy); + InstructionCost MaskedGatherCost = + TTI.getGatherScatterOpCost( + Instruction::Load, VecTy, + cast<LoadInst>(VL0)->getPointerOperand(), + /*VariableMask=*/false, CommonAlignment, CostKind) + + VectorGEPCost - ScalarGEPCost; + InstructionCost VecLdCost = 0; + auto *SubVecTy = getWidenedType(ScalarTy, VF); + for (auto [I, LS] : enumerate(States)) { + auto *LI0 = cast<LoadInst>(VL[I * VF]); + switch (LS) { + case LoadsState::Vectorize: { + auto [ScalarGEPCost, VectorGEPCost] = + getGEPCosts(TTI, ArrayRef(PointerOps).slice(I * VF, VF), + LI0->getPointerOperand(), Instruction::Load, + CostKind, ScalarTy, SubVecTy); + VecLdCost += TTI.getMemoryOpCost( + Instruction::Load, SubVecTy, LI0->getAlign(), + LI0->getPointerAddressSpace(), CostKind, + TTI::OperandValueInfo()) + + VectorGEPCost - ScalarGEPCost; + break; + } + case LoadsState::StridedVectorize: { + auto [ScalarGEPCost, VectorGEPCost] = + getGEPCosts(TTI, ArrayRef(PointerOps).slice(I * VF, VF), + LI0->getPointerOperand(), Instruction::Load, + CostKind, ScalarTy, SubVecTy); + VecLdCost += + TTI.getStridedMemoryOpCost( + Instruction::Load, SubVecTy, LI0->getPointerOperand(), + /*VariableMask=*/false, CommonAlignment, CostKind) + + VectorGEPCost - ScalarGEPCost; + break; + } + case LoadsState::ScatterVectorize: { + auto [ScalarGEPCost, VectorGEPCost] = getGEPCosts( + TTI, ArrayRef(PointerOps).slice(I * VF, VF), + LI0->getPointerOperand(), Instruction::GetElementPtr, + CostKind, ScalarTy, SubVecTy); + VecLdCost += + TTI.getGatherScatterOpCost( + Instruction::Load, SubVecTy, LI0->getPointerOperand(), + /*VariableMask=*/false, CommonAlignment, CostKind) + + VectorGEPCost - ScalarGEPCost; + break; + } + case LoadsState::Gather: + llvm_unreachable( + "Expected only consecutive, strided or masked gather loads."); + } + SmallVector<int> ShuffleMask(VL.size()); + for (int Idx : seq<int>(0, VL.size())) + ShuffleMask[Idx] = Idx / VF == I ? VL.size() + Idx % VF : Idx; + VecLdCost += + TTI.getShuffleCost(TTI::SK_InsertSubvector, VecTy, ShuffleMask, + CostKind, I * VF, SubVecTy); + } + // If masked gather cost is higher - better to vectorize, so + // consider it as a gather node. It will be better estimated + // later. + if (MaskedGatherCost >= VecLdCost) + return true; + } + } + return false; + }; // TODO: need to improve analysis of the pointers, if not all of them are // GEPs or have > 2 operands, we end up with a gather node, which just // increases the cost. - Loop *L = LI.getLoopFor(cast<LoadInst>(VL0)->getParent()); + Loop *L = LI->getLoopFor(cast<LoadInst>(VL0)->getParent()); bool ProfitableGatherPointers = + L && Sz > 2 && static_cast<unsigned>(count_if(PointerOps, [L](Value *V) { - return L && L->isLoopInvariant(V); - })) <= VL.size() / 2 && VL.size() > 2; + return L->isLoopInvariant(V); + })) <= Sz / 2; if (ProfitableGatherPointers || all_of(PointerOps, [IsSorted](Value *P) { auto *GEP = dyn_cast<GetElementPtrInst>(P); return (IsSorted && !GEP && doesNotNeedToBeScheduled(P)) || - (GEP && GEP->getNumOperands() == 2); + (GEP && GEP->getNumOperands() == 2 && + isa<Constant, Instruction>(GEP->getOperand(1))); })) { - Align CommonAlignment = cast<LoadInst>(VL0)->getAlign(); - for (Value *V : VL) - CommonAlignment = - std::min(CommonAlignment, cast<LoadInst>(V)->getAlign()); - auto *VecTy = FixedVectorType::get(ScalarTy, VL.size()); - if (TTI.isLegalMaskedGather(VecTy, CommonAlignment) && - !TTI.forceScalarizeMaskedGather(VecTy, CommonAlignment)) - return IsPossibleStrided ? LoadsState::PossibleStridedVectorize - : LoadsState::ScatterVectorize; + Align CommonAlignment = computeCommonAlignment<LoadInst>(VL); + if (TTI->isLegalMaskedGather(VecTy, CommonAlignment) && + !TTI->forceScalarizeMaskedGather(VecTy, CommonAlignment)) { + // Check if potential masked gather can be represented as series + // of loads + insertsubvectors. + if (TryRecursiveCheck && CheckForShuffledLoads(CommonAlignment)) { + // If masked gather cost is higher - better to vectorize, so + // consider it as a gather node. It will be better estimated + // later. + return LoadsState::Gather; + } + return LoadsState::ScatterVectorize; + } } } @@ -4000,7 +4854,7 @@ static bool clusterSortPtrAccesses(ArrayRef<Value *> VL, Type *ElemTy, std::optional<BoUpSLP::OrdersType> BoUpSLP::findPartiallyOrderedLoads(const BoUpSLP::TreeEntry &TE) { - assert(TE.State == TreeEntry::NeedToGather && "Expected gather node only."); + assert(TE.isGather() && "Expected gather node only."); Type *ScalarTy = TE.Scalars[0]->getType(); SmallVector<Value *> Ptrs; @@ -4033,8 +4887,8 @@ static bool areTwoInsertFromSameBuildVector( return false; auto *IE1 = VU; auto *IE2 = V; - std::optional<unsigned> Idx1 = getInsertIndex(IE1); - std::optional<unsigned> Idx2 = getInsertIndex(IE2); + std::optional<unsigned> Idx1 = getElementIndex(IE1); + std::optional<unsigned> Idx2 = getElementIndex(IE2); if (Idx1 == std::nullopt || Idx2 == std::nullopt) return false; // Go through the vector operand of insertelement instructions trying to find @@ -4049,7 +4903,7 @@ static bool areTwoInsertFromSameBuildVector( if (IE1 == V && !IE2) return V->hasOneUse(); if (IE1 && IE1 != V) { - unsigned Idx1 = getInsertIndex(IE1).value_or(*Idx2); + unsigned Idx1 = getElementIndex(IE1).value_or(*Idx2); IsReusedIdx |= ReusedIdx.test(Idx1); ReusedIdx.set(Idx1); if ((IE1 != VU && !IE1->hasOneUse()) || IsReusedIdx) @@ -4058,7 +4912,7 @@ static bool areTwoInsertFromSameBuildVector( IE1 = dyn_cast_or_null<InsertElementInst>(GetBaseOperand(IE1)); } if (IE2 && IE2 != VU) { - unsigned Idx2 = getInsertIndex(IE2).value_or(*Idx1); + unsigned Idx2 = getElementIndex(IE2).value_or(*Idx1); IsReusedIdx |= ReusedIdx.test(Idx2); ReusedIdx.set(Idx2); if ((IE2 != V && !IE2->hasOneUse()) || IsReusedIdx) @@ -4072,9 +4926,15 @@ static bool areTwoInsertFromSameBuildVector( std::optional<BoUpSLP::OrdersType> BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) { + // FIXME: Vectorizing is not supported yet for non-power-of-2 ops. + if (TE.isNonPowOf2Vec()) + return std::nullopt; + // No need to reorder if need to shuffle reuses, still need to shuffle the // node. if (!TE.ReuseShuffleIndices.empty()) { + if (isSplat(TE.Scalars)) + return std::nullopt; // Check if reuse shuffle indices can be improved by reordering. // For this, check that reuse mask is "clustered", i.e. each scalar values // is used once in each submask of size <number_of_scalars>. @@ -4083,9 +4943,60 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) { // 0, 1, 2, 3, 3, 3, 1, 0 - not clustered, because // element 3 is used twice in the second submask. unsigned Sz = TE.Scalars.size(); - if (!ShuffleVectorInst::isOneUseSingleSourceMask(TE.ReuseShuffleIndices, - Sz)) + if (TE.isGather()) { + if (std::optional<OrdersType> CurrentOrder = + findReusedOrderedScalars(TE)) { + SmallVector<int> Mask; + fixupOrderingIndices(*CurrentOrder); + inversePermutation(*CurrentOrder, Mask); + ::addMask(Mask, TE.ReuseShuffleIndices); + OrdersType Res(TE.getVectorFactor(), TE.getVectorFactor()); + unsigned Sz = TE.Scalars.size(); + for (int K = 0, E = TE.getVectorFactor() / Sz; K < E; ++K) { + for (auto [I, Idx] : enumerate(ArrayRef(Mask).slice(K * Sz, Sz))) + if (Idx != PoisonMaskElem) + Res[Idx + K * Sz] = I + K * Sz; + } + return std::move(Res); + } + } + if (Sz == 2 && TE.getVectorFactor() == 4 && + TTI->getNumberOfParts(getWidenedType(TE.Scalars.front()->getType(), + 2 * TE.getVectorFactor())) == 1) return std::nullopt; + if (!ShuffleVectorInst::isOneUseSingleSourceMask(TE.ReuseShuffleIndices, + Sz)) { + SmallVector<int> ReorderMask(Sz, PoisonMaskElem); + if (TE.ReorderIndices.empty()) + std::iota(ReorderMask.begin(), ReorderMask.end(), 0); + else + inversePermutation(TE.ReorderIndices, ReorderMask); + ::addMask(ReorderMask, TE.ReuseShuffleIndices); + unsigned VF = ReorderMask.size(); + OrdersType ResOrder(VF, VF); + unsigned NumParts = divideCeil(VF, Sz); + SmallBitVector UsedVals(NumParts); + for (unsigned I = 0; I < VF; I += Sz) { + int Val = PoisonMaskElem; + unsigned UndefCnt = 0; + unsigned Limit = std::min(Sz, VF - I); + if (any_of(ArrayRef(ReorderMask).slice(I, Limit), + [&](int Idx) { + if (Val == PoisonMaskElem && Idx != PoisonMaskElem) + Val = Idx; + if (Idx == PoisonMaskElem) + ++UndefCnt; + return Idx != PoisonMaskElem && Idx != Val; + }) || + Val >= static_cast<int>(NumParts) || UsedVals.test(Val) || + UndefCnt > Sz / 2) + return std::nullopt; + UsedVals.set(Val); + for (unsigned K = 0; K < NumParts; ++K) + ResOrder[Val + Sz * K] = I + K; + } + return std::move(ResOrder); + } unsigned VF = TE.getVectorFactor(); // Try build correct order for extractelement instructions. SmallVector<int> ReusedMask(TE.ReuseShuffleIndices.begin(), @@ -4123,13 +5034,21 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) { transform(CurrentOrder, It, [K](unsigned Pos) { return Pos + K; }); std::advance(It, Sz); } - if (all_of(enumerate(ResOrder), - [](const auto &Data) { return Data.index() == Data.value(); })) + if (TE.isGather() && all_of(enumerate(ResOrder), [](const auto &Data) { + return Data.index() == Data.value(); + })) return std::nullopt; // No need to reorder. return std::move(ResOrder); } + if (TE.State == TreeEntry::StridedVectorize && !TopToBottom && + any_of(TE.UserTreeIndices, + [](const EdgeInfo &EI) { + return !Instruction::isBinaryOp(EI.UserTE->getOpcode()); + }) && + (TE.ReorderIndices.empty() || isReverseOrder(TE.ReorderIndices))) + return std::nullopt; if ((TE.State == TreeEntry::Vectorize || - TE.State == TreeEntry::PossibleStridedVectorize) && + TE.State == TreeEntry::StridedVectorize) && (isa<LoadInst, ExtractElementInst, ExtractValueInst>(TE.getMainOp()) || (TopToBottom && isa<StoreInst, InsertElementInst>(TE.getMainOp()))) && !TE.isAltShuffle()) @@ -4138,9 +5057,11 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) { auto PHICompare = [&](unsigned I1, unsigned I2) { Value *V1 = TE.Scalars[I1]; Value *V2 = TE.Scalars[I2]; - if (V1 == V2) + if (V1 == V2 || (V1->getNumUses() == 0 && V2->getNumUses() == 0)) return false; - if (!V1->hasOneUse() || !V2->hasOneUse()) + if (V1->getNumUses() < V2->getNumUses()) + return true; + if (V1->getNumUses() > V2->getNumUses()) return false; auto *FirstUserOfPhi1 = cast<Instruction>(*V1->user_begin()); auto *FirstUserOfPhi2 = cast<Instruction>(*V2->user_begin()); @@ -4149,24 +5070,16 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) { if (!areTwoInsertFromSameBuildVector( IE1, IE2, [](InsertElementInst *II) { return II->getOperand(0); })) - return false; - std::optional<unsigned> Idx1 = getInsertIndex(IE1); - std::optional<unsigned> Idx2 = getInsertIndex(IE2); - if (Idx1 == std::nullopt || Idx2 == std::nullopt) - return false; - return *Idx1 < *Idx2; + return I1 < I2; + return getElementIndex(IE1) < getElementIndex(IE2); } if (auto *EE1 = dyn_cast<ExtractElementInst>(FirstUserOfPhi1)) if (auto *EE2 = dyn_cast<ExtractElementInst>(FirstUserOfPhi2)) { if (EE1->getOperand(0) != EE2->getOperand(0)) - return false; - std::optional<unsigned> Idx1 = getExtractIndex(EE1); - std::optional<unsigned> Idx2 = getExtractIndex(EE2); - if (Idx1 == std::nullopt || Idx2 == std::nullopt) - return false; - return *Idx1 < *Idx2; + return I1 < I2; + return getElementIndex(EE1) < getElementIndex(EE2); } - return false; + return I1 < I2; }; auto IsIdentityOrder = [](const OrdersType &Order) { for (unsigned Idx : seq<unsigned>(0, Order.size())) @@ -4189,33 +5102,23 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) { return std::nullopt; // No need to reorder. return std::move(ResOrder); } - if (TE.State == TreeEntry::NeedToGather) { + if (TE.isGather() && !TE.isAltShuffle() && allSameType(TE.Scalars)) { // TODO: add analysis of other gather nodes with extractelement // instructions and other values/instructions, not only undefs. - if (((TE.getOpcode() == Instruction::ExtractElement && - !TE.isAltShuffle()) || - (all_of(TE.Scalars, - [](Value *V) { - return isa<UndefValue, ExtractElementInst>(V); - }) && - any_of(TE.Scalars, - [](Value *V) { return isa<ExtractElementInst>(V); }))) && - all_of(TE.Scalars, - [](Value *V) { - auto *EE = dyn_cast<ExtractElementInst>(V); - return !EE || isa<FixedVectorType>(EE->getVectorOperandType()); - }) && - allSameType(TE.Scalars)) { + if ((TE.getOpcode() == Instruction::ExtractElement || + (all_of(TE.Scalars, IsaPred<UndefValue, ExtractElementInst>) && + any_of(TE.Scalars, IsaPred<ExtractElementInst>))) && + all_of(TE.Scalars, [](Value *V) { + auto *EE = dyn_cast<ExtractElementInst>(V); + return !EE || isa<FixedVectorType>(EE->getVectorOperandType()); + })) { // Check that gather of extractelements can be represented as // just a shuffle of a single vector. OrdersType CurrentOrder; bool Reuse = canReuseExtract(TE.Scalars, TE.getMainOp(), CurrentOrder, /*ResizeAllowed=*/true); - if (Reuse || !CurrentOrder.empty()) { - if (!CurrentOrder.empty()) - fixupOrderingIndices(CurrentOrder); + if (Reuse || !CurrentOrder.empty()) return std::move(CurrentOrder); - } } // If the gather node is <undef, v, .., poison> and // insertelement poison, v, 0 [+ permute] @@ -4225,12 +5128,12 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) { // might be transformed. int Sz = TE.Scalars.size(); if (isSplat(TE.Scalars) && !allConstant(TE.Scalars) && - count_if(TE.Scalars, UndefValue::classof) == Sz - 1) { + count_if(TE.Scalars, IsaPred<UndefValue>) == Sz - 1) { const auto *It = find_if(TE.Scalars, [](Value *V) { return !isConstant(V); }); if (It == TE.Scalars.begin()) return OrdersType(); - auto *Ty = FixedVectorType::get(TE.Scalars.front()->getType(), Sz); + auto *Ty = getWidenedType(TE.Scalars.front()->getType(), Sz); if (It != TE.Scalars.end()) { OrdersType Order(Sz, Sz); unsigned Idx = std::distance(TE.Scalars.begin(), It); @@ -4248,15 +5151,20 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) { InstructionCost InsertIdxCost = TTI->getVectorInstrCost( Instruction::InsertElement, Ty, TTI::TCK_RecipThroughput, Idx, PoisonValue::get(Ty), *It); - if (InsertFirstCost + PermuteCost < InsertIdxCost) + if (InsertFirstCost + PermuteCost < InsertIdxCost) { + OrdersType Order(Sz, Sz); + Order[Idx] = 0; return std::move(Order); + } } } - if (std::optional<OrdersType> CurrentOrder = findReusedOrderedScalars(TE)) - return CurrentOrder; + if (isSplat(TE.Scalars)) + return std::nullopt; if (TE.Scalars.size() >= 4) if (std::optional<OrdersType> Order = findPartiallyOrderedLoads(TE)) return Order; + if (std::optional<OrdersType> CurrentOrder = findReusedOrderedScalars(TE)) + return CurrentOrder; } return std::nullopt; } @@ -4281,7 +5189,7 @@ void BoUpSLP::reorderNodeWithReuses(TreeEntry &TE, ArrayRef<int> Mask) const { reorderReuses(TE.ReuseShuffleIndices, Mask); const unsigned Sz = TE.Scalars.size(); // For vectorized and non-clustered reused no need to do anything else. - if (TE.State != TreeEntry::NeedToGather || + if (!TE.isGather() || !ShuffleVectorInst::isOneUseSingleSourceMask(TE.ReuseShuffleIndices, Sz) || !isRepeatedNonIdentityClusteredMask(TE.ReuseShuffleIndices, Sz)) @@ -4303,6 +5211,28 @@ void BoUpSLP::reorderNodeWithReuses(TreeEntry &TE, ArrayRef<int> Mask) const { std::iota(It, std::next(It, Sz), 0); } +static void combineOrders(MutableArrayRef<unsigned> Order, + ArrayRef<unsigned> SecondaryOrder) { + assert((SecondaryOrder.empty() || Order.size() == SecondaryOrder.size()) && + "Expected same size of orders"); + unsigned Sz = Order.size(); + SmallBitVector UsedIndices(Sz); + for (unsigned Idx : seq<unsigned>(0, Sz)) { + if (Order[Idx] != Sz) + UsedIndices.set(Order[Idx]); + } + if (SecondaryOrder.empty()) { + for (unsigned Idx : seq<unsigned>(0, Sz)) + if (Order[Idx] == Sz && !UsedIndices.test(Idx)) + Order[Idx] = Idx; + } else { + for (unsigned Idx : seq<unsigned>(0, Sz)) + if (SecondaryOrder[Idx] != Sz && Order[Idx] == Sz && + !UsedIndices.test(SecondaryOrder[Idx])) + Order[Idx] = SecondaryOrder[Idx]; + } +} + void BoUpSLP::reorderTopToBottom() { // Maps VF to the graph nodes. DenseMap<unsigned, SetVector<TreeEntry *>> VFToOrderedEntries; @@ -4320,14 +5250,10 @@ void BoUpSLP::reorderTopToBottom() { // Maps a TreeEntry to the reorder indices of external users. DenseMap<const TreeEntry *, SmallVector<OrdersType, 1>> ExternalUserReorderMap; - // FIXME: Workaround for syntax error reported by MSVC buildbots. - TargetTransformInfo &TTIRef = *TTI; // Find all reorderable nodes with the given VF. // Currently the are vectorized stores,loads,extracts + some gathering of // extracts. - for_each(VectorizableTree, [this, &TTIRef, &VFToOrderedEntries, - &GathersToOrders, &ExternalUserReorderMap, - &AltShufflesToOrders, &PhisToOrders]( + for_each(VectorizableTree, [&, &TTIRef = *TTI]( const std::unique_ptr<TreeEntry> &TE) { // Look for external users that will probably be vectorized. SmallVector<OrdersType, 1> ExternalUserReorderIndices = @@ -4343,14 +5269,10 @@ void BoUpSLP::reorderTopToBottom() { // to take into account their order when looking for the most used order. if (TE->isAltShuffle()) { VectorType *VecTy = - FixedVectorType::get(TE->Scalars[0]->getType(), TE->Scalars.size()); + getWidenedType(TE->Scalars[0]->getType(), TE->Scalars.size()); unsigned Opcode0 = TE->getOpcode(); unsigned Opcode1 = TE->getAltOpcode(); - // The opcode mask selects between the two opcodes. - SmallBitVector OpcodeMask(TE->Scalars.size(), false); - for (unsigned Lane : seq<unsigned>(0, TE->Scalars.size())) - if (cast<Instruction>(TE->Scalars[Lane])->getOpcode() == Opcode1) - OpcodeMask.set(Lane); + SmallBitVector OpcodeMask(getAltInstrMask(TE->Scalars, Opcode0, Opcode1)); // If this pattern is supported by the target then we consider the order. if (TTIRef.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask)) { VFToOrderedEntries[TE->getVectorFactor()].insert(TE.get()); @@ -4383,7 +5305,7 @@ void BoUpSLP::reorderTopToBottom() { } VFToOrderedEntries[TE->getVectorFactor()].insert(TE.get()); if (!(TE->State == TreeEntry::Vectorize || - TE->State == TreeEntry::PossibleStridedVectorize) || + TE->State == TreeEntry::StridedVectorize) || !TE->ReuseShuffleIndices.empty()) GathersToOrders.try_emplace(TE.get(), *CurrentOrder); if (TE->State == TreeEntry::Vectorize && @@ -4407,9 +5329,6 @@ void BoUpSLP::reorderTopToBottom() { MapVector<OrdersType, unsigned, DenseMap<OrdersType, unsigned, OrdersTypeDenseMapInfo>> OrdersUses; - // Last chance orders - scatter vectorize. Try to use their orders if no - // other orders or the order is counted already. - SmallVector<OrdersType> StridedVectorizeOrders; SmallPtrSet<const TreeEntry *, 4> VisitedOps; for (const TreeEntry *OpTE : OrderedEntries) { // No need to reorder this nodes, still need to extend and to use shuffle, @@ -4419,8 +5338,7 @@ void BoUpSLP::reorderTopToBottom() { // Count number of orders uses. const auto &Order = [OpTE, &GathersToOrders, &AltShufflesToOrders, &PhisToOrders]() -> const OrdersType & { - if (OpTE->State == TreeEntry::NeedToGather || - !OpTE->ReuseShuffleIndices.empty()) { + if (OpTE->isGather() || !OpTE->ReuseShuffleIndices.empty()) { auto It = GathersToOrders.find(OpTE); if (It != GathersToOrders.end()) return It->second; @@ -4456,11 +5374,6 @@ void BoUpSLP::reorderTopToBottom() { if (Order.empty()) continue; } - // Postpone scatter orders. - if (OpTE->State == TreeEntry::PossibleStridedVectorize) { - StridedVectorizeOrders.push_back(Order); - continue; - } // Stores actually store the mask, not the order, need to invert. if (OpTE->State == TreeEntry::Vectorize && !OpTE->isAltShuffle() && OpTE->getOpcode() == Instruction::Store && !Order.empty()) { @@ -4477,34 +5390,48 @@ void BoUpSLP::reorderTopToBottom() { ++OrdersUses.insert(std::make_pair(Order, 0)).first->second; } } - // Set order of the user node. - if (OrdersUses.empty()) { - if (StridedVectorizeOrders.empty()) - continue; - // Add (potentially!) strided vectorize orders. - for (OrdersType &Order : StridedVectorizeOrders) - ++OrdersUses.insert(std::make_pair(Order, 0)).first->second; - } else { - // Account (potentially!) strided vectorize orders only if it was used - // already. - for (OrdersType &Order : StridedVectorizeOrders) { - auto *It = OrdersUses.find(Order); - if (It != OrdersUses.end()) - ++It->second; - } - } + if (OrdersUses.empty()) + continue; + auto IsIdentityOrder = [](ArrayRef<unsigned> Order) { + const unsigned Sz = Order.size(); + for (unsigned Idx : seq<unsigned>(0, Sz)) + if (Idx != Order[Idx] && Order[Idx] != Sz) + return false; + return true; + }; // Choose the most used order. - ArrayRef<unsigned> BestOrder = OrdersUses.front().first; - unsigned Cnt = OrdersUses.front().second; - for (const auto &Pair : drop_begin(OrdersUses)) { - if (Cnt < Pair.second || (Cnt == Pair.second && Pair.first.empty())) { + unsigned IdentityCnt = 0; + unsigned FilledIdentityCnt = 0; + OrdersType IdentityOrder(VF, VF); + for (auto &Pair : OrdersUses) { + if (Pair.first.empty() || IsIdentityOrder(Pair.first)) { + if (!Pair.first.empty()) + FilledIdentityCnt += Pair.second; + IdentityCnt += Pair.second; + combineOrders(IdentityOrder, Pair.first); + } + } + MutableArrayRef<unsigned> BestOrder = IdentityOrder; + unsigned Cnt = IdentityCnt; + for (auto &Pair : OrdersUses) { + // Prefer identity order. But, if filled identity found (non-empty order) + // with same number of uses, as the new candidate order, we can choose + // this candidate order. + if (Cnt < Pair.second || + (Cnt == IdentityCnt && IdentityCnt == FilledIdentityCnt && + Cnt == Pair.second && !BestOrder.empty() && + IsIdentityOrder(BestOrder))) { + combineOrders(Pair.first, BestOrder); BestOrder = Pair.first; Cnt = Pair.second; + } else { + combineOrders(BestOrder, Pair.first); } } // Set order of the user node. - if (BestOrder.empty()) + if (IsIdentityOrder(BestOrder)) continue; + fixupOrderingIndices(BestOrder); SmallVector<int> Mask; inversePermutation(BestOrder, Mask); SmallVector<int> MaskOrder(BestOrder.size(), PoisonMaskElem); @@ -4534,7 +5461,7 @@ void BoUpSLP::reorderTopToBottom() { continue; } if ((TE->State == TreeEntry::Vectorize || - TE->State == TreeEntry::PossibleStridedVectorize) && + TE->State == TreeEntry::StridedVectorize) && isa<ExtractElementInst, ExtractValueInst, LoadInst, StoreInst, InsertElementInst>(TE->getMainOp()) && !TE->isAltShuffle()) { @@ -4568,17 +5495,18 @@ bool BoUpSLP::canReorderOperands( TreeEntry *UserTE, SmallVectorImpl<std::pair<unsigned, TreeEntry *>> &Edges, ArrayRef<TreeEntry *> ReorderableGathers, SmallVectorImpl<TreeEntry *> &GatherOps) { + // FIXME: Reordering isn't implemented for non-power-of-2 nodes yet. + if (UserTE->isNonPowOf2Vec()) + return false; + for (unsigned I = 0, E = UserTE->getNumOperands(); I < E; ++I) { if (any_of(Edges, [I](const std::pair<unsigned, TreeEntry *> &OpData) { return OpData.first == I && - OpData.second->State == TreeEntry::Vectorize; + (OpData.second->State == TreeEntry::Vectorize || + OpData.second->State == TreeEntry::StridedVectorize); })) continue; if (TreeEntry *TE = getVectorizedOperand(UserTE, I)) { - // FIXME: Do not reorder (possible!) strided vectorized nodes, they - // require reordering of the operands, which is not implemented yet. - if (TE->State == TreeEntry::PossibleStridedVectorize) - return false; // Do not reorder if operand node is used by many user nodes. if (any_of(TE->UserTreeIndices, [UserTE](const EdgeInfo &EI) { return EI.UserTE != UserTE; })) @@ -4592,6 +5520,7 @@ bool BoUpSLP::canReorderOperands( // If there are reused scalars, process this node as a regular vectorize // node, just reorder reuses mask. if (TE->State != TreeEntry::Vectorize && + TE->State != TreeEntry::StridedVectorize && TE->ReuseShuffleIndices.empty() && TE->ReorderIndices.empty()) GatherOps.push_back(TE); continue; @@ -4600,6 +5529,7 @@ bool BoUpSLP::canReorderOperands( if (count_if(ReorderableGathers, [&Gather, UserTE, I](TreeEntry *TE) { assert(TE->State != TreeEntry::Vectorize && + TE->State != TreeEntry::StridedVectorize && "Only non-vectorized nodes are expected."); if (any_of(TE->UserTreeIndices, [UserTE, I](const EdgeInfo &EI) { @@ -4622,22 +5552,22 @@ bool BoUpSLP::canReorderOperands( void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { SetVector<TreeEntry *> OrderedEntries; - DenseMap<const TreeEntry *, OrdersType> GathersToOrders; + DenseSet<const TreeEntry *> GathersToOrders; // Find all reorderable leaf nodes with the given VF. // Currently the are vectorized loads,extracts without alternate operands + // some gathering of extracts. SmallVector<TreeEntry *> NonVectorized; for (const std::unique_ptr<TreeEntry> &TE : VectorizableTree) { if (TE->State != TreeEntry::Vectorize && - TE->State != TreeEntry::PossibleStridedVectorize) + TE->State != TreeEntry::StridedVectorize) NonVectorized.push_back(TE.get()); if (std::optional<OrdersType> CurrentOrder = getReorderingData(*TE, /*TopToBottom=*/false)) { OrderedEntries.insert(TE.get()); if (!(TE->State == TreeEntry::Vectorize || - TE->State == TreeEntry::PossibleStridedVectorize) || + TE->State == TreeEntry::StridedVectorize) || !TE->ReuseShuffleIndices.empty()) - GathersToOrders.try_emplace(TE.get(), *CurrentOrder); + GathersToOrders.insert(TE.get()); } } @@ -4653,9 +5583,8 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { SmallVector<TreeEntry *> Filtered; for (TreeEntry *TE : OrderedEntries) { if (!(TE->State == TreeEntry::Vectorize || - TE->State == TreeEntry::PossibleStridedVectorize || - (TE->State == TreeEntry::NeedToGather && - GathersToOrders.count(TE))) || + TE->State == TreeEntry::StridedVectorize || + (TE->isGather() && GathersToOrders.contains(TE))) || TE->UserTreeIndices.empty() || !TE->ReuseShuffleIndices.empty() || !all_of(drop_begin(TE->UserTreeIndices), [TE](const EdgeInfo &EI) { @@ -4698,9 +5627,6 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { MapVector<OrdersType, unsigned, DenseMap<OrdersType, unsigned, OrdersTypeDenseMapInfo>> OrdersUses; - // Last chance orders - scatter vectorize. Try to use their orders if no - // other orders or the order is counted already. - SmallVector<std::pair<OrdersType, unsigned>> StridedVectorizeOrders; // Do the analysis for each tree entry only once, otherwise the order of // the same node my be considered several times, though might be not // profitable. @@ -4712,21 +5638,20 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { continue; if (!OpTE->ReuseShuffleIndices.empty() && !GathersToOrders.count(OpTE)) continue; - const auto &Order = [OpTE, &GathersToOrders]() -> const OrdersType & { - if (OpTE->State == TreeEntry::NeedToGather || - !OpTE->ReuseShuffleIndices.empty()) - return GathersToOrders.find(OpTE)->second; + const auto Order = [&]() -> const OrdersType { + if (OpTE->isGather() || !OpTE->ReuseShuffleIndices.empty()) + return getReorderingData(*OpTE, /*TopToBottom=*/false) + .value_or(OrdersType(1)); return OpTE->ReorderIndices; }(); + // The order is partially ordered, skip it in favor of fully non-ordered + // orders. + if (Order.size() == 1) + continue; unsigned NumOps = count_if( Data.second, [OpTE](const std::pair<unsigned, TreeEntry *> &P) { return P.second == OpTE; }); - // Postpone scatter orders. - if (OpTE->State == TreeEntry::PossibleStridedVectorize) { - StridedVectorizeOrders.emplace_back(Order, NumOps); - continue; - } // Stores actually store the mask, not the order, need to invert. if (OpTE->State == TreeEntry::Vectorize && !OpTE->isAltShuffle() && OpTE->getOpcode() == Instruction::Store && !Order.empty()) { @@ -4744,16 +5669,19 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { OrdersUses.insert(std::make_pair(Order, 0)).first->second += NumOps; } auto Res = OrdersUses.insert(std::make_pair(OrdersType(), 0)); - const auto &&AllowsReordering = [IgnoreReorder, &GathersToOrders]( - const TreeEntry *TE) { + const auto AllowsReordering = [&](const TreeEntry *TE) { + // FIXME: Reordering isn't implemented for non-power-of-2 nodes yet. + if (TE->isNonPowOf2Vec()) + return false; if (!TE->ReorderIndices.empty() || !TE->ReuseShuffleIndices.empty() || (TE->State == TreeEntry::Vectorize && TE->isAltShuffle()) || (IgnoreReorder && TE->Idx == 0)) return true; - if (TE->State == TreeEntry::NeedToGather) { - auto It = GathersToOrders.find(TE); - if (It != GathersToOrders.end()) - return !It->second.empty(); + if (TE->isGather()) { + if (GathersToOrders.contains(TE)) + return !getReorderingData(*TE, /*TopToBottom=*/false) + .value_or(OrdersType(1)) + .empty(); return true; } return false; @@ -4785,45 +5713,49 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { ++Res.first->second; } } - // If no orders - skip current nodes and jump to the next one, if any. if (OrdersUses.empty()) { - if (StridedVectorizeOrders.empty() || - (Data.first->ReorderIndices.empty() && - Data.first->ReuseShuffleIndices.empty() && - !(IgnoreReorder && - Data.first == VectorizableTree.front().get()))) { - for (const std::pair<unsigned, TreeEntry *> &Op : Data.second) - OrderedEntries.remove(Op.second); - continue; - } - // Add (potentially!) strided vectorize orders. - for (std::pair<OrdersType, unsigned> &Pair : StridedVectorizeOrders) - OrdersUses.insert(std::make_pair(Pair.first, 0)).first->second += - Pair.second; - } else { - // Account (potentially!) strided vectorize orders only if it was used - // already. - for (std::pair<OrdersType, unsigned> &Pair : StridedVectorizeOrders) { - auto *It = OrdersUses.find(Pair.first); - if (It != OrdersUses.end()) - It->second += Pair.second; + for (const std::pair<unsigned, TreeEntry *> &Op : Data.second) + OrderedEntries.remove(Op.second); + continue; + } + auto IsIdentityOrder = [](ArrayRef<unsigned> Order) { + const unsigned Sz = Order.size(); + for (unsigned Idx : seq<unsigned>(0, Sz)) + if (Idx != Order[Idx] && Order[Idx] != Sz) + return false; + return true; + }; + // Choose the most used order. + unsigned IdentityCnt = 0; + unsigned VF = Data.second.front().second->getVectorFactor(); + OrdersType IdentityOrder(VF, VF); + for (auto &Pair : OrdersUses) { + if (Pair.first.empty() || IsIdentityOrder(Pair.first)) { + IdentityCnt += Pair.second; + combineOrders(IdentityOrder, Pair.first); } } - // Choose the best order. - ArrayRef<unsigned> BestOrder = OrdersUses.front().first; - unsigned Cnt = OrdersUses.front().second; - for (const auto &Pair : drop_begin(OrdersUses)) { - if (Cnt < Pair.second || (Cnt == Pair.second && Pair.first.empty())) { + MutableArrayRef<unsigned> BestOrder = IdentityOrder; + unsigned Cnt = IdentityCnt; + for (auto &Pair : OrdersUses) { + // Prefer identity order. But, if filled identity found (non-empty + // order) with same number of uses, as the new candidate order, we can + // choose this candidate order. + if (Cnt < Pair.second) { + combineOrders(Pair.first, BestOrder); BestOrder = Pair.first; Cnt = Pair.second; + } else { + combineOrders(BestOrder, Pair.first); } } - // Set order of the user node (reordering of operands and user nodes). - if (BestOrder.empty()) { + // Set order of the user node. + if (IsIdentityOrder(BestOrder)) { for (const std::pair<unsigned, TreeEntry *> &Op : Data.second) OrderedEntries.remove(Op.second); continue; } + fixupOrderingIndices(BestOrder); // Erase operands from OrderedEntries list and adjust their orders. VisitedOps.clear(); SmallVector<int> Mask; @@ -4844,7 +5776,7 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { } // Gathers are processed separately. if (TE->State != TreeEntry::Vectorize && - TE->State != TreeEntry::PossibleStridedVectorize && + TE->State != TreeEntry::StridedVectorize && (TE->State != TreeEntry::ScatterVectorize || TE->ReorderIndices.empty())) continue; @@ -4876,9 +5808,10 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { Data.first->reorderOperands(Mask); if (!isa<InsertElementInst, StoreInst>(Data.first->getMainOp()) || Data.first->isAltShuffle() || - Data.first->State == TreeEntry::PossibleStridedVectorize) { + Data.first->State == TreeEntry::StridedVectorize) { reorderScalars(Data.first->Scalars, Mask); - reorderOrder(Data.first->ReorderIndices, MaskOrder); + reorderOrder(Data.first->ReorderIndices, MaskOrder, + /*BottomOrder=*/true); if (Data.first->ReuseShuffleIndices.empty() && !Data.first->ReorderIndices.empty() && !Data.first->isAltShuffle()) { @@ -4899,12 +5832,13 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) { void BoUpSLP::buildExternalUses( const ExtraValueToDebugLocsMap &ExternallyUsedValues) { + DenseMap<Value *, unsigned> ScalarToExtUses; // Collect the values that we need to extract from the tree. for (auto &TEPtr : VectorizableTree) { TreeEntry *Entry = TEPtr.get(); // No need to handle users of gathered values. - if (Entry->State == TreeEntry::NeedToGather) + if (Entry->isGather()) continue; // For each lane: @@ -4912,14 +5846,20 @@ void BoUpSLP::buildExternalUses( Value *Scalar = Entry->Scalars[Lane]; if (!isa<Instruction>(Scalar)) continue; - int FoundLane = Entry->findLaneForValue(Scalar); + // All uses must be replaced already? No need to do it again. + auto It = ScalarToExtUses.find(Scalar); + if (It != ScalarToExtUses.end() && !ExternalUses[It->second].User) + continue; // Check if the scalar is externally used as an extra arg. const auto *ExtI = ExternallyUsedValues.find(Scalar); if (ExtI != ExternallyUsedValues.end()) { + int FoundLane = Entry->findLaneForValue(Scalar); LLVM_DEBUG(dbgs() << "SLP: Need to extract: Extra arg from lane " - << Lane << " from " << *Scalar << ".\n"); + << FoundLane << " from " << *Scalar << ".\n"); + ScalarToExtUses.try_emplace(Scalar, ExternalUses.size()); ExternalUses.emplace_back(Scalar, nullptr, FoundLane); + continue; } for (User *U : Scalar->users()) { LLVM_DEBUG(dbgs() << "SLP: Checking user:" << *U << ".\n"); @@ -4938,21 +5878,30 @@ void BoUpSLP::buildExternalUses( // instructions. If that is the case, the one in FoundLane will // be used. if (UseEntry->State == TreeEntry::ScatterVectorize || - UseEntry->State == TreeEntry::PossibleStridedVectorize || !doesInTreeUserNeedToExtract( Scalar, cast<Instruction>(UseEntry->Scalars.front()), TLI)) { LLVM_DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << *U << ".\n"); - assert(UseEntry->State != TreeEntry::NeedToGather && "Bad state"); + assert(!UseEntry->isGather() && "Bad state"); continue; } U = nullptr; + if (It != ScalarToExtUses.end()) { + ExternalUses[It->second].User = nullptr; + break; + } } + if (U && Scalar->hasNUsesOrMore(UsesLimit)) + U = nullptr; + int FoundLane = Entry->findLaneForValue(Scalar); LLVM_DEBUG(dbgs() << "SLP: Need to extract:" << *UserInst - << " from lane " << Lane << " from " << *Scalar + << " from lane " << FoundLane << " from " << *Scalar << ".\n"); + It = ScalarToExtUses.try_emplace(Scalar, ExternalUses.size()).first; ExternalUses.emplace_back(Scalar, U, FoundLane); + if (!U) + break; } } } @@ -4964,8 +5913,7 @@ BoUpSLP::collectUserStores(const BoUpSLP::TreeEntry *TE) const { for (unsigned Lane : seq<unsigned>(0, TE->Scalars.size())) { Value *V = TE->Scalars[Lane]; // To save compilation time we don't visit if we have too many users. - static constexpr unsigned UsersLimit = 4; - if (V->hasNUsesOrMore(UsersLimit)) + if (V->hasNUsesOrMore(UsesLimit)) break; // Collect stores per pointer object. @@ -5231,6 +6179,113 @@ static bool isAlternateInstruction(const Instruction *I, const Instruction *AltOp, const TargetLibraryInfo &TLI); +bool BoUpSLP::areAltOperandsProfitable(const InstructionsState &S, + ArrayRef<Value *> VL) const { + unsigned Opcode0 = S.getOpcode(); + unsigned Opcode1 = S.getAltOpcode(); + SmallBitVector OpcodeMask(getAltInstrMask(VL, Opcode0, Opcode1)); + // If this pattern is supported by the target then consider it profitable. + if (TTI->isLegalAltInstr(getWidenedType(S.MainOp->getType(), VL.size()), + Opcode0, Opcode1, OpcodeMask)) + return true; + SmallVector<ValueList> Operands; + for (unsigned I : seq<unsigned>(0, S.MainOp->getNumOperands())) { + Operands.emplace_back(); + // Prepare the operand vector. + for (Value *V : VL) + Operands.back().push_back(cast<Instruction>(V)->getOperand(I)); + } + if (Operands.size() == 2) { + // Try find best operands candidates. + for (unsigned I : seq<unsigned>(0, VL.size() - 1)) { + SmallVector<std::pair<Value *, Value *>> Candidates(3); + Candidates[0] = std::make_pair(Operands[0][I], Operands[0][I + 1]); + Candidates[1] = std::make_pair(Operands[0][I], Operands[1][I + 1]); + Candidates[2] = std::make_pair(Operands[1][I], Operands[0][I + 1]); + std::optional<int> Res = findBestRootPair(Candidates); + switch (Res.value_or(0)) { + case 0: + break; + case 1: + std::swap(Operands[0][I + 1], Operands[1][I + 1]); + break; + case 2: + std::swap(Operands[0][I], Operands[1][I]); + break; + default: + llvm_unreachable("Unexpected index."); + } + } + } + DenseSet<unsigned> UniqueOpcodes; + constexpr unsigned NumAltInsts = 3; // main + alt + shuffle. + unsigned NonInstCnt = 0; + // Estimate number of instructions, required for the vectorized node and for + // the buildvector node. + unsigned UndefCnt = 0; + // Count the number of extra shuffles, required for vector nodes. + unsigned ExtraShuffleInsts = 0; + // Check that operands do not contain same values and create either perfect + // diamond match or shuffled match. + if (Operands.size() == 2) { + // Do not count same operands twice. + if (Operands.front() == Operands.back()) { + Operands.erase(Operands.begin()); + } else if (!allConstant(Operands.front()) && + all_of(Operands.front(), [&](Value *V) { + return is_contained(Operands.back(), V); + })) { + Operands.erase(Operands.begin()); + ++ExtraShuffleInsts; + } + } + const Loop *L = LI->getLoopFor(S.MainOp->getParent()); + // Vectorize node, if: + // 1. at least single operand is constant or splat. + // 2. Operands have many loop invariants (the instructions are not loop + // invariants). + // 3. At least single unique operands is supposed to vectorized. + return none_of(Operands, + [&](ArrayRef<Value *> Op) { + if (allConstant(Op) || + (!isSplat(Op) && allSameBlock(Op) && allSameType(Op) && + getSameOpcode(Op, *TLI).MainOp)) + return false; + DenseMap<Value *, unsigned> Uniques; + for (Value *V : Op) { + if (isa<Constant, ExtractElementInst>(V) || + getTreeEntry(V) || (L && L->isLoopInvariant(V))) { + if (isa<UndefValue>(V)) + ++UndefCnt; + continue; + } + auto Res = Uniques.try_emplace(V, 0); + // Found first duplicate - need to add shuffle. + if (!Res.second && Res.first->second == 1) + ++ExtraShuffleInsts; + ++Res.first->getSecond(); + if (auto *I = dyn_cast<Instruction>(V)) + UniqueOpcodes.insert(I->getOpcode()); + else if (Res.second) + ++NonInstCnt; + } + return none_of(Uniques, [&](const auto &P) { + return P.first->hasNUsesOrMore(P.second + 1) && + none_of(P.first->users(), [&](User *U) { + return getTreeEntry(U) || Uniques.contains(U); + }); + }); + }) || + // Do not vectorize node, if estimated number of vector instructions is + // more than estimated number of buildvector instructions. Number of + // vector operands is number of vector instructions + number of vector + // instructions for operands (buildvectors). Number of buildvector + // instructions is just number_of_operands * number_of_scalars. + (UndefCnt < (VL.size() - 1) * S.MainOp->getNumOperands() && + (UniqueOpcodes.size() + NonInstCnt + ExtraShuffleInsts + + NumAltInsts) < S.MainOp->getNumOperands() * VL.size()); +} + BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState( InstructionsState &S, ArrayRef<Value *> VL, bool IsScatterVectorizeUserTE, OrdersType &CurrentOrder, SmallVectorImpl<Value *> &PointerOps) const { @@ -5241,6 +6296,9 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState( auto *VL0 = cast<Instruction>(S.OpValue); switch (ShuffleOrOp) { case Instruction::PHI: { + // Too many operands - gather, most probably won't be vectorized. + if (VL0->getNumOperands() > MaxPHINumOperands) + return TreeEntry::NeedToGather; // Check for terminator values (e.g. invoke). for (Value *V : VL) for (Value *Incoming : cast<PHINode>(V)->incoming_values()) { @@ -5257,6 +6315,9 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState( case Instruction::ExtractValue: case Instruction::ExtractElement: { bool Reuse = canReuseExtract(VL, VL0, CurrentOrder); + // FIXME: Vectorizing is not supported yet for non-power-of-2 ops. + if (!isPowerOf2_32(VL.size())) + return TreeEntry::NeedToGather; if (Reuse || !CurrentOrder.empty()) return TreeEntry::Vectorize; LLVM_DEBUG(dbgs() << "SLP: Gather extract sequence.\n"); @@ -5268,7 +6329,7 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState( ValueSet SourceVectors; for (Value *V : VL) { SourceVectors.insert(cast<Instruction>(V)->getOperand(0)); - assert(getInsertIndex(V) != std::nullopt && + assert(getElementIndex(V) != std::nullopt && "Non-constant or undef index?"); } @@ -5290,14 +6351,13 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState( // treats loading/storing it as an i8 struct. If we vectorize loads/stores // from such a struct, we read/write packed bits disagreeing with the // unvectorized version. - switch (canVectorizeLoads(VL, VL0, *TTI, *DL, *SE, *LI, *TLI, CurrentOrder, - PointerOps)) { + switch (canVectorizeLoads(VL, VL0, CurrentOrder, PointerOps)) { case LoadsState::Vectorize: return TreeEntry::Vectorize; case LoadsState::ScatterVectorize: return TreeEntry::ScatterVectorize; - case LoadsState::PossibleStridedVectorize: - return TreeEntry::PossibleStridedVectorize; + case LoadsState::StridedVectorize: + return TreeEntry::StridedVectorize; case LoadsState::Gather: #ifndef NDEBUG Type *ScalarTy = VL0->getType(); @@ -5529,6 +6589,14 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState( LLVM_DEBUG(dbgs() << "SLP: ShuffleVector are not vectorized.\n"); return TreeEntry::NeedToGather; } + if (!SLPSkipEarlyProfitabilityCheck && !areAltOperandsProfitable(S, VL)) { + LLVM_DEBUG( + dbgs() + << "SLP: ShuffleVector not vectorized, operands are buildvector and " + "the whole alt sequence is not profitable.\n"); + return TreeEntry::NeedToGather; + } + return TreeEntry::Vectorize; } default: @@ -5537,11 +6605,90 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState( } } +namespace { +/// Allows to correctly handle operands of the phi nodes based on the \p Main +/// PHINode order of incoming basic blocks/values. +class PHIHandler { + DominatorTree &DT; + PHINode *Main = nullptr; + SmallVector<Value *> Phis; + SmallVector<SmallVector<Value *>> Operands; + +public: + PHIHandler() = delete; + PHIHandler(DominatorTree &DT, PHINode *Main, ArrayRef<Value *> Phis) + : DT(DT), Main(Main), Phis(Phis), + Operands(Main->getNumIncomingValues(), + SmallVector<Value *>(Phis.size(), nullptr)) {} + void buildOperands() { + constexpr unsigned FastLimit = 4; + if (Main->getNumIncomingValues() <= FastLimit) { + for (unsigned I : seq<unsigned>(0, Main->getNumIncomingValues())) { + BasicBlock *InBB = Main->getIncomingBlock(I); + if (!DT.isReachableFromEntry(InBB)) { + Operands[I].assign(Phis.size(), PoisonValue::get(Main->getType())); + continue; + } + // Prepare the operand vector. + for (auto [Idx, V] : enumerate(Phis)) { + auto *P = cast<PHINode>(V); + if (P->getIncomingBlock(I) == InBB) + Operands[I][Idx] = P->getIncomingValue(I); + else + Operands[I][Idx] = P->getIncomingValueForBlock(InBB); + } + } + return; + } + SmallDenseMap<BasicBlock *, SmallVector<unsigned>, 4> Blocks; + for (unsigned I : seq<unsigned>(0, Main->getNumIncomingValues())) { + BasicBlock *InBB = Main->getIncomingBlock(I); + if (!DT.isReachableFromEntry(InBB)) { + Operands[I].assign(Phis.size(), PoisonValue::get(Main->getType())); + continue; + } + Blocks.try_emplace(InBB).first->second.push_back(I); + } + for (auto [Idx, V] : enumerate(Phis)) { + auto *P = cast<PHINode>(V); + for (unsigned I : seq<unsigned>(0, P->getNumIncomingValues())) { + BasicBlock *InBB = P->getIncomingBlock(I); + if (InBB == Main->getIncomingBlock(I)) { + if (isa_and_nonnull<PoisonValue>(Operands[I][Idx])) + continue; + Operands[I][Idx] = P->getIncomingValue(I); + continue; + } + auto It = Blocks.find(InBB); + if (It == Blocks.end()) + continue; + Operands[It->second.front()][Idx] = P->getIncomingValue(I); + } + } + for (const auto &P : Blocks) { + if (P.getSecond().size() <= 1) + continue; + unsigned BasicI = P.getSecond().front(); + for (unsigned I : ArrayRef(P.getSecond()).drop_front()) { + assert(all_of(enumerate(Operands[I]), + [&](const auto &Data) { + return !Data.value() || + Data.value() == Operands[BasicI][Data.index()]; + }) && + "Expected empty operands list."); + Operands[I] = Operands[BasicI]; + } + } + } + ArrayRef<Value *> getOperands(unsigned I) const { return Operands[I]; } +}; +} // namespace + void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, const EdgeInfo &UserTreeIdx) { assert((allConstant(VL) || allSameType(VL)) && "Invalid types!"); - SmallVector<int> ReuseShuffleIndicies; + SmallVector<int> ReuseShuffleIndices; SmallVector<Value *> UniqueValues; SmallVector<Value *> NonUniqueValueVL; auto TryToFindDuplicates = [&](const InstructionsState &S, @@ -5550,20 +6697,27 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, DenseMap<Value *, unsigned> UniquePositions(VL.size()); for (Value *V : VL) { if (isConstant(V)) { - ReuseShuffleIndicies.emplace_back( + ReuseShuffleIndices.emplace_back( isa<UndefValue>(V) ? PoisonMaskElem : UniqueValues.size()); UniqueValues.emplace_back(V); continue; } auto Res = UniquePositions.try_emplace(V, UniqueValues.size()); - ReuseShuffleIndicies.emplace_back(Res.first->second); + ReuseShuffleIndices.emplace_back(Res.first->second); if (Res.second) UniqueValues.emplace_back(V); } size_t NumUniqueScalarValues = UniqueValues.size(); if (NumUniqueScalarValues == VL.size()) { - ReuseShuffleIndicies.clear(); + ReuseShuffleIndices.clear(); } else { + // FIXME: Reshuffing scalars is not supported yet for non-power-of-2 ops. + if (UserTreeIdx.UserTE && UserTreeIdx.UserTE->isNonPowOf2Vec()) { + LLVM_DEBUG(dbgs() << "SLP: Reshuffling scalars not yet supported " + "for nodes with padding.\n"); + newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx); + return false; + } LLVM_DEBUG(dbgs() << "SLP: Shuffle for reused scalars.\n"); if (NumUniqueScalarValues <= 1 || (UniquePositions.size() == 1 && all_of(UniqueValues, @@ -5581,7 +6735,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, })) { unsigned PWSz = PowerOf2Ceil(UniqueValues.size()); if (PWSz == VL.size()) { - ReuseShuffleIndicies.clear(); + ReuseShuffleIndices.clear(); } else { NonUniqueValueVL.assign(UniqueValues.begin(), UniqueValues.end()); NonUniqueValueVL.append(PWSz - UniqueValues.size(), @@ -5628,7 +6782,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, LLVM_DEBUG(dbgs() << "SLP: Gathering due to max recursion depth.\n"); if (TryToFindDuplicates(S)) newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); + ReuseShuffleIndices); return; } @@ -5639,12 +6793,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, LLVM_DEBUG(dbgs() << "SLP: Gathering due to scalable vector type.\n"); if (TryToFindDuplicates(S)) newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); + ReuseShuffleIndices); return; } // Don't handle vectors. - if (S.OpValue->getType()->isVectorTy() && + if (!SLPReVec && S.OpValue->getType()->isVectorTy() && !isa<InsertElementInst>(S.OpValue)) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to vector type.\n"); newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx); @@ -5652,7 +6806,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } if (StoreInst *SI = dyn_cast<StoreInst>(S.OpValue)) - if (SI->getValueOperand()->getType()->isVectorTy()) { + if (!SLPReVec && SI->getValueOperand()->getType()->isVectorTy()) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to store vector type.\n"); newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx); return; @@ -5718,11 +6872,10 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, BasicBlock *BB = nullptr; bool IsScatterVectorizeUserTE = UserTreeIdx.UserTE && - (UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize || - UserTreeIdx.UserTE->State == TreeEntry::PossibleStridedVectorize); - bool AreAllSameInsts = - (S.getOpcode() && allSameBlock(VL)) || - (S.OpValue->getType()->isPointerTy() && IsScatterVectorizeUserTE && + UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize; + bool AreAllSameBlock = S.getOpcode() && allSameBlock(VL); + bool AreScatterAllGEPSameBlock = + (IsScatterVectorizeUserTE && S.OpValue->getType()->isPointerTy() && VL.size() > 2 && all_of(VL, [&BB](Value *V) { @@ -5736,6 +6889,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, BB && sortPtrAccesses(VL, UserTreeIdx.UserTE->getMainOp()->getType(), *DL, *SE, SortedIndices)); + bool AreAllSameInsts = AreAllSameBlock || AreScatterAllGEPSameBlock; if (!AreAllSameInsts || allConstant(VL) || isSplat(VL) || (isa<InsertElementInst, ExtractValueInst, ExtractElementInst>( S.OpValue) && @@ -5744,7 +6898,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, LLVM_DEBUG(dbgs() << "SLP: Gathering due to C,S,B,O, small shuffle. \n"); if (TryToFindDuplicates(S)) newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); + ReuseShuffleIndices); return; } @@ -5772,7 +6926,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, LLVM_DEBUG(dbgs() << "SLP: Gathering due to partial overlap.\n"); if (TryToFindDuplicates(S)) newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); + ReuseShuffleIndices); return; } } else { @@ -5795,7 +6949,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, << ") is already in tree.\n"); if (TryToFindDuplicates(S)) newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); + ReuseShuffleIndices); return; } } @@ -5807,7 +6961,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, LLVM_DEBUG(dbgs() << "SLP: Gathering due to gathered scalar.\n"); if (TryToFindDuplicates(S)) newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); + ReuseShuffleIndices); return; } } @@ -5815,16 +6969,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // Special processing for sorted pointers for ScatterVectorize node with // constant indeces only. - if (AreAllSameInsts && UserTreeIdx.UserTE && - (UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize || - UserTreeIdx.UserTE->State == TreeEntry::PossibleStridedVectorize) && - !(S.getOpcode() && allSameBlock(VL))) { + if (!AreAllSameBlock && AreScatterAllGEPSameBlock) { assert(S.OpValue->getType()->isPointerTy() && - count_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); }) >= - 2 && + count_if(VL, IsaPred<GetElementPtrInst>) >= 2 && "Expected pointers only."); // Reset S to make it GetElementPtr kind of node. - const auto *It = find_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); }); + const auto *It = find_if(VL, IsaPred<GetElementPtrInst>); assert(It != VL.end() && "Expected at least one GEP."); S = getSameOpcode(*It, *TLI); } @@ -5862,7 +7012,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, S, VL, IsScatterVectorizeUserTE, CurrentOrder, PointerOps); if (State == TreeEntry::NeedToGather) { newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); + ReuseShuffleIndices); return; } @@ -5884,7 +7034,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, !BS.getScheduleData(VL0)->isPartOfBundle()) && "tryScheduleBundle should cancelScheduling on failure"); newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); + ReuseShuffleIndices); + NonScheduledFirst.insert(VL.front()); return; } LLVM_DEBUG(dbgs() << "SLP: We are able to schedule this bundle.\n"); @@ -5896,55 +7047,36 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, auto *PH = cast<PHINode>(VL0); TreeEntry *TE = - newTreeEntry(VL, Bundle, S, UserTreeIdx, ReuseShuffleIndicies); + newTreeEntry(VL, Bundle, S, UserTreeIdx, ReuseShuffleIndices); LLVM_DEBUG(dbgs() << "SLP: added a vector of PHINodes.\n"); // Keeps the reordered operands to avoid code duplication. - SmallVector<ValueList, 2> OperandsVec; - for (unsigned I = 0, E = PH->getNumIncomingValues(); I < E; ++I) { - if (!DT->isReachableFromEntry(PH->getIncomingBlock(I))) { - ValueList Operands(VL.size(), PoisonValue::get(PH->getType())); - TE->setOperand(I, Operands); - OperandsVec.push_back(Operands); - continue; - } - ValueList Operands; - // Prepare the operand vector. - for (Value *V : VL) - Operands.push_back(cast<PHINode>(V)->getIncomingValueForBlock( - PH->getIncomingBlock(I))); - TE->setOperand(I, Operands); - OperandsVec.push_back(Operands); - } - for (unsigned OpIdx = 0, OpE = OperandsVec.size(); OpIdx != OpE; ++OpIdx) - buildTree_rec(OperandsVec[OpIdx], Depth + 1, {TE, OpIdx}); + PHIHandler Handler(*DT, PH, VL); + Handler.buildOperands(); + for (unsigned I : seq<unsigned>(0, PH->getNumOperands())) + TE->setOperand(I, Handler.getOperands(I)); + for (unsigned I : seq<unsigned>(0, PH->getNumOperands())) + buildTree_rec(Handler.getOperands(I), Depth + 1, {TE, I}); return; } case Instruction::ExtractValue: case Instruction::ExtractElement: { if (CurrentOrder.empty()) { LLVM_DEBUG(dbgs() << "SLP: Reusing or shuffling extract sequence.\n"); - newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - // This is a special case, as it does not gather, but at the same time - // we are not extending buildTree_rec() towards the operands. - ValueList Op0; - Op0.assign(VL.size(), VL0->getOperand(0)); - VectorizableTree.back()->setOperand(0, Op0); - return; + } else { + LLVM_DEBUG({ + dbgs() << "SLP: Reusing or shuffling of reordered extract sequence " + "with order"; + for (unsigned Idx : CurrentOrder) + dbgs() << " " << Idx; + dbgs() << "\n"; + }); + fixupOrderingIndices(CurrentOrder); } - LLVM_DEBUG({ - dbgs() << "SLP: Reusing or shuffling of reordered extract sequence " - "with order"; - for (unsigned Idx : CurrentOrder) - dbgs() << " " << Idx; - dbgs() << "\n"; - }); - fixupOrderingIndices(CurrentOrder); // Insert new order with initial value 0, if it does not exist, // otherwise return the iterator to the existing one. newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies, CurrentOrder); + ReuseShuffleIndices, CurrentOrder); // This is a special case, as it does not gather, but at the same time // we are not extending buildTree_rec() towards the operands. ValueList Op0; @@ -5953,7 +7085,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, return; } case Instruction::InsertElement: { - assert(ReuseShuffleIndicies.empty() && "All inserts should be unique"); + assert(ReuseShuffleIndices.empty() && "All inserts should be unique"); auto OrdCompare = [](const std::pair<int, int> &P1, const std::pair<int, int> &P2) { @@ -5963,7 +7095,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, decltype(OrdCompare)> Indices(OrdCompare); for (int I = 0, E = VL.size(); I < E; ++I) { - unsigned Idx = *getInsertIndex(VL[I]); + unsigned Idx = *getElementIndex(VL[I]); Indices.emplace(Idx, I); } OrdersType CurrentOrder(VL.size(), VL.size()); @@ -5979,15 +7111,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, std::nullopt, CurrentOrder); LLVM_DEBUG(dbgs() << "SLP: added inserts bundle.\n"); - constexpr int NumOps = 2; - ValueList VectorOperands[NumOps]; - for (int I = 0; I < NumOps; ++I) { - for (Value *V : VL) - VectorOperands[I].push_back(cast<Instruction>(V)->getOperand(I)); - - TE->setOperand(I, VectorOperands[I]); - } - buildTree_rec(VectorOperands[NumOps - 1], Depth + 1, {TE, NumOps - 1}); + TE->setOperandsInOrder(); + buildTree_rec(TE->getOperand(1), Depth + 1, {TE, 1}); return; } case Instruction::Load: { @@ -6001,36 +7126,25 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, fixupOrderingIndices(CurrentOrder); switch (State) { case TreeEntry::Vectorize: - if (CurrentOrder.empty()) { - // Original loads are consecutive and does not require reordering. - TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); + TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, + ReuseShuffleIndices, CurrentOrder); + if (CurrentOrder.empty()) LLVM_DEBUG(dbgs() << "SLP: added a vector of loads.\n"); - } else { - // Need to reorder. - TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies, CurrentOrder); + else LLVM_DEBUG(dbgs() << "SLP: added a vector of jumbled loads.\n"); - } TE->setOperandsInOrder(); break; - case TreeEntry::PossibleStridedVectorize: + case TreeEntry::StridedVectorize: // Vectorizing non-consecutive loads with `llvm.masked.gather`. - if (CurrentOrder.empty()) { - TE = newTreeEntry(VL, TreeEntry::PossibleStridedVectorize, Bundle, S, - UserTreeIdx, ReuseShuffleIndicies); - } else { - TE = newTreeEntry(VL, TreeEntry::PossibleStridedVectorize, Bundle, S, - UserTreeIdx, ReuseShuffleIndicies, CurrentOrder); - } + TE = newTreeEntry(VL, TreeEntry::StridedVectorize, Bundle, S, + UserTreeIdx, ReuseShuffleIndices, CurrentOrder); TE->setOperandsInOrder(); - buildTree_rec(PointerOps, Depth + 1, {TE, 0}); - LLVM_DEBUG(dbgs() << "SLP: added a vector of non-consecutive loads.\n"); + LLVM_DEBUG(dbgs() << "SLP: added a vector of strided loads.\n"); break; case TreeEntry::ScatterVectorize: // Vectorizing non-consecutive loads with `llvm.masked.gather`. TE = newTreeEntry(VL, TreeEntry::ScatterVectorize, Bundle, S, - UserTreeIdx, ReuseShuffleIndicies); + UserTreeIdx, ReuseShuffleIndices); TE->setOperandsInOrder(); buildTree_rec(PointerOps, Depth + 1, {TE, 0}); LLVM_DEBUG(dbgs() << "SLP: added a vector of non-consecutive loads.\n"); @@ -6052,19 +7166,44 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, case Instruction::Trunc: case Instruction::FPTrunc: case Instruction::BitCast: { + auto [PrevMaxBW, PrevMinBW] = CastMaxMinBWSizes.value_or( + std::make_pair(std::numeric_limits<unsigned>::min(), + std::numeric_limits<unsigned>::max())); + if (ShuffleOrOp == Instruction::ZExt || + ShuffleOrOp == Instruction::SExt) { + CastMaxMinBWSizes = std::make_pair( + std::max<unsigned>(DL->getTypeSizeInBits(VL0->getType()), + PrevMaxBW), + std::min<unsigned>( + DL->getTypeSizeInBits(VL0->getOperand(0)->getType()), + PrevMinBW)); + } else if (ShuffleOrOp == Instruction::Trunc) { + CastMaxMinBWSizes = std::make_pair( + std::max<unsigned>( + DL->getTypeSizeInBits(VL0->getOperand(0)->getType()), + PrevMaxBW), + std::min<unsigned>(DL->getTypeSizeInBits(VL0->getType()), + PrevMinBW)); + ExtraBitWidthNodes.insert(VectorizableTree.size() + 1); + } else if (ShuffleOrOp == Instruction::SIToFP || + ShuffleOrOp == Instruction::UIToFP) { + unsigned NumSignBits = + ComputeNumSignBits(VL0->getOperand(0), *DL, 0, AC, nullptr, DT); + if (auto *OpI = dyn_cast<Instruction>(VL0->getOperand(0))) { + APInt Mask = DB->getDemandedBits(OpI); + NumSignBits = std::max(NumSignBits, Mask.countl_zero()); + } + if (NumSignBits * 2 >= + DL->getTypeSizeInBits(VL0->getOperand(0)->getType())) + ExtraBitWidthNodes.insert(VectorizableTree.size() + 1); + } TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); + ReuseShuffleIndices); LLVM_DEBUG(dbgs() << "SLP: added a vector of casts.\n"); TE->setOperandsInOrder(); - for (unsigned I : seq<unsigned>(0, VL0->getNumOperands())) { - ValueList Operands; - // Prepare the operand vector. - for (Value *V : VL) - Operands.push_back(cast<Instruction>(V)->getOperand(I)); - - buildTree_rec(Operands, Depth + 1, {TE, I}); - } + for (unsigned I : seq<unsigned>(0, VL0->getNumOperands())) + buildTree_rec(TE->getOperand(I), Depth + 1, {TE, I}); return; } case Instruction::ICmp: @@ -6072,7 +7211,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // Check that all of the compares have the same predicate. CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate(); TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); + ReuseShuffleIndices); LLVM_DEBUG(dbgs() << "SLP: added a vector of compares.\n"); ValueList Left, Right; @@ -6081,7 +7220,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // so that each side is more likely to have the same opcode. assert(P0 == CmpInst::getSwappedPredicate(P0) && "Commutative Predicate mismatch"); - reorderInputsAccordingToOpcode(VL, Left, Right, *TLI, *DL, *SE, *this); + reorderInputsAccordingToOpcode(VL, Left, Right, *this); } else { // Collect operands - commute if it uses the swapped predicate. for (Value *V : VL) { @@ -6098,6 +7237,18 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, TE->setOperand(1, Right); buildTree_rec(Left, Depth + 1, {TE, 0}); buildTree_rec(Right, Depth + 1, {TE, 1}); + if (ShuffleOrOp == Instruction::ICmp) { + unsigned NumSignBits0 = + ComputeNumSignBits(VL0->getOperand(0), *DL, 0, AC, nullptr, DT); + if (NumSignBits0 * 2 >= + DL->getTypeSizeInBits(VL0->getOperand(0)->getType())) + ExtraBitWidthNodes.insert(getOperandEntry(TE, 0)->Idx); + unsigned NumSignBits1 = + ComputeNumSignBits(VL0->getOperand(1), *DL, 0, AC, nullptr, DT); + if (NumSignBits1 * 2 >= + DL->getTypeSizeInBits(VL0->getOperand(1)->getType())) + ExtraBitWidthNodes.insert(getOperandEntry(TE, 1)->Idx); + } return; } case Instruction::Select: @@ -6121,14 +7272,14 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, case Instruction::Or: case Instruction::Xor: { TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); + ReuseShuffleIndices); LLVM_DEBUG(dbgs() << "SLP: added a vector of un/bin op.\n"); // Sort operands of the instructions so that each side is more likely to // have the same opcode. - if (isa<BinaryOperator>(VL0) && VL0->isCommutative()) { + if (isa<BinaryOperator>(VL0) && isCommutative(VL0)) { ValueList Left, Right; - reorderInputsAccordingToOpcode(VL, Left, Right, *TLI, *DL, *SE, *this); + reorderInputsAccordingToOpcode(VL, Left, Right, *this); TE->setOperand(0, Left); TE->setOperand(1, Right); buildTree_rec(Left, Depth + 1, {TE, 0}); @@ -6137,19 +7288,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } TE->setOperandsInOrder(); - for (unsigned I : seq<unsigned>(0, VL0->getNumOperands())) { - ValueList Operands; - // Prepare the operand vector. - for (Value *V : VL) - Operands.push_back(cast<Instruction>(V)->getOperand(I)); - - buildTree_rec(Operands, Depth + 1, {TE, I}); - } + for (unsigned I : seq<unsigned>(0, VL0->getNumOperands())) + buildTree_rec(TE->getOperand(I), Depth + 1, {TE, I}); return; } case Instruction::GetElementPtr: { TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); + ReuseShuffleIndices); LLVM_DEBUG(dbgs() << "SLP: added a vector of GEPs.\n"); SmallVector<ValueList, 2> Operands(2); // Prepare the operand vector for pointer operands. @@ -6203,30 +7348,17 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, return; } case Instruction::Store: { - // Check if the stores are consecutive or if we need to swizzle them. - ValueList Operands(VL.size()); - auto *OIter = Operands.begin(); - for (Value *V : VL) { - auto *SI = cast<StoreInst>(V); - *OIter = SI->getValueOperand(); - ++OIter; - } - // Check that the sorted pointer operands are consecutive. - if (CurrentOrder.empty()) { - // Original stores are consecutive and does not require reordering. - TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); - TE->setOperandsInOrder(); - buildTree_rec(Operands, Depth + 1, {TE, 0}); - LLVM_DEBUG(dbgs() << "SLP: added a vector of stores.\n"); - } else { + bool Consecutive = CurrentOrder.empty(); + if (!Consecutive) fixupOrderingIndices(CurrentOrder); - TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies, CurrentOrder); - TE->setOperandsInOrder(); - buildTree_rec(Operands, Depth + 1, {TE, 0}); + TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, + ReuseShuffleIndices, CurrentOrder); + TE->setOperandsInOrder(); + buildTree_rec(TE->getOperand(0), Depth + 1, {TE, 0}); + if (Consecutive) + LLVM_DEBUG(dbgs() << "SLP: added a vector of stores.\n"); + else LLVM_DEBUG(dbgs() << "SLP: added a vector of jumbled stores.\n"); - } return; } case Instruction::Call: { @@ -6236,7 +7368,34 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); + ReuseShuffleIndices); + // Sort operands of the instructions so that each side is more likely to + // have the same opcode. + if (isCommutative(VL0)) { + ValueList Left, Right; + reorderInputsAccordingToOpcode(VL, Left, Right, *this); + TE->setOperand(0, Left); + TE->setOperand(1, Right); + SmallVector<ValueList> Operands; + for (unsigned I : seq<unsigned>(2, CI->arg_size())) { + Operands.emplace_back(); + if (isVectorIntrinsicWithScalarOpAtArg(ID, I)) + continue; + for (Value *V : VL) { + auto *CI2 = cast<CallInst>(V); + Operands.back().push_back(CI2->getArgOperand(I)); + } + TE->setOperand(I, Operands.back()); + } + buildTree_rec(Left, Depth + 1, {TE, 0}); + buildTree_rec(Right, Depth + 1, {TE, 1}); + for (unsigned I : seq<unsigned>(2, CI->arg_size())) { + if (Operands[I - 2].empty()) + continue; + buildTree_rec(Operands[I - 2], Depth + 1, {TE, I}); + } + return; + } TE->setOperandsInOrder(); for (unsigned I : seq<unsigned>(0, CI->arg_size())) { // For scalar operands no need to create an entry since no need to @@ -6255,7 +7414,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } case Instruction::ShuffleVector: { TreeEntry *TE = newTreeEntry(VL, Bundle /*vectorized*/, S, UserTreeIdx, - ReuseShuffleIndicies); + ReuseShuffleIndices); LLVM_DEBUG(dbgs() << "SLP: added a ShuffleVector op.\n"); // Reorder operands if reordering would enable vectorization. @@ -6265,8 +7424,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (!CI || all_of(VL, [](Value *V) { return cast<CmpInst>(V)->isCommutative(); })) { - reorderInputsAccordingToOpcode(VL, Left, Right, *TLI, *DL, *SE, - *this); + reorderInputsAccordingToOpcode(VL, Left, Right, *this); } else { auto *MainCI = cast<CmpInst>(S.MainOp); auto *AltCI = cast<CmpInst>(S.AltOp); @@ -6300,14 +7458,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } TE->setOperandsInOrder(); - for (unsigned I : seq<unsigned>(0, VL0->getNumOperands())) { - ValueList Operands; - // Prepare the operand vector. - for (Value *V : VL) - Operands.push_back(cast<Instruction>(V)->getOperand(I)); - - buildTree_rec(Operands, Depth + 1, {TE, I}); - } + for (unsigned I : seq<unsigned>(0, VL0->getNumOperands())) + buildTree_rec(TE->getOperand(I), Depth + 1, {TE, I}); return; } default: @@ -6340,7 +7492,7 @@ unsigned BoUpSLP::canMapToVector(Type *T) const { if (!isValidElementType(EltTy)) return 0; - uint64_t VTSize = DL->getTypeStoreSizeInBits(FixedVectorType::get(EltTy, N)); + uint64_t VTSize = DL->getTypeStoreSizeInBits(getWidenedType(EltTy, N)); if (VTSize < MinVecRegSize || VTSize > MaxVecRegSize || VTSize != DL->getTypeStoreSizeInBits(T)) return 0; @@ -6350,17 +7502,12 @@ unsigned BoUpSLP::canMapToVector(Type *T) const { bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, Value *OpValue, SmallVectorImpl<unsigned> &CurrentOrder, bool ResizeAllowed) const { - const auto *It = find_if(VL, [](Value *V) { - return isa<ExtractElementInst, ExtractValueInst>(V); - }); + const auto *It = find_if(VL, IsaPred<ExtractElementInst, ExtractValueInst>); assert(It != VL.end() && "Expected at least one extract instruction."); auto *E0 = cast<Instruction>(*It); - assert(all_of(VL, - [](Value *V) { - return isa<UndefValue, ExtractElementInst, ExtractValueInst>( - V); - }) && - "Invalid opcode"); + assert( + all_of(VL, IsaPred<UndefValue, ExtractElementInst, ExtractValueInst>) && + "Invalid opcode"); // Check if all of the extracts come from the same vector and from the // correct offset. Value *Vec = E0->getOperand(0); @@ -6449,19 +7596,16 @@ bool BoUpSLP::areAllUsersVectorized( static std::pair<InstructionCost, InstructionCost> getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy, - TargetTransformInfo *TTI, TargetLibraryInfo *TLI) { + TargetTransformInfo *TTI, TargetLibraryInfo *TLI, + ArrayRef<Type *> ArgTys) { Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); // Calculate the cost of the scalar and vector calls. - SmallVector<Type *, 4> VecTys; - for (Use &Arg : CI->args()) - VecTys.push_back( - FixedVectorType::get(Arg->getType(), VecTy->getNumElements())); FastMathFlags FMF; if (auto *FPCI = dyn_cast<FPMathOperator>(CI)) FMF = FPCI->getFastMathFlags(); SmallVector<const Value *> Arguments(CI->args()); - IntrinsicCostAttributes CostAttrs(ID, VecTy, Arguments, VecTys, FMF, + IntrinsicCostAttributes CostAttrs(ID, VecTy, Arguments, ArgTys, FMF, dyn_cast<IntrinsicInst>(CI)); auto IntrinsicCost = TTI->getIntrinsicInstrCost(CostAttrs, TTI::TCK_RecipThroughput); @@ -6474,8 +7618,8 @@ getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy, if (!CI->isNoBuiltin() && VecFunc) { // Calculate the cost of the vector library call. // If the corresponding vector call is cheaper, return its cost. - LibCost = TTI->getCallInstrCost(nullptr, VecTy, VecTys, - TTI::TCK_RecipThroughput); + LibCost = + TTI->getCallInstrCost(nullptr, VecTy, ArgTys, TTI::TCK_RecipThroughput); } return {IntrinsicCost, LibCost}; } @@ -6913,12 +8057,162 @@ getShuffleCost(const TargetTransformInfo &TTI, TTI::ShuffleKind Kind, Index + NumSrcElts <= static_cast<int>(Mask.size())) return TTI.getShuffleCost( TTI::SK_InsertSubvector, - FixedVectorType::get(Tp->getElementType(), Mask.size()), std::nullopt, + getWidenedType(Tp->getElementType(), Mask.size()), Mask, TTI::TCK_RecipThroughput, Index, Tp); } return TTI.getShuffleCost(Kind, Tp, Mask, CostKind, Index, SubTp, Args); } +/// Calculate the scalar and the vector costs from vectorizing set of GEPs. +static std::pair<InstructionCost, InstructionCost> +getGEPCosts(const TargetTransformInfo &TTI, ArrayRef<Value *> Ptrs, + Value *BasePtr, unsigned Opcode, TTI::TargetCostKind CostKind, + Type *ScalarTy, VectorType *VecTy) { + InstructionCost ScalarCost = 0; + InstructionCost VecCost = 0; + // Here we differentiate two cases: (1) when Ptrs represent a regular + // vectorization tree node (as they are pointer arguments of scattered + // loads) or (2) when Ptrs are the arguments of loads or stores being + // vectorized as plane wide unit-stride load/store since all the + // loads/stores are known to be from/to adjacent locations. + if (Opcode == Instruction::Load || Opcode == Instruction::Store) { + // Case 2: estimate costs for pointer related costs when vectorizing to + // a wide load/store. + // Scalar cost is estimated as a set of pointers with known relationship + // between them. + // For vector code we will use BasePtr as argument for the wide load/store + // but we also need to account all the instructions which are going to + // stay in vectorized code due to uses outside of these scalar + // loads/stores. + ScalarCost = TTI.getPointersChainCost( + Ptrs, BasePtr, TTI::PointersChainInfo::getUnitStride(), ScalarTy, + CostKind); + + SmallVector<const Value *> PtrsRetainedInVecCode; + for (Value *V : Ptrs) { + if (V == BasePtr) { + PtrsRetainedInVecCode.push_back(V); + continue; + } + auto *Ptr = dyn_cast<GetElementPtrInst>(V); + // For simplicity assume Ptr to stay in vectorized code if it's not a + // GEP instruction. We don't care since it's cost considered free. + // TODO: We should check for any uses outside of vectorizable tree + // rather than just single use. + if (!Ptr || !Ptr->hasOneUse()) + PtrsRetainedInVecCode.push_back(V); + } + + if (PtrsRetainedInVecCode.size() == Ptrs.size()) { + // If all pointers stay in vectorized code then we don't have + // any savings on that. + return std::make_pair(TTI::TCC_Free, TTI::TCC_Free); + } + VecCost = TTI.getPointersChainCost(PtrsRetainedInVecCode, BasePtr, + TTI::PointersChainInfo::getKnownStride(), + VecTy, CostKind); + } else { + // Case 1: Ptrs are the arguments of loads that we are going to transform + // into masked gather load intrinsic. + // All the scalar GEPs will be removed as a result of vectorization. + // For any external uses of some lanes extract element instructions will + // be generated (which cost is estimated separately). + TTI::PointersChainInfo PtrsInfo = + all_of(Ptrs, + [](const Value *V) { + auto *Ptr = dyn_cast<GetElementPtrInst>(V); + return Ptr && !Ptr->hasAllConstantIndices(); + }) + ? TTI::PointersChainInfo::getUnknownStride() + : TTI::PointersChainInfo::getKnownStride(); + + ScalarCost = + TTI.getPointersChainCost(Ptrs, BasePtr, PtrsInfo, ScalarTy, CostKind); + auto *BaseGEP = dyn_cast<GEPOperator>(BasePtr); + if (!BaseGEP) { + auto *It = find_if(Ptrs, IsaPred<GEPOperator>); + if (It != Ptrs.end()) + BaseGEP = cast<GEPOperator>(*It); + } + if (BaseGEP) { + SmallVector<const Value *> Indices(BaseGEP->indices()); + VecCost = TTI.getGEPCost(BaseGEP->getSourceElementType(), + BaseGEP->getPointerOperand(), Indices, VecTy, + CostKind); + } + } + + return std::make_pair(ScalarCost, VecCost); +} + +void BoUpSLP::transformNodes() { + constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + for (std::unique_ptr<TreeEntry> &TE : VectorizableTree) { + TreeEntry &E = *TE; + switch (E.getOpcode()) { + case Instruction::Load: { + // No need to reorder masked gather loads, just reorder the scalar + // operands. + if (E.State != TreeEntry::Vectorize) + break; + Type *ScalarTy = E.getMainOp()->getType(); + auto *VecTy = getWidenedType(ScalarTy, E.Scalars.size()); + Align CommonAlignment = computeCommonAlignment<LoadInst>(E.Scalars); + // Check if profitable to represent consecutive load + reverse as strided + // load with stride -1. + if (isReverseOrder(E.ReorderIndices) && + TTI->isLegalStridedLoadStore(VecTy, CommonAlignment)) { + SmallVector<int> Mask; + inversePermutation(E.ReorderIndices, Mask); + auto *BaseLI = cast<LoadInst>(E.Scalars.back()); + InstructionCost OriginalVecCost = + TTI->getMemoryOpCost(Instruction::Load, VecTy, BaseLI->getAlign(), + BaseLI->getPointerAddressSpace(), CostKind, + TTI::OperandValueInfo()) + + ::getShuffleCost(*TTI, TTI::SK_Reverse, VecTy, Mask, CostKind); + InstructionCost StridedCost = TTI->getStridedMemoryOpCost( + Instruction::Load, VecTy, BaseLI->getPointerOperand(), + /*VariableMask=*/false, CommonAlignment, CostKind, BaseLI); + if (StridedCost < OriginalVecCost) + // Strided load is more profitable than consecutive load + reverse - + // transform the node to strided load. + E.State = TreeEntry::StridedVectorize; + } + break; + } + case Instruction::Store: { + Type *ScalarTy = + cast<StoreInst>(E.getMainOp())->getValueOperand()->getType(); + auto *VecTy = getWidenedType(ScalarTy, E.Scalars.size()); + Align CommonAlignment = computeCommonAlignment<StoreInst>(E.Scalars); + // Check if profitable to represent consecutive load + reverse as strided + // load with stride -1. + if (isReverseOrder(E.ReorderIndices) && + TTI->isLegalStridedLoadStore(VecTy, CommonAlignment)) { + SmallVector<int> Mask; + inversePermutation(E.ReorderIndices, Mask); + auto *BaseSI = cast<StoreInst>(E.Scalars.back()); + InstructionCost OriginalVecCost = + TTI->getMemoryOpCost(Instruction::Store, VecTy, BaseSI->getAlign(), + BaseSI->getPointerAddressSpace(), CostKind, + TTI::OperandValueInfo()) + + ::getShuffleCost(*TTI, TTI::SK_Reverse, VecTy, Mask, CostKind); + InstructionCost StridedCost = TTI->getStridedMemoryOpCost( + Instruction::Store, VecTy, BaseSI->getPointerOperand(), + /*VariableMask=*/false, CommonAlignment, CostKind, BaseSI); + if (StridedCost < OriginalVecCost) + // Strided load is more profitable than consecutive load + reverse - + // transform the node to strided load. + E.State = TreeEntry::StridedVectorize; + } + break; + } + default: + break; + } + } +} + /// Merges shuffle masks and emits final shuffle instruction, if required. It /// supports shuffling of 2 input vectors. It implements lazy shuffles emission, /// when the actual shuffle instruction is generated only if this is actually @@ -6929,6 +8223,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { bool IsFinalized = false; SmallVector<int> CommonMask; SmallVector<PointerUnion<Value *, const TreeEntry *>, 2> InVectors; + Type *ScalarTy = nullptr; const TargetTransformInfo &TTI; InstructionCost Cost = 0; SmallDenseSet<Value *> VectorizedVals; @@ -6956,15 +8251,15 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { } InstructionCost getBuildVectorCost(ArrayRef<Value *> VL, Value *Root) { - if ((!Root && allConstant(VL)) || all_of(VL, UndefValue::classof)) + if ((!Root && allConstant(VL)) || all_of(VL, IsaPred<UndefValue>)) return TTI::TCC_Free; - auto *VecTy = FixedVectorType::get(VL.front()->getType(), VL.size()); + auto *VecTy = getWidenedType(ScalarTy, VL.size()); InstructionCost GatherCost = 0; SmallVector<Value *> Gathers(VL.begin(), VL.end()); // Improve gather cost for gather of loads, if we can group some of the // loads into vector loads. InstructionsState S = getSameOpcode(VL, *R.TLI); - const unsigned Sz = R.DL->getTypeSizeInBits(VL.front()->getType()); + const unsigned Sz = R.DL->getTypeSizeInBits(ScalarTy); unsigned MinVF = R.getMinVF(2 * Sz); if (VL.size() > 2 && ((S.getOpcode() == Instruction::Load && !S.isAltShuffle()) || @@ -6978,9 +8273,10 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { }))) && !all_of(Gathers, [&](Value *V) { return R.getTreeEntry(V); }) && !isSplat(Gathers)) { + InstructionCost BaseCost = R.getGatherCost(Gathers, !Root, ScalarTy); SetVector<Value *> VectorizedLoads; - SmallVector<LoadInst *> VectorizedStarts; - SmallVector<std::pair<unsigned, unsigned>> ScatterVectorized; + SmallVector<std::pair<unsigned, LoadsState>> VectorizedStarts; + SmallVector<unsigned> ScatterVectorized; unsigned StartIdx = 0; unsigned VF = VL.size() / 2; for (; VF >= MinVF; VF /= 2) { @@ -6997,20 +8293,23 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { !VectorizedLoads.count(Slice.back()) && allSameBlock(Slice)) { SmallVector<Value *> PointerOps; OrdersType CurrentOrder; - LoadsState LS = - canVectorizeLoads(Slice, Slice.front(), TTI, *R.DL, *R.SE, - *R.LI, *R.TLI, CurrentOrder, PointerOps); + LoadsState LS = R.canVectorizeLoads(Slice, Slice.front(), + CurrentOrder, PointerOps); switch (LS) { case LoadsState::Vectorize: case LoadsState::ScatterVectorize: - case LoadsState::PossibleStridedVectorize: + case LoadsState::StridedVectorize: // Mark the vectorized loads so that we don't vectorize them // again. // TODO: better handling of loads with reorders. - if (LS == LoadsState::Vectorize && CurrentOrder.empty()) - VectorizedStarts.push_back(cast<LoadInst>(Slice.front())); + if (((LS == LoadsState::Vectorize || + LS == LoadsState::StridedVectorize) && + CurrentOrder.empty()) || + (LS == LoadsState::StridedVectorize && + isReverseOrder(CurrentOrder))) + VectorizedStarts.emplace_back(Cnt, LS); else - ScatterVectorized.emplace_back(Cnt, VF); + ScatterVectorized.push_back(Cnt); VectorizedLoads.insert(Slice.begin(), Slice.end()); // If we vectorized initial block, no need to try to vectorize // it again. @@ -7037,7 +8336,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { for (unsigned I = 0, End = VL.size(); I < End; I += VF) { if (VectorizedLoads.contains(VL[I])) continue; - GatherCost += getBuildVectorCost(VL.slice(I, VF), Root); + GatherCost += + getBuildVectorCost(VL.slice(I, std::min(End - I, VF)), Root); } // Exclude potentially vectorized loads from list of gathered // scalars. @@ -7051,57 +8351,104 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { LI->getAlign(), LI->getPointerAddressSpace(), CostKind, TTI::OperandValueInfo(), LI); } - auto *LoadTy = FixedVectorType::get(VL.front()->getType(), VF); - for (LoadInst *LI : VectorizedStarts) { + auto *LoadTy = getWidenedType(VL.front()->getType(), VF); + for (const std::pair<unsigned, LoadsState> &P : VectorizedStarts) { + auto *LI = cast<LoadInst>(VL[P.first]); Align Alignment = LI->getAlign(); GatherCost += - TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, - LI->getPointerAddressSpace(), CostKind, - TTI::OperandValueInfo(), LI); + P.second == LoadsState::Vectorize + ? TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, + LI->getPointerAddressSpace(), CostKind, + TTI::OperandValueInfo(), LI) + : TTI.getStridedMemoryOpCost( + Instruction::Load, LoadTy, LI->getPointerOperand(), + /*VariableMask=*/false, Alignment, CostKind, LI); + // Estimate GEP cost. + SmallVector<Value *> PointerOps(VF); + for (auto [I, V] : enumerate(VL.slice(P.first, VF))) + PointerOps[I] = cast<LoadInst>(V)->getPointerOperand(); + auto [ScalarGEPCost, VectorGEPCost] = + getGEPCosts(TTI, PointerOps, LI->getPointerOperand(), + Instruction::Load, CostKind, LI->getType(), LoadTy); + GatherCost += VectorGEPCost - ScalarGEPCost; } - for (std::pair<unsigned, unsigned> P : ScatterVectorized) { - auto *LI0 = cast<LoadInst>(VL[P.first]); - Align CommonAlignment = LI0->getAlign(); - for (Value *V : VL.slice(P.first + 1, VF - 1)) - CommonAlignment = - std::min(CommonAlignment, cast<LoadInst>(V)->getAlign()); + for (unsigned P : ScatterVectorized) { + auto *LI0 = cast<LoadInst>(VL[P]); + ArrayRef<Value *> Slice = VL.slice(P, VF); + Align CommonAlignment = computeCommonAlignment<LoadInst>(Slice); GatherCost += TTI.getGatherScatterOpCost( Instruction::Load, LoadTy, LI0->getPointerOperand(), /*VariableMask=*/false, CommonAlignment, CostKind, LI0); + // Estimate GEP cost. + SmallVector<Value *> PointerOps(VF); + for (auto [I, V] : enumerate(Slice)) + PointerOps[I] = cast<LoadInst>(V)->getPointerOperand(); + OrdersType Order; + if (sortPtrAccesses(PointerOps, LI0->getType(), *R.DL, *R.SE, + Order)) { + // TODO: improve checks if GEPs can be vectorized. + Value *Ptr0 = PointerOps.front(); + Type *ScalarTy = Ptr0->getType(); + auto *VecTy = getWidenedType(ScalarTy, VF); + auto [ScalarGEPCost, VectorGEPCost] = + getGEPCosts(TTI, PointerOps, Ptr0, Instruction::GetElementPtr, + CostKind, ScalarTy, VecTy); + GatherCost += VectorGEPCost - ScalarGEPCost; + if (!Order.empty()) { + SmallVector<int> Mask; + inversePermutation(Order, Mask); + GatherCost += ::getShuffleCost(TTI, TTI::SK_PermuteSingleSrc, + VecTy, Mask, CostKind); + } + } else { + GatherCost += R.getGatherCost(PointerOps, /*ForPoisonSrc=*/true, + PointerOps.front()->getType()); + } } if (NeedInsertSubvectorAnalysis) { // Add the cost for the subvectors insert. - for (int I = VF, E = VL.size(); I < E; I += VF) + SmallVector<int> ShuffleMask(VL.size()); + for (unsigned I = VF, E = VL.size(); I < E; I += VF) { + for (unsigned Idx : seq<unsigned>(0, E)) + ShuffleMask[Idx] = Idx / VF == I ? E + Idx % VF : Idx; GatherCost += TTI.getShuffleCost(TTI::SK_InsertSubvector, VecTy, - std::nullopt, CostKind, I, LoadTy); + ShuffleMask, CostKind, I, LoadTy); + } } GatherCost -= ScalarsCost; } + GatherCost = std::min(BaseCost, GatherCost); } else if (!Root && isSplat(VL)) { // Found the broadcasting of the single scalar, calculate the cost as // the broadcast. - const auto *It = - find_if(VL, [](Value *V) { return !isa<UndefValue>(V); }); + const auto *It = find_if_not(VL, IsaPred<UndefValue>); assert(It != VL.end() && "Expected at least one non-undef value."); // Add broadcast for non-identity shuffle only. bool NeedShuffle = count(VL, *It) > 1 && - (VL.front() != *It || !all_of(VL.drop_front(), UndefValue::classof)); - InstructionCost InsertCost = TTI.getVectorInstrCost( - Instruction::InsertElement, VecTy, CostKind, - NeedShuffle ? 0 : std::distance(VL.begin(), It), - PoisonValue::get(VecTy), *It); - return InsertCost + - (NeedShuffle ? TTI.getShuffleCost( - TargetTransformInfo::SK_Broadcast, VecTy, - /*Mask=*/std::nullopt, CostKind, /*Index=*/0, - /*SubTp=*/nullptr, /*Args=*/*It) - : TTI::TCC_Free); + (VL.front() != *It || !all_of(VL.drop_front(), IsaPred<UndefValue>)); + if (!NeedShuffle) + return TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, + CostKind, std::distance(VL.begin(), It), + PoisonValue::get(VecTy), *It); + + SmallVector<int> ShuffleMask(VL.size(), PoisonMaskElem); + transform(VL, ShuffleMask.begin(), [](Value *V) { + return isa<PoisonValue>(V) ? PoisonMaskElem : 0; + }); + InstructionCost InsertCost = + TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0, + PoisonValue::get(VecTy), *It); + return InsertCost + TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, + VecTy, ShuffleMask, CostKind, + /*Index=*/0, /*SubTp=*/nullptr, + /*Args=*/*It); } return GatherCost + - (all_of(Gathers, UndefValue::classof) + (all_of(Gathers, IsaPred<UndefValue>) ? TTI::TCC_Free - : R.getGatherCost(Gathers, !Root && VL.equals(Gathers))); + : R.getGatherCost(Gathers, !Root && VL.equals(Gathers), + ScalarTy)); }; /// Compute the cost of creating a vector containing the extracted values from @@ -7116,34 +8463,64 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { auto *EE = dyn_cast<ExtractElementInst>(V); if (!EE) return Sz; - auto *VecTy = cast<FixedVectorType>(EE->getVectorOperandType()); + auto *VecTy = dyn_cast<FixedVectorType>(EE->getVectorOperandType()); + if (!VecTy) + return Sz; return std::max(Sz, VecTy->getNumElements()); }); - unsigned NumSrcRegs = TTI.getNumberOfParts( - FixedVectorType::get(VL.front()->getType(), NumElts)); - if (NumSrcRegs == 0) - NumSrcRegs = 1; // FIXME: this must be moved to TTI for better estimation. - unsigned EltsPerVector = PowerOf2Ceil(std::max( - divideCeil(VL.size(), NumParts), divideCeil(NumElts, NumSrcRegs))); - auto CheckPerRegistersShuffle = - [&](MutableArrayRef<int> Mask) -> std::optional<TTI::ShuffleKind> { + unsigned EltsPerVector = getPartNumElems(VL.size(), NumParts); + auto CheckPerRegistersShuffle = [&](MutableArrayRef<int> Mask, + SmallVectorImpl<unsigned> &Indices) + -> std::optional<TTI::ShuffleKind> { + if (NumElts <= EltsPerVector) + return std::nullopt; + int OffsetReg0 = + alignDown(std::accumulate(Mask.begin(), Mask.end(), INT_MAX, + [](int S, int I) { + if (I == PoisonMaskElem) + return S; + return std::min(S, I); + }), + EltsPerVector); + int OffsetReg1 = OffsetReg0; DenseSet<int> RegIndices; // Check that if trying to permute same single/2 input vectors. TTI::ShuffleKind ShuffleKind = TTI::SK_PermuteSingleSrc; int FirstRegId = -1; - for (int &I : Mask) { + Indices.assign(1, OffsetReg0); + for (auto [Pos, I] : enumerate(Mask)) { if (I == PoisonMaskElem) continue; - int RegId = (I / NumElts) * NumParts + (I % NumElts) / EltsPerVector; + int Idx = I - OffsetReg0; + int RegId = + (Idx / NumElts) * NumParts + (Idx % NumElts) / EltsPerVector; if (FirstRegId < 0) FirstRegId = RegId; RegIndices.insert(RegId); if (RegIndices.size() > 2) return std::nullopt; - if (RegIndices.size() == 2) + if (RegIndices.size() == 2) { ShuffleKind = TTI::SK_PermuteTwoSrc; - I = (I % NumElts) % EltsPerVector + + if (Indices.size() == 1) { + OffsetReg1 = alignDown( + std::accumulate( + std::next(Mask.begin(), Pos), Mask.end(), INT_MAX, + [&](int S, int I) { + if (I == PoisonMaskElem) + return S; + int RegId = ((I - OffsetReg0) / NumElts) * NumParts + + ((I - OffsetReg0) % NumElts) / EltsPerVector; + if (RegId == FirstRegId) + return S; + return std::min(S, I); + }), + EltsPerVector); + Indices.push_back(OffsetReg1 % NumElts); + } + Idx = I - OffsetReg1; + } + I = (Idx % NumElts) % EltsPerVector + (RegId == FirstRegId ? 0 : EltsPerVector); } return ShuffleKind; @@ -7153,31 +8530,48 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { // Process extracts in blocks of EltsPerVector to check if the source vector // operand can be re-used directly. If not, add the cost of creating a // shuffle to extract the values into a vector register. - for (unsigned Part = 0; Part < NumParts; ++Part) { + for (unsigned Part : seq<unsigned>(NumParts)) { if (!ShuffleKinds[Part]) continue; - ArrayRef<int> MaskSlice = - Mask.slice(Part * EltsPerVector, - (Part == NumParts - 1 && Mask.size() % EltsPerVector != 0) - ? Mask.size() % EltsPerVector - : EltsPerVector); + ArrayRef<int> MaskSlice = Mask.slice( + Part * EltsPerVector, getNumElems(Mask.size(), EltsPerVector, Part)); SmallVector<int> SubMask(EltsPerVector, PoisonMaskElem); copy(MaskSlice, SubMask.begin()); + SmallVector<unsigned, 2> Indices; std::optional<TTI::ShuffleKind> RegShuffleKind = - CheckPerRegistersShuffle(SubMask); + CheckPerRegistersShuffle(SubMask, Indices); if (!RegShuffleKind) { - Cost += ::getShuffleCost( - TTI, *ShuffleKinds[Part], - FixedVectorType::get(VL.front()->getType(), NumElts), MaskSlice); + if (*ShuffleKinds[Part] != TTI::SK_PermuteSingleSrc || + !ShuffleVectorInst::isIdentityMask( + MaskSlice, std::max<unsigned>(NumElts, MaskSlice.size()))) + Cost += + ::getShuffleCost(TTI, *ShuffleKinds[Part], + getWidenedType(ScalarTy, NumElts), MaskSlice); continue; } if (*RegShuffleKind != TTI::SK_PermuteSingleSrc || !ShuffleVectorInst::isIdentityMask(SubMask, EltsPerVector)) { - Cost += ::getShuffleCost( - TTI, *RegShuffleKind, - FixedVectorType::get(VL.front()->getType(), EltsPerVector), - SubMask); + Cost += + ::getShuffleCost(TTI, *RegShuffleKind, + getWidenedType(ScalarTy, EltsPerVector), SubMask); } + for (unsigned Idx : Indices) { + assert((Idx + EltsPerVector) <= alignTo(NumElts, EltsPerVector) && + "SK_ExtractSubvector index out of range"); + Cost += ::getShuffleCost( + TTI, TTI::SK_ExtractSubvector, + getWidenedType(ScalarTy, alignTo(NumElts, EltsPerVector)), + std::nullopt, CostKind, Idx, + getWidenedType(ScalarTy, EltsPerVector)); + } + // Second attempt to check, if just a permute is better estimated than + // subvector extract. + SubMask.assign(NumElts, PoisonMaskElem); + copy(MaskSlice, SubMask.begin()); + InstructionCost OriginalCost = ::getShuffleCost( + TTI, *ShuffleKinds[Part], getWidenedType(ScalarTy, NumElts), SubMask); + if (OriginalCost < Cost) + Cost = OriginalCost; } return Cost; } @@ -7205,11 +8599,11 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { InVectors.front().get<const TreeEntry *>() == &E1 && InVectors.back().get<const TreeEntry *>() == E2) || (!E2 && InVectors.front().get<const TreeEntry *>() == &E1)) { - assert(all_of(ArrayRef(CommonMask).slice(Part * SliceSize, SliceSize), + unsigned Limit = getNumElems(Mask.size(), SliceSize, Part); + assert(all_of(ArrayRef(CommonMask).slice(Part * SliceSize, Limit), [](int Idx) { return Idx == PoisonMaskElem; }) && "Expected all poisoned elements."); - ArrayRef<int> SubMask = - ArrayRef(Mask).slice(Part * SliceSize, SliceSize); + ArrayRef<int> SubMask = ArrayRef(Mask).slice(Part * SliceSize, Limit); copy(SubMask, std::next(CommonMask.begin(), SliceSize * Part)); return; } @@ -7221,8 +8615,24 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { transformMaskAfterShuffle(CommonMask, CommonMask); } SameNodesEstimated = false; - Cost += createShuffle(&E1, E2, Mask); - transformMaskAfterShuffle(CommonMask, Mask); + if (!E2 && InVectors.size() == 1) { + unsigned VF = E1.getVectorFactor(); + if (Value *V1 = InVectors.front().dyn_cast<Value *>()) { + VF = std::max(VF, + cast<FixedVectorType>(V1->getType())->getNumElements()); + } else { + const auto *E = InVectors.front().get<const TreeEntry *>(); + VF = std::max(VF, E->getVectorFactor()); + } + for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx) + if (Mask[Idx] != PoisonMaskElem && CommonMask[Idx] == PoisonMaskElem) + CommonMask[Idx] = Mask[Idx] + VF; + Cost += createShuffle(InVectors.front(), &E1, CommonMask); + transformMaskAfterShuffle(CommonMask, CommonMask); + } else { + Cost += createShuffle(&E1, E2, Mask); + transformMaskAfterShuffle(CommonMask, Mask); + } } class ShuffleCostBuilder { @@ -7277,6 +8687,47 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { SmallVector<int> CommonMask(Mask.begin(), Mask.end()); Value *V1 = P1.dyn_cast<Value *>(), *V2 = P2.dyn_cast<Value *>(); unsigned CommonVF = Mask.size(); + InstructionCost ExtraCost = 0; + auto GetNodeMinBWAffectedCost = [&](const TreeEntry &E, + unsigned VF) -> InstructionCost { + if (E.isGather() && allConstant(E.Scalars)) + return TTI::TCC_Free; + Type *EScalarTy = E.Scalars.front()->getType(); + bool IsSigned = true; + if (auto It = R.MinBWs.find(&E); It != R.MinBWs.end()) { + EScalarTy = IntegerType::get(EScalarTy->getContext(), It->second.first); + IsSigned = It->second.second; + } + if (EScalarTy != ScalarTy) { + unsigned CastOpcode = Instruction::Trunc; + unsigned DstSz = R.DL->getTypeSizeInBits(ScalarTy); + unsigned SrcSz = R.DL->getTypeSizeInBits(EScalarTy); + if (DstSz > SrcSz) + CastOpcode = IsSigned ? Instruction::SExt : Instruction::ZExt; + return TTI.getCastInstrCost(CastOpcode, getWidenedType(ScalarTy, VF), + getWidenedType(EScalarTy, VF), + TTI::CastContextHint::None, CostKind); + } + return TTI::TCC_Free; + }; + auto GetValueMinBWAffectedCost = [&](const Value *V) -> InstructionCost { + if (isa<Constant>(V)) + return TTI::TCC_Free; + auto *VecTy = cast<VectorType>(V->getType()); + Type *EScalarTy = VecTy->getElementType(); + if (EScalarTy != ScalarTy) { + bool IsSigned = !isKnownNonNegative(V, SimplifyQuery(*R.DL)); + unsigned CastOpcode = Instruction::Trunc; + unsigned DstSz = R.DL->getTypeSizeInBits(ScalarTy); + unsigned SrcSz = R.DL->getTypeSizeInBits(EScalarTy); + if (DstSz > SrcSz) + CastOpcode = IsSigned ? Instruction::SExt : Instruction::ZExt; + return TTI.getCastInstrCost( + CastOpcode, VectorType::get(ScalarTy, VecTy->getElementCount()), + VecTy, TTI::CastContextHint::None, CostKind); + } + return TTI::TCC_Free; + }; if (!V1 && !V2 && !P2.isNull()) { // Shuffle 2 entry nodes. const TreeEntry *E = P1.get<const TreeEntry *>(); @@ -7303,11 +8754,14 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { } } CommonVF = E->Scalars.size(); + ExtraCost += GetNodeMinBWAffectedCost(*E, CommonVF) + + GetNodeMinBWAffectedCost(*E2, CommonVF); + } else { + ExtraCost += GetNodeMinBWAffectedCost(*E, E->getVectorFactor()) + + GetNodeMinBWAffectedCost(*E2, E2->getVectorFactor()); } - V1 = Constant::getNullValue( - FixedVectorType::get(E->Scalars.front()->getType(), CommonVF)); - V2 = getAllOnesValue( - *R.DL, FixedVectorType::get(E->Scalars.front()->getType(), CommonVF)); + V1 = Constant::getNullValue(getWidenedType(ScalarTy, CommonVF)); + V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF)); } else if (!V1 && P2.isNull()) { // Shuffle single entry node. const TreeEntry *E = P1.get<const TreeEntry *>(); @@ -7326,10 +8780,25 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { } CommonVF = E->Scalars.size(); } - V1 = Constant::getNullValue( - FixedVectorType::get(E->Scalars.front()->getType(), CommonVF)); + ExtraCost += GetNodeMinBWAffectedCost(*E, CommonVF); + V1 = Constant::getNullValue(getWidenedType(ScalarTy, CommonVF)); + // Not identity/broadcast? Try to see if the original vector is better. + if (!E->ReorderIndices.empty() && CommonVF == E->ReorderIndices.size() && + CommonVF == CommonMask.size() && + any_of(enumerate(CommonMask), + [](const auto &&P) { + return P.value() != PoisonMaskElem && + static_cast<unsigned>(P.value()) != P.index(); + }) && + any_of(CommonMask, + [](int Idx) { return Idx != PoisonMaskElem && Idx != 0; })) { + SmallVector<int> ReorderMask; + inversePermutation(E->ReorderIndices, ReorderMask); + ::addMask(CommonMask, ReorderMask); + } } else if (V1 && P2.isNull()) { // Shuffle single vector. + ExtraCost += GetValueMinBWAffectedCost(V1); CommonVF = cast<FixedVectorType>(V1->getType())->getNumElements(); assert( all_of(Mask, @@ -7356,11 +8825,11 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { } CommonVF = VF; } - V1 = Constant::getNullValue( - FixedVectorType::get(E2->Scalars.front()->getType(), CommonVF)); - V2 = getAllOnesValue( - *R.DL, - FixedVectorType::get(E2->Scalars.front()->getType(), CommonVF)); + ExtraCost += GetValueMinBWAffectedCost(V1); + V1 = Constant::getNullValue(getWidenedType(ScalarTy, CommonVF)); + ExtraCost += GetNodeMinBWAffectedCost( + *E2, std::min(CommonVF, E2->getVectorFactor())); + V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF)); } else if (!V1 && V2) { // Shuffle vector and tree node. unsigned VF = cast<FixedVectorType>(V2->getType())->getNumElements(); @@ -7384,11 +8853,11 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { } CommonVF = VF; } - V1 = Constant::getNullValue( - FixedVectorType::get(E1->Scalars.front()->getType(), CommonVF)); - V2 = getAllOnesValue( - *R.DL, - FixedVectorType::get(E1->Scalars.front()->getType(), CommonVF)); + ExtraCost += GetNodeMinBWAffectedCost( + *E1, std::min(CommonVF, E1->getVectorFactor())); + V1 = Constant::getNullValue(getWidenedType(ScalarTy, CommonVF)); + ExtraCost += GetValueMinBWAffectedCost(V2); + V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF)); } else { assert(V1 && V2 && "Expected both vectors."); unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements(); @@ -7399,30 +8868,33 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { return Idx < 2 * static_cast<int>(CommonVF); }) && "All elements in mask must be less than 2 * CommonVF."); + ExtraCost += + GetValueMinBWAffectedCost(V1) + GetValueMinBWAffectedCost(V2); if (V1->getType() != V2->getType()) { - V1 = Constant::getNullValue(FixedVectorType::get( - cast<FixedVectorType>(V1->getType())->getElementType(), CommonVF)); - V2 = getAllOnesValue( - *R.DL, FixedVectorType::get( - cast<FixedVectorType>(V1->getType())->getElementType(), - CommonVF)); + V1 = Constant::getNullValue(getWidenedType(ScalarTy, CommonVF)); + V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF)); + } else { + if (cast<VectorType>(V1->getType())->getElementType() != ScalarTy) + V1 = Constant::getNullValue(getWidenedType(ScalarTy, CommonVF)); + if (cast<VectorType>(V2->getType())->getElementType() != ScalarTy) + V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF)); } } - InVectors.front() = Constant::getNullValue(FixedVectorType::get( - cast<FixedVectorType>(V1->getType())->getElementType(), - CommonMask.size())); + InVectors.front() = + Constant::getNullValue(getWidenedType(ScalarTy, CommonMask.size())); if (InVectors.size() == 2) InVectors.pop_back(); - return BaseShuffleAnalysis::createShuffle<InstructionCost>( - V1, V2, CommonMask, Builder); + return ExtraCost + BaseShuffleAnalysis::createShuffle<InstructionCost>( + V1, V2, CommonMask, Builder); } public: - ShuffleCostEstimator(TargetTransformInfo &TTI, + ShuffleCostEstimator(Type *ScalarTy, TargetTransformInfo &TTI, ArrayRef<Value *> VectorizedVals, BoUpSLP &R, SmallPtrSetImpl<Value *> &CheckedExtracts) - : TTI(TTI), VectorizedVals(VectorizedVals.begin(), VectorizedVals.end()), - R(R), CheckedExtracts(CheckedExtracts) {} + : ScalarTy(ScalarTy), TTI(TTI), + VectorizedVals(VectorizedVals.begin(), VectorizedVals.end()), R(R), + CheckedExtracts(CheckedExtracts) {} Value *adjustExtracts(const TreeEntry *E, MutableArrayRef<int> Mask, ArrayRef<std::optional<TTI::ShuffleKind>> ShuffleKinds, unsigned NumParts, bool &UseVecBaseAsInput) { @@ -7441,7 +8913,7 @@ public: [&](const std::unique_ptr<TreeEntry> &TE) { return ((!TE->isAltShuffle() && TE->getOpcode() == Instruction::ExtractElement) || - TE->State == TreeEntry::NeedToGather) && + TE->isGather()) && all_of(enumerate(TE->Scalars), [&](auto &&Data) { return VL.size() > Data.index() && (Mask[Data.index()] == PoisonMaskElem || @@ -7450,10 +8922,11 @@ public: }); }); SmallPtrSet<Value *, 4> UniqueBases; - unsigned SliceSize = VL.size() / NumParts; - for (unsigned Part = 0; Part < NumParts; ++Part) { - ArrayRef<int> SubMask = Mask.slice(Part * SliceSize, SliceSize); - for (auto [I, V] : enumerate(VL.slice(Part * SliceSize, SliceSize))) { + unsigned SliceSize = getPartNumElems(VL.size(), NumParts); + for (unsigned Part : seq<unsigned>(NumParts)) { + unsigned Limit = getNumElems(VL.size(), SliceSize, Part); + ArrayRef<int> SubMask = Mask.slice(Part * SliceSize, Limit); + for (auto [I, V] : enumerate(VL.slice(Part * SliceSize, Limit))) { // Ignore non-extractelement scalars. if (isa<UndefValue>(V) || (!SubMask.empty() && SubMask[I] == PoisonMaskElem)) @@ -7470,6 +8943,12 @@ public: const TreeEntry *VE = R.getTreeEntry(V); if (!CheckedExtracts.insert(V).second || !R.areAllUsersVectorized(cast<Instruction>(V), &VectorizedVals) || + any_of(EE->users(), + [&](User *U) { + return isa<GetElementPtrInst>(U) && + !R.areAllUsersVectorized(cast<Instruction>(U), + &VectorizedVals); + }) || (VE && VE != E)) continue; std::optional<unsigned> EEIdx = getExtractIndex(EE); @@ -7479,9 +8958,8 @@ public: // Take credit for instruction that will become dead. if (EE->hasOneUse() || !PrevNodeFound) { Instruction *Ext = EE->user_back(); - if (isa<SExtInst, ZExtInst>(Ext) && all_of(Ext->users(), [](User *U) { - return isa<GetElementPtrInst>(U); - })) { + if (isa<SExtInst, ZExtInst>(Ext) && + all_of(Ext->users(), IsaPred<GetElementPtrInst>)) { // Use getExtractWithExtendCost() to calculate the cost of // extractelement/ext pair. Cost -= @@ -7512,8 +8990,8 @@ public: SameNodesEstimated = false; if (NumParts != 1 && UniqueBases.size() != 1) { UseVecBaseAsInput = true; - VecBase = Constant::getNullValue( - FixedVectorType::get(VL.front()->getType(), CommonMask.size())); + VecBase = + Constant::getNullValue(getWidenedType(ScalarTy, CommonMask.size())); } return VecBase; } @@ -7541,12 +9019,11 @@ public: return; } assert(!CommonMask.empty() && "Expected non-empty common mask."); - auto *MaskVecTy = - FixedVectorType::get(E1.Scalars.front()->getType(), Mask.size()); + auto *MaskVecTy = getWidenedType(ScalarTy, Mask.size()); unsigned NumParts = TTI.getNumberOfParts(MaskVecTy); if (NumParts == 0 || NumParts >= Mask.size()) NumParts = 1; - unsigned SliceSize = Mask.size() / NumParts; + unsigned SliceSize = getPartNumElems(Mask.size(), NumParts); const auto *It = find_if(Mask, [](int Idx) { return Idx != PoisonMaskElem; }); unsigned Part = std::distance(Mask.begin(), It) / SliceSize; @@ -7559,12 +9036,11 @@ public: return; } assert(!CommonMask.empty() && "Expected non-empty common mask."); - auto *MaskVecTy = - FixedVectorType::get(E1.Scalars.front()->getType(), Mask.size()); + auto *MaskVecTy = getWidenedType(ScalarTy, Mask.size()); unsigned NumParts = TTI.getNumberOfParts(MaskVecTy); if (NumParts == 0 || NumParts >= Mask.size()) NumParts = 1; - unsigned SliceSize = Mask.size() / NumParts; + unsigned SliceSize = getPartNumElems(Mask.size(), NumParts); const auto *It = find_if(Mask, [](int Idx) { return Idx != PoisonMaskElem; }); unsigned Part = std::distance(Mask.begin(), It) / SliceSize; @@ -7660,7 +9136,7 @@ public: return ConstantVector::getSplat( ElementCount::getFixed( cast<FixedVectorType>(Root->getType())->getNumElements()), - getAllOnesValue(*R.DL, VL.front()->getType())); + getAllOnesValue(*R.DL, ScalarTy)); } InstructionCost createFreeze(InstructionCost Cost) { return Cost; } /// Finalize emission of the shuffles. @@ -7704,7 +9180,7 @@ const BoUpSLP::TreeEntry *BoUpSLP::getOperandEntry(const TreeEntry *E, unsigned Idx) const { Value *Op = E->getOperand(Idx).front(); if (const TreeEntry *TE = getTreeEntry(Op)) { - if (find_if(E->UserTreeIndices, [&](const EdgeInfo &EI) { + if (find_if(TE->UserTreeIndices, [&](const EdgeInfo &EI) { return EI.EdgeIdx == Idx && EI.UserTE == E; }) != TE->UserTreeIndices.end()) return TE; @@ -7720,7 +9196,7 @@ const BoUpSLP::TreeEntry *BoUpSLP::getOperandEntry(const TreeEntry *E, } const auto *It = find_if(VectorizableTree, [&](const std::unique_ptr<TreeEntry> &TE) { - return TE->State == TreeEntry::NeedToGather && + return TE->isGather() && find_if(TE->UserTreeIndices, [&](const EdgeInfo &EI) { return EI.EdgeIdx == Idx && EI.UserTE == E; }) != TE->UserTreeIndices.end(); @@ -7729,13 +9205,53 @@ const BoUpSLP::TreeEntry *BoUpSLP::getOperandEntry(const TreeEntry *E, return It->get(); } +TTI::CastContextHint BoUpSLP::getCastContextHint(const TreeEntry &TE) const { + if (TE.State == TreeEntry::ScatterVectorize || + TE.State == TreeEntry::StridedVectorize) + return TTI::CastContextHint::GatherScatter; + if (TE.State == TreeEntry::Vectorize && TE.getOpcode() == Instruction::Load && + !TE.isAltShuffle()) { + if (TE.ReorderIndices.empty()) + return TTI::CastContextHint::Normal; + SmallVector<int> Mask; + inversePermutation(TE.ReorderIndices, Mask); + if (ShuffleVectorInst::isReverseMask(Mask, Mask.size())) + return TTI::CastContextHint::Reversed; + } + return TTI::CastContextHint::None; +} + +/// Builds the arguments types vector for the given call instruction with the +/// given \p ID for the specified vector factor. +static SmallVector<Type *> buildIntrinsicArgTypes(const CallInst *CI, + const Intrinsic::ID ID, + const unsigned VF, + unsigned MinBW) { + SmallVector<Type *> ArgTys; + for (auto [Idx, Arg] : enumerate(CI->args())) { + if (ID != Intrinsic::not_intrinsic) { + if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) { + ArgTys.push_back(Arg->getType()); + continue; + } + if (MinBW > 0) { + ArgTys.push_back( + getWidenedType(IntegerType::get(CI->getContext(), MinBW), VF)); + continue; + } + } + ArgTys.push_back(getWidenedType(Arg->getType(), VF)); + } + return ArgTys; +} + InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, SmallPtrSetImpl<Value *> &CheckedExtracts) { ArrayRef<Value *> VL = E->Scalars; Type *ScalarTy = VL[0]->getType(); - if (E->State != TreeEntry::NeedToGather) { + if (!E->isGather()) { if (auto *SI = dyn_cast<StoreInst>(VL[0])) ScalarTy = SI->getValueOperand()->getType(); else if (auto *CI = dyn_cast<CmpInst>(VL[0])) @@ -7743,34 +9259,34 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, else if (auto *IE = dyn_cast<InsertElementInst>(VL[0])) ScalarTy = IE->getOperand(1)->getType(); } - if (!FixedVectorType::isValidElementType(ScalarTy)) + if (!isValidElementType(ScalarTy)) return InstructionCost::getInvalid(); - auto *VecTy = FixedVectorType::get(ScalarTy, VL.size()); TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; // If we have computed a smaller type for the expression, update VecTy so // that the costs will be accurate. auto It = MinBWs.find(E); - if (It != MinBWs.end()) { + Type *OrigScalarTy = ScalarTy; + if (It != MinBWs.end()) ScalarTy = IntegerType::get(F->getContext(), It->second.first); - VecTy = FixedVectorType::get(ScalarTy, VL.size()); - } + auto *VecTy = getWidenedType(ScalarTy, VL.size()); unsigned EntryVF = E->getVectorFactor(); - auto *FinalVecTy = FixedVectorType::get(ScalarTy, EntryVF); + auto *FinalVecTy = getWidenedType(ScalarTy, EntryVF); bool NeedToShuffleReuses = !E->ReuseShuffleIndices.empty(); - if (E->State == TreeEntry::NeedToGather) { + if (E->isGather()) { if (allConstant(VL)) return 0; if (isa<InsertElementInst>(VL[0])) return InstructionCost::getInvalid(); return processBuildVector<ShuffleCostEstimator, InstructionCost>( - E, *TTI, VectorizedVals, *this, CheckedExtracts); + E, ScalarTy, *TTI, VectorizedVals, *this, CheckedExtracts); } InstructionCost CommonCost = 0; SmallVector<int> Mask; + bool IsReverseOrder = isReverseOrder(E->ReorderIndices); if (!E->ReorderIndices.empty() && - E->State != TreeEntry::PossibleStridedVectorize) { + (E->State != TreeEntry::StridedVectorize || !IsReverseOrder)) { SmallVector<int> NewMask; if (E->getOpcode() == Instruction::Store) { // For stores the order is actually a mask. @@ -7788,7 +9304,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, TTI->getShuffleCost(TTI::SK_PermuteSingleSrc, FinalVecTy, Mask); assert((E->State == TreeEntry::Vectorize || E->State == TreeEntry::ScatterVectorize || - E->State == TreeEntry::PossibleStridedVectorize) && + E->State == TreeEntry::StridedVectorize) && "Unhandled state"); assert(E->getOpcode() && ((allSameType(VL) && allSameBlock(VL)) || @@ -7807,23 +9323,11 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, UsedScalars.set(I); } auto GetCastContextHint = [&](Value *V) { - if (const TreeEntry *OpTE = getTreeEntry(V)) { - if (OpTE->State == TreeEntry::ScatterVectorize) - return TTI::CastContextHint::GatherScatter; - if (OpTE->State == TreeEntry::Vectorize && - OpTE->getOpcode() == Instruction::Load && !OpTE->isAltShuffle()) { - if (OpTE->ReorderIndices.empty()) - return TTI::CastContextHint::Normal; - SmallVector<int> Mask; - inversePermutation(OpTE->ReorderIndices, Mask); - if (ShuffleVectorInst::isReverseMask(Mask, Mask.size())) - return TTI::CastContextHint::Reversed; - } - } else { - InstructionsState SrcState = getSameOpcode(E->getOperand(0), *TLI); - if (SrcState.getOpcode() == Instruction::Load && !SrcState.isAltShuffle()) - return TTI::CastContextHint::GatherScatter; - } + if (const TreeEntry *OpTE = getTreeEntry(V)) + return getCastContextHint(*OpTE); + InstructionsState SrcState = getSameOpcode(E->getOperand(0), *TLI); + if (SrcState.getOpcode() == Instruction::Load && !SrcState.isAltShuffle()) + return TTI::CastContextHint::GatherScatter; return TTI::CastContextHint::None; }; auto GetCostDiff = @@ -7831,7 +9335,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, function_ref<InstructionCost(InstructionCost)> VectorCost) { // Calculate the cost of this instruction. InstructionCost ScalarCost = 0; - if (isa<CastInst, CmpInst, SelectInst, CallInst>(VL0)) { + if (isa<CastInst, CallInst>(VL0)) { // For some of the instructions no need to calculate cost for each // particular instruction, we can use the cost of the single // instruction x total number of scalar instructions. @@ -7862,19 +9366,16 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, unsigned BWSz = DL->getTypeSizeInBits(ScalarTy); unsigned SrcBWSz = DL->getTypeSizeInBits(UserScalarTy); unsigned VecOpcode; - auto *SrcVecTy = - FixedVectorType::get(UserScalarTy, E->getVectorFactor()); + auto *UserVecTy = + getWidenedType(UserScalarTy, E->getVectorFactor()); if (BWSz > SrcBWSz) VecOpcode = Instruction::Trunc; else VecOpcode = It->second.second ? Instruction::SExt : Instruction::ZExt; TTI::CastContextHint CCH = GetCastContextHint(VL0); - VecCost += TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH, + VecCost += TTI->getCastInstrCost(VecOpcode, UserVecTy, VecTy, CCH, CostKind); - ScalarCost += - Sz * TTI->getCastInstrCost(VecOpcode, ScalarTy, UserScalarTy, - CCH, CostKind); } } } @@ -7885,78 +9386,13 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, // Calculate cost difference from vectorizing set of GEPs. // Negative value means vectorizing is profitable. auto GetGEPCostDiff = [=](ArrayRef<Value *> Ptrs, Value *BasePtr) { + assert((E->State == TreeEntry::Vectorize || + E->State == TreeEntry::StridedVectorize) && + "Entry state expected to be Vectorize or StridedVectorize here."); InstructionCost ScalarCost = 0; InstructionCost VecCost = 0; - // Here we differentiate two cases: (1) when Ptrs represent a regular - // vectorization tree node (as they are pointer arguments of scattered - // loads) or (2) when Ptrs are the arguments of loads or stores being - // vectorized as plane wide unit-stride load/store since all the - // loads/stores are known to be from/to adjacent locations. - assert(E->State == TreeEntry::Vectorize && - "Entry state expected to be Vectorize here."); - if (isa<LoadInst, StoreInst>(VL0)) { - // Case 2: estimate costs for pointer related costs when vectorizing to - // a wide load/store. - // Scalar cost is estimated as a set of pointers with known relationship - // between them. - // For vector code we will use BasePtr as argument for the wide load/store - // but we also need to account all the instructions which are going to - // stay in vectorized code due to uses outside of these scalar - // loads/stores. - ScalarCost = TTI->getPointersChainCost( - Ptrs, BasePtr, TTI::PointersChainInfo::getUnitStride(), ScalarTy, - CostKind); - - SmallVector<const Value *> PtrsRetainedInVecCode; - for (Value *V : Ptrs) { - if (V == BasePtr) { - PtrsRetainedInVecCode.push_back(V); - continue; - } - auto *Ptr = dyn_cast<GetElementPtrInst>(V); - // For simplicity assume Ptr to stay in vectorized code if it's not a - // GEP instruction. We don't care since it's cost considered free. - // TODO: We should check for any uses outside of vectorizable tree - // rather than just single use. - if (!Ptr || !Ptr->hasOneUse()) - PtrsRetainedInVecCode.push_back(V); - } - - if (PtrsRetainedInVecCode.size() == Ptrs.size()) { - // If all pointers stay in vectorized code then we don't have - // any savings on that. - LLVM_DEBUG(dumpTreeCosts(E, 0, ScalarCost, ScalarCost, - "Calculated GEPs cost for Tree")); - return InstructionCost{TTI::TCC_Free}; - } - VecCost = TTI->getPointersChainCost( - PtrsRetainedInVecCode, BasePtr, - TTI::PointersChainInfo::getKnownStride(), VecTy, CostKind); - } else { - // Case 1: Ptrs are the arguments of loads that we are going to transform - // into masked gather load intrinsic. - // All the scalar GEPs will be removed as a result of vectorization. - // For any external uses of some lanes extract element instructions will - // be generated (which cost is estimated separately). - TTI::PointersChainInfo PtrsInfo = - all_of(Ptrs, - [](const Value *V) { - auto *Ptr = dyn_cast<GetElementPtrInst>(V); - return Ptr && !Ptr->hasAllConstantIndices(); - }) - ? TTI::PointersChainInfo::getUnknownStride() - : TTI::PointersChainInfo::getKnownStride(); - - ScalarCost = TTI->getPointersChainCost(Ptrs, BasePtr, PtrsInfo, ScalarTy, - CostKind); - if (auto *BaseGEP = dyn_cast<GEPOperator>(BasePtr)) { - SmallVector<const Value *> Indices(BaseGEP->indices()); - VecCost = TTI->getGEPCost(BaseGEP->getSourceElementType(), - BaseGEP->getPointerOperand(), Indices, VecTy, - CostKind); - } - } - + std::tie(ScalarCost, VecCost) = getGEPCosts( + *TTI, Ptrs, BasePtr, E->getOpcode(), CostKind, OrigScalarTy, VecTy); LLVM_DEBUG(dumpTreeCosts(E, 0, VecCost, ScalarCost, "Calculated GEPs cost for Tree")); @@ -8003,13 +9439,12 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, NumElts = ATy->getNumElements(); else NumElts = AggregateTy->getStructNumElements(); - SrcVecTy = FixedVectorType::get(ScalarTy, NumElts); + SrcVecTy = getWidenedType(OrigScalarTy, NumElts); } if (I->hasOneUse()) { Instruction *Ext = I->user_back(); if ((isa<SExtInst>(Ext) || isa<ZExtInst>(Ext)) && - all_of(Ext->users(), - [](User *U) { return isa<GetElementPtrInst>(U); })) { + all_of(Ext->users(), IsaPred<GetElementPtrInst>)) { // Use getExtractWithExtendCost() to calculate the cost of // extractelement/ext pair. InstructionCost Cost = TTI->getExtractWithExtendCost( @@ -8037,11 +9472,11 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, unsigned NumOfParts = TTI->getNumberOfParts(SrcVecTy); SmallVector<int> InsertMask(NumElts, PoisonMaskElem); - unsigned OffsetBeg = *getInsertIndex(VL.front()); + unsigned OffsetBeg = *getElementIndex(VL.front()); unsigned OffsetEnd = OffsetBeg; InsertMask[OffsetBeg] = 0; for (auto [I, V] : enumerate(VL.drop_front())) { - unsigned Idx = *getInsertIndex(V); + unsigned Idx = *getElementIndex(V); if (OffsetBeg > Idx) OffsetBeg = Idx; else if (OffsetEnd < Idx) @@ -8082,7 +9517,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, SmallVector<int> PrevMask(InsertVecSz, PoisonMaskElem); Mask.swap(PrevMask); for (unsigned I = 0; I < NumScalars; ++I) { - unsigned InsertIdx = *getInsertIndex(VL[PrevMask[I]]); + unsigned InsertIdx = *getElementIndex(VL[PrevMask[I]]); DemandedElts.setBit(InsertIdx); IsIdentity &= InsertIdx - OffsetBeg == I; Mask[InsertIdx - OffsetBeg] = I; @@ -8098,7 +9533,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, // need to shift the vector. // Do not calculate the cost if the actual size is the register size and // we can merge this shuffle with the following SK_Select. - auto *InsertVecTy = FixedVectorType::get(ScalarTy, InsertVecSz); + auto *InsertVecTy = getWidenedType(ScalarTy, InsertVecSz); if (!IsIdentity) Cost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, InsertVecTy, Mask); @@ -8114,7 +9549,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, buildUseMask(NumElts, InsertMask, UseMask::UndefsAsMask)); if (!InMask.all() && NumScalars != NumElts && !IsWholeSubvector) { if (InsertVecSz != VecSz) { - auto *ActualVecTy = FixedVectorType::get(ScalarTy, VecSz); + auto *ActualVecTy = getWidenedType(ScalarTy, VecSz); Cost += TTI->getShuffleCost(TTI::SK_InsertSubvector, ActualVecTy, std::nullopt, CostKind, OffsetBeg - Offset, InsertVecTy); @@ -8148,7 +9583,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, case Instruction::BitCast: { auto SrcIt = MinBWs.find(getOperandEntry(E, 0)); Type *SrcScalarTy = VL0->getOperand(0)->getType(); - auto *SrcVecTy = FixedVectorType::get(SrcScalarTy, VL.size()); + auto *SrcVecTy = getWidenedType(SrcScalarTy, VL.size()); unsigned Opcode = ShuffleOrOp; unsigned VecOpcode = Opcode; if (!ScalarTy->isFloatingPointTy() && !SrcScalarTy->isFloatingPointTy() && @@ -8158,7 +9593,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, if (SrcIt != MinBWs.end()) { SrcBWSz = SrcIt->second.first; SrcScalarTy = IntegerType::get(F->getContext(), SrcBWSz); - SrcVecTy = FixedVectorType::get(SrcScalarTy, VL.size()); + SrcVecTy = getWidenedType(SrcScalarTy, VL.size()); } unsigned BWSz = DL->getTypeSizeInBits(ScalarTy); if (BWSz == SrcBWSz) { @@ -8168,16 +9603,17 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, } else if (It != MinBWs.end()) { assert(BWSz > SrcBWSz && "Invalid cast!"); VecOpcode = It->second.second ? Instruction::SExt : Instruction::ZExt; + } else if (SrcIt != MinBWs.end()) { + assert(BWSz > SrcBWSz && "Invalid cast!"); + VecOpcode = + SrcIt->second.second ? Instruction::SExt : Instruction::ZExt; } + } else if (VecOpcode == Instruction::SIToFP && SrcIt != MinBWs.end() && + !SrcIt->second.second) { + VecOpcode = Instruction::UIToFP; } auto GetScalarCost = [&](unsigned Idx) -> InstructionCost { - // Do not count cost here if minimum bitwidth is in effect and it is just - // a bitcast (here it is just a noop). - if (VecOpcode != Opcode && VecOpcode == Instruction::BitCast) - return TTI::TCC_Free; - auto *VI = VL0->getOpcode() == Opcode - ? cast<Instruction>(UniqueValues[Idx]) - : nullptr; + auto *VI = cast<Instruction>(UniqueValues[Idx]); return TTI->getCastInstrCost(Opcode, VL0->getType(), VL0->getOperand(0)->getType(), TTI::getCastContextHint(VI), CostKind, VI); @@ -8220,29 +9656,61 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, ? CmpInst::BAD_FCMP_PREDICATE : CmpInst::BAD_ICMP_PREDICATE; - return TTI->getCmpSelInstrCost(E->getOpcode(), ScalarTy, - Builder.getInt1Ty(), CurrentPred, CostKind, - VI); + InstructionCost ScalarCost = TTI->getCmpSelInstrCost( + E->getOpcode(), OrigScalarTy, Builder.getInt1Ty(), CurrentPred, + CostKind, VI); + auto [MinMaxID, SelectOnly] = canConvertToMinOrMaxIntrinsic(VI); + if (MinMaxID != Intrinsic::not_intrinsic) { + Type *CanonicalType = OrigScalarTy; + if (CanonicalType->isPtrOrPtrVectorTy()) + CanonicalType = CanonicalType->getWithNewType(IntegerType::get( + CanonicalType->getContext(), + DL->getTypeSizeInBits(CanonicalType->getScalarType()))); + + IntrinsicCostAttributes CostAttrs(MinMaxID, CanonicalType, + {CanonicalType, CanonicalType}); + InstructionCost IntrinsicCost = + TTI->getIntrinsicInstrCost(CostAttrs, CostKind); + // If the selects are the only uses of the compares, they will be + // dead and we can adjust the cost by removing their cost. + if (SelectOnly) { + auto *CI = cast<CmpInst>(VI->getOperand(0)); + IntrinsicCost -= TTI->getCmpSelInstrCost( + CI->getOpcode(), OrigScalarTy, Builder.getInt1Ty(), + CI->getPredicate(), CostKind, CI); + } + ScalarCost = std::min(ScalarCost, IntrinsicCost); + } + + return ScalarCost; }; auto GetVectorCost = [&](InstructionCost CommonCost) { - auto *MaskTy = FixedVectorType::get(Builder.getInt1Ty(), VL.size()); + auto *MaskTy = getWidenedType(Builder.getInt1Ty(), VL.size()); InstructionCost VecCost = TTI->getCmpSelInstrCost( E->getOpcode(), VecTy, MaskTy, VecPred, CostKind, VL0); // Check if it is possible and profitable to use min/max for selects // in VL. // - auto IntrinsicAndUse = canConvertToMinOrMaxIntrinsic(VL); - if (IntrinsicAndUse.first != Intrinsic::not_intrinsic) { - IntrinsicCostAttributes CostAttrs(IntrinsicAndUse.first, VecTy, - {VecTy, VecTy}); + auto [MinMaxID, SelectOnly] = canConvertToMinOrMaxIntrinsic(VL); + if (MinMaxID != Intrinsic::not_intrinsic) { + Type *CanonicalType = VecTy; + if (CanonicalType->isPtrOrPtrVectorTy()) + CanonicalType = CanonicalType->getWithNewType(IntegerType::get( + CanonicalType->getContext(), + DL->getTypeSizeInBits(CanonicalType->getScalarType()))); + IntrinsicCostAttributes CostAttrs(MinMaxID, CanonicalType, + {CanonicalType, CanonicalType}); InstructionCost IntrinsicCost = TTI->getIntrinsicInstrCost(CostAttrs, CostKind); // If the selects are the only uses of the compares, they will be // dead and we can adjust the cost by removing their cost. - if (IntrinsicAndUse.second) - IntrinsicCost -= TTI->getCmpSelInstrCost(Instruction::ICmp, VecTy, + if (SelectOnly) { + auto *CI = + cast<CmpInst>(cast<Instruction>(VL.front())->getOperand(0)); + IntrinsicCost -= TTI->getCmpSelInstrCost(CI->getOpcode(), VecTy, MaskTy, VecPred, CostKind); + } VecCost = std::min(VecCost, IntrinsicCost); } return VecCost + CommonCost; @@ -8275,15 +9743,25 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, TTI::OperandValueInfo Op2Info = TTI::getOperandInfo(VI->getOperand(OpIdx)); SmallVector<const Value *> Operands(VI->operand_values()); - return TTI->getArithmeticInstrCost(ShuffleOrOp, ScalarTy, CostKind, + return TTI->getArithmeticInstrCost(ShuffleOrOp, OrigScalarTy, CostKind, Op1Info, Op2Info, Operands, VI); }; auto GetVectorCost = [=](InstructionCost CommonCost) { + if (ShuffleOrOp == Instruction::And && It != MinBWs.end()) { + for (unsigned I : seq<unsigned>(0, E->getNumOperands())) { + ArrayRef<Value *> Ops = E->getOperand(I); + if (all_of(Ops, [&](Value *Op) { + auto *CI = dyn_cast<ConstantInt>(Op); + return CI && CI->getValue().countr_one() >= It->second.first; + })) + return CommonCost; + } + } unsigned OpIdx = isa<UnaryOperator>(VL0) ? 0 : 1; TTI::OperandValueInfo Op1Info = getOperandInfo(E->getOperand(0)); TTI::OperandValueInfo Op2Info = getOperandInfo(E->getOperand(OpIdx)); return TTI->getArithmeticInstrCost(ShuffleOrOp, VecTy, CostKind, Op1Info, - Op2Info) + + Op2Info, std::nullopt, nullptr, TLI) + CommonCost; }; return GetCostDiff(GetScalarCost, GetVectorCost); @@ -8294,9 +9772,9 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, case Instruction::Load: { auto GetScalarCost = [&](unsigned Idx) { auto *VI = cast<LoadInst>(UniqueValues[Idx]); - return TTI->getMemoryOpCost(Instruction::Load, ScalarTy, VI->getAlign(), - VI->getPointerAddressSpace(), CostKind, - TTI::OperandValueInfo(), VI); + return TTI->getMemoryOpCost(Instruction::Load, OrigScalarTy, + VI->getAlign(), VI->getPointerAddressSpace(), + CostKind, TTI::OperandValueInfo(), VI); }; auto *LI0 = cast<LoadInst>(VL0); auto GetVectorCost = [&](InstructionCost CommonCost) { @@ -8305,14 +9783,16 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, VecLdCost = TTI->getMemoryOpCost( Instruction::Load, VecTy, LI0->getAlign(), LI0->getPointerAddressSpace(), CostKind, TTI::OperandValueInfo()); + } else if (E->State == TreeEntry::StridedVectorize) { + Align CommonAlignment = + computeCommonAlignment<LoadInst>(UniqueValues.getArrayRef()); + VecLdCost = TTI->getStridedMemoryOpCost( + Instruction::Load, VecTy, LI0->getPointerOperand(), + /*VariableMask=*/false, CommonAlignment, CostKind); } else { - assert((E->State == TreeEntry::ScatterVectorize || - E->State == TreeEntry::PossibleStridedVectorize) && - "Unknown EntryState"); - Align CommonAlignment = LI0->getAlign(); - for (Value *V : UniqueValues) - CommonAlignment = - std::min(CommonAlignment, cast<LoadInst>(V)->getAlign()); + assert(E->State == TreeEntry::ScatterVectorize && "Unknown EntryState"); + Align CommonAlignment = + computeCommonAlignment<LoadInst>(UniqueValues.getArrayRef()); VecLdCost = TTI->getGatherScatterOpCost( Instruction::Load, VecTy, LI0->getPointerOperand(), /*VariableMask=*/false, CommonAlignment, CostKind); @@ -8323,8 +9803,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, InstructionCost Cost = GetCostDiff(GetScalarCost, GetVectorCost); // If this node generates masked gather load then it is not a terminal node. // Hence address operand cost is estimated separately. - if (E->State == TreeEntry::ScatterVectorize || - E->State == TreeEntry::PossibleStridedVectorize) + if (E->State == TreeEntry::ScatterVectorize) return Cost; // Estimate cost of GEPs since this tree node is a terminator. @@ -8338,19 +9817,30 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, auto GetScalarCost = [=](unsigned Idx) { auto *VI = cast<StoreInst>(VL[Idx]); TTI::OperandValueInfo OpInfo = TTI::getOperandInfo(VI->getValueOperand()); - return TTI->getMemoryOpCost(Instruction::Store, ScalarTy, VI->getAlign(), - VI->getPointerAddressSpace(), CostKind, - OpInfo, VI); + return TTI->getMemoryOpCost(Instruction::Store, OrigScalarTy, + VI->getAlign(), VI->getPointerAddressSpace(), + CostKind, OpInfo, VI); }; auto *BaseSI = cast<StoreInst>(IsReorder ? VL[E->ReorderIndices.front()] : VL0); auto GetVectorCost = [=](InstructionCost CommonCost) { // We know that we can merge the stores. Calculate the cost. - TTI::OperandValueInfo OpInfo = getOperandInfo(E->getOperand(0)); - return TTI->getMemoryOpCost(Instruction::Store, VecTy, BaseSI->getAlign(), - BaseSI->getPointerAddressSpace(), CostKind, - OpInfo) + - CommonCost; + InstructionCost VecStCost; + if (E->State == TreeEntry::StridedVectorize) { + Align CommonAlignment = + computeCommonAlignment<StoreInst>(UniqueValues.getArrayRef()); + VecStCost = TTI->getStridedMemoryOpCost( + Instruction::Store, VecTy, BaseSI->getPointerOperand(), + /*VariableMask=*/false, CommonAlignment, CostKind); + } else { + assert(E->State == TreeEntry::Vectorize && + "Expected either strided or consecutive stores."); + TTI::OperandValueInfo OpInfo = getOperandInfo(E->getOperand(0)); + VecStCost = TTI->getMemoryOpCost( + Instruction::Store, VecTy, BaseSI->getAlign(), + BaseSI->getPointerAddressSpace(), CostKind, OpInfo); + } + return VecStCost + CommonCost; }; SmallVector<Value *> PointerOps(VL.size()); for (auto [I, V] : enumerate(VL)) { @@ -8375,7 +9865,11 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, }; auto GetVectorCost = [=](InstructionCost CommonCost) { auto *CI = cast<CallInst>(VL0); - auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI); + Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); + SmallVector<Type *> ArgTys = + buildIntrinsicArgTypes(CI, ID, VecTy->getNumElements(), + It != MinBWs.end() ? It->second.first : 0); + auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, ArgTys); return std::min(VecCallCosts.first, VecCallCosts.second) + CommonCost; }; return GetCostDiff(GetScalarCost, GetVectorCost); @@ -8410,11 +9904,9 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, (void)E; return TTI->getInstructionCost(VI, CostKind); }; - // FIXME: Workaround for syntax error reported by MSVC buildbots. - TargetTransformInfo &TTIRef = *TTI; // Need to clear CommonCost since the final shuffle cost is included into // vector cost. - auto GetVectorCost = [&](InstructionCost) { + auto GetVectorCost = [&, &TTIRef = *TTI](InstructionCost) { // VecCost is equal to sum of the cost of creating 2 vectors // and the cost of creating shuffle. InstructionCost VecCost = 0; @@ -8431,7 +9923,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, VecCost += TTIRef.getArithmeticInstrCost(E->getAltOpcode(), VecTy, CostKind); } else if (auto *CI0 = dyn_cast<CmpInst>(VL0)) { - auto *MaskTy = FixedVectorType::get(Builder.getInt1Ty(), VL.size()); + auto *MaskTy = getWidenedType(Builder.getInt1Ty(), VL.size()); VecCost = TTIRef.getCmpSelInstrCost(E->getOpcode(), VecTy, MaskTy, CI0->getPredicate(), CostKind, VL0); VecCost += TTIRef.getCmpSelInstrCost( @@ -8439,14 +9931,35 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, cast<CmpInst>(E->getAltOp())->getPredicate(), CostKind, E->getAltOp()); } else { - Type *Src0SclTy = E->getMainOp()->getOperand(0)->getType(); - Type *Src1SclTy = E->getAltOp()->getOperand(0)->getType(); - auto *Src0Ty = FixedVectorType::get(Src0SclTy, VL.size()); - auto *Src1Ty = FixedVectorType::get(Src1SclTy, VL.size()); - VecCost = TTIRef.getCastInstrCost(E->getOpcode(), VecTy, Src0Ty, + Type *SrcSclTy = E->getMainOp()->getOperand(0)->getType(); + auto *SrcTy = getWidenedType(SrcSclTy, VL.size()); + if (SrcSclTy->isIntegerTy() && ScalarTy->isIntegerTy()) { + auto SrcIt = MinBWs.find(getOperandEntry(E, 0)); + unsigned BWSz = DL->getTypeSizeInBits(ScalarTy); + unsigned SrcBWSz = + DL->getTypeSizeInBits(E->getMainOp()->getOperand(0)->getType()); + if (SrcIt != MinBWs.end()) { + SrcBWSz = SrcIt->second.first; + SrcSclTy = IntegerType::get(SrcSclTy->getContext(), SrcBWSz); + SrcTy = getWidenedType(SrcSclTy, VL.size()); + } + if (BWSz <= SrcBWSz) { + if (BWSz < SrcBWSz) + VecCost = + TTIRef.getCastInstrCost(Instruction::Trunc, VecTy, SrcTy, + TTI::CastContextHint::None, CostKind); + LLVM_DEBUG({ + dbgs() + << "SLP: alternate extension, which should be truncated.\n"; + E->dump(); + }); + return VecCost; + } + } + VecCost = TTIRef.getCastInstrCost(E->getOpcode(), VecTy, SrcTy, TTI::CastContextHint::None, CostKind); VecCost += - TTIRef.getCastInstrCost(E->getAltOpcode(), VecTy, Src1Ty, + TTIRef.getCastInstrCost(E->getAltOpcode(), VecTy, SrcTy, TTI::CastContextHint::None, CostKind); } SmallVector<int> Mask; @@ -8464,11 +9977,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals, // order. unsigned Opcode0 = E->getOpcode(); unsigned Opcode1 = E->getAltOpcode(); - // The opcode mask selects between the two opcodes. - SmallBitVector OpcodeMask(E->Scalars.size(), false); - for (unsigned Lane : seq<unsigned>(0, E->Scalars.size())) - if (cast<Instruction>(E->Scalars[Lane])->getOpcode() == Opcode1) - OpcodeMask.set(Lane); + SmallBitVector OpcodeMask(getAltInstrMask(E->Scalars, Opcode0, Opcode1)); // If this pattern is supported by the target then we consider the // order. if (TTIRef.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask)) { @@ -8492,19 +10001,16 @@ bool BoUpSLP::isFullyVectorizableTinyTree(bool ForReduction) const { auto &&AreVectorizableGathers = [this](const TreeEntry *TE, unsigned Limit) { SmallVector<int> Mask; - return TE->State == TreeEntry::NeedToGather && + return TE->isGather() && !any_of(TE->Scalars, [this](Value *V) { return EphValues.contains(V); }) && (allConstant(TE->Scalars) || isSplat(TE->Scalars) || TE->Scalars.size() < Limit || ((TE->getOpcode() == Instruction::ExtractElement || - all_of(TE->Scalars, - [](Value *V) { - return isa<ExtractElementInst, UndefValue>(V); - })) && + all_of(TE->Scalars, IsaPred<ExtractElementInst, UndefValue>)) && isFixedVectorShuffle(TE->Scalars, Mask)) || - (TE->State == TreeEntry::NeedToGather && - TE->getOpcode() == Instruction::Load && !TE->isAltShuffle())); + (TE->isGather() && TE->getOpcode() == Instruction::Load && + !TE->isAltShuffle())); }; // We only handle trees of heights 1 and 2. @@ -8530,10 +10036,10 @@ bool BoUpSLP::isFullyVectorizableTinyTree(bool ForReduction) const { return true; // Gathering cost would be too much for tiny trees. - if (VectorizableTree[0]->State == TreeEntry::NeedToGather || - (VectorizableTree[1]->State == TreeEntry::NeedToGather && + if (VectorizableTree[0]->isGather() || + (VectorizableTree[1]->isGather() && VectorizableTree[0]->State != TreeEntry::ScatterVectorize && - VectorizableTree[0]->State != TreeEntry::PossibleStridedVectorize)) + VectorizableTree[0]->State != TreeEntry::StridedVectorize)) return false; return true; @@ -8589,11 +10095,11 @@ bool BoUpSLP::isLoadCombineReductionCandidate(RecurKind RdxKind) const { /* MatchOr */ false); } -bool BoUpSLP::isLoadCombineCandidate() const { +bool BoUpSLP::isLoadCombineCandidate(ArrayRef<Value *> Stores) const { // Peek through a final sequence of stores and check if all operations are // likely to be load-combined. - unsigned NumElts = VectorizableTree[0]->Scalars.size(); - for (Value *Scalar : VectorizableTree[0]->Scalars) { + unsigned NumElts = Stores.size(); + for (Value *Scalar : Stores) { Value *X; if (!match(Scalar, m_Store(m_Value(X), m_Value())) || !isLoadCombineCandidateImpl(X, NumElts, TTI, /* MatchOr */ true)) @@ -8606,7 +10112,7 @@ bool BoUpSLP::isTreeTinyAndNotFullyVectorizable(bool ForReduction) const { // No need to vectorize inserts of gathered values. if (VectorizableTree.size() == 2 && isa<InsertElementInst>(VectorizableTree[0]->Scalars[0]) && - VectorizableTree[1]->State == TreeEntry::NeedToGather && + VectorizableTree[1]->isGather() && (VectorizableTree[1]->getVectorFactor() <= 2 || !(isSplat(VectorizableTree[1]->Scalars) || allConstant(VectorizableTree[1]->Scalars)))) @@ -8620,11 +10126,9 @@ bool BoUpSLP::isTreeTinyAndNotFullyVectorizable(bool ForReduction) const { if (!ForReduction && !SLPCostThreshold.getNumOccurrences() && !VectorizableTree.empty() && all_of(VectorizableTree, [&](const std::unique_ptr<TreeEntry> &TE) { - return (TE->State == TreeEntry::NeedToGather && + return (TE->isGather() && TE->getOpcode() != Instruction::ExtractElement && - count_if(TE->Scalars, - [](Value *V) { return isa<ExtractElementInst>(V); }) <= - Limit) || + count_if(TE->Scalars, IsaPred<ExtractElementInst>) <= Limit) || TE->getOpcode() == Instruction::PHI; })) return true; @@ -8639,6 +10143,25 @@ bool BoUpSLP::isTreeTinyAndNotFullyVectorizable(bool ForReduction) const { if (isFullyVectorizableTinyTree(ForReduction)) return false; + // Check if any of the gather node forms an insertelement buildvector + // somewhere. + bool IsAllowedSingleBVNode = + VectorizableTree.size() > 1 || + (VectorizableTree.size() == 1 && VectorizableTree.front()->getOpcode() && + !VectorizableTree.front()->isAltShuffle() && + VectorizableTree.front()->getOpcode() != Instruction::PHI && + VectorizableTree.front()->getOpcode() != Instruction::GetElementPtr && + allSameBlock(VectorizableTree.front()->Scalars)); + if (any_of(VectorizableTree, [&](const std::unique_ptr<TreeEntry> &TE) { + return TE->isGather() && all_of(TE->Scalars, [&](Value *V) { + return isa<ExtractElementInst, UndefValue>(V) || + (IsAllowedSingleBVNode && + !V->hasNUsesOrMore(UsesLimit) && + any_of(V->users(), IsaPred<InsertElementInst>)); + }); + })) + return false; + assert(VectorizableTree.empty() ? ExternalUses.empty() : true && "We shouldn't have any external users"); @@ -8754,7 +10277,7 @@ InstructionCost BoUpSLP::getSpillCost() const { auto *ScalarTy = II->getType(); if (auto *VectorTy = dyn_cast<FixedVectorType>(ScalarTy)) ScalarTy = VectorTy->getElementType(); - V.push_back(FixedVectorType::get(ScalarTy, BundleWidth)); + V.push_back(getWidenedType(ScalarTy, BundleWidth)); } Cost += NumCalls * TTI->getCostOfKeepingLiveOverCall(V); } @@ -8775,8 +10298,8 @@ static bool isFirstInsertElement(const InsertElementInst *IE1, const auto *I2 = IE2; const InsertElementInst *PrevI1; const InsertElementInst *PrevI2; - unsigned Idx1 = *getInsertIndex(IE1); - unsigned Idx2 = *getInsertIndex(IE2); + unsigned Idx1 = *getElementIndex(IE1); + unsigned Idx2 = *getElementIndex(IE2); do { if (I2 == IE1) return true; @@ -8785,10 +10308,10 @@ static bool isFirstInsertElement(const InsertElementInst *IE1, PrevI1 = I1; PrevI2 = I2; if (I1 && (I1 == IE1 || I1->hasOneUse()) && - getInsertIndex(I1).value_or(Idx2) != Idx2) + getElementIndex(I1).value_or(Idx2) != Idx2) I1 = dyn_cast<InsertElementInst>(I1->getOperand(0)); if (I2 && ((I2 == IE2 || I2->hasOneUse())) && - getInsertIndex(I2).value_or(Idx1) != Idx1) + getElementIndex(I2).value_or(Idx1) != Idx1) I2 = dyn_cast<InsertElementInst>(I2->getOperand(0)); } while ((I1 && PrevI1 != I1) || (I2 && PrevI2 != I2)); llvm_unreachable("Two different buildvectors not expected."); @@ -8929,7 +10452,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { SmallPtrSet<Value *, 4> CheckedExtracts; for (unsigned I = 0, E = VectorizableTree.size(); I < E; ++I) { TreeEntry &TE = *VectorizableTree[I]; - if (TE.State == TreeEntry::NeedToGather) { + if (TE.isGather()) { if (const TreeEntry *E = getTreeEntry(TE.getMainOp()); E && E->getVectorFactor() == TE.getVectorFactor() && E->isSame(TE.Scalars)) { @@ -8955,7 +10478,8 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { SmallVector<std::pair<Value *, const TreeEntry *>> FirstUsers; SmallVector<APInt> DemandedElts; SmallDenseSet<Value *, 4> UsedInserts; - DenseSet<Value *> VectorCasts; + DenseSet<std::pair<const TreeEntry *, Type *>> VectorCasts; + std::optional<DenseMap<Value *, unsigned>> ValueToExtUses; for (ExternalUser &EU : ExternalUses) { // We only add extract cost once for the same scalar. if (!isa_and_nonnull<InsertElementInst>(EU.User) && @@ -8974,11 +10498,12 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { // If found user is an insertelement, do not calculate extract cost but try // to detect it as a final shuffled/identity match. - if (auto *VU = dyn_cast_or_null<InsertElementInst>(EU.User)) { + if (auto *VU = dyn_cast_or_null<InsertElementInst>(EU.User); + VU && VU->getOperand(1) == EU.Scalar) { if (auto *FTy = dyn_cast<FixedVectorType>(VU->getType())) { if (!UsedInserts.insert(VU).second) continue; - std::optional<unsigned> InsertIdx = getInsertIndex(VU); + std::optional<unsigned> InsertIdx = getElementIndex(VU); if (InsertIdx) { const TreeEntry *ScalarTE = getTreeEntry(EU.Scalar); auto *It = find_if( @@ -9004,14 +10529,14 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { while (auto *IEBase = dyn_cast<InsertElementInst>(Base)) { if (IEBase != EU.User && (!IEBase->hasOneUse() || - getInsertIndex(IEBase).value_or(*InsertIdx) == *InsertIdx)) + getElementIndex(IEBase).value_or(*InsertIdx) == *InsertIdx)) break; // Build the mask for the vectorized insertelement instructions. if (const TreeEntry *E = getTreeEntry(IEBase)) { VU = IEBase; do { IEBase = cast<InsertElementInst>(Base); - int Idx = *getInsertIndex(IEBase); + int Idx = *getElementIndex(IEBase); assert(Mask[Idx] == PoisonMaskElem && "InsertElementInstruction used already."); Mask[Idx] = Idx; @@ -9025,11 +10550,14 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { DemandedElts.push_back(APInt::getZero(FTy->getNumElements())); VecId = FirstUsers.size() - 1; auto It = MinBWs.find(ScalarTE); - if (It != MinBWs.end() && VectorCasts.insert(EU.Scalar).second) { - unsigned BWSz = It->second.second; - unsigned SrcBWSz = DL->getTypeSizeInBits(FTy->getElementType()); + if (It != MinBWs.end() && + VectorCasts + .insert(std::make_pair(ScalarTE, FTy->getElementType())) + .second) { + unsigned BWSz = It->second.first; + unsigned DstBWSz = DL->getTypeSizeInBits(FTy->getElementType()); unsigned VecOpcode; - if (BWSz < SrcBWSz) + if (DstBWSz < BWSz) VecOpcode = Instruction::Trunc; else VecOpcode = @@ -9037,9 +10565,8 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; InstructionCost C = TTI->getCastInstrCost( VecOpcode, FTy, - FixedVectorType::get( - IntegerType::get(FTy->getContext(), It->second.first), - FTy->getNumElements()), + getWidenedType(IntegerType::get(FTy->getContext(), BWSz), + FTy->getNumElements()), TTI::CastContextHint::None, CostKind); LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C << " for extending externally used vector with " @@ -9061,18 +10588,46 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { } } } + // Leave the GEPs as is, they are free in most cases and better to keep them + // as GEPs. + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + if (auto *GEP = dyn_cast<GetElementPtrInst>(EU.Scalar)) { + if (!ValueToExtUses) { + ValueToExtUses.emplace(); + for_each(enumerate(ExternalUses), [&](const auto &P) { + ValueToExtUses->try_emplace(P.value().Scalar, P.index()); + }); + } + // Can use original GEP, if no operands vectorized or they are marked as + // externally used already. + bool CanBeUsedAsGEP = all_of(GEP->operands(), [&](Value *V) { + if (!getTreeEntry(V)) + return true; + auto It = ValueToExtUses->find(V); + if (It != ValueToExtUses->end()) { + // Replace all uses to avoid compiler crash. + ExternalUses[It->second].User = nullptr; + return true; + } + return false; + }); + if (CanBeUsedAsGEP) { + ExtractCost += TTI->getInstructionCost(GEP, CostKind); + ExternalUsesAsGEPs.insert(EU.Scalar); + continue; + } + } // If we plan to rewrite the tree in a smaller type, we will need to sign // extend the extracted value back to the original type. Here, we account // for the extract and the added cost of the sign extend if needed. - auto *VecTy = FixedVectorType::get(EU.Scalar->getType(), BundleWidth); - TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + auto *VecTy = getWidenedType(EU.Scalar->getType(), BundleWidth); auto It = MinBWs.find(getTreeEntry(EU.Scalar)); if (It != MinBWs.end()) { auto *MinTy = IntegerType::get(F->getContext(), It->second.first); unsigned Extend = It->second.second ? Instruction::SExt : Instruction::ZExt; - VecTy = FixedVectorType::get(MinTy, BundleWidth); + VecTy = getWidenedType(MinTy, BundleWidth); ExtractCost += TTI->getExtractWithExtendCost(Extend, EU.Scalar->getType(), VecTy, EU.Lane); } else { @@ -9082,17 +10637,22 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { } // Add reduced value cost, if resized. if (!VectorizedVals.empty()) { - auto BWIt = MinBWs.find(VectorizableTree.front().get()); + const TreeEntry &Root = *VectorizableTree.front(); + auto BWIt = MinBWs.find(&Root); if (BWIt != MinBWs.end()) { - Type *DstTy = VectorizableTree.front()->Scalars.front()->getType(); + Type *DstTy = Root.Scalars.front()->getType(); unsigned OriginalSz = DL->getTypeSizeInBits(DstTy); - unsigned Opcode = Instruction::Trunc; - if (OriginalSz < BWIt->second.first) - Opcode = BWIt->second.second ? Instruction::SExt : Instruction::ZExt; - Type *SrcTy = IntegerType::get(DstTy->getContext(), BWIt->second.first); - Cost += TTI->getCastInstrCost(Opcode, DstTy, SrcTy, - TTI::CastContextHint::None, - TTI::TCK_RecipThroughput); + unsigned SrcSz = + ReductionBitWidth == 0 ? BWIt->second.first : ReductionBitWidth; + if (OriginalSz != SrcSz) { + unsigned Opcode = Instruction::Trunc; + if (OriginalSz > SrcSz) + Opcode = BWIt->second.second ? Instruction::SExt : Instruction::ZExt; + Type *SrcTy = IntegerType::get(DstTy->getContext(), SrcSz); + Cost += TTI->getCastInstrCost(Opcode, DstTy, SrcTy, + TTI::CastContextHint::None, + TTI::TCK_RecipThroughput); + } } } @@ -9109,9 +10669,9 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { SmallVector<int> OrigMask(VecVF, PoisonMaskElem); std::copy(Mask.begin(), std::next(Mask.begin(), std::min(VF, VecVF)), OrigMask.begin()); - C = TTI->getShuffleCost( - TTI::SK_PermuteSingleSrc, - FixedVectorType::get(TE->getMainOp()->getType(), VecVF), OrigMask); + C = TTI->getShuffleCost(TTI::SK_PermuteSingleSrc, + getWidenedType(TE->getMainOp()->getType(), VecVF), + OrigMask); LLVM_DEBUG( dbgs() << "SLP: Adding cost " << C << " for final shuffle of insertelement external users.\n"; @@ -9133,8 +10693,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { if (TEs.size() == 1) { if (VF == 0) VF = TEs.front()->getVectorFactor(); - auto *FTy = - FixedVectorType::get(TEs.back()->Scalars.front()->getType(), VF); + auto *FTy = getWidenedType(TEs.back()->Scalars.front()->getType(), VF); if (!ShuffleVectorInst::isIdentityMask(Mask, VF) && !all_of(enumerate(Mask), [=](const auto &Data) { return Data.value() == PoisonMaskElem || @@ -9158,8 +10717,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { else VF = Mask.size(); } - auto *FTy = - FixedVectorType::get(TEs.back()->Scalars.front()->getType(), VF); + auto *FTy = getWidenedType(TEs.back()->Scalars.front()->getType(), VF); InstructionCost C = ::getShuffleCost(*TTI, TTI::SK_PermuteTwoSrc, FTy, Mask); LLVM_DEBUG(dbgs() << "SLP: Adding cost " << C @@ -9182,6 +10740,44 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) { Cost -= InsertCost; } + // Add the cost for reduced value resize (if required). + if (ReductionBitWidth != 0) { + assert(UserIgnoreList && "Expected reduction tree."); + const TreeEntry &E = *VectorizableTree.front(); + auto It = MinBWs.find(&E); + if (It != MinBWs.end() && It->second.first != ReductionBitWidth) { + unsigned SrcSize = It->second.first; + unsigned DstSize = ReductionBitWidth; + unsigned Opcode = Instruction::Trunc; + if (SrcSize < DstSize) + Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt; + auto *SrcVecTy = + getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor()); + auto *DstVecTy = + getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor()); + TTI::CastContextHint CCH = getCastContextHint(E); + InstructionCost CastCost; + switch (E.getOpcode()) { + case Instruction::SExt: + case Instruction::ZExt: + case Instruction::Trunc: { + const TreeEntry *OpTE = getOperandEntry(&E, 0); + CCH = getCastContextHint(*OpTE); + break; + } + default: + break; + } + CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH, + TTI::TCK_RecipThroughput); + Cost += CastCost; + LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost + << " for final resize for reduction from " << SrcVecTy + << " to " << DstVecTy << "\n"; + dbgs() << "SLP: Current total cost = " << Cost << "\n"); + } + } + #ifndef NDEBUG SmallString<256> Str; { @@ -9235,36 +10831,20 @@ BoUpSLP::tryToGatherSingleRegisterExtractElements( VectorOpToIdx[EI->getVectorOperand()].push_back(I); } // Sort the vector operands by the maximum number of uses in extractelements. - MapVector<unsigned, SmallVector<Value *>> VFToVector; - for (const auto &Data : VectorOpToIdx) - VFToVector[cast<FixedVectorType>(Data.first->getType())->getNumElements()] - .push_back(Data.first); - for (auto &Data : VFToVector) { - stable_sort(Data.second, [&VectorOpToIdx](Value *V1, Value *V2) { - return VectorOpToIdx.find(V1)->second.size() > - VectorOpToIdx.find(V2)->second.size(); - }); - } - // Find the best pair of the vectors with the same number of elements or a - // single vector. + SmallVector<std::pair<Value *, SmallVector<int>>> Vectors = + VectorOpToIdx.takeVector(); + stable_sort(Vectors, [](const auto &P1, const auto &P2) { + return P1.second.size() > P2.second.size(); + }); + // Find the best pair of the vectors or a single vector. const int UndefSz = UndefVectorExtracts.size(); unsigned SingleMax = 0; - Value *SingleVec = nullptr; unsigned PairMax = 0; - std::pair<Value *, Value *> PairVec(nullptr, nullptr); - for (auto &Data : VFToVector) { - Value *V1 = Data.second.front(); - if (SingleMax < VectorOpToIdx[V1].size() + UndefSz) { - SingleMax = VectorOpToIdx[V1].size() + UndefSz; - SingleVec = V1; - } - Value *V2 = nullptr; - if (Data.second.size() > 1) - V2 = *std::next(Data.second.begin()); - if (V2 && PairMax < VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() + - UndefSz) { - PairMax = VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() + UndefSz; - PairVec = std::make_pair(V1, V2); + if (!Vectors.empty()) { + SingleMax = Vectors.front().second.size() + UndefSz; + if (Vectors.size() > 1) { + auto *ItNext = std::next(Vectors.begin()); + PairMax = SingleMax + ItNext->second.size(); } } if (SingleMax == 0 && PairMax == 0 && UndefSz == 0) @@ -9275,11 +10855,11 @@ BoUpSLP::tryToGatherSingleRegisterExtractElements( SmallVector<Value *> GatheredExtracts( VL.size(), PoisonValue::get(VL.front()->getType())); if (SingleMax >= PairMax && SingleMax) { - for (int Idx : VectorOpToIdx[SingleVec]) + for (int Idx : Vectors.front().second) std::swap(GatheredExtracts[Idx], VL[Idx]); - } else { - for (Value *V : {PairVec.first, PairVec.second}) - for (int Idx : VectorOpToIdx[V]) + } else if (!Vectors.empty()) { + for (unsigned Idx : {0, 1}) + for (int Idx : Vectors[Idx].second) std::swap(GatheredExtracts[Idx], VL[Idx]); } // Add extracts from undefs too. @@ -9324,12 +10904,12 @@ BoUpSLP::tryToGatherExtractElements(SmallVectorImpl<Value *> &VL, assert(NumParts > 0 && "NumParts expected be greater than or equal to 1."); SmallVector<std::optional<TTI::ShuffleKind>> ShufflesRes(NumParts); Mask.assign(VL.size(), PoisonMaskElem); - unsigned SliceSize = VL.size() / NumParts; - for (unsigned Part = 0; Part < NumParts; ++Part) { + unsigned SliceSize = getPartNumElems(VL.size(), NumParts); + for (unsigned Part : seq<unsigned>(NumParts)) { // Scan list of gathered scalars for extractelements that can be represented // as shuffles. - MutableArrayRef<Value *> SubVL = - MutableArrayRef(VL).slice(Part * SliceSize, SliceSize); + MutableArrayRef<Value *> SubVL = MutableArrayRef(VL).slice( + Part * SliceSize, getNumElems(VL.size(), SliceSize, Part)); SmallVector<int> SubMask; std::optional<TTI::ShuffleKind> Res = tryToGatherSingleRegisterExtractElements(SubVL, SubMask); @@ -9346,7 +10926,7 @@ BoUpSLP::tryToGatherExtractElements(SmallVectorImpl<Value *> &VL, std::optional<TargetTransformInfo::ShuffleKind> BoUpSLP::isGatherShuffledSingleRegisterEntry( const TreeEntry *TE, ArrayRef<Value *> VL, MutableArrayRef<int> Mask, - SmallVectorImpl<const TreeEntry *> &Entries, unsigned Part) { + SmallVectorImpl<const TreeEntry *> &Entries, unsigned Part, bool ForOrder) { Entries.clear(); // TODO: currently checking only for Scalars in the tree entry, need to count // reused elements too for better cost estimation. @@ -9361,6 +10941,8 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry( } else { TEInsertBlock = TEInsertPt->getParent(); } + if (!DT->isReachableFromEntry(TEInsertBlock)) + return std::nullopt; auto *NodeUI = DT->getNode(TEInsertBlock); assert(NodeUI && "Should only process reachable instructions"); SmallPtrSet<Value *, 4> GatheredScalars(VL.begin(), VL.end()); @@ -9443,14 +11025,24 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry( VToTEs.insert(TEPtr); } if (const TreeEntry *VTE = getTreeEntry(V)) { + if (ForOrder) { + if (VTE->State != TreeEntry::Vectorize) { + auto It = MultiNodeScalars.find(V); + if (It == MultiNodeScalars.end()) + continue; + VTE = *It->getSecond().begin(); + // Iterate through all vectorized nodes. + auto *MIt = find_if(It->getSecond(), [](const TreeEntry *MTE) { + return MTE->State == TreeEntry::Vectorize; + }); + if (MIt == It->getSecond().end()) + continue; + VTE = *MIt; + } + } Instruction &LastBundleInst = getLastInstructionInBundle(VTE); if (&LastBundleInst == TEInsertPt || !CheckOrdering(&LastBundleInst)) continue; - auto It = MinBWs.find(VTE); - // If vectorize node is demoted - do not match. - if (It != MinBWs.end() && - It->second.first != DL->getTypeSizeInBits(V->getType())) - continue; VToTEs.insert(VTE); } if (VToTEs.empty()) @@ -9565,18 +11157,17 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry( // No 2 source vectors with the same vector factor - just choose 2 with max // index. if (Entries.empty()) { - Entries.push_back( - *std::max_element(UsedTEs.front().begin(), UsedTEs.front().end(), - [](const TreeEntry *TE1, const TreeEntry *TE2) { - return TE1->Idx < TE2->Idx; - })); + Entries.push_back(*llvm::max_element( + UsedTEs.front(), [](const TreeEntry *TE1, const TreeEntry *TE2) { + return TE1->Idx < TE2->Idx; + })); Entries.push_back(SecondEntries.front()); VF = std::max(Entries.front()->getVectorFactor(), Entries.back()->getVectorFactor()); } } - bool IsSplatOrUndefs = isSplat(VL) || all_of(VL, UndefValue::classof); + bool IsSplatOrUndefs = isSplat(VL) || all_of(VL, IsaPred<UndefValue>); // Checks if the 2 PHIs are compatible in terms of high possibility to be // vectorized. auto AreCompatiblePHIs = [&](Value *V, Value *V1) { @@ -9676,8 +11267,12 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry( // scalar in the list. for (const std::pair<unsigned, int> &Pair : EntryLanes) { unsigned Idx = Part * VL.size() + Pair.second; - Mask[Idx] = Pair.first * VF + - Entries[Pair.first]->findLaneForValue(VL[Pair.second]); + Mask[Idx] = + Pair.first * VF + + (ForOrder ? std::distance( + Entries[Pair.first]->Scalars.begin(), + find(Entries[Pair.first]->Scalars, VL[Pair.second])) + : Entries[Pair.first]->findLaneForValue(VL[Pair.second])); IsIdentity &= Mask[Idx] == Pair.second; } switch (Entries.size()) { @@ -9702,26 +11297,31 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry( SmallVector<std::optional<TargetTransformInfo::ShuffleKind>> BoUpSLP::isGatherShuffledEntry( const TreeEntry *TE, ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask, - SmallVectorImpl<SmallVector<const TreeEntry *>> &Entries, - unsigned NumParts) { + SmallVectorImpl<SmallVector<const TreeEntry *>> &Entries, unsigned NumParts, + bool ForOrder) { assert(NumParts > 0 && NumParts < VL.size() && "Expected positive number of registers."); Entries.clear(); // No need to check for the topmost gather node. if (TE == VectorizableTree.front().get()) return {}; + // FIXME: Gathering for non-power-of-2 nodes not implemented yet. + if (TE->isNonPowOf2Vec()) + return {}; Mask.assign(VL.size(), PoisonMaskElem); assert(TE->UserTreeIndices.size() == 1 && "Expected only single user of the gather node."); assert(VL.size() % NumParts == 0 && "Number of scalars must be divisible by NumParts."); - unsigned SliceSize = VL.size() / NumParts; + unsigned SliceSize = getPartNumElems(VL.size(), NumParts); SmallVector<std::optional<TTI::ShuffleKind>> Res; - for (unsigned Part = 0; Part < NumParts; ++Part) { - ArrayRef<Value *> SubVL = VL.slice(Part * SliceSize, SliceSize); + for (unsigned Part : seq<unsigned>(NumParts)) { + ArrayRef<Value *> SubVL = + VL.slice(Part * SliceSize, getNumElems(VL.size(), SliceSize, Part)); SmallVectorImpl<const TreeEntry *> &SubEntries = Entries.emplace_back(); std::optional<TTI::ShuffleKind> SubRes = - isGatherShuffledSingleRegisterEntry(TE, SubVL, Mask, SubEntries, Part); + isGatherShuffledSingleRegisterEntry(TE, SubVL, Mask, SubEntries, Part, + ForOrder); if (!SubRes) SubEntries.clear(); Res.push_back(SubRes); @@ -9751,60 +11351,68 @@ BoUpSLP::isGatherShuffledEntry( return Res; } -InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL, - bool ForPoisonSrc) const { - // Find the type of the operands in VL. - Type *ScalarTy = VL[0]->getType(); - if (StoreInst *SI = dyn_cast<StoreInst>(VL[0])) - ScalarTy = SI->getValueOperand()->getType(); - auto *VecTy = FixedVectorType::get(ScalarTy, VL.size()); +InstructionCost BoUpSLP::getGatherCost(ArrayRef<Value *> VL, bool ForPoisonSrc, + Type *ScalarTy) const { + auto *VecTy = getWidenedType(ScalarTy, VL.size()); bool DuplicateNonConst = false; // Find the cost of inserting/extracting values from the vector. // Check if the same elements are inserted several times and count them as // shuffle candidates. APInt ShuffledElements = APInt::getZero(VL.size()); - DenseSet<Value *> UniqueElements; + DenseMap<Value *, unsigned> UniqueElements; constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; InstructionCost Cost; auto EstimateInsertCost = [&](unsigned I, Value *V) { + if (V->getType() != ScalarTy) { + Cost += TTI->getCastInstrCost(Instruction::Trunc, ScalarTy, V->getType(), + TTI::CastContextHint::None, CostKind); + V = nullptr; + } if (!ForPoisonSrc) Cost += TTI->getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, I, Constant::getNullValue(VecTy), V); }; + SmallVector<int> ShuffleMask(VL.size(), PoisonMaskElem); for (unsigned I = 0, E = VL.size(); I < E; ++I) { Value *V = VL[I]; // No need to shuffle duplicates for constants. if ((ForPoisonSrc && isConstant(V)) || isa<UndefValue>(V)) { ShuffledElements.setBit(I); + ShuffleMask[I] = isa<PoisonValue>(V) ? PoisonMaskElem : I; continue; } - if (!UniqueElements.insert(V).second) { - DuplicateNonConst = true; - ShuffledElements.setBit(I); + + auto Res = UniqueElements.try_emplace(V, I); + if (Res.second) { + EstimateInsertCost(I, V); + ShuffleMask[I] = I; continue; } - EstimateInsertCost(I, V); + + DuplicateNonConst = true; + ShuffledElements.setBit(I); + ShuffleMask[I] = Res.first->second; } if (ForPoisonSrc) Cost = TTI->getScalarizationOverhead(VecTy, ~ShuffledElements, /*Insert*/ true, /*Extract*/ false, CostKind); if (DuplicateNonConst) - Cost += - TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy); + Cost += TTI->getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, + VecTy, ShuffleMask); return Cost; } // Perform operand reordering on the instructions in VL and return the reordered // operands in Left and Right. -void BoUpSLP::reorderInputsAccordingToOpcode( - ArrayRef<Value *> VL, SmallVectorImpl<Value *> &Left, - SmallVectorImpl<Value *> &Right, const TargetLibraryInfo &TLI, - const DataLayout &DL, ScalarEvolution &SE, const BoUpSLP &R) { +void BoUpSLP::reorderInputsAccordingToOpcode(ArrayRef<Value *> VL, + SmallVectorImpl<Value *> &Left, + SmallVectorImpl<Value *> &Right, + const BoUpSLP &R) { if (VL.empty()) return; - VLOperands Ops(VL, TLI, DL, SE, R); + VLOperands Ops(VL, R); // Reorder the operands in place. Ops.reorder(); Left = Ops.getVL(0); @@ -9903,16 +11511,21 @@ Instruction &BoUpSLP::getLastInstructionInBundle(const TreeEntry *E) { // Set the insert point to the beginning of the basic block if the entry // should not be scheduled. if (doesNotNeedToSchedule(E->Scalars) || - (E->State != TreeEntry::NeedToGather && - all_of(E->Scalars, isVectorLikeInstWithConstOps))) { + (!E->isGather() && all_of(E->Scalars, isVectorLikeInstWithConstOps))) { if ((E->getOpcode() == Instruction::GetElementPtr && any_of(E->Scalars, [](Value *V) { return !isa<GetElementPtrInst>(V) && isa<Instruction>(V); })) || - all_of(E->Scalars, [](Value *V) { - return !isVectorLikeInstWithConstOps(V) && isUsedOutsideBlock(V); - })) + all_of(E->Scalars, + [](Value *V) { + return !isVectorLikeInstWithConstOps(V) && + isUsedOutsideBlock(V); + }) || + (E->isGather() && E->Idx == 0 && all_of(E->Scalars, [](Value *V) { + return isa<ExtractElementInst, UndefValue>(V) || + areAllOperandsNonInsts(V); + }))) Res.second = FindLastInst(); else Res.second = FindFirstInst(); @@ -9967,8 +11580,7 @@ void BoUpSLP::setInsertPointAfterBundle(const TreeEntry *E) { bool IsPHI = isa<PHINode>(LastInst); if (IsPHI) LastInstIt = LastInst->getParent()->getFirstNonPHIIt(); - if (IsPHI || (E->State != TreeEntry::NeedToGather && - doesNotNeedToSchedule(E->Scalars))) { + if (IsPHI || (!E->isGather() && doesNotNeedToSchedule(E->Scalars))) { Builder.SetInsertPoint(LastInst->getParent(), LastInstIt); } else { // Set the insertion point after the last instruction in the bundle. Set the @@ -9980,7 +11592,7 @@ void BoUpSLP::setInsertPointAfterBundle(const TreeEntry *E) { Builder.SetCurrentDebugLocation(Front->getDebugLoc()); } -Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root) { +Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root, Type *ScalarTy) { // List of instructions/lanes from current block and/or the blocks which are // part of the current loop. These instructions will be inserted at the end to // make it possible to optimize loops and hoist invariant instructions out of @@ -10003,8 +11615,25 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root) { PostponedInsts.emplace_back(Inst, I); } - auto &&CreateInsertElement = [this](Value *Vec, Value *V, unsigned Pos) { - Vec = Builder.CreateInsertElement(Vec, V, Builder.getInt32(Pos)); + auto &&CreateInsertElement = [this](Value *Vec, Value *V, unsigned Pos, + Type *Ty) { + Value *Scalar = V; + if (Scalar->getType() != Ty) { + assert(Scalar->getType()->isIntegerTy() && Ty->isIntegerTy() && + "Expected integer types only."); + Value *V = Scalar; + if (auto *CI = dyn_cast<CastInst>(Scalar); + isa_and_nonnull<SExtInst, ZExtInst>(CI)) { + Value *Op = CI->getOperand(0); + if (auto *IOp = dyn_cast<Instruction>(Op); + !IOp || !(isDeleted(IOp) || getTreeEntry(IOp))) + V = Op; + } + Scalar = Builder.CreateIntCast( + V, Ty, !isKnownNonNegative(Scalar, SimplifyQuery(*DL))); + } + + Vec = Builder.CreateInsertElement(Vec, Scalar, Builder.getInt32(Pos)); auto *InsElt = dyn_cast<InsertElementInst>(Vec); if (!InsElt) return Vec; @@ -10014,15 +11643,22 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root) { if (isa<Instruction>(V)) { if (TreeEntry *Entry = getTreeEntry(V)) { // Find which lane we need to extract. - unsigned FoundLane = Entry->findLaneForValue(V); - ExternalUses.emplace_back(V, InsElt, FoundLane); + User *UserOp = nullptr; + if (Scalar != V) { + if (auto *SI = dyn_cast<Instruction>(Scalar)) + UserOp = SI; + } else { + UserOp = InsElt; + } + if (UserOp) { + unsigned FoundLane = Entry->findLaneForValue(V); + ExternalUses.emplace_back(V, UserOp, FoundLane); + } } } return Vec; }; - Value *Val0 = - isa<StoreInst>(VL[0]) ? cast<StoreInst>(VL[0])->getValueOperand() : VL[0]; - FixedVectorType *VecTy = FixedVectorType::get(Val0->getType(), VL.size()); + auto *VecTy = getWidenedType(ScalarTy, VL.size()); Value *Vec = Root ? Root : PoisonValue::get(VecTy); SmallVector<int> NonConsts; // Insert constant values at first. @@ -10045,15 +11681,15 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root) { continue; } } - Vec = CreateInsertElement(Vec, VL[I], I); + Vec = CreateInsertElement(Vec, VL[I], I, ScalarTy); } // Insert non-constant values. for (int I : NonConsts) - Vec = CreateInsertElement(Vec, VL[I], I); + Vec = CreateInsertElement(Vec, VL[I], I, ScalarTy); // Append instructions, which are/may be part of the loop, in the end to make // it possible to hoist non-loop-based instructions. for (const std::pair<Value *, unsigned> &Pair : PostponedInsts) - Vec = CreateInsertElement(Vec, Pair.first, Pair.second); + Vec = CreateInsertElement(Vec, Pair.first, Pair.second, ScalarTy); return Vec; } @@ -10101,6 +11737,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis { /// resulting shuffle and the second operand sets to be the newly added /// operand. The \p CommonMask is transformed in the proper way after that. SmallVector<Value *, 2> InVectors; + Type *ScalarTy = nullptr; IRBuilderBase &Builder; BoUpSLP &R; @@ -10110,16 +11747,35 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis { SetVector<Instruction *> &GatherShuffleExtractSeq; /// A list of blocks that we are going to CSE. DenseSet<BasicBlock *> &CSEBlocks; + /// Data layout. + const DataLayout &DL; public: ShuffleIRBuilder(IRBuilderBase &Builder, SetVector<Instruction *> &GatherShuffleExtractSeq, - DenseSet<BasicBlock *> &CSEBlocks) + DenseSet<BasicBlock *> &CSEBlocks, const DataLayout &DL) : Builder(Builder), GatherShuffleExtractSeq(GatherShuffleExtractSeq), - CSEBlocks(CSEBlocks) {} + CSEBlocks(CSEBlocks), DL(DL) {} ~ShuffleIRBuilder() = default; /// Creates shufflevector for the 2 operands with the given mask. Value *createShuffleVector(Value *V1, Value *V2, ArrayRef<int> Mask) { + if (V1->getType() != V2->getType()) { + assert(V1->getType()->isIntOrIntVectorTy() && + V1->getType()->isIntOrIntVectorTy() && + "Expected integer vector types only."); + if (V1->getType() != V2->getType()) { + if (cast<VectorType>(V2->getType()) + ->getElementType() + ->getIntegerBitWidth() < cast<VectorType>(V1->getType()) + ->getElementType() + ->getIntegerBitWidth()) + V2 = Builder.CreateIntCast( + V2, V1->getType(), !isKnownNonNegative(V2, SimplifyQuery(DL))); + else + V1 = Builder.CreateIntCast( + V1, V2->getType(), !isKnownNonNegative(V1, SimplifyQuery(DL))); + } + } Value *Vec = Builder.CreateShuffleVector(V1, V2, Mask); if (auto *I = dyn_cast<Instruction>(Vec)) { GatherShuffleExtractSeq.insert(I); @@ -10145,7 +11801,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis { } Value *createIdentity(Value *V) { return V; } Value *createPoison(Type *Ty, unsigned VF) { - return PoisonValue::get(FixedVectorType::get(Ty, VF)); + return PoisonValue::get(getWidenedType(Ty, VF)); } /// Resizes 2 input vector to match the sizes, if the they are not equal /// yet. The smallest vector is resized to the size of the larger vector. @@ -10178,7 +11834,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis { Value *createShuffle(Value *V1, Value *V2, ArrayRef<int> Mask) { assert(V1 && "Expected at least one vector value."); ShuffleIRBuilder ShuffleBuilder(Builder, R.GatherShuffleExtractSeq, - R.CSEBlocks); + R.CSEBlocks, *R.DL); return BaseShuffleAnalysis::createShuffle<Value *>(V1, V2, Mask, ShuffleBuilder); } @@ -10192,9 +11848,22 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis { CommonMask[Idx] = Idx; } + /// Cast value \p V to the vector type with the same number of elements, but + /// the base type \p ScalarTy. + Value *castToScalarTyElem(Value *V, + std::optional<bool> IsSigned = std::nullopt) { + auto *VecTy = cast<VectorType>(V->getType()); + assert(getNumElements(VecTy) % getNumElements(ScalarTy) == 0); + if (VecTy->getElementType() == ScalarTy->getScalarType()) + return V; + return Builder.CreateIntCast( + V, VectorType::get(ScalarTy->getScalarType(), VecTy->getElementCount()), + IsSigned.value_or(!isKnownNonNegative(V, SimplifyQuery(*R.DL)))); + } + public: - ShuffleInstructionBuilder(IRBuilderBase &Builder, BoUpSLP &R) - : Builder(Builder), R(R) {} + ShuffleInstructionBuilder(Type *ScalarTy, IRBuilderBase &Builder, BoUpSLP &R) + : ScalarTy(ScalarTy), Builder(Builder), R(R) {} /// Adjusts extractelements after reusing them. Value *adjustExtracts(const TreeEntry *E, MutableArrayRef<int> Mask, @@ -10219,6 +11888,8 @@ public: any_of(EI->users(), [&](User *U) { const TreeEntry *UTE = R.getTreeEntry(U); return !UTE || R.MultiNodeScalars.contains(U) || + (isa<GetElementPtrInst>(U) && + !R.areAllUsersVectorized(cast<Instruction>(U))) || count_if(R.VectorizableTree, [&](const std::unique_ptr<TreeEntry> &TE) { return any_of(TE->UserTreeIndices, @@ -10231,8 +11902,10 @@ public: continue; R.eraseInstruction(EI); } - if (NumParts == 1 || UniqueBases.size() == 1) - return VecBase; + if (NumParts == 1 || UniqueBases.size() == 1) { + assert(VecBase && "Expected vectorized value."); + return castToScalarTyElem(VecBase); + } UseVecBaseAsInput = true; auto TransformToIdentity = [](MutableArrayRef<int> Mask) { for (auto [I, Idx] : enumerate(Mask)) @@ -10245,31 +11918,37 @@ public: // into a long virtual vector register, forming the original vector. Value *Vec = nullptr; SmallVector<int> VecMask(Mask.size(), PoisonMaskElem); - unsigned SliceSize = E->Scalars.size() / NumParts; - for (unsigned Part = 0; Part < NumParts; ++Part) { + unsigned SliceSize = getPartNumElems(E->Scalars.size(), NumParts); + for (unsigned Part : seq<unsigned>(NumParts)) { + unsigned Limit = getNumElems(E->Scalars.size(), SliceSize, Part); ArrayRef<Value *> VL = - ArrayRef(E->Scalars).slice(Part * SliceSize, SliceSize); - MutableArrayRef<int> SubMask = Mask.slice(Part * SliceSize, SliceSize); + ArrayRef(E->Scalars).slice(Part * SliceSize, Limit); + MutableArrayRef<int> SubMask = Mask.slice(Part * SliceSize, Limit); constexpr int MaxBases = 2; SmallVector<Value *, MaxBases> Bases(MaxBases); -#ifndef NDEBUG - int PrevSize = 0; -#endif // NDEBUG - for (const auto [I, V]: enumerate(VL)) { - if (SubMask[I] == PoisonMaskElem) + auto VLMask = zip(VL, SubMask); + const unsigned VF = std::accumulate( + VLMask.begin(), VLMask.end(), 0U, [&](unsigned S, const auto &D) { + if (std::get<1>(D) == PoisonMaskElem) + return S; + Value *VecOp = + cast<ExtractElementInst>(std::get<0>(D))->getVectorOperand(); + if (const TreeEntry *TE = R.getTreeEntry(VecOp)) + VecOp = TE->VectorizedValue; + assert(VecOp && "Expected vectorized value."); + const unsigned Size = + cast<FixedVectorType>(VecOp->getType())->getNumElements(); + return std::max(S, Size); + }); + for (const auto [V, I] : VLMask) { + if (I == PoisonMaskElem) continue; Value *VecOp = cast<ExtractElementInst>(V)->getVectorOperand(); if (const TreeEntry *TE = R.getTreeEntry(VecOp)) VecOp = TE->VectorizedValue; assert(VecOp && "Expected vectorized value."); - const int Size = - cast<FixedVectorType>(VecOp->getType())->getNumElements(); -#ifndef NDEBUG - assert((PrevSize == Size || PrevSize == 0) && - "Expected vectors of the same size."); - PrevSize = Size; -#endif // NDEBUG - Bases[SubMask[I] < Size ? 0 : 1] = VecOp; + VecOp = castToScalarTyElem(VecOp); + Bases[I / VF] = VecOp; } if (!Bases.front()) continue; @@ -10285,7 +11964,9 @@ public: assert((Part == 0 || all_of(seq<unsigned>(0, Part), [&](unsigned P) { ArrayRef<int> SubMask = - Mask.slice(P * SliceSize, SliceSize); + Mask.slice(P * SliceSize, + getNumElems(Mask.size(), + SliceSize, P)); return all_of(SubMask, [](int Idx) { return Idx == PoisonMaskElem; }); @@ -10293,16 +11974,17 @@ public: "Expected first part or all previous parts masked."); copy(SubMask, std::next(VecMask.begin(), Part * SliceSize)); } else { - unsigned VF = cast<FixedVectorType>(Vec->getType())->getNumElements(); + unsigned NewVF = + cast<FixedVectorType>(Vec->getType())->getNumElements(); if (Vec->getType() != SubVec->getType()) { unsigned SubVecVF = cast<FixedVectorType>(SubVec->getType())->getNumElements(); - VF = std::max(VF, SubVecVF); + NewVF = std::max(NewVF, SubVecVF); } // Adjust SubMask. - for (auto [I, Idx] : enumerate(SubMask)) + for (int &Idx : SubMask) if (Idx != PoisonMaskElem) - Idx += VF; + Idx += NewVF; copy(SubMask, std::next(VecMask.begin(), Part * SliceSize)); Vec = createShuffle(Vec, SubVec, VecMask); TransformToIdentity(VecMask); @@ -10324,25 +12006,45 @@ public: return std::nullopt; // Postpone gather emission, will be emitted after the end of the // process to keep correct order. - auto *VecTy = FixedVectorType::get(E->Scalars.front()->getType(), - E->getVectorFactor()); + auto *ResVecTy = getWidenedType(ScalarTy, E->getVectorFactor()); return Builder.CreateAlignedLoad( - VecTy, PoisonValue::get(PointerType::getUnqual(VecTy->getContext())), + ResVecTy, + PoisonValue::get(PointerType::getUnqual(ScalarTy->getContext())), MaybeAlign()); } /// Adds 2 input vectors (in form of tree entries) and the mask for their /// shuffling. void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) { - add(E1.VectorizedValue, E2.VectorizedValue, Mask); + Value *V1 = E1.VectorizedValue; + if (V1->getType()->isIntOrIntVectorTy()) + V1 = castToScalarTyElem(V1, any_of(E1.Scalars, [&](Value *V) { + return !isKnownNonNegative( + V, SimplifyQuery(*R.DL)); + })); + Value *V2 = E2.VectorizedValue; + if (V2->getType()->isIntOrIntVectorTy()) + V2 = castToScalarTyElem(V2, any_of(E2.Scalars, [&](Value *V) { + return !isKnownNonNegative( + V, SimplifyQuery(*R.DL)); + })); + add(V1, V2, Mask); } /// Adds single input vector (in form of tree entry) and the mask for its /// shuffling. void add(const TreeEntry &E1, ArrayRef<int> Mask) { - add(E1.VectorizedValue, Mask); + Value *V1 = E1.VectorizedValue; + if (V1->getType()->isIntOrIntVectorTy()) + V1 = castToScalarTyElem(V1, any_of(E1.Scalars, [&](Value *V) { + return !isKnownNonNegative( + V, SimplifyQuery(*R.DL)); + })); + add(V1, Mask); } /// Adds 2 input vectors and the mask for their shuffling. void add(Value *V1, Value *V2, ArrayRef<int> Mask) { assert(V1 && V2 && !Mask.empty() && "Expected non-empty input vectors."); + V1 = castToScalarTyElem(V1); + V2 = castToScalarTyElem(V2); if (InVectors.empty()) { InVectors.push_back(V1); InVectors.push_back(V2); @@ -10370,6 +12072,7 @@ public: } /// Adds another one input vector and the mask for the shuffling. void add(Value *V1, ArrayRef<int> Mask, bool = false) { + V1 = castToScalarTyElem(V1); if (InVectors.empty()) { if (!isa<FixedVectorType>(V1->getType())) { V1 = createShuffle(V1, nullptr, CommonMask); @@ -10433,7 +12136,7 @@ public: } Value *gather(ArrayRef<Value *> VL, unsigned MaskVF = 0, Value *Root = nullptr) { - return R.gather(VL, Root); + return R.gather(VL, Root, ScalarTy); } Value *createFreeze(Value *V) { return Builder.CreateFreeze(V); } /// Finalize emission of the shuffles. @@ -10496,17 +12199,11 @@ public: Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx, bool PostponedPHIs) { ValueList &VL = E->getOperand(NodeIdx); - if (E->State == TreeEntry::PossibleStridedVectorize && - !E->ReorderIndices.empty()) { - SmallVector<int> Mask(E->ReorderIndices.begin(), E->ReorderIndices.end()); - reorderScalars(VL, Mask); - } const unsigned VF = VL.size(); InstructionsState S = getSameOpcode(VL, *TLI); // Special processing for GEPs bundle, which may include non-gep values. if (!S.getOpcode() && VL.front()->getType()->isPointerTy()) { - const auto *It = - find_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); }); + const auto *It = find_if(VL, IsaPred<GetElementPtrInst>); if (It != VL.end()) S = getSameOpcode(*It, *TLI); } @@ -10539,12 +12236,14 @@ Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx, } if (IsSameVE) { auto FinalShuffle = [&](Value *V, ArrayRef<int> Mask) { - ShuffleInstructionBuilder ShuffleBuilder(Builder, *this); + ShuffleInstructionBuilder ShuffleBuilder( + cast<VectorType>(V->getType())->getElementType(), Builder, *this); ShuffleBuilder.add(V, Mask); return ShuffleBuilder.finalize(std::nullopt); }; Value *V = vectorizeTree(VE, PostponedPHIs); - if (VF != cast<FixedVectorType>(V->getType())->getNumElements()) { + if (VF * getNumElements(VL[0]->getType()) != + cast<FixedVectorType>(V->getType())->getNumElements()) { if (!VE->ReuseShuffleIndices.empty()) { // Reshuffle to get only unique values. // If some of the scalars are duplicated in the vectorization @@ -10564,19 +12263,13 @@ Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx, // ... (use %2) // %shuffle = shuffle <2 x> %2, poison, <2 x> {2, 0} // br %block - SmallVector<int> UniqueIdxs(VF, PoisonMaskElem); - SmallSet<int, 4> UsedIdxs; - int Pos = 0; - for (int Idx : VE->ReuseShuffleIndices) { - if (Idx != static_cast<int>(VF) && Idx != PoisonMaskElem && - UsedIdxs.insert(Idx).second) - UniqueIdxs[Idx] = Pos; - ++Pos; + SmallVector<int> Mask(VF, PoisonMaskElem); + for (auto [I, V] : enumerate(VL)) { + if (isa<PoisonValue>(V)) + continue; + Mask[I] = VE->findLaneForValue(V); } - assert(VF >= UsedIdxs.size() && "Expected vectorization factor " - "less than original vector size."); - UniqueIdxs.append(VF - UsedIdxs.size(), PoisonMaskElem); - V = FinalShuffle(V, UniqueIdxs); + V = FinalShuffle(V, Mask); } else { assert(VF < cast<FixedVectorType>(V->getType())->getNumElements() && "Expected vectorization factor less " @@ -10594,7 +12287,7 @@ Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx, }) == VE->UserTreeIndices.end()) { auto *It = find_if( VectorizableTree, [&](const std::unique_ptr<TreeEntry> &TE) { - return TE->State == TreeEntry::NeedToGather && + return TE->isGather() && TE->UserTreeIndices.front().UserTE == E && TE->UserTreeIndices.front().EdgeIdx == NodeIdx; }); @@ -10620,13 +12313,14 @@ Value *BoUpSLP::vectorizeOperand(TreeEntry *E, unsigned NodeIdx, } template <typename BVTy, typename ResTy, typename... Args> -ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { - assert(E->State == TreeEntry::NeedToGather && "Expected gather node."); +ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy, + Args &...Params) { + assert(E->isGather() && "Expected gather node."); unsigned VF = E->getVectorFactor(); bool NeedFreeze = false; - SmallVector<int> ReuseShuffleIndicies(E->ReuseShuffleIndices.begin(), - E->ReuseShuffleIndices.end()); + SmallVector<int> ReuseShuffleIndices(E->ReuseShuffleIndices.begin(), + E->ReuseShuffleIndices.end()); SmallVector<Value *> GatheredScalars(E->Scalars.begin(), E->Scalars.end()); // Build a mask out of the reorder indices and reorder scalars per this // mask. @@ -10658,17 +12352,23 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { Idx == 0) || (Mask.size() == InputVF && ShuffleVectorInst::isIdentityMask(Mask, Mask.size()))) { - std::iota(std::next(Mask.begin(), I * SliceSize), - std::next(Mask.begin(), (I + 1) * SliceSize), 0); + std::iota( + std::next(Mask.begin(), I * SliceSize), + std::next(Mask.begin(), + I * SliceSize + getNumElems(Mask.size(), SliceSize, I)), + 0); } else { unsigned IVal = *find_if_not(Mask, [](int Idx) { return Idx == PoisonMaskElem; }); - std::fill(std::next(Mask.begin(), I * SliceSize), - std::next(Mask.begin(), (I + 1) * SliceSize), IVal); + std::fill( + std::next(Mask.begin(), I * SliceSize), + std::next(Mask.begin(), + I * SliceSize + getNumElems(Mask.size(), SliceSize, I)), + IVal); } return true; }; - BVTy ShuffleBuilder(Params...); + BVTy ShuffleBuilder(ScalarTy, Params...); ResTy Res = ResTy(); SmallVector<int> Mask; SmallVector<int> ExtractMask(GatheredScalars.size(), PoisonMaskElem); @@ -10677,12 +12377,12 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { bool UseVecBaseAsInput = false; SmallVector<std::optional<TargetTransformInfo::ShuffleKind>> GatherShuffles; SmallVector<SmallVector<const TreeEntry *>> Entries; - Type *ScalarTy = GatheredScalars.front()->getType(); - auto *VecTy = FixedVectorType::get(ScalarTy, GatheredScalars.size()); + Type *OrigScalarTy = GatheredScalars.front()->getType(); + auto *VecTy = getWidenedType(ScalarTy, GatheredScalars.size()); unsigned NumParts = TTI->getNumberOfParts(VecTy); if (NumParts == 0 || NumParts >= GatheredScalars.size()) NumParts = 1; - if (!all_of(GatheredScalars, UndefValue::classof)) { + if (!all_of(GatheredScalars, IsaPred<UndefValue>)) { // Check for gathered extracts. bool Resized = false; ExtractShuffles = @@ -10712,7 +12412,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { GatheredScalars.size() != VF) { Resized = true; GatheredScalars.append(VF - GatheredScalars.size(), - PoisonValue::get(ScalarTy)); + PoisonValue::get(OrigScalarTy)); } } } @@ -10772,12 +12472,12 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { }); })) GatheredScalars.append(VF - GatheredScalars.size(), - PoisonValue::get(ScalarTy)); + PoisonValue::get(OrigScalarTy)); } // Remove shuffled elements from list of gathers. for (int I = 0, Sz = Mask.size(); I < Sz; ++I) { if (Mask[I] != PoisonMaskElem) - GatheredScalars[I] = PoisonValue::get(ScalarTy); + GatheredScalars[I] = PoisonValue::get(OrigScalarTy); } } } @@ -10788,7 +12488,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { // such sequences. bool IsSplat = IsRootPoison && isSplat(Scalars) && (Scalars.size() > 2 || Scalars.front() == Scalars.back()); - Scalars.append(VF - Scalars.size(), PoisonValue::get(ScalarTy)); + Scalars.append(VF - Scalars.size(), PoisonValue::get(OrigScalarTy)); SmallVector<int> UndefPos; DenseMap<Value *, unsigned> UniquePositions; // Gather unique non-const values and all constant values. @@ -10810,7 +12510,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { ++NumNonConsts; SinglePos = I; Value *OrigV = V; - Scalars[I] = PoisonValue::get(ScalarTy); + Scalars[I] = PoisonValue::get(OrigScalarTy); if (IsSplat) { Scalars.front() = OrigV; ReuseMask[I] = 0; @@ -10826,7 +12526,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { ReuseMask.assign(VF, PoisonMaskElem); std::swap(Scalars.front(), Scalars[SinglePos]); if (!UndefPos.empty() && UndefPos.front() == 0) - Scalars.front() = UndefValue::get(ScalarTy); + Scalars.front() = UndefValue::get(OrigScalarTy); } ReuseMask[SinglePos] = SinglePos; } else if (!UndefPos.empty() && IsSplat) { @@ -10856,7 +12556,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { // Replace the undef by the poison, in the mask it is replaced by // non-poisoned scalar already. if (I != Pos) - Scalars[I] = PoisonValue::get(ScalarTy); + Scalars[I] = PoisonValue::get(OrigScalarTy); } } else { // Replace undefs by the poisons, emit broadcast and then emit @@ -10864,7 +12564,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { for (int I : UndefPos) { ReuseMask[I] = PoisonMaskElem; if (isa<UndefValue>(Scalars[I])) - Scalars[I] = PoisonValue::get(ScalarTy); + Scalars[I] = PoisonValue::get(OrigScalarTy); } NeedFreeze = true; } @@ -10898,8 +12598,8 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { VecOp = TE->VectorizedValue; if (!Vec1) { Vec1 = VecOp; - } else if (Vec1 != EI->getVectorOperand()) { - assert((!Vec2 || Vec2 == EI->getVectorOperand()) && + } else if (Vec1 != VecOp) { + assert((!Vec2 || Vec2 == VecOp) && "Expected only 1 or 2 vectors shuffle."); Vec2 = VecOp; } @@ -10919,13 +12619,12 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { IsNonPoisoned &= isGuaranteedNotToBePoison(Vec1); } else { IsUsedInExpr = false; - ShuffleBuilder.add(PoisonValue::get(FixedVectorType::get( - ScalarTy, GatheredScalars.size())), - ExtractMask, /*ForExtracts=*/true); + ShuffleBuilder.add(PoisonValue::get(VecTy), ExtractMask, + /*ForExtracts=*/true); } } if (!GatherShuffles.empty()) { - unsigned SliceSize = E->Scalars.size() / NumParts; + unsigned SliceSize = getPartNumElems(E->Scalars.size(), NumParts); SmallVector<int> VecMask(Mask.size(), PoisonMaskElem); for (const auto [I, TEs] : enumerate(Entries)) { if (TEs.empty()) { @@ -10935,12 +12634,13 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { } assert((TEs.size() == 1 || TEs.size() == 2) && "Expected shuffle of 1 or 2 entries."); - auto SubMask = ArrayRef(Mask).slice(I * SliceSize, SliceSize); + unsigned Limit = getNumElems(Mask.size(), SliceSize, I); + auto SubMask = ArrayRef(Mask).slice(I * SliceSize, Limit); VecMask.assign(VecMask.size(), PoisonMaskElem); copy(SubMask, std::next(VecMask.begin(), I * SliceSize)); if (TEs.size() == 1) { - IsUsedInExpr &= - FindReusedSplat(VecMask, TEs.front()->getVectorFactor(), I, SliceSize); + IsUsedInExpr &= FindReusedSplat( + VecMask, TEs.front()->getVectorFactor(), I, SliceSize); ShuffleBuilder.add(*TEs.front(), VecMask); if (TEs.front()->VectorizedValue) IsNonPoisoned &= @@ -11002,12 +12702,12 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { // contains only constant to build final vector and then shuffle. for (int I = 0, Sz = GatheredScalars.size(); I < Sz; ++I) { if (EnoughConstsForShuffle && isa<Constant>(GatheredScalars[I])) - NonConstants[I] = PoisonValue::get(ScalarTy); + NonConstants[I] = PoisonValue::get(OrigScalarTy); else - GatheredScalars[I] = PoisonValue::get(ScalarTy); + GatheredScalars[I] = PoisonValue::get(OrigScalarTy); } // Generate constants for final shuffle and build a mask for them. - if (!all_of(GatheredScalars, PoisonValue::classof)) { + if (!all_of(GatheredScalars, IsaPred<PoisonValue>)) { SmallVector<int> BVMask(GatheredScalars.size(), PoisonMaskElem); TryPackScalars(GatheredScalars, BVMask, /*IsRootPoison=*/true); Value *BV = ShuffleBuilder.gather(GatheredScalars, BVMask.size()); @@ -11050,13 +12750,13 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) { return Res; } -Value *BoUpSLP::createBuildVector(const TreeEntry *E) { - return processBuildVector<ShuffleInstructionBuilder, Value *>(E, Builder, - *this); +Value *BoUpSLP::createBuildVector(const TreeEntry *E, Type *ScalarTy) { + return processBuildVector<ShuffleInstructionBuilder, Value *>(E, ScalarTy, + Builder, *this); } Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { - IRBuilder<>::InsertPointGuard Guard(Builder); + IRBuilderBase::InsertPointGuard Guard(Builder); if (E->VectorizedValue && (E->State != TreeEntry::Vectorize || E->getOpcode() != Instruction::PHI || @@ -11065,26 +12765,35 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { return E->VectorizedValue; } - if (E->State == TreeEntry::NeedToGather) { + Value *V = E->Scalars.front(); + Type *ScalarTy = V->getType(); + if (auto *Store = dyn_cast<StoreInst>(V)) + ScalarTy = Store->getValueOperand()->getType(); + else if (auto *IE = dyn_cast<InsertElementInst>(V)) + ScalarTy = IE->getOperand(1)->getType(); + auto It = MinBWs.find(E); + if (It != MinBWs.end()) + ScalarTy = IntegerType::get(F->getContext(), It->second.first); + auto *VecTy = getWidenedType(ScalarTy, E->Scalars.size()); + if (E->isGather()) { // Set insert point for non-reduction initial nodes. if (E->getMainOp() && E->Idx == 0 && !UserIgnoreList) setInsertPointAfterBundle(E); - Value *Vec = createBuildVector(E); + Value *Vec = createBuildVector(E, ScalarTy); E->VectorizedValue = Vec; return Vec; } - auto FinalShuffle = [&](Value *V, const TreeEntry *E, VectorType *VecTy, - bool IsSigned) { - if (V->getType() != VecTy) - V = Builder.CreateIntCast(V, VecTy, IsSigned); - ShuffleInstructionBuilder ShuffleBuilder(Builder, *this); - if (E->getOpcode() == Instruction::Store) { + bool IsReverseOrder = isReverseOrder(E->ReorderIndices); + auto FinalShuffle = [&](Value *V, const TreeEntry *E, VectorType *VecTy) { + ShuffleInstructionBuilder ShuffleBuilder(ScalarTy, Builder, *this); + if (E->getOpcode() == Instruction::Store && + E->State == TreeEntry::Vectorize) { ArrayRef<int> Mask = ArrayRef(reinterpret_cast<const int *>(E->ReorderIndices.begin()), E->ReorderIndices.size()); ShuffleBuilder.add(V, Mask); - } else if (E->State == TreeEntry::PossibleStridedVectorize) { + } else if (E->State == TreeEntry::StridedVectorize && IsReverseOrder) { ShuffleBuilder.addOrdered(V, std::nullopt); } else { ShuffleBuilder.addOrdered(V, E->ReorderIndices); @@ -11094,26 +12803,26 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { assert((E->State == TreeEntry::Vectorize || E->State == TreeEntry::ScatterVectorize || - E->State == TreeEntry::PossibleStridedVectorize) && + E->State == TreeEntry::StridedVectorize) && "Unhandled state"); unsigned ShuffleOrOp = E->isAltShuffle() ? (unsigned)Instruction::ShuffleVector : E->getOpcode(); Instruction *VL0 = E->getMainOp(); - Type *ScalarTy = VL0->getType(); - if (auto *Store = dyn_cast<StoreInst>(VL0)) - ScalarTy = Store->getValueOperand()->getType(); - else if (auto *IE = dyn_cast<InsertElementInst>(VL0)) - ScalarTy = IE->getOperand(1)->getType(); - bool IsSigned = false; - auto It = MinBWs.find(E); - if (It != MinBWs.end()) { - ScalarTy = IntegerType::get(F->getContext(), It->second.first); - IsSigned = It->second.second; - } - auto *VecTy = FixedVectorType::get(ScalarTy, E->Scalars.size()); + auto GetOperandSignedness = [&](unsigned Idx) { + const TreeEntry *OpE = getOperandEntry(E, Idx); + bool IsSigned = false; + auto It = MinBWs.find(OpE); + if (It != MinBWs.end()) + IsSigned = It->second.second; + else + IsSigned = any_of(OpE->Scalars, [&](Value *R) { + return !isKnownNonNegative(R, SimplifyQuery(*DL)); + }); + return IsSigned; + }; switch (ShuffleOrOp) { case Instruction::PHI: { - assert((E->ReorderIndices.empty() || + assert((E->ReorderIndices.empty() || !E->ReuseShuffleIndices.empty() || E != VectorizableTree.front().get() || !E->UserTreeIndices.empty()) && "PHI reordering is free."); @@ -11133,7 +12842,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { PH->getParent()->getFirstInsertionPt()); Builder.SetCurrentDebugLocation(PH->getDebugLoc()); - V = FinalShuffle(V, E, VecTy, IsSigned); + V = FinalShuffle(V, E, VecTy); E->VectorizedValue = V; if (PostponedPHIs) @@ -11167,9 +12876,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { Builder.SetCurrentDebugLocation(PH->getDebugLoc()); Value *Vec = vectorizeOperand(E, I, /*PostponedPHIs=*/true); if (VecTy != Vec->getType()) { - assert(MinBWs.contains(getOperandEntry(E, I)) && + assert((It != MinBWs.end() || getOperandEntry(E, I)->isGather() || + MinBWs.contains(getOperandEntry(E, I))) && "Expected item in MinBWs."); - Vec = Builder.CreateIntCast(Vec, VecTy, It->second.second); + Vec = Builder.CreateIntCast(Vec, VecTy, GetOperandSignedness(I)); } NewPhi->addIncoming(Vec, IBB); } @@ -11184,7 +12894,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { if (const TreeEntry *TE = getTreeEntry(V)) V = TE->VectorizedValue; setInsertPointAfterBundle(E); - V = FinalShuffle(V, E, VecTy, IsSigned); + V = FinalShuffle(V, E, VecTy); E->VectorizedValue = V; return V; } @@ -11194,7 +12904,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { Value *Ptr = LI->getPointerOperand(); LoadInst *V = Builder.CreateAlignedLoad(VecTy, Ptr, LI->getAlign()); Value *NewV = propagateMetadata(V, E->Scalars); - NewV = FinalShuffle(NewV, E, VecTy, IsSigned); + NewV = FinalShuffle(NewV, E, VecTy); E->VectorizedValue = NewV; return NewV; } @@ -11210,7 +12920,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { assert(Res.first > 0 && "Expected item in MinBWs."); V = Builder.CreateIntCast( V, - FixedVectorType::get( + getWidenedType( ScalarTy, cast<FixedVectorType>(V->getType())->getNumElements()), Res.second); @@ -11224,7 +12934,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { cast<FixedVectorType>(FirstInsert->getType())->getNumElements(); const unsigned NumScalars = E->Scalars.size(); - unsigned Offset = *getInsertIndex(VL0); + unsigned Offset = *getElementIndex(VL0); assert(Offset < NumElts && "Failed to find vector index offset"); // Create shuffle to resize vector @@ -11242,7 +12952,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { Mask.swap(PrevMask); for (unsigned I = 0; I < NumScalars; ++I) { Value *Scalar = E->Scalars[PrevMask[I]]; - unsigned InsertIdx = *getInsertIndex(Scalar); + unsigned InsertIdx = *getElementIndex(Scalar); IsIdentity &= InsertIdx - Offset == I; Mask[InsertIdx - Offset] = I; } @@ -11255,7 +12965,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { // sequence. InsertElementInst *Ins = cast<InsertElementInst>(VL0); do { - std::optional<unsigned> InsertIdx = getInsertIndex(Ins); + std::optional<unsigned> InsertIdx = getElementIndex(Ins); if (!InsertIdx) break; if (InsertMask[*InsertIdx] == PoisonMaskElem) @@ -11380,10 +13090,11 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { auto *CI = cast<CastInst>(VL0); Instruction::CastOps VecOpcode = CI->getOpcode(); - Type *SrcScalarTy = VL0->getOperand(0)->getType(); + Type *SrcScalarTy = cast<VectorType>(InVec->getType())->getElementType(); auto SrcIt = MinBWs.find(getOperandEntry(E, 0)); if (!ScalarTy->isFloatingPointTy() && !SrcScalarTy->isFloatingPointTy() && - (SrcIt != MinBWs.end() || It != MinBWs.end())) { + (SrcIt != MinBWs.end() || It != MinBWs.end() || + SrcScalarTy != CI->getOperand(0)->getType())) { // Check if the values are candidates to demote. unsigned SrcBWSz = DL->getTypeSizeInBits(SrcScalarTy); if (SrcIt != MinBWs.end()) @@ -11396,12 +13107,19 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { } else if (It != MinBWs.end()) { assert(BWSz > SrcBWSz && "Invalid cast!"); VecOpcode = It->second.second ? Instruction::SExt : Instruction::ZExt; + } else if (SrcIt != MinBWs.end()) { + assert(BWSz > SrcBWSz && "Invalid cast!"); + VecOpcode = + SrcIt->second.second ? Instruction::SExt : Instruction::ZExt; } + } else if (VecOpcode == Instruction::SIToFP && SrcIt != MinBWs.end() && + !SrcIt->second.second) { + VecOpcode = Instruction::UIToFP; } Value *V = (VecOpcode != ShuffleOrOp && VecOpcode == Instruction::BitCast) ? InVec : Builder.CreateCast(VecOpcode, InVec, VecTy); - V = FinalShuffle(V, E, VecTy, IsSigned); + V = FinalShuffle(V, E, VecTy); E->VectorizedValue = V; ++NumVectorInstructions; @@ -11422,11 +13140,22 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { return E->VectorizedValue; } if (L->getType() != R->getType()) { - assert((MinBWs.contains(getOperandEntry(E, 0)) || + assert((getOperandEntry(E, 0)->isGather() || + getOperandEntry(E, 1)->isGather() || + MinBWs.contains(getOperandEntry(E, 0)) || MinBWs.contains(getOperandEntry(E, 1))) && "Expected item in MinBWs."); - L = Builder.CreateIntCast(L, VecTy, IsSigned); - R = Builder.CreateIntCast(R, VecTy, IsSigned); + if (cast<VectorType>(L->getType()) + ->getElementType() + ->getIntegerBitWidth() < cast<VectorType>(R->getType()) + ->getElementType() + ->getIntegerBitWidth()) { + Type *CastTy = R->getType(); + L = Builder.CreateIntCast(L, CastTy, GetOperandSignedness(0)); + } else { + Type *CastTy = L->getType(); + R = Builder.CreateIntCast(R, CastTy, GetOperandSignedness(1)); + } } CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate(); @@ -11434,7 +13163,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { propagateIRFlags(V, E->Scalars, VL0); // Do not cast for cmps. VecTy = cast<FixedVectorType>(V->getType()); - V = FinalShuffle(V, E, VecTy, IsSigned); + V = FinalShuffle(V, E, VecTy); E->VectorizedValue = V; ++NumVectorInstructions; @@ -11458,16 +13187,20 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; } - if (True->getType() != False->getType()) { - assert((MinBWs.contains(getOperandEntry(E, 1)) || + if (True->getType() != VecTy || False->getType() != VecTy) { + assert((It != MinBWs.end() || getOperandEntry(E, 1)->isGather() || + getOperandEntry(E, 2)->isGather() || + MinBWs.contains(getOperandEntry(E, 1)) || MinBWs.contains(getOperandEntry(E, 2))) && "Expected item in MinBWs."); - True = Builder.CreateIntCast(True, VecTy, IsSigned); - False = Builder.CreateIntCast(False, VecTy, IsSigned); + if (True->getType() != VecTy) + True = Builder.CreateIntCast(True, VecTy, GetOperandSignedness(1)); + if (False->getType() != VecTy) + False = Builder.CreateIntCast(False, VecTy, GetOperandSignedness(2)); } Value *V = Builder.CreateSelect(Cond, True, False); - V = FinalShuffle(V, E, VecTy, IsSigned); + V = FinalShuffle(V, E, VecTy); E->VectorizedValue = V; ++NumVectorInstructions; @@ -11489,7 +13222,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { if (auto *I = dyn_cast<Instruction>(V)) V = propagateMetadata(I, E->Scalars); - V = FinalShuffle(V, E, VecTy, IsSigned); + V = FinalShuffle(V, E, VecTy); E->VectorizedValue = V; ++NumVectorInstructions; @@ -11526,22 +13259,47 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; } - if (LHS->getType() != RHS->getType()) { - assert((MinBWs.contains(getOperandEntry(E, 0)) || + if (ShuffleOrOp == Instruction::And && It != MinBWs.end()) { + for (unsigned I : seq<unsigned>(0, E->getNumOperands())) { + ArrayRef<Value *> Ops = E->getOperand(I); + if (all_of(Ops, [&](Value *Op) { + auto *CI = dyn_cast<ConstantInt>(Op); + return CI && CI->getValue().countr_one() >= It->second.first; + })) { + V = FinalShuffle(I == 0 ? RHS : LHS, E, VecTy); + E->VectorizedValue = V; + ++NumVectorInstructions; + return V; + } + } + } + if (LHS->getType() != VecTy || RHS->getType() != VecTy) { + assert((It != MinBWs.end() || getOperandEntry(E, 0)->isGather() || + getOperandEntry(E, 1)->isGather() || + MinBWs.contains(getOperandEntry(E, 0)) || MinBWs.contains(getOperandEntry(E, 1))) && "Expected item in MinBWs."); - LHS = Builder.CreateIntCast(LHS, VecTy, IsSigned); - RHS = Builder.CreateIntCast(RHS, VecTy, IsSigned); + if (LHS->getType() != VecTy) + LHS = Builder.CreateIntCast(LHS, VecTy, GetOperandSignedness(0)); + if (RHS->getType() != VecTy) + RHS = Builder.CreateIntCast(RHS, VecTy, GetOperandSignedness(1)); } Value *V = Builder.CreateBinOp( static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS, RHS); - propagateIRFlags(V, E->Scalars, VL0, !MinBWs.contains(E)); - if (auto *I = dyn_cast<Instruction>(V)) + propagateIRFlags(V, E->Scalars, VL0, It == MinBWs.end()); + if (auto *I = dyn_cast<Instruction>(V)) { V = propagateMetadata(I, E->Scalars); + // Drop nuw flags for abs(sub(commutative), true). + if (!MinBWs.contains(E) && ShuffleOrOp == Instruction::Sub && + any_of(E->Scalars, [](Value *V) { + return isCommutative(cast<Instruction>(V)); + })) + I->setHasNoUnsignedWrap(/*b=*/false); + } - V = FinalShuffle(V, E, VecTy, IsSigned); + V = FinalShuffle(V, E, VecTy); E->VectorizedValue = V; ++NumVectorInstructions; @@ -11558,25 +13316,61 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { Value *PO = LI->getPointerOperand(); if (E->State == TreeEntry::Vectorize) { NewLI = Builder.CreateAlignedLoad(VecTy, PO, LI->getAlign()); + } else if (E->State == TreeEntry::StridedVectorize) { + Value *Ptr0 = cast<LoadInst>(E->Scalars.front())->getPointerOperand(); + Value *PtrN = cast<LoadInst>(E->Scalars.back())->getPointerOperand(); + PO = IsReverseOrder ? PtrN : Ptr0; + std::optional<int> Diff = getPointersDiff( + VL0->getType(), Ptr0, VL0->getType(), PtrN, *DL, *SE); + Type *StrideTy = DL->getIndexType(PO->getType()); + Value *StrideVal; + if (Diff) { + int Stride = *Diff / (static_cast<int>(E->Scalars.size()) - 1); + StrideVal = + ConstantInt::get(StrideTy, (IsReverseOrder ? -1 : 1) * Stride * + DL->getTypeAllocSize(ScalarTy)); + } else { + SmallVector<Value *> PointerOps(E->Scalars.size(), nullptr); + transform(E->Scalars, PointerOps.begin(), [](Value *V) { + return cast<LoadInst>(V)->getPointerOperand(); + }); + OrdersType Order; + std::optional<Value *> Stride = + calculateRtStride(PointerOps, ScalarTy, *DL, *SE, Order, + &*Builder.GetInsertPoint()); + Value *NewStride = + Builder.CreateIntCast(*Stride, StrideTy, /*isSigned=*/true); + StrideVal = Builder.CreateMul( + NewStride, + ConstantInt::get( + StrideTy, + (IsReverseOrder ? -1 : 1) * + static_cast<int>(DL->getTypeAllocSize(ScalarTy)))); + } + Align CommonAlignment = computeCommonAlignment<LoadInst>(E->Scalars); + auto *Inst = Builder.CreateIntrinsic( + Intrinsic::experimental_vp_strided_load, + {VecTy, PO->getType(), StrideTy}, + {PO, StrideVal, Builder.getAllOnesMask(VecTy->getElementCount()), + Builder.getInt32(E->Scalars.size())}); + Inst->addParamAttr( + /*ArgNo=*/0, + Attribute::getWithAlignment(Inst->getContext(), CommonAlignment)); + NewLI = Inst; } else { - assert((E->State == TreeEntry::ScatterVectorize || - E->State == TreeEntry::PossibleStridedVectorize) && - "Unhandled state"); + assert(E->State == TreeEntry::ScatterVectorize && "Unhandled state"); Value *VecPtr = vectorizeOperand(E, 0, PostponedPHIs); if (E->VectorizedValue) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; } // Use the minimum alignment of the gathered loads. - Align CommonAlignment = LI->getAlign(); - for (Value *V : E->Scalars) - CommonAlignment = - std::min(CommonAlignment, cast<LoadInst>(V)->getAlign()); + Align CommonAlignment = computeCommonAlignment<LoadInst>(E->Scalars); NewLI = Builder.CreateMaskedGather(VecTy, VecPtr, CommonAlignment); } Value *V = propagateMetadata(NewLI, E->Scalars); - V = FinalShuffle(V, E, VecTy, IsSigned); + V = FinalShuffle(V, E, VecTy); E->VectorizedValue = V; ++NumVectorInstructions; return V; @@ -11587,11 +13381,37 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { setInsertPointAfterBundle(E); Value *VecValue = vectorizeOperand(E, 0, PostponedPHIs); - VecValue = FinalShuffle(VecValue, E, VecTy, IsSigned); + if (VecValue->getType() != VecTy) + VecValue = + Builder.CreateIntCast(VecValue, VecTy, GetOperandSignedness(0)); + VecValue = FinalShuffle(VecValue, E, VecTy); Value *Ptr = SI->getPointerOperand(); - StoreInst *ST = - Builder.CreateAlignedStore(VecValue, Ptr, SI->getAlign()); + Instruction *ST; + if (E->State == TreeEntry::Vectorize) { + ST = Builder.CreateAlignedStore(VecValue, Ptr, SI->getAlign()); + } else { + assert(E->State == TreeEntry::StridedVectorize && + "Expected either strided or conseutive stores."); + if (!E->ReorderIndices.empty()) { + SI = cast<StoreInst>(E->Scalars[E->ReorderIndices.front()]); + Ptr = SI->getPointerOperand(); + } + Align CommonAlignment = computeCommonAlignment<StoreInst>(E->Scalars); + Type *StrideTy = DL->getIndexType(SI->getPointerOperandType()); + auto *Inst = Builder.CreateIntrinsic( + Intrinsic::experimental_vp_strided_store, + {VecTy, Ptr->getType(), StrideTy}, + {VecValue, Ptr, + ConstantInt::get( + StrideTy, -static_cast<int>(DL->getTypeAllocSize(ScalarTy))), + Builder.getAllOnesMask(VecTy->getElementCount()), + Builder.getInt32(E->Scalars.size())}); + Inst->addParamAttr( + /*ArgNo=*/1, + Attribute::getWithAlignment(Inst->getContext(), CommonAlignment)); + ST = Inst; + } Value *V = propagateMetadata(ST, E->Scalars); @@ -11629,7 +13449,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { V = propagateMetadata(I, GEPs); } - V = FinalShuffle(V, E, VecTy, IsSigned); + V = FinalShuffle(V, E, VecTy); E->VectorizedValue = V; ++NumVectorInstructions; @@ -11642,7 +13462,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); - auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI); + SmallVector<Type *> ArgTys = + buildIntrinsicArgTypes(CI, ID, VecTy->getNumElements(), + It != MinBWs.end() ? It->second.first : 0); + auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, ArgTys); bool UseIntrinsic = ID != Intrinsic::not_intrinsic && VecCallCosts.first <= VecCallCosts.second; @@ -11651,8 +13474,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { SmallVector<Type *, 2> TysForDecl; // Add return type if intrinsic is overloaded on it. if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, -1)) - TysForDecl.push_back( - FixedVectorType::get(CI->getType(), E->Scalars.size())); + TysForDecl.push_back(VecTy); auto *CEI = cast<CallInst>(VL0); for (unsigned I : seq<unsigned>(0, CI->arg_size())) { ValueList OpVL; @@ -11660,7 +13482,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { // vectorized. if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(ID, I)) { ScalarArg = CEI->getArgOperand(I); - OpVecs.push_back(CEI->getArgOperand(I)); + // if decided to reduce bitwidth of abs intrinsic, it second argument + // must be set false (do not return poison, if value issigned min). + if (ID == Intrinsic::abs && It != MinBWs.end() && + It->second.first < DL->getTypeSizeInBits(CEI->getType())) + ScalarArg = Builder.getFalse(); + OpVecs.push_back(ScalarArg); if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) TysForDecl.push_back(ScalarArg->getType()); continue; @@ -11671,24 +13498,15 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; } - auto GetOperandSignedness = [&](unsigned Idx) { - const TreeEntry *OpE = getOperandEntry(E, Idx); - bool IsSigned = false; - auto It = MinBWs.find(OpE); - if (It != MinBWs.end()) - IsSigned = It->second.second; - else - IsSigned = any_of(OpE->Scalars, [&](Value *R) { - return !isKnownNonNegative(R, SimplifyQuery(*DL)); - }); - return IsSigned; - }; ScalarArg = CEI->getArgOperand(I); if (cast<VectorType>(OpVec->getType())->getElementType() != - ScalarArg->getType()) { - auto *CastTy = FixedVectorType::get(ScalarArg->getType(), - VecTy->getNumElements()); + ScalarArg->getType()->getScalarType() && + It == MinBWs.end()) { + auto *CastTy = + getWidenedType(ScalarArg->getType(), VecTy->getNumElements()); OpVec = Builder.CreateIntCast(OpVec, CastTy, GetOperandSignedness(I)); + } else if (It != MinBWs.end()) { + OpVec = Builder.CreateIntCast(OpVec, VecTy, GetOperandSignedness(I)); } LLVM_DEBUG(dbgs() << "SLP: OpVec[" << I << "]: " << *OpVec << "\n"); OpVecs.push_back(OpVec); @@ -11713,7 +13531,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { Value *V = Builder.CreateCall(CF, OpVecs, OpBundles); propagateIRFlags(V, E->Scalars, VL0); - V = FinalShuffle(V, E, VecTy, IsSigned); + V = FinalShuffle(V, E, VecTy); E->VectorizedValue = V; ++NumVectorInstructions; @@ -11745,12 +13563,30 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n"); return E->VectorizedValue; } - if (LHS && RHS && LHS->getType() != RHS->getType()) { - assert((MinBWs.contains(getOperandEntry(E, 0)) || + if (LHS && RHS && + ((Instruction::isBinaryOp(E->getOpcode()) && + (LHS->getType() != VecTy || RHS->getType() != VecTy)) || + (isa<CmpInst>(VL0) && LHS->getType() != RHS->getType()))) { + assert((It != MinBWs.end() || getOperandEntry(E, 0)->isGather() || + getOperandEntry(E, 1)->isGather() || + MinBWs.contains(getOperandEntry(E, 0)) || MinBWs.contains(getOperandEntry(E, 1))) && "Expected item in MinBWs."); - LHS = Builder.CreateIntCast(LHS, VecTy, IsSigned); - RHS = Builder.CreateIntCast(RHS, VecTy, IsSigned); + Type *CastTy = VecTy; + if (isa<CmpInst>(VL0) && LHS->getType() != RHS->getType()) { + if (cast<VectorType>(LHS->getType()) + ->getElementType() + ->getIntegerBitWidth() < cast<VectorType>(RHS->getType()) + ->getElementType() + ->getIntegerBitWidth()) + CastTy = RHS->getType(); + else + CastTy = LHS->getType(); + } + if (LHS->getType() != CastTy) + LHS = Builder.CreateIntCast(LHS, CastTy, GetOperandSignedness(0)); + if (RHS->getType() != CastTy) + RHS = Builder.CreateIntCast(RHS, CastTy, GetOperandSignedness(1)); } Value *V0, *V1; @@ -11765,6 +13601,21 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { CmpInst::Predicate AltPred = AltCI->getPredicate(); V1 = Builder.CreateCmp(AltPred, LHS, RHS); } else { + if (LHS->getType()->isIntOrIntVectorTy() && ScalarTy->isIntegerTy()) { + unsigned SrcBWSz = DL->getTypeSizeInBits( + cast<VectorType>(LHS->getType())->getElementType()); + unsigned BWSz = DL->getTypeSizeInBits(ScalarTy); + if (BWSz <= SrcBWSz) { + if (BWSz < SrcBWSz) + LHS = Builder.CreateIntCast(LHS, VecTy, It->second.first); + assert(LHS->getType() == VecTy && "Expected same type as operand."); + if (auto *I = dyn_cast<Instruction>(LHS)) + LHS = propagateMetadata(I, E->Scalars); + E->VectorizedValue = LHS; + ++NumVectorInstructions; + return LHS; + } + } V0 = Builder.CreateCast( static_cast<Instruction::CastOps>(E->getOpcode()), LHS, VecTy); V1 = Builder.CreateCast( @@ -11792,8 +13643,21 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { }, Mask, &OpScalars, &AltScalars); - propagateIRFlags(V0, OpScalars); - propagateIRFlags(V1, AltScalars); + propagateIRFlags(V0, OpScalars, E->getMainOp(), It == MinBWs.end()); + propagateIRFlags(V1, AltScalars, E->getAltOp(), It == MinBWs.end()); + auto DropNuwFlag = [&](Value *Vec, unsigned Opcode) { + // Drop nuw flags for abs(sub(commutative), true). + if (auto *I = dyn_cast<Instruction>(Vec); + I && Opcode == Instruction::Sub && !MinBWs.contains(E) && + any_of(E->Scalars, [](Value *V) { + auto *IV = cast<Instruction>(V); + return IV->getOpcode() == Instruction::Sub && + isCommutative(cast<Instruction>(IV)); + })) + I->setHasNoUnsignedWrap(/*b=*/false); + }; + DropNuwFlag(V0, E->getOpcode()); + DropNuwFlag(V1, E->getAltOpcode()); Value *V = Builder.CreateShuffleVector(V0, V1, Mask); if (auto *I = dyn_cast<Instruction>(V)) { @@ -11802,9 +13666,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { CSEBlocks.insert(I->getParent()); } - if (V->getType() != VecTy && !isa<CmpInst>(VL0)) - V = Builder.CreateIntCast( - V, FixedVectorType::get(ScalarTy, E->getVectorFactor()), IsSigned); E->VectorizedValue = V; ++NumVectorInstructions; @@ -11866,7 +13727,8 @@ Value *BoUpSLP::vectorizeTree( auto *TE = const_cast<TreeEntry *>(E); if (auto *VecTE = getTreeEntry(TE->Scalars.front())) if (VecTE->isSame(TE->UserTreeIndices.front().UserTE->getOperand( - TE->UserTreeIndices.front().EdgeIdx))) + TE->UserTreeIndices.front().EdgeIdx)) && + VecTE->isSame(TE->Scalars)) // Found gather node which is absolutely the same as one of the // vectorized nodes. It may happen after reordering. continue; @@ -11900,6 +13762,60 @@ Value *BoUpSLP::vectorizeTree( } Builder.SetCurrentDebugLocation(UserI->getDebugLoc()); Value *Vec = vectorizeTree(TE, /*PostponedPHIs=*/false); + if (Vec->getType() != PrevVec->getType()) { + assert(Vec->getType()->isIntOrIntVectorTy() && + PrevVec->getType()->isIntOrIntVectorTy() && + "Expected integer vector types only."); + std::optional<bool> IsSigned; + for (Value *V : TE->Scalars) { + if (const TreeEntry *BaseTE = getTreeEntry(V)) { + auto It = MinBWs.find(BaseTE); + if (It != MinBWs.end()) { + IsSigned = IsSigned.value_or(false) || It->second.second; + if (*IsSigned) + break; + } + for (const TreeEntry *MNTE : MultiNodeScalars.lookup(V)) { + auto It = MinBWs.find(MNTE); + if (It != MinBWs.end()) { + IsSigned = IsSigned.value_or(false) || It->second.second; + if (*IsSigned) + break; + } + } + if (IsSigned.value_or(false)) + break; + // Scan through gather nodes. + for (const TreeEntry *BVE : ValueToGatherNodes.lookup(V)) { + auto It = MinBWs.find(BVE); + if (It != MinBWs.end()) { + IsSigned = IsSigned.value_or(false) || It->second.second; + if (*IsSigned) + break; + } + } + if (IsSigned.value_or(false)) + break; + if (auto *EE = dyn_cast<ExtractElementInst>(V)) { + IsSigned = + IsSigned.value_or(false) || + !isKnownNonNegative(EE->getVectorOperand(), SimplifyQuery(*DL)); + continue; + } + if (IsSigned.value_or(false)) + break; + } + } + if (IsSigned.value_or(false)) { + // Final attempt - check user node. + auto It = MinBWs.find(TE->UserTreeIndices.front().UserTE); + if (It != MinBWs.end()) + IsSigned = It->second.second; + } + assert(IsSigned && + "Expected user node or perfect diamond match in MinBWs."); + Vec = Builder.CreateIntCast(Vec, PrevVec->getType(), *IsSigned); + } PrevVec->replaceAllUsesWith(Vec); PostponedValues.try_emplace(Vec).first->second.push_back(TE); // Replace the stub vector node, if it was used before for one of the @@ -11920,9 +13836,11 @@ Value *BoUpSLP::vectorizeTree( DenseMap<Value *, InsertElementInst *> VectorToInsertElement; // Maps extract Scalar to the corresponding extractelement instruction in the // basic block. Only one extractelement per block should be emitted. - DenseMap<Value *, DenseMap<BasicBlock *, Instruction *>> ScalarToEEs; + DenseMap<Value *, + DenseMap<BasicBlock *, std::pair<Instruction *, Instruction *>>> + ScalarToEEs; SmallDenseSet<Value *, 4> UsedInserts; - DenseMap<Value *, Value *> VectorCasts; + DenseMap<std::pair<Value *, Type *>, Value *> VectorCasts; SmallDenseSet<Value *, 4> ScalarsWithNullptrUser; // Extract all of the elements with the external uses. for (const auto &ExternalUse : ExternalUses) { @@ -11935,8 +13853,7 @@ Value *BoUpSLP::vectorizeTree( continue; TreeEntry *E = getTreeEntry(Scalar); assert(E && "Invalid scalar"); - assert(E->State != TreeEntry::NeedToGather && - "Extracting from a gather list"); + assert(!E->isGather() && "Extracting from a gather list"); // Non-instruction pointers are not deleted, just skip them. if (E->getOpcode() == Instruction::GetElementPtr && !isa<GetElementPtrInst>(Scalar)) @@ -11949,18 +13866,25 @@ Value *BoUpSLP::vectorizeTree( auto ExtractAndExtendIfNeeded = [&](Value *Vec) { if (Scalar->getType() != Vec->getType()) { Value *Ex = nullptr; + Value *ExV = nullptr; + auto *GEP = dyn_cast<GetElementPtrInst>(Scalar); + bool ReplaceGEP = GEP && ExternalUsesAsGEPs.contains(GEP); auto It = ScalarToEEs.find(Scalar); if (It != ScalarToEEs.end()) { // No need to emit many extracts, just move the only one in the // current block. auto EEIt = It->second.find(Builder.GetInsertBlock()); if (EEIt != It->second.end()) { - Instruction *I = EEIt->second; + Instruction *I = EEIt->second.first; if (Builder.GetInsertPoint() != Builder.GetInsertBlock()->end() && - Builder.GetInsertPoint()->comesBefore(I)) + Builder.GetInsertPoint()->comesBefore(I)) { I->moveBefore(*Builder.GetInsertPoint()->getParent(), Builder.GetInsertPoint()); + if (auto *CI = EEIt->second.second) + CI->moveAfter(I); + } Ex = I; + ExV = EEIt->second.second ? EEIt->second.second : Ex; } } if (!Ex) { @@ -11970,11 +13894,31 @@ Value *BoUpSLP::vectorizeTree( if (const TreeEntry *ETE = getTreeEntry(V)) V = ETE->VectorizedValue; Ex = Builder.CreateExtractElement(V, ES->getIndexOperand()); + } else if (ReplaceGEP) { + // Leave the GEPs as is, they are free in most cases and better to + // keep them as GEPs. + auto *CloneGEP = GEP->clone(); + if (isa<Instruction>(Vec)) + CloneGEP->insertBefore(*Builder.GetInsertBlock(), + Builder.GetInsertPoint()); + else + CloneGEP->insertBefore(GEP); + if (GEP->hasName()) + CloneGEP->takeName(GEP); + Ex = CloneGEP; } else { Ex = Builder.CreateExtractElement(Vec, Lane); } + // If necessary, sign-extend or zero-extend ScalarRoot + // to the larger type. + ExV = Ex; + if (Scalar->getType() != Ex->getType()) + ExV = Builder.CreateIntCast(Ex, Scalar->getType(), + MinBWs.find(E)->second.second); if (auto *I = dyn_cast<Instruction>(Ex)) - ScalarToEEs[Scalar].try_emplace(Builder.GetInsertBlock(), I); + ScalarToEEs[Scalar].try_emplace( + Builder.GetInsertBlock(), + std::make_pair(I, cast<Instruction>(ExV))); } // The then branch of the previous if may produce constants, since 0 // operand might be a constant. @@ -11982,12 +13926,7 @@ Value *BoUpSLP::vectorizeTree( GatherShuffleExtractSeq.insert(ExI); CSEBlocks.insert(ExI->getParent()); } - // If necessary, sign-extend or zero-extend ScalarRoot - // to the larger type. - if (Scalar->getType() != Ex->getType()) - return Builder.CreateIntCast(Ex, Scalar->getType(), - MinBWs.find(E)->second.second); - return Ex; + return ExV; } assert(isa<FixedVectorType>(Scalar->getType()) && isa<InsertElementInst>(Scalar) && @@ -12003,12 +13942,18 @@ Value *BoUpSLP::vectorizeTree( if (!ScalarsWithNullptrUser.insert(Scalar).second) continue; assert((ExternallyUsedValues.count(Scalar) || + Scalar->hasNUsesOrMore(UsesLimit) || any_of(Scalar->users(), [&](llvm::User *U) { + if (ExternalUsesAsGEPs.contains(U)) + return true; TreeEntry *UseEntry = getTreeEntry(U); return UseEntry && - UseEntry->State == TreeEntry::Vectorize && - E->State == TreeEntry::Vectorize && + (UseEntry->State == TreeEntry::Vectorize || + UseEntry->State == + TreeEntry::StridedVectorize) && + (E->State == TreeEntry::Vectorize || + E->State == TreeEntry::StridedVectorize) && doesInTreeUserNeedToExtract( Scalar, cast<Instruction>(UseEntry->Scalars.front()), @@ -12034,7 +13979,8 @@ Value *BoUpSLP::vectorizeTree( continue; } - if (auto *VU = dyn_cast<InsertElementInst>(User)) { + if (auto *VU = dyn_cast<InsertElementInst>(User); + VU && VU->getOperand(1) == Scalar) { // Skip if the scalar is another vector op or Vec is not an instruction. if (!Scalar->getType()->isVectorTy() && isa<Instruction>(Vec)) { if (auto *FTy = dyn_cast<FixedVectorType>(User->getType())) { @@ -12043,24 +13989,29 @@ Value *BoUpSLP::vectorizeTree( // Need to use original vector, if the root is truncated. auto BWIt = MinBWs.find(E); if (BWIt != MinBWs.end() && Vec->getType() != VU->getType()) { - auto VecIt = VectorCasts.find(Scalar); + auto *ScalarTy = FTy->getElementType(); + auto Key = std::make_pair(Vec, ScalarTy); + auto VecIt = VectorCasts.find(Key); if (VecIt == VectorCasts.end()) { - IRBuilder<>::InsertPointGuard Guard(Builder); - if (auto *IVec = dyn_cast<Instruction>(Vec)) + IRBuilderBase::InsertPointGuard Guard(Builder); + if (auto *IVec = dyn_cast<PHINode>(Vec)) + Builder.SetInsertPoint( + IVec->getParent()->getFirstNonPHIOrDbgOrLifetime()); + else if (auto *IVec = dyn_cast<Instruction>(Vec)) Builder.SetInsertPoint(IVec->getNextNonDebugInstruction()); Vec = Builder.CreateIntCast( Vec, - FixedVectorType::get( - cast<VectorType>(VU->getType())->getElementType(), + getWidenedType( + ScalarTy, cast<FixedVectorType>(Vec->getType())->getNumElements()), BWIt->second.second); - VectorCasts.try_emplace(Scalar, Vec); + VectorCasts.try_emplace(Key, Vec); } else { Vec = VecIt->second; } } - std::optional<unsigned> InsertIdx = getInsertIndex(VU); + std::optional<unsigned> InsertIdx = getElementIndex(VU); if (InsertIdx) { auto *It = find_if(ShuffledInserts, [VU](const ShuffledInsertData &Data) { @@ -12083,14 +14034,14 @@ Value *BoUpSLP::vectorizeTree( while (auto *IEBase = dyn_cast<InsertElementInst>(Base)) { if (IEBase != User && (!IEBase->hasOneUse() || - getInsertIndex(IEBase).value_or(Idx) == Idx)) + getElementIndex(IEBase).value_or(Idx) == Idx)) break; // Build the mask for the vectorized insertelement instructions. if (const TreeEntry *E = getTreeEntry(IEBase)) { do { IEBase = cast<InsertElementInst>(Base); - int IEIdx = *getInsertIndex(IEBase); - assert(Mask[Idx] == PoisonMaskElem && + int IEIdx = *getElementIndex(IEBase); + assert(Mask[IEIdx] == PoisonMaskElem && "InsertElementInstruction used already."); Mask[IEIdx] = IEIdx; Base = IEBase->getOperand(0); @@ -12159,7 +14110,8 @@ Value *BoUpSLP::vectorizeTree( else CombinedMask2[I] = Mask[I] - VF; } - ShuffleInstructionBuilder ShuffleBuilder(Builder, *this); + ShuffleInstructionBuilder ShuffleBuilder( + cast<VectorType>(V1->getType())->getElementType(), Builder, *this); ShuffleBuilder.add(V1, CombinedMask1); if (V2) ShuffleBuilder.add(V2, CombinedMask2); @@ -12260,7 +14212,7 @@ Value *BoUpSLP::vectorizeTree( TreeEntry *Entry = TEPtr.get(); // No need to handle users of gathered values. - if (Entry->State == TreeEntry::NeedToGather) + if (Entry->isGather()) continue; assert(Entry->VectorizedValue && "Can't find vectorizable value"); @@ -12288,11 +14240,8 @@ Value *BoUpSLP::vectorizeTree( } #endif LLVM_DEBUG(dbgs() << "SLP: \tErasing scalar:" << *Scalar << ".\n"); - eraseInstruction(cast<Instruction>(Scalar)); - // Retain to-be-deleted instructions for some debug-info - // bookkeeping. NOTE: eraseInstruction only marks the instruction for - // deletion - instructions are not deleted until later. - RemovedInsts.push_back(cast<Instruction>(Scalar)); + auto *I = cast<Instruction>(Scalar); + RemovedInsts.push_back(I); } } @@ -12301,10 +14250,54 @@ Value *BoUpSLP::vectorizeTree( if (auto *V = dyn_cast<Instruction>(VectorizableTree[0]->VectorizedValue)) V->mergeDIAssignID(RemovedInsts); + // Clear up reduction references, if any. + if (UserIgnoreList) { + for (Instruction *I : RemovedInsts) { + if (getTreeEntry(I)->Idx != 0) + continue; + SmallVector<SelectInst *> LogicalOpSelects; + I->replaceUsesWithIf(PoisonValue::get(I->getType()), [&](Use &U) { + // Do not replace condition of the logical op in form select <cond>. + bool IsPoisoningLogicalOp = isa<SelectInst>(U.getUser()) && + (match(U.getUser(), m_LogicalAnd()) || + match(U.getUser(), m_LogicalOr())) && + U.getOperandNo() == 0; + if (IsPoisoningLogicalOp) { + LogicalOpSelects.push_back(cast<SelectInst>(U.getUser())); + return false; + } + return UserIgnoreList->contains(U.getUser()); + }); + // Replace conditions of the poisoning logical ops with the non-poison + // constant value. + for (SelectInst *SI : LogicalOpSelects) + SI->setCondition(Constant::getNullValue(SI->getCondition()->getType())); + } + } + // Retain to-be-deleted instructions for some debug-info bookkeeping and alias + // cache correctness. + // NOTE: removeInstructionAndOperands only marks the instruction for deletion + // - instructions are not deleted until later. + removeInstructionsAndOperands(ArrayRef(RemovedInsts)); + Builder.ClearInsertionPoint(); InstrElementSize.clear(); - return VectorizableTree[0]->VectorizedValue; + const TreeEntry &RootTE = *VectorizableTree.front(); + Value *Vec = RootTE.VectorizedValue; + if (auto It = MinBWs.find(&RootTE); ReductionBitWidth != 0 && + It != MinBWs.end() && + ReductionBitWidth != It->second.first) { + IRBuilder<>::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(ReductionRoot->getParent(), + ReductionRoot->getIterator()); + Vec = Builder.CreateIntCast( + Vec, + VectorType::get(Builder.getIntNTy(ReductionBitWidth), + cast<VectorType>(Vec->getType())->getElementCount()), + It->second.second); + } + return Vec; } void BoUpSLP::optimizeGatherSequence() { @@ -12396,8 +14389,8 @@ void BoUpSLP::optimizeGatherSequence() { return SM1.size() - LastUndefsCnt > 1 && TTI->getNumberOfParts(SI1->getType()) == TTI->getNumberOfParts( - FixedVectorType::get(SI1->getType()->getElementType(), - SM1.size() - LastUndefsCnt)); + getWidenedType(SI1->getType()->getElementType(), + SM1.size() - LastUndefsCnt)); }; // Perform O(N^2) search over the gather/shuffle sequences and merge identical // instructions. TODO: We can further optimize this scan if we split the @@ -13063,26 +15056,29 @@ unsigned BoUpSLP::getVectorElementSize(Value *V) { // that feed it. The type of the loaded value may indicate a more suitable // width than V's type. We want to base the vector element size on the width // of memory operations where possible. - SmallVector<std::pair<Instruction *, BasicBlock *>, 16> Worklist; + SmallVector<std::tuple<Instruction *, BasicBlock *, unsigned>> Worklist; SmallPtrSet<Instruction *, 16> Visited; if (auto *I = dyn_cast<Instruction>(V)) { - Worklist.emplace_back(I, I->getParent()); + Worklist.emplace_back(I, I->getParent(), 0); Visited.insert(I); } // Traverse the expression tree in bottom-up order looking for loads. If we // encounter an instruction we don't yet handle, we give up. auto Width = 0u; + Value *FirstNonBool = nullptr; while (!Worklist.empty()) { - Instruction *I; - BasicBlock *Parent; - std::tie(I, Parent) = Worklist.pop_back_val(); + auto [I, Parent, Level] = Worklist.pop_back_val(); // We should only be looking at scalar instructions here. If the current // instruction has a vector type, skip. auto *Ty = I->getType(); if (isa<VectorType>(Ty)) continue; + if (Ty != Builder.getInt1Ty() && !FirstNonBool) + FirstNonBool = I; + if (Level > RecursionMaxDepth) + continue; // If the current instruction is a load, update MaxWidth to reflect the // width of the loaded value. @@ -13095,11 +15091,16 @@ unsigned BoUpSLP::getVectorElementSize(Value *V) { // user or the use is a PHI node, we add it to the worklist. else if (isa<PHINode, CastInst, GetElementPtrInst, CmpInst, SelectInst, BinaryOperator, UnaryOperator>(I)) { - for (Use &U : I->operands()) + for (Use &U : I->operands()) { if (auto *J = dyn_cast<Instruction>(U.get())) if (Visited.insert(J).second && - (isa<PHINode>(I) || J->getParent() == Parent)) - Worklist.emplace_back(J, J->getParent()); + (isa<PHINode>(I) || J->getParent() == Parent)) { + Worklist.emplace_back(J, J->getParent(), Level + 1); + continue; + } + if (!FirstNonBool && U.get()->getType() != Builder.getInt1Ty()) + FirstNonBool = U.get(); + } } else { break; } @@ -13109,8 +15110,8 @@ unsigned BoUpSLP::getVectorElementSize(Value *V) { // gave up for some reason, just return the width of V. Otherwise, return the // maximum width we found. if (!Width) { - if (auto *CI = dyn_cast<CmpInst>(V)) - V = CI->getOperand(0); + if (V->getType() == Builder.getInt1Ty() && FirstNonBool) + V = FirstNonBool; Width = DL->getTypeSizeInBits(V->getType()); } @@ -13120,41 +15121,192 @@ unsigned BoUpSLP::getVectorElementSize(Value *V) { return Width; } -// Determine if a value V in a vectorizable expression Expr can be demoted to a -// smaller type with a truncation. We collect the values that will be demoted -// in ToDemote and additional roots that require investigating in Roots. bool BoUpSLP::collectValuesToDemote( - Value *V, SmallVectorImpl<Value *> &ToDemote, - DenseMap<Instruction *, SmallVector<unsigned>> &DemotedConsts, - SmallVectorImpl<Value *> &Roots, DenseSet<Value *> &Visited) const { + const TreeEntry &E, bool IsProfitableToDemoteRoot, unsigned &BitWidth, + SmallVectorImpl<unsigned> &ToDemote, DenseSet<const TreeEntry *> &Visited, + unsigned &MaxDepthLevel, bool &IsProfitableToDemote, + bool IsTruncRoot) const { // We can always demote constants. - if (isa<Constant>(V)) + if (all_of(E.Scalars, IsaPred<Constant>)) return true; + unsigned OrigBitWidth = DL->getTypeSizeInBits(E.Scalars.front()->getType()); + if (OrigBitWidth == BitWidth) { + MaxDepthLevel = 1; + return true; + } + // If the value is not a vectorized instruction in the expression and not used // by the insertelement instruction and not used in multiple vector nodes, it // cannot be demoted. - auto *I = dyn_cast<Instruction>(V); - if (!I || !getTreeEntry(I) || MultiNodeScalars.contains(I) || - !Visited.insert(I).second || all_of(I->users(), [&](User *U) { - return isa<InsertElementInst>(U) && !getTreeEntry(U); + bool IsSignedNode = any_of(E.Scalars, [&](Value *R) { + return !isKnownNonNegative(R, SimplifyQuery(*DL)); + }); + auto IsPotentiallyTruncated = [&](Value *V, unsigned &BitWidth) -> bool { + if (MultiNodeScalars.contains(V)) + return false; + // For lat shuffle of sext/zext with many uses need to check the extra bit + // for unsigned values, otherwise may have incorrect casting for reused + // scalars. + bool IsSignedVal = !isKnownNonNegative(V, SimplifyQuery(*DL)); + if ((!IsSignedNode || IsSignedVal) && OrigBitWidth > BitWidth) { + APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth); + if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL))) + return true; + } + unsigned NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT); + unsigned BitWidth1 = OrigBitWidth - NumSignBits; + if (IsSignedNode) + ++BitWidth1; + if (auto *I = dyn_cast<Instruction>(V)) { + APInt Mask = DB->getDemandedBits(I); + unsigned BitWidth2 = + std::max<unsigned>(1, Mask.getBitWidth() - Mask.countl_zero()); + while (!IsSignedNode && BitWidth2 < OrigBitWidth) { + APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth2 - 1); + if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL))) + break; + BitWidth2 *= 2; + } + BitWidth1 = std::min(BitWidth1, BitWidth2); + } + BitWidth = std::max(BitWidth, BitWidth1); + return BitWidth > 0 && OrigBitWidth >= (BitWidth * 2); + }; + using namespace std::placeholders; + auto FinalAnalysis = [&]() { + if (!IsProfitableToDemote) + return false; + bool Res = all_of( + E.Scalars, std::bind(IsPotentiallyTruncated, _1, std::ref(BitWidth))); + // Demote gathers. + if (Res && E.isGather()) { + // Check possible extractelement instructions bases and final vector + // length. + SmallPtrSet<Value *, 4> UniqueBases; + for (Value *V : E.Scalars) { + auto *EE = dyn_cast<ExtractElementInst>(V); + if (!EE) + continue; + UniqueBases.insert(EE->getVectorOperand()); + } + const unsigned VF = E.Scalars.size(); + Type *OrigScalarTy = E.Scalars.front()->getType(); + if (UniqueBases.size() <= 2 || + TTI->getNumberOfParts(getWidenedType(OrigScalarTy, VF)) == + TTI->getNumberOfParts(getWidenedType( + IntegerType::get(OrigScalarTy->getContext(), BitWidth), VF))) + ToDemote.push_back(E.Idx); + } + return Res; + }; + if (E.isGather() || !Visited.insert(&E).second || + any_of(E.Scalars, [&](Value *V) { + return all_of(V->users(), [&](User *U) { + return isa<InsertElementInst>(U) && !getTreeEntry(U); + }); + })) + return FinalAnalysis(); + + if (any_of(E.Scalars, [&](Value *V) { + return !all_of(V->users(), [=](User *U) { + return getTreeEntry(U) || + (E.Idx == 0 && UserIgnoreList && + UserIgnoreList->contains(U)) || + (!isa<CmpInst>(U) && U->getType()->isSized() && + !U->getType()->isScalableTy() && + DL->getTypeSizeInBits(U->getType()) <= BitWidth); + }) && !IsPotentiallyTruncated(V, BitWidth); })) return false; - unsigned Start = 0; - unsigned End = I->getNumOperands(); - switch (I->getOpcode()) { + auto ProcessOperands = [&](ArrayRef<const TreeEntry *> Operands, + bool &NeedToExit) { + NeedToExit = false; + unsigned InitLevel = MaxDepthLevel; + for (const TreeEntry *Op : Operands) { + unsigned Level = InitLevel; + if (!collectValuesToDemote(*Op, IsProfitableToDemoteRoot, BitWidth, + ToDemote, Visited, Level, IsProfitableToDemote, + IsTruncRoot)) { + if (!IsProfitableToDemote) + return false; + NeedToExit = true; + if (!FinalAnalysis()) + return false; + continue; + } + MaxDepthLevel = std::max(MaxDepthLevel, Level); + } + return true; + }; + auto AttemptCheckBitwidth = + [&](function_ref<bool(unsigned, unsigned)> Checker, bool &NeedToExit) { + // Try all bitwidth < OrigBitWidth. + NeedToExit = false; + unsigned BestFailBitwidth = 0; + for (; BitWidth < OrigBitWidth; BitWidth *= 2) { + if (Checker(BitWidth, OrigBitWidth)) + return true; + if (BestFailBitwidth == 0 && FinalAnalysis()) + BestFailBitwidth = BitWidth; + } + if (BitWidth >= OrigBitWidth) { + if (BestFailBitwidth == 0) { + BitWidth = OrigBitWidth; + return false; + } + MaxDepthLevel = 1; + BitWidth = BestFailBitwidth; + NeedToExit = true; + return true; + } + return false; + }; + auto TryProcessInstruction = + [&](unsigned &BitWidth, + ArrayRef<const TreeEntry *> Operands = std::nullopt, + function_ref<bool(unsigned, unsigned)> Checker = {}) { + if (Operands.empty()) { + if (!IsTruncRoot) + MaxDepthLevel = 1; + (void)for_each(E.Scalars, std::bind(IsPotentiallyTruncated, _1, + std::ref(BitWidth))); + } else { + // Several vectorized uses? Check if we can truncate it, otherwise - + // exit. + if (E.UserTreeIndices.size() > 1 && + !all_of(E.Scalars, std::bind(IsPotentiallyTruncated, _1, + std::ref(BitWidth)))) + return false; + bool NeedToExit = false; + if (Checker && !AttemptCheckBitwidth(Checker, NeedToExit)) + return false; + if (NeedToExit) + return true; + if (!ProcessOperands(Operands, NeedToExit)) + return false; + if (NeedToExit) + return true; + } + + ++MaxDepthLevel; + // Record the entry that we can demote. + ToDemote.push_back(E.Idx); + return IsProfitableToDemote; + }; + switch (E.getOpcode()) { // We can always demote truncations and extensions. Since truncations can // seed additional demotion, we save the truncated value. case Instruction::Trunc: - Roots.push_back(I->getOperand(0)); - break; + if (IsProfitableToDemoteRoot) + IsProfitableToDemote = true; + return TryProcessInstruction(BitWidth); case Instruction::ZExt: case Instruction::SExt: - if (isa<ExtractElementInst, InsertElementInst>(I->getOperand(0))) - return false; - break; + IsProfitableToDemote = true; + return TryProcessInstruction(BitWidth); // We can demote certain binary operations if we can demote both of their // operands. @@ -13163,184 +15315,491 @@ bool BoUpSLP::collectValuesToDemote( case Instruction::Mul: case Instruction::And: case Instruction::Or: - case Instruction::Xor: - if (!collectValuesToDemote(I->getOperand(0), ToDemote, DemotedConsts, Roots, - Visited) || - !collectValuesToDemote(I->getOperand(1), ToDemote, DemotedConsts, Roots, - Visited)) - return false; - break; + case Instruction::Xor: { + return TryProcessInstruction( + BitWidth, {getOperandEntry(&E, 0), getOperandEntry(&E, 1)}); + } + case Instruction::Shl: { + // If we are truncating the result of this SHL, and if it's a shift of an + // inrange amount, we can always perform a SHL in a smaller type. + auto ShlChecker = [&](unsigned BitWidth, unsigned) { + return all_of(E.Scalars, [&](Value *V) { + auto *I = cast<Instruction>(V); + KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL); + return AmtKnownBits.getMaxValue().ult(BitWidth); + }); + }; + return TryProcessInstruction( + BitWidth, {getOperandEntry(&E, 0), getOperandEntry(&E, 1)}, ShlChecker); + } + case Instruction::LShr: { + // If this is a truncate of a logical shr, we can truncate it to a smaller + // lshr iff we know that the bits we would otherwise be shifting in are + // already zeros. + auto LShrChecker = [&](unsigned BitWidth, unsigned OrigBitWidth) { + return all_of(E.Scalars, [&](Value *V) { + auto *I = cast<Instruction>(V); + KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL); + APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth); + return AmtKnownBits.getMaxValue().ult(BitWidth) && + MaskedValueIsZero(I->getOperand(0), ShiftedBits, + SimplifyQuery(*DL)); + }); + }; + return TryProcessInstruction( + BitWidth, {getOperandEntry(&E, 0), getOperandEntry(&E, 1)}, + LShrChecker); + } + case Instruction::AShr: { + // If this is a truncate of an arithmetic shr, we can truncate it to a + // smaller ashr iff we know that all the bits from the sign bit of the + // original type and the sign bit of the truncate type are similar. + auto AShrChecker = [&](unsigned BitWidth, unsigned OrigBitWidth) { + return all_of(E.Scalars, [&](Value *V) { + auto *I = cast<Instruction>(V); + KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL); + unsigned ShiftedBits = OrigBitWidth - BitWidth; + return AmtKnownBits.getMaxValue().ult(BitWidth) && + ShiftedBits < ComputeNumSignBits(I->getOperand(0), *DL, 0, AC, + nullptr, DT); + }); + }; + return TryProcessInstruction( + BitWidth, {getOperandEntry(&E, 0), getOperandEntry(&E, 1)}, + AShrChecker); + } + case Instruction::UDiv: + case Instruction::URem: { + // UDiv and URem can be truncated if all the truncated bits are zero. + auto Checker = [&](unsigned BitWidth, unsigned OrigBitWidth) { + assert(BitWidth <= OrigBitWidth && "Unexpected bitwidths!"); + return all_of(E.Scalars, [&](Value *V) { + auto *I = cast<Instruction>(V); + APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth); + return MaskedValueIsZero(I->getOperand(0), Mask, SimplifyQuery(*DL)) && + MaskedValueIsZero(I->getOperand(1), Mask, SimplifyQuery(*DL)); + }); + }; + return TryProcessInstruction( + BitWidth, {getOperandEntry(&E, 0), getOperandEntry(&E, 1)}, Checker); + } // We can demote selects if we can demote their true and false values. case Instruction::Select: { - Start = 1; - SelectInst *SI = cast<SelectInst>(I); - if (!collectValuesToDemote(SI->getTrueValue(), ToDemote, DemotedConsts, - Roots, Visited) || - !collectValuesToDemote(SI->getFalseValue(), ToDemote, DemotedConsts, - Roots, Visited)) - return false; - break; + return TryProcessInstruction( + BitWidth, {getOperandEntry(&E, 1), getOperandEntry(&E, 2)}); } // We can demote phis if we can demote all their incoming operands. Note that // we don't need to worry about cycles since we ensure single use above. case Instruction::PHI: { - PHINode *PN = cast<PHINode>(I); - for (Value *IncValue : PN->incoming_values()) - if (!collectValuesToDemote(IncValue, ToDemote, DemotedConsts, Roots, - Visited)) - return false; - break; + const unsigned NumOps = E.getNumOperands(); + SmallVector<const TreeEntry *> Ops(NumOps); + transform(seq<unsigned>(0, NumOps), Ops.begin(), + std::bind(&BoUpSLP::getOperandEntry, this, &E, _1)); + + return TryProcessInstruction(BitWidth, Ops); + } + + case Instruction::Call: { + auto *IC = dyn_cast<IntrinsicInst>(E.getMainOp()); + if (!IC) + break; + Intrinsic::ID ID = getVectorIntrinsicIDForCall(IC, TLI); + if (ID != Intrinsic::abs && ID != Intrinsic::smin && + ID != Intrinsic::smax && ID != Intrinsic::umin && ID != Intrinsic::umax) + break; + SmallVector<const TreeEntry *, 2> Operands(1, getOperandEntry(&E, 0)); + function_ref<bool(unsigned, unsigned)> CallChecker; + auto CompChecker = [&](unsigned BitWidth, unsigned OrigBitWidth) { + assert(BitWidth <= OrigBitWidth && "Unexpected bitwidths!"); + return all_of(E.Scalars, [&](Value *V) { + auto *I = cast<Instruction>(V); + if (ID == Intrinsic::umin || ID == Intrinsic::umax) { + APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth); + return MaskedValueIsZero(I->getOperand(0), Mask, + SimplifyQuery(*DL)) && + MaskedValueIsZero(I->getOperand(1), Mask, SimplifyQuery(*DL)); + } + assert((ID == Intrinsic::smin || ID == Intrinsic::smax) && + "Expected min/max intrinsics only."); + unsigned SignBits = OrigBitWidth - BitWidth; + APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth - 1); + unsigned Op0SignBits = ComputeNumSignBits(I->getOperand(0), *DL, 0, AC, + nullptr, DT); + unsigned Op1SignBits = ComputeNumSignBits(I->getOperand(1), *DL, 0, AC, + nullptr, DT); + return SignBits <= Op0SignBits && + ((SignBits != Op0SignBits && + !isKnownNonNegative(I->getOperand(0), SimplifyQuery(*DL))) || + MaskedValueIsZero(I->getOperand(0), Mask, + SimplifyQuery(*DL))) && + SignBits <= Op1SignBits && + ((SignBits != Op1SignBits && + !isKnownNonNegative(I->getOperand(1), SimplifyQuery(*DL))) || + MaskedValueIsZero(I->getOperand(1), Mask, SimplifyQuery(*DL))); + }); + }; + auto AbsChecker = [&](unsigned BitWidth, unsigned OrigBitWidth) { + assert(BitWidth <= OrigBitWidth && "Unexpected bitwidths!"); + return all_of(E.Scalars, [&](Value *V) { + auto *I = cast<Instruction>(V); + unsigned SignBits = OrigBitWidth - BitWidth; + APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth - 1); + unsigned Op0SignBits = + ComputeNumSignBits(I->getOperand(0), *DL, 0, AC, nullptr, DT); + return SignBits <= Op0SignBits && + ((SignBits != Op0SignBits && + !isKnownNonNegative(I->getOperand(0), SimplifyQuery(*DL))) || + MaskedValueIsZero(I->getOperand(0), Mask, SimplifyQuery(*DL))); + }); + }; + if (ID != Intrinsic::abs) { + Operands.push_back(getOperandEntry(&E, 1)); + CallChecker = CompChecker; + } else { + CallChecker = AbsChecker; + } + InstructionCost BestCost = + std::numeric_limits<InstructionCost::CostType>::max(); + unsigned BestBitWidth = BitWidth; + unsigned VF = E.Scalars.size(); + // Choose the best bitwidth based on cost estimations. + auto Checker = [&](unsigned BitWidth, unsigned) { + unsigned MinBW = PowerOf2Ceil(BitWidth); + SmallVector<Type *> ArgTys = buildIntrinsicArgTypes(IC, ID, VF, MinBW); + auto VecCallCosts = getVectorCallCosts( + IC, getWidenedType(IntegerType::get(IC->getContext(), MinBW), VF), + TTI, TLI, ArgTys); + InstructionCost Cost = std::min(VecCallCosts.first, VecCallCosts.second); + if (Cost < BestCost) { + BestCost = Cost; + BestBitWidth = BitWidth; + } + return false; + }; + [[maybe_unused]] bool NeedToExit; + (void)AttemptCheckBitwidth(Checker, NeedToExit); + BitWidth = BestBitWidth; + return TryProcessInstruction(BitWidth, Operands, CallChecker); } // Otherwise, conservatively give up. default: - return false; + break; } - - // Gather demoted constant operands. - for (unsigned Idx : seq<unsigned>(Start, End)) - if (isa<Constant>(I->getOperand(Idx))) - DemotedConsts.try_emplace(I).first->getSecond().push_back(Idx); - // Record the value that we can demote. - ToDemote.push_back(V); - return true; + MaxDepthLevel = 1; + return FinalAnalysis(); } +static RecurKind getRdxKind(Value *V); + void BoUpSLP::computeMinimumValueSizes() { // We only attempt to truncate integer expressions. - auto &TreeRoot = VectorizableTree[0]->Scalars; - auto *TreeRootIT = dyn_cast<IntegerType>(TreeRoot[0]->getType()); - if (!TreeRootIT || VectorizableTree.front()->State == TreeEntry::NeedToGather) + bool IsStoreOrInsertElt = + VectorizableTree.front()->getOpcode() == Instruction::Store || + VectorizableTree.front()->getOpcode() == Instruction::InsertElement; + if ((IsStoreOrInsertElt || UserIgnoreList) && + ExtraBitWidthNodes.size() <= 1 && + (!CastMaxMinBWSizes || CastMaxMinBWSizes->second == 0 || + CastMaxMinBWSizes->first / CastMaxMinBWSizes->second <= 2)) return; + unsigned NodeIdx = 0; + if (IsStoreOrInsertElt && !VectorizableTree.front()->isGather()) + NodeIdx = 1; + // Ensure the roots of the vectorizable tree don't form a cycle. - if (!VectorizableTree.front()->UserTreeIndices.empty()) + if (VectorizableTree[NodeIdx]->isGather() || + (NodeIdx == 0 && !VectorizableTree[NodeIdx]->UserTreeIndices.empty()) || + (NodeIdx != 0 && any_of(VectorizableTree[NodeIdx]->UserTreeIndices, + [NodeIdx](const EdgeInfo &EI) { + return EI.UserTE->Idx > + static_cast<int>(NodeIdx); + }))) return; - // Conservatively determine if we can actually truncate the roots of the - // expression. Collect the values that can be demoted in ToDemote and - // additional roots that require investigating in Roots. - SmallVector<Value *, 32> ToDemote; - DenseMap<Instruction *, SmallVector<unsigned>> DemotedConsts; - SmallVector<Value *, 4> Roots; - for (auto *Root : TreeRoot) { - DenseSet<Value *> Visited; - if (!collectValuesToDemote(Root, ToDemote, DemotedConsts, Roots, Visited)) - return; - } - - // The maximum bit width required to represent all the values that can be - // demoted without loss of precision. It would be safe to truncate the roots - // of the expression to this width. - auto MaxBitWidth = 1u; - - // We first check if all the bits of the roots are demanded. If they're not, - // we can truncate the roots to this narrower type. - for (auto *Root : TreeRoot) { - auto Mask = DB->getDemandedBits(cast<Instruction>(Root)); - MaxBitWidth = std::max<unsigned>(Mask.getBitWidth() - Mask.countl_zero(), - MaxBitWidth); - } - - // True if the roots can be zero-extended back to their original type, rather - // than sign-extended. We know that if the leading bits are not demanded, we - // can safely zero-extend. So we initialize IsKnownPositive to True. - bool IsKnownPositive = true; - - // If all the bits of the roots are demanded, we can try a little harder to - // compute a narrower type. This can happen, for example, if the roots are - // getelementptr indices. InstCombine promotes these indices to the pointer - // width. Thus, all their bits are technically demanded even though the - // address computation might be vectorized in a smaller type. - // - // We start by looking at each entry that can be demoted. We compute the - // maximum bit width required to store the scalar by using ValueTracking to - // compute the number of high-order bits we can truncate. - if (MaxBitWidth == DL->getTypeSizeInBits(TreeRoot[0]->getType()) && - all_of(TreeRoot, [](Value *V) { - return all_of(V->users(), - [](User *U) { return isa<GetElementPtrInst>(U); }); - })) { - MaxBitWidth = 8u; + // The first value node for store/insertelement is sext/zext/trunc? Skip it, + // resize to the final type. + bool IsTruncRoot = false; + bool IsProfitableToDemoteRoot = !IsStoreOrInsertElt; + SmallVector<unsigned> RootDemotes; + if (NodeIdx != 0 && + VectorizableTree[NodeIdx]->State == TreeEntry::Vectorize && + VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc) { + assert(IsStoreOrInsertElt && "Expected store/insertelement seeded graph."); + IsTruncRoot = true; + RootDemotes.push_back(NodeIdx); + IsProfitableToDemoteRoot = true; + ++NodeIdx; + } + + // Analyzed the reduction already and not profitable - exit. + if (AnalyzedMinBWVals.contains(VectorizableTree[NodeIdx]->Scalars.front())) + return; + SmallVector<unsigned> ToDemote; + auto ComputeMaxBitWidth = [&](const TreeEntry &E, bool IsTopRoot, + bool IsProfitableToDemoteRoot, unsigned Opcode, + unsigned Limit, bool IsTruncRoot, + bool IsSignedCmp) -> unsigned { + ToDemote.clear(); + // Check if the root is trunc and the next node is gather/buildvector, then + // keep trunc in scalars, which is free in most cases. + if (E.isGather() && IsTruncRoot && E.UserTreeIndices.size() == 1 && + E.Idx > (IsStoreOrInsertElt ? 2 : 1) && + all_of(E.Scalars, [&](Value *V) { + return V->hasOneUse() || isa<Constant>(V) || + (!V->hasNUsesOrMore(UsesLimit) && + none_of(V->users(), [&](User *U) { + const TreeEntry *TE = getTreeEntry(U); + const TreeEntry *UserTE = E.UserTreeIndices.back().UserTE; + if (TE == UserTE || !TE) + return false; + if (!isa<CastInst, BinaryOperator, FreezeInst, PHINode, + SelectInst>(U) || + !isa<CastInst, BinaryOperator, FreezeInst, PHINode, + SelectInst>(UserTE->getMainOp())) + return true; + unsigned UserTESz = DL->getTypeSizeInBits( + UserTE->Scalars.front()->getType()); + auto It = MinBWs.find(TE); + if (It != MinBWs.end() && It->second.first > UserTESz) + return true; + return DL->getTypeSizeInBits(U->getType()) > UserTESz; + })); + })) { + ToDemote.push_back(E.Idx); + const TreeEntry *UserTE = E.UserTreeIndices.back().UserTE; + auto It = MinBWs.find(UserTE); + if (It != MinBWs.end()) + return It->second.first; + unsigned MaxBitWidth = + DL->getTypeSizeInBits(UserTE->Scalars.front()->getType()); + MaxBitWidth = bit_ceil(MaxBitWidth); + if (MaxBitWidth < 8 && MaxBitWidth > 1) + MaxBitWidth = 8; + return MaxBitWidth; + } + + unsigned VF = E.getVectorFactor(); + auto *TreeRootIT = dyn_cast<IntegerType>(E.Scalars.front()->getType()); + if (!TreeRootIT || !Opcode) + return 0u; + + if (any_of(E.Scalars, + [&](Value *V) { return AnalyzedMinBWVals.contains(V); })) + return 0u; + + unsigned NumParts = TTI->getNumberOfParts(getWidenedType(TreeRootIT, VF)); + + // The maximum bit width required to represent all the values that can be + // demoted without loss of precision. It would be safe to truncate the roots + // of the expression to this width. + unsigned MaxBitWidth = 1u; + + // True if the roots can be zero-extended back to their original type, + // rather than sign-extended. We know that if the leading bits are not + // demanded, we can safely zero-extend. So we initialize IsKnownPositive to + // True. // Determine if the sign bit of all the roots is known to be zero. If not, // IsKnownPositive is set to False. - IsKnownPositive = llvm::all_of(TreeRoot, [&](Value *R) { + bool IsKnownPositive = !IsSignedCmp && all_of(E.Scalars, [&](Value *R) { KnownBits Known = computeKnownBits(R, *DL); return Known.isNonNegative(); }); - // Determine the maximum number of bits required to store the scalar - // values. - for (auto *Scalar : ToDemote) { - auto NumSignBits = ComputeNumSignBits(Scalar, *DL, 0, AC, nullptr, DT); - auto NumTypeBits = DL->getTypeSizeInBits(Scalar->getType()); - MaxBitWidth = std::max<unsigned>(NumTypeBits - NumSignBits, MaxBitWidth); - } - - // If we can't prove that the sign bit is zero, we must add one to the - // maximum bit width to account for the unknown sign bit. This preserves - // the existing sign bit so we can safely sign-extend the root back to the - // original type. Otherwise, if we know the sign bit is zero, we will - // zero-extend the root instead. - // - // FIXME: This is somewhat suboptimal, as there will be cases where adding - // one to the maximum bit width will yield a larger-than-necessary - // type. In general, we need to add an extra bit only if we can't - // prove that the upper bit of the original type is equal to the - // upper bit of the proposed smaller type. If these two bits are the - // same (either zero or one) we know that sign-extending from the - // smaller type will result in the same value. Here, since we can't - // yet prove this, we are just making the proposed smaller type - // larger to ensure correctness. - if (!IsKnownPositive) - ++MaxBitWidth; - } - - // Round MaxBitWidth up to the next power-of-two. - MaxBitWidth = llvm::bit_ceil(MaxBitWidth); - - // If the maximum bit width we compute is less than the with of the roots' - // type, we can proceed with the narrowing. Otherwise, do nothing. - if (MaxBitWidth >= TreeRootIT->getBitWidth()) - return; + // We first check if all the bits of the roots are demanded. If they're not, + // we can truncate the roots to this narrower type. + for (Value *Root : E.Scalars) { + unsigned NumSignBits = ComputeNumSignBits(Root, *DL, 0, AC, nullptr, DT); + TypeSize NumTypeBits = DL->getTypeSizeInBits(Root->getType()); + unsigned BitWidth1 = NumTypeBits - NumSignBits; + // If we can't prove that the sign bit is zero, we must add one to the + // maximum bit width to account for the unknown sign bit. This preserves + // the existing sign bit so we can safely sign-extend the root back to the + // original type. Otherwise, if we know the sign bit is zero, we will + // zero-extend the root instead. + // + // FIXME: This is somewhat suboptimal, as there will be cases where adding + // one to the maximum bit width will yield a larger-than-necessary + // type. In general, we need to add an extra bit only if we can't + // prove that the upper bit of the original type is equal to the + // upper bit of the proposed smaller type. If these two bits are + // the same (either zero or one) we know that sign-extending from + // the smaller type will result in the same value. Here, since we + // can't yet prove this, we are just making the proposed smaller + // type larger to ensure correctness. + if (!IsKnownPositive) + ++BitWidth1; + + APInt Mask = DB->getDemandedBits(cast<Instruction>(Root)); + unsigned BitWidth2 = Mask.getBitWidth() - Mask.countl_zero(); + MaxBitWidth = + std::max<unsigned>(std::min(BitWidth1, BitWidth2), MaxBitWidth); + } + + if (MaxBitWidth < 8 && MaxBitWidth > 1) + MaxBitWidth = 8; + + // If the original type is large, but reduced type does not improve the reg + // use - ignore it. + if (NumParts > 1 && + NumParts == + TTI->getNumberOfParts(getWidenedType( + IntegerType::get(F->getContext(), bit_ceil(MaxBitWidth)), VF))) + return 0u; + + bool IsProfitableToDemote = Opcode == Instruction::Trunc || + Opcode == Instruction::SExt || + Opcode == Instruction::ZExt || NumParts > 1; + // Conservatively determine if we can actually truncate the roots of the + // expression. Collect the values that can be demoted in ToDemote and + // additional roots that require investigating in Roots. + DenseSet<const TreeEntry *> Visited; + unsigned MaxDepthLevel = IsTruncRoot ? Limit : 1; + bool NeedToDemote = IsProfitableToDemote; + + if (!collectValuesToDemote(E, IsProfitableToDemoteRoot, MaxBitWidth, + ToDemote, Visited, MaxDepthLevel, NeedToDemote, + IsTruncRoot) || + (MaxDepthLevel <= Limit && + !(((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) && + (!IsTopRoot || !(IsStoreOrInsertElt || UserIgnoreList) || + DL->getTypeSizeInBits(TreeRootIT) / + DL->getTypeSizeInBits(cast<Instruction>(E.Scalars.front()) + ->getOperand(0) + ->getType()) > + 2))))) + return 0u; + // Round MaxBitWidth up to the next power-of-two. + MaxBitWidth = bit_ceil(MaxBitWidth); + + return MaxBitWidth; + }; // If we can truncate the root, we must collect additional values that might // be demoted as a result. That is, those seeded by truncations we will // modify. - while (!Roots.empty()) { - DenseSet<Value *> Visited; - collectValuesToDemote(Roots.pop_back_val(), ToDemote, DemotedConsts, Roots, - Visited); - } + // Add reduction ops sizes, if any. + if (UserIgnoreList && + isa<IntegerType>(VectorizableTree.front()->Scalars.front()->getType())) { + for (Value *V : *UserIgnoreList) { + auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT); + auto NumTypeBits = DL->getTypeSizeInBits(V->getType()); + unsigned BitWidth1 = NumTypeBits - NumSignBits; + if (!isKnownNonNegative(V, SimplifyQuery(*DL))) + ++BitWidth1; + unsigned BitWidth2 = BitWidth1; + if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) { + auto Mask = DB->getDemandedBits(cast<Instruction>(V)); + BitWidth2 = Mask.getBitWidth() - Mask.countl_zero(); + } + ReductionBitWidth = + std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth); + } + if (ReductionBitWidth < 8 && ReductionBitWidth > 1) + ReductionBitWidth = 8; + + ReductionBitWidth = bit_ceil(ReductionBitWidth); + } + bool IsTopRoot = NodeIdx == 0; + while (NodeIdx < VectorizableTree.size() && + VectorizableTree[NodeIdx]->State == TreeEntry::Vectorize && + VectorizableTree[NodeIdx]->getOpcode() == Instruction::Trunc) { + RootDemotes.push_back(NodeIdx); + ++NodeIdx; + IsTruncRoot = true; + } + bool IsSignedCmp = false; + while (NodeIdx < VectorizableTree.size()) { + ArrayRef<Value *> TreeRoot = VectorizableTree[NodeIdx]->Scalars; + unsigned Limit = 2; + unsigned Opcode = VectorizableTree[NodeIdx]->getOpcode(); + if (IsTopRoot && + ReductionBitWidth == + DL->getTypeSizeInBits( + VectorizableTree.front()->Scalars.front()->getType())) + Limit = 3; + unsigned MaxBitWidth = ComputeMaxBitWidth( + *VectorizableTree[NodeIdx], IsTopRoot, IsProfitableToDemoteRoot, Opcode, + Limit, IsTruncRoot, IsSignedCmp); + if (ReductionBitWidth != 0 && (IsTopRoot || !RootDemotes.empty())) { + if (MaxBitWidth != 0 && ReductionBitWidth < MaxBitWidth) + ReductionBitWidth = bit_ceil(MaxBitWidth); + else if (MaxBitWidth == 0) + ReductionBitWidth = 0; + } + + for (unsigned Idx : RootDemotes) { + if (all_of(VectorizableTree[Idx]->Scalars, [&](Value *V) { + uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType()); + if (OrigBitWidth > MaxBitWidth) { + APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, MaxBitWidth); + return MaskedValueIsZero(V, Mask, SimplifyQuery(*DL)); + } + return false; + })) + ToDemote.push_back(Idx); + } + RootDemotes.clear(); + IsTopRoot = false; + IsProfitableToDemoteRoot = true; + + if (ExtraBitWidthNodes.empty()) { + NodeIdx = VectorizableTree.size(); + } else { + unsigned NewIdx = 0; + do { + NewIdx = *ExtraBitWidthNodes.begin(); + ExtraBitWidthNodes.erase(ExtraBitWidthNodes.begin()); + } while (NewIdx <= NodeIdx && !ExtraBitWidthNodes.empty()); + NodeIdx = NewIdx; + IsTruncRoot = + NodeIdx < VectorizableTree.size() && + any_of(VectorizableTree[NodeIdx]->UserTreeIndices, + [](const EdgeInfo &EI) { + return EI.EdgeIdx == 0 && + EI.UserTE->getOpcode() == Instruction::Trunc && + !EI.UserTE->isAltShuffle(); + }); + IsSignedCmp = + NodeIdx < VectorizableTree.size() && + any_of(VectorizableTree[NodeIdx]->UserTreeIndices, + [&](const EdgeInfo &EI) { + return EI.UserTE->getOpcode() == Instruction::ICmp && + any_of(EI.UserTE->Scalars, [&](Value *V) { + auto *IC = dyn_cast<ICmpInst>(V); + return IC && + (IC->isSigned() || + !isKnownNonNegative(IC->getOperand(0), + SimplifyQuery(*DL)) || + !isKnownNonNegative(IC->getOperand(1), + SimplifyQuery(*DL))); + }); + }); + } - // Finally, map the values we can demote to the maximum bit with we computed. - for (auto *Scalar : ToDemote) { - auto *TE = getTreeEntry(Scalar); - assert(TE && "Expected vectorized scalar."); - if (MinBWs.contains(TE)) + // If the maximum bit width we compute is less than the with of the roots' + // type, we can proceed with the narrowing. Otherwise, do nothing. + if (MaxBitWidth == 0 || + MaxBitWidth >= + cast<IntegerType>(TreeRoot.front()->getType())->getBitWidth()) { + if (UserIgnoreList) + AnalyzedMinBWVals.insert(TreeRoot.begin(), TreeRoot.end()); continue; - bool IsSigned = any_of(TE->Scalars, [&](Value *R) { - KnownBits Known = computeKnownBits(R, *DL); - return !Known.isNonNegative(); - }); - MinBWs.try_emplace(TE, MaxBitWidth, IsSigned); - const auto *I = cast<Instruction>(Scalar); - auto DCIt = DemotedConsts.find(I); - if (DCIt != DemotedConsts.end()) { - for (unsigned Idx : DCIt->getSecond()) { - // Check that all instructions operands are demoted. - if (all_of(TE->Scalars, [&](Value *V) { - auto SIt = DemotedConsts.find(cast<Instruction>(V)); - return SIt != DemotedConsts.end() && - is_contained(SIt->getSecond(), Idx); - })) { - const TreeEntry *CTE = getOperandEntry(TE, Idx); - MinBWs.try_emplace(CTE, MaxBitWidth, IsSigned); - } - } + } + + // Finally, map the values we can demote to the maximum bit with we + // computed. + for (unsigned Idx : ToDemote) { + TreeEntry *TE = VectorizableTree[Idx].get(); + if (MinBWs.contains(TE)) + continue; + bool IsSigned = any_of(TE->Scalars, [&](Value *R) { + return !isKnownNonNegative(R, SimplifyQuery(*DL)); + }); + MinBWs.try_emplace(TE, MaxBitWidth, IsSigned); } } } @@ -13381,7 +15840,7 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_, DT = DT_; AC = AC_; DB = DB_; - DL = &F.getParent()->getDataLayout(); + DL = &F.getDataLayout(); Stores.clear(); GEPs.clear(); @@ -13444,30 +15903,73 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_, return Changed; } -bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, - unsigned Idx, unsigned MinVF) { +std::optional<bool> +SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, + unsigned Idx, unsigned MinVF, + unsigned &Size) { + Size = 0; LLVM_DEBUG(dbgs() << "SLP: Analyzing a store chain of length " << Chain.size() << "\n"); const unsigned Sz = R.getVectorElementSize(Chain[0]); unsigned VF = Chain.size(); - if (!isPowerOf2_32(Sz) || !isPowerOf2_32(VF) || VF < 2 || VF < MinVF) - return false; + if (!isPowerOf2_32(Sz) || !isPowerOf2_32(VF) || VF < 2 || VF < MinVF) { + // Check if vectorizing with a non-power-of-2 VF should be considered. At + // the moment, only consider cases where VF + 1 is a power-of-2, i.e. almost + // all vector lanes are used. + if (!VectorizeNonPowerOf2 || (VF < MinVF && VF + 1 != MinVF)) + return false; + } LLVM_DEBUG(dbgs() << "SLP: Analyzing " << VF << " stores at offset " << Idx << "\n"); + SetVector<Value *> ValOps; + for (Value *V : Chain) + ValOps.insert(cast<StoreInst>(V)->getValueOperand()); + // Operands are not same/alt opcodes or non-power-of-2 uniques - exit. + InstructionsState S = getSameOpcode(ValOps.getArrayRef(), *TLI); + if (all_of(ValOps, IsaPred<Instruction>) && ValOps.size() > 1) { + DenseSet<Value *> Stores(Chain.begin(), Chain.end()); + bool IsPowerOf2 = + isPowerOf2_32(ValOps.size()) || + (VectorizeNonPowerOf2 && isPowerOf2_32(ValOps.size() + 1)); + if ((!IsPowerOf2 && S.getOpcode() && S.getOpcode() != Instruction::Load && + (!S.MainOp->isSafeToRemove() || + any_of(ValOps.getArrayRef(), + [&](Value *V) { + return !isa<ExtractElementInst>(V) && + (V->getNumUses() > Chain.size() || + any_of(V->users(), [&](User *U) { + return !Stores.contains(U); + })); + }))) || + (ValOps.size() > Chain.size() / 2 && !S.getOpcode())) { + Size = (!IsPowerOf2 && S.getOpcode()) ? 1 : 2; + return false; + } + } + if (R.isLoadCombineCandidate(Chain)) + return true; R.buildTree(Chain); - if (R.isTreeTinyAndNotFullyVectorizable()) - return false; - if (R.isLoadCombineCandidate()) + // Check if tree tiny and store itself or its value is not vectorized. + if (R.isTreeTinyAndNotFullyVectorizable()) { + if (R.isGathered(Chain.front()) || + R.isNotScheduled(cast<StoreInst>(Chain.front())->getValueOperand())) + return std::nullopt; + Size = R.getTreeSize(); return false; + } R.reorderTopToBottom(); R.reorderBottomToTop(); R.buildExternalUses(); R.computeMinimumValueSizes(); + R.transformNodes(); + Size = R.getTreeSize(); + if (S.getOpcode() == Instruction::Load) + Size = 2; // cut off masked gather small trees InstructionCost Cost = R.getTreeCost(); LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost << " for VF=" << VF << "\n"); @@ -13489,17 +15991,45 @@ bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, return false; } -bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores, - BoUpSLP &R) { +/// Checks if the quadratic mean deviation is less than 90% of the mean size. +static bool checkTreeSizes(ArrayRef<std::pair<unsigned, unsigned>> Sizes, + bool First) { + unsigned Num = 0; + uint64_t Sum = std::accumulate( + Sizes.begin(), Sizes.end(), static_cast<uint64_t>(0), + [&](uint64_t V, const std::pair<unsigned, unsigned> &Val) { + unsigned Size = First ? Val.first : Val.second; + if (Size == 1) + return V; + ++Num; + return V + Size; + }); + if (Num == 0) + return true; + uint64_t Mean = Sum / Num; + if (Mean == 0) + return true; + uint64_t Dev = std::accumulate( + Sizes.begin(), Sizes.end(), static_cast<uint64_t>(0), + [&](uint64_t V, const std::pair<unsigned, unsigned> &Val) { + unsigned P = First ? Val.first : Val.second; + if (P == 1) + return V; + return V + (P - Mean) * (P - Mean); + }) / + Num; + return Dev * 81 / (Mean * Mean) == 0; +} + +bool SLPVectorizerPass::vectorizeStores( + ArrayRef<StoreInst *> Stores, BoUpSLP &R, + DenseSet<std::tuple<Value *, Value *, Value *, Value *, unsigned>> + &Visited) { // We may run into multiple chains that merge into a single chain. We mark the // stores that we vectorized so that we don't visit the same store twice. BoUpSLP::ValueSet VectorizedStores; bool Changed = false; - // Stores the pair of stores (first_store, last_store) in a range, that were - // already tried to be vectorized. Allows to skip the store ranges that were - // already tried to be vectorized but the attempts were unsuccessful. - DenseSet<std::pair<Value *, Value *>> TriedSequences; struct StoreDistCompare { bool operator()(const std::pair<unsigned, int> &Op1, const std::pair<unsigned, int> &Op2) const { @@ -13521,12 +16051,21 @@ bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores, if (Idx != Set.size() - 1) continue; } - if (Operands.size() <= 1) { + auto E = make_scope_exit([&, &DataVar = Data]() { Operands.clear(); - Operands.push_back(Stores[Data.first]); - PrevDist = Data.second; + Operands.push_back(Stores[DataVar.first]); + PrevDist = DataVar.second; + }); + + if (Operands.size() <= 1 || + !Visited + .insert({Operands.front(), + cast<StoreInst>(Operands.front())->getValueOperand(), + Operands.back(), + cast<StoreInst>(Operands.back())->getValueOperand(), + Operands.size()}) + .second) continue; - } unsigned MaxVecRegSize = R.getMaxVecRegSize(); unsigned EltSize = R.getVectorElementSize(Operands[0]); @@ -13534,60 +16073,223 @@ bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores, unsigned MaxVF = std::min(R.getMaximumVF(EltSize, Instruction::Store), MaxElts); + unsigned MaxRegVF = MaxVF; auto *Store = cast<StoreInst>(Operands[0]); Type *StoreTy = Store->getValueOperand()->getType(); Type *ValueTy = StoreTy; if (auto *Trunc = dyn_cast<TruncInst>(Store->getValueOperand())) ValueTy = Trunc->getSrcTy(); - unsigned MinVF = TTI->getStoreMinimumVF( - R.getMinVF(DL->getTypeSizeInBits(ValueTy)), StoreTy, ValueTy); - - if (MaxVF <= MinVF) { + if (ValueTy == StoreTy && + R.getVectorElementSize(Store->getValueOperand()) <= EltSize) + MaxVF = std::min<unsigned>(MaxVF, bit_floor(Operands.size())); + unsigned MinVF = std::max<unsigned>( + 2, PowerOf2Ceil(TTI->getStoreMinimumVF( + R.getMinVF(DL->getTypeStoreSizeInBits(StoreTy)), StoreTy, + ValueTy))); + + if (MaxVF < MinVF) { LLVM_DEBUG(dbgs() << "SLP: Vectorization infeasible as MaxVF (" << MaxVF - << ") <= " + << ") < " << "MinVF (" << MinVF << ")\n"); + continue; } - // FIXME: Is division-by-2 the correct step? Should we assert that the - // register size is a power-of-2? - unsigned StartIdx = 0; - for (unsigned Size = MaxVF; Size >= MinVF; Size /= 2) { - for (unsigned Cnt = StartIdx, E = Operands.size(); Cnt + Size <= E;) { - ArrayRef<Value *> Slice = ArrayRef(Operands).slice(Cnt, Size); - assert( - all_of( - Slice, - [&](Value *V) { - return cast<StoreInst>(V)->getValueOperand()->getType() == - cast<StoreInst>(Slice.front()) - ->getValueOperand() - ->getType(); - }) && - "Expected all operands of same type."); - if (!VectorizedStores.count(Slice.front()) && - !VectorizedStores.count(Slice.back()) && - TriedSequences.insert(std::make_pair(Slice.front(), Slice.back())) - .second && - vectorizeStoreChain(Slice, R, Cnt, MinVF)) { - // Mark the vectorized stores so that we don't vectorize them again. - VectorizedStores.insert(Slice.begin(), Slice.end()); - Changed = true; - // If we vectorized initial block, no need to try to vectorize it - // again. - if (Cnt == StartIdx) - StartIdx += Size; - Cnt += Size; - continue; + unsigned NonPowerOf2VF = 0; + if (VectorizeNonPowerOf2) { + // First try vectorizing with a non-power-of-2 VF. At the moment, only + // consider cases where VF + 1 is a power-of-2, i.e. almost all vector + // lanes are used. + unsigned CandVF = Operands.size(); + if (isPowerOf2_32(CandVF + 1) && CandVF <= MaxRegVF) + NonPowerOf2VF = CandVF; + } + + unsigned Sz = 1 + Log2_32(MaxVF) - Log2_32(MinVF); + SmallVector<unsigned> CandidateVFs(Sz + (NonPowerOf2VF > 0 ? 1 : 0)); + unsigned Size = MinVF; + for_each(reverse(CandidateVFs), [&](unsigned &VF) { + VF = Size > MaxVF ? NonPowerOf2VF : Size; + Size *= 2; + }); + unsigned End = Operands.size(); + unsigned Repeat = 0; + constexpr unsigned MaxAttempts = 4; + OwningArrayRef<std::pair<unsigned, unsigned>> RangeSizes(Operands.size()); + for_each(RangeSizes, [](std::pair<unsigned, unsigned> &P) { + P.first = P.second = 1; + }); + DenseMap<Value *, std::pair<unsigned, unsigned>> NonSchedulable; + auto IsNotVectorized = [](bool First, + const std::pair<unsigned, unsigned> &P) { + return First ? P.first > 0 : P.second > 0; + }; + auto IsVectorized = [](bool First, + const std::pair<unsigned, unsigned> &P) { + return First ? P.first == 0 : P.second == 0; + }; + auto VFIsProfitable = [](bool First, unsigned Size, + const std::pair<unsigned, unsigned> &P) { + return First ? Size >= P.first : Size >= P.second; + }; + auto FirstSizeSame = [](unsigned Size, + const std::pair<unsigned, unsigned> &P) { + return Size == P.first; + }; + while (true) { + ++Repeat; + bool RepeatChanged = false; + bool AnyProfitableGraph = false; + for (unsigned Size : CandidateVFs) { + AnyProfitableGraph = false; + unsigned StartIdx = std::distance( + RangeSizes.begin(), + find_if(RangeSizes, std::bind(IsNotVectorized, Size >= MaxRegVF, + std::placeholders::_1))); + while (StartIdx < End) { + unsigned EndIdx = + std::distance(RangeSizes.begin(), + find_if(RangeSizes.drop_front(StartIdx), + std::bind(IsVectorized, Size >= MaxRegVF, + std::placeholders::_1))); + unsigned Sz = EndIdx >= End ? End : EndIdx; + for (unsigned Cnt = StartIdx; Cnt + Size <= Sz;) { + if (!checkTreeSizes(RangeSizes.slice(Cnt, Size), + Size >= MaxRegVF)) { + ++Cnt; + continue; + } + ArrayRef<Value *> Slice = ArrayRef(Operands).slice(Cnt, Size); + assert(all_of(Slice, + [&](Value *V) { + return cast<StoreInst>(V) + ->getValueOperand() + ->getType() == + cast<StoreInst>(Slice.front()) + ->getValueOperand() + ->getType(); + }) && + "Expected all operands of same type."); + if (!NonSchedulable.empty()) { + auto [NonSchedSizeMax, NonSchedSizeMin] = + NonSchedulable.lookup(Slice.front()); + if (NonSchedSizeMax > 0 && NonSchedSizeMin <= Size) { + Cnt += NonSchedSizeMax; + continue; + } + } + unsigned TreeSize; + std::optional<bool> Res = + vectorizeStoreChain(Slice, R, Cnt, MinVF, TreeSize); + if (!Res) { + NonSchedulable + .try_emplace(Slice.front(), std::make_pair(Size, Size)) + .first->getSecond() + .second = Size; + } else if (*Res) { + // Mark the vectorized stores so that we don't vectorize them + // again. + VectorizedStores.insert(Slice.begin(), Slice.end()); + // Mark the vectorized stores so that we don't vectorize them + // again. + AnyProfitableGraph = RepeatChanged = Changed = true; + // If we vectorized initial block, no need to try to vectorize + // it again. + for_each(RangeSizes.slice(Cnt, Size), + [](std::pair<unsigned, unsigned> &P) { + P.first = P.second = 0; + }); + if (Cnt < StartIdx + MinVF) { + for_each(RangeSizes.slice(StartIdx, Cnt - StartIdx), + [](std::pair<unsigned, unsigned> &P) { + P.first = P.second = 0; + }); + StartIdx = Cnt + Size; + } + if (Cnt > Sz - Size - MinVF) { + for_each(RangeSizes.slice(Cnt + Size, Sz - (Cnt + Size)), + [](std::pair<unsigned, unsigned> &P) { + P.first = P.second = 0; + }); + if (Sz == End) + End = Cnt; + Sz = Cnt; + } + Cnt += Size; + continue; + } + if (Size > 2 && Res && + !all_of(RangeSizes.slice(Cnt, Size), + std::bind(VFIsProfitable, Size >= MaxRegVF, TreeSize, + std::placeholders::_1))) { + Cnt += Size; + continue; + } + // Check for the very big VFs that we're not rebuilding same + // trees, just with larger number of elements. + if (Size > MaxRegVF && TreeSize > 1 && + all_of(RangeSizes.slice(Cnt, Size), + std::bind(FirstSizeSame, TreeSize, + std::placeholders::_1))) { + Cnt += Size; + while (Cnt != Sz && RangeSizes[Cnt].first == TreeSize) + ++Cnt; + continue; + } + if (TreeSize > 1) + for_each(RangeSizes.slice(Cnt, Size), + [&](std::pair<unsigned, unsigned> &P) { + if (Size >= MaxRegVF) + P.second = std::max(P.second, TreeSize); + else + P.first = std::max(P.first, TreeSize); + }); + ++Cnt; + AnyProfitableGraph = true; + } + if (StartIdx >= End) + break; + if (Sz - StartIdx < Size && Sz - StartIdx >= MinVF) + AnyProfitableGraph = true; + StartIdx = std::distance( + RangeSizes.begin(), + find_if(RangeSizes.drop_front(Sz), + std::bind(IsNotVectorized, Size >= MaxRegVF, + std::placeholders::_1))); } - ++Cnt; + if (!AnyProfitableGraph && Size >= MaxRegVF) + break; } - // Check if the whole array was vectorized already - exit. - if (StartIdx >= Operands.size()) + // All values vectorized - exit. + if (all_of(RangeSizes, [](const std::pair<unsigned, unsigned> &P) { + return P.first == 0 && P.second == 0; + })) + break; + // Check if tried all attempts or no need for the last attempts at all. + if (Repeat >= MaxAttempts || + (Repeat > 1 && (RepeatChanged || !AnyProfitableGraph))) + break; + constexpr unsigned StoresLimit = 64; + const unsigned MaxTotalNum = bit_floor(std::min<unsigned>( + Operands.size(), + static_cast<unsigned>( + End - + std::distance( + RangeSizes.begin(), + find_if(RangeSizes, std::bind(IsNotVectorized, true, + std::placeholders::_1))) + + 1))); + unsigned VF = PowerOf2Ceil(CandidateVFs.front()) * 2; + if (VF > MaxTotalNum || VF >= StoresLimit) break; + for_each(RangeSizes, [&](std::pair<unsigned, unsigned> &P) { + if (P.first != 0) + P.first = std::max(P.second, P.first); + }); + // Last attempt to vectorize max number of elements, if all previous + // attempts were unsuccessful because of the cost issues. + CandidateVFs.clear(); + CandidateVFs.push_back(VF); } - Operands.clear(); - Operands.push_back(Stores[Data.first]); - PrevDist = Data.second; } }; @@ -13679,15 +16381,18 @@ bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores, Res.first = Idx; Res.second.emplace(Idx, 0); }; - StoreInst *PrevStore = Stores.front(); + Type *PrevValTy = nullptr; for (auto [I, SI] : enumerate(Stores)) { + if (R.isDeleted(SI)) + continue; + if (!PrevValTy) + PrevValTy = SI->getValueOperand()->getType(); // Check that we do not try to vectorize stores of different types. - if (PrevStore->getValueOperand()->getType() != - SI->getValueOperand()->getType()) { + if (PrevValTy != SI->getValueOperand()->getType()) { for (auto &Set : SortedStores) TryToVectorize(Set.second); SortedStores.clear(); - PrevStore = SI; + PrevValTy = SI->getValueOperand()->getType(); } FillStoresSet(I, SI); } @@ -13764,7 +16469,7 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, Ty->print(rso); return OptimizationRemarkMissed(SV_NAME, "UnsupportedType", I0) << "Cannot SLP vectorize list: type " - << rso.str() + " is unsupported by vectorizer"; + << TypeStr + " is unsupported by vectorizer"; }); return false; } @@ -13795,7 +16500,7 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, // No actual vectorization should happen, if number of parts is the same as // provided vectorization factor (i.e. the scalar type is used for vector // code during codegen). - auto *VecTy = FixedVectorType::get(ScalarTy, VF); + auto *VecTy = getWidenedType(ScalarTy, VF); if (TTI->getNumberOfParts(VecTy) == VF) continue; for (unsigned I = NextInst; I < MaxInst; ++I) { @@ -13830,6 +16535,7 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, R.buildExternalUses(); R.computeMinimumValueSizes(); + R.transformNodes(); InstructionCost Cost = R.getTreeCost(); CandidateFound = true; MinCost = std::min(MinCost, Cost); @@ -14013,10 +16719,9 @@ class HorizontalReduction { } /// Creates reduction operation with the current opcode. - static Value *createOp(IRBuilder<> &Builder, RecurKind Kind, Value *LHS, + static Value *createOp(IRBuilderBase &Builder, RecurKind Kind, Value *LHS, Value *RHS, const Twine &Name, bool UseSelect) { unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(Kind); - bool IsConstant = isConstant(LHS) && isConstant(RHS); switch (Kind) { case RecurKind::Or: if (UseSelect && @@ -14038,49 +16743,33 @@ class HorizontalReduction { return Builder.CreateBinOp((Instruction::BinaryOps)RdxOpcode, LHS, RHS, Name); case RecurKind::FMax: - if (IsConstant) - return ConstantFP::get(LHS->getType(), - maxnum(cast<ConstantFP>(LHS)->getValueAPF(), - cast<ConstantFP>(RHS)->getValueAPF())); return Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, LHS, RHS); case RecurKind::FMin: - if (IsConstant) - return ConstantFP::get(LHS->getType(), - minnum(cast<ConstantFP>(LHS)->getValueAPF(), - cast<ConstantFP>(RHS)->getValueAPF())); return Builder.CreateBinaryIntrinsic(Intrinsic::minnum, LHS, RHS); case RecurKind::FMaximum: - if (IsConstant) - return ConstantFP::get(LHS->getType(), - maximum(cast<ConstantFP>(LHS)->getValueAPF(), - cast<ConstantFP>(RHS)->getValueAPF())); return Builder.CreateBinaryIntrinsic(Intrinsic::maximum, LHS, RHS); case RecurKind::FMinimum: - if (IsConstant) - return ConstantFP::get(LHS->getType(), - minimum(cast<ConstantFP>(LHS)->getValueAPF(), - cast<ConstantFP>(RHS)->getValueAPF())); return Builder.CreateBinaryIntrinsic(Intrinsic::minimum, LHS, RHS); case RecurKind::SMax: - if (IsConstant || UseSelect) { + if (UseSelect) { Value *Cmp = Builder.CreateICmpSGT(LHS, RHS, Name); return Builder.CreateSelect(Cmp, LHS, RHS, Name); } return Builder.CreateBinaryIntrinsic(Intrinsic::smax, LHS, RHS); case RecurKind::SMin: - if (IsConstant || UseSelect) { + if (UseSelect) { Value *Cmp = Builder.CreateICmpSLT(LHS, RHS, Name); return Builder.CreateSelect(Cmp, LHS, RHS, Name); } return Builder.CreateBinaryIntrinsic(Intrinsic::smin, LHS, RHS); case RecurKind::UMax: - if (IsConstant || UseSelect) { + if (UseSelect) { Value *Cmp = Builder.CreateICmpUGT(LHS, RHS, Name); return Builder.CreateSelect(Cmp, LHS, RHS, Name); } return Builder.CreateBinaryIntrinsic(Intrinsic::umax, LHS, RHS); case RecurKind::UMin: - if (IsConstant || UseSelect) { + if (UseSelect) { Value *Cmp = Builder.CreateICmpULT(LHS, RHS, Name); return Builder.CreateSelect(Cmp, LHS, RHS, Name); } @@ -14092,15 +16781,13 @@ class HorizontalReduction { /// Creates reduction operation with the current opcode with the IR flags /// from \p ReductionOps, dropping nuw/nsw flags. - static Value *createOp(IRBuilder<> &Builder, RecurKind RdxKind, Value *LHS, + static Value *createOp(IRBuilderBase &Builder, RecurKind RdxKind, Value *LHS, Value *RHS, const Twine &Name, const ReductionOpsListType &ReductionOps) { - bool UseSelect = - ReductionOps.size() == 2 || - // Logical or/and. - (ReductionOps.size() == 1 && any_of(ReductionOps.front(), [](Value *V) { - return isa<SelectInst>(V); - })); + bool UseSelect = ReductionOps.size() == 2 || + // Logical or/and. + (ReductionOps.size() == 1 && + any_of(ReductionOps.front(), IsaPred<SelectInst>)); assert((!UseSelect || ReductionOps.size() != 2 || isa<SelectInst>(ReductionOps[1][0])) && "Expected cmp + select pairs for reduction"); @@ -14318,9 +17005,8 @@ public: SmallVectorImpl<Value *> &ExtraArgs, SmallVectorImpl<Value *> &PossibleReducedVals, SmallVectorImpl<Instruction *> &ReductionOps) { - for (int I = getFirstOperandIndex(TreeN), - End = getNumberOfOperands(TreeN); - I < End; ++I) { + for (int I : reverse(seq<int>(getFirstOperandIndex(TreeN), + getNumberOfOperands(TreeN)))) { Value *EdgeVal = getRdxOperand(TreeN, I); ReducedValsToOps[EdgeVal].push_back(TreeN); auto *EdgeInst = dyn_cast<Instruction>(EdgeVal); @@ -14339,7 +17025,7 @@ public: !hasRequiredNumberOfUses(IsCmpSelMinMax, EdgeInst) || !isVectorizable(RdxKind, EdgeInst) || (R.isAnalyzedReductionRoot(EdgeInst) && - all_of(EdgeInst->operands(), Constant::classof))) { + all_of(EdgeInst->operands(), IsaPred<Constant>))) { PossibleReducedVals.push_back(EdgeVal); continue; } @@ -14356,7 +17042,6 @@ public: initReductionOps(Root); DenseMap<Value *, SmallVector<LoadInst *>> LoadsMap; SmallSet<size_t, 2> LoadKeyUsed; - SmallPtrSet<Value *, 4> DoNotReverseVals; auto GenerateLoadsSubkey = [&](size_t Key, LoadInst *LI) { Value *Ptr = getUnderlyingObject(LI->getPointerOperand()); @@ -14373,14 +17058,12 @@ public: if (arePointersCompatible(RLI->getPointerOperand(), LI->getPointerOperand(), TLI)) { hash_code SubKey = hash_value(RLI->getPointerOperand()); - DoNotReverseVals.insert(RLI); return SubKey; } } if (LIt->second.size() > 2) { hash_code SubKey = hash_value(LIt->second.back()->getPointerOperand()); - DoNotReverseVals.insert(LIt->second.back()); return SubKey; } } @@ -14445,24 +17128,19 @@ public: }); int NewIdx = -1; for (ArrayRef<Value *> Data : PossibleRedValsVect) { - if (isGoodForReduction(Data) || - (isa<LoadInst>(Data.front()) && NewIdx >= 0 && - isa<LoadInst>(ReducedVals[NewIdx].front()) && - getUnderlyingObject( - cast<LoadInst>(Data.front())->getPointerOperand()) == - getUnderlyingObject(cast<LoadInst>(ReducedVals[NewIdx].front()) - ->getPointerOperand()))) { - if (NewIdx < 0) { - NewIdx = ReducedVals.size(); - ReducedVals.emplace_back(); - } - if (DoNotReverseVals.contains(Data.front())) - ReducedVals[NewIdx].append(Data.begin(), Data.end()); - else - ReducedVals[NewIdx].append(Data.rbegin(), Data.rend()); - } else { - ReducedVals.emplace_back().append(Data.rbegin(), Data.rend()); + if (NewIdx < 0 || + (!isGoodForReduction(Data) && + (!isa<LoadInst>(Data.front()) || + !isa<LoadInst>(ReducedVals[NewIdx].front()) || + getUnderlyingObject( + cast<LoadInst>(Data.front())->getPointerOperand()) != + getUnderlyingObject( + cast<LoadInst>(ReducedVals[NewIdx].front()) + ->getPointerOperand())))) { + NewIdx = ReducedVals.size(); + ReducedVals.emplace_back(); } + ReducedVals[NewIdx].append(Data.rbegin(), Data.rend()); } } // Sort the reduced values by number of same/alternate opcode and/or pointer @@ -14474,7 +17152,7 @@ public: } /// Attempt to vectorize the tree found by matchAssociativeReduction. - Value *tryToReduce(BoUpSLP &V, TargetTransformInfo *TTI, + Value *tryToReduce(BoUpSLP &V, const DataLayout &DL, TargetTransformInfo *TTI, const TargetLibraryInfo &TLI) { constexpr int ReductionLimit = 4; constexpr unsigned RegMaxNumber = 4; @@ -14500,7 +17178,9 @@ public: return nullptr; } - IRBuilder<> Builder(cast<Instruction>(ReductionRoot)); + IRBuilder<TargetFolder> Builder(ReductionRoot->getContext(), + TargetFolder(DL)); + Builder.SetInsertPoint(cast<Instruction>(ReductionRoot)); // Track the reduced values in case if they are replaced by extractelement // because of the vectorization. @@ -14586,9 +17266,12 @@ public: Value *VectorizedTree = nullptr; bool CheckForReusedReductionOps = false; // Try to vectorize elements based on their type. + SmallVector<InstructionsState> States; + for (ArrayRef<Value *> RV : ReducedVals) + States.push_back(getSameOpcode(RV, TLI)); for (unsigned I = 0, E = ReducedVals.size(); I < E; ++I) { ArrayRef<Value *> OrigReducedVals = ReducedVals[I]; - InstructionsState S = getSameOpcode(OrigReducedVals, TLI); + InstructionsState S = States[I]; SmallVector<Value *> Candidates; Candidates.reserve(2 * OrigReducedVals.size()); DenseMap<Value *, Value *> TrackedToOrig(2 * OrigReducedVals.size()); @@ -14708,7 +17391,9 @@ public: RegMaxNumber * llvm::bit_floor(MaxVecRegSize / EltSize); unsigned ReduxWidth = std::min<unsigned>( - llvm::bit_floor(NumReducedVals), std::max(RedValsMaxNumber, MaxElts)); + llvm::bit_floor(NumReducedVals), + std::clamp<unsigned>(MaxElts, RedValsMaxNumber, + RegMaxNumber * RedValsMaxNumber)); unsigned Start = 0; unsigned Pos = Start; // Restarts vectorization attempt with lower vector factor. @@ -14841,6 +17526,7 @@ public: V.buildExternalUses(LocalExternallyUsedValues); V.computeMinimumValueSizes(); + V.transformNodes(); // Estimate cost. InstructionCost TreeCost = V.getTreeCost(VL); @@ -14850,7 +17536,7 @@ public: LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost << " for reduction\n"); if (!Cost.isValid()) - return nullptr; + break; if (Cost >= -SLPCostThreshold) { V.getORE()->emit([&]() { return OptimizationRemarkMissed( @@ -14902,22 +17588,18 @@ public: // Emit code to correctly handle reused reduced values, if required. if (OptReusedScalars && !SameScaleFactor) { - VectorizedRoot = - emitReusedOps(VectorizedRoot, Builder, V.getRootNodeScalars(), - SameValuesCounter, TrackedToOrig); + VectorizedRoot = emitReusedOps(VectorizedRoot, Builder, V, + SameValuesCounter, TrackedToOrig); } Value *ReducedSubTree = emitReduction(VectorizedRoot, Builder, ReduxWidth, TTI); if (ReducedSubTree->getType() != VL.front()->getType()) { - ReducedSubTree = Builder.CreateIntCast( - ReducedSubTree, VL.front()->getType(), any_of(VL, [&](Value *R) { - KnownBits Known = computeKnownBits( - R, cast<Instruction>(ReductionOps.front().front()) - ->getModule() - ->getDataLayout()); - return !Known.isNonNegative(); - })); + assert(ReducedSubTree->getType() != VL.front()->getType() && + "Expected different reduction type."); + ReducedSubTree = + Builder.CreateIntCast(ReducedSubTree, VL.front()->getType(), + V.isSignedMinBitwidthRootNode()); } // Improved analysis for add/fadd/xor reductions with same scale factor @@ -15050,7 +17732,6 @@ public: // Iterate through all not-vectorized reduction values/extra arguments. bool InitStep = true; while (ExtraReductions.size() > 1) { - VectorizedTree = ExtraReductions.front().second; SmallVector<std::pair<Instruction *, Value *>> NewReds = FinalGen(ExtraReductions, InitStep); ExtraReductions.swap(NewReds); @@ -15080,11 +17761,11 @@ public: } #endif if (!Ignore->use_empty()) { - Value *Undef = UndefValue::get(Ignore->getType()); - Ignore->replaceAllUsesWith(Undef); + Value *P = PoisonValue::get(Ignore->getType()); + Ignore->replaceAllUsesWith(P); } - V.eraseInstruction(cast<Instruction>(Ignore)); } + V.removeInstructionsAndOperands(RdxOps); } } else if (!CheckForReusedReductionOps) { for (ReductionOpsType &RdxOps : ReductionOps) @@ -15102,7 +17783,7 @@ private: FastMathFlags FMF) { TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; Type *ScalarTy = ReducedVals.front()->getType(); - FixedVectorType *VectorTy = FixedVectorType::get(ScalarTy, ReduxWidth); + FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth); InstructionCost VectorCost = 0, ScalarCost; // If all of the reduced values are constant, the vector cost is 0, since // the reduction value can be calculated at the compile time. @@ -15181,7 +17862,7 @@ private: } /// Emit a horizontal reduction of the vectorized value. - Value *emitReduction(Value *VectorizedValue, IRBuilder<> &Builder, + Value *emitReduction(Value *VectorizedValue, IRBuilderBase &Builder, unsigned ReduxWidth, const TargetTransformInfo *TTI) { assert(VectorizedValue && "Need to have a vectorized tree node"); assert(isPowerOf2_32(ReduxWidth) && @@ -15248,24 +17929,19 @@ private: /// Emits actual operation for the scalar identity values, found during /// horizontal reduction analysis. Value *emitReusedOps(Value *VectorizedValue, IRBuilderBase &Builder, - ArrayRef<Value *> VL, + BoUpSLP &R, const MapVector<Value *, unsigned> &SameValuesCounter, const DenseMap<Value *, Value *> &TrackedToOrig) { assert(IsSupportedHorRdxIdentityOp && "The optimization of matched scalar identity horizontal reductions " "must be supported."); + ArrayRef<Value *> VL = R.getRootNodeScalars(); auto *VTy = cast<FixedVectorType>(VectorizedValue->getType()); if (VTy->getElementType() != VL.front()->getType()) { VectorizedValue = Builder.CreateIntCast( VectorizedValue, - FixedVectorType::get(VL.front()->getType(), VTy->getNumElements()), - any_of(VL, [&](Value *R) { - KnownBits Known = computeKnownBits( - R, cast<Instruction>(ReductionOps.front().front()) - ->getModule() - ->getDataLayout()); - return !Known.isNonNegative(); - })); + getWidenedType(VL.front()->getType(), VTy->getNumElements()), + R.isSignedMinBitwidthRootNode()); } switch (RdxKind) { case RecurKind::Add: { @@ -15349,6 +18025,10 @@ private: }; } // end anonymous namespace +/// Gets recurrence kind from the specified value. +static RecurKind getRdxKind(Value *V) { + return HorizontalReduction::getRdxKind(V); +} static std::optional<unsigned> getAggregateSize(Instruction *InsertInst) { if (auto *IE = dyn_cast<InsertElementInst>(InsertInst)) return cast<FixedVectorType>(IE->getType())->getNumElements(); @@ -15385,7 +18065,7 @@ static void findBuildAggregate_rec(Instruction *LastInsertInst, do { Value *InsertedOperand = LastInsertInst->getOperand(1); std::optional<unsigned> OperandIndex = - getInsertIndex(LastInsertInst, OperandOffset); + getElementIndex(LastInsertInst, OperandOffset); if (!OperandIndex) return; if (isa<InsertElementInst, InsertValueInst>(InsertedOperand)) { @@ -15596,7 +18276,7 @@ bool SLPVectorizerPass::vectorizeHorReduction( HorizontalReduction HorRdx; if (!HorRdx.matchAssociativeReduction(R, Inst, *SE, *DL, *TLI)) return nullptr; - return HorRdx.tryToReduce(R, TTI, *TLI); + return HorRdx.tryToReduce(R, *DL, TTI, *TLI); }; auto TryAppendToPostponedInsts = [&](Instruction *FutureSeed) { if (TryOperandsAsNewSeeds && FutureSeed == Root) { @@ -15628,6 +18308,8 @@ bool SLPVectorizerPass::vectorizeHorReduction( Stack.emplace(I, Level); continue; } + if (R.isDeleted(Inst)) + continue; } else { // We could not vectorize `Inst` so try to use it as a future seed. if (!TryAppendToPostponedInsts(Inst)) { @@ -15671,7 +18353,8 @@ bool SLPVectorizerPass::tryToVectorize(ArrayRef<WeakTrackingVH> Insts, } bool SLPVectorizerPass::vectorizeInsertValueInst(InsertValueInst *IVI, - BasicBlock *BB, BoUpSLP &R) { + BasicBlock *BB, BoUpSLP &R, + bool MaxVFOnly) { if (!R.canMapToVector(IVI->getType())) return false; @@ -15680,25 +18363,40 @@ bool SLPVectorizerPass::vectorizeInsertValueInst(InsertValueInst *IVI, if (!findBuildAggregate(IVI, TTI, BuildVectorOpds, BuildVectorInsts)) return false; + if (MaxVFOnly && BuildVectorOpds.size() == 2) { + R.getORE()->emit([&]() { + return OptimizationRemarkMissed(SV_NAME, "NotPossible", IVI) + << "Cannot SLP vectorize list: only 2 elements of buildvalue, " + "trying reduction first."; + }); + return false; + } LLVM_DEBUG(dbgs() << "SLP: array mappable to vector: " << *IVI << "\n"); // Aggregate value is unlikely to be processed in vector register. - return tryToVectorizeList(BuildVectorOpds, R); + return tryToVectorizeList(BuildVectorOpds, R, MaxVFOnly); } bool SLPVectorizerPass::vectorizeInsertElementInst(InsertElementInst *IEI, - BasicBlock *BB, BoUpSLP &R) { + BasicBlock *BB, BoUpSLP &R, + bool MaxVFOnly) { SmallVector<Value *, 16> BuildVectorInsts; SmallVector<Value *, 16> BuildVectorOpds; SmallVector<int> Mask; if (!findBuildAggregate(IEI, TTI, BuildVectorOpds, BuildVectorInsts) || - (llvm::all_of( - BuildVectorOpds, - [](Value *V) { return isa<ExtractElementInst, UndefValue>(V); }) && + (llvm::all_of(BuildVectorOpds, IsaPred<ExtractElementInst, UndefValue>) && isFixedVectorShuffle(BuildVectorOpds, Mask))) return false; + if (MaxVFOnly && BuildVectorInsts.size() == 2) { + R.getORE()->emit([&]() { + return OptimizationRemarkMissed(SV_NAME, "NotPossible", IEI) + << "Cannot SLP vectorize list: only 2 elements of buildvector, " + "trying reduction first."; + }); + return false; + } LLVM_DEBUG(dbgs() << "SLP: array mappable to vector: " << *IEI << "\n"); - return tryToVectorizeList(BuildVectorInsts, R); + return tryToVectorizeList(BuildVectorInsts, R, MaxVFOnly); } template <typename T> @@ -15713,15 +18411,28 @@ static bool tryToVectorizeSequence( // Try to vectorize elements base on their type. SmallVector<T *> Candidates; - for (auto *IncIt = Incoming.begin(), *E = Incoming.end(); IncIt != E;) { + SmallVector<T *> VL; + for (auto *IncIt = Incoming.begin(), *E = Incoming.end(); IncIt != E; + VL.clear()) { // Look for the next elements with the same type, parent and operand // kinds. + auto *I = dyn_cast<Instruction>(*IncIt); + if (!I || R.isDeleted(I)) { + ++IncIt; + continue; + } auto *SameTypeIt = IncIt; - while (SameTypeIt != E && AreCompatible(*SameTypeIt, *IncIt)) + while (SameTypeIt != E && (!isa<Instruction>(*SameTypeIt) || + R.isDeleted(cast<Instruction>(*SameTypeIt)) || + AreCompatible(*SameTypeIt, *IncIt))) { + auto *I = dyn_cast<Instruction>(*SameTypeIt); ++SameTypeIt; + if (I && !R.isDeleted(I)) + VL.push_back(cast<T>(I)); + } // Try to vectorize them. - unsigned NumElts = (SameTypeIt - IncIt); + unsigned NumElts = VL.size(); LLVM_DEBUG(dbgs() << "SLP: Trying to vectorize starting at nodes (" << NumElts << ")\n"); // The vectorization is a 3-state attempt: @@ -15733,10 +18444,15 @@ static bool tryToVectorizeSequence( // 3. Final attempt to try to vectorize all instructions with the // same/alternate ops only, this may result in some extra final // vectorization. - if (NumElts > 1 && - TryToVectorizeHelper(ArrayRef(IncIt, NumElts), MaxVFOnly)) { + if (NumElts > 1 && TryToVectorizeHelper(ArrayRef(VL), MaxVFOnly)) { // Success start over because instructions might have been changed. Changed = true; + VL.swap(Candidates); + Candidates.clear(); + for (T *V : VL) { + if (auto *I = dyn_cast<Instruction>(V); I && !R.isDeleted(I)) + Candidates.push_back(V); + } } else { /// \Returns the minimum number of elements that we will attempt to /// vectorize. @@ -15747,7 +18463,10 @@ static bool tryToVectorizeSequence( if (NumElts < GetMinNumElements(*IncIt) && (Candidates.empty() || Candidates.front()->getType() == (*IncIt)->getType())) { - Candidates.append(IncIt, std::next(IncIt, NumElts)); + for (T *V : VL) { + if (auto *I = dyn_cast<Instruction>(V); I && !R.isDeleted(I)) + Candidates.push_back(V); + } } } // Final attempt to vectorize instructions with the same types. @@ -15758,13 +18477,26 @@ static bool tryToVectorizeSequence( Changed = true; } else if (MaxVFOnly) { // Try to vectorize using small vectors. - for (auto *It = Candidates.begin(), *End = Candidates.end(); - It != End;) { + SmallVector<T *> VL; + for (auto *It = Candidates.begin(), *End = Candidates.end(); It != End; + VL.clear()) { + auto *I = dyn_cast<Instruction>(*It); + if (!I || R.isDeleted(I)) { + ++It; + continue; + } auto *SameTypeIt = It; - while (SameTypeIt != End && AreCompatible(*SameTypeIt, *It)) + while (SameTypeIt != End && + (!isa<Instruction>(*SameTypeIt) || + R.isDeleted(cast<Instruction>(*SameTypeIt)) || + AreCompatible(*SameTypeIt, *It))) { + auto *I = dyn_cast<Instruction>(*SameTypeIt); ++SameTypeIt; - unsigned NumElts = (SameTypeIt - It); - if (NumElts > 1 && TryToVectorizeHelper(ArrayRef(It, NumElts), + if (I && !R.isDeleted(I)) + VL.push_back(cast<T>(I)); + } + unsigned NumElts = VL.size(); + if (NumElts > 1 && TryToVectorizeHelper(ArrayRef(VL), /*MaxVFOnly=*/false)) Changed = true; It = SameTypeIt; @@ -15864,8 +18596,11 @@ bool SLPVectorizerPass::vectorizeCmpInsts(iterator_range<ItT> CmpInsts, if (R.isDeleted(I)) continue; for (Value *Op : I->operands()) - if (auto *RootOp = dyn_cast<Instruction>(Op)) + if (auto *RootOp = dyn_cast<Instruction>(Op)) { Changed |= vectorizeRootInstruction(nullptr, RootOp, BB, R, TTI); + if (R.isDeleted(I)) + break; + } } // Try to vectorize operands as vector bundles. for (CmpInst *I : CmpInsts) { @@ -15914,27 +18649,34 @@ bool SLPVectorizerPass::vectorizeCmpInsts(iterator_range<ItT> CmpInsts, bool SLPVectorizerPass::vectorizeInserts(InstSetVector &Instructions, BasicBlock *BB, BoUpSLP &R) { - assert(all_of(Instructions, - [](auto *I) { - return isa<InsertElementInst, InsertValueInst>(I); - }) && + assert(all_of(Instructions, IsaPred<InsertElementInst, InsertValueInst>) && "This function only accepts Insert instructions"); bool OpsChanged = false; SmallVector<WeakTrackingVH> PostponedInsts; - // pass1 - try to vectorize reductions only for (auto *I : reverse(Instructions)) { + // pass1 - try to match and vectorize a buildvector sequence for MaxVF only. + if (R.isDeleted(I) || isa<CmpInst>(I)) + continue; + if (auto *LastInsertValue = dyn_cast<InsertValueInst>(I)) { + OpsChanged |= + vectorizeInsertValueInst(LastInsertValue, BB, R, /*MaxVFOnly=*/true); + } else if (auto *LastInsertElem = dyn_cast<InsertElementInst>(I)) { + OpsChanged |= + vectorizeInsertElementInst(LastInsertElem, BB, R, /*MaxVFOnly=*/true); + } + // pass2 - try to vectorize reductions only if (R.isDeleted(I)) continue; OpsChanged |= vectorizeHorReduction(nullptr, I, BB, R, TTI, PostponedInsts); - } - // pass2 - try to match and vectorize a buildvector sequence. - for (auto *I : reverse(Instructions)) { if (R.isDeleted(I) || isa<CmpInst>(I)) continue; + // pass3 - try to match and vectorize a buildvector sequence. if (auto *LastInsertValue = dyn_cast<InsertValueInst>(I)) { - OpsChanged |= vectorizeInsertValueInst(LastInsertValue, BB, R); + OpsChanged |= + vectorizeInsertValueInst(LastInsertValue, BB, R, /*MaxVFOnly=*/false); } else if (auto *LastInsertElem = dyn_cast<InsertElementInst>(I)) { - OpsChanged |= vectorizeInsertElementInst(LastInsertElem, BB, R); + OpsChanged |= vectorizeInsertElementInst(LastInsertElem, BB, R, + /*MaxVFOnly=*/false); } } // Now try to vectorize postponed instructions. @@ -15969,22 +18711,11 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { if (Opcodes1.size() > Opcodes2.size()) return false; for (int I = 0, E = Opcodes1.size(); I < E; ++I) { - // Undefs are compatible with any other value. - if (isa<UndefValue>(Opcodes1[I]) || isa<UndefValue>(Opcodes2[I])) { - if (isa<Instruction>(Opcodes1[I])) - return true; - if (isa<Instruction>(Opcodes2[I])) - return false; - if (isa<Constant>(Opcodes1[I]) && !isa<UndefValue>(Opcodes1[I])) - return true; - if (isa<Constant>(Opcodes2[I]) && !isa<UndefValue>(Opcodes2[I])) - return false; - if (isa<UndefValue>(Opcodes1[I]) && isa<UndefValue>(Opcodes2[I])) - continue; - return isa<UndefValue>(Opcodes2[I]); - } - if (auto *I1 = dyn_cast<Instruction>(Opcodes1[I])) - if (auto *I2 = dyn_cast<Instruction>(Opcodes2[I])) { + { + // Instructions come first. + auto *I1 = dyn_cast<Instruction>(Opcodes1[I]); + auto *I2 = dyn_cast<Instruction>(Opcodes2[I]); + if (I1 && I2) { DomTreeNodeBase<BasicBlock> *NodeI1 = DT->getNode(I1->getParent()); DomTreeNodeBase<BasicBlock> *NodeI2 = DT->getNode(I2->getParent()); if (!NodeI1) @@ -16001,24 +18732,48 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { continue; return I1->getOpcode() < I2->getOpcode(); } - if (isa<Constant>(Opcodes1[I]) && isa<Constant>(Opcodes2[I])) - return Opcodes1[I]->getValueID() < Opcodes2[I]->getValueID(); - if (isa<Instruction>(Opcodes1[I])) - return true; - if (isa<Instruction>(Opcodes2[I])) - return false; - if (isa<Constant>(Opcodes1[I])) - return true; - if (isa<Constant>(Opcodes2[I])) - return false; - if (Opcodes1[I]->getValueID() < Opcodes2[I]->getValueID()) - return true; - if (Opcodes1[I]->getValueID() > Opcodes2[I]->getValueID()) - return false; + if (I1) + return true; + if (I2) + return false; + } + { + // Non-undef constants come next. + bool C1 = isa<Constant>(Opcodes1[I]) && !isa<UndefValue>(Opcodes1[I]); + bool C2 = isa<Constant>(Opcodes2[I]) && !isa<UndefValue>(Opcodes2[I]); + if (C1 && C2) + continue; + if (C1) + return true; + if (C2) + return false; + } + bool U1 = isa<UndefValue>(Opcodes1[I]); + bool U2 = isa<UndefValue>(Opcodes2[I]); + { + // Non-constant non-instructions come next. + if (!U1 && !U2) { + auto ValID1 = Opcodes1[I]->getValueID(); + auto ValID2 = Opcodes2[I]->getValueID(); + if (ValID1 == ValID2) + continue; + if (ValID1 < ValID2) + return true; + if (ValID1 > ValID2) + return false; + } + if (!U1) + return true; + if (!U2) + return false; + } + // Undefs come last. + assert(U1 && U2 && "The only thing left should be undef & undef."); + continue; } return false; }; - auto AreCompatiblePHIs = [&PHIToOpcodes, this](Value *V1, Value *V2) { + auto AreCompatiblePHIs = [&PHIToOpcodes, this, &R](Value *V1, Value *V2) { if (V1 == V2) return true; if (V1->getType() != V2->getType()) @@ -16033,6 +18788,8 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { continue; if (auto *I1 = dyn_cast<Instruction>(Opcodes1[I])) if (auto *I2 = dyn_cast<Instruction>(Opcodes2[I])) { + if (R.isDeleted(I1) || R.isDeleted(I2)) + return false; if (I1->getParent() != I2->getParent()) return false; InstructionsState S = getSameOpcode({I1, I2}, *TLI); @@ -16053,8 +18810,8 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { // Collect the incoming values from the PHIs. Incoming.clear(); for (Instruction &I : *BB) { - PHINode *P = dyn_cast<PHINode>(&I); - if (!P) + auto *P = dyn_cast<PHINode>(&I); + if (!P || P->getNumIncomingValues() > MaxPHINumOperands) break; // No need to analyze deleted, vectorized and non-vectorizable @@ -16097,6 +18854,11 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { }, /*MaxVFOnly=*/true, R); Changed |= HaveVectorizedPhiNodes; + if (HaveVectorizedPhiNodes && any_of(PHIToOpcodes, [&](const auto &P) { + auto *PHI = dyn_cast<PHINode>(P.first); + return !PHI || R.isDeleted(PHI); + })) + PHIToOpcodes.clear(); VisitedInstrs.insert(Incoming.begin(), Incoming.end()); } while (HaveVectorizedPhiNodes); @@ -16169,7 +18931,7 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { } // Try to vectorize the incoming values of the PHI, to catch reductions // that feed into PHIs. - for (unsigned I = 0, E = P->getNumIncomingValues(); I != E; I++) { + for (unsigned I : seq<unsigned>(P->getNumIncomingValues())) { // Skip if the incoming block is the current BB for now. Also, bypass // unreachable IR for efficiency and to avoid crashing. // TODO: Collect the skipped incoming values and try to vectorize them @@ -16181,9 +18943,16 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { // Postponed instructions should not be vectorized here, delay their // vectorization. if (auto *PI = dyn_cast<Instruction>(P->getIncomingValue(I)); - PI && !IsInPostProcessInstrs(PI)) - Changed |= vectorizeRootInstruction(nullptr, PI, + PI && !IsInPostProcessInstrs(PI)) { + bool Res = vectorizeRootInstruction(nullptr, PI, P->getIncomingBlock(I), R, TTI); + Changed |= Res; + if (Res && R.isDeleted(P)) { + It = BB->begin(); + E = BB->end(); + break; + } + } } continue; } @@ -16253,8 +19022,13 @@ bool SLPVectorizerPass::vectorizeGEPIndices(BasicBlock *BB, BoUpSLP &R) { // are trying to vectorize the index computations, so the maximum number of // elements is based on the size of the index expression, rather than the // size of the GEP itself (the target's pointer size). + auto *It = find_if(Entry.second, [&](GetElementPtrInst *GEP) { + return !R.isDeleted(GEP); + }); + if (It == Entry.second.end()) + continue; unsigned MaxVecRegSize = R.getMaxVecRegSize(); - unsigned EltSize = R.getVectorElementSize(*Entry.second[0]->idx_begin()); + unsigned EltSize = R.getVectorElementSize(*(*It)->idx_begin()); if (MaxVecRegSize < EltSize) continue; @@ -16405,6 +19179,7 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { }; // Attempt to sort and vectorize each of the store-groups. + DenseSet<std::tuple<Value *, Value *, Value *, Value *, unsigned>> Attempted; for (auto &Pair : Stores) { if (Pair.second.size() < 2) continue; @@ -16422,8 +19197,8 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { Pair.second.rend()); Changed |= tryToVectorizeSequence<StoreInst>( ReversedStores, StoreSorter, AreCompatibleStores, - [this, &R](ArrayRef<StoreInst *> Candidates, bool) { - return vectorizeStores(Candidates, R); + [&](ArrayRef<StoreInst *> Candidates, bool) { + return vectorizeStores(Candidates, R, Attempted); }, /*MaxVFOnly=*/false, R); } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h index 4b3143aead46..b4c7ab02f928 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h @@ -21,10 +21,11 @@ class LoopVectorizationLegality; class LoopVectorizationCostModel; class TargetLibraryInfo; -using VPRecipeOrVPValueTy = PointerUnion<VPRecipeBase *, VPValue *>; - /// Helper class to create VPRecipies from IR instructions. class VPRecipeBuilder { + /// The VPlan new recipes are added to. + VPlan &Plan; + /// The loop that we evaluate. Loop *OrigLoop; @@ -51,9 +52,8 @@ class VPRecipeBuilder { EdgeMaskCacheTy EdgeMaskCache; BlockMaskCacheTy BlockMaskCache; - // VPlan-VPlan transformations support: Hold a mapping from ingredients to - // their recipe. To save on memory, only do so for selected ingredients, - // marked by having a nullptr entry in this map. + // VPlan construction support: Hold a mapping from ingredients to + // their recipe. DenseMap<Instruction *, VPRecipeBase *> Ingredient2Recipe; /// Cross-iteration reduction & first-order recurrence phis for which we need @@ -69,92 +69,78 @@ class VPRecipeBuilder { /// Check if the load or store instruction \p I should widened for \p /// Range.Start and potentially masked. Such instructions are handled by a /// recipe that takes an additional VPInstruction for the mask. - VPRecipeBase *tryToWidenMemory(Instruction *I, ArrayRef<VPValue *> Operands, - VFRange &Range, VPlanPtr &Plan); + VPWidenMemoryRecipe *tryToWidenMemory(Instruction *I, + ArrayRef<VPValue *> Operands, + VFRange &Range); /// Check if an induction recipe should be constructed for \p Phi. If so build /// and return it. If not, return null. - VPRecipeBase *tryToOptimizeInductionPHI(PHINode *Phi, - ArrayRef<VPValue *> Operands, - VPlan &Plan, VFRange &Range); + VPHeaderPHIRecipe *tryToOptimizeInductionPHI(PHINode *Phi, + ArrayRef<VPValue *> Operands, + VFRange &Range); /// Optimize the special case where the operand of \p I is a constant integer /// induction variable. VPWidenIntOrFpInductionRecipe * tryToOptimizeInductionTruncate(TruncInst *I, ArrayRef<VPValue *> Operands, - VFRange &Range, VPlan &Plan); + VFRange &Range); - /// Handle non-loop phi nodes. Return a VPValue, if all incoming values match - /// or a new VPBlendRecipe otherwise. Currently all such phi nodes are turned - /// into a sequence of select instructions as the vectorizer currently - /// performs full if-conversion. - VPRecipeOrVPValueTy tryToBlend(PHINode *Phi, ArrayRef<VPValue *> Operands, - VPlanPtr &Plan); + /// Handle non-loop phi nodes. Return a new VPBlendRecipe otherwise. Currently + /// all such phi nodes are turned into a sequence of select instructions as + /// the vectorizer currently performs full if-conversion. + VPBlendRecipe *tryToBlend(PHINode *Phi, ArrayRef<VPValue *> Operands); /// Handle call instructions. If \p CI can be widened for \p Range.Start, /// return a new VPWidenCallRecipe. Range.End may be decreased to ensure same /// decision from \p Range.Start to \p Range.End. VPWidenCallRecipe *tryToWidenCall(CallInst *CI, ArrayRef<VPValue *> Operands, - VFRange &Range, VPlanPtr &Plan); + VFRange &Range); /// Check if \p I has an opcode that can be widened and return a VPWidenRecipe /// if it can. The function should only be called if the cost-model indicates /// that widening should be performed. - VPRecipeBase *tryToWiden(Instruction *I, ArrayRef<VPValue *> Operands, - VPBasicBlock *VPBB, VPlanPtr &Plan); - - /// Return a VPRecipeOrValueTy with VPRecipeBase * being set. This can be used to force the use as VPRecipeBase* for recipe sub-types that also inherit from VPValue. - VPRecipeOrVPValueTy toVPRecipeResult(VPRecipeBase *R) const { return R; } + VPWidenRecipe *tryToWiden(Instruction *I, ArrayRef<VPValue *> Operands, + VPBasicBlock *VPBB); public: - VPRecipeBuilder(Loop *OrigLoop, const TargetLibraryInfo *TLI, + VPRecipeBuilder(VPlan &Plan, Loop *OrigLoop, const TargetLibraryInfo *TLI, LoopVectorizationLegality *Legal, LoopVectorizationCostModel &CM, PredicatedScalarEvolution &PSE, VPBuilder &Builder) - : OrigLoop(OrigLoop), TLI(TLI), Legal(Legal), CM(CM), PSE(PSE), - Builder(Builder) {} - - /// Check if an existing VPValue can be used for \p Instr or a recipe can be - /// create for \p I withing the given VF \p Range. If an existing VPValue can - /// be used or if a recipe can be created, return it. Otherwise return a - /// VPRecipeOrVPValueTy with nullptr. - VPRecipeOrVPValueTy tryToCreateWidenRecipe(Instruction *Instr, - ArrayRef<VPValue *> Operands, - VFRange &Range, VPBasicBlock *VPBB, - VPlanPtr &Plan); - - /// Set the recipe created for given ingredient. This operation is a no-op for - /// ingredients that were not marked using a nullptr entry in the map. + : Plan(Plan), OrigLoop(OrigLoop), TLI(TLI), Legal(Legal), CM(CM), + PSE(PSE), Builder(Builder) {} + + /// Create and return a widened recipe for \p I if one can be created within + /// the given VF \p Range. + VPRecipeBase *tryToCreateWidenRecipe(Instruction *Instr, + ArrayRef<VPValue *> Operands, + VFRange &Range, VPBasicBlock *VPBB); + + /// Set the recipe created for given ingredient. void setRecipe(Instruction *I, VPRecipeBase *R) { - if (!Ingredient2Recipe.count(I)) - return; - assert(Ingredient2Recipe[I] == nullptr && - "Recipe already set for ingredient"); + assert(!Ingredient2Recipe.contains(I) && + "Cannot reset recipe for instruction."); Ingredient2Recipe[I] = R; } /// Create the mask for the vector loop header block. - void createHeaderMask(VPlan &Plan); + void createHeaderMask(); /// A helper function that computes the predicate of the block BB, assuming /// that the header block of the loop is set to True or the loop mask when /// tail folding. - void createBlockInMask(BasicBlock *BB, VPlan &Plan); + void createBlockInMask(BasicBlock *BB); /// Returns the *entry* mask for the block \p BB. VPValue *getBlockInMask(BasicBlock *BB) const; /// A helper function that computes the predicate of the edge between SRC /// and DST. - VPValue *createEdgeMask(BasicBlock *Src, BasicBlock *Dst, VPlan &Plan); - - /// Mark given ingredient for recording its recipe once one is created for - /// it. - void recordRecipeOf(Instruction *I) { - assert((!Ingredient2Recipe.count(I) || Ingredient2Recipe[I] == nullptr) && - "Recipe already set for ingredient"); - Ingredient2Recipe[I] = nullptr; - } + VPValue *createEdgeMask(BasicBlock *Src, BasicBlock *Dst); + + /// A helper that returns the previously computed predicate of the edge + /// between SRC and DST. + VPValue *getEdgeMask(BasicBlock *Src, BasicBlock *Dst) const; /// Return the recipe created for given ingredient. VPRecipeBase *getRecipe(Instruction *I) { @@ -168,12 +154,24 @@ public: /// Build a VPReplicationRecipe for \p I. If it is predicated, add the mask as /// last operand. Range.End may be decreased to ensure same recipe behavior /// from \p Range.Start to \p Range.End. - VPRecipeOrVPValueTy handleReplication(Instruction *I, VFRange &Range, - VPlan &Plan); + VPReplicateRecipe *handleReplication(Instruction *I, VFRange &Range); /// Add the incoming values from the backedge to reduction & first-order /// recurrence cross-iteration phis. void fixHeaderPhis(); + + /// Returns a range mapping the values of the range \p Operands to their + /// corresponding VPValues. + iterator_range<mapped_iterator<Use *, std::function<VPValue *(Value *)>>> + mapToVPValues(User::op_range Operands); + + VPValue *getVPValueOrAddLiveIn(Value *V, VPlan &Plan) { + if (auto *I = dyn_cast<Instruction>(V)) { + if (auto *R = Ingredient2Recipe.lookup(I)) + return R->getVPSingleValue(); + } + return Plan.getOrAddLiveIn(V); + } }; } // end namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.cpp index 3eeb1a6948f2..58de6256900f 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -17,13 +17,16 @@ //===----------------------------------------------------------------------===// #include "VPlan.h" +#include "LoopVectorizationPlanner.h" #include "VPlanCFG.h" #include "VPlanDominatorTree.h" +#include "VPlanPatternMatch.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/Twine.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -46,6 +49,7 @@ #include <vector> using namespace llvm; +using namespace llvm::VPlanPatternMatch; namespace llvm { extern cl::opt<bool> EnableVPlanNativePath; @@ -212,6 +216,14 @@ VPBasicBlock::iterator VPBasicBlock::getFirstNonPhi() { return It; } +VPTransformState::VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI, + DominatorTree *DT, IRBuilderBase &Builder, + InnerLoopVectorizer *ILV, VPlan *Plan, + LLVMContext &Ctx) + : VF(VF), UF(UF), CFG(DT), LI(LI), Builder(Builder), ILV(ILV), Plan(Plan), + LVer(nullptr), + TypeAnalysis(Plan->getCanonicalIV()->getScalarType(), Ctx) {} + Value *VPTransformState::get(VPValue *Def, const VPIteration &Instance) { if (Def->isLiveIn()) return Def->getLiveInIRValue(); @@ -220,6 +232,11 @@ Value *VPTransformState::get(VPValue *Def, const VPIteration &Instance) { return Data .PerPartScalars[Def][Instance.Part][Instance.Lane.mapToCacheIndex(VF)]; } + if (!Instance.Lane.isFirstLane() && + vputils::isUniformAfterVectorization(Def) && + hasScalarValue(Def, {Instance.Part, VPLane::getFirstLane()})) { + return Data.PerPartScalars[Def][Instance.Part][0]; + } assert(hasVectorValue(Def, Instance.Part)); auto *VecPart = Data.PerPartOutput[Def][Instance.Part]; @@ -234,7 +251,17 @@ Value *VPTransformState::get(VPValue *Def, const VPIteration &Instance) { return Extract; } -Value *VPTransformState::get(VPValue *Def, unsigned Part) { +Value *VPTransformState::get(VPValue *Def, unsigned Part, bool NeedsScalar) { + if (NeedsScalar) { + assert((VF.isScalar() || Def->isLiveIn() || hasVectorValue(Def, Part) || + !vputils::onlyFirstLaneUsed(Def) || + (hasScalarValue(Def, VPIteration(Part, 0)) && + Data.PerPartScalars[Def][Part].size() == 1)) && + "Trying to access a single scalar per part but has multiple scalars " + "per part."); + return get(Def, VPIteration(Part, 0)); + } + // If Values have been set for this Def return the one relevant for \p Part. if (hasVectorValue(Def, Part)) return Data.PerPartOutput[Def][Part]; @@ -339,23 +366,14 @@ void VPTransformState::addNewMetadata(Instruction *To, LVer->annotateInstWithNoAlias(To, Orig); } -void VPTransformState::addMetadata(Instruction *To, Instruction *From) { - // No source instruction to transfer metadata from? - if (!From) - return; - - propagateMetadata(To, From); - addNewMetadata(To, From); -} - -void VPTransformState::addMetadata(ArrayRef<Value *> To, Instruction *From) { +void VPTransformState::addMetadata(Value *To, Instruction *From) { // No source instruction to transfer metadata from? if (!From) return; - for (Value *V : To) { - if (Instruction *I = dyn_cast<Instruction>(V)) - addMetadata(I, From); + if (Instruction *ToI = dyn_cast<Instruction>(To)) { + propagateMetadata(ToI, From); + addNewMetadata(ToI, From); } } @@ -426,10 +444,42 @@ VPBasicBlock::createEmptyBasicBlock(VPTransformState::CFGState &CFG) { "Trying to reset an existing successor block."); TermBr->setSuccessor(idx, NewBB); } + CFG.DTU.applyUpdates({{DominatorTree::Insert, PredBB, NewBB}}); } return NewBB; } +void VPIRBasicBlock::execute(VPTransformState *State) { + assert(getHierarchicalSuccessors().size() <= 2 && + "VPIRBasicBlock can have at most two successors at the moment!"); + State->Builder.SetInsertPoint(getIRBasicBlock()->getTerminator()); + executeRecipes(State, getIRBasicBlock()); + if (getSingleSuccessor()) { + assert(isa<UnreachableInst>(getIRBasicBlock()->getTerminator())); + auto *Br = State->Builder.CreateBr(getIRBasicBlock()); + Br->setOperand(0, nullptr); + getIRBasicBlock()->getTerminator()->eraseFromParent(); + } + + for (VPBlockBase *PredVPBlock : getHierarchicalPredecessors()) { + VPBasicBlock *PredVPBB = PredVPBlock->getExitingBasicBlock(); + BasicBlock *PredBB = State->CFG.VPBB2IRBB[PredVPBB]; + assert(PredBB && "Predecessor basic-block not found building successor."); + LLVM_DEBUG(dbgs() << "LV: draw edge from" << PredBB->getName() << '\n'); + + auto *PredBBTerminator = PredBB->getTerminator(); + auto *TermBr = cast<BranchInst>(PredBBTerminator); + // Set each forward successor here when it is created, excluding + // backedges. A backward successor is set when the branch is created. + const auto &PredVPSuccessors = PredVPBB->getHierarchicalSuccessors(); + unsigned idx = PredVPSuccessors.front() == this ? 0 : 1; + assert(!TermBr->getSuccessor(idx) && + "Trying to reset an existing successor block."); + TermBr->setSuccessor(idx, IRBB); + State->CFG.DTU.applyUpdates({{DominatorTree::Insert, PredBB, IRBB}}); + } +} + void VPBasicBlock::execute(VPTransformState *State) { bool Replica = State->Instance && !State->Instance->isFirstIteration(); VPBasicBlock *PrevVPBB = State->CFG.PrevVPBB; @@ -441,29 +491,14 @@ void VPBasicBlock::execute(VPTransformState *State) { return R && !R->isReplicator(); }; - // 1. Create an IR basic block, or reuse the last one or ExitBB if possible. - if (getPlan()->getVectorLoopRegion()->getSingleSuccessor() == this) { - // ExitBB can be re-used for the exit block of the Plan. - NewBB = State->CFG.ExitBB; - State->CFG.PrevBB = NewBB; - State->Builder.SetInsertPoint(NewBB->getFirstNonPHI()); - - // Update the branch instruction in the predecessor to branch to ExitBB. - VPBlockBase *PredVPB = getSingleHierarchicalPredecessor(); - VPBasicBlock *ExitingVPBB = PredVPB->getExitingBasicBlock(); - assert(PredVPB->getSingleSuccessor() == this && - "predecessor must have the current block as only successor"); - BasicBlock *ExitingBB = State->CFG.VPBB2IRBB[ExitingVPBB]; - // The Exit block of a loop is always set to be successor 0 of the Exiting - // block. - cast<BranchInst>(ExitingBB->getTerminator())->setSuccessor(0, NewBB); - } else if (PrevVPBB && /* A */ - !((SingleHPred = getSingleHierarchicalPredecessor()) && - SingleHPred->getExitingBasicBlock() == PrevVPBB && - PrevVPBB->getSingleHierarchicalSuccessor() && - (SingleHPred->getParent() == getEnclosingLoopRegion() && - !IsLoopRegion(SingleHPred))) && /* B */ - !(Replica && getPredecessors().empty())) { /* C */ + // 1. Create an IR basic block. + if (PrevVPBB && /* A */ + !((SingleHPred = getSingleHierarchicalPredecessor()) && + SingleHPred->getExitingBasicBlock() == PrevVPBB && + PrevVPBB->getSingleHierarchicalSuccessor() && + (SingleHPred->getParent() == getEnclosingLoopRegion() && + !IsLoopRegion(SingleHPred))) && /* B */ + !(Replica && getPredecessors().empty())) { /* C */ // The last IR basic block is reused, as an optimization, in three cases: // A. the first VPBB reuses the loop pre-header BB - when PrevVPBB is null; // B. when the current VPBB has a single (hierarchical) predecessor which @@ -486,16 +521,7 @@ void VPBasicBlock::execute(VPTransformState *State) { } // 2. Fill the IR basic block with IR instructions. - LLVM_DEBUG(dbgs() << "LV: vectorizing VPBB:" << getName() - << " in BB:" << NewBB->getName() << '\n'); - - State->CFG.VPBB2IRBB[this] = NewBB; - State->CFG.PrevVPBB = this; - - for (VPRecipeBase &Recipe : Recipes) - Recipe.execute(*State); - - LLVM_DEBUG(dbgs() << "LV: filled BB:" << *NewBB); + executeRecipes(State, NewBB); } void VPBasicBlock::dropAllReferences(VPValue *NewValue) { @@ -508,6 +534,19 @@ void VPBasicBlock::dropAllReferences(VPValue *NewValue) { } } +void VPBasicBlock::executeRecipes(VPTransformState *State, BasicBlock *BB) { + LLVM_DEBUG(dbgs() << "LV: vectorizing VPBB:" << getName() + << " in BB:" << BB->getName() << '\n'); + + State->CFG.VPBB2IRBB[this] = BB; + State->CFG.PrevVPBB = this; + + for (VPRecipeBase &Recipe : Recipes) + Recipe.execute(*State); + + LLVM_DEBUG(dbgs() << "LV: filled BB:" << *BB); +} + VPBasicBlock *VPBasicBlock::splitAt(iterator SplitAt) { assert((SplitAt == end() || SplitAt->getParent() == this) && "can only split at a position in the same block"); @@ -552,14 +591,13 @@ static bool hasConditionalTerminator(const VPBasicBlock *VPBB) { } const VPRecipeBase *R = &VPBB->back(); - auto *VPI = dyn_cast<VPInstruction>(R); - bool IsCondBranch = - isa<VPBranchOnMaskRecipe>(R) || - (VPI && (VPI->getOpcode() == VPInstruction::BranchOnCond || - VPI->getOpcode() == VPInstruction::BranchOnCount)); + bool IsCondBranch = isa<VPBranchOnMaskRecipe>(R) || + match(R, m_BranchOnCond(m_VPValue())) || + match(R, m_BranchOnCount(m_VPValue(), m_VPValue())); (void)IsCondBranch; - if (VPBB->getNumSuccessors() >= 2 || VPBB->isExiting()) { + if (VPBB->getNumSuccessors() >= 2 || + (VPBB->isExiting() && !VPBB->getParent()->isReplicator())) { assert(IsCondBranch && "block with multiple successors not terminated by " "conditional branch recipe"); @@ -585,7 +623,7 @@ const VPRecipeBase *VPBasicBlock::getTerminator() const { } bool VPBasicBlock::isExiting() const { - return getParent()->getExitingBasicBlock() == this; + return getParent() && getParent()->getExitingBasicBlock() == this; } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -615,6 +653,72 @@ void VPBasicBlock::print(raw_ostream &O, const Twine &Indent, } #endif +static std::pair<VPBlockBase *, VPBlockBase *> cloneFrom(VPBlockBase *Entry); + +// Clone the CFG for all nodes reachable from \p Entry, this includes cloning +// the blocks and their recipes. Operands of cloned recipes will NOT be updated. +// Remapping of operands must be done separately. Returns a pair with the new +// entry and exiting blocks of the cloned region. If \p Entry isn't part of a +// region, return nullptr for the exiting block. +static std::pair<VPBlockBase *, VPBlockBase *> cloneFrom(VPBlockBase *Entry) { + DenseMap<VPBlockBase *, VPBlockBase *> Old2NewVPBlocks; + VPBlockBase *Exiting = nullptr; + bool InRegion = Entry->getParent(); + // First, clone blocks reachable from Entry. + for (VPBlockBase *BB : vp_depth_first_shallow(Entry)) { + VPBlockBase *NewBB = BB->clone(); + Old2NewVPBlocks[BB] = NewBB; + if (InRegion && BB->getNumSuccessors() == 0) { + assert(!Exiting && "Multiple exiting blocks?"); + Exiting = BB; + } + } + assert((!InRegion || Exiting) && "regions must have a single exiting block"); + + // Second, update the predecessors & successors of the cloned blocks. + for (VPBlockBase *BB : vp_depth_first_shallow(Entry)) { + VPBlockBase *NewBB = Old2NewVPBlocks[BB]; + SmallVector<VPBlockBase *> NewPreds; + for (VPBlockBase *Pred : BB->getPredecessors()) { + NewPreds.push_back(Old2NewVPBlocks[Pred]); + } + NewBB->setPredecessors(NewPreds); + SmallVector<VPBlockBase *> NewSuccs; + for (VPBlockBase *Succ : BB->successors()) { + NewSuccs.push_back(Old2NewVPBlocks[Succ]); + } + NewBB->setSuccessors(NewSuccs); + } + +#if !defined(NDEBUG) + // Verify that the order of predecessors and successors matches in the cloned + // version. + for (const auto &[OldBB, NewBB] : + zip(vp_depth_first_shallow(Entry), + vp_depth_first_shallow(Old2NewVPBlocks[Entry]))) { + for (const auto &[OldPred, NewPred] : + zip(OldBB->getPredecessors(), NewBB->getPredecessors())) + assert(NewPred == Old2NewVPBlocks[OldPred] && "Different predecessors"); + + for (const auto &[OldSucc, NewSucc] : + zip(OldBB->successors(), NewBB->successors())) + assert(NewSucc == Old2NewVPBlocks[OldSucc] && "Different successors"); + } +#endif + + return std::make_pair(Old2NewVPBlocks[Entry], + Exiting ? Old2NewVPBlocks[Exiting] : nullptr); +} + +VPRegionBlock *VPRegionBlock::clone() { + const auto &[NewEntry, NewExiting] = cloneFrom(getEntry()); + auto *NewRegion = + new VPRegionBlock(NewEntry, NewExiting, getName(), isReplicator()); + for (VPBlockBase *Block : vp_depth_first_shallow(NewEntry)) + Block->setParent(NewRegion); + return NewRegion; +} + void VPRegionBlock::dropAllReferences(VPValue *NewValue) { for (VPBlockBase *Block : vp_depth_first_shallow(Entry)) // Drop all references in VPBasicBlocks and replace all uses with @@ -673,6 +777,48 @@ void VPRegionBlock::execute(VPTransformState *State) { State->Instance.reset(); } +InstructionCost VPBasicBlock::cost(ElementCount VF, VPCostContext &Ctx) { + InstructionCost Cost = 0; + for (VPRecipeBase &R : Recipes) + Cost += R.cost(VF, Ctx); + return Cost; +} + +InstructionCost VPRegionBlock::cost(ElementCount VF, VPCostContext &Ctx) { + if (!isReplicator()) { + InstructionCost Cost = 0; + for (VPBlockBase *Block : vp_depth_first_shallow(getEntry())) + Cost += Block->cost(VF, Ctx); + InstructionCost BackedgeCost = + Ctx.TTI.getCFInstrCost(Instruction::Br, TTI::TCK_RecipThroughput); + LLVM_DEBUG(dbgs() << "Cost of " << BackedgeCost << " for VF " << VF + << ": vector loop backedge\n"); + Cost += BackedgeCost; + return Cost; + } + + // Compute the cost of a replicate region. Replicating isn't supported for + // scalable vectors, return an invalid cost for them. + // TODO: Discard scalable VPlans with replicate recipes earlier after + // construction. + if (VF.isScalable()) + return InstructionCost::getInvalid(); + + // First compute the cost of the conditionally executed recipes, followed by + // account for the branching cost, except if the mask is a header mask or + // uniform condition. + using namespace llvm::VPlanPatternMatch; + VPBasicBlock *Then = cast<VPBasicBlock>(getEntry()->getSuccessors()[0]); + InstructionCost ThenCost = Then->cost(VF, Ctx); + + // For the scalar case, we may not always execute the original predicated + // block, Thus, scale the block's cost by the probability of executing it. + if (VF.isScalar()) + return ThenCost / getReciprocalPredBlockProb(); + + return ThenCost; +} + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPRegionBlock::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { @@ -709,17 +855,61 @@ VPlan::~VPlan() { delete BackedgeTakenCount; } -VPlanPtr VPlan::createInitialVPlan(const SCEV *TripCount, ScalarEvolution &SE) { - VPBasicBlock *Preheader = new VPBasicBlock("ph"); +VPlanPtr VPlan::createInitialVPlan(const SCEV *TripCount, ScalarEvolution &SE, + bool RequiresScalarEpilogueCheck, + bool TailFolded, Loop *TheLoop) { + VPIRBasicBlock *Entry = new VPIRBasicBlock(TheLoop->getLoopPreheader()); VPBasicBlock *VecPreheader = new VPBasicBlock("vector.ph"); - auto Plan = std::make_unique<VPlan>(Preheader, VecPreheader); + auto Plan = std::make_unique<VPlan>(Entry, VecPreheader); Plan->TripCount = vputils::getOrCreateVPValueForSCEVExpr(*Plan, TripCount, SE); - // Create empty VPRegionBlock, to be filled during processing later. - auto *TopRegion = new VPRegionBlock("vector loop", false /*isReplicator*/); + // Create VPRegionBlock, with empty header and latch blocks, to be filled + // during processing later. + VPBasicBlock *HeaderVPBB = new VPBasicBlock("vector.body"); + VPBasicBlock *LatchVPBB = new VPBasicBlock("vector.latch"); + VPBlockUtils::insertBlockAfter(LatchVPBB, HeaderVPBB); + auto *TopRegion = new VPRegionBlock(HeaderVPBB, LatchVPBB, "vector loop", + false /*isReplicator*/); + VPBlockUtils::insertBlockAfter(TopRegion, VecPreheader); VPBasicBlock *MiddleVPBB = new VPBasicBlock("middle.block"); VPBlockUtils::insertBlockAfter(MiddleVPBB, TopRegion); + + VPBasicBlock *ScalarPH = new VPBasicBlock("scalar.ph"); + if (!RequiresScalarEpilogueCheck) { + VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH); + return Plan; + } + + // If needed, add a check in the middle block to see if we have completed + // all of the iterations in the first vector loop. Three cases: + // 1) If (N - N%VF) == N, then we *don't* need to run the remainder. + // Thus if tail is to be folded, we know we don't need to run the + // remainder and we can set the condition to true. + // 2) If we require a scalar epilogue, there is no conditional branch as + // we unconditionally branch to the scalar preheader. Do nothing. + // 3) Otherwise, construct a runtime check. + BasicBlock *IRExitBlock = TheLoop->getUniqueExitBlock(); + auto *VPExitBlock = new VPIRBasicBlock(IRExitBlock); + // The connection order corresponds to the operands of the conditional branch. + VPBlockUtils::insertBlockAfter(VPExitBlock, MiddleVPBB); + VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH); + + auto *ScalarLatchTerm = TheLoop->getLoopLatch()->getTerminator(); + // Here we use the same DebugLoc as the scalar loop latch terminator instead + // of the corresponding compare because they may have ended up with + // different line numbers and we want to avoid awkward line stepping while + // debugging. Eg. if the compare has got a line number inside the loop. + VPBuilder Builder(MiddleVPBB); + VPValue *Cmp = + TailFolded + ? Plan->getOrAddLiveIn(ConstantInt::getTrue( + IntegerType::getInt1Ty(TripCount->getType()->getContext()))) + : Builder.createICmp(CmpInst::ICMP_EQ, Plan->getTripCount(), + &Plan->getVectorTripCount(), + ScalarLatchTerm->getDebugLoc(), "cmp.n"); + Builder.createNaryOp(VPInstruction::BranchOnCond, {Cmp}, + ScalarLatchTerm->getDebugLoc()); return Plan; } @@ -732,31 +922,26 @@ void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV, auto *TCMO = Builder.CreateSub(TripCountV, ConstantInt::get(TripCountV->getType(), 1), "trip.count.minus.1"); - auto VF = State.VF; - Value *VTCMO = - VF.isScalar() ? TCMO : Builder.CreateVectorSplat(VF, TCMO, "broadcast"); - for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part) - State.set(BackedgeTakenCount, VTCMO, Part); + BackedgeTakenCount->setUnderlyingValue(TCMO); } - for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part) - State.set(&VectorTripCount, VectorTripCountV, Part); + VectorTripCount.setUnderlyingValue(VectorTripCountV); IRBuilder<> Builder(State.CFG.PrevBB->getTerminator()); // FIXME: Model VF * UF computation completely in VPlan. - State.set(&VFxUF, - createStepForVF(Builder, TripCountV->getType(), State.VF, State.UF), - 0); + VFxUF.setUnderlyingValue( + createStepForVF(Builder, TripCountV->getType(), State.VF, State.UF)); // When vectorizing the epilogue loop, the canonical induction start value // needs to be changed from zero to the value after the main vector loop. // FIXME: Improve modeling for canonical IV start values in the epilogue loop. if (CanonicalIVStartValue) { - VPValue *VPV = getVPValueOrAddLiveIn(CanonicalIVStartValue); + VPValue *VPV = getOrAddLiveIn(CanonicalIVStartValue); auto *IV = getCanonicalIV(); assert(all_of(IV->users(), [](const VPUser *U) { return isa<VPScalarIVStepsRecipe>(U) || + isa<VPScalarCastRecipe>(U) || isa<VPDerivedIVRecipe>(U) || cast<VPInstruction>(U)->getOpcode() == Instruction::Add; @@ -767,20 +952,69 @@ void VPlan::prepareToExecute(Value *TripCountV, Value *VectorTripCountV, } } +/// Replace \p VPBB with a VPIRBasicBlock wrapping \p IRBB. All recipes from \p +/// VPBB are moved to the newly created VPIRBasicBlock. VPBB must have a single +/// predecessor, which is rewired to the new VPIRBasicBlock. All successors of +/// VPBB, if any, are rewired to the new VPIRBasicBlock. +static void replaceVPBBWithIRVPBB(VPBasicBlock *VPBB, BasicBlock *IRBB) { + VPIRBasicBlock *IRMiddleVPBB = new VPIRBasicBlock(IRBB); + for (auto &R : make_early_inc_range(*VPBB)) + R.moveBefore(*IRMiddleVPBB, IRMiddleVPBB->end()); + VPBlockBase *PredVPBB = VPBB->getSinglePredecessor(); + VPBlockUtils::disconnectBlocks(PredVPBB, VPBB); + VPBlockUtils::connectBlocks(PredVPBB, IRMiddleVPBB); + for (auto *Succ : to_vector(VPBB->getSuccessors())) { + VPBlockUtils::connectBlocks(IRMiddleVPBB, Succ); + VPBlockUtils::disconnectBlocks(VPBB, Succ); + } + delete VPBB; +} + /// Generate the code inside the preheader and body of the vectorized loop. /// Assumes a single pre-header basic-block was created for this. Introduce /// additional basic-blocks as needed, and fill them all. void VPlan::execute(VPTransformState *State) { - // Set the reverse mapping from VPValues to Values for code generation. - for (auto &Entry : Value2VPValue) - State->VPValue2Value[Entry.second] = Entry.first; - // Initialize CFG state. State->CFG.PrevVPBB = nullptr; State->CFG.ExitBB = State->CFG.PrevBB->getSingleSuccessor(); BasicBlock *VectorPreHeader = State->CFG.PrevBB; State->Builder.SetInsertPoint(VectorPreHeader->getTerminator()); + // Disconnect VectorPreHeader from ExitBB in both the CFG and DT. + cast<BranchInst>(VectorPreHeader->getTerminator())->setSuccessor(0, nullptr); + State->CFG.DTU.applyUpdates( + {{DominatorTree::Delete, VectorPreHeader, State->CFG.ExitBB}}); + + // Replace regular VPBB's for the middle and scalar preheader blocks with + // VPIRBasicBlocks wrapping their IR blocks. The IR blocks are created during + // skeleton creation, so we can only create the VPIRBasicBlocks now during + // VPlan execution rather than earlier during VPlan construction. + BasicBlock *MiddleBB = State->CFG.ExitBB; + VPBasicBlock *MiddleVPBB = + cast<VPBasicBlock>(getVectorLoopRegion()->getSingleSuccessor()); + // Find the VPBB for the scalar preheader, relying on the current structure + // when creating the middle block and its successrs: if there's a single + // predecessor, it must be the scalar preheader. Otherwise, the second + // successor is the scalar preheader. + BasicBlock *ScalarPh = MiddleBB->getSingleSuccessor(); + auto &MiddleSuccs = MiddleVPBB->getSuccessors(); + assert((MiddleSuccs.size() == 1 || MiddleSuccs.size() == 2) && + "middle block has unexpected successors"); + VPBasicBlock *ScalarPhVPBB = cast<VPBasicBlock>( + MiddleSuccs.size() == 1 ? MiddleSuccs[0] : MiddleSuccs[1]); + assert(!isa<VPIRBasicBlock>(ScalarPhVPBB) && + "scalar preheader cannot be wrapped already"); + replaceVPBBWithIRVPBB(ScalarPhVPBB, ScalarPh); + replaceVPBBWithIRVPBB(MiddleVPBB, MiddleBB); + + // Disconnect the middle block from its single successor (the scalar loop + // header) in both the CFG and DT. The branch will be recreated during VPlan + // execution. + auto *BrInst = new UnreachableInst(MiddleBB->getContext()); + BrInst->insertBefore(MiddleBB->getTerminator()); + MiddleBB->getTerminator()->eraseFromParent(); + State->CFG.DTU.applyUpdates({{DominatorTree::Delete, MiddleBB, ScalarPh}}); + // Generate code in the loop pre-header and body. for (VPBlockBase *Block : vp_depth_first_shallow(Entry)) Block->execute(State); @@ -803,11 +1037,8 @@ void VPlan::execute(VPTransformState *State) { Phi = cast<PHINode>(State->get(R.getVPSingleValue(), 0)); } else { auto *WidenPhi = cast<VPWidenPointerInductionRecipe>(&R); - // TODO: Split off the case that all users of a pointer phi are scalar - // from the VPWidenPointerInductionRecipe. - if (WidenPhi->onlyScalarsGenerated(State->VF)) - continue; - + assert(!WidenPhi->onlyScalarsGenerated(State->VF.isScalable()) && + "recipe generating only scalars should have been replaced"); auto *GEP = cast<GetElementPtrInst>(State->get(WidenPhi, 0)); Phi = cast<PHINode>(GEP->getPointerOperand()); } @@ -826,27 +1057,36 @@ void VPlan::execute(VPTransformState *State) { // only a single part is generated, which provides the last part from the // previous iteration. For non-ordered reductions all UF parts are // generated. - bool SinglePartNeeded = isa<VPCanonicalIVPHIRecipe>(PhiR) || - isa<VPFirstOrderRecurrencePHIRecipe>(PhiR) || - (isa<VPReductionPHIRecipe>(PhiR) && - cast<VPReductionPHIRecipe>(PhiR)->isOrdered()); + bool SinglePartNeeded = + isa<VPCanonicalIVPHIRecipe>(PhiR) || + isa<VPFirstOrderRecurrencePHIRecipe, VPEVLBasedIVPHIRecipe>(PhiR) || + (isa<VPReductionPHIRecipe>(PhiR) && + cast<VPReductionPHIRecipe>(PhiR)->isOrdered()); + bool NeedsScalar = + isa<VPCanonicalIVPHIRecipe, VPEVLBasedIVPHIRecipe>(PhiR) || + (isa<VPReductionPHIRecipe>(PhiR) && + cast<VPReductionPHIRecipe>(PhiR)->isInLoop()); unsigned LastPartForNewPhi = SinglePartNeeded ? 1 : State->UF; for (unsigned Part = 0; Part < LastPartForNewPhi; ++Part) { - Value *Phi = State->get(PhiR, Part); - Value *Val = State->get(PhiR->getBackedgeValue(), - SinglePartNeeded ? State->UF - 1 : Part); + Value *Phi = State->get(PhiR, Part, NeedsScalar); + Value *Val = + State->get(PhiR->getBackedgeValue(), + SinglePartNeeded ? State->UF - 1 : Part, NeedsScalar); cast<PHINode>(Phi)->addIncoming(Val, VectorLatchBB); } } - // We do not attempt to preserve DT for outer loop vectorization currently. - if (!EnableVPlanNativePath) { - BasicBlock *VectorHeaderBB = State->CFG.VPBB2IRBB[Header]; - State->DT->addNewBlock(VectorHeaderBB, VectorPreHeader); - updateDominatorTree(State->DT, VectorHeaderBB, VectorLatchBB, - State->CFG.ExitBB); - } + State->CFG.DTU.flush(); + assert(State->CFG.DTU.getDomTree().verify( + DominatorTree::VerificationLevel::Fast) && + "DT not preserved correctly"); +} + +InstructionCost VPlan::cost(ElementCount VF, VPCostContext &Ctx) { + // For now only return the cost of the vector loop region, ignoring any other + // blocks, like the preheader or middle blocks. + return getVectorLoopRegion()->cost(VF, Ctx); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -944,42 +1184,85 @@ void VPlan::addLiveOut(PHINode *PN, VPValue *V) { LiveOuts.insert({PN, new VPLiveOut(PN, V)}); } -void VPlan::updateDominatorTree(DominatorTree *DT, BasicBlock *LoopHeaderBB, - BasicBlock *LoopLatchBB, - BasicBlock *LoopExitBB) { - // The vector body may be more than a single basic-block by this point. - // Update the dominator tree information inside the vector body by propagating - // it from header to latch, expecting only triangular control-flow, if any. - BasicBlock *PostDomSucc = nullptr; - for (auto *BB = LoopHeaderBB; BB != LoopLatchBB; BB = PostDomSucc) { - // Get the list of successors of this block. - std::vector<BasicBlock *> Succs(succ_begin(BB), succ_end(BB)); - assert(Succs.size() <= 2 && - "Basic block in vector loop has more than 2 successors."); - PostDomSucc = Succs[0]; - if (Succs.size() == 1) { - assert(PostDomSucc->getSinglePredecessor() && - "PostDom successor has more than one predecessor."); - DT->addNewBlock(PostDomSucc, BB); - continue; - } - BasicBlock *InterimSucc = Succs[1]; - if (PostDomSucc->getSingleSuccessor() == InterimSucc) { - PostDomSucc = Succs[1]; - InterimSucc = Succs[0]; +static void remapOperands(VPBlockBase *Entry, VPBlockBase *NewEntry, + DenseMap<VPValue *, VPValue *> &Old2NewVPValues) { + // Update the operands of all cloned recipes starting at NewEntry. This + // traverses all reachable blocks. This is done in two steps, to handle cycles + // in PHI recipes. + ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> + OldDeepRPOT(Entry); + ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> + NewDeepRPOT(NewEntry); + // First, collect all mappings from old to new VPValues defined by cloned + // recipes. + for (const auto &[OldBB, NewBB] : + zip(VPBlockUtils::blocksOnly<VPBasicBlock>(OldDeepRPOT), + VPBlockUtils::blocksOnly<VPBasicBlock>(NewDeepRPOT))) { + assert(OldBB->getRecipeList().size() == NewBB->getRecipeList().size() && + "blocks must have the same number of recipes"); + for (const auto &[OldR, NewR] : zip(*OldBB, *NewBB)) { + assert(OldR.getNumOperands() == NewR.getNumOperands() && + "recipes must have the same number of operands"); + assert(OldR.getNumDefinedValues() == NewR.getNumDefinedValues() && + "recipes must define the same number of operands"); + for (const auto &[OldV, NewV] : + zip(OldR.definedValues(), NewR.definedValues())) + Old2NewVPValues[OldV] = NewV; } - assert(InterimSucc->getSingleSuccessor() == PostDomSucc && - "One successor of a basic block does not lead to the other."); - assert(InterimSucc->getSinglePredecessor() && - "Interim successor has more than one predecessor."); - assert(PostDomSucc->hasNPredecessors(2) && - "PostDom successor has more than two predecessors."); - DT->addNewBlock(InterimSucc, BB); - DT->addNewBlock(PostDomSucc, BB); } - // Latch block is a new dominator for the loop exit. - DT->changeImmediateDominator(LoopExitBB, LoopLatchBB); - assert(DT->verify(DominatorTree::VerificationLevel::Fast)); + + // Update all operands to use cloned VPValues. + for (VPBasicBlock *NewBB : + VPBlockUtils::blocksOnly<VPBasicBlock>(NewDeepRPOT)) { + for (VPRecipeBase &NewR : *NewBB) + for (unsigned I = 0, E = NewR.getNumOperands(); I != E; ++I) { + VPValue *NewOp = Old2NewVPValues.lookup(NewR.getOperand(I)); + NewR.setOperand(I, NewOp); + } + } +} + +VPlan *VPlan::duplicate() { + // Clone blocks. + VPBasicBlock *NewPreheader = Preheader->clone(); + const auto &[NewEntry, __] = cloneFrom(Entry); + + // Create VPlan, clone live-ins and remap operands in the cloned blocks. + auto *NewPlan = new VPlan(NewPreheader, cast<VPBasicBlock>(NewEntry)); + DenseMap<VPValue *, VPValue *> Old2NewVPValues; + for (VPValue *OldLiveIn : VPLiveInsToFree) { + Old2NewVPValues[OldLiveIn] = + NewPlan->getOrAddLiveIn(OldLiveIn->getLiveInIRValue()); + } + Old2NewVPValues[&VectorTripCount] = &NewPlan->VectorTripCount; + Old2NewVPValues[&VFxUF] = &NewPlan->VFxUF; + if (BackedgeTakenCount) { + NewPlan->BackedgeTakenCount = new VPValue(); + Old2NewVPValues[BackedgeTakenCount] = NewPlan->BackedgeTakenCount; + } + assert(TripCount && "trip count must be set"); + if (TripCount->isLiveIn()) + Old2NewVPValues[TripCount] = + NewPlan->getOrAddLiveIn(TripCount->getLiveInIRValue()); + // else NewTripCount will be created and inserted into Old2NewVPValues when + // TripCount is cloned. In any case NewPlan->TripCount is updated below. + + remapOperands(Preheader, NewPreheader, Old2NewVPValues); + remapOperands(Entry, NewEntry, Old2NewVPValues); + + // Clone live-outs. + for (const auto &[_, LO] : LiveOuts) + NewPlan->addLiveOut(LO->getPhi(), Old2NewVPValues[LO->getOperand(0)]); + + // Initialize remaining fields of cloned VPlan. + NewPlan->VFs = VFs; + NewPlan->UFs = UFs; + // TODO: Adjust names. + NewPlan->Name = Name; + assert(Old2NewVPValues.contains(TripCount) && + "TripCount must have been added to Old2NewVPValues"); + NewPlan->TripCount = Old2NewVPValues[TripCount]; + return NewPlan; } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -1168,18 +1451,7 @@ void VPValue::replaceUsesWithIf( #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPValue::printAsOperand(raw_ostream &OS, VPSlotTracker &Tracker) const { - if (const Value *UV = getUnderlyingValue()) { - OS << "ir<"; - UV->printAsOperand(OS, false); - OS << ">"; - return; - } - - unsigned Slot = Tracker.getSlot(this); - if (Slot == unsigned(-1)) - OS << "<badref>"; - else - OS << "vp<%" << Tracker.getSlot(this) << ">"; + OS << Tracker.getOrCreateName(this); } void VPUser::printOperands(raw_ostream &O, VPSlotTracker &SlotTracker) const { @@ -1203,7 +1475,7 @@ void VPInterleavedAccessInfo::visitBlock(VPBlockBase *Block, Old2NewTy &Old2New, InterleavedAccessInfo &IAI) { if (VPBasicBlock *VPBB = dyn_cast<VPBasicBlock>(Block)) { for (VPRecipeBase &VPI : *VPBB) { - if (isa<VPHeaderPHIRecipe>(&VPI)) + if (isa<VPWidenPHIRecipe>(&VPI)) continue; assert(isa<VPInstruction>(&VPI) && "Can only handle VPInstructions"); auto *VPInst = cast<VPInstruction>(&VPI); @@ -1241,40 +1513,98 @@ VPInterleavedAccessInfo::VPInterleavedAccessInfo(VPlan &Plan, visitRegion(Plan.getVectorLoopRegion(), Old2New, IAI); } -void VPSlotTracker::assignSlot(const VPValue *V) { - assert(!Slots.contains(V) && "VPValue already has a slot!"); - Slots[V] = NextSlot++; +void VPSlotTracker::assignName(const VPValue *V) { + assert(!VPValue2Name.contains(V) && "VPValue already has a name!"); + auto *UV = V->getUnderlyingValue(); + if (!UV) { + VPValue2Name[V] = (Twine("vp<%") + Twine(NextSlot) + ">").str(); + NextSlot++; + return; + } + + // Use the name of the underlying Value, wrapped in "ir<>", and versioned by + // appending ".Number" to the name if there are multiple uses. + std::string Name; + raw_string_ostream S(Name); + UV->printAsOperand(S, false); + assert(!Name.empty() && "Name cannot be empty."); + std::string BaseName = (Twine("ir<") + Name + Twine(">")).str(); + + // First assign the base name for V. + const auto &[A, _] = VPValue2Name.insert({V, BaseName}); + // Integer or FP constants with different types will result in he same string + // due to stripping types. + if (V->isLiveIn() && isa<ConstantInt, ConstantFP>(UV)) + return; + + // If it is already used by C > 0 other VPValues, increase the version counter + // C and use it for V. + const auto &[C, UseInserted] = BaseName2Version.insert({BaseName, 0}); + if (!UseInserted) { + C->second++; + A->second = (BaseName + Twine(".") + Twine(C->second)).str(); + } } -void VPSlotTracker::assignSlots(const VPlan &Plan) { +void VPSlotTracker::assignNames(const VPlan &Plan) { if (Plan.VFxUF.getNumUsers() > 0) - assignSlot(&Plan.VFxUF); - assignSlot(&Plan.VectorTripCount); + assignName(&Plan.VFxUF); + assignName(&Plan.VectorTripCount); if (Plan.BackedgeTakenCount) - assignSlot(Plan.BackedgeTakenCount); - assignSlots(Plan.getPreheader()); + assignName(Plan.BackedgeTakenCount); + for (VPValue *LI : Plan.VPLiveInsToFree) + assignName(LI); + assignNames(Plan.getPreheader()); ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<const VPBlockBase *>> RPOT(VPBlockDeepTraversalWrapper<const VPBlockBase *>(Plan.getEntry())); for (const VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<const VPBasicBlock>(RPOT)) - assignSlots(VPBB); + assignNames(VPBB); } -void VPSlotTracker::assignSlots(const VPBasicBlock *VPBB) { +void VPSlotTracker::assignNames(const VPBasicBlock *VPBB) { for (const VPRecipeBase &Recipe : *VPBB) for (VPValue *Def : Recipe.definedValues()) - assignSlot(Def); + assignName(Def); } -bool vputils::onlyFirstLaneUsed(VPValue *Def) { +std::string VPSlotTracker::getOrCreateName(const VPValue *V) const { + std::string Name = VPValue2Name.lookup(V); + if (!Name.empty()) + return Name; + + // If no name was assigned, no VPlan was provided when creating the slot + // tracker or it is not reachable from the provided VPlan. This can happen, + // e.g. when trying to print a recipe that has not been inserted into a VPlan + // in a debugger. + // TODO: Update VPSlotTracker constructor to assign names to recipes & + // VPValues not associated with a VPlan, instead of constructing names ad-hoc + // here. + const VPRecipeBase *DefR = V->getDefiningRecipe(); + (void)DefR; + assert((!DefR || !DefR->getParent() || !DefR->getParent()->getPlan()) && + "VPValue defined by a recipe in a VPlan?"); + + // Use the underlying value's name, if there is one. + if (auto *UV = V->getUnderlyingValue()) { + std::string Name; + raw_string_ostream S(Name); + UV->printAsOperand(S, false); + return (Twine("ir<") + Name + ">").str(); + } + + return "<badref>"; +} + +bool vputils::onlyFirstLaneUsed(const VPValue *Def) { return all_of(Def->users(), - [Def](VPUser *U) { return U->onlyFirstLaneUsed(Def); }); + [Def](const VPUser *U) { return U->onlyFirstLaneUsed(Def); }); } -bool vputils::onlyFirstPartUsed(VPValue *Def) { +bool vputils::onlyFirstPartUsed(const VPValue *Def) { return all_of(Def->users(), - [Def](VPUser *U) { return U->onlyFirstPartUsed(Def); }); + [Def](const VPUser *U) { return U->onlyFirstPartUsed(Def); }); } VPValue *vputils::getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr, @@ -1283,9 +1613,9 @@ VPValue *vputils::getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr, return Expanded; VPValue *Expanded = nullptr; if (auto *E = dyn_cast<SCEVConstant>(Expr)) - Expanded = Plan.getVPValueOrAddLiveIn(E->getValue()); + Expanded = Plan.getOrAddLiveIn(E->getValue()); else if (auto *E = dyn_cast<SCEVUnknown>(Expr)) - Expanded = Plan.getVPValueOrAddLiveIn(E->getValue()); + Expanded = Plan.getOrAddLiveIn(E->getValue()); else { Expanded = new VPExpandSCEVRecipe(Expr, SE); Plan.getPreheader()->appendRecipe(Expanded->getDefiningRecipe()); @@ -1293,3 +1623,23 @@ VPValue *vputils::getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr, Plan.addSCEVExpansion(Expr, Expanded); return Expanded; } + +bool vputils::isHeaderMask(VPValue *V, VPlan &Plan) { + if (isa<VPActiveLaneMaskPHIRecipe>(V)) + return true; + + auto IsWideCanonicalIV = [](VPValue *A) { + return isa<VPWidenCanonicalIVRecipe>(A) || + (isa<VPWidenIntOrFpInductionRecipe>(A) && + cast<VPWidenIntOrFpInductionRecipe>(A)->isCanonical()); + }; + + VPValue *A, *B; + if (match(V, m_ActiveLaneMask(m_VPValue(A), m_VPValue(B)))) + return B == Plan.getTripCount() && + (match(A, m_ScalarIVSteps(m_CanonicalIV(), m_SpecificInt(1))) || + IsWideCanonicalIV(A)); + + return match(V, m_Binary<Instruction::ICmp>(m_VPValue(A), m_VPValue(B))) && + IsWideCanonicalIV(A) && B == Plan.getOrCreateBackedgeTakenCount(); +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.h index 0c6214868d84..0b596e7e4f63 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlan.h @@ -35,12 +35,14 @@ #include "llvm/ADT/Twine.h" #include "llvm/ADT/ilist.h" #include "llvm/ADT/ilist_node.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/IVDescriptors.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/FMF.h" #include "llvm/IR/Operator.h" +#include "llvm/Support/InstructionCost.h" #include <algorithm> #include <cassert> #include <cstddef> @@ -63,8 +65,11 @@ class VPlan; class VPReplicateRecipe; class VPlanSlp; class Value; +class LoopVectorizationCostModel; class LoopVersioning; +struct VPCostContext; + namespace Intrinsic { typedef unsigned ID; } @@ -81,6 +86,14 @@ Value *createStepForVF(IRBuilderBase &B, Type *Ty, ElementCount VF, const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE, Loop *CurLoop = nullptr); +/// A helper function that returns the reciprocal of the block probability of +/// predicated blocks. If we return X, we are assuming the predicated block +/// will execute once for every X iterations of the loop header. +/// +/// TODO: We should use actual block probability here, if available. Currently, +/// we always assume predicated blocks have a 50% chance of executing. +inline unsigned getReciprocalPredBlockProb() { return 2; } + /// A range of powers-of-2 vectorization factors with fixed start and /// adjustable end. The range includes start and excludes end, e.g.,: /// [1, 16) = {1, 2, 4, 8} @@ -166,8 +179,10 @@ public: static VPLane getFirstLane() { return VPLane(0, VPLane::Kind::First); } - static VPLane getLastLaneForVF(const ElementCount &VF) { - unsigned LaneOffset = VF.getKnownMinValue() - 1; + static VPLane getLaneFromEnd(const ElementCount &VF, unsigned Offset) { + assert(Offset > 0 && Offset <= VF.getKnownMinValue() && + "trying to extract with invalid offset"); + unsigned LaneOffset = VF.getKnownMinValue() - Offset; Kind LaneKind; if (VF.isScalable()) // In this case 'LaneOffset' refers to the offset from the start of the @@ -178,6 +193,10 @@ public: return VPLane(LaneOffset, LaneKind); } + static VPLane getLastLaneForVF(const ElementCount &VF) { + return getLaneFromEnd(VF, 1); + } + /// Returns a compile-time known value for the lane index and asserts if the /// lane can only be calculated at runtime. unsigned getKnownLane() const { @@ -236,9 +255,7 @@ struct VPIteration { struct VPTransformState { VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI, DominatorTree *DT, IRBuilderBase &Builder, - InnerLoopVectorizer *ILV, VPlan *Plan, LLVMContext &Ctx) - : VF(VF), UF(UF), LI(LI), DT(DT), Builder(Builder), ILV(ILV), Plan(Plan), - LVer(nullptr), TypeAnalysis(Ctx) {} + InnerLoopVectorizer *ILV, VPlan *Plan, LLVMContext &Ctx); /// The chosen Vectorization and Unroll Factors of the loop being vectorized. ElementCount VF; @@ -261,12 +278,10 @@ struct VPTransformState { DenseMap<VPValue *, ScalarsPerPartValuesTy> PerPartScalars; } Data; - /// Get the generated Value for a given VPValue and a given Part. Note that - /// as some Defs are still created by ILV and managed in its ValueMap, this - /// method will delegate the call to ILV in such cases in order to provide - /// callers a consistent API. - /// \see set. - Value *get(VPValue *Def, unsigned Part); + /// Get the generated vector Value for a given VPValue \p Def and a given \p + /// Part if \p IsScalar is false, otherwise return the generated scalar + /// for \p Part. \See set. + Value *get(VPValue *Def, unsigned Part, bool IsScalar = false); /// Get the generated Value for a given VPValue and given Part and Lane. Value *get(VPValue *Def, const VPIteration &Instance); @@ -287,14 +302,22 @@ struct VPTransformState { I->second[Instance.Part][CacheIdx]; } - /// Set the generated Value for a given VPValue and a given Part. - void set(VPValue *Def, Value *V, unsigned Part) { + /// Set the generated vector Value for a given VPValue and a given Part, if \p + /// IsScalar is false. If \p IsScalar is true, set the scalar in (Part, 0). + void set(VPValue *Def, Value *V, unsigned Part, bool IsScalar = false) { + if (IsScalar) { + set(Def, V, VPIteration(Part, 0)); + return; + } + assert((VF.isScalar() || V->getType()->isVectorTy()) && + "scalar values must be stored as (Part, 0)"); if (!Data.PerPartOutput.count(Def)) { DataState::PerPartValuesTy Entry(UF); Data.PerPartOutput[Def] = Entry; } Data.PerPartOutput[Def][Part] = V; } + /// Reset an existing vector value for \p Def and a given \p Part. void reset(VPValue *Def, Value *V, unsigned Part) { auto Iter = Data.PerPartOutput.find(Def); @@ -307,12 +330,12 @@ struct VPTransformState { void set(VPValue *Def, Value *V, const VPIteration &Instance) { auto Iter = Data.PerPartScalars.insert({Def, {}}); auto &PerPartVec = Iter.first->second; - while (PerPartVec.size() <= Instance.Part) - PerPartVec.emplace_back(); + if (PerPartVec.size() <= Instance.Part) + PerPartVec.resize(Instance.Part + 1); auto &Scalars = PerPartVec[Instance.Part]; unsigned CacheIdx = Instance.Lane.mapToCacheIndex(VF); - while (Scalars.size() <= CacheIdx) - Scalars.push_back(nullptr); + if (Scalars.size() <= CacheIdx) + Scalars.resize(CacheIdx + 1); assert(!Scalars[CacheIdx] && "should overwrite existing value"); Scalars[CacheIdx] = V; } @@ -342,11 +365,7 @@ struct VPTransformState { /// This includes both the original MDs from \p From and additional ones (\see /// addNewMetadata). Use this for *newly created* instructions in the vector /// loop. - void addMetadata(Instruction *To, Instruction *From); - - /// Similar to the previous function but it adds the metadata to a - /// vector of instructions. - void addMetadata(ArrayRef<Value *> To, Instruction *From); + void addMetadata(Value *To, Instruction *From); /// Set the debug location in the builder using the debug location \p DL. void setDebugLocFrom(DebugLoc DL); @@ -372,7 +391,11 @@ struct VPTransformState { /// of replication, maps the BasicBlock of the last replica created. SmallDenseMap<VPBasicBlock *, BasicBlock *> VPBB2IRBB; - CFGState() = default; + /// Updater for the DominatorTree. + DomTreeUpdater DTU; + + CFGState(DominatorTree *DT) + : DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy) {} /// Returns the BasicBlock* mapped to the pre-header of the loop region /// containing \p R. @@ -382,17 +405,9 @@ struct VPTransformState { /// Hold a pointer to LoopInfo to register new basic blocks in the loop. LoopInfo *LI; - /// Hold a pointer to Dominator Tree to register new basic blocks in the loop. - DominatorTree *DT; - /// Hold a reference to the IRBuilder used to generate output IR code. IRBuilderBase &Builder; - VPValue2ValueTy VPValue2Value; - - /// Hold the canonical scalar IV of the vector loop (start=0, step=VF*UF). - Value *CanonicalIV = nullptr; - /// Hold a pointer to InnerLoopVectorizer to reuse its IR generation methods. InnerLoopVectorizer *ILV; @@ -476,7 +491,7 @@ public: /// that are actually instantiated. Values of this enumeration are kept in the /// SubclassID field of the VPBlockBase objects. They are used for concrete /// type identification. - using VPBlockTy = enum { VPBasicBlockSC, VPRegionBlockSC }; + using VPBlockTy = enum { VPRegionBlockSC, VPBasicBlockSC, VPIRBasicBlockSC }; using VPBlocksTy = SmallVectorImpl<VPBlockBase *>; @@ -611,6 +626,15 @@ public: appendPredecessor(Pred); } + /// Set each VPBasicBlock in \p NewSuccss as successor of this VPBlockBase. + /// This VPBlockBase must have no successors. This VPBlockBase is not added + /// as predecessor of any VPBasicBlock in \p NewSuccs. + void setSuccessors(ArrayRef<VPBlockBase *> NewSuccs) { + assert(Successors.empty() && "Block successors already set."); + for (auto *Succ : NewSuccs) + appendSuccessor(Succ); + } + /// Remove all the predecessor of this block. void clearPredecessors() { Predecessors.clear(); } @@ -621,6 +645,9 @@ public: /// VPBlockBase, thereby "executing" the VPlan. virtual void execute(VPTransformState *State) = 0; + /// Return the cost of the block. + virtual InstructionCost cost(ElementCount VF, VPCostContext &Ctx) = 0; + /// Delete all blocks reachable from a given VPBlockBase, inclusive. static void deleteCFG(VPBlockBase *Entry); @@ -662,10 +689,18 @@ public: /// Dump this VPBlockBase to dbgs(). LLVM_DUMP_METHOD void dump() const { print(dbgs()); } #endif + + /// Clone the current block and it's recipes without updating the operands of + /// the cloned recipes, including all blocks in the single-entry single-exit + /// region for VPRegionBlocks. + virtual VPBlockBase *clone() = 0; }; /// A value that is used outside the VPlan. The operand of the user needs to be -/// added to the associated LCSSA phi node. +/// added to the associated phi node. The incoming block from VPlan is +/// determined by where the VPValue is defined: if it is defined by a recipe +/// outside a region, its parent block is used, otherwise the middle block is +/// used. class VPLiveOut : public VPUser { PHINode *Phi; @@ -677,11 +712,10 @@ public: return U->getVPUserID() == VPUser::VPUserID::LiveOut; } - /// Fixup the wrapped LCSSA phi node in the unique exit block. This simply - /// means we need to add the appropriate incoming value from the middle - /// block as exiting edges from the scalar epilogue loop (if present) are - /// already in place, and we exit the vector loop exclusively to the middle - /// block. + /// Fix the wrapped phi node. This means adding an incoming value to exit + /// block phi's from the vector loop via middle block (values from scalar loop + /// already reach these phi's), and updating the value to scalar header phi's + /// from the scalar preheader. void fixPhi(VPlan &Plan, VPTransformState &State); /// Returns true if the VPLiveOut uses scalars of operand \p Op. @@ -699,6 +733,27 @@ public: #endif }; +/// Struct to hold various analysis needed for cost computations. +struct VPCostContext { + const TargetTransformInfo &TTI; + VPTypeAnalysis Types; + LLVMContext &LLVMCtx; + LoopVectorizationCostModel &CM; + SmallPtrSet<Instruction *, 8> SkipCostComputation; + + VPCostContext(const TargetTransformInfo &TTI, Type *CanIVTy, + LLVMContext &LLVMCtx, LoopVectorizationCostModel &CM) + : TTI(TTI), Types(CanIVTy, LLVMCtx), LLVMCtx(LLVMCtx), CM(CM) {} + + /// Return the cost for \p UI with \p VF using the legacy cost model as + /// fallback until computing the cost of all recipes migrates to VPlan. + InstructionCost getLegacyCost(Instruction *UI, ElementCount VF) const; + + /// Return true if the cost for \p UI shouldn't be computed, e.g. because it + /// has already been pre-computed. + bool skipCostComputation(Instruction *UI, bool IsVector) const; +}; + /// VPRecipeBase is a base class modeling a sequence of one or more output IR /// instructions. VPRecipeBase owns the VPValues it defines through VPDef /// and is responsible for deleting its defined values. Single-value @@ -727,6 +782,9 @@ public: : VPDef(SC), VPUser(Operands, VPUser::VPUserID::Recipe), DL(DL) {} virtual ~VPRecipeBase() = default; + /// Clone the current recipe. + virtual VPRecipeBase *clone() = 0; + /// \return the VPBasicBlock which this VPRecipe belongs to. VPBasicBlock *getParent() { return Parent; } const VPBasicBlock *getParent() const { return Parent; } @@ -735,6 +793,11 @@ public: /// this VPRecipe, thereby "executing" the VPlan. virtual void execute(VPTransformState &State) = 0; + /// Return the cost of this recipe, taking into account if the cost + /// computation should be skipped and the ForceTargetInstructionCost flag. + /// Also takes care of printing the cost for debugging. + virtual InstructionCost cost(ElementCount VF, VPCostContext &Ctx); + /// Insert an unlinked recipe into a basic block immediately before /// the specified recipe. void insertBefore(VPRecipeBase *InsertPos); @@ -795,6 +858,11 @@ public: /// Returns the debug location of the recipe. DebugLoc getDebugLoc() const { return DL; } + +protected: + /// Compute the cost of this recipe using the legacy cost model and the + /// underlying instructions. + InstructionCost computeCost(ElementCount VF, VPCostContext &Ctx) const; }; // Helper macro to define common classof implementations for recipes. @@ -838,8 +906,10 @@ public: static inline bool classof(const VPRecipeBase *R) { switch (R->getVPDefID()) { case VPRecipeBase::VPDerivedIVSC: + case VPRecipeBase::VPEVLBasedIVPHISC: case VPRecipeBase::VPExpandSCEVSC: case VPRecipeBase::VPInstructionSC: + case VPRecipeBase::VPReductionEVLSC: case VPRecipeBase::VPReductionSC: case VPRecipeBase::VPReplicateSC: case VPRecipeBase::VPScalarIVStepsSC: @@ -859,10 +929,14 @@ public: case VPRecipeBase::VPWidenIntOrFpInductionSC: case VPRecipeBase::VPWidenPointerInductionSC: case VPRecipeBase::VPReductionPHISC: + case VPRecipeBase::VPScalarCastSC: return true; case VPRecipeBase::VPInterleaveSC: case VPRecipeBase::VPBranchOnMaskSC: - case VPRecipeBase::VPWidenMemoryInstructionSC: + case VPRecipeBase::VPWidenLoadEVLSC: + case VPRecipeBase::VPWidenLoadSC: + case VPRecipeBase::VPWidenStoreEVLSC: + case VPRecipeBase::VPWidenStoreSC: // TODO: Widened stores don't define a value, but widened loads do. Split // the recipes to be able to make widened loads VPSingleDefRecipes. return false; @@ -875,6 +949,8 @@ public: return R && classof(R); } + virtual VPSingleDefRecipe *clone() override = 0; + /// Returns the underlying instruction. Instruction *getUnderlyingInstr() { return cast<Instruction>(getUnderlyingValue()); @@ -905,6 +981,11 @@ public: WrapFlagsTy(bool HasNUW, bool HasNSW) : HasNUW(HasNUW), HasNSW(HasNSW) {} }; + struct DisjointFlagsTy { + char IsDisjoint : 1; + DisjointFlagsTy(bool IsDisjoint) : IsDisjoint(IsDisjoint) {} + }; + protected: struct GEPFlagsTy { char IsInBounds : 1; @@ -912,9 +993,6 @@ protected: }; private: - struct DisjointFlagsTy { - char IsDisjoint : 1; - }; struct ExactFlagsTy { char IsExact : 1; }; @@ -946,6 +1024,12 @@ private: unsigned AllFlags; }; +protected: + void transferFlags(VPRecipeWithIRFlags &Other) { + OpType = Other.OpType; + AllFlags = Other.AllFlags; + } + public: template <typename IterT> VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, DebugLoc DL = {}) @@ -1002,6 +1086,12 @@ public: : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::FPMathOp), FMFs(FMFs) {} + template <typename IterT> + VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, + DisjointFlagsTy DisjointFlags, DebugLoc DL = {}) + : VPSingleDefRecipe(SC, Operands, DL), OpType(OperationType::DisjointOp), + DisjointFlags(DisjointFlags) {} + protected: template <typename IterT> VPRecipeWithIRFlags(const unsigned char SC, IterT Operands, @@ -1019,6 +1109,11 @@ public: R->getVPDefID() == VPRecipeBase::VPVectorPointerSC; } + static inline bool classof(const VPUser *U) { + auto *R = dyn_cast<VPRecipeBase>(U); + return R && classof(R); + } + /// Drop all poison-generating flags. void dropPoisonGeneratingFlags() { // NOTE: This needs to be kept in-sync with @@ -1064,7 +1159,10 @@ public: I->setIsExact(ExactFlags.IsExact); break; case OperationType::GEPOp: - cast<GetElementPtrInst>(I)->setIsInBounds(GEPFlags.IsInBounds); + // TODO(gep_nowrap): Track the full GEPNoWrapFlags in VPlan. + cast<GetElementPtrInst>(I)->setNoWrapFlags( + GEPFlags.IsInBounds ? GEPNoWrapFlags::inBounds() + : GEPNoWrapFlags::none()); break; case OperationType::FPMathOp: I->setHasAllowReassoc(FMFs.AllowReassoc); @@ -1113,6 +1211,12 @@ public: return WrapFlags.HasNSW; } + bool isDisjoint() const { + assert(OpType == OperationType::DisjointOp && + "recipe cannot have a disjoing flag"); + return DisjointFlags.IsDisjoint; + } + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void printFlags(raw_ostream &O) const; #endif @@ -1135,12 +1239,29 @@ public: SLPLoad, SLPStore, ActiveLaneMask, + ExplicitVectorLength, + /// Creates a scalar phi in a leaf VPBB with a single predecessor in VPlan. + /// The first operand is the incoming value from the predecessor in VPlan, + /// the second operand is the incoming value for all other predecessors + /// (which are currently not modeled in VPlan). + ResumePhi, CalculateTripCountMinusVF, // Increment the canonical IV separately for each unrolled part. CanonicalIVIncrementForPart, BranchOnCount, BranchOnCond, ComputeReductionResult, + // Takes the VPValue to extract from as first operand and the lane or part + // to extract as second operand, counting from the end starting with 1 for + // last. The second operand must be a positive constant and <= VF when + // extracting from a vector or <= UF when extracting from an unrolled + // scalar. + ExtractFromEnd, + LogicalAnd, // Non-poison propagating logical And. + // Add an offset in bytes (second operand) to a base pointer (first + // operand). Only generates scalar values (either for the first lane only or + // for all lanes, depending on its uses). + PtrAdd, }; private: @@ -1150,11 +1271,28 @@ private: /// An optional name that can be used for the generated IR instruction. const std::string Name; - /// Utility method serving execute(): generates a single instance of the - /// modeled instruction. \returns the generated value for \p Part. - /// In some cases an existing value is returned rather than a generated + /// Returns true if this VPInstruction generates scalar values for all lanes. + /// Most VPInstructions generate a single value per part, either vector or + /// scalar. VPReplicateRecipe takes care of generating multiple (scalar) + /// values per all lanes, stemming from an original ingredient. This method + /// identifies the (rare) cases of VPInstructions that do so as well, w/o an + /// underlying ingredient. + bool doesGeneratePerAllLanes() const; + + /// Returns true if we can generate a scalar for the first lane only if + /// needed. + bool canGenerateScalarForFirstLane() const; + + /// Utility methods serving execute(): generates a single instance of the + /// modeled instruction for a given part. \returns the generated value for \p + /// Part. In some cases an existing value is returned rather than a generated /// one. - Value *generateInstruction(VPTransformState &State, unsigned Part); + Value *generatePerPart(VPTransformState &State, unsigned Part); + + /// Utility methods serving execute(): generates a scalar single instance of + /// the modeled instruction for a given lane. \returns the scalar generated + /// value for lane \p Lane. + Value *generatePerLane(VPTransformState &State, const VPIteration &Lane); #if !defined(NDEBUG) /// Return true if the VPInstruction is a floating point math operation, i.e. @@ -1162,9 +1300,6 @@ private: bool isFPMathOp() const; #endif -protected: - void setUnderlyingInstr(Instruction *I) { setUnderlyingValue(I); } - public: VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands, DebugLoc DL, const Twine &Name = "") @@ -1184,10 +1319,25 @@ public: Opcode(Opcode), Name(Name.str()) {} VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands, + DisjointFlagsTy DisjointFlag, DebugLoc DL = {}, + const Twine &Name = "") + : VPRecipeWithIRFlags(VPDef::VPInstructionSC, Operands, DisjointFlag, DL), + Opcode(Opcode), Name(Name.str()) { + assert(Opcode == Instruction::Or && "only OR opcodes can be disjoint"); + } + + VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands, FastMathFlags FMFs, DebugLoc DL = {}, const Twine &Name = ""); VP_CLASSOF_IMPL(VPDef::VPInstructionSC) + VPInstruction *clone() override { + SmallVector<VPValue *, 2> Operands(operands()); + auto *New = new VPInstruction(Opcode, Operands, getDebugLoc(), Name); + New->transferFlags(*this); + return New; + } + unsigned getOpcode() const { return Opcode; } /// Generate the instruction. @@ -1235,42 +1385,24 @@ public: } /// Returns true if the recipe only uses the first lane of operand \p Op. - bool onlyFirstLaneUsed(const VPValue *Op) const override { - assert(is_contained(operands(), Op) && - "Op must be an operand of the recipe"); - if (getOperand(0) != Op) - return false; - switch (getOpcode()) { - default: - return false; - case VPInstruction::ActiveLaneMask: - case VPInstruction::CalculateTripCountMinusVF: - case VPInstruction::CanonicalIVIncrementForPart: - case VPInstruction::BranchOnCount: - return true; - }; - llvm_unreachable("switch should return"); - } + bool onlyFirstLaneUsed(const VPValue *Op) const override; /// Returns true if the recipe only uses the first part of operand \p Op. - bool onlyFirstPartUsed(const VPValue *Op) const override { - assert(is_contained(operands(), Op) && - "Op must be an operand of the recipe"); - if (getOperand(0) != Op) - return false; - switch (getOpcode()) { - default: - return false; - case VPInstruction::BranchOnCount: - return true; - }; - llvm_unreachable("switch should return"); - } + bool onlyFirstPartUsed(const VPValue *Op) const override; + + /// Returns true if this VPInstruction produces a scalar value from a vector, + /// e.g. by performing a reduction or extracting a lane. + bool isVectorToScalar() const; + + /// Returns true if this VPInstruction's operands are single scalars and the + /// result is also a single scalar. + bool isSingleScalar() const; }; -/// VPWidenRecipe is a recipe for producing a copy of vector type its -/// ingredient. This recipe covers most of the traditional vectorization cases -/// where each ingredient transforms into a vectorized version of itself. +/// VPWidenRecipe is a recipe for producing a widened instruction using the +/// opcode and operands of the recipe. This recipe covers most of the +/// traditional vectorization cases where each recipe transforms into a +/// vectorized version of itself. class VPWidenRecipe : public VPRecipeWithIRFlags { unsigned Opcode; @@ -1282,9 +1414,16 @@ public: ~VPWidenRecipe() override = default; + VPWidenRecipe *clone() override { + auto *R = new VPWidenRecipe(*getUnderlyingInstr(), operands()); + R->transferFlags(*this); + return R; + } + VP_CLASSOF_IMPL(VPDef::VPWidenSC) - /// Produce widened copies of all Ingredients. + /// Produce a widened instruction using the opcode and operands of the recipe, + /// processing State.VF elements. void execute(VPTransformState &State) override; unsigned getOpcode() const { return Opcode; } @@ -1311,8 +1450,6 @@ public: ResultTy(ResultTy) { assert(UI.getOpcode() == Opcode && "opcode of underlying cast doesn't match"); - assert(UI.getType() == ResultTy && - "result type of underlying cast doesn't match"); } VPWidenCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy) @@ -1321,6 +1458,14 @@ public: ~VPWidenCastRecipe() override = default; + VPWidenCastRecipe *clone() override { + if (auto *UV = getUnderlyingValue()) + return new VPWidenCastRecipe(Opcode, getOperand(0), ResultTy, + *cast<CastInst>(UV)); + + return new VPWidenCastRecipe(Opcode, getOperand(0), ResultTy); + } + VP_CLASSOF_IMPL(VPDef::VPWidenCastSC) /// Produce widened copies of the cast. @@ -1338,6 +1483,45 @@ public: Type *getResultType() const { return ResultTy; } }; +/// VPScalarCastRecipe is a recipe to create scalar cast instructions. +class VPScalarCastRecipe : public VPSingleDefRecipe { + Instruction::CastOps Opcode; + + Type *ResultTy; + + Value *generate(VPTransformState &State, unsigned Part); + +public: + VPScalarCastRecipe(Instruction::CastOps Opcode, VPValue *Op, Type *ResultTy) + : VPSingleDefRecipe(VPDef::VPScalarCastSC, {Op}), Opcode(Opcode), + ResultTy(ResultTy) {} + + ~VPScalarCastRecipe() override = default; + + VPScalarCastRecipe *clone() override { + return new VPScalarCastRecipe(Opcode, getOperand(0), ResultTy); + } + + VP_CLASSOF_IMPL(VPDef::VPScalarCastSC) + + void execute(VPTransformState &State) override; + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + void print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const override; +#endif + + /// Returns the result type of the cast. + Type *getResultType() const { return ResultTy; } + + bool onlyFirstLaneUsed(const VPValue *Op) const override { + // At the moment, only uniform codegen is implemented. + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + return true; + } +}; + /// A recipe for widening Call instructions. class VPWidenCallRecipe : public VPSingleDefRecipe { /// ID of the vector intrinsic to call when widening the call. If set the @@ -1351,19 +1535,39 @@ class VPWidenCallRecipe : public VPSingleDefRecipe { public: template <typename IterT> - VPWidenCallRecipe(CallInst &I, iterator_range<IterT> CallArguments, + VPWidenCallRecipe(Value *UV, iterator_range<IterT> CallArguments, Intrinsic::ID VectorIntrinsicID, DebugLoc DL = {}, Function *Variant = nullptr) - : VPSingleDefRecipe(VPDef::VPWidenCallSC, CallArguments, &I, DL), - VectorIntrinsicID(VectorIntrinsicID), Variant(Variant) {} + : VPSingleDefRecipe(VPDef::VPWidenCallSC, CallArguments, UV, DL), + VectorIntrinsicID(VectorIntrinsicID), Variant(Variant) { + assert( + isa<Function>(getOperand(getNumOperands() - 1)->getLiveInIRValue()) && + "last operand must be the called function"); + } ~VPWidenCallRecipe() override = default; + VPWidenCallRecipe *clone() override { + return new VPWidenCallRecipe(getUnderlyingValue(), operands(), + VectorIntrinsicID, getDebugLoc(), Variant); + } + VP_CLASSOF_IMPL(VPDef::VPWidenCallSC) /// Produce a widened version of the call instruction. void execute(VPTransformState &State) override; + Function *getCalledScalarFunction() const { + return cast<Function>(getOperand(getNumOperands() - 1)->getLiveInIRValue()); + } + + operand_range arg_operands() { + return make_range(op_begin(), op_begin() + getNumOperands() - 1); + } + const_operand_range arg_operands() const { + return make_range(op_begin(), op_begin() + getNumOperands() - 1); + } + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) /// Print the recipe. void print(raw_ostream &O, const Twine &Indent, @@ -1380,6 +1584,11 @@ struct VPWidenSelectRecipe : public VPSingleDefRecipe { ~VPWidenSelectRecipe() override = default; + VPWidenSelectRecipe *clone() override { + return new VPWidenSelectRecipe(*cast<SelectInst>(getUnderlyingInstr()), + operands()); + } + VP_CLASSOF_IMPL(VPDef::VPWidenSelectSC) /// Produce a widened version of the select instruction. @@ -1423,6 +1632,11 @@ public: ~VPWidenGEPRecipe() override = default; + VPWidenGEPRecipe *clone() override { + return new VPWidenGEPRecipe(cast<GetElementPtrInst>(getUnderlyingInstr()), + operands()); + } + VP_CLASSOF_IMPL(VPDef::VPWidenGEPSC) /// Generate the gep nodes. @@ -1459,6 +1673,11 @@ public: return true; } + VPVectorPointerRecipe *clone() override { + return new VPVectorPointerRecipe(getOperand(0), IndexedTy, IsReverse, + isInBounds(), getDebugLoc()); + } + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) /// Print the recipe. void print(raw_ostream &O, const Twine &Indent, @@ -1569,6 +1788,11 @@ public: ~VPWidenIntOrFpInductionRecipe() override = default; + VPWidenIntOrFpInductionRecipe *clone() override { + return new VPWidenIntOrFpInductionRecipe(IV, getStartValue(), + getStepValue(), IndDesc, Trunc); + } + VP_CLASSOF_IMPL(VPDef::VPWidenIntOrFpInductionSC) /// Generate the vectorized and scalarized versions of the phi node as @@ -1610,7 +1834,8 @@ public: const InductionDescriptor &getInductionDescriptor() const { return IndDesc; } /// Returns true if the induction is canonical, i.e. starting at 0 and - /// incremented by UF * VF (= the original IV is incremented by 1). + /// incremented by UF * VF (= the original IV is incremented by 1) and has the + /// same type as the canonical induction. bool isCanonical() const; /// Returns the scalar type of the induction. @@ -1639,13 +1864,19 @@ public: ~VPWidenPointerInductionRecipe() override = default; + VPWidenPointerInductionRecipe *clone() override { + return new VPWidenPointerInductionRecipe( + cast<PHINode>(getUnderlyingInstr()), getOperand(0), getOperand(1), + IndDesc, IsScalarAfterVectorization); + } + VP_CLASSOF_IMPL(VPDef::VPWidenPointerInductionSC) /// Generate vector values for the pointer induction. void execute(VPTransformState &State) override; /// Returns true if only scalar values will be generated. - bool onlyScalarsGenerated(ElementCount VF); + bool onlyScalarsGenerated(bool IsScalable); /// Returns the induction descriptor for the recipe. const InductionDescriptor &getInductionDescriptor() const { return IndDesc; } @@ -1657,21 +1888,25 @@ public: #endif }; -/// A recipe for handling header phis that are widened in the vector loop. +/// A recipe for handling phis that are widened in the vector loop. /// In the VPlan native path, all incoming VPValues & VPBasicBlock pairs are /// managed in the recipe directly. -class VPWidenPHIRecipe : public VPHeaderPHIRecipe { +class VPWidenPHIRecipe : public VPSingleDefRecipe { /// List of incoming blocks. Only used in the VPlan native path. SmallVector<VPBasicBlock *, 2> IncomingBlocks; public: /// Create a new VPWidenPHIRecipe for \p Phi with start value \p Start. VPWidenPHIRecipe(PHINode *Phi, VPValue *Start = nullptr) - : VPHeaderPHIRecipe(VPDef::VPWidenPHISC, Phi) { + : VPSingleDefRecipe(VPDef::VPWidenPHISC, ArrayRef<VPValue *>(), Phi) { if (Start) addOperand(Start); } + VPWidenPHIRecipe *clone() override { + llvm_unreachable("cloning not implemented yet"); + } + ~VPWidenPHIRecipe() override = default; VP_CLASSOF_IMPL(VPDef::VPWidenPHISC) @@ -1711,6 +1946,11 @@ struct VPFirstOrderRecurrencePHIRecipe : public VPHeaderPHIRecipe { return R->getVPDefID() == VPDef::VPFirstOrderRecurrencePHISC; } + VPFirstOrderRecurrencePHIRecipe *clone() override { + return new VPFirstOrderRecurrencePHIRecipe( + cast<PHINode>(getUnderlyingInstr()), *getOperand(0)); + } + void execute(VPTransformState &State) override; #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -1746,6 +1986,14 @@ public: ~VPReductionPHIRecipe() override = default; + VPReductionPHIRecipe *clone() override { + auto *R = + new VPReductionPHIRecipe(cast<PHINode>(getUnderlyingInstr()), RdxDesc, + *getOperand(0), IsInLoop, IsOrdered); + R->addOperand(getBackedgeValue()); + return R; + } + VP_CLASSOF_IMPL(VPDef::VPReductionPHISC) static inline bool classof(const VPHeaderPHIRecipe *R) { @@ -1777,27 +2025,35 @@ public: class VPBlendRecipe : public VPSingleDefRecipe { public: /// The blend operation is a User of the incoming values and of their - /// respective masks, ordered [I0, M0, I1, M1, ...]. Note that a single value - /// might be incoming with a full mask for which there is no VPValue. + /// respective masks, ordered [I0, I1, M1, I2, M2, ...]. Note that the first + /// incoming value does not have a mask associated. VPBlendRecipe(PHINode *Phi, ArrayRef<VPValue *> Operands) : VPSingleDefRecipe(VPDef::VPBlendSC, Operands, Phi, Phi->getDebugLoc()) { - assert(Operands.size() > 0 && - ((Operands.size() == 1) || (Operands.size() % 2 == 0)) && - "Expected either a single incoming value or a positive even number " - "of operands"); + assert((Operands.size() + 1) % 2 == 0 && + "Expected an odd number of operands"); + } + + VPBlendRecipe *clone() override { + SmallVector<VPValue *> Ops(operands()); + return new VPBlendRecipe(cast<PHINode>(getUnderlyingValue()), Ops); } VP_CLASSOF_IMPL(VPDef::VPBlendSC) - /// Return the number of incoming values, taking into account that a single + /// Return the number of incoming values, taking into account that the first /// incoming value has no mask. unsigned getNumIncomingValues() const { return (getNumOperands() + 1) / 2; } /// Return incoming value number \p Idx. - VPValue *getIncomingValue(unsigned Idx) const { return getOperand(Idx * 2); } + VPValue *getIncomingValue(unsigned Idx) const { + return Idx == 0 ? getOperand(0) : getOperand(Idx * 2 - 1); + } /// Return mask number \p Idx. - VPValue *getMask(unsigned Idx) const { return getOperand(Idx * 2 + 1); } + VPValue *getMask(unsigned Idx) const { + assert(Idx > 0 && "First index has no mask associated."); + return getOperand(Idx * 2); + } /// Generate the phi/select nodes. void execute(VPTransformState &State) override; @@ -1856,6 +2112,11 @@ public: } ~VPInterleaveRecipe() override = default; + VPInterleaveRecipe *clone() override { + return new VPInterleaveRecipe(IG, getAddr(), getStoredValues(), getMask(), + NeedsMaskForGaps); + } + VP_CLASSOF_IMPL(VPDef::VPInterleaveSC) /// Return the address accessed by this recipe. @@ -1902,6 +2163,8 @@ public: "Op must be an operand of the recipe"); return Op == getAddr() && !llvm::is_contained(getStoredValues(), Op); } + + Instruction *getInsertPos() const { return IG->getInsertPos(); } }; /// A recipe to represent inloop reduction operations, performing a reduction on @@ -1910,20 +2173,45 @@ public: class VPReductionRecipe : public VPSingleDefRecipe { /// The recurrence decriptor for the reduction in question. const RecurrenceDescriptor &RdxDesc; + bool IsOrdered; + /// Whether the reduction is conditional. + bool IsConditional = false; -public: - VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I, - VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp) - : VPSingleDefRecipe(VPDef::VPReductionSC, - ArrayRef<VPValue *>({ChainOp, VecOp}), I), - RdxDesc(R) { - if (CondOp) +protected: + VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R, + Instruction *I, ArrayRef<VPValue *> Operands, + VPValue *CondOp, bool IsOrdered) + : VPSingleDefRecipe(SC, Operands, I), RdxDesc(R), IsOrdered(IsOrdered) { + if (CondOp) { + IsConditional = true; addOperand(CondOp); + } } +public: + VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I, + VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp, + bool IsOrdered) + : VPReductionRecipe(VPDef::VPReductionSC, R, I, + ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp, + IsOrdered) {} + ~VPReductionRecipe() override = default; - VP_CLASSOF_IMPL(VPDef::VPReductionSC) + VPReductionRecipe *clone() override { + return new VPReductionRecipe(RdxDesc, getUnderlyingInstr(), getChainOp(), + getVecOp(), getCondOp(), IsOrdered); + } + + static inline bool classof(const VPRecipeBase *R) { + return R->getVPDefID() == VPRecipeBase::VPReductionSC || + R->getVPDefID() == VPRecipeBase::VPReductionEVLSC; + } + + static inline bool classof(const VPUser *U) { + auto *R = dyn_cast<VPRecipeBase>(U); + return R && classof(R); + } /// Generate the reduction in the loop void execute(VPTransformState &State) override; @@ -1934,13 +2222,62 @@ public: VPSlotTracker &SlotTracker) const override; #endif + /// Return the recurrence decriptor for the in-loop reduction. + const RecurrenceDescriptor &getRecurrenceDescriptor() const { + return RdxDesc; + } + /// Return true if the in-loop reduction is ordered. + bool isOrdered() const { return IsOrdered; }; + /// Return true if the in-loop reduction is conditional. + bool isConditional() const { return IsConditional; }; /// The VPValue of the scalar Chain being accumulated. VPValue *getChainOp() const { return getOperand(0); } /// The VPValue of the vector value to be reduced. VPValue *getVecOp() const { return getOperand(1); } /// The VPValue of the condition for the block. VPValue *getCondOp() const { - return getNumOperands() > 2 ? getOperand(2) : nullptr; + return isConditional() ? getOperand(getNumOperands() - 1) : nullptr; + } +}; + +/// A recipe to represent inloop reduction operations with vector-predication +/// intrinsics, performing a reduction on a vector operand with the explicit +/// vector length (EVL) into a scalar value, and adding the result to a chain. +/// The Operands are {ChainOp, VecOp, EVL, [Condition]}. +class VPReductionEVLRecipe : public VPReductionRecipe { +public: + VPReductionEVLRecipe(VPReductionRecipe *R, VPValue *EVL, VPValue *CondOp) + : VPReductionRecipe( + VPDef::VPReductionEVLSC, R->getRecurrenceDescriptor(), + cast_or_null<Instruction>(R->getUnderlyingValue()), + ArrayRef<VPValue *>({R->getChainOp(), R->getVecOp(), EVL}), CondOp, + R->isOrdered()) {} + + ~VPReductionEVLRecipe() override = default; + + VPReductionEVLRecipe *clone() override { + llvm_unreachable("cloning not implemented yet"); + } + + VP_CLASSOF_IMPL(VPDef::VPReductionEVLSC) + + /// Generate the reduction in the loop + void execute(VPTransformState &State) override; + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const override; +#endif + + /// The VPValue of the explicit vector length. + VPValue *getEVL() const { return getOperand(2); } + + /// Returns true if the recipe only uses the first lane of operand \p Op. + bool onlyFirstLaneUsed(const VPValue *Op) const override { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + return Op == getEVL(); } }; @@ -1967,6 +2304,14 @@ public: ~VPReplicateRecipe() override = default; + VPReplicateRecipe *clone() override { + auto *Copy = + new VPReplicateRecipe(getUnderlyingInstr(), operands(), IsUniform, + isPredicated() ? getMask() : nullptr); + Copy->transferFlags(*this); + return Copy; + } + VP_CLASSOF_IMPL(VPDef::VPReplicateSC) /// Generate replicas of the desired Ingredient. Replicas will be generated @@ -2008,6 +2353,8 @@ public: assert(isPredicated() && "Trying to get the mask of a unpredicated recipe"); return getOperand(getNumOperands() - 1); } + + unsigned getOpcode() const { return getUnderlyingInstr()->getOpcode(); } }; /// A recipe for generating conditional branches on the bits of a mask. @@ -2019,6 +2366,10 @@ public: addOperand(BlockInMask); } + VPBranchOnMaskRecipe *clone() override { + return new VPBranchOnMaskRecipe(getOperand(0)); + } + VP_CLASSOF_IMPL(VPDef::VPBranchOnMaskSC) /// Generate the extraction of the appropriate bit from the block mask and the @@ -2066,6 +2417,10 @@ public: : VPSingleDefRecipe(VPDef::VPPredInstPHISC, PredV) {} ~VPPredInstPHIRecipe() override = default; + VPPredInstPHIRecipe *clone() override { + return new VPPredInstPHIRecipe(getOperand(0)); + } + VP_CLASSOF_IMPL(VPDef::VPPredInstPHISC) /// Generates phi nodes for live-outs as needed to retain SSA form. @@ -2085,56 +2440,66 @@ public: } }; -/// A Recipe for widening load/store operations. -/// The recipe uses the following VPValues: -/// - For load: Address, optional mask -/// - For store: Address, stored value, optional mask -/// TODO: We currently execute only per-part unless a specific instance is -/// provided. -class VPWidenMemoryInstructionRecipe : public VPRecipeBase { +/// A common base class for widening memory operations. An optional mask can be +/// provided as the last operand. +class VPWidenMemoryRecipe : public VPRecipeBase { +protected: Instruction &Ingredient; - // Whether the loaded-from / stored-to addresses are consecutive. + /// Whether the accessed addresses are consecutive. bool Consecutive; - // Whether the consecutive loaded/stored addresses are in reverse order. + /// Whether the consecutive accessed addresses are in reverse order. bool Reverse; + /// Whether the memory access is masked. + bool IsMasked = false; + void setMask(VPValue *Mask) { + assert(!IsMasked && "cannot re-set mask"); if (!Mask) return; addOperand(Mask); + IsMasked = true; } - bool isMasked() const { - return isStore() ? getNumOperands() == 3 : getNumOperands() == 2; + VPWidenMemoryRecipe(const char unsigned SC, Instruction &I, + std::initializer_list<VPValue *> Operands, + bool Consecutive, bool Reverse, DebugLoc DL) + : VPRecipeBase(SC, Operands, DL), Ingredient(I), Consecutive(Consecutive), + Reverse(Reverse) { + assert((Consecutive || !Reverse) && "Reverse implies consecutive"); } public: - VPWidenMemoryInstructionRecipe(LoadInst &Load, VPValue *Addr, VPValue *Mask, - bool Consecutive, bool Reverse) - : VPRecipeBase(VPDef::VPWidenMemoryInstructionSC, {Addr}), - Ingredient(Load), Consecutive(Consecutive), Reverse(Reverse) { - assert((Consecutive || !Reverse) && "Reverse implies consecutive"); - new VPValue(this, &Load); - setMask(Mask); + VPWidenMemoryRecipe *clone() override { + llvm_unreachable("cloning not supported"); } - VPWidenMemoryInstructionRecipe(StoreInst &Store, VPValue *Addr, - VPValue *StoredValue, VPValue *Mask, - bool Consecutive, bool Reverse) - : VPRecipeBase(VPDef::VPWidenMemoryInstructionSC, {Addr, StoredValue}), - Ingredient(Store), Consecutive(Consecutive), Reverse(Reverse) { - assert((Consecutive || !Reverse) && "Reverse implies consecutive"); - setMask(Mask); + static inline bool classof(const VPRecipeBase *R) { + return R->getVPDefID() == VPRecipeBase::VPWidenLoadSC || + R->getVPDefID() == VPRecipeBase::VPWidenStoreSC || + R->getVPDefID() == VPRecipeBase::VPWidenLoadEVLSC || + R->getVPDefID() == VPRecipeBase::VPWidenStoreEVLSC; + } + + static inline bool classof(const VPUser *U) { + auto *R = dyn_cast<VPRecipeBase>(U); + return R && classof(R); } - VP_CLASSOF_IMPL(VPDef::VPWidenMemoryInstructionSC) + /// Return whether the loaded-from / stored-to addresses are consecutive. + bool isConsecutive() const { return Consecutive; } + + /// Return whether the consecutive loaded/stored addresses are in reverse + /// order. + bool isReverse() const { return Reverse; } /// Return the address accessed by this recipe. - VPValue *getAddr() const { - return getOperand(0); // Address is the 1st, mandatory operand. - } + VPValue *getAddr() const { return getOperand(0); } + + /// Returns true if the recipe is masked. + bool isMasked() const { return IsMasked; } /// Return the mask used by this recipe. Note that a full mask is represented /// by a nullptr. @@ -2143,23 +2508,34 @@ public: return isMasked() ? getOperand(getNumOperands() - 1) : nullptr; } - /// Returns true if this recipe is a store. - bool isStore() const { return isa<StoreInst>(Ingredient); } + /// Generate the wide load/store. + void execute(VPTransformState &State) override { + llvm_unreachable("VPWidenMemoryRecipe should not be instantiated."); + } + + Instruction &getIngredient() const { return Ingredient; } +}; - /// Return the address accessed by this recipe. - VPValue *getStoredValue() const { - assert(isStore() && "Stored value only available for store instructions"); - return getOperand(1); // Stored value is the 2nd, mandatory operand. +/// A recipe for widening load operations, using the address to load from and an +/// optional mask. +struct VPWidenLoadRecipe final : public VPWidenMemoryRecipe, public VPValue { + VPWidenLoadRecipe(LoadInst &Load, VPValue *Addr, VPValue *Mask, + bool Consecutive, bool Reverse, DebugLoc DL) + : VPWidenMemoryRecipe(VPDef::VPWidenLoadSC, Load, {Addr}, Consecutive, + Reverse, DL), + VPValue(this, &Load) { + setMask(Mask); } - // Return whether the loaded-from / stored-to addresses are consecutive. - bool isConsecutive() const { return Consecutive; } + VPWidenLoadRecipe *clone() override { + return new VPWidenLoadRecipe(cast<LoadInst>(Ingredient), getAddr(), + getMask(), Consecutive, Reverse, + getDebugLoc()); + } - // Return whether the consecutive loaded/stored addresses are in reverse - // order. - bool isReverse() const { return Reverse; } + VP_CLASSOF_IMPL(VPDef::VPWidenLoadSC); - /// Generate the wide load/store. + /// Generate a wide load or gather. void execute(VPTransformState &State) override; #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -2172,15 +2548,130 @@ public: bool onlyFirstLaneUsed(const VPValue *Op) const override { assert(is_contained(operands(), Op) && "Op must be an operand of the recipe"); + // Widened, consecutive loads operations only demand the first lane of + // their address. + return Op == getAddr() && isConsecutive(); + } +}; + +/// A recipe for widening load operations with vector-predication intrinsics, +/// using the address to load from, the explicit vector length and an optional +/// mask. +struct VPWidenLoadEVLRecipe final : public VPWidenMemoryRecipe, public VPValue { + VPWidenLoadEVLRecipe(VPWidenLoadRecipe *L, VPValue *EVL, VPValue *Mask) + : VPWidenMemoryRecipe(VPDef::VPWidenLoadEVLSC, L->getIngredient(), + {L->getAddr(), EVL}, L->isConsecutive(), + L->isReverse(), L->getDebugLoc()), + VPValue(this, &getIngredient()) { + setMask(Mask); + } + + VP_CLASSOF_IMPL(VPDef::VPWidenLoadEVLSC) + + /// Return the EVL operand. + VPValue *getEVL() const { return getOperand(1); } + + /// Generate the wide load or gather. + void execute(VPTransformState &State) override; + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const override; +#endif + /// Returns true if the recipe only uses the first lane of operand \p Op. + bool onlyFirstLaneUsed(const VPValue *Op) const override { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + // Widened loads only demand the first lane of EVL and consecutive loads + // only demand the first lane of their address. + return Op == getEVL() || (Op == getAddr() && isConsecutive()); + } +}; + +/// A recipe for widening store operations, using the stored value, the address +/// to store to and an optional mask. +struct VPWidenStoreRecipe final : public VPWidenMemoryRecipe { + VPWidenStoreRecipe(StoreInst &Store, VPValue *Addr, VPValue *StoredVal, + VPValue *Mask, bool Consecutive, bool Reverse, DebugLoc DL) + : VPWidenMemoryRecipe(VPDef::VPWidenStoreSC, Store, {Addr, StoredVal}, + Consecutive, Reverse, DL) { + setMask(Mask); + } + + VPWidenStoreRecipe *clone() override { + return new VPWidenStoreRecipe(cast<StoreInst>(Ingredient), getAddr(), + getStoredValue(), getMask(), Consecutive, + Reverse, getDebugLoc()); + } + + VP_CLASSOF_IMPL(VPDef::VPWidenStoreSC); + + /// Return the value stored by this recipe. + VPValue *getStoredValue() const { return getOperand(1); } + + /// Generate a wide store or scatter. + void execute(VPTransformState &State) override; + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const override; +#endif + + /// Returns true if the recipe only uses the first lane of operand \p Op. + bool onlyFirstLaneUsed(const VPValue *Op) const override { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + // Widened, consecutive stores only demand the first lane of their address, + // unless the same operand is also stored. + return Op == getAddr() && isConsecutive() && Op != getStoredValue(); + } +}; + +/// A recipe for widening store operations with vector-predication intrinsics, +/// using the value to store, the address to store to, the explicit vector +/// length and an optional mask. +struct VPWidenStoreEVLRecipe final : public VPWidenMemoryRecipe { + VPWidenStoreEVLRecipe(VPWidenStoreRecipe *S, VPValue *EVL, VPValue *Mask) + : VPWidenMemoryRecipe(VPDef::VPWidenStoreEVLSC, S->getIngredient(), + {S->getAddr(), S->getStoredValue(), EVL}, + S->isConsecutive(), S->isReverse(), + S->getDebugLoc()) { + setMask(Mask); + } + + VP_CLASSOF_IMPL(VPDef::VPWidenStoreEVLSC) + + /// Return the address accessed by this recipe. + VPValue *getStoredValue() const { return getOperand(1); } + + /// Return the EVL operand. + VPValue *getEVL() const { return getOperand(2); } + + /// Generate the wide store or scatter. + void execute(VPTransformState &State) override; + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const override; +#endif + + /// Returns true if the recipe only uses the first lane of operand \p Op. + bool onlyFirstLaneUsed(const VPValue *Op) const override { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + if (Op == getEVL()) { + assert(getStoredValue() != Op && "unexpected store of EVL"); + return true; + } // Widened, consecutive memory operations only demand the first lane of // their address, unless the same operand is also stored. That latter can // happen with opaque pointers. - return Op == getAddr() && isConsecutive() && - (!isStore() || Op != getStoredValue()); + return Op == getAddr() && isConsecutive() && Op != getStoredValue(); } - - Instruction &getIngredient() const { return Ingredient; } }; /// Recipe to expand a SCEV expression. @@ -2194,6 +2685,10 @@ public: ~VPExpandSCEVRecipe() override = default; + VPExpandSCEVRecipe *clone() override { + return new VPExpandSCEVRecipe(Expr, SE); + } + VP_CLASSOF_IMPL(VPDef::VPExpandSCEVSC) /// Generate a canonical vector induction variable of the vector loop, with @@ -2219,6 +2714,12 @@ public: ~VPCanonicalIVPHIRecipe() override = default; + VPCanonicalIVPHIRecipe *clone() override { + auto *R = new VPCanonicalIVPHIRecipe(getOperand(0), getDebugLoc()); + R->addOperand(getBackedgeValue()); + return R; + } + VP_CLASSOF_IMPL(VPDef::VPCanonicalIVPHISC) static inline bool classof(const VPHeaderPHIRecipe *D) { @@ -2254,10 +2755,9 @@ public: } /// Check if the induction described by \p Kind, /p Start and \p Step is - /// canonical, i.e. has the same start, step (of 1), and type as the - /// canonical IV. + /// canonical, i.e. has the same start and step (of 1) as the canonical IV. bool isCanonical(InductionDescriptor::InductionKind Kind, VPValue *Start, - VPValue *Step, Type *Ty) const; + VPValue *Step) const; }; /// A recipe for generating the active lane mask for the vector loop that is @@ -2272,6 +2772,10 @@ public: ~VPActiveLaneMaskPHIRecipe() override = default; + VPActiveLaneMaskPHIRecipe *clone() override { + return new VPActiveLaneMaskPHIRecipe(getOperand(0), getDebugLoc()); + } + VP_CLASSOF_IMPL(VPDef::VPActiveLaneMaskPHISC) static inline bool classof(const VPHeaderPHIRecipe *D) { @@ -2288,6 +2792,45 @@ public: #endif }; +/// A recipe for generating the phi node for the current index of elements, +/// adjusted in accordance with EVL value. It starts at the start value of the +/// canonical induction and gets incremented by EVL in each iteration of the +/// vector loop. +class VPEVLBasedIVPHIRecipe : public VPHeaderPHIRecipe { +public: + VPEVLBasedIVPHIRecipe(VPValue *StartIV, DebugLoc DL) + : VPHeaderPHIRecipe(VPDef::VPEVLBasedIVPHISC, nullptr, StartIV, DL) {} + + ~VPEVLBasedIVPHIRecipe() override = default; + + VPEVLBasedIVPHIRecipe *clone() override { + llvm_unreachable("cloning not implemented yet"); + } + + VP_CLASSOF_IMPL(VPDef::VPEVLBasedIVPHISC) + + static inline bool classof(const VPHeaderPHIRecipe *D) { + return D->getVPDefID() == VPDef::VPEVLBasedIVPHISC; + } + + /// Generate phi for handling IV based on EVL over iterations correctly. + /// TODO: investigate if it can share the code with VPCanonicalIVPHIRecipe. + void execute(VPTransformState &State) override; + + /// Returns true if the recipe only uses the first lane of operand \p Op. + bool onlyFirstLaneUsed(const VPValue *Op) const override { + assert(is_contained(operands(), Op) && + "Op must be an operand of the recipe"); + return true; + } + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const override; +#endif +}; + /// A Recipe for widening the canonical induction variable of the vector loop. class VPWidenCanonicalIVRecipe : public VPSingleDefRecipe { public: @@ -2296,6 +2839,11 @@ public: ~VPWidenCanonicalIVRecipe() override = default; + VPWidenCanonicalIVRecipe *clone() override { + return new VPWidenCanonicalIVRecipe( + cast<VPCanonicalIVPHIRecipe>(getOperand(0))); + } + VP_CLASSOF_IMPL(VPDef::VPWidenCanonicalIVSC) /// Generate a canonical vector induction variable of the vector loop, with @@ -2308,22 +2856,12 @@ public: void print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const override; #endif - - /// Returns the scalar type of the induction. - const Type *getScalarType() const { - return cast<VPCanonicalIVPHIRecipe>(getOperand(0)->getDefiningRecipe()) - ->getScalarType(); - } }; -/// A recipe for converting the canonical IV value to the corresponding value of -/// an IV with different start and step values, using Start + CanonicalIV * +/// A recipe for converting the input value \p IV value to the corresponding +/// value of an IV with different start and step values, using Start + IV * /// Step. class VPDerivedIVRecipe : public VPSingleDefRecipe { - /// If not nullptr, the result of the induction will get truncated to - /// TruncResultTy. - Type *TruncResultTy; - /// Kind of the induction. const InductionDescriptor::InductionKind Kind; /// If not nullptr, the floating point induction binary operator. Must be set @@ -2332,15 +2870,25 @@ class VPDerivedIVRecipe : public VPSingleDefRecipe { public: VPDerivedIVRecipe(const InductionDescriptor &IndDesc, VPValue *Start, - VPCanonicalIVPHIRecipe *CanonicalIV, VPValue *Step, - Type *TruncResultTy) - : VPSingleDefRecipe(VPDef::VPDerivedIVSC, {Start, CanonicalIV, Step}), - TruncResultTy(TruncResultTy), Kind(IndDesc.getKind()), - FPBinOp(dyn_cast_or_null<FPMathOperator>(IndDesc.getInductionBinOp())) { - } + VPCanonicalIVPHIRecipe *CanonicalIV, VPValue *Step) + : VPDerivedIVRecipe( + IndDesc.getKind(), + dyn_cast_or_null<FPMathOperator>(IndDesc.getInductionBinOp()), + Start, CanonicalIV, Step) {} + + VPDerivedIVRecipe(InductionDescriptor::InductionKind Kind, + const FPMathOperator *FPBinOp, VPValue *Start, VPValue *IV, + VPValue *Step) + : VPSingleDefRecipe(VPDef::VPDerivedIVSC, {Start, IV, Step}), Kind(Kind), + FPBinOp(FPBinOp) {} ~VPDerivedIVRecipe() override = default; + VPDerivedIVRecipe *clone() override { + return new VPDerivedIVRecipe(Kind, FPBinOp, getStartValue(), getOperand(1), + getStepValue()); + } + VP_CLASSOF_IMPL(VPDef::VPDerivedIVSC) /// Generate the transformed value of the induction at offset StartValue (1. @@ -2354,12 +2902,10 @@ public: #endif Type *getScalarType() const { - return TruncResultTy ? TruncResultTy - : getStartValue()->getLiveInIRValue()->getType(); + return getStartValue()->getLiveInIRValue()->getType(); } VPValue *getStartValue() const { return getOperand(0); } - VPValue *getCanonicalIV() const { return getOperand(1); } VPValue *getStepValue() const { return getOperand(2); } /// Returns true if the recipe only uses the first lane of operand \p Op. @@ -2392,6 +2938,12 @@ public: ~VPScalarIVStepsRecipe() override = default; + VPScalarIVStepsRecipe *clone() override { + return new VPScalarIVStepsRecipe( + getOperand(0), getOperand(1), InductionOpcode, + hasFastMathFlags() ? getFastMathFlags() : FastMathFlags()); + } + VP_CLASSOF_IMPL(VPDef::VPScalarIVStepsSC) /// Generate the scalarized versions of the phi node as needed by their users. @@ -2420,10 +2972,13 @@ class VPBasicBlock : public VPBlockBase { public: using RecipeListTy = iplist<VPRecipeBase>; -private: +protected: /// The VPRecipes held in the order of output instructions to generate. RecipeListTy Recipes; + VPBasicBlock(const unsigned char BlockSC, const Twine &Name = "") + : VPBlockBase(BlockSC, Name.str()) {} + public: VPBasicBlock(const Twine &Name = "", VPRecipeBase *Recipe = nullptr) : VPBlockBase(VPBasicBlockSC, Name.str()) { @@ -2472,7 +3027,8 @@ public: /// Method to support type inquiry through isa, cast, and dyn_cast. static inline bool classof(const VPBlockBase *V) { - return V->getVPBlockID() == VPBlockBase::VPBasicBlockSC; + return V->getVPBlockID() == VPBlockBase::VPBasicBlockSC || + V->getVPBlockID() == VPBlockBase::VPIRBasicBlockSC; } void insert(VPRecipeBase *Recipe, iterator InsertPt) { @@ -2490,6 +3046,9 @@ public: /// this VPBasicBlock, thereby "executing" the VPlan. void execute(VPTransformState *State) override; + /// Return the cost of this VPBasicBlock. + InstructionCost cost(ElementCount VF, VPCostContext &Ctx) override; + /// Return the position of the first non-phi node recipe in the block. iterator getFirstNonPhi(); @@ -2526,12 +3085,59 @@ public: /// Returns true if the block is exiting it's parent region. bool isExiting() const; + /// Clone the current block and it's recipes, without updating the operands of + /// the cloned recipes. + VPBasicBlock *clone() override { + auto *NewBlock = new VPBasicBlock(getName()); + for (VPRecipeBase &R : *this) + NewBlock->appendRecipe(R.clone()); + return NewBlock; + } + +protected: + /// Execute the recipes in the IR basic block \p BB. + void executeRecipes(VPTransformState *State, BasicBlock *BB); + private: /// Create an IR BasicBlock to hold the output instructions generated by this /// VPBasicBlock, and return it. Update the CFGState accordingly. BasicBlock *createEmptyBasicBlock(VPTransformState::CFGState &CFG); }; +/// A special type of VPBasicBlock that wraps an existing IR basic block. +/// Recipes of the block get added before the first non-phi instruction in the +/// wrapped block. +/// Note: At the moment, VPIRBasicBlock can only be used to wrap VPlan's +/// preheader block. +class VPIRBasicBlock : public VPBasicBlock { + BasicBlock *IRBB; + +public: + VPIRBasicBlock(BasicBlock *IRBB) + : VPBasicBlock(VPIRBasicBlockSC, + (Twine("ir-bb<") + IRBB->getName() + Twine(">")).str()), + IRBB(IRBB) {} + + ~VPIRBasicBlock() override {} + + static inline bool classof(const VPBlockBase *V) { + return V->getVPBlockID() == VPBlockBase::VPIRBasicBlockSC; + } + + /// The method which generates the output IR instructions that correspond to + /// this VPBasicBlock, thereby "executing" the VPlan. + void execute(VPTransformState *State) override; + + VPIRBasicBlock *clone() override { + auto *NewBlock = new VPIRBasicBlock(IRBB); + for (VPRecipeBase &R : Recipes) + NewBlock->appendRecipe(R.clone()); + return NewBlock; + } + + BasicBlock *getIRBasicBlock() const { return IRBB; } +}; + /// VPRegionBlock represents a collection of VPBasicBlocks and VPRegionBlocks /// which form a Single-Entry-Single-Exiting subgraph of the output IR CFG. /// A VPRegionBlock may indicate that its contents are to be replicated several @@ -2617,6 +3223,9 @@ public: /// this VPRegionBlock, thereby "executing" the VPlan. void execute(VPTransformState *State) override; + // Return the cost of this region. + InstructionCost cost(ElementCount VF, VPCostContext &Ctx) override; + void dropAllReferences(VPValue *NewValue) override; #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -2630,6 +3239,10 @@ public: VPSlotTracker &SlotTracker) const override; using VPBlockBase::print; // Get the print(raw_stream &O) version. #endif + + /// Clone all blocks in the single-entry single-exit region of the block and + /// their recipes without updating the operands of the cloned recipes. + VPRegionBlock *clone() override; }; /// VPlan models a candidate for vectorization, encoding various decisions take @@ -2682,11 +3295,9 @@ class VPlan { /// definitions are VPValues that hold a pointer to their underlying IR. SmallVector<VPValue *, 16> VPLiveInsToFree; - /// Indicates whether it is safe use the Value2VPValue mapping or if the - /// mapping cannot be used any longer, because it is stale. - bool Value2VPValueEnabled = true; - - /// Values used outside the plan. + /// Values used outside the plan. It contains live-outs that need fixing. Any + /// live-out that is fixed outside VPlan needs to be removed. The remaining + /// live-outs are fixed via VPLiveOut::fixPhi. MapVector<PHINode *, VPLiveOut *> LiveOuts; /// Mapping from SCEVs to the VPValues representing their expansions. @@ -2719,13 +3330,17 @@ public: ~VPlan(); - /// Create initial VPlan skeleton, having an "entry" VPBasicBlock (wrapping - /// original scalar pre-header) which contains SCEV expansions that need to - /// happen before the CFG is modified; a VPBasicBlock for the vector + /// Create initial VPlan, having an "entry" VPBasicBlock (wrapping + /// original scalar pre-header ) which contains SCEV expansions that need + /// to happen before the CFG is modified; a VPBasicBlock for the vector /// pre-header, followed by a region for the vector loop, followed by the - /// middle VPBasicBlock. + /// middle VPBasicBlock. If a check is needed to guard executing the scalar + /// epilogue loop, it will be added to the middle block, together with + /// VPBasicBlocks for the scalar preheader and exit blocks. static VPlanPtr createInitialVPlan(const SCEV *TripCount, - ScalarEvolution &PSE); + ScalarEvolution &PSE, + bool RequiresScalarEpilogueCheck, + bool TailFolded, Loop *TheLoop); /// Prepare the plan for execution, setting up the required live-in values. void prepareToExecute(Value *TripCount, Value *VectorTripCount, @@ -2734,6 +3349,9 @@ public: /// Generate the IR code for this VPlan. void execute(VPTransformState *State); + /// Return the cost of this plan. + InstructionCost cost(ElementCount VF, VPCostContext &Ctx); + VPBasicBlock *getEntry() { return Entry; } const VPBasicBlock *getEntry() const { return Entry; } @@ -2743,6 +3361,14 @@ public: return TripCount; } + /// Resets the trip count for the VPlan. The caller must make sure all uses of + /// the original trip count have been replaced. + void resetTripCount(VPValue *NewTripCount) { + assert(TripCount && NewTripCount && TripCount->getNumUsers() == 0 && + "TripCount always must be set"); + TripCount = NewTripCount; + } + /// The backedge taken count of the original loop. VPValue *getOrCreateBackedgeTakenCount() { if (!BackedgeTakenCount) @@ -2756,10 +3382,6 @@ public: /// Returns VF * UF of the vector loop region. VPValue &getVFxUF() { return VFxUF; } - /// Mark the plan to indicate that using Value2VPValue is not safe any - /// longer, because it may be stale. - void disableValue2VPValue() { Value2VPValueEnabled = false; } - void addVF(ElementCount VF) { VFs.insert(VF); } void setVF(ElementCount VF) { @@ -2769,6 +3391,15 @@ public: } bool hasVF(ElementCount VF) { return VFs.count(VF); } + bool hasScalableVF() { + return any_of(VFs, [](ElementCount VF) { return VF.isScalable(); }); + } + + /// Returns an iterator range over all VFs of the plan. + iterator_range<SmallSetVector<ElementCount, 2>::iterator> + vectorFactors() const { + return {VFs.begin(), VFs.end()}; + } bool hasScalarVFOnly() const { return VFs.size() == 1 && VFs[0].isScalar(); } @@ -2785,38 +3416,27 @@ public: void setName(const Twine &newName) { Name = newName.str(); } - void addVPValue(Value *V, VPValue *VPV) { - assert((Value2VPValueEnabled || VPV->isLiveIn()) && - "Value2VPValue mapping may be out of date!"); - assert(V && "Trying to add a null Value to VPlan"); - assert(!Value2VPValue.count(V) && "Value already exists in VPlan"); - Value2VPValue[V] = VPV; - } - - /// Returns the VPValue for \p V. \p OverrideAllowed can be used to disable - /// /// checking whether it is safe to query VPValues using IR Values. - VPValue *getVPValue(Value *V, bool OverrideAllowed = false) { - assert(V && "Trying to get the VPValue of a null Value"); - assert(Value2VPValue.count(V) && "Value does not exist in VPlan"); - assert((Value2VPValueEnabled || OverrideAllowed || - Value2VPValue[V]->isLiveIn()) && - "Value2VPValue mapping may be out of date!"); - return Value2VPValue[V]; - } - - /// Gets the VPValue for \p V or adds a new live-in (if none exists yet) for - /// \p V. - VPValue *getVPValueOrAddLiveIn(Value *V) { + /// Gets the live-in VPValue for \p V or adds a new live-in (if none exists + /// yet) for \p V. + VPValue *getOrAddLiveIn(Value *V) { assert(V && "Trying to get or add the VPValue of a null Value"); if (!Value2VPValue.count(V)) { VPValue *VPV = new VPValue(V); VPLiveInsToFree.push_back(VPV); - addVPValue(V, VPV); + assert(VPV->isLiveIn() && "VPV must be a live-in."); + assert(!Value2VPValue.count(V) && "Value already exists in VPlan"); + Value2VPValue[V] = VPV; } - return getVPValue(V); + assert(Value2VPValue.count(V) && "Value does not exist in VPlan"); + assert(Value2VPValue[V]->isLiveIn() && + "Only live-ins should be in mapping"); + return Value2VPValue[V]; } + /// Return the live-in VPValue for \p V, if there is one or nullptr otherwise. + VPValue *getLiveIn(Value *V) const { return Value2VPValue.lookup(V); } + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) /// Print the live-ins of this VPlan to \p O. void printLiveIns(raw_ostream &O) const; @@ -2831,16 +3451,6 @@ public: LLVM_DUMP_METHOD void dump() const; #endif - /// Returns a range mapping the values the range \p Operands to their - /// corresponding VPValues. - iterator_range<mapped_iterator<Use *, std::function<VPValue *(Value *)>>> - mapToVPValues(User::op_range Operands) { - std::function<VPValue *(Value *)> Fn = [this](Value *Op) { - return getVPValueOrAddLiveIn(Op); - }; - return map_range(Operands, Fn); - } - /// Returns the VPRegionBlock of the vector loop. VPRegionBlock *getVectorLoopRegion() { return cast<VPRegionBlock>(getEntry()->getSingleSuccessor()); @@ -2883,12 +3493,9 @@ public: VPBasicBlock *getPreheader() { return Preheader; } const VPBasicBlock *getPreheader() const { return Preheader; } -private: - /// Add to the given dominator tree the header block and every new basic block - /// that was created between it and the latch block, inclusive. - static void updateDominatorTree(DominatorTree *DT, BasicBlock *LoopLatchBB, - BasicBlock *LoopPreHeaderBB, - BasicBlock *LoopExitBB); + /// Clone the current VPlan, update all VPValues of the new VPlan and cloned + /// recipes to refer to the clones, and return it. + VPlan *duplicate(); }; #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -3184,10 +3791,10 @@ public: namespace vputils { /// Returns true if only the first lane of \p Def is used. -bool onlyFirstLaneUsed(VPValue *Def); +bool onlyFirstLaneUsed(const VPValue *Def); /// Returns true if only the first part of \p Def is used. -bool onlyFirstPartUsed(VPValue *Def); +bool onlyFirstPartUsed(const VPValue *Def); /// Get or create a VPValue that corresponds to the expansion of \p Expr. If \p /// Expr is a SCEVConstant or SCEVUnknown, return a VPValue wrapping the live-in @@ -3210,9 +3817,12 @@ inline bool isUniformAfterVectorization(VPValue *VPV) { if (auto *GEP = dyn_cast<VPWidenGEPRecipe>(Def)) return all_of(GEP->operands(), isUniformAfterVectorization); if (auto *VPI = dyn_cast<VPInstruction>(Def)) - return VPI->getOpcode() == VPInstruction::ComputeReductionResult; + return VPI->isSingleScalar() || VPI->isVectorToScalar(); return false; } + +/// Return true if \p V is a header mask in \p Plan. +bool isHeaderMask(VPValue *V, VPlan &Plan); } // end namespace vputils } // end namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp index 97a8a1803bbf..6d89ad9fee8a 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp @@ -8,7 +8,10 @@ #include "VPlanAnalysis.h" #include "VPlan.h" +#include "VPlanCFG.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/PatternMatch.h" using namespace llvm; @@ -26,7 +29,24 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPBlendRecipe *R) { } Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) { - switch (R->getOpcode()) { + // Set the result type from the first operand, check if the types for all + // other operands match and cache them. + auto SetResultTyFromOp = [this, R]() { + Type *ResTy = inferScalarType(R->getOperand(0)); + for (unsigned Op = 1; Op != R->getNumOperands(); ++Op) { + VPValue *OtherV = R->getOperand(Op); + assert(inferScalarType(OtherV) == ResTy && + "different types inferred for different operands"); + CachedTypes[OtherV] = ResTy; + } + return ResTy; + }; + + unsigned Opcode = R->getOpcode(); + if (Instruction::isBinaryOp(Opcode) || Instruction::isUnaryOp(Opcode)) + return SetResultTyFromOp(); + + switch (Opcode) { case Instruction::Select: { Type *ResTy = inferScalarType(R->getOperand(1)); VPValue *OtherV = R->getOperand(2); @@ -35,14 +55,26 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) { CachedTypes[OtherV] = ResTy; return ResTy; } - case VPInstruction::FirstOrderRecurrenceSplice: { - Type *ResTy = inferScalarType(R->getOperand(0)); - VPValue *OtherV = R->getOperand(1); - assert(inferScalarType(OtherV) == ResTy && - "different types inferred for different operands"); - CachedTypes[OtherV] = ResTy; - return ResTy; + case Instruction::ICmp: + case VPInstruction::ActiveLaneMask: + return inferScalarType(R->getOperand(1)); + case VPInstruction::FirstOrderRecurrenceSplice: + case VPInstruction::Not: + return SetResultTyFromOp(); + case VPInstruction::ExtractFromEnd: { + Type *BaseTy = inferScalarType(R->getOperand(0)); + if (auto *VecTy = dyn_cast<VectorType>(BaseTy)) + return VecTy->getElementType(); + return BaseTy; } + case VPInstruction::LogicalAnd: + return IntegerType::get(Ctx, 1); + case VPInstruction::PtrAdd: + // Return the type based on the pointer argument (i.e. first operand). + return inferScalarType(R->getOperand(0)); + case VPInstruction::BranchOnCond: + case VPInstruction::BranchOnCount: + return Type::getVoidTy(Ctx); default: break; } @@ -104,9 +136,9 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenCallRecipe *R) { return CI.getType(); } -Type *VPTypeAnalysis::inferScalarTypeForRecipe( - const VPWidenMemoryInstructionRecipe *R) { - assert(!R->isStore() && "Store recipes should not define any values"); +Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenMemoryRecipe *R) { + assert((isa<VPWidenLoadRecipe>(R) || isa<VPWidenLoadEVLRecipe>(R)) && + "Store recipes should not define any values"); return cast<LoadInst>(&R->getIngredient())->getType(); } @@ -160,6 +192,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) { case Instruction::ICmp: case Instruction::FCmp: return IntegerType::get(Ctx, 1); + case Instruction::AddrSpaceCast: case Instruction::Alloca: case Instruction::BitCast: case Instruction::Trunc: @@ -201,15 +234,21 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) { if (Type *CachedTy = CachedTypes.lookup(V)) return CachedTy; - if (V->isLiveIn()) - return V->getLiveInIRValue()->getType(); + if (V->isLiveIn()) { + if (auto *IRValue = V->getLiveInIRValue()) + return IRValue->getType(); + // All VPValues without any underlying IR value (like the vector trip count + // or the backedge-taken count) have the same type as the canonical IV. + return CanonicalIVTy; + } Type *ResultTy = TypeSwitch<const VPRecipeBase *, Type *>(V->getDefiningRecipe()) - .Case<VPCanonicalIVPHIRecipe, VPFirstOrderRecurrencePHIRecipe, - VPReductionPHIRecipe, VPWidenPointerInductionRecipe>( + .Case<VPActiveLaneMaskPHIRecipe, VPCanonicalIVPHIRecipe, + VPFirstOrderRecurrencePHIRecipe, VPReductionPHIRecipe, + VPWidenPointerInductionRecipe, VPEVLBasedIVPHIRecipe>( [this](const auto *R) { - // Handle header phi recipes, except VPWienIntOrFpInduction + // Handle header phi recipes, except VPWidenIntOrFpInduction // which needs special handling due it being possibly truncated. // TODO: consider inferring/caching type of siblings, e.g., // backedge value, here and in cases below. @@ -217,21 +256,66 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) { }) .Case<VPWidenIntOrFpInductionRecipe, VPDerivedIVRecipe>( [](const auto *R) { return R->getScalarType(); }) - .Case<VPPredInstPHIRecipe, VPWidenPHIRecipe, VPScalarIVStepsRecipe, - VPWidenGEPRecipe>([this](const VPRecipeBase *R) { + .Case<VPReductionRecipe, VPPredInstPHIRecipe, VPWidenPHIRecipe, + VPScalarIVStepsRecipe, VPWidenGEPRecipe, VPVectorPointerRecipe, + VPWidenCanonicalIVRecipe>([this](const VPRecipeBase *R) { return inferScalarType(R->getOperand(0)); }) .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe, - VPWidenCallRecipe, VPWidenMemoryInstructionRecipe, - VPWidenSelectRecipe>( + VPWidenCallRecipe, VPWidenMemoryRecipe, VPWidenSelectRecipe>( [this](const auto *R) { return inferScalarTypeForRecipe(R); }) .Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) { // TODO: Use info from interleave group. return V->getUnderlyingValue()->getType(); }) .Case<VPWidenCastRecipe>( - [](const VPWidenCastRecipe *R) { return R->getResultType(); }); + [](const VPWidenCastRecipe *R) { return R->getResultType(); }) + .Case<VPScalarCastRecipe>( + [](const VPScalarCastRecipe *R) { return R->getResultType(); }) + .Case<VPExpandSCEVRecipe>([](const VPExpandSCEVRecipe *R) { + return R->getSCEV()->getType(); + }) + .Case<VPReductionRecipe>([this](const auto *R) { + return inferScalarType(R->getChainOp()); + }); + assert(ResultTy && "could not infer type for the given VPValue"); CachedTypes[V] = ResultTy; return ResultTy; } + +void llvm::collectEphemeralRecipesForVPlan( + VPlan &Plan, DenseSet<VPRecipeBase *> &EphRecipes) { + // First, collect seed recipes which are operands of assumes. + SmallVector<VPRecipeBase *> Worklist; + for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( + vp_depth_first_deep(Plan.getVectorLoopRegion()->getEntry()))) { + for (VPRecipeBase &R : *VPBB) { + auto *RepR = dyn_cast<VPReplicateRecipe>(&R); + if (!RepR || !match(RepR->getUnderlyingInstr(), + PatternMatch::m_Intrinsic<Intrinsic::assume>())) + continue; + Worklist.push_back(RepR); + EphRecipes.insert(RepR); + } + } + + // Process operands of candidates in worklist and add them to the set of + // ephemeral recipes, if they don't have side-effects and are only used by + // other ephemeral recipes. + while (!Worklist.empty()) { + VPRecipeBase *Cur = Worklist.pop_back_val(); + for (VPValue *Op : Cur->operands()) { + auto *OpR = Op->getDefiningRecipe(); + if (!OpR || OpR->mayHaveSideEffects() || EphRecipes.contains(OpR)) + continue; + if (any_of(Op->users(), [EphRecipes](VPUser *U) { + auto *UR = dyn_cast<VPRecipeBase>(U); + return !UR || !EphRecipes.contains(UR); + })) + continue; + EphRecipes.insert(OpR); + Worklist.push_back(OpR); + } + } +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h index 7276641551ae..438364efc629 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h @@ -10,6 +10,7 @@ #define LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" namespace llvm { @@ -20,9 +21,11 @@ class VPInstruction; class VPWidenRecipe; class VPWidenCallRecipe; class VPWidenIntOrFpInductionRecipe; -class VPWidenMemoryInstructionRecipe; +class VPWidenMemoryRecipe; struct VPWidenSelectRecipe; class VPReplicateRecipe; +class VPRecipeBase; +class VPlan; class Type; /// An analysis for type-inference for VPValues. @@ -35,6 +38,10 @@ class Type; /// of the previously inferred types. class VPTypeAnalysis { DenseMap<const VPValue *, Type *> CachedTypes; + /// Type of the canonical induction variable. Used for all VPValues without + /// any underlying IR value (like the vector trip count or the backedge-taken + /// count). + Type *CanonicalIVTy; LLVMContext &Ctx; Type *inferScalarTypeForRecipe(const VPBlendRecipe *R); @@ -42,12 +49,13 @@ class VPTypeAnalysis { Type *inferScalarTypeForRecipe(const VPWidenCallRecipe *R); Type *inferScalarTypeForRecipe(const VPWidenRecipe *R); Type *inferScalarTypeForRecipe(const VPWidenIntOrFpInductionRecipe *R); - Type *inferScalarTypeForRecipe(const VPWidenMemoryInstructionRecipe *R); + Type *inferScalarTypeForRecipe(const VPWidenMemoryRecipe *R); Type *inferScalarTypeForRecipe(const VPWidenSelectRecipe *R); Type *inferScalarTypeForRecipe(const VPReplicateRecipe *R); public: - VPTypeAnalysis(LLVMContext &Ctx) : Ctx(Ctx) {} + VPTypeAnalysis(Type *CanonicalIVTy, LLVMContext &Ctx) + : CanonicalIVTy(CanonicalIVTy), Ctx(Ctx) {} /// Infer the type of \p V. Returns the scalar type of \p V. Type *inferScalarType(const VPValue *V); @@ -56,6 +64,9 @@ public: LLVMContext &getContext() { return Ctx; } }; +// Collect a VPlan's ephemeral recipes (those used only by an assume). +void collectEphemeralRecipesForVPlan(VPlan &Plan, + DenseSet<VPRecipeBase *> &EphRecipes); } // end namespace llvm #endif // LLVM_TRANSFORMS_VECTORIZE_VPLANANALYSIS_H diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp index 94456bf858d9..6e633739fcc3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp @@ -272,7 +272,7 @@ VPValue *PlainCFGBuilder::getOrCreateVPOperand(Value *IRVal) { // A and B: Create VPValue and add it to the pool of external definitions and // to the Value->VPValue map. - VPValue *NewVPVal = Plan.getVPValueOrAddLiveIn(IRVal); + VPValue *NewVPVal = Plan.getOrAddLiveIn(IRVal); IRDef2VPValue[IRVal] = NewVPVal; return NewVPVal; } @@ -296,8 +296,7 @@ void PlainCFGBuilder::createVPInstructionsForVPBB(VPBasicBlock *VPBB, // recipes. if (Br->isConditional()) { VPValue *Cond = getOrCreateVPOperand(Br->getCondition()); - VPBB->appendRecipe( - new VPInstruction(VPInstruction::BranchOnCond, {Cond})); + VPIRBuilder.createNaryOp(VPInstruction::BranchOnCond, {Cond}, Inst); } // Skip the rest of the Instruction processing for Branch instructions. @@ -347,9 +346,24 @@ void PlainCFGBuilder::buildPlainCFG() { // latter. BB2VPBB[ThePreheaderBB] = VectorPreheaderVPBB; BasicBlock *LoopExitBB = TheLoop->getUniqueExitBlock(); + Loop2Region[LI->getLoopFor(TheLoop->getHeader())] = TheRegion; assert(LoopExitBB && "Loops with multiple exits are not supported."); BB2VPBB[LoopExitBB] = cast<VPBasicBlock>(TheRegion->getSingleSuccessor()); + // The existing vector region's entry and exiting VPBBs correspond to the loop + // header and latch. + VPBasicBlock *VectorHeaderVPBB = TheRegion->getEntryBasicBlock(); + VPBasicBlock *VectorLatchVPBB = TheRegion->getExitingBasicBlock(); + BB2VPBB[TheLoop->getHeader()] = VectorHeaderVPBB; + VectorHeaderVPBB->clearSuccessors(); + VectorLatchVPBB->clearPredecessors(); + if (TheLoop->getHeader() != TheLoop->getLoopLatch()) { + BB2VPBB[TheLoop->getLoopLatch()] = VectorLatchVPBB; + } else { + TheRegion->setExiting(VectorHeaderVPBB); + delete VectorLatchVPBB; + } + // 1. Scan the body of the loop in a topological order to visit each basic // block after having visited its predecessor basic blocks. Create a VPBB for // each BB and link it to its successor and predecessor VPBBs. Note that @@ -362,7 +376,7 @@ void PlainCFGBuilder::buildPlainCFG() { for (auto &I : *ThePreheaderBB) { if (I.getType()->isVoidTy()) continue; - IRDef2VPValue[&I] = Plan.getVPValueOrAddLiveIn(&I); + IRDef2VPValue[&I] = Plan.getOrAddLiveIn(&I); } LoopBlocksRPO RPO(TheLoop); @@ -414,10 +428,11 @@ void PlainCFGBuilder::buildPlainCFG() { // of VPBB and it should be set to the exit, i.e., non-header successor, // except for the top region, whose successor was set when creating VPlan's // skeleton. - if (TheRegion != Region) + if (TheRegion != Region) { Region->setOneSuccessor(isHeaderVPBB(Successor0) ? Successor1 : Successor0); - Region->setExiting(VPBB); + Region->setExiting(VPBB); + } } // 2. The whole CFG has been built at this point so all the input Values must @@ -437,9 +452,6 @@ void VPlanHCFGBuilder::buildHierarchicalCFG() { buildPlainCFG(); LLVM_DEBUG(Plan.setName("HCFGBuilder: Plain CFG\n"); dbgs() << Plan); - VPRegionBlock *TopRegion = Plan.getVectorLoopRegion(); - Verifier.verifyHierarchicalCFG(TopRegion); - // Compute plain CFG dom tree for VPLInfo. VPDomTree.recalculate(Plan); LLVM_DEBUG(dbgs() << "Dominator Tree after building the plain CFG.\n"; diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h index 299ae36155cb..9e8f9f3f4002 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanHCFGBuilder.h @@ -25,7 +25,6 @@ #define LLVM_TRANSFORMS_VECTORIZE_VPLAN_VPLANHCFGBUILDER_H #include "VPlanDominatorTree.h" -#include "VPlanVerifier.h" namespace llvm { @@ -49,9 +48,6 @@ private: // The VPlan that will contain the H-CFG we are building. VPlan &Plan; - // VPlan verifier utility. - VPlanVerifier Verifier; - // Dominator analysis for VPlan plain CFG to be used in the // construction of the H-CFG. This analysis is no longer valid once regions // are introduced. diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h new file mode 100644 index 000000000000..9cd7712624ba --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h @@ -0,0 +1,349 @@ +//===- VPlanPatternMatch.h - Match on VPValues and recipes ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides a simple and efficient mechanism for performing general +// tree-based pattern matches on the VPlan values and recipes, based on +// LLVM's IR pattern matchers. +// +// Currently it provides generic matchers for unary and binary VPInstructions, +// and specialized matchers like m_Not, m_ActiveLaneMask, m_BranchOnCond, +// m_BranchOnCount to match specific VPInstructions. +// TODO: Add missing matchers for additional opcodes and recipes as needed. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H +#define LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H + +#include "VPlan.h" + +namespace llvm { +namespace VPlanPatternMatch { + +template <typename Val, typename Pattern> bool match(Val *V, const Pattern &P) { + return const_cast<Pattern &>(P).match(V); +} + +template <typename Class> struct class_match { + template <typename ITy> bool match(ITy *V) { return isa<Class>(V); } +}; + +/// Match an arbitrary VPValue and ignore it. +inline class_match<VPValue> m_VPValue() { return class_match<VPValue>(); } + +template <typename Class> struct bind_ty { + Class *&VR; + + bind_ty(Class *&V) : VR(V) {} + + template <typename ITy> bool match(ITy *V) { + if (auto *CV = dyn_cast<Class>(V)) { + VR = CV; + return true; + } + return false; + } +}; + +/// Match a specified integer value or vector of all elements of that +/// value. \p BitWidth optionally specifies the bitwidth the matched constant +/// must have. If it is 0, the matched constant can have any bitwidth. +template <unsigned BitWidth = 0> struct specific_intval { + APInt Val; + + specific_intval(APInt V) : Val(std::move(V)) {} + + bool match(VPValue *VPV) { + if (!VPV->isLiveIn()) + return false; + Value *V = VPV->getLiveInIRValue(); + const auto *CI = dyn_cast<ConstantInt>(V); + if (!CI && V->getType()->isVectorTy()) + if (const auto *C = dyn_cast<Constant>(V)) + CI = dyn_cast_or_null<ConstantInt>( + C->getSplatValue(/*AllowPoison=*/false)); + if (!CI) + return false; + + assert((BitWidth == 0 || CI->getBitWidth() == BitWidth) && + "Trying the match constant with unexpected bitwidth."); + return APInt::isSameValue(CI->getValue(), Val); + } +}; + +inline specific_intval<0> m_SpecificInt(uint64_t V) { + return specific_intval<0>(APInt(64, V)); +} + +inline specific_intval<1> m_False() { return specific_intval<1>(APInt(64, 0)); } + +/// Matching combinators +template <typename LTy, typename RTy> struct match_combine_or { + LTy L; + RTy R; + + match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} + + template <typename ITy> bool match(ITy *V) { + if (L.match(V)) + return true; + if (R.match(V)) + return true; + return false; + } +}; + +template <typename LTy, typename RTy> +inline match_combine_or<LTy, RTy> m_CombineOr(const LTy &L, const RTy &R) { + return match_combine_or<LTy, RTy>(L, R); +} + +/// Match a VPValue, capturing it if we match. +inline bind_ty<VPValue> m_VPValue(VPValue *&V) { return V; } + +namespace detail { + +/// A helper to match an opcode against multiple recipe types. +template <unsigned Opcode, typename...> struct MatchRecipeAndOpcode {}; + +template <unsigned Opcode, typename RecipeTy> +struct MatchRecipeAndOpcode<Opcode, RecipeTy> { + static bool match(const VPRecipeBase *R) { + auto *DefR = dyn_cast<RecipeTy>(R); + return DefR && DefR->getOpcode() == Opcode; + } +}; + +template <unsigned Opcode, typename RecipeTy, typename... RecipeTys> +struct MatchRecipeAndOpcode<Opcode, RecipeTy, RecipeTys...> { + static bool match(const VPRecipeBase *R) { + return MatchRecipeAndOpcode<Opcode, RecipeTy>::match(R) || + MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R); + } +}; +} // namespace detail + +template <typename Op0_t, unsigned Opcode, typename... RecipeTys> +struct UnaryRecipe_match { + Op0_t Op0; + + UnaryRecipe_match(Op0_t Op0) : Op0(Op0) {} + + bool match(const VPValue *V) { + auto *DefR = V->getDefiningRecipe(); + return DefR && match(DefR); + } + + bool match(const VPRecipeBase *R) { + if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R)) + return false; + assert(R->getNumOperands() == 1 && + "recipe with matched opcode does not have 1 operands"); + return Op0.match(R->getOperand(0)); + } +}; + +template <typename Op0_t, unsigned Opcode> +using UnaryVPInstruction_match = + UnaryRecipe_match<Op0_t, Opcode, VPInstruction>; + +template <typename Op0_t, unsigned Opcode> +using AllUnaryRecipe_match = + UnaryRecipe_match<Op0_t, Opcode, VPWidenRecipe, VPReplicateRecipe, + VPWidenCastRecipe, VPInstruction>; + +template <typename Op0_t, typename Op1_t, unsigned Opcode, bool Commutative, + typename... RecipeTys> +struct BinaryRecipe_match { + Op0_t Op0; + Op1_t Op1; + + BinaryRecipe_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {} + + bool match(const VPValue *V) { + auto *DefR = V->getDefiningRecipe(); + return DefR && match(DefR); + } + + bool match(const VPSingleDefRecipe *R) { + return match(static_cast<const VPRecipeBase *>(R)); + } + + bool match(const VPRecipeBase *R) { + if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R)) + return false; + assert(R->getNumOperands() == 2 && + "recipe with matched opcode does not have 2 operands"); + if (Op0.match(R->getOperand(0)) && Op1.match(R->getOperand(1))) + return true; + return Commutative && Op0.match(R->getOperand(1)) && + Op1.match(R->getOperand(0)); + } +}; + +template <typename Op0_t, typename Op1_t, unsigned Opcode> +using BinaryVPInstruction_match = + BinaryRecipe_match<Op0_t, Op1_t, Opcode, /*Commutative*/ false, + VPInstruction>; + +template <typename Op0_t, typename Op1_t, unsigned Opcode, + bool Commutative = false> +using AllBinaryRecipe_match = + BinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative, VPWidenRecipe, + VPReplicateRecipe, VPWidenCastRecipe, VPInstruction>; + +template <unsigned Opcode, typename Op0_t> +inline UnaryVPInstruction_match<Op0_t, Opcode> +m_VPInstruction(const Op0_t &Op0) { + return UnaryVPInstruction_match<Op0_t, Opcode>(Op0); +} + +template <unsigned Opcode, typename Op0_t, typename Op1_t> +inline BinaryVPInstruction_match<Op0_t, Op1_t, Opcode> +m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1) { + return BinaryVPInstruction_match<Op0_t, Op1_t, Opcode>(Op0, Op1); +} + +template <typename Op0_t> +inline UnaryVPInstruction_match<Op0_t, VPInstruction::Not> +m_Not(const Op0_t &Op0) { + return m_VPInstruction<VPInstruction::Not>(Op0); +} + +template <typename Op0_t> +inline UnaryVPInstruction_match<Op0_t, VPInstruction::BranchOnCond> +m_BranchOnCond(const Op0_t &Op0) { + return m_VPInstruction<VPInstruction::BranchOnCond>(Op0); +} + +template <typename Op0_t, typename Op1_t> +inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::ActiveLaneMask> +m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1) { + return m_VPInstruction<VPInstruction::ActiveLaneMask>(Op0, Op1); +} + +template <typename Op0_t, typename Op1_t> +inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::BranchOnCount> +m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) { + return m_VPInstruction<VPInstruction::BranchOnCount>(Op0, Op1); +} + +template <unsigned Opcode, typename Op0_t> +inline AllUnaryRecipe_match<Op0_t, Opcode> m_Unary(const Op0_t &Op0) { + return AllUnaryRecipe_match<Op0_t, Opcode>(Op0); +} + +template <typename Op0_t> +inline AllUnaryRecipe_match<Op0_t, Instruction::Trunc> +m_Trunc(const Op0_t &Op0) { + return m_Unary<Instruction::Trunc, Op0_t>(Op0); +} + +template <typename Op0_t> +inline AllUnaryRecipe_match<Op0_t, Instruction::ZExt> m_ZExt(const Op0_t &Op0) { + return m_Unary<Instruction::ZExt, Op0_t>(Op0); +} + +template <typename Op0_t> +inline AllUnaryRecipe_match<Op0_t, Instruction::SExt> m_SExt(const Op0_t &Op0) { + return m_Unary<Instruction::SExt, Op0_t>(Op0); +} + +template <typename Op0_t> +inline match_combine_or<AllUnaryRecipe_match<Op0_t, Instruction::ZExt>, + AllUnaryRecipe_match<Op0_t, Instruction::SExt>> +m_ZExtOrSExt(const Op0_t &Op0) { + return m_CombineOr(m_ZExt(Op0), m_SExt(Op0)); +} + +template <unsigned Opcode, typename Op0_t, typename Op1_t, + bool Commutative = false> +inline AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative> +m_Binary(const Op0_t &Op0, const Op1_t &Op1) { + return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode, Commutative>(Op0, Op1); +} + +template <typename Op0_t, typename Op1_t> +inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul> +m_Mul(const Op0_t &Op0, const Op1_t &Op1) { + return m_Binary<Instruction::Mul, Op0_t, Op1_t>(Op0, Op1); +} + +template <typename Op0_t, typename Op1_t> +inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul, + /* Commutative =*/true> +m_c_Mul(const Op0_t &Op0, const Op1_t &Op1) { + return m_Binary<Instruction::Mul, Op0_t, Op1_t, true>(Op0, Op1); +} + +/// Match a binary OR operation. Note that while conceptually the operands can +/// be matched commutatively, \p Commutative defaults to false in line with the +/// IR-based pattern matching infrastructure. Use m_c_BinaryOr for a commutative +/// version of the matcher. +template <typename Op0_t, typename Op1_t, bool Commutative = false> +inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or, Commutative> +m_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) { + return m_Binary<Instruction::Or, Op0_t, Op1_t, Commutative>(Op0, Op1); +} + +template <typename Op0_t, typename Op1_t> +inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Or, + /*Commutative*/ true> +m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) { + return m_BinaryOr<Op0_t, Op1_t, /*Commutative*/ true>(Op0, Op1); +} + +template <typename Op0_t, typename Op1_t> +inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::LogicalAnd> +m_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) { + return m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1); +} + +struct VPCanonicalIVPHI_match { + bool match(const VPValue *V) { + auto *DefR = V->getDefiningRecipe(); + return DefR && match(DefR); + } + + bool match(const VPRecipeBase *R) { return isa<VPCanonicalIVPHIRecipe>(R); } +}; + +inline VPCanonicalIVPHI_match m_CanonicalIV() { + return VPCanonicalIVPHI_match(); +} + +template <typename Op0_t, typename Op1_t> struct VPScalarIVSteps_match { + Op0_t Op0; + Op1_t Op1; + + VPScalarIVSteps_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {} + + bool match(const VPValue *V) { + auto *DefR = V->getDefiningRecipe(); + return DefR && match(DefR); + } + + bool match(const VPRecipeBase *R) { + if (!isa<VPScalarIVStepsRecipe>(R)) + return false; + assert(R->getNumOperands() == 2 && + "VPScalarIVSteps must have exactly 2 operands"); + return Op0.match(R->getOperand(0)) && Op1.match(R->getOperand(1)); + } +}; + +template <typename Op0_t, typename Op1_t> +inline VPScalarIVSteps_match<Op0_t, Op1_t> m_ScalarIVSteps(const Op0_t &Op0, + const Op1_t &Op1) { + return VPScalarIVSteps_match<Op0_t, Op1_t>(Op0, Op1); +} + +} // namespace VPlanPatternMatch +} // namespace llvm + +#endif diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index ae2fc522ba40..1b787d049067 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -39,6 +39,7 @@ using VectorParts = SmallVector<Value *, 2>; namespace llvm { extern cl::opt<bool> EnableVPlanNativePath; } +extern cl::opt<unsigned> ForceTargetInstructionCost; #define LV_NAME "loop-vectorize" #define DEBUG_TYPE LV_NAME @@ -47,23 +48,29 @@ bool VPRecipeBase::mayWriteToMemory() const { switch (getVPDefID()) { case VPInterleaveSC: return cast<VPInterleaveRecipe>(this)->getNumStoreOperands() > 0; - case VPWidenMemoryInstructionSC: { - return cast<VPWidenMemoryInstructionRecipe>(this)->isStore(); - } + case VPWidenStoreEVLSC: + case VPWidenStoreSC: + return true; case VPReplicateSC: - case VPWidenCallSC: return cast<Instruction>(getVPSingleValue()->getUnderlyingValue()) ->mayWriteToMemory(); + case VPWidenCallSC: + return !cast<VPWidenCallRecipe>(this) + ->getCalledScalarFunction() + ->onlyReadsMemory(); case VPBranchOnMaskSC: case VPScalarIVStepsSC: case VPPredInstPHISC: return false; case VPBlendSC: + case VPReductionEVLSC: case VPReductionSC: case VPWidenCanonicalIVSC: case VPWidenCastSC: case VPWidenGEPSC: case VPWidenIntOrFpInductionSC: + case VPWidenLoadEVLSC: + case VPWidenLoadSC: case VPWidenPHISC: case VPWidenSC: case VPWidenSelectSC: { @@ -81,18 +88,24 @@ bool VPRecipeBase::mayWriteToMemory() const { bool VPRecipeBase::mayReadFromMemory() const { switch (getVPDefID()) { - case VPWidenMemoryInstructionSC: { - return !cast<VPWidenMemoryInstructionRecipe>(this)->isStore(); - } + case VPWidenLoadEVLSC: + case VPWidenLoadSC: + return true; case VPReplicateSC: - case VPWidenCallSC: return cast<Instruction>(getVPSingleValue()->getUnderlyingValue()) ->mayReadFromMemory(); + case VPWidenCallSC: + return !cast<VPWidenCallRecipe>(this) + ->getCalledScalarFunction() + ->onlyWritesMemory(); case VPBranchOnMaskSC: - case VPScalarIVStepsSC: case VPPredInstPHISC: + case VPScalarIVStepsSC: + case VPWidenStoreEVLSC: + case VPWidenStoreSC: return false; case VPBlendSC: + case VPReductionEVLSC: case VPReductionSC: case VPWidenCanonicalIVSC: case VPWidenCastSC: @@ -117,6 +130,7 @@ bool VPRecipeBase::mayHaveSideEffects() const { switch (getVPDefID()) { case VPDerivedIVSC: case VPPredInstPHISC: + case VPScalarCastSC: return false; case VPInstructionSC: switch (cast<VPInstruction>(this)->getOpcode()) { @@ -126,14 +140,20 @@ bool VPRecipeBase::mayHaveSideEffects() const { case VPInstruction::Not: case VPInstruction::CalculateTripCountMinusVF: case VPInstruction::CanonicalIVIncrementForPart: + case VPInstruction::ExtractFromEnd: + case VPInstruction::FirstOrderRecurrenceSplice: + case VPInstruction::LogicalAnd: + case VPInstruction::PtrAdd: return false; default: return true; } - case VPWidenCallSC: - return cast<Instruction>(getVPSingleValue()->getUnderlyingValue()) - ->mayHaveSideEffects(); + case VPWidenCallSC: { + Function *Fn = cast<VPWidenCallRecipe>(this)->getCalledScalarFunction(); + return mayWriteToMemory() || !Fn->doesNotThrow() || !Fn->willReturn(); + } case VPBlendSC: + case VPReductionEVLSC: case VPReductionSC: case VPScalarIVStepsSC: case VPWidenCanonicalIVSC: @@ -153,12 +173,15 @@ bool VPRecipeBase::mayHaveSideEffects() const { } case VPInterleaveSC: return mayWriteToMemory(); - case VPWidenMemoryInstructionSC: - assert(cast<VPWidenMemoryInstructionRecipe>(this) - ->getIngredient() - .mayHaveSideEffects() == mayWriteToMemory() && - "mayHaveSideffects result for ingredient differs from this " - "implementation"); + case VPWidenLoadEVLSC: + case VPWidenLoadSC: + case VPWidenStoreEVLSC: + case VPWidenStoreSC: + assert( + cast<VPWidenMemoryRecipe>(this)->getIngredient().mayHaveSideEffects() == + mayWriteToMemory() && + "mayHaveSideffects result for ingredient differs from this " + "implementation"); return mayWriteToMemory(); case VPReplicateSC: { auto *R = cast<VPReplicateRecipe>(this); @@ -170,17 +193,28 @@ bool VPRecipeBase::mayHaveSideEffects() const { } void VPLiveOut::fixPhi(VPlan &Plan, VPTransformState &State) { - auto Lane = VPLane::getLastLaneForVF(State.VF); VPValue *ExitValue = getOperand(0); - if (vputils::isUniformAfterVectorization(ExitValue)) - Lane = VPLane::getFirstLane(); + auto Lane = vputils::isUniformAfterVectorization(ExitValue) + ? VPLane::getFirstLane() + : VPLane::getLastLaneForVF(State.VF); VPBasicBlock *MiddleVPBB = cast<VPBasicBlock>(Plan.getVectorLoopRegion()->getSingleSuccessor()); - assert(MiddleVPBB->getNumSuccessors() == 0 && - "the middle block must not have any successors"); - BasicBlock *MiddleBB = State.CFG.VPBB2IRBB[MiddleVPBB]; - Phi->addIncoming(State.get(ExitValue, VPIteration(State.UF - 1, Lane)), - MiddleBB); + VPRecipeBase *ExitingRecipe = ExitValue->getDefiningRecipe(); + auto *ExitingVPBB = ExitingRecipe ? ExitingRecipe->getParent() : nullptr; + // Values leaving the vector loop reach live out phi's in the exiting block + // via middle block. + auto *PredVPBB = !ExitingVPBB || ExitingVPBB->getEnclosingLoopRegion() + ? MiddleVPBB + : ExitingVPBB; + BasicBlock *PredBB = State.CFG.VPBB2IRBB[PredVPBB]; + // Set insertion point in PredBB in case an extract needs to be generated. + // TODO: Model extracts explicitly. + State.Builder.SetInsertPoint(PredBB, PredBB->getFirstNonPHIIt()); + Value *V = State.get(ExitValue, VPIteration(State.UF - 1, Lane)); + if (Phi->getBasicBlockIndex(PredBB) != -1) + Phi->setIncomingValueForBlock(PredBB, V); + else + Phi->addIncoming(V, PredBB); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -197,24 +231,21 @@ void VPRecipeBase::insertBefore(VPRecipeBase *InsertPos) { assert(!Parent && "Recipe already in some VPBasicBlock"); assert(InsertPos->getParent() && "Insertion position not in any VPBasicBlock"); - Parent = InsertPos->getParent(); - Parent->getRecipeList().insert(InsertPos->getIterator(), this); + InsertPos->getParent()->insert(this, InsertPos->getIterator()); } void VPRecipeBase::insertBefore(VPBasicBlock &BB, iplist<VPRecipeBase>::iterator I) { assert(!Parent && "Recipe already in some VPBasicBlock"); assert(I == BB.end() || I->getParent() == &BB); - Parent = &BB; - BB.getRecipeList().insert(I, this); + BB.insert(this, I); } void VPRecipeBase::insertAfter(VPRecipeBase *InsertPos) { assert(!Parent && "Recipe already in some VPBasicBlock"); assert(InsertPos->getParent() && "Insertion position not in any VPBasicBlock"); - Parent = InsertPos->getParent(); - Parent->getRecipeList().insertAfter(InsertPos->getIterator(), this); + InsertPos->getParent()->insert(this, std::next(InsertPos->getIterator())); } void VPRecipeBase::removeFromParent() { @@ -239,6 +270,49 @@ void VPRecipeBase::moveBefore(VPBasicBlock &BB, insertBefore(BB, I); } +/// Return the underlying instruction to be used for computing \p R's cost via +/// the legacy cost model. Return nullptr if there's no suitable instruction. +static Instruction *getInstructionForCost(const VPRecipeBase *R) { + if (auto *S = dyn_cast<VPSingleDefRecipe>(R)) + return dyn_cast_or_null<Instruction>(S->getUnderlyingValue()); + if (auto *IG = dyn_cast<VPInterleaveRecipe>(R)) + return IG->getInsertPos(); + if (auto *WidenMem = dyn_cast<VPWidenMemoryRecipe>(R)) + return &WidenMem->getIngredient(); + return nullptr; +} + +InstructionCost VPRecipeBase::cost(ElementCount VF, VPCostContext &Ctx) { + if (auto *UI = getInstructionForCost(this)) + if (Ctx.skipCostComputation(UI, VF.isVector())) + return 0; + + InstructionCost RecipeCost = computeCost(VF, Ctx); + if (ForceTargetInstructionCost.getNumOccurrences() > 0 && + RecipeCost.isValid()) + RecipeCost = InstructionCost(ForceTargetInstructionCost); + + LLVM_DEBUG({ + dbgs() << "Cost of " << RecipeCost << " for VF " << VF << ": "; + dump(); + }); + return RecipeCost; +} + +InstructionCost VPRecipeBase::computeCost(ElementCount VF, + VPCostContext &Ctx) const { + // Compute the cost for the recipe falling back to the legacy cost model using + // the underlying instruction. If there is no underlying instruction, returns + // 0. + Instruction *UI = getInstructionForCost(this); + if (UI && isa<VPReplicateRecipe>(this)) { + // VPReplicateRecipe may be cloned as part of an existing VPlan-to-VPlan + // transform, avoid computing their cost multiple times for now. + Ctx.SkipCostComputation.insert(UI); + } + return UI ? Ctx.getLegacyCost(UI, VF) : 0; +} + FastMathFlags VPRecipeWithIRFlags::getFastMathFlags() const { assert(OpType == OperationType::FPMathOp && "recipe doesn't have fast math flags"); @@ -272,17 +346,46 @@ VPInstruction::VPInstruction(unsigned Opcode, assert(isFPMathOp() && "this op can't take fast-math flags"); } -Value *VPInstruction::generateInstruction(VPTransformState &State, - unsigned Part) { +bool VPInstruction::doesGeneratePerAllLanes() const { + return Opcode == VPInstruction::PtrAdd && !vputils::onlyFirstLaneUsed(this); +} + +bool VPInstruction::canGenerateScalarForFirstLane() const { + if (Instruction::isBinaryOp(getOpcode())) + return true; + if (isSingleScalar() || isVectorToScalar()) + return true; + switch (Opcode) { + case Instruction::ICmp: + case VPInstruction::BranchOnCond: + case VPInstruction::BranchOnCount: + case VPInstruction::CalculateTripCountMinusVF: + case VPInstruction::CanonicalIVIncrementForPart: + case VPInstruction::PtrAdd: + case VPInstruction::ExplicitVectorLength: + return true; + default: + return false; + } +} + +Value *VPInstruction::generatePerLane(VPTransformState &State, + const VPIteration &Lane) { IRBuilderBase &Builder = State.Builder; - Builder.SetCurrentDebugLocation(getDebugLoc()); - if (Instruction::isBinaryOp(getOpcode())) { - if (Part != 0 && vputils::onlyFirstPartUsed(this)) - return State.get(this, 0); + assert(getOpcode() == VPInstruction::PtrAdd && + "only PtrAdd opcodes are supported for now"); + return Builder.CreatePtrAdd(State.get(getOperand(0), Lane), + State.get(getOperand(1), Lane), Name); +} - Value *A = State.get(getOperand(0), Part); - Value *B = State.get(getOperand(1), Part); +Value *VPInstruction::generatePerPart(VPTransformState &State, unsigned Part) { + IRBuilderBase &Builder = State.Builder; + + if (Instruction::isBinaryOp(getOpcode())) { + bool OnlyFirstLaneUsed = vputils::onlyFirstLaneUsed(this); + Value *A = State.get(getOperand(0), Part, OnlyFirstLaneUsed); + Value *B = State.get(getOperand(1), Part, OnlyFirstLaneUsed); auto *Res = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B, Name); if (auto *I = dyn_cast<Instruction>(Res)) @@ -296,8 +399,9 @@ Value *VPInstruction::generateInstruction(VPTransformState &State, return Builder.CreateNot(A, Name); } case Instruction::ICmp: { - Value *A = State.get(getOperand(0), Part); - Value *B = State.get(getOperand(1), Part); + bool OnlyFirstLaneUsed = vputils::onlyFirstLaneUsed(this); + Value *A = State.get(getOperand(0), Part, OnlyFirstLaneUsed); + Value *B = State.get(getOperand(1), Part, OnlyFirstLaneUsed); return Builder.CreateCmp(getPredicate(), A, B, Name); } case Instruction::Select: { @@ -312,6 +416,12 @@ Value *VPInstruction::generateInstruction(VPTransformState &State, // Get the original loop tripcount. Value *ScalarTC = State.get(getOperand(1), VPIteration(Part, 0)); + // If this part of the active lane mask is scalar, generate the CMP directly + // to avoid unnecessary extracts. + if (State.VF.isScalar()) + return Builder.CreateCmp(CmpInst::Predicate::ICMP_ULT, VIVElem0, ScalarTC, + Name); + auto *Int1Ty = Type::getInt1Ty(Builder.getContext()); auto *PredTy = VectorType::get(Int1Ty, State.VF); return Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, @@ -340,6 +450,9 @@ Value *VPInstruction::generateInstruction(VPTransformState &State, return Builder.CreateVectorSplice(PartMinus1, V2, -1, Name); } case VPInstruction::CalculateTripCountMinusVF: { + if (Part != 0) + return State.get(this, 0, /*IsScalar*/ true); + Value *ScalarTC = State.get(getOperand(0), {0, 0}); Value *Step = createStepForVF(Builder, ScalarTC->getType(), State.VF, State.UF); @@ -348,6 +461,31 @@ Value *VPInstruction::generateInstruction(VPTransformState &State, Value *Zero = ConstantInt::get(ScalarTC->getType(), 0); return Builder.CreateSelect(Cmp, Sub, Zero); } + case VPInstruction::ExplicitVectorLength: { + // Compute EVL + auto GetEVL = [=](VPTransformState &State, Value *AVL) { + assert(AVL->getType()->isIntegerTy() && + "Requested vector length should be an integer."); + + // TODO: Add support for MaxSafeDist for correct loop emission. + assert(State.VF.isScalable() && "Expected scalable vector factor."); + Value *VFArg = State.Builder.getInt32(State.VF.getKnownMinValue()); + + Value *EVL = State.Builder.CreateIntrinsic( + State.Builder.getInt32Ty(), Intrinsic::experimental_get_vector_length, + {AVL, VFArg, State.Builder.getTrue()}); + return EVL; + }; + // TODO: Restructure this code with an explicit remainder loop, vsetvli can + // be outside of the main loop. + assert(Part == 0 && "No unrolling expected for predicated vectorization."); + // Compute VTC - IV as the AVL (requested vector length). + Value *Index = State.get(getOperand(0), VPIteration(0, 0)); + Value *TripCount = State.get(getOperand(1), VPIteration(0, 0)); + Value *AVL = State.Builder.CreateSub(TripCount, Index); + Value *EVL = GetEVL(State, AVL); + return EVL; + } case VPInstruction::CanonicalIVIncrementForPart: { auto *IV = State.get(getOperand(0), VPIteration(0, 0)); if (Part == 0) @@ -364,28 +502,28 @@ Value *VPInstruction::generateInstruction(VPTransformState &State, return nullptr; Value *Cond = State.get(getOperand(0), VPIteration(Part, 0)); - VPRegionBlock *ParentRegion = getParent()->getParent(); - VPBasicBlock *Header = ParentRegion->getEntryBasicBlock(); - // Replace the temporary unreachable terminator with a new conditional // branch, hooking it up to backward destination for exiting blocks now and // to forward destination(s) later when they are created. BranchInst *CondBr = Builder.CreateCondBr(Cond, Builder.GetInsertBlock(), nullptr); - - if (getParent()->isExiting()) - CondBr->setSuccessor(1, State.CFG.VPBB2IRBB[Header]); - CondBr->setSuccessor(0, nullptr); Builder.GetInsertBlock()->getTerminator()->eraseFromParent(); + + if (!getParent()->isExiting()) + return CondBr; + + VPRegionBlock *ParentRegion = getParent()->getParent(); + VPBasicBlock *Header = ParentRegion->getEntryBasicBlock(); + CondBr->setSuccessor(1, State.CFG.VPBB2IRBB[Header]); return CondBr; } case VPInstruction::BranchOnCount: { if (Part != 0) return nullptr; // First create the compare. - Value *IV = State.get(getOperand(0), Part); - Value *TC = State.get(getOperand(1), Part); + Value *IV = State.get(getOperand(0), Part, /*IsScalar*/ true); + Value *TC = State.get(getOperand(1), Part, /*IsScalar*/ true); Value *Cond = Builder.CreateICmpEQ(IV, TC); // Now create the branch. @@ -406,7 +544,7 @@ Value *VPInstruction::generateInstruction(VPTransformState &State, } case VPInstruction::ComputeReductionResult: { if (Part != 0) - return State.get(this, 0); + return State.get(this, 0, /*IsScalar*/ true); // FIXME: The cross-recipe dependency on VPReductionPHIRecipe is temporary // and will be removed by breaking up the recipe further. @@ -417,13 +555,11 @@ Value *VPInstruction::generateInstruction(VPTransformState &State, RecurKind RK = RdxDesc.getRecurrenceKind(); - State.setDebugLocFrom(getDebugLoc()); - VPValue *LoopExitingDef = getOperand(1); Type *PhiTy = OrigPhi->getType(); VectorParts RdxParts(State.UF); for (unsigned Part = 0; Part < State.UF; ++Part) - RdxParts[Part] = State.get(LoopExitingDef, Part); + RdxParts[Part] = State.get(LoopExitingDef, Part, PhiR->isInLoop()); // If the vector reduction can be performed in a smaller type, we truncate // then extend the loop exit value to enable InstCombine to evaluate the @@ -437,6 +573,8 @@ Value *VPInstruction::generateInstruction(VPTransformState &State, // Reduce all of the unrolled parts into a single vector. Value *ReducedPartRdx = RdxParts[0]; unsigned Op = RecurrenceDescriptor::getOpcode(RK); + if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) + Op = Instruction::Or; if (PhiR->isOrdered()) { ReducedPartRdx = RdxParts[State.UF - 1]; @@ -449,19 +587,16 @@ Value *VPInstruction::generateInstruction(VPTransformState &State, if (Op != Instruction::ICmp && Op != Instruction::FCmp) ReducedPartRdx = Builder.CreateBinOp( (Instruction::BinaryOps)Op, RdxPart, ReducedPartRdx, "bin.rdx"); - else if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) { - TrackingVH<Value> ReductionStartValue = - RdxDesc.getRecurrenceStartValue(); - ReducedPartRdx = createAnyOfOp(Builder, ReductionStartValue, RK, - ReducedPartRdx, RdxPart); - } else + else ReducedPartRdx = createMinMaxOp(Builder, RK, ReducedPartRdx, RdxPart); } } // Create the reduction after the loop. Note that inloop reductions create // the target reduction in the loop using a Reduction recipe. - if (State.VF.isVector() && !PhiR->isInLoop()) { + if ((State.VF.isVector() || + RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) && + !PhiR->isInLoop()) { ReducedPartRdx = createTargetReduction(Builder, RdxDesc, ReducedPartRdx, OrigPhi); // If the reduction can be performed in a smaller type, we need to extend @@ -482,11 +617,77 @@ Value *VPInstruction::generateInstruction(VPTransformState &State, return ReducedPartRdx; } + case VPInstruction::ExtractFromEnd: { + if (Part != 0) + return State.get(this, 0, /*IsScalar*/ true); + + auto *CI = cast<ConstantInt>(getOperand(1)->getLiveInIRValue()); + unsigned Offset = CI->getZExtValue(); + assert(Offset > 0 && "Offset from end must be positive"); + Value *Res; + if (State.VF.isVector()) { + assert(Offset <= State.VF.getKnownMinValue() && + "invalid offset to extract from"); + // Extract lane VF - Offset from the operand. + Res = State.get( + getOperand(0), + VPIteration(State.UF - 1, VPLane::getLaneFromEnd(State.VF, Offset))); + } else { + assert(Offset <= State.UF && "invalid offset to extract from"); + // When loop is unrolled without vectorizing, retrieve UF - Offset. + Res = State.get(getOperand(0), State.UF - Offset); + } + if (isa<ExtractElementInst>(Res)) + Res->setName(Name); + return Res; + } + case VPInstruction::LogicalAnd: { + Value *A = State.get(getOperand(0), Part); + Value *B = State.get(getOperand(1), Part); + return Builder.CreateLogicalAnd(A, B, Name); + } + case VPInstruction::PtrAdd: { + assert(vputils::onlyFirstLaneUsed(this) && + "can only generate first lane for PtrAdd"); + Value *Ptr = State.get(getOperand(0), Part, /* IsScalar */ true); + Value *Addend = State.get(getOperand(1), Part, /* IsScalar */ true); + return Builder.CreatePtrAdd(Ptr, Addend, Name); + } + case VPInstruction::ResumePhi: { + if (Part != 0) + return State.get(this, 0, /*IsScalar*/ true); + Value *IncomingFromVPlanPred = + State.get(getOperand(0), Part, /* IsScalar */ true); + Value *IncomingFromOtherPreds = + State.get(getOperand(1), Part, /* IsScalar */ true); + auto *NewPhi = + Builder.CreatePHI(IncomingFromOtherPreds->getType(), 2, Name); + BasicBlock *VPlanPred = + State.CFG + .VPBB2IRBB[cast<VPBasicBlock>(getParent()->getSinglePredecessor())]; + NewPhi->addIncoming(IncomingFromVPlanPred, VPlanPred); + for (auto *OtherPred : predecessors(Builder.GetInsertBlock())) { + assert(OtherPred != VPlanPred && + "VPlan predecessors should not be connected yet"); + NewPhi->addIncoming(IncomingFromOtherPreds, OtherPred); + } + return NewPhi; + } + default: llvm_unreachable("Unsupported opcode for instruction"); } } +bool VPInstruction::isVectorToScalar() const { + return getOpcode() == VPInstruction::ExtractFromEnd || + getOpcode() == VPInstruction::ComputeReductionResult; +} + +bool VPInstruction::isSingleScalar() const { + return getOpcode() == VPInstruction::ResumePhi; +} + #if !defined(NDEBUG) bool VPInstruction::isFPMathOp() const { // Inspired by FPMathOperator::classof. Notable differences are that we don't @@ -506,15 +707,86 @@ void VPInstruction::execute(VPTransformState &State) { "Recipe not a FPMathOp but has fast-math flags?"); if (hasFastMathFlags()) State.Builder.setFastMathFlags(getFastMathFlags()); + State.setDebugLocFrom(getDebugLoc()); + bool GeneratesPerFirstLaneOnly = canGenerateScalarForFirstLane() && + (vputils::onlyFirstLaneUsed(this) || + isVectorToScalar() || isSingleScalar()); + bool GeneratesPerAllLanes = doesGeneratePerAllLanes(); + bool OnlyFirstPartUsed = vputils::onlyFirstPartUsed(this); for (unsigned Part = 0; Part < State.UF; ++Part) { - Value *GeneratedValue = generateInstruction(State, Part); + if (GeneratesPerAllLanes) { + for (unsigned Lane = 0, NumLanes = State.VF.getKnownMinValue(); + Lane != NumLanes; ++Lane) { + Value *GeneratedValue = generatePerLane(State, VPIteration(Part, Lane)); + assert(GeneratedValue && "generatePerLane must produce a value"); + State.set(this, GeneratedValue, VPIteration(Part, Lane)); + } + continue; + } + + if (Part != 0 && OnlyFirstPartUsed && hasResult()) { + Value *Part0 = State.get(this, 0, /*IsScalar*/ GeneratesPerFirstLaneOnly); + State.set(this, Part0, Part, + /*IsScalar*/ GeneratesPerFirstLaneOnly); + continue; + } + + Value *GeneratedValue = generatePerPart(State, Part); if (!hasResult()) continue; - assert(GeneratedValue && "generateInstruction must produce a value"); - State.set(this, GeneratedValue, Part); + assert(GeneratedValue && "generatePerPart must produce a value"); + assert((GeneratedValue->getType()->isVectorTy() == + !GeneratesPerFirstLaneOnly || + State.VF.isScalar()) && + "scalar value but not only first lane defined"); + State.set(this, GeneratedValue, Part, + /*IsScalar*/ GeneratesPerFirstLaneOnly); } } +bool VPInstruction::onlyFirstLaneUsed(const VPValue *Op) const { + assert(is_contained(operands(), Op) && "Op must be an operand of the recipe"); + if (Instruction::isBinaryOp(getOpcode())) + return vputils::onlyFirstLaneUsed(this); + + switch (getOpcode()) { + default: + return false; + case Instruction::ICmp: + case VPInstruction::PtrAdd: + // TODO: Cover additional opcodes. + return vputils::onlyFirstLaneUsed(this); + case VPInstruction::ActiveLaneMask: + case VPInstruction::ExplicitVectorLength: + case VPInstruction::CalculateTripCountMinusVF: + case VPInstruction::CanonicalIVIncrementForPart: + case VPInstruction::BranchOnCount: + case VPInstruction::BranchOnCond: + case VPInstruction::ResumePhi: + return true; + }; + llvm_unreachable("switch should return"); +} + +bool VPInstruction::onlyFirstPartUsed(const VPValue *Op) const { + assert(is_contained(operands(), Op) && "Op must be an operand of the recipe"); + if (Instruction::isBinaryOp(getOpcode())) + return vputils::onlyFirstPartUsed(this); + + switch (getOpcode()) { + default: + return false; + case Instruction::ICmp: + case Instruction::Select: + return vputils::onlyFirstPartUsed(this); + case VPInstruction::BranchOnCount: + case VPInstruction::BranchOnCond: + case VPInstruction::CanonicalIVIncrementForPart: + return true; + }; + llvm_unreachable("switch should return"); +} + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPInstruction::dump() const { VPSlotTracker SlotTracker(getParent()->getPlan()); @@ -543,6 +815,12 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent, case VPInstruction::ActiveLaneMask: O << "active lane mask"; break; + case VPInstruction::ResumePhi: + O << "resume-phi"; + break; + case VPInstruction::ExplicitVectorLength: + O << "EXPLICIT-VECTOR-LENGTH"; + break; case VPInstruction::FirstOrderRecurrenceSplice: O << "first-order splice"; break; @@ -558,9 +836,18 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent, case VPInstruction::BranchOnCount: O << "branch-on-count"; break; + case VPInstruction::ExtractFromEnd: + O << "extract-from-end"; + break; case VPInstruction::ComputeReductionResult: O << "compute-reduction-result"; break; + case VPInstruction::LogicalAnd: + O << "logical-and"; + break; + case VPInstruction::PtrAdd: + O << "ptradd"; + break; default: O << Instruction::getOpcodeName(getOpcode()); } @@ -577,8 +864,8 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent, void VPWidenCallRecipe::execute(VPTransformState &State) { assert(State.VF.isVector() && "not widening"); - auto &CI = *cast<CallInst>(getUnderlyingInstr()); - assert(!isa<DbgInfoIntrinsic>(CI) && + Function *CalledScalarFn = getCalledScalarFunction(); + assert(!isDbgInfoIntrinsic(CalledScalarFn->getIntrinsicID()) && "DbgInfoIntrinsic should have been dropped during VPlan construction"); State.setDebugLocFrom(getDebugLoc()); @@ -591,10 +878,10 @@ void VPWidenCallRecipe::execute(VPTransformState &State) { // Add return type if intrinsic is overloaded on it. if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, -1)) - TysForDecl.push_back( - VectorType::get(CI.getType()->getScalarType(), State.VF)); + TysForDecl.push_back(VectorType::get( + CalledScalarFn->getReturnType()->getScalarType(), State.VF)); SmallVector<Value *, 4> Args; - for (const auto &I : enumerate(operands())) { + for (const auto &I : enumerate(arg_operands())) { // Some intrinsics have a scalar argument - don't replace it with a // vector. Value *Arg; @@ -627,15 +914,19 @@ void VPWidenCallRecipe::execute(VPTransformState &State) { VectorF = Variant; } + auto *CI = cast_or_null<CallInst>(getUnderlyingInstr()); SmallVector<OperandBundleDef, 1> OpBundles; - CI.getOperandBundlesAsDefs(OpBundles); + if (CI) + CI->getOperandBundlesAsDefs(OpBundles); + CallInst *V = State.Builder.CreateCall(VectorF, Args, OpBundles); if (isa<FPMathOperator>(V)) - V->copyFastMathFlags(&CI); + V->copyFastMathFlags(CI); - State.set(this, V, Part); - State.addMetadata(V, &CI); + if (!V->getType()->isVoidTy()) + State.set(this, V, Part); + State.addMetadata(V, CI); } } @@ -644,16 +935,18 @@ void VPWidenCallRecipe::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { O << Indent << "WIDEN-CALL "; - auto *CI = cast<CallInst>(getUnderlyingInstr()); - if (CI->getType()->isVoidTy()) + Function *CalledFn = getCalledScalarFunction(); + if (CalledFn->getReturnType()->isVoidTy()) O << "void "; else { printAsOperand(O, SlotTracker); O << " = "; } - O << "call @" << CI->getCalledFunction()->getName() << "("; - printOperands(O, SlotTracker); + O << "call @" << CalledFn->getName() << "("; + interleaveComma(arg_operands(), O, [&O, &SlotTracker](VPValue *Op) { + Op->printAsOperand(O, SlotTracker); + }); O << ")"; if (VectorIntrinsicID) @@ -1084,7 +1377,9 @@ bool VPWidenIntOrFpInductionRecipe::isCanonical() const { return false; auto *StepC = dyn_cast<ConstantInt>(getStepValue()->getLiveInIRValue()); auto *StartC = dyn_cast<ConstantInt>(getStartValue()->getLiveInIRValue()); - return StartC && StartC->isZero() && StepC && StepC->isOne(); + auto *CanIV = cast<VPCanonicalIVPHIRecipe>(&*getParent()->begin()); + return StartC && StartC->isZero() && StepC && StepC->isOne() && + getScalarType() == CanIV->getScalarType(); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -1095,12 +1390,9 @@ void VPDerivedIVRecipe::print(raw_ostream &O, const Twine &Indent, O << Indent << "= DERIVED-IV "; getStartValue()->printAsOperand(O, SlotTracker); O << " + "; - getCanonicalIV()->printAsOperand(O, SlotTracker); + getOperand(1)->printAsOperand(O, SlotTracker); O << " * "; getStepValue()->printAsOperand(O, SlotTracker); - - if (TruncResultTy) - O << " (truncated to " << *TruncResultTy << ")"; } #endif @@ -1119,13 +1411,7 @@ void VPScalarIVStepsRecipe::execute(VPTransformState &State) { // Ensure step has the same type as that of scalar IV. Type *BaseIVTy = BaseIV->getType()->getScalarType(); - if (BaseIVTy != Step->getType()) { - // TODO: Also use VPDerivedIVRecipe when only the step needs truncating, to - // avoid separate truncate here. - assert(Step->getType()->isIntegerTy() && - "Truncation requires an integer step"); - Step = State.Builder.CreateTrunc(Step, BaseIVTy); - } + assert(BaseIVTy == Step->getType() && "Types of BaseIV and Step must match!"); // We build scalar steps for both integer and floating-point induction // variables. Here, we determine the kind of arithmetic we will perform. @@ -1306,7 +1592,7 @@ void VPVectorPointerRecipe ::execute(VPTransformState &State) { // Use i32 for the gep index type when the value is constant, // or query DataLayout for a more suitable index type otherwise. const DataLayout &DL = - Builder.GetInsertBlock()->getModule()->getDataLayout(); + Builder.GetInsertBlock()->getDataLayout(); Type *IndexTy = State.VF.isScalable() && (IsReverse || Part > 0) ? DL.getIndexType(IndexedTy->getPointerTo()) : Builder.getInt32Ty(); @@ -1331,7 +1617,7 @@ void VPVectorPointerRecipe ::execute(VPTransformState &State) { PartPtr = Builder.CreateGEP(IndexedTy, Ptr, Increment, "", InBounds); } - State.set(this, PartPtr, Part); + State.set(this, PartPtr, Part, /*IsScalar*/ true); } } @@ -1367,24 +1653,25 @@ void VPBlendRecipe::execute(VPTransformState &State) { // Note that Mask0 is never used: lanes for which no path reaches this phi and // are essentially undef are taken from In0. VectorParts Entry(State.UF); - for (unsigned In = 0; In < NumIncoming; ++In) { - for (unsigned Part = 0; Part < State.UF; ++Part) { - // We might have single edge PHIs (blocks) - use an identity - // 'select' for the first PHI operand. - Value *In0 = State.get(getIncomingValue(In), Part); - if (In == 0) - Entry[Part] = In0; // Initialize with the first incoming value. - else { - // Select between the current value and the previous incoming edge - // based on the incoming mask. - Value *Cond = State.get(getMask(In), Part); - Entry[Part] = - State.Builder.CreateSelect(Cond, In0, Entry[Part], "predphi"); - } - } - } + bool OnlyFirstLaneUsed = vputils::onlyFirstLaneUsed(this); + for (unsigned In = 0; In < NumIncoming; ++In) { + for (unsigned Part = 0; Part < State.UF; ++Part) { + // We might have single edge PHIs (blocks) - use an identity + // 'select' for the first PHI operand. + Value *In0 = State.get(getIncomingValue(In), Part, OnlyFirstLaneUsed); + if (In == 0) + Entry[Part] = In0; // Initialize with the first incoming value. + else { + // Select between the current value and the previous incoming edge + // based on the incoming mask. + Value *Cond = State.get(getMask(In), Part, OnlyFirstLaneUsed); + Entry[Part] = + State.Builder.CreateSelect(Cond, In0, Entry[Part], "predphi"); + } + } + } for (unsigned Part = 0; Part < State.UF; ++Part) - State.set(this, Entry[Part], Part); + State.set(this, Entry[Part], Part, OnlyFirstLaneUsed); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -1402,12 +1689,105 @@ void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent, for (unsigned I = 0, E = getNumIncomingValues(); I < E; ++I) { O << " "; getIncomingValue(I)->printAsOperand(O, SlotTracker); + if (I == 0) + continue; O << "/"; getMask(I)->printAsOperand(O, SlotTracker); } } } +#endif + +void VPReductionRecipe::execute(VPTransformState &State) { + assert(!State.Instance && "Reduction being replicated."); + Value *PrevInChain = State.get(getChainOp(), 0, /*IsScalar*/ true); + RecurKind Kind = RdxDesc.getRecurrenceKind(); + // Propagate the fast-math flags carried by the underlying instruction. + IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder); + State.Builder.setFastMathFlags(RdxDesc.getFastMathFlags()); + for (unsigned Part = 0; Part < State.UF; ++Part) { + Value *NewVecOp = State.get(getVecOp(), Part); + if (VPValue *Cond = getCondOp()) { + Value *NewCond = State.get(Cond, Part, State.VF.isScalar()); + VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType()); + Type *ElementTy = VecTy ? VecTy->getElementType() : NewVecOp->getType(); + Value *Iden = RdxDesc.getRecurrenceIdentity(Kind, ElementTy, + RdxDesc.getFastMathFlags()); + if (State.VF.isVector()) { + Iden = State.Builder.CreateVectorSplat(VecTy->getElementCount(), Iden); + } + + Value *Select = State.Builder.CreateSelect(NewCond, NewVecOp, Iden); + NewVecOp = Select; + } + Value *NewRed; + Value *NextInChain; + if (IsOrdered) { + if (State.VF.isVector()) + NewRed = createOrderedReduction(State.Builder, RdxDesc, NewVecOp, + PrevInChain); + else + NewRed = State.Builder.CreateBinOp( + (Instruction::BinaryOps)RdxDesc.getOpcode(Kind), PrevInChain, + NewVecOp); + PrevInChain = NewRed; + } else { + PrevInChain = State.get(getChainOp(), Part, /*IsScalar*/ true); + NewRed = createTargetReduction(State.Builder, RdxDesc, NewVecOp); + } + if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) { + NextInChain = createMinMaxOp(State.Builder, RdxDesc.getRecurrenceKind(), + NewRed, PrevInChain); + } else if (IsOrdered) + NextInChain = NewRed; + else + NextInChain = State.Builder.CreateBinOp( + (Instruction::BinaryOps)RdxDesc.getOpcode(Kind), NewRed, PrevInChain); + State.set(this, NextInChain, Part, /*IsScalar*/ true); + } +} + +void VPReductionEVLRecipe::execute(VPTransformState &State) { + assert(!State.Instance && "Reduction being replicated."); + assert(State.UF == 1 && + "Expected only UF == 1 when vectorizing with explicit vector length."); + + auto &Builder = State.Builder; + // Propagate the fast-math flags carried by the underlying instruction. + IRBuilderBase::FastMathFlagGuard FMFGuard(Builder); + const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor(); + Builder.setFastMathFlags(RdxDesc.getFastMathFlags()); + + RecurKind Kind = RdxDesc.getRecurrenceKind(); + Value *Prev = State.get(getChainOp(), 0, /*IsScalar*/ true); + Value *VecOp = State.get(getVecOp(), 0); + Value *EVL = State.get(getEVL(), VPIteration(0, 0)); + + VectorBuilder VBuilder(Builder); + VBuilder.setEVL(EVL); + Value *Mask; + // TODO: move the all-true mask generation into VectorBuilder. + if (VPValue *CondOp = getCondOp()) + Mask = State.get(CondOp, 0); + else + Mask = Builder.CreateVectorSplat(State.VF, Builder.getTrue()); + VBuilder.setMask(Mask); + + Value *NewRed; + if (isOrdered()) { + NewRed = createOrderedReduction(VBuilder, RdxDesc, VecOp, Prev); + } else { + NewRed = createSimpleTargetReduction(VBuilder, VecOp, RdxDesc); + if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) + NewRed = createMinMaxOp(Builder, Kind, NewRed, Prev); + else + NewRed = Builder.CreateBinOp( + (Instruction::BinaryOps)RdxDesc.getOpcode(Kind), NewRed, Prev); + } + State.set(this, NewRed, 0, /*IsScalar*/ true); +} +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const { O << Indent << "REDUCE "; @@ -1419,7 +1799,31 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent, O << getUnderlyingInstr()->getFastMathFlags(); O << " reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " ("; getVecOp()->printAsOperand(O, SlotTracker); - if (getCondOp()) { + if (isConditional()) { + O << ", "; + getCondOp()->printAsOperand(O, SlotTracker); + } + O << ")"; + if (RdxDesc.IntermediateStore) + O << " (with final reduction value stored in invariant address sank " + "outside of loop)"; +} + +void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor(); + O << Indent << "REDUCE "; + printAsOperand(O, SlotTracker); + O << " = "; + getChainOp()->printAsOperand(O, SlotTracker); + O << " +"; + if (isa<FPMathOperator>(getUnderlyingInstr())) + O << getUnderlyingInstr()->getFastMathFlags(); + O << " vp.reduce." << Instruction::getOpcodeName(RdxDesc.getOpcode()) << " ("; + getVecOp()->printAsOperand(O, SlotTracker); + O << ", "; + getEVL()->printAsOperand(O, SlotTracker); + if (isConditional()) { O << ", "; getCondOp()->printAsOperand(O, SlotTracker); } @@ -1471,6 +1875,58 @@ void VPReplicateRecipe::print(raw_ostream &O, const Twine &Indent, } #endif +/// Checks if \p C is uniform across all VFs and UFs. It is considered as such +/// if it is either defined outside the vector region or its operand is known to +/// be uniform across all VFs and UFs (e.g. VPDerivedIV or VPCanonicalIVPHI). +/// TODO: Uniformity should be associated with a VPValue and there should be a +/// generic way to check. +static bool isUniformAcrossVFsAndUFs(VPScalarCastRecipe *C) { + return C->isDefinedOutsideVectorRegions() || + isa<VPDerivedIVRecipe>(C->getOperand(0)) || + isa<VPCanonicalIVPHIRecipe>(C->getOperand(0)); +} + +Value *VPScalarCastRecipe ::generate(VPTransformState &State, unsigned Part) { + assert(vputils::onlyFirstLaneUsed(this) && + "Codegen only implemented for first lane."); + switch (Opcode) { + case Instruction::SExt: + case Instruction::ZExt: + case Instruction::Trunc: { + // Note: SExt/ZExt not used yet. + Value *Op = State.get(getOperand(0), VPIteration(Part, 0)); + return State.Builder.CreateCast(Instruction::CastOps(Opcode), Op, ResultTy); + } + default: + llvm_unreachable("opcode not implemented yet"); + } +} + +void VPScalarCastRecipe ::execute(VPTransformState &State) { + bool IsUniformAcrossVFsAndUFs = isUniformAcrossVFsAndUFs(this); + for (unsigned Part = 0; Part != State.UF; ++Part) { + Value *Res; + // Only generate a single instance, if the recipe is uniform across UFs and + // VFs. + if (Part > 0 && IsUniformAcrossVFsAndUFs) + Res = State.get(this, VPIteration(0, 0)); + else + Res = generate(State, Part); + State.set(this, Res, VPIteration(Part, 0)); + } +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void VPScalarCastRecipe ::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "SCALAR-CAST "; + printAsOperand(O, SlotTracker); + O << " = " << Instruction::getOpcodeName(Opcode) << " "; + printOperands(O, SlotTracker); + O << " to " << *ResultTy; +} +#endif + void VPBranchOnMaskRecipe::execute(VPTransformState &State) { assert(State.Instance && "Branch on Mask works only on single instance."); @@ -1552,18 +2008,400 @@ void VPPredInstPHIRecipe::print(raw_ostream &O, const Twine &Indent, printOperands(O, SlotTracker); } -void VPWidenMemoryInstructionRecipe::print(raw_ostream &O, const Twine &Indent, - VPSlotTracker &SlotTracker) const { +void VPWidenLoadRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { O << Indent << "WIDEN "; + printAsOperand(O, SlotTracker); + O << " = load "; + printOperands(O, SlotTracker); +} - if (!isStore()) { - getVPSingleValue()->printAsOperand(O, SlotTracker); - O << " = "; - } - O << Instruction::getOpcodeName(Ingredient.getOpcode()) << " "; +void VPWidenLoadEVLRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "WIDEN "; + printAsOperand(O, SlotTracker); + O << " = vp.load "; + printOperands(O, SlotTracker); +} +void VPWidenStoreRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "WIDEN store "; printOperands(O, SlotTracker); } + +void VPWidenStoreEVLRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "WIDEN vp.store "; + printOperands(O, SlotTracker); +} +#endif + +static Value *createBitOrPointerCast(IRBuilderBase &Builder, Value *V, + VectorType *DstVTy, const DataLayout &DL) { + // Verify that V is a vector type with same number of elements as DstVTy. + auto VF = DstVTy->getElementCount(); + auto *SrcVecTy = cast<VectorType>(V->getType()); + assert(VF == SrcVecTy->getElementCount() && "Vector dimensions do not match"); + Type *SrcElemTy = SrcVecTy->getElementType(); + Type *DstElemTy = DstVTy->getElementType(); + assert((DL.getTypeSizeInBits(SrcElemTy) == DL.getTypeSizeInBits(DstElemTy)) && + "Vector elements must have same size"); + + // Do a direct cast if element types are castable. + if (CastInst::isBitOrNoopPointerCastable(SrcElemTy, DstElemTy, DL)) { + return Builder.CreateBitOrPointerCast(V, DstVTy); + } + // V cannot be directly casted to desired vector type. + // May happen when V is a floating point vector but DstVTy is a vector of + // pointers or vice-versa. Handle this using a two-step bitcast using an + // intermediate Integer type for the bitcast i.e. Ptr <-> Int <-> Float. + assert((DstElemTy->isPointerTy() != SrcElemTy->isPointerTy()) && + "Only one type should be a pointer type"); + assert((DstElemTy->isFloatingPointTy() != SrcElemTy->isFloatingPointTy()) && + "Only one type should be a floating point type"); + Type *IntTy = + IntegerType::getIntNTy(V->getContext(), DL.getTypeSizeInBits(SrcElemTy)); + auto *VecIntTy = VectorType::get(IntTy, VF); + Value *CastVal = Builder.CreateBitOrPointerCast(V, VecIntTy); + return Builder.CreateBitOrPointerCast(CastVal, DstVTy); +} + +/// Return a vector containing interleaved elements from multiple +/// smaller input vectors. +static Value *interleaveVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vals, + const Twine &Name) { + unsigned Factor = Vals.size(); + assert(Factor > 1 && "Tried to interleave invalid number of vectors"); + + VectorType *VecTy = cast<VectorType>(Vals[0]->getType()); +#ifndef NDEBUG + for (Value *Val : Vals) + assert(Val->getType() == VecTy && "Tried to interleave mismatched types"); +#endif + + // Scalable vectors cannot use arbitrary shufflevectors (only splats), so + // must use intrinsics to interleave. + if (VecTy->isScalableTy()) { + VectorType *WideVecTy = VectorType::getDoubleElementsVectorType(VecTy); + return Builder.CreateIntrinsic(WideVecTy, Intrinsic::vector_interleave2, + Vals, + /*FMFSource=*/nullptr, Name); + } + + // Fixed length. Start by concatenating all vectors into a wide vector. + Value *WideVec = concatenateVectors(Builder, Vals); + + // Interleave the elements into the wide vector. + const unsigned NumElts = VecTy->getElementCount().getFixedValue(); + return Builder.CreateShuffleVector( + WideVec, createInterleaveMask(NumElts, Factor), Name); +} + +// Try to vectorize the interleave group that \p Instr belongs to. +// +// E.g. Translate following interleaved load group (factor = 3): +// for (i = 0; i < N; i+=3) { +// R = Pic[i]; // Member of index 0 +// G = Pic[i+1]; // Member of index 1 +// B = Pic[i+2]; // Member of index 2 +// ... // do something to R, G, B +// } +// To: +// %wide.vec = load <12 x i32> ; Read 4 tuples of R,G,B +// %R.vec = shuffle %wide.vec, poison, <0, 3, 6, 9> ; R elements +// %G.vec = shuffle %wide.vec, poison, <1, 4, 7, 10> ; G elements +// %B.vec = shuffle %wide.vec, poison, <2, 5, 8, 11> ; B elements +// +// Or translate following interleaved store group (factor = 3): +// for (i = 0; i < N; i+=3) { +// ... do something to R, G, B +// Pic[i] = R; // Member of index 0 +// Pic[i+1] = G; // Member of index 1 +// Pic[i+2] = B; // Member of index 2 +// } +// To: +// %R_G.vec = shuffle %R.vec, %G.vec, <0, 1, 2, ..., 7> +// %B_U.vec = shuffle %B.vec, poison, <0, 1, 2, 3, u, u, u, u> +// %interleaved.vec = shuffle %R_G.vec, %B_U.vec, +// <0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11> ; Interleave R,G,B elements +// store <12 x i32> %interleaved.vec ; Write 4 tuples of R,G,B +void VPInterleaveRecipe::execute(VPTransformState &State) { + assert(!State.Instance && "Interleave group being replicated."); + const InterleaveGroup<Instruction> *Group = IG; + Instruction *Instr = Group->getInsertPos(); + + // Prepare for the vector type of the interleaved load/store. + Type *ScalarTy = getLoadStoreType(Instr); + unsigned InterleaveFactor = Group->getFactor(); + auto *VecTy = VectorType::get(ScalarTy, State.VF * InterleaveFactor); + + // Prepare for the new pointers. + SmallVector<Value *, 2> AddrParts; + unsigned Index = Group->getIndex(Instr); + + // TODO: extend the masked interleaved-group support to reversed access. + VPValue *BlockInMask = getMask(); + assert((!BlockInMask || !Group->isReverse()) && + "Reversed masked interleave-group not supported."); + + Value *Idx; + // If the group is reverse, adjust the index to refer to the last vector lane + // instead of the first. We adjust the index from the first vector lane, + // rather than directly getting the pointer for lane VF - 1, because the + // pointer operand of the interleaved access is supposed to be uniform. For + // uniform instructions, we're only required to generate a value for the + // first vector lane in each unroll iteration. + if (Group->isReverse()) { + Value *RuntimeVF = + getRuntimeVF(State.Builder, State.Builder.getInt32Ty(), State.VF); + Idx = State.Builder.CreateSub(RuntimeVF, State.Builder.getInt32(1)); + Idx = State.Builder.CreateMul(Idx, + State.Builder.getInt32(Group->getFactor())); + Idx = State.Builder.CreateAdd(Idx, State.Builder.getInt32(Index)); + Idx = State.Builder.CreateNeg(Idx); + } else + Idx = State.Builder.getInt32(-Index); + + VPValue *Addr = getAddr(); + for (unsigned Part = 0; Part < State.UF; Part++) { + Value *AddrPart = State.get(Addr, VPIteration(Part, 0)); + if (auto *I = dyn_cast<Instruction>(AddrPart)) + State.setDebugLocFrom(I->getDebugLoc()); + + // Notice current instruction could be any index. Need to adjust the address + // to the member of index 0. + // + // E.g. a = A[i+1]; // Member of index 1 (Current instruction) + // b = A[i]; // Member of index 0 + // Current pointer is pointed to A[i+1], adjust it to A[i]. + // + // E.g. A[i+1] = a; // Member of index 1 + // A[i] = b; // Member of index 0 + // A[i+2] = c; // Member of index 2 (Current instruction) + // Current pointer is pointed to A[i+2], adjust it to A[i]. + + bool InBounds = false; + if (auto *gep = dyn_cast<GetElementPtrInst>(AddrPart->stripPointerCasts())) + InBounds = gep->isInBounds(); + AddrPart = State.Builder.CreateGEP(ScalarTy, AddrPart, Idx, "", InBounds); + AddrParts.push_back(AddrPart); + } + + State.setDebugLocFrom(Instr->getDebugLoc()); + Value *PoisonVec = PoisonValue::get(VecTy); + + auto CreateGroupMask = [&BlockInMask, &State, &InterleaveFactor]( + unsigned Part, Value *MaskForGaps) -> Value * { + if (State.VF.isScalable()) { + assert(!MaskForGaps && "Interleaved groups with gaps are not supported."); + assert(InterleaveFactor == 2 && + "Unsupported deinterleave factor for scalable vectors"); + auto *BlockInMaskPart = State.get(BlockInMask, Part); + SmallVector<Value *, 2> Ops = {BlockInMaskPart, BlockInMaskPart}; + auto *MaskTy = VectorType::get(State.Builder.getInt1Ty(), + State.VF.getKnownMinValue() * 2, true); + return State.Builder.CreateIntrinsic( + MaskTy, Intrinsic::vector_interleave2, Ops, + /*FMFSource=*/nullptr, "interleaved.mask"); + } + + if (!BlockInMask) + return MaskForGaps; + + Value *BlockInMaskPart = State.get(BlockInMask, Part); + Value *ShuffledMask = State.Builder.CreateShuffleVector( + BlockInMaskPart, + createReplicatedMask(InterleaveFactor, State.VF.getKnownMinValue()), + "interleaved.mask"); + return MaskForGaps ? State.Builder.CreateBinOp(Instruction::And, + ShuffledMask, MaskForGaps) + : ShuffledMask; + }; + + const DataLayout &DL = Instr->getDataLayout(); + // Vectorize the interleaved load group. + if (isa<LoadInst>(Instr)) { + Value *MaskForGaps = nullptr; + if (NeedsMaskForGaps) { + MaskForGaps = createBitMaskForGaps(State.Builder, + State.VF.getKnownMinValue(), *Group); + assert(MaskForGaps && "Mask for Gaps is required but it is null"); + } + + // For each unroll part, create a wide load for the group. + SmallVector<Value *, 2> NewLoads; + for (unsigned Part = 0; Part < State.UF; Part++) { + Instruction *NewLoad; + if (BlockInMask || MaskForGaps) { + Value *GroupMask = CreateGroupMask(Part, MaskForGaps); + NewLoad = State.Builder.CreateMaskedLoad(VecTy, AddrParts[Part], + Group->getAlign(), GroupMask, + PoisonVec, "wide.masked.vec"); + } else + NewLoad = State.Builder.CreateAlignedLoad( + VecTy, AddrParts[Part], Group->getAlign(), "wide.vec"); + Group->addMetadata(NewLoad); + NewLoads.push_back(NewLoad); + } + + ArrayRef<VPValue *> VPDefs = definedValues(); + const DataLayout &DL = State.CFG.PrevBB->getDataLayout(); + if (VecTy->isScalableTy()) { + assert(InterleaveFactor == 2 && + "Unsupported deinterleave factor for scalable vectors"); + + for (unsigned Part = 0; Part < State.UF; ++Part) { + // Scalable vectors cannot use arbitrary shufflevectors (only splats), + // so must use intrinsics to deinterleave. + Value *DI = State.Builder.CreateIntrinsic( + Intrinsic::vector_deinterleave2, VecTy, NewLoads[Part], + /*FMFSource=*/nullptr, "strided.vec"); + unsigned J = 0; + for (unsigned I = 0; I < InterleaveFactor; ++I) { + Instruction *Member = Group->getMember(I); + + if (!Member) + continue; + + Value *StridedVec = State.Builder.CreateExtractValue(DI, I); + // If this member has different type, cast the result type. + if (Member->getType() != ScalarTy) { + VectorType *OtherVTy = VectorType::get(Member->getType(), State.VF); + StridedVec = + createBitOrPointerCast(State.Builder, StridedVec, OtherVTy, DL); + } + + if (Group->isReverse()) + StridedVec = + State.Builder.CreateVectorReverse(StridedVec, "reverse"); + + State.set(VPDefs[J], StridedVec, Part); + ++J; + } + } + + return; + } + + // For each member in the group, shuffle out the appropriate data from the + // wide loads. + unsigned J = 0; + for (unsigned I = 0; I < InterleaveFactor; ++I) { + Instruction *Member = Group->getMember(I); + + // Skip the gaps in the group. + if (!Member) + continue; + + auto StrideMask = + createStrideMask(I, InterleaveFactor, State.VF.getKnownMinValue()); + for (unsigned Part = 0; Part < State.UF; Part++) { + Value *StridedVec = State.Builder.CreateShuffleVector( + NewLoads[Part], StrideMask, "strided.vec"); + + // If this member has different type, cast the result type. + if (Member->getType() != ScalarTy) { + assert(!State.VF.isScalable() && "VF is assumed to be non scalable."); + VectorType *OtherVTy = VectorType::get(Member->getType(), State.VF); + StridedVec = + createBitOrPointerCast(State.Builder, StridedVec, OtherVTy, DL); + } + + if (Group->isReverse()) + StridedVec = State.Builder.CreateVectorReverse(StridedVec, "reverse"); + + State.set(VPDefs[J], StridedVec, Part); + } + ++J; + } + return; + } + + // The sub vector type for current instruction. + auto *SubVT = VectorType::get(ScalarTy, State.VF); + + // Vectorize the interleaved store group. + Value *MaskForGaps = + createBitMaskForGaps(State.Builder, State.VF.getKnownMinValue(), *Group); + assert((!MaskForGaps || !State.VF.isScalable()) && + "masking gaps for scalable vectors is not yet supported."); + ArrayRef<VPValue *> StoredValues = getStoredValues(); + for (unsigned Part = 0; Part < State.UF; Part++) { + // Collect the stored vector from each member. + SmallVector<Value *, 4> StoredVecs; + unsigned StoredIdx = 0; + for (unsigned i = 0; i < InterleaveFactor; i++) { + assert((Group->getMember(i) || MaskForGaps) && + "Fail to get a member from an interleaved store group"); + Instruction *Member = Group->getMember(i); + + // Skip the gaps in the group. + if (!Member) { + Value *Undef = PoisonValue::get(SubVT); + StoredVecs.push_back(Undef); + continue; + } + + Value *StoredVec = State.get(StoredValues[StoredIdx], Part); + ++StoredIdx; + + if (Group->isReverse()) + StoredVec = State.Builder.CreateVectorReverse(StoredVec, "reverse"); + + // If this member has different type, cast it to a unified type. + + if (StoredVec->getType() != SubVT) + StoredVec = createBitOrPointerCast(State.Builder, StoredVec, SubVT, DL); + + StoredVecs.push_back(StoredVec); + } + + // Interleave all the smaller vectors into one wider vector. + Value *IVec = + interleaveVectors(State.Builder, StoredVecs, "interleaved.vec"); + Instruction *NewStoreInstr; + if (BlockInMask || MaskForGaps) { + Value *GroupMask = CreateGroupMask(Part, MaskForGaps); + NewStoreInstr = State.Builder.CreateMaskedStore( + IVec, AddrParts[Part], Group->getAlign(), GroupMask); + } else + NewStoreInstr = State.Builder.CreateAlignedStore(IVec, AddrParts[Part], + Group->getAlign()); + + Group->addMetadata(NewStoreInstr); + } +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void VPInterleaveRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "INTERLEAVE-GROUP with factor " << IG->getFactor() << " at "; + IG->getInsertPos()->printAsOperand(O, false); + O << ", "; + getAddr()->printAsOperand(O, SlotTracker); + VPValue *Mask = getMask(); + if (Mask) { + O << ", "; + Mask->printAsOperand(O, SlotTracker); + } + + unsigned OpIdx = 0; + for (unsigned i = 0; i < IG->getFactor(); ++i) { + if (!IG->getMember(i)) + continue; + if (getNumStoreOperands() > 0) { + O << "\n" << Indent << " store "; + getOperand(1 + OpIdx)->printAsOperand(O, SlotTracker); + O << " to index " << i; + } else { + O << "\n" << Indent << " "; + getVPValue(OpIdx)->printAsOperand(O, SlotTracker); + O << " = load from index " << i; + } + ++OpIdx; + } +} #endif void VPCanonicalIVPHIRecipe::execute(VPTransformState &State) { @@ -1575,7 +2413,7 @@ void VPCanonicalIVPHIRecipe::execute(VPTransformState &State) { EntryPart->addIncoming(Start, VectorPH); EntryPart->setDebugLoc(getDebugLoc()); for (unsigned Part = 0, UF = State.UF; Part < UF; ++Part) - State.set(this, EntryPart, Part); + State.set(this, EntryPart, Part, /*IsScalar*/ true); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -1589,10 +2427,10 @@ void VPCanonicalIVPHIRecipe::print(raw_ostream &O, const Twine &Indent, #endif bool VPCanonicalIVPHIRecipe::isCanonical( - InductionDescriptor::InductionKind Kind, VPValue *Start, VPValue *Step, - Type *Ty) const { - // The types must match and it must be an integer induction. - if (Ty != getScalarType() || Kind != InductionDescriptor::IK_IntInduction) + InductionDescriptor::InductionKind Kind, VPValue *Start, + VPValue *Step) const { + // Must be an integer induction. + if (Kind != InductionDescriptor::IK_IntInduction) return false; // Start must match the start value of this canonical induction. if (Start != getStartValue()) @@ -1606,9 +2444,9 @@ bool VPCanonicalIVPHIRecipe::isCanonical( return StepC && StepC->isOne(); } -bool VPWidenPointerInductionRecipe::onlyScalarsGenerated(ElementCount VF) { +bool VPWidenPointerInductionRecipe::onlyScalarsGenerated(bool IsScalable) { return IsScalarAfterVectorization && - (!VF.isScalable() || vputils::onlyFirstLaneUsed(this)); + (!IsScalable || vputils::onlyFirstLaneUsed(this)); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -1624,7 +2462,7 @@ void VPWidenPointerInductionRecipe::print(raw_ostream &O, const Twine &Indent, void VPExpandSCEVRecipe::execute(VPTransformState &State) { assert(!State.Instance && "cannot be used in per-lane"); - const DataLayout &DL = State.CFG.PrevBB->getModule()->getDataLayout(); + const DataLayout &DL = State.CFG.PrevBB->getDataLayout(); SCEVExpander Exp(SE, DL, "induction"); Value *Res = Exp.expandCodeFor(Expr, Expr->getType(), @@ -1646,7 +2484,7 @@ void VPExpandSCEVRecipe::print(raw_ostream &O, const Twine &Indent, #endif void VPWidenCanonicalIVRecipe::execute(VPTransformState &State) { - Value *CanonicalIV = State.get(getOperand(0), 0); + Value *CanonicalIV = State.get(getOperand(0), 0, /*IsScalar*/ true); Type *STy = CanonicalIV->getType(); IRBuilder<> Builder(State.CFG.PrevBB->getTerminator()); ElementCount VF = State.VF; @@ -1736,7 +2574,7 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) { for (unsigned Part = 0; Part < LastPartForNewPhi; ++Part) { Instruction *EntryPart = PHINode::Create(VecTy, 2, "vec.phi"); EntryPart->insertBefore(HeaderBB->getFirstInsertionPt()); - State.set(this, EntryPart, Part); + State.set(this, EntryPart, Part, IsInLoop); } BasicBlock *VectorPH = State.CFG.getPreheaderBBFor(this); @@ -1768,7 +2606,7 @@ void VPReductionPHIRecipe::execute(VPTransformState &State) { } for (unsigned Part = 0; Part < LastPartForNewPhi; ++Part) { - Value *EntryPart = State.get(this, Part); + Value *EntryPart = State.get(this, Part, IsInLoop); // Make sure to add the reduction start value only to the // first unroll part. Value *StartVal = (Part == 0) ? StartV : Iden; @@ -1842,3 +2680,25 @@ void VPActiveLaneMaskPHIRecipe::print(raw_ostream &O, const Twine &Indent, printOperands(O, SlotTracker); } #endif + +void VPEVLBasedIVPHIRecipe::execute(VPTransformState &State) { + BasicBlock *VectorPH = State.CFG.getPreheaderBBFor(this); + assert(State.UF == 1 && "Expected unroll factor 1 for VP vectorization."); + Value *Start = State.get(getOperand(0), VPIteration(0, 0)); + PHINode *EntryPart = + State.Builder.CreatePHI(Start->getType(), 2, "evl.based.iv"); + EntryPart->addIncoming(Start, VectorPH); + EntryPart->setDebugLoc(getDebugLoc()); + State.set(this, EntryPart, 0, /*IsScalar=*/true); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void VPEVLBasedIVPHIRecipe::print(raw_ostream &O, const Twine &Indent, + VPSlotTracker &SlotTracker) const { + O << Indent << "EXPLICIT-VECTOR-LENGTH-BASED-IV-PHI "; + + printAsOperand(O, SlotTracker); + O << " = phi "; + printOperands(O, SlotTracker); +} +#endif diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanSLP.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanSLP.cpp index fbcadba33e67..98ccf2169463 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanSLP.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanSLP.cpp @@ -461,7 +461,6 @@ VPInstruction *VPlanSlp::buildGraph(ArrayRef<VPValue *> Values) { assert(CombinedOperands.size() > 0 && "Need more some operands"); auto *Inst = cast<VPInstruction>(Values[0])->getUnderlyingInstr(); auto *VPI = new VPInstruction(Opcode, CombinedOperands, Inst->getDebugLoc()); - VPI->setUnderlyingInstr(Inst); LLVM_DEBUG(dbgs() << "Create VPInstruction " << *VPI << " " << *cast<VPInstruction>(Values[0]) << "\n"); diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 8e6b48cdb2c8..c91fd0f118e3 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -16,6 +16,7 @@ #include "VPlanAnalysis.h" #include "VPlanCFG.h" #include "VPlanDominatorTree.h" +#include "VPlanPatternMatch.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" @@ -26,8 +27,6 @@ using namespace llvm; -using namespace llvm::PatternMatch; - void VPlanTransforms::VPInstructionsToVPRecipes( VPlanPtr &Plan, function_ref<const InductionDescriptor *(PHINode *)> @@ -35,8 +34,11 @@ void VPlanTransforms::VPInstructionsToVPRecipes( ScalarEvolution &SE, const TargetLibraryInfo &TLI) { ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT( - Plan->getEntry()); + Plan->getVectorLoopRegion()); for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT)) { + // Skip blocks outside region + if (!VPBB->getParent()) + break; VPRecipeBase *Term = VPBB->getTerminator(); auto EndIter = Term ? Term->getIterator() : VPBB->end(); // Introduce each ingredient into VPlan. @@ -49,34 +51,35 @@ void VPlanTransforms::VPInstructionsToVPRecipes( VPRecipeBase *NewRecipe = nullptr; if (auto *VPPhi = dyn_cast<VPWidenPHIRecipe>(&Ingredient)) { auto *Phi = cast<PHINode>(VPPhi->getUnderlyingValue()); - if (const auto *II = GetIntOrFpInductionDescriptor(Phi)) { - VPValue *Start = Plan->getVPValueOrAddLiveIn(II->getStartValue()); - VPValue *Step = - vputils::getOrCreateVPValueForSCEVExpr(*Plan, II->getStep(), SE); - NewRecipe = new VPWidenIntOrFpInductionRecipe(Phi, Start, Step, *II); - } else { - Plan->addVPValue(Phi, VPPhi); + const auto *II = GetIntOrFpInductionDescriptor(Phi); + if (!II) continue; - } + + VPValue *Start = Plan->getOrAddLiveIn(II->getStartValue()); + VPValue *Step = + vputils::getOrCreateVPValueForSCEVExpr(*Plan, II->getStep(), SE); + NewRecipe = new VPWidenIntOrFpInductionRecipe(Phi, Start, Step, *II); } else { assert(isa<VPInstruction>(&Ingredient) && "only VPInstructions expected here"); assert(!isa<PHINode>(Inst) && "phis should be handled above"); - // Create VPWidenMemoryInstructionRecipe for loads and stores. + // Create VPWidenMemoryRecipe for loads and stores. if (LoadInst *Load = dyn_cast<LoadInst>(Inst)) { - NewRecipe = new VPWidenMemoryInstructionRecipe( + NewRecipe = new VPWidenLoadRecipe( *Load, Ingredient.getOperand(0), nullptr /*Mask*/, - false /*Consecutive*/, false /*Reverse*/); + false /*Consecutive*/, false /*Reverse*/, + Ingredient.getDebugLoc()); } else if (StoreInst *Store = dyn_cast<StoreInst>(Inst)) { - NewRecipe = new VPWidenMemoryInstructionRecipe( + NewRecipe = new VPWidenStoreRecipe( *Store, Ingredient.getOperand(1), Ingredient.getOperand(0), - nullptr /*Mask*/, false /*Consecutive*/, false /*Reverse*/); + nullptr /*Mask*/, false /*Consecutive*/, false /*Reverse*/, + Ingredient.getDebugLoc()); } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Inst)) { NewRecipe = new VPWidenGEPRecipe(GEP, Ingredient.operands()); } else if (CallInst *CI = dyn_cast<CallInst>(Inst)) { NewRecipe = new VPWidenCallRecipe( - *CI, drop_end(Ingredient.operands()), - getVectorIntrinsicIDForCall(CI, &TLI), CI->getDebugLoc()); + CI, Ingredient.operands(), getVectorIntrinsicIDForCall(CI, &TLI), + CI->getDebugLoc()); } else if (SelectInst *SI = dyn_cast<SelectInst>(Inst)) { NewRecipe = new VPWidenSelectRecipe(*SI, Ingredient.operands()); } else if (auto *CI = dyn_cast<CastInst>(Inst)) { @@ -157,8 +160,7 @@ static bool sinkScalarOperands(VPlan &Plan) { if (NeedsDuplicating) { if (ScalarVFOnly) continue; - Instruction *I = cast<Instruction>( - cast<VPReplicateRecipe>(SinkCandidate)->getUnderlyingValue()); + Instruction *I = SinkCandidate->getUnderlyingInstr(); auto *Clone = new VPReplicateRecipe(I, SinkCandidate->operands(), true); // TODO: add ".cloned" suffix to name of Clone's VPValue. @@ -276,6 +278,11 @@ static bool mergeReplicateRegionsIntoSuccessors(VPlan &Plan) { return UI && UI->getParent() == Then2; }); + // Remove phi recipes that are unused after merging the regions. + if (Phi1ToMove.getVPSingleValue()->getNumUsers() == 0) { + Phi1ToMove.eraseFromParent(); + continue; + } Phi1ToMove.moveBefore(*Merge2, Merge2->begin()); } @@ -357,25 +364,22 @@ static void addReplicateRegions(VPlan &Plan) { } } -void VPlanTransforms::createAndOptimizeReplicateRegions(VPlan &Plan) { - // Convert masked VPReplicateRecipes to if-then region blocks. - addReplicateRegions(Plan); - - bool ShouldSimplify = true; - while (ShouldSimplify) { - ShouldSimplify = sinkScalarOperands(Plan); - ShouldSimplify |= mergeReplicateRegionsIntoSuccessors(Plan); - ShouldSimplify |= VPlanTransforms::mergeBlocksIntoPredecessors(Plan); - } -} -bool VPlanTransforms::mergeBlocksIntoPredecessors(VPlan &Plan) { +/// Remove redundant VPBasicBlocks by merging them into their predecessor if +/// the predecessor has a single successor. +static bool mergeBlocksIntoPredecessors(VPlan &Plan) { SmallVector<VPBasicBlock *> WorkList; for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( vp_depth_first_deep(Plan.getEntry()))) { + // Don't fold the exit block of the Plan into its single predecessor for + // now. + // TODO: Remove restriction once more of the skeleton is modeled in VPlan. + if (VPBB->getNumSuccessors() == 0 && !VPBB->getParent()) + continue; auto *PredVPBB = dyn_cast_or_null<VPBasicBlock>(VPBB->getSinglePredecessor()); - if (PredVPBB && PredVPBB->getNumSuccessors() == 1) - WorkList.push_back(VPBB); + if (!PredVPBB || PredVPBB->getNumSuccessors() != 1) + continue; + WorkList.push_back(VPBB); } for (VPBasicBlock *VPBB : WorkList) { @@ -395,7 +399,25 @@ bool VPlanTransforms::mergeBlocksIntoPredecessors(VPlan &Plan) { return !WorkList.empty(); } -void VPlanTransforms::removeRedundantInductionCasts(VPlan &Plan) { +void VPlanTransforms::createAndOptimizeReplicateRegions(VPlan &Plan) { + // Convert masked VPReplicateRecipes to if-then region blocks. + addReplicateRegions(Plan); + + bool ShouldSimplify = true; + while (ShouldSimplify) { + ShouldSimplify = sinkScalarOperands(Plan); + ShouldSimplify |= mergeReplicateRegionsIntoSuccessors(Plan); + ShouldSimplify |= mergeBlocksIntoPredecessors(Plan); + } +} + +/// Remove redundant casts of inductions. +/// +/// Such redundant casts are casts of induction variables that can be ignored, +/// because we already proved that the casted phi is equal to the uncasted phi +/// in the vectorized loop. There is no need to vectorize the cast - the same +/// value can be used for both the phi and casts in the vector loop. +static void removeRedundantInductionCasts(VPlan &Plan) { for (auto &Phi : Plan.getVectorLoopRegion()->getEntryBasicBlock()->phis()) { auto *IV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&Phi); if (!IV || IV->getTruncInst()) @@ -426,7 +448,9 @@ void VPlanTransforms::removeRedundantInductionCasts(VPlan &Plan) { } } -void VPlanTransforms::removeRedundantCanonicalIVs(VPlan &Plan) { +/// Try to replace VPWidenCanonicalIVRecipes with a widened canonical IV +/// recipe, if it exists. +static void removeRedundantCanonicalIVs(VPlan &Plan) { VPCanonicalIVPHIRecipe *CanonicalIV = Plan.getCanonicalIV(); VPWidenCanonicalIVRecipe *WidenNewIV = nullptr; for (VPUser *U : CanonicalIV->users()) { @@ -442,8 +466,7 @@ void VPlanTransforms::removeRedundantCanonicalIVs(VPlan &Plan) { for (VPRecipeBase &Phi : HeaderVPBB->phis()) { auto *WidenOriginalIV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&Phi); - if (!WidenOriginalIV || !WidenOriginalIV->isCanonical() || - WidenOriginalIV->getScalarType() != WidenNewIV->getScalarType()) + if (!WidenOriginalIV || !WidenOriginalIV->isCanonical()) continue; // Replace WidenNewIV with WidenOriginalIV if WidenOriginalIV provides @@ -462,7 +485,27 @@ void VPlanTransforms::removeRedundantCanonicalIVs(VPlan &Plan) { } } -void VPlanTransforms::removeDeadRecipes(VPlan &Plan) { +/// Returns true if \p R is dead and can be removed. +static bool isDeadRecipe(VPRecipeBase &R) { + using namespace llvm::PatternMatch; + // Do remove conditional assume instructions as their conditions may be + // flattened. + auto *RepR = dyn_cast<VPReplicateRecipe>(&R); + bool IsConditionalAssume = + RepR && RepR->isPredicated() && + match(RepR->getUnderlyingInstr(), m_Intrinsic<Intrinsic::assume>()); + if (IsConditionalAssume) + return true; + + if (R.mayHaveSideEffects()) + return false; + + // Recipe is dead if no user keeps the recipe alive. + return all_of(R.definedValues(), + [](VPValue *V) { return V->getNumUsers() == 0; }); +} + +static void removeDeadRecipes(VPlan &Plan) { ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT( Plan.getEntry()); @@ -470,50 +513,99 @@ void VPlanTransforms::removeDeadRecipes(VPlan &Plan) { // The recipes in the block are processed in reverse order, to catch chains // of dead recipes. for (VPRecipeBase &R : make_early_inc_range(reverse(*VPBB))) { - // A user keeps R alive: - if (any_of(R.definedValues(), - [](VPValue *V) { return V->getNumUsers(); })) - continue; - - // Having side effects keeps R alive, but do remove conditional assume - // instructions as their conditions may be flattened. - auto *RepR = dyn_cast<VPReplicateRecipe>(&R); - bool IsConditionalAssume = - RepR && RepR->isPredicated() && - match(RepR->getUnderlyingInstr(), m_Intrinsic<Intrinsic::assume>()); - if (R.mayHaveSideEffects() && !IsConditionalAssume) - continue; - - R.eraseFromParent(); + if (isDeadRecipe(R)) + R.eraseFromParent(); } } } -static VPValue *createScalarIVSteps(VPlan &Plan, const InductionDescriptor &ID, - ScalarEvolution &SE, Instruction *TruncI, - Type *IVTy, VPValue *StartV, - VPValue *Step) { +static VPScalarIVStepsRecipe * +createScalarIVSteps(VPlan &Plan, InductionDescriptor::InductionKind Kind, + Instruction::BinaryOps InductionOpcode, + FPMathOperator *FPBinOp, ScalarEvolution &SE, + Instruction *TruncI, VPValue *StartV, VPValue *Step, + VPBasicBlock::iterator IP) { VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock(); - auto IP = HeaderVPBB->getFirstNonPhi(); VPCanonicalIVPHIRecipe *CanonicalIV = Plan.getCanonicalIV(); - Type *TruncTy = TruncI ? TruncI->getType() : IVTy; - VPValue *BaseIV = CanonicalIV; - if (!CanonicalIV->isCanonical(ID.getKind(), StartV, Step, TruncTy)) { - BaseIV = new VPDerivedIVRecipe(ID, StartV, CanonicalIV, Step, - TruncI ? TruncI->getType() : nullptr); - HeaderVPBB->insert(BaseIV->getDefiningRecipe(), IP); + VPSingleDefRecipe *BaseIV = CanonicalIV; + if (!CanonicalIV->isCanonical(Kind, StartV, Step)) { + BaseIV = new VPDerivedIVRecipe(Kind, FPBinOp, StartV, CanonicalIV, Step); + HeaderVPBB->insert(BaseIV, IP); } - VPScalarIVStepsRecipe *Steps = new VPScalarIVStepsRecipe(ID, BaseIV, Step); + // Truncate base induction if needed. + VPTypeAnalysis TypeInfo(Plan.getCanonicalIV()->getScalarType(), + SE.getContext()); + Type *ResultTy = TypeInfo.inferScalarType(BaseIV); + if (TruncI) { + Type *TruncTy = TruncI->getType(); + assert(ResultTy->getScalarSizeInBits() > TruncTy->getScalarSizeInBits() && + "Not truncating."); + assert(ResultTy->isIntegerTy() && "Truncation requires an integer type"); + BaseIV = new VPScalarCastRecipe(Instruction::Trunc, BaseIV, TruncTy); + HeaderVPBB->insert(BaseIV, IP); + ResultTy = TruncTy; + } + + // Truncate step if needed. + Type *StepTy = TypeInfo.inferScalarType(Step); + if (ResultTy != StepTy) { + assert(StepTy->getScalarSizeInBits() > ResultTy->getScalarSizeInBits() && + "Not truncating."); + assert(StepTy->isIntegerTy() && "Truncation requires an integer type"); + Step = new VPScalarCastRecipe(Instruction::Trunc, Step, ResultTy); + auto *VecPreheader = + cast<VPBasicBlock>(HeaderVPBB->getSingleHierarchicalPredecessor()); + VecPreheader->appendRecipe(Step->getDefiningRecipe()); + } + + VPScalarIVStepsRecipe *Steps = new VPScalarIVStepsRecipe( + BaseIV, Step, InductionOpcode, + FPBinOp ? FPBinOp->getFastMathFlags() : FastMathFlags()); HeaderVPBB->insert(Steps, IP); return Steps; } -void VPlanTransforms::optimizeInductions(VPlan &Plan, ScalarEvolution &SE) { +/// Legalize VPWidenPointerInductionRecipe, by replacing it with a PtrAdd +/// (IndStart, ScalarIVSteps (0, Step)) if only its scalar values are used, as +/// VPWidenPointerInductionRecipe will generate vectors only. If some users +/// require vectors while other require scalars, the scalar uses need to extract +/// the scalars from the generated vectors (Note that this is different to how +/// int/fp inductions are handled). Also optimize VPWidenIntOrFpInductionRecipe, +/// if any of its users needs scalar values, by providing them scalar steps +/// built on the canonical scalar IV and update the original IV's users. This is +/// an optional optimization to reduce the needs of vector extracts. +static void legalizeAndOptimizeInductions(VPlan &Plan, ScalarEvolution &SE) { SmallVector<VPRecipeBase *> ToRemove; VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock(); bool HasOnlyVectorVFs = !Plan.hasVF(ElementCount::getFixed(1)); + VPBasicBlock::iterator InsertPt = HeaderVPBB->getFirstNonPhi(); for (VPRecipeBase &Phi : HeaderVPBB->phis()) { + // Replace wide pointer inductions which have only their scalars used by + // PtrAdd(IndStart, ScalarIVSteps (0, Step)). + if (auto *PtrIV = dyn_cast<VPWidenPointerInductionRecipe>(&Phi)) { + if (!PtrIV->onlyScalarsGenerated(Plan.hasScalableVF())) + continue; + + const InductionDescriptor &ID = PtrIV->getInductionDescriptor(); + VPValue *StartV = + Plan.getOrAddLiveIn(ConstantInt::get(ID.getStep()->getType(), 0)); + VPValue *StepV = PtrIV->getOperand(1); + VPScalarIVStepsRecipe *Steps = createScalarIVSteps( + Plan, InductionDescriptor::IK_IntInduction, Instruction::Add, nullptr, + SE, nullptr, StartV, StepV, InsertPt); + + auto *Recipe = new VPInstruction(VPInstruction::PtrAdd, + {PtrIV->getStartValue(), Steps}, + PtrIV->getDebugLoc(), "next.gep"); + + Recipe->insertAfter(Steps); + PtrIV->replaceAllUsesWith(Recipe); + continue; + } + + // Replace widened induction with scalar steps for users that only use + // scalars. auto *WideIV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&Phi); if (!WideIV) continue; @@ -523,9 +615,11 @@ void VPlanTransforms::optimizeInductions(VPlan &Plan, ScalarEvolution &SE) { continue; const InductionDescriptor &ID = WideIV->getInductionDescriptor(); - VPValue *Steps = createScalarIVSteps( - Plan, ID, SE, WideIV->getTruncInst(), WideIV->getPHINode()->getType(), - WideIV->getStartValue(), WideIV->getStepValue()); + VPScalarIVStepsRecipe *Steps = createScalarIVSteps( + Plan, ID.getKind(), ID.getInductionOpcode(), + dyn_cast_or_null<FPMathOperator>(ID.getInductionBinOp()), SE, + WideIV->getTruncInst(), WideIV->getStartValue(), WideIV->getStepValue(), + InsertPt); // Update scalar users of IV to use Step instead. if (!HasOnlyVectorVFs) @@ -537,7 +631,9 @@ void VPlanTransforms::optimizeInductions(VPlan &Plan, ScalarEvolution &SE) { } } -void VPlanTransforms::removeRedundantExpandSCEVRecipes(VPlan &Plan) { +/// Remove redundant EpxandSCEVRecipes in \p Plan's entry block by replacing +/// them with already existing recipes expanding the same SCEV expression. +static void removeRedundantExpandSCEVRecipes(VPlan &Plan) { DenseMap<const SCEV *, VPValue *> SCEV2VPV; for (VPRecipeBase &R : @@ -554,13 +650,23 @@ void VPlanTransforms::removeRedundantExpandSCEVRecipes(VPlan &Plan) { } } -static bool canSimplifyBranchOnCond(VPInstruction *Term) { - VPInstruction *Not = dyn_cast<VPInstruction>(Term->getOperand(0)); - if (!Not || Not->getOpcode() != VPInstruction::Not) - return false; +static void recursivelyDeleteDeadRecipes(VPValue *V) { + SmallVector<VPValue *> WorkList; + SmallPtrSet<VPValue *, 8> Seen; + WorkList.push_back(V); - VPInstruction *ALM = dyn_cast<VPInstruction>(Not->getOperand(0)); - return ALM && ALM->getOpcode() == VPInstruction::ActiveLaneMask; + while (!WorkList.empty()) { + VPValue *Cur = WorkList.pop_back_val(); + if (!Seen.insert(Cur).second) + continue; + VPRecipeBase *R = Cur->getDefiningRecipe(); + if (!R) + continue; + if (!isDeadRecipe(*R)) + continue; + WorkList.append(R->op_begin(), R->op_end()); + R->eraseFromParent(); + } } void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF, @@ -570,32 +676,37 @@ void VPlanTransforms::optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF, assert(Plan.hasUF(BestUF) && "BestUF is not available in Plan"); VPBasicBlock *ExitingVPBB = Plan.getVectorLoopRegion()->getExitingBasicBlock(); - auto *Term = dyn_cast<VPInstruction>(&ExitingVPBB->back()); + auto *Term = &ExitingVPBB->back(); // Try to simplify the branch condition if TC <= VF * UF when preparing to // execute the plan for the main vector loop. We only do this if the // terminator is: // 1. BranchOnCount, or // 2. BranchOnCond where the input is Not(ActiveLaneMask). - if (!Term || (Term->getOpcode() != VPInstruction::BranchOnCount && - (Term->getOpcode() != VPInstruction::BranchOnCond || - !canSimplifyBranchOnCond(Term)))) + using namespace llvm::VPlanPatternMatch; + if (!match(Term, m_BranchOnCount(m_VPValue(), m_VPValue())) && + !match(Term, + m_BranchOnCond(m_Not(m_ActiveLaneMask(m_VPValue(), m_VPValue()))))) return; Type *IdxTy = Plan.getCanonicalIV()->getStartValue()->getLiveInIRValue()->getType(); const SCEV *TripCount = createTripCountSCEV(IdxTy, PSE); ScalarEvolution &SE = *PSE.getSE(); - const SCEV *C = - SE.getConstant(TripCount->getType(), BestVF.getKnownMinValue() * BestUF); + ElementCount NumElements = BestVF.multiplyCoefficientBy(BestUF); + const SCEV *C = SE.getElementCount(TripCount->getType(), NumElements); if (TripCount->isZero() || !SE.isKnownPredicate(CmpInst::ICMP_ULE, TripCount, C)) return; LLVMContext &Ctx = SE.getContext(); - auto *BOC = new VPInstruction( - VPInstruction::BranchOnCond, - {Plan.getVPValueOrAddLiveIn(ConstantInt::getTrue(Ctx))}); + auto *BOC = + new VPInstruction(VPInstruction::BranchOnCond, + {Plan.getOrAddLiveIn(ConstantInt::getTrue(Ctx))}); + + SmallVector<VPValue *> PossiblyDead(Term->operands()); Term->eraseFromParent(); + for (VPValue *Op : PossiblyDead) + recursivelyDeleteDeadRecipes(Op); ExitingVPBB->appendRecipe(BOC); Plan.setVF(BestVF); Plan.setUF(BestUF); @@ -705,7 +816,7 @@ sinkRecurrenceUsersAfterPrevious(VPFirstOrderRecurrencePHIRecipe *FOR, } bool VPlanTransforms::adjustFixedOrderRecurrences(VPlan &Plan, - VPBuilder &Builder) { + VPBuilder &LoopBuilder) { VPDominatorTree VPDT; VPDT.recalculate(Plan); @@ -715,6 +826,20 @@ bool VPlanTransforms::adjustFixedOrderRecurrences(VPlan &Plan, if (auto *FOR = dyn_cast<VPFirstOrderRecurrencePHIRecipe>(&R)) RecurrencePhis.push_back(FOR); + VPBasicBlock *MiddleVPBB = + cast<VPBasicBlock>(Plan.getVectorLoopRegion()->getSingleSuccessor()); + VPBuilder MiddleBuilder; + // Set insert point so new recipes are inserted before terminator and + // condition, if there is either the former or both. + if (auto *Term = + dyn_cast_or_null<VPInstruction>(MiddleVPBB->getTerminator())) { + if (auto *Cmp = dyn_cast<VPInstruction>(Term->getOperand(0))) + MiddleBuilder.setInsertPoint(Cmp); + else + MiddleBuilder.setInsertPoint(Term); + } else + MiddleBuilder.setInsertPoint(MiddleVPBB); + for (VPFirstOrderRecurrencePHIRecipe *FOR : RecurrencePhis) { SmallPtrSet<VPFirstOrderRecurrencePHIRecipe *, 4> SeenPhis; VPRecipeBase *Previous = FOR->getBackedgeValue()->getDefiningRecipe(); @@ -734,22 +859,115 @@ bool VPlanTransforms::adjustFixedOrderRecurrences(VPlan &Plan, // fixed-order recurrence. VPBasicBlock *InsertBlock = Previous->getParent(); if (isa<VPHeaderPHIRecipe>(Previous)) - Builder.setInsertPoint(InsertBlock, InsertBlock->getFirstNonPhi()); + LoopBuilder.setInsertPoint(InsertBlock, InsertBlock->getFirstNonPhi()); else - Builder.setInsertPoint(InsertBlock, std::next(Previous->getIterator())); + LoopBuilder.setInsertPoint(InsertBlock, + std::next(Previous->getIterator())); auto *RecurSplice = cast<VPInstruction>( - Builder.createNaryOp(VPInstruction::FirstOrderRecurrenceSplice, - {FOR, FOR->getBackedgeValue()})); + LoopBuilder.createNaryOp(VPInstruction::FirstOrderRecurrenceSplice, + {FOR, FOR->getBackedgeValue()})); FOR->replaceAllUsesWith(RecurSplice); // Set the first operand of RecurSplice to FOR again, after replacing // all users. RecurSplice->setOperand(0, FOR); + + // This is the second phase of vectorizing first-order recurrences. An + // overview of the transformation is described below. Suppose we have the + // following loop with some use after the loop of the last a[i-1], + // + // for (int i = 0; i < n; ++i) { + // t = a[i - 1]; + // b[i] = a[i] - t; + // } + // use t; + // + // There is a first-order recurrence on "a". For this loop, the shorthand + // scalar IR looks like: + // + // scalar.ph: + // s_init = a[-1] + // br scalar.body + // + // scalar.body: + // i = phi [0, scalar.ph], [i+1, scalar.body] + // s1 = phi [s_init, scalar.ph], [s2, scalar.body] + // s2 = a[i] + // b[i] = s2 - s1 + // br cond, scalar.body, exit.block + // + // exit.block: + // use = lcssa.phi [s1, scalar.body] + // + // In this example, s1 is a recurrence because it's value depends on the + // previous iteration. In the first phase of vectorization, we created a + // vector phi v1 for s1. We now complete the vectorization and produce the + // shorthand vector IR shown below (for VF = 4, UF = 1). + // + // vector.ph: + // v_init = vector(..., ..., ..., a[-1]) + // br vector.body + // + // vector.body + // i = phi [0, vector.ph], [i+4, vector.body] + // v1 = phi [v_init, vector.ph], [v2, vector.body] + // v2 = a[i, i+1, i+2, i+3]; + // v3 = vector(v1(3), v2(0, 1, 2)) + // b[i, i+1, i+2, i+3] = v2 - v3 + // br cond, vector.body, middle.block + // + // middle.block: + // s_penultimate = v2(2) = v3(3) + // s_resume = v2(3) + // br cond, scalar.ph, exit.block + // + // scalar.ph: + // s_init' = phi [s_resume, middle.block], [s_init, otherwise] + // br scalar.body + // + // scalar.body: + // i = phi [0, scalar.ph], [i+1, scalar.body] + // s1 = phi [s_init', scalar.ph], [s2, scalar.body] + // s2 = a[i] + // b[i] = s2 - s1 + // br cond, scalar.body, exit.block + // + // exit.block: + // lo = lcssa.phi [s1, scalar.body], [s.penultimate, middle.block] + // + // After execution completes the vector loop, we extract the next value of + // the recurrence (x) to use as the initial value in the scalar loop. This + // is modeled by ExtractFromEnd. + Type *IntTy = Plan.getCanonicalIV()->getScalarType(); + + // Extract the penultimate value of the recurrence and update VPLiveOut + // users of the recurrence splice. Note that the extract of the final value + // used to resume in the scalar loop is created earlier during VPlan + // construction. + auto *Penultimate = cast<VPInstruction>(MiddleBuilder.createNaryOp( + VPInstruction::ExtractFromEnd, + {FOR->getBackedgeValue(), + Plan.getOrAddLiveIn(ConstantInt::get(IntTy, 2))}, + {}, "vector.recur.extract.for.phi")); + RecurSplice->replaceUsesWithIf( + Penultimate, [](VPUser &U, unsigned) { return isa<VPLiveOut>(&U); }); } return true; } +static SmallVector<VPUser *> collectUsersRecursively(VPValue *V) { + SetVector<VPUser *> Users(V->user_begin(), V->user_end()); + for (unsigned I = 0; I != Users.size(); ++I) { + VPRecipeBase *Cur = dyn_cast<VPRecipeBase>(Users[I]); + if (!Cur || isa<VPHeaderPHIRecipe>(Cur)) + continue; + for (VPValue *V : Cur->definedValues()) + Users.insert(V->user_begin(), V->user_end()); + } + return Users.takeVector(); +} + void VPlanTransforms::clearReductionWrapFlags(VPlan &Plan) { for (VPRecipeBase &R : Plan.getVectorLoopRegion()->getEntryBasicBlock()->phis()) { @@ -761,68 +979,30 @@ void VPlanTransforms::clearReductionWrapFlags(VPlan &Plan) { if (RK != RecurKind::Add && RK != RecurKind::Mul) continue; - SmallSetVector<VPValue *, 8> Worklist; - Worklist.insert(PhiR); - - for (unsigned I = 0; I != Worklist.size(); ++I) { - VPValue *Cur = Worklist[I]; - if (auto *RecWithFlags = - dyn_cast<VPRecipeWithIRFlags>(Cur->getDefiningRecipe())) { + for (VPUser *U : collectUsersRecursively(PhiR)) + if (auto *RecWithFlags = dyn_cast<VPRecipeWithIRFlags>(U)) { RecWithFlags->dropPoisonGeneratingFlags(); } - - for (VPUser *U : Cur->users()) { - auto *UserRecipe = dyn_cast<VPRecipeBase>(U); - if (!UserRecipe) - continue; - for (VPValue *V : UserRecipe->definedValues()) - Worklist.insert(V); - } - } } } -/// Returns true is \p V is constant one. -static bool isConstantOne(VPValue *V) { - if (!V->isLiveIn()) - return false; - auto *C = dyn_cast<ConstantInt>(V->getLiveInIRValue()); - return C && C->isOne(); -} - -/// Returns the llvm::Instruction opcode for \p R. -static unsigned getOpcodeForRecipe(VPRecipeBase &R) { - if (auto *WidenR = dyn_cast<VPWidenRecipe>(&R)) - return WidenR->getUnderlyingInstr()->getOpcode(); - if (auto *WidenC = dyn_cast<VPWidenCastRecipe>(&R)) - return WidenC->getOpcode(); - if (auto *RepR = dyn_cast<VPReplicateRecipe>(&R)) - return RepR->getUnderlyingInstr()->getOpcode(); - if (auto *VPI = dyn_cast<VPInstruction>(&R)) - return VPI->getOpcode(); - return 0; -} - /// Try to simplify recipe \p R. static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) { - switch (getOpcodeForRecipe(R)) { - case Instruction::Mul: { - VPValue *A = R.getOperand(0); - VPValue *B = R.getOperand(1); - if (isConstantOne(A)) - return R.getVPSingleValue()->replaceAllUsesWith(B); - if (isConstantOne(B)) - return R.getVPSingleValue()->replaceAllUsesWith(A); - break; + using namespace llvm::VPlanPatternMatch; + // Try to remove redundant blend recipes. + if (auto *Blend = dyn_cast<VPBlendRecipe>(&R)) { + VPValue *Inc0 = Blend->getIncomingValue(0); + for (unsigned I = 1; I != Blend->getNumIncomingValues(); ++I) + if (Inc0 != Blend->getIncomingValue(I) && + !match(Blend->getMask(I), m_False())) + return; + Blend->replaceAllUsesWith(Inc0); + Blend->eraseFromParent(); + return; } - case Instruction::Trunc: { - VPRecipeBase *Ext = R.getOperand(0)->getDefiningRecipe(); - if (!Ext) - break; - unsigned ExtOpcode = getOpcodeForRecipe(*Ext); - if (ExtOpcode != Instruction::ZExt && ExtOpcode != Instruction::SExt) - break; - VPValue *A = Ext->getOperand(0); + + VPValue *A; + if (match(&R, m_Trunc(m_ZExtOrSExt(m_VPValue(A))))) { VPValue *Trunc = R.getVPSingleValue(); Type *TruncTy = TypeInfo.inferScalarType(Trunc); Type *ATy = TypeInfo.inferScalarType(A); @@ -831,10 +1011,18 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) { } else { // Don't replace a scalarizing recipe with a widened cast. if (isa<VPReplicateRecipe>(&R)) - break; + return; if (ATy->getScalarSizeInBits() < TruncTy->getScalarSizeInBits()) { + + unsigned ExtOpcode = match(R.getOperand(0), m_SExt(m_VPValue())) + ? Instruction::SExt + : Instruction::ZExt; auto *VPC = new VPWidenCastRecipe(Instruction::CastOps(ExtOpcode), A, TruncTy); + if (auto *UnderlyingExt = R.getOperand(0)->getUnderlyingValue()) { + // UnderlyingExt has distinct return type, used to retain legacy cost. + VPC->setUnderlyingValue(UnderlyingExt); + } VPC->insertBefore(&R); Trunc->replaceAllUsesWith(VPC); } else if (ATy->getScalarSizeInBits() > TruncTy->getScalarSizeInBits()) { @@ -846,7 +1034,9 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) { #ifndef NDEBUG // Verify that the cached type info is for both A and its users is still // accurate by comparing it to freshly computed types. - VPTypeAnalysis TypeInfo2(TypeInfo.getContext()); + VPTypeAnalysis TypeInfo2( + R.getParent()->getPlan()->getCanonicalIV()->getScalarType(), + TypeInfo.getContext()); assert(TypeInfo.inferScalarType(A) == TypeInfo2.inferScalarType(A)); for (VPUser *U : A->users()) { auto *R = dyn_cast<VPRecipeBase>(U); @@ -856,18 +1046,30 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) { assert(TypeInfo.inferScalarType(VPV) == TypeInfo2.inferScalarType(VPV)); } #endif - break; } - default: - break; + + // Simplify (X && Y) || (X && !Y) -> X. + // TODO: Split up into simpler, modular combines: (X && Y) || (X && Z) into X + // && (Y || Z) and (X || !X) into true. This requires queuing newly created + // recipes to be visited during simplification. + VPValue *X, *Y, *X1, *Y1; + if (match(&R, + m_c_BinaryOr(m_LogicalAnd(m_VPValue(X), m_VPValue(Y)), + m_LogicalAnd(m_VPValue(X1), m_Not(m_VPValue(Y1))))) && + X == X1 && Y == Y1) { + R.getVPSingleValue()->replaceAllUsesWith(X); + return; } + + if (match(&R, m_c_Mul(m_VPValue(A), m_SpecificInt(1)))) + return R.getVPSingleValue()->replaceAllUsesWith(A); } /// Try to simplify the recipes in \p Plan. static void simplifyRecipes(VPlan &Plan, LLVMContext &Ctx) { ReversePostOrderTraversal<VPBlockDeepTraversalWrapper<VPBlockBase *>> RPOT( Plan.getEntry()); - VPTypeAnalysis TypeInfo(Ctx); + VPTypeAnalysis TypeInfo(Plan.getCanonicalIV()->getScalarType(), Ctx); for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(RPOT)) { for (VPRecipeBase &R : make_early_inc_range(*VPBB)) { simplifyRecipe(R, TypeInfo); @@ -888,16 +1090,13 @@ void VPlanTransforms::truncateToMinimalBitwidths( // other uses have different types for their operands, making them invalidly // typed. DenseMap<VPValue *, VPWidenCastRecipe *> ProcessedTruncs; - VPTypeAnalysis TypeInfo(Ctx); + VPTypeAnalysis TypeInfo(Plan.getCanonicalIV()->getScalarType(), Ctx); VPBasicBlock *PH = Plan.getEntry(); for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>( vp_depth_first_deep(Plan.getVectorLoopRegion()))) { for (VPRecipeBase &R : make_early_inc_range(*VPBB)) { if (!isa<VPWidenRecipe, VPWidenCastRecipe, VPReplicateRecipe, - VPWidenSelectRecipe, VPWidenMemoryInstructionRecipe>(&R)) - continue; - if (isa<VPWidenMemoryInstructionRecipe>(&R) && - cast<VPWidenMemoryInstructionRecipe>(&R)->isStore()) + VPWidenSelectRecipe, VPWidenLoadRecipe>(&R)) continue; VPValue *ResultVPV = R.getVPSingleValue(); @@ -943,9 +1142,6 @@ void VPlanTransforms::truncateToMinimalBitwidths( Type *OldResTy = TypeInfo.inferScalarType(ResultVPV); unsigned OldResSizeInBits = OldResTy->getScalarSizeInBits(); assert(OldResTy->isIntegerTy() && "only integer types supported"); - if (OldResSizeInBits == NewResSizeInBits) - continue; - assert(OldResSizeInBits > NewResSizeInBits && "Nothing to shrink?"); (void)OldResSizeInBits; auto *NewResTy = IntegerType::get(Ctx, NewResSizeInBits); @@ -956,16 +1152,24 @@ void VPlanTransforms::truncateToMinimalBitwidths( if (auto *VPW = dyn_cast<VPRecipeWithIRFlags>(&R)) VPW->dropPoisonGeneratingFlags(); - // Extend result to original width. - auto *Ext = new VPWidenCastRecipe(Instruction::ZExt, ResultVPV, OldResTy); - Ext->insertAfter(&R); - ResultVPV->replaceAllUsesWith(Ext); - Ext->setOperand(0, ResultVPV); - - if (isa<VPWidenMemoryInstructionRecipe>(&R)) { - assert(!cast<VPWidenMemoryInstructionRecipe>(&R)->isStore() && "stores cannot be narrowed"); + using namespace llvm::VPlanPatternMatch; + if (OldResSizeInBits != NewResSizeInBits && + !match(&R, m_Binary<Instruction::ICmp>(m_VPValue(), m_VPValue()))) { + // Extend result to original width. + auto *Ext = + new VPWidenCastRecipe(Instruction::ZExt, ResultVPV, OldResTy); + Ext->insertAfter(&R); + ResultVPV->replaceAllUsesWith(Ext); + Ext->setOperand(0, ResultVPV); + assert(OldResSizeInBits > NewResSizeInBits && "Nothing to shrink?"); + } else + assert( + match(&R, m_Binary<Instruction::ICmp>(m_VPValue(), m_VPValue())) && + "Only ICmps should not need extending the result."); + + assert(!isa<VPWidenStoreRecipe>(&R) && "stores cannot be narrowed"); + if (isa<VPWidenLoadRecipe>(&R)) continue; - } // Shrink operands by introducing truncates as needed. unsigned StartIdx = isa<VPWidenSelectRecipe>(&R) ? 1 : 0; @@ -1009,8 +1213,8 @@ void VPlanTransforms::optimize(VPlan &Plan, ScalarEvolution &SE) { removeRedundantCanonicalIVs(Plan); removeRedundantInductionCasts(Plan); - optimizeInductions(Plan, SE); simplifyRecipes(Plan, SE.getContext()); + legalizeAndOptimizeInductions(Plan, SE); removeDeadRecipes(Plan); createAndOptimizeReplicateRegions(Plan); @@ -1123,6 +1327,51 @@ static VPActiveLaneMaskPHIRecipe *addVPLaneMaskPhiAndUpdateExitBranch( return LaneMaskPhi; } +/// Collect all VPValues representing a header mask through the (ICMP_ULE, +/// WideCanonicalIV, backedge-taken-count) pattern. +/// TODO: Introduce explicit recipe for header-mask instead of searching +/// for the header-mask pattern manually. +static SmallVector<VPValue *> collectAllHeaderMasks(VPlan &Plan) { + SmallVector<VPValue *> WideCanonicalIVs; + auto *FoundWidenCanonicalIVUser = + find_if(Plan.getCanonicalIV()->users(), + [](VPUser *U) { return isa<VPWidenCanonicalIVRecipe>(U); }); + assert(count_if(Plan.getCanonicalIV()->users(), + [](VPUser *U) { return isa<VPWidenCanonicalIVRecipe>(U); }) <= + 1 && + "Must have at most one VPWideCanonicalIVRecipe"); + if (FoundWidenCanonicalIVUser != Plan.getCanonicalIV()->users().end()) { + auto *WideCanonicalIV = + cast<VPWidenCanonicalIVRecipe>(*FoundWidenCanonicalIVUser); + WideCanonicalIVs.push_back(WideCanonicalIV); + } + + // Also include VPWidenIntOrFpInductionRecipes that represent a widened + // version of the canonical induction. + VPBasicBlock *HeaderVPBB = Plan.getVectorLoopRegion()->getEntryBasicBlock(); + for (VPRecipeBase &Phi : HeaderVPBB->phis()) { + auto *WidenOriginalIV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&Phi); + if (WidenOriginalIV && WidenOriginalIV->isCanonical()) + WideCanonicalIVs.push_back(WidenOriginalIV); + } + + // Walk users of wide canonical IVs and collect to all compares of the form + // (ICMP_ULE, WideCanonicalIV, backedge-taken-count). + SmallVector<VPValue *> HeaderMasks; + for (auto *Wide : WideCanonicalIVs) { + for (VPUser *U : SmallVector<VPUser *>(Wide->users())) { + auto *HeaderMask = dyn_cast<VPInstruction>(U); + if (!HeaderMask || !vputils::isHeaderMask(HeaderMask, Plan)) + continue; + + assert(HeaderMask->getOperand(0) == Wide && + "WidenCanonicalIV must be the first operand of the compare"); + HeaderMasks.push_back(HeaderMask); + } + } + return HeaderMasks; +} + void VPlanTransforms::addActiveLaneMask( VPlan &Plan, bool UseActiveLaneMaskForControlFlow, bool DataAndControlFlowWithoutRuntimeCheck) { @@ -1143,27 +1392,233 @@ void VPlanTransforms::addActiveLaneMask( LaneMask = addVPLaneMaskPhiAndUpdateExitBranch( Plan, DataAndControlFlowWithoutRuntimeCheck); } else { - LaneMask = new VPInstruction(VPInstruction::ActiveLaneMask, - {WideCanonicalIV, Plan.getTripCount()}, - nullptr, "active.lane.mask"); - LaneMask->insertAfter(WideCanonicalIV); + VPBuilder B = VPBuilder::getToInsertAfter(WideCanonicalIV); + LaneMask = B.createNaryOp(VPInstruction::ActiveLaneMask, + {WideCanonicalIV, Plan.getTripCount()}, nullptr, + "active.lane.mask"); } // Walk users of WideCanonicalIV and replace all compares of the form // (ICMP_ULE, WideCanonicalIV, backedge-taken-count) with an // active-lane-mask. - VPValue *BTC = Plan.getOrCreateBackedgeTakenCount(); - for (VPUser *U : SmallVector<VPUser *>(WideCanonicalIV->users())) { - auto *CompareToReplace = dyn_cast<VPInstruction>(U); - if (!CompareToReplace || - CompareToReplace->getOpcode() != Instruction::ICmp || - CompareToReplace->getPredicate() != CmpInst::ICMP_ULE || - CompareToReplace->getOperand(1) != BTC) - continue; + for (VPValue *HeaderMask : collectAllHeaderMasks(Plan)) + HeaderMask->replaceAllUsesWith(LaneMask); +} - assert(CompareToReplace->getOperand(0) == WideCanonicalIV && - "WidenCanonicalIV must be the first operand of the compare"); - CompareToReplace->replaceAllUsesWith(LaneMask); - CompareToReplace->eraseFromParent(); +/// Add a VPEVLBasedIVPHIRecipe and related recipes to \p Plan and +/// replaces all uses except the canonical IV increment of +/// VPCanonicalIVPHIRecipe with a VPEVLBasedIVPHIRecipe. VPCanonicalIVPHIRecipe +/// is used only for loop iterations counting after this transformation. +/// +/// The function uses the following definitions: +/// %StartV is the canonical induction start value. +/// +/// The function adds the following recipes: +/// +/// vector.ph: +/// ... +/// +/// vector.body: +/// ... +/// %EVLPhi = EXPLICIT-VECTOR-LENGTH-BASED-IV-PHI [ %StartV, %vector.ph ], +/// [ %NextEVLIV, %vector.body ] +/// %VPEVL = EXPLICIT-VECTOR-LENGTH %EVLPhi, original TC +/// ... +/// %NextEVLIV = add IVSize (cast i32 %VPEVVL to IVSize), %EVLPhi +/// ... +/// +bool VPlanTransforms::tryAddExplicitVectorLength(VPlan &Plan) { + VPBasicBlock *Header = Plan.getVectorLoopRegion()->getEntryBasicBlock(); + // The transform updates all users of inductions to work based on EVL, instead + // of the VF directly. At the moment, widened inductions cannot be updated, so + // bail out if the plan contains any. + bool ContainsWidenInductions = any_of(Header->phis(), [](VPRecipeBase &Phi) { + return isa<VPWidenIntOrFpInductionRecipe, VPWidenPointerInductionRecipe>( + &Phi); + }); + // FIXME: Remove this once we can transform (select header_mask, true_value, + // false_value) into vp.merge. + bool ContainsOutloopReductions = + any_of(Header->phis(), [&](VPRecipeBase &Phi) { + auto *R = dyn_cast<VPReductionPHIRecipe>(&Phi); + return R && !R->isInLoop(); + }); + if (ContainsWidenInductions || ContainsOutloopReductions) + return false; + + auto *CanonicalIVPHI = Plan.getCanonicalIV(); + VPValue *StartV = CanonicalIVPHI->getStartValue(); + + // Create the ExplicitVectorLengthPhi recipe in the main loop. + auto *EVLPhi = new VPEVLBasedIVPHIRecipe(StartV, DebugLoc()); + EVLPhi->insertAfter(CanonicalIVPHI); + auto *VPEVL = new VPInstruction(VPInstruction::ExplicitVectorLength, + {EVLPhi, Plan.getTripCount()}); + VPEVL->insertBefore(*Header, Header->getFirstNonPhi()); + + auto *CanonicalIVIncrement = + cast<VPInstruction>(CanonicalIVPHI->getBackedgeValue()); + VPSingleDefRecipe *OpVPEVL = VPEVL; + if (unsigned IVSize = CanonicalIVPHI->getScalarType()->getScalarSizeInBits(); + IVSize != 32) { + OpVPEVL = new VPScalarCastRecipe(IVSize < 32 ? Instruction::Trunc + : Instruction::ZExt, + OpVPEVL, CanonicalIVPHI->getScalarType()); + OpVPEVL->insertBefore(CanonicalIVIncrement); + } + auto *NextEVLIV = + new VPInstruction(Instruction::Add, {OpVPEVL, EVLPhi}, + {CanonicalIVIncrement->hasNoUnsignedWrap(), + CanonicalIVIncrement->hasNoSignedWrap()}, + CanonicalIVIncrement->getDebugLoc(), "index.evl.next"); + NextEVLIV->insertBefore(CanonicalIVIncrement); + EVLPhi->addOperand(NextEVLIV); + + for (VPValue *HeaderMask : collectAllHeaderMasks(Plan)) { + for (VPUser *U : collectUsersRecursively(HeaderMask)) { + VPRecipeBase *NewRecipe = nullptr; + auto *CurRecipe = dyn_cast<VPRecipeBase>(U); + if (!CurRecipe) + continue; + + auto GetNewMask = [&](VPValue *OrigMask) -> VPValue * { + assert(OrigMask && "Unmasked recipe when folding tail"); + return HeaderMask == OrigMask ? nullptr : OrigMask; + }; + if (auto *MemR = dyn_cast<VPWidenMemoryRecipe>(CurRecipe)) { + VPValue *NewMask = GetNewMask(MemR->getMask()); + if (auto *L = dyn_cast<VPWidenLoadRecipe>(MemR)) + NewRecipe = new VPWidenLoadEVLRecipe(L, VPEVL, NewMask); + else if (auto *S = dyn_cast<VPWidenStoreRecipe>(MemR)) + NewRecipe = new VPWidenStoreEVLRecipe(S, VPEVL, NewMask); + else + llvm_unreachable("unsupported recipe"); + } else if (auto *RedR = dyn_cast<VPReductionRecipe>(CurRecipe)) { + NewRecipe = new VPReductionEVLRecipe(RedR, VPEVL, + GetNewMask(RedR->getCondOp())); + } + + if (NewRecipe) { + [[maybe_unused]] unsigned NumDefVal = NewRecipe->getNumDefinedValues(); + assert(NumDefVal == CurRecipe->getNumDefinedValues() && + "New recipe must define the same number of values as the " + "original."); + assert( + NumDefVal <= 1 && + "Only supports recipes with a single definition or without users."); + NewRecipe->insertBefore(CurRecipe); + if (isa<VPSingleDefRecipe, VPWidenLoadEVLRecipe>(NewRecipe)) { + VPValue *CurVPV = CurRecipe->getVPSingleValue(); + CurVPV->replaceAllUsesWith(NewRecipe->getVPSingleValue()); + } + CurRecipe->eraseFromParent(); + } + } + recursivelyDeleteDeadRecipes(HeaderMask); + } + // Replace all uses of VPCanonicalIVPHIRecipe by + // VPEVLBasedIVPHIRecipe except for the canonical IV increment. + CanonicalIVPHI->replaceAllUsesWith(EVLPhi); + CanonicalIVIncrement->setOperand(0, CanonicalIVPHI); + // TODO: support unroll factor > 1. + Plan.setUF(1); + return true; +} + +void VPlanTransforms::dropPoisonGeneratingRecipes( + VPlan &Plan, function_ref<bool(BasicBlock *)> BlockNeedsPredication) { + // Collect recipes in the backward slice of `Root` that may generate a poison + // value that is used after vectorization. + SmallPtrSet<VPRecipeBase *, 16> Visited; + auto collectPoisonGeneratingInstrsInBackwardSlice([&](VPRecipeBase *Root) { + SmallVector<VPRecipeBase *, 16> Worklist; + Worklist.push_back(Root); + + // Traverse the backward slice of Root through its use-def chain. + while (!Worklist.empty()) { + VPRecipeBase *CurRec = Worklist.back(); + Worklist.pop_back(); + + if (!Visited.insert(CurRec).second) + continue; + + // Prune search if we find another recipe generating a widen memory + // instruction. Widen memory instructions involved in address computation + // will lead to gather/scatter instructions, which don't need to be + // handled. + if (isa<VPWidenMemoryRecipe>(CurRec) || isa<VPInterleaveRecipe>(CurRec) || + isa<VPScalarIVStepsRecipe>(CurRec) || isa<VPHeaderPHIRecipe>(CurRec)) + continue; + + // This recipe contributes to the address computation of a widen + // load/store. If the underlying instruction has poison-generating flags, + // drop them directly. + if (auto *RecWithFlags = dyn_cast<VPRecipeWithIRFlags>(CurRec)) { + VPValue *A, *B; + using namespace llvm::VPlanPatternMatch; + // Dropping disjoint from an OR may yield incorrect results, as some + // analysis may have converted it to an Add implicitly (e.g. SCEV used + // for dependence analysis). Instead, replace it with an equivalent Add. + // This is possible as all users of the disjoint OR only access lanes + // where the operands are disjoint or poison otherwise. + if (match(RecWithFlags, m_BinaryOr(m_VPValue(A), m_VPValue(B))) && + RecWithFlags->isDisjoint()) { + VPBuilder Builder(RecWithFlags); + VPInstruction *New = Builder.createOverflowingOp( + Instruction::Add, {A, B}, {false, false}, + RecWithFlags->getDebugLoc()); + New->setUnderlyingValue(RecWithFlags->getUnderlyingValue()); + RecWithFlags->replaceAllUsesWith(New); + RecWithFlags->eraseFromParent(); + CurRec = New; + } else + RecWithFlags->dropPoisonGeneratingFlags(); + } else { + Instruction *Instr = dyn_cast_or_null<Instruction>( + CurRec->getVPSingleValue()->getUnderlyingValue()); + (void)Instr; + assert((!Instr || !Instr->hasPoisonGeneratingFlags()) && + "found instruction with poison generating flags not covered by " + "VPRecipeWithIRFlags"); + } + + // Add new definitions to the worklist. + for (VPValue *operand : CurRec->operands()) + if (VPRecipeBase *OpDef = operand->getDefiningRecipe()) + Worklist.push_back(OpDef); + } + }); + + // Traverse all the recipes in the VPlan and collect the poison-generating + // recipes in the backward slice starting at the address of a VPWidenRecipe or + // VPInterleaveRecipe. + auto Iter = vp_depth_first_deep(Plan.getEntry()); + for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(Iter)) { + for (VPRecipeBase &Recipe : *VPBB) { + if (auto *WidenRec = dyn_cast<VPWidenMemoryRecipe>(&Recipe)) { + Instruction &UnderlyingInstr = WidenRec->getIngredient(); + VPRecipeBase *AddrDef = WidenRec->getAddr()->getDefiningRecipe(); + if (AddrDef && WidenRec->isConsecutive() && + BlockNeedsPredication(UnderlyingInstr.getParent())) + collectPoisonGeneratingInstrsInBackwardSlice(AddrDef); + } else if (auto *InterleaveRec = dyn_cast<VPInterleaveRecipe>(&Recipe)) { + VPRecipeBase *AddrDef = InterleaveRec->getAddr()->getDefiningRecipe(); + if (AddrDef) { + // Check if any member of the interleave group needs predication. + const InterleaveGroup<Instruction> *InterGroup = + InterleaveRec->getInterleaveGroup(); + bool NeedPredication = false; + for (int I = 0, NumMembers = InterGroup->getNumMembers(); + I < NumMembers; ++I) { + Instruction *Member = InterGroup->getMember(I); + if (Member) + NeedPredication |= BlockNeedsPredication(Member->getParent()); + } + + if (NeedPredication) + collectPoisonGeneratingInstrsInBackwardSlice(AddrDef); + } + } + } } } diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.h index 3bf91115debb..96b8a6639723 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -84,35 +84,28 @@ struct VPlanTransforms { const MapVector<Instruction *, uint64_t> &MinBWs, LLVMContext &Ctx); -private: - /// Remove redundant VPBasicBlocks by merging them into their predecessor if - /// the predecessor has a single successor. - static bool mergeBlocksIntoPredecessors(VPlan &Plan); - - /// Remove redundant casts of inductions. - /// - /// Such redundant casts are casts of induction variables that can be ignored, - /// because we already proved that the casted phi is equal to the uncasted phi - /// in the vectorized loop. There is no need to vectorize the cast - the same - /// value can be used for both the phi and casts in the vector loop. - static void removeRedundantInductionCasts(VPlan &Plan); - - /// Try to replace VPWidenCanonicalIVRecipes with a widened canonical IV - /// recipe, if it exists. - static void removeRedundantCanonicalIVs(VPlan &Plan); - - static void removeDeadRecipes(VPlan &Plan); - - /// If any user of a VPWidenIntOrFpInductionRecipe needs scalar values, - /// provide them by building scalar steps off of the canonical scalar IV and - /// update the original IV's users. This is an optional optimization to reduce - /// the needs of vector extracts. - static void optimizeInductions(VPlan &Plan, ScalarEvolution &SE); - - /// Remove redundant EpxandSCEVRecipes in \p Plan's entry block by replacing - /// them with already existing recipes expanding the same SCEV expression. - static void removeRedundantExpandSCEVRecipes(VPlan &Plan); - + /// Drop poison flags from recipes that may generate a poison value that is + /// used after vectorization, even when their operands are not poison. Those + /// recipes meet the following conditions: + /// * Contribute to the address computation of a recipe generating a widen + /// memory load/store (VPWidenMemoryInstructionRecipe or + /// VPInterleaveRecipe). + /// * Such a widen memory load/store has at least one underlying Instruction + /// that is in a basic block that needs predication and after vectorization + /// the generated instruction won't be predicated. + /// Uses \p BlockNeedsPredication to check if a block needs predicating. + /// TODO: Replace BlockNeedsPredication callback with retrieving info from + /// VPlan directly. + static void dropPoisonGeneratingRecipes( + VPlan &Plan, function_ref<bool(BasicBlock *)> BlockNeedsPredication); + + /// Add a VPEVLBasedIVPHIRecipe and related recipes to \p Plan and + /// replaces all uses except the canonical IV increment of + /// VPCanonicalIVPHIRecipe with a VPEVLBasedIVPHIRecipe. + /// VPCanonicalIVPHIRecipe is only used to control the loop after + /// this transformation. + /// \returns true if the transformation succeeds, or false if it doesn't. + static bool tryAddExplicitVectorLength(VPlan &Plan); }; } // namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanValue.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanValue.h index 8cc98f4abf93..452c977106a7 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanValue.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanValue.h @@ -23,6 +23,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" #include "llvm/ADT/TinyPtrVector.h" #include "llvm/ADT/iterator_range.h" @@ -35,7 +36,6 @@ class VPDef; class VPSlotTracker; class VPUser; class VPRecipeBase; -class VPWidenMemoryInstructionRecipe; // This is the base class of the VPlan Def/Use graph, used for modeling the data // flow into, within and out of the VPlan. VPValues can stand for live-ins @@ -50,7 +50,6 @@ class VPValue { friend class VPInterleavedAccessInfo; friend class VPSlotTracker; friend class VPRecipeBase; - friend class VPWidenMemoryInstructionRecipe; const unsigned char SubclassID; ///< Subclass identifier (for isa/dyn_cast). @@ -73,16 +72,9 @@ protected: // for multiple underlying IRs (Polly?) by providing a new VPlan front-end, // back-end and analysis information for the new IR. - // Set \p Val as the underlying Value of this VPValue. - void setUnderlyingValue(Value *Val) { - assert(!UnderlyingVal && "Underlying Value is already set."); - UnderlyingVal = Val; - } - public: /// Return the underlying Value attached to this VPValue. - Value *getUnderlyingValue() { return UnderlyingVal; } - const Value *getUnderlyingValue() const { return UnderlyingVal; } + Value *getUnderlyingValue() const { return UnderlyingVal; } /// An enumeration for keeping track of the concrete subclass of VPValue that /// are actually instantiated. @@ -192,6 +184,12 @@ public: /// is a live-in value. /// TODO: Also handle recipes defined in pre-header blocks. bool isDefinedOutsideVectorRegions() const { return !hasDefiningRecipe(); } + + // Set \p Val as the underlying Value of this VPValue. + void setUnderlyingValue(Value *Val) { + assert(!UnderlyingVal && "Underlying Value is already set."); + UnderlyingVal = Val; + } }; typedef DenseMap<Value *, VPValue *> Value2VPValueTy; @@ -262,11 +260,6 @@ public: New->addUser(*this); } - void removeLastOperand() { - VPValue *Op = Operands.pop_back_val(); - Op->removeUser(*this); - } - typedef SmallVectorImpl<VPValue *>::iterator operand_iterator; typedef SmallVectorImpl<VPValue *>::const_iterator const_operand_iterator; typedef iterator_range<operand_iterator> operand_range; @@ -348,32 +341,38 @@ public: VPExpandSCEVSC, VPInstructionSC, VPInterleaveSC, + VPReductionEVLSC, VPReductionSC, VPReplicateSC, + VPScalarCastSC, VPScalarIVStepsSC, VPVectorPointerSC, VPWidenCallSC, VPWidenCanonicalIVSC, VPWidenCastSC, VPWidenGEPSC, - VPWidenMemoryInstructionSC, + VPWidenLoadEVLSC, + VPWidenLoadSC, + VPWidenStoreEVLSC, + VPWidenStoreSC, VPWidenSC, VPWidenSelectSC, - // START: Phi-like recipes. Need to be kept together. VPBlendSC, + // START: Phi-like recipes. Need to be kept together. + VPWidenPHISC, VPPredInstPHISC, // START: SubclassID for recipes that inherit VPHeaderPHIRecipe. // VPHeaderPHIRecipe need to be kept together. VPCanonicalIVPHISC, VPActiveLaneMaskPHISC, + VPEVLBasedIVPHISC, VPFirstOrderRecurrencePHISC, - VPWidenPHISC, VPWidenIntOrFpInductionSC, VPWidenPointerInductionSC, VPReductionPHISC, // END: SubclassID for recipes that inherit VPHeaderPHIRecipe // END: Phi-like recipes - VPFirstPHISC = VPBlendSC, + VPFirstPHISC = VPWidenPHISC, VPFirstHeaderPHISC = VPCanonicalIVPHISC, VPLastHeaderPHISC = VPReductionPHISC, VPLastPHISC = VPReductionPHISC, @@ -441,29 +440,36 @@ public: class VPlan; class VPBasicBlock; -/// This class can be used to assign consecutive numbers to all VPValues in a -/// VPlan and allows querying the numbering for printing, similar to the +/// This class can be used to assign names to VPValues. For VPValues without +/// underlying value, assign consecutive numbers and use those as names (wrapped +/// in vp<>). Otherwise, use the name from the underlying value (wrapped in +/// ir<>), appending a .V version number if there are multiple uses of the same +/// name. Allows querying names for VPValues for printing, similar to the /// ModuleSlotTracker for IR values. class VPSlotTracker { - DenseMap<const VPValue *, unsigned> Slots; + /// Keep track of versioned names assigned to VPValues with underlying IR + /// values. + DenseMap<const VPValue *, std::string> VPValue2Name; + /// Keep track of the next number to use to version the base name. + StringMap<unsigned> BaseName2Version; + + /// Number to assign to the next VPValue without underlying value. unsigned NextSlot = 0; - void assignSlot(const VPValue *V); - void assignSlots(const VPlan &Plan); - void assignSlots(const VPBasicBlock *VPBB); + void assignName(const VPValue *V); + void assignNames(const VPlan &Plan); + void assignNames(const VPBasicBlock *VPBB); public: VPSlotTracker(const VPlan *Plan = nullptr) { if (Plan) - assignSlots(*Plan); + assignNames(*Plan); } - unsigned getSlot(const VPValue *V) const { - auto I = Slots.find(V); - if (I == Slots.end()) - return -1; - return I->second; - } + /// Returns the name assigned to \p V, if there is one, otherwise try to + /// construct one from the underlying value, if there's one; else return + /// <badref>. + std::string getOrCreateName(const VPValue *V) const; }; } // namespace llvm diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp index d6b81543dbc9..765dc983cab4 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp @@ -17,126 +17,50 @@ #include "VPlanCFG.h" #include "VPlanDominatorTree.h" #include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/CommandLine.h" #define DEBUG_TYPE "loop-vectorize" using namespace llvm; -static cl::opt<bool> EnableHCFGVerifier("vplan-verify-hcfg", cl::init(false), - cl::Hidden, - cl::desc("Verify VPlan H-CFG.")); +namespace { +class VPlanVerifier { + const VPDominatorTree &VPDT; -#ifndef NDEBUG -/// Utility function that checks whether \p VPBlockVec has duplicate -/// VPBlockBases. -static bool hasDuplicates(const SmallVectorImpl<VPBlockBase *> &VPBlockVec) { - SmallDenseSet<const VPBlockBase *, 8> VPBlockSet; - for (const auto *Block : VPBlockVec) { - if (VPBlockSet.count(Block)) - return true; - VPBlockSet.insert(Block); - } - return false; -} -#endif + SmallPtrSet<BasicBlock *, 8> WrappedIRBBs; -/// Helper function that verifies the CFG invariants of the VPBlockBases within -/// \p Region. Checks in this function are generic for VPBlockBases. They are -/// not specific for VPBasicBlocks or VPRegionBlocks. -static void verifyBlocksInRegion(const VPRegionBlock *Region) { - for (const VPBlockBase *VPB : vp_depth_first_shallow(Region->getEntry())) { - // Check block's parent. - assert(VPB->getParent() == Region && "VPBlockBase has wrong parent"); - - auto *VPBB = dyn_cast<VPBasicBlock>(VPB); - // Check block's condition bit. - if (VPB->getNumSuccessors() > 1 || (VPBB && VPBB->isExiting())) - assert(VPBB && VPBB->getTerminator() && - "Block has multiple successors but doesn't " - "have a proper branch recipe!"); - else - assert((!VPBB || !VPBB->getTerminator()) && "Unexpected branch recipe!"); - - // Check block's successors. - const auto &Successors = VPB->getSuccessors(); - // There must be only one instance of a successor in block's successor list. - // TODO: This won't work for switch statements. - assert(!hasDuplicates(Successors) && - "Multiple instances of the same successor."); - - for (const VPBlockBase *Succ : Successors) { - // There must be a bi-directional link between block and successor. - const auto &SuccPreds = Succ->getPredecessors(); - assert(llvm::is_contained(SuccPreds, VPB) && "Missing predecessor link."); - (void)SuccPreds; - } + // Verify that phi-like recipes are at the beginning of \p VPBB, with no + // other recipes in between. Also check that only header blocks contain + // VPHeaderPHIRecipes. + bool verifyPhiRecipes(const VPBasicBlock *VPBB); - // Check block's predecessors. - const auto &Predecessors = VPB->getPredecessors(); - // There must be only one instance of a predecessor in block's predecessor - // list. - // TODO: This won't work for switch statements. - assert(!hasDuplicates(Predecessors) && - "Multiple instances of the same predecessor."); - - for (const VPBlockBase *Pred : Predecessors) { - // Block and predecessor must be inside the same region. - assert(Pred->getParent() == VPB->getParent() && - "Predecessor is not in the same region."); - - // There must be a bi-directional link between block and predecessor. - const auto &PredSuccs = Pred->getSuccessors(); - assert(llvm::is_contained(PredSuccs, VPB) && "Missing successor link."); - (void)PredSuccs; - } - } -} + bool verifyVPBasicBlock(const VPBasicBlock *VPBB); -/// Verify the CFG invariants of VPRegionBlock \p Region and its nested -/// VPBlockBases. Do not recurse inside nested VPRegionBlocks. -static void verifyRegion(const VPRegionBlock *Region) { - const VPBlockBase *Entry = Region->getEntry(); - const VPBlockBase *Exiting = Region->getExiting(); + bool verifyBlock(const VPBlockBase *VPB); - // Entry and Exiting shouldn't have any predecessor/successor, respectively. - assert(!Entry->getNumPredecessors() && "Region entry has predecessors."); - assert(!Exiting->getNumSuccessors() && - "Region exiting block has successors."); - (void)Entry; - (void)Exiting; + /// Helper function that verifies the CFG invariants of the VPBlockBases + /// within + /// \p Region. Checks in this function are generic for VPBlockBases. They are + /// not specific for VPBasicBlocks or VPRegionBlocks. + bool verifyBlocksInRegion(const VPRegionBlock *Region); - verifyBlocksInRegion(Region); -} + /// Verify the CFG invariants of VPRegionBlock \p Region and its nested + /// VPBlockBases. Do not recurse inside nested VPRegionBlocks. + bool verifyRegion(const VPRegionBlock *Region); -/// Verify the CFG invariants of VPRegionBlock \p Region and its nested -/// VPBlockBases. Recurse inside nested VPRegionBlocks. -static void verifyRegionRec(const VPRegionBlock *Region) { - verifyRegion(Region); - - // Recurse inside nested regions. - for (const VPBlockBase *VPB : make_range( - df_iterator<const VPBlockBase *>::begin(Region->getEntry()), - df_iterator<const VPBlockBase *>::end(Region->getExiting()))) { - if (const auto *SubRegion = dyn_cast<VPRegionBlock>(VPB)) - verifyRegionRec(SubRegion); - } -} + /// Verify the CFG invariants of VPRegionBlock \p Region and its nested + /// VPBlockBases. Recurse inside nested VPRegionBlocks. + bool verifyRegionRec(const VPRegionBlock *Region); -void VPlanVerifier::verifyHierarchicalCFG( - const VPRegionBlock *TopRegion) const { - if (!EnableHCFGVerifier) - return; +public: + VPlanVerifier(VPDominatorTree &VPDT) : VPDT(VPDT) {} - LLVM_DEBUG(dbgs() << "Verifying VPlan H-CFG.\n"); - assert(!TopRegion->getParent() && "VPlan Top Region should have no parent."); - verifyRegionRec(TopRegion); -} + bool verify(const VPlan &Plan); +}; +} // namespace -// Verify that phi-like recipes are at the beginning of \p VPBB, with no -// other recipes in between. Also check that only header blocks contain -// VPHeaderPHIRecipes. -static bool verifyPhiRecipes(const VPBasicBlock *VPBB) { +bool VPlanVerifier::verifyPhiRecipes(const VPBasicBlock *VPBB) { auto RecipeI = VPBB->begin(); auto End = VPBB->end(); unsigned NumActiveLaneMaskPhiRecipes = 0; @@ -147,7 +71,7 @@ static bool verifyPhiRecipes(const VPBasicBlock *VPBB) { if (isa<VPActiveLaneMaskPHIRecipe>(RecipeI)) NumActiveLaneMaskPhiRecipes++; - if (IsHeaderVPBB && !isa<VPHeaderPHIRecipe>(*RecipeI)) { + if (IsHeaderVPBB && !isa<VPHeaderPHIRecipe, VPWidenPHIRecipe>(*RecipeI)) { errs() << "Found non-header PHI recipe in header VPBB"; #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) errs() << ": "; @@ -190,8 +114,7 @@ static bool verifyPhiRecipes(const VPBasicBlock *VPBB) { return true; } -static bool verifyVPBasicBlock(const VPBasicBlock *VPBB, - VPDominatorTree &VPDT) { +bool VPlanVerifier::verifyVPBasicBlock(const VPBasicBlock *VPBB) { if (!verifyPhiRecipes(VPBB)) return false; @@ -207,7 +130,8 @@ static bool verifyVPBasicBlock(const VPBasicBlock *VPBB, for (const VPUser *U : V->users()) { auto *UI = dyn_cast<VPRecipeBase>(U); // TODO: check dominance of incoming values for phis properly. - if (!UI || isa<VPHeaderPHIRecipe>(UI) || isa<VPPredInstPHIRecipe>(UI)) + if (!UI || + isa<VPHeaderPHIRecipe, VPWidenPHIRecipe, VPPredInstPHIRecipe>(UI)) continue; // If the user is in the same block, check it comes after R in the @@ -227,21 +151,157 @@ static bool verifyVPBasicBlock(const VPBasicBlock *VPBB, } } } + + auto *IRBB = dyn_cast<VPIRBasicBlock>(VPBB); + if (!IRBB) + return true; + + if (!WrappedIRBBs.insert(IRBB->getIRBasicBlock()).second) { + errs() << "Same IR basic block used by multiple wrapper blocks!\n"; + return false; + } + + VPBlockBase *MiddleBB = + IRBB->getPlan()->getVectorLoopRegion()->getSingleSuccessor(); + if (IRBB != IRBB->getPlan()->getPreheader() && + IRBB->getSinglePredecessor() != MiddleBB) { + errs() << "VPIRBasicBlock can only be used as pre-header or a successor of " + "middle-block at the moment!\n"; + return false; + } return true; } -bool VPlanVerifier::verifyPlanIsValid(const VPlan &Plan) { - VPDominatorTree VPDT; - VPDT.recalculate(const_cast<VPlan &>(Plan)); +/// Utility function that checks whether \p VPBlockVec has duplicate +/// VPBlockBases. +static bool hasDuplicates(const SmallVectorImpl<VPBlockBase *> &VPBlockVec) { + SmallDenseSet<const VPBlockBase *, 8> VPBlockSet; + for (const auto *Block : VPBlockVec) { + if (VPBlockSet.count(Block)) + return true; + VPBlockSet.insert(Block); + } + return false; +} + +bool VPlanVerifier::verifyBlock(const VPBlockBase *VPB) { + auto *VPBB = dyn_cast<VPBasicBlock>(VPB); + // Check block's condition bit. + if (VPB->getNumSuccessors() > 1 || + (VPBB && VPBB->getParent() && VPBB->isExiting() && + !VPBB->getParent()->isReplicator())) { + if (!VPBB || !VPBB->getTerminator()) { + errs() << "Block has multiple successors but doesn't " + "have a proper branch recipe!\n"; + return false; + } + } else { + if (VPBB && VPBB->getTerminator()) { + errs() << "Unexpected branch recipe!\n"; + return false; + } + } + + // Check block's successors. + const auto &Successors = VPB->getSuccessors(); + // There must be only one instance of a successor in block's successor list. + // TODO: This won't work for switch statements. + if (hasDuplicates(Successors)) { + errs() << "Multiple instances of the same successor.\n"; + return false; + } + + for (const VPBlockBase *Succ : Successors) { + // There must be a bi-directional link between block and successor. + const auto &SuccPreds = Succ->getPredecessors(); + if (!is_contained(SuccPreds, VPB)) { + errs() << "Missing predecessor link.\n"; + return false; + } + } + + // Check block's predecessors. + const auto &Predecessors = VPB->getPredecessors(); + // There must be only one instance of a predecessor in block's predecessor + // list. + // TODO: This won't work for switch statements. + if (hasDuplicates(Predecessors)) { + errs() << "Multiple instances of the same predecessor.\n"; + return false; + } + + for (const VPBlockBase *Pred : Predecessors) { + // Block and predecessor must be inside the same region. + if (Pred->getParent() != VPB->getParent()) { + errs() << "Predecessor is not in the same region.\n"; + return false; + } - auto Iter = vp_depth_first_deep(Plan.getEntry()); - for (const VPBasicBlock *VPBB : - VPBlockUtils::blocksOnly<const VPBasicBlock>(Iter)) { - if (!verifyVPBasicBlock(VPBB, VPDT)) + // There must be a bi-directional link between block and predecessor. + const auto &PredSuccs = Pred->getSuccessors(); + if (!is_contained(PredSuccs, VPB)) { + errs() << "Missing successor link.\n"; return false; + } } + return !VPBB || verifyVPBasicBlock(VPBB); +} + +bool VPlanVerifier::verifyBlocksInRegion(const VPRegionBlock *Region) { + for (const VPBlockBase *VPB : vp_depth_first_shallow(Region->getEntry())) { + // Check block's parent. + if (VPB->getParent() != Region) { + errs() << "VPBlockBase has wrong parent\n"; + return false; + } + + if (!verifyBlock(VPB)) + return false; + } + return true; +} + +bool VPlanVerifier::verifyRegion(const VPRegionBlock *Region) { + const VPBlockBase *Entry = Region->getEntry(); + const VPBlockBase *Exiting = Region->getExiting(); + + // Entry and Exiting shouldn't have any predecessor/successor, respectively. + if (Entry->getNumPredecessors() != 0) { + errs() << "region entry block has predecessors\n"; + return false; + } + if (Exiting->getNumSuccessors() != 0) { + errs() << "region exiting block has successors\n"; + return false; + } + + return verifyBlocksInRegion(Region); +} + +bool VPlanVerifier::verifyRegionRec(const VPRegionBlock *Region) { + // Recurse inside nested regions and check all blocks inside the region. + return verifyRegion(Region) && + all_of(vp_depth_first_shallow(Region->getEntry()), + [this](const VPBlockBase *VPB) { + const auto *SubRegion = dyn_cast<VPRegionBlock>(VPB); + return !SubRegion || verifyRegionRec(SubRegion); + }); +} + +bool VPlanVerifier::verify(const VPlan &Plan) { + if (any_of(vp_depth_first_shallow(Plan.getEntry()), + [this](const VPBlockBase *VPB) { return !verifyBlock(VPB); })) + return false; const VPRegionBlock *TopRegion = Plan.getVectorLoopRegion(); + if (!verifyRegionRec(TopRegion)) + return false; + + if (TopRegion->getParent()) { + errs() << "VPlan Top Region should have no parent.\n"; + return false; + } + const VPBasicBlock *Entry = dyn_cast<VPBasicBlock>(TopRegion->getEntry()); if (!Entry) { errs() << "VPlan entry block is not a VPBasicBlock\n"; @@ -274,19 +334,6 @@ bool VPlanVerifier::verifyPlanIsValid(const VPlan &Plan) { return false; } - for (const VPRegionBlock *Region : - VPBlockUtils::blocksOnly<const VPRegionBlock>( - vp_depth_first_deep(Plan.getEntry()))) { - if (Region->getEntry()->getNumPredecessors() != 0) { - errs() << "region entry block has predecessors\n"; - return false; - } - if (Region->getExiting()->getNumSuccessors() != 0) { - errs() << "region exiting block has successors\n"; - return false; - } - } - for (const auto &KV : Plan.getLiveOuts()) if (KV.second->getNumOperands() != 1) { errs() << "live outs must have a single operand\n"; @@ -295,3 +342,10 @@ bool VPlanVerifier::verifyPlanIsValid(const VPlan &Plan) { return true; } + +bool llvm::verifyVPlanIsValid(const VPlan &Plan) { + VPDominatorTree VPDT; + VPDT.recalculate(const_cast<VPlan &>(Plan)); + VPlanVerifier Verifier(VPDT); + return Verifier.verify(Plan); +} diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.h b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.h index 839c24e2c9f4..3ddc49fda36b 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.h +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VPlanVerifier.h @@ -25,24 +25,16 @@ #define LLVM_TRANSFORMS_VECTORIZE_VPLANVERIFIER_H namespace llvm { -class VPRegionBlock; class VPlan; -/// Struct with utility functions that can be used to check the consistency and -/// invariants of a VPlan, including the components of its H-CFG. -struct VPlanVerifier { - /// Verify the invariants of the H-CFG starting from \p TopRegion. The - /// verification process comprises the following steps: - /// 1. Region/Block verification: Check the Region/Block verification - /// invariants for every region in the H-CFG. - void verifyHierarchicalCFG(const VPRegionBlock *TopRegion) const; +/// Verify invariants for general VPlans. Currently it checks the following: +/// 1. Region/Block verification: Check the Region/Block verification +/// invariants for every region in the H-CFG. +/// 2. all phi-like recipes must be at the beginning of a block, with no other +/// recipes in between. Note that currently there is still an exception for +/// VPBlendRecipes. +bool verifyVPlanIsValid(const VPlan &Plan); - /// Verify invariants for general VPlans. Currently it checks the following: - /// 1. all phi-like recipes must be at the beginning of a block, with no other - /// recipes in between. Note that currently there is still an exception for - /// VPBlendRecipes. - static bool verifyPlanIsValid(const VPlan &Plan); -}; } // namespace llvm #endif //LLVM_TRANSFORMS_VECTORIZE_VPLANVERIFIER_H diff --git a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index f18711ba30b7..679934d07e36 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/Vectorize/VectorCombine.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" @@ -29,6 +30,7 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" #include <numeric> #include <queue> @@ -65,8 +67,8 @@ class VectorCombine { public: VectorCombine(Function &F, const TargetTransformInfo &TTI, const DominatorTree &DT, AAResults &AA, AssumptionCache &AC, - bool TryEarlyFoldsOnly) - : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC), + const DataLayout *DL, bool TryEarlyFoldsOnly) + : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC), DL(DL), TryEarlyFoldsOnly(TryEarlyFoldsOnly) {} bool run(); @@ -78,6 +80,7 @@ private: const DominatorTree &DT; AAResults &AA; AssumptionCache &AC; + const DataLayout *DL; /// If true, only perform beneficial early IR transforms. Do not introduce new /// vector operations. @@ -110,7 +113,11 @@ private: bool foldSingleElementStore(Instruction &I); bool scalarizeLoadExtract(Instruction &I); bool foldShuffleOfBinops(Instruction &I); + bool foldShuffleOfCastops(Instruction &I); + bool foldShuffleOfShuffles(Instruction &I); + bool foldShuffleToIdentity(Instruction &I); bool foldShuffleFromReductions(Instruction &I); + bool foldCastFromReductions(Instruction &I); bool foldSelectShuffle(Instruction &I, bool FromReduction = false); void replaceValue(Value &Old, Value &New) { @@ -132,6 +139,14 @@ private: }; } // namespace +/// Return the source operand of a potentially bitcasted value. If there is no +/// bitcast, return the input value itself. +static Value *peekThroughBitcasts(Value *V) { + while (auto *BitCast = dyn_cast<BitCastInst>(V)) + V = BitCast->getOperand(0); + return V; +} + static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) { // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan. // The widened load may load data from dirty regions or create data races @@ -179,7 +194,6 @@ bool VectorCombine::vectorizeLoadInsert(Instruction &I) { // We use minimal alignment (maximum flexibility) because we only care about // the dereferenceable region. When calculating cost and creating a new op, // we may use a larger value based on alignment attributes. - const DataLayout &DL = I.getModule()->getDataLayout(); Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts(); assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type"); @@ -187,15 +201,15 @@ bool VectorCombine::vectorizeLoadInsert(Instruction &I) { auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false); unsigned OffsetEltIndex = 0; Align Alignment = Load->getAlign(); - if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), DL, Load, &AC, + if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, &AC, &DT)) { // It is not safe to load directly from the pointer, but we can still peek // through gep offsets and check if it safe to load from a base address with // updated alignment. If it is, we can shuffle the element(s) into place // after loading. - unsigned OffsetBitWidth = DL.getIndexTypeSizeInBits(SrcPtr->getType()); + unsigned OffsetBitWidth = DL->getIndexTypeSizeInBits(SrcPtr->getType()); APInt Offset(OffsetBitWidth, 0); - SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(DL, Offset); + SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(*DL, Offset); // We want to shuffle the result down from a high element of a vector, so // the offset must be positive. @@ -213,7 +227,7 @@ bool VectorCombine::vectorizeLoadInsert(Instruction &I) { if (OffsetEltIndex >= MinVecNumElts) return false; - if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), DL, Load, &AC, + if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, &AC, &DT)) return false; @@ -225,7 +239,7 @@ bool VectorCombine::vectorizeLoadInsert(Instruction &I) { // Original pattern: insertelt undef, load [free casts of] PtrOp, 0 // Use the greater of the alignment on the load or its source pointer. - Alignment = std::max(SrcPtr->getPointerAlignment(DL), Alignment); + Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment); Type *LoadTy = Load->getType(); unsigned AS = Load->getPointerAddressSpace(); InstructionCost OldCost = @@ -296,14 +310,13 @@ bool VectorCombine::widenSubvectorLoad(Instruction &I) { // the dereferenceable region. When calculating cost and creating a new op, // we may use a larger value based on alignment attributes. auto *Ty = cast<FixedVectorType>(I.getType()); - const DataLayout &DL = I.getModule()->getDataLayout(); Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts(); assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type"); Align Alignment = Load->getAlign(); - if (!isSafeToLoadUnconditionally(SrcPtr, Ty, Align(1), DL, Load, &AC, &DT)) + if (!isSafeToLoadUnconditionally(SrcPtr, Ty, Align(1), *DL, Load, &AC, &DT)) return false; - Alignment = std::max(SrcPtr->getPointerAlignment(DL), Alignment); + Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment); Type *LoadTy = Load->getType(); unsigned AS = Load->getPointerAddressSpace(); @@ -682,10 +695,10 @@ bool VectorCombine::foldInsExtFNeg(Instruction &I) { /// destination type followed by shuffle. This can enable further transforms by /// moving bitcasts or shuffles together. bool VectorCombine::foldBitcastShuffle(Instruction &I) { - Value *V; + Value *V0, *V1; ArrayRef<int> Mask; - if (!match(&I, m_BitCast( - m_OneUse(m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask)))))) + if (!match(&I, m_BitCast(m_OneUse( + m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(Mask)))))) return false; // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for @@ -694,7 +707,7 @@ bool VectorCombine::foldBitcastShuffle(Instruction &I) { // 2) Disallow non-vector casts. // TODO: We could allow any shuffle. auto *DestTy = dyn_cast<FixedVectorType>(I.getType()); - auto *SrcTy = dyn_cast<FixedVectorType>(V->getType()); + auto *SrcTy = dyn_cast<FixedVectorType>(V0->getType()); if (!DestTy || !SrcTy) return false; @@ -703,6 +716,18 @@ bool VectorCombine::foldBitcastShuffle(Instruction &I) { if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0) return false; + bool IsUnary = isa<UndefValue>(V1); + + // For binary shuffles, only fold bitcast(shuffle(X,Y)) + // if it won't increase the number of bitcasts. + if (!IsUnary) { + auto *BCTy0 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V0)->getType()); + auto *BCTy1 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V1)->getType()); + if (!(BCTy0 && BCTy0->getElementType() == DestTy->getElementType()) && + !(BCTy1 && BCTy1->getElementType() == DestTy->getElementType())) + return false; + } + SmallVector<int, 16> NewMask; if (DestEltSize <= SrcEltSize) { // The bitcast is from wide to narrow/equal elements. The shuffle mask can @@ -722,21 +747,36 @@ bool VectorCombine::foldBitcastShuffle(Instruction &I) { // Bitcast the shuffle src - keep its original width but using the destination // scalar type. unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize; - auto *ShuffleTy = FixedVectorType::get(DestTy->getScalarType(), NumSrcElts); - - // The new shuffle must not cost more than the old shuffle. The bitcast is - // moved ahead of the shuffle, so assume that it has the same cost as before. - InstructionCost DestCost = TTI.getShuffleCost( - TargetTransformInfo::SK_PermuteSingleSrc, ShuffleTy, NewMask); + auto *NewShuffleTy = + FixedVectorType::get(DestTy->getScalarType(), NumSrcElts); + auto *OldShuffleTy = + FixedVectorType::get(SrcTy->getScalarType(), Mask.size()); + unsigned NumOps = IsUnary ? 1 : 2; + + // The new shuffle must not cost more than the old shuffle. + TargetTransformInfo::TargetCostKind CK = + TargetTransformInfo::TCK_RecipThroughput; + TargetTransformInfo::ShuffleKind SK = + IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc + : TargetTransformInfo::SK_PermuteTwoSrc; + + InstructionCost DestCost = + TTI.getShuffleCost(SK, NewShuffleTy, NewMask, CK) + + (NumOps * TTI.getCastInstrCost(Instruction::BitCast, NewShuffleTy, SrcTy, + TargetTransformInfo::CastContextHint::None, + CK)); InstructionCost SrcCost = - TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy, Mask); + TTI.getShuffleCost(SK, SrcTy, Mask, CK) + + TTI.getCastInstrCost(Instruction::BitCast, DestTy, OldShuffleTy, + TargetTransformInfo::CastContextHint::None, CK); if (DestCost > SrcCost || !DestCost.isValid()) return false; - // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC' + // bitcast (shuf V0, V1, MaskC) --> shuf (bitcast V0), (bitcast V1), MaskC' ++NumShufOfBitcast; - Value *CastV = Builder.CreateBitCast(V, ShuffleTy); - Value *Shuf = Builder.CreateShuffleVector(CastV, NewMask); + Value *CastV0 = Builder.CreateBitCast(peekThroughBitcasts(V0), NewShuffleTy); + Value *CastV1 = Builder.CreateBitCast(peekThroughBitcasts(V1), NewShuffleTy); + Value *Shuf = Builder.CreateShuffleVector(CastV0, CastV1, NewMask); replaceValue(I, *Shuf); return true; } @@ -784,9 +824,12 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) { // intrinsic VectorType *VecTy = cast<VectorType>(VPI.getType()); TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + SmallVector<int> Mask; + if (auto *FVTy = dyn_cast<FixedVectorType>(VecTy)) + Mask.resize(FVTy->getNumElements(), 0); InstructionCost SplatCost = TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) + - TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy); + TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, Mask); // Calculate the cost of the VP Intrinsic SmallVector<Type *, 4> Args; @@ -833,7 +876,6 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) { // Scalarize the intrinsic ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount(); Value *EVL = VPI.getArgOperand(3); - const DataLayout &DL = VPI.getModule()->getDataLayout(); // If the VP op might introduce UB or poison, we can scalarize it provided // that we know the EVL > 0: If the EVL is zero, then the original VP op @@ -846,7 +888,8 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) { else SafeToSpeculate = isSafeToSpeculativelyExecuteWithOpcode( *FunctionalOpcode, &VPI, nullptr, &AC, &DT); - if (!SafeToSpeculate && !isKnownNonZero(EVL, DL, 0, &AC, &VPI, &DT)) + if (!SafeToSpeculate && + !isKnownNonZero(EVL, SimplifyQuery(*DL, &DT, &AC, &VPI))) return false; Value *ScalarVal = @@ -1225,12 +1268,11 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) { if (auto *Load = dyn_cast<LoadInst>(Source)) { auto VecTy = cast<VectorType>(SI->getValueOperand()->getType()); - const DataLayout &DL = I.getModule()->getDataLayout(); Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts(); // Don't optimize for atomic/volatile load or store. Ensure memory is not // modified between, vector type matches store size, and index is inbounds. if (!Load->isSimple() || Load->getParent() != SI->getParent() || - !DL.typeSizeEqualsStoreSize(Load->getType()->getScalarType()) || + !DL->typeSizeEqualsStoreSize(Load->getType()->getScalarType()) || SrcAddr != SI->getPointerOperand()->stripPointerCasts()) return false; @@ -1249,7 +1291,7 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) { NSI->copyMetadata(*SI); Align ScalarOpAlignment = computeAlignmentAfterScalarization( std::max(SI->getAlign(), Load->getAlign()), NewElement->getType(), Idx, - DL); + *DL); NSI->setAlignment(ScalarOpAlignment); replaceValue(I, *NSI); eraseInstruction(I); @@ -1267,8 +1309,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { auto *VecTy = cast<VectorType>(I.getType()); auto *LI = cast<LoadInst>(&I); - const DataLayout &DL = I.getModule()->getDataLayout(); - if (LI->isVolatile() || !DL.typeSizeEqualsStoreSize(VecTy->getScalarType())) + if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType())) return false; InstructionCost OriginalCost = @@ -1346,7 +1387,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { VecTy->getElementType(), GEP, EI->getName() + ".scalar")); Align ScalarOpAlignment = computeAlignmentAfterScalarization( - LI->getAlign(), VecTy->getElementType(), Idx, DL); + LI->getAlign(), VecTy->getElementType(), Idx, *DL); NewLoad->setAlignment(ScalarOpAlignment); replaceValue(*EI, *NewLoad); @@ -1356,57 +1397,604 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { return true; } -/// Try to convert "shuffle (binop), (binop)" with a shared binop operand into -/// "binop (shuffle), (shuffle)". +/// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)". bool VectorCombine::foldShuffleOfBinops(Instruction &I) { - auto *VecTy = cast<FixedVectorType>(I.getType()); BinaryOperator *B0, *B1; - ArrayRef<int> Mask; + ArrayRef<int> OldMask; if (!match(&I, m_Shuffle(m_OneUse(m_BinOp(B0)), m_OneUse(m_BinOp(B1)), - m_Mask(Mask))) || - B0->getOpcode() != B1->getOpcode() || B0->getType() != VecTy) + m_Mask(OldMask)))) return false; - // Try to replace a binop with a shuffle if the shuffle is not costly. - // The new shuffle will choose from a single, common operand, so it may be - // cheaper than the existing two-operand shuffle. - SmallVector<int> UnaryMask = createUnaryMask(Mask, Mask.size()); + // Don't introduce poison into div/rem. + if (any_of(OldMask, [](int M) { return M == PoisonMaskElem; }) && + B0->isIntDivRem()) + return false; + + // TODO: Add support for addlike etc. Instruction::BinaryOps Opcode = B0->getOpcode(); - InstructionCost BinopCost = TTI.getArithmeticInstrCost(Opcode, VecTy); - InstructionCost ShufCost = TTI.getShuffleCost( - TargetTransformInfo::SK_PermuteSingleSrc, VecTy, UnaryMask); - if (ShufCost > BinopCost) + if (Opcode != B1->getOpcode()) return false; + auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); + auto *BinOpTy = dyn_cast<FixedVectorType>(B0->getType()); + if (!ShuffleDstTy || !BinOpTy) + return false; + + unsigned NumSrcElts = BinOpTy->getNumElements(); + // If we have something like "add X, Y" and "add Z, X", swap ops to match. Value *X = B0->getOperand(0), *Y = B0->getOperand(1); Value *Z = B1->getOperand(0), *W = B1->getOperand(1); - if (BinaryOperator::isCommutative(Opcode) && X != Z && Y != W) + if (BinaryOperator::isCommutative(Opcode) && X != Z && Y != W && + (X == W || Y == Z)) std::swap(X, Y); - Value *Shuf0, *Shuf1; + auto ConvertToUnary = [NumSrcElts](int &M) { + if (M >= (int)NumSrcElts) + M -= NumSrcElts; + }; + + SmallVector<int> NewMask0(OldMask.begin(), OldMask.end()); + TargetTransformInfo::ShuffleKind SK0 = TargetTransformInfo::SK_PermuteTwoSrc; if (X == Z) { - // shuf (bo X, Y), (bo X, W) --> bo (shuf X), (shuf Y, W) - Shuf0 = Builder.CreateShuffleVector(X, UnaryMask); - Shuf1 = Builder.CreateShuffleVector(Y, W, Mask); - } else if (Y == W) { - // shuf (bo X, Y), (bo Z, Y) --> bo (shuf X, Z), (shuf Y) - Shuf0 = Builder.CreateShuffleVector(X, Z, Mask); - Shuf1 = Builder.CreateShuffleVector(Y, UnaryMask); - } else { - return false; + llvm::for_each(NewMask0, ConvertToUnary); + SK0 = TargetTransformInfo::SK_PermuteSingleSrc; + Z = PoisonValue::get(BinOpTy); } + SmallVector<int> NewMask1(OldMask.begin(), OldMask.end()); + TargetTransformInfo::ShuffleKind SK1 = TargetTransformInfo::SK_PermuteTwoSrc; + if (Y == W) { + llvm::for_each(NewMask1, ConvertToUnary); + SK1 = TargetTransformInfo::SK_PermuteSingleSrc; + W = PoisonValue::get(BinOpTy); + } + + // Try to replace a binop with a shuffle if the shuffle is not costly. + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + + InstructionCost OldCost = + TTI.getArithmeticInstrCost(B0->getOpcode(), BinOpTy, CostKind) + + TTI.getArithmeticInstrCost(B1->getOpcode(), BinOpTy, CostKind) + + TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, BinOpTy, + OldMask, CostKind, 0, nullptr, {B0, B1}, &I); + + InstructionCost NewCost = + TTI.getShuffleCost(SK0, BinOpTy, NewMask0, CostKind, 0, nullptr, {X, Z}) + + TTI.getShuffleCost(SK1, BinOpTy, NewMask1, CostKind, 0, nullptr, {Y, W}) + + TTI.getArithmeticInstrCost(Opcode, ShuffleDstTy, CostKind); + + LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I + << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost + << "\n"); + if (NewCost >= OldCost) + return false; + + Value *Shuf0 = Builder.CreateShuffleVector(X, Z, NewMask0); + Value *Shuf1 = Builder.CreateShuffleVector(Y, W, NewMask1); Value *NewBO = Builder.CreateBinOp(Opcode, Shuf0, Shuf1); + // Intersect flags from the old binops. if (auto *NewInst = dyn_cast<Instruction>(NewBO)) { NewInst->copyIRFlags(B0); NewInst->andIRFlags(B1); } + + Worklist.pushValue(Shuf0); + Worklist.pushValue(Shuf1); replaceValue(I, *NewBO); return true; } +/// Try to convert "shuffle (castop), (castop)" with a shared castop operand +/// into "castop (shuffle)". +bool VectorCombine::foldShuffleOfCastops(Instruction &I) { + Value *V0, *V1; + ArrayRef<int> OldMask; + if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask)))) + return false; + + auto *C0 = dyn_cast<CastInst>(V0); + auto *C1 = dyn_cast<CastInst>(V1); + if (!C0 || !C1) + return false; + + Instruction::CastOps Opcode = C0->getOpcode(); + if (C0->getSrcTy() != C1->getSrcTy()) + return false; + + // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds. + if (Opcode != C1->getOpcode()) { + if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value()))) + Opcode = Instruction::SExt; + else + return false; + } + + auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); + auto *CastDstTy = dyn_cast<FixedVectorType>(C0->getDestTy()); + auto *CastSrcTy = dyn_cast<FixedVectorType>(C0->getSrcTy()); + if (!ShuffleDstTy || !CastDstTy || !CastSrcTy) + return false; + + unsigned NumSrcElts = CastSrcTy->getNumElements(); + unsigned NumDstElts = CastDstTy->getNumElements(); + assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) && + "Only bitcasts expected to alter src/dst element counts"); + + // Check for bitcasting of unscalable vector types. + // e.g. <32 x i40> -> <40 x i32> + if (NumDstElts != NumSrcElts && (NumSrcElts % NumDstElts) != 0 && + (NumDstElts % NumSrcElts) != 0) + return false; + + SmallVector<int, 16> NewMask; + if (NumSrcElts >= NumDstElts) { + // The bitcast is from wide to narrow/equal elements. The shuffle mask can + // always be expanded to the equivalent form choosing narrower elements. + assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask"); + unsigned ScaleFactor = NumSrcElts / NumDstElts; + narrowShuffleMaskElts(ScaleFactor, OldMask, NewMask); + } else { + // The bitcast is from narrow elements to wide elements. The shuffle mask + // must choose consecutive elements to allow casting first. + assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask"); + unsigned ScaleFactor = NumDstElts / NumSrcElts; + if (!widenShuffleMaskElts(ScaleFactor, OldMask, NewMask)) + return false; + } + + auto *NewShuffleDstTy = + FixedVectorType::get(CastSrcTy->getScalarType(), NewMask.size()); + + // Try to replace a castop with a shuffle if the shuffle is not costly. + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + + InstructionCost CostC0 = + TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy, + TTI::CastContextHint::None, CostKind); + InstructionCost CostC1 = + TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy, + TTI::CastContextHint::None, CostKind); + InstructionCost OldCost = CostC0 + CostC1; + OldCost += + TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, CastDstTy, + OldMask, CostKind, 0, nullptr, std::nullopt, &I); + + InstructionCost NewCost = TTI.getShuffleCost( + TargetTransformInfo::SK_PermuteTwoSrc, CastSrcTy, NewMask, CostKind); + NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy, + TTI::CastContextHint::None, CostKind); + if (!C0->hasOneUse()) + NewCost += CostC0; + if (!C1->hasOneUse()) + NewCost += CostC1; + + LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I + << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost + << "\n"); + if (NewCost > OldCost) + return false; + + Value *Shuf = Builder.CreateShuffleVector(C0->getOperand(0), + C1->getOperand(0), NewMask); + Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy); + + // Intersect flags from the old casts. + if (auto *NewInst = dyn_cast<Instruction>(Cast)) { + NewInst->copyIRFlags(C0); + NewInst->andIRFlags(C1); + } + + Worklist.pushValue(Shuf); + replaceValue(I, *Cast); + return true; +} + +/// Try to convert "shuffle (shuffle x, undef), (shuffle y, undef)" +/// into "shuffle x, y". +bool VectorCombine::foldShuffleOfShuffles(Instruction &I) { + Value *V0, *V1; + UndefValue *U0, *U1; + ArrayRef<int> OuterMask, InnerMask0, InnerMask1; + if (!match(&I, m_Shuffle(m_OneUse(m_Shuffle(m_Value(V0), m_UndefValue(U0), + m_Mask(InnerMask0))), + m_OneUse(m_Shuffle(m_Value(V1), m_UndefValue(U1), + m_Mask(InnerMask1))), + m_Mask(OuterMask)))) + return false; + + auto *ShufI0 = dyn_cast<Instruction>(I.getOperand(0)); + auto *ShufI1 = dyn_cast<Instruction>(I.getOperand(1)); + auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType()); + auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(V0->getType()); + auto *ShuffleImmTy = dyn_cast<FixedVectorType>(I.getOperand(0)->getType()); + if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy || + V0->getType() != V1->getType()) + return false; + + unsigned NumSrcElts = ShuffleSrcTy->getNumElements(); + unsigned NumImmElts = ShuffleImmTy->getNumElements(); + + // Bail if either inner masks reference a RHS undef arg. + if ((!isa<PoisonValue>(U0) && + any_of(InnerMask0, [&](int M) { return M >= (int)NumSrcElts; })) || + (!isa<PoisonValue>(U1) && + any_of(InnerMask1, [&](int M) { return M >= (int)NumSrcElts; }))) + return false; + + // Merge shuffles - replace index to the RHS poison arg with PoisonMaskElem, + SmallVector<int, 16> NewMask(OuterMask.begin(), OuterMask.end()); + for (int &M : NewMask) { + if (0 <= M && M < (int)NumImmElts) { + M = (InnerMask0[M] >= (int)NumSrcElts) ? PoisonMaskElem : InnerMask0[M]; + } else if (M >= (int)NumImmElts) { + if (InnerMask1[M - NumImmElts] >= (int)NumSrcElts) + M = PoisonMaskElem; + else + M = InnerMask1[M - NumImmElts] + (V0 == V1 ? 0 : NumSrcElts); + } + } + + // Have we folded to an Identity shuffle? + if (ShuffleVectorInst::isIdentityMask(NewMask, NumSrcElts)) { + replaceValue(I, *V0); + return true; + } + + // Try to merge the shuffles if the new shuffle is not costly. + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + + InstructionCost OldCost = + TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy, + InnerMask0, CostKind, 0, nullptr, {V0, U0}, ShufI0) + + TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleSrcTy, + InnerMask1, CostKind, 0, nullptr, {V1, U1}, ShufI1) + + TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleImmTy, + OuterMask, CostKind, 0, nullptr, {ShufI0, ShufI1}, &I); + + InstructionCost NewCost = + TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleSrcTy, + NewMask, CostKind, 0, nullptr, {V0, V1}); + + LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I + << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost + << "\n"); + if (NewCost > OldCost) + return false; + + // Clear unused sources to poison. + if (none_of(NewMask, [&](int M) { return 0 <= M && M < (int)NumSrcElts; })) + V0 = PoisonValue::get(ShuffleSrcTy); + if (none_of(NewMask, [&](int M) { return (int)NumSrcElts <= M; })) + V1 = PoisonValue::get(ShuffleSrcTy); + + Value *Shuf = Builder.CreateShuffleVector(V0, V1, NewMask); + replaceValue(I, *Shuf); + return true; +} + +using InstLane = std::pair<Use *, int>; + +static InstLane lookThroughShuffles(Use *U, int Lane) { + while (auto *SV = dyn_cast<ShuffleVectorInst>(U->get())) { + unsigned NumElts = + cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements(); + int M = SV->getMaskValue(Lane); + if (M < 0) + return {nullptr, PoisonMaskElem}; + if (static_cast<unsigned>(M) < NumElts) { + U = &SV->getOperandUse(0); + Lane = M; + } else { + U = &SV->getOperandUse(1); + Lane = M - NumElts; + } + } + return InstLane{U, Lane}; +} + +static SmallVector<InstLane> +generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) { + SmallVector<InstLane> NItem; + for (InstLane IL : Item) { + auto [U, Lane] = IL; + InstLane OpLane = + U ? lookThroughShuffles(&cast<Instruction>(U->get())->getOperandUse(Op), + Lane) + : InstLane{nullptr, PoisonMaskElem}; + NItem.emplace_back(OpLane); + } + return NItem; +} + +/// Detect concat of multiple values into a vector +static bool isFreeConcat(ArrayRef<InstLane> Item, + const TargetTransformInfo &TTI) { + auto *Ty = cast<FixedVectorType>(Item.front().first->get()->getType()); + unsigned NumElts = Ty->getNumElements(); + if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0) + return false; + + // Check that the concat is free, usually meaning that the type will be split + // during legalization. + SmallVector<int, 16> ConcatMask(NumElts * 2); + std::iota(ConcatMask.begin(), ConcatMask.end(), 0); + if (TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, Ty, ConcatMask, + TTI::TCK_RecipThroughput) != 0) + return false; + + unsigned NumSlices = Item.size() / NumElts; + // Currently we generate a tree of shuffles for the concats, which limits us + // to a power2. + if (!isPowerOf2_32(NumSlices)) + return false; + for (unsigned Slice = 0; Slice < NumSlices; ++Slice) { + Use *SliceV = Item[Slice * NumElts].first; + if (!SliceV || SliceV->get()->getType() != Ty) + return false; + for (unsigned Elt = 0; Elt < NumElts; ++Elt) { + auto [V, Lane] = Item[Slice * NumElts + Elt]; + if (Lane != static_cast<int>(Elt) || SliceV->get() != V->get()) + return false; + } + } + return true; +} + +static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty, + const SmallPtrSet<Use *, 4> &IdentityLeafs, + const SmallPtrSet<Use *, 4> &SplatLeafs, + const SmallPtrSet<Use *, 4> &ConcatLeafs, + IRBuilder<> &Builder) { + auto [FrontU, FrontLane] = Item.front(); + + if (IdentityLeafs.contains(FrontU)) { + return FrontU->get(); + } + if (SplatLeafs.contains(FrontU)) { + SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane); + return Builder.CreateShuffleVector(FrontU->get(), Mask); + } + if (ConcatLeafs.contains(FrontU)) { + unsigned NumElts = + cast<FixedVectorType>(FrontU->get()->getType())->getNumElements(); + SmallVector<Value *> Values(Item.size() / NumElts, nullptr); + for (unsigned S = 0; S < Values.size(); ++S) + Values[S] = Item[S * NumElts].first->get(); + + while (Values.size() > 1) { + NumElts *= 2; + SmallVector<int, 16> Mask(NumElts, 0); + std::iota(Mask.begin(), Mask.end(), 0); + SmallVector<Value *> NewValues(Values.size() / 2, nullptr); + for (unsigned S = 0; S < NewValues.size(); ++S) + NewValues[S] = + Builder.CreateShuffleVector(Values[S * 2], Values[S * 2 + 1], Mask); + Values = NewValues; + } + return Values[0]; + } + + auto *I = cast<Instruction>(FrontU->get()); + auto *II = dyn_cast<IntrinsicInst>(I); + unsigned NumOps = I->getNumOperands() - (II ? 1 : 0); + SmallVector<Value *> Ops(NumOps); + for (unsigned Idx = 0; Idx < NumOps; Idx++) { + if (II && isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx)) { + Ops[Idx] = II->getOperand(Idx); + continue; + } + Ops[Idx] = + generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx), Ty, + IdentityLeafs, SplatLeafs, ConcatLeafs, Builder); + } + + SmallVector<Value *, 8> ValueList; + for (const auto &Lane : Item) + if (Lane.first) + ValueList.push_back(Lane.first->get()); + + Type *DstTy = + FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements()); + if (auto *BI = dyn_cast<BinaryOperator>(I)) { + auto *Value = Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(), + Ops[0], Ops[1]); + propagateIRFlags(Value, ValueList); + return Value; + } + if (auto *CI = dyn_cast<CmpInst>(I)) { + auto *Value = Builder.CreateCmp(CI->getPredicate(), Ops[0], Ops[1]); + propagateIRFlags(Value, ValueList); + return Value; + } + if (auto *SI = dyn_cast<SelectInst>(I)) { + auto *Value = Builder.CreateSelect(Ops[0], Ops[1], Ops[2], "", SI); + propagateIRFlags(Value, ValueList); + return Value; + } + if (auto *CI = dyn_cast<CastInst>(I)) { + auto *Value = Builder.CreateCast((Instruction::CastOps)CI->getOpcode(), + Ops[0], DstTy); + propagateIRFlags(Value, ValueList); + return Value; + } + if (II) { + auto *Value = Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops); + propagateIRFlags(Value, ValueList); + return Value; + } + assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate"); + auto *Value = + Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]); + propagateIRFlags(Value, ValueList); + return Value; +} + +// Starting from a shuffle, look up through operands tracking the shuffled index +// of each lane. If we can simplify away the shuffles to identities then +// do so. +bool VectorCombine::foldShuffleToIdentity(Instruction &I) { + auto *Ty = dyn_cast<FixedVectorType>(I.getType()); + if (!Ty || I.use_empty()) + return false; + + SmallVector<InstLane> Start(Ty->getNumElements()); + for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M) + Start[M] = lookThroughShuffles(&*I.use_begin(), M); + + SmallVector<SmallVector<InstLane>> Worklist; + Worklist.push_back(Start); + SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs, ConcatLeafs; + unsigned NumVisited = 0; + + while (!Worklist.empty()) { + if (++NumVisited > MaxInstrsToScan) + return false; + + SmallVector<InstLane> Item = Worklist.pop_back_val(); + auto [FrontU, FrontLane] = Item.front(); + + // If we found an undef first lane then bail out to keep things simple. + if (!FrontU) + return false; + + // Helper to peek through bitcasts to the same value. + auto IsEquiv = [&](Value *X, Value *Y) { + return X->getType() == Y->getType() && + peekThroughBitcasts(X) == peekThroughBitcasts(Y); + }; + + // Look for an identity value. + if (FrontLane == 0 && + cast<FixedVectorType>(FrontU->get()->getType())->getNumElements() == + Ty->getNumElements() && + all_of(drop_begin(enumerate(Item)), [IsEquiv, Item](const auto &E) { + Value *FrontV = Item.front().first->get(); + return !E.value().first || (IsEquiv(E.value().first->get(), FrontV) && + E.value().second == (int)E.index()); + })) { + IdentityLeafs.insert(FrontU); + continue; + } + // Look for constants, for the moment only supporting constant splats. + if (auto *C = dyn_cast<Constant>(FrontU); + C && C->getSplatValue() && + all_of(drop_begin(Item), [Item](InstLane &IL) { + Value *FrontV = Item.front().first->get(); + Use *U = IL.first; + return !U || U->get() == FrontV; + })) { + SplatLeafs.insert(FrontU); + continue; + } + // Look for a splat value. + if (all_of(drop_begin(Item), [Item](InstLane &IL) { + auto [FrontU, FrontLane] = Item.front(); + auto [U, Lane] = IL; + return !U || (U->get() == FrontU->get() && Lane == FrontLane); + })) { + SplatLeafs.insert(FrontU); + continue; + } + + // We need each element to be the same type of value, and check that each + // element has a single use. + auto CheckLaneIsEquivalentToFirst = [Item](InstLane IL) { + Value *FrontV = Item.front().first->get(); + if (!IL.first) + return true; + Value *V = IL.first->get(); + if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUse()) + return false; + if (V->getValueID() != FrontV->getValueID()) + return false; + if (auto *CI = dyn_cast<CmpInst>(V)) + if (CI->getPredicate() != cast<CmpInst>(FrontV)->getPredicate()) + return false; + if (auto *CI = dyn_cast<CastInst>(V)) + if (CI->getSrcTy() != cast<CastInst>(FrontV)->getSrcTy()) + return false; + if (auto *SI = dyn_cast<SelectInst>(V)) + if (!isa<VectorType>(SI->getOperand(0)->getType()) || + SI->getOperand(0)->getType() != + cast<SelectInst>(FrontV)->getOperand(0)->getType()) + return false; + if (isa<CallInst>(V) && !isa<IntrinsicInst>(V)) + return false; + auto *II = dyn_cast<IntrinsicInst>(V); + return !II || (isa<IntrinsicInst>(FrontV) && + II->getIntrinsicID() == + cast<IntrinsicInst>(FrontV)->getIntrinsicID() && + !II->hasOperandBundles()); + }; + if (all_of(drop_begin(Item), CheckLaneIsEquivalentToFirst)) { + // Check the operator is one that we support. + if (isa<BinaryOperator, CmpInst>(FrontU)) { + // We exclude div/rem in case they hit UB from poison lanes. + if (auto *BO = dyn_cast<BinaryOperator>(FrontU); + BO && BO->isIntDivRem()) + return false; + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1)); + continue; + } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst>(FrontU)) { + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); + continue; + } else if (auto *BitCast = dyn_cast<BitCastInst>(FrontU)) { + // TODO: Handle vector widening/narrowing bitcasts. + auto *DstTy = dyn_cast<FixedVectorType>(BitCast->getDestTy()); + auto *SrcTy = dyn_cast<FixedVectorType>(BitCast->getSrcTy()); + if (DstTy && SrcTy && + SrcTy->getNumElements() == DstTy->getNumElements()) { + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); + continue; + } + } else if (isa<SelectInst>(FrontU)) { + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0)); + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1)); + Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2)); + continue; + } else if (auto *II = dyn_cast<IntrinsicInst>(FrontU); + II && isTriviallyVectorizable(II->getIntrinsicID()) && + !II->hasOperandBundles()) { + for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) { + if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op)) { + if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) { + Value *FrontV = Item.front().first->get(); + Use *U = IL.first; + return !U || (cast<Instruction>(U->get())->getOperand(Op) == + cast<Instruction>(FrontV)->getOperand(Op)); + })) + return false; + continue; + } + Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op)); + } + continue; + } + } + + if (isFreeConcat(Item, TTI)) { + ConcatLeafs.insert(FrontU); + continue; + } + + return false; + } + + if (NumVisited <= 1) + return false; + + // If we got this far, we know the shuffles are superfluous and can be + // removed. Scan through again and generate the new tree of instructions. + Builder.SetInsertPoint(&I); + Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs, + ConcatLeafs, Builder); + replaceValue(I, *V); + return true; +} + /// Given a commutative reduction, the order of the input lanes does not alter /// the results. We can use this to remove certain shuffles feeding the /// reduction, removing the need to shuffle at all. @@ -1526,6 +2114,67 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) { return foldSelectShuffle(*Shuffle, true); } +/// Determine if its more efficient to fold: +/// reduce(trunc(x)) -> trunc(reduce(x)). +/// reduce(sext(x)) -> sext(reduce(x)). +/// reduce(zext(x)) -> zext(reduce(x)). +bool VectorCombine::foldCastFromReductions(Instruction &I) { + auto *II = dyn_cast<IntrinsicInst>(&I); + if (!II) + return false; + + bool TruncOnly = false; + Intrinsic::ID IID = II->getIntrinsicID(); + switch (IID) { + case Intrinsic::vector_reduce_add: + case Intrinsic::vector_reduce_mul: + TruncOnly = true; + break; + case Intrinsic::vector_reduce_and: + case Intrinsic::vector_reduce_or: + case Intrinsic::vector_reduce_xor: + break; + default: + return false; + } + + unsigned ReductionOpc = getArithmeticReductionInstruction(IID); + Value *ReductionSrc = I.getOperand(0); + + Value *Src; + if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) && + (TruncOnly || !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src)))))) + return false; + + auto CastOpc = + (Instruction::CastOps)cast<Instruction>(ReductionSrc)->getOpcode(); + + auto *SrcTy = cast<VectorType>(Src->getType()); + auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType()); + Type *ResultTy = I.getType(); + + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + InstructionCost OldCost = TTI.getArithmeticReductionCost( + ReductionOpc, ReductionSrcTy, std::nullopt, CostKind); + OldCost += TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy, + TTI::CastContextHint::None, CostKind, + cast<CastInst>(ReductionSrc)); + InstructionCost NewCost = + TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt, + CostKind) + + TTI.getCastInstrCost(CastOpc, ResultTy, ReductionSrcTy->getScalarType(), + TTI::CastContextHint::None, CostKind); + + if (OldCost <= NewCost || !NewCost.isValid()) + return false; + + Value *NewReduction = Builder.CreateIntrinsic(SrcTy->getScalarType(), + II->getIntrinsicID(), {Src}); + Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy); + replaceValue(I, *NewCast); + return true; +} + /// This method looks for groups of shuffles acting on binops, of the form: /// %x = shuffle ... /// %y = shuffle ... @@ -1907,7 +2556,10 @@ bool VectorCombine::run() { break; case Instruction::ShuffleVector: MadeChange |= foldShuffleOfBinops(I); + MadeChange |= foldShuffleOfCastops(I); + MadeChange |= foldShuffleOfShuffles(I); MadeChange |= foldSelectShuffle(I); + MadeChange |= foldShuffleToIdentity(I); break; case Instruction::BitCast: MadeChange |= foldBitcastShuffle(I); @@ -1917,6 +2569,7 @@ bool VectorCombine::run() { switch (Opcode) { case Instruction::Call: MadeChange |= foldShuffleFromReductions(I); + MadeChange |= foldCastFromReductions(I); break; case Instruction::ICmp: case Instruction::FCmp: @@ -1966,7 +2619,8 @@ PreservedAnalyses VectorCombinePass::run(Function &F, TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); AAResults &AA = FAM.getResult<AAManager>(F); - VectorCombine Combiner(F, TTI, DT, AA, AC, TryEarlyFoldsOnly); + const DataLayout *DL = &F.getDataLayout(); + VectorCombine Combiner(F, TTI, DT, AA, AC, DL, TryEarlyFoldsOnly); if (!Combiner.run()) return PreservedAnalyses::all(); PreservedAnalyses PA; |